-
Notifications
You must be signed in to change notification settings - Fork 166
Description
Note
LLM generated issue, overly focused on XLA nudge, by no means a ready proposal
...
Problem/Goal
Currently, to broadcast a tensor (e.g., adding a new dimension or expanding a dimension of size 1), users often rely on operations like alloc or broadcast_to, which require specifying the full target shape. For dimensions that are simply inherited from the input without modification, the user must explicitly fetch their sizes (e.g., x.shape[i]) and pass them to the Op. This introduces unnecessary Shape and Subtensor (indexing) nodes to the graph purely to reconstruct known information. These redundant nodes clutter the graph, making it harder to reason about, optimize, and debug.
The goal is to implement a new Op that allows broadcasting specific dimensions without requiring the user to explicitly provide the sizes of unmodified dimensions. This aligns with XLA's capabilities (specifically BroadcastInDim) and simplifies graph construction.
Technical Details / Proposed Changes
Implement a new Op (potentially named BroadcastInDim or similar) that:
- Accepts the input tensor.
- Accepts a specification for the new or modified dimensions only, rather than the full target shape.
- This could be a mapping of output dimension indices to their new sizes (e.g.,
{1: new_size}), or a list of sizes for only the broadcast dimensions.
- This could be a mapping of output dimension indices to their new sizes (e.g.,
- Accepts a
broadcast_dimensionsargument (similar to XLA/JAX), which maps input dimensions to their corresponding output dimensions.- Example: If input
xhas shape(A, B)and we want output(A, C, B),broadcast_dimensionsmight be(0, 2), implying input dim 0 maps to output dim 0, and input dim 1 maps to output dim 2. Output dim 1 is the new dimensionC.
- Example: If input
- Infers the full output shape internally.
- For dimensions mapped from the input, the Op should symbolically use the input's shape properties without creating explicit
Shape/Subtensornodes in the graph. - For new dimensions, it uses the provided sizes.
- For dimensions mapped from the input, the Op should symbolically use the input's shape properties without creating explicit
- Backend Integration:
- Ensure the Op can be lowered efficiently to XLA's
BroadcastInDim(or equivalent in other backends), where the full shape is constructed during compilation/lowering rather than in the high-level graph.
- Ensure the Op can be lowered efficiently to XLA's
Benefits
- Graph Simplification: Eliminates redundant
Shape,MakeVector, and indexing nodes associated with retrieving input dimensions. - Improved Readability: The intent (broadcasting/inserting specific dimensions) is much clearer than a generic
alloccall. - Better Shape Inference: The Op explicitly encodes the relationship "Output dim
i== Input dimj", which aids static analysis and shape propagation. - Optimization: Facilitates pattern matching for rewrites and lowering to backends (like XLA/JAX) that support
BroadcastInDimnatively.