256 lines
9.6 KiB
Python
256 lines
9.6 KiB
Python
from __future__ import annotations
|
|
|
|
import collections.abc
|
|
from typing import Callable
|
|
|
|
from .errors import CompilationError, UnexpectedTokenError
|
|
from .lexer import Tokens, Token
|
|
from .logger import Logger, Tracer, LogLevel
|
|
from .nodes import Float, Sum, Value, Product, Node, Division, Sub, Integer, Expression, Identifier, Assignment, \
|
|
Variable, Statement, PseudoNode, Block, Definition, Call
|
|
|
|
logger = Logger(__name__)
|
|
tracer = Tracer(logger, level=LogLevel.Debug)
|
|
|
|
|
|
class Parser:
|
|
def __init__(self, tokens: collections.abc.Iterator[Token]):
|
|
self.tokens = tokens
|
|
self._token_cache: list[Token] = []
|
|
self._EOF = False
|
|
self.pos = 0
|
|
self._last_accepted_token: Tokens | None = None
|
|
|
|
@property
|
|
def token(self) -> Token:
|
|
if self._EOF:
|
|
return self._token_cache[-1]
|
|
|
|
while len(self._token_cache) <= self.pos:
|
|
tok = next(self.tokens)
|
|
self._token_cache.append(tok)
|
|
if tok.kind == Tokens.EOF:
|
|
self._EOF = True
|
|
return tok
|
|
|
|
return self._token_cache[self.pos]
|
|
|
|
def _fetch_until(self, desired_length: int) -> int:
|
|
while len(self._token_cache) <= desired_length and not self._EOF:
|
|
tok = next(self.tokens)
|
|
self._token_cache.append(tok)
|
|
if tok.kind == Tokens.EOF:
|
|
self._EOF = True
|
|
break
|
|
|
|
return len(self._token_cache)
|
|
|
|
@property
|
|
def prev_token(self) -> Token:
|
|
return self._token_cache[self.pos - 1]
|
|
|
|
def next_symbol(self):
|
|
self.pos += 1
|
|
logger.debug(f"Advancing to token {self.pos} {self.token}")
|
|
|
|
@tracer.trace_method(level=LogLevel.Trace)
|
|
def accept(self, *token_types: Tokens) -> False | Token:
|
|
tok = self.token
|
|
if self.token.kind in token_types:
|
|
self._last_accepted_token = self.token
|
|
self.next_symbol()
|
|
return tok
|
|
return False
|
|
|
|
def peek(self, token_type: Tokens) -> False | Token:
|
|
tok = self.token
|
|
if self.token.kind == token_type:
|
|
return tok
|
|
return False
|
|
|
|
def peek_several(self, *tokens_types: Tokens) -> False | list[Token]:
|
|
desired_pos = self.pos + len(tokens_types)
|
|
if desired_pos >= self._fetch_until(desired_pos + 1):
|
|
return False
|
|
|
|
toks = self._token_cache[self.pos:self.pos + len(tokens_types)]
|
|
for i, token in enumerate(toks):
|
|
if token.kind != tokens_types[i]:
|
|
return False
|
|
|
|
return toks
|
|
|
|
@tracer.trace_method(level=LogLevel.Trace)
|
|
def expect(self, token_type: Tokens) -> Token:
|
|
r = self.accept(token_type)
|
|
logger.debug(f"Expecting {token_type}, got {r}")
|
|
if r is False:
|
|
raise UnexpectedTokenError(self.token, token_type)
|
|
return r
|
|
|
|
@tracer.trace_method
|
|
def number(self, mandatory: bool = False):
|
|
if tok := self.accept(Tokens.Float):
|
|
logger.debug(f"Found float {tok}")
|
|
return Float(location=tok.loc, value=float(tok.value))
|
|
elif tok := self.accept(Tokens.Integer):
|
|
logger.debug(f"Found integer {tok}")
|
|
return Integer(location=tok.loc, value=int(tok.value))
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "integer or float")
|
|
|
|
@tracer.trace_method
|
|
def identifier(self, mandatory: bool = False) -> Identifier:
|
|
if ident := self.accept(Tokens.Identifier):
|
|
return Identifier(location=ident.loc, name=str(ident.value))
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "identifier")
|
|
|
|
@tracer.trace_method
|
|
def variable(self, mandatory: bool = False) -> Variable:
|
|
if ident := self.identifier(mandatory=False):
|
|
return Variable(identifier=ident)
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "variable identifier")
|
|
|
|
@tracer.trace_method
|
|
def binary_op(self, operand_func: Callable[[bool], Value], operators: dict[Tokens, Value], mandatory: bool = True):
|
|
operand = operand_func(mandatory)
|
|
if not operand and mandatory:
|
|
raise UnexpectedTokenError(operand, "operand")
|
|
|
|
while operator := self.accept(*list(operators.keys())):
|
|
node_type = operators[operator.kind]
|
|
operand2 = operand_func(True)
|
|
operand = node_type(operand, operand2)
|
|
logger.debug(f"{node_type.__name__} of the following operands: {operand} and {operand2}")
|
|
|
|
return operand
|
|
|
|
@tracer.trace_method
|
|
def factor(self, mandatory: bool = False) -> Value:
|
|
if par_expression := self.parenthesized_expression(mandatory=False):
|
|
return par_expression
|
|
elif num := self.number(mandatory=False):
|
|
return num
|
|
elif call := self.call(mandatory=False):
|
|
return call
|
|
elif variable := self.variable(mandatory=False):
|
|
return variable
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "parenthesized expression, number, function call or variable")
|
|
|
|
@tracer.trace_method
|
|
def term(self, mandatory: bool = False) -> Value:
|
|
return self.binary_op(self.factor, operators={
|
|
Tokens.Op_Multiply: Product,
|
|
Tokens.Op_Divide: Division,
|
|
}, mandatory=mandatory)
|
|
|
|
@tracer.trace_method
|
|
def summation(self, mandatory: bool = True) -> Sum:
|
|
return self.binary_op(self.term, operators={
|
|
Tokens.Op_Plus: Sum,
|
|
Tokens.Op_Minus: Sub,
|
|
}, mandatory=mandatory)
|
|
|
|
@tracer.trace_method
|
|
def assignment(self, mandatory: bool = False) -> Assignment:
|
|
if ident := self.identifier(mandatory):
|
|
self.expect(Tokens.Equal)
|
|
expr = self.expression(mandatory=True)
|
|
return Assignment(ident, expr)
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "assignment")
|
|
|
|
@tracer.trace_method
|
|
def call(self, mandatory: bool = False) -> Call:
|
|
if self.peek_several(Tokens.Identifier, Tokens.Parens_Left):
|
|
ident = self.identifier(mandatory=True)
|
|
lparens = self.expect(Tokens.Parens_Left)
|
|
expressions: list[Expression] = []
|
|
if expr := self.expression(mandatory=False):
|
|
expressions += [expr]
|
|
while self.accept(Tokens.Comma):
|
|
expressions += [self.expression(mandatory=True)]
|
|
rparens = self.expect(Tokens.Parens_Right)
|
|
return Call(identifier=ident, arguments=expressions,
|
|
pseudo_nodes=[PseudoNode(lparens), PseudoNode(rparens)])
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "function call")
|
|
|
|
@tracer.trace_method
|
|
def definition(self, mandatory: bool = False) -> Definition:
|
|
if let_kw := self.accept(Tokens.KwLet):
|
|
ident = self.identifier(mandatory=True)
|
|
self.expect(Tokens.Colon)
|
|
type_ident = self.identifier(mandatory=True)
|
|
expr = None
|
|
if self.accept(Tokens.Equal):
|
|
expr = self.expression(mandatory=False)
|
|
return Definition(ident, type_ident, expr, PseudoNode(let_kw))
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "definition")
|
|
|
|
@tracer.trace_method
|
|
def parenthesized_expression(self, mandatory: bool = False):
|
|
if lparens := self.accept(Tokens.Parens_Left):
|
|
expression = self.expression(mandatory=True)
|
|
rparens = self.expect(Tokens.Parens_Right)
|
|
expression.pseudo_nodes = [PseudoNode(lparens), PseudoNode(rparens)]
|
|
return expression
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "parenthesized expression")
|
|
|
|
@tracer.trace_method
|
|
def expression(self, mandatory: bool = False) -> Value:
|
|
node: Node | None = None
|
|
if self.peek(Tokens.KwLet) and (definition := self.definition(mandatory)):
|
|
node = definition
|
|
elif self.peek_several(Tokens.Identifier, Tokens.Equal) and (assignment := self.assignment(mandatory)):
|
|
node = assignment
|
|
elif summation := self.summation(mandatory):
|
|
node = summation
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(self.token, "expression")
|
|
|
|
if node is not None:
|
|
return Expression(node)
|
|
|
|
@tracer.trace_method
|
|
def statement(self, mandatory: bool = False) -> Statement:
|
|
if lbrace := self.accept(Tokens.Brace_Left):
|
|
block = self.block(name="anon")
|
|
rbrace = self.expect(Tokens.Brace_Right)
|
|
block.pseudo_nodes = [PseudoNode(lbrace), PseudoNode(rbrace)]
|
|
return Statement(block)
|
|
elif expr := self.expression(mandatory):
|
|
semicolon = PseudoNode(self.expect(Tokens.Semicolon))
|
|
return Statement(expr, pseudo_nodes=[semicolon])
|
|
elif mandatory:
|
|
raise UnexpectedTokenError(expr, wanted="expression")
|
|
|
|
def block(self, name: str, pseudo_nodes: list[PseudoNode] | None = None) -> Block:
|
|
nodes: list[Statement] = []
|
|
while stmt := self.statement(mandatory=False):
|
|
nodes += [stmt]
|
|
return Block(name, *nodes, pseudo_nodes=pseudo_nodes)
|
|
|
|
@tracer.trace_method
|
|
def root(self) -> Node:
|
|
begin = self.expect(Tokens.BEGIN)
|
|
root_block = self.block(name="root")
|
|
end = self.expect(Tokens.EOF)
|
|
root_block.pseudo_nodes = [begin, end]
|
|
return root_block
|
|
|
|
def parse(self) -> Node:
|
|
try:
|
|
return self.root()
|
|
except CompilationError:
|
|
raise
|
|
except Exception as e:
|
|
tok = self._last_accepted_token
|
|
if tok is None:
|
|
tok = self.token
|
|
raise CompilationError(tok.loc) from e
|