compiler/compiler/parser.py

256 lines
9.6 KiB
Python

from __future__ import annotations
import collections.abc
from typing import Callable
from .errors import CompilationError, UnexpectedTokenError
from .lexer import Tokens, Token
from .logger import Logger, Tracer, LogLevel
from .nodes import Float, Sum, Value, Product, Node, Division, Sub, Integer, Expression, Identifier, Assignment, \
Variable, Statement, PseudoNode, Block, Definition, Call
logger = Logger(__name__)
tracer = Tracer(logger, level=LogLevel.Debug)
class Parser:
def __init__(self, tokens: collections.abc.Iterator[Token]):
self.tokens = tokens
self._token_cache: list[Token] = []
self._EOF = False
self.pos = 0
self._last_accepted_token: Tokens | None = None
@property
def token(self) -> Token:
if self._EOF:
return self._token_cache[-1]
while len(self._token_cache) <= self.pos:
tok = next(self.tokens)
self._token_cache.append(tok)
if tok.kind == Tokens.EOF:
self._EOF = True
return tok
return self._token_cache[self.pos]
def _fetch_until(self, desired_length: int) -> int:
while len(self._token_cache) <= desired_length and not self._EOF:
tok = next(self.tokens)
self._token_cache.append(tok)
if tok.kind == Tokens.EOF:
self._EOF = True
break
return len(self._token_cache)
@property
def prev_token(self) -> Token:
return self._token_cache[self.pos - 1]
def next_symbol(self):
self.pos += 1
logger.debug(f"Advancing to token {self.pos} {self.token}")
@tracer.trace_method(level=LogLevel.Trace)
def accept(self, *token_types: Tokens) -> False | Token:
tok = self.token
if self.token.kind in token_types:
self._last_accepted_token = self.token
self.next_symbol()
return tok
return False
def peek(self, token_type: Tokens) -> False | Token:
tok = self.token
if self.token.kind == token_type:
return tok
return False
def peek_several(self, *tokens_types: Tokens) -> False | list[Token]:
desired_pos = self.pos + len(tokens_types)
if desired_pos >= self._fetch_until(desired_pos + 1):
return False
toks = self._token_cache[self.pos:self.pos + len(tokens_types)]
for i, token in enumerate(toks):
if token.kind != tokens_types[i]:
return False
return toks
@tracer.trace_method(level=LogLevel.Trace)
def expect(self, token_type: Tokens) -> Token:
r = self.accept(token_type)
logger.debug(f"Expecting {token_type}, got {r}")
if r is False:
raise UnexpectedTokenError(self.token, token_type)
return r
@tracer.trace_method
def number(self, mandatory: bool = False):
if tok := self.accept(Tokens.Float):
logger.debug(f"Found float {tok}")
return Float(location=tok.loc, value=float(tok.value))
elif tok := self.accept(Tokens.Integer):
logger.debug(f"Found integer {tok}")
return Integer(location=tok.loc, value=int(tok.value))
elif mandatory:
raise UnexpectedTokenError(self.token, "integer or float")
@tracer.trace_method
def identifier(self, mandatory: bool = False) -> Identifier:
if ident := self.accept(Tokens.Identifier):
return Identifier(location=ident.loc, name=str(ident.value))
elif mandatory:
raise UnexpectedTokenError(self.token, "identifier")
@tracer.trace_method
def variable(self, mandatory: bool = False) -> Variable:
if ident := self.identifier(mandatory=False):
return Variable(identifier=ident)
elif mandatory:
raise UnexpectedTokenError(self.token, "variable identifier")
@tracer.trace_method
def binary_op(self, operand_func: Callable[[bool], Value], operators: dict[Tokens, Value], mandatory: bool = True):
operand = operand_func(mandatory)
if not operand and mandatory:
raise UnexpectedTokenError(operand, "operand")
while operator := self.accept(*list(operators.keys())):
node_type = operators[operator.kind]
operand2 = operand_func(True)
operand = node_type(operand, operand2)
logger.debug(f"{node_type.__name__} of the following operands: {operand} and {operand2}")
return operand
@tracer.trace_method
def factor(self, mandatory: bool = False) -> Value:
if par_expression := self.parenthesized_expression(mandatory=False):
return par_expression
elif num := self.number(mandatory=False):
return num
elif call := self.call(mandatory=False):
return call
elif variable := self.variable(mandatory=False):
return variable
elif mandatory:
raise UnexpectedTokenError(self.token, "parenthesized expression, number, function call or variable")
@tracer.trace_method
def term(self, mandatory: bool = False) -> Value:
return self.binary_op(self.factor, operators={
Tokens.Op_Multiply: Product,
Tokens.Op_Divide: Division,
}, mandatory=mandatory)
@tracer.trace_method
def summation(self, mandatory: bool = True) -> Sum:
return self.binary_op(self.term, operators={
Tokens.Op_Plus: Sum,
Tokens.Op_Minus: Sub,
}, mandatory=mandatory)
@tracer.trace_method
def assignment(self, mandatory: bool = False) -> Assignment:
if ident := self.identifier(mandatory):
self.expect(Tokens.Equal)
expr = self.expression(mandatory=True)
return Assignment(ident, expr)
elif mandatory:
raise UnexpectedTokenError(self.token, "assignment")
@tracer.trace_method
def call(self, mandatory: bool = False) -> Call:
if self.peek_several(Tokens.Identifier, Tokens.Parens_Left):
ident = self.identifier(mandatory=True)
lparens = self.expect(Tokens.Parens_Left)
expressions: list[Expression] = []
if expr := self.expression(mandatory=False):
expressions += [expr]
while self.accept(Tokens.Comma):
expressions += [self.expression(mandatory=True)]
rparens = self.expect(Tokens.Parens_Right)
return Call(identifier=ident, arguments=expressions,
pseudo_nodes=[PseudoNode(lparens), PseudoNode(rparens)])
elif mandatory:
raise UnexpectedTokenError(self.token, "function call")
@tracer.trace_method
def definition(self, mandatory: bool = False) -> Definition:
if let_kw := self.accept(Tokens.KwLet):
ident = self.identifier(mandatory=True)
self.expect(Tokens.Colon)
type_ident = self.identifier(mandatory=True)
expr = None
if self.accept(Tokens.Equal):
expr = self.expression(mandatory=False)
return Definition(ident, type_ident, expr, PseudoNode(let_kw))
elif mandatory:
raise UnexpectedTokenError(self.token, "definition")
@tracer.trace_method
def parenthesized_expression(self, mandatory: bool = False):
if lparens := self.accept(Tokens.Parens_Left):
expression = self.expression(mandatory=True)
rparens = self.expect(Tokens.Parens_Right)
expression.pseudo_nodes = [PseudoNode(lparens), PseudoNode(rparens)]
return expression
elif mandatory:
raise UnexpectedTokenError(self.token, "parenthesized expression")
@tracer.trace_method
def expression(self, mandatory: bool = False) -> Value:
node: Node | None = None
if self.peek(Tokens.KwLet) and (definition := self.definition(mandatory)):
node = definition
elif self.peek_several(Tokens.Identifier, Tokens.Equal) and (assignment := self.assignment(mandatory)):
node = assignment
elif summation := self.summation(mandatory):
node = summation
elif mandatory:
raise UnexpectedTokenError(self.token, "expression")
if node is not None:
return Expression(node)
@tracer.trace_method
def statement(self, mandatory: bool = False) -> Statement:
if lbrace := self.accept(Tokens.Brace_Left):
block = self.block(name="anon")
rbrace = self.expect(Tokens.Brace_Right)
block.pseudo_nodes = [PseudoNode(lbrace), PseudoNode(rbrace)]
return Statement(block)
elif expr := self.expression(mandatory):
semicolon = PseudoNode(self.expect(Tokens.Semicolon))
return Statement(expr, pseudo_nodes=[semicolon])
elif mandatory:
raise UnexpectedTokenError(expr, wanted="expression")
def block(self, name: str, pseudo_nodes: list[PseudoNode] | None = None) -> Block:
nodes: list[Statement] = []
while stmt := self.statement(mandatory=False):
nodes += [stmt]
return Block(name, *nodes, pseudo_nodes=pseudo_nodes)
@tracer.trace_method
def root(self) -> Node:
begin = self.expect(Tokens.BEGIN)
root_block = self.block(name="root")
end = self.expect(Tokens.EOF)
root_block.pseudo_nodes = [begin, end]
return root_block
def parse(self) -> Node:
try:
return self.root()
except CompilationError:
raise
except Exception as e:
tok = self._last_accepted_token
if tok is None:
tok = self.token
raise CompilationError(tok.loc) from e