Conversation
Proposing framework-aware trainer classes (TorchTrainer, MPITrainer, JAXTrainer, XGBoostTrainer) with automatic runtime discovery via the trainer.kubeflow.org/framework label, and a RuntimeConfig dataclass to separate per-job environment settings from training logic. Issue: kubeflow#285 Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Antonin Stefanutti <astefanutti@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Antonin Stefanutti <astefanutti@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
There was a problem hiding this comment.
Pull request overview
This PR adds a comprehensive design proposal for specialized trainer abstractions and a RuntimeConfig dataclass to the Kubeflow SDK. The proposal addresses current limitations in the SDK's trainer subsystem by introducing framework-aware trainer classes that bridge the gap between the generic CustomTrainer and the highly specific BuiltinTrainer.
Changes:
- Adds a detailed design proposal document describing a new BaseTrainer abstract interface and specialized framework trainers (TorchTrainer, MPITrainer, JAXTrainer, XGBoostTrainer)
- Proposes a RuntimeConfig dataclass to cleanly separate runtime environment settings from training logic
- Includes comprehensive documentation covering motivation, design details, API examples, migration strategy, test plan, and alternatives considered
| 3. **Deprecating `CustomTrainer` or `BuiltinTrainer`.** Both remain supported. | ||
| Specialized trainers are an additional option, not a replacement. | ||
| 4. **Tier 2 trainer implementations.** This proposal defines the extension mechanism | ||
| and interface. Concrete Tier 2 implementations (HuggingFace, DeepSpeed, Unsloth, |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" throughout the document. This applies to references in text and comments, though the class name "HuggingFaceTrainer" would be correct as Python class names don't use spaces.
| and interface. Concrete Tier 2 implementations (HuggingFace, DeepSpeed, Unsloth, | |
| and interface. Concrete Tier 2 implementations (Hugging Face, DeepSpeed, Unsloth, |
| # Example: future HuggingFaceTrainer (NOT part of this proposal's implementation scope) | ||
|
|
||
| @dataclass | ||
| class TransformersTrainer(BaseTrainer): | ||
| """Trainer for HuggingFace Transformers training. | ||
|
|
||
| Wraps HuggingFace's Trainer API and maps to a PyTorch runtime. |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" in the comment and docstring text.
| # Example: future HuggingFaceTrainer (NOT part of this proposal's implementation scope) | |
| @dataclass | |
| class TransformersTrainer(BaseTrainer): | |
| """Trainer for HuggingFace Transformers training. | |
| Wraps HuggingFace's Trainer API and maps to a PyTorch runtime. | |
| # Example: future Hugging Face trainer (NOT part of this proposal's implementation scope) | |
| @dataclass | |
| class TransformersTrainer(BaseTrainer): | |
| """Trainer for Hugging Face Transformers training. | |
| Wraps Hugging Face's Trainer API and maps to a PyTorch runtime. |
| │ | ||
| ┌─────┴──────────┐ | ||
| │ │ | ||
| HuggingFace DeepSpeed |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" in the diagram text.
| HuggingFace DeepSpeed | |
| Hugging Face DeepSpeed |
|
@szaher @andreyvelich — this proposal is really well thought out, especially the separation between BaseTrainer and framework-specific trainers along with RuntimeConfig. I had a question regarding TorchTrainer extensibility and runtime selection: Given that multiple torch-based runtimes may coexist (as discussed earlier in #287), how do you envision selecting the appropriate runtime for a given TorchTrainer instance? One possible approach could be:
This might help keep the API simple while still supporting multiple backends (e.g., TorchTune vs custom PEFT/TRL runtimes for LLM workflows) Curious if something along these lines aligns with the intended direction. Happy to explore this further or prototype once the design is clearer. |
kramaranya
left a comment
There was a problem hiding this comment.
Thanks @szaher!
Looks great to me, and it should be a great improvement to the user experience in Kubeflow SDK!
/assign @andreyvelich @astefanutti @briangallagher @Fiona-Waters @MStokluska
| 2. **`RuntimeConfig` dataclass** — A dedicated configuration object that cleanly separates | ||
| per-job runtime environment settings (packages, pip config, environment variables) from | ||
| training logic and scaling parameters. This replaces the current pattern where | ||
| `CustomTrainer` conflates runtime concerns with trainer concerns. |
There was a problem hiding this comment.
Would this require runtime/controller changes?
| 3. **Deprecating `CustomTrainer` or `BuiltinTrainer`.** Both remain supported. | ||
| Specialized trainers are an additional option, not a replacement. |
There was a problem hiding this comment.
Is the plan to eventually deprecate those or do we want to always maintain both options?
| if runtime.trainer.framework not in self.supported_frameworks: | ||
| raise ValueError( | ||
| f"{type(self).__name__} supports frameworks " | ||
| f"{self.supported_frameworks}, but runtime '{runtime.name}' " | ||
| f"has framework '{runtime.trainer.framework}'" | ||
| ) |
There was a problem hiding this comment.
We also would need to validate runtime.trainer.trainer_type too
| def get_framework_args(self) -> dict: | ||
| args = {} | ||
| if self.max_restarts is not None: | ||
| args["max-restarts"] = str(self.max_restarts) | ||
| if self.monitor_interval is not None: | ||
| args["monitor-interval"] = str(self.monitor_interval) | ||
| return args |
There was a problem hiding this comment.
where these new args go in the TrainJob spec?
|
@kramaranya: GitHub didn't allow me to assign the following users: MStokluska. Note that only kubeflow members with read permissions, repo collaborators and people who have commented on this issue/PR can be assigned. Additionally, issues/PRs can only have 10 assignees at the same time. DetailsIn response to this:
Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
|
+1 on the points around validation and argument placement, I had a related question while reading through this. For the framework-specific args (e.g.
This also seems tied to whether Clarifying this mapping would help understand how far the abstraction goes (SDK-only vs API/CRD impact). |
What this PR does / why we need it:
Which issue(s) this PR fixes (optional, in
Fixes #<issue number>, #<issue number>, ...format, will close the issue(s) when PR gets merged):Fixes #
Checklist: