diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index f9e94f0175e2..310e2832ffb6 100755 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -736,6 +736,21 @@ def get_local_rank(): return get_local_rank_from_launcher() +def get_backend(group=None): + """ + Returns the backend of the given process group. + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + Returns: + The backend of the given process group as a string. + """ + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + return cdb.get_backend(group) + + def get_global_rank(group=None, group_rank=0): global cdb assert cdb is not None and cdb.is_initialized( diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 47aa4f41bcbc..df8ebca47623 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -299,7 +299,10 @@ def recurse(cl): @instrument_w_nvtx def free_param(param: Parameter) -> None: """Free underlying storage of a parameter.""" - assert not param.ds_active_sub_modules, param.ds_summary() + if param.ds_active_sub_modules: + raise RuntimeError("Cannot free a ZeRO-3 parameter while it is still active in submodules. " + "Ensure all submodules have released the parameter before it is freed. " + f"param={param.ds_summary()}") if get_accelerator().on_accelerator(param.data): # need to make sure that we don't free the parameter while it is still # being used for computation @@ -2316,11 +2319,32 @@ def __enter__(self): if not self.enabled: return self.params[0].all_gather(param_list=self.params) + if self.src_rank is None: + self._param_versions = {p: p._version for p in self.params} def __exit__(self, *exc): if not self.enabled: return if self.src_rank is None: + check_mutation = True + if self.params and dist.is_initialized(): + if dist.get_world_size(group=self.params[0].ds_process_group) <= 1: + check_mutation = False + if check_mutation: + mutated = [p for p in self.params if p._version != self._param_versions.get(p, p._version)] + mutated_any = bool(mutated) + if self.params and dist.is_initialized(): + backend = dist.get_backend(self.params[0].ds_process_group) + device = torch.device( + get_accelerator().current_device_name()) if backend == "nccl" else torch.device("cpu") + flag = torch.tensor([1 if mutated_any else 0], device=device, dtype=torch.int32) + dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=self.params[0].ds_process_group) + mutated_any = bool(int(flag.item())) + if mutated_any: + raise RuntimeError( + "Detected in-place modification of parameters inside `zero.GatheredParameters` " + "with `modifier_rank=None`. Use `modifier_rank=` when mutating parameters " + "so updates are broadcast consistently across ranks.") self.params[0].partition(param_list=self.params, has_been_updated=False) return diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index 1c065bb791f1..71e762e5819f 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -107,6 +107,29 @@ def __init__(self, hidden_dim): assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized" +class TestZeroFreeParamActiveSubmodule(DistributedTest): + world_size = 2 + + def test(self): + config_dict = {"train_micro_batch_size_per_gpu": 1, "zero_optimization": {"stage": 3}} + hidden_dim = 10 + + class MyModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.l2 = torch.nn.Linear(hidden_dim, hidden_dim) + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = MyModel(hidden_dim) + + with pytest.raises(RuntimeError, match="in-place modification"): + with deepspeed.zero.GatheredParameters([model.l1.weight, model.l2.weight], modifier_rank=None): + with torch.no_grad(): + model.l1.weight.add_(0.0) + + class TestSerialContext(DistributedTest): world_size = 1 init_distributed = False diff --git a/tests/unit/v1/zero/test_zero.py b/tests/unit/v1/zero/test_zero.py index fb0e393dd5da..9e43e6e4c041 100644 --- a/tests/unit/v1/zero/test_zero.py +++ b/tests/unit/v1/zero/test_zero.py @@ -232,7 +232,7 @@ def forward(self, x, y): orig_state_dict[name] = param.detach().cpu() if zero_stage == 3: - with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=None): + with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0): fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) fp32_state_dict = fp32_model.state_dict() else: @@ -339,7 +339,7 @@ def forward(self, x, y): orig_state_dict[name] = param.detach().cpu() if zero_stage == 3: - with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=None): + with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0): fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) fp32_state_dict = fp32_model.state_dict() else: