parser: only require an iterator of tokens instead of a list

This commit is contained in:
Antoine Viallon 2024-01-11 00:52:17 +01:00
parent 46c7907165
commit 530214b254
Signed by: aviallon
GPG key ID: 186FC35EDEB25716

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import collections.abc
from typing import Callable from typing import Callable
from .errors import CompilationError, UnexpectedTokenError from .errors import CompilationError, UnexpectedTokenError
@ -13,20 +14,40 @@ tracer = Tracer(logger, level=LogLevel.Debug)
class Parser: class Parser:
def __init__(self, tokens: list[Token]): def __init__(self, tokens: collections.abc.Iterator[Token]):
self.tokens = tokens self.tokens = tokens
self._token_cache: list[Token] = []
self._EOF = False
self.pos = 0 self.pos = 0
self._last_accepted_token: Tokens | None = None self._last_accepted_token: Tokens | None = None
@property @property
def token(self) -> Token: def token(self) -> Token:
if self.pos >= len(self.tokens): if self._EOF:
return Token(kind=Tokens.EOF) return self._token_cache[-1]
return self.tokens[self.pos]
while len(self._token_cache) <= self.pos:
tok = next(self.tokens)
self._token_cache.append(tok)
if tok.kind == Tokens.EOF:
self._EOF = True
return tok
return self._token_cache[self.pos]
def _fetch_until(self, desired_length: int) -> int:
while len(self._token_cache) <= desired_length and not self._EOF:
tok = next(self.tokens)
self._token_cache.append(tok)
if tok.kind == Tokens.EOF:
self._EOF = True
break
return len(self._token_cache)
@property @property
def prev_token(self) -> Token: def prev_token(self) -> Token:
return self.tokens[self.pos - 1] return self._token_cache[self.pos - 1]
def next_symbol(self): def next_symbol(self):
self.pos += 1 self.pos += 1
@ -48,10 +69,11 @@ class Parser:
return False return False
def peek_several(self, *tokens_types: Tokens) -> False | list[Token]: def peek_several(self, *tokens_types: Tokens) -> False | list[Token]:
if self.pos + len(tokens_types) >= len(self.tokens): desired_pos = self.pos + len(tokens_types)
if desired_pos >= self._fetch_until(desired_pos + 1):
return False return False
toks = self.tokens[self.pos:self.pos + len(tokens_types)] toks = self._token_cache[self.pos:self.pos + len(tokens_types)]
for i, token in enumerate(toks): for i, token in enumerate(toks):
if token.kind != tokens_types[i]: if token.kind != tokens_types[i]:
return False return False