nodes+parser: refactor parsing of terms and factors

Make a distinction between Summation, Subtraction, Product and Division.
Also distinguish Integers and Floats
This commit is contained in:
Antoine Viallon 2023-05-08 17:43:34 +02:00
parent fd13900e9b
commit be9f389159
Signed by: aviallon
GPG key ID: D126B13AB555E16F
2 changed files with 83 additions and 40 deletions

View file

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from beartype import beartype
class Node: class Node:
pass pass
@ -10,7 +12,15 @@ class Operator(Node):
class Sum(Node): class Sum(Node):
def __init__(self, *values: Expression): def __init__(self, *values: Value):
self.values = values
def __repr__(self):
return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})"
class Sub(Node):
def __init__(self, *values: Value):
self.values = values self.values = values
def __repr__(self): def __repr__(self):
@ -18,14 +28,35 @@ class Sum(Node):
class Product(Node): class Product(Node):
def __init__(self, *values: Expression): def __init__(self, *values: Value):
self.values = values self.values = values
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})"
class Number(Node): class Division(Node):
def __init__(self, *values: Value):
self.values = values
def __repr__(self):
return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})"
BinaryOperation = Sum | Sub | Product | Division
@beartype
class Float(Node):
def __init__(self, value: float):
self.value = value
def __repr__(self):
return f"{self.__class__.__name__}({self.value})"
@beartype
class Integer(Node):
def __init__(self, value: int): def __init__(self, value: int):
self.value = value self.value = value
@ -33,4 +64,13 @@ class Number(Node):
return f"{self.__class__.__name__}({self.value})" return f"{self.__class__.__name__}({self.value})"
Expression = Sum | Product | Number class Expression(Node):
def __init__(self, node: Node):
self.node = node
def __repr__(self):
return f"{self.__class__.__name__}({self.node})"
Number = Float | Integer
Value = BinaryOperation | Number

View file

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
from beartype.typing import List, Dict, Callable from beartype.typing import List, Dict, Callable
from .logger import Logger from .logger import Logger
@ -21,6 +20,7 @@ class Parser:
def __init__(self, tokens: List[Token]): def __init__(self, tokens: List[Token]):
self.tokens = tokens self.tokens = tokens
self.pos = 0 self.pos = 0
self._last_accepted_token: Tokens | None = None
@property @property
def token(self) -> Token: def token(self) -> Token:
@ -34,11 +34,12 @@ class Parser:
def next_symbol(self): def next_symbol(self):
self.pos += 1 self.pos += 1
logger.debug("%s", f"Advancing to token {self.pos} {self.token}") logger.debug(f"Advancing to token {self.pos} {self.token}")
def accept(self, *token_types: Tokens) -> False | Token: def accept(self, *token_types: Tokens) -> False | Token:
tok = self.token tok = self.token
if self.token.kind in token_types: if self.token.kind in token_types:
self._last_accepted_token = self.token
self.next_symbol() self.next_symbol()
return tok return tok
return False return False
@ -51,56 +52,58 @@ class Parser:
def expect(self, token_type: Tokens) -> Token: def expect(self, token_type: Tokens) -> Token:
r = self.accept(token_type) r = self.accept(token_type)
logger.debug("%s", f"Expecting {token_type}, got {r}") logger.debug(f"Expecting {token_type}, got {r}")
if r is False: if r is False:
raise ParsingError(self.token.loc, f"Unexpected token '{self.token}', wanted {token_type}") raise ParsingError(self.token.loc, f"Unexpected token '{self.token}', wanted {token_type}")
return r return r
def factor(self) -> Expression: def number(self, mandatory: bool = False):
if tok := self.accept(Tokens.Float):
logger.debug(f"Found float {tok}")
return Float(value=float(tok.value))
elif tok := self.accept(Tokens.Integer):
logger.debug(f"Found integer {tok}")
return Integer(value=int(tok.value))
elif mandatory:
raise ParsingError(self.token.loc, f"Unexpected token '{self.token}', wanted integer or float")
def binary_op(self, operand_func: Callable[[], Value], operators: Dict[Tokens, Value]):
operand = operand_func()
while operator := self.accept(*list(operators.keys())):
node_type = operators[operator.kind]
operand2 = operand_func()
operand = node_type(operand, operand2)
logger.debug(f"{node_type.__name__} of the following operands: {operand} and {operand2}")
return operand
def factor(self) -> Value:
if self.accept(Tokens.Parens_Left): if self.accept(Tokens.Parens_Left):
v = self.expression() v = self.expression()
self.expect(Tokens.Parens_Right) self.expect(Tokens.Parens_Right)
return v return v
elif tok := self.accept(Tokens.Number): elif num := self.number():
logger.debug("%s", f"Found number {self.prev_token}") return num
return Number(value=int(tok.value))
else: else:
raise ParsingError(self.token.loc, f"Unexpected token '{self.token}', wanted parenthesized expression or " raise ParsingError(self.token.loc, f"Unexpected token '{self.token}', wanted parenthesized expression or "
f"number") f"number")
def term(self) -> Expression: def term(self) -> Value:
operations = [] return self.binary_op(self.factor, operators={
operand = self.factor() Tokens.Op_Multiply: Product,
operations += [operand] Tokens.Op_Divide: Division,
})
while operator := self.accept(Tokens.Op_Multiply, Tokens.Op_Divide):
operand = self.factor()
operations += [operand]
if len(operations) == 1:
return operations[0]
logger.debug("%s", f"Product of the following terms: {operations}")
return Product(*operations)
def summation(self) -> Sum: def summation(self) -> Sum:
operations = [] return self.binary_op(self.term, operators={
operand = self.term() Tokens.Op_Plus: Sum,
operations += [operand] Tokens.Op_Minus: Sub,
})
while operator := self.accept(Tokens.Op_Plus, Tokens.Op_Minus): def expression(self) -> Value:
operand = self.term()
operations += [operand]
if len(operations) == 1:
return operations[0]
logger.debug("%s", f"Sum of the following terms: {operations}")
return Sum(*operations)
def expression(self) -> Expression:
summation = self.summation() summation = self.summation()
return summation return Expression(summation)
def root(self): def root(self):
blocks = [ blocks = [