-
Notifications
You must be signed in to change notification settings - Fork 235
Release/v1.0 #225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Release/v1.0 #225
Changes from all commits
f6db58d
8741d7f
3e7dd7d
267e99f
99edaaf
e0d47ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| device: 0 | ||
| accelerator: auto | ||
| device: auto | ||
| cpu_num: 16 | ||
|
|
||
| image_size: [640, 640] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,13 @@ | |
| from torch.optim import Optimizer | ||
| from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler | ||
|
|
||
| from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig | ||
| from yolo.config.config import ( | ||
| IDX_TO_ID, | ||
| DataConfig, | ||
| NMSConfig, | ||
| OptimizerConfig, | ||
| SchedulerConfig, | ||
| ) | ||
| from yolo.model.yolo import YOLO | ||
| from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, transform_bbox | ||
| from yolo.utils.logger import logger | ||
|
|
@@ -44,6 +50,7 @@ def __init__(self, decay: float = 0.9999, tau: float = 2000): | |
| self.decay = decay | ||
| self.tau = tau | ||
| self.step = 0 | ||
| self.batch_step_counter = 0 | ||
| self.ema_state_dict = None | ||
|
|
||
| def setup(self, trainer, pl_module, stage): | ||
|
|
@@ -53,18 +60,53 @@ def setup(self, trainer, pl_module, stage): | |
| param.requires_grad = False | ||
|
|
||
| def on_validation_start(self, trainer: "Trainer", pl_module: "LightningModule"): | ||
| self.batch_step_counter = 0 | ||
| if self.ema_state_dict is None: | ||
| self.ema_state_dict = deepcopy(pl_module.model.state_dict()) | ||
| pl_module.ema.load_state_dict(self.ema_state_dict) | ||
|
|
||
| @no_grad() | ||
| def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None: | ||
| self.batch_step_counter += 1 | ||
| if self.batch_step_counter % trainer.accumulate_grad_batches: | ||
| return | ||
| self.step += 1 | ||
| decay_factor = self.decay * (1 - exp(-self.step / self.tau)) | ||
| for key, param in pl_module.model.state_dict().items(): | ||
| self.ema_state_dict[key] = lerp(param.detach(), self.ema_state_dict[key], decay_factor) | ||
|
|
||
|
|
||
| class GradientAccumulation(Callback): | ||
| def __init__(self, data_cfg: DataConfig, scheduler_cfg: SchedulerConfig): | ||
| super().__init__() | ||
| self.equivalent_batch_size = data_cfg.equivalent_batch_size | ||
| self.actual_batch_size = data_cfg.batch_size | ||
| self.warmup_epochs = getattr(scheduler_cfg.warmup, "epochs", 0) | ||
| self.current_batch = 0 | ||
| self.max_accumulation = 1 | ||
| self.warmup_batches = 0 | ||
| logger.info(":arrows_counterclockwise: Enable Gradient Accumulation") | ||
|
|
||
| def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: | ||
| effective_batch_size = self.actual_batch_size * trainer.world_size | ||
| self.max_accumulation = max(1, round(self.equivalent_batch_size / effective_batch_size)) | ||
| batches_per_epoch = int(len(pl_module.train_loader) / trainer.world_size) | ||
| self.warmup_batches = int(self.warmup_epochs * batches_per_epoch) | ||
|
|
||
| def on_train_epoch_start(self, trainer: "Trainer", pl_module: "LightningModule") -> None: | ||
| self.current_batch = trainer.global_step | ||
|
|
||
| def on_train_batch_start(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None: | ||
| if self.current_batch < self.warmup_batches: | ||
| current_accumulation = round(lerp(1, self.max_accumulation, self.current_batch, self.warmup_batches)) | ||
| else: | ||
| current_accumulation = self.max_accumulation | ||
| trainer.accumulate_grad_batches = current_accumulation | ||
|
|
||
| def on_train_batch_end(self, trainer: "Trainer", pl_module: "LightningModule", *args, **kwargs) -> None: | ||
| self.current_batch += 1 | ||
|
|
||
|
Comment on lines
+79
to
+108
|
||
|
|
||
| def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer: | ||
| """Create an optimizer for the given model parameters based on the configuration. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GradientAccumulation class lacks a docstring explaining its purpose, parameters, and behavior. Consider adding documentation to describe how this callback manages gradient accumulation during training warmup.