diff --git a/compiler/logger.py b/compiler/logger.py index 2783bde..8e228b4 100644 --- a/compiler/logger.py +++ b/compiler/logger.py @@ -82,7 +82,7 @@ class Logger: T = typing.TypeVar("T") - +P = typing.ParamSpec("P") class Tracer: def __init__(self, logger: Logger, level: LogLevel = LogLevel.Trace, max_depth: int = 15): @@ -91,13 +91,13 @@ class Tracer: self.logger = logger self.max_depth: int = max_depth - def trace_method(self, func_or_nothing: Callable[..., T] | None = None, **kwargs): + def trace_method(self, func_or_nothing: Callable[P, T] | None = None, /, **kwargs) -> Callable[P, T]: if func_or_nothing is not None: return self._trace_method(func=func_or_nothing, **kwargs) return lambda func: self._trace_method(func, **kwargs) - def _trace_method(self, func: Callable[..., T], level: LogLevel = None) -> Callable[..., T]: + def _trace_method(self, func: Callable[P, T], level: LogLevel = None) -> Callable[P, T]: if level is None: level = self.level @@ -106,7 +106,7 @@ class Tracer: func_info = inspect.signature(func) - def _wrapped(*args, **kwargs) -> T: + def _wrapped(*args: P.args, **kwargs: P.kwargs) -> T: bound_args = func_info.bind(*args, **kwargs) signature = ", ".join( key if key == "self" else f"{key}={repr(value)}"