diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 0d665b6..d184bba 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -445,66 +445,93 @@ def test_rank_aware_sampler_initialization(self): assert sampler._states == {} def test_rank_aware_sampler_first_rank_sampling(self): - """Test that first rank in DP group performs actual sampling.""" + """Test that first rank in data replica group performs actual sampling.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - # When world_size == dp_world_size, fetches_per_batch = 1 - # First rank samples and immediately marks consumed (no other ranks to wait for) - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + # Rank 0 (first in group) samples and caches for all ranks + # Since rank 1 will call next, state is kept until rank 1 fetches + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task", + partition_id="test", + ) assert sampled == [0, 1, 2] - # consumed is returned assert consumed == [0, 1, 2] assert len(sampled) == batch_size - # State should be cleaned up - assert sampler._states == {} + # State is kept for other ranks to fetch def test_rank_aware_sampler_second_rank_gets_cached(self): - """Test that second rank in DP group gets cached indices.""" + """Test that second rank in data replica group gets cached indices.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5] batch_size = 3 - dp_world_size = 2 - world_size = 4 # Use world_size=4 so fetches_per_batch=2 - # Rank 0 (dp_group=0) samples first + # Rank 0 (first in group) samples first sampled1, consumed1 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task", + partition_id="test", ) - # Rank 1 (dp_group=0) should get same cached indices + # Rank 1 (second in group) should get same cached indices sampled2, consumed2 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=2, + task_name="task", + partition_id="test", ) assert sampled1 == sampled2 == [0, 1, 2] - # First rank already returns consumed indexes assert consumed1 == [0, 1, 2] - # Second rank also sees the same consumed indexes; state is then cleaned up assert consumed2 == [0, 1, 2] - # State should be cleaned up - assert sampler._states == {} + + # cache should be empty after all ranks fetch + assert len(sampler._states["test"]["task"][0][0]) == 0 + assert len(sampler._states["test"]["task"][0][1]) == 0 def test_rank_aware_sampler_multiple_dp_groups(self): - """Test that multiple DP groups work independently.""" + """Test that multiple data replica groups work independently.""" sampler = RankAwareSampler() ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] batch_size = 2 - dp_world_size = 4 - world_size = 8 + data_replica_world_size = 2 # Each group has 2 ranks - # DP group 0: rank 0 samples first + # data replica group 0: rank 0 samples first sampled0_g0, consumed0_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="task", + partition_id="test", ) # mimic the consumption status update managed in TransferQueueController ready_indexes = [i for i in ready_indexes if i not in consumed0_g0] - # DP group 1: rank 0 samples first + # data replica group 1: rank 0 samples first sampled0_g1, consumed0_g1 = sampler.sample( - ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=1, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="task", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed0_g1] @@ -514,39 +541,66 @@ def test_rank_aware_sampler_multiple_dp_groups(self): assert consumed0_g0 == [0, 1] assert consumed0_g1 == [2, 3] - # DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed + # data replica group 0: rank 1 fetches cached sampled1_g0, consumed1_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="task", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g0] assert sampled1_g0 == [0, 1] assert consumed1_g0 == [0, 1] - # DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed + # data replica group 1: rank 1 fetches cached sampled1_g1, consumed1_g1 = sampler.sample( - ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=1, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="task", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed1_g1] assert sampled1_g1 == [2, 3] assert consumed1_g1 == [2, 3] - # DP group 0: rank 0 fetches again, this should return new data + # data replica group 0: rank 0 fetches again, this should return new data sampled2_g0, consumed2_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=data_replica_world_size, + task_name="task", + partition_id="test", ) ready_indexes = [i for i in ready_indexes if i not in consumed2_g0] assert sampled2_g0 == [4, 5] assert consumed2_g0 == [4, 5] - # DP group 0: rank 1 fetches cached + # data replica group 0: rank 1 fetches cached sampled3_g0, consumed3_g0 = sampler.sample( - ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=data_replica_world_size, + task_name="task", + partition_id="test", ) assert sampled3_g0 == [4, 5] assert consumed3_g0 == [4, 5] - # Both groups should be cleaned up - assert sampler._states == {} + # examine the internal state to ensure proper caching and clearing + assert len(sampler._states["test"]["task"][0][0]) == 0 + assert len(sampler._states["test"]["task"][0][1]) == 0 + assert len(sampler._states["test"]["task"][1][0]) == 0 + assert len(sampler._states["test"]["task"][1][1]) == 0 def test_rank_aware_sampler_empty_ready_indexes(self): """Test behavior with empty ready indexes.""" @@ -554,7 +608,15 @@ def test_rank_aware_sampler_empty_ready_indexes(self): ready_indexes = [] batch_size = 3 - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task", + partition_id="test", + ) assert sampled == [] assert consumed == [] @@ -565,8 +627,15 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self): ready_indexes = [0, 1] batch_size = 5 - # When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task", + partition_id="test", + ) assert sampled == [] assert consumed == [] @@ -577,11 +646,112 @@ def test_rank_aware_sampler_zero_batch_size(self): ready_indexes = [0, 1, 2, 3] batch_size = 0 - sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2) + sampled, consumed = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task", + partition_id="test", + ) assert sampled == [] assert consumed == [] + def test_rank_aware_sampler_data_prefetch(self): + """Test behavior with data prefetch.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + + sampled_rank0_time0, consumed_rank0_time0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task", + partition_id="test", + ) + + assert sampled_rank0_time0 == [0, 1] + assert consumed_rank0_time0 == [0, 1] + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[0, 1]] + + ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0] + + sampled_rank0_time1, consumed_rank0_time1 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task", + partition_id="test", + ) + + assert sampled_rank0_time1 == [2, 3] + assert consumed_rank0_time1 == [2, 3] + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[0, 1], [2, 3]] + + ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1] + + sampled_rank1_time0, consumed_rank1_time0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=1, + data_replica_world_size=2, + task_name="task", + partition_id="test", + ) + assert sampled_rank1_time0 == [0, 1] + assert consumed_rank1_time0 == [0, 1] + + assert sampler._states["test"]["task"][0][0] == [] + assert sampler._states["test"]["task"][0][1] == [[2, 3]] + + def test_rank_aware_sampler_multiple_tasks(self): + """Test behavior with multiple tasks.""" + sampler = RankAwareSampler() + ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7] + batch_size = 2 + + sampled_rank0_task0, consumed_rank0_task0 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task0", + partition_id="test", + ) + + assert sampled_rank0_task0 == [0, 1] + assert consumed_rank0_task0 == [0, 1] + assert sampler._states["test"]["task0"][0][0] == [] + assert sampler._states["test"]["task0"][0][1] == [[0, 1]] + + sampled_rank0_task1, consumed_rank0_task1 = sampler.sample( + ready_indexes, + batch_size, + data_replica_group=0, + data_replica_rank=0, + data_replica_world_size=2, + task_name="task1", + partition_id="test", + ) + + assert sampled_rank0_task1 == [0, 1] + assert consumed_rank0_task1 == [0, 1] + assert sampler._states["test"]["task0"][0][0] == [] + assert sampler._states["test"]["task0"][0][1] == [[0, 1]] + assert sampler._states["test"]["task1"][0][0] == [] + assert sampler._states["test"]["task1"][0][1] == [[0, 1]] + class TestSamplerIntegration: """Integration tests for samplers.""" diff --git a/transfer_queue/sampler/rank_aware_sampler.py b/transfer_queue/sampler/rank_aware_sampler.py index b369ca5..64531fd 100644 --- a/transfer_queue/sampler/rank_aware_sampler.py +++ b/transfer_queue/sampler/rank_aware_sampler.py @@ -24,16 +24,15 @@ class RankAwareSampler(BaseSampler): This sampler is designed for distributed data parallel training scenarios where each rank retrieves data independently. - This sampler guarantees that all ranks within the same DP group receive + This sampler guarantees that all ranks within the same data replica group receive the same sample indices. - The sampler maintains per-DP-group state to coordinate sampling across ranks: + The sampler maintains inner state to coordinate sampling across ranks: - - First rank in a DP group to call :meth:`sample` performs actual sampling from - ``ready_indexes`` and caches the result - - Subsequent ranks in the same DP group retrieve the cached indices - - Once all ranks in the DP group have fetched their samples, the cached state is - cleaned up. + - First rank in a data replica group to call :meth:`sample` performs actual sampling from + ``ready_indexes`` and caches the result for other ranks in the same group + - Subsequent ranks in the same group retrieve the cached indices. + - If no cached indices are available, sampling is performed again and cached for others. Please refer to our roadmap for more details: @@ -45,88 +44,111 @@ def __init__(self): """Initialize the RankAwareSampler. The sampler maintains internal state to coordinate sampling across ranks - within the same DP group. This state tracks which samples have been sampled + within the same data replica group. This state tracks which samples have been sampled and how many times they have been fetched. """ + super().__init__() def sample( self, ready_indexes: list[int], batch_size: int, - dp_group: int, - dp_world_size: int, - world_size: int, + data_replica_group: int, + data_replica_rank: int, + data_replica_world_size: int, + task_name: str, + partition_id: str, *args: Any, **kwargs: Any, ) -> tuple[list[int], list[int]]: - """Sample indices for the current rank, coordinating with other DP ranks. + """Sample indices for the current rank, coordinating with other data replica ranks. This method implements coordinated sampling for distributed training. - The first rank in each DP group to call this method performs actual sampling + The first rank in each data replica group to call this method performs actual sampling from ``ready_indexes`` and caches the result. Subsequent ranks in the same - DP group receive the cached indices directly. + data replica group receive the cached indices directly. + + Internal state structure (self._states): + + .. code-block:: python + + self._states = { + "partition_id": { + "task_name": { + data_replica_group: { + data_replica_rank: [[sampled_indexes], ...] # Buffer of cached sampled indices + } + } + } + } + + State lifecycle: + 1. First rank samples from ``ready_indexes``, caches results for other ranks + 2. Other ranks pop and retrieve the cached indices Args: ready_indexes: List of global indices for which all required fields of the corresponding samples have been produced, and the samples are not labeled as consumed in the corresponding task. batch_size: Number of samples to select. If larger than available - ready samples, all available samples will be returned. - dp_group: The group id of current data parallel group. Used to - identify which DP group this rank belongs to. - dp_world_size: Number of ranks in the data parallelism group. Used to - determine when all ranks have fetched their samples. - world_size: Total number of ranks across all parallelism dimensions. - Used to determine when all ranks have fetched their samples. + ready samples, no samples are returned and both lists are empty. + data_replica_group: The group id of current data replica group. Used to + identify which data replica group this rank belongs to. + data_replica_rank: Local rank inside this data_replica_group. + data_replica_world_size: Total number of ranks in this data_replica_group. + task_name: Identifier for the task. + partition_id: Partition ID for data management. *args: Additional positional arguments (ignored). **kwargs: Additional keyword arguments (ignored). Returns: - List of sampled global indices. Typically, has length `batch_size`, - or returns an empty list if samples are insufficient. - - List of global indices that should be labeled as consumed - (will never be retrieved by other dp_groups in the future). + Tuple of two lists: + - List of sampled global indices. Typically, has length ``batch_size``, + or empty if samples are insufficient. + - List of global indices to mark as consumed (excluded from future + retrieval by other data_replica_groups). Raises: - RuntimeError: If ``world_size`` is not divisible by ``dp_world_size``. + ValueError: If ``data_replica_rank`` or ``data_replica_world_size`` is invalid. + """ - # Check if this DP group already has sampled data cached - data_for_dp_group = self._states.get(dp_group, None) + if data_replica_world_size < 1: + raise ValueError(f"data_replica_world_size {data_replica_world_size} must >= 1") + + if data_replica_rank >= data_replica_world_size or data_replica_rank < 0: + raise ValueError( + f"data_replica_rank {data_replica_rank} must be greater than or equal to 0 and less than " + f"data_replica_world_size {data_replica_world_size}" + ) + + if partition_id not in self._states: + self._states[partition_id] = {} - # Calculate how many times this batch should be fetched across all ranks - if dp_world_size <= 0 or world_size % dp_world_size != 0: - raise RuntimeError(f"world_size ({world_size}) is not divisible by dp_world_size ({dp_world_size})") + if task_name not in self._states[partition_id]: + self._states[partition_id][task_name] = {} - fetches_per_batch = world_size // dp_world_size + if data_replica_group not in self._states[partition_id][task_name]: + self._states[partition_id][task_name][data_replica_group] = {i: [] for i in range(data_replica_world_size)} - if data_for_dp_group is None: + if len(self._states[partition_id][task_name][data_replica_group][data_replica_rank]) == 0: # Select first batch_size indices from ready_indexes sampled_indexes = ready_indexes[:batch_size] if len(sampled_indexes) < batch_size: return [], [] - # Initialize state for this DP group - self._states[dp_group] = {} consumed_indexes = sampled_indexes - # Cache the sampled indices for other ranks in this DP group - self._states[dp_group]["index"] = sampled_indexes - self._states[dp_group]["fetch_count"] = 1 + # Cache the sampled indices for other ranks in this data replica group + for i in range(data_replica_world_size): + if i != data_replica_rank: + self._states[partition_id][task_name][data_replica_group][i].append(sampled_indexes) else: # Return the cached indices (identical to what first rank received) - sampled_indexes = self._states[dp_group]["index"] - consumed_indexes = self._states[dp_group]["index"] - - # Increment fetch count to track progress - self._states[dp_group]["fetch_count"] += 1 - - # Check if this was the last rank in the DP group to fetch - if self._states[dp_group]["fetch_count"] >= fetches_per_batch: - del self._states[dp_group] + sampled_indexes = self._states[partition_id][task_name][data_replica_group][data_replica_rank].pop(0) + consumed_indexes = sampled_indexes return sampled_indexes, consumed_indexes