-
Notifications
You must be signed in to change notification settings - Fork 60
Add Image Classification Implementation #608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0194499b11
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
src/lightly_train/_task_models/image_classification/task_model.py
Outdated
Show resolved
Hide resolved
src/lightly_train/_task_models/image_classification/task_model.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds an image-classification task to LightlyTrain, including both multiclass and multilabel fine-tuning support, along with dataset handling, transforms, task model, and end-to-end training tests.
Changes:
- Introduces a new
image_classificationtask model + train loop (losses/metrics/exports) with multiclass and multilabel modes. - Adds image-classification dataset parsing (folder + CSV) and a dedicated collate function.
- Extends the public training API (
train_image_classification) and adds integration tests.
Reviewed changes
Copilot reviewed 23 out of 23 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/helpers.py | Renames/extends dataset test helpers for multiclass and multilabel classification datasets. |
| tests/_data/test_image_classification_dataset.py | Updates dataset tests to use new multiclass/multilabel data-args types and helper naming. |
| tests/_commands/test_train_task.py | Adds end-to-end training tests for image classification (multiclass + multilabel). |
| src/lightly_train/types.py | Updates typed batch/item shapes for image classification labels representation. |
| src/lightly_train/_transforms/view_transform.py | Updates link comment related to RandomResizedCrop interpolation choice. |
| src/lightly_train/_transforms/image_classification_transform.py | Introduces an Albumentations-based transform pipeline for classification train/val. |
| src/lightly_train/_task_models/image_classification/transforms.py | Adds default transform-args presets for classification train/val. |
| src/lightly_train/_task_models/image_classification/train_model.py | Adds the train loop (optimizer/scheduler/loss/metrics) for classification. |
| src/lightly_train/_task_models/image_classification/task_model.py | Adds the classification task model (head, predict, ONNX/TRT export, checkpoint head handling). |
| src/lightly_train/_task_models/image_classification/init.py | Adds package marker for the new task model submodule. |
| src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/task_model.py | Improves backbone-weight load log message to include the path. |
| src/lightly_train/_task_models/dinov3_eomt_panoptic_segmentation/task_model.py | Improves backbone-weight load log message to include the path. |
| src/lightly_train/_task_models/dinov3_eomt_instance_segmentation/task_model.py | Improves backbone-weight load log message to include the path. |
| src/lightly_train/_task_models/dinov2_ltdetr_object_detection/task_model.py | Improves backbone-weight load log message to include the path. |
| src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/task_model.py | Improves backbone-weight load log message to include the path. |
| src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/task_model.py | Improves backbone-weight load log message to include the path. |
| src/lightly_train/_task_checkpoint.py | Adds a resolve_auto hook to checkpoint-args base for task-specific defaults. |
| src/lightly_train/_models/package_helpers.py | Adds load_weights plumbing when constructing wrapped models. |
| src/lightly_train/_data/task_batch_collation.py | Adds an image-classification collate function producing ImageClassificationBatch. |
| src/lightly_train/_data/image_classification_dataset.py | Adds task type selection (multiclass vs multilabel) and CSV validation behavior. |
| src/lightly_train/_commands/train_task_helpers.py | Updates checkpoint-args validation to use per-task checkpoint args class + auto-resolution. |
| src/lightly_train/_commands/train_task.py | Adds train_image_classification entry point and task config types. |
| src/lightly_train/init.py | Exposes train_image_classification in the public package API. |
What has changed and why?
Adds support for fine-tuning any model supported by LightlyTrain for classification.
Will make a follow-up PR with docs and changelog updates. The current PR is for testing.
How has it been tested?
Did you update CHANGELOG.md?
Did you update the documentation?