-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Describe the bug
I hit a reproducible crash in DeepSpeed ZeRO Stage 3 when repeatedly entering/exiting deepspeed.zero.GatheredParameters over multiple parameters ([emb.weight, l1.weight, l2.weight]) with modifier_rank=None and touching small parameter slices inside the gather context.
The crash occurs on context exit in ZeRO-3 repartition/free logic and raises:
AssertionError: assert not param.ds_active_sub_modules, param.ds_summary()
This is a DeepSpeed internal invariant failure in free_param() (deepspeed/runtime/zero/partition_parameters.py). Even if this usage is considered invalid, it should fail with a clear user-facing exception rather than an internal assert.
To Reproduce
import os
import sys
import json
import traceback
import random
def _skip(msg: str):
print(f"SKIP_ENV: {msg}", flush=True)
sys.exit(0)
def env_int(k: str, d: int) -> int:
v = os.environ.get(k, "").strip()
try:
return int(v) if v else d
except Exception:
return d
def env_bool(k: str, d: bool) -> bool:
v = os.environ.get(k, "").strip().lower()
if not v:
return d
return v in ("1", "true", "yes", "y", "on")
def main():
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
except Exception as e:
_skip(f"missing torch/dist: {type(e).__name__}: {e}")
try:
import deepspeed
from deepspeed import zero
except Exception as e:
_skip(f"missing deepspeed: {type(e).__name__}: {e}")
if not dist.is_available():
_skip("torch.distributed not available")
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world = dist.get_world_size()
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
def barrier():
dist.barrier()
ds_config = os.environ.get("DEEPSPEED_CONFIG", "ds_config_zero3_stress.json")
VOCAB = env_int("VOCAB", 32768)
D_MODEL = env_int("D_MODEL", 2048)
ITERS = env_int("ITERS", 200)
TILE = env_int("TILE", 8)
DO_BWD = env_bool("DO_BWD", True)
FORCE_EDIT = env_bool("FORCE_GATHER_EDIT", True)
try:
cfg = json.load(open(ds_config, "r", encoding="utf-8"))
except Exception as e:
_skip(f"cannot load {ds_config}: {type(e).__name__}: {e}")
micro = int(cfg.get("train_micro_batch_size_per_gpu", 1))
gas = int(cfg.get("gradient_accumulation_steps", 1))
tbs = int(cfg.get("train_batch_size", micro * gas * world))
if tbs != micro * gas * world:
_skip(f"bad DS batch params: train_batch_size {tbs} != micro {micro} * gas {gas} * world {world}")
class M(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(VOCAB, D_MODEL)
self.l1 = nn.Linear(D_MODEL, 4 * D_MODEL, bias=False)
self.l2 = nn.Linear(4 * D_MODEL, D_MODEL, bias=False)
def forward(self, x):
h = self.emb(x)
h = self.l1(h)
h = F.gelu(h)
h = self.l2(h)
return h.sum()
model = M().to(device).train()
try:
engine, _, _, _ = deepspeed.initialize(
model=model,
model_parameters=[p for p in model.parameters() if p.requires_grad],
config=ds_config,
)
except Exception as e:
_skip(f"deepspeed.initialize failed: {type(e).__name__}: {e}")
wrapped = getattr(engine, "module", engine)
p_emb = wrapped.emb.weight
p_l1 = wrapped.l1.weight
p_l2 = wrapped.l2.weight
params = [p_emb, p_l1, p_l2]
if rank == 0:
print(
f"[rank0] CONFIG world={world} device={device} vocab={VOCAB} d_model={D_MODEL} "
f"iters={ITERS} tile={TILE} do_bwd={DO_BWD} force_edit={FORCE_EDIT}",
flush=True,
)
rnd = random.Random(2026 + rank)
any_exc = None
try:
for i in range(1, ITERS + 1):
x = torch.randint(0, VOCAB, (1, 8), device=device)
with zero.GatheredParameters(params, modifier_rank=None):
for p in params:
t = p.data
r = rnd.randrange(0, t.shape[0])
if t.ndim >= 2:
cmax = max(int(t.shape[1]) - TILE, 1)
c0 = rnd.randrange(0, cmax)
_ = t[r, c0 : c0 + TILE].sum().item()
if FORCE_EDIT:
t[r, c0 : c0 + TILE].add_(0.0)
loss = engine(x)
if DO_BWD:
engine.backward(loss)
engine.step()
else:
engine.zero_grad()
if rank == 0 and (i % 25 == 0 or i == ITERS):
print(f"[rank0] step={i} ok", flush=True)
except Exception as e:
any_exc = f"{type(e).__name__}: {e}"
if rank == 0:
print("[rank0] EXC TRACEBACK:", flush=True)
traceback.print_exc()
flag = torch.tensor([1 if any_exc else 0], device=device, dtype=torch.int32)
dist.all_reduce(flag, op=dist.ReduceOp.MAX)
hit = bool(int(flag.item()) == 1)
if rank == 0:
print(f"[rank0] RESULT hit={hit} exc={any_exc}", flush=True)
decision = torch.tensor([1 if hit else 0], device=device, dtype=torch.int32)
dist.broadcast(decision, src=0)
barrier()
if int(decision.item()) == 1:
if rank == 0:
print("Test Passed ✅", flush=True)
else:
if rank == 0:
print("Test Failed ❌", flush=True)
barrier()
try:
dist.destroy_process_group()
except Exception:
pass
sys.exit(0)
if __name__ == "__main__":
main()
ZeRO-3 config file:
{
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"train_batch_size": 2,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.0
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"stage3_param_persistence_threshold": 0
},
"fp16": { "enabled": false },
"bf16": { "enabled": false },
"steps_per_print": 0,
"wall_clock_breakdown": false
}
Run on 2 GPUs
export CUDA_VISIBLE_DEVICES=0,1
export NCCL_DEBUG=WARN
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export DEEPSPEED_CONFIG=ds_config_zero3_stress.json
export VOCAB=32768
export D_MODEL=2048
export ITERS=200
export TILE=8
export DO_BWD=1
export FORCE_GATHER_EDIT=1
torchrun --nproc_per_node=2 ds_zero3_gather_assert_tri.py 2>&1 | tee FINAL_out_gather_tri.log
Observed output
[rank0] CONFIG world=2 device=cuda:0 vocab=32768 d_model=2048 iters=200 tile=8 do_bwd=True force_edit=True
[rank0] EXC TRACEBACK:
Traceback (most recent call last):
File ".../ds_zero3_gather_assert_tri.py", line 125, in main
with zero.GatheredParameters(params, modifier_rank=None):
File ".../deepspeed/runtime/zero/partition_parameters.py", line 2344, in __exit__
self.params[0].partition(param_list=self.params, has_been_updated=False)
File ".../deepspeed/runtime/zero/partition_parameters.py", line 1487, in partition
self._partition(param_list, has_been_updated=has_been_updated, free_data=True)
File ".../deepspeed/runtime/zero/partition_parameters.py", line 1636, in _partition
self._partition_param(param, has_been_updated=has_been_updated, free_data=True)
File ".../deepspeed/runtime/zero/partition_parameters.py", line 1670, in _partition_param
free_param(param)
File ".../deepspeed/runtime/zero/partition_parameters.py", line 302, in free_param
assert not param.ds_active_sub_modules, param.ds_summary()
AssertionError: {'id': 0, 'status': 'AVAILABLE', 'numel': 67108864, 'ds_numel': 67108864, 'shape': (32768, 2048), 'ds_shape': (32768, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {1}, 'ds_tensor.shape': torch.Size([33554432])}
[rank0] RESULT hit=True exc=AssertionError: {'id': 0, ... 'active_sub_modules': {1}, ...}
Test Passed ✅
Expected behavior
Exiting zero.GatheredParameters should not crash. Either:
- DeepSpeed should correctly re-partition and free gathered parameter storage, or
- if this usage is invalid (e.g., any in-place touch inside gathered context when modifier_rank=None), DeepSpeed should raise a clear runtime error explaining the misuse, not an internal invariant assert.
Launcher context
Launched using torchrun --nproc_per_node=2 (PyTorch distributed launcher) with NCCL backend.
Not using the deepspeed launcher for this reproducer.
System info
**OS**: Ubuntu 24.04.3 LTS (Noble Numbat), kernel 6.8.0-31-generic
**Python**: 3.13.5 (venv: /home/talha/.venvs/dl_testing/bin/python)
**PyTorch**: 2.6.0+cu124 (CUDA runtime: 12.4)
**DeepSpeed**: 0.18.4 (pip install in venv; install path: /home/talha/.venvs/dl_testing/lib/python3.13/site-packages/deepspeed)
**GPU count & type**: 4 × NVIDIA GeForce RTX 3090 (24GB each)
**Driver / CUDA**: NVIDIA driver 550.78, CUDA driver/runtime 12.4
**Interconnect**: Single-node (PCIe). No multi-node IB used.
**Launcher**: torchrun --nproc_per_node=2 (NCCL backend)
Additional context
-
The reproducer gathers multiple parameters in a single GatheredParameters context and touches random slices; this seems to stress ZeRO-3 partition/free bookkeeping.
-
The crash is triggered inside GatheredParameters.exit() during partition(... free_data=True) when free_param() asserts that param.ds_active_sub_modules is empty, but it is not (active_sub_modules: {1}).
-
I understand that modifying parameters inside a GatheredParameters(..., modifier_rank=None) context may be considered invalid. However, the current behavior is still problematic because it crashes via an internal assert rather than producing a descriptive exception and guidance (e.g., requiring modifier_rank= for mutations).
-
Environment uses Python 3.13.5 and pip-installed DeepSpeed 0.18.4. ds_report reports deepspeed info: 0.18.4, unknown, unknown and deepspeed wheel compiled w. torch 0.0, cuda 0.0 (possibly missing build metadata).
-
Repro does not rely on optional ops (async_io/gds/sparse_attn). Crash occurs in ZeRO-3 parameter partition/free path inside Python runtime (partition_parameters.py).
DeepSpeed C++/CUDA extension op report:
JIT compiled ops requires ninja
ninja .................. [OKAY]
[WARNING] async_io requires libaio-dev (not found)
async_io ............... [NO] ....... [NO]
[WARNING] Please specify CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
/home/talha/miniconda3/compiler_compat/ld: cannot find -lcufile: No such file or directory
gds .................... [NO] ....... [NO]
[WARNING] sparse_attn requires torch >=1.5 and <2.0 but detected 2.6
[WARNING] using untested triton version (3.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
DeepSpeed general environment info:
torch install path ............... ['/home/talha/.venvs/dl_testing/lib/python3.13/site-packages/torch']
torch version .................... 2.6.0+cu124
deepspeed install path ........... ['/home/talha/.venvs/dl_testing/lib/python3.13/site-packages/deepspeed']
deepspeed info ................... 0.18.4, unknown, unknown
torch cuda version ............... 12.4
nvcc version ..................... 12.0
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 62.88 GB