Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 159 additions & 3 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@

"""Unit tests for TransferQueue samplers."""

import sys
from pathlib import Path
from typing import Any

import pytest

from transfer_queue.sampler import BaseSampler
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler
from transfer_queue.sampler.sequential_sampler import SequentialSampler
# Setup path
parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))
Comment on lines +24 to +26
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manual sys.path manipulation is generally an anti-pattern in Python testing. Modern test runners like pytest automatically handle module discovery. This could cause issues in CI/CD environments or when running tests from different directories. Consider removing this manual path setup and relying on proper package installation (e.g., pip install -e .) or pytest's natural module discovery instead.

Copilot uses AI. Check for mistakes.

from transfer_queue.sampler import BaseSampler # noqa: E402
from transfer_queue.sampler.grpo_group_n_sampler import GRPOGroupNSampler # noqa: E402
from transfer_queue.sampler.rank_aware_sampler import RankAwareSampler # noqa: E402
from transfer_queue.sampler.sequential_sampler import SequentialSampler # noqa: E402


class TestBaseSampler:
Expand Down Expand Up @@ -427,6 +434,155 @@ def test_grpo_sampler_insufficient_groups(self):
assert consumed == []


class TestRankAwareSampler:
"""Test cases for RankAwareSampler."""

def test_rank_aware_sampler_initialization(self):
"""Test RankAwareSampler initialization."""
sampler = RankAwareSampler()
assert isinstance(sampler, BaseSampler)
assert hasattr(sampler, "_states")
assert sampler._states == {}

def test_rank_aware_sampler_first_rank_sampling(self):
"""Test that first rank in DP group performs actual sampling."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3, 4, 5]
batch_size = 3

# When world_size == dp_world_size, fetches_per_batch = 1
# First rank samples and immediately marks consumed (no other ranks to wait for)
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)

assert sampled == [0, 1, 2]
# consumed is returned
assert consumed == [0, 1, 2]
assert len(sampled) == batch_size
# State should be cleaned up
assert sampler._states == {}

def test_rank_aware_sampler_second_rank_gets_cached(self):
"""Test that second rank in DP group gets cached indices."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3, 4, 5]
batch_size = 3
dp_world_size = 2
world_size = 4 # Use world_size=4 so fetches_per_batch=2

# Rank 0 (dp_group=0) samples first
sampled1, consumed1 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
)

# Rank 1 (dp_group=0) should get same cached indices
sampled2, consumed2 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
)

assert sampled1 == sampled2 == [0, 1, 2]
# First rank already returns consumed indexes
assert consumed1 == [0, 1, 2]
# Second rank also sees the same consumed indexes; state is then cleaned up
assert consumed2 == [0, 1, 2]
# State should be cleaned up
assert sampler._states == {}

def test_rank_aware_sampler_multiple_dp_groups(self):
"""Test that multiple DP groups work independently."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
batch_size = 2
dp_world_size = 4
world_size = 8

# DP group 0: rank 0 samples first
sampled0_g0, consumed0_g0 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
)
# mimic the consumption status update managed in TransferQueueController
ready_indexes = [i for i in ready_indexes if i not in consumed0_g0]

# DP group 1: rank 0 samples first
sampled0_g1, consumed0_g1 = sampler.sample(
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
)
ready_indexes = [i for i in ready_indexes if i not in consumed0_g1]

# Both should have sampled their first batch
assert sampled0_g0 == [0, 1]
assert sampled0_g1 == [2, 3]
assert consumed0_g0 == [0, 1]
assert consumed0_g1 == [2, 3]

# DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed
sampled1_g0, consumed1_g0 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
)
ready_indexes = [i for i in ready_indexes if i not in consumed1_g0]
assert sampled1_g0 == [0, 1]
assert consumed1_g0 == [0, 1]

# DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed
sampled1_g1, consumed1_g1 = sampler.sample(
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
)
ready_indexes = [i for i in ready_indexes if i not in consumed1_g1]
assert sampled1_g1 == [2, 3]
assert consumed1_g1 == [2, 3]

# DP group 0: rank 0 fetches again, this should return new data
sampled2_g0, consumed2_g0 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
)
ready_indexes = [i for i in ready_indexes if i not in consumed2_g0]
assert sampled2_g0 == [4, 5]
assert consumed2_g0 == [4, 5]

# DP group 0: rank 1 fetches cached
sampled3_g0, consumed3_g0 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
)
assert sampled3_g0 == [4, 5]
assert consumed3_g0 == [4, 5]

# Both groups should be cleaned up
assert sampler._states == {}

def test_rank_aware_sampler_empty_ready_indexes(self):
"""Test behavior with empty ready indexes."""
sampler = RankAwareSampler()
ready_indexes = []
batch_size = 3

sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)

assert sampled == []
assert consumed == []

def test_rank_aware_sampler_batch_size_larger_than_ready(self):
"""Test behavior when batch_size > len(ready_indexes)."""
sampler = RankAwareSampler()
ready_indexes = [0, 1]
batch_size = 5

# When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)

assert sampled == []
assert consumed == []

def test_rank_aware_sampler_zero_batch_size(self):
"""Test behavior with zero batch size."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3]
batch_size = 0

sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)

assert sampled == []
assert consumed == []


class TestSamplerIntegration:
"""Integration tests for samplers."""

Expand Down
2 changes: 2 additions & 0 deletions transfer_queue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .metadata import BatchMeta
from .sampler import BaseSampler
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
from .sampler.rank_aware_sampler import RankAwareSampler
from .sampler.sequential_sampler import SequentialSampler
from .storage import SimpleStorageUnit
from .utils.utils import get_placement_group
Expand All @@ -41,6 +42,7 @@
"BaseSampler",
"GRPOGroupNSampler",
"SequentialSampler",
"RankAwareSampler",
]

version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
Expand Down
3 changes: 2 additions & 1 deletion transfer_queue/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .base import BaseSampler
from .grpo_group_n_sampler import GRPOGroupNSampler
from .rank_aware_sampler import RankAwareSampler
from .sequential_sampler import SequentialSampler

__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler"]
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler", "RankAwareSampler"]
5 changes: 3 additions & 2 deletions transfer_queue/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ class BaseSampler(ABC):
- **SequentialSampler**: Default sampler, selects samples sequentially without replacement
- **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready.
It assumes the N samples associated with the same prompt are stored contiguously
- **RankAwareSampler**: Rank-aware sampling for distributed scenarios (TODO)
- **RankAwareSampler**: Rank-aware sampling for distributed training where each rank retrieves data independently.
This sampler will guarantee ranks of the same DP group consume identical samples.

NOTE: Always return both sampled and consumed indexes (may be identical).
"""

def __init__(self):
self._states: dict[str, Any] = {}
self._states: dict[Any, Any] = {}

@abstractmethod
def sample(
Expand Down
132 changes: 132 additions & 0 deletions transfer_queue/sampler/rank_aware_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from transfer_queue.sampler import BaseSampler


class RankAwareSampler(BaseSampler):
"""Rank-aware sampler for distributed training with TransferQueue.

This sampler is designed for distributed data parallel training scenarios
where each rank retrieves data independently.

This sampler guarantees that all ranks within the same DP group receive
the same sample indices.

The sampler maintains per-DP-group state to coordinate sampling across ranks:

- First rank in a DP group to call :meth:`sample` performs actual sampling from
``ready_indexes`` and caches the result
- Subsequent ranks in the same DP group retrieve the cached indices
- Once all ranks in the DP group have fetched their samples, the cached state is
cleaned up.


Please refer to our roadmap for more details:
[Roadmap] StreamingDataLoader for task-separated RL post-training
https://github.com/Ascend/TransferQueue/issues/1
"""

def __init__(self):
"""Initialize the RankAwareSampler.

The sampler maintains internal state to coordinate sampling across ranks
within the same DP group. This state tracks which samples have been sampled
and how many times they have been fetched.
"""
super().__init__()

def sample(
self,
ready_indexes: list[int],
batch_size: int,
dp_group: int,
dp_world_size: int,
world_size: int,
*args: Any,
**kwargs: Any,
) -> tuple[list[int], list[int]]:
"""Sample indices for the current rank, coordinating with other DP ranks.

This method implements coordinated sampling for distributed training.
The first rank in each DP group to call this method performs actual sampling
from ``ready_indexes`` and caches the result. Subsequent ranks in the same
DP group receive the cached indices directly.

Args:
ready_indexes: List of global indices for which all required fields of the
corresponding samples have been produced, and the samples are not labeled
as consumed in the corresponding task.
batch_size: Number of samples to select. If larger than available
ready samples, all available samples will be returned.
dp_group: The group id of current data parallel group. Used to
identify which DP group this rank belongs to.
dp_world_size: Number of ranks in the data parallelism group. Used to
determine when all ranks have fetched their samples.
world_size: Total number of ranks across all parallelism dimensions.
Used to determine when all ranks have fetched their samples.
*args: Additional positional arguments (ignored).
**kwargs: Additional keyword arguments (ignored).

Returns:
List of sampled global indices. Typically, has length `batch_size`,
or returns an empty list if samples are insufficient.

List of global indices that should be labeled as consumed
(will never be retrieved by other dp_groups in the future).

Raises:
RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``.
"""

# Check if this DP group already has sampled data cached
data_for_dp_group = self._states.get(dp_group, None)

# Calculate how many times this batch should be fetched across all ranks
if dp_world_size <= 0 or world_size % dp_world_size != 0:
raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})")

fetches_per_batch = world_size // dp_world_size

if data_for_dp_group is None:
# Select first batch_size indices from ready_indexes
sampled_indexes = ready_indexes[:batch_size]

if len(sampled_indexes) < batch_size:
return [], []

# Initialize state for this DP group
self._states[dp_group] = {}
consumed_indexes = sampled_indexes

# Cache the sampled indices for other ranks in this DP group
self._states[dp_group]["index"] = sampled_indexes
self._states[dp_group]["fetch_count"] = 1

else:
# Return the cached indices (identical to what first rank received)
sampled_indexes = self._states[dp_group]["index"]
consumed_indexes = self._states[dp_group]["index"]

# Increment fetch count to track progress
self._states[dp_group]["fetch_count"] += 1

# Check if this was the last rank in the DP group to fetch
if self._states[dp_group]["fetch_count"] >= fetches_per_batch:
del self._states[dp_group]

Comment on lines +128 to +131
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If some ranks fail to call sample() (e.g., due to crashes or network issues), the state for that dp_group will never be cleaned up and will remain in self._states indefinitely, causing a memory leak. Consider adding a timeout mechanism or state cleanup strategy for orphaned entries, or document this limitation in the class docstring.

Copilot uses AI. Check for mistakes.
return sampled_indexes, consumed_indexes