diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 628febc..0d665b6 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -15,13 +15,20 @@ """Unit tests for TransferQueue samplers.""" +import sys +from pathlib import Path from typing import Any import pytest -from transfer_queue.sampler import BaseSampler -from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler -from transfer_queue.sampler.sequential_sampler import SequentialSampler +# Setup path +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +from transfer_queue.sampler import BaseSampler # noqa: E402 +from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler # noqa: E402 +from transfer_queue.sampler.rank_aware_sampler import RankAwareSampler # noqa: E402 +from transfer_queue.sampler.sequential_sampler import SequentialSampler # noqa: E402 class TestBaseSampler: @@ -427,6 +434,155 @@ def test_grpo_sampler_insufficient_groups(self): assert consumed == [] +class TestRankAwareSampler: + """Test cases for RankAwareSampler.""" + + def test_rank_aware_sampler_initialization(self): + """Test RankAwareSampler initialization.""" + sampler = RankAwareSampler() + assert isinstance(sampler, BaseSampler) + assert hasattr(sampler, "_states") + assert sampler._states == {} + + def test_rank_aware_sampler_first_rank_sampling(self): + """Test that first rank in DP group performs actual sampling.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5] + batch_size = 3 + + # When world_size == dp_world_size, fetches_per_batch = 1 + # First rank samples and immediately marks consumed (no other ranks to wait for) + sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + + assert sampled == [0, 1, 2] + # consumed is returned + assert consumed == [0, 1, 2] + assert len(sampled) == batch_size + # State should be cleaned up + assert sampler._states == {} + + def test_rank_aware_sampler_second_rank_gets_cached(self): + """Test that second rank in DP group gets cached indices.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5] + batch_size = 3 + dp_world_size = 2 + world_size = 4 # Use world_size=4 so fetches_per_batch=2 + + # Rank 0 (dp_group=0) samples first + sampled1, consumed1 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + + # Rank 1 (dp_group=0) should get same cached indices + sampled2, consumed2 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + + assert sampled1 == sampled2 == [0, 1, 2] + # First rank already returns consumed indexes + assert consumed1 == [0, 1, 2] + # Second rank also sees the same consumed indexes; state is then cleaned up + assert consumed2 == [0, 1, 2] + # State should be cleaned up + assert sampler._states == {} + + def test_rank_aware_sampler_multiple_dp_groups(self): + """Test that multiple DP groups work independently.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + dp_world_size = 4 + world_size = 8 + + # DP group 0: rank 0 samples first + sampled0_g0, consumed0_g0 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + # mimic the consumption status update managed in TransferQueueController + ready_indexes = [i for i in ready_indexes if i not in consumed0_g0] + + # DP group 1: rank 0 samples first + sampled0_g1, consumed0_g1 = sampler.sample( + ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ) + ready_indexes = [i for i in ready_indexes if i not in consumed0_g1] + + # Both should have sampled their first batch + assert sampled0_g0 == [0, 1] + assert sampled0_g1 == [2, 3] + assert consumed0_g0 == [0, 1] + assert consumed0_g1 == [2, 3] + + # DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed + sampled1_g0, consumed1_g0 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + ready_indexes = [i for i in ready_indexes if i not in consumed1_g0] + assert sampled1_g0 == [0, 1] + assert consumed1_g0 == [0, 1] + + # DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed + sampled1_g1, consumed1_g1 = sampler.sample( + ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ) + ready_indexes = [i for i in ready_indexes if i not in consumed1_g1] + assert sampled1_g1 == [2, 3] + assert consumed1_g1 == [2, 3] + + # DP group 0: rank 0 fetches again, this should return new data + sampled2_g0, consumed2_g0 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + ready_indexes = [i for i in ready_indexes if i not in consumed2_g0] + assert sampled2_g0 == [4, 5] + assert consumed2_g0 == [4, 5] + + # DP group 0: rank 1 fetches cached + sampled3_g0, consumed3_g0 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + assert sampled3_g0 == [4, 5] + assert consumed3_g0 == [4, 5] + + # Both groups should be cleaned up + assert sampler._states == {} + + def test_rank_aware_sampler_empty_ready_indexes(self): + """Test behavior with empty ready indexes.""" + sampler = RankAwareSampler() + ready_indexes = [] + batch_size = 3 + + sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + + assert sampled == [] + assert consumed == [] + + def test_rank_aware_sampler_batch_size_larger_than_ready(self): + """Test behavior when batch_size > len(ready_indexes).""" + sampler = RankAwareSampler() + ready_indexes = [0, 1] + batch_size = 5 + + # When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately + sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + + assert sampled == [] + assert consumed == [] + + def test_rank_aware_sampler_zero_batch_size(self): + """Test behavior with zero batch size.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3] + batch_size = 0 + + sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + + assert sampled == [] + assert consumed == [] + + class TestSamplerIntegration: """Integration tests for samplers.""" diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 167a26e..6569ada 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -24,6 +24,7 @@ from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler +from .sampler.rank_aware_sampler import RankAwareSampler from .sampler.sequential_sampler import SequentialSampler from .storage import SimpleStorageUnit from .utils.utils import get_placement_group @@ -41,6 +42,7 @@ "BaseSampler", "GRPOGroupNSampler", "SequentialSampler", + "RankAwareSampler", ] version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) diff --git a/transfer_queue/sampler/__init__.py b/transfer_queue/sampler/__init__.py index c2ae449..26532a2 100644 --- a/transfer_queue/sampler/__init__.py +++ b/transfer_queue/sampler/__init__.py @@ -15,6 +15,7 @@ from .base import BaseSampler from .grpo_group_n_sampler import GRPOGroupNSampler +from .rank_aware_sampler import RankAwareSampler from .sequential_sampler import SequentialSampler -__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler"] +__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"] diff --git a/transfer_queue/sampler/base.py b/transfer_queue/sampler/base.py index 3d307ff..2cdd1b5 100644 --- a/transfer_queue/sampler/base.py +++ b/transfer_queue/sampler/base.py @@ -34,13 +34,14 @@ class BaseSampler(ABC): - **SequentialSampler**: Default sampler, selects samples sequentially without replacement - **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready. It assumes the N samples associated with the same prompt are stored contiguously - - **RankAwareSampler**: Rank-aware sampling for distributed scenarios (TODO) + - **RankAwareSampler**: Rank-aware sampling for distributed training where each rank retrieves data independently. + This sampler will guarantee ranks of the same DP group consume identical samples. NOTE: Always return both sampled and consumed indexes (may be identical). """ def __init__(self): - self._states: dict[str, Any] = {} + self._states: dict[Any, Any] = {} @abstractmethod def sample( diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py new file mode 100644 index 0000000..b369ca5 --- /dev/null +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -0,0 +1,132 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from transfer_queue.sampler import BaseSampler + + +class RankAwareSampler(BaseSampler): + """Rank-aware sampler for distributed training with TransferQueue. + + This sampler is designed for distributed data parallel training scenarios + where each rank retrieves data independently. + + This sampler guarantees that all ranks within the same DP group receive + the same sample indices. + + The sampler maintains per-DP-group state to coordinate sampling across ranks: + + - First rank in a DP group to call :meth:`sample` performs actual sampling from + ``ready_indexes`` and caches the result + - Subsequent ranks in the same DP group retrieve the cached indices + - Once all ranks in the DP group have fetched their samples, the cached state is + cleaned up. + + + Please refer to our roadmap for more details: + [Roadmap] StreamingDataLoader for task-separated RL post-training + https://github.com/Ascend/TransferQueue/issues/1 + """ + + def __init__(self): + """Initialize the RankAwareSampler. + + The sampler maintains internal state to coordinate sampling across ranks + within the same DP group. This state tracks which samples have been sampled + and how many times they have been fetched. + """ + super().__init__() + + def sample( + self, + ready_indexes: list[int], + batch_size: int, + dp_group: int, + dp_world_size: int, + world_size: int, + *args: Any, + **kwargs: Any, + ) -> tuple[list[int], list[int]]: + """Sample indices for the current rank, coordinating with other DP ranks. + + This method implements coordinated sampling for distributed training. + The first rank in each DP group to call this method performs actual sampling + from ``ready_indexes`` and caches the result. Subsequent ranks in the same + DP group receive the cached indices directly. + + Args: + ready_indexes: List of global indices for which all required fields of the + corresponding samples have been produced, and the samples are not labeled + as consumed in the corresponding task. + batch_size: Number of samples to select. If larger than available + ready samples, all available samples will be returned. + dp_group: The group id of current data parallel group. Used to + identify which DP group this rank belongs to. + dp_world_size: Number of ranks in the data parallelism group. Used to + determine when all ranks have fetched their samples. + world_size: Total number of ranks across all parallelism dimensions. + Used to determine when all ranks have fetched their samples. + *args: Additional positional arguments (ignored). + **kwargs: Additional keyword arguments (ignored). + + Returns: + List of sampled global indices. Typically, has length `batch_size`, + or returns an empty list if samples are insufficient. + + List of global indices that should be labeled as consumed + (will never be retrieved by other dp_groups in the future). + + Raises: + RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``. + """ + + # Check if this DP group already has sampled data cached + data_for_dp_group = self._states.get(dp_group, None) + + # Calculate how many times this batch should be fetched across all ranks + if dp_world_size <= 0 or world_size % dp_world_size != 0: + raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})") + + fetches_per_batch = world_size // dp_world_size + + if data_for_dp_group is None: + # Select first batch_size indices from ready_indexes + sampled_indexes = ready_indexes[:batch_size] + + if len(sampled_indexes) < batch_size: + return [], [] + + # Initialize state for this DP group + self._states[dp_group] = {} + consumed_indexes = sampled_indexes + + # Cache the sampled indices for other ranks in this DP group + self._states[dp_group]["index"] = sampled_indexes + self._states[dp_group]["fetch_count"] = 1 + + else: + # Return the cached indices (identical to what first rank received) + sampled_indexes = self._states[dp_group]["index"] + consumed_indexes = self._states[dp_group]["index"] + + # Increment fetch count to track progress + self._states[dp_group]["fetch_count"] += 1 + + # Check if this was the last rank in the DP group to fetch + if self._states[dp_group]["fetch_count"] >= fetches_per_batch: + del self._states[dp_group] + + return sampled_indexes, consumed_indexes