diff --git a/tests/test_client.py b/tests/test_client.py index 4b8394c..e8027b4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -112,18 +112,20 @@ def _handle_requests(self): # Mock partition metadata response response_body = {"metadata": self._mock_batch_meta(request_msg.body)} response_type = ZMQRequestType.GET_PARTITION_META_RESPONSE - elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION: + elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION: # Mock consumption status check - all consumed response_body = { "partition_id": request_msg.body.get("partition_id"), - "consumed": True, + "global_index": torch.tensor([0, 1, 2]), + "consumption_status": torch.tensor([1, 1, 1]), } response_type = ZMQRequestType.CONSUMPTION_RESPONSE - elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION: + elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION: # Mock production status check - all produced response_body = { "partition_id": request_msg.body.get("partition_id"), - "produced": True, + "global_index": torch.tensor([0, 1, 2]), + "production_status": torch.tensor([[1, 1, 1], [1, 1, 1]]), } response_type = ZMQRequestType.PRODUCTION_RESPONSE elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS: @@ -467,6 +469,52 @@ def test_check_production_status(client_setup): assert is_produced is True +def test_get_consumption_status(client_setup): + """Test get_consumption_status - returns global_index and consumption_status tensors""" + client, _, _ = client_setup + + # Test synchronous get_consumption_status + global_index, consumption_status = client.get_consumption_status( + task_name="generate_sequences", partition_id="train_0" + ) + + # Verify return types + assert global_index is not None + assert consumption_status is not None + + # Verify global_index contains expected values + assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long)) + + # Verify consumption_status (mock returns all consumed) + expected_status = torch.tensor([1, 1, 1], dtype=torch.int8) + assert torch.equal(consumption_status, expected_status) + + print("✓ get_consumption_status returns correct global_index and consumption_status") + + +def test_get_production_status(client_setup): + """Test get_production_status - returns global_index and production_status tensors""" + client, _, _ = client_setup + + # Test synchronous get_production_status + global_index, production_status = client.get_production_status( + data_fields=["prompt_ids", "attention_mask"], partition_id="train_0" + ) + + # Verify return types + assert global_index is not None + assert production_status is not None + + # Verify global_index contains expected values + assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long)) + + # Verify production_status shape (mock returns 2x3 matrix) + expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8) + assert torch.equal(production_status, expected_status) + + print("✓ get_production_status returns correct global_index and production_status") + + def test_get_partition_list(client_setup): """Test partition list retrieval""" client, _, _ = client_setup @@ -502,6 +550,54 @@ async def test_async_check_production_status(client_setup): assert is_produced is True +@pytest.mark.asyncio +async def test_async_get_consumption_status(client_setup): + """Test async get_consumption_status - returns global_index and consumption_status tensors""" + client, _, _ = client_setup + + # Test async_get_consumption_status + global_index, consumption_status = await client.async_get_consumption_status( + task_name="generate_sequences", partition_id="train_0" + ) + + # Verify return types + assert global_index is not None + assert consumption_status is not None + + # Verify global_index contains expected values + assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long)) + + # Verify consumption_status (mock returns all consumed) + expected_status = torch.tensor([1, 1, 1], dtype=torch.int8) + assert torch.equal(consumption_status, expected_status) + + print("✓ async_get_consumption_status returns correct global_index and consumption_status") + + +@pytest.mark.asyncio +async def test_async_get_production_status(client_setup): + """Test async get_production_status - returns global_index and production_status tensors""" + client, _, _ = client_setup + + # Test async_get_production_status + global_index, production_status = await client.async_get_production_status( + data_fields=["prompt_ids", "attention_mask"], partition_id="train_0" + ) + + # Verify return types + assert global_index is not None + assert production_status is not None + + # Verify global_index contains expected values + assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long)) + + # Verify production_status shape (mock returns 2x3 matrix) + expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8) + assert torch.equal(production_status, expected_status) + + print("✓ async_get_production_status returns correct global_index and production_status") + + @pytest.mark.asyncio async def test_async_get_partition_list(client_setup): """Test async partition list retrieval""" diff --git a/tests/test_controller.py b/tests/test_controller.py index 360b7f0..3d1379e 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -89,6 +89,7 @@ def test_controller_with_single_partition(self, ray_setup): field_names=metadata.field_names, dtypes=dtypes, shapes=shapes, + custom_meta=None, ) ) assert success @@ -97,13 +98,18 @@ def test_controller_with_single_partition(self, ray_setup): assert partition.production_status.size(0) == gbs * num_n_samples # Test for get production status - production_status = ray.get( + global_index, production_status = ray.get( tq_controller.get_production_status.remote( partition_id=partition_id, data_fields=data_fields, ) ) - assert production_status + # Verify global_index contains all expected indexes + assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long)) + # Verify all samples are produced for all fields (status should be 1) + expected_production_status = torch.ones(gbs * num_n_samples, len(metadata.field_names), dtype=torch.int8) + assert torch.equal(production_status, expected_production_status) + print("✓ Get production status returns correct global_index and production_status") # Total fields should match the number of fields we added assert partition.total_fields_num == len(data_fields) @@ -126,14 +132,19 @@ def test_controller_with_single_partition(self, ray_setup): print(f"✓ Updated production status for partition {partition_id}") - # Test for get consumption status - consumption_status = ray.get( + # Test for get consumption status BEFORE consumption + global_index, consumption_status = ray.get( tq_controller.get_consumption_status.remote( partition_id=partition_id, task_name="generate_sequences", ) ) - assert torch.equal(consumption_status, torch.zeros(gbs * num_n_samples)) + # Verify global_index + assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long)) + # Verify all samples are NOT consumed yet (status should be 0) + expected_consumption_status_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8) + assert torch.equal(consumption_status, expected_consumption_status_before) + print("✓ Get consumption status returns correct global_index and status (before consumption)") # Test get metadate in fetch mode gen_meta = ray.get( @@ -153,14 +164,19 @@ def test_controller_with_single_partition(self, ray_setup): assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples)) print("✓ Get metadata in fetch mode correct") - # Test for get consumption status - consumption_status = ray.get( + # Test for get consumption status AFTER consumption + global_index, consumption_status = ray.get( tq_controller.get_consumption_status.remote( partition_id=partition_id, task_name="generate_sequences", ) ) - assert torch.equal(consumption_status, torch.ones(gbs * num_n_samples)) + # Verify global_index + assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long)) + # Verify all samples are consumed (status should be 1) + expected_consumption_status_after = torch.ones(gbs * num_n_samples, dtype=torch.int8) + assert torch.equal(consumption_status, expected_consumption_status_after) + print("✓ Get consumption status returns correct global_index and status (after consumption)") # Test get clear meta clear_meta = ray.get( @@ -222,6 +238,19 @@ def test_controller_with_multi_partitions(self, ray_setup): ) assert success + # Verify get production status returns correct data + global_index_1, production_status_1 = ray.get( + tq_controller.get_production_status.remote( + partition_id=partition_id_1, + data_fields=data_fields, + ) + ) + expected_global_index_1 = torch.tensor(range(gbs_1 * num_n_samples_1), dtype=torch.long) + assert torch.equal(global_index_1, expected_global_index_1) + expected_production_status_1 = torch.ones(gbs_1 * num_n_samples_1, len(data_fields), dtype=torch.int8) + assert torch.equal(production_status_1, expected_production_status_1) + print("✓ Get production status for partition_1 returns correct global_index and status") + # Test get metadate in fetch mode gen_meta = ray.get( tq_controller.get_metadata.remote( @@ -234,6 +263,18 @@ def test_controller_with_multi_partitions(self, ray_setup): ) assert gen_meta + # Verify get consumption status after fetch (samples should be consumed) + global_index_1_consumed, consumption_status_1 = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id_1, + task_name="generate_sequences", + ) + ) + assert torch.equal(global_index_1_consumed, expected_global_index_1) + expected_consumption_status_1 = torch.ones(gbs_1 * num_n_samples_1, dtype=torch.int8) + assert torch.equal(consumption_status_1, expected_consumption_status_1) + print("✓ Get consumption status for partition_1 returns correct global_index and status (after fetch)") + # Test get clear meta clear_meta = ray.get( tq_controller.get_metadata.remote( @@ -282,6 +323,33 @@ def test_controller_with_multi_partitions(self, ray_setup): ) assert success + # Verify get production status for partition_2 + global_index_2, production_status_2 = ray.get( + tq_controller.get_production_status.remote( + partition_id=partition_id_2, + data_fields=data_fields, + ) + ) + expected_global_index_2 = torch.tensor( + range(part1_index_range, part2_index_range + part1_index_range), dtype=torch.long + ) + assert torch.equal(global_index_2, expected_global_index_2) + expected_production_status_2 = torch.ones(part2_index_range, len(data_fields), dtype=torch.int8) + assert torch.equal(production_status_2, expected_production_status_2) + print("✓ Get production status for partition_2 returns correct global_index and status") + + # Verify get consumption status for partition_2 (before consumption - should be all zeros) + global_index_2_consumed, consumption_status_2 = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id_2, + task_name="generate_sequences", + ) + ) + assert torch.equal(global_index_2_consumed, expected_global_index_2) + expected_consumption_status_2 = torch.zeros(part2_index_range, dtype=torch.int8) + assert torch.equal(consumption_status_2, expected_consumption_status_2) + print("✓ Get consumption status for partition_2 returns correct global_index and status (before consumption)") + # Clear partition 1 partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1)) assert partition_index_range_1 diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index ee1092a..df9f60d 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -86,9 +86,9 @@ def test_data_partition_status(): print("✓ Field metadata retrieval works") # Test consumption status - consumption_tensor = partition.get_consumption_status("test_task") + global_index, consumption_tensor = partition.get_consumption_status("test_task", mask=False) assert consumption_tensor is not None - assert consumption_tensor.shape[0] == partition.total_samples_num + assert consumption_tensor.shape[0] == partition.allocated_samples_num print("✓ Consumption status creation works") @@ -240,7 +240,7 @@ def test_data_partition_status_advanced(): # Initial consumption tracking partition.mark_consumed(task_name, [0, 1]) - initial_consumption = partition.get_consumption_status(task_name) + global_index, initial_consumption = partition.get_consumption_status(task_name) assert initial_consumption[0] == 1 assert initial_consumption[1] == 1 @@ -258,7 +258,7 @@ def test_data_partition_status_advanced(): 12: {"field_d": (32,)}, } partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes, None) # Triggers sample expansion - expanded_consumption = partition.get_consumption_status(task_name) + global_index, expanded_consumption = partition.get_consumption_status(task_name) assert expanded_consumption[0] == 1 # Preserved assert expanded_consumption[1] == 1 # Preserved assert expanded_consumption.shape[0] >= 13 # Expanded to accommodate new samples @@ -358,7 +358,7 @@ def test_edge_cases_and_error_handling(): # Test 3: Consumption status edge cases # Test consumption status creation before production status task_name = "early_task" - consumption_tensor = partition.get_consumption_status(task_name) + _, consumption_tensor = partition.get_consumption_status(task_name) assert consumption_tensor is not None assert consumption_tensor.shape[0] == partition.allocated_samples_num @@ -526,3 +526,130 @@ def test_update_field_metadata_variants(): # Length mismatch should raise ValueError when provided mapping lengths differ from global_indices with pytest.raises(ValueError): partition._update_field_metadata([0, 1, 2], dtypes={0: {}}, shapes=None, custom_meta=None) + + +def test_get_production_status_for_fields(): + """Test get_production_status_for_fields method with mask parameter.""" + print("Testing get_production_status_for_fields...") + + import torch + + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="production_status_test") + + # Add some data first (using non-contiguous indices) + partition.update_production_status( + global_indices=[0, 1, 2, 3, 9], + field_names=["field_a", "field_b"], + dtypes={i: {"field_a": "torch.int64", "field_b": "torch.bool"} for i in [0, 1, 2, 3, 9]}, + shapes={i: {"field_a": (32,), "field_b": (32,)} for i in [0, 1, 2, 3, 9]}, + ) + + # Test get_production_status_for_fields WITHOUT mask (mask=False) + global_index, production_status = partition.get_production_status_for_fields( + field_names=["field_a", "field_b"], mask=False + ) + assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long)) + # Without mask, should return all allocated samples + assert production_status.shape[0] == partition.allocated_samples_num + # Production status should be 1 for produced samples (0, 1, 2, 3, 9), 0 for others + # Check that produced samples have all fields produced (all 1s) + assert torch.all(production_status[0] == 1), "Sample 0 should be produced" + assert torch.all(production_status[1] == 1), "Sample 1 should be produced" + assert torch.all(production_status[2] == 1), "Sample 2 should be produced" + assert torch.all(production_status[3] == 1), "Sample 3 should be produced" + assert torch.all(production_status[9] == 1), "Sample 9 should be produced" + # Verify shape - should have 2 fields (columns) + assert production_status.shape[1] == 2 + + print("✓ get_production_status_for_fields without mask works") + + # Test get_production_status_for_fields WITH mask (mask=True) + global_index_masked, production_status_masked = partition.get_production_status_for_fields( + field_names=["field_a", "field_b"], mask=True + ) + assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long)) + # Masked status should be same as original for these indices + assert production_status_masked.shape == (len([0, 1, 2, 3, 9]), 2) + # All returned samples should be produced + assert torch.all(production_status_masked == 1) + + print("✓ get_production_status_for_fields with mask works") + + # Test get_production_status_for_fields with subset of fields + global_index_subset, production_status_subset = partition.get_production_status_for_fields( + field_names=["field_a"], mask=True + ) + assert global_index_subset.shape[0] == len([0, 1, 2, 3, 9]) + assert production_status_subset.shape == (len([0, 1, 2, 3, 9]), 1) # Only one field + + print("✓ get_production_status_for_fields with subset fields works") + + print("get_production_status_for_fields tests passed!\n") + + +def test_get_consumption_status_parameter(): + """Test get_consumption_status method with mask parameter.""" + print("Testing consumption status mask parameter...") + + import torch + + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="consumption_mask_test") + partition_another = DataPartitionStatus(partition_id="other_partition") + + # Add some data + partition.update_production_status( + global_indices=[0, 1, 2, 3, 9], + field_names=["field_a"], + dtypes={i: {"field_a": "torch.int64"} for i in [0, 1, 2, 3, 9]}, + shapes={i: {"field_a": (32,)} for i in [0, 1, 2, 3, 9]}, + ) + + partition_another.update_production_status( + global_indices=[5, 6, 7], + field_names=["field_a"], + dtypes={i: {"field_a": "torch.int64"} for i in [5, 6, 7]}, + shapes={i: {"field_a": (32,)} for i in [5, 6, 7]}, + ) + + # Mark some samples as consumed + partition.mark_consumed("test_task", [0, 2]) + + # Test get_consumption_status WITHOUT mask (mask=False) + global_index, consumption_status = partition.get_consumption_status("test_task", mask=False) + assert global_index.shape[0] == partition.total_samples_num + assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long)) + # Without mask, should return all allocated samples + assert consumption_status.shape[0] == 10 + assert consumption_status[0].item() == 1 + assert consumption_status[1].item() == 0 + assert consumption_status[2].item() == 1 + assert consumption_status[3].item() == 0 + assert consumption_status[4].item() == 0 # empty slot + assert consumption_status[5].item() == 0 # empty slot + assert consumption_status[6].item() == 0 # empty slot + assert consumption_status[7].item() == 0 # empty slot + assert consumption_status[8].item() == 0 # empty slot + assert consumption_status[9].item() == 0 + + print("✓ get_consumption_status without mask works") + + # Test get_consumption_status WITH mask (mask=True) + global_index_masked, consumption_status_masked = partition.get_consumption_status("test_task", mask=True) + # With mask, should return only global_indexes [0, 1, 2, 3, 9] + assert global_index_masked.shape[0] == partition.total_samples_num + assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long)) + # Masked status shape[0] should correspond to global indexes + assert consumption_status_masked.shape[0] == partition.total_samples_num + assert consumption_status_masked[0].item() == 1 + assert consumption_status_masked[1].item() == 0 + assert consumption_status_masked[2].item() == 1 + assert consumption_status_masked[3].item() == 0 + assert consumption_status_masked[4].item() == 0 # no empty slot. this corresponds to global_index=9 + + print("✓ get_consumption_status with mask works") + + print("Consumption status mask parameter tests passed!\n") diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 59a3886..37096b4 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -21,9 +21,11 @@ from uuid import uuid4 import ray +import torch import zmq import zmq.asyncio from tensordict import TensorDict +from torch import Tensor from transfer_queue.controller import TransferQueueController from transfer_queue.metadata import ( @@ -536,13 +538,13 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): raise RuntimeError(f"Failed to clear partition {partition_id} in controller.") @dynamic_socket(socket_name="request_handle_socket") - async def async_check_consumption_status( + async def async_get_consumption_status( self, task_name: str, partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, - ) -> bool: - """Check if all samples for current partition have been consumed by a specific task. + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Get consumption status for current partition in a specific task. Args: task_name: Name of the task to check consumption for @@ -550,22 +552,25 @@ async def async_check_consumption_status( socket: ZMQ async socket for message transmission (injected by decorator) Returns: - bool: True if all samples have been consumed by the task, False otherwise + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. Raises: RuntimeError: If communication fails or controller returns error response Example: - >>> # Check if all samples have been consumed - >>> is_consumed = asyncio.run(client.async_check_consumption_status( + >>> # Get consumption status + >>> global_index, consumption_status = asyncio.run(client.async_get_consumption_status( ... task_name="generate_sequences", ... partition_id="train_0" ... )) - >>> print(f"All samples consumed: {is_consumed}") + >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") """ + assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CHECK_CONSUMPTION, + request_type=ZMQRequestType.GET_CONSUMPTION, sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -579,29 +584,30 @@ async def async_check_consumption_status( response_serialized = await socket.recv_multipart() response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( - f"[{self.client_id}]: Client check consumption response: {response_msg} " + f"[{self.client_id}]: Client get consumption response: {response_msg} " f"from controller {self._controller.id}" ) if response_msg.request_type == ZMQRequestType.CONSUMPTION_RESPONSE: - consumed = response_msg.body.get("consumed", False) - return consumed + global_index = response_msg.body.get("global_index") + consumption_status = response_msg.body.get("consumption_status") + return global_index, consumption_status else: raise RuntimeError( - f"[{self.client_id}]: Failed to check consumption status from controller {self._controller.id}: " + f"[{self.client_id}]: Failed to get consumption status from controller {self._controller.id}: " f"{response_msg.body.get('message', 'Unknown error')}" ) except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in check_data_consumption_status: {str(e)}") from e + raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e @dynamic_socket(socket_name="request_handle_socket") - async def async_check_production_status( + async def async_get_production_status( self, data_fields: list[str], partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, - ) -> bool: - """Check if all samples for current partition are ready (produced) for consumption. + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Get production status for current partition for specific fields. Args: data_fields: Data fields to check production status for @@ -609,22 +615,24 @@ async def async_check_production_status( socket: ZMQ async socket for message transmission (injected by decorator) Returns: - bool: True if all samples have been produced and ready, False otherwise + Tuple of: + - Partition global index tensor + - Production status tensor for the specified fields. 1 for ready, 0 for not ready. Raises: RuntimeError: If communication fails or controller returns error response Example: - >>> # Check if all samples are ready for consumption - >>> is_ready = asyncio.run(client.async_check_production_status( + >>> # Get production status + >>> global_index, production_status = asyncio.run(client.async_get_production_status( ... data_fields=["input_ids", "attention_mask"], ... partition_id="train_0" ... )) - >>> print(f"All samples ready: {is_ready}") + >>> print(f"Global index: {global_index}, Production status: {production_status}") """ assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CHECK_PRODUCTION, + request_type=ZMQRequestType.GET_PRODUCTION, sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -638,20 +646,91 @@ async def async_check_production_status( response_serialized = await socket.recv_multipart() response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( - f"[{self.client_id}]: Client check production response: {response_msg} " + f"[{self.client_id}]: Client get production response: {response_msg} " f"from controller {self._controller.id}" ) if response_msg.request_type == ZMQRequestType.PRODUCTION_RESPONSE: - produced = response_msg.body.get("produced", False) - return produced + global_index = response_msg.body.get("global_index") + production_status = response_msg.body.get("production_status") + return global_index, production_status else: raise RuntimeError( - f"[{self.client_id}]: Failed to check production status from controller {self._controller.id}: " + f"[{self.client_id}]: Failed to get production status from controller {self._controller.id}: " f"{response_msg.body.get('message', 'Unknown error')}" ) except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in check_data_production_status: {str(e)}") from e + raise RuntimeError(f"[{self.client_id}]: Error in get_data_production_status: {str(e)}") from e + + async def async_check_consumption_status( + self, + task_name: str, + partition_id: str, + ) -> bool: + """Check if all samples for current partition have been consumed by a specific task. + + Args: + task_name: Name of the task to check consumption for + partition_id: Partition id to check consumption status for + + Returns: + bool: True if all samples have been consumed by the task, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples have been consumed + >>> is_consumed = asyncio.run(client.async_check_consumption_status( + ... task_name="generate_sequences", + ... partition_id="train_0" + ... )) + >>> print(f"All samples consumed: {is_consumed}") + """ + + _, consumption_status = await self.async_get_consumption_status( + task_name=task_name, + partition_id=partition_id, + ) + + if consumption_status is None: + return False + return torch.all(consumption_status == 1).item() + + async def async_check_production_status( + self, + data_fields: list[str], + partition_id: str, + ) -> bool: + """Check if the all specific fields of samples for current partition are ready + (produced) for consumption. + + Args: + data_fields: Data fields to check production status for + partition_id: Partition id to check production status for + + Returns: + bool: True if all samples have been produced and ready, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples are ready for consumption + >>> is_ready = asyncio.run(client.async_check_production_status( + ... data_fields=["input_ids", "attention_mask"], + ... partition_id="train_0" + ... )) + >>> print(f"All samples ready: {is_ready}") + """ + _, production_status = await self.async_get_production_status( + data_fields=data_fields, + partition_id=partition_id, + ) + + if production_status is None: + return False + return torch.all(production_status == 1).item() @dynamic_socket(socket_name="request_handle_socket") async def async_get_partition_list( @@ -811,6 +890,31 @@ def check_consumption_status(self, task_name: str, partition_id: str) -> bool: """ return asyncio.run(self.async_check_consumption_status(task_name, partition_id)) + def get_consumption_status( + self, + task_name: str, + partition_id: str, + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Synchronously get consumption status for a specific task and partition. + + Args: + task_name: Name of the task to check consumption for + partition_id: Partition id to check consumption status for + + Returns: + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. + + Example: + >>> global_index, consumption_status = client.get_consumption_status( + ... task_name="generate_sequences", + ... partition_id="train_0" + ... ) + >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") + """ + return asyncio.run(self.async_get_consumption_status(task_name, partition_id)) + def check_production_status(self, data_fields: list[str], partition_id: str) -> bool: """Synchronously check if all samples for a partition are ready (produced) for consumption. @@ -823,6 +927,31 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> """ return asyncio.run(self.async_check_production_status(data_fields, partition_id)) + def get_production_status( + self, + data_fields: list[str], + partition_id: str, + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Synchronously get production status for a specific data fields and partition. + + Args: + data_fields: Data fields to check production status for + partition_id: Partition id to check production status for + + Returns: + Tuple of: + - Partition global index tensor + - Production status tensor for the specified fields. 1 for ready, 0 for not ready. + + Example: + >>> global_index, production_status = client.get_production_status( + ... data_fields=["input_ids", "attention_mask"], + ... partition_id="train_0" + ... ) + >>> print(f"Global index: {global_index}, Production status: {production_status}") + """ + return asyncio.run(self.async_get_production_status(data_fields, partition_id)) + def get_partition_list( self, ): diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index fae9e96..2b6767f 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -29,6 +29,7 @@ import torch import zmq from ray.util import get_node_ip_address +from torch import Tensor from transfer_queue.metadata import ( BatchMeta, @@ -215,11 +216,11 @@ class DataPartitionStatus: # Production status tensor - dynamically expandable # Values: 0 = not produced, 1 = ready for consumption - production_status: Optional[torch.Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) + production_status: Optional[Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) # Consumption status per task - task_name -> consumption_tensor # Each tensor tracks which samples have been consumed by that task - consumption_status: dict[str, torch.Tensor] = field(default_factory=dict) + consumption_status: dict[str, Tensor] = field(default_factory=dict) # Sample metadata global_indexes: set[int] = field( @@ -442,18 +443,42 @@ def _update_field_metadata( self.field_custom_metas[global_idx] = {} self.field_custom_metas[global_idx].update(custom_meta_value[i]) + def mark_consumed(self, task_name: str, global_indices: list[int]): + """ + Mark specific samples as consumed by a task. + + Args: + task_name: Name of the consumer task + global_indices: List of sample indices to mark as consumed + + """ + try: + _, consumption_status = self.get_consumption_status(task_name, mask=False) + + if consumption_status.numel() > 0 and global_indices: + consumption_status[global_indices] = 1 + except Exception as e: + logger.error( + f"Error marking samples consumed for partition {self.partition_id}, task {task_name}: {e}. " + f"Target global_indices {global_indices}, but current consumption_status has " + f"shape {consumption_status.shape}" + ) + # ==================== Consumption Status Interface ==================== - def get_consumption_status(self, task_name: str) -> torch.Tensor: + def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Tensor, Tensor]: """ Get or create consumption status for a specific task. Handles dynamic expansion when new samples are added. Args: task_name: Name of the consumer task + mask: Whether to return only the status for current partition samples Returns: - Consumption status tensor for the specified task + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. """ if task_name not in self.consumption_status: @@ -462,38 +487,29 @@ def get_consumption_status(self, task_name: str) -> torch.Tensor: else: self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8) - return self.consumption_status[task_name] - - def mark_consumed(self, task_name: str, global_indices: list[int]): - """ - Mark specific samples as consumed by a task. + # Get consumption status for requested task + consumption_status = self.consumption_status[task_name] - Args: - task_name: Name of the consumer task - global_indices: List of sample indices to mark as consumed + partition_global_index = torch.tensor(sorted(self.global_indexes), dtype=torch.long) - """ - try: - consumption_status = self.get_consumption_status(task_name) + if mask: + consumption_status = consumption_status[partition_global_index] - if consumption_status.numel() > 0 and global_indices: - consumption_status[global_indices] = 1 - except Exception as e: - logger.error( - f"Error marking samples consumed for partition {self.partition_id}, task {task_name}: {e}. " - f"Target global_indices {global_indices}, but current consumption_status has " - f"shape {consumption_status.shape}" - ) + return partition_global_index, consumption_status - def get_production_status_for_fields(self, field_names: list[str]) -> bool: + # ==================== Production Status Interface ==================== + def get_production_status_for_fields(self, field_names: list[str], mask: bool = False) -> tuple[Tensor, Tensor]: """ Check if all samples for specified fields are fully produced and ready. Args: field_names: List of field names to check production status for + mask: Whether to return only the status for current partition samples Returns: - bool: True if all samples have been produced for all specified fields, False otherwise + Tuple of: + - Partition global index tensor + - Production status tensor for the specified task. 1 for ready, 0 for not ready. """ if self.production_status is None or field_names is None or len(field_names) == 0: return False @@ -509,13 +525,14 @@ def get_production_status_for_fields(self, field_names: list[str]) -> bool: if field_indices: col_mask[field_indices] = True - # Get production status for requested fields - relevant_status = self.production_status[:, col_mask] + production_status = self.production_status[:, col_mask] + + partition_global_index = torch.tensor(sorted(self.global_indexes), dtype=torch.long) - # Check if all samples have all requested fields produced (all values are 1) - all_fields_produced = torch.all(relevant_status == 1).item() + if mask: + production_status = production_status[partition_global_index] - return all_fields_produced + return partition_global_index, production_status # ==================== Data Scanning and Query Methods ==================== @@ -543,7 +560,7 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool) # Apply consumption filter (exclude already consumed samples) - consumption_status = self.get_consumption_status(task_name) + _, consumption_status = self.get_consumption_status(task_name, mask=False) if consumption_status is not None: unconsumed_mask = consumption_status == 0 row_mask &= unconsumed_mask @@ -650,7 +667,7 @@ def _perform_copy(): if name == "data_status_lock": continue - if isinstance(value, torch.Tensor): + if isinstance(value, Tensor): new_val = value.clone().detach() else: new_val = copy.deepcopy(value) @@ -873,7 +890,7 @@ def update_production_status( # ==================== Data Consumption API ==================== - def get_consumption_status(self, partition_id: str, task_name: str) -> Optional[torch.Tensor]: + def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Get or create consumption status for a specific task and partition. Delegates to the partition's own method. @@ -883,15 +900,19 @@ def get_consumption_status(self, partition_id: str, task_name: str) -> Optional[ task_name: Name of the consumer task Returns: - Consumption status tensor if partition exists, None otherwise + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. """ partition = self._get_partition(partition_id) if not partition: - return None + return None, None - return partition.get_consumption_status(task_name) + return partition.get_consumption_status(task_name, mask=True) - def get_production_status(self, partition_id: str, data_fields: list[str]) -> bool: + def get_production_status( + self, partition_id: str, data_fields: list[str] + ) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Check if all samples for specified fields are fully produced in a partition. @@ -900,13 +921,15 @@ def get_production_status(self, partition_id: str, data_fields: list[str]) -> bo data_fields: List of field names to check production status for Returns: - bool: True if all samples have been produced for all specified fields, False otherwise + Tuple of: + - Partition global index tensor + - Production status tensor for the specified task. 1 for ready, 0 for not ready. """ partition = self._get_partition(partition_id) if not partition: - return False + return None, None - return partition.get_production_status_for_fields(data_fields) + return partition.get_production_status_for_fields(data_fields, mask=True) def get_metadata( self, @@ -1392,22 +1415,19 @@ def _process_request(self): body={"message": f"Clear partition operation completed by controller {self.controller_id}"}, ) - elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION: - with perf_monitor.measure(op_type="CHECK_CONSUMPTION"): + elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION: + with perf_monitor.measure(op_type="GET_CONSUMPTION"): # Handle consumption status checks params = request_msg.body - consumption_status = self.get_consumption_status(params["partition_id"], params["task_name"]) - sample_filter = params.get("sample_filter") + global_index, consumption_status = self.get_consumption_status( + params["partition_id"], params["task_name"] + ) + sample_filter = params.get("sample_filter") # TODO: DEPRECATED in future - if consumption_status is not None and sample_filter: - batch_status = consumption_status[sample_filter] - consumed = torch.all(batch_status == 1).item() - elif consumption_status is not None: - batch_status = consumption_status - consumed = torch.all(batch_status == 1).item() - else: - consumed = False + if sample_filter and consumption_status is not None: + # TODO: DEPRECATED in future + consumption_status = consumption_status[sample_filter] response_msg = ZMQMessage.create( request_type=ZMQRequestType.CONSUMPTION_RESPONSE, @@ -1415,16 +1435,19 @@ def _process_request(self): receiver_id=request_msg.sender_id, body={ "partition_id": params["partition_id"], - "consumed": consumed, + "global_index": global_index, + "consumption_status": consumption_status, }, ) - elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION: - with perf_monitor.measure(op_type="CHECK_PRODUCTION"): + elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION: + with perf_monitor.measure(op_type="GET_PRODUCTION"): # Handle production status checks params = request_msg.body - produced = self.get_production_status(params["partition_id"], params["data_fields"]) + global_index, production_status = self.get_production_status( + params["partition_id"], params["data_fields"] + ) response_msg = ZMQMessage.create( request_type=ZMQRequestType.PRODUCTION_RESPONSE, @@ -1432,7 +1455,8 @@ def _process_request(self): receiver_id=request_msg.sender_id, body={ "partition_id": params["partition_id"], - "produced": produced, + "global_index": global_index, + "production_status": production_status, }, ) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 05b35f8..f462773 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -72,12 +72,12 @@ class ZMQRequestType(ExplicitEnum): CLEAR_PARTITION = "CLEAR_PARTITION" CLEAR_PARTITION_RESPONSE = "CLEAR_PARTITION_RESPONSE" - # CHECK_CONSUMPTION - CHECK_CONSUMPTION = "CHECK_CONSUMPTION" + # GET_CONSUMPTION + GET_CONSUMPTION = "GET_CONSUMPTION" CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE" - # CHECK_PRODUCTION - CHECK_PRODUCTION = "CHECK_PRODUCTION" + # GET_PRODUCTION + GET_PRODUCTION = "GET_PRODUCTION" PRODUCTION_RESPONSE = "PRODUCTION_RESPONSE" # LIST_PARTITIONS