Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 100 additions & 4 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,20 @@ def _handle_requests(self):
# Mock partition metadata response
response_body = {"metadata": self._mock_batch_meta(request_msg.body)}
response_type = ZMQRequestType.GET_PARTITION_META_RESPONSE
elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION:
# Mock consumption status check - all consumed
response_body = {
"partition_id": request_msg.body.get("partition_id"),
"consumed": True,
"global_index": torch.tensor([0, 1, 2]),
"consumption_status": torch.tensor([1, 1, 1]),
}
response_type = ZMQRequestType.CONSUMPTION_RESPONSE
elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION:
elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION:
# Mock production status check - all produced
response_body = {
"partition_id": request_msg.body.get("partition_id"),
"produced": True,
"global_index": torch.tensor([0, 1, 2]),
"production_status": torch.tensor([[1, 1, 1], [1, 1, 1]]),
}
response_type = ZMQRequestType.PRODUCTION_RESPONSE
elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS:
Expand Down Expand Up @@ -467,6 +469,52 @@ def test_check_production_status(client_setup):
assert is_produced is True


def test_get_consumption_status(client_setup):
"""Test get_consumption_status - returns global_index and consumption_status tensors"""
client, _, _ = client_setup

# Test synchronous get_consumption_status
global_index, consumption_status = client.get_consumption_status(
task_name="generate_sequences", partition_id="train_0"
)

# Verify return types
assert global_index is not None
assert consumption_status is not None

# Verify global_index contains expected values
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))

# Verify consumption_status (mock returns all consumed)
expected_status = torch.tensor([1, 1, 1], dtype=torch.int8)
assert torch.equal(consumption_status, expected_status)

print("✓ get_consumption_status returns correct global_index and consumption_status")


def test_get_production_status(client_setup):
"""Test get_production_status - returns global_index and production_status tensors"""
client, _, _ = client_setup

# Test synchronous get_production_status
global_index, production_status = client.get_production_status(
data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
)

# Verify return types
assert global_index is not None
assert production_status is not None

# Verify global_index contains expected values
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))

# Verify production_status shape (mock returns 2x3 matrix)
expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8)
assert torch.equal(production_status, expected_status)

print("✓ get_production_status returns correct global_index and production_status")


def test_get_partition_list(client_setup):
"""Test partition list retrieval"""
client, _, _ = client_setup
Expand Down Expand Up @@ -502,6 +550,54 @@ async def test_async_check_production_status(client_setup):
assert is_produced is True


@pytest.mark.asyncio
async def test_async_get_consumption_status(client_setup):
"""Test async get_consumption_status - returns global_index and consumption_status tensors"""
client, _, _ = client_setup

# Test async_get_consumption_status
global_index, consumption_status = await client.async_get_consumption_status(
task_name="generate_sequences", partition_id="train_0"
)

# Verify return types
assert global_index is not None
assert consumption_status is not None

# Verify global_index contains expected values
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))

# Verify consumption_status (mock returns all consumed)
expected_status = torch.tensor([1, 1, 1], dtype=torch.int8)
assert torch.equal(consumption_status, expected_status)

print("✓ async_get_consumption_status returns correct global_index and consumption_status")


@pytest.mark.asyncio
async def test_async_get_production_status(client_setup):
"""Test async get_production_status - returns global_index and production_status tensors"""
client, _, _ = client_setup

# Test async_get_production_status
global_index, production_status = await client.async_get_production_status(
data_fields=["prompt_ids", "attention_mask"], partition_id="train_0"
)

# Verify return types
assert global_index is not None
assert production_status is not None

# Verify global_index contains expected values
assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long))

# Verify production_status shape (mock returns 2x3 matrix)
expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8)
assert torch.equal(production_status, expected_status)

print("✓ async_get_production_status returns correct global_index and production_status")


@pytest.mark.asyncio
async def test_async_get_partition_list(client_setup):
"""Test async partition list retrieval"""
Expand Down
84 changes: 76 additions & 8 deletions tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_controller_with_single_partition(self, ray_setup):
field_names=metadata.field_names,
dtypes=dtypes,
shapes=shapes,
custom_meta=None,
)
)
assert success
Expand All @@ -97,13 +98,18 @@ def test_controller_with_single_partition(self, ray_setup):
assert partition.production_status.size(0) == gbs * num_n_samples

# Test for get production status
production_status = ray.get(
global_index, production_status = ray.get(
tq_controller.get_production_status.remote(
partition_id=partition_id,
data_fields=data_fields,
)
)
assert production_status
# Verify global_index contains all expected indexes
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
# Verify all samples are produced for all fields (status should be 1)
expected_production_status = torch.ones(gbs * num_n_samples, len(metadata.field_names), dtype=torch.int8)
assert torch.equal(production_status, expected_production_status)
print("✓ Get production status returns correct global_index and production_status")

# Total fields should match the number of fields we added
assert partition.total_fields_num == len(data_fields)
Expand All @@ -126,14 +132,19 @@ def test_controller_with_single_partition(self, ray_setup):

print(f"✓ Updated production status for partition {partition_id}")

# Test for get consumption status
consumption_status = ray.get(
# Test for get consumption status BEFORE consumption
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
assert torch.equal(consumption_status, torch.zeros(gbs * num_n_samples))
# Verify global_index
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
# Verify all samples are NOT consumed yet (status should be 0)
expected_consumption_status_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_status_before)
print("✓ Get consumption status returns correct global_index and status (before consumption)")

# Test get metadate in fetch mode
gen_meta = ray.get(
Expand All @@ -153,14 +164,19 @@ def test_controller_with_single_partition(self, ray_setup):
assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples))
print("✓ Get metadata in fetch mode correct")

# Test for get consumption status
consumption_status = ray.get(
# Test for get consumption status AFTER consumption
global_index, consumption_status = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id,
task_name="generate_sequences",
)
)
assert torch.equal(consumption_status, torch.ones(gbs * num_n_samples))
# Verify global_index
assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long))
# Verify all samples are consumed (status should be 1)
expected_consumption_status_after = torch.ones(gbs * num_n_samples, dtype=torch.int8)
assert torch.equal(consumption_status, expected_consumption_status_after)
print("✓ Get consumption status returns correct global_index and status (after consumption)")

# Test get clear meta
clear_meta = ray.get(
Expand Down Expand Up @@ -222,6 +238,19 @@ def test_controller_with_multi_partitions(self, ray_setup):
)
assert success

# Verify get production status returns correct data
global_index_1, production_status_1 = ray.get(
tq_controller.get_production_status.remote(
partition_id=partition_id_1,
data_fields=data_fields,
)
)
expected_global_index_1 = torch.tensor(range(gbs_1 * num_n_samples_1), dtype=torch.long)
assert torch.equal(global_index_1, expected_global_index_1)
expected_production_status_1 = torch.ones(gbs_1 * num_n_samples_1, len(data_fields), dtype=torch.int8)
assert torch.equal(production_status_1, expected_production_status_1)
print("✓ Get production status for partition_1 returns correct global_index and status")

# Test get metadate in fetch mode
gen_meta = ray.get(
tq_controller.get_metadata.remote(
Expand All @@ -234,6 +263,18 @@ def test_controller_with_multi_partitions(self, ray_setup):
)
assert gen_meta

# Verify get consumption status after fetch (samples should be consumed)
global_index_1_consumed, consumption_status_1 = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id_1,
task_name="generate_sequences",
)
)
assert torch.equal(global_index_1_consumed, expected_global_index_1)
expected_consumption_status_1 = torch.ones(gbs_1 * num_n_samples_1, dtype=torch.int8)
assert torch.equal(consumption_status_1, expected_consumption_status_1)
print("✓ Get consumption status for partition_1 returns correct global_index and status (after fetch)")

# Test get clear meta
clear_meta = ray.get(
tq_controller.get_metadata.remote(
Expand Down Expand Up @@ -282,6 +323,33 @@ def test_controller_with_multi_partitions(self, ray_setup):
)
assert success

# Verify get production status for partition_2
global_index_2, production_status_2 = ray.get(
tq_controller.get_production_status.remote(
partition_id=partition_id_2,
data_fields=data_fields,
)
)
expected_global_index_2 = torch.tensor(
range(part1_index_range, part2_index_range + part1_index_range), dtype=torch.long
)
assert torch.equal(global_index_2, expected_global_index_2)
expected_production_status_2 = torch.ones(part2_index_range, len(data_fields), dtype=torch.int8)
assert torch.equal(production_status_2, expected_production_status_2)
print("✓ Get production status for partition_2 returns correct global_index and status")

# Verify get consumption status for partition_2 (before consumption - should be all zeros)
global_index_2_consumed, consumption_status_2 = ray.get(
tq_controller.get_consumption_status.remote(
partition_id=partition_id_2,
task_name="generate_sequences",
)
)
assert torch.equal(global_index_2_consumed, expected_global_index_2)
expected_consumption_status_2 = torch.zeros(part2_index_range, dtype=torch.int8)
assert torch.equal(consumption_status_2, expected_consumption_status_2)
print("✓ Get consumption status for partition_2 returns correct global_index and status (before consumption)")

# Clear partition 1
partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1))
assert partition_index_range_1
Expand Down
Loading