nodes: implement pretty printing

This commit is contained in:
Antoine Viallon 2023-05-08 19:25:54 +02:00
parent be9f389159
commit 1f21bcc89f
Signed by: aviallon
GPG key ID: D126B13AB555E16F
3 changed files with 70 additions and 33 deletions

View file

@ -1,13 +1,12 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
from pprint import pprint
from .logger import rootLogger, LogLevel from .logger import rootLogger, LogLevel
from .parser import Parser, ParsingError from .parser import Parser, ParsingError
from .tokenizer import Tokenizer, Tokens from .tokenizer import Tokenizer, Tokens
from .parser import Parser
data = "2 * (32.9 + 1)"
def main(): def main():
data = """ data = """
@ -18,11 +17,11 @@ def main():
tokens = tokenizer.tokenize(data) tokens = tokenizer.tokenize(data)
tokens = [token for token in tokens if token.kind not in [Tokens.Blank, Tokens.Newline]] tokens = [token for token in tokens if token.kind not in [Tokens.Blank, Tokens.Newline]]
print(tokens) pprint(tokens)
parser = Parser(tokens) parser = Parser(tokens)
try: try:
print(parser.parse()) parser.parse().pprint(depth=3)
except ParsingError as e: except ParsingError as e:
e.location.source = data e.location.source = data
print(f"{e}\n{e.location.show_in_source()}", file=sys.stderr) print(f"{e}\n{e.location.show_in_source()}", file=sys.stderr)

View file

@ -1,75 +1,117 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod
from typing import Any
from beartype import beartype from beartype import beartype
from .logger import Logger
logger = Logger(__name__)
class Node: class Node:
pass
@abstractmethod
def _values(self) -> Any | list[Node]:
raise NotImplementedError(f"Please override {__name__}")
def __repr__(self):
vals = self._values()
if type(vals) == list:
vals = ", ".join(repr(val) for val in vals)
return f"{self.__class__.__name__}({vals})"
def pprint(self, depth: int | None = None, indent: str = "\t"):
print("\n".join(self._pprint(depth=depth, indent=indent)))
def _pprint(self, depth: int | None, indent: str, _depth: int = 0) -> list[str]:
if depth is not None and _depth >= depth:
return [f"{indent}{self.__class__.__name__} {{ ... }}"]
vals = self._values()
try:
vals = (val for val in vals)
except TypeError as e:
vals = (vals,)
result = [f"{self.__class__.__name__} {{"]
for val in vals:
if isinstance(val, Node):
result += val._pprint(depth=depth, indent=indent, _depth=_depth + 1)
else:
result += [f"{indent}{repr(val)}"]
result += [f"}} // {self.__class__.__name__}"]
for i, line in enumerate(result):
result[i] = indent + line
return result
class Operator(Node): class Literal(Node):
op: str def __init__(self, value: Any):
self.value = value
def _values(self) -> Any | list[Node]:
return self.value
def _pprint(self, depth: int | None, indent: str, _depth: int = 0) -> list[str]:
return [f"{indent}{repr(self)}"]
class Sum(Node): class Sum(Node):
def __init__(self, *values: Value): def __init__(self, *values: Value):
self.values = values self.values = values
def __repr__(self): def _values(self) -> Any | list[Node]:
return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" return self.values
class Sub(Node): class Sub(Node):
def __init__(self, *values: Value): def __init__(self, *values: Value):
self.values = values self.values = values
def __repr__(self): def _values(self) -> Any | list[Node]:
return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" return self.values
class Product(Node): class Product(Node):
def __init__(self, *values: Value): def __init__(self, *values: Value):
self.values = values self.values = values
def __repr__(self): def _values(self) -> Any | list[Node]:
return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" return self.values
class Division(Node): class Division(Node):
def __init__(self, *values: Value): def __init__(self, *values: Value):
self.values = values self.values = values
def __repr__(self): def _values(self) -> Any | list[Node]:
return f"{self.__class__.__name__}({', '.join(repr(v) for v in self.values)})" return self.values
BinaryOperation = Sum | Sub | Product | Division BinaryOperation = Sum | Sub | Product | Division
@beartype @beartype
class Float(Node): class Float(Literal):
def __init__(self, value: float): def __init__(self, value: float):
self.value = value super().__init__(value)
def __repr__(self):
return f"{self.__class__.__name__}({self.value})"
@beartype @beartype
class Integer(Node): class Integer(Literal):
def __init__(self, value: int): def __init__(self, value: int):
self.value = value super().__init__(value)
def __repr__(self):
return f"{self.__class__.__name__}({self.value})"
class Expression(Node): class Expression(Node):
def __init__(self, node: Node): def __init__(self, node: Node):
self.node = node self.node = node
def __repr__(self): def _values(self) -> Any | list[Node]:
return f"{self.__class__.__name__}({self.node})" return self.node
Number = Float | Integer Number = Float | Integer

View file

@ -105,12 +105,8 @@ class Parser:
summation = self.summation() summation = self.summation()
return Expression(summation) return Expression(summation)
def root(self): def root(self) -> Node:
blocks = [ return self.expression()
self.expression(),
self.expect(Tokens.EOF),
]
return blocks
def parse(self) -> Node: def parse(self) -> Node:
try: try: