Trainer: Add JAX distributed training guide to Kubeflow Trainer docs#4305
Trainer: Add JAX distributed training guide to Kubeflow Trainer docs#4305Amir380-A wants to merge 5 commits intokubeflow:masterfrom
Conversation
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
|
Hi @Amir380-A. Thanks for your PR. I'm waiting for a kubeflow member to verify that this patch is reasonable to test. If it is, they should reply with Once the patch is verified, the new status will be reflected by the I understand the commands that are listed here. DetailsInstructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
|
🚫 This command cannot be processed. Only organization members or owners can use the commands. |
c27c24f to
15011f7
Compare
Signed-off-by: Amir380-A <62997533+Amir380-A@users.noreply.github.com>
Signed-off-by: Amir380-A <62997533+Amir380-A@users.noreply.github.com>
15011f7 to
339c764
Compare
Signed-off-by: Amir380-A <62997533+Amir380-A@users.noreply.github.com>
54a2380 to
880eca2
Compare
kevo-1
left a comment
There was a problem hiding this comment.
Thanks for putting this together! It's great to see the new JAX distributed training guide. It covers the core SPMD concepts and built-in runtime perfectly.
To align this guide with the other training framework documentation (like PyTorch, DeepSpeed, and MLX), I have some suggested some changes.
Everything else looks fantastic.
| ```python | ||
|
|
||
| from kubeflow.trainer import TrainerClient, TrainJob | ||
|
|
||
| client = TrainerClient() | ||
|
|
||
| job = TrainJob( | ||
| name="jax-sdk-example", | ||
| runtime="jax-distributed", | ||
| num_nodes=2, | ||
| container={ | ||
| "image": "nvcr.io/nvidia/jax:25.10-py3", | ||
| "command": ["python", "train.py"], | ||
| }, | ||
| ) | ||
|
|
||
| client.create_trainjob(job) | ||
| ``` |
There was a problem hiding this comment.
Suggestion: Let's update this Python SDK example to use the new CustomTrainer API wrapper rather than manually constructing the TrainJob pod spec. This makes it consistent with how we show PyTorch, DeepSpeed, and MLX Python SDK usage.
| ```python | |
| from kubeflow.trainer import TrainerClient, TrainJob | |
| client = TrainerClient() | |
| job = TrainJob( | |
| name="jax-sdk-example", | |
| runtime="jax-distributed", | |
| num_nodes=2, | |
| container={ | |
| "image": "nvcr.io/nvidia/jax:25.10-py3", | |
| "command": ["python", "train.py"], | |
| }, | |
| ) | |
| client.create_trainjob(job) | |
| ``` | |
| ```python | |
| from kubeflow.trainer import TrainerClient, CustomTrainer | |
| def train_jax(): | |
| import os | |
| import jax | |
| import jax.distributed as dist | |
| dist.initialize( | |
| num_processes=int(os.environ["JAX_NUM_PROCESSES"]), | |
| process_id=int(os.environ["JAX_PROCESS_ID"]), | |
| coordinator_address=os.environ["JAX_COORDINATOR_ADDRESS"], | |
| ) | |
| print("JAX Distributed Environment") | |
| print("Global devices:", jax.devices()) | |
| print("Local devices:", jax.local_devices()) | |
| job_id = TrainerClient().train( | |
| runtime=TrainerClient().get_runtime("jax-distributed"), | |
| trainer=CustomTrainer( | |
| func=train_jax, | |
| num_nodes=2, | |
| resources_per_node={ | |
| "cpu": 2, | |
| }, | |
| ), | |
| ) | |
| ``` |
There was a problem hiding this comment.
Edited the code. Thank you for the review.
|
|
||
| ## Next Steps | ||
|
|
||
| - Check out [the MNIST JAX example](https://github.com/kaisoz/trainer/blob/ca27f54971070a1f65f2d9bf3a1b643f92736448/examples/jax/image-classification/mnist.ipynb). |
There was a problem hiding this comment.
Suggestion: We should link to the official kubeflow repository on master instead of the personal fork, to prevent broken links later. I also added a link to the TrainerClient SDK documentation!
| - Check out [the MNIST JAX example](https://github.com/kaisoz/trainer/blob/ca27f54971070a1f65f2d9bf3a1b643f92736448/examples/jax/image-classification/mnist.ipynb). | |
| - Check out [the MNIST JAX example](https://github.com/kubeflow/trainer/blob/master/examples/jax/image-classification/mnist.ipynb). | |
| - Learn more about `TrainerClient()` APIs [in the Kubeflow SDK](https://github.com/kubeflow/sdk/blob/main/kubeflow/trainer/api/trainer_client.py). | |
| ``` | ||
|
|
||
|
|
||
|
|
There was a problem hiding this comment.
| ### Get the TrainJob Results | |
| You can use the `get_job_logs()` API to see your TrainJob logs. For JAX distributed training, logs are typically available on all nodes. You can inspect node 0: | |
| ```py | |
| print("\n".join(TrainerClient().get_job_logs(name=job_id, step="node-0"))) | |
| ``` | |
|
@kevo-1: changing LGTM is restricted to collaborators DetailsIn response to this:
Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
andreyvelich
left a comment
There was a problem hiding this comment.
Sorry for the late review @Amir380-A!
I left a few comments in addition to @kevo-1.
| @@ -0,0 +1,231 @@ | |||
| +++ | |||
| title = "JAX Guide" | |||
| description = "How to run JAX training on Kubernetes with Kubeflow Trainer" | |||
There was a problem hiding this comment.
| description = "How to run JAX training on Kubernetes with Kubeflow Trainer" | |
| description = "How to run JAX on Kubernetes with Kubeflow Trainer" |
There was a problem hiding this comment.
sorry for the late reply, I reviewed all the comments and committed with all proposed changes.
Noted and added, thank you for the review.
| +++ | ||
| title = "JAX Guide" | ||
| description = "How to run JAX training on Kubernetes with Kubeflow Trainer" | ||
| weight = 10 |
There was a problem hiding this comment.
Let's move it under PyTorch
| weight = 10 | |
| weight = 15 |
| TPU workloads are not supported because installing both `jax[cuda]` | ||
| and `jax[tpu]` in the same image leads to backend and plugin conflicts. | ||
| A separate TPU-specific runtime is required. | ||
| {{% /alert %}} |
There was a problem hiding this comment.
@kaisoz Do we have a tracking issue to support TPUs?
| TrainerClient().get_runtime_packages( | ||
| runtime=TrainerClient().get_runtime("jax-distributed") | ||
| ) | ||
|
|
There was a problem hiding this comment.
Can you show output of this command, like we did for DeepSpeed: https://deploy-preview-4305--competent-brattain-de2d6d.netlify.app/docs/components/trainer/user-guides/deepspeed/#get-deepspeed-runtime-packages
There was a problem hiding this comment.
Added the output, please check.
|
|
||
| Your training script must explicitly initialize the JAX distributed runtime before performing any JAX computation. | ||
|
|
||
| ### Example: train.py |
There was a problem hiding this comment.
Can you modify this example to define the JAX script under training function, and showcase the example with calling train() API like here: https://deploy-preview-4305--competent-brattain-de2d6d.netlify.app/docs/components/trainer/user-guides/deepspeed/#deepspeed-distributed-environment
There was a problem hiding this comment.
Updated the example to define the JAX logic inside a train() function and added the entrypoint call, Please let me know if you’d like any further adjustments.
| ) | ||
|
|
||
| ``` | ||
| ## Initializing the JAX Distributed Runtime |
There was a problem hiding this comment.
| ## Initializing the JAX Distributed Runtime | |
| ## JAX Distributed Environment |
|
|
||
| --- | ||
|
|
||
| ## Creating a TrainJob with JAX Runtime |
There was a problem hiding this comment.
can you refactor this section to use Kubeflow SDK to submit jobs, to be aligned with other examples: https://deploy-preview-4305--competent-brattain-de2d6d.netlify.app/docs/components/trainer/user-guides/deepspeed/#deepspeed-distributed-environment
There was a problem hiding this comment.
Refactored it and merge it into one example for the python SDK.
| Kubeflow Trainer automatically injects the following environment variables into each trainer container: | ||
|
|
||
| | Variable | Description | | ||
| |--------|-------------| | ||
| | `JAX_NUM_PROCESSES` | Total number of JAX processes | | ||
| | `JAX_PROCESS_ID` | Global process index (0-based) | | ||
| | `JAX_COORDINATOR_ADDRESS` | Address of the coordinator (process 0) | |
There was a problem hiding this comment.
This can be moved to the JAX Distributed Environment section.
There was a problem hiding this comment.
Moved the environment variables table to the JAX Distributed Environment section as suggested. Please let me know if the placement looks correct.
| ## Limitations | ||
|
|
||
| Current limitations of the JAX runtime include: | ||
|
|
||
| - No TPU support | ||
| - No elastic or dynamic scaling | ||
| - Homogeneous node and device configurations are assumed | ||
| - All processes must start and finish together |
There was a problem hiding this comment.
I don't think that is needed.
| ## Limitations | |
| Current limitations of the JAX runtime include: | |
| - No TPU support | |
| - No elastic or dynamic scaling | |
| - Homogeneous node and device configurations are assumed | |
| - All processes must start and finish together |
There was a problem hiding this comment.
Removed the Limitations section as suggested. Thanks!
| ## Parallelism with JAX Primitives | ||
|
|
||
| Once initialized, you can use JAX SPMD primitives normally: | ||
|
|
||
| - `pmap` — data-parallel execution | ||
| - `pjit` — explicit global sharding | ||
| - `shard_map` — low-level SPMD control | ||
|
|
||
| Kubeflow Trainer does not alter JAX semantics, it only provides the distributed execution environment. |
There was a problem hiding this comment.
Same suggestion to move it to JAX Distributed Environment section.
There was a problem hiding this comment.
Done. Please Let me know if anything else I can edit or review!
|
Hi @Amir380-A, did you get a chance to review proposed changes, or we should ask @kevo-1 to take over this work? |
Updated the JAX guide to improve clarity. Edited and added examples for using the Python SDK. Adjusted weight and description for better organization. Signed-off-by: Amir Ibrahim <62997533+Amir380-A@users.noreply.github.com>
Signed-off-by: Amir Ibrahim <62997533+Amir380-A@users.noreply.github.com>
Description of Changes
This PR adds a new JAX user guide describing how to run distributed JAX
training jobs with Kubeflow Trainer.
The guide covers:
Related Issues
Closes: kubeflow/trainer#3183
Related Issues
Update screenshot preview

Checklist