Skip to content

Parametrize Alloc by the dimensions that are actually broadcasted #1886

@ricardoV94

Description

@ricardoV94

Note

LLM generated issue, overly focused on XLA nudge, by no means a ready proposal

...

Problem/Goal
Currently, to broadcast a tensor (e.g., adding a new dimension or expanding a dimension of size 1), users often rely on operations like alloc or broadcast_to, which require specifying the full target shape. For dimensions that are simply inherited from the input without modification, the user must explicitly fetch their sizes (e.g., x.shape[i]) and pass them to the Op. This introduces unnecessary Shape and Subtensor (indexing) nodes to the graph purely to reconstruct known information. These redundant nodes clutter the graph, making it harder to reason about, optimize, and debug.

The goal is to implement a new Op that allows broadcasting specific dimensions without requiring the user to explicitly provide the sizes of unmodified dimensions. This aligns with XLA's capabilities (specifically BroadcastInDim) and simplifies graph construction.

Technical Details / Proposed Changes
Implement a new Op (potentially named BroadcastInDim or similar) that:

  1. Accepts the input tensor.
  2. Accepts a specification for the new or modified dimensions only, rather than the full target shape.
    • This could be a mapping of output dimension indices to their new sizes (e.g., {1: new_size}), or a list of sizes for only the broadcast dimensions.
  3. Accepts a broadcast_dimensions argument (similar to XLA/JAX), which maps input dimensions to their corresponding output dimensions.
    • Example: If input x has shape (A, B) and we want output (A, C, B), broadcast_dimensions might be (0, 2), implying input dim 0 maps to output dim 0, and input dim 1 maps to output dim 2. Output dim 1 is the new dimension C.
  4. Infers the full output shape internally.
    • For dimensions mapped from the input, the Op should symbolically use the input's shape properties without creating explicit Shape/Subtensor nodes in the graph.
    • For new dimensions, it uses the provided sizes.
  5. Backend Integration:
    • Ensure the Op can be lowered efficiently to XLA's BroadcastInDim (or equivalent in other backends), where the full shape is constructed during compilation/lowering rather than in the high-level graph.

Benefits

  • Graph Simplification: Eliminates redundant Shape, MakeVector, and indexing nodes associated with retrieving input dimensions.
  • Improved Readability: The intent (broadcasting/inserting specific dimensions) is much clearer than a generic alloc call.
  • Better Shape Inference: The Op explicitly encodes the relationship "Output dim i == Input dim j", which aids static analysis and shape propagation.
  • Optimization: Facilitates pattern matching for rewrites and lowering to backends (like XLA/JAX) that support BroadcastInDim natively.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions