-
Notifications
You must be signed in to change notification settings - Fork 69
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
| single_grad_partition_groups.append(flat_fp32_avg_grads) |
single_grad_partition_groups.append(flat_fp32_avg_grads) 收集了 flat_fp32_avg_grads 用于 unscale_and_clip_grad, 但开启 cpu_offload 后,self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) 设置 grad 的 tensor to CPU 了,这样 clip grad 只作用于 single_grad_partition_groups 中的 device tensor,真正用于计算的 cpu grad 仍然是之前的数值,会导致 loss 异常。
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
# unscale and clip grads
# get the global norm
global_norm_groups = {}
if self._clip_grad_norm > 0:
for group_name, norm in norms.items():
global_norm_groups[group_name] = norm**0.5
# the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups,
list(global_norm_groups.values()),
loss_scale,
)
Environment
internevo: 5ad2eb0
HybridZeroOptimizer 开启 offload_cpu 即可复现问题。
Other information
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working