-
Notifications
You must be signed in to change notification settings - Fork 25
Description
Issue Title
train_llama_torchtitan example: Multiple issues running on TPU v4-32
Issue Description
Summary
Encountered multiple issues setting up and running the train_llama_torchtitan example on TPU v4-32. Documenting all problems and fixes.
Environment
- Hardware: TPU v4-32 (4 workers)
- Python: 3.11
Issue 1: Missing splash_attn.py file
The example references splash_attn.py for TPU flash attention but it's not automatically included when copying the example files.
Fix: Manually download splash_attn.py from the repository.
Issue 2: Python version compatibility
The TPU VM defaulted to an older Python version incompatible with the dependencies.
Fix: Install and use Python 3.11 explicitly in install_deps.sh.
Issue 3: pip install -e . fails for splash_attn
The install script tried to run pip install -e . for splash_attn, but it's a standalone file, not a package.
Fix: Remove the pip install -e . line for splash_attn from install_deps.sh.
Issue 4: freqs_cis KeyError / state_dict mismatch
Error:
KeyError: 'freqs_cis'
or
RuntimeError: Unexpected key(s) in state_dict: "freqs_cis"
Cause: Code assumes freqs_cis handling for scan mode. When use_scan=False, the model structure differs.
Fix: Conditionally handle freqs_cis based on use_scan:
state_dict = create_sharded_weights(gpt, mesh, sharding_map)
if "freqs_cis" in state_dict:
state_dict.pop("freqs_cis")
if use_scan:
state_dict["freqs_cis"] = freqs_cis.to("jax").apply_jax(jax.device_put, replicated)
gpt.load_state_dict(state_dict, assign=True)
if not use_scan:
gpt.freqs_cis = freqs_cis.to("jax").apply_jax(jax.device_put, replicated)Issue 5: Embedding sharding error with tp_parallelism=1
Error:
jax._src.core.ShardingTypeError: Use `.at[...].get(out_sharding=)` to specify the output sharding of a gather from a sharded source.
Got operand=ShapedArray(bfloat16[128256@fsdp,4096@tp]), indices=ShapedArray(int32[16@fsdp,2048,1])
Cause: sharding_map_original shards tok_embeddings.weight as ("fsdp", "tp"). With tp_parallelism=1, the embedding lookup creates an ambiguous gather that JAX cannot resolve.
Fix: Add sharding_map_original_fsdp for non-scan mode with tp_parallelism=1:
sharding_map_original_fsdp = {
"tok_embeddings.weight": (None, "fsdp"), # vocab replicated, hidden sharded
"output.weight": (None, "fsdp"),
# ... other weights with FSDP-only sharding
}Update selection logic to use this map when use_scan=False and tp_parallelism=1.
Issue 6: Missing sharding map selection for scan mode with tp_parallelism=1
The code only selected sharding_map_scan regardless of tp_parallelism value.
Fix: Update selection logic:
if use_scan:
if tp_parallelism == 1:
sharding_map = sharding_map_scan_fsdp
else:
sharding_map = sharding_map_scan
else:
if tp_parallelism == 1:
sharding_map = sharding_map_original_fsdp
else:
sharding_map = sharding_map_original