Skip to content

Generalize (join|split)_dims to work with arbitrary axes locations #1844

@ricardoV94

Description

@ricardoV94

Description

As discussed in #1842 (comment)

Instead of forcing the input axis to be consecutive in join_dims, and the output axis to come out consecutive in split_dims, we could generalize it so they can map anywhere and still be functional inverses of each other.

def join_dims(x, axes: int | Sequence[int] | None = None, output_axis: int = 0):
  if axes = None:
    axes = tuple(range(x.ndim))
  if isinstance(axes, int):
    axes = (axes,)  # No dims actually joined

  # I am not sure I want to do that, but if we could make output_axis default to the first of `axes` if these are consecutive
  # or zero otherwise. Much like numpy advanced indexing decides where to place the advanced views of the array.
  # This would be a more back-compatible with the old behavior of `join_dims` that only supported consecutive axis
  
  return JoinDims(output_axis, len(axes))(pt.moveaxis(x, axes, output_axis))


def split_dims(x, shape, axis: int = 0, output_axes: int | Sequence[int] | None = None):
  if output_axes = None:
    # default is the same location as input axis
    output_axes = axis
  if isinstance(output_axes, int):
    output_axes = tuple(range(output_axes, output_axes + len(shape))
  
  axes = tuple(range(axis, axis + len(shape))
  return pt.moveaxis(SplitDims(axis)(x), axes, output_axes)

We should decide pretty soon, as it's a breaking change, and we want to bring these ops to the spotlight.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions