diff --git a/test/test_libraries.py b/test/test_libraries.py index cbac6a6..a527f72 100644 --- a/test/test_libraries.py +++ b/test/test_libraries.py @@ -16,7 +16,7 @@ import torch import torch.nn.functional as F -from torch.library import Library, impl, impl_abstract +from torch.library import Library, impl, register_fake import torchax import torchax.export @@ -44,7 +44,7 @@ def _mylib_scaled_dot_product_attention(q, k, v): return y.transpose(1, 2) -@impl_abstract("mylib::scaled_dot_product_attention") +@register_fake("mylib::scaled_dot_product_attention") def _mylib_scaled_dot_product_attention_meta(q, k, v): return torch.empty_like(q)