-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
Hey I am getting the following error in my conversion which seems to be coming from the buffers which is weird since it was able to move params (also a dict) without any issues. Wondering whether I need to use a more explicit API.
File "/orcd/home/002/mkotak/y/envs/matbench-speed/lib/python3.11/site-packages/nequix/calculator.py", line 185, in calculate
energy_per_atom, forces, stress = self.model_func(
^^^^^^^^^^^^^^^^
File "/orcd/home/002/mkotak/y/envs/matbench-speed/lib/python3.11/site-packages/torchax/interop.py", line 219, in call_jax
args, kwargs = jax_view((args, kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^
File "/orcd/home/002/mkotak/y/envs/matbench-speed/lib/python3.11/site-packages/jax/_src/tree_util.py", line 361, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/orcd/home/002/mkotak/y/envs/matbench-speed/lib/python3.11/site-packages/jax/_src/tree_util.py", line 361, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^
File "/orcd/home/002/mkotak/y/envs/matbench-speed/lib/python3.11/site-packages/torchax/interop.py", line 202, in _jax_view
assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: <class 'torch.Tensor'>import torch
from torchax.interop import JittableModule, jax_jit, call_jax
graph = dict_to_pytorch_geometric(processed_graph)
graph.n_graph = torch.zeros(graph.x.shape[0], dtype=torch.int32).to(self.device)
graph = graph.to(self.device)
model_jittable = JittableModule(self.model).to(self.device)
self.model_func = jax_jit(functools.partial(model_jittable.functional_call, 'forward'),
kwargs_for_jax_jit={'donate_argnums': (0,)})
self.weights = model_jittable.params
self.buffers = model_jittable.buffers
energy_per_atom, forces, stress = self.model_func(
self.weights,
self.buffers,
(graph.x,
graph.positions,
graph.edge_attr,
graph.edge_index,
getattr(graph, "cell", None),
graph.n_node,
graph.n_edge,
graph.n_graph)
)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels