Skip to content
This repository was archived by the owner on Mar 3, 2026. It is now read-only.
This repository was archived by the owner on Mar 3, 2026. It is now read-only.

Bug in torch_xla._XLAC._xla_tensors_from_aten #353

@liurupeng

Description

@liurupeng
Image Seems it doesn't support sharding on the sequence length when fsdp and dp is enabled, it can't figure out the sharding logic when three sharding specs are enabled.

The issue is here: https://github.com/pytorch/xla/blob/cd3bd91f1b959c27047196855649a6a933023428/torch_xla/core/xla_model.py#L1297

After printed out the sharding of the convert_fn, it is correct, but the torch-xla function couldn't finish the sharding, the RCA is on the lower level in the _xla_tensors_from_aten function

Image

To use context parallelism, we need to bypass the parrallel_loader and directly use activations that are sharded correctly

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions