diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 1c5f32d1db..dc213e568e 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -451,6 +451,7 @@ def sync_gradients(self, model: torch.nn.Module) -> None: else: synchronize() + # don't need to implement default_model method by force def default_model(self) -> Tuple[str, List[str]]: """ diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index fcea6c81e8..480bfe8c4f 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -46,6 +46,25 @@ def allreduce(x: torch.Tensor) -> None: dist.all_reduce(x) x.div_(get_world_size()) +def allreduce_with_indicator(grad: torch.Tensor, indicator: torch.Tensor) -> None: + """ + Overview: + Custom allreduce: Sum both the gradient and indicator tensors across all processes. + Then, if at least one process contributed (i.e., the summation of indicator > 0), + divide the gradient by the summed indicator. This ensures that if only a subset of + GPUs contributed a gradient, the averaging is performed based on the actual number + of contributors rather than the total number of GPUs. + Arguments: + - grad (torch.Tensor): Local gradient tensor to be reduced. + - indicator (torch.Tensor): A tensor flag (1 if the gradient is computed, 0 otherwise). + """ + # Allreduce (sum) the gradient and indicator + dist.all_reduce(grad) + dist.all_reduce(indicator) + + # Avoid division by zero. If indicator is close to 0 (extreme case), grad remains zeros. + if not torch.isclose(indicator, torch.tensor(0.0)): + grad.div_(indicator.item()) def allreduce_with_indicator(grad: torch.Tensor, indicator: torch.Tensor) -> None: """ diff --git a/ding/worker/learner/base_learner.py b/ding/worker/learner/base_learner.py index a070e7a58d..487e36c400 100644 --- a/ding/worker/learner/base_learner.py +++ b/ding/worker/learner/base_learner.py @@ -119,7 +119,8 @@ def __init__( self._logger, _ = build_logger( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False ) - self._tb_logger = None + self._tb_logger = tb_logger + self._log_buffer = { 'scalar': build_log_buffer(),