Skip to content
This repository was archived by the owner on Mar 3, 2026. It is now read-only.
This repository was archived by the owner on Mar 3, 2026. It is now read-only.

Data seems not sharded across processes in multi-host single-slice setting #267

@weirayao

Description

@weirayao

Hi, thanks for the great work. We are running the sample script train.py on v4-16 with mesh as follows:

XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python torchprime/torch_xla_models/train.py
model=llama-3-8b
global_batch_size=8
block_size=4096
max_steps=1000
ici_mesh.fsdp=8
ici_mesh.tensor=1
ici_mesh.data=1
ici_mesh.expert=1

We run the above script with gcloud:

gcloud alpha compute tpus tpu-vm ssh xxx-v4-16
--zone=xxx
--project=xxx
--tunnel-through-iap
--worker=all
--command="bash train.sh"

When printing out the batch on each host/process, we get the following logs:

[2025-05-30 23:22:00,939][main][INFO] - Logical mesh shape: OrderedDict([('data', 1), ('fsdp', 8), ('tensor', 1), ('expert', 1)])
[2025-05-30 23:22:00,939][main][INFO] - Logical mesh device assignments: [0 1 2 3 4 5 6 7]
[2025-05-30 23:22:00,940][main][INFO] - Minibatch dataloading: True
[2025-05-30 23:22:51,649][main][INFO] - All processes synchronized, starting training
[2025-05-30 23:22:51,649][main][INFO] - Num replicas: 2
[2025-05-30 23:22:51,649][main][INFO] - Num replicas: 2

Host: 0, batch: {'input_ids': tensor([[ 284, 23947, 1231, ..., 220, 19, 51857],
[ 364, 82, 11094, ..., 1365, 448, 45556],
[ 323, 30739, 311, ..., 9886, 87152, 8933],
...,
[79323, 14336, 304, ..., 1393, 24917, 307],
[ 284, 715, 10751, ..., 315, 98098, 62940],
[ 389, 5813, 220, ..., 11678, 39294, 892]], device='xla:0'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], device='xla:0'), 'labels': tensor([[ 284, 23947, 1231, ..., 220, 19, 51857],
[ 364, 82, 11094, ..., 1365, 448, 45556],
[ 323, 30739, 311, ..., 9886, 87152, 8933],
...,
[79323, 14336, 304, ..., 1393, 24917, 307],
[ 284, 715, 10751, ..., 315, 98098, 62940],
[ 389, 5813, 220, ..., 11678, 39294, 892]], device='xla:0')}, shape: torch.Size([8, 4096])
Host: 1, batch: {'input_ids': tensor([[ 284, 23947, 1231, ..., 220, 19, 51857],
[ 364, 82, 11094, ..., 1365, 448, 45556],
[ 323, 30739, 311, ..., 9886, 87152, 8933],
...,
[79323, 14336, 304, ..., 1393, 24917, 307],
[ 284, 715, 10751, ..., 315, 98098, 62940],
[ 389, 5813, 220, ..., 11678, 39294, 892]], device='xla:0'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], device='xla:0'), 'labels': tensor([[ 284, 23947, 1231, ..., 220, 19, 51857],
[ 364, 82, 11094, ..., 1365, 448, 45556],
[ 323, 30739, 311, ..., 9886, 87152, 8933],
...,
[79323, 14336, 304, ..., 1393, 24917, 307],
[ 284, 715, 10751, ..., 315, 98098, 62940],
[ 389, 5813, 220, ..., 11678, 39294, 892]], device='xla:0')}, shape: torch.Size([8, 4096])

This seems to imply we are getting the global batch (8 samples) on each host even if the Minibatch is enabled. Any idea why this happened? Also the printed loss is the same on the two hosts in the VM. Many thanks!!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions