source: wrap input with a caching reader, allowing to reference source even from REPLInput.

Will be especially useful when we support files as input.
This commit is contained in:
Antoine Viallon 2024-04-12 16:33:06 +02:00
parent a98f38d43f
commit 67c2ed3a26
Signed by: aviallon
GPG key ID: 186FC35EDEB25716
7 changed files with 129 additions and 25 deletions

View file

@ -4,12 +4,15 @@ import argparse
import io import io
import sys import sys
import typing import typing
from pathlib import Path
from . import semantic, ir, optimizations from . import semantic, ir, optimizations
from .errors import CompilationError, CompilationWarning from .errors import CompilationError, CompilationWarning
from .interpreter import virtual_machine from .interpreter import virtual_machine
from .lexer import Lexer, Tokens from .lexer import Lexer, Tokens
from .parser import Parser from .parser import Parser
from .source import TextIOWithMemory
from .tests import mock
class REPLInput(io.TextIOWrapper): class REPLInput(io.TextIOWrapper):
@ -45,11 +48,14 @@ def main():
display(byte, 3) + 1; display(byte, 3) + 1;
""" """
input_stream = REPLInput() _input_stream = REPLInput()
if args.mock: if args.mock:
print("Source:\n", data) print("Source:\n", data)
input_stream = io.StringIO(data) input_stream = io.StringIO(data)
_input_stream = io.StringIO(data)
input_stream = TextIOWithMemory(_input_stream)
tokens = \ tokens = \
Lexer(input_stream, Lexer(input_stream,
@ -66,7 +72,7 @@ def main():
context.check() context.check()
intermediate_representation = ir.IR(ast) intermediate_representation = ir.IR(ast)
intermediate_representation.update_location(source=tokens.data) intermediate_representation.update_location()
print("\n---\n", repr(context)) print("\n---\n", repr(context))
@ -80,11 +86,11 @@ def main():
ir.IRRegister.pprint() ir.IRRegister.pprint()
except CompilationError as e: except CompilationError as e:
e.location.source = tokens.data e.location.source = tokens.input
e.pretty_print() e.pretty_print()
finally: finally:
CompilationWarning.show_warnings(tokens.data, file=sys.stdout) CompilationWarning.show_warnings(tokens.input, file=sys.stdout)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -9,6 +9,7 @@ from termcolor import termcolor
from . import source, lexer from . import source, lexer
from .logger import rootLogger, LogLevel from .logger import rootLogger, LogLevel
from .source import TextIOWithMemory
class LevelType: class LevelType:
@ -58,7 +59,7 @@ class CompilationWarning(Warning):
self.location = location self.location = location
@classmethod @classmethod
def show_warnings(cls, _source: str, file: TextIO = sys.stderr): def show_warnings(cls, _source: TextIOWithMemory, file: TextIO = sys.stderr):
for warning in CompilationWarning._pending_warnings: for warning in CompilationWarning._pending_warnings:
warning.location.source = _source warning.location.source = _source
print(f"{warning}\n{warning.location.show_in_source()}", file=file) print(f"{warning}\n{warning.location.show_in_source()}", file=file)

View file

@ -15,6 +15,7 @@ logger = Logger(__name__)
class IRItem: class IRItem:
def __init__(self, location: SourceLocation): def __init__(self, location: SourceLocation):
assert location.source is not None
self.location = location self.location = location
@abstractmethod @abstractmethod
@ -277,12 +278,11 @@ class IR:
def code(self) -> list[str]: def code(self) -> list[str]:
return [x.codegen() for x in self.intermediate_representation] return [x.codegen() for x in self.intermediate_representation]
def update_location(self, source: str): def update_location(self):
code = self.code() code = self.code()
ir_item: IRAction ir_item: IRAction
for i, ir_item in enumerate(self.intermediate_representation): for i, ir_item in enumerate(self.intermediate_representation):
ir_item.ir_location = Location(line=i, ir=code) ir_item.ir_location = Location(line=i, ir=code)
ir_item.location.source = source
def pretty_print(self): def pretty_print(self):
messages = [] messages = []
@ -290,7 +290,7 @@ class IR:
for i, ir_item in enumerate(self.intermediate_representation): for i, ir_item in enumerate(self.intermediate_representation):
prefix = f"{str(ir_item.location) + ':':<30}" prefix = f"{str(ir_item.location) + ':':<30}"
source_info = ir_item.location.source_substring.splitlines(keepends=False) source_info = ir_item.location.source_substring.splitlines(keepends=False)
logger.debug(f"source: {repr(ir_item.location.source)}") logger.trace(f"source: {repr(ir_item.location.source_name)}")
if len(source_info) == 0: if len(source_info) == 0:
source_info = ["<NO SOURCE>"] source_info = ["<NO SOURCE>"]
messages += [f"# {prefix} {source_info.pop(0)}"] messages += [f"# {prefix} {source_info.pop(0)}"]

View file

@ -8,7 +8,7 @@ from dataclasses import dataclass, field
from typing import cast from typing import cast
from .logger import Logger from .logger import Logger
from .source import SourceLocation, Location from .source import SourceLocation, Location, TextIOWithMemory
from .typechecking import typecheck from .typechecking import typecheck
from .utils import implies from .utils import implies
@ -62,9 +62,8 @@ class Tokens(enum.Enum):
class Lexer(collections.abc.Iterator): class Lexer(collections.abc.Iterator):
def __init__(self, input_stream: typing.TextIO, token_filter: typing.Callable[[Token], bool] | None = None): def __init__(self, input_stream: TextIOWithMemory, token_filter: typing.Callable[[Token], bool] | None = None):
self.input = input_stream self.input = input_stream
self.data: str = ""
self.tokens = [] self.tokens = []
self.length: int | None = None self.length: int | None = None
self.begin: int = 0 self.begin: int = 0
@ -91,27 +90,28 @@ class Lexer(collections.abc.Iterator):
tok = Token(Tokens.BEGIN, tok = Token(Tokens.BEGIN,
loc=SourceLocation( loc=SourceLocation(
Location(line=0, character=0), Location(line=0, character=0),
source=self.data source=self.input
), ),
value=None) value=None)
self.tokens.append(tok) self.tokens.append(tok)
return tok return tok
if self.tokens[-1].kind in [Tokens.BEGIN, Tokens.Newline]: if self.tokens[-1].kind in [Tokens.BEGIN, Tokens.Newline]:
self.data += self.input.readline() self.input.readline()
if self.begin == len(self.data): if self.begin == len(self.input.stream_cache):
eof_token = Token(Tokens.EOF, value=None, loc=SourceLocation( eof_token = Token(Tokens.EOF, value=None, loc=SourceLocation(
Location(line=self.line, character=0), Location(line=self.line, character=0),
source=self.input
)) ))
self.tokens += [eof_token] self.tokens += [eof_token]
self.length = len(self.tokens) self.length = len(self.tokens)
return eof_token return eof_token
elif self.begin < len(self.data): elif self.begin < len(self.input.stream_cache):
best_result: Token = Token(Tokens.Unknown, best_result: Token = Token(Tokens.Unknown,
loc=SourceLocation( loc=SourceLocation(
Location(line=self.line, character=self.character), Location(line=self.line, character=self.character),
source=self.data), source=self.input),
value="" value=""
) )
token_kind: Tokens token_kind: Tokens
@ -119,7 +119,7 @@ class Lexer(collections.abc.Iterator):
if token_kind == Tokens.Unknown: if token_kind == Tokens.Unknown:
continue continue
regex = cast(re.Pattern, token_kind.value) regex = cast(re.Pattern, token_kind.value)
match = regex.match(self.data, self.begin) match = regex.match(self.input.stream_cache, self.begin)
if match is not None: if match is not None:
logger.trace(f"Got match: {match}") logger.trace(f"Got match: {match}")
result = match.group(0) result = match.group(0)
@ -130,7 +130,8 @@ class Lexer(collections.abc.Iterator):
continue continue
loc = SourceLocation( loc = SourceLocation(
begin=Location(line=self.line, character=self.character), begin=Location(line=self.line, character=self.character),
end=Location(line=self.line, character=self.character + len(result)) end=Location(line=self.line, character=self.character + len(result)),
source=self.input
) )
best_result = Token(token_kind, value=result, loc=loc) best_result = Token(token_kind, value=result, loc=loc)
logger.trace(f"New best match: {best_result}") logger.trace(f"New best match: {best_result}")
@ -144,7 +145,7 @@ class Lexer(collections.abc.Iterator):
elif best_result.kind == Tokens.Newline: elif best_result.kind == Tokens.Newline:
self.line += 1 self.line += 1
self.character = 0 self.character = 0
best_result.loc.end = Location(line=self.line, character=0) best_result.loc.end = Location(line=self.line, character=0, file=self.input.name)
logger.debug(f"Added token {best_result}") logger.debug(f"Added token {best_result}")

View file

@ -26,10 +26,10 @@ class Node:
@functools.cache @functools.cache
def location(self) -> SourceLocation: def location(self) -> SourceLocation:
assert len(self._values()) > 0 assert len(self._values()) > 0
locations = [v.location() for v in self._values()] locations: list[SourceLocation] = [v.location() for v in self._values()]
begin = min([loc.begin for loc in locations]) begin = min([loc.begin for loc in locations])
end = max([loc.end for loc in locations]) end = max([loc.end for loc in locations])
loc = SourceLocation(begin=begin, end=end) loc = SourceLocation(begin=begin, end=end, source=locations[0].source)
return loc return loc
def __repr__(self): def __repr__(self):
@ -98,6 +98,7 @@ class Literal(Node, ABC):
super().__init__() super().__init__()
self.value = value self.value = value
self.loc = location self.loc = location
assert location.source is not None
def location(self) -> SourceLocation: def location(self) -> SourceLocation:
return self.loc return self.loc

View file

@ -15,7 +15,7 @@ builtin_node = nodes.PseudoNode(
lexer.Token( lexer.Token(
kind=lexer.Tokens.Unknown, kind=lexer.Tokens.Unknown,
value="__compiler_internal__", value="__compiler_internal__",
loc=source.SourceLocation(begin=source.Location(internal=True), source="__compiler_internal__") loc=source.COMPILER_INTERNAL
) )
) )

View file

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import io
import math import math
import typing
from dataclasses import dataclass from dataclasses import dataclass
from .logger import Logger from .logger import Logger
@ -9,6 +11,84 @@ from .typechecking import typecheck
logger = Logger(__name__) logger = Logger(__name__)
@typecheck
class TextIOWithMemory(typing.TextIO):
def __init__(self, io_stream: io.TextIOBase):
assert isinstance(io_stream, io.TextIOBase)
self._stream = io_stream
self.stream_cache: str = ""
@property
def name(self):
_name = "<none>"
try:
_name = self._stream.name
except AttributeError:
pass
return _name
def __enter__(self):
return self._stream.__enter__()
def close(self):
return self._stream.close()
def fileno(self):
return self._stream.fileno()
def flush(self):
return self._stream.flush()
def isatty(self):
return self._stream.isatty()
def read(self, __n=...):
r = self._stream.read(__size=__n)
self.stream_cache += r
def readable(self):
return self._stream.readable()
def readline(self, *args, **kwargs):
r = self._stream.readline(*args, **kwargs)
self.stream_cache += r
def readlines(self, __hint=...):
return self._stream.readlines(__hint=__hint)
def seek(self, __offset, __whence=...):
return self._stream.seek(__offset, __whence=__whence)
def seekable(self):
# return self.stream.seekable()
return False
def tell(self):
return self._stream.tell()
def truncate(self, __size=...):
return self._stream.truncate(__size=__size)
def writable(self):
return self._stream.writable()
def write(self, __s):
return self._stream.write(__s)
def writelines(self, __lines):
return self._stream.writelines(__lines)
def __next__(self):
return self._stream.__next__()
def __iter__(self):
return self._stream.__iter__()
def __exit__(self, __t, __value, __traceback):
return self._stream.__exit__(__t, __value, __traceback)
@typecheck @typecheck
@dataclass @dataclass
class Location: class Location:
@ -68,13 +148,19 @@ class Location:
@typecheck @typecheck
class SourceLocation: class SourceLocation:
def __init__(self, begin: Location, end: Location | None = None, source: str | None = None): def __init__(self, begin: Location, end: Location | None = None, source: TextIOWithMemory | None = None):
self.begin = begin self.begin = begin
self.end = end self.end = end
if self.end is None: if self.end is None:
self.end = self.begin self.end = self.begin
self.source = source self.source = source
self.source_name = "<none>"
if source is not None:
self.source_name = source.name
self.begin.file = self.source_name
self.end.file = self.source_name
assert (self.begin.line, self.begin.character) <= (self.end.line, self.end.character) assert (self.begin.line, self.begin.character) <= (self.end.line, self.end.character)
assert self.begin.file == self.end.file assert self.begin.file == self.end.file
@ -103,7 +189,10 @@ class SourceLocation:
@property @property
def source_substring(self) -> str: def source_substring(self) -> str:
source = self.source.splitlines(keepends=False) if self.source is None:
return "__compiler_internal__"
source = self.source.stream_cache.splitlines(keepends=False)
source_lines = source[self.begin.line:self.end.line + 1] source_lines = source[self.begin.line:self.end.line + 1]
if len(source_lines) == 1: if len(source_lines) == 1:
source_lines[0] = source_lines[0][self.begin.character:self.end.character + 1] source_lines[0] = source_lines[0][self.begin.character:self.end.character + 1]
@ -125,7 +214,10 @@ class SourceLocation:
return "\n".join(result) return "\n".join(result)
def show_in_source(self) -> str: def show_in_source(self) -> str:
source = self.source.splitlines(keepends=False) if self.source is None:
return "__compiler_internal__"
source = self.source.stream_cache.splitlines(keepends=False)
line_number_maxlen = 1 if not len(source) else int(math.log10(len(source)) + 1) line_number_maxlen = 1 if not len(source) else int(math.log10(len(source)) + 1)
lines = source[self.begin.line:self.end.line + 1] lines = source[self.begin.line:self.end.line + 1]
result = [] result = []
@ -146,6 +238,9 @@ class SourceLocation:
return "\n".join(result) return "\n".join(result)
COMPILER_INTERNAL = SourceLocation(begin=Location(internal=True), end=Location(internal=True),
source=TextIOWithMemory(io.StringIO("__compiler_internal__")))
@typecheck @typecheck
class IRLocation: class IRLocation:
def __init__(self, line: int = -1, ir: list[str] = None): def __init__(self, line: int = -1, ir: list[str] = None):