-
Notifications
You must be signed in to change notification settings - Fork 166
Open
Labels
Milestone
Description
Description
As discussed in #1842 (comment)
Instead of forcing the input axis to be consecutive in join_dims, and the output axis to come out consecutive in split_dims, we could generalize it so they can map anywhere and still be functional inverses of each other.
def join_dims(x, axes: int | Sequence[int] | None = None, output_axis: int = 0):
if axes = None:
axes = tuple(range(x.ndim))
if isinstance(axes, int):
axes = (axes,) # No dims actually joined
# I am not sure I want to do that, but if we could make output_axis default to the first of `axes` if these are consecutive
# or zero otherwise. Much like numpy advanced indexing decides where to place the advanced views of the array.
# This would be a more back-compatible with the old behavior of `join_dims` that only supported consecutive axis
return JoinDims(output_axis, len(axes))(pt.moveaxis(x, axes, output_axis))
def split_dims(x, shape, axis: int = 0, output_axes: int | Sequence[int] | None = None):
if output_axes = None:
# default is the same location as input axis
output_axes = axis
if isinstance(output_axes, int):
output_axes = tuple(range(output_axes, output_axes + len(shape))
axes = tuple(range(axis, axis + len(shape))
return pt.moveaxis(SplitDims(axis)(x), axes, output_axes)We should decide pretty soon, as it's a breaking change, and we want to bring these ops to the spotlight.
Reactions are currently unavailable