diff --git a/docs/source/beam.py b/docs/source/beam.py index 87e9362f..843572b0 100644 --- a/docs/source/beam.py +++ b/docs/source/beam.py @@ -6,7 +6,6 @@ import functools import heapq import random -import typing from collections.abc import Callable from dataclasses import dataclass from pprint import pprint diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 9a670f1e..68a0f754 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -73,17 +73,12 @@ class Operation[**Q, V]: """ - __signature__: inspect.Signature __name__: str __default__: Callable[Q, V] __apply__: typing.ClassVar["Operation"] - def __init__( - self, signature: inspect.Signature, name: str, default: Callable[Q, V] - ): + def __init__(self, name: str, default: Callable[Q, V]): functools.update_wrapper(self, default) - - self.__signature__ = signature self.__name__ = name self.__default__ = default @@ -252,7 +247,7 @@ def func(*args, **kwargs): op = cls.define(func, name=name) else: name = name or t.__name__ - op = cls(inspect.signature(t), name, t) # type: ignore[arg-type] + op = cls(name, t) # type: ignore[arg-type] return op # type: ignore[return-value] @@ -324,6 +319,21 @@ def func(*args, **kwargs): op.register = default._registry.register # type: ignore[attr-defined] return op + @functools.cached_property + def __signature__(self): + annots = typing.get_type_hints(self.__default__, include_extras=True) + sig = inspect.signature(self.__default__) + + updated_params = [ + p.replace(annotation=annots[p.name]) if p.name in annots else p + for p in sig.parameters.values() + ] + updated_ret = annots.get("return", sig.return_annotation) + updated_sig = sig.replace( + parameters=updated_params, return_annotation=updated_ret + ) + return updated_sig + @typing.final def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": """The default rule is used when the operation is not handled. diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index fb52a392..f133f646 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1119,3 +1119,27 @@ def id[T](base: T) -> T: raise NotHandled assert isinstance(id(A(0)).x, Term) + + +# Forward references in types only work on module-level definitions. +@defop +def forward_ref_op() -> "A": + raise NotHandled + + +class A: ... + + +def test_defop_forward_ref(): + term = forward_ref_op() + assert term.op == forward_ref_op + assert typeof(term) is A + + @defop + def local_forward_ref_op() -> "B": + raise NotHandled + + class B: ... + + with pytest.raises(NameError): + local_forward_ref_op()