Skip to content

Trainer: Add JAX distributed training guide to Kubeflow Trainer docs#4305

Open
Amir380-A wants to merge 5 commits intokubeflow:masterfrom
Amir380-A:docs/jax-guide
Open

Trainer: Add JAX distributed training guide to Kubeflow Trainer docs#4305
Amir380-A wants to merge 5 commits intokubeflow:masterfrom
Amir380-A:docs/jax-guide

Conversation

@Amir380-A
Copy link

Description of Changes

This PR adds a new JAX user guide describing how to run distributed JAX
training jobs with Kubeflow Trainer.

The guide covers:

  • JAX SPMD execution model overview
  • Built-in jax-distributed runtime
  • Required JAX distributed initialization
  • TrainJob configuration and example (YAML and Python SDK)

Related Issues

Closes: kubeflow/trainer#3183

Related Issues

Update screenshot preview
image

Checklist

@google-oss-prow
Copy link

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign jeffwan for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@google-oss-prow google-oss-prow bot added the area/trainer AREA: Kubeflow Trainer / Kubeflow Training Operator label Feb 9, 2026
@google-oss-prow
Copy link

Hi @Amir380-A. Thanks for your PR.

I'm waiting for a kubeflow member to verify that this patch is reasonable to test. If it is, they should reply with /ok-to-test on its own line. Until that is done, I will not automatically test new commits in this PR, but the usual testing commands by org members will still work. Regular contributors should join the org to skip this step.

Once the patch is verified, the new status will be reflected by the ok-to-test label.

I understand the commands that are listed here.

Details

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

@github-actions
Copy link

github-actions bot commented Feb 9, 2026

🚫 This command cannot be processed. Only organization members or owners can use the commands.

Signed-off-by: Amir380-A <62997533+Amir380-A@users.noreply.github.com>
Signed-off-by: Amir380-A <62997533+Amir380-A@users.noreply.github.com>
Copy link
Member

@Arhell Arhell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ok-to-test

Signed-off-by: Amir380-A <62997533+Amir380-A@users.noreply.github.com>
Copy link

@kevo-1 kevo-1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this together! It's great to see the new JAX distributed training guide. It covers the core SPMD concepts and built-in runtime perfectly.

To align this guide with the other training framework documentation (like PyTorch, DeepSpeed, and MLX), I have some suggested some changes.

Everything else looks fantastic.

Comment on lines 172 to 189
```python

from kubeflow.trainer import TrainerClient, TrainJob

client = TrainerClient()

job = TrainJob(
name="jax-sdk-example",
runtime="jax-distributed",
num_nodes=2,
container={
"image": "nvcr.io/nvidia/jax:25.10-py3",
"command": ["python", "train.py"],
},
)

client.create_trainjob(job)
```
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Let's update this Python SDK example to use the new CustomTrainer API wrapper rather than manually constructing the TrainJob pod spec. This makes it consistent with how we show PyTorch, DeepSpeed, and MLX Python SDK usage.

Suggested change
```python
from kubeflow.trainer import TrainerClient, TrainJob
client = TrainerClient()
job = TrainJob(
name="jax-sdk-example",
runtime="jax-distributed",
num_nodes=2,
container={
"image": "nvcr.io/nvidia/jax:25.10-py3",
"command": ["python", "train.py"],
},
)
client.create_trainjob(job)
```
```python
from kubeflow.trainer import TrainerClient, CustomTrainer
def train_jax():
import os
import jax
import jax.distributed as dist
dist.initialize(
num_processes=int(os.environ["JAX_NUM_PROCESSES"]),
process_id=int(os.environ["JAX_PROCESS_ID"]),
coordinator_address=os.environ["JAX_COORDINATOR_ADDRESS"],
)
print("JAX Distributed Environment")
print("Global devices:", jax.devices())
print("Local devices:", jax.local_devices())
job_id = TrainerClient().train(
runtime=TrainerClient().get_runtime("jax-distributed"),
trainer=CustomTrainer(
func=train_jax,
num_nodes=2,
resources_per_node={
"cpu": 2,
},
),
)
```

Copy link
Author

@Amir380-A Amir380-A Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edited the code. Thank you for the review.


## Next Steps

- Check out [the MNIST JAX example](https://github.com/kaisoz/trainer/blob/ca27f54971070a1f65f2d9bf3a1b643f92736448/examples/jax/image-classification/mnist.ipynb).
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: We should link to the official kubeflow repository on master instead of the personal fork, to prevent broken links later. I also added a link to the TrainerClient SDK documentation!

Suggested change
- Check out [the MNIST JAX example](https://github.com/kaisoz/trainer/blob/ca27f54971070a1f65f2d9bf3a1b643f92736448/examples/jax/image-classification/mnist.ipynb).
- Check out [the MNIST JAX example](https://github.com/kubeflow/trainer/blob/master/examples/jax/image-classification/mnist.ipynb).
- Learn more about `TrainerClient()` APIs [in the Kubeflow SDK](https://github.com/kubeflow/sdk/blob/main/kubeflow/trainer/api/trainer_client.py).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted and added.

```



Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### Get the TrainJob Results
You can use the `get_job_logs()` API to see your TrainJob logs. For JAX distributed training, logs are typically available on all nodes. You can inspect node 0:
```py
print("\n".join(TrainerClient().get_job_logs(name=job_id, step="node-0")))
```

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noted and added. Thank you.

@google-oss-prow
Copy link

@kevo-1: changing LGTM is restricted to collaborators

Details

In response to this:

Thanks for putting this together! It's great to see the new JAX distributed training guide. It covers the core SPMD concepts and built-in runtime perfectly.

To align this guide with the other training framework documentation (like PyTorch, DeepSpeed, and MLX), I have some suggested some changes.

Everything else looks fantastic.

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review @Amir380-A!
I left a few comments in addition to @kevo-1.

@@ -0,0 +1,231 @@
+++
title = "JAX Guide"
description = "How to run JAX training on Kubernetes with Kubeflow Trainer"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
description = "How to run JAX training on Kubernetes with Kubeflow Trainer"
description = "How to run JAX on Kubernetes with Kubeflow Trainer"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the late reply, I reviewed all the comments and committed with all proposed changes.
Noted and added, thank you for the review.

+++
title = "JAX Guide"
description = "How to run JAX training on Kubernetes with Kubeflow Trainer"
weight = 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move it under PyTorch

Suggested change
weight = 10
weight = 15

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, edited the weight.

Comment on lines +41 to +44
TPU workloads are not supported because installing both `jax[cuda]`
and `jax[tpu]` in the same image leads to backend and plugin conflicts.
A separate TPU-specific runtime is required.
{{% /alert %}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaisoz Do we have a tracking issue to support TPUs?

Comment on lines +75 to +78
TrainerClient().get_runtime_packages(
runtime=TrainerClient().get_runtime("jax-distributed")
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the output, please check.


Your training script must explicitly initialize the JAX distributed runtime before performing any JAX computation.

### Example: train.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you modify this example to define the JAX script under training function, and showcase the example with calling train() API like here: https://deploy-preview-4305--competent-brattain-de2d6d.netlify.app/docs/components/trainer/user-guides/deepspeed/#deepspeed-distributed-environment

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the example to define the JAX logic inside a train() function and added the entrypoint call, Please let me know if you’d like any further adjustments.

)

```
## Initializing the JAX Distributed Runtime
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
## Initializing the JAX Distributed Runtime
## JAX Distributed Environment

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. thank you.


---

## Creating a TrainJob with JAX Runtime
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you refactor this section to use Kubeflow SDK to submit jobs, to be aligned with other examples: https://deploy-preview-4305--competent-brattain-de2d6d.netlify.app/docs/components/trainer/user-guides/deepspeed/#deepspeed-distributed-environment

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored it and merge it into one example for the python SDK.

Comment on lines 196 to 202
Kubeflow Trainer automatically injects the following environment variables into each trainer container:

| Variable | Description |
|--------|-------------|
| `JAX_NUM_PROCESSES` | Total number of JAX processes |
| `JAX_PROCESS_ID` | Global process index (0-based) |
| `JAX_COORDINATOR_ADDRESS` | Address of the coordinator (process 0) |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be moved to the JAX Distributed Environment section.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the environment variables table to the JAX Distributed Environment section as suggested. Please let me know if the placement looks correct.

Comment on lines 218 to 225
## Limitations

Current limitations of the JAX runtime include:

- No TPU support
- No elastic or dynamic scaling
- Homogeneous node and device configurations are assumed
- All processes must start and finish together
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is needed.

Suggested change
## Limitations
Current limitations of the JAX runtime include:
- No TPU support
- No elastic or dynamic scaling
- Homogeneous node and device configurations are assumed
- All processes must start and finish together

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the Limitations section as suggested. Thanks!

Comment on lines 206 to 214
## Parallelism with JAX Primitives

Once initialized, you can use JAX SPMD primitives normally:

- `pmap` — data-parallel execution
- `pjit` — explicit global sharding
- `shard_map` — low-level SPMD control

Kubeflow Trainer does not alter JAX semantics, it only provides the distributed execution environment.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestion to move it to JAX Distributed Environment section.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Please Let me know if anything else I can edit or review!

@andreyvelich
Copy link
Member

Hi @Amir380-A, did you get a chance to review proposed changes, or we should ask @kevo-1 to take over this work?

Updated the JAX guide to improve clarity. Edited and added examples for using the Python SDK. Adjusted weight and description for better organization.

Signed-off-by: Amir Ibrahim <62997533+Amir380-A@users.noreply.github.com>
Signed-off-by: Amir Ibrahim <62997533+Amir380-A@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area/trainer AREA: Kubeflow Trainer / Kubeflow Training Operator ok-to-test size/L

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Documentation for Distributed JAX with TrainJob

4 participants