diff --git a/compiler/__main__.py b/compiler/__main__.py index fb2e246..28ca82c 100644 --- a/compiler/__main__.py +++ b/compiler/__main__.py @@ -2,7 +2,6 @@ from __future__ import annotations import argparse import sys -import typing from pprint import pprint from . import semantic, ir, optimizations @@ -53,22 +52,17 @@ def main(): context.check() - intermediate_representation = typing.cast(ast.intermediate_representation(), list[ir.IRAction]) - pseudo_asm = [x.codegen() for x in intermediate_representation] - for i, action in enumerate(intermediate_representation): - action.ir_location.line = i - action.ir_location.ir = pseudo_asm - - action.location.source = data + intermediate_representation = ir.IR(ast, source=data) + intermediate_representation.update_location() print("\n---\n", repr(context)) # architecture = optimizations.ArchitectureConstraints(registers=["A", "B", "C"], direct_memory_store=True) - register_alloc = optimizations.RegisterAllocation(intermediate_representation) + register_alloc = optimizations.RegisterAllocation(intermediate_representation.intermediate_representation) register_alloc.analyze() - print_ir(intermediate_representation) + intermediate_representation.pretty_print() print(ir.IRRegister.get_registers()) @@ -81,19 +75,6 @@ def main(): print(f"Caused by:\n{e.__cause__.__class__.__name__}: {e.__cause__}", flush=True) finally: - - -def print_ir(intermediate_representation: list[ir.IRAction]): - messages = [] - for i, ir_item in enumerate(intermediate_representation): - prefix = f"{str(ir_item.location) + ':':<30}" - source_info = ir_item.location.source_substring.splitlines(keepends=False) - messages += [f"# {prefix} {source_info.pop(0)}"] - while len(source_info) > 0: - messages += [f"# {' ' * len(prefix)} {source_info.pop(0)}"] - - messages += [f"{repr(ir_item)}\n"] - print("\n".join(messages)) CompilationWarning.show_warnings(data, file=sys.stdout) diff --git a/compiler/ir.py b/compiler/ir.py index 823dc19..da83c43 100644 --- a/compiler/ir.py +++ b/compiler/ir.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc from abc import abstractmethod +from . import nodes from .errors import OverrideMandatoryError from .logger import Logger from .source import SourceLocation, IRLocation as Location @@ -251,3 +252,43 @@ class IRInvert(IRAction): def codegen(self) -> str: return f"INVERT {self.source} -> {self.dest}" + + +def _to_text(ir: list[IRItem]) -> list[str]: + return [x.codegen() for x in ir] + + +class IR: + + def __init__(self, ast: nodes.Node, source: str): + node_ir = ast.intermediate_representation() + + assert all((isinstance(ir, IRAction) for ir in node_ir)) + + # noinspection PyTypeChecker + self.intermediate_representation: list[IRAction] = node_ir + + self.source = source + + def code(self) -> list[str]: + return [x.codegen() for x in self.intermediate_representation] + + def update_location(self): + code = self.code() + ir_item: IRAction + for i, ir_item in enumerate(self.intermediate_representation): + ir_item.ir_location = Location(line=i, ir=code) + ir_item.location.source = self.source + + def pretty_print(self): + messages = [] + ir_item: IRAction + for i, ir_item in enumerate(self.intermediate_representation): + prefix = f"{str(ir_item.location) + ':':<30}" + source_info = ir_item.location.source_substring.splitlines(keepends=False) + messages += [f"# {prefix} {source_info.pop(0)}"] + while len(source_info) > 0: + messages += [f"# {' ' * len(prefix)} {source_info.pop(0)}"] + + messages += [f"{repr(ir_item)}\n"] + print("\n".join(messages))