From c8a72ab45f9bc392892e6a65e0ff9cc32fd09f1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 15 Dec 2025 11:34:13 +0800 Subject: [PATCH 01/15] Introduce Megatron-style parallel state management Signed-off-by: Jikang Mo Signed-off-by: Junjie Mao --- deepspeed/utils/parallel_state.py | 1037 ++++++++++++ deepspeed/utils/parallel_state_deepspeed.py | 555 ++++++ tests/unit/utils/test_mpu.py | 1692 +++++++++++++++++++ 3 files changed, 3284 insertions(+) create mode 100644 deepspeed/utils/parallel_state.py create mode 100644 deepspeed/utils/parallel_state_deepspeed.py create mode 100644 tests/unit/utils/test_mpu.py diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py new file mode 100644 index 000000000000..df9906d2fcee --- /dev/null +++ b/deepspeed/utils/parallel_state.py @@ -0,0 +1,1037 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team + +# DeepSpeed Team + +# The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Refactored Model and data parallel groups with class-based design.""" + +import logging +from datetime import timedelta +from typing import Callable, List, Optional + +import numpy as np +import torch + +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +logger = logging.getLogger(__name__) + +try: + import einops + HAVE_EINOPS = True +except ImportError: + HAVE_EINOPS = False + + +def is_torch_min_version(version: str, check_equality: bool = True) -> bool: + """Check if PyTorch version meets minimum requirement. + + Args: + version: Version string to check (e.g., "2.4.0") + check_equality: If True, also check for equality + + Returns: + True if version requirement is met + """ + try: + from packaging.version import Version as PkgVersion + torch_version = PkgVersion(torch.__version__) + required_version = PkgVersion(version) + if check_equality: + return torch_version >= required_version + return torch_version > required_version + except Exception: + return False + + +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name, mem_alloc_context=None): + """Returns a sub-tensor from the buffer for the given shape.""" + from functools import reduce + import operator + + required_len = reduce(operator.mul, tensor_shape, 1) + if (self.buffer.get((name, dtype), None) is None or self.buffer[(name, dtype)].numel() < required_len): + from contextlib import nullcontext + mem_alloc_context = mem_alloc_context if mem_alloc_context else nullcontext + with mem_alloc_context(): + self.buffer[(name, dtype)] = torch.empty( + required_len, + dtype=dtype, + device=get_accelerator().current_device(), + requires_grad=False, + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + + +def generate_masked_orthogonal_rank_groups(world_size: int, parallel_size: List[int], + mask: List[bool]) -> List[List[int]]: + r"""Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + local_rank satisfy the following equation: + global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """Solve: index = sum(idx[i] * stride[i])""" + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + assert (sum([x * y for x, y in zip(idx, stride[:-1])]) == index), f"idx {index} with shape {shape} mismatch" + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride)) + ranks.append(rank) + return ranks + + +class RankGenerator: + """A class for generating rank groups for different modes of parallelism.""" + + def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0) -> None: + assert (ep == 1 or cp == 1), "Both EP and CP > 1 is not allowed in one rank generator." + + self.tp = tp + self.ep = ep + self.dp = dp + self.pp = pp + self.cp = cp + self.rank_offset = rank_offset + self.world_size = tp * dp * pp * cp * ep + + self.name_to_size = { + "tp": self.tp, + "pp": self.pp, + "dp": self.dp, + "ep": self.ep, + "cp": self.cp, + } + self.order = order + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError(f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't" + f"specified the order ({self.order}).") + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + """Create a mask for the specified tokens based on the given order.""" + ordered_token = order.split("-") + token_list = token.split("-") + mask = [False] * len(ordered_token) + for t in token_list: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Args: + token (str): Specify the ranks type (e.g., 'tp-dp') + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks + + +class ParallelState: + """Encapsulates all parallel state and operations. + + This class replaces the global variables and functions from the original + parallel_state.py, providing a cleaner, more maintainable interface. + """ + + def __init__(self): + # Process groups + self.tensor_model_parallel_group = None + self.pipeline_model_parallel_group = None + self.model_parallel_group = None + self.embedding_group = None + self.position_embedding_group = None + self.data_parallel_group = None + self.data_parallel_group_gloo = None + self.tensor_and_data_parallel_group = None + self.context_parallel_group = None + self.tensor_and_context_parallel_group = None + self.tensor_and_data_parallel_group_with_cp = None + self.data_parallel_group_with_cp = None + self.data_parallel_group_with_cp_gloo = None + + # Expert-related groups + self.expert_model_parallel_group = None + self.expert_tensor_parallel_group = None + self.expert_tensor_and_model_parallel_group = None + self.expert_tensor_model_pipeline_parallel_group = None + self.expert_data_parallel_group = None + self.expert_data_parallel_group_gloo = None + self.intra_partial_expert_data_parallel_group = None + self.intra_partial_expert_data_parallel_group_gloo = None + self.inter_partial_expert_data_parallel_group = None + + # Global ranks lists + self.embedding_global_ranks = None + self.position_embedding_global_ranks = None + self.pipeline_global_ranks = None + self.data_parallel_global_ranks = None + self.tensor_model_parallel_global_ranks = None + self.model_parallel_global_ranks = None + self.context_parallel_global_ranks = None + self.data_parallel_global_ranks_with_cp = None + self.hierarchical_context_parallel_groups = None + + # Parallel state values + self.virtual_pipeline_model_parallel_rank = None + self.virtual_pipeline_model_parallel_world_size = None + self.mpu_tensor_model_parallel_world_size = None + self.mpu_pipeline_model_parallel_world_size = None + self.mpu_data_parallel_world_size = None + self.mpu_data_parallel_rank = None + self.mpu_tensor_model_parallel_rank = None + self.mpu_pipeline_model_parallel_rank = None + + # Expert parallel state values + self.mpu_expert_model_parallel_world_size = None + self.mpu_expert_model_parallel_rank = None + self.mpu_expert_tensor_parallel_world_size = None + self.mpu_expert_tensor_parallel_rank = None + + # Other + self.global_memory_buffer = None + self.global_process_group_list = None + self.intra_partial_data_parallel_group_with_cp = None + self.intra_partial_data_parallel_group_with_cp_gloo = None + self.intra_distributed_optimizer_instance_group = None + + # Rank generators + self.decoder_rank_generator = None + self.expert_decoder_rank_generator = None + + def _get_nccl_options(self, pg_name: str, nccl_comm_cfgs: dict): + """Set the NCCL process group options.""" + if pg_name in nccl_comm_cfgs: + # FIXME: deepspeed.comm does not provide a way to set NCCL options yet. + nccl_options = torch.distributed.ProcessGroupNCCL.Options( + is_high_priority_stream=nccl_comm_cfgs[pg_name].get("is_high_priority_stream", False)) + if "cga_cluster_size" in nccl_comm_cfgs[pg_name]: + nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name]["cga_cluster_size"] + if "max_ctas" in nccl_comm_cfgs[pg_name]: + nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name]["max_ctas"] + if "min_ctas" in nccl_comm_cfgs[pg_name]: + nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name]["min_ctas"] + if "net_name" in nccl_comm_cfgs[pg_name]: + nccl_options.config.net_name = nccl_comm_cfgs[pg_name]["net_name"] + if nccl_options.config.net_name.lower() not in ["ib", "socket"]: + raise RuntimeError(f"net_name ({nccl_options.config.net_name}) is not supported." + f"Accepted values: 'IB' or 'socket'.") + return nccl_options + return None + + def _create_group( + self, + ranks, + timeout=None, + backend=None, + pg_options=None, + use_local_synchronization=False, + group_desc=None, + ): + """Creates a ProcessGroup.""" + kwargs = { + "ranks": ranks, + "timeout": timeout, + "backend": backend, + "pg_options": pg_options, + "use_local_synchronization": use_local_synchronization, + "group_desc": group_desc, + } + if not is_torch_min_version("2.4.0"): + kwargs.pop("group_desc") + if timeout is None: + kwargs.pop("timeout") + + group = dist.new_group(**kwargs) + if self.global_process_group_list is None: + self.global_process_group_list = [None] + if dist.get_rank() in ranks: + self.global_process_group_list.append(group) + return group + + def _create_hierarchical_groups( + self, + rank, + ranks, + hierarchical_group_sizes, + create_gloo_process_groups=False, + pg_options=None, + timeout=None, + group_desc=None, + ): + """Create hierarchical groups for a set of ranks.""" + if not HAVE_EINOPS: + raise ImportError("einops is not installed. Please install it with `pip install einops`.") + + hierarchical_groups = [] + hierarchical_groups_gloo = [] + if not isinstance(pg_options, list): + pg_options = [pg_options] * len(hierarchical_group_sizes) + + for level in range(len(hierarchical_group_sizes)): + rearranged_ranks = einops.rearrange( + np.array(ranks), + "(l s u) -> (l u) s", + u=int(np.prod(hierarchical_group_sizes[:level])), + s=hierarchical_group_sizes[level], + l=int(np.prod(hierarchical_group_sizes[level + 1:])), + ).tolist() + for sub_ranks in rearranged_ranks: + sub_group = self._create_group( + sub_ranks, + timeout=timeout, + pg_options=pg_options[level], + group_desc=f"HIERARCHICAL_{group_desc}_L{level}", + ) + if create_gloo_process_groups: + sub_group_gloo = self._create_group( + sub_ranks, + timeout=timeout, + backend="gloo", + pg_options=pg_options[level], + group_desc=f"HIERARCHICAL_{group_desc}_GLOO_L{level}", + ) + else: + sub_group_gloo = None + if rank in sub_ranks: + hierarchical_groups.append(sub_group) + hierarchical_groups_gloo.append(sub_group_gloo) + + assert rank not in ranks or len(hierarchical_groups) == len(hierarchical_group_sizes) + assert rank not in ranks or len(hierarchical_groups_gloo) == len(hierarchical_group_sizes) + return hierarchical_groups, hierarchical_groups_gloo + + def initialize_model_parallel( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_comm_backend: Optional[str] = None, + context_parallel_size: int = 1, + hierarchical_context_parallel_sizes: Optional[List[int]] = None, + expert_model_parallel_size: int = 1, + num_distributed_optimizer_instances: int = 1, + expert_tensor_parallel_size: Optional[int] = None, + nccl_communicator_config_path: Optional[str] = None, + distributed_timeout_minutes: int = 30, + order: str = "tp-cp-ep-dp-pp", + get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, + get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, + create_gloo_process_groups: bool = True, + high_priority_stream_groups: Optional[List[str]] = None, + ) -> None: + """Initialize model data parallel groups. + + This is the main initialization method that sets up all parallel groups. + """ + + def default_embedding_ranks(pp_ranks): + """Return the default ranks that constitute the stages on which the word embeddings live.""" + if len(pp_ranks) == 1: + return [pp_ranks[0]] + else: + return [pp_ranks[0], pp_ranks[-1]] + + def default_position_embedding_ranks(pp_ranks): + """Return the default ranks that constitute the stages on which the position embeddings live.""" + return [pp_ranks[0]] + + if get_embedding_ranks is None: + get_embedding_ranks = default_embedding_ranks + if get_position_embedding_ranks is None: + get_position_embedding_ranks = default_position_embedding_ranks + + # Get world size and rank + assert dist.is_initialized() + world_size: int = dist.get_world_size() + rank = dist.get_rank() + + model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + if world_size % model_size != 0: + raise RuntimeError(f"world_size ({world_size}) is not divisible by {model_size}") + + data_parallel_size: int = world_size // model_size + + if virtual_pipeline_model_parallel_size is not None: + if not pipeline_model_parallel_size > 1: + raise RuntimeError("pipeline-model-parallel size should be greater than 1 with interleaved schedule") + self.virtual_pipeline_model_parallel_rank = 0 + self.virtual_pipeline_model_parallel_world_size = virtual_pipeline_model_parallel_size + + # Load NCCL configs + nccl_comm_cfgs = {} + if nccl_communicator_config_path is not None: + try: + import yaml + except ImportError: + raise RuntimeError("Cannot import `yaml`. Setting custom nccl communicator configs " + "requires the yaml package.") + with open(nccl_communicator_config_path, "r") as stream: + nccl_comm_cfgs = yaml.safe_load(stream) + + # Set high priority stream groups + high_priority_stream_groups = high_priority_stream_groups or [] + for pg_name in high_priority_stream_groups: + if pg_name not in nccl_comm_cfgs: + nccl_comm_cfgs[pg_name] = {} + nccl_comm_cfgs[pg_name]["is_high_priority_stream"] = True + + # Create rank generators + self.decoder_rank_generator = RankGenerator( + tp=tensor_model_parallel_size, + ep=1, + dp=data_parallel_size, + pp=pipeline_model_parallel_size, + cp=context_parallel_size, + order=order, + rank_offset=0, + ) + + # Build expert rank generator + if expert_tensor_parallel_size is None: + expert_tensor_parallel_size = tensor_model_parallel_size + expert_tensor_model_pipeline_parallel_size = (expert_tensor_parallel_size * expert_model_parallel_size * + pipeline_model_parallel_size) + expert_data_parallel_size = world_size // expert_tensor_model_pipeline_parallel_size + if world_size % expert_tensor_model_pipeline_parallel_size != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" + ) + + self.expert_decoder_rank_generator = RankGenerator( + tp=expert_tensor_parallel_size, + ep=expert_model_parallel_size, + dp=expert_data_parallel_size, + pp=pipeline_model_parallel_size, + cp=1, + order=order, + rank_offset=0, + ) + + timeout = timedelta(minutes=distributed_timeout_minutes) + + # Build data-parallel groups with context parallel + assert self.data_parallel_group is None, "data parallel group is already initialized" + assert (data_parallel_size * context_parallel_size) % num_distributed_optimizer_instances == 0, ( + "Data parallel size should be divisible by partial DistOpt shard factor") + intra_partial_data_parallel_size = (data_parallel_size * + context_parallel_size) // num_distributed_optimizer_instances + + for ranks_with_cp in self.decoder_rank_generator.get_ranks('dp-cp'): + group_with_cp = self._create_group( + ranks_with_cp, + timeout=timeout, + pg_options=self._get_nccl_options("dp_cp", nccl_comm_cfgs), + group_desc="DATA_PARALLEL_GROUP_WITH_CP", + ) + if create_gloo_process_groups: + group_with_cp_gloo = self._create_group( + ranks_with_cp, + timeout=timeout, + backend="gloo", + group_desc="DATA_PARALLEL_GROUP_WITH_CP_GLOO", + ) + else: + group_with_cp_gloo = None + if rank in ranks_with_cp: + self.data_parallel_group_with_cp = group_with_cp + self.data_parallel_group_with_cp_gloo = group_with_cp_gloo + self.data_parallel_global_ranks_with_cp = ranks_with_cp + + if num_distributed_optimizer_instances > 1: + for i in range(num_distributed_optimizer_instances): + intra_partial_dp_ranks_with_cp = ranks_with_cp[( + i * intra_partial_data_parallel_size):((i + 1) * intra_partial_data_parallel_size)] + intra_partial_dp_group_with_cp = self._create_group( + intra_partial_dp_ranks_with_cp, + timeout=timeout, + pg_options=self._get_nccl_options("intra_dp_cp", nccl_comm_cfgs), + group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP", + ) + if create_gloo_process_groups: + intra_partial_dp_group_with_cp_gloo = self._create_group( + intra_partial_dp_ranks_with_cp, + timeout=timeout, + backend="gloo", + group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO", + ) + else: + intra_partial_dp_group_with_cp_gloo = None + if rank in intra_partial_dp_ranks_with_cp: + self.intra_partial_data_parallel_group_with_cp = intra_partial_dp_group_with_cp + self.intra_partial_data_parallel_group_with_cp_gloo = (intra_partial_dp_group_with_cp_gloo) + else: + self.intra_partial_data_parallel_group_with_cp = self.data_parallel_group_with_cp + self.intra_partial_data_parallel_group_with_cp_gloo = self.data_parallel_group_with_cp_gloo + + # Build data-parallel groups + for ranks in self.decoder_rank_generator.get_ranks('dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("dp", nccl_comm_cfgs), + group_desc="DATA_PARALLEL_GROUP", + ) + if create_gloo_process_groups: + group_gloo = self._create_group(ranks, + timeout=timeout, + backend="gloo", + group_desc="DATA_PARALLEL_GROUP_GLOO") + else: + group_gloo = None + if rank in ranks: + self.data_parallel_group = group + self.data_parallel_group_gloo = group_gloo + self.data_parallel_global_ranks = ranks + + # Build context-parallel groups + assert self.context_parallel_group is None, 'context parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("cp", nccl_comm_cfgs), + group_desc="CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.context_parallel_group = group + self.context_parallel_global_ranks = ranks + if hierarchical_context_parallel_sizes: + assert np.prod(hierarchical_context_parallel_sizes) == context_parallel_size + hierarchical_groups, _ = self._create_hierarchical_groups( + rank, + ranks, + hierarchical_context_parallel_sizes, + create_gloo_process_groups=False, + pg_options=self._get_nccl_options("hcp", nccl_comm_cfgs), + timeout=timeout, + group_desc="CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.hierarchical_context_parallel_groups = hierarchical_groups + + # Build model-parallel groups + assert self.model_parallel_group is None, 'model parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-pp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("mp", nccl_comm_cfgs), + group_desc="MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.model_parallel_group = group + self.model_parallel_global_ranks = ranks + + # Build tensor model-parallel groups + assert self.tensor_model_parallel_group is None, 'tensor model parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp", nccl_comm_cfgs), + group_desc="TENSOR_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_model_parallel_group = group + self.tensor_model_parallel_global_ranks = ranks + + # Build pipeline model-parallel groups and embedding groups + assert self.pipeline_model_parallel_group is None, "pipeline model parallel group is already initialized" + assert self.embedding_group is None, "embedding group is already initialized" + assert self.position_embedding_group is None, "position embedding group is already initialized" + + for ranks in self.decoder_rank_generator.get_ranks('pp'): + group = self._create_group( + ranks, + timeout=timeout, + backend=pipeline_model_parallel_comm_backend, + pg_options=(None if pipeline_model_parallel_comm_backend == "ucc" else self._get_nccl_options( + "pp", nccl_comm_cfgs)), + group_desc="PIPELINE_MODEL_PARALLEL_GROUP", + ) + assert ( + pipeline_model_parallel_comm_backend == None or pipeline_model_parallel_comm_backend == "nccl" + or pipeline_model_parallel_comm_backend == "ucc" + ), f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported' + + if rank in ranks: + if self.pipeline_model_parallel_group is None: + self.pipeline_model_parallel_group = group + self.pipeline_global_ranks = ranks + elif isinstance(self.pipeline_global_ranks[0], list): + if not isinstance(self.pipeline_model_parallel_group, list): + self.pipeline_model_parallel_group = [self.pipeline_model_parallel_group] + self.pipeline_model_parallel_group.append(group) + self.pipeline_global_ranks.append(ranks) + else: + self.pipeline_model_parallel_group = [self.pipeline_model_parallel_group, group] + self.pipeline_global_ranks = [self.pipeline_global_ranks, ranks] + + embedding_ranks = get_embedding_ranks(ranks) + group = self._create_group( + embedding_ranks, + timeout=timeout, + pg_options=self._get_nccl_options("embd", nccl_comm_cfgs), + group_desc="EMBEDDING_GROUP", + ) + if rank in embedding_ranks: + self.embedding_group = group + self.embedding_global_ranks = embedding_ranks + + position_embedding_ranks = get_position_embedding_ranks(ranks) + group = self._create_group( + position_embedding_ranks, + timeout=timeout, + pg_options=self._get_nccl_options("pos_embd", nccl_comm_cfgs), + group_desc="POSITION_EMBEDDING_GROUP", + ) + if rank in position_embedding_ranks: + self.position_embedding_group = group + self.position_embedding_global_ranks = position_embedding_ranks + + # Build tensor + data parallel groups + assert self.tensor_and_data_parallel_group is None, 'Tensor + data parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-dp-cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_dp_cp", nccl_comm_cfgs), + group_desc="TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP", + ) + if rank in ranks: + self.tensor_and_data_parallel_group_with_cp = group + for ranks in self.decoder_rank_generator.get_ranks('tp-dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_dp", nccl_comm_cfgs), + group_desc="TENSOR_AND_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_and_data_parallel_group = group + + assert self.tensor_and_context_parallel_group is None, 'Tensor + context parallel group is already initialized' + for ranks in self.decoder_rank_generator.get_ranks('tp-cp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_cp", nccl_comm_cfgs), + group_desc="TENSOR_AND_CONTEXT_PARALLEL_GROUP", + ) + if rank in ranks: + self.tensor_and_context_parallel_group = group + + # Build expert-related parallel groups + assert self.expert_model_parallel_group is None, 'Expert parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('ep'): + group = self._create_group( + ranks, + pg_options=self._get_nccl_options("ep", nccl_comm_cfgs), + group_desc="EXPERT_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_model_parallel_group = group + + assert self.expert_tensor_parallel_group is None, 'Expert tensor model parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("ep_tp", nccl_comm_cfgs), + group_desc="EXPERT_TENSOR_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_parallel_group = group + + assert self.expert_tensor_and_model_parallel_group is None, 'Expert tensor + model parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_ep_mp", nccl_comm_cfgs), + group_desc="EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_and_model_parallel_group = group + + assert self.expert_tensor_model_pipeline_parallel_group is None, 'The expert_tensor_model_pipeline parallel group is already initialized' + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep-pp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("tp_ep_pp", nccl_comm_cfgs), + group_desc="EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP", + ) + if rank in ranks: + self.expert_tensor_model_pipeline_parallel_group = group + + assert self.expert_data_parallel_group is None, "Expert data group is already initialized" + assert self.expert_data_parallel_group_gloo is None, "Expert data group-gloo is already initialized" + assert self.intra_partial_expert_data_parallel_group is None, "Intra partial expert data group is already initialized" + assert self.intra_partial_expert_data_parallel_group_gloo is None, "Intra partial expert data group-gloo is already initialized" + assert self.inter_partial_expert_data_parallel_group is None, "Inter partial expert data group is already initialized" + + assert (expert_data_parallel_size % num_distributed_optimizer_instances == 0 + ), "Expert data parallel size should be divisible by partial DistOpt shard factor" + intra_partial_expert_data_parallel_size = (expert_data_parallel_size // num_distributed_optimizer_instances) + + for ranks in self.expert_decoder_rank_generator.get_ranks('dp'): + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_nccl_options("ep_dp", nccl_comm_cfgs), + group_desc="EXPERT_DATA_PARALLEL_GROUP", + ) + if create_gloo_process_groups: + group_gloo = self._create_group(ranks, backend="gloo", group_desc="EXPERT_DATA_PARALLEL_GROUP_GLOO") + else: + group_gloo = None + if rank in ranks: + self.expert_data_parallel_group = group + self.expert_data_parallel_group_gloo = group_gloo + + if num_distributed_optimizer_instances > 1: + hierarchical_groups, hierarchical_groups_gloo = self._create_hierarchical_groups( + rank, + ranks, + [intra_partial_expert_data_parallel_size, num_distributed_optimizer_instances], + create_gloo_process_groups=create_gloo_process_groups, + pg_options=[ + self._get_nccl_options("intra_ep_dp", nccl_comm_cfgs), + self._get_nccl_options("inter_ep_dp", nccl_comm_cfgs), + ], + timeout=timeout, + group_desc="EXPERT_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.intra_partial_expert_data_parallel_group = hierarchical_groups[0] + self.intra_partial_expert_data_parallel_group_gloo = hierarchical_groups_gloo[0] + self.inter_partial_expert_data_parallel_group = hierarchical_groups[1] + else: + self.intra_partial_expert_data_parallel_group = self.expert_data_parallel_group + self.intra_partial_expert_data_parallel_group_gloo = self.expert_data_parallel_group_gloo + + # Build intra distributed optimizer instance group + assert self.intra_distributed_optimizer_instance_group is None, "Intra distributed optimizer instance group is already initialized" + model_parallel_group_id = 0 + intra_dist_opt_ranks = [] + for ranks in self.expert_decoder_rank_generator.get_ranks('tp-ep-pp'): + model_parallel_group_id += 1 + intra_dist_opt_ranks.extend(ranks) + if model_parallel_group_id % intra_partial_expert_data_parallel_size == 0: + intra_dist_opt_instance_group = self._create_group( + intra_dist_opt_ranks, + timeout=timeout, + pg_options=self._get_nccl_options("intra_dist_opt_instance", nccl_comm_cfgs), + group_desc="INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP", + ) + if rank in intra_dist_opt_ranks: + self.intra_distributed_optimizer_instance_group = intra_dist_opt_instance_group + intra_dist_opt_ranks = [] + + # Initialize global memory buffer + self._set_global_memory_buffer() + + def _set_global_memory_buffer(self): + """Initialize global buffer.""" + assert self.global_memory_buffer is None, "global memory buffer is already initialized" + self.global_memory_buffer = GlobalMemoryBuffer() + + # Getter methods for process groups + def get_model_parallel_group(self, check_initialized=True): + """Get the model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.model_parallel_group is not None, "model parallel group is not initialized" + return self.model_parallel_group + + def get_tensor_model_parallel_group(self, check_initialized=True): + """Get the tensor-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.tensor_model_parallel_group is not None, "tensor model parallel group is not initialized" + return self.tensor_model_parallel_group + + def get_pipeline_model_parallel_group(self, check_initialized=True): + """Get the pipeline-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.pipeline_model_parallel_group is not None, "pipeline_model parallel group is not initialized" + return self.pipeline_model_parallel_group + + def get_data_parallel_group(self, with_context_parallel=False, partial_data_parallel=False): + """Get the data-parallel group the caller rank belongs to.""" + if with_context_parallel: + if partial_data_parallel: + assert self.intra_partial_data_parallel_group_with_cp is not None, "Intra partial data parallel group is not initialized" + return self.intra_partial_data_parallel_group_with_cp + assert self.data_parallel_group_with_cp is not None, "data parallel group with context parallel combined is not initialized" + return self.data_parallel_group_with_cp + else: + assert self.data_parallel_group is not None, "data parallel group is not initialized" + assert partial_data_parallel == False, "Partial DP for Optimizer needs to include CP" + return self.data_parallel_group + + def get_context_parallel_group(self, check_initialized=True): + """Get the context-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.context_parallel_group is not None, "context parallel group is not initialized" + return self.context_parallel_group + + def get_embedding_group(self, check_initialized=True): + """Get the embedding group the caller rank belongs to.""" + if check_initialized: + assert self.embedding_group is not None, "embedding group is not initialized" + return self.embedding_group + + def get_tensor_and_data_parallel_group(self, check_initialized=True, with_context_parallel=False): + """Get the tensor- and data-parallel group the caller rank belongs to.""" + if with_context_parallel: + if check_initialized: + assert self.tensor_and_data_parallel_group_with_cp is not None, 'tensor and data parallel group is not initialized' + return self.tensor_and_data_parallel_group_with_cp + else: + if check_initialized: + assert self.tensor_and_data_parallel_group is not None, 'tensor and data parallel group is not initialized' + return self.tensor_and_data_parallel_group + + def get_tensor_and_context_parallel_group(self, check_initialized=True): + """Get the tensor- and context-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.tensor_and_context_parallel_group is not None, "tensor and context parallel group is not initialized" + return self.tensor_and_context_parallel_group + + # Getter methods for world sizes and ranks + def get_tensor_model_parallel_world_size(self): + """Return world size for the tensor-model-parallel group.""" + if self.mpu_tensor_model_parallel_world_size is not None: + return self.mpu_tensor_model_parallel_world_size + return self.get_tensor_model_parallel_group().size() + + def get_pipeline_model_parallel_world_size(self): + """Return world size for the pipeline-model-parallel group.""" + if self.mpu_pipeline_model_parallel_world_size is not None: + return self.mpu_pipeline_model_parallel_world_size + return self.get_pipeline_model_parallel_group().size() + + def get_tensor_model_parallel_rank(self): + """Return caller's rank for the tensor-model-parallel group.""" + if self.mpu_tensor_model_parallel_rank is not None: + return self.mpu_tensor_model_parallel_rank + return self.get_tensor_model_parallel_group().rank() + + def get_pipeline_model_parallel_rank(self): + """Return caller's rank for the pipeline-model-parallel group.""" + if self.mpu_pipeline_model_parallel_rank is not None: + return self.mpu_pipeline_model_parallel_rank + return dist.get_rank(group=self.get_pipeline_model_parallel_group()) + + def get_data_parallel_world_size(self, with_context_parallel=False, partial_data_parallel=False): + """Return world size for the data parallel group.""" + if self.mpu_data_parallel_world_size is not None: + return self.mpu_data_parallel_world_size + if dist.is_available() and dist.is_initialized(): + return self.get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel).size() + else: + return 0 + + def get_data_parallel_rank(self, with_context_parallel=False, partial_data_parallel=False): + """Return caller's rank in the data-parallel group.""" + if self.mpu_data_parallel_rank is not None: + return self.mpu_data_parallel_rank + if dist.is_available() and dist.is_initialized(): + return self.get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel).rank() + else: + return 0 + + def get_context_parallel_world_size(self): + """Return world size for the context parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_context_parallel_group().size() + else: + return 0 + + def get_context_parallel_rank(self): + """Return caller's rank in the context-parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_context_parallel_group().rank() + else: + return 0 + + def is_initialized(self): + """Check if parallel state has been initialized""" + return self.data_parallel_group is not None + + def get_global_memory_buffer(self): + """Return the global GlobalMemoryBuffer object""" + assert self.global_memory_buffer is not None, "global memory buffer is not initialized" + return self.global_memory_buffer + + # Expert-related getter methods + def get_expert_model_parallel_group(self, check_initialized=True): + """Get the expert-model-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.expert_model_parallel_group is not None, "expert model parallel group is not initialized" + return self.expert_model_parallel_group + + def get_expert_model_parallel_world_size(self): + """Return world size for the expert-model-parallel group.""" + if self.mpu_expert_model_parallel_world_size is not None: + return self.mpu_expert_model_parallel_world_size + if dist.is_available() and dist.is_initialized(): + return self.get_expert_model_parallel_group().size() + else: + return 0 + + def get_expert_model_parallel_rank(self): + """Return caller's rank in the expert-model-parallel group.""" + if self.mpu_expert_model_parallel_rank is not None: + return self.mpu_expert_model_parallel_rank + if dist.is_available() and dist.is_initialized(): + return self.get_expert_model_parallel_group().rank() + else: + return 0 + + def get_expert_tensor_parallel_group(self, check_initialized=True): + """Get the expert-tensor-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.expert_tensor_parallel_group is not None, "Expert tensor parallel group is not initialized" + return self.expert_tensor_parallel_group + + def get_expert_tensor_parallel_world_size(self): + """Return world size for the expert tensor parallel group.""" + if self.mpu_expert_tensor_parallel_world_size is not None: + return self.mpu_expert_tensor_parallel_world_size + if not self.expert_tensor_parallel_group: + return self.mpu_tensor_model_parallel_world_size + else: + return self.get_expert_tensor_parallel_group().size() + + def get_expert_tensor_parallel_rank(self): + """Return my rank for the expert tensor parallel group.""" + if self.mpu_expert_tensor_parallel_rank is not None: + return self.mpu_expert_tensor_parallel_rank + if not self.expert_tensor_parallel_group: + return self.mpu_tensor_model_parallel_rank + else: + return self.get_expert_tensor_parallel_group().rank() + + def get_expert_data_parallel_group(self, check_initialized=True, partial_expert_data_parallel=False): + """Get expert data parallel group.""" + if partial_expert_data_parallel: + if check_initialized: + assert self.intra_partial_expert_data_parallel_group is not None, "Intra partial expert data parallel group is not initialized" + return self.intra_partial_expert_data_parallel_group + else: + if check_initialized: + assert self.expert_data_parallel_group is not None, "Expert data parallel group is not initialized" + return self.expert_data_parallel_group + + def get_expert_data_parallel_rank(self, partial_expert_data_parallel=False): + """Return caller's rank in the expert data parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel).rank() + else: + return 0 + + def get_expert_data_parallel_world_size(self, partial_expert_data_parallel=False): + """Return world size for the expert data parallel group.""" + if dist.is_available() and dist.is_initialized(): + return self.get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel).size() + else: + return 0 + + +# Convenience function to create a singleton instance +_parallel_state_instance = None + + +def get_parallel_state() -> ParallelState: + """Get or create the global ParallelState instance.""" + global _parallel_state_instance + if _parallel_state_instance is None: + _parallel_state_instance = ParallelState() + return _parallel_state_instance diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py new file mode 100644 index 000000000000..bf3a346de194 --- /dev/null +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -0,0 +1,555 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team + +# DeepSpeed Team + +# The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DeepSpeed Compatibility Layer for parallel_state. + +This module provides module-level functions compatible with DeepSpeed's +groups.py API, allowing code written for DeepSpeed to work with the +refactored parallel_state module. + +Key Features: +- Supports multiple parallel state instances (for RL scenarios with different models) +- Backward compatible with single global instance +- Context manager for switching between different parallel configurations + +Usage: + # Basic usage (single global instance): + from parallel_state_deepspeed import get_data_parallel_group + dp_group = get_data_parallel_group() + + # Multi-instance usage (for RL scenarios): + from parallel_state_deepspeed import ( + get_parallel_state_instance, + set_current_parallel_state, + get_data_parallel_group, + ) + + # Create different instances for different models + actor_state = get_parallel_state_instance("actor") + critic_state = get_parallel_state_instance("critic") + + # Initialize with different DP sizes + actor_state.initialize_model_parallel(tensor_model_parallel_size=2, data_parallel_size=4) + critic_state.initialize_model_parallel(tensor_model_parallel_size=1, data_parallel_size=8) + + # Use context manager to switch + with set_current_parallel_state("actor"): + actor_dp_group = get_data_parallel_group() # Uses actor's DP group + + with set_current_parallel_state("critic"): + critic_dp_group = get_data_parallel_group() # Uses critic's DP group +""" + +from contextlib import contextmanager +from typing import Optional +from parallel_state import ParallelState, get_parallel_state as _get_default_parallel_state + +# Registry for multiple parallel state instances +_parallel_state_registry = {} +_default_instance_name = "__default__" + +# Current active instance name (thread-local would be better, but using global for simplicity) +_current_instance_name = _default_instance_name + + +def get_parallel_state_instance(name: Optional[str] = None) -> ParallelState: + """Get or create a named ParallelState instance. + + Args: + name: Name of the instance. If None, returns the default global instance. + Use different names for different models in RL scenarios. + + Returns: + ParallelState instance + + Example: + # For RL with actor and critic models + actor_state = get_parallel_state_instance("actor") + critic_state = get_parallel_state_instance("critic") + """ + if name is None: + return _get_default_parallel_state() + + if name not in _parallel_state_registry: + _parallel_state_registry[name] = ParallelState() + + return _parallel_state_registry[name] + + +def set_current_parallel_state(name: Optional[str] = None): + """Set the current active parallel state instance. + + Args: + name: Name of the instance to activate. If None, uses the default instance. + + Returns: + Context manager for temporarily switching the active instance + + Example: + with set_current_parallel_state("actor"): + dp_group = get_data_parallel_group() # Uses actor's DP group + """ + + @contextmanager + def _context(): + global _current_instance_name + old_name = _current_instance_name + _current_instance_name = name if name is not None else _default_instance_name + try: + yield + finally: + _current_instance_name = old_name + + return _context() + + +def get_current_parallel_state() -> ParallelState: + """Get the currently active parallel state instance. + + Returns: + The currently active ParallelState instance + """ + return get_parallel_state_instance(_current_instance_name) + + +def get_parallel_state(name: Optional[str] = None) -> ParallelState: + """Get parallel state instance (backward compatible). + + If name is provided, returns the named instance. + Otherwise, returns the currently active instance. + + Args: + name: Optional name of the instance. If None, returns current active instance. + + Returns: + ParallelState instance + """ + if name is not None: + return get_parallel_state_instance(name) + return get_current_parallel_state() + + +# ============================================================================ +# Core Tensor/Model/Data Parallel Functions +# ============================================================================ + + +def get_tensor_model_parallel_group(name: Optional[str] = None): + """Get the tensor model parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + Use this in RL scenarios to specify which model's parallel groups to use. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_group() + + +def get_model_parallel_group(name: Optional[str] = None): + """Get the model parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_model_parallel_group() + + +def get_data_parallel_group(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Get the data parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + Use this in RL scenarios to specify which model's DP group to use. + For example, "actor" vs "critic" may have different DP sizes. + with_context_parallel: Whether to include context parallel in the group. + partial_data_parallel: Whether to use partial data parallel group. + + DeepSpeed-compatible interface. + + Example: + # In RL scenario with different DP sizes: + actor_dp = get_data_parallel_group("actor") # Actor's DP group + critic_dp = get_data_parallel_group("critic") # Critic's DP group + + # Or use context manager: + with set_current_parallel_state("actor"): + dp_group = get_data_parallel_group() # Uses actor's DP group + """ + return get_parallel_state(name).get_data_parallel_group(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_tensor_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the tensor model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_world_size() + + +def get_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_world_size() + + +def get_tensor_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the tensor-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_rank() + + +def get_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_model_parallel_rank() + + +def get_data_parallel_world_size(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Return world size for the data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + partial_data_parallel: Whether to use partial data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_data_parallel_world_size(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_data_parallel_rank(name: Optional[str] = None, + with_context_parallel: bool = False, + partial_data_parallel: bool = False): + """Return caller's rank in the data-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + partial_data_parallel: Whether to use partial data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_data_parallel_rank(with_context_parallel=with_context_parallel, + partial_data_parallel=partial_data_parallel) + + +def get_tensor_model_parallel_src_rank(name: Optional[str] = None): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + import torch.distributed as dist + global_rank = dist.get_rank() + local_world_size = get_tensor_model_parallel_world_size(name) + return (global_rank // local_world_size) * local_world_size + + +def set_tensor_model_parallel_world_size(world_size, name: Optional[str] = None): + """Set the tensor model parallel size. + + Args: + world_size: World size to set. + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + ps = get_parallel_state(name) + ps.mpu_tensor_model_parallel_world_size = world_size + + +def set_tensor_model_parallel_rank(rank, name: Optional[str] = None): + """Set tensor model parallel rank. + + Args: + rank: Rank to set. + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + ps = get_parallel_state(name) + ps.mpu_tensor_model_parallel_rank = rank + + +# ============================================================================ +# Pipeline Parallel Functions +# ============================================================================ + + +def get_pipeline_model_parallel_group(name: Optional[str] = None): + """Get the pipeline-model-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_group() + + +def get_pipeline_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the pipeline-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_world_size() + + +def get_pipeline_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank for the pipeline-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_pipeline_model_parallel_rank() + + +# ============================================================================ +# Context Parallel Functions +# ============================================================================ + + +def get_context_parallel_group(name: Optional[str] = None): + """Get the context-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_group() + + +def get_context_parallel_world_size(name: Optional[str] = None): + """Return world size for the context parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_world_size() + + +def get_context_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the context-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_context_parallel_rank() + + +# ============================================================================ +# Expert Parallel Functions +# ============================================================================ + + +def get_expert_model_parallel_group(name: Optional[str] = None): + """Get the expert-model-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_group() + + +def get_expert_model_parallel_world_size(name: Optional[str] = None): + """Return world size for the expert-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_world_size() + + +def get_expert_model_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the expert-model-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_model_parallel_rank() + + +def get_expert_tensor_parallel_group(name: Optional[str] = None): + """Get the expert-tensor-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_group() + + +def get_expert_tensor_parallel_world_size(name: Optional[str] = None): + """Return world size for the expert tensor parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_world_size() + + +def get_expert_tensor_parallel_rank(name: Optional[str] = None): + """Return my rank for the expert tensor parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_tensor_parallel_rank() + + +def get_expert_data_parallel_group(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Get expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_group( + partial_expert_data_parallel=partial_expert_data_parallel) + + +def get_expert_data_parallel_world_size(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Return world size for the expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_world_size( + partial_expert_data_parallel=partial_expert_data_parallel) + + +def get_expert_data_parallel_rank(name: Optional[str] = None, partial_expert_data_parallel: bool = False): + """Return caller's rank in the expert data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + partial_expert_data_parallel: Whether to use partial expert data parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_expert_data_parallel_rank( + partial_expert_data_parallel=partial_expert_data_parallel) + + +# ============================================================================ +# Additional Helper Functions +# ============================================================================ + + +def get_embedding_group(name: Optional[str] = None): + """Get the embedding group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_embedding_group() + + +def get_tensor_and_data_parallel_group(name: Optional[str] = None, with_context_parallel: bool = False): + """Get the tensor- and data-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + with_context_parallel: Whether to include context parallel. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_and_data_parallel_group(with_context_parallel=with_context_parallel) + + +def get_tensor_and_context_parallel_group(name: Optional[str] = None): + """Get the tensor- and context-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_tensor_and_context_parallel_group() + + +def is_initialized(name: Optional[str] = None): + """Check if parallel state has been initialized. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).is_initialized() diff --git a/tests/unit/utils/test_mpu.py b/tests/unit/utils/test_mpu.py new file mode 100644 index 000000000000..11ed585c92b3 --- /dev/null +++ b/tests/unit/utils/test_mpu.py @@ -0,0 +1,1692 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) DeepSpeed Team + +# DeepSpeed Team +""" +Automated testing of parallel strategy combinations using random configurations. + +This test automatically generates random parallel configurations and tests +both parallel_state_refactored and DeepSpeed to see if they produce compatible results. +""" + +import pytest +import random +from typing import Dict, List, Tuple, Optional +from collections import defaultdict + +# Try to import both libraries +try: + from deepspeed.utils.parallel_state import RankGenerator + PARALLEL_STATE_AVAILABLE = True +except ImportError as e: + PARALLEL_STATE_AVAILABLE = False + print(f"Warning: Could not import Megatron parallel_state_refactored: {e}") + +try: + from deepspeed.utils import groups as ds_groups + from deepspeed.runtime.sequence_parallel import parallel_state_sp as ds_sp + DEEPSPEED_AVAILABLE = True +except ImportError as e: + DEEPSPEED_AVAILABLE = False + print(f"Warning: Could not import DeepSpeed: {e}") + + +class ParallelConfigGenerator: + """Generate random parallel configurations for testing.""" + + def __init__(self, seed=None): + if seed is not None: + random.seed(seed) + self.tested_configs = [] + self.failed_configs = [] + + def generate_random_config(self, max_size=1024, min_parallel_size=1, max_parallel_size=32): + """Generate a random parallel configuration. + + Args: + max_size: Maximum world size to consider + min_parallel_size: Minimum parallel size for each dimension + max_parallel_size: Maximum parallel size for each dimension + + Returns: + Dict with tp, dp, pp, cp, ep values and order + """ + # Generate random sizes for each dimension + # Don't filter invalid configurations - we want to test and report all cases + tp = random.randint(min_parallel_size, max_parallel_size) + dp = random.randint(min_parallel_size, max_parallel_size) + pp = random.randint(min_parallel_size, max_parallel_size) + cp = random.randint(min_parallel_size, max_parallel_size) + ep = random.randint(min_parallel_size, max_parallel_size) + + # Calculate world size + world_size = tp * dp * pp * cp * ep + + # If world size is too large, scale down proportionally + # But try to keep at least one dimension > 1 + if world_size > max_size: + # Scale down proportionally + scale_factor = (max_size / world_size)**0.25 + tp = max(1, int(tp * scale_factor)) + dp = max(1, int(dp * scale_factor)) + pp = max(1, int(pp * scale_factor)) + cp = max(1, int(cp * scale_factor)) + ep = max(1, int(ep * scale_factor)) + world_size = tp * dp * pp * cp * ep + + # Ensure at least one dimension is > 1 + if world_size == 1: + tp = 2 + world_size = 2 + + # Generate random order (but must include all non-1 dimensions) + dimensions = [] + if tp > 1: + dimensions.append('tp') + if dp > 1: + dimensions.append('dp') + if pp > 1: + dimensions.append('pp') + if cp > 1: + dimensions.append('cp') + if ep > 1: + dimensions.append('ep') + + # Shuffle to get random order + random.shuffle(dimensions) + order = '-'.join(dimensions) if dimensions else 'tp' + + # If no dimensions > 1, use default + if not dimensions: + order = 'tp-dp' + tp = 2 + dp = 2 + + config = { + "tp": tp, + "dp": dp, + "pp": pp, + "cp": cp, + "ep": ep, + "order": order, + "world_size": tp * dp * pp * cp * ep, + } + + return config + + def generate_systematic_configs(self, max_world_size=512): + """Generate systematic configurations covering common cases. + + Args: + max_world_size: Maximum world size to consider + + Returns: + List of configurations + """ + configs = [] + + # Single parallelism - test larger sizes + for size in [2, 4, 8, 16, 32, 64, 128, 256]: + if size <= max_world_size: + configs.append({"tp": size, "dp": 1, "pp": 1, "cp": 1, "ep": 1, "order": "tp", "world_size": size}) + configs.append({"tp": 1, "dp": size, "pp": 1, "cp": 1, "ep": 1, "order": "dp", "world_size": size}) + configs.append({"tp": 1, "dp": 1, "pp": size, "cp": 1, "ep": 1, "order": "pp", "world_size": size}) + + # Two-way combinations - more variations + for tp, dp in [(2, 2), (2, 4), (4, 2), (2, 8), (8, 2), (4, 4), (2, 16), (16, 2), (4, 8), (8, 4)]: + if tp * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp-dp", + "world_size": tp * dp + }) + configs.append({ + "tp": tp, + "dp": dp, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp-tp", + "world_size": tp * dp + }) + + for tp, pp in [(2, 2), (2, 4), (4, 2), (2, 8), (8, 2), (4, 4)]: + if tp * pp <= max_world_size: + configs.append({ + "tp": tp, + "dp": 1, + "pp": pp, + "cp": 1, + "ep": 1, + "order": "tp-pp", + "world_size": tp * pp + }) + + for tp, cp in [(2, 2), (2, 4), (4, 2), (2, 8)]: + if tp * cp <= max_world_size: + configs.append({ + "tp": tp, + "dp": 1, + "pp": 1, + "cp": cp, + "ep": 1, + "order": "tp-cp", + "world_size": tp * cp + }) + + for tp, ep in [(2, 2), (2, 4), (4, 2), (2, 8)]: + if tp * ep <= max_world_size: + configs.append({ + "tp": tp, + "dp": 1, + "pp": 1, + "cp": 1, + "ep": ep, + "order": "tp-ep", + "world_size": tp * ep + }) + + # Three-way combinations - more variations + for tp, pp, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2), (4, 2, 2), (2, 2, 8), (2, 4, 4), (4, 4, 2)]: + if tp * pp * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": pp, + "cp": 1, + "ep": 1, + "order": "tp-pp-dp", + "world_size": tp * pp * dp + }) + configs.append({ + "tp": tp, + "dp": dp, + "pp": pp, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": tp * pp * dp + }) + + for tp, cp, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2)]: + if tp * cp * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": 1, + "cp": cp, + "ep": 1, + "order": "tp-cp-dp", + "world_size": tp * cp * dp + }) + + for tp, ep, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2)]: + if tp * ep * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": 1, + "cp": 1, + "ep": ep, + "order": "tp-ep-dp", + "world_size": tp * ep * dp + }) + + # Four-way combinations - more variations + for tp, pp, dp, cp in [(2, 2, 2, 2), (2, 2, 2, 4), (2, 2, 4, 2), (2, 4, 2, 2)]: + if tp * pp * dp * cp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": pp, + "cp": cp, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": tp * pp * dp * cp + }) + + for tp, ep, pp, dp in [(2, 2, 2, 2), (2, 2, 2, 4), (2, 2, 4, 2)]: + if tp * ep * pp * dp <= max_world_size: + configs.append({ + "tp": tp, + "dp": dp, + "pp": pp, + "cp": 1, + "ep": ep, + "order": "tp-ep-pp-dp", + "world_size": tp * ep * pp * dp + }) + + return configs + + def generate_random_configs(self, count=1000, max_size=1024): + """Generate multiple random configurations. + + Args: + count: Number of random configurations to generate + max_size: Maximum world size + + Returns: + List of configurations + """ + configs = [] + seen = set() + + for _ in range(count): + config = self.generate_random_config(max_size=max_size) + # Create a unique key for this configuration + key = (config["tp"], config["dp"], config["pp"], config["cp"], config["ep"], config["order"]) + if key not in seen: + seen.add(key) + configs.append(config) + + return configs + + def generate_random_config_by_dimension(self, + dimension_count: int, + max_size=1024, + min_parallel_size=2, + max_parallel_size=32): + """Generate a random configuration with exactly the specified number of dimensions > 1. + + Args: + dimension_count: Number of dimensions that should be > 1 (1-5) + max_size: Maximum world size + min_parallel_size: Minimum parallel size for each dimension + max_parallel_size: Maximum parallel size for each dimension + + Returns: + Dict with tp, dp, pp, cp, ep values and order + """ + # All possible dimensions + all_dims = ['tp', 'dp', 'pp', 'cp', 'ep'] + + # Randomly select which dimensions to activate + active_dims = random.sample(all_dims, min(dimension_count, len(all_dims))) + + # Initialize all dimensions to 1 + config = { + "tp": 1, + "dp": 1, + "pp": 1, + "cp": 1, + "ep": 1, + } + + # Set active dimensions to random values + for dim in active_dims: + config[dim] = random.randint(min_parallel_size, max_parallel_size) + + # Calculate world size + world_size = config["tp"] * config["dp"] * config["pp"] * config["cp"] * config["ep"] + + # If world size is too large, scale down proportionally + if world_size > max_size: + scale_factor = (max_size / world_size)**(1.0 / dimension_count) + for dim in active_dims: + config[dim] = max(min_parallel_size, int(config[dim] * scale_factor)) + world_size = config["tp"] * config["dp"] * config["pp"] * config["cp"] * config["ep"] + + # Generate random order from active dimensions + random.shuffle(active_dims) + order = '-'.join(active_dims) + + config["order"] = order + config["world_size"] = world_size + + return config + + def generate_random_configs_by_dimension(self, + counts_by_dimension: Dict[int, int], + max_size=1024, + min_parallel_size=2, + max_parallel_size=32): + """Generate random configurations for each dimension separately. + + Args: + counts_by_dimension: Dict mapping dimension count (1-5) to number of configs to generate + e.g., {1: 100, 2: 200, 3: 150, 4: 100, 5: 50} + max_size: Maximum world size + min_parallel_size: Minimum parallel size for each dimension + max_parallel_size: Maximum parallel size for each dimension + + Returns: + List of configurations grouped by dimension count + """ + all_configs = [] + seen = set() + + for dim_count, count in counts_by_dimension.items(): + if dim_count < 1 or dim_count > 5: + continue + + dim_configs = [] + attempts = 0 + # Increased max_attempts for larger test sets (20x more configs) + max_attempts = count * 20 # Prevent infinite loops, allow more attempts for uniqueness + + while len(dim_configs) < count and attempts < max_attempts: + attempts += 1 + config = self.generate_random_config_by_dimension(dim_count, max_size, min_parallel_size, + max_parallel_size) + + # Create a unique key for this configuration + key = (config["tp"], config["dp"], config["pp"], config["cp"], config["ep"], config["order"]) + + if key not in seen: + seen.add(key) + dim_configs.append(config) + all_configs.append(config) + + if len(dim_configs) < count: + print( + f"Warning: Only generated {len(dim_configs)}/{count} configs for {dim_count}D combinations (attempted {attempts} times)" + ) + + return all_configs + + +class ErrorCategorizer: + """Categorize and aggregate errors by type.""" + + def __init__(self): + self.error_categories = defaultdict(list) + self.combination_stats = defaultdict(int) + + def categorize_error(self, error_msg: str, config: Dict) -> str: + """Categorize an error message into a category.""" + error_lower = error_msg.lower() + + if "ep and cp cannot both be > 1" in error_lower: + return "EP_CP_CONFLICT" + elif "cp not supported" in error_lower: + return "CP_NOT_SUPPORTED" + elif "pp requires" in error_lower or "pipeline" in error_lower: + return "PP_REQUIRES_MPU" + elif "not divisible" in error_lower: + return "DIVISIBILITY_ERROR" + elif "order" in error_lower and "specified" in error_lower: + return "ORDER_MISMATCH" + elif "not available" in error_lower: + return "FEATURE_NOT_AVAILABLE" + else: + return "OTHER_ERROR" + + def get_combination_type(self, config: Dict) -> str: + """Get the combination type string for a configuration.""" + dims = [] + if config["tp"] > 1: + dims.append("TP") + if config["dp"] > 1: + dims.append("DP") + if config["pp"] > 1: + dims.append("PP") + if config["cp"] > 1: + dims.append("CP") + if config["ep"] > 1: + dims.append("EP") + + if not dims: + return "NONE" + + return "+".join(sorted(dims)) + + def record_error(self, error_msg: str, config: Dict, library: str): + """Record an error with categorization.""" + category = self.categorize_error(error_msg, config) + combo_type = self.get_combination_type(config) + + self.error_categories[category].append({ + "error": error_msg, + "config": config, + "library": library, + "combination": combo_type, + }) + + self.combination_stats[combo_type] += 1 + + def get_error_summary(self) -> Dict: + """Get summary of errors by category.""" + summary = {} + for category, errors in self.error_categories.items(): + summary[category] = { + "count": len(errors), + "examples": errors[:5], # First 5 examples + "unique_combinations": len(set(e["combination"] for e in errors)), + } + return summary + + +class ParallelCompatibilityTester: + """Test compatibility between Megatron and DeepSpeed for parallel configurations.""" + + def __init__(self): + self.results = { + "megatron_success": [], + "megatron_failures": [], + "deepspeed_success": [], + "deepspeed_failures": [], + "compatible": [], + "incompatible": [], + "megatron_only": [], + "deepspeed_only": [], + } + self.error_categorizer = ErrorCategorizer() + self.combination_stats = defaultdict( + lambda: { + "total": 0, + "megatron_success": 0, + "megatron_failures": 0, + "deepspeed_success": 0, + "deepspeed_failures": 0, + "compatible": 0, + "megatron_only": 0, + "deepspeed_only": 0, + "incompatible": 0, + }) + + def test_megatron_config(self, config: Dict) -> Tuple[bool, Optional[str], Optional[Dict]]: + """Test if a configuration works with Megatron. + + Returns: + (success, error_message, result_data) + """ + if not PARALLEL_STATE_AVAILABLE: + return False, "Megatron not available", None + + try: + # Check EP and CP constraint + if config["ep"] > 1 and config["cp"] > 1: + return False, "EP and CP cannot both be > 1 in Megatron", None + + # Create RankGenerator + rg = RankGenerator(tp=config["tp"], + ep=config["ep"], + dp=config["dp"], + pp=config["pp"], + cp=config["cp"], + order=config["order"]) + + # Test getting ranks for each dimension + result_data = { + "world_size": rg.world_size, + "tp_groups": rg.get_ranks("tp") if config["tp"] > 1 else [], + "dp_groups": rg.get_ranks("dp") if config["dp"] > 1 else [], + "pp_groups": rg.get_ranks("pp") if config["pp"] > 1 else [], + "cp_groups": rg.get_ranks("cp") if config["cp"] > 1 else [], + "ep_groups": rg.get_ranks("ep") if config["ep"] > 1 else [], + } + + # Test combined groups + if len([d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1]) > 1: + combined_token = config["order"] + result_data["combined_groups"] = rg.get_ranks(combined_token) + + return True, None, result_data + + except Exception as e: + return False, str(e), None + + def test_deepspeed_config(self, config: Dict) -> Tuple[bool, Optional[str], Optional[Dict]]: + """Test if a configuration is supported by DeepSpeed. + + Returns: + (supported, error_message, support_info) + """ + if not DEEPSPEED_AVAILABLE: + return False, "DeepSpeed not available", None + + support_info = { + "tp_supported": False, + "dp_supported": False, + "pp_supported": False, + "cp_supported": False, + "ep_supported": False, + "sp_supported": False, + "notes": [], + } + + # Check TP support + if config["tp"] > 1: + support_info["tp_supported"] = hasattr(ds_groups, 'get_tensor_model_parallel_group') + + # Check DP support + if config["dp"] > 1: + support_info["dp_supported"] = hasattr(ds_groups, 'get_data_parallel_group') + + # Check PP support + if config["pp"] > 1: + # DeepSpeed supports PP via mpu or pipe module + support_info["pp_supported"] = (hasattr(ds_groups, 'bwc_pipeline_parallel_world_size') + or self._check_module_exists('deepspeed.pipe')) + if not support_info["pp_supported"]: + support_info["notes"].append("PP requires mpu object or deepspeed.pipe module") + + # Check CP support + if config["cp"] > 1: + support_info["cp_supported"] = hasattr(ds_groups, 'get_context_parallel_group') + if not support_info["cp_supported"]: + support_info["notes"].append("CP not supported in DeepSpeed") + + # Check EP support + if config["ep"] > 1: + support_info["ep_supported"] = (hasattr(ds_groups, '_create_expert_and_data_parallel') + or hasattr(ds_groups, '_create_expert_data_and_model_parallel')) + + # Check SP support (DeepSpeed-specific) + support_info["sp_supported"] = hasattr(ds_sp, 'initialize_sequence_parallel') + + # Determine overall support + required_dims = [d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1] + supported_dims = [] + if config["tp"] > 1 and support_info["tp_supported"]: + supported_dims.append("tp") + if config["dp"] > 1 and support_info["dp_supported"]: + supported_dims.append("dp") + if config["pp"] > 1 and support_info["pp_supported"]: + supported_dims.append("pp") + if config["cp"] > 1 and support_info["cp_supported"]: + supported_dims.append("cp") + if config["ep"] > 1 and support_info["ep_supported"]: + supported_dims.append("ep") + + fully_supported = len(supported_dims) == len(required_dims) + + return fully_supported, None, support_info + + def _check_module_exists(self, module_name): + """Check if a module exists.""" + try: + __import__(module_name) + return True + except ImportError: + return False + + def _simulate_deepspeed_rank_generation(self, config: Dict) -> Optional[Dict]: + """Simulate DeepSpeed's rank generation logic based on code analysis. + + This attempts to replicate DeepSpeed's rank assignment logic for comparison. + """ + if not DEEPSPEED_AVAILABLE: + return None + + try: + world_size = config["world_size"] + result = {} + + # For TP+DP: DeepSpeed uses mesh_device which creates groups in a specific way + if config["tp"] > 1 and config["dp"] > 1 and config["pp"] == 1 and config["cp"] == 1 and config["ep"] == 1: + # DeepSpeed's _init_tp_mesh_device creates: + # TP groups: [0,1], [2,3], [4,5], ... (consecutive) + # DP groups: [0,2,4,...], [1,3,5,...] (strided) + tp_size = config["tp"] + dp_size = config["dp"] + + tp_groups = [] + for i in range(world_size // tp_size): + group = list(range(i * tp_size, (i + 1) * tp_size)) + tp_groups.append(group) + + dp_groups = [] + for i in range(tp_size): + group = list(range(i, world_size, tp_size)) + dp_groups.append(group) + + result["tp_groups"] = tp_groups + result["dp_groups"] = dp_groups + result["world_size"] = world_size + return result + + # For other combinations, we can't easily simulate without actual distributed setup + # But we can note that DeepSpeed supports it + return {"supported": True, "note": "Rank generation requires actual distributed setup"} + + except Exception as e: + return {"error": str(e)} + + def _compare_rank_groups(self, megatron_groups: List[List[int]], deepspeed_groups: List[List[int]]) -> Dict: + """Compare rank groups from Megatron and DeepSpeed. + + Returns: + Dict with comparison results + """ + comparison = {"same_structure": False, "same_ranks": False, "differences": []} + + if not megatron_groups or not deepspeed_groups: + return comparison + + # Check if same number of groups + if len(megatron_groups) != len(deepspeed_groups): + comparison["differences"].append( + f"Group count mismatch: Megatron={len(megatron_groups)}, DeepSpeed={len(deepspeed_groups)}") + return comparison + + # Check if same group sizes + megatron_sizes = [len(g) for g in megatron_groups] + deepspeed_sizes = [len(g) for g in deepspeed_groups] + if megatron_sizes != deepspeed_sizes: + comparison["differences"].append( + f"Group size mismatch: Megatron={megatron_sizes}, DeepSpeed={deepspeed_sizes}") + return comparison + + # Check if same ranks (order may differ) + megatron_ranks = set() + for group in megatron_groups: + megatron_ranks.update(group) + + deepspeed_ranks = set() + for group in deepspeed_groups: + deepspeed_ranks.update(group) + + if megatron_ranks != deepspeed_ranks: + comparison["differences"].append( + f"Rank set mismatch: Megatron={sorted(megatron_ranks)}, DeepSpeed={sorted(deepspeed_ranks)}") + return comparison + + # Check if same structure (same groups, possibly different order) + megatron_sets = [set(g) for g in megatron_groups] + deepspeed_sets = [set(g) for g in deepspeed_groups] + + if sorted(megatron_sets, key=lambda x: min(x)) == sorted(deepspeed_sets, key=lambda x: min(x)): + comparison["same_structure"] = True + comparison["same_ranks"] = True + else: + comparison["differences"].append("Group structure differs (same ranks but different grouping)") + + return comparison + + def test_config_compatibility(self, config: Dict): + """Test compatibility of a configuration between both libraries.""" + # Get combination type for statistics + combo_type = self.error_categorizer.get_combination_type(config) + self.combination_stats[combo_type]["total"] += 1 + + # Test Megatron + megatron_success, megatron_error, megatron_result = self.test_megatron_config(config) + + # Test DeepSpeed + deepspeed_success, deepspeed_error, deepspeed_support = self.test_deepspeed_config(config) + + # Record errors in categorizer + if not megatron_success and megatron_error: + self.error_categorizer.record_error(megatron_error, config, "Megatron") + self.combination_stats[combo_type]["megatron_failures"] += 1 + else: + self.combination_stats[combo_type]["megatron_success"] += 1 + + if not deepspeed_success: + # Get error message from support_info notes + error_msg = deepspeed_support.get("notes", ["Not supported"])[0] if deepspeed_support else "Not supported" + self.error_categorizer.record_error(error_msg, config, "DeepSpeed") + self.combination_stats[combo_type]["deepspeed_failures"] += 1 + else: + self.combination_stats[combo_type]["deepspeed_success"] += 1 + + # Try to simulate DeepSpeed rank generation for comparison + deepspeed_simulated = None + if megatron_success and deepspeed_success: + deepspeed_simulated = self._simulate_deepspeed_rank_generation(config) + + # Compare rank generation if both succeeded and we have simulated results + rank_comparison = None + if megatron_success and deepspeed_success and deepspeed_simulated and "tp_groups" in deepspeed_simulated: + # Compare TP groups + if config["tp"] > 1 and "tp_groups" in megatron_result: + rank_comparison = self._compare_rank_groups(megatron_result["tp_groups"], + deepspeed_simulated.get("tp_groups", [])) + # Compare DP groups + if config["dp"] > 1 and "dp_groups" in megatron_result and not rank_comparison: + rank_comparison = self._compare_rank_groups(megatron_result["dp_groups"], + deepspeed_simulated.get("dp_groups", [])) + + # Record results + config_key = f"tp={config['tp']},dp={config['dp']},pp={config['pp']},cp={config['cp']},ep={config['ep']},order={config['order']}" + + if megatron_success: + self.results["megatron_success"].append(config_key) + else: + self.results["megatron_failures"].append({ + "config": config_key, + "error": megatron_error, + "combination": combo_type, + }) + + if deepspeed_success: + self.results["deepspeed_success"].append(config_key) + else: + self.results["deepspeed_failures"].append({ + "config": config_key, + "error": deepspeed_error, + "support_info": deepspeed_support, + "combination": combo_type, + }) + + # Determine compatibility and update stats + if megatron_success and deepspeed_success: + compat_entry = { + "config": config_key, + "megatron_result": megatron_result, + "deepspeed_support": deepspeed_support, + "combination": combo_type, + } + if rank_comparison: + compat_entry["rank_comparison"] = rank_comparison + if rank_comparison.get("same_structure"): + compat_entry["rank_match"] = True + else: + compat_entry["rank_match"] = False + compat_entry["rank_differences"] = rank_comparison.get("differences", []) + self.results["compatible"].append(compat_entry) + self.combination_stats[combo_type]["compatible"] += 1 + elif megatron_success and not deepspeed_success: + self.results["megatron_only"].append({ + "config": + config_key, + "megatron_result": + megatron_result, + "deepspeed_issue": + deepspeed_support.get("notes", []) if deepspeed_support else [], + "combination": + combo_type, + }) + self.combination_stats[combo_type]["megatron_only"] += 1 + elif not megatron_success and deepspeed_success: + self.results["deepspeed_only"].append({ + "config": config_key, + "megatron_error": megatron_error, + "deepspeed_support": deepspeed_support, + "combination": combo_type, + }) + self.combination_stats[combo_type]["deepspeed_only"] += 1 + else: + self.results["incompatible"].append({ + "config": + config_key, + "megatron_error": + megatron_error, + "deepspeed_issue": + deepspeed_support.get("notes", []) if deepspeed_support else [], + "combination": + combo_type, + }) + self.combination_stats[combo_type]["incompatible"] += 1 + + +class TestAutomatedParallelCombinations: + """Automated tests for parallel strategy combinations.""" + + def test_systematic_configurations(self): + """Test systematic configurations covering common cases.""" + generator = ParallelConfigGenerator(seed=42) + tester = ParallelCompatibilityTester() + + configs = generator.generate_systematic_configs(max_world_size=16) + + print("\n" + "=" * 80) + print("SYSTEMATIC CONFIGURATION TESTING") + print("=" * 80) + print(f"\nTesting {len(configs)} systematic configurations...") + + for i, config in enumerate(configs, 1): + print(f"\n[{i}/{len(configs)}] Testing: {config}") + tester.test_config_compatibility(config) + + self._print_results(tester, "Systematic") + self._generate_comprehensive_report(tester, "Systematic") + + def test_random_configurations(self): + """Test random configurations.""" + generator = ParallelConfigGenerator(seed=123) + tester = ParallelCompatibilityTester() + + configs = generator.generate_random_configs(count=1000, max_size=1024) + + print("\n" + "=" * 80) + print("RANDOM CONFIGURATION TESTING") + print("=" * 80) + print(f"\nTesting {len(configs)} random configurations...") + print(f"Max world size: 1024, Max parallel size per dimension: 32") + + for i, config in enumerate(configs, 1): + if i % 100 == 0: + print(f"Progress: {i}/{len(configs)} ({(i/len(configs)*100):.1f}%)") + tester.test_config_compatibility(config) + + self._print_results(tester, "Random") + self._generate_comprehensive_report(tester, "Random") + + def test_random_configurations_by_dimension(self): + """Test random configurations generated separately for each dimension.""" + generator = ParallelConfigGenerator(seed=789) + tester = ParallelCompatibilityTester() + + # Generate configs for each dimension separately + # This ensures balanced coverage across all dimensions + # Increased by 20x for comprehensive testing + counts_by_dimension = { + 1: 4000, # 1D: 4000 configs (200 * 20) + 2: 6000, # 2D: 6000 configs (300 * 20) - more because there are more 2D combinations + 3: 5000, # 3D: 5000 configs (250 * 20) + 4: 3000, # 4D: 3000 configs (150 * 20) + 5: 2000, # 5D: 2000 configs (100 * 20) + } + + print("\n" + "=" * 80) + print("RANDOM CONFIGURATION TESTING BY DIMENSION") + print("=" * 80) + print(f"\nGenerating configurations by dimension:") + for dim, count in counts_by_dimension.items(): + print(f" {dim}D: {count} configurations") + + configs = generator.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, + max_size=1024, + min_parallel_size=2, + max_parallel_size=32) + + print(f"\nTotal unique configurations generated: {len(configs)}") + print(f"Max world size: 1024, Parallel size range: 2-32") + + # Count configs by dimension + dim_counts = defaultdict(int) + for config in configs: + dim_count = len([d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1]) + dim_counts[dim_count] += 1 + + print("\nActual distribution:") + for dim in sorted(dim_counts.keys()): + print(f" {dim}D: {dim_counts[dim]} configurations") + + print(f"\nTesting {len(configs)} configurations...") + + for i, config in enumerate(configs, 1): + # Update progress more frequently for large test sets + if i % 1000 == 0 or i == len(configs): + print(f"Progress: {i}/{len(configs)} ({(i/len(configs)*100):.1f}%)") + tester.test_config_compatibility(config) + + self._print_results(tester, "Random by Dimension") + self._generate_comprehensive_report(tester, "Random by Dimension") + + def test_edge_cases(self): + """Test edge cases and boundary conditions.""" + generator = ParallelConfigGenerator(seed=456) + tester = ParallelCompatibilityTester() + + # Edge cases - including larger sizes + edge_configs = [ + # Maximum dimensions - larger sizes + { + "tp": 8, + "dp": 8, + "pp": 8, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": 512 + }, + { + "tp": 16, + "dp": 16, + "pp": 4, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": 1024 + }, + # EP and CP conflict + { + "tp": 2, + "dp": 2, + "pp": 1, + "cp": 2, + "ep": 2, + "order": "tp-ep-dp", + "world_size": 8 + }, + { + "tp": 4, + "dp": 4, + "pp": 1, + "cp": 4, + "ep": 4, + "order": "tp-ep-dp", + "world_size": 64 + }, + # Single dimension - larger sizes + { + "tp": 1, + "dp": 1, + "pp": 64, + "cp": 1, + "ep": 1, + "order": "pp", + "world_size": 64 + }, + { + "tp": 128, + "dp": 1, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp", + "world_size": 128 + }, + { + "tp": 1, + "dp": 256, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp", + "world_size": 256 + }, + # All dimensions - larger sizes + { + "tp": 2, + "dp": 2, + "pp": 2, + "cp": 2, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": 16 + }, + { + "tp": 4, + "dp": 4, + "pp": 4, + "cp": 4, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": 256 + }, + # Different orders + { + "tp": 2, + "dp": 4, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp-tp", + "world_size": 8 + }, + { + "tp": 2, + "dp": 4, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp-dp", + "world_size": 8 + }, + { + "tp": 8, + "dp": 16, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp-tp", + "world_size": 128 + }, + { + "tp": 8, + "dp": 16, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp-dp", + "world_size": 128 + }, + # Large multi-dimensional + { + "tp": 8, + "dp": 8, + "pp": 4, + "cp": 1, + "ep": 1, + "order": "tp-pp-dp", + "world_size": 256 + }, + { + "tp": 4, + "dp": 8, + "pp": 8, + "cp": 1, + "ep": 1, + "order": "tp-pp-dp", + "world_size": 256 + }, + ] + + print("\n" + "=" * 80) + print("EDGE CASE TESTING") + print("=" * 80) + print(f"\nTesting {len(edge_configs)} edge case configurations...") + + for i, config in enumerate(edge_configs, 1): + print(f"\n[{i}/{len(edge_configs)}] Testing: {config}") + tester.test_config_compatibility(config) + + self._print_results(tester, "Edge Cases") + self._generate_comprehensive_report(tester, "Edge Cases") + + def _print_results(self, tester: ParallelCompatibilityTester, test_type: str): + """Print test results.""" + results = tester.results + + print("\n" + "=" * 80) + print(f"{test_type} TEST RESULTS") + print("=" * 80) + + print(f"\n✓ Megatron Success: {len(results['megatron_success'])}") + print(f"✗ Megatron Failures: {len(results['megatron_failures'])}") + if results['megatron_failures']: + print("\nMegatron Failures:") + for failure in results['megatron_failures'][:10]: # Show first 10 + print(f" - {failure['config']}: {failure['error']}") + if len(results['megatron_failures']) > 10: + print(f" ... and {len(results['megatron_failures']) - 10} more") + + print(f"\n✓ DeepSpeed Success: {len(results['deepspeed_success'])}") + print(f"✗ DeepSpeed Failures: {len(results['deepspeed_failures'])}") + if results['deepspeed_failures']: + print("\nDeepSpeed Failures:") + for failure in results['deepspeed_failures'][:10]: # Show first 10 + print(f" - {failure['config']}") + if failure.get('support_info'): + notes = failure['support_info'].get('notes', []) + if notes: + print(f" Notes: {', '.join(notes)}") + if len(results['deepspeed_failures']) > 10: + print(f" ... and {len(results['deepspeed_failures']) - 10} more") + + print(f"\n✓ Compatible (Both Support): {len(results['compatible'])}") + if results['compatible']: + print(" Examples:") + rank_matches = 0 + rank_mismatches = 0 + for item in results['compatible'][:10]: + if isinstance(item, dict): + config = item.get('config', 'Unknown') + rank_comp = item.get('rank_comparison') + if rank_comp: + if rank_comp.get('same_structure'): + print(f" - {config} ✓ Rank groups match") + rank_matches += 1 + else: + print(f" - {config} ⚠ Rank groups differ") + rank_mismatches += 1 + if rank_comp.get('differences'): + for diff in rank_comp['differences'][:2]: + print(f" {diff}") + else: + print(f" - {config}") + else: + print(f" - {item}") + if len(results['compatible']) > 10: + print(f" ... and {len(results['compatible']) - 10} more") + + if rank_matches > 0 or rank_mismatches > 0: + print(f"\n Rank Comparison Summary:") + print(f" Matches: {rank_matches}") + print(f" Mismatches: {rank_mismatches}") + print(f" (Note: Comparison only available for TP+DP combinations)") + + print(f"\n⚠ Megatron Only: {len(results['megatron_only'])}") + if results['megatron_only']: + print(" Examples:") + for item in results['megatron_only'][:5]: + print(f" - {item['config']}") + if item.get('deepspeed_issue'): + print(f" DeepSpeed issue: {', '.join(item['deepspeed_issue'])}") + if len(results['megatron_only']) > 5: + print(f" ... and {len(results['megatron_only']) - 5} more") + + print(f"\n→ DeepSpeed Only: {len(results['deepspeed_only'])}") + if results['deepspeed_only']: + print(" Examples:") + for item in results['deepspeed_only'][:5]: + print(f" - {item['config']}") + print(f" Megatron error: {item['megatron_error']}") + if len(results['deepspeed_only']) > 5: + print(f" ... and {len(results['deepspeed_only']) - 5} more") + + print(f"\n✗ Incompatible (Neither Support): {len(results['incompatible'])}") + if results['incompatible']: + print(" Examples:") + for item in results['incompatible'][:5]: + print(f" - {item['config']}") + print(f" Megatron: {item['megatron_error']}") + if len(results['incompatible']) > 5: + print(f" ... and {len(results['incompatible']) - 5} more") + + print("\n" + "=" * 80) + + def _generate_comprehensive_report(self, tester: ParallelCompatibilityTester, test_type: str): + """Generate comprehensive test report with error categorization and combination statistics.""" + results = tester.results + error_summary = tester.error_categorizer.get_error_summary() + combo_stats = tester.combination_stats + + print("\n" + "=" * 80) + print(f"{test_type} COMPREHENSIVE TEST REPORT") + print("=" * 80) + + # Overall statistics + print("\n" + "-" * 80) + print("OVERALL STATISTICS") + print("-" * 80) + total_tested = (len(results['megatron_success']) + len(results['megatron_failures']) + + len(results['deepspeed_success']) + len(results['deepspeed_failures'])) + print(f"Total Configurations Tested: {total_tested}") + print( + f" Megatron Success: {len(results['megatron_success'])} ({len(results['megatron_success'])/total_tested*100:.1f}%)" + ) + print( + f" Megatron Failures: {len(results['megatron_failures'])} ({len(results['megatron_failures'])/total_tested*100:.1f}%)" + ) + print( + f" DeepSpeed Success: {len(results['deepspeed_success'])} ({len(results['deepspeed_success'])/total_tested*100:.1f}%)" + ) + print( + f" DeepSpeed Failures: {len(results['deepspeed_failures'])} ({len(results['deepspeed_failures'])/total_tested*100:.1f}%)" + ) + print(f" Compatible: {len(results['compatible'])} ({len(results['compatible'])/total_tested*100:.1f}%)") + print( + f" Megatron Only: {len(results['megatron_only'])} ({len(results['megatron_only'])/total_tested*100:.1f}%)" + ) + print( + f" DeepSpeed Only: {len(results['deepspeed_only'])} ({len(results['deepspeed_only'])/total_tested*100:.1f}%)" + ) + print(f" Incompatible: {len(results['incompatible'])} ({len(results['incompatible'])/total_tested*100:.1f}%)") + + # Error categorization + print("\n" + "-" * 80) + print("ERROR CATEGORIZATION (Aggregated by Type)") + print("-" * 80) + for category, summary in sorted(error_summary.items(), key=lambda x: x[1]['count'], reverse=True): + print(f"\n{category}: {summary['count']} occurrences") + print(f" Affects {summary['unique_combinations']} unique combination types") + print(f" Examples:") + for example in summary['examples'][:3]: + combo = example.get('combination', 'Unknown') + lib = example.get('library', 'Unknown') + print(f" - {combo} ({lib}): {example['error'][:80]}") + if len(summary['examples']) > 3: + print(f" ... and {len(summary['examples']) - 3} more examples") + + # Combination type statistics + print("\n" + "-" * 80) + print("COMBINATION TYPE STATISTICS") + print("-" * 80) + print( + f"{'Combination':<20} {'Total':<8} {'M-Succ':<8} {'M-Fail':<8} {'DS-Succ':<8} {'DS-Fail':<8} {'Compat':<8} {'M-Only':<8} {'DS-Only':<8} {'Incomp':<8}" + ) + print("-" * 100) + + # Sort by total count + sorted_combos = sorted(combo_stats.items(), key=lambda x: x[1]['total'], reverse=True) + for combo_type, stats in sorted_combos: + if stats['total'] > 0: + print(f"{combo_type:<20} {stats['total']:<8} {stats['megatron_success']:<8} " + f"{stats['megatron_failures']:<8} {stats['deepspeed_success']:<8} " + f"{stats['deepspeed_failures']:<8} {stats['compatible']:<8} " + f"{stats['megatron_only']:<8} {stats['deepspeed_only']:<8} " + f"{stats['incompatible']:<8}") + + # Detailed combination analysis + print("\n" + "-" * 80) + print("DETAILED COMBINATION ANALYSIS") + print("-" * 80) + + # Group by number of dimensions + by_dimension_count = defaultdict(list) + for combo_type, stats in combo_stats.items(): + dim_count = len([c for c in combo_type.split('+') if c != 'NONE']) + by_dimension_count[dim_count].append((combo_type, stats)) + + for dim_count in sorted(by_dimension_count.keys()): + print(f"\n{dim_count}-Dimensional Combinations:") + combos = sorted(by_dimension_count[dim_count], key=lambda x: x[1]['total'], reverse=True) + for combo_type, stats in combos[:10]: # Show top 10 + if stats['total'] > 0: + compat_rate = (stats['compatible'] / stats['total'] * 100) if stats['total'] > 0 else 0 + print(f" {combo_type}:") + print(f" Total: {stats['total']}, Compatible: {stats['compatible']} ({compat_rate:.1f}%)") + print(f" Megatron: {stats['megatron_success']} success, {stats['megatron_failures']} failures") + print( + f" DeepSpeed: {stats['deepspeed_success']} success, {stats['deepspeed_failures']} failures") + if len(combos) > 10: + print(f" ... and {len(combos) - 10} more {dim_count}-dimensional combinations") + + print("\n" + "=" * 80) + + def test_cp_vs_sp_compatibility_by_dimension(self): + """Test CP vs SP compatibility using the same config generation as test_random_configurations_by_dimension. + + This test: + 1. Uses parallel_state_refactored with CP + 2. Uses DeepSpeed with SP + 3. Compares CP rank groups with SP rank groups to see if they match + """ + generator = ParallelConfigGenerator(seed=789) + + # Use the same configuration generation as test_random_configurations_by_dimension + counts_by_dimension = { + 1: 4000, # 1D: 4000 configs + 2: 6000, # 2D: 6000 configs + 3: 5000, # 3D: 5000 configs + 4: 3000, # 4D: 3000 configs + 5: 2000, # 5D: 2000 configs + } + + print("\n" + "=" * 80) + print("CP vs SP COMPATIBILITY TESTING BY DIMENSION") + print("=" * 80) + print(f"\nGenerating configurations by dimension:") + for dim, count in counts_by_dimension.items(): + print(f" {dim}D: {count} configurations") + + configs = generator.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, + max_size=1024, + min_parallel_size=2, + max_parallel_size=32) + + # Filter to only include configs with CP > 1 and EP == 1 (EP and CP cannot both be > 1) + configs_with_cp = [c for c in configs if c["cp"] > 1 and c["ep"] == 1] + + print(f"\nTotal unique configurations generated: {len(configs)}") + print(f"Configurations with CP > 1 and EP == 1: {len(configs_with_cp)}") + print(f"Max world size: 1024, Parallel size range: 2-32") + + # Test CP vs SP compatibility + results = { + "total_tested": 0, + "cp_groups_generated": 0, + "sp_groups_generated": 0, + "rank_groups_match": 0, + "rank_groups_differ": 0, + "errors": 0, + "match_details": [], + "differ_details": [], + } + + combination_stats = defaultdict(lambda: { + "total": 0, + "match": 0, + "differ": 0, + "errors": 0, + }) + + print(f"\nTesting {len(configs_with_cp)} configurations for CP vs SP compatibility...") + + for i, config in enumerate(configs_with_cp, 1): + if i % 1000 == 0 or i == len(configs_with_cp): + print(f"Progress: {i}/{len(configs_with_cp)} ({(i/len(configs_with_cp)*100):.1f}%)") + + results["total_tested"] += 1 + + # Get combination type + combo_type = self._get_combination_type_for_cp_sp(config) + combination_stats[combo_type]["total"] += 1 + + try: + # Get CP rank groups from Megatron + if not PARALLEL_STATE_AVAILABLE: + results["errors"] += 1 + combination_stats[combo_type]["errors"] += 1 + continue + + rg = RankGenerator(tp=config["tp"], + ep=config["ep"], + dp=config["dp"], + pp=config["pp"], + cp=config["cp"], + order=config["order"]) + + cp_groups = rg.get_ranks("cp") + if cp_groups: + results["cp_groups_generated"] += 1 + + # Simulate SP rank groups from DeepSpeed + # DeepSpeed SP creates consecutive rank groups + sp_groups = self._simulate_deepspeed_sp_groups(config["world_size"], config["cp"]) + if sp_groups: + results["sp_groups_generated"] += 1 + + # Compare CP and SP groups + if self._compare_cp_sp_groups(cp_groups, sp_groups): + results["rank_groups_match"] += 1 + combination_stats[combo_type]["match"] += 1 + results["match_details"].append(config) + else: + results["rank_groups_differ"] += 1 + combination_stats[combo_type]["differ"] += 1 + results["differ_details"].append({ + "config": config, + "cp_groups": cp_groups, + "sp_groups": sp_groups, + }) + + except Exception as e: + results["errors"] += 1 + combination_stats[combo_type]["errors"] += 1 + + # Generate report + self._generate_cp_vs_sp_report(results, combination_stats) + + def _simulate_deepspeed_sp_groups(self, world_size: int, sp_size: int) -> List[List[int]]: + """Simulate DeepSpeed's SP rank group generation. + + DeepSpeed SP creates groups as consecutive ranks: + - Group 0: [0, 1, ..., sp_size-1] + - Group 1: [sp_size, sp_size+1, ..., 2*sp_size-1] + - etc. + """ + if sp_size <= 1 or world_size % sp_size != 0: + return [] + + num_groups = world_size // sp_size + groups = [] + for i in range(num_groups): + group = list(range(i * sp_size, (i + 1) * sp_size)) + groups.append(group) + + return groups + + def _compare_cp_sp_groups(self, cp_groups: List[List[int]], sp_groups: List[List[int]]) -> bool: + """Compare CP and SP rank groups to see if they match.""" + if not cp_groups and not sp_groups: + return True + + if not cp_groups or not sp_groups: + return False + + if len(cp_groups) != len(sp_groups): + return False + + # Check if all CP groups have a matching SP group (order may differ) + cp_sets = [set(g) for g in cp_groups] + sp_sets = [set(g) for g in sp_groups] + + # Check if all CP groups match SP groups + for cp_set in cp_sets: + found = False + for sp_set in sp_sets: + if cp_set == sp_set: + found = True + break + if not found: + return False + + # Check if all SP groups match CP groups + for sp_set in sp_sets: + found = False + for cp_set in cp_sets: + if sp_set == cp_set: + found = True + break + if not found: + return False + + return True + + def _get_combination_type_for_cp_sp(self, config: Dict) -> str: + """Get combination type string for CP vs SP testing.""" + dims = [] + if config["tp"] > 1: + dims.append("TP") + if config["dp"] > 1: + dims.append("DP") + if config["pp"] > 1: + dims.append("PP") + if config["cp"] > 1: + dims.append("CP") + # Note: EP is always 1 in this test + + if not dims: + return "NONE" + + return "+".join(sorted(dims)) + + def _generate_cp_vs_sp_report(self, results: Dict, combination_stats: Dict): + """Generate comprehensive CP vs SP compatibility report.""" + print("\n" + "=" * 80) + print("CP vs SP COMPATIBILITY TEST REPORT") + print("=" * 80) + + # Overall statistics + print("\n" + "-" * 80) + print("OVERALL STATISTICS") + print("-" * 80) + print(f"Total Configurations Tested: {results['total_tested']}") + print(f" CP Groups Generated: {results['cp_groups_generated']}") + print(f" SP Groups Generated: {results['sp_groups_generated']}") + print(f" Rank Groups Match: {results['rank_groups_match']}") + print(f" Rank Groups Differ: {results['rank_groups_differ']}") + print(f" Errors: {results['errors']}") + + if results['total_tested'] > 0: + match_rate = (results['rank_groups_match'] / results['total_tested']) * 100 + print(f"\n Match Rate: {match_rate:.2f}%") + print(f" CP can replace SP in {match_rate:.2f}% of tested configurations") + + # Combination type statistics + print("\n" + "-" * 80) + print("COMBINATION TYPE STATISTICS") + print("-" * 80) + print(f"{'Combination':<20} {'Total':<8} {'Match':<8} {'Differ':<8} {'Errors':<8} {'Match Rate':<12}") + print("-" * 80) + + sorted_combos = sorted(combination_stats.items(), key=lambda x: x[1]['total'], reverse=True) + for combo_type, stats in sorted_combos: + if stats['total'] > 0: + match_rate = (stats['match'] / stats['total'] * 100) if stats['total'] > 0 else 0 + print(f"{combo_type:<20} {stats['total']:<8} {stats['match']:<8} " + f"{stats['differ']:<8} {stats['errors']:<8} {match_rate:.1f}%") + + # Examples of matching configurations + print("\n" + "-" * 80) + print("EXAMPLES OF MATCHING CONFIGURATIONS (CP can replace SP)") + print("-" * 80) + for i, config in enumerate(results['match_details'][:10], 1): + print(f"{i}. {config}") + print(f" CP size: {config['cp']}, Order: {config['order']}") + + if len(results['match_details']) > 10: + print(f"\n... and {len(results['match_details']) - 10} more matching configurations") + + # Examples of differing configurations + if results['differ_details']: + print("\n" + "-" * 80) + print("EXAMPLES OF DIFFERING CONFIGURATIONS (CP cannot replace SP)") + print("-" * 80) + for i, item in enumerate(results['differ_details'][:10], 1): + config = item['config'] + cp_groups = item['cp_groups'] + sp_groups = item['sp_groups'] + print(f"{i}. {config}") + print(f" CP size: {config['cp']}, Order: {config['order']}") + print(f" CP groups count: {len(cp_groups)}, SP groups count: {len(sp_groups)}") + if cp_groups and sp_groups: + print(f" CP first group: {cp_groups[0]}") + print(f" SP first group: {sp_groups[0]}") + + if len(results['differ_details']) > 10: + print(f"\n... and {len(results['differ_details']) - 10} more differing configurations") + + # Conclusion + print("\n" + "=" * 80) + print("CONCLUSION") + print("=" * 80) + if results['rank_groups_match'] > 0: + match_rate = (results['rank_groups_match'] / results['total_tested']) * 100 + print(f"\n✓ CP can replace SP in {match_rate:.2f}% of tested configurations") + print( + f" - {results['rank_groups_match']} out of {results['total_tested']} configurations have matching rank groups" + ) + else: + print("\n✗ CP cannot replace SP in any of the tested configurations") + + if results['rank_groups_differ'] > 0: + print(f"\n⚠ {results['rank_groups_differ']} configurations have different rank groups") + print(" - These configurations may require special handling when migrating from CP to SP") + + print("\n" + "=" * 80) + + def test_comprehensive_automated_testing(self): + """Comprehensive automated testing with all test types.""" + print("\n" + "=" * 80) + print("COMPREHENSIVE AUTOMATED PARALLEL COMBINATION TESTING") + print("=" * 80) + + # Create a combined tester for overall report + combined_tester = ParallelCompatibilityTester() + + # Run all test types and accumulate results + print("\n[1/3] Running systematic configurations...") + generator1 = ParallelConfigGenerator(seed=42) + configs1 = generator1.generate_systematic_configs(max_world_size=512) + print(f"Testing {len(configs1)} systematic configurations...") + for i, config in enumerate(configs1, 1): + if i % 50 == 0 or i == len(configs1): + print(f" Progress: {i}/{len(configs1)}") + combined_tester.test_config_compatibility(config) + + print("\n[2/4] Running random configurations by dimension...") + generator2 = ParallelConfigGenerator(seed=789) + # Increased by 20x for comprehensive testing + counts_by_dimension = { + 1: 4000, # 1D: 4000 configs (200 * 20) + 2: 6000, # 2D: 6000 configs (300 * 20) + 3: 5000, # 3D: 5000 configs (250 * 20) + 4: 3000, # 4D: 3000 configs (150 * 20) + 5: 2000, # 5D: 2000 configs (100 * 20) + } + configs2 = generator2.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, + max_size=1024, + min_parallel_size=2, + max_parallel_size=32) + print(f"Testing {len(configs2)} random configurations (balanced by dimension)...") + print(f"Max world size: 1024, Parallel size range: 2-32") + for i, config in enumerate(configs2, 1): + # Update progress more frequently for large test sets + if i % 1000 == 0 or i == len(configs2): + print(f" Progress: {i}/{len(configs2)} ({(i/len(configs2)*100):.1f}%)") + combined_tester.test_config_compatibility(config) + + print("\n[3/4] Running additional random configurations...") + generator3 = ParallelConfigGenerator(seed=123) + # Increased by 20x: 500 * 20 = 10000 + configs3 = generator3.generate_random_configs(count=10000, max_size=1024) + print(f"Testing {len(configs3)} additional random configurations...") + for i, config in enumerate(configs3, 1): + # Update progress more frequently for large test sets + if i % 1000 == 0 or i == len(configs3): + print(f" Progress: {i}/{len(configs3)} ({(i/len(configs3)*100):.1f}%)") + combined_tester.test_config_compatibility(config) + + print("\n[4/4] Running edge cases...") + edge_configs = [ + { + "tp": 8, + "dp": 8, + "pp": 8, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": 512 + }, + { + "tp": 16, + "dp": 16, + "pp": 4, + "cp": 1, + "ep": 1, + "order": "tp-dp-pp", + "world_size": 1024 + }, + { + "tp": 2, + "dp": 2, + "pp": 1, + "cp": 2, + "ep": 2, + "order": "tp-ep-dp", + "world_size": 8 + }, + { + "tp": 4, + "dp": 4, + "pp": 1, + "cp": 4, + "ep": 4, + "order": "tp-ep-dp", + "world_size": 64 + }, + { + "tp": 1, + "dp": 1, + "pp": 64, + "cp": 1, + "ep": 1, + "order": "pp", + "world_size": 64 + }, + { + "tp": 128, + "dp": 1, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "tp", + "world_size": 128 + }, + { + "tp": 1, + "dp": 256, + "pp": 1, + "cp": 1, + "ep": 1, + "order": "dp", + "world_size": 256 + }, + { + "tp": 2, + "dp": 2, + "pp": 2, + "cp": 2, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": 16 + }, + { + "tp": 4, + "dp": 4, + "pp": 4, + "cp": 4, + "ep": 1, + "order": "tp-pp-dp-cp", + "world_size": 256 + }, + ] + print(f"Testing {len(edge_configs)} edge case configurations...") + for config in edge_configs: + combined_tester.test_config_compatibility(config) + + # Generate comprehensive report + print("\n" + "=" * 80) + print("COMPREHENSIVE FINAL REPORT") + print("=" * 80) + self._generate_comprehensive_report(combined_tester, "COMPREHENSIVE") + + print("\n" + "=" * 80) + print("ALL TESTS COMPLETED") + print("=" * 80) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) From d97cfb854ba0e58cd8059b6ace8aadca144a0fb4 Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Wed, 7 Jan 2026 11:43:37 +0800 Subject: [PATCH 02/15] parallel_state: Cleanup dependency on ProcessGroupNCCL.Options Signed-off-by: Junjie Mao --- deepspeed/utils/parallel_state.py | 98 ++++++++++++------------------- 1 file changed, 39 insertions(+), 59 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index df9906d2fcee..495241daa523 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -280,24 +280,19 @@ def __init__(self): self.decoder_rank_generator = None self.expert_decoder_rank_generator = None - def _get_nccl_options(self, pg_name: str, nccl_comm_cfgs: dict): - """Set the NCCL process group options.""" - if pg_name in nccl_comm_cfgs: - # FIXME: deepspeed.comm does not provide a way to set NCCL options yet. - nccl_options = torch.distributed.ProcessGroupNCCL.Options( - is_high_priority_stream=nccl_comm_cfgs[pg_name].get("is_high_priority_stream", False)) - if "cga_cluster_size" in nccl_comm_cfgs[pg_name]: - nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name]["cga_cluster_size"] - if "max_ctas" in nccl_comm_cfgs[pg_name]: - nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name]["max_ctas"] - if "min_ctas" in nccl_comm_cfgs[pg_name]: - nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name]["min_ctas"] - if "net_name" in nccl_comm_cfgs[pg_name]: - nccl_options.config.net_name = nccl_comm_cfgs[pg_name]["net_name"] - if nccl_options.config.net_name.lower() not in ["ib", "socket"]: - raise RuntimeError(f"net_name ({nccl_options.config.net_name}) is not supported." - f"Accepted values: 'IB' or 'socket'.") - return nccl_options + def _get_pg_options(self, pg_name: str, pg_comm_cfgs: dict): + """Get the options for a specific process group.""" + # TODO: construct process group options from json config + # + # As of PyTorch 2.9, the only backend that supports pg options is nccl, + # and a nccl-specific class, namely ProcessGroupNCCL.Options, is + # required to construct the options. + # + # To enable configuring such options in DeepSpeed, we need to define the + # interface for users to specify them and also figure out whether we + # want to export ProcessGroupNCCL.Options in deepspeed.comm or allow + # using torch distributed for this specific case in check-torchdist.py. + # Those are left as future work. return None def _create_group( @@ -393,13 +388,11 @@ def initialize_model_parallel( expert_model_parallel_size: int = 1, num_distributed_optimizer_instances: int = 1, expert_tensor_parallel_size: Optional[int] = None, - nccl_communicator_config_path: Optional[str] = None, distributed_timeout_minutes: int = 30, order: str = "tp-cp-ep-dp-pp", get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, create_gloo_process_groups: bool = True, - high_priority_stream_groups: Optional[List[str]] = None, ) -> None: """Initialize model data parallel groups. @@ -439,23 +432,10 @@ def default_position_embedding_ranks(pp_ranks): self.virtual_pipeline_model_parallel_rank = 0 self.virtual_pipeline_model_parallel_world_size = virtual_pipeline_model_parallel_size - # Load NCCL configs - nccl_comm_cfgs = {} - if nccl_communicator_config_path is not None: - try: - import yaml - except ImportError: - raise RuntimeError("Cannot import `yaml`. Setting custom nccl communicator configs " - "requires the yaml package.") - with open(nccl_communicator_config_path, "r") as stream: - nccl_comm_cfgs = yaml.safe_load(stream) - - # Set high priority stream groups - high_priority_stream_groups = high_priority_stream_groups or [] - for pg_name in high_priority_stream_groups: - if pg_name not in nccl_comm_cfgs: - nccl_comm_cfgs[pg_name] = {} - nccl_comm_cfgs[pg_name]["is_high_priority_stream"] = True + # TODO: Collect process group options from configs + # + # Check _get_pg_options for details. + pg_comm_cfgs = {} # Create rank generators self.decoder_rank_generator = RankGenerator( @@ -502,7 +482,7 @@ def default_position_embedding_ranks(pp_ranks): group_with_cp = self._create_group( ranks_with_cp, timeout=timeout, - pg_options=self._get_nccl_options("dp_cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("dp_cp", pg_comm_cfgs), group_desc="DATA_PARALLEL_GROUP_WITH_CP", ) if create_gloo_process_groups: @@ -526,7 +506,7 @@ def default_position_embedding_ranks(pp_ranks): intra_partial_dp_group_with_cp = self._create_group( intra_partial_dp_ranks_with_cp, timeout=timeout, - pg_options=self._get_nccl_options("intra_dp_cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("intra_dp_cp", pg_comm_cfgs), group_desc="INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP", ) if create_gloo_process_groups: @@ -550,7 +530,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("dp", nccl_comm_cfgs), + pg_options=self._get_pg_options("dp", pg_comm_cfgs), group_desc="DATA_PARALLEL_GROUP", ) if create_gloo_process_groups: @@ -571,7 +551,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("cp", pg_comm_cfgs), group_desc="CONTEXT_PARALLEL_GROUP", ) if rank in ranks: @@ -584,7 +564,7 @@ def default_position_embedding_ranks(pp_ranks): ranks, hierarchical_context_parallel_sizes, create_gloo_process_groups=False, - pg_options=self._get_nccl_options("hcp", nccl_comm_cfgs), + pg_options=self._get_pg_options("hcp", pg_comm_cfgs), timeout=timeout, group_desc="CONTEXT_PARALLEL_GROUP", ) @@ -597,7 +577,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("mp", nccl_comm_cfgs), + pg_options=self._get_pg_options("mp", pg_comm_cfgs), group_desc="MODEL_PARALLEL_GROUP", ) if rank in ranks: @@ -610,7 +590,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp", pg_comm_cfgs), group_desc="TENSOR_MODEL_PARALLEL_GROUP", ) if rank in ranks: @@ -627,8 +607,8 @@ def default_position_embedding_ranks(pp_ranks): ranks, timeout=timeout, backend=pipeline_model_parallel_comm_backend, - pg_options=(None if pipeline_model_parallel_comm_backend == "ucc" else self._get_nccl_options( - "pp", nccl_comm_cfgs)), + pg_options=(None if pipeline_model_parallel_comm_backend == "ucc" else self._get_pg_options( + "pp", pg_comm_cfgs)), group_desc="PIPELINE_MODEL_PARALLEL_GROUP", ) assert ( @@ -653,7 +633,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( embedding_ranks, timeout=timeout, - pg_options=self._get_nccl_options("embd", nccl_comm_cfgs), + pg_options=self._get_pg_options("embd", pg_comm_cfgs), group_desc="EMBEDDING_GROUP", ) if rank in embedding_ranks: @@ -664,7 +644,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( position_embedding_ranks, timeout=timeout, - pg_options=self._get_nccl_options("pos_embd", nccl_comm_cfgs), + pg_options=self._get_pg_options("pos_embd", pg_comm_cfgs), group_desc="POSITION_EMBEDDING_GROUP", ) if rank in position_embedding_ranks: @@ -677,7 +657,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_dp_cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_dp_cp", pg_comm_cfgs), group_desc="TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP", ) if rank in ranks: @@ -686,7 +666,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_dp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_dp", pg_comm_cfgs), group_desc="TENSOR_AND_DATA_PARALLEL_GROUP", ) if rank in ranks: @@ -697,7 +677,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_cp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_cp", pg_comm_cfgs), group_desc="TENSOR_AND_CONTEXT_PARALLEL_GROUP", ) if rank in ranks: @@ -708,7 +688,7 @@ def default_position_embedding_ranks(pp_ranks): for ranks in self.expert_decoder_rank_generator.get_ranks('ep'): group = self._create_group( ranks, - pg_options=self._get_nccl_options("ep", nccl_comm_cfgs), + pg_options=self._get_pg_options("ep", pg_comm_cfgs), group_desc="EXPERT_MODEL_PARALLEL_GROUP", ) if rank in ranks: @@ -719,7 +699,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("ep_tp", nccl_comm_cfgs), + pg_options=self._get_pg_options("ep_tp", pg_comm_cfgs), group_desc="EXPERT_TENSOR_PARALLEL_GROUP", ) if rank in ranks: @@ -730,7 +710,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_ep_mp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_ep_mp", pg_comm_cfgs), group_desc="EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP", ) if rank in ranks: @@ -741,7 +721,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("tp_ep_pp", nccl_comm_cfgs), + pg_options=self._get_pg_options("tp_ep_pp", pg_comm_cfgs), group_desc="EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP", ) if rank in ranks: @@ -761,7 +741,7 @@ def default_position_embedding_ranks(pp_ranks): group = self._create_group( ranks, timeout=timeout, - pg_options=self._get_nccl_options("ep_dp", nccl_comm_cfgs), + pg_options=self._get_pg_options("ep_dp", pg_comm_cfgs), group_desc="EXPERT_DATA_PARALLEL_GROUP", ) if create_gloo_process_groups: @@ -779,8 +759,8 @@ def default_position_embedding_ranks(pp_ranks): [intra_partial_expert_data_parallel_size, num_distributed_optimizer_instances], create_gloo_process_groups=create_gloo_process_groups, pg_options=[ - self._get_nccl_options("intra_ep_dp", nccl_comm_cfgs), - self._get_nccl_options("inter_ep_dp", nccl_comm_cfgs), + self._get_pg_options("intra_ep_dp", pg_comm_cfgs), + self._get_pg_options("inter_ep_dp", pg_comm_cfgs), ], timeout=timeout, group_desc="EXPERT_DATA_PARALLEL_GROUP", @@ -804,7 +784,7 @@ def default_position_embedding_ranks(pp_ranks): intra_dist_opt_instance_group = self._create_group( intra_dist_opt_ranks, timeout=timeout, - pg_options=self._get_nccl_options("intra_dist_opt_instance", nccl_comm_cfgs), + pg_options=self._get_pg_options("intra_dist_opt_instance", pg_comm_cfgs), group_desc="INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP", ) if rank in intra_dist_opt_ranks: From a146fcff80b41bea5cfa9ebbcb21b9a72b6a5707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Tue, 23 Dec 2025 20:41:37 +0800 Subject: [PATCH 03/15] feat: add config-based parallel state initialization with validation Add comprehensive config.json support for parallelism configuration with smart priority handling and context parallel validation. Key features: - Support all parallelism dimensions via config.json - Config priority: config file > function params > defaults - Conflict detection with warning logs - Context parallel validation (CP must be 1) - Backward compatible with existing code Changes: - Add 14 optional parameters to initialize_parallel_state_from_config() - Implement 3-tier priority system with conflict detection - Add CP validation: raise NotImplementedError if CP > 1 - Update default order from "tp-cp-ep-dp-pp" to "tp-ep-dp-pp" - Add detailed docstrings and usage examples This allows users to configure all parallel dimensions in config.json instead of reading documentation and manually calling initialize_model_parallel. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 252 +++++++++++++++++++- 1 file changed, 249 insertions(+), 3 deletions(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index bf3a346de194..2d4cbf93915a 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -29,6 +29,7 @@ - Supports multiple parallel state instances (for RL scenarios with different models) - Backward compatible with single global instance - Context manager for switching between different parallel configurations +- Configuration-based initialization from config.json Usage: # Basic usage (single global instance): @@ -56,11 +57,16 @@ with set_current_parallel_state("critic"): critic_dp_group = get_data_parallel_group() # Uses critic's DP group + + # Initialize from config.json: + from deepspeed import DeepSpeedConfig + ds_config = DeepSpeedConfig("config.json") + initialize_parallel_state_from_config(ds_config) """ from contextlib import contextmanager -from typing import Optional -from parallel_state import ParallelState, get_parallel_state as _get_default_parallel_state +from typing import Optional, Union, Dict, Any, List +from .parallel_state import ParallelState, get_parallel_state as _get_default_parallel_state # Registry for multiple parallel state instances _parallel_state_registry = {} @@ -287,7 +293,7 @@ def get_tensor_model_parallel_src_rank(name: Optional[str] = None): DeepSpeed-compatible interface. """ - import torch.distributed as dist + import deepspeed.comm as dist global_rank = dist.get_rank() local_world_size = get_tensor_model_parallel_world_size(name) return (global_rank // local_world_size) * local_world_size @@ -553,3 +559,243 @@ def is_initialized(name: Optional[str] = None): DeepSpeed-compatible interface. """ return get_parallel_state(name).is_initialized() + + +# ============================================================================ +# Configuration-based Initialization +# ============================================================================ + + +def initialize_parallel_state_from_config( + config: Union[Dict[str, Any], Any], + name: Optional[str] = None, + config_key: str = "parallelism", + # Optional parameters to override config values + tensor_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_size: Optional[int] = None, + virtual_pipeline_model_parallel_size: Optional[int] = None, + pipeline_model_parallel_comm_backend: Optional[str] = None, + context_parallel_size: Optional[int] = None, + hierarchical_context_parallel_sizes: Optional[List[int]] = None, + expert_model_parallel_size: Optional[int] = None, + num_distributed_optimizer_instances: Optional[int] = None, + expert_tensor_parallel_size: Optional[int] = None, + nccl_communicator_config_path: Optional[str] = None, + distributed_timeout_minutes: Optional[int] = None, + order: Optional[str] = None, + create_gloo_process_groups: Optional[bool] = None, + high_priority_stream_groups: Optional[List[str]] = None, +) -> None: + """Initialize parallel state from DeepSpeed config.json with optional parameter overrides. + + This function reads parallelism configuration from the DeepSpeed config file + and automatically initializes the ParallelState instance. This allows users + to configure all parallelism dimensions in a single place (config.json) + rather than having to read documentation and manually call initialize_model_parallel. + + Configuration priority: config file (if explicitly set) > function parameters > default values + + Note: If a value is explicitly set in config file, it takes precedence over function + parameters. A warning will be logged if there's a conflict. To override config file + values, remove them from the config file first. + + Args: + config: Either a DeepSpeedConfig object or a config dictionary. + If DeepSpeedConfig, will access its _param_dict attribute. + If dict, will use it directly. + name: Optional name of the parallel state instance to initialize. + If None, initializes the default global instance. + config_key: Key in the config dictionary where parallelism config is stored. + Default is "parallelism". + + # Parallelism dimension parameters (override config if provided): + tensor_model_parallel_size: Size of tensor model parallel group. Default: 1 + pipeline_model_parallel_size: Size of pipeline model parallel group. Default: 1 + virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size. Default: None + pipeline_model_parallel_comm_backend: Communication backend for pipeline. Default: None + context_parallel_size: Size of context parallel group. Default: 1 (MUST be 1, CP not supported) + hierarchical_context_parallel_sizes: Hierarchical context parallel sizes. Default: None (NOT supported) + expert_model_parallel_size: Size of expert model parallel group. Default: 1 + num_distributed_optimizer_instances: Number of distributed optimizer instances. Default: 1 + expert_tensor_parallel_size: Size of expert tensor parallel group. Default: None + nccl_communicator_config_path: Path to NCCL communicator config. Default: None + distributed_timeout_minutes: Timeout for distributed operations. Default: 30 + order: Order of parallelism dimensions. Default: "tp-ep-dp-pp" + create_gloo_process_groups: Whether to create Gloo process groups. Default: True + high_priority_stream_groups: High priority stream groups. Default: None + + Example config.json: + { + "parallelism": { + "tensor_model_parallel_size": 2, + "pipeline_model_parallel_size": 1, + "expert_model_parallel_size": 1, + "expert_tensor_parallel_size": 1, + "virtual_pipeline_model_parallel_size": null, + "pipeline_model_parallel_comm_backend": null, ##不要加入config中,保留加载逻辑 + "num_distributed_optimizer_instances": 1, + "nccl_communicator_config_path": null, + "distributed_timeout_minutes": 30, + "order": "tp-ep-dp-pp", + "create_gloo_process_groups": true, + "high_priority_stream_groups": null + }, + + // Note: The following parameters are NOT supported in DeepSpeed: + // - "context_parallel_size": must be 1 (default) + // - "hierarchical_context_parallel_sizes": not supported + "train_batch_size": 8, + ... + } + + Example usage: + # Basic usage from config file: + from deepspeed import DeepSpeedConfig + ds_config = DeepSpeedConfig("config.json") + initialize_parallel_state_from_config(ds_config) + + # Override specific parameters: + initialize_parallel_state_from_config( + ds_config, + tensor_model_parallel_size=4, # Override config value + expert_model_parallel_size=2 + ) + + # From config dictionary: + import json + with open("config.json") as f: + config_dict = json.load(f) + initialize_parallel_state_from_config(config_dict) + + # For named instances (RL scenarios): + initialize_parallel_state_from_config(ds_config, name="actor") + initialize_parallel_state_from_config( + ds_config, + name="critic", + tensor_model_parallel_size=2 # Override for critic + ) + """ + # Extract config dictionary + if hasattr(config, '_param_dict'): + # DeepSpeedConfig object + config_dict = config._param_dict + elif isinstance(config, dict): + # Already a dictionary + config_dict = config + else: + raise ValueError(f"config must be a DeepSpeedConfig object or a dict, got {type(config)}") + + # Check if parallelism config exists in config file + parallelism_config = config_dict.get(config_key, {}) + if parallelism_config and not isinstance(parallelism_config, dict): + raise ValueError(f"'{config_key}' in config must be a dictionary, got {type(parallelism_config)}") + + # Get the parallel state instance + ps = get_parallel_state_instance(name) + + # Check if already initialized + if ps.is_initialized(): + # Already initialized, skip + return + + # Import logging + import logging + logger = logging.getLogger(__name__) + + # Helper function to get value with proper priority handling + # Priority: config file (if explicitly set) > function parameter > default + def get_value(param_name, param_value, config_key, default_value): + """ + Get value with priority handling and conflict detection. + + Priority: + 1. If config file explicitly sets the value -> use config value (warn if param differs) + 2. If config file doesn't have the value -> use function parameter + 3. If both are None -> use default value + """ + config_has_key = config_key in parallelism_config + config_value = parallelism_config.get(config_key) + + # Case 1: Config file explicitly sets the value + if config_has_key: + # If function parameter is also provided and differs, warn and use config + if param_value is not None and param_value != config_value: + logger.warning(f"Parameter '{param_name}' conflict detected: " + f"config file specifies {config_value}, but function parameter is {param_value}. " + f"Using config file value ({config_value}). " + f"To override config, remove '{config_key}' from config file.") + return config_value + + # Case 2: Config file doesn't have the key, use function parameter if provided + if param_value is not None: + return param_value + + # Case 3: Neither config nor parameter provided, use default + return default_value + + # Extract parameters with proper priority: config (if set) > function param > default + init_kwargs = { + "tensor_model_parallel_size": + get_value("tensor_model_parallel_size", tensor_model_parallel_size, "tensor_model_parallel_size", 1), + "pipeline_model_parallel_size": + get_value("pipeline_model_parallel_size", pipeline_model_parallel_size, "pipeline_model_parallel_size", 1), + "virtual_pipeline_model_parallel_size": + get_value("virtual_pipeline_model_parallel_size", virtual_pipeline_model_parallel_size, + "virtual_pipeline_model_parallel_size", None), + "pipeline_model_parallel_comm_backend": + get_value("pipeline_model_parallel_comm_backend", pipeline_model_parallel_comm_backend, + "pipeline_model_parallel_comm_backend", None), + "context_parallel_size": + get_value("context_parallel_size", context_parallel_size, "context_parallel_size", 1), + "hierarchical_context_parallel_sizes": + get_value("hierarchical_context_parallel_sizes", hierarchical_context_parallel_sizes, + "hierarchical_context_parallel_sizes", None), + "expert_model_parallel_size": + get_value("expert_model_parallel_size", expert_model_parallel_size, "expert_model_parallel_size", 1), + "num_distributed_optimizer_instances": + get_value("num_distributed_optimizer_instances", num_distributed_optimizer_instances, + "num_distributed_optimizer_instances", 1), + "expert_tensor_parallel_size": + get_value("expert_tensor_parallel_size", expert_tensor_parallel_size, "expert_tensor_parallel_size", None), + "nccl_communicator_config_path": + get_value("nccl_communicator_config_path", nccl_communicator_config_path, "nccl_communicator_config_path", + None), + "distributed_timeout_minutes": + get_value("distributed_timeout_minutes", distributed_timeout_minutes, "distributed_timeout_minutes", 30), + "order": + get_value("order", order, "order", "tp-ep-dp-pp"), + "create_gloo_process_groups": + get_value("create_gloo_process_groups", create_gloo_process_groups, "create_gloo_process_groups", True), + "high_priority_stream_groups": + get_value("high_priority_stream_groups", high_priority_stream_groups, "high_priority_stream_groups", None), + } + + # Validate context_parallel_size + cp_size = init_kwargs["context_parallel_size"] + if cp_size != 1: + raise NotImplementedError( + f"DeepSpeed currently does not support context_parallel_size > 1. " + f"Got context_parallel_size={cp_size}. Please set context_parallel_size=1 in your config.") + + # Validate hierarchical_context_parallel_sizes + hcp_sizes = init_kwargs["hierarchical_context_parallel_sizes"] + if hcp_sizes is not None: + raise NotImplementedError( + f"DeepSpeed currently does not support hierarchical_context_parallel_sizes. " + f"Got hierarchical_context_parallel_sizes={hcp_sizes}. Please remove this configuration.") + + # Remove None values for optional parameters (except those that can be None) + # Keep None for: virtual_pipeline_model_parallel_size, pipeline_model_parallel_comm_backend, + # hierarchical_context_parallel_sizes, expert_tensor_parallel_size, nccl_communicator_config_path, + # high_priority_stream_groups + filtered_kwargs = {} + for key, value in init_kwargs.items(): + if value is not None or key in [ + "virtual_pipeline_model_parallel_size", "pipeline_model_parallel_comm_backend", + "hierarchical_context_parallel_sizes", "expert_tensor_parallel_size", "nccl_communicator_config_path", + "high_priority_stream_groups" + ]: + filtered_kwargs[key] = value + + # Initialize parallel state + ps.initialize_model_parallel(**filtered_kwargs) From 8ca894556f86dd8830e1e04fac101bc162b7ecbe Mon Sep 17 00:00:00 2001 From: yunqing Date: Thu, 15 Jan 2026 10:51:43 +0800 Subject: [PATCH 04/15] Add sequence parallel support to refactored parallel state - Extend RankGenerator to include SP dimension and enforce TP/PP/EP compatibility - Initialize sequence parallel and sequence+data parallel process groups in ParallelState.initialize_model_parallel - Add sequence-parallel accessor stubs in parallel_state_deepspeed for future unified SP interfaces Signed-off-by: Yuqing Li --- deepspeed/utils/parallel_state.py | 112 +++++++++++++++++++- deepspeed/utils/parallel_state_deepspeed.py | 81 +++++++++++++- 2 files changed, 189 insertions(+), 4 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 495241daa523..0e4793e97a86 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -150,16 +150,32 @@ def decompose(index, shape, stride=None): class RankGenerator: """A class for generating rank groups for different modes of parallelism.""" - def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0) -> None: + def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, sp: int, order: str, rank_offset: int = 0) -> None: assert (ep == 1 or cp == 1), "Both EP and CP > 1 is not allowed in one rank generator." + # Check SP compatibility: SP cannot be used with TP, PP, or EP + if sp > 1: + if tp > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Tensor Parallel (TP). " + f"SP size: {sp}, TP size: {tp}. " + "Please set tp=1 when using SP.") + if pp > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Pipeline Parallel (PP). " + f"SP size: {sp}, PP size: {pp}. " + "Please set pp=1 when using SP.") + if ep > 1: + raise RuntimeError(f"Sequence Parallel (SP) cannot be used together with Expert Parallel (EP). " + f"SP size: {sp}, EP size: {ep}. " + "Please set ep=1 when using SP.") + self.tp = tp self.ep = ep self.dp = dp self.pp = pp self.cp = cp + self.sp = sp self.rank_offset = rank_offset - self.world_size = tp * dp * pp * cp * ep + self.world_size = tp * dp * pp * cp * ep * sp self.name_to_size = { "tp": self.tp, @@ -167,6 +183,7 @@ def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank "dp": self.dp, "ep": self.ep, "cp": self.cp, + "sp": self.sp, } self.order = order order = order.lower() @@ -231,6 +248,10 @@ def __init__(self): self.data_parallel_group_with_cp = None self.data_parallel_group_with_cp_gloo = None + # Sequence parallel groups + self.sequence_parallel_group = None + self.sequence_and_data_parallel_group = None + # Expert-related groups self.expert_model_parallel_group = None self.expert_tensor_parallel_group = None @@ -384,12 +405,13 @@ def initialize_model_parallel( virtual_pipeline_model_parallel_size: Optional[int] = None, pipeline_model_parallel_comm_backend: Optional[str] = None, context_parallel_size: int = 1, + sequence_parallel_size: int = 1, hierarchical_context_parallel_sizes: Optional[List[int]] = None, expert_model_parallel_size: int = 1, num_distributed_optimizer_instances: int = 1, expert_tensor_parallel_size: Optional[int] = None, distributed_timeout_minutes: int = 30, - order: str = "tp-cp-ep-dp-pp", + order: str = "tp-ep-dp-pp", get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, create_gloo_process_groups: bool = True, @@ -446,6 +468,7 @@ def default_position_embedding_ranks(pp_ranks): cp=context_parallel_size, order=order, rank_offset=0, + sp=1, ) # Build expert rank generator @@ -467,6 +490,7 @@ def default_position_embedding_ranks(pp_ranks): cp=1, order=order, rank_offset=0, + sp=1, ) timeout = timedelta(minutes=distributed_timeout_minutes) @@ -791,6 +815,48 @@ def default_position_embedding_ranks(pp_ranks): self.intra_distributed_optimizer_instance_group = intra_dist_opt_instance_group intra_dist_opt_ranks = [] + # Build sequence parallel groups + if sequence_parallel_size > 1: + assert self.sequence_parallel_group is None, "sequence parallel group is already initialized" + assert self.sequence_and_data_parallel_group is None, "sequence and data parallel group is already initialized" + + if world_size < sequence_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is less than sequence_parallel_size ({sequence_parallel_size})") + + if world_size % sequence_parallel_size != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by sequence_parallel_size ({sequence_parallel_size})") + + sp_data_parallel_size = world_size // sequence_parallel_size + sequence_and_data_parallel_size = sequence_parallel_size * sp_data_parallel_size + num_sequence_parallel_groups = world_size // sequence_parallel_size + num_sequence_and_data_parallel_groups = world_size // sequence_and_data_parallel_size + + # Build the sequence parallel groups + for i in range(num_sequence_parallel_groups): + ranks = list(range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)) + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("sp", pg_comm_cfgs), + group_desc="SEQUENCE_PARALLEL_GROUP", + ) + if rank in ranks: + self.sequence_parallel_group = group + + # Build the sequence and data parallel groups + for i in range(num_sequence_and_data_parallel_groups): + ranks = list(range(i * sequence_and_data_parallel_size, (i + 1) * sequence_and_data_parallel_size)) + group = self._create_group( + ranks, + timeout=timeout, + pg_options=self._get_pg_options("sp_dp", pg_comm_cfgs), + group_desc="SEQUENCE_AND_DATA_PARALLEL_GROUP", + ) + if rank in ranks: + self.sequence_and_data_parallel_group = group + # Initialize global memory buffer self._set_global_memory_buffer() @@ -837,6 +903,18 @@ def get_context_parallel_group(self, check_initialized=True): assert self.context_parallel_group is not None, "context parallel group is not initialized" return self.context_parallel_group + def get_sequence_parallel_group(self, check_initialized=True): + """Get the sequence-parallel group the caller rank belongs to.""" + if check_initialized: + assert self.sequence_parallel_group is not None, "sequence parallel group is not initialized" + return self.sequence_parallel_group + + def get_sequence_and_data_parallel_group(self, check_initialized=True): + """Get the sequence and data parallel group the caller rank belongs to.""" + if check_initialized: + assert self.sequence_and_data_parallel_group is not None, "sequence and data parallel group is not initialized" + return self.sequence_and_data_parallel_group + def get_embedding_group(self, check_initialized=True): """Get the embedding group the caller rank belongs to.""" if check_initialized: @@ -919,6 +997,34 @@ def get_context_parallel_rank(self): else: return 0 + def get_sequence_parallel_world_size(self): + """Return world size for the sequence parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_parallel_group is not None: + return self.get_sequence_parallel_group().size() + return 1 + + def get_sequence_parallel_rank(self): + """Return caller's rank in the sequence-parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_parallel_group is not None: + return self.get_sequence_parallel_group().rank() + return 0 + + def get_sequence_and_data_parallel_world_size(self): + """Return world size for the sequence and data parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_and_data_parallel_group is not None: + return self.get_sequence_and_data_parallel_group().size() + return 0 + + def get_sequence_and_data_parallel_rank(self): + """Return caller's rank in the sequence and data parallel group.""" + if dist.is_available() and dist.is_initialized(): + if self.sequence_and_data_parallel_group is not None: + return self.get_sequence_and_data_parallel_group().rank() + return 0 + def is_initialized(self): """Check if parallel state has been initialized""" return self.data_parallel_group is not None diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index 2d4cbf93915a..d0eb21acfa19 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -401,6 +401,77 @@ def get_context_parallel_rank(name: Optional[str] = None): return get_parallel_state(name).get_context_parallel_rank() +# ============================================================================ +# Sequence Parallel Functions +# ============================================================================ + + +def get_sequence_parallel_group(name: Optional[str] = None): + """Get the sequence-parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_group() + + +def get_sequence_parallel_world_size(name: Optional[str] = None): + """Return world size for the sequence parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_world_size() + + +def get_sequence_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the sequence-parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_parallel_rank() + + +def get_sequence_and_data_parallel_group(name: Optional[str] = None): + """Get the sequence and data parallel group the caller rank belongs to. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_group() + + +def get_sequence_and_data_parallel_world_size(name: Optional[str] = None): + """Return world size for the sequence and data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_world_size() + + +def get_sequence_and_data_parallel_rank(name: Optional[str] = None): + """Return caller's rank in the sequence and data parallel group. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_sequence_and_data_parallel_rank() + + # ============================================================================ # Expert Parallel Functions # ============================================================================ @@ -638,12 +709,20 @@ def initialize_parallel_state_from_config( "distributed_timeout_minutes": 30, "order": "tp-ep-dp-pp", "create_gloo_process_groups": true, - "high_priority_stream_groups": null + "high_priority_stream_groups": null, + "sequence_parallel_size": 1 }, // Note: The following parameters are NOT supported in DeepSpeed: // - "context_parallel_size": must be 1 (default) // - "hierarchical_context_parallel_sizes": not supported + + // Sequence Parallel (SP) usage notes: + // - SP cannot be used together with TP, PP, or EP + // - When using SP, set tp=1, pp=1, ep=1 + // - Example SP config: {"sequence_parallel_size": 4, "order": "sp-dp"} + // - SP can be combined with DP: {"sequence_parallel_size": 4, "data_parallel_size": 2, "order": "sp-dp"} + "train_batch_size": 8, ... } From d0c6b699a0f17fe0f79b30289a48c25b1eb302d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 19 Jan 2026 10:02:23 +0800 Subject: [PATCH 05/15] fix: remove Chinese comment from config example Remove Chinese inline comment from the example config.json docstring to comply with DeepSpeed community coding standards. This ensures all comments and documentation are in English only. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index d0eb21acfa19..a3e46a062519 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -703,7 +703,7 @@ def initialize_parallel_state_from_config( "expert_model_parallel_size": 1, "expert_tensor_parallel_size": 1, "virtual_pipeline_model_parallel_size": null, - "pipeline_model_parallel_comm_backend": null, ##不要加入config中,保留加载逻辑 + "pipeline_model_parallel_comm_backend": null, "num_distributed_optimizer_instances": 1, "nccl_communicator_config_path": null, "distributed_timeout_minutes": 30, From 5a51d021b5f2d2457c1fead6bb893a2abf5d6dad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 19 Jan 2026 15:08:35 +0800 Subject: [PATCH 06/15] fix: use torch.distributed.new_group directly in _create_group The deepspeed.comm.new_group() wrapper only accepts 'ranks' parameter, but _create_group() needs to pass additional parameters like timeout, backend, pg_options, etc. to support advanced process group configuration. This fix uses torch.distributed.new_group() directly to support all parameters while still using deepspeed.comm for other operations. Fixes TypeError: new_group() got an unexpected keyword argument 'timeout' Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 0e4793e97a86..12d711f4e64f 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -326,6 +326,9 @@ def _create_group( group_desc=None, ): """Creates a ProcessGroup.""" + # Use torch.distributed directly for new_group as deepspeed.comm.new_group only supports ranks parameter + import torch.distributed as torch_dist + kwargs = { "ranks": ranks, "timeout": timeout, @@ -339,7 +342,7 @@ def _create_group( if timeout is None: kwargs.pop("timeout") - group = dist.new_group(**kwargs) + group = torch_dist.new_group(**kwargs) if self.global_process_group_list is None: self.global_process_group_list = [None] if dist.get_rank() in ranks: From 0093da3cc8685b2a39e97c8f527941d68c75b949 Mon Sep 17 00:00:00 2001 From: yunqing Date: Mon, 19 Jan 2026 16:05:00 +0800 Subject: [PATCH 07/15] fix: correct SP parallel group creation logic in parallel_state - Include sequence_parallel_size in model_size calculation - Fix SP group count: num_sequence_parallel_groups = data_parallel_size - Use consecutive rank grouping for SP (not RankGenerator) - SP uses different parallelism model than TP/PP/CP/EP Signed-off-by: Yuqing Li --- deepspeed/utils/parallel_state.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 12d711f4e64f..c2757587f3b8 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -445,7 +445,7 @@ def default_position_embedding_ranks(pp_ranks): world_size: int = dist.get_world_size() rank = dist.get_rank() - model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size * sequence_parallel_size if world_size % model_size != 0: raise RuntimeError(f"world_size ({world_size}) is not divisible by {model_size}") @@ -471,7 +471,7 @@ def default_position_embedding_ranks(pp_ranks): cp=context_parallel_size, order=order, rank_offset=0, - sp=1, + sp=sequence_parallel_size, ) # Build expert rank generator @@ -831,12 +831,14 @@ def default_position_embedding_ranks(pp_ranks): raise RuntimeError( f"world_size ({world_size}) is not divisible by sequence_parallel_size ({sequence_parallel_size})") - sp_data_parallel_size = world_size // sequence_parallel_size - sequence_and_data_parallel_size = sequence_parallel_size * sp_data_parallel_size - num_sequence_parallel_groups = world_size // sequence_parallel_size - num_sequence_and_data_parallel_groups = world_size // sequence_and_data_parallel_size + # SP groups use consecutive ranks + # Number of SP groups = data_parallel_size (each DP rank has its own SP group) + num_sequence_parallel_groups = data_parallel_size + sequence_and_data_parallel_size = world_size + num_sequence_and_data_parallel_groups = 1 - # Build the sequence parallel groups + # Build the sequence parallel groups using consecutive ranks + # SP uses consecutive rank grouping, not orthogonal grouping like TP/PP/CP for i in range(num_sequence_parallel_groups): ranks = list(range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)) group = self._create_group( From e5b52686dd9f9613cccd675466fbc38d797e6c09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Wed, 21 Jan 2026 10:59:01 +0800 Subject: [PATCH 08/15] refactor: simplify _create_group to use deepspeed.comm interface Updated _create_group() to use deepspeed.comm.new_group() which currently only supports 'ranks' parameter. Other parameters (timeout, backend, pg_options, etc.) are commented out and documented in TODO comments. For non-nccl backends, the function returns None with a warning, as these are not yet supported by the deepspeed.comm interface. These parameters will be enabled once DeepSpeed's comm interface is enhanced to support them. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 35 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index c2757587f3b8..8033ebf72a68 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -29,6 +29,8 @@ from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist +from deepspeed.utils.torch import required_torch_version + logger = logging.getLogger(__name__) @@ -324,25 +326,30 @@ def _create_group( pg_options=None, use_local_synchronization=False, group_desc=None, - ): + ): """Creates a ProcessGroup.""" - # Use torch.distributed directly for new_group as deepspeed.comm.new_group only supports ranks parameter - import torch.distributed as torch_dist - + if backend is not None and backend != "nccl": + logger.warning(f"{backend} backend is not supported for new_group. Using torch.distributed directly.") + return None + + # TODO: Currently using deepspeed.comm.new_group() which only supports 'ranks' parameter. + # The following parameters are commented out and will be enabled once DeepSpeed's + # comm interface supports them: + # - timeout: Timeout for process group operations + # - backend: Communication backend (e.g., 'nccl', 'gloo') + # - pg_options: Process group options + # - use_local_synchronization: Enable local synchronization + # - group_desc: Group description for debugging (requires PyTorch >= 2.4) kwargs = { "ranks": ranks, - "timeout": timeout, - "backend": backend, - "pg_options": pg_options, - "use_local_synchronization": use_local_synchronization, - "group_desc": group_desc, + # "timeout": timeout, + # "backend": backend, + # "pg_options": pg_options, + # "use_local_synchronization": use_local_synchronization, + # "group_desc": group_desc, } - if not is_torch_min_version("2.4.0"): - kwargs.pop("group_desc") - if timeout is None: - kwargs.pop("timeout") - group = torch_dist.new_group(**kwargs) + group = dist.new_group(**kwargs) if self.global_process_group_list is None: self.global_process_group_list = [None] if dist.get_rank() in ranks: From 9aad9538b75d59d045a22670b617efebdcca4c2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Wed, 21 Jan 2026 11:58:15 +0800 Subject: [PATCH 09/15] feat: migrate All-to-All groups to parallel_state architecture Migrate _get_local_all_to_all_group functionality from groups.py to the new parallel_state architecture to support ZeRO++ quantized gradients. Changes in parallel_state.py: - Add all_to_all_groups and all_to_all_initialized to ParallelState class - Implement initialize_all_to_all_groups() method to create local and global All-to-All groups based on node topology - Implement get_all_to_all_groups() method to retrieve initialized groups Changes in parallel_state_deepspeed.py: - Add initialize_all_to_all_groups() wrapper function - Add get_all_to_all_groups() wrapper function - Add _get_local_all_to_all_group() for backward compatibility with groups.py Benefits: - Supports multi-instance scenarios (e.g., RL with actor/critic models) - Consistent with the new parallel_state architecture - Maintains backward compatibility with existing groups.py interface - Enables future config-based initialization of All-to-All groups Note: This does not remove the implementation from groups.py yet to maintain backward compatibility during the transition period. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 68 ++++++++++++++++++++- deepspeed/utils/parallel_state_deepspeed.py | 64 +++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 8033ebf72a68..a8ef43091096 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -265,6 +265,10 @@ def __init__(self): self.intra_partial_expert_data_parallel_group_gloo = None self.inter_partial_expert_data_parallel_group = None + # All-to-All groups for ZeRO++ quantized gradients + self.all_to_all_groups = {} + self.all_to_all_initialized = False + # Global ranks lists self.embedding_global_ranks = None self.position_embedding_global_ranks = None @@ -326,7 +330,7 @@ def _create_group( pg_options=None, use_local_synchronization=False, group_desc=None, - ): + ): """Creates a ProcessGroup.""" if backend is not None and backend != "nccl": logger.warning(f"{backend} backend is not supported for new_group. Using torch.distributed directly.") @@ -1041,6 +1045,68 @@ def is_initialized(self): """Check if parallel state has been initialized""" return self.data_parallel_group is not None + def initialize_all_to_all_groups(self): + """Initialize All-to-All groups for quantized gradient communication. + + Creates local and global All-to-All groups based on node topology: + - Local groups: intra-node communication (NVLink/NVSwitch) + - Global groups: inter-node communication (cross-node) + + Used by ZeRO++ when zero_quantized_gradients is enabled. + + Returns: + Dictionary of All-to-All groups + """ + if self.all_to_all_initialized: + return self.all_to_all_groups + + assert dist.is_initialized(), 'dist is not initialized' + + device_per_node = get_accelerator().device_count() + world_size = dist.get_world_size() + num_nodes = world_size // device_per_node + + if num_nodes == 0 and world_size > 0: + # Single incomplete node + assert world_size >= 1, 'num_gpus must >=1, cannot initialize All-To-All' + ranks = list(range(world_size)) + self.all_to_all_groups['local_0'] = self._create_group(ranks) + + elif num_nodes == 1: + # Exactly one node + assert world_size == device_per_node, 'num_gpus not equal to device per node, cannot initialize All-To-All' + ranks = list(range(device_per_node)) + self.all_to_all_groups['local_0'] = self._create_group(ranks) + + else: + # Multiple nodes: create both local and global groups + assert world_size > device_per_node, 'num_nodes<2 cannot initialize All-To-All' + + # Local groups (intra-node) + for node_id in range(num_nodes): + local_ranks = [j + device_per_node * node_id for j in range(device_per_node)] + self.all_to_all_groups[f"local_{node_id}"] = self._create_group(local_ranks) + + # Global groups (inter-node) + for device_id in range(device_per_node): + global_ranks = [device_id + j * device_per_node for j in range(num_nodes)] + self.all_to_all_groups[f"global_{device_id}"] = self._create_group(global_ranks) + + self.all_to_all_initialized = True + return self.all_to_all_groups + + def get_all_to_all_groups(self): + """Get All-to-All groups dictionary. + + Initializes the groups if not already initialized. + + Returns: + Dictionary of All-to-All groups + """ + if not self.all_to_all_initialized: + self.initialize_all_to_all_groups() + return self.all_to_all_groups + def get_global_memory_buffer(self): """Return the global GlobalMemoryBuffer object""" assert self.global_memory_buffer is not None, "global memory buffer is not initialized" diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index a3e46a062519..3d65a8b70e9b 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -632,6 +632,70 @@ def is_initialized(name: Optional[str] = None): return get_parallel_state(name).is_initialized() +# ============================================================================ +# All-to-All Groups for ZeRO++ Quantized Gradients +# ============================================================================ + + +def initialize_all_to_all_groups(name: Optional[str] = None): + """Initialize All-to-All groups for quantized gradient communication. + + Creates local and global All-to-All groups based on node topology. + Used by ZeRO++ when zero_quantized_gradients is enabled. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + Example: + # Initialize for default instance + all_to_all_groups = initialize_all_to_all_groups() + + # Initialize for named instance (RL scenario) + actor_groups = initialize_all_to_all_groups("actor") + critic_groups = initialize_all_to_all_groups("critic") + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).initialize_all_to_all_groups() + + +def get_all_to_all_groups(name: Optional[str] = None): + """Get All-to-All groups dictionary. + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_all_to_all_groups() + + +def _get_local_all_to_all_group(name: Optional[str] = None): + """Get All-to-All groups for current rank (backward compatible with groups.py). + + This function provides backward compatibility with the groups.py interface. + It returns all All-to-All groups (both local and global). + + Args: + name: Optional name of the parallel state instance. If None, uses current active instance. + + Returns: + Dictionary of All-to-All groups + + Note: + This is a compatibility wrapper. New code should use get_all_to_all_groups() instead. + + DeepSpeed-compatible interface. + """ + return get_parallel_state(name).get_all_to_all_groups() + + # ============================================================================ # Configuration-based Initialization # ============================================================================ From 4b75a1a8aa011f0d4a8000c9a61dd3204f7239aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Fri, 23 Jan 2026 15:21:48 +0800 Subject: [PATCH 10/15] fix: disable gloo process groups by default DeepSpeed's comm interface does not support gloo backend, so set create_gloo_process_groups default to False. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index a8ef43091096..106155f3c846 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -428,7 +428,7 @@ def initialize_model_parallel( order: str = "tp-ep-dp-pp", get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None, - create_gloo_process_groups: bool = True, + create_gloo_process_groups: bool = False, ) -> None: """Initialize model data parallel groups. From a187ebe7a83bfaa0d7922308d59063c3e6502f53 Mon Sep 17 00:00:00 2001 From: yunqing Date: Fri, 23 Jan 2026 15:31:36 +0800 Subject: [PATCH 11/15] refactor: simplify SP group creation using RankGenerator - Replace manual consecutive rank grouping with RankGenerator.get_ranks('sp') - Remove redundant world_size validation logic (handled by RankGenerator) - Reduce SP group creation code from 41 lines to 26 lines - Maintain same SP group topology: consecutive ranks [0,1], [2,3] for sp_size=2 - Fix code style issues: remove unused import, update warning message This change unifies process group creation by leveraging RankGenerator's orthogonal parallelism algorithm, which naturally produces consecutive rank grouping when order='sp-dp'. Signed-off-by: Yuqing Li --- deepspeed/utils/parallel_state.py | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 106155f3c846..52719a807986 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -29,8 +29,6 @@ from deepspeed.accelerator import get_accelerator import deepspeed.comm as dist -from deepspeed.utils.torch import required_torch_version - logger = logging.getLogger(__name__) @@ -333,7 +331,7 @@ def _create_group( ): """Creates a ProcessGroup.""" if backend is not None and backend != "nccl": - logger.warning(f"{backend} backend is not supported for new_group. Using torch.distributed directly.") + logger.warning(f"{backend} backend is not supported for new_group. Using deepspeed.comm directly.") return None # TODO: Currently using deepspeed.comm.new_group() which only supports 'ranks' parameter. @@ -829,29 +827,13 @@ def default_position_embedding_ranks(pp_ranks): self.intra_distributed_optimizer_instance_group = intra_dist_opt_instance_group intra_dist_opt_ranks = [] - # Build sequence parallel groups + # Build sequence parallel groups using RankGenerator if sequence_parallel_size > 1: assert self.sequence_parallel_group is None, "sequence parallel group is already initialized" assert self.sequence_and_data_parallel_group is None, "sequence and data parallel group is already initialized" - if world_size < sequence_parallel_size: - raise RuntimeError( - f"world_size ({world_size}) is less than sequence_parallel_size ({sequence_parallel_size})") - - if world_size % sequence_parallel_size != 0: - raise RuntimeError( - f"world_size ({world_size}) is not divisible by sequence_parallel_size ({sequence_parallel_size})") - - # SP groups use consecutive ranks - # Number of SP groups = data_parallel_size (each DP rank has its own SP group) - num_sequence_parallel_groups = data_parallel_size - sequence_and_data_parallel_size = world_size - num_sequence_and_data_parallel_groups = 1 - - # Build the sequence parallel groups using consecutive ranks - # SP uses consecutive rank grouping, not orthogonal grouping like TP/PP/CP - for i in range(num_sequence_parallel_groups): - ranks = list(range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)) + # Build SP groups using RankGenerator + for ranks in self.decoder_rank_generator.get_ranks('sp'): group = self._create_group( ranks, timeout=timeout, @@ -861,9 +843,8 @@ def default_position_embedding_ranks(pp_ranks): if rank in ranks: self.sequence_parallel_group = group - # Build the sequence and data parallel groups - for i in range(num_sequence_and_data_parallel_groups): - ranks = list(range(i * sequence_and_data_parallel_size, (i + 1) * sequence_and_data_parallel_size)) + # Build SP+DP combined groups using RankGenerator + for ranks in self.decoder_rank_generator.get_ranks('sp-dp'): group = self._create_group( ranks, timeout=timeout, From af827a0dc164e8e7371b902b64be126c72d47e8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 26 Jan 2026 09:48:36 +0800 Subject: [PATCH 12/15] docs: fix config example and SP usage notes 1. Change create_gloo_process_groups from true to false - Aligns with default value change in previous commit - DeepSpeed comm interface does not support gloo backend 2. Correct Sequence Parallel usage description - SP is included in model_size calculation (tp * pp * cp * sp) - SP can be used together with TP/PP/EP - Number of SP groups equals data_parallel_size - SP uses consecutive rank grouping (not orthogonal like TP/PP/CP/EP) Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index 3d65a8b70e9b..23bbe61b0b6d 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -772,7 +772,7 @@ def initialize_parallel_state_from_config( "nccl_communicator_config_path": null, "distributed_timeout_minutes": 30, "order": "tp-ep-dp-pp", - "create_gloo_process_groups": true, + "create_gloo_process_groups": false, "high_priority_stream_groups": null, "sequence_parallel_size": 1 }, @@ -782,10 +782,10 @@ def initialize_parallel_state_from_config( // - "hierarchical_context_parallel_sizes": not supported // Sequence Parallel (SP) usage notes: - // - SP cannot be used together with TP, PP, or EP - // - When using SP, set tp=1, pp=1, ep=1 - // - Example SP config: {"sequence_parallel_size": 4, "order": "sp-dp"} - // - SP can be combined with DP: {"sequence_parallel_size": 4, "data_parallel_size": 2, "order": "sp-dp"} + // - SP dimension is included in model_size calculation: model_size = tp * pp * cp * sp + // - Number of SP groups = data_parallel_size (each DP rank has its own SP group) + // - SP uses consecutive rank grouping, different from TP/PP/CP/EP orthogonal grouping + // - Example: world_size=16, tp=2, sp=2, pp=1, ep=1 => dp=4, and 4 SP groups "train_batch_size": 8, ... From addf430a542ff3b0f02737a42199ef20dbd5288e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 26 Jan 2026 09:59:14 +0800 Subject: [PATCH 13/15] refactor: remove unused is_torch_min_version function Remove is_torch_min_version function that is never called in the codebase. This reduces code complexity and removes unnecessary dependencies. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/deepspeed/utils/parallel_state.py b/deepspeed/utils/parallel_state.py index 52719a807986..93bf297d0ba0 100644 --- a/deepspeed/utils/parallel_state.py +++ b/deepspeed/utils/parallel_state.py @@ -39,27 +39,6 @@ HAVE_EINOPS = False -def is_torch_min_version(version: str, check_equality: bool = True) -> bool: - """Check if PyTorch version meets minimum requirement. - - Args: - version: Version string to check (e.g., "2.4.0") - check_equality: If True, also check for equality - - Returns: - True if version requirement is met - """ - try: - from packaging.version import Version as PkgVersion - torch_version = PkgVersion(torch.__version__) - required_version = PkgVersion(version) - if check_equality: - return torch_version >= required_version - return torch_version > required_version - except Exception: - return False - - class GlobalMemoryBuffer: """Global buffer to avoid dynamic memory allocations.""" From d0ba240192c0c136100b9cdab3de29b8d8d5a423 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A8=80=E6=9E=A2?= Date: Mon, 26 Jan 2026 14:52:14 +0800 Subject: [PATCH 14/15] refactor: simplify config-based initialization to use top-level fields Remove the 'parallelism' config block concept and directly read from existing DeepSpeed config fields. This avoids adding new top-level config structure which requires changes throughout the codebase. Changes: - Remove 'parallelism' nested config block from examples - Read 'sequence_parallel_size' directly from top-level config - Change priority: function params > config values > defaults (was: config > params > defaults) - Update create_gloo_process_groups default from True to False - Simplify documentation to reflect current implementation This makes the config-based initialization fully backward compatible without requiring any new config schema validation or parsing logic. Signed-off-by: Jikang Mo --- deepspeed/utils/parallel_state_deepspeed.py | 99 +++++++-------------- 1 file changed, 32 insertions(+), 67 deletions(-) diff --git a/deepspeed/utils/parallel_state_deepspeed.py b/deepspeed/utils/parallel_state_deepspeed.py index 23bbe61b0b6d..0a2d71b112c6 100644 --- a/deepspeed/utils/parallel_state_deepspeed.py +++ b/deepspeed/utils/parallel_state_deepspeed.py @@ -704,7 +704,6 @@ def _get_local_all_to_all_group(name: Optional[str] = None): def initialize_parallel_state_from_config( config: Union[Dict[str, Any], Any], name: Optional[str] = None, - config_key: str = "parallelism", # Optional parameters to override config values tensor_model_parallel_size: Optional[int] = None, pipeline_model_parallel_size: Optional[int] = None, @@ -715,6 +714,7 @@ def initialize_parallel_state_from_config( expert_model_parallel_size: Optional[int] = None, num_distributed_optimizer_instances: Optional[int] = None, expert_tensor_parallel_size: Optional[int] = None, + sequence_parallel_size: Optional[int] = None, nccl_communicator_config_path: Optional[str] = None, distributed_timeout_minutes: Optional[int] = None, order: Optional[str] = None, @@ -724,15 +724,10 @@ def initialize_parallel_state_from_config( """Initialize parallel state from DeepSpeed config.json with optional parameter overrides. This function reads parallelism configuration from the DeepSpeed config file - and automatically initializes the ParallelState instance. This allows users - to configure all parallelism dimensions in a single place (config.json) - rather than having to read documentation and manually call initialize_model_parallel. + (top-level fields) and automatically initializes the ParallelState instance. + This allows code to work with both explicit initialization and config-based initialization. - Configuration priority: config file (if explicitly set) > function parameters > default values - - Note: If a value is explicitly set in config file, it takes precedence over function - parameters. A warning will be logged if there's a conflict. To override config file - values, remove them from the config file first. + Configuration priority: function parameters > config file values > default values (1) Args: config: Either a DeepSpeedConfig object or a config dictionary. @@ -740,8 +735,6 @@ def initialize_parallel_state_from_config( If dict, will use it directly. name: Optional name of the parallel state instance to initialize. If None, initializes the default global instance. - config_key: Key in the config dictionary where parallelism config is stored. - Default is "parallelism". # Parallelism dimension parameters (override config if provided): tensor_model_parallel_size: Size of tensor model parallel group. Default: 1 @@ -753,44 +746,27 @@ def initialize_parallel_state_from_config( expert_model_parallel_size: Size of expert model parallel group. Default: 1 num_distributed_optimizer_instances: Number of distributed optimizer instances. Default: 1 expert_tensor_parallel_size: Size of expert tensor parallel group. Default: None + sequence_parallel_size: Size of sequence parallel group. Default: 1 nccl_communicator_config_path: Path to NCCL communicator config. Default: None distributed_timeout_minutes: Timeout for distributed operations. Default: 30 order: Order of parallelism dimensions. Default: "tp-ep-dp-pp" - create_gloo_process_groups: Whether to create Gloo process groups. Default: True + create_gloo_process_groups: Whether to create Gloo process groups. Default: False high_priority_stream_groups: High priority stream groups. Default: None - Example config.json: + Example config.json (using existing DeepSpeed config fields): { - "parallelism": { - "tensor_model_parallel_size": 2, - "pipeline_model_parallel_size": 1, - "expert_model_parallel_size": 1, - "expert_tensor_parallel_size": 1, - "virtual_pipeline_model_parallel_size": null, - "pipeline_model_parallel_comm_backend": null, - "num_distributed_optimizer_instances": 1, - "nccl_communicator_config_path": null, - "distributed_timeout_minutes": 30, - "order": "tp-ep-dp-pp", - "create_gloo_process_groups": false, - "high_priority_stream_groups": null, - "sequence_parallel_size": 1 - }, - - // Note: The following parameters are NOT supported in DeepSpeed: - // - "context_parallel_size": must be 1 (default) - // - "hierarchical_context_parallel_sizes": not supported - - // Sequence Parallel (SP) usage notes: - // - SP dimension is included in model_size calculation: model_size = tp * pp * cp * sp - // - Number of SP groups = data_parallel_size (each DP rank has its own SP group) - // - SP uses consecutive rank grouping, different from TP/PP/CP/EP orthogonal grouping - // - Example: world_size=16, tp=2, sp=2, pp=1, ep=1 => dp=4, and 4 SP groups - "train_batch_size": 8, - ... + "sequence_parallel_size": 1, + "zero_optimization": { + "stage": 1 + } } + Note: + - Currently only "sequence_parallel_size" can be read from config (existing field) + - Other parallelism parameters must be passed via function parameters or use defaults + - Context Parallel is NOT supported (cp must be 1) + Example usage: # Basic usage from config file: from deepspeed import DeepSpeedConfig @@ -828,11 +804,6 @@ def initialize_parallel_state_from_config( else: raise ValueError(f"config must be a DeepSpeedConfig object or a dict, got {type(config)}") - # Check if parallelism config exists in config file - parallelism_config = config_dict.get(config_key, {}) - if parallelism_config and not isinstance(parallelism_config, dict): - raise ValueError(f"'{config_key}' in config must be a dictionary, got {type(parallelism_config)}") - # Get the parallel state instance ps = get_parallel_state_instance(name) @@ -846,37 +817,29 @@ def initialize_parallel_state_from_config( logger = logging.getLogger(__name__) # Helper function to get value with proper priority handling - # Priority: config file (if explicitly set) > function parameter > default + # Priority: function parameter > config file value > default value def get_value(param_name, param_value, config_key, default_value): """ - Get value with priority handling and conflict detection. + Get value with priority handling. Priority: - 1. If config file explicitly sets the value -> use config value (warn if param differs) - 2. If config file doesn't have the value -> use function parameter - 3. If both are None -> use default value + 1. If function parameter is provided -> use parameter value + 2. If config file has the value -> use config value + 3. Otherwise -> use default value """ - config_has_key = config_key in parallelism_config - config_value = parallelism_config.get(config_key) - - # Case 1: Config file explicitly sets the value - if config_has_key: - # If function parameter is also provided and differs, warn and use config - if param_value is not None and param_value != config_value: - logger.warning(f"Parameter '{param_name}' conflict detected: " - f"config file specifies {config_value}, but function parameter is {param_value}. " - f"Using config file value ({config_value}). " - f"To override config, remove '{config_key}' from config file.") - return config_value - - # Case 2: Config file doesn't have the key, use function parameter if provided + # Case 1: Function parameter provided if param_value is not None: return param_value - # Case 3: Neither config nor parameter provided, use default + # Case 2: Config file has the key + if config_key in config_dict: + config_value = config_dict[config_key] + return config_value + + # Case 3: Use default return default_value - # Extract parameters with proper priority: config (if set) > function param > default + # Extract parameters with proper priority: function param > config value > default init_kwargs = { "tensor_model_parallel_size": get_value("tensor_model_parallel_size", tensor_model_parallel_size, "tensor_model_parallel_size", 1), @@ -890,6 +853,8 @@ def get_value(param_name, param_value, config_key, default_value): "pipeline_model_parallel_comm_backend", None), "context_parallel_size": get_value("context_parallel_size", context_parallel_size, "context_parallel_size", 1), + "sequence_parallel_size": + get_value("sequence_parallel_size", sequence_parallel_size, "sequence_parallel_size", 1), "hierarchical_context_parallel_sizes": get_value("hierarchical_context_parallel_sizes", hierarchical_context_parallel_sizes, "hierarchical_context_parallel_sizes", None), @@ -908,7 +873,7 @@ def get_value(param_name, param_value, config_key, default_value): "order": get_value("order", order, "order", "tp-ep-dp-pp"), "create_gloo_process_groups": - get_value("create_gloo_process_groups", create_gloo_process_groups, "create_gloo_process_groups", True), + get_value("create_gloo_process_groups", create_gloo_process_groups, "create_gloo_process_groups", False), "high_priority_stream_groups": get_value("high_priority_stream_groups", high_priority_stream_groups, "high_priority_stream_groups", None), } From fa341163d934318d310f63a92c905b787c82cfcd Mon Sep 17 00:00:00 2001 From: Junjie Mao Date: Mon, 26 Jan 2026 18:08:41 +0800 Subject: [PATCH 15/15] tests: Drop test_mpu.py from the PR The test_mpu.py script is used to verify the equivalence between existing process group management facilities and the proposed, unified ParallelState. It is meant to be an temporary helper and will not be useful after we switch existing implementations to the new interfaces. Thus remove it from the current PR. The test is still available at https://gist.github.com/eternalNight/b76c72216b4be84832b615b76465396f. Signed-off-by: Junjie Mao --- tests/unit/utils/test_mpu.py | 1692 ---------------------------------- 1 file changed, 1692 deletions(-) delete mode 100644 tests/unit/utils/test_mpu.py diff --git a/tests/unit/utils/test_mpu.py b/tests/unit/utils/test_mpu.py deleted file mode 100644 index 11ed585c92b3..000000000000 --- a/tests/unit/utils/test_mpu.py +++ /dev/null @@ -1,1692 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) DeepSpeed Team - -# DeepSpeed Team -""" -Automated testing of parallel strategy combinations using random configurations. - -This test automatically generates random parallel configurations and tests -both parallel_state_refactored and DeepSpeed to see if they produce compatible results. -""" - -import pytest -import random -from typing import Dict, List, Tuple, Optional -from collections import defaultdict - -# Try to import both libraries -try: - from deepspeed.utils.parallel_state import RankGenerator - PARALLEL_STATE_AVAILABLE = True -except ImportError as e: - PARALLEL_STATE_AVAILABLE = False - print(f"Warning: Could not import Megatron parallel_state_refactored: {e}") - -try: - from deepspeed.utils import groups as ds_groups - from deepspeed.runtime.sequence_parallel import parallel_state_sp as ds_sp - DEEPSPEED_AVAILABLE = True -except ImportError as e: - DEEPSPEED_AVAILABLE = False - print(f"Warning: Could not import DeepSpeed: {e}") - - -class ParallelConfigGenerator: - """Generate random parallel configurations for testing.""" - - def __init__(self, seed=None): - if seed is not None: - random.seed(seed) - self.tested_configs = [] - self.failed_configs = [] - - def generate_random_config(self, max_size=1024, min_parallel_size=1, max_parallel_size=32): - """Generate a random parallel configuration. - - Args: - max_size: Maximum world size to consider - min_parallel_size: Minimum parallel size for each dimension - max_parallel_size: Maximum parallel size for each dimension - - Returns: - Dict with tp, dp, pp, cp, ep values and order - """ - # Generate random sizes for each dimension - # Don't filter invalid configurations - we want to test and report all cases - tp = random.randint(min_parallel_size, max_parallel_size) - dp = random.randint(min_parallel_size, max_parallel_size) - pp = random.randint(min_parallel_size, max_parallel_size) - cp = random.randint(min_parallel_size, max_parallel_size) - ep = random.randint(min_parallel_size, max_parallel_size) - - # Calculate world size - world_size = tp * dp * pp * cp * ep - - # If world size is too large, scale down proportionally - # But try to keep at least one dimension > 1 - if world_size > max_size: - # Scale down proportionally - scale_factor = (max_size / world_size)**0.25 - tp = max(1, int(tp * scale_factor)) - dp = max(1, int(dp * scale_factor)) - pp = max(1, int(pp * scale_factor)) - cp = max(1, int(cp * scale_factor)) - ep = max(1, int(ep * scale_factor)) - world_size = tp * dp * pp * cp * ep - - # Ensure at least one dimension is > 1 - if world_size == 1: - tp = 2 - world_size = 2 - - # Generate random order (but must include all non-1 dimensions) - dimensions = [] - if tp > 1: - dimensions.append('tp') - if dp > 1: - dimensions.append('dp') - if pp > 1: - dimensions.append('pp') - if cp > 1: - dimensions.append('cp') - if ep > 1: - dimensions.append('ep') - - # Shuffle to get random order - random.shuffle(dimensions) - order = '-'.join(dimensions) if dimensions else 'tp' - - # If no dimensions > 1, use default - if not dimensions: - order = 'tp-dp' - tp = 2 - dp = 2 - - config = { - "tp": tp, - "dp": dp, - "pp": pp, - "cp": cp, - "ep": ep, - "order": order, - "world_size": tp * dp * pp * cp * ep, - } - - return config - - def generate_systematic_configs(self, max_world_size=512): - """Generate systematic configurations covering common cases. - - Args: - max_world_size: Maximum world size to consider - - Returns: - List of configurations - """ - configs = [] - - # Single parallelism - test larger sizes - for size in [2, 4, 8, 16, 32, 64, 128, 256]: - if size <= max_world_size: - configs.append({"tp": size, "dp": 1, "pp": 1, "cp": 1, "ep": 1, "order": "tp", "world_size": size}) - configs.append({"tp": 1, "dp": size, "pp": 1, "cp": 1, "ep": 1, "order": "dp", "world_size": size}) - configs.append({"tp": 1, "dp": 1, "pp": size, "cp": 1, "ep": 1, "order": "pp", "world_size": size}) - - # Two-way combinations - more variations - for tp, dp in [(2, 2), (2, 4), (4, 2), (2, 8), (8, 2), (4, 4), (2, 16), (16, 2), (4, 8), (8, 4)]: - if tp * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp-dp", - "world_size": tp * dp - }) - configs.append({ - "tp": tp, - "dp": dp, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp-tp", - "world_size": tp * dp - }) - - for tp, pp in [(2, 2), (2, 4), (4, 2), (2, 8), (8, 2), (4, 4)]: - if tp * pp <= max_world_size: - configs.append({ - "tp": tp, - "dp": 1, - "pp": pp, - "cp": 1, - "ep": 1, - "order": "tp-pp", - "world_size": tp * pp - }) - - for tp, cp in [(2, 2), (2, 4), (4, 2), (2, 8)]: - if tp * cp <= max_world_size: - configs.append({ - "tp": tp, - "dp": 1, - "pp": 1, - "cp": cp, - "ep": 1, - "order": "tp-cp", - "world_size": tp * cp - }) - - for tp, ep in [(2, 2), (2, 4), (4, 2), (2, 8)]: - if tp * ep <= max_world_size: - configs.append({ - "tp": tp, - "dp": 1, - "pp": 1, - "cp": 1, - "ep": ep, - "order": "tp-ep", - "world_size": tp * ep - }) - - # Three-way combinations - more variations - for tp, pp, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2), (4, 2, 2), (2, 2, 8), (2, 4, 4), (4, 4, 2)]: - if tp * pp * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": pp, - "cp": 1, - "ep": 1, - "order": "tp-pp-dp", - "world_size": tp * pp * dp - }) - configs.append({ - "tp": tp, - "dp": dp, - "pp": pp, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": tp * pp * dp - }) - - for tp, cp, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2)]: - if tp * cp * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": 1, - "cp": cp, - "ep": 1, - "order": "tp-cp-dp", - "world_size": tp * cp * dp - }) - - for tp, ep, dp in [(2, 2, 2), (2, 2, 4), (2, 4, 2)]: - if tp * ep * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": 1, - "cp": 1, - "ep": ep, - "order": "tp-ep-dp", - "world_size": tp * ep * dp - }) - - # Four-way combinations - more variations - for tp, pp, dp, cp in [(2, 2, 2, 2), (2, 2, 2, 4), (2, 2, 4, 2), (2, 4, 2, 2)]: - if tp * pp * dp * cp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": pp, - "cp": cp, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": tp * pp * dp * cp - }) - - for tp, ep, pp, dp in [(2, 2, 2, 2), (2, 2, 2, 4), (2, 2, 4, 2)]: - if tp * ep * pp * dp <= max_world_size: - configs.append({ - "tp": tp, - "dp": dp, - "pp": pp, - "cp": 1, - "ep": ep, - "order": "tp-ep-pp-dp", - "world_size": tp * ep * pp * dp - }) - - return configs - - def generate_random_configs(self, count=1000, max_size=1024): - """Generate multiple random configurations. - - Args: - count: Number of random configurations to generate - max_size: Maximum world size - - Returns: - List of configurations - """ - configs = [] - seen = set() - - for _ in range(count): - config = self.generate_random_config(max_size=max_size) - # Create a unique key for this configuration - key = (config["tp"], config["dp"], config["pp"], config["cp"], config["ep"], config["order"]) - if key not in seen: - seen.add(key) - configs.append(config) - - return configs - - def generate_random_config_by_dimension(self, - dimension_count: int, - max_size=1024, - min_parallel_size=2, - max_parallel_size=32): - """Generate a random configuration with exactly the specified number of dimensions > 1. - - Args: - dimension_count: Number of dimensions that should be > 1 (1-5) - max_size: Maximum world size - min_parallel_size: Minimum parallel size for each dimension - max_parallel_size: Maximum parallel size for each dimension - - Returns: - Dict with tp, dp, pp, cp, ep values and order - """ - # All possible dimensions - all_dims = ['tp', 'dp', 'pp', 'cp', 'ep'] - - # Randomly select which dimensions to activate - active_dims = random.sample(all_dims, min(dimension_count, len(all_dims))) - - # Initialize all dimensions to 1 - config = { - "tp": 1, - "dp": 1, - "pp": 1, - "cp": 1, - "ep": 1, - } - - # Set active dimensions to random values - for dim in active_dims: - config[dim] = random.randint(min_parallel_size, max_parallel_size) - - # Calculate world size - world_size = config["tp"] * config["dp"] * config["pp"] * config["cp"] * config["ep"] - - # If world size is too large, scale down proportionally - if world_size > max_size: - scale_factor = (max_size / world_size)**(1.0 / dimension_count) - for dim in active_dims: - config[dim] = max(min_parallel_size, int(config[dim] * scale_factor)) - world_size = config["tp"] * config["dp"] * config["pp"] * config["cp"] * config["ep"] - - # Generate random order from active dimensions - random.shuffle(active_dims) - order = '-'.join(active_dims) - - config["order"] = order - config["world_size"] = world_size - - return config - - def generate_random_configs_by_dimension(self, - counts_by_dimension: Dict[int, int], - max_size=1024, - min_parallel_size=2, - max_parallel_size=32): - """Generate random configurations for each dimension separately. - - Args: - counts_by_dimension: Dict mapping dimension count (1-5) to number of configs to generate - e.g., {1: 100, 2: 200, 3: 150, 4: 100, 5: 50} - max_size: Maximum world size - min_parallel_size: Minimum parallel size for each dimension - max_parallel_size: Maximum parallel size for each dimension - - Returns: - List of configurations grouped by dimension count - """ - all_configs = [] - seen = set() - - for dim_count, count in counts_by_dimension.items(): - if dim_count < 1 or dim_count > 5: - continue - - dim_configs = [] - attempts = 0 - # Increased max_attempts for larger test sets (20x more configs) - max_attempts = count * 20 # Prevent infinite loops, allow more attempts for uniqueness - - while len(dim_configs) < count and attempts < max_attempts: - attempts += 1 - config = self.generate_random_config_by_dimension(dim_count, max_size, min_parallel_size, - max_parallel_size) - - # Create a unique key for this configuration - key = (config["tp"], config["dp"], config["pp"], config["cp"], config["ep"], config["order"]) - - if key not in seen: - seen.add(key) - dim_configs.append(config) - all_configs.append(config) - - if len(dim_configs) < count: - print( - f"Warning: Only generated {len(dim_configs)}/{count} configs for {dim_count}D combinations (attempted {attempts} times)" - ) - - return all_configs - - -class ErrorCategorizer: - """Categorize and aggregate errors by type.""" - - def __init__(self): - self.error_categories = defaultdict(list) - self.combination_stats = defaultdict(int) - - def categorize_error(self, error_msg: str, config: Dict) -> str: - """Categorize an error message into a category.""" - error_lower = error_msg.lower() - - if "ep and cp cannot both be > 1" in error_lower: - return "EP_CP_CONFLICT" - elif "cp not supported" in error_lower: - return "CP_NOT_SUPPORTED" - elif "pp requires" in error_lower or "pipeline" in error_lower: - return "PP_REQUIRES_MPU" - elif "not divisible" in error_lower: - return "DIVISIBILITY_ERROR" - elif "order" in error_lower and "specified" in error_lower: - return "ORDER_MISMATCH" - elif "not available" in error_lower: - return "FEATURE_NOT_AVAILABLE" - else: - return "OTHER_ERROR" - - def get_combination_type(self, config: Dict) -> str: - """Get the combination type string for a configuration.""" - dims = [] - if config["tp"] > 1: - dims.append("TP") - if config["dp"] > 1: - dims.append("DP") - if config["pp"] > 1: - dims.append("PP") - if config["cp"] > 1: - dims.append("CP") - if config["ep"] > 1: - dims.append("EP") - - if not dims: - return "NONE" - - return "+".join(sorted(dims)) - - def record_error(self, error_msg: str, config: Dict, library: str): - """Record an error with categorization.""" - category = self.categorize_error(error_msg, config) - combo_type = self.get_combination_type(config) - - self.error_categories[category].append({ - "error": error_msg, - "config": config, - "library": library, - "combination": combo_type, - }) - - self.combination_stats[combo_type] += 1 - - def get_error_summary(self) -> Dict: - """Get summary of errors by category.""" - summary = {} - for category, errors in self.error_categories.items(): - summary[category] = { - "count": len(errors), - "examples": errors[:5], # First 5 examples - "unique_combinations": len(set(e["combination"] for e in errors)), - } - return summary - - -class ParallelCompatibilityTester: - """Test compatibility between Megatron and DeepSpeed for parallel configurations.""" - - def __init__(self): - self.results = { - "megatron_success": [], - "megatron_failures": [], - "deepspeed_success": [], - "deepspeed_failures": [], - "compatible": [], - "incompatible": [], - "megatron_only": [], - "deepspeed_only": [], - } - self.error_categorizer = ErrorCategorizer() - self.combination_stats = defaultdict( - lambda: { - "total": 0, - "megatron_success": 0, - "megatron_failures": 0, - "deepspeed_success": 0, - "deepspeed_failures": 0, - "compatible": 0, - "megatron_only": 0, - "deepspeed_only": 0, - "incompatible": 0, - }) - - def test_megatron_config(self, config: Dict) -> Tuple[bool, Optional[str], Optional[Dict]]: - """Test if a configuration works with Megatron. - - Returns: - (success, error_message, result_data) - """ - if not PARALLEL_STATE_AVAILABLE: - return False, "Megatron not available", None - - try: - # Check EP and CP constraint - if config["ep"] > 1 and config["cp"] > 1: - return False, "EP and CP cannot both be > 1 in Megatron", None - - # Create RankGenerator - rg = RankGenerator(tp=config["tp"], - ep=config["ep"], - dp=config["dp"], - pp=config["pp"], - cp=config["cp"], - order=config["order"]) - - # Test getting ranks for each dimension - result_data = { - "world_size": rg.world_size, - "tp_groups": rg.get_ranks("tp") if config["tp"] > 1 else [], - "dp_groups": rg.get_ranks("dp") if config["dp"] > 1 else [], - "pp_groups": rg.get_ranks("pp") if config["pp"] > 1 else [], - "cp_groups": rg.get_ranks("cp") if config["cp"] > 1 else [], - "ep_groups": rg.get_ranks("ep") if config["ep"] > 1 else [], - } - - # Test combined groups - if len([d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1]) > 1: - combined_token = config["order"] - result_data["combined_groups"] = rg.get_ranks(combined_token) - - return True, None, result_data - - except Exception as e: - return False, str(e), None - - def test_deepspeed_config(self, config: Dict) -> Tuple[bool, Optional[str], Optional[Dict]]: - """Test if a configuration is supported by DeepSpeed. - - Returns: - (supported, error_message, support_info) - """ - if not DEEPSPEED_AVAILABLE: - return False, "DeepSpeed not available", None - - support_info = { - "tp_supported": False, - "dp_supported": False, - "pp_supported": False, - "cp_supported": False, - "ep_supported": False, - "sp_supported": False, - "notes": [], - } - - # Check TP support - if config["tp"] > 1: - support_info["tp_supported"] = hasattr(ds_groups, 'get_tensor_model_parallel_group') - - # Check DP support - if config["dp"] > 1: - support_info["dp_supported"] = hasattr(ds_groups, 'get_data_parallel_group') - - # Check PP support - if config["pp"] > 1: - # DeepSpeed supports PP via mpu or pipe module - support_info["pp_supported"] = (hasattr(ds_groups, 'bwc_pipeline_parallel_world_size') - or self._check_module_exists('deepspeed.pipe')) - if not support_info["pp_supported"]: - support_info["notes"].append("PP requires mpu object or deepspeed.pipe module") - - # Check CP support - if config["cp"] > 1: - support_info["cp_supported"] = hasattr(ds_groups, 'get_context_parallel_group') - if not support_info["cp_supported"]: - support_info["notes"].append("CP not supported in DeepSpeed") - - # Check EP support - if config["ep"] > 1: - support_info["ep_supported"] = (hasattr(ds_groups, '_create_expert_and_data_parallel') - or hasattr(ds_groups, '_create_expert_data_and_model_parallel')) - - # Check SP support (DeepSpeed-specific) - support_info["sp_supported"] = hasattr(ds_sp, 'initialize_sequence_parallel') - - # Determine overall support - required_dims = [d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1] - supported_dims = [] - if config["tp"] > 1 and support_info["tp_supported"]: - supported_dims.append("tp") - if config["dp"] > 1 and support_info["dp_supported"]: - supported_dims.append("dp") - if config["pp"] > 1 and support_info["pp_supported"]: - supported_dims.append("pp") - if config["cp"] > 1 and support_info["cp_supported"]: - supported_dims.append("cp") - if config["ep"] > 1 and support_info["ep_supported"]: - supported_dims.append("ep") - - fully_supported = len(supported_dims) == len(required_dims) - - return fully_supported, None, support_info - - def _check_module_exists(self, module_name): - """Check if a module exists.""" - try: - __import__(module_name) - return True - except ImportError: - return False - - def _simulate_deepspeed_rank_generation(self, config: Dict) -> Optional[Dict]: - """Simulate DeepSpeed's rank generation logic based on code analysis. - - This attempts to replicate DeepSpeed's rank assignment logic for comparison. - """ - if not DEEPSPEED_AVAILABLE: - return None - - try: - world_size = config["world_size"] - result = {} - - # For TP+DP: DeepSpeed uses mesh_device which creates groups in a specific way - if config["tp"] > 1 and config["dp"] > 1 and config["pp"] == 1 and config["cp"] == 1 and config["ep"] == 1: - # DeepSpeed's _init_tp_mesh_device creates: - # TP groups: [0,1], [2,3], [4,5], ... (consecutive) - # DP groups: [0,2,4,...], [1,3,5,...] (strided) - tp_size = config["tp"] - dp_size = config["dp"] - - tp_groups = [] - for i in range(world_size // tp_size): - group = list(range(i * tp_size, (i + 1) * tp_size)) - tp_groups.append(group) - - dp_groups = [] - for i in range(tp_size): - group = list(range(i, world_size, tp_size)) - dp_groups.append(group) - - result["tp_groups"] = tp_groups - result["dp_groups"] = dp_groups - result["world_size"] = world_size - return result - - # For other combinations, we can't easily simulate without actual distributed setup - # But we can note that DeepSpeed supports it - return {"supported": True, "note": "Rank generation requires actual distributed setup"} - - except Exception as e: - return {"error": str(e)} - - def _compare_rank_groups(self, megatron_groups: List[List[int]], deepspeed_groups: List[List[int]]) -> Dict: - """Compare rank groups from Megatron and DeepSpeed. - - Returns: - Dict with comparison results - """ - comparison = {"same_structure": False, "same_ranks": False, "differences": []} - - if not megatron_groups or not deepspeed_groups: - return comparison - - # Check if same number of groups - if len(megatron_groups) != len(deepspeed_groups): - comparison["differences"].append( - f"Group count mismatch: Megatron={len(megatron_groups)}, DeepSpeed={len(deepspeed_groups)}") - return comparison - - # Check if same group sizes - megatron_sizes = [len(g) for g in megatron_groups] - deepspeed_sizes = [len(g) for g in deepspeed_groups] - if megatron_sizes != deepspeed_sizes: - comparison["differences"].append( - f"Group size mismatch: Megatron={megatron_sizes}, DeepSpeed={deepspeed_sizes}") - return comparison - - # Check if same ranks (order may differ) - megatron_ranks = set() - for group in megatron_groups: - megatron_ranks.update(group) - - deepspeed_ranks = set() - for group in deepspeed_groups: - deepspeed_ranks.update(group) - - if megatron_ranks != deepspeed_ranks: - comparison["differences"].append( - f"Rank set mismatch: Megatron={sorted(megatron_ranks)}, DeepSpeed={sorted(deepspeed_ranks)}") - return comparison - - # Check if same structure (same groups, possibly different order) - megatron_sets = [set(g) for g in megatron_groups] - deepspeed_sets = [set(g) for g in deepspeed_groups] - - if sorted(megatron_sets, key=lambda x: min(x)) == sorted(deepspeed_sets, key=lambda x: min(x)): - comparison["same_structure"] = True - comparison["same_ranks"] = True - else: - comparison["differences"].append("Group structure differs (same ranks but different grouping)") - - return comparison - - def test_config_compatibility(self, config: Dict): - """Test compatibility of a configuration between both libraries.""" - # Get combination type for statistics - combo_type = self.error_categorizer.get_combination_type(config) - self.combination_stats[combo_type]["total"] += 1 - - # Test Megatron - megatron_success, megatron_error, megatron_result = self.test_megatron_config(config) - - # Test DeepSpeed - deepspeed_success, deepspeed_error, deepspeed_support = self.test_deepspeed_config(config) - - # Record errors in categorizer - if not megatron_success and megatron_error: - self.error_categorizer.record_error(megatron_error, config, "Megatron") - self.combination_stats[combo_type]["megatron_failures"] += 1 - else: - self.combination_stats[combo_type]["megatron_success"] += 1 - - if not deepspeed_success: - # Get error message from support_info notes - error_msg = deepspeed_support.get("notes", ["Not supported"])[0] if deepspeed_support else "Not supported" - self.error_categorizer.record_error(error_msg, config, "DeepSpeed") - self.combination_stats[combo_type]["deepspeed_failures"] += 1 - else: - self.combination_stats[combo_type]["deepspeed_success"] += 1 - - # Try to simulate DeepSpeed rank generation for comparison - deepspeed_simulated = None - if megatron_success and deepspeed_success: - deepspeed_simulated = self._simulate_deepspeed_rank_generation(config) - - # Compare rank generation if both succeeded and we have simulated results - rank_comparison = None - if megatron_success and deepspeed_success and deepspeed_simulated and "tp_groups" in deepspeed_simulated: - # Compare TP groups - if config["tp"] > 1 and "tp_groups" in megatron_result: - rank_comparison = self._compare_rank_groups(megatron_result["tp_groups"], - deepspeed_simulated.get("tp_groups", [])) - # Compare DP groups - if config["dp"] > 1 and "dp_groups" in megatron_result and not rank_comparison: - rank_comparison = self._compare_rank_groups(megatron_result["dp_groups"], - deepspeed_simulated.get("dp_groups", [])) - - # Record results - config_key = f"tp={config['tp']},dp={config['dp']},pp={config['pp']},cp={config['cp']},ep={config['ep']},order={config['order']}" - - if megatron_success: - self.results["megatron_success"].append(config_key) - else: - self.results["megatron_failures"].append({ - "config": config_key, - "error": megatron_error, - "combination": combo_type, - }) - - if deepspeed_success: - self.results["deepspeed_success"].append(config_key) - else: - self.results["deepspeed_failures"].append({ - "config": config_key, - "error": deepspeed_error, - "support_info": deepspeed_support, - "combination": combo_type, - }) - - # Determine compatibility and update stats - if megatron_success and deepspeed_success: - compat_entry = { - "config": config_key, - "megatron_result": megatron_result, - "deepspeed_support": deepspeed_support, - "combination": combo_type, - } - if rank_comparison: - compat_entry["rank_comparison"] = rank_comparison - if rank_comparison.get("same_structure"): - compat_entry["rank_match"] = True - else: - compat_entry["rank_match"] = False - compat_entry["rank_differences"] = rank_comparison.get("differences", []) - self.results["compatible"].append(compat_entry) - self.combination_stats[combo_type]["compatible"] += 1 - elif megatron_success and not deepspeed_success: - self.results["megatron_only"].append({ - "config": - config_key, - "megatron_result": - megatron_result, - "deepspeed_issue": - deepspeed_support.get("notes", []) if deepspeed_support else [], - "combination": - combo_type, - }) - self.combination_stats[combo_type]["megatron_only"] += 1 - elif not megatron_success and deepspeed_success: - self.results["deepspeed_only"].append({ - "config": config_key, - "megatron_error": megatron_error, - "deepspeed_support": deepspeed_support, - "combination": combo_type, - }) - self.combination_stats[combo_type]["deepspeed_only"] += 1 - else: - self.results["incompatible"].append({ - "config": - config_key, - "megatron_error": - megatron_error, - "deepspeed_issue": - deepspeed_support.get("notes", []) if deepspeed_support else [], - "combination": - combo_type, - }) - self.combination_stats[combo_type]["incompatible"] += 1 - - -class TestAutomatedParallelCombinations: - """Automated tests for parallel strategy combinations.""" - - def test_systematic_configurations(self): - """Test systematic configurations covering common cases.""" - generator = ParallelConfigGenerator(seed=42) - tester = ParallelCompatibilityTester() - - configs = generator.generate_systematic_configs(max_world_size=16) - - print("\n" + "=" * 80) - print("SYSTEMATIC CONFIGURATION TESTING") - print("=" * 80) - print(f"\nTesting {len(configs)} systematic configurations...") - - for i, config in enumerate(configs, 1): - print(f"\n[{i}/{len(configs)}] Testing: {config}") - tester.test_config_compatibility(config) - - self._print_results(tester, "Systematic") - self._generate_comprehensive_report(tester, "Systematic") - - def test_random_configurations(self): - """Test random configurations.""" - generator = ParallelConfigGenerator(seed=123) - tester = ParallelCompatibilityTester() - - configs = generator.generate_random_configs(count=1000, max_size=1024) - - print("\n" + "=" * 80) - print("RANDOM CONFIGURATION TESTING") - print("=" * 80) - print(f"\nTesting {len(configs)} random configurations...") - print(f"Max world size: 1024, Max parallel size per dimension: 32") - - for i, config in enumerate(configs, 1): - if i % 100 == 0: - print(f"Progress: {i}/{len(configs)} ({(i/len(configs)*100):.1f}%)") - tester.test_config_compatibility(config) - - self._print_results(tester, "Random") - self._generate_comprehensive_report(tester, "Random") - - def test_random_configurations_by_dimension(self): - """Test random configurations generated separately for each dimension.""" - generator = ParallelConfigGenerator(seed=789) - tester = ParallelCompatibilityTester() - - # Generate configs for each dimension separately - # This ensures balanced coverage across all dimensions - # Increased by 20x for comprehensive testing - counts_by_dimension = { - 1: 4000, # 1D: 4000 configs (200 * 20) - 2: 6000, # 2D: 6000 configs (300 * 20) - more because there are more 2D combinations - 3: 5000, # 3D: 5000 configs (250 * 20) - 4: 3000, # 4D: 3000 configs (150 * 20) - 5: 2000, # 5D: 2000 configs (100 * 20) - } - - print("\n" + "=" * 80) - print("RANDOM CONFIGURATION TESTING BY DIMENSION") - print("=" * 80) - print(f"\nGenerating configurations by dimension:") - for dim, count in counts_by_dimension.items(): - print(f" {dim}D: {count} configurations") - - configs = generator.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, - max_size=1024, - min_parallel_size=2, - max_parallel_size=32) - - print(f"\nTotal unique configurations generated: {len(configs)}") - print(f"Max world size: 1024, Parallel size range: 2-32") - - # Count configs by dimension - dim_counts = defaultdict(int) - for config in configs: - dim_count = len([d for d in ["tp", "dp", "pp", "cp", "ep"] if config[d] > 1]) - dim_counts[dim_count] += 1 - - print("\nActual distribution:") - for dim in sorted(dim_counts.keys()): - print(f" {dim}D: {dim_counts[dim]} configurations") - - print(f"\nTesting {len(configs)} configurations...") - - for i, config in enumerate(configs, 1): - # Update progress more frequently for large test sets - if i % 1000 == 0 or i == len(configs): - print(f"Progress: {i}/{len(configs)} ({(i/len(configs)*100):.1f}%)") - tester.test_config_compatibility(config) - - self._print_results(tester, "Random by Dimension") - self._generate_comprehensive_report(tester, "Random by Dimension") - - def test_edge_cases(self): - """Test edge cases and boundary conditions.""" - generator = ParallelConfigGenerator(seed=456) - tester = ParallelCompatibilityTester() - - # Edge cases - including larger sizes - edge_configs = [ - # Maximum dimensions - larger sizes - { - "tp": 8, - "dp": 8, - "pp": 8, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": 512 - }, - { - "tp": 16, - "dp": 16, - "pp": 4, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": 1024 - }, - # EP and CP conflict - { - "tp": 2, - "dp": 2, - "pp": 1, - "cp": 2, - "ep": 2, - "order": "tp-ep-dp", - "world_size": 8 - }, - { - "tp": 4, - "dp": 4, - "pp": 1, - "cp": 4, - "ep": 4, - "order": "tp-ep-dp", - "world_size": 64 - }, - # Single dimension - larger sizes - { - "tp": 1, - "dp": 1, - "pp": 64, - "cp": 1, - "ep": 1, - "order": "pp", - "world_size": 64 - }, - { - "tp": 128, - "dp": 1, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp", - "world_size": 128 - }, - { - "tp": 1, - "dp": 256, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp", - "world_size": 256 - }, - # All dimensions - larger sizes - { - "tp": 2, - "dp": 2, - "pp": 2, - "cp": 2, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": 16 - }, - { - "tp": 4, - "dp": 4, - "pp": 4, - "cp": 4, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": 256 - }, - # Different orders - { - "tp": 2, - "dp": 4, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp-tp", - "world_size": 8 - }, - { - "tp": 2, - "dp": 4, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp-dp", - "world_size": 8 - }, - { - "tp": 8, - "dp": 16, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp-tp", - "world_size": 128 - }, - { - "tp": 8, - "dp": 16, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp-dp", - "world_size": 128 - }, - # Large multi-dimensional - { - "tp": 8, - "dp": 8, - "pp": 4, - "cp": 1, - "ep": 1, - "order": "tp-pp-dp", - "world_size": 256 - }, - { - "tp": 4, - "dp": 8, - "pp": 8, - "cp": 1, - "ep": 1, - "order": "tp-pp-dp", - "world_size": 256 - }, - ] - - print("\n" + "=" * 80) - print("EDGE CASE TESTING") - print("=" * 80) - print(f"\nTesting {len(edge_configs)} edge case configurations...") - - for i, config in enumerate(edge_configs, 1): - print(f"\n[{i}/{len(edge_configs)}] Testing: {config}") - tester.test_config_compatibility(config) - - self._print_results(tester, "Edge Cases") - self._generate_comprehensive_report(tester, "Edge Cases") - - def _print_results(self, tester: ParallelCompatibilityTester, test_type: str): - """Print test results.""" - results = tester.results - - print("\n" + "=" * 80) - print(f"{test_type} TEST RESULTS") - print("=" * 80) - - print(f"\n✓ Megatron Success: {len(results['megatron_success'])}") - print(f"✗ Megatron Failures: {len(results['megatron_failures'])}") - if results['megatron_failures']: - print("\nMegatron Failures:") - for failure in results['megatron_failures'][:10]: # Show first 10 - print(f" - {failure['config']}: {failure['error']}") - if len(results['megatron_failures']) > 10: - print(f" ... and {len(results['megatron_failures']) - 10} more") - - print(f"\n✓ DeepSpeed Success: {len(results['deepspeed_success'])}") - print(f"✗ DeepSpeed Failures: {len(results['deepspeed_failures'])}") - if results['deepspeed_failures']: - print("\nDeepSpeed Failures:") - for failure in results['deepspeed_failures'][:10]: # Show first 10 - print(f" - {failure['config']}") - if failure.get('support_info'): - notes = failure['support_info'].get('notes', []) - if notes: - print(f" Notes: {', '.join(notes)}") - if len(results['deepspeed_failures']) > 10: - print(f" ... and {len(results['deepspeed_failures']) - 10} more") - - print(f"\n✓ Compatible (Both Support): {len(results['compatible'])}") - if results['compatible']: - print(" Examples:") - rank_matches = 0 - rank_mismatches = 0 - for item in results['compatible'][:10]: - if isinstance(item, dict): - config = item.get('config', 'Unknown') - rank_comp = item.get('rank_comparison') - if rank_comp: - if rank_comp.get('same_structure'): - print(f" - {config} ✓ Rank groups match") - rank_matches += 1 - else: - print(f" - {config} ⚠ Rank groups differ") - rank_mismatches += 1 - if rank_comp.get('differences'): - for diff in rank_comp['differences'][:2]: - print(f" {diff}") - else: - print(f" - {config}") - else: - print(f" - {item}") - if len(results['compatible']) > 10: - print(f" ... and {len(results['compatible']) - 10} more") - - if rank_matches > 0 or rank_mismatches > 0: - print(f"\n Rank Comparison Summary:") - print(f" Matches: {rank_matches}") - print(f" Mismatches: {rank_mismatches}") - print(f" (Note: Comparison only available for TP+DP combinations)") - - print(f"\n⚠ Megatron Only: {len(results['megatron_only'])}") - if results['megatron_only']: - print(" Examples:") - for item in results['megatron_only'][:5]: - print(f" - {item['config']}") - if item.get('deepspeed_issue'): - print(f" DeepSpeed issue: {', '.join(item['deepspeed_issue'])}") - if len(results['megatron_only']) > 5: - print(f" ... and {len(results['megatron_only']) - 5} more") - - print(f"\n→ DeepSpeed Only: {len(results['deepspeed_only'])}") - if results['deepspeed_only']: - print(" Examples:") - for item in results['deepspeed_only'][:5]: - print(f" - {item['config']}") - print(f" Megatron error: {item['megatron_error']}") - if len(results['deepspeed_only']) > 5: - print(f" ... and {len(results['deepspeed_only']) - 5} more") - - print(f"\n✗ Incompatible (Neither Support): {len(results['incompatible'])}") - if results['incompatible']: - print(" Examples:") - for item in results['incompatible'][:5]: - print(f" - {item['config']}") - print(f" Megatron: {item['megatron_error']}") - if len(results['incompatible']) > 5: - print(f" ... and {len(results['incompatible']) - 5} more") - - print("\n" + "=" * 80) - - def _generate_comprehensive_report(self, tester: ParallelCompatibilityTester, test_type: str): - """Generate comprehensive test report with error categorization and combination statistics.""" - results = tester.results - error_summary = tester.error_categorizer.get_error_summary() - combo_stats = tester.combination_stats - - print("\n" + "=" * 80) - print(f"{test_type} COMPREHENSIVE TEST REPORT") - print("=" * 80) - - # Overall statistics - print("\n" + "-" * 80) - print("OVERALL STATISTICS") - print("-" * 80) - total_tested = (len(results['megatron_success']) + len(results['megatron_failures']) + - len(results['deepspeed_success']) + len(results['deepspeed_failures'])) - print(f"Total Configurations Tested: {total_tested}") - print( - f" Megatron Success: {len(results['megatron_success'])} ({len(results['megatron_success'])/total_tested*100:.1f}%)" - ) - print( - f" Megatron Failures: {len(results['megatron_failures'])} ({len(results['megatron_failures'])/total_tested*100:.1f}%)" - ) - print( - f" DeepSpeed Success: {len(results['deepspeed_success'])} ({len(results['deepspeed_success'])/total_tested*100:.1f}%)" - ) - print( - f" DeepSpeed Failures: {len(results['deepspeed_failures'])} ({len(results['deepspeed_failures'])/total_tested*100:.1f}%)" - ) - print(f" Compatible: {len(results['compatible'])} ({len(results['compatible'])/total_tested*100:.1f}%)") - print( - f" Megatron Only: {len(results['megatron_only'])} ({len(results['megatron_only'])/total_tested*100:.1f}%)" - ) - print( - f" DeepSpeed Only: {len(results['deepspeed_only'])} ({len(results['deepspeed_only'])/total_tested*100:.1f}%)" - ) - print(f" Incompatible: {len(results['incompatible'])} ({len(results['incompatible'])/total_tested*100:.1f}%)") - - # Error categorization - print("\n" + "-" * 80) - print("ERROR CATEGORIZATION (Aggregated by Type)") - print("-" * 80) - for category, summary in sorted(error_summary.items(), key=lambda x: x[1]['count'], reverse=True): - print(f"\n{category}: {summary['count']} occurrences") - print(f" Affects {summary['unique_combinations']} unique combination types") - print(f" Examples:") - for example in summary['examples'][:3]: - combo = example.get('combination', 'Unknown') - lib = example.get('library', 'Unknown') - print(f" - {combo} ({lib}): {example['error'][:80]}") - if len(summary['examples']) > 3: - print(f" ... and {len(summary['examples']) - 3} more examples") - - # Combination type statistics - print("\n" + "-" * 80) - print("COMBINATION TYPE STATISTICS") - print("-" * 80) - print( - f"{'Combination':<20} {'Total':<8} {'M-Succ':<8} {'M-Fail':<8} {'DS-Succ':<8} {'DS-Fail':<8} {'Compat':<8} {'M-Only':<8} {'DS-Only':<8} {'Incomp':<8}" - ) - print("-" * 100) - - # Sort by total count - sorted_combos = sorted(combo_stats.items(), key=lambda x: x[1]['total'], reverse=True) - for combo_type, stats in sorted_combos: - if stats['total'] > 0: - print(f"{combo_type:<20} {stats['total']:<8} {stats['megatron_success']:<8} " - f"{stats['megatron_failures']:<8} {stats['deepspeed_success']:<8} " - f"{stats['deepspeed_failures']:<8} {stats['compatible']:<8} " - f"{stats['megatron_only']:<8} {stats['deepspeed_only']:<8} " - f"{stats['incompatible']:<8}") - - # Detailed combination analysis - print("\n" + "-" * 80) - print("DETAILED COMBINATION ANALYSIS") - print("-" * 80) - - # Group by number of dimensions - by_dimension_count = defaultdict(list) - for combo_type, stats in combo_stats.items(): - dim_count = len([c for c in combo_type.split('+') if c != 'NONE']) - by_dimension_count[dim_count].append((combo_type, stats)) - - for dim_count in sorted(by_dimension_count.keys()): - print(f"\n{dim_count}-Dimensional Combinations:") - combos = sorted(by_dimension_count[dim_count], key=lambda x: x[1]['total'], reverse=True) - for combo_type, stats in combos[:10]: # Show top 10 - if stats['total'] > 0: - compat_rate = (stats['compatible'] / stats['total'] * 100) if stats['total'] > 0 else 0 - print(f" {combo_type}:") - print(f" Total: {stats['total']}, Compatible: {stats['compatible']} ({compat_rate:.1f}%)") - print(f" Megatron: {stats['megatron_success']} success, {stats['megatron_failures']} failures") - print( - f" DeepSpeed: {stats['deepspeed_success']} success, {stats['deepspeed_failures']} failures") - if len(combos) > 10: - print(f" ... and {len(combos) - 10} more {dim_count}-dimensional combinations") - - print("\n" + "=" * 80) - - def test_cp_vs_sp_compatibility_by_dimension(self): - """Test CP vs SP compatibility using the same config generation as test_random_configurations_by_dimension. - - This test: - 1. Uses parallel_state_refactored with CP - 2. Uses DeepSpeed with SP - 3. Compares CP rank groups with SP rank groups to see if they match - """ - generator = ParallelConfigGenerator(seed=789) - - # Use the same configuration generation as test_random_configurations_by_dimension - counts_by_dimension = { - 1: 4000, # 1D: 4000 configs - 2: 6000, # 2D: 6000 configs - 3: 5000, # 3D: 5000 configs - 4: 3000, # 4D: 3000 configs - 5: 2000, # 5D: 2000 configs - } - - print("\n" + "=" * 80) - print("CP vs SP COMPATIBILITY TESTING BY DIMENSION") - print("=" * 80) - print(f"\nGenerating configurations by dimension:") - for dim, count in counts_by_dimension.items(): - print(f" {dim}D: {count} configurations") - - configs = generator.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, - max_size=1024, - min_parallel_size=2, - max_parallel_size=32) - - # Filter to only include configs with CP > 1 and EP == 1 (EP and CP cannot both be > 1) - configs_with_cp = [c for c in configs if c["cp"] > 1 and c["ep"] == 1] - - print(f"\nTotal unique configurations generated: {len(configs)}") - print(f"Configurations with CP > 1 and EP == 1: {len(configs_with_cp)}") - print(f"Max world size: 1024, Parallel size range: 2-32") - - # Test CP vs SP compatibility - results = { - "total_tested": 0, - "cp_groups_generated": 0, - "sp_groups_generated": 0, - "rank_groups_match": 0, - "rank_groups_differ": 0, - "errors": 0, - "match_details": [], - "differ_details": [], - } - - combination_stats = defaultdict(lambda: { - "total": 0, - "match": 0, - "differ": 0, - "errors": 0, - }) - - print(f"\nTesting {len(configs_with_cp)} configurations for CP vs SP compatibility...") - - for i, config in enumerate(configs_with_cp, 1): - if i % 1000 == 0 or i == len(configs_with_cp): - print(f"Progress: {i}/{len(configs_with_cp)} ({(i/len(configs_with_cp)*100):.1f}%)") - - results["total_tested"] += 1 - - # Get combination type - combo_type = self._get_combination_type_for_cp_sp(config) - combination_stats[combo_type]["total"] += 1 - - try: - # Get CP rank groups from Megatron - if not PARALLEL_STATE_AVAILABLE: - results["errors"] += 1 - combination_stats[combo_type]["errors"] += 1 - continue - - rg = RankGenerator(tp=config["tp"], - ep=config["ep"], - dp=config["dp"], - pp=config["pp"], - cp=config["cp"], - order=config["order"]) - - cp_groups = rg.get_ranks("cp") - if cp_groups: - results["cp_groups_generated"] += 1 - - # Simulate SP rank groups from DeepSpeed - # DeepSpeed SP creates consecutive rank groups - sp_groups = self._simulate_deepspeed_sp_groups(config["world_size"], config["cp"]) - if sp_groups: - results["sp_groups_generated"] += 1 - - # Compare CP and SP groups - if self._compare_cp_sp_groups(cp_groups, sp_groups): - results["rank_groups_match"] += 1 - combination_stats[combo_type]["match"] += 1 - results["match_details"].append(config) - else: - results["rank_groups_differ"] += 1 - combination_stats[combo_type]["differ"] += 1 - results["differ_details"].append({ - "config": config, - "cp_groups": cp_groups, - "sp_groups": sp_groups, - }) - - except Exception as e: - results["errors"] += 1 - combination_stats[combo_type]["errors"] += 1 - - # Generate report - self._generate_cp_vs_sp_report(results, combination_stats) - - def _simulate_deepspeed_sp_groups(self, world_size: int, sp_size: int) -> List[List[int]]: - """Simulate DeepSpeed's SP rank group generation. - - DeepSpeed SP creates groups as consecutive ranks: - - Group 0: [0, 1, ..., sp_size-1] - - Group 1: [sp_size, sp_size+1, ..., 2*sp_size-1] - - etc. - """ - if sp_size <= 1 or world_size % sp_size != 0: - return [] - - num_groups = world_size // sp_size - groups = [] - for i in range(num_groups): - group = list(range(i * sp_size, (i + 1) * sp_size)) - groups.append(group) - - return groups - - def _compare_cp_sp_groups(self, cp_groups: List[List[int]], sp_groups: List[List[int]]) -> bool: - """Compare CP and SP rank groups to see if they match.""" - if not cp_groups and not sp_groups: - return True - - if not cp_groups or not sp_groups: - return False - - if len(cp_groups) != len(sp_groups): - return False - - # Check if all CP groups have a matching SP group (order may differ) - cp_sets = [set(g) for g in cp_groups] - sp_sets = [set(g) for g in sp_groups] - - # Check if all CP groups match SP groups - for cp_set in cp_sets: - found = False - for sp_set in sp_sets: - if cp_set == sp_set: - found = True - break - if not found: - return False - - # Check if all SP groups match CP groups - for sp_set in sp_sets: - found = False - for cp_set in cp_sets: - if sp_set == cp_set: - found = True - break - if not found: - return False - - return True - - def _get_combination_type_for_cp_sp(self, config: Dict) -> str: - """Get combination type string for CP vs SP testing.""" - dims = [] - if config["tp"] > 1: - dims.append("TP") - if config["dp"] > 1: - dims.append("DP") - if config["pp"] > 1: - dims.append("PP") - if config["cp"] > 1: - dims.append("CP") - # Note: EP is always 1 in this test - - if not dims: - return "NONE" - - return "+".join(sorted(dims)) - - def _generate_cp_vs_sp_report(self, results: Dict, combination_stats: Dict): - """Generate comprehensive CP vs SP compatibility report.""" - print("\n" + "=" * 80) - print("CP vs SP COMPATIBILITY TEST REPORT") - print("=" * 80) - - # Overall statistics - print("\n" + "-" * 80) - print("OVERALL STATISTICS") - print("-" * 80) - print(f"Total Configurations Tested: {results['total_tested']}") - print(f" CP Groups Generated: {results['cp_groups_generated']}") - print(f" SP Groups Generated: {results['sp_groups_generated']}") - print(f" Rank Groups Match: {results['rank_groups_match']}") - print(f" Rank Groups Differ: {results['rank_groups_differ']}") - print(f" Errors: {results['errors']}") - - if results['total_tested'] > 0: - match_rate = (results['rank_groups_match'] / results['total_tested']) * 100 - print(f"\n Match Rate: {match_rate:.2f}%") - print(f" CP can replace SP in {match_rate:.2f}% of tested configurations") - - # Combination type statistics - print("\n" + "-" * 80) - print("COMBINATION TYPE STATISTICS") - print("-" * 80) - print(f"{'Combination':<20} {'Total':<8} {'Match':<8} {'Differ':<8} {'Errors':<8} {'Match Rate':<12}") - print("-" * 80) - - sorted_combos = sorted(combination_stats.items(), key=lambda x: x[1]['total'], reverse=True) - for combo_type, stats in sorted_combos: - if stats['total'] > 0: - match_rate = (stats['match'] / stats['total'] * 100) if stats['total'] > 0 else 0 - print(f"{combo_type:<20} {stats['total']:<8} {stats['match']:<8} " - f"{stats['differ']:<8} {stats['errors']:<8} {match_rate:.1f}%") - - # Examples of matching configurations - print("\n" + "-" * 80) - print("EXAMPLES OF MATCHING CONFIGURATIONS (CP can replace SP)") - print("-" * 80) - for i, config in enumerate(results['match_details'][:10], 1): - print(f"{i}. {config}") - print(f" CP size: {config['cp']}, Order: {config['order']}") - - if len(results['match_details']) > 10: - print(f"\n... and {len(results['match_details']) - 10} more matching configurations") - - # Examples of differing configurations - if results['differ_details']: - print("\n" + "-" * 80) - print("EXAMPLES OF DIFFERING CONFIGURATIONS (CP cannot replace SP)") - print("-" * 80) - for i, item in enumerate(results['differ_details'][:10], 1): - config = item['config'] - cp_groups = item['cp_groups'] - sp_groups = item['sp_groups'] - print(f"{i}. {config}") - print(f" CP size: {config['cp']}, Order: {config['order']}") - print(f" CP groups count: {len(cp_groups)}, SP groups count: {len(sp_groups)}") - if cp_groups and sp_groups: - print(f" CP first group: {cp_groups[0]}") - print(f" SP first group: {sp_groups[0]}") - - if len(results['differ_details']) > 10: - print(f"\n... and {len(results['differ_details']) - 10} more differing configurations") - - # Conclusion - print("\n" + "=" * 80) - print("CONCLUSION") - print("=" * 80) - if results['rank_groups_match'] > 0: - match_rate = (results['rank_groups_match'] / results['total_tested']) * 100 - print(f"\n✓ CP can replace SP in {match_rate:.2f}% of tested configurations") - print( - f" - {results['rank_groups_match']} out of {results['total_tested']} configurations have matching rank groups" - ) - else: - print("\n✗ CP cannot replace SP in any of the tested configurations") - - if results['rank_groups_differ'] > 0: - print(f"\n⚠ {results['rank_groups_differ']} configurations have different rank groups") - print(" - These configurations may require special handling when migrating from CP to SP") - - print("\n" + "=" * 80) - - def test_comprehensive_automated_testing(self): - """Comprehensive automated testing with all test types.""" - print("\n" + "=" * 80) - print("COMPREHENSIVE AUTOMATED PARALLEL COMBINATION TESTING") - print("=" * 80) - - # Create a combined tester for overall report - combined_tester = ParallelCompatibilityTester() - - # Run all test types and accumulate results - print("\n[1/3] Running systematic configurations...") - generator1 = ParallelConfigGenerator(seed=42) - configs1 = generator1.generate_systematic_configs(max_world_size=512) - print(f"Testing {len(configs1)} systematic configurations...") - for i, config in enumerate(configs1, 1): - if i % 50 == 0 or i == len(configs1): - print(f" Progress: {i}/{len(configs1)}") - combined_tester.test_config_compatibility(config) - - print("\n[2/4] Running random configurations by dimension...") - generator2 = ParallelConfigGenerator(seed=789) - # Increased by 20x for comprehensive testing - counts_by_dimension = { - 1: 4000, # 1D: 4000 configs (200 * 20) - 2: 6000, # 2D: 6000 configs (300 * 20) - 3: 5000, # 3D: 5000 configs (250 * 20) - 4: 3000, # 4D: 3000 configs (150 * 20) - 5: 2000, # 5D: 2000 configs (100 * 20) - } - configs2 = generator2.generate_random_configs_by_dimension(counts_by_dimension=counts_by_dimension, - max_size=1024, - min_parallel_size=2, - max_parallel_size=32) - print(f"Testing {len(configs2)} random configurations (balanced by dimension)...") - print(f"Max world size: 1024, Parallel size range: 2-32") - for i, config in enumerate(configs2, 1): - # Update progress more frequently for large test sets - if i % 1000 == 0 or i == len(configs2): - print(f" Progress: {i}/{len(configs2)} ({(i/len(configs2)*100):.1f}%)") - combined_tester.test_config_compatibility(config) - - print("\n[3/4] Running additional random configurations...") - generator3 = ParallelConfigGenerator(seed=123) - # Increased by 20x: 500 * 20 = 10000 - configs3 = generator3.generate_random_configs(count=10000, max_size=1024) - print(f"Testing {len(configs3)} additional random configurations...") - for i, config in enumerate(configs3, 1): - # Update progress more frequently for large test sets - if i % 1000 == 0 or i == len(configs3): - print(f" Progress: {i}/{len(configs3)} ({(i/len(configs3)*100):.1f}%)") - combined_tester.test_config_compatibility(config) - - print("\n[4/4] Running edge cases...") - edge_configs = [ - { - "tp": 8, - "dp": 8, - "pp": 8, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": 512 - }, - { - "tp": 16, - "dp": 16, - "pp": 4, - "cp": 1, - "ep": 1, - "order": "tp-dp-pp", - "world_size": 1024 - }, - { - "tp": 2, - "dp": 2, - "pp": 1, - "cp": 2, - "ep": 2, - "order": "tp-ep-dp", - "world_size": 8 - }, - { - "tp": 4, - "dp": 4, - "pp": 1, - "cp": 4, - "ep": 4, - "order": "tp-ep-dp", - "world_size": 64 - }, - { - "tp": 1, - "dp": 1, - "pp": 64, - "cp": 1, - "ep": 1, - "order": "pp", - "world_size": 64 - }, - { - "tp": 128, - "dp": 1, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "tp", - "world_size": 128 - }, - { - "tp": 1, - "dp": 256, - "pp": 1, - "cp": 1, - "ep": 1, - "order": "dp", - "world_size": 256 - }, - { - "tp": 2, - "dp": 2, - "pp": 2, - "cp": 2, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": 16 - }, - { - "tp": 4, - "dp": 4, - "pp": 4, - "cp": 4, - "ep": 1, - "order": "tp-pp-dp-cp", - "world_size": 256 - }, - ] - print(f"Testing {len(edge_configs)} edge case configurations...") - for config in edge_configs: - combined_tester.test_config_compatibility(config) - - # Generate comprehensive report - print("\n" + "=" * 80) - print("COMPREHENSIVE FINAL REPORT") - print("=" * 80) - self._generate_comprehensive_report(combined_tester, "COMPREHENSIVE") - - print("\n" + "=" * 80) - print("ALL TESTS COMPLETED") - print("=" * 80) - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"])