diff --git a/compiler/__main__.py b/compiler/__main__.py index 936dbc8..d464172 100644 --- a/compiler/__main__.py +++ b/compiler/__main__.py @@ -1,13 +1,12 @@ from __future__ import annotations import sys +from pprint import pprint from .logger import rootLogger, LogLevel from .parser import Parser, ParsingError from .tokenizer import Tokenizer, Tokens -from .parser import Parser -data = "2 * (32.9 + 1)" def main(): data = """ @@ -18,11 +17,11 @@ def main(): tokens = tokenizer.tokenize(data) tokens = [token for token in tokens if token.kind not in [Tokens.Blank, Tokens.Newline]] - print(tokens) + pprint(tokens) parser = Parser(tokens) try: - print(parser.parse()) + parser.parse().pprint(depth=3) except ParsingError as e: e.location.source = data print(f"{e}\n{e.location.show_in_source()}", file=sys.stderr) diff --git a/compiler/nodes.py b/compiler/nodes.py index e1d2625..9b8d6c5 100644 --- a/compiler/nodes.py +++ b/compiler/nodes.py @@ -1,75 +1,117 @@ from __future__ import annotations +from abc import abstractmethod +from typing import Any + from beartype import beartype +from .logger import Logger + +logger = Logger(__name__) + class Node: - pass + + @abstractmethod + def _values(self) -> Any | list[Node]: + raise NotImplementedError(f"Please override {__name__}") + + def __repr__(self): + vals = self._values() + if type(vals) == list: + vals = ", ".join(repr(val) for val in vals) + return f"{self.__class__.__name__}({vals})" + + def pprint(self, depth: int | None = None, indent: str = "\t"): + print("\n".join(self._pprint(depth=depth, indent=indent))) + + def _pprint(self, depth: int | None, indent: str, _depth: int = 0) -> list[str]: + if depth is not None and _depth >= depth: + return [f"{indent}{self.__class__.__name__} {{ ... }}"] + + vals = self._values() + try: + vals = (val for val in vals) + except TypeError as e: + vals = (vals,) + + result = [f"{self.__class__.__name__} {{"] + for val in vals: + if isinstance(val, Node): + result += val._pprint(depth=depth, indent=indent, _depth=_depth + 1) + else: + result += [f"{indent}{repr(val)}"] + result += [f"}} // {self.__class__.__name__}"] + for i, line in enumerate(result): + result[i] = indent + line + + return result -class Operator(Node): - op: str +class Literal(Node): + def __init__(self, value: Any): + self.value = value + + def _values(self) -> Any | list[Node]: + return self.value + + def _pprint(self, depth: int | None, indent: str, _depth: int = 0) -> list[str]: + return [f"{indent}{repr(self)}"] class Sum(Node): def __init__(self, *values: Value): self.values = values - def __repr__(self): - return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" + def _values(self) -> Any | list[Node]: + return self.values class Sub(Node): def __init__(self, *values: Value): self.values = values - def __repr__(self): - return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" + def _values(self) -> Any | list[Node]: + return self.values class Product(Node): def __init__(self, *values: Value): self.values = values - def __repr__(self): - return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" + def _values(self) -> Any | list[Node]: + return self.values class Division(Node): def __init__(self, *values: Value): self.values = values - def __repr__(self): - return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" + def _values(self) -> Any | list[Node]: + return self.values BinaryOperation = Sum | Sub | Product | Division @beartype -class Float(Node): +class Float(Literal): def __init__(self, value: float): - self.value = value - - def __repr__(self): - return f"{self.__class__.__name__}({self.value})" + super().__init__(value) @beartype -class Integer(Node): +class Integer(Literal): def __init__(self, value: int): - self.value = value - - def __repr__(self): - return f"{self.__class__.__name__}({self.value})" + super().__init__(value) class Expression(Node): def __init__(self, node: Node): self.node = node - def __repr__(self): - return f"{self.__class__.__name__}({self.node})" + def _values(self) -> Any | list[Node]: + return self.node Number = Float | Integer diff --git a/compiler/parser.py b/compiler/parser.py index bb091bb..13e2ed5 100644 --- a/compiler/parser.py +++ b/compiler/parser.py @@ -105,12 +105,8 @@ class Parser: summation = self.summation() return Expression(summation) - def root(self): - blocks = [ - self.expression(), - self.expect(Tokens.EOF), - ] - return blocks + def root(self) -> Node: + return self.expression() def parse(self) -> Node: try: