interpreter+main: add a working interpreter

This commit is contained in:
Antoine Viallon 2024-04-12 16:44:32 +02:00
parent 4f63fb9dfc
commit ea4a0c2ff1
Signed by: aviallon
GPG key ID: 186FC35EDEB25716
3 changed files with 217 additions and 0 deletions

View file

@ -78,6 +78,9 @@ def main():
ir.IRRegister.pprint()
vm = virtual_machine.VirtualMachine()
vm.execute_stream(intermediate_representation.intermediate_representation)
except CompilationError as e:
e.location.source = tokens.input
e.pretty_print()

View file

View file

@ -0,0 +1,214 @@
import functools
import operator
import pprint
import typing
from collections import defaultdict
from operator import mul
from .. import ir, errors
from ..logger import Logger
logger = Logger(__name__)
class Register:
def __init__(self, byte_size: int = 4):
self.byte_size = byte_size
self.bytes = bytearray([0] * byte_size)
def set(self, value: bytearray):
if len(value) > self.byte_size:
raise RuntimeError(f"Can't fit value of {len(value)} bytes into register of {self.byte_size} in size")
self.bytes = value
def __int__(self) -> int:
return int(self.bytes)
def __bytes__(self) -> bytes:
return bytes(self.bytes)
def __repr__(self):
return hex(int.from_bytes(self.bytes, byteorder="big", signed=False))
def product(iterable: typing.Iterable):
return functools.reduce(mul, iterable)
def divide(iterable: typing.Iterable):
operands = list(iterable)
assert len(operands) == 2
return operands[0] // operands[1]
def invert(iterable: typing.Iterable):
operands = list(iterable)
assert len(operands) == 1
return operator.invert(operands[0])
def negate(iterable: typing.Iterable):
operands = list(iterable)
assert len(operands) == 1
return -operands[0]
operations = {
ir.IRAdd: sum,
ir.IRMul: product,
ir.IRDiv: divide,
ir.IRInvert: invert,
ir.IRNegation: negate,
}
def _builtin_display(inputs):
converted_inputs = (repr(x) for x in inputs)
print(f"PROGRAM OUTPUT: {', '.join(converted_inputs)}")
return 0
builtin_functions: callable = {
"builtins_0.display": _builtin_display
}
class VirtualMachine:
def __init__(self, initial_state: dict[str, Register] = None, entrypoint: int = 0):
if initial_state is None:
initial_state = {}
self.registers: defaultdict[str, Register] = defaultdict(Register)
self.registers.update(initial_state)
self.stack: list[Register] = []
self.stack_pointer: int = 0
self.instruction_pointer: int = 0
self.variables: dict[str, bytearray] = {}
self.word_size = 4
self.carry = False
# self.memory: bytearray = bytearray(b"\x00" * memory_size)
def get_value(self, value: ir.IRValue) -> bytearray:
match type(value):
case ir.IRImmediate:
# noinspection PyTypeChecker
immediate: ir.IRImmediate = value
if type(immediate.value) is int:
return bytearray(int(immediate.value).to_bytes(signed=False, byteorder="big"))
else:
raise NotImplementedError("Only integers immediate values are handled for now.")
case ir.IRVariable:
# noinspection PyTypeChecker
variable: ir.IRVariable = value
return self.variables.get(variable.codegen(), bytearray())
case ir.IRRegister:
# noinspection PyTypeChecker
register: ir.IRRegister = value
return self.registers[register.codegen()].bytes
case _:
raise TypeError(f"{repr(value)} is not an IRValue")
def int_to_bytes(self, value: int) -> bytes:
return value.to_bytes(byteorder="big", length=self.word_size, signed=False)
def get_values(self, values: typing.Iterable[ir.IRValue]) -> typing.Iterator[int]:
for value in values:
yield int.from_bytes(self.get_value(value), byteorder="big", signed=False)
def pprint_values(self, values: typing.Iterable[ir.IRValue]) -> list[str]:
return [hex(value) for value in self.get_values(values)]
def set_value(self, key: ir.IRAssignable, value: bytearray):
match type(key):
case ir.IRVariable:
# noinspection PyTypeChecker
variable: ir.IRVariable = key
self.variables[variable.codegen()] = value
case ir.IRRegister:
# noinspection PyTypeChecker
register: ir.IRRegister = key
self.registers[register.codegen()].set(value)
case _:
raise TypeError(f"{repr(key)} is not an IRAssignable")
def execute_one(self, instruction: ir.IRAction):
match type(instruction):
case ir.IRAdd | ir.IRMul | ir.IRDiv | ir.IRInvert | ir.IRNegation:
_operation = operations.get(instruction.__class__,
lambda _: None)
# noinspection PyTypeChecker
ins: ir.IRAdd | ir.IRMul = instruction
pprinted_input_values = self.pprint_values(ins.reads)
values = list(self.get_values(ins.reads))
result = _operation(values)
self.carry = int(result).bit_length() > self.word_size
if self.carry:
errors.CompilationWarning(instruction.location,
message=f"Result of operation is overflowing",
level=errors.Levels.Note).raise_warning()
# Overflow
result %= (2 ** (self.word_size * 8) - 1)
result_bytes = self.int_to_bytes(result)
self.set_value(ins.dest, bytearray(result_bytes))
logger.debug(
f"\t{_operation.__name__}: {pprinted_input_values} = {hex(result)} : store into {ins.dest.codegen()}")
case ir.IRMove:
# noinspection PyTypeChecker
ins: ir.IRMove = instruction
pprinted_input_value = self.pprint_values((ins.source,))
_source = self.get_value(ins.source)
self.set_value(ins.dest, _source)
logger.debug(f"\tmove: {pprinted_input_value} : store into {ins.dest.codegen()}")
case ir.IRCall:
ins: ir.IRCall = instruction
logger.debug(
f"\tFully-qualified function: {ins.function.fq_identifier}, location: {ins.function.location}")
pprinted_input_value = self.pprint_values(ins.arguments)
if ins.function.fq_identifier in builtin_functions:
builtin_func = builtin_functions[ins.function.fq_identifier]
inputs = list(self.get_values(ins.arguments))
output = self.int_to_bytes(builtin_func(inputs))
self.set_value(ins.dest, output)
logger.debug(
f"\tcall (builtin): {ins.function.fq_identifier} ( {pprinted_input_value} ) : store into {ins.dest.codegen()}")
else:
raise NotImplementedError(f"Calling non-bultin functions is not yet supported")
case _:
if __debug__:
raise NotImplementedError(f"{instruction.__class__.__name__} not yet supported")
logger.debug(f"{instruction.__class__.__name__} not yet supported")
def execute_stream(self, instruction_stream: typing.Iterable[ir.IRAction]):
for instruction in instruction_stream:
logger.debug(instruction.codegen())
try:
self.execute_one(instruction)
logger.debug(f"\tRegisters: {pprint.pformat(dict(self.registers))}")
except errors.CompilationError as e:
e.pretty_print()
except Exception as e:
message = f"""
Error while executing instruction: {instruction.codegen()}
Inputs values: {list(self.get_values(instruction.reads))}
Outputs: {list(instruction.writes)}
"""
_e = errors.CompilationError(instruction.location,
message)
_e.with_traceback(e.__traceback__)
raise _e from e