diff --git a/compiler/__main__.py b/compiler/__main__.py index 4da8d37..1ba44ec 100644 --- a/compiler/__main__.py +++ b/compiler/__main__.py @@ -4,12 +4,15 @@ import argparse import io import sys import typing +from pathlib import Path from . import semantic, ir, optimizations from .errors import CompilationError, CompilationWarning from .interpreter import virtual_machine from .lexer import Lexer, Tokens from .parser import Parser +from .source import TextIOWithMemory +from .tests import mock class REPLInput(io.TextIOWrapper): @@ -45,11 +48,14 @@ def main(): display(byte, 3) + 1; """ - input_stream = REPLInput() + _input_stream = REPLInput() if args.mock: print("Source:\n", data) input_stream = io.StringIO(data) + _input_stream = io.StringIO(data) + + input_stream = TextIOWithMemory(_input_stream) tokens = \ Lexer(input_stream, @@ -66,7 +72,7 @@ def main(): context.check() intermediate_representation = ir.IR(ast) - intermediate_representation.update_location(source=tokens.data) + intermediate_representation.update_location() print("\n---\n", repr(context)) @@ -80,11 +86,11 @@ def main(): ir.IRRegister.pprint() except CompilationError as e: - e.location.source = tokens.data + e.location.source = tokens.input e.pretty_print() finally: - CompilationWarning.show_warnings(tokens.data, file=sys.stdout) + CompilationWarning.show_warnings(tokens.input, file=sys.stdout) if __name__ == "__main__": diff --git a/compiler/errors.py b/compiler/errors.py index c571a89..099340d 100644 --- a/compiler/errors.py +++ b/compiler/errors.py @@ -9,6 +9,7 @@ from termcolor import termcolor from . import source, lexer from .logger import rootLogger, LogLevel +from .source import TextIOWithMemory class LevelType: @@ -58,7 +59,7 @@ class CompilationWarning(Warning): self.location = location @classmethod - def show_warnings(cls, _source: str, file: TextIO = sys.stderr): + def show_warnings(cls, _source: TextIOWithMemory, file: TextIO = sys.stderr): for warning in CompilationWarning._pending_warnings: warning.location.source = _source print(f"{warning}\n{warning.location.show_in_source()}", file=file) diff --git a/compiler/ir.py b/compiler/ir.py index 52005b2..55e9e44 100644 --- a/compiler/ir.py +++ b/compiler/ir.py @@ -15,6 +15,7 @@ logger = Logger(__name__) class IRItem: def __init__(self, location: SourceLocation): + assert location.source is not None self.location = location @abstractmethod @@ -277,12 +278,11 @@ class IR: def code(self) -> list[str]: return [x.codegen() for x in self.intermediate_representation] - def update_location(self, source: str): + def update_location(self): code = self.code() ir_item: IRAction for i, ir_item in enumerate(self.intermediate_representation): ir_item.ir_location = Location(line=i, ir=code) - ir_item.location.source = source def pretty_print(self): messages = [] @@ -290,7 +290,7 @@ class IR: for i, ir_item in enumerate(self.intermediate_representation): prefix = f"{str(ir_item.location) + ':':<30}" source_info = ir_item.location.source_substring.splitlines(keepends=False) - logger.debug(f"source: {repr(ir_item.location.source)}") + logger.trace(f"source: {repr(ir_item.location.source_name)}") if len(source_info) == 0: source_info = [""] messages += [f"# {prefix} {source_info.pop(0)}"] diff --git a/compiler/lexer.py b/compiler/lexer.py index 3cd05b9..8bc37ef 100644 --- a/compiler/lexer.py +++ b/compiler/lexer.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from typing import cast from .logger import Logger -from .source import SourceLocation, Location +from .source import SourceLocation, Location, TextIOWithMemory from .typechecking import typecheck from .utils import implies @@ -62,9 +62,8 @@ class Tokens(enum.Enum): class Lexer(collections.abc.Iterator): - def __init__(self, input_stream: typing.TextIO, token_filter: typing.Callable[[Token], bool] | None = None): + def __init__(self, input_stream: TextIOWithMemory, token_filter: typing.Callable[[Token], bool] | None = None): self.input = input_stream - self.data: str = "" self.tokens = [] self.length: int | None = None self.begin: int = 0 @@ -91,27 +90,28 @@ class Lexer(collections.abc.Iterator): tok = Token(Tokens.BEGIN, loc=SourceLocation( Location(line=0, character=0), - source=self.data + source=self.input ), value=None) self.tokens.append(tok) return tok if self.tokens[-1].kind in [Tokens.BEGIN, Tokens.Newline]: - self.data += self.input.readline() + self.input.readline() - if self.begin == len(self.data): + if self.begin == len(self.input.stream_cache): eof_token = Token(Tokens.EOF, value=None, loc=SourceLocation( Location(line=self.line, character=0), + source=self.input )) self.tokens += [eof_token] self.length = len(self.tokens) return eof_token - elif self.begin < len(self.data): + elif self.begin < len(self.input.stream_cache): best_result: Token = Token(Tokens.Unknown, loc=SourceLocation( Location(line=self.line, character=self.character), - source=self.data), + source=self.input), value="" ) token_kind: Tokens @@ -119,7 +119,7 @@ class Lexer(collections.abc.Iterator): if token_kind == Tokens.Unknown: continue regex = cast(re.Pattern, token_kind.value) - match = regex.match(self.data, self.begin) + match = regex.match(self.input.stream_cache, self.begin) if match is not None: logger.trace(f"Got match: {match}") result = match.group(0) @@ -130,7 +130,8 @@ class Lexer(collections.abc.Iterator): continue loc = SourceLocation( begin=Location(line=self.line, character=self.character), - end=Location(line=self.line, character=self.character + len(result)) + end=Location(line=self.line, character=self.character + len(result)), + source=self.input ) best_result = Token(token_kind, value=result, loc=loc) logger.trace(f"New best match: {best_result}") @@ -144,7 +145,7 @@ class Lexer(collections.abc.Iterator): elif best_result.kind == Tokens.Newline: self.line += 1 self.character = 0 - best_result.loc.end = Location(line=self.line, character=0) + best_result.loc.end = Location(line=self.line, character=0, file=self.input.name) logger.debug(f"Added token {best_result}") diff --git a/compiler/nodes.py b/compiler/nodes.py index dd607a3..84f426f 100644 --- a/compiler/nodes.py +++ b/compiler/nodes.py @@ -26,10 +26,10 @@ class Node: @functools.cache def location(self) -> SourceLocation: assert len(self._values()) > 0 - locations = [v.location() for v in self._values()] + locations: list[SourceLocation] = [v.location() for v in self._values()] begin = min([loc.begin for loc in locations]) end = max([loc.end for loc in locations]) - loc = SourceLocation(begin=begin, end=end) + loc = SourceLocation(begin=begin, end=end, source=locations[0].source) return loc def __repr__(self): @@ -98,6 +98,7 @@ class Literal(Node, ABC): super().__init__() self.value = value self.loc = location + assert location.source is not None def location(self) -> SourceLocation: return self.loc diff --git a/compiler/semantic.py b/compiler/semantic.py index 4135cc4..8bea3e3 100644 --- a/compiler/semantic.py +++ b/compiler/semantic.py @@ -15,7 +15,7 @@ builtin_node = nodes.PseudoNode( lexer.Token( kind=lexer.Tokens.Unknown, value="__compiler_internal__", - loc=source.SourceLocation(begin=source.Location(internal=True), source="__compiler_internal__") + loc=source.COMPILER_INTERNAL ) ) diff --git a/compiler/source.py b/compiler/source.py index 8a90bc2..ab27e73 100644 --- a/compiler/source.py +++ b/compiler/source.py @@ -1,6 +1,8 @@ from __future__ import annotations +import io import math +import typing from dataclasses import dataclass from .logger import Logger @@ -9,6 +11,84 @@ from .typechecking import typecheck logger = Logger(__name__) +@typecheck +class TextIOWithMemory(typing.TextIO): + def __init__(self, io_stream: io.TextIOBase): + assert isinstance(io_stream, io.TextIOBase) + + self._stream = io_stream + self.stream_cache: str = "" + + @property + def name(self): + _name = "" + try: + _name = self._stream.name + except AttributeError: + pass + return _name + + def __enter__(self): + return self._stream.__enter__() + + def close(self): + return self._stream.close() + + def fileno(self): + return self._stream.fileno() + + def flush(self): + return self._stream.flush() + + def isatty(self): + return self._stream.isatty() + + def read(self, __n=...): + r = self._stream.read(__size=__n) + self.stream_cache += r + + def readable(self): + return self._stream.readable() + + def readline(self, *args, **kwargs): + r = self._stream.readline(*args, **kwargs) + self.stream_cache += r + + def readlines(self, __hint=...): + return self._stream.readlines(__hint=__hint) + + def seek(self, __offset, __whence=...): + return self._stream.seek(__offset, __whence=__whence) + + def seekable(self): + # return self.stream.seekable() + return False + + def tell(self): + return self._stream.tell() + + def truncate(self, __size=...): + return self._stream.truncate(__size=__size) + + def writable(self): + return self._stream.writable() + + def write(self, __s): + return self._stream.write(__s) + + def writelines(self, __lines): + return self._stream.writelines(__lines) + + def __next__(self): + return self._stream.__next__() + + def __iter__(self): + return self._stream.__iter__() + + def __exit__(self, __t, __value, __traceback): + return self._stream.__exit__(__t, __value, __traceback) + + @typecheck @dataclass class Location: @@ -68,13 +148,19 @@ class Location: @typecheck class SourceLocation: - def __init__(self, begin: Location, end: Location | None = None, source: str | None = None): + def __init__(self, begin: Location, end: Location | None = None, source: TextIOWithMemory | None = None): self.begin = begin self.end = end if self.end is None: self.end = self.begin self.source = source + self.source_name = "" + if source is not None: + self.source_name = source.name + + self.begin.file = self.source_name + self.end.file = self.source_name assert (self.begin.line, self.begin.character) <= (self.end.line, self.end.character) assert self.begin.file == self.end.file @@ -103,7 +189,10 @@ class SourceLocation: @property def source_substring(self) -> str: - source = self.source.splitlines(keepends=False) + if self.source is None: + return "__compiler_internal__" + + source = self.source.stream_cache.splitlines(keepends=False) source_lines = source[self.begin.line:self.end.line + 1] if len(source_lines) == 1: source_lines[0] = source_lines[0][self.begin.character:self.end.character + 1] @@ -125,7 +214,10 @@ class SourceLocation: return "\n".join(result) def show_in_source(self) -> str: - source = self.source.splitlines(keepends=False) + if self.source is None: + return "__compiler_internal__" + + source = self.source.stream_cache.splitlines(keepends=False) line_number_maxlen = 1 if not len(source) else int(math.log10(len(source)) + 1) lines = source[self.begin.line:self.end.line + 1] result = [] @@ -146,6 +238,9 @@ class SourceLocation: return "\n".join(result) +COMPILER_INTERNAL = SourceLocation(begin=Location(internal=True), end=Location(internal=True), + source=TextIOWithMemory(io.StringIO("__compiler_internal__"))) + @typecheck class IRLocation: def __init__(self, line: int = -1, ir: list[str] = None):