Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import functools
import heapq
import random
import typing
from collections.abc import Callable
from dataclasses import dataclass
from pprint import pprint
Expand Down
24 changes: 17 additions & 7 deletions effectful/ops/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,12 @@ class Operation[**Q, V]:

"""

__signature__: inspect.Signature
__name__: str
__default__: Callable[Q, V]
__apply__: typing.ClassVar["Operation"]

def __init__(
self, signature: inspect.Signature, name: str, default: Callable[Q, V]
):
def __init__(self, name: str, default: Callable[Q, V]):
functools.update_wrapper(self, default)

self.__signature__ = signature
self.__name__ = name
self.__default__ = default

Expand Down Expand Up @@ -252,7 +247,7 @@ def func(*args, **kwargs):
op = cls.define(func, name=name)
else:
name = name or t.__name__
op = cls(inspect.signature(t), name, t) # type: ignore[arg-type]
op = cls(name, t) # type: ignore[arg-type]

return op # type: ignore[return-value]

Expand Down Expand Up @@ -324,6 +319,21 @@ def func(*args, **kwargs):
op.register = default._registry.register # type: ignore[attr-defined]
return op

@functools.cached_property
def __signature__(self):
annots = typing.get_type_hints(self.__default__, include_extras=True)
sig = inspect.signature(self.__default__)

updated_params = [
p.replace(annotation=annots[p.name]) if p.name in annots else p
for p in sig.parameters.values()
]
updated_ret = annots.get("return", sig.return_annotation)
updated_sig = sig.replace(
parameters=updated_params, return_annotation=updated_ret
)
return updated_sig

@typing.final
def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
"""The default rule is used when the operation is not handled.
Expand Down
24 changes: 24 additions & 0 deletions tests/test_ops_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,3 +1119,27 @@ def id[T](base: T) -> T:
raise NotHandled

assert isinstance(id(A(0)).x, Term)


# Forward references in types only work on module-level definitions.
@defop
def forward_ref_op() -> "A":
raise NotHandled


class A: ...


def test_defop_forward_ref():
term = forward_ref_op()
assert term.op == forward_ref_op
assert typeof(term) is A

@defop
def local_forward_ref_op() -> "B":
raise NotHandled

class B: ...

with pytest.raises(NameError):
local_forward_ref_op()
Loading