From 86fef46dd26f775fbf9a43e317f0d9a7493f0386 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 20 Jan 2026 11:10:09 +0800 Subject: [PATCH 1/5] implement RankAwareSampler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 0oshowero0 Co-authored-by: ji-huazhong Co-authored-by: baymax591 Co-authored-by: jianjunzhong Co-authored-by: LLLLxmmm Co-authored-by: dpj135 <958208521@qq.com> Co-authored-by: Evelynn-V Co-authored-by: liujia7 Co-authored-by: 赵海源 Co-authored-by: NINGBENZHE --- tests/test_samplers.py | 163 ++++++++++++++++++- transfer_queue/__init__.py | 5 + transfer_queue/sampler/__init__.py | 3 +- transfer_queue/sampler/base.py | 6 +- transfer_queue/sampler/rank_aware_sampler.py | 131 +++++++++++++++ 5 files changed, 302 insertions(+), 6 deletions(-) create mode 100644 transfer_queue/sampler/rank_aware_sampler.py diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 628febc..c594724 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -15,13 +15,20 @@ """Unit tests for TransferQueue samplers.""" +import sys +from pathlib import Path from typing import Any import pytest -from transfer_queue.sampler import BaseSampler -from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler -from transfer_queue.sampler.sequential_sampler import SequentialSampler +# Setup path +parent_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(parent_dir)) + +from transfer_queue.sampler import BaseSampler # noqa: E402 +from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler # noqa: E402 +from transfer_queue.sampler.rank_aware_sampler import RankAwareSampler # noqa: E402 +from transfer_queue.sampler.sequential_sampler import SequentialSampler # noqa: E402 class TestBaseSampler: @@ -427,6 +434,156 @@ def test_grpo_sampler_insufficient_groups(self): assert consumed == [] +class TestRankAwareSampler: + """Test cases for RankAwareSampler.""" + + def test_rank_aware_sampler_initialization(self): + """Test RankAwareSampler initialization.""" + sampler = RankAwareSampler() + assert isinstance(sampler, BaseSampler) + assert hasattr(sampler, "_states") + assert sampler._states == {} + + def test_rank_aware_sampler_first_rank_sampling(self): + """Test that first rank in DP group performs actual sampling.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5] + batch_size = 3 + + # When world_size == dp_world_size, fetches_per_batch = 1 + # First rank samples and immediately marks consumed (no other ranks to wait for) + sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + + assert sampled == [0, 1, 2] + # consumed is returned + assert consumed == [0, 1, 2] + assert len(sampled) == batch_size + # State should be cleaned up + assert sampler._states == {} + + def test_rank_aware_sampler_second_rank_gets_cached(self): + """Test that second rank in DP group gets cached indices.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5] + batch_size = 3 + dp_world_size = 2 + world_size = 4 # Use world_size=4 so fetches_per_batch=2 + + # Rank 0 (dp_group=0) samples first + sampled1, consumed1 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + + # Rank 1 (dp_group=0) should get same cached indices + sampled2, consumed2 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + + assert sampled1 == sampled2 == [0, 1, 2] + # First rank returns empty consumed (not all ranks have fetched yet) + assert consumed1 == [0, 1, 2] + # Last rank returns consumed when all ranks have fetched + assert consumed2 == [0, 1, 2] + # State should be cleaned up + assert sampler._states == {} + + def test_rank_aware_sampler_multiple_dp_groups(self): + """Test that multiple DP groups work independently.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + dp_world_size = 4 + world_size = 8 + + # DP group 0: rank 0 samples first + sampled0_g0, consumed0_g0 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + # minic the consumption status update managed in TransferQueueController + ready_indexes = [i for i in ready_indexes if i not in consumed0_g0] + + # DP group 1: rank 0 samples first + sampled0_g1, consumed0_g1 = sampler.sample( + ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ) + ready_indexes = [i for i in ready_indexes if i not in consumed0_g1] + + # Both should have sampled their first batch + assert sampled0_g0 == [0, 1] + assert sampled0_g1 == [2, 3] + assert consumed0_g0 == [0, 1] + assert consumed0_g1 == [2, 3] + + # DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed + sampled1_g0, consumed1_g0 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + ready_indexes = [i for i in ready_indexes if i not in consumed1_g0] + assert sampled1_g0 == [0, 1] + assert consumed1_g0 == [0, 1] + + # DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed + sampled1_g1, consumed1_g1 = sampler.sample( + ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ) + ready_indexes = [i for i in ready_indexes if i not in consumed1_g1] + assert sampled1_g1 == [2, 3] + assert consumed1_g1 == [2, 3] + + # DP group 0: rank 0 fetches again, this should return new data + sampled2_g0, consumed2_g0 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + ready_indexes = [i for i in ready_indexes if i not in consumed2_g0] + assert sampled2_g0 == [4, 5] + assert consumed2_g0 == [4, 5] + + # DP group 0: rank 1 fetches cached + sampled3_g0, consumed3_g0 = sampler.sample( + ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ) + assert sampled3_g0 == [4, 5] + assert consumed3_g0 == [4, 5] + + # Both groups should be cleaned up + assert sampler._states == {} + + def test_rank_aware_sampler_empty_ready_indexes(self): + """Test behavior with empty ready indexes.""" + sampler = RankAwareSampler() + ready_indexes = [] + batch_size = 3 + + sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + + assert sampled == [] + assert consumed == [] + + def test_rank_aware_sampler_batch_size_larger_than_ready(self): + """Test behavior when batch_size > len(ready_indexes).""" + sampler = RankAwareSampler() + ready_indexes = [0, 1] + batch_size = 5 + + # When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately + sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + + assert sampled == [0, 1] + assert consumed == [0, 1] + assert len(sampled) == len(ready_indexes) + + def test_rank_aware_sampler_zero_batch_size(self): + """Test behavior with zero batch size.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3] + batch_size = 0 + + sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + + assert sampled == [] + assert consumed == [] + + class TestSamplerIntegration: """Integration tests for samplers.""" diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 167a26e..4f541a5 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -24,8 +24,10 @@ from .metadata import BatchMeta from .sampler import BaseSampler from .sampler.grpo_group_n_sampler import GRPOGroupNSampler +from .sampler.rank_aware_sampler import RankAwareSampler from .sampler.sequential_sampler import SequentialSampler from .storage import SimpleStorageUnit +from .streaming_dataloader import StreamDataLoader, StreamingDataset from .utils.utils import get_placement_group from .utils.zmq_utils import ZMQServerInfo @@ -41,6 +43,9 @@ "BaseSampler", "GRPOGroupNSampler", "SequentialSampler", + "RankAwareSampler", + "StreamingDataset", + "StreamDataLoader", ] version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) diff --git a/transfer_queue/sampler/__init__.py b/transfer_queue/sampler/__init__.py index c2ae449..26532a2 100644 --- a/transfer_queue/sampler/__init__.py +++ b/transfer_queue/sampler/__init__.py @@ -15,6 +15,7 @@ from .base import BaseSampler from .grpo_group_n_sampler import GRPOGroupNSampler +from .rank_aware_sampler import RankAwareSampler from .sequential_sampler import SequentialSampler -__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler"] +__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"] diff --git a/transfer_queue/sampler/base.py b/transfer_queue/sampler/base.py index 3d307ff..69677ae 100644 --- a/transfer_queue/sampler/base.py +++ b/transfer_queue/sampler/base.py @@ -34,13 +34,15 @@ class BaseSampler(ABC): - **SequentialSampler**: Default sampler, selects samples sequentially without replacement - **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready. It assumes the N samples associated with the same prompt are stored contiguously - - **RankAwareSampler**: Rank-aware sampling for distributed scenarios (TODO) + - **RankAwareSampler**: Rank-aware sampling for distributed training where each ranks independently retrieve data + by themselves. This sampler will guarantee ranks of the same DP group consume identical + samples. NOTE: Always return both sampled and consumed indexes (may be identical). """ def __init__(self): - self._states: dict[str, Any] = {} + self._states: dict[Any, Any] = {} @abstractmethod def sample( diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py new file mode 100644 index 0000000..f504004 --- /dev/null +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -0,0 +1,131 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from transfer_queue.sampler import BaseSampler + + +class RankAwareSampler(BaseSampler): + """Rank-aware sampler for distributed training with TransferQueue. + + This sampler is designed for distributed data parallel training scenarios + where each ranks independently retrieve data by themselves. + + Each rank independently calls the sampler, passing its own rank information, + and the sampler guarantees that all ranks within the same DP group receive + the same sample indices. + + The sampler maintains per-DP-group state to coordinate sampling across ranks: + + - First rank in a DP group to call :meth:`sample` performs actual sampling from + ``ready_indexes`` and caches the result + - Subsequent ranks in the same DP group retrieve the cached indices + - Once all ranks in the DP group have fetched their samples, the indices are + marked as consumed + + + Please refer to our roadmap for more details: + [Roadmap] StreamingDataLoader for task-separated RL post-training + https://github.com/Ascend/TransferQueue/issues/1 + """ + + def __init__(self): + """Initialize the RankAwareSampler. + + The sampler maintains internal state to coordinate sampling across ranks + within the same DP group. This state tracks which samples have been sampled + and how many times they have been fetched. + """ + super().__init__() + + def sample( + self, + ready_indexes: list[int], + batch_size: int, + dp_group: int, + dp_world_size: int, + world_size: int, + *args: Any, + **kwargs: Any, + ) -> tuple[list[int], list[int]]: + """Sample indices for the current rank, coordinating with other DP ranks. + + This method implements coordinated sampling for distributed training. + The first rank in each DP group to call this method performs actual sampling + from ``ready_indexes`` and caches the result. Subsequent ranks in the same + DP group receive the cached indices directly. + + Args: + ready_indexes: List of global indices for which all required fields of the + corresponding samples have been produced, and the samples are not labeled + as consumed in the corresponding task. + batch_size: batch_size: Number of samples to select. If larger than available + ready samples, all available samples will be returned. + dp_group: The group id of current data parallel group. Used to + identify which DP group this rank belongs to. + dp_world_size: Number of ranks in the data parallel group. Used to + determine when all ranks have fetched their samples. + world_size: Total number of ranks across all parallelism dimensions. + Used to determine when all ranks have fetched their samples. + *args: Additional positional arguments (ignored). + **kwargs: Additional keyword arguments (ignored). + + Returns: + List of sampled global indices of length batch_size + + List of global indices of length batch_size that should be labeled as consumed + (will never be retrieved in the future) + + Raises: + RuntimeError: If the fetch count exceeds the expected number of + fetches per DP group. + + Note: + The ``world_size // dp_world_size`` calculation determines how many + times each batch should be fetched (once per TP/PP/... rank group). + """ + + # Check if this DP group already has sampled data cached + data_for_dp_group = self._states.get(dp_group, None) + + # Calculate how many times this batch should be fetched across all ranks + fetches_per_batch = world_size // dp_world_size + + if data_for_dp_group is None: + # Initialize state for this DP group + self._states[dp_group] = {} + + # Select first batch_size indices from ready_indexes + sampled_indexes = ready_indexes[:batch_size] + consumed_indexes = sampled_indexes + + # Cache the sampled indices for other ranks in this DP group + self._states[dp_group]["index"] = sampled_indexes + self._states[dp_group]["fetch_count"] = 1 + + else: + # Return the cached indices (identical to what first rank received) + sampled_indexes = self._states[dp_group]["index"] + consumed_indexes = self._states[dp_group]["index"] + + # Increment fetch count to track progress + self._states[dp_group]["fetch_count"] += 1 + + # Check if this was the last rank in the DP group to fetch + if self._states[dp_group]["fetch_count"] >= fetches_per_batch: + del self._states[dp_group] + + return sampled_indexes, consumed_indexes From fe8c9ccde350c4e309842facd5963b9aba4b7600 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 20 Jan 2026 11:52:52 +0800 Subject: [PATCH 2/5] fix Signed-off-by: 0oshowero0 --- transfer_queue/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/transfer_queue/__init__.py b/transfer_queue/__init__.py index 4f541a5..6569ada 100644 --- a/transfer_queue/__init__.py +++ b/transfer_queue/__init__.py @@ -27,7 +27,6 @@ from .sampler.rank_aware_sampler import RankAwareSampler from .sampler.sequential_sampler import SequentialSampler from .storage import SimpleStorageUnit -from .streaming_dataloader import StreamDataLoader, StreamingDataset from .utils.utils import get_placement_group from .utils.zmq_utils import ZMQServerInfo @@ -44,8 +43,6 @@ "GRPOGroupNSampler", "SequentialSampler", "RankAwareSampler", - "StreamingDataset", - "StreamDataLoader", ] version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) From 6a638fc4ca56a135a46cf0426f319988b9ead18f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 20 Jan 2026 14:13:08 +0800 Subject: [PATCH 3/5] fix gramma & add more check Signed-off-by: 0oshowero0 Co-authored-by: zhabuye <2947436155@qq.com> --- tests/test_samplers.py | 2 +- transfer_queue/sampler/base.py | 5 ++--- transfer_queue/sampler/rank_aware_sampler.py | 15 ++++++++------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index c594724..5e3cc7c 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -499,7 +499,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): sampled0_g0, consumed0_g0 = sampler.sample( ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size ) - # minic the consumption status update managed in TransferQueueController + # mimic the consumption status update managed in TransferQueueController ready_indexes = [i for i in ready_indexes if i not in consumed0_g0] # DP group 1: rank 0 samples first diff --git a/transfer_queue/sampler/base.py b/transfer_queue/sampler/base.py index 69677ae..2cdd1b5 100644 --- a/transfer_queue/sampler/base.py +++ b/transfer_queue/sampler/base.py @@ -34,9 +34,8 @@ class BaseSampler(ABC): - **SequentialSampler**: Default sampler, selects samples sequentially without replacement - **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready. It assumes the N samples associated with the same prompt are stored contiguously - - **RankAwareSampler**: Rank-aware sampling for distributed training where each ranks independently retrieve data - by themselves. This sampler will guarantee ranks of the same DP group consume identical - samples. + - **RankAwareSampler**: Rank-aware sampling for distributed training where each rank retrieves data independently. + This sampler will guarantee ranks of the same DP group consume identical samples. NOTE: Always return both sampled and consumed indexes (may be identical). """ diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index f504004..bb89089 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -22,10 +22,9 @@ class RankAwareSampler(BaseSampler): """Rank-aware sampler for distributed training with TransferQueue. This sampler is designed for distributed data parallel training scenarios - where each ranks independently retrieve data by themselves. + where each rank retrieves data independently. - Each rank independently calls the sampler, passing its own rank information, - and the sampler guarantees that all ranks within the same DP group receive + This sampler guarantees that all ranks within the same DP group receive the same sample indices. The sampler maintains per-DP-group state to coordinate sampling across ranks: @@ -72,7 +71,7 @@ def sample( ready_indexes: List of global indices for which all required fields of the corresponding samples have been produced, and the samples are not labeled as consumed in the corresponding task. - batch_size: batch_size: Number of samples to select. If larger than available + batch_size: Number of samples to select. If larger than available ready samples, all available samples will be returned. dp_group: The group id of current data parallel group. Used to identify which DP group this rank belongs to. @@ -89,9 +88,8 @@ def sample( List of global indices of length batch_size that should be labeled as consumed (will never be retrieved in the future) - Raises: - RuntimeError: If the fetch count exceeds the expected number of - fetches per DP group. + Raise: + RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``. Note: The ``world_size // dp_world_size`` calculation determines how many @@ -102,6 +100,9 @@ def sample( data_for_dp_group = self._states.get(dp_group, None) # Calculate how many times this batch should be fetched across all ranks + if world_size % dp_world_size != 0: + raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})") + fetches_per_batch = world_size // dp_world_size if data_for_dp_group is None: From 0af123b2e0b81f4948dca6cd057d5b4529702a14 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 20 Jan 2026 14:29:50 +0800 Subject: [PATCH 4/5] fix Signed-off-by: 0oshowero0 --- tests/test_samplers.py | 4 ++-- transfer_queue/sampler/rank_aware_sampler.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 5e3cc7c..d8cd1d6 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -480,9 +480,9 @@ def test_rank_aware_sampler_second_rank_gets_cached(self): ) assert sampled1 == sampled2 == [0, 1, 2] - # First rank returns empty consumed (not all ranks have fetched yet) + # First rank already returns consumed indexes assert consumed1 == [0, 1, 2] - # Last rank returns consumed when all ranks have fetched + # Second rank also sees the same consumed indexes; state is then cleaned up assert consumed2 == [0, 1, 2] # State should be cleaned up assert sampler._states == {} diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index bb89089..77e64b8 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -75,7 +75,7 @@ def sample( ready samples, all available samples will be returned. dp_group: The group id of current data parallel group. Used to identify which DP group this rank belongs to. - dp_world_size: Number of ranks in the data parallel group. Used to + dp_world_size: Number of ranks in the data parallelism group. Used to determine when all ranks have fetched their samples. world_size: Total number of ranks across all parallelism dimensions. Used to determine when all ranks have fetched their samples. @@ -83,12 +83,14 @@ def sample( **kwargs: Additional keyword arguments (ignored). Returns: - List of sampled global indices of length batch_size + List of sampled global indices. The length is + min(batch_size, len(ready_indexes)), and may be smaller than + batch_size if fewer ready samples are available. - List of global indices of length batch_size that should be labeled as consumed - (will never be retrieved in the future) + List of global indices that should be labeled as consumed + (will never be retrieved by other dp_groups in the future). - Raise: + Raises: RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``. Note: From 3b6eff26da37bcab471555c93b3613d34aa81846 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 20 Jan 2026 14:44:35 +0800 Subject: [PATCH 5/5] change behavior when batch_size > ready_indexes Signed-off-by: 0oshowero0 --- tests/test_samplers.py | 5 ++-- transfer_queue/sampler/rank_aware_sampler.py | 24 +++++++++----------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index d8cd1d6..0d665b6 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -568,9 +568,8 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self): # When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) - assert sampled == [0, 1] - assert consumed == [0, 1] - assert len(sampled) == len(ready_indexes) + assert sampled == [] + assert consumed == [] def test_rank_aware_sampler_zero_batch_size(self): """Test behavior with zero batch size.""" diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index 77e64b8..b369ca5 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -32,8 +32,8 @@ class RankAwareSampler(BaseSampler): - First rank in a DP group to call :meth:`sample` performs actual sampling from ``ready_indexes`` and caches the result - Subsequent ranks in the same DP group retrieve the cached indices - - Once all ranks in the DP group have fetched their samples, the indices are - marked as consumed + - Once all ranks in the DP group have fetched their samples, the cached state is + cleaned up. Please refer to our roadmap for more details: @@ -83,36 +83,34 @@ def sample( **kwargs: Additional keyword arguments (ignored). Returns: - List of sampled global indices. The length is - min(batch_size, len(ready_indexes)), and may be smaller than - batch_size if fewer ready samples are available. + List of sampled global indices. Typically, has length `batch_size`, + or returns an empty list if samples are insufficient. List of global indices that should be labeled as consumed (will never be retrieved by other dp_groups in the future). Raises: RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``. - - Note: - The ``world_size // dp_world_size`` calculation determines how many - times each batch should be fetched (once per TP/PP/... rank group). """ # Check if this DP group already has sampled data cached data_for_dp_group = self._states.get(dp_group, None) # Calculate how many times this batch should be fetched across all ranks - if world_size % dp_world_size != 0: + if dp_world_size <= 0 or world_size % dp_world_size != 0: raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})") fetches_per_batch = world_size // dp_world_size if data_for_dp_group is None: - # Initialize state for this DP group - self._states[dp_group] = {} - # Select first batch_size indices from ready_indexes sampled_indexes = ready_indexes[:batch_size] + + if len(sampled_indexes) < batch_size: + return [], [] + + # Initialize state for this DP group + self._states[dp_group] = {} consumed_indexes = sampled_indexes # Cache the sampled indices for other ranks in this DP group