From 1a75b60f85d0cb941997fbfe8c033a118f566970 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 18:01:02 +0800 Subject: [PATCH 1/7] fix for async retrival and data preload Signed-off-by: 0oshowero0 --- transfer_queue/sampler/rank_aware_sampler.py | 85 +++++++++++--------- 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index b369ca5..a44c4cb 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -48,15 +48,18 @@ def __init__(self): within the same DP group. This state tracks which samples have been sampled and how many times they have been fetched. """ + super().__init__() def sample( self, ready_indexes: list[int], batch_size: int, - dp_group: int, - dp_world_size: int, - world_size: int, + data_replica_group: int, + data_replica_rank: int, + data_replica_world_size: int, + task_name: str, + partition_id: str, *args: Any, **kwargs: Any, ) -> tuple[list[int], list[int]]: @@ -67,66 +70,74 @@ def sample( from ``ready_indexes`` and caches the result. Subsequent ranks in the same DP group receive the cached indices directly. + Internal state structure (self._states): + + .. code-block:: python + + self._states = { + "partition_id": { + "task_name": { + data_replica_group: { + data_replica_rank: [sampled_indexes] # Cached sampled indices + } + } + } + } + + State lifecycle: + 1. First rank samples from ``ready_indexes``, caches results for other ranks + 2. Other ranks pop and retrieve the cached indices + Args: ready_indexes: List of global indices for which all required fields of the corresponding samples have been produced, and the samples are not labeled as consumed in the corresponding task. batch_size: Number of samples to select. If larger than available ready samples, all available samples will be returned. - dp_group: The group id of current data parallel group. Used to - identify which DP group this rank belongs to. - dp_world_size: Number of ranks in the data parallelism group. Used to - determine when all ranks have fetched their samples. - world_size: Total number of ranks across all parallelism dimensions. - Used to determine when all ranks have fetched their samples. + data_replica_group: The group id of current data replica group. Used to + identify which data replica group this rank belongs to. + data_replica_rank: Local rank inside this data_replica_group. + data_replica_world_size: Total number of ranks in this data_replica_group. + task_name: Identifier for the task. + partition_id: Partition ID for data management. *args: Additional positional arguments (ignored). **kwargs: Additional keyword arguments (ignored). Returns: - List of sampled global indices. Typically, has length `batch_size`, - or returns an empty list if samples are insufficient. + Tuple of two lists: + - List of sampled global indices. Typically, has length ``batch_size``, + or empty if samples are insufficient. + - List of global indices to mark as consumed (excluded from future + retrieval by other data_replica_groups). - List of global indices that should be labeled as consumed - (will never be retrieved by other dp_groups in the future). - - Raises: - RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``. """ - # Check if this DP group already has sampled data cached - data_for_dp_group = self._states.get(dp_group, None) + if partition_id not in self._states: + self._states[partition_id] = {} - # Calculate how many times this batch should be fetched across all ranks - if dp_world_size <= 0 or world_size % dp_world_size != 0: - raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})") + if task_name not in self._states[partition_id]: + self._states[partition_id][task_name] = {} - fetches_per_batch = world_size // dp_world_size + if data_replica_group not in self._states[partition_id][task_name]: + self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)} - if data_for_dp_group is None: + if len(self._states[partition_id][task_name][data_replica_group][data_replica_rank]) == 0: # Select first batch_size indices from ready_indexes sampled_indexes = ready_indexes[:batch_size] if len(sampled_indexes) < batch_size: return [], [] - # Initialize state for this DP group - self._states[dp_group] = {} consumed_indexes = sampled_indexes - # Cache the sampled indices for other ranks in this DP group - self._states[dp_group]["index"] = sampled_indexes - self._states[dp_group]["fetch_count"] = 1 + # Cache the sampled indices for other ranks in this data replica group + for i in range(data_replica_world_size): + if i != data_replica_rank: + self._states[partition_id][task_name][data_replica_group][i].append(sampled_indexes) else: # Return the cached indices (identical to what first rank received) - sampled_indexes = self._states[dp_group]["index"] - consumed_indexes = self._states[dp_group]["index"] - - # Increment fetch count to track progress - self._states[dp_group]["fetch_count"] += 1 - - # Check if this was the last rank in the DP group to fetch - if self._states[dp_group]["fetch_count"] >= fetches_per_batch: - del self._states[dp_group] + sampled_indexes = self._states[partition_id][task_name][data_replica_group][data_replica_rank].pop() + consumed_indexes = sampled_indexes return sampled_indexes, consumed_indexes From 2f81d3f1cad4c3631a56f5036cbadaaadff9611f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 18:32:14 +0800 Subject: [PATCH 2/7] update CI Signed-off-by: 0oshowero0 --- tests/test_samplers.py | 249 ++++++++++++++++--- transfer_queue/sampler/rank_aware_sampler.py | 2 +- 2 files changed, 211 insertions(+), 40 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 0d665b6..4cf9864 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -445,66 +445,93 @@ def test_rank_aware_sampler_initialization(self): assert sampler._states == {} def test_rank_aware_sampler_first_rank_sampling(self): - """Test that first rank in DP group performs actual sampling.""" + """Test that first rank in data replica group performs actual sampling.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - # When world_size == dp_world_size, fetches_per_batch = 1 - # First rank samples and immediately marks consumed (no other ranks to wait for) - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + # Rank 0 (first in group) samples and caches for all ranks + # Since rank 1 will call next, state is kept until rank 1 fetches + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) assert sampled == [0, 1, 2] - # consumed is returned assert consumed == [0, 1, 2] assert len(sampled) == batch_size - # State should be cleaned up - assert sampler._states == {} + # State is kept for other ranks to fetch def test_rank_aware_sampler_second_rank_gets_cached(self): - """Test that second rank in DP group gets cached indices.""" + """Test that second rank in data replica group gets cached indices.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - dp_world_size = 2 - world_size = 4 # Use world_size=4 so fetches_per_batch=2 - # Rank 0 (dp_group=0) samples first + # Rank 0 (first in group) samples first sampled1, consumed1 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", ) - # Rank 1 (dp_group=0) should get same cached indices + # Rank 1 (second in group) should get same cached indices sampled2, consumed2 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=2, + task_name="test", + partition_id="test", ) assert sampled1 == sampled2 == [0, 1, 2] - # First rank already returns consumed indexes assert consumed1 == [0, 1, 2] - # Second rank also sees the same consumed indexes; state is then cleaned up assert consumed2 == [0, 1, 2] - # State should be cleaned up - assert sampler._states == {} + + # cache should be empty after all ranks fetch + assert len(sampler._states["test"]["test"][0][0]) == 0 + assert len(sampler._states["test"]["test"][0][1]) == 0 def test_rank_aware_sampler_multiple_dp_groups(self): - """Test that multiple DP groups work independently.""" + """Test that multiple data replica groups work independently.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] batch_size = 2 - dp_world_size = 4 - world_size = 8 + data_replica_world_size = 2 # Each group has 2 ranks - # DP group 0: rank 0 samples first + # data replica group 0: rank 0 samples first sampled0_g0, consumed0_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) # mimic the consumption status update managed in TransferQueueController ready_indexes = [i for i in ready_indexes if i not in consumed0_g0] - # DP group 1: rank 0 samples first + # data replica group 1: rank 0 samples first sampled0_g1, consumed0_g1 = sampler.sample( - ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=1, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed0_g1] @@ -514,39 +541,66 @@ def test_rank_aware_sampler_multiple_dp_groups(self): assert consumed0_g0 == [0, 1] assert consumed0_g1 == [2, 3] - # DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed + # data replica group 0: rank 1 fetches cached sampled1_g0, consumed1_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g0] assert sampled1_g0 == [0, 1] assert consumed1_g0 == [0, 1] - # DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed + # data replica group 1: rank 1 fetches cached sampled1_g1, consumed1_g1 = sampler.sample( - ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=1, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g1] assert sampled1_g1 == [2, 3] assert consumed1_g1 == [2, 3] - # DP group 0: rank 0 fetches again, this should return new data + # data replica group 0: rank 0 fetches again, this should return new data sampled2_g0, consumed2_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed2_g0] assert sampled2_g0 == [4, 5] assert consumed2_g0 == [4, 5] - # DP group 0: rank 1 fetches cached + # data replica group 0: rank 1 fetches cached sampled3_g0, consumed3_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="test", + partition_id="test", ) assert sampled3_g0 == [4, 5] assert consumed3_g0 == [4, 5] - # Both groups should be cleaned up - assert sampler._states == {} + # examine the internal state to ensure proper caching and clearing + assert len(sampler._states["test"]["test"][0][0]) == 0 + assert len(sampler._states["test"]["test"][0][1]) == 0 + assert len(sampler._states["test"]["test"][1][0]) == 0 + assert len(sampler._states["test"]["test"][1][1]) == 0 def test_rank_aware_sampler_empty_ready_indexes(self): """Test behavior with empty ready indexes.""" @@ -554,7 +608,15 @@ def test_rank_aware_sampler_empty_ready_indexes(self): ready_indexes = [] batch_size = 3 - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) assert sampled == [] assert consumed == [] @@ -565,8 +627,15 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self): ready_indexes = [0, 1] batch_size = 5 - # When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) assert sampled == [] assert consumed == [] @@ -577,11 +646,113 @@ def test_rank_aware_sampler_zero_batch_size(self): ready_indexes = [0, 1, 2, 3] batch_size = 0 - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) assert sampled == [] assert consumed == [] + def test_rank_aware_sampler_data_prefetch(self): + """Test behavior with data prefetch.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + + sampled_rank0_time0, consumed_rank0_time0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) + + assert sampled_rank0_time0 == [0, 1] + assert consumed_rank0_time0 == [0, 1] + assert sampler._states["test"]["test"][0][0] == [] + assert sampler._states["test"]["test"][0][1] == [[0, 1]] + + ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0] + + sampled_rank0_time1, consumed_rank0_time1 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) + + assert sampled_rank0_time1 == [2, 3] + assert consumed_rank0_time1 == [2, 3] + assert sampler._states["test"]["test"][0][0] == [] + assert sampler._states["test"]["test"][0][1] == [[0, 1], [2, 3]] + + ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1] + + sampled_rank1_time0, consumed_rank1_time0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=2, + task_name="test", + partition_id="test", + ) + assert sampled_rank1_time0 == [0, 1] + assert consumed_rank1_time0 == [0, 1] + ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0] + + assert sampler._states["test"]["test"][0][0] == [] + assert sampler._states["test"]["test"][0][1] == [[2, 3]] + + def test_rank_aware_sampler_multiple_tasks(self): + """Test behavior with multiple tasks.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + + sampled_rank0_task0, consumed_rank0_task0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task0", + partition_id="test", + ) + + assert sampled_rank0_task0 == [0, 1] + assert consumed_rank0_task0 == [0, 1] + assert sampler._states["test"]["task0"][0][0] == [] + assert sampler._states["test"]["task0"][0][1] == [[0, 1]] + + sampled_rank0_task1, consumed_rank0_task1 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=2, + task_name="task1", + partition_id="test", + ) + + assert sampled_rank0_task1 == [0, 1] + assert consumed_rank0_task1 == [0, 1] + assert sampler._states["test"]["task0"][0][0] == [] + assert sampler._states["test"]["task0"][0][1] == [[0, 1]] + assert sampler._states["test"]["task1"][0][0] == [[0, 1]] + assert sampler._states["test"]["task1"][0][1] == [] + class TestSamplerIntegration: """Integration tests for samplers.""" diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index a44c4cb..1984362 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -137,7 +137,7 @@ def sample( else: # Return the cached indices (identical to what first rank received) - sampled_indexes = self._states[partition_id][task_name][data_replica_group][data_replica_rank].pop() + sampled_indexes = self._states[partition_id][task_name][data_replica_group][data_replica_rank].pop(0) consumed_indexes = sampled_indexes return sampled_indexes, consumed_indexes From 0d84f2582db8242ac6943ccdf1da1cb5d38ab9bd Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 18:33:34 +0800 Subject: [PATCH 3/7] fix Signed-off-by: 0oshowero0 --- tests/test_samplers.py | 54 +++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 4cf9864..399797c 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -458,7 +458,7 @@ def test_rank_aware_sampler_first_rank_sampling(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -480,7 +480,7 @@ def test_rank_aware_sampler_second_rank_gets_cached(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -491,7 +491,7 @@ def test_rank_aware_sampler_second_rank_gets_cached(self): data_replica_group=0, data_replica_rank=1, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -500,8 +500,8 @@ def test_rank_aware_sampler_second_rank_gets_cached(self): assert consumed2 == [0, 1, 2] # cache should be empty after all ranks fetch - assert len(sampler._states["test"]["test"][0][0]) == 0 - assert len(sampler._states["test"]["test"][0][1]) == 0 + assert len(sampler._states["test"]["task"][0][0]) == 0 + assert len(sampler._states["test"]["task"][0][1]) == 0 def test_rank_aware_sampler_multiple_dp_groups(self): """Test that multiple data replica groups work independently.""" @@ -517,7 +517,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) # mimic the consumption status update managed in TransferQueueController @@ -530,7 +530,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=1, data_replica_rank=0, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed0_g1] @@ -548,7 +548,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=0, data_replica_rank=1, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g0] @@ -562,7 +562,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=1, data_replica_rank=1, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g1] @@ -576,7 +576,7 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed2_g0] @@ -590,17 +590,17 @@ def test_rank_aware_sampler_multiple_dp_groups(self): data_replica_group=0, data_replica_rank=1, data_replica_world_size=data_replica_world_size, - task_name="test", + task_name="task", partition_id="test", ) assert sampled3_g0 == [4, 5] assert consumed3_g0 == [4, 5] # examine the internal state to ensure proper caching and clearing - assert len(sampler._states["test"]["test"][0][0]) == 0 - assert len(sampler._states["test"]["test"][0][1]) == 0 - assert len(sampler._states["test"]["test"][1][0]) == 0 - assert len(sampler._states["test"]["test"][1][1]) == 0 + assert len(sampler._states["test"]["task"][0][0]) == 0 + assert len(sampler._states["test"]["task"][0][1]) == 0 + assert len(sampler._states["test"]["task"][1][0]) == 0 + assert len(sampler._states["test"]["task"][1][1]) == 0 def test_rank_aware_sampler_empty_ready_indexes(self): """Test behavior with empty ready indexes.""" @@ -614,7 +614,7 @@ def test_rank_aware_sampler_empty_ready_indexes(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -633,7 +633,7 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -652,7 +652,7 @@ def test_rank_aware_sampler_zero_batch_size(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) @@ -671,14 +671,14 @@ def test_rank_aware_sampler_data_prefetch(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) assert sampled_rank0_time0 == [0, 1] assert consumed_rank0_time0 == [0, 1] - assert sampler._states["test"]["test"][0][0] == [] - assert sampler._states["test"]["test"][0][1] == [[0, 1]] + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[0, 1]] ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0] @@ -688,14 +688,14 @@ def test_rank_aware_sampler_data_prefetch(self): data_replica_group=0, data_replica_rank=0, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) assert sampled_rank0_time1 == [2, 3] assert consumed_rank0_time1 == [2, 3] - assert sampler._states["test"]["test"][0][0] == [] - assert sampler._states["test"]["test"][0][1] == [[0, 1], [2, 3]] + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[0, 1], [2, 3]] ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1] @@ -705,15 +705,15 @@ def test_rank_aware_sampler_data_prefetch(self): data_replica_group=0, data_replica_rank=1, data_replica_world_size=2, - task_name="test", + task_name="task", partition_id="test", ) assert sampled_rank1_time0 == [0, 1] assert consumed_rank1_time0 == [0, 1] ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0] - assert sampler._states["test"]["test"][0][0] == [] - assert sampler._states["test"]["test"][0][1] == [[2, 3]] + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[2, 3]] def test_rank_aware_sampler_multiple_tasks(self): """Test behavior with multiple tasks.""" From f808f78625d5647a1dc73778e6f1585e550f986b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 18:36:01 +0800 Subject: [PATCH 4/7] add check logics Signed-off-by: 0oshowero0 --- transfer_queue/sampler/rank_aware_sampler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index 1984362..2b18dfa 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -110,8 +110,17 @@ def sample( - List of global indices to mark as consumed (excluded from future retrieval by other data_replica_groups). + Raises: + ValueError: If ``data_replica_rank`` is invalid. + """ + if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: + raise ValueError( + f"data_replica_rank {data_replica_rank} must bigger than 0 and less than " + f"data_replica_world_size {data_replica_world_size}" + ) + if partition_id not in self._states: self._states[partition_id] = {} From 1e8d0bb5cc83e14b07afa388013090ff5f0d79f6 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 19:09:42 +0800 Subject: [PATCH 5/7] fix Signed-off-by: 0oshowero0 --- transfer_queue/sampler/rank_aware_sampler.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index 2b18dfa..fa2a115 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -24,16 +24,15 @@ class RankAwareSampler(BaseSampler): This sampler is designed for distributed data parallel training scenarios where each rank retrieves data independently. - This sampler guarantees that all ranks within the same DP group receive + This sampler guarantees that all ranks within the same data replica group receive the same sample indices. - The sampler maintains per-DP-group state to coordinate sampling across ranks: + The sampler maintains inner state to coordinate sampling across ranks: - - First rank in a DP group to call :meth:`sample` performs actual sampling from - ``ready_indexes`` and caches the result - - Subsequent ranks in the same DP group retrieve the cached indices - - Once all ranks in the DP group have fetched their samples, the cached state is - cleaned up. + - First rank in a data replica group to call :meth:`sample` performs actual sampling from + ``ready_indexes`` and caches the result for other ranks in the same group + - Subsequent ranks in the same group retrieve the cached indices. + - If no cached indices are available, sampling is performed again and cached for others. Please refer to our roadmap for more details: From 6a9dfb34d8bf9cde467ee1c135ad4b3f3b911b9a Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 19:16:58 +0800 Subject: [PATCH 6/7] fix Signed-off-by: 0oshowero0 --- tests/test_samplers.py | 7 +++---- transfer_queue/sampler/rank_aware_sampler.py | 9 ++++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 399797c..d184bba 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -710,7 +710,6 @@ def test_rank_aware_sampler_data_prefetch(self): ) assert sampled_rank1_time0 == [0, 1] assert consumed_rank1_time0 == [0, 1] - ready_indexes = [i for i in ready_indexes if i not in consumed_rank1_time0] assert sampler._states["test"]["task"][0][0] == [] assert sampler._states["test"]["task"][0][1] == [[2, 3]] @@ -740,7 +739,7 @@ def test_rank_aware_sampler_multiple_tasks(self): ready_indexes, batch_size, data_replica_group=0, - data_replica_rank=1, + data_replica_rank=0, data_replica_world_size=2, task_name="task1", partition_id="test", @@ -750,8 +749,8 @@ def test_rank_aware_sampler_multiple_tasks(self): assert consumed_rank0_task1 == [0, 1] assert sampler._states["test"]["task0"][0][0] == [] assert sampler._states["test"]["task0"][0][1] == [[0, 1]] - assert sampler._states["test"]["task1"][0][0] == [[0, 1]] - assert sampler._states["test"]["task1"][0][1] == [] + assert sampler._states["test"]["task1"][0][0] == [] + assert sampler._states["test"]["task1"][0][1] == [[0, 1]] class TestSamplerIntegration: diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index fa2a115..c879556 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -92,7 +92,7 @@ def sample( corresponding samples have been produced, and the samples are not labeled as consumed in the corresponding task. batch_size: Number of samples to select. If larger than available - ready samples, all available samples will be returned. + ready samples, no samples are returned and both lists are empty. data_replica_group: The group id of current data replica group. Used to identify which data replica group this rank belongs to. data_replica_rank: Local rank inside this data_replica_group. @@ -110,13 +110,16 @@ def sample( retrieval by other data_replica_groups). Raises: - ValueError: If ``data_replica_rank`` is invalid. + ValueError: If ``data_replica_rank`` or ``data_replica_world_size`` is invalid. """ + if data_replica_world_size < 1: + raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") + if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: raise ValueError( - f"data_replica_rank {data_replica_rank} must bigger than 0 and less than " + f"data_replica_rank {data_replica_rank} must be greater than or equal to 0 and less than " f"data_replica_world_size {data_replica_world_size}" ) From 8cc1281c3e0827a1d06fcb6899c14025ec5794e7 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Thu, 22 Jan 2026 20:26:43 +0800 Subject: [PATCH 7/7] fix Signed-off-by: 0oshowero0 --- transfer_queue/sampler/rank_aware_sampler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index c879556..64531fd 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -44,7 +44,7 @@ def __init__(self): """Initialize the RankAwareSampler. The sampler maintains internal state to coordinate sampling across ranks - within the same DP group. This state tracks which samples have been sampled + within the same data replica group. This state tracks which samples have been sampled and how many times they have been fetched. """ @@ -62,12 +62,12 @@ def sample( *args: Any, **kwargs: Any, ) -> tuple[list[int], list[int]]: - """Sample indices for the current rank, coordinating with other DP ranks. + """Sample indices for the current rank, coordinating with other data replica ranks. This method implements coordinated sampling for distributed training. - The first rank in each DP group to call this method performs actual sampling + The first rank in each data replica group to call this method performs actual sampling from ``ready_indexes`` and caches the result. Subsequent ranks in the same - DP group receive the cached indices directly. + data replica group receive the cached indices directly. Internal state structure (self._states): @@ -77,7 +77,7 @@ def sample( "partition_id": { "task_name": { data_replica_group: { - data_replica_rank: [sampled_indexes] # Cached sampled indices + data_replica_rank: [[sampled_indexes], ...] # Buffer of cached sampled indices } } }