treewide: make beartype optional

This commit is contained in:
Antoine Viallon 2023-05-15 00:33:25 +02:00
parent bac49d20d7
commit a832cd1214
Signed by: aviallon
GPG key ID: D126B13AB555E16F
6 changed files with 46 additions and 32 deletions

View file

@ -3,11 +3,10 @@ from __future__ import annotations
import abc import abc
from abc import abstractmethod from abc import abstractmethod
from beartype import beartype
from .errors import OverrideMandatoryError from .errors import OverrideMandatoryError
from .logger import Logger from .logger import Logger
from .source import SourceLocation from .source import SourceLocation
from .typechecking import typecheck
logger = Logger(__name__) logger = Logger(__name__)
@ -39,7 +38,7 @@ class IRValue(IRItem, abc.ABC):
class IRMove(IRAction): class IRMove(IRAction):
@beartype @typecheck
def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue): def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue):
super().__init__(location) super().__init__(location)
self.dest = dest self.dest = dest
@ -53,7 +52,7 @@ class IRMove(IRAction):
class IRImmediate(IRValue): class IRImmediate(IRValue):
@beartype @typecheck
def __init__(self, location: SourceLocation, value: int | float | str): def __init__(self, location: SourceLocation, value: int | float | str):
super().__init__(location) super().__init__(location)
self.value = value self.value = value
@ -89,7 +88,7 @@ class IRVariable(IRAssignable):
class IRAdd(IRAction): class IRAdd(IRAction):
@beartype @typecheck
def __init__(self, location: SourceLocation, dest: IRAssignable, *values: IRValue): def __init__(self, location: SourceLocation, dest: IRAssignable, *values: IRValue):
super().__init__(location) super().__init__(location)
assert all(isinstance(v, IRValue) for v in values) assert all(isinstance(v, IRValue) for v in values)
@ -107,7 +106,7 @@ class IRAdd(IRAction):
class IRMul(IRAction): class IRMul(IRAction):
@beartype @typecheck
def __init__(self, location: SourceLocation, dest: IRAssignable, *values: IRValue): def __init__(self, location: SourceLocation, dest: IRAssignable, *values: IRValue):
super().__init__(location) super().__init__(location)
assert all(isinstance(v, IRValue) for v in values) assert all(isinstance(v, IRValue) for v in values)
@ -125,7 +124,7 @@ class IRMul(IRAction):
class IRNegation(IRAction): class IRNegation(IRAction):
@beartype @typecheck
def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue): def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue):
super().__init__(location) super().__init__(location)
@ -141,7 +140,7 @@ class IRNegation(IRAction):
class IRInvert(IRAction): class IRInvert(IRAction):
@beartype @typecheck
def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue): def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue):
super().__init__(location) super().__init__(location)

View file

@ -5,21 +5,19 @@ import enum
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from beartype import beartype
from beartype.typing import Optional, List
from .logger import Logger from .logger import Logger
from .source import SourceLocation, Location from .source import SourceLocation, Location
from .typechecking import typecheck
logger = Logger(__name__) logger = Logger(__name__)
@beartype @typecheck
@dataclass @dataclass
class Token: class Token:
kind: Tokens kind: Tokens
loc: SourceLocation = field(compare=False, hash=False, default=None) loc: SourceLocation = field(compare=False, hash=False, default=None)
value: Optional[str] = field(compare=False, hash=False, default=None) value: str | None = field(compare=False, hash=False, default=None)
def __repr__(self): def __repr__(self):
return f"{self.kind.name}({repr(self.value)})" return f"{self.kind.name}({repr(self.value)})"
@ -78,7 +76,9 @@ class Lexer(collections.abc.Sequence):
actual_result: Token actual_result: Token
if self.begin < len(self.data): if self.begin < len(self.data):
best_result: Token = Token(Tokens.Unknown, best_result: Token = Token(Tokens.Unknown,
loc=SourceLocation(Location(line=self.line, character=self.character), source=self.data), loc=SourceLocation(
Location(line=self.line, character=self.character),
source=self.data),
value="" value=""
) )
for token_kind in list(Tokens): for token_kind in list(Tokens):

View file

@ -4,12 +4,11 @@ import functools
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from typing import Any, Iterable from typing import Any, Iterable
from beartype import beartype
from . import ir, semantic, lexer from . import ir, semantic, lexer
from .errors import SemanticAnalysisError, OverrideMandatoryError from .errors import SemanticAnalysisError, OverrideMandatoryError
from .logger import Logger from .logger import Logger
from .source import SourceLocation from .source import SourceLocation
from .typechecking import typecheck
logger = Logger(__name__) logger = Logger(__name__)
@ -54,7 +53,7 @@ class Node:
vals = self._values() vals = self._values()
try: try:
vals = (val for val in vals) vals = (val for val in vals)
except TypeError as e: except TypeError:
vals = (vals,) vals = (vals,)
result = [f"{self.__class__.__name__} {{"] result = [f"{self.__class__.__name__} {{"]
@ -223,13 +222,13 @@ class Division(Node):
BinaryOperation = Sum | Sub | Product | Division BinaryOperation = Sum | Sub | Product | Division
@beartype @typecheck
class Float(Literal): class Float(Literal):
def __init__(self, location: SourceLocation, value: float): def __init__(self, location: SourceLocation, value: float):
super().__init__(location, value) super().__init__(location, value)
@beartype @typecheck
class Integer(Literal): class Integer(Literal):
def __init__(self, location: SourceLocation, value: int): def __init__(self, location: SourceLocation, value: int):

View file

@ -1,19 +1,19 @@
from __future__ import annotations from __future__ import annotations
from beartype.typing import List, Dict, Callable from typing import Callable
from .errors import CompilationError, UnexpectedTokenError from .errors import CompilationError, UnexpectedTokenError
from .lexer import Tokens, Token
from .logger import Logger, Tracer, LogLevel from .logger import Logger, Tracer, LogLevel
from .nodes import Float, Sum, Value, Product, Node, Division, Sub, Integer, Expression, Identifier, Assignment, \ from .nodes import Float, Sum, Value, Product, Node, Division, Sub, Integer, Expression, Identifier, Assignment, \
Variable, Statement, PseudoNode, Block Variable, Statement, PseudoNode, Block
from .lexer import Tokens, Token
logger = Logger(__name__) logger = Logger(__name__)
tracer = Tracer(logger, level=LogLevel.Debug) tracer = Tracer(logger, level=LogLevel.Debug)
class Parser: class Parser:
def __init__(self, tokens: List[Token]): def __init__(self, tokens: list[Token]):
self.tokens = tokens self.tokens = tokens
self.pos = 0 self.pos = 0
self._last_accepted_token: Tokens | None = None self._last_accepted_token: Tokens | None = None
@ -92,7 +92,7 @@ class Parser:
raise UnexpectedTokenError(self.token, "variable identifier") raise UnexpectedTokenError(self.token, "variable identifier")
@tracer.trace_method @tracer.trace_method
def binary_op(self, operand_func: Callable[[bool], Value], operators: Dict[Tokens, Value], mandatory: bool = True): def binary_op(self, operand_func: Callable[[bool], Value], operators: dict[Tokens, Value], mandatory: bool = True):
operand = operand_func(mandatory) operand = operand_func(mandatory)
if not operand and mandatory: if not operand and mandatory:
raise UnexpectedTokenError(operand, "operand") raise UnexpectedTokenError(operand, "operand")

View file

@ -3,15 +3,13 @@ from __future__ import annotations
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from beartype import beartype
from beartype.typing import Optional
from .logger import Logger from .logger import Logger
from .typechecking import typecheck
logger = Logger(__name__) logger = Logger(__name__)
@beartype @typecheck
@dataclass @dataclass
class Location: class Location:
line: int line: int
@ -21,7 +19,7 @@ class Location:
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.file}:{self.line}:{self.character}" return f"{self.file}:{self.line}:{self.character}"
@beartype @typecheck
def __lt__(self, other: Location) -> bool: def __lt__(self, other: Location) -> bool:
if self.file != other.file: if self.file != other.file:
logger.trace(f"{self} is not in the same file as {other}") logger.trace(f"{self} is not in the same file as {other}")
@ -41,11 +39,11 @@ class Location:
return False return False
@beartype @typecheck
def __ge__(self, other: Location) -> bool: def __ge__(self, other: Location) -> bool:
return not self.__lt__(other) return not self.__lt__(other)
@beartype @typecheck
def __eq__(self, other: Location) -> bool: def __eq__(self, other: Location) -> bool:
return all(( return all((
self.file == other.file, self.file == other.file,
@ -54,9 +52,9 @@ class Location:
)) ))
@beartype @typecheck
class SourceLocation: class SourceLocation:
def __init__(self, begin: Location, end: Optional[Location] = None, source: Optional[str] = None): def __init__(self, begin: Location, end: Location | None = None, source: str | None = None):
self.begin = begin self.begin = begin
self.end = end self.end = end
if self.end is None: if self.end is None:

18
compiler/typechecking.py Normal file
View file

@ -0,0 +1,18 @@
from typing import Callable, TypeVar, ParamSpec
T = TypeVar("T")
P = ParamSpec("P")
def _typecheck_stub(func: Callable[P, T]) -> Callable[P, T]:
return func
typecheck: Callable[P, T] = _typecheck_stub
try:
import beartype
typecheck = beartype.beartype
except ImportError:
pass