Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,21 @@ def get_local_rank():
return get_local_rank_from_launcher()


def get_backend(group=None):
"""
Returns the backend of the given process group.
Args:
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
The backend of the given process group as a string.
"""
global cdb
assert cdb is not None and cdb.is_initialized(
), 'DeepSpeed backend not set, please initialize it using init_process_group()'
return cdb.get_backend(group)


def get_global_rank(group=None, group_rank=0):
global cdb
assert cdb is not None and cdb.is_initialized(
Expand Down
26 changes: 25 additions & 1 deletion deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def recurse(cl):
@instrument_w_nvtx
def free_param(param: Parameter) -> None:
"""Free underlying storage of a parameter."""
assert not param.ds_active_sub_modules, param.ds_summary()
if param.ds_active_sub_modules:
raise RuntimeError("Cannot free a ZeRO-3 parameter while it is still active in submodules. "
"Ensure all submodules have released the parameter before it is freed. "
f"param={param.ds_summary()}")
if get_accelerator().on_accelerator(param.data):
# need to make sure that we don't free the parameter while it is still
# being used for computation
Expand Down Expand Up @@ -2316,11 +2319,32 @@ def __enter__(self):
if not self.enabled:
return
self.params[0].all_gather(param_list=self.params)
if self.src_rank is None:
self._param_versions = {p: p._version for p in self.params}

def __exit__(self, *exc):
if not self.enabled:
return
if self.src_rank is None:
check_mutation = True
if self.params and dist.is_initialized():
if dist.get_world_size(group=self.params[0].ds_process_group) <= 1:
check_mutation = False
if check_mutation:
mutated = [p for p in self.params if p._version != self._param_versions.get(p, p._version)]
mutated_any = bool(mutated)
if self.params and dist.is_initialized():
backend = dist.get_backend(self.params[0].ds_process_group)
device = torch.device(
get_accelerator().current_device_name()) if backend == "nccl" else torch.device("cpu")
flag = torch.tensor([1 if mutated_any else 0], device=device, dtype=torch.int32)
dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=self.params[0].ds_process_group)
mutated_any = bool(int(flag.item()))
if mutated_any:
raise RuntimeError(
"Detected in-place modification of parameters inside `zero.GatheredParameters` "
"with `modifier_rank=None`. Use `modifier_rank=<rank>` when mutating parameters "
"so updates are broadcast consistently across ranks.")
self.params[0].partition(param_list=self.params, has_been_updated=False)
return

Expand Down
23 changes: 23 additions & 0 deletions tests/unit/runtime/zero/test_zero_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ def __init__(self, hidden_dim):
assert model.l1.weight.numel() == 0, "outside of GatheredParameters the param should go back to be 0-sized"


class TestZeroFreeParamActiveSubmodule(DistributedTest):
world_size = 2

def test(self):
config_dict = {"train_micro_batch_size_per_gpu": 1, "zero_optimization": {"stage": 3}}
hidden_dim = 10

class MyModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(MyModel, self).__init__()
self.l1 = torch.nn.Linear(hidden_dim, hidden_dim)
self.l2 = torch.nn.Linear(hidden_dim, hidden_dim)

with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = MyModel(hidden_dim)

with pytest.raises(RuntimeError, match="in-place modification"):
with deepspeed.zero.GatheredParameters([model.l1.weight, model.l2.weight], modifier_rank=None):
with torch.no_grad():
model.l1.weight.add_(0.0)


class TestSerialContext(DistributedTest):
world_size = 1
init_distributed = False
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/v1/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def forward(self, x, y):
orig_state_dict[name] = param.detach().cpu()

if zero_stage == 3:
with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=None):
with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
fp32_state_dict = fp32_model.state_dict()
else:
Expand Down Expand Up @@ -339,7 +339,7 @@ def forward(self, x, y):
orig_state_dict[name] = param.detach().cpu()

if zero_stage == 3:
with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=None):
with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir)
fp32_state_dict = fp32_model.state_dict()
else:
Expand Down
Loading