Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 209 additions & 39 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,66 +445,93 @@ def test_rank_aware_sampler_initialization(self):
assert sampler._states == {}

def test_rank_aware_sampler_first_rank_sampling(self):
"""Test that first rank in DP group performs actual sampling."""
"""Test that first rank in data replica group performs actual sampling."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3, 4, 5]
batch_size = 3

# When world_size == dp_world_size, fetches_per_batch = 1
# First rank samples and immediately marks consumed (no other ranks to wait for)
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
# Rank 0 (first in group) samples and caches for all ranks
# Since rank 1 will call next, state is kept until rank 1 fetches
sampled, consumed = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)

assert sampled == [0, 1, 2]
# consumed is returned
assert consumed == [0, 1, 2]
assert len(sampled) == batch_size
# State should be cleaned up
assert sampler._states == {}
# State is kept for other ranks to fetch

def test_rank_aware_sampler_second_rank_gets_cached(self):
"""Test that second rank in DP group gets cached indices."""
"""Test that second rank in data replica group gets cached indices."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3, 4, 5]
batch_size = 3
dp_world_size = 2
world_size = 4 # Use world_size=4 so fetches_per_batch=2

# Rank 0 (dp_group=0) samples first
# Rank 0 (first in group) samples first
sampled1, consumed1 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)

# Rank 1 (dp_group=0) should get same cached indices
# Rank 1 (second in group) should get same cached indices
sampled2, consumed2 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=1,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)

assert sampled1 == sampled2 == [0, 1, 2]
# First rank already returns consumed indexes
assert consumed1 == [0, 1, 2]
# Second rank also sees the same consumed indexes; state is then cleaned up
assert consumed2 == [0, 1, 2]
# State should be cleaned up
assert sampler._states == {}

# cache should be empty after all ranks fetch
assert len(sampler._states["test"]["task"][0][0]) == 0
assert len(sampler._states["test"]["task"][0][1]) == 0

def test_rank_aware_sampler_multiple_dp_groups(self):
"""Test that multiple DP groups work independently."""
"""Test that multiple data replica groups work independently."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
batch_size = 2
dp_world_size = 4
world_size = 8
data_replica_world_size = 2 # Each group has 2 ranks

# DP group 0: rank 0 samples first
# data replica group 0: rank 0 samples first
sampled0_g0, consumed0_g0 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=data_replica_world_size,
task_name="task",
partition_id="test",
)
# mimic the consumption status update managed in TransferQueueController
ready_indexes = [i for i in ready_indexes if i not in consumed0_g0]

# DP group 1: rank 0 samples first
# data replica group 1: rank 0 samples first
sampled0_g1, consumed0_g1 = sampler.sample(
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
ready_indexes,
batch_size,
data_replica_group=1,
data_replica_rank=0,
data_replica_world_size=data_replica_world_size,
task_name="task",
partition_id="test",
)
ready_indexes = [i for i in ready_indexes if i not in consumed0_g1]

Expand All @@ -514,47 +541,82 @@ def test_rank_aware_sampler_multiple_dp_groups(self):
assert consumed0_g0 == [0, 1]
assert consumed0_g1 == [2, 3]

# DP group 0: rank 1 fetches cached, and all the data should be labeled as consumed
# data replica group 0: rank 1 fetches cached
sampled1_g0, consumed1_g0 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=1,
data_replica_world_size=data_replica_world_size,
task_name="task",
partition_id="test",
)
ready_indexes = [i for i in ready_indexes if i not in consumed1_g0]
assert sampled1_g0 == [0, 1]
assert consumed1_g0 == [0, 1]

# DP group 1: rank 1 fetches cached, and all the data should be labeled as consumed
# data replica group 1: rank 1 fetches cached
sampled1_g1, consumed1_g1 = sampler.sample(
ready_indexes, batch_size, dp_group=1, dp_world_size=dp_world_size, world_size=world_size
ready_indexes,
batch_size,
data_replica_group=1,
data_replica_rank=1,
data_replica_world_size=data_replica_world_size,
task_name="task",
partition_id="test",
)
ready_indexes = [i for i in ready_indexes if i not in consumed1_g1]
assert sampled1_g1 == [2, 3]
assert consumed1_g1 == [2, 3]

# DP group 0: rank 0 fetches again, this should return new data
# data replica group 0: rank 0 fetches again, this should return new data
sampled2_g0, consumed2_g0 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=data_replica_world_size,
task_name="task",
partition_id="test",
)
ready_indexes = [i for i in ready_indexes if i not in consumed2_g0]
assert sampled2_g0 == [4, 5]
assert consumed2_g0 == [4, 5]

# DP group 0: rank 1 fetches cached
# data replica group 0: rank 1 fetches cached
sampled3_g0, consumed3_g0 = sampler.sample(
ready_indexes, batch_size, dp_group=0, dp_world_size=dp_world_size, world_size=world_size
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=1,
data_replica_world_size=data_replica_world_size,
task_name="task",
partition_id="test",
)
assert sampled3_g0 == [4, 5]
assert consumed3_g0 == [4, 5]

# Both groups should be cleaned up
assert sampler._states == {}
# examine the internal state to ensure proper caching and clearing
assert len(sampler._states["test"]["task"][0][0]) == 0
assert len(sampler._states["test"]["task"][0][1]) == 0
assert len(sampler._states["test"]["task"][1][0]) == 0
assert len(sampler._states["test"]["task"][1][1]) == 0

def test_rank_aware_sampler_empty_ready_indexes(self):
"""Test behavior with empty ready indexes."""
sampler = RankAwareSampler()
ready_indexes = []
batch_size = 3

sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
sampled, consumed = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)

assert sampled == []
assert consumed == []
Expand All @@ -565,8 +627,15 @@ def test_rank_aware_sampler_batch_size_larger_than_ready(self):
ready_indexes = [0, 1]
batch_size = 5

# When world_size == dp_world_size, fetches_per_batch=1, consumed returned immediately
sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
sampled, consumed = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)

assert sampled == []
assert consumed == []
Expand All @@ -577,11 +646,112 @@ def test_rank_aware_sampler_zero_batch_size(self):
ready_indexes = [0, 1, 2, 3]
batch_size = 0

sampled, consumed = sampler.sample(ready_indexes, batch_size, dp_group=0, dp_world_size=2, world_size=2)
sampled, consumed = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)

assert sampled == []
assert consumed == []

def test_rank_aware_sampler_data_prefetch(self):
"""Test behavior with data prefetch."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
batch_size = 2

sampled_rank0_time0, consumed_rank0_time0 = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)

assert sampled_rank0_time0 == [0, 1]
assert consumed_rank0_time0 == [0, 1]
assert sampler._states["test"]["task"][0][0] == []
assert sampler._states["test"]["task"][0][1] == [[0, 1]]

ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time0]

sampled_rank0_time1, consumed_rank0_time1 = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)

assert sampled_rank0_time1 == [2, 3]
assert consumed_rank0_time1 == [2, 3]
assert sampler._states["test"]["task"][0][0] == []
assert sampler._states["test"]["task"][0][1] == [[0, 1], [2, 3]]

ready_indexes = [i for i in ready_indexes if i not in consumed_rank0_time1]

sampled_rank1_time0, consumed_rank1_time0 = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=1,
data_replica_world_size=2,
task_name="task",
partition_id="test",
)
assert sampled_rank1_time0 == [0, 1]
assert consumed_rank1_time0 == [0, 1]

assert sampler._states["test"]["task"][0][0] == []
assert sampler._states["test"]["task"][0][1] == [[2, 3]]

def test_rank_aware_sampler_multiple_tasks(self):
"""Test behavior with multiple tasks."""
sampler = RankAwareSampler()
ready_indexes = [0, 1, 2, 3, 4, 5, 6, 7]
batch_size = 2

sampled_rank0_task0, consumed_rank0_task0 = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task0",
partition_id="test",
)

assert sampled_rank0_task0 == [0, 1]
assert consumed_rank0_task0 == [0, 1]
assert sampler._states["test"]["task0"][0][0] == []
assert sampler._states["test"]["task0"][0][1] == [[0, 1]]

sampled_rank0_task1, consumed_rank0_task1 = sampler.sample(
ready_indexes,
batch_size,
data_replica_group=0,
data_replica_rank=0,
data_replica_world_size=2,
task_name="task1",
partition_id="test",
)

assert sampled_rank0_task1 == [0, 1]
assert consumed_rank0_task1 == [0, 1]
assert sampler._states["test"]["task0"][0][0] == []
assert sampler._states["test"]["task0"][0][1] == [[0, 1]]
assert sampler._states["test"]["task1"][0][0] == []
assert sampler._states["test"]["task1"][0][1] == [[0, 1]]


class TestSamplerIntegration:
"""Integration tests for samplers."""
Expand Down
Loading