main+ir: refactor ir from being entirely guided from main to a dedicated class in ir module

This commit is contained in:
Antoine Viallon 2024-01-09 17:16:11 +01:00
parent 692ee511eb
commit 8067d8178e
Signed by: aviallon
GPG key ID: 186FC35EDEB25716
2 changed files with 45 additions and 23 deletions

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import argparse import argparse
import sys import sys
import typing
from pprint import pprint from pprint import pprint
from . import semantic, ir, optimizations from . import semantic, ir, optimizations
@ -53,22 +52,17 @@ def main():
context.check() context.check()
intermediate_representation = typing.cast(ast.intermediate_representation(), list[ir.IRAction]) intermediate_representation = ir.IR(ast, source=data)
pseudo_asm = [x.codegen() for x in intermediate_representation] intermediate_representation.update_location()
for i, action in enumerate(intermediate_representation):
action.ir_location.line = i
action.ir_location.ir = pseudo_asm
action.location.source = data
print("\n---\n", repr(context)) print("\n---\n", repr(context))
# architecture = optimizations.ArchitectureConstraints(registers=["A", "B", "C"], direct_memory_store=True) # architecture = optimizations.ArchitectureConstraints(registers=["A", "B", "C"], direct_memory_store=True)
register_alloc = optimizations.RegisterAllocation(intermediate_representation) register_alloc = optimizations.RegisterAllocation(intermediate_representation.intermediate_representation)
register_alloc.analyze() register_alloc.analyze()
print_ir(intermediate_representation) intermediate_representation.pretty_print()
print(ir.IRRegister.get_registers()) print(ir.IRRegister.get_registers())
@ -81,19 +75,6 @@ def main():
print(f"Caused by:\n{e.__cause__.__class__.__name__}: {e.__cause__}", flush=True) print(f"Caused by:\n{e.__cause__.__class__.__name__}: {e.__cause__}", flush=True)
finally: finally:
def print_ir(intermediate_representation: list[ir.IRAction]):
messages = []
for i, ir_item in enumerate(intermediate_representation):
prefix = f"{str(ir_item.location) + ':':<30}"
source_info = ir_item.location.source_substring.splitlines(keepends=False)
messages += [f"# {prefix} {source_info.pop(0)}"]
while len(source_info) > 0:
messages += [f"# {' ' * len(prefix)} {source_info.pop(0)}"]
messages += [f"{repr(ir_item)}\n"]
print("\n".join(messages))
CompilationWarning.show_warnings(data, file=sys.stdout) CompilationWarning.show_warnings(data, file=sys.stdout)

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import abc import abc
from abc import abstractmethod from abc import abstractmethod
from . import nodes
from .errors import OverrideMandatoryError from .errors import OverrideMandatoryError
from .logger import Logger from .logger import Logger
from .source import SourceLocation, IRLocation as Location from .source import SourceLocation, IRLocation as Location
@ -251,3 +252,43 @@ class IRInvert(IRAction):
def codegen(self) -> str: def codegen(self) -> str:
return f"INVERT {self.source} -> {self.dest}" return f"INVERT {self.source} -> {self.dest}"
def _to_text(ir: list[IRItem]) -> list[str]:
return [x.codegen() for x in ir]
class IR:
def __init__(self, ast: nodes.Node, source: str):
node_ir = ast.intermediate_representation()
assert all((isinstance(ir, IRAction) for ir in node_ir))
# noinspection PyTypeChecker
self.intermediate_representation: list[IRAction] = node_ir
self.source = source
def code(self) -> list[str]:
return [x.codegen() for x in self.intermediate_representation]
def update_location(self):
code = self.code()
ir_item: IRAction
for i, ir_item in enumerate(self.intermediate_representation):
ir_item.ir_location = Location(line=i, ir=code)
ir_item.location.source = self.source
def pretty_print(self):
messages = []
ir_item: IRAction
for i, ir_item in enumerate(self.intermediate_representation):
prefix = f"{str(ir_item.location) + ':':<30}"
source_info = ir_item.location.source_substring.splitlines(keepends=False)
messages += [f"# {prefix} {source_info.pop(0)}"]
while len(source_info) > 0:
messages += [f"# {' ' * len(prefix)} {source_info.pop(0)}"]
messages += [f"{repr(ir_item)}\n"]
print("\n".join(messages))