From c4c515ae1f6045e3cd0f155e19f97a460a799412 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 19 Dec 2025 12:31:36 -0500 Subject: [PATCH 1/3] make __signature__ a lazily computed property --- docs/source/beam.py | 1 - effectful/ops/types.py | 24 +++++++++++++++++------- tests/test_ops_syntax.py | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/docs/source/beam.py b/docs/source/beam.py index 87e9362f..843572b0 100644 --- a/docs/source/beam.py +++ b/docs/source/beam.py @@ -6,7 +6,6 @@ import functools import heapq import random -import typing from collections.abc import Callable from dataclasses import dataclass from pprint import pprint diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 9a670f1e..68a0f754 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -73,17 +73,12 @@ class Operation[**Q, V]: """ - __signature__: inspect.Signature __name__: str __default__: Callable[Q, V] __apply__: typing.ClassVar["Operation"] - def __init__( - self, signature: inspect.Signature, name: str, default: Callable[Q, V] - ): + def __init__(self, name: str, default: Callable[Q, V]): functools.update_wrapper(self, default) - - self.__signature__ = signature self.__name__ = name self.__default__ = default @@ -252,7 +247,7 @@ def func(*args, **kwargs): op = cls.define(func, name=name) else: name = name or t.__name__ - op = cls(inspect.signature(t), name, t) # type: ignore[arg-type] + op = cls(name, t) # type: ignore[arg-type] return op # type: ignore[return-value] @@ -324,6 +319,21 @@ def func(*args, **kwargs): op.register = default._registry.register # type: ignore[attr-defined] return op + @functools.cached_property + def __signature__(self): + annots = typing.get_type_hints(self.__default__, include_extras=True) + sig = inspect.signature(self.__default__) + + updated_params = [ + p.replace(annotation=annots[p.name]) if p.name in annots else p + for p in sig.parameters.values() + ] + updated_ret = annots.get("return", sig.return_annotation) + updated_sig = sig.replace( + parameters=updated_params, return_annotation=updated_ret + ) + return updated_sig + @typing.final def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": """The default rule is used when the operation is not handled. diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index fb52a392..95b58655 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1119,3 +1119,18 @@ def id[T](base: T) -> T: raise NotHandled assert isinstance(id(A(0)).x, Term) + + +# Forward references in types only work on module-level definitions. +@defop +def forward_ref_op() -> "A": + raise NotHandled + + +class A: ... + + +def test_defop_forward_ref(): + term = forward_ref_op() + assert term.op == forward_ref_op + assert typeof(term) is A From 2f8a944d83950880a0fc0f2cce0cbd9ea9c32f26 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 19 Dec 2025 12:50:07 -0500 Subject: [PATCH 2/3] extend test --- tests/test_ops_syntax.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 95b58655..e680819f 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1134,3 +1134,12 @@ def test_defop_forward_ref(): term = forward_ref_op() assert term.op == forward_ref_op assert typeof(term) is A + + @defop + def local_forward_ref_op() -> "B": + raise NotHandled + + class B: ... + + with pytest.raises(NameError): + term2 = local_forward_ref_op() From aa4c49cdff365923bf688bff19c18d6705140d25 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 19 Dec 2025 12:56:17 -0500 Subject: [PATCH 3/3] lint --- tests/test_ops_syntax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index e680819f..f133f646 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1142,4 +1142,4 @@ def local_forward_ref_op() -> "B": class B: ... with pytest.raises(NameError): - term2 = local_forward_ref_op() + local_forward_ref_op()