Skip to content

[BUG] ZeRO-3: zero.GatheredParameters([multiple params], modifier_rank=None) + in-place slice touch triggers assert not param.ds_active_sub_modules in free_param() #7811

@griffinstalha

Description

@griffinstalha

Describe the bug
I hit a reproducible crash in DeepSpeed ZeRO Stage 3 when repeatedly entering/exiting deepspeed.zero.GatheredParameters over multiple parameters ([emb.weight, l1.weight, l2.weight]) with modifier_rank=None and touching small parameter slices inside the gather context.

The crash occurs on context exit in ZeRO-3 repartition/free logic and raises:

AssertionError: assert not param.ds_active_sub_modules, param.ds_summary()

This is a DeepSpeed internal invariant failure in free_param() (deepspeed/runtime/zero/partition_parameters.py). Even if this usage is considered invalid, it should fail with a clear user-facing exception rather than an internal assert.

To Reproduce

import os
import sys
import json
import traceback
import random


def _skip(msg: str):
    print(f"SKIP_ENV: {msg}", flush=True)
    sys.exit(0)


def env_int(k: str, d: int) -> int:
    v = os.environ.get(k, "").strip()
    try:
        return int(v) if v else d
    except Exception:
        return d


def env_bool(k: str, d: bool) -> bool:
    v = os.environ.get(k, "").strip().lower()
    if not v:
        return d
    return v in ("1", "true", "yes", "y", "on")


def main():
    try:
        import torch
        import torch.nn as nn
        import torch.nn.functional as F
        import torch.distributed as dist
    except Exception as e:
        _skip(f"missing torch/dist: {type(e).__name__}: {e}")

    try:
        import deepspeed
        from deepspeed import zero
    except Exception as e:
        _skip(f"missing deepspeed: {type(e).__name__}: {e}")

    if not dist.is_available():
        _skip("torch.distributed not available")

    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")

    rank = dist.get_rank()
    world = dist.get_world_size()

    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")

    def barrier():
        dist.barrier()

    ds_config = os.environ.get("DEEPSPEED_CONFIG", "ds_config_zero3_stress.json")

    VOCAB = env_int("VOCAB", 32768)
    D_MODEL = env_int("D_MODEL", 2048)
    ITERS = env_int("ITERS", 200)
    TILE = env_int("TILE", 8)
    DO_BWD = env_bool("DO_BWD", True)
    FORCE_EDIT = env_bool("FORCE_GATHER_EDIT", True)

    try:
        cfg = json.load(open(ds_config, "r", encoding="utf-8"))
    except Exception as e:
        _skip(f"cannot load {ds_config}: {type(e).__name__}: {e}")

    micro = int(cfg.get("train_micro_batch_size_per_gpu", 1))
    gas = int(cfg.get("gradient_accumulation_steps", 1))
    tbs = int(cfg.get("train_batch_size", micro * gas * world))
    if tbs != micro * gas * world:
        _skip(f"bad DS batch params: train_batch_size {tbs} != micro {micro} * gas {gas} * world {world}")

    class M(nn.Module):
        def __init__(self):
            super().__init__()
            self.emb = nn.Embedding(VOCAB, D_MODEL)
            self.l1 = nn.Linear(D_MODEL, 4 * D_MODEL, bias=False)
            self.l2 = nn.Linear(4 * D_MODEL, D_MODEL, bias=False)

        def forward(self, x):
            h = self.emb(x)
            h = self.l1(h)
            h = F.gelu(h)
            h = self.l2(h)
            return h.sum()

    model = M().to(device).train()

    try:
        engine, _, _, _ = deepspeed.initialize(
            model=model,
            model_parameters=[p for p in model.parameters() if p.requires_grad],
            config=ds_config,
        )
    except Exception as e:
        _skip(f"deepspeed.initialize failed: {type(e).__name__}: {e}")

    wrapped = getattr(engine, "module", engine)

    p_emb = wrapped.emb.weight
    p_l1 = wrapped.l1.weight
    p_l2 = wrapped.l2.weight
    params = [p_emb, p_l1, p_l2]

    if rank == 0:
        print(
            f"[rank0] CONFIG world={world} device={device} vocab={VOCAB} d_model={D_MODEL} "
            f"iters={ITERS} tile={TILE} do_bwd={DO_BWD} force_edit={FORCE_EDIT}",
            flush=True,
        )

    rnd = random.Random(2026 + rank)
    any_exc = None

    try:
        for i in range(1, ITERS + 1):
            x = torch.randint(0, VOCAB, (1, 8), device=device)

            with zero.GatheredParameters(params, modifier_rank=None):
                for p in params:
                    t = p.data
                    r = rnd.randrange(0, t.shape[0])
                    if t.ndim >= 2:
                        cmax = max(int(t.shape[1]) - TILE, 1)
                        c0 = rnd.randrange(0, cmax)
                        _ = t[r, c0 : c0 + TILE].sum().item()
                        if FORCE_EDIT:
                            t[r, c0 : c0 + TILE].add_(0.0)

            loss = engine(x)
            if DO_BWD:
                engine.backward(loss)
                engine.step()
            else:
                engine.zero_grad()

            if rank == 0 and (i % 25 == 0 or i == ITERS):
                print(f"[rank0] step={i} ok", flush=True)

    except Exception as e:
        any_exc = f"{type(e).__name__}: {e}"
        if rank == 0:
            print("[rank0] EXC TRACEBACK:", flush=True)
            traceback.print_exc()

    flag = torch.tensor([1 if any_exc else 0], device=device, dtype=torch.int32)
    dist.all_reduce(flag, op=dist.ReduceOp.MAX)
    hit = bool(int(flag.item()) == 1)

    if rank == 0:
        print(f"[rank0] RESULT hit={hit} exc={any_exc}", flush=True)

    decision = torch.tensor([1 if hit else 0], device=device, dtype=torch.int32)
    dist.broadcast(decision, src=0)
    barrier()

    if int(decision.item()) == 1:
        if rank == 0:
            print("Test Passed ✅", flush=True)
    else:
        if rank == 0:
            print("Test Failed ❌", flush=True)

    barrier()
    try:
        dist.destroy_process_group()
    except Exception:
        pass
    sys.exit(0)


if __name__ == "__main__":
    main()

ZeRO-3 config file:

{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 1,
  "train_batch_size": 2,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 1e-4,
      "betas": [0.9, 0.999],
      "eps": 1e-8,
      "weight_decay": 0.0
    }
  },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "stage3_param_persistence_threshold": 0
  },
  "fp16": { "enabled": false },
  "bf16": { "enabled": false },
  "steps_per_print": 0,
  "wall_clock_breakdown": false
}

Run on 2 GPUs

export CUDA_VISIBLE_DEVICES=0,1
export NCCL_DEBUG=WARN
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export DEEPSPEED_CONFIG=ds_config_zero3_stress.json

export VOCAB=32768
export D_MODEL=2048
export ITERS=200
export TILE=8
export DO_BWD=1
export FORCE_GATHER_EDIT=1

torchrun --nproc_per_node=2 ds_zero3_gather_assert_tri.py 2>&1 | tee FINAL_out_gather_tri.log

Observed output

[rank0] CONFIG world=2 device=cuda:0 vocab=32768 d_model=2048 iters=200 tile=8 do_bwd=True force_edit=True
[rank0] EXC TRACEBACK:
Traceback (most recent call last):
  File ".../ds_zero3_gather_assert_tri.py", line 125, in main
    with zero.GatheredParameters(params, modifier_rank=None):
  File ".../deepspeed/runtime/zero/partition_parameters.py", line 2344, in __exit__
    self.params[0].partition(param_list=self.params, has_been_updated=False)
  File ".../deepspeed/runtime/zero/partition_parameters.py", line 1487, in partition
    self._partition(param_list, has_been_updated=has_been_updated, free_data=True)
  File ".../deepspeed/runtime/zero/partition_parameters.py", line 1636, in _partition
    self._partition_param(param, has_been_updated=has_been_updated, free_data=True)
  File ".../deepspeed/runtime/zero/partition_parameters.py", line 1670, in _partition_param
    free_param(param)
  File ".../deepspeed/runtime/zero/partition_parameters.py", line 302, in free_param
    assert not param.ds_active_sub_modules, param.ds_summary()
AssertionError: {'id': 0, 'status': 'AVAILABLE', 'numel': 67108864, 'ds_numel': 67108864, 'shape': (32768, 2048), 'ds_shape': (32768, 2048), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {1}, 'ds_tensor.shape': torch.Size([33554432])}

[rank0] RESULT hit=True exc=AssertionError: {'id': 0, ... 'active_sub_modules': {1}, ...}
Test Passed ✅

Expected behavior
Exiting zero.GatheredParameters should not crash. Either:

  • DeepSpeed should correctly re-partition and free gathered parameter storage, or
  • if this usage is invalid (e.g., any in-place touch inside gathered context when modifier_rank=None), DeepSpeed should raise a clear runtime error explaining the misuse, not an internal invariant assert.

Launcher context

Launched using torchrun --nproc_per_node=2 (PyTorch distributed launcher) with NCCL backend.
Not using the deepspeed launcher for this reproducer.

System info

**OS**: Ubuntu 24.04.3 LTS (Noble Numbat), kernel 6.8.0-31-generic
**Python**: 3.13.5 (venv: /home/talha/.venvs/dl_testing/bin/python)
**PyTorch**: 2.6.0+cu124 (CUDA runtime: 12.4)
**DeepSpeed**: 0.18.4 (pip install in venv; install path: /home/talha/.venvs/dl_testing/lib/python3.13/site-packages/deepspeed)
**GPU count & type**: 4 × NVIDIA GeForce RTX 3090 (24GB each)
**Driver / CUDA**: NVIDIA driver 550.78, CUDA driver/runtime 12.4
**Interconnect**: Single-node (PCIe). No multi-node IB used.
**Launcher**: torchrun --nproc_per_node=2 (NCCL backend)

Additional context

  • The reproducer gathers multiple parameters in a single GatheredParameters context and touches random slices; this seems to stress ZeRO-3 partition/free bookkeeping.

  • The crash is triggered inside GatheredParameters.exit() during partition(... free_data=True) when free_param() asserts that param.ds_active_sub_modules is empty, but it is not (active_sub_modules: {1}).

  • I understand that modifying parameters inside a GatheredParameters(..., modifier_rank=None) context may be considered invalid. However, the current behavior is still problematic because it crashes via an internal assert rather than producing a descriptive exception and guidance (e.g., requiring modifier_rank= for mutations).

  • Environment uses Python 3.13.5 and pip-installed DeepSpeed 0.18.4. ds_report reports deepspeed info: 0.18.4, unknown, unknown and deepspeed wheel compiled w. torch 0.0, cuda 0.0 (possibly missing build metadata).

  • Repro does not rely on optional ops (async_io/gds/sparse_attn). Crash occurs in ZeRO-3 parameter partition/free path inside Python runtime (partition_parameters.py).

DeepSpeed C++/CUDA extension op report:

JIT compiled ops requires ninja
ninja .................. [OKAY]

[WARNING] async_io requires libaio-dev (not found)
async_io ............... [NO] ....... [NO]

[WARNING] Please specify CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]

/home/talha/miniconda3/compiler_compat/ld: cannot find -lcufile: No such file or directory
gds .................... [NO] ....... [NO]

[WARNING] sparse_attn requires torch >=1.5 and <2.0 but detected 2.6
[WARNING] using untested triton version (3.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]

DeepSpeed general environment info:
torch install path ............... ['/home/talha/.venvs/dl_testing/lib/python3.13/site-packages/torch']
torch version .................... 2.6.0+cu124
deepspeed install path ........... ['/home/talha/.venvs/dl_testing/lib/python3.13/site-packages/deepspeed']
deepspeed info ................... 0.18.4, unknown, unknown
torch cuda version ............... 12.4
nvcc version ..................... 12.0
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 62.88 GB

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions