Skip to content

train_llama_torchtitan example: Multiple issues running on TPU v4-32 #67

@KarthikRevanuru

Description

@KarthikRevanuru

Issue Title

train_llama_torchtitan example: Multiple issues running on TPU v4-32

Issue Description

Summary

Encountered multiple issues setting up and running the train_llama_torchtitan example on TPU v4-32. Documenting all problems and fixes.

Environment

  • Hardware: TPU v4-32 (4 workers)
  • Python: 3.11

Issue 1: Missing splash_attn.py file

The example references splash_attn.py for TPU flash attention but it's not automatically included when copying the example files.

Fix: Manually download splash_attn.py from the repository.


Issue 2: Python version compatibility

The TPU VM defaulted to an older Python version incompatible with the dependencies.

Fix: Install and use Python 3.11 explicitly in install_deps.sh.


Issue 3: pip install -e . fails for splash_attn

The install script tried to run pip install -e . for splash_attn, but it's a standalone file, not a package.

Fix: Remove the pip install -e . line for splash_attn from install_deps.sh.


Issue 4: freqs_cis KeyError / state_dict mismatch

Error:

KeyError: 'freqs_cis'

or

RuntimeError: Unexpected key(s) in state_dict: "freqs_cis"

Cause: Code assumes freqs_cis handling for scan mode. When use_scan=False, the model structure differs.

Fix: Conditionally handle freqs_cis based on use_scan:

state_dict = create_sharded_weights(gpt, mesh, sharding_map)
if "freqs_cis" in state_dict:
  state_dict.pop("freqs_cis")
if use_scan:
  state_dict["freqs_cis"] = freqs_cis.to("jax").apply_jax(jax.device_put, replicated)
gpt.load_state_dict(state_dict, assign=True)
if not use_scan:
  gpt.freqs_cis = freqs_cis.to("jax").apply_jax(jax.device_put, replicated)

Issue 5: Embedding sharding error with tp_parallelism=1

Error:

jax._src.core.ShardingTypeError: Use `.at[...].get(out_sharding=)` to specify the output sharding of a gather from a sharded source.
Got operand=ShapedArray(bfloat16[128256@fsdp,4096@tp]), indices=ShapedArray(int32[16@fsdp,2048,1])

Cause: sharding_map_original shards tok_embeddings.weight as ("fsdp", "tp"). With tp_parallelism=1, the embedding lookup creates an ambiguous gather that JAX cannot resolve.

Fix: Add sharding_map_original_fsdp for non-scan mode with tp_parallelism=1:

sharding_map_original_fsdp = {
  "tok_embeddings.weight": (None, "fsdp"),  # vocab replicated, hidden sharded
  "output.weight": (None, "fsdp"),
  # ... other weights with FSDP-only sharding
}

Update selection logic to use this map when use_scan=False and tp_parallelism=1.


Issue 6: Missing sharding map selection for scan mode with tp_parallelism=1

The code only selected sharding_map_scan regardless of tp_parallelism value.

Fix: Update selection logic:

if use_scan:
  if tp_parallelism == 1:
    sharding_map = sharding_map_scan_fsdp
  else:
    sharding_map = sharding_map_scan
else:
  if tp_parallelism == 1:
    sharding_map = sharding_map_original_fsdp
  else:
    sharding_map = sharding_map_original

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions