From a9265ae97334d0bea33947fe52342e73e4c504c2 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 25 Jan 2026 23:52:46 -0800 Subject: [PATCH 1/6] improve error raised by modifying param in GatheredParameters context Signed-off-by: Masahiro Tanaka --- tests/unit/runtime/zero/test_zero_context.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/unit/runtime/zero/test_zero_context.py b/tests/unit/runtime/zero/test_zero_context.py index 1c065bb791f1..71e762e5819f 100644 --- a/tests/unit/runtime/zero/test_zero_context.py +++ b/tests/unit/runtime/zero/test_zero_context.py @@ -107,6 +107,29 @@ def __init__(self, hidden_dim): assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized" +class TestZeroFreeParamActiveSubmodule(DistributedTest): + world_size = 2 + + def test(self): + config_dict = {"train_micro_batch_size_per_gpu": 1, "zero_optimization": {"stage": 3}} + hidden_dim = 10 + + class MyModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(MyModel, self).__init__() + self.l1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.l2 = torch.nn.Linear(hidden_dim, hidden_dim) + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = MyModel(hidden_dim) + + with pytest.raises(RuntimeError, match="in-place modification"): + with deepspeed.zero.GatheredParameters([model.l1.weight, model.l2.weight], modifier_rank=None): + with torch.no_grad(): + model.l1.weight.add_(0.0) + + class TestSerialContext(DistributedTest): world_size = 1 init_distributed = False From b95cd4237161636cecdc9d6462365220e18d415d Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 25 Jan 2026 23:55:22 -0800 Subject: [PATCH 2/6] check in fix in partition_parameters Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/partition_parameters.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 47aa4f41bcbc..7832b4ba3cb6 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -299,7 +299,10 @@ def recurse(cl): @instrument_w_nvtx def free_param(param: Parameter) -> None: """Free underlying storage of a parameter.""" - assert not param.ds_active_sub_modules, param.ds_summary() + if param.ds_active_sub_modules: + raise RuntimeError("Cannot free a ZeRO-3 parameter while it is still active in submodules. " + "Ensure all submodules have released the parameter before it is freed. " + f"param={param.ds_summary()}") if get_accelerator().on_accelerator(param.data): # need to make sure that we don't free the parameter while it is still # being used for computation @@ -2316,11 +2319,24 @@ def __enter__(self): if not self.enabled: return self.params[0].all_gather(param_list=self.params) + if self.src_rank is None: + self._param_versions = {p: p._version for p in self.params} def __exit__(self, *exc): if not self.enabled: return if self.src_rank is None: + mutated = [p for p in self.params if p._version != self._param_versions.get(p, p._version)] + mutated_any = bool(mutated) + if self.params and dist.is_initialized(): + device = self.params[0].device + flag = torch.tensor([1 if mutated_any else 0], device=device, dtype=torch.int32) + dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=self.params[0].ds_process_group) + mutated_any = bool(int(flag.item())) + if mutated_any: + raise RuntimeError("Detected in-place modification of parameters inside `zero.GatheredParameters` " + "with `modifier_rank=None`. Use `modifier_rank=` when mutating parameters " + "so updates are broadcast consistently across ranks.") self.params[0].partition(param_list=self.params, has_been_updated=False) return From b178f5ae4566eafbb4863c1050a3808df8c9815c Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 26 Jan 2026 00:40:29 -0800 Subject: [PATCH 3/6] get device from accelerator Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/partition_parameters.py | 2 +- tests/unit/v1/zero/test_zero.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 7832b4ba3cb6..b3adf5adedb2 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -2329,7 +2329,7 @@ def __exit__(self, *exc): mutated = [p for p in self.params if p._version != self._param_versions.get(p, p._version)] mutated_any = bool(mutated) if self.params and dist.is_initialized(): - device = self.params[0].device + device = torch.device(get_accelerator().current_device_name()) flag = torch.tensor([1 if mutated_any else 0], device=device, dtype=torch.int32) dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=self.params[0].ds_process_group) mutated_any = bool(int(flag.item())) diff --git a/tests/unit/v1/zero/test_zero.py b/tests/unit/v1/zero/test_zero.py index fb0e393dd5da..9e43e6e4c041 100644 --- a/tests/unit/v1/zero/test_zero.py +++ b/tests/unit/v1/zero/test_zero.py @@ -232,7 +232,7 @@ def forward(self, x, y): orig_state_dict[name] = param.detach().cpu() if zero_stage == 3: - with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=None): + with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0): fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) fp32_state_dict = fp32_model.state_dict() else: @@ -339,7 +339,7 @@ def forward(self, x, y): orig_state_dict[name] = param.detach().cpu() if zero_stage == 3: - with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=None): + with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0): fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) fp32_state_dict = fp32_model.state_dict() else: From 0213957354aa0cd1d448d20d7951cb860ea199fc Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 26 Jan 2026 08:59:50 -0800 Subject: [PATCH 4/6] check backend to match device Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/partition_parameters.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index b3adf5adedb2..1a37febd300b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -2329,7 +2329,9 @@ def __exit__(self, *exc): mutated = [p for p in self.params if p._version != self._param_versions.get(p, p._version)] mutated_any = bool(mutated) if self.params and dist.is_initialized(): - device = torch.device(get_accelerator().current_device_name()) + backend = dist.get_backend(self.params[0].ds_process_group) + device = torch.device( + get_accelerator().current_device_name()) if backend == "nccl" else torch.device("cpu") flag = torch.tensor([1 if mutated_any else 0], device=device, dtype=torch.int32) dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=self.params[0].ds_process_group) mutated_any = bool(int(flag.item())) From 2ab4c498b7f4cd9c0d892b573a8b4faa95de86dc Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 26 Jan 2026 09:39:37 -0800 Subject: [PATCH 5/6] check mutation only when process group size > 1 Signed-off-by: Masahiro Tanaka --- .../runtime/zero/partition_parameters.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 1a37febd300b..df8ebca47623 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -2326,19 +2326,25 @@ def __exit__(self, *exc): if not self.enabled: return if self.src_rank is None: - mutated = [p for p in self.params if p._version != self._param_versions.get(p, p._version)] - mutated_any = bool(mutated) + check_mutation = True if self.params and dist.is_initialized(): - backend = dist.get_backend(self.params[0].ds_process_group) - device = torch.device( - get_accelerator().current_device_name()) if backend == "nccl" else torch.device("cpu") - flag = torch.tensor([1 if mutated_any else 0], device=device, dtype=torch.int32) - dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=self.params[0].ds_process_group) - mutated_any = bool(int(flag.item())) - if mutated_any: - raise RuntimeError("Detected in-place modification of parameters inside `zero.GatheredParameters` " - "with `modifier_rank=None`. Use `modifier_rank=` when mutating parameters " - "so updates are broadcast consistently across ranks.") + if dist.get_world_size(group=self.params[0].ds_process_group) <= 1: + check_mutation = False + if check_mutation: + mutated = [p for p in self.params if p._version != self._param_versions.get(p, p._version)] + mutated_any = bool(mutated) + if self.params and dist.is_initialized(): + backend = dist.get_backend(self.params[0].ds_process_group) + device = torch.device( + get_accelerator().current_device_name()) if backend == "nccl" else torch.device("cpu") + flag = torch.tensor([1 if mutated_any else 0], device=device, dtype=torch.int32) + dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=self.params[0].ds_process_group) + mutated_any = bool(int(flag.item())) + if mutated_any: + raise RuntimeError( + "Detected in-place modification of parameters inside `zero.GatheredParameters` " + "with `modifier_rank=None`. Use `modifier_rank=` when mutating parameters " + "so updates are broadcast consistently across ranks.") self.params[0].partition(param_list=self.params, has_been_updated=False) return From 9582f49924faf9f7cc53e34225e42192ceb3b69e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 26 Jan 2026 22:13:47 -0800 Subject: [PATCH 6/6] add api to get backend Signed-off-by: Masahiro Tanaka --- deepspeed/comm/comm.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index f9e94f0175e2..310e2832ffb6 100755 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -736,6 +736,21 @@ def get_local_rank(): return get_local_rank_from_launcher() +def get_backend(group=None): + """ + Returns the backend of the given process group. + Args: + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + Returns: + The backend of the given process group as a string. + """ + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + return cdb.get_backend(group) + + def get_global_rank(group=None, group_rank=0): global cdb assert cdb is not None and cdb.is_initialized(