-
Notifications
You must be signed in to change notification settings - Fork 14
Data seems not sharded across processes in multi-host single-slice setting #267
Description
Hi, thanks for the great work. We are running the sample script train.py on v4-16 with mesh as follows:
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python torchprime/torch_xla_models/train.py
model=llama-3-8b
global_batch_size=8
block_size=4096
max_steps=1000
ici_mesh.fsdp=8
ici_mesh.tensor=1
ici_mesh.data=1
ici_mesh.expert=1
We run the above script with gcloud:
gcloud alpha compute tpus tpu-vm ssh xxx-v4-16
--zone=xxx
--project=xxx
--tunnel-through-iap
--worker=all
--command="bash train.sh"
When printing out the batch on each host/process, we get the following logs:
[2025-05-30 23:22:00,939][main][INFO] - Logical mesh shape: OrderedDict([('data', 1), ('fsdp', 8), ('tensor', 1), ('expert', 1)])
[2025-05-30 23:22:00,939][main][INFO] - Logical mesh device assignments: [0 1 2 3 4 5 6 7]
[2025-05-30 23:22:00,940][main][INFO] - Minibatch dataloading: True
[2025-05-30 23:22:51,649][main][INFO] - All processes synchronized, starting training
[2025-05-30 23:22:51,649][main][INFO] - Num replicas: 2
[2025-05-30 23:22:51,649][main][INFO] - Num replicas: 2Host: 0, batch: {'input_ids': tensor([[ 284, 23947, 1231, ..., 220, 19, 51857],
[ 364, 82, 11094, ..., 1365, 448, 45556],
[ 323, 30739, 311, ..., 9886, 87152, 8933],
...,
[79323, 14336, 304, ..., 1393, 24917, 307],
[ 284, 715, 10751, ..., 315, 98098, 62940],
[ 389, 5813, 220, ..., 11678, 39294, 892]], device='xla:0'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], device='xla:0'), 'labels': tensor([[ 284, 23947, 1231, ..., 220, 19, 51857],
[ 364, 82, 11094, ..., 1365, 448, 45556],
[ 323, 30739, 311, ..., 9886, 87152, 8933],
...,
[79323, 14336, 304, ..., 1393, 24917, 307],
[ 284, 715, 10751, ..., 315, 98098, 62940],
[ 389, 5813, 220, ..., 11678, 39294, 892]], device='xla:0')}, shape: torch.Size([8, 4096])
Host: 1, batch: {'input_ids': tensor([[ 284, 23947, 1231, ..., 220, 19, 51857],
[ 364, 82, 11094, ..., 1365, 448, 45556],
[ 323, 30739, 311, ..., 9886, 87152, 8933],
...,
[79323, 14336, 304, ..., 1393, 24917, 307],
[ 284, 715, 10751, ..., 315, 98098, 62940],
[ 389, 5813, 220, ..., 11678, 39294, 892]], device='xla:0'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], device='xla:0'), 'labels': tensor([[ 284, 23947, 1231, ..., 220, 19, 51857],
[ 364, 82, 11094, ..., 1365, 448, 45556],
[ 323, 30739, 311, ..., 9886, 87152, 8933],
...,
[79323, 14336, 304, ..., 1393, 24917, 307],
[ 284, 715, 10751, ..., 315, 98098, 62940],
[ 389, 5813, 220, ..., 11678, 39294, 892]], device='xla:0')}, shape: torch.Size([8, 4096])
This seems to imply we are getting the global batch (8 samples) on each host even if the Minibatch is enabled. Any idea why this happened? Also the printed loss is the same on the two hosts in the VM. Many thanks!!!