Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
381c765
Initial plan
Copilot Sep 28, 2025
687349c
Implement core refactoring of AdvancedSubtensor and AdvancedIncSubtensor
Copilot Sep 28, 2025
e9675d8
Complete refactoring with improved factory functions and proper slice…
Copilot Sep 28, 2025
a75c904
Final fix: use as_index_variable consistently with original implement…
Copilot Sep 28, 2025
a28103a
Refactor newaxis handling: move to __getitem__ level, unify with Subt…
Copilot Sep 28, 2025
18982eb
Update dispatch functions and rewrite rules for new AdvancedSubtensor…
Copilot Sep 29, 2025
4d0daca
Finish Copilot code
jaanerik Dec 4, 2025
acaf059
Replace np.newaxis with None, remove NoneConst from indexing
jaanerik Dec 16, 2025
8e4a39f
Fix rewriting, use existing functions, respect subclasses
jaanerik Dec 18, 2025
4c0c5f9
Fix tests
jaanerik Dec 19, 2025
05162f9
Implement BaseSubtensor
jaanerik Dec 19, 2025
bd171ba
Fix rebase
jaanerik Dec 30, 2025
2c736a1
Remove one test that incorrectly extends materialized Ops
jaanerik Jan 2, 2026
fa0e23d
Rename tensor_inputs to index_variables
jaanerik Jan 2, 2026
f863957
Fix JAX dispatch
jaanerik Jan 4, 2026
2a7bb42
Refactor AdvancedSubtensor1 to use idx_list, add comments
jaanerik Jan 4, 2026
39d2e6a
Implement simpler idx_list
jaanerik Jan 6, 2026
9110a08
Revert unrelated code, remove deprecated code from a test
jaanerik Jan 7, 2026
d0ba66c
Remove symbolic slices
jaanerik Jan 15, 2026
b896c64
Fix leaky test
jaanerik Jan 30, 2026
5c7e1fa
Simplify
jaanerik Jan 30, 2026
c918c0c
Add xtensor hashing and restore is_full_slice for pymc-extras compati…
jaanerik Feb 1, 2026
c27841b
Remove comments and useless code
jaanerik Feb 4, 2026
d6c9153
Small refactor
jaanerik Feb 6, 2026
6359718
Revert XTensor refactor
jaanerik Feb 6, 2026
a52ce00
Implement AdvancedSubtensors in bool_idx_to_nonzero
jaanerik Feb 6, 2026
c905a21
Simplify numba dispatch
jaanerik Feb 9, 2026
efd5392
Add helper fun to not violate against DRY
jaanerik Feb 9, 2026
1eba923
Fix failing test with symbolic slice
jaanerik Feb 9, 2026
daad63f
Remove redundant code/comments and simplify
jaanerik Feb 10, 2026
4ec9707
Unify x,y,*index_variables naming convention
jaanerik Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,9 +771,9 @@ def orderings(self, fgraph, ordered=True):
}
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", []
app.op, "destroyhandler_tolerate_aliased", ()
)
assert isinstance(tolerate_aliased, list)
assert isinstance(tolerate_aliased, tuple | list)
ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
}
Expand Down
36 changes: 3 additions & 33 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice


BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
Expand All @@ -35,10 +34,8 @@
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -48,10 +45,9 @@ def subtensor(x, *ilists):


@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
Expand All @@ -62,7 +58,7 @@ def jax_fn(x, indices, y):
def jax_fn(x, indices, y):
return x.at[indices].add(y)

def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
Expand All @@ -73,29 +69,3 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
return jax_fn(x, indices, y)

return incsubtensor


@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
return x.at[indices].set(y)

else:

def jax_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)

return advancedincsubtensor


@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)

return makeslice
23 changes: 5 additions & 18 deletions pytensor/link/mlx/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice


@mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

def subtensor(x, *ilists):
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
indices = indices_from_subtensor(
[int(element) for element in ilists], op.idx_list
)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -30,10 +29,8 @@ def subtensor(x, *ilists):
@mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -45,8 +42,6 @@ def advanced_subtensor(x, *ilists):
@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

if getattr(op, "set_instead_of_inc", False):

def mlx_fn(x, indices, y):
Expand All @@ -63,7 +58,7 @@ def mlx_fn(x, indices, y):
x[indices] += y
return x

def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
Expand Down Expand Up @@ -95,11 +90,3 @@ def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
return mlx_fn(x, ilist, y)

return advancedincsubtensor


@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)

return makeslice
Loading