source: wrap input with a caching reader, allowing to reference source even from REPLInput.

Will be especially useful when we support files as input.
This commit is contained in:
Antoine Viallon 2024-04-12 16:33:06 +02:00
parent a98f38d43f
commit 67c2ed3a26
Signed by: aviallon
GPG key ID: 186FC35EDEB25716
7 changed files with 129 additions and 25 deletions

View file

@ -4,12 +4,15 @@ import argparse
import io
import sys
import typing
from pathlib import Path
from . import semantic, ir, optimizations
from .errors import CompilationError, CompilationWarning
from .interpreter import virtual_machine
from .lexer import Lexer, Tokens
from .parser import Parser
from .source import TextIOWithMemory
from .tests import mock
class REPLInput(io.TextIOWrapper):
@ -45,11 +48,14 @@ def main():
display(byte, 3) + 1;
"""
input_stream = REPLInput()
_input_stream = REPLInput()
if args.mock:
print("Source:\n", data)
input_stream = io.StringIO(data)
_input_stream = io.StringIO(data)
input_stream = TextIOWithMemory(_input_stream)
tokens = \
Lexer(input_stream,
@ -66,7 +72,7 @@ def main():
context.check()
intermediate_representation = ir.IR(ast)
intermediate_representation.update_location(source=tokens.data)
intermediate_representation.update_location()
print("\n---\n", repr(context))
@ -80,11 +86,11 @@ def main():
ir.IRRegister.pprint()
except CompilationError as e:
e.location.source = tokens.data
e.location.source = tokens.input
e.pretty_print()
finally:
CompilationWarning.show_warnings(tokens.data, file=sys.stdout)
CompilationWarning.show_warnings(tokens.input, file=sys.stdout)
if __name__ == "__main__":

View file

@ -9,6 +9,7 @@ from termcolor import termcolor
from . import source, lexer
from .logger import rootLogger, LogLevel
from .source import TextIOWithMemory
class LevelType:
@ -58,7 +59,7 @@ class CompilationWarning(Warning):
self.location = location
@classmethod
def show_warnings(cls, _source: str, file: TextIO = sys.stderr):
def show_warnings(cls, _source: TextIOWithMemory, file: TextIO = sys.stderr):
for warning in CompilationWarning._pending_warnings:
warning.location.source = _source
print(f"{warning}\n{warning.location.show_in_source()}", file=file)

View file

@ -15,6 +15,7 @@ logger = Logger(__name__)
class IRItem:
def __init__(self, location: SourceLocation):
assert location.source is not None
self.location = location
@abstractmethod
@ -277,12 +278,11 @@ class IR:
def code(self) -> list[str]:
return [x.codegen() for x in self.intermediate_representation]
def update_location(self, source: str):
def update_location(self):
code = self.code()
ir_item: IRAction
for i, ir_item in enumerate(self.intermediate_representation):
ir_item.ir_location = Location(line=i, ir=code)
ir_item.location.source = source
def pretty_print(self):
messages = []
@ -290,7 +290,7 @@ class IR:
for i, ir_item in enumerate(self.intermediate_representation):
prefix = f"{str(ir_item.location) + ':':<30}"
source_info = ir_item.location.source_substring.splitlines(keepends=False)
logger.debug(f"source: {repr(ir_item.location.source)}")
logger.trace(f"source: {repr(ir_item.location.source_name)}")
if len(source_info) == 0:
source_info = ["<NO SOURCE>"]
messages += [f"# {prefix} {source_info.pop(0)}"]

View file

@ -8,7 +8,7 @@ from dataclasses import dataclass, field
from typing import cast
from .logger import Logger
from .source import SourceLocation, Location
from .source import SourceLocation, Location, TextIOWithMemory
from .typechecking import typecheck
from .utils import implies
@ -62,9 +62,8 @@ class Tokens(enum.Enum):
class Lexer(collections.abc.Iterator):
def __init__(self, input_stream: typing.TextIO, token_filter: typing.Callable[[Token], bool] | None = None):
def __init__(self, input_stream: TextIOWithMemory, token_filter: typing.Callable[[Token], bool] | None = None):
self.input = input_stream
self.data: str = ""
self.tokens = []
self.length: int | None = None
self.begin: int = 0
@ -91,27 +90,28 @@ class Lexer(collections.abc.Iterator):
tok = Token(Tokens.BEGIN,
loc=SourceLocation(
Location(line=0, character=0),
source=self.data
source=self.input
),
value=None)
self.tokens.append(tok)
return tok
if self.tokens[-1].kind in [Tokens.BEGIN, Tokens.Newline]:
self.data += self.input.readline()
self.input.readline()
if self.begin == len(self.data):
if self.begin == len(self.input.stream_cache):
eof_token = Token(Tokens.EOF, value=None, loc=SourceLocation(
Location(line=self.line, character=0),
source=self.input
))
self.tokens += [eof_token]
self.length = len(self.tokens)
return eof_token
elif self.begin < len(self.data):
elif self.begin < len(self.input.stream_cache):
best_result: Token = Token(Tokens.Unknown,
loc=SourceLocation(
Location(line=self.line, character=self.character),
source=self.data),
source=self.input),
value=""
)
token_kind: Tokens
@ -119,7 +119,7 @@ class Lexer(collections.abc.Iterator):
if token_kind == Tokens.Unknown:
continue
regex = cast(re.Pattern, token_kind.value)
match = regex.match(self.data, self.begin)
match = regex.match(self.input.stream_cache, self.begin)
if match is not None:
logger.trace(f"Got match: {match}")
result = match.group(0)
@ -130,7 +130,8 @@ class Lexer(collections.abc.Iterator):
continue
loc = SourceLocation(
begin=Location(line=self.line, character=self.character),
end=Location(line=self.line, character=self.character + len(result))
end=Location(line=self.line, character=self.character + len(result)),
source=self.input
)
best_result = Token(token_kind, value=result, loc=loc)
logger.trace(f"New best match: {best_result}")
@ -144,7 +145,7 @@ class Lexer(collections.abc.Iterator):
elif best_result.kind == Tokens.Newline:
self.line += 1
self.character = 0
best_result.loc.end = Location(line=self.line, character=0)
best_result.loc.end = Location(line=self.line, character=0, file=self.input.name)
logger.debug(f"Added token {best_result}")

View file

@ -26,10 +26,10 @@ class Node:
@functools.cache
def location(self) -> SourceLocation:
assert len(self._values()) > 0
locations = [v.location() for v in self._values()]
locations: list[SourceLocation] = [v.location() for v in self._values()]
begin = min([loc.begin for loc in locations])
end = max([loc.end for loc in locations])
loc = SourceLocation(begin=begin, end=end)
loc = SourceLocation(begin=begin, end=end, source=locations[0].source)
return loc
def __repr__(self):
@ -98,6 +98,7 @@ class Literal(Node, ABC):
super().__init__()
self.value = value
self.loc = location
assert location.source is not None
def location(self) -> SourceLocation:
return self.loc

View file

@ -15,7 +15,7 @@ builtin_node = nodes.PseudoNode(
lexer.Token(
kind=lexer.Tokens.Unknown,
value="__compiler_internal__",
loc=source.SourceLocation(begin=source.Location(internal=True), source="__compiler_internal__")
loc=source.COMPILER_INTERNAL
)
)

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import io
import math
import typing
from dataclasses import dataclass
from .logger import Logger
@ -9,6 +11,84 @@ from .typechecking import typecheck
logger = Logger(__name__)
@typecheck
class TextIOWithMemory(typing.TextIO):
def __init__(self, io_stream: io.TextIOBase):
assert isinstance(io_stream, io.TextIOBase)
self._stream = io_stream
self.stream_cache: str = ""
@property
def name(self):
_name = "<none>"
try:
_name = self._stream.name
except AttributeError:
pass
return _name
def __enter__(self):
return self._stream.__enter__()
def close(self):
return self._stream.close()
def fileno(self):
return self._stream.fileno()
def flush(self):
return self._stream.flush()
def isatty(self):
return self._stream.isatty()
def read(self, __n=...):
r = self._stream.read(__size=__n)
self.stream_cache += r
def readable(self):
return self._stream.readable()
def readline(self, *args, **kwargs):
r = self._stream.readline(*args, **kwargs)
self.stream_cache += r
def readlines(self, __hint=...):
return self._stream.readlines(__hint=__hint)
def seek(self, __offset, __whence=...):
return self._stream.seek(__offset, __whence=__whence)
def seekable(self):
# return self.stream.seekable()
return False
def tell(self):
return self._stream.tell()
def truncate(self, __size=...):
return self._stream.truncate(__size=__size)
def writable(self):
return self._stream.writable()
def write(self, __s):
return self._stream.write(__s)
def writelines(self, __lines):
return self._stream.writelines(__lines)
def __next__(self):
return self._stream.__next__()
def __iter__(self):
return self._stream.__iter__()
def __exit__(self, __t, __value, __traceback):
return self._stream.__exit__(__t, __value, __traceback)
@typecheck
@dataclass
class Location:
@ -68,13 +148,19 @@ class Location:
@typecheck
class SourceLocation:
def __init__(self, begin: Location, end: Location | None = None, source: str | None = None):
def __init__(self, begin: Location, end: Location | None = None, source: TextIOWithMemory | None = None):
self.begin = begin
self.end = end
if self.end is None:
self.end = self.begin
self.source = source
self.source_name = "<none>"
if source is not None:
self.source_name = source.name
self.begin.file = self.source_name
self.end.file = self.source_name
assert (self.begin.line, self.begin.character) <= (self.end.line, self.end.character)
assert self.begin.file == self.end.file
@ -103,7 +189,10 @@ class SourceLocation:
@property
def source_substring(self) -> str:
source = self.source.splitlines(keepends=False)
if self.source is None:
return "__compiler_internal__"
source = self.source.stream_cache.splitlines(keepends=False)
source_lines = source[self.begin.line:self.end.line + 1]
if len(source_lines) == 1:
source_lines[0] = source_lines[0][self.begin.character:self.end.character + 1]
@ -125,7 +214,10 @@ class SourceLocation:
return "\n".join(result)
def show_in_source(self) -> str:
source = self.source.splitlines(keepends=False)
if self.source is None:
return "__compiler_internal__"
source = self.source.stream_cache.splitlines(keepends=False)
line_number_maxlen = 1 if not len(source) else int(math.log10(len(source)) + 1)
lines = source[self.begin.line:self.end.line + 1]
result = []
@ -146,6 +238,9 @@ class SourceLocation:
return "\n".join(result)
COMPILER_INTERNAL = SourceLocation(begin=Location(internal=True), end=Location(internal=True),
source=TextIOWithMemory(io.StringIO("__compiler_internal__")))
@typecheck
class IRLocation:
def __init__(self, line: int = -1, ir: list[str] = None):