Skip to content

chore: Trainer: Specialized Trainers#308

Open
szaher wants to merge 5 commits intokubeflow:mainfrom
szaher:KEP-285
Open

chore: Trainer: Specialized Trainers#308
szaher wants to merge 5 commits intokubeflow:mainfrom
szaher:KEP-285

Conversation

@szaher
Copy link
Member

@szaher szaher commented Feb 19, 2026

What this PR does / why we need it:

Which issue(s) this PR fixes (optional, in Fixes #<issue number>, #<issue number>, ... format, will close the issue(s) when PR gets merged):

Fixes #

Checklist:

  • Docs included if any changes are user facing

szaher and others added 5 commits February 11, 2026 19:17
Proposing framework-aware trainer classes (TorchTrainer,
MPITrainer, JAXTrainer, XGBoostTrainer) with automatic runtime discovery
via the trainer.kubeflow.org/framework label, and a RuntimeConfig
dataclass to separate per-job environment settings from training logic.

Issue: kubeflow#285

Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Antonin Stefanutti <astefanutti@users.noreply.github.com>
Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Antonin Stefanutti <astefanutti@users.noreply.github.com>
Signed-off-by: Saad Zaher <szaher@redhat.com>
Copilot AI review requested due to automatic review settings February 19, 2026 22:43
@google-oss-prow
Copy link
Contributor

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign kramaranya for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a comprehensive design proposal for specialized trainer abstractions and a RuntimeConfig dataclass to the Kubeflow SDK. The proposal addresses current limitations in the SDK's trainer subsystem by introducing framework-aware trainer classes that bridge the gap between the generic CustomTrainer and the highly specific BuiltinTrainer.

Changes:

  • Adds a detailed design proposal document describing a new BaseTrainer abstract interface and specialized framework trainers (TorchTrainer, MPITrainer, JAXTrainer, XGBoostTrainer)
  • Proposes a RuntimeConfig dataclass to cleanly separate runtime environment settings from training logic
  • Includes comprehensive documentation covering motivation, design details, API examples, migration strategy, test plan, and alternatives considered

3. **Deprecating `CustomTrainer` or `BuiltinTrainer`.** Both remain supported.
Specialized trainers are an additional option, not a replacement.
4. **Tier 2 trainer implementations.** This proposal defines the extension mechanism
and interface. Concrete Tier 2 implementations (HuggingFace, DeepSpeed, Unsloth,
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" throughout the document. This applies to references in text and comments, though the class name "HuggingFaceTrainer" would be correct as Python class names don't use spaces.

Suggested change
and interface. Concrete Tier 2 implementations (HuggingFace, DeepSpeed, Unsloth,
and interface. Concrete Tier 2 implementations (Hugging Face, DeepSpeed, Unsloth,

Copilot uses AI. Check for mistakes.
Comment on lines +483 to +489
# Example: future HuggingFaceTrainer (NOT part of this proposal's implementation scope)

@dataclass
class TransformersTrainer(BaseTrainer):
"""Trainer for HuggingFace Transformers training.

Wraps HuggingFace's Trainer API and maps to a PyTorch runtime.
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" in the comment and docstring text.

Suggested change
# Example: future HuggingFaceTrainer (NOT part of this proposal's implementation scope)
@dataclass
class TransformersTrainer(BaseTrainer):
"""Trainer for HuggingFace Transformers training.
Wraps HuggingFace's Trainer API and maps to a PyTorch runtime.
# Example: future Hugging Face trainer (NOT part of this proposal's implementation scope)
@dataclass
class TransformersTrainer(BaseTrainer):
"""Trainer for Hugging Face Transformers training.
Wraps Hugging Face's Trainer API and maps to a PyTorch runtime.

Copilot uses AI. Check for mistakes.
┌─────┴──────────┐
│ │
HuggingFace DeepSpeed
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" in the diagram text.

Suggested change
HuggingFace DeepSpeed
Hugging Face DeepSpeed

Copilot uses AI. Check for mistakes.
@krishdef7
Copy link
Contributor

@szaher @andreyvelich — this proposal is really well thought out, especially the separation between BaseTrainer and framework-specific trainers along with RuntimeConfig.

I had a question regarding TorchTrainer extensibility and runtime selection:

Given that multiple torch-based runtimes may coexist (as discussed earlier in #287), how do you envision selecting the appropriate runtime for a given TorchTrainer instance?

One possible approach could be:

  • Allow an optional runtime_name in RuntimeConfig (explicit selection), and
  • Fall back to a priority-based selection among compatible runtimes (e.g., via annotations or ordering) when not specified.

This might help keep the API simple while still supporting multiple backends (e.g., TorchTune vs custom PEFT/TRL runtimes for LLM workflows)

Curious if something along these lines aligns with the intended direction.

Happy to explore this further or prototype once the design is clearer.

Copy link
Contributor

@kramaranya kramaranya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @szaher!
Looks great to me, and it should be a great improvement to the user experience in Kubeflow SDK!

/assign @andreyvelich @astefanutti @briangallagher @Fiona-Waters @MStokluska

Comment on lines +81 to +84
2. **`RuntimeConfig` dataclass** — A dedicated configuration object that cleanly separates
per-job runtime environment settings (packages, pip config, environment variables) from
training logic and scaling parameters. This replaces the current pattern where
`CustomTrainer` conflates runtime concerns with trainer concerns.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this require runtime/controller changes?

Comment on lines +143 to +144
3. **Deprecating `CustomTrainer` or `BuiltinTrainer`.** Both remain supported.
Specialized trainers are an additional option, not a replacement.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to eventually deprecate those or do we want to always maintain both options?

Comment on lines +297 to +302
if runtime.trainer.framework not in self.supported_frameworks:
raise ValueError(
f"{type(self).__name__} supports frameworks "
f"{self.supported_frameworks}, but runtime '{runtime.name}' "
f"has framework '{runtime.trainer.framework}'"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also would need to validate runtime.trainer.trainer_type too

Comment on lines +362 to +368
def get_framework_args(self) -> dict:
args = {}
if self.max_restarts is not None:
args["max-restarts"] = str(self.max_restarts)
if self.monitor_interval is not None:
args["monitor-interval"] = str(self.monitor_interval)
return args
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where these new args go in the TrainJob spec?

@google-oss-prow
Copy link
Contributor

@kramaranya: GitHub didn't allow me to assign the following users: MStokluska.

Note that only kubeflow members with read permissions, repo collaborators and people who have commented on this issue/PR can be assigned. Additionally, issues/PRs can only have 10 assignees at the same time.
For more information please see the contributor guide

Details

In response to this:

Thanks @szaher!
Looks great to me, and it should be a great improvement to the user experience in Kubeflow SDK!

/assign @andreyvelich @astefanutti @briangallagher @Fiona-Waters @MStokluska

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

@krishdef7
Copy link
Contributor

+1 on the points around validation and argument placement, I had a related question while reading through this.

For the framework-specific args (e.g. max_restarts, monitor_interval), where do you envision these being materialized in the resulting TrainJob spec?

  • Do they map directly to existing fields in the underlying CRDs (e.g. TorchJob/MPIJob), or
  • Are they intended to flow through a more generic extension mechanism (e.g. annotations / plugin args)?

This also seems tied to whether RuntimeConfig and the specialized trainers remain purely SDK-layer abstractions, or if they imply corresponding changes in the controller/runtime layer.

Clarifying this mapping would help understand how far the abstraction goes (SDK-only vs API/CRD impact).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants