From a832cd1214249d57e1d70e4e361ada939180ce18 Mon Sep 17 00:00:00 2001 From: Antoine Viallon Date: Mon, 15 May 2023 00:33:25 +0200 Subject: [PATCH] treewide: make beartype optional --- compiler/ir.py | 15 +++++++-------- compiler/lexer.py | 12 ++++++------ compiler/nodes.py | 9 ++++----- compiler/parser.py | 8 ++++---- compiler/source.py | 16 +++++++--------- compiler/typechecking.py | 18 ++++++++++++++++++ 6 files changed, 46 insertions(+), 32 deletions(-) create mode 100644 compiler/typechecking.py diff --git a/compiler/ir.py b/compiler/ir.py index 9041cc1..beafd7c 100644 --- a/compiler/ir.py +++ b/compiler/ir.py @@ -3,11 +3,10 @@ from __future__ import annotations import abc from abc import abstractmethod -from beartype import beartype - from .errors import OverrideMandatoryError from .logger import Logger from .source import SourceLocation +from .typechecking import typecheck logger = Logger(__name__) @@ -39,7 +38,7 @@ class IRValue(IRItem, abc.ABC): class IRMove(IRAction): - @beartype + @typecheck def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue): super().__init__(location) self.dest = dest @@ -53,7 +52,7 @@ class IRMove(IRAction): class IRImmediate(IRValue): - @beartype + @typecheck def __init__(self, location: SourceLocation, value: int | float | str): super().__init__(location) self.value = value @@ -89,7 +88,7 @@ class IRVariable(IRAssignable): class IRAdd(IRAction): - @beartype + @typecheck def __init__(self, location: SourceLocation, dest: IRAssignable, *values: IRValue): super().__init__(location) assert all(isinstance(v, IRValue) for v in values) @@ -107,7 +106,7 @@ class IRAdd(IRAction): class IRMul(IRAction): - @beartype + @typecheck def __init__(self, location: SourceLocation, dest: IRAssignable, *values: IRValue): super().__init__(location) assert all(isinstance(v, IRValue) for v in values) @@ -125,7 +124,7 @@ class IRMul(IRAction): class IRNegation(IRAction): - @beartype + @typecheck def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue): super().__init__(location) @@ -141,7 +140,7 @@ class IRNegation(IRAction): class IRInvert(IRAction): - @beartype + @typecheck def __init__(self, location: SourceLocation, dest: IRAssignable, source: IRValue): super().__init__(location) diff --git a/compiler/lexer.py b/compiler/lexer.py index 87fa1d3..55c67cc 100644 --- a/compiler/lexer.py +++ b/compiler/lexer.py @@ -5,21 +5,19 @@ import enum import re from dataclasses import dataclass, field -from beartype import beartype -from beartype.typing import Optional, List - from .logger import Logger from .source import SourceLocation, Location +from .typechecking import typecheck logger = Logger(__name__) -@beartype +@typecheck @dataclass class Token: kind: Tokens loc: SourceLocation = field(compare=False, hash=False, default=None) - value: Optional[str] = field(compare=False, hash=False, default=None) + value: str | None = field(compare=False, hash=False, default=None) def __repr__(self): return f"{self.kind.name}({repr(self.value)})" @@ -78,7 +76,9 @@ class Lexer(collections.abc.Sequence): actual_result: Token if self.begin < len(self.data): best_result: Token = Token(Tokens.Unknown, - loc=SourceLocation(Location(line=self.line, character=self.character), source=self.data), + loc=SourceLocation( + Location(line=self.line, character=self.character), + source=self.data), value="" ) for token_kind in list(Tokens): diff --git a/compiler/nodes.py b/compiler/nodes.py index 090cecf..38614c1 100644 --- a/compiler/nodes.py +++ b/compiler/nodes.py @@ -4,12 +4,11 @@ import functools from abc import abstractmethod, ABC from typing import Any, Iterable -from beartype import beartype - from . import ir, semantic, lexer from .errors import SemanticAnalysisError, OverrideMandatoryError from .logger import Logger from .source import SourceLocation +from .typechecking import typecheck logger = Logger(__name__) @@ -54,7 +53,7 @@ class Node: vals = self._values() try: vals = (val for val in vals) - except TypeError as e: + except TypeError: vals = (vals,) result = [f"{self.__class__.__name__} {{"] @@ -223,13 +222,13 @@ class Division(Node): BinaryOperation = Sum | Sub | Product | Division -@beartype +@typecheck class Float(Literal): def __init__(self, location: SourceLocation, value: float): super().__init__(location, value) -@beartype +@typecheck class Integer(Literal): def __init__(self, location: SourceLocation, value: int): diff --git a/compiler/parser.py b/compiler/parser.py index ccafa09..50907ec 100644 --- a/compiler/parser.py +++ b/compiler/parser.py @@ -1,19 +1,19 @@ from __future__ import annotations -from beartype.typing import List, Dict, Callable +from typing import Callable from .errors import CompilationError, UnexpectedTokenError +from .lexer import Tokens, Token from .logger import Logger, Tracer, LogLevel from .nodes import Float, Sum, Value, Product, Node, Division, Sub, Integer, Expression, Identifier, Assignment, \ Variable, Statement, PseudoNode, Block -from .lexer import Tokens, Token logger = Logger(__name__) tracer = Tracer(logger, level=LogLevel.Debug) class Parser: - def __init__(self, tokens: List[Token]): + def __init__(self, tokens: list[Token]): self.tokens = tokens self.pos = 0 self._last_accepted_token: Tokens | None = None @@ -92,7 +92,7 @@ class Parser: raise UnexpectedTokenError(self.token, "variable identifier") @tracer.trace_method - def binary_op(self, operand_func: Callable[[bool], Value], operators: Dict[Tokens, Value], mandatory: bool = True): + def binary_op(self, operand_func: Callable[[bool], Value], operators: dict[Tokens, Value], mandatory: bool = True): operand = operand_func(mandatory) if not operand and mandatory: raise UnexpectedTokenError(operand, "operand") diff --git a/compiler/source.py b/compiler/source.py index e302a55..a765a54 100644 --- a/compiler/source.py +++ b/compiler/source.py @@ -3,15 +3,13 @@ from __future__ import annotations import math from dataclasses import dataclass -from beartype import beartype -from beartype.typing import Optional - from .logger import Logger +from .typechecking import typecheck logger = Logger(__name__) -@beartype +@typecheck @dataclass class Location: line: int @@ -21,7 +19,7 @@ class Location: def __str__(self) -> str: return f"{self.file}:{self.line}:{self.character}" - @beartype + @typecheck def __lt__(self, other: Location) -> bool: if self.file != other.file: logger.trace(f"{self} is not in the same file as {other}") @@ -41,11 +39,11 @@ class Location: return False - @beartype + @typecheck def __ge__(self, other: Location) -> bool: return not self.__lt__(other) - @beartype + @typecheck def __eq__(self, other: Location) -> bool: return all(( self.file == other.file, @@ -54,9 +52,9 @@ class Location: )) -@beartype +@typecheck class SourceLocation: - def __init__(self, begin: Location, end: Optional[Location] = None, source: Optional[str] = None): + def __init__(self, begin: Location, end: Location | None = None, source: str | None = None): self.begin = begin self.end = end if self.end is None: diff --git a/compiler/typechecking.py b/compiler/typechecking.py new file mode 100644 index 0000000..26e207e --- /dev/null +++ b/compiler/typechecking.py @@ -0,0 +1,18 @@ +from typing import Callable, TypeVar, ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + + +def _typecheck_stub(func: Callable[P, T]) -> Callable[P, T]: + return func + + +typecheck: Callable[P, T] = _typecheck_stub + +try: + import beartype + + typecheck = beartype.beartype +except ImportError: + pass