From 387bdcfa0629ab41a027a2d2d579c41519a816cc Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 23 Jan 2026 10:38:21 +0800 Subject: [PATCH 1/7] provide raw consumption&production status retrieval Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 130 +++++++++++++++++++++++------ transfer_queue/controller.py | 131 +++++++++++++++++------------- transfer_queue/utils/zmq_utils.py | 8 +- 3 files changed, 184 insertions(+), 85 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 59a3886..ecc1983 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -21,9 +21,11 @@ from uuid import uuid4 import ray +import torch import zmq import zmq.asyncio from tensordict import TensorDict +from torch import Tensor from transfer_queue.controller import TransferQueueController from transfer_queue.metadata import ( @@ -536,13 +538,13 @@ async def _clear_partition_in_controller(self, partition_id, socket=None): raise RuntimeError(f"Failed to clear partition {partition_id} in controller.") @dynamic_socket(socket_name="request_handle_socket") - async def async_check_consumption_status( + async def async_get_consumption_status( self, task_name: str, partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, - ) -> bool: - """Check if all samples for current partition have been consumed by a specific task. + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Get consumption status for current partition in a specific task. Args: task_name: Name of the task to check consumption for @@ -550,19 +552,22 @@ async def async_check_consumption_status( socket: ZMQ async socket for message transmission (injected by decorator) Returns: - bool: True if all samples have been consumed by the task, False otherwise + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. Raises: RuntimeError: If communication fails or controller returns error response Example: - >>> # Check if all samples have been consumed - >>> is_consumed = asyncio.run(client.async_check_consumption_status( + >>> # Get consumption status + >>> global_index, consumption_status = asyncio.run(client.async_check_consumption_status( ... task_name="generate_sequences", ... partition_id="train_0" ... )) - >>> print(f"All samples consumed: {is_consumed}") + >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") """ + assert socket is not None request_msg = ZMQMessage.create( request_type=ZMQRequestType.CHECK_CONSUMPTION, @@ -579,29 +584,30 @@ async def async_check_consumption_status( response_serialized = await socket.recv_multipart() response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( - f"[{self.client_id}]: Client check consumption response: {response_msg} " + f"[{self.client_id}]: Client get consumption response: {response_msg} " f"from controller {self._controller.id}" ) if response_msg.request_type == ZMQRequestType.CONSUMPTION_RESPONSE: - consumed = response_msg.body.get("consumed", False) - return consumed + global_index = response_msg.body.get("global_index") + consumption_status = response_msg.body.get("consumption_status") + return global_index, consumption_status else: raise RuntimeError( - f"[{self.client_id}]: Failed to check consumption status from controller {self._controller.id}: " + f"[{self.client_id}]: Failed to get consumption status from controller {self._controller.id}: " f"{response_msg.body.get('message', 'Unknown error')}" ) except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in check_data_consumption_status: {str(e)}") from e + raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e @dynamic_socket(socket_name="request_handle_socket") - async def async_check_production_status( + async def async_get_production_status( self, data_fields: list[str], partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, - ) -> bool: - """Check if all samples for current partition are ready (produced) for consumption. + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Get production status for current partition in a specific task. Args: data_fields: Data fields to check production status for @@ -609,22 +615,24 @@ async def async_check_production_status( socket: ZMQ async socket for message transmission (injected by decorator) Returns: - bool: True if all samples have been produced and ready, False otherwise + Tuple of: + - Partition global index tensor + - Production status tensor for the specified task. 1 for ready, 0 for not ready. Raises: RuntimeError: If communication fails or controller returns error response Example: - >>> # Check if all samples are ready for consumption - >>> is_ready = asyncio.run(client.async_check_production_status( + >>> # Get production status + >>> global_index, production_status = asyncio.run(client.async_check_production_status( ... data_fields=["input_ids", "attention_mask"], ... partition_id="train_0" ... )) - >>> print(f"All samples ready: {is_ready}") + >>> print(f"Global index: {global_index}, Production status: {production_status}") """ assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CHECK_PRODUCTION, + request_type=ZMQRequestType.GET_PRODUCTION, sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -638,20 +646,92 @@ async def async_check_production_status( response_serialized = await socket.recv_multipart() response_msg = ZMQMessage.deserialize(response_serialized) logger.debug( - f"[{self.client_id}]: Client check production response: {response_msg} " + f"[{self.client_id}]: Client get production response: {response_msg} " f"from controller {self._controller.id}" ) if response_msg.request_type == ZMQRequestType.PRODUCTION_RESPONSE: - produced = response_msg.body.get("produced", False) - return produced + global_index = response_msg.body.get("global_index") + production_status = response_msg.body.get("production_status") + return global_index, production_status else: raise RuntimeError( - f"[{self.client_id}]: Failed to check production status from controller {self._controller.id}: " + f"[{self.client_id}]: Failed to get production status from controller {self._controller.id}: " f"{response_msg.body.get('message', 'Unknown error')}" ) except Exception as e: - raise RuntimeError(f"[{self.client_id}]: Error in check_data_production_status: {str(e)}") from e + raise RuntimeError(f"[{self.client_id}]: Error in get_data_production_status: {str(e)}") from e + + async def async_check_consumption_status( + self, + task_name: str, + partition_id: str, + ) -> bool: + """Check if all samples for current partition have been consumed by a specific task. + + Args: + task_name: Name of the task to check consumption for + partition_id: Partition id to check consumption status for + socket: ZMQ async socket for message transmission (injected by decorator) + + Returns: + bool: True if all samples have been consumed by the task, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples have been consumed + >>> is_consumed = asyncio.run(client.async_check_consumption_status( + ... task_name="generate_sequences", + ... partition_id="train_0" + ... )) + >>> print(f"All samples consumed: {is_consumed}") + """ + + _, consumption_status = await self.async_get_consumption_status( + task_name=task_name, + partition_id=partition_id, + ) + + if consumption_status is None: + return False + return torch.all(consumption_status).item() == 1 + + async def async_check_production_status( + self, + data_fields: list[str], + partition_id: str, + ) -> bool: + """Check if all samples for current partition are ready (produced) for consumption. + + Args: + data_fields: Data fields to check production status for + partition_id: Partition id to check production status for + socket: ZMQ async socket for message transmission (injected by decorator) + + Returns: + bool: True if all samples have been produced and ready, False otherwise + + Raises: + RuntimeError: If communication fails or controller returns error response + + Example: + >>> # Check if all samples are ready for consumption + >>> is_ready = asyncio.run(client.async_check_production_status( + ... data_fields=["input_ids", "attention_mask"], + ... partition_id="train_0" + ... )) + >>> print(f"All samples ready: {is_ready}") + """ + _, production_status = await self.async_get_production_status( + data_fields=data_fields, + partition_id=partition_id, + ) + + if production_status is None: + return False + return torch.all(production_status).item() == 1 @dynamic_socket(socket_name="request_handle_socket") async def async_get_partition_list( diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index fae9e96..0eea081 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -29,6 +29,7 @@ import torch import zmq from ray.util import get_node_ip_address +from torch import Tensor from transfer_queue.metadata import ( BatchMeta, @@ -215,11 +216,11 @@ class DataPartitionStatus: # Production status tensor - dynamically expandable # Values: 0 = not produced, 1 = ready for consumption - production_status: Optional[torch.Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) + production_status: Optional[Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8) # Consumption status per task - task_name -> consumption_tensor # Each tensor tracks which samples have been consumed by that task - consumption_status: dict[str, torch.Tensor] = field(default_factory=dict) + consumption_status: dict[str, Tensor] = field(default_factory=dict) # Sample metadata global_indexes: set[int] = field( @@ -366,7 +367,7 @@ def update_production_status( # Update production status if self.production_status is not None and global_indices and field_names: field_indices = [self.field_name_mapping.get(field) for field in field_names] - self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 + self.production_status[Tensor(global_indices)[:, None], Tensor(field_indices)] = 1 # Update field metadata self._update_field_metadata(global_indices, dtypes, shapes, custom_meta) @@ -442,28 +443,6 @@ def _update_field_metadata( self.field_custom_metas[global_idx] = {} self.field_custom_metas[global_idx].update(custom_meta_value[i]) - # ==================== Consumption Status Interface ==================== - - def get_consumption_status(self, task_name: str) -> torch.Tensor: - """ - Get or create consumption status for a specific task. - Handles dynamic expansion when new samples are added. - - Args: - task_name: Name of the consumer task - - Returns: - Consumption status tensor for the specified task - """ - - if task_name not in self.consumption_status: - if self.production_status is not None: - self.consumption_status[task_name] = torch.zeros(self.allocated_samples_num, dtype=torch.int8) - else: - self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8) - - return self.consumption_status[task_name] - def mark_consumed(self, task_name: str, global_indices: list[int]): """ Mark specific samples as consumed by a task. @@ -485,7 +464,38 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): f"shape {consumption_status.shape}" ) - def get_production_status_for_fields(self, field_names: list[str]) -> bool: + # ==================== Consumption Status Interface ==================== + + def get_consumption_status(self, task_name: str) -> tuple[Tensor, Tensor]: + """ + Get or create consumption status for a specific task. + Handles dynamic expansion when new samples are added. + + Args: + task_name: Name of the consumer task + + Returns: + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. + """ + + if task_name not in self.consumption_status: + if self.production_status is not None: + self.consumption_status[task_name] = torch.zeros(self.allocated_samples_num, dtype=torch.int8) + else: + self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8) + + # Get mask for target partition + partition_global_index = Tensor(sorted(self.global_indexes), dtype=torch.long) + + # Get consumption status for requested task + relevant_consumption_status = self.consumption_status[task_name][partition_global_index] + + return partition_global_index, relevant_consumption_status + + # ==================== Production Status Interface ==================== + def get_production_status_for_fields(self, field_names: list[str]) -> tuple[Tensor, Tensor]: """ Check if all samples for specified fields are fully produced and ready. @@ -493,7 +503,9 @@ def get_production_status_for_fields(self, field_names: list[str]) -> bool: field_names: List of field names to check production status for Returns: - bool: True if all samples have been produced for all specified fields, False otherwise + Tuple of: + - Partition global index tensor + - Production status tensor for the specified task. 1 for ready, 0 for not ready. """ if self.production_status is None or field_names is None or len(field_names) == 0: return False @@ -509,13 +521,13 @@ def get_production_status_for_fields(self, field_names: list[str]) -> bool: if field_indices: col_mask[field_indices] = True - # Get production status for requested fields - relevant_status = self.production_status[:, col_mask] + # Get mask for target partition + partition_global_index = Tensor(sorted(self.global_indexes), dtype=torch.long) - # Check if all samples have all requested fields produced (all values are 1) - all_fields_produced = torch.all(relevant_status == 1).item() + # Get production status for requested fields + relevant_production_status = self.production_status[partition_global_index, col_mask] - return all_fields_produced + return partition_global_index, relevant_production_status # ==================== Data Scanning and Query Methods ==================== @@ -650,7 +662,7 @@ def _perform_copy(): if name == "data_status_lock": continue - if isinstance(value, torch.Tensor): + if isinstance(value, Tensor): new_val = value.clone().detach() else: new_val = copy.deepcopy(value) @@ -873,7 +885,7 @@ def update_production_status( # ==================== Data Consumption API ==================== - def get_consumption_status(self, partition_id: str, task_name: str) -> Optional[torch.Tensor]: + def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Get or create consumption status for a specific task and partition. Delegates to the partition's own method. @@ -883,15 +895,19 @@ def get_consumption_status(self, partition_id: str, task_name: str) -> Optional[ task_name: Name of the consumer task Returns: - Consumption status tensor if partition exists, None otherwise + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. """ partition = self._get_partition(partition_id) if not partition: - return None + return None, None return partition.get_consumption_status(task_name) - def get_production_status(self, partition_id: str, data_fields: list[str]) -> bool: + def get_production_status( + self, partition_id: str, data_fields: list[str] + ) -> tuple[Optional[Tensor], Optional[Tensor]]: """ Check if all samples for specified fields are fully produced in a partition. @@ -900,11 +916,13 @@ def get_production_status(self, partition_id: str, data_fields: list[str]) -> bo data_fields: List of field names to check production status for Returns: - bool: True if all samples have been produced for all specified fields, False otherwise + Tuple of: + - Partition global index tensor + - Production status tensor for the specified task. 1 for ready, 0 for not ready. """ partition = self._get_partition(partition_id) if not partition: - return False + return None, None return partition.get_production_status_for_fields(data_fields) @@ -1392,22 +1410,19 @@ def _process_request(self): body={"message": f"Clear partition operation completed by controller {self.controller_id}"}, ) - elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION: - with perf_monitor.measure(op_type="CHECK_CONSUMPTION"): + elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION: + with perf_monitor.measure(op_type="GET_CONSUMPTION"): # Handle consumption status checks params = request_msg.body - consumption_status = self.get_consumption_status(params["partition_id"], params["task_name"]) - sample_filter = params.get("sample_filter") + global_index, consumption_status = self.get_consumption_status( + params["partition_id"], params["task_name"] + ) + sample_filter = params.get("sample_filter") # DEPRECATED in future - if consumption_status is not None and sample_filter: - batch_status = consumption_status[sample_filter] - consumed = torch.all(batch_status == 1).item() - elif consumption_status is not None: - batch_status = consumption_status - consumed = torch.all(batch_status == 1).item() - else: - consumed = False + if sample_filter: + # DEPRECATED in future + consumption_status = consumption_status[sample_filter] response_msg = ZMQMessage.create( request_type=ZMQRequestType.CONSUMPTION_RESPONSE, @@ -1415,16 +1430,19 @@ def _process_request(self): receiver_id=request_msg.sender_id, body={ "partition_id": params["partition_id"], - "consumed": consumed, + "global_index": global_index, + "consumption_status": consumption_status, }, ) - elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION: - with perf_monitor.measure(op_type="CHECK_PRODUCTION"): + elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION: + with perf_monitor.measure(op_type="GET_PRODUCTION"): # Handle production status checks params = request_msg.body - produced = self.get_production_status(params["partition_id"], params["data_fields"]) + global_index, production_status = self.get_production_status( + params["partition_id"], params["data_fields"] + ) response_msg = ZMQMessage.create( request_type=ZMQRequestType.PRODUCTION_RESPONSE, @@ -1432,7 +1450,8 @@ def _process_request(self): receiver_id=request_msg.sender_id, body={ "partition_id": params["partition_id"], - "produced": produced, + "global_index": global_index, + "production_status": production_status, }, ) diff --git a/transfer_queue/utils/zmq_utils.py b/transfer_queue/utils/zmq_utils.py index 05b35f8..f462773 100644 --- a/transfer_queue/utils/zmq_utils.py +++ b/transfer_queue/utils/zmq_utils.py @@ -72,12 +72,12 @@ class ZMQRequestType(ExplicitEnum): CLEAR_PARTITION = "CLEAR_PARTITION" CLEAR_PARTITION_RESPONSE = "CLEAR_PARTITION_RESPONSE" - # CHECK_CONSUMPTION - CHECK_CONSUMPTION = "CHECK_CONSUMPTION" + # GET_CONSUMPTION + GET_CONSUMPTION = "GET_CONSUMPTION" CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE" - # CHECK_PRODUCTION - CHECK_PRODUCTION = "CHECK_PRODUCTION" + # GET_PRODUCTION + GET_PRODUCTION = "GET_PRODUCTION" PRODUCTION_RESPONSE = "PRODUCTION_RESPONSE" # LIST_PARTITIONS From c68400799e4782bd2128ba372e999e0b3450260d Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 23 Jan 2026 11:24:09 +0800 Subject: [PATCH 2/7] try fix Signed-off-by: 0oshowero0 --- tests/test_client.py | 10 ++++--- transfer_queue/client.py | 52 +++++++++++++++++++++++++++++++++++- transfer_queue/controller.py | 38 +++++++++++++++----------- 3 files changed, 79 insertions(+), 21 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 4b8394c..f06ede5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -112,18 +112,20 @@ def _handle_requests(self): # Mock partition metadata response response_body = {"metadata": self._mock_batch_meta(request_msg.body)} response_type = ZMQRequestType.GET_PARTITION_META_RESPONSE - elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION: + elif request_msg.request_type == ZMQRequestType.GET_CONSUMPTION: # Mock consumption status check - all consumed response_body = { "partition_id": request_msg.body.get("partition_id"), - "consumed": True, + "global_index": torch.tensor([0,1,2]), + "consumption_status": torch.tensor([1,1,1]), } response_type = ZMQRequestType.CONSUMPTION_RESPONSE - elif request_msg.request_type == ZMQRequestType.CHECK_PRODUCTION: + elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION: # Mock production status check - all produced response_body = { "partition_id": request_msg.body.get("partition_id"), - "produced": True, + "global_index": torch.tensor([0,1,2]), + "production_status": torch.tensor([[1,1,1],[1,1,1]]), } response_type = ZMQRequestType.PRODUCTION_RESPONSE elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS: diff --git a/transfer_queue/client.py b/transfer_queue/client.py index ecc1983..4586bbb 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -570,7 +570,7 @@ async def async_get_consumption_status( assert socket is not None request_msg = ZMQMessage.create( - request_type=ZMQRequestType.CHECK_CONSUMPTION, + request_type=ZMQRequestType.GET_CONSUMPTION, sender_id=self.client_id, receiver_id=self._controller.id, body={ @@ -891,6 +891,31 @@ def check_consumption_status(self, task_name: str, partition_id: str) -> bool: """ return asyncio.run(self.async_check_consumption_status(task_name, partition_id)) + def get_consumption_status( + self, + task_name: str, + partition_id: str, + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Synchronously get consumption status for a specific task and partition. + + Args: + task_name: Name of the task to check consumption for + partition_id: Partition id to check consumption status for + + Returns: + Tuple of: + - Partition global index tensor + - Consumption status tensor for the specified task. 1 for consumed, 0 for not consumed. + + Example: + >>> global_index, consumption_status = client.get_consumption_status( + ... task_name="generate_sequences", + ... partition_id="train_0" + ... ) + >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") + """ + return asyncio.run(self.async_get_consumption_status(task_name, partition_id)) + def check_production_status(self, data_fields: list[str], partition_id: str) -> bool: """Synchronously check if all samples for a partition are ready (produced) for consumption. @@ -903,6 +928,31 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> """ return asyncio.run(self.async_check_production_status(data_fields, partition_id)) + def get_production_status( + self, + data_fields: list[str], + partition_id: str, + ) -> tuple[Optional[Tensor], Optional[Tensor]]: + """Synchronously get production status for a specific data fields and partition. + + Args: + data_fields: Data fields to check production status for + partition_id: Partition id to check production status for + + Returns: + Tuple of: + - Partition global index tensor + - Production status tensor for the specified task. 1 for ready, 0 for not ready. + + Example: + >>> global_index, production_status = client.get_production_status( + ... data_fields=["input_ids", "attention_mask"], + ... partition_id="train_0" + ... ) + >>> print(f"Global index: {global_index}, Production status: {production_status}") + """ + return asyncio.run(self.async_get_production_status(data_fields, partition_id)) + def get_partition_list( self, ): diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 0eea081..50308be 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -453,7 +453,7 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): """ try: - consumption_status = self.get_consumption_status(task_name) + _, consumption_status = self.get_consumption_status(task_name, mask=False) if consumption_status.numel() > 0 and global_indices: consumption_status[global_indices] = 1 @@ -466,13 +466,14 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): # ==================== Consumption Status Interface ==================== - def get_consumption_status(self, task_name: str) -> tuple[Tensor, Tensor]: + def get_consumption_status(self, task_name: str, mask:bool=False) -> tuple[Tensor, Tensor]: """ Get or create consumption status for a specific task. Handles dynamic expansion when new samples are added. Args: task_name: Name of the consumer task + mask: Whether to return only the status for current partition samples Returns: Tuple of: @@ -486,21 +487,24 @@ def get_consumption_status(self, task_name: str) -> tuple[Tensor, Tensor]: else: self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8) - # Get mask for target partition - partition_global_index = Tensor(sorted(self.global_indexes), dtype=torch.long) + consumption_status = self.consumption_status[task_name] - # Get consumption status for requested task - relevant_consumption_status = self.consumption_status[task_name][partition_global_index] + if mask: + # Get mask for target partition + partition_global_index = Tensor(sorted(self.global_indexes), dtype=torch.long) + # Get consumption status for requested task + consumption_status = consumption_status[partition_global_index] - return partition_global_index, relevant_consumption_status + return partition_global_index, consumption_status # ==================== Production Status Interface ==================== - def get_production_status_for_fields(self, field_names: list[str]) -> tuple[Tensor, Tensor]: + def get_production_status_for_fields(self, field_names: list[str], mask:bool=False) -> tuple[Tensor, Tensor]: """ Check if all samples for specified fields are fully produced and ready. Args: field_names: List of field names to check production status for + mask: Whether to return only the status for current partition samples Returns: Tuple of: @@ -521,13 +525,15 @@ def get_production_status_for_fields(self, field_names: list[str]) -> tuple[Tens if field_indices: col_mask[field_indices] = True - # Get mask for target partition - partition_global_index = Tensor(sorted(self.global_indexes), dtype=torch.long) + production_status = self.production_status[:, col_mask] - # Get production status for requested fields - relevant_production_status = self.production_status[partition_global_index, col_mask] + if mask: + # Get mask for target partition + partition_global_index = Tensor(sorted(self.global_indexes), dtype=torch.long) + # Get production status for requested fields + production_status = production_status[partition_global_index] - return partition_global_index, relevant_production_status + return partition_global_index, production_status # ==================== Data Scanning and Query Methods ==================== @@ -555,7 +561,7 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool) # Apply consumption filter (exclude already consumed samples) - consumption_status = self.get_consumption_status(task_name) + consumption_status = self.get_consumption_status(task_name, mask=False) if consumption_status is not None: unconsumed_mask = consumption_status == 0 row_mask &= unconsumed_mask @@ -903,7 +909,7 @@ def get_consumption_status(self, partition_id: str, task_name: str) -> tuple[Opt if not partition: return None, None - return partition.get_consumption_status(task_name) + return partition.get_consumption_status(task_name, mask=True) def get_production_status( self, partition_id: str, data_fields: list[str] @@ -924,7 +930,7 @@ def get_production_status( if not partition: return None, None - return partition.get_production_status_for_fields(data_fields) + return partition.get_production_status_for_fields(data_fields, mask=True) def get_metadata( self, From 6c1bca17ae3bd0af3b893e4c79a101ae023aee18 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 23 Jan 2026 12:02:27 +0800 Subject: [PATCH 3/7] fix Signed-off-by: 0oshowero0 --- tests/test_controller.py | 82 +++++++++++++++++++++--- tests/test_controller_data_partitions.py | 8 +-- transfer_queue/controller.py | 15 ++--- 3 files changed, 85 insertions(+), 20 deletions(-) diff --git a/tests/test_controller.py b/tests/test_controller.py index 360b7f0..728cd9d 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -89,6 +89,7 @@ def test_controller_with_single_partition(self, ray_setup): field_names=metadata.field_names, dtypes=dtypes, shapes=shapes, + custom_meta=None, ) ) assert success @@ -97,13 +98,18 @@ def test_controller_with_single_partition(self, ray_setup): assert partition.production_status.size(0) == gbs * num_n_samples # Test for get production status - production_status = ray.get( + global_index, production_status = ray.get( tq_controller.get_production_status.remote( partition_id=partition_id, data_fields=data_fields, ) ) - assert production_status + # Verify global_index contains all expected indexes + assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long)) + # Verify all samples are produced for all fields (status should be 1) + expected_production_status = torch.ones(gbs * num_n_samples, len(metadata.field_names), dtype=torch.int8) + assert torch.equal(production_status, expected_production_status) + print("✓ Get production status returns correct global_index and production_status") # Total fields should match the number of fields we added assert partition.total_fields_num == len(data_fields) @@ -126,14 +132,19 @@ def test_controller_with_single_partition(self, ray_setup): print(f"✓ Updated production status for partition {partition_id}") - # Test for get consumption status - consumption_status = ray.get( + # Test for get consumption status BEFORE consumption + global_index, consumption_status = ray.get( tq_controller.get_consumption_status.remote( partition_id=partition_id, task_name="generate_sequences", ) ) - assert torch.equal(consumption_status, torch.zeros(gbs * num_n_samples)) + # Verify global_index + assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long)) + # Verify all samples are NOT consumed yet (status should be 0) + expected_consumption_status_before = torch.zeros(gbs * num_n_samples, dtype=torch.int8) + assert torch.equal(consumption_status, expected_consumption_status_before) + print("✓ Get consumption status returns correct global_index and status (before consumption)") # Test get metadate in fetch mode gen_meta = ray.get( @@ -153,14 +164,19 @@ def test_controller_with_single_partition(self, ray_setup): assert torch.equal(partition.consumption_status["generate_sequences"], torch.ones(gbs * num_n_samples)) print("✓ Get metadata in fetch mode correct") - # Test for get consumption status - consumption_status = ray.get( + # Test for get consumption status AFTER consumption + global_index, consumption_status = ray.get( tq_controller.get_consumption_status.remote( partition_id=partition_id, task_name="generate_sequences", ) ) - assert torch.equal(consumption_status, torch.ones(gbs * num_n_samples)) + # Verify global_index + assert torch.equal(global_index, torch.tensor(range(gbs * num_n_samples), dtype=torch.long)) + # Verify all samples are consumed (status should be 1) + expected_consumption_status_after = torch.ones(gbs * num_n_samples, dtype=torch.int8) + assert torch.equal(consumption_status, expected_consumption_status_after) + print("✓ Get consumption status returns correct global_index and status (after consumption)") # Test get clear meta clear_meta = ray.get( @@ -222,6 +238,19 @@ def test_controller_with_multi_partitions(self, ray_setup): ) assert success + # Verify get production status returns correct data + global_index_1, production_status_1 = ray.get( + tq_controller.get_production_status.remote( + partition_id=partition_id_1, + data_fields=data_fields, + ) + ) + expected_global_index_1 = torch.tensor(range(gbs_1 * num_n_samples_1), dtype=torch.long) + assert torch.equal(global_index_1, expected_global_index_1) + expected_production_status_1 = torch.ones(gbs_1 * num_n_samples_1, len(data_fields), dtype=torch.int8) + assert torch.equal(production_status_1, expected_production_status_1) + print("✓ Get production status for partition_1 returns correct global_index and status") + # Test get metadate in fetch mode gen_meta = ray.get( tq_controller.get_metadata.remote( @@ -234,6 +263,18 @@ def test_controller_with_multi_partitions(self, ray_setup): ) assert gen_meta + # Verify get consumption status after fetch (samples should be consumed) + global_index_1_consumed, consumption_status_1 = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id_1, + task_name="generate_sequences", + ) + ) + assert torch.equal(global_index_1_consumed, expected_global_index_1) + expected_consumption_status_1 = torch.ones(gbs_1 * num_n_samples_1, dtype=torch.int8) + assert torch.equal(consumption_status_1, expected_consumption_status_1) + print("✓ Get consumption status for partition_1 returns correct global_index and status (after fetch)") + # Test get clear meta clear_meta = ray.get( tq_controller.get_metadata.remote( @@ -282,6 +323,31 @@ def test_controller_with_multi_partitions(self, ray_setup): ) assert success + # Verify get production status for partition_2 + global_index_2, production_status_2 = ray.get( + tq_controller.get_production_status.remote( + partition_id=partition_id_2, + data_fields=data_fields, + ) + ) + expected_global_index_2 = torch.tensor(range(part1_index_range, part2_index_range + part1_index_range), dtype=torch.long) + assert torch.equal(global_index_2, expected_global_index_2) + expected_production_status_2 = torch.ones(part2_index_range, len(data_fields), dtype=torch.int8) + assert torch.equal(production_status_2, expected_production_status_2) + print("✓ Get production status for partition_2 returns correct global_index and status") + + # Verify get consumption status for partition_2 (before consumption - should be all zeros) + global_index_2_consumed, consumption_status_2 = ray.get( + tq_controller.get_consumption_status.remote( + partition_id=partition_id_2, + task_name="generate_sequences", + ) + ) + assert torch.equal(global_index_2_consumed, expected_global_index_2) + expected_consumption_status_2 = torch.zeros(part2_index_range, dtype=torch.int8) + assert torch.equal(consumption_status_2, expected_consumption_status_2) + print("✓ Get consumption status for partition_2 returns correct global_index and status (before consumption)") + # Clear partition 1 partition_index_range_1 = ray.get(tq_controller.get_partition_index_range.remote(partition_id_1)) assert partition_index_range_1 diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index ee1092a..32ede5f 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -86,7 +86,7 @@ def test_data_partition_status(): print("✓ Field metadata retrieval works") # Test consumption status - consumption_tensor = partition.get_consumption_status("test_task") + global_index, consumption_tensor = partition.get_consumption_status("test_task", ) assert consumption_tensor is not None assert consumption_tensor.shape[0] == partition.total_samples_num @@ -240,7 +240,7 @@ def test_data_partition_status_advanced(): # Initial consumption tracking partition.mark_consumed(task_name, [0, 1]) - initial_consumption = partition.get_consumption_status(task_name) + global_index, initial_consumption = partition.get_consumption_status(task_name) assert initial_consumption[0] == 1 assert initial_consumption[1] == 1 @@ -258,7 +258,7 @@ def test_data_partition_status_advanced(): 12: {"field_d": (32,)}, } partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes, None) # Triggers sample expansion - expanded_consumption = partition.get_consumption_status(task_name) + global_index, expanded_consumption = partition.get_consumption_status(task_name) assert expanded_consumption[0] == 1 # Preserved assert expanded_consumption[1] == 1 # Preserved assert expanded_consumption.shape[0] >= 13 # Expanded to accommodate new samples @@ -358,7 +358,7 @@ def test_edge_cases_and_error_handling(): # Test 3: Consumption status edge cases # Test consumption status creation before production status task_name = "early_task" - consumption_tensor = partition.get_consumption_status(task_name) + _, consumption_tensor = partition.get_consumption_status(task_name) assert consumption_tensor is not None assert consumption_tensor.shape[0] == partition.allocated_samples_num diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 50308be..c61d1c2 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -367,7 +367,7 @@ def update_production_status( # Update production status if self.production_status is not None and global_indices and field_names: field_indices = [self.field_name_mapping.get(field) for field in field_names] - self.production_status[Tensor(global_indices)[:, None], Tensor(field_indices)] = 1 + self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 # Update field metadata self._update_field_metadata(global_indices, dtypes, shapes, custom_meta) @@ -487,12 +487,12 @@ def get_consumption_status(self, task_name: str, mask:bool=False) -> tuple[Tenso else: self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8) + # Get consumption status for requested task consumption_status = self.consumption_status[task_name] + partition_global_index = torch.tensor(sorted(self.global_indexes), dtype=torch.long) + if mask: - # Get mask for target partition - partition_global_index = Tensor(sorted(self.global_indexes), dtype=torch.long) - # Get consumption status for requested task consumption_status = consumption_status[partition_global_index] return partition_global_index, consumption_status @@ -527,10 +527,9 @@ def get_production_status_for_fields(self, field_names: list[str], mask:bool=Fal production_status = self.production_status[:, col_mask] + partition_global_index = torch.tensor(sorted(self.global_indexes), dtype=torch.long) + if mask: - # Get mask for target partition - partition_global_index = Tensor(sorted(self.global_indexes), dtype=torch.long) - # Get production status for requested fields production_status = production_status[partition_global_index] return partition_global_index, production_status @@ -561,7 +560,7 @@ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]: row_mask = torch.ones(self.allocated_samples_num, dtype=torch.bool) # Apply consumption filter (exclude already consumed samples) - consumption_status = self.get_consumption_status(task_name, mask=False) + _, consumption_status = self.get_consumption_status(task_name, mask=False) if consumption_status is not None: unconsumed_mask = consumption_status == 0 row_mask &= unconsumed_mask From 11f8eebc7f0d65fec5d07e6faa6c48dab826ab1d Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 23 Jan 2026 12:09:19 +0800 Subject: [PATCH 4/7] fix bugs & add CI Signed-off-by: 0oshowero0 --- tests/test_client.py | 102 ++++++++++++++++++++- tests/test_controller_data_partitions.py | 112 ++++++++++++++++++++++- 2 files changed, 208 insertions(+), 6 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index f06ede5..e8027b4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -116,16 +116,16 @@ def _handle_requests(self): # Mock consumption status check - all consumed response_body = { "partition_id": request_msg.body.get("partition_id"), - "global_index": torch.tensor([0,1,2]), - "consumption_status": torch.tensor([1,1,1]), + "global_index": torch.tensor([0, 1, 2]), + "consumption_status": torch.tensor([1, 1, 1]), } response_type = ZMQRequestType.CONSUMPTION_RESPONSE elif request_msg.request_type == ZMQRequestType.GET_PRODUCTION: # Mock production status check - all produced response_body = { "partition_id": request_msg.body.get("partition_id"), - "global_index": torch.tensor([0,1,2]), - "production_status": torch.tensor([[1,1,1],[1,1,1]]), + "global_index": torch.tensor([0, 1, 2]), + "production_status": torch.tensor([[1, 1, 1], [1, 1, 1]]), } response_type = ZMQRequestType.PRODUCTION_RESPONSE elif request_msg.request_type == ZMQRequestType.GET_LIST_PARTITIONS: @@ -469,6 +469,52 @@ def test_check_production_status(client_setup): assert is_produced is True +def test_get_consumption_status(client_setup): + """Test get_consumption_status - returns global_index and consumption_status tensors""" + client, _, _ = client_setup + + # Test synchronous get_consumption_status + global_index, consumption_status = client.get_consumption_status( + task_name="generate_sequences", partition_id="train_0" + ) + + # Verify return types + assert global_index is not None + assert consumption_status is not None + + # Verify global_index contains expected values + assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long)) + + # Verify consumption_status (mock returns all consumed) + expected_status = torch.tensor([1, 1, 1], dtype=torch.int8) + assert torch.equal(consumption_status, expected_status) + + print("✓ get_consumption_status returns correct global_index and consumption_status") + + +def test_get_production_status(client_setup): + """Test get_production_status - returns global_index and production_status tensors""" + client, _, _ = client_setup + + # Test synchronous get_production_status + global_index, production_status = client.get_production_status( + data_fields=["prompt_ids", "attention_mask"], partition_id="train_0" + ) + + # Verify return types + assert global_index is not None + assert production_status is not None + + # Verify global_index contains expected values + assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long)) + + # Verify production_status shape (mock returns 2x3 matrix) + expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8) + assert torch.equal(production_status, expected_status) + + print("✓ get_production_status returns correct global_index and production_status") + + def test_get_partition_list(client_setup): """Test partition list retrieval""" client, _, _ = client_setup @@ -504,6 +550,54 @@ async def test_async_check_production_status(client_setup): assert is_produced is True +@pytest.mark.asyncio +async def test_async_get_consumption_status(client_setup): + """Test async get_consumption_status - returns global_index and consumption_status tensors""" + client, _, _ = client_setup + + # Test async_get_consumption_status + global_index, consumption_status = await client.async_get_consumption_status( + task_name="generate_sequences", partition_id="train_0" + ) + + # Verify return types + assert global_index is not None + assert consumption_status is not None + + # Verify global_index contains expected values + assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long)) + + # Verify consumption_status (mock returns all consumed) + expected_status = torch.tensor([1, 1, 1], dtype=torch.int8) + assert torch.equal(consumption_status, expected_status) + + print("✓ async_get_consumption_status returns correct global_index and consumption_status") + + +@pytest.mark.asyncio +async def test_async_get_production_status(client_setup): + """Test async get_production_status - returns global_index and production_status tensors""" + client, _, _ = client_setup + + # Test async_get_production_status + global_index, production_status = await client.async_get_production_status( + data_fields=["prompt_ids", "attention_mask"], partition_id="train_0" + ) + + # Verify return types + assert global_index is not None + assert production_status is not None + + # Verify global_index contains expected values + assert torch.equal(global_index, torch.tensor([0, 1, 2], dtype=torch.long)) + + # Verify production_status shape (mock returns 2x3 matrix) + expected_status = torch.tensor([[1, 1, 1], [1, 1, 1]], dtype=torch.int8) + assert torch.equal(production_status, expected_status) + + print("✓ async_get_production_status returns correct global_index and production_status") + + @pytest.mark.asyncio async def test_async_get_partition_list(client_setup): """Test async partition list retrieval""" diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 32ede5f..06b580b 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -86,9 +86,9 @@ def test_data_partition_status(): print("✓ Field metadata retrieval works") # Test consumption status - global_index, consumption_tensor = partition.get_consumption_status("test_task", ) + global_index, consumption_tensor = partition.get_consumption_status("test_task", mask=False) assert consumption_tensor is not None - assert consumption_tensor.shape[0] == partition.total_samples_num + assert consumption_tensor.shape[0] == partition.allocated_samples_num print("✓ Consumption status creation works") @@ -526,3 +526,111 @@ def test_update_field_metadata_variants(): # Length mismatch should raise ValueError when provided mapping lengths differ from global_indices with pytest.raises(ValueError): partition._update_field_metadata([0, 1, 2], dtypes={0: {}}, shapes=None, custom_meta=None) + + +def test_get_production_status_for_fields(): + """Test get_production_status_for_fields method with mask parameter.""" + print("Testing get_production_status_for_fields...") + + import torch + + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="production_status_test") + + # Add some data first + partition.update_production_status( + global_indices=[0, 1, 2, 3, 4], + field_names=["field_a", "field_b"], + dtypes={i: {"field_a": "torch.int64", "field_b": "torch.bool"} for i in range(5)}, + shapes={i: {"field_a": (32,), "field_b": (32,)} for i in range(5)}, + ) + + # Test get_production_status_for_fields WITHOUT mask (mask=False) + global_index, production_status = partition.get_production_status_for_fields( + field_names=["field_a", "field_b"], mask=False + ) + # Without mask, should return all allocated samples + assert production_status.shape[0] == partition.allocated_samples_num + # Production status should be 1 for samples 0-4 (produced), 0 for others + # Check that samples 0-4 have all fields produced (all 1s) + assert torch.all(production_status[0:5] == 1), "Samples 0-4 should all be produced" + # Verify shape - should have 2 fields (columns) + assert production_status.shape[1] == 2 + + print("✓ get_production_status_for_fields without mask works") + + # Test get_production_status_for_fields WITH mask (mask=True) + global_index_masked, production_status_masked = partition.get_production_status_for_fields( + field_names=["field_a", "field_b"], mask=True + ) + # With mask, should return only global_indexes (0-4) + assert global_index_masked.shape[0] == 5 + assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)) + # Masked status should be same as original for these indices + assert production_status_masked.shape[0] == 5 + assert torch.all(production_status_masked == 1) + + print("✓ get_production_status_for_fields with mask works") + + # Test get_production_status_for_fields with subset of fields + global_index_subset, production_status_subset = partition.get_production_status_for_fields( + field_names=["field_a"], mask=True + ) + assert global_index_subset.shape[0] == 5 + assert production_status_subset.shape == (5, 1) # Only one field + + print("✓ get_production_status_for_fields with subset fields works") + + print("get_production_status_for_fields tests passed!\n") + + +def test_consumption_status_mask_parameter(): + """Test get_consumption_status method with mask parameter.""" + print("Testing consumption status mask parameter...") + + import torch + + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="consumption_mask_test") + + # Add some data + partition.update_production_status( + global_indices=[0, 1, 2, 3, 4], + field_names=["field_a"], + dtypes={i: {"field_a": "torch.int64"} for i in range(5)}, + shapes={i: {"field_a": (32,)} for i in range(5)}, + ) + + # Mark some samples as consumed + partition.mark_consumed("test_task", [0, 2, 4]) + + # Test get_consumption_status WITHOUT mask (mask=False) + global_index, consumption_status = partition.get_consumption_status("test_task", mask=False) + # Without mask, should return all allocated samples + assert consumption_status.shape[0] == partition.allocated_samples_num + assert consumption_status[0].item() == 1 + assert consumption_status[1].item() == 0 # Not consumed + assert consumption_status[2].item() == 1 + assert consumption_status[3].item() == 0 + assert consumption_status[4].item() == 1 + + print("✓ get_consumption_status without mask works") + + # Test get_consumption_status WITH mask (mask=True) + global_index_masked, consumption_status_masked = partition.get_consumption_status("test_task", mask=True) + # With mask, should return only global_indexes (0-4) + assert global_index_masked.shape[0] == 5 + assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)) + # Masked status should correspond to global indexes + assert consumption_status_masked.shape[0] == 5 + assert consumption_status_masked[0].item() == 1 + assert consumption_status_masked[1].item() == 0 + assert consumption_status_masked[2].item() == 1 + assert consumption_status_masked[3].item() == 0 + assert consumption_status_masked[4].item() == 1 + + print("✓ get_consumption_status with mask works") + + print("Consumption status mask parameter tests passed!\n") From a9bc74ea66e8d510432d6e47f32df541b5aeb462 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 23 Jan 2026 12:20:24 +0800 Subject: [PATCH 5/7] fix pre-commit Signed-off-by: 0oshowero0 --- tests/test_controller.py | 4 +++- transfer_queue/controller.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_controller.py b/tests/test_controller.py index 728cd9d..3d1379e 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -330,7 +330,9 @@ def test_controller_with_multi_partitions(self, ray_setup): data_fields=data_fields, ) ) - expected_global_index_2 = torch.tensor(range(part1_index_range, part2_index_range + part1_index_range), dtype=torch.long) + expected_global_index_2 = torch.tensor( + range(part1_index_range, part2_index_range + part1_index_range), dtype=torch.long + ) assert torch.equal(global_index_2, expected_global_index_2) expected_production_status_2 = torch.ones(part2_index_range, len(data_fields), dtype=torch.int8) assert torch.equal(production_status_2, expected_production_status_2) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index c61d1c2..03b8530 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -466,7 +466,7 @@ def mark_consumed(self, task_name: str, global_indices: list[int]): # ==================== Consumption Status Interface ==================== - def get_consumption_status(self, task_name: str, mask:bool=False) -> tuple[Tensor, Tensor]: + def get_consumption_status(self, task_name: str, mask: bool = False) -> tuple[Tensor, Tensor]: """ Get or create consumption status for a specific task. Handles dynamic expansion when new samples are added. @@ -498,7 +498,7 @@ def get_consumption_status(self, task_name: str, mask:bool=False) -> tuple[Tenso return partition_global_index, consumption_status # ==================== Production Status Interface ==================== - def get_production_status_for_fields(self, field_names: list[str], mask:bool=False) -> tuple[Tensor, Tensor]: + def get_production_status_for_fields(self, field_names: list[str], mask: bool = False) -> tuple[Tensor, Tensor]: """ Check if all samples for specified fields are fully produced and ready. From 9150ad0cfc17350ab89f1e904b72885596750db4 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 23 Jan 2026 13:11:00 +0800 Subject: [PATCH 6/7] better CI Signed-off-by: 0oshowero0 --- tests/test_controller_data_partitions.py | 73 +++++++++++++++--------- 1 file changed, 46 insertions(+), 27 deletions(-) diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 06b580b..df9f60d 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -538,23 +538,28 @@ def test_get_production_status_for_fields(): partition = DataPartitionStatus(partition_id="production_status_test") - # Add some data first + # Add some data first (using non-contiguous indices) partition.update_production_status( - global_indices=[0, 1, 2, 3, 4], + global_indices=[0, 1, 2, 3, 9], field_names=["field_a", "field_b"], - dtypes={i: {"field_a": "torch.int64", "field_b": "torch.bool"} for i in range(5)}, - shapes={i: {"field_a": (32,), "field_b": (32,)} for i in range(5)}, + dtypes={i: {"field_a": "torch.int64", "field_b": "torch.bool"} for i in [0, 1, 2, 3, 9]}, + shapes={i: {"field_a": (32,), "field_b": (32,)} for i in [0, 1, 2, 3, 9]}, ) # Test get_production_status_for_fields WITHOUT mask (mask=False) global_index, production_status = partition.get_production_status_for_fields( field_names=["field_a", "field_b"], mask=False ) + assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long)) # Without mask, should return all allocated samples assert production_status.shape[0] == partition.allocated_samples_num - # Production status should be 1 for samples 0-4 (produced), 0 for others - # Check that samples 0-4 have all fields produced (all 1s) - assert torch.all(production_status[0:5] == 1), "Samples 0-4 should all be produced" + # Production status should be 1 for produced samples (0, 1, 2, 3, 9), 0 for others + # Check that produced samples have all fields produced (all 1s) + assert torch.all(production_status[0] == 1), "Sample 0 should be produced" + assert torch.all(production_status[1] == 1), "Sample 1 should be produced" + assert torch.all(production_status[2] == 1), "Sample 2 should be produced" + assert torch.all(production_status[3] == 1), "Sample 3 should be produced" + assert torch.all(production_status[9] == 1), "Sample 9 should be produced" # Verify shape - should have 2 fields (columns) assert production_status.shape[1] == 2 @@ -564,11 +569,10 @@ def test_get_production_status_for_fields(): global_index_masked, production_status_masked = partition.get_production_status_for_fields( field_names=["field_a", "field_b"], mask=True ) - # With mask, should return only global_indexes (0-4) - assert global_index_masked.shape[0] == 5 - assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)) + assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long)) # Masked status should be same as original for these indices - assert production_status_masked.shape[0] == 5 + assert production_status_masked.shape == (len([0, 1, 2, 3, 9]), 2) + # All returned samples should be produced assert torch.all(production_status_masked == 1) print("✓ get_production_status_for_fields with mask works") @@ -577,15 +581,15 @@ def test_get_production_status_for_fields(): global_index_subset, production_status_subset = partition.get_production_status_for_fields( field_names=["field_a"], mask=True ) - assert global_index_subset.shape[0] == 5 - assert production_status_subset.shape == (5, 1) # Only one field + assert global_index_subset.shape[0] == len([0, 1, 2, 3, 9]) + assert production_status_subset.shape == (len([0, 1, 2, 3, 9]), 1) # Only one field print("✓ get_production_status_for_fields with subset fields works") print("get_production_status_for_fields tests passed!\n") -def test_consumption_status_mask_parameter(): +def test_get_consumption_status_parameter(): """Test get_consumption_status method with mask parameter.""" print("Testing consumption status mask parameter...") @@ -594,42 +598,57 @@ def test_consumption_status_mask_parameter(): from transfer_queue.controller import DataPartitionStatus partition = DataPartitionStatus(partition_id="consumption_mask_test") + partition_another = DataPartitionStatus(partition_id="other_partition") # Add some data partition.update_production_status( - global_indices=[0, 1, 2, 3, 4], + global_indices=[0, 1, 2, 3, 9], field_names=["field_a"], - dtypes={i: {"field_a": "torch.int64"} for i in range(5)}, - shapes={i: {"field_a": (32,)} for i in range(5)}, + dtypes={i: {"field_a": "torch.int64"} for i in [0, 1, 2, 3, 9]}, + shapes={i: {"field_a": (32,)} for i in [0, 1, 2, 3, 9]}, + ) + + partition_another.update_production_status( + global_indices=[5, 6, 7], + field_names=["field_a"], + dtypes={i: {"field_a": "torch.int64"} for i in [5, 6, 7]}, + shapes={i: {"field_a": (32,)} for i in [5, 6, 7]}, ) # Mark some samples as consumed - partition.mark_consumed("test_task", [0, 2, 4]) + partition.mark_consumed("test_task", [0, 2]) # Test get_consumption_status WITHOUT mask (mask=False) global_index, consumption_status = partition.get_consumption_status("test_task", mask=False) + assert global_index.shape[0] == partition.total_samples_num + assert torch.equal(global_index, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long)) # Without mask, should return all allocated samples - assert consumption_status.shape[0] == partition.allocated_samples_num + assert consumption_status.shape[0] == 10 assert consumption_status[0].item() == 1 - assert consumption_status[1].item() == 0 # Not consumed + assert consumption_status[1].item() == 0 assert consumption_status[2].item() == 1 assert consumption_status[3].item() == 0 - assert consumption_status[4].item() == 1 + assert consumption_status[4].item() == 0 # empty slot + assert consumption_status[5].item() == 0 # empty slot + assert consumption_status[6].item() == 0 # empty slot + assert consumption_status[7].item() == 0 # empty slot + assert consumption_status[8].item() == 0 # empty slot + assert consumption_status[9].item() == 0 print("✓ get_consumption_status without mask works") # Test get_consumption_status WITH mask (mask=True) global_index_masked, consumption_status_masked = partition.get_consumption_status("test_task", mask=True) - # With mask, should return only global_indexes (0-4) - assert global_index_masked.shape[0] == 5 - assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 4], dtype=torch.long)) - # Masked status should correspond to global indexes - assert consumption_status_masked.shape[0] == 5 + # With mask, should return only global_indexes [0, 1, 2, 3, 9] + assert global_index_masked.shape[0] == partition.total_samples_num + assert torch.equal(global_index_masked, torch.tensor([0, 1, 2, 3, 9], dtype=torch.long)) + # Masked status shape[0] should correspond to global indexes + assert consumption_status_masked.shape[0] == partition.total_samples_num assert consumption_status_masked[0].item() == 1 assert consumption_status_masked[1].item() == 0 assert consumption_status_masked[2].item() == 1 assert consumption_status_masked[3].item() == 0 - assert consumption_status_masked[4].item() == 1 + assert consumption_status_masked[4].item() == 0 # no empty slot. this corresponds to global_index=9 print("✓ get_consumption_status with mask works") From 779fa01c5dfba6fe46a90c9481cc3b0e2fec3331 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Fri, 23 Jan 2026 13:19:09 +0800 Subject: [PATCH 7/7] fix comments Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 19 +++++++++---------- transfer_queue/controller.py | 6 +++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 4586bbb..37096b4 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -561,7 +561,7 @@ async def async_get_consumption_status( Example: >>> # Get consumption status - >>> global_index, consumption_status = asyncio.run(client.async_check_consumption_status( + >>> global_index, consumption_status = asyncio.run(client.async_get_consumption_status( ... task_name="generate_sequences", ... partition_id="train_0" ... )) @@ -607,7 +607,7 @@ async def async_get_production_status( partition_id: str, socket: Optional[zmq.asyncio.Socket] = None, ) -> tuple[Optional[Tensor], Optional[Tensor]]: - """Get production status for current partition in a specific task. + """Get production status for current partition for specific fields. Args: data_fields: Data fields to check production status for @@ -617,14 +617,14 @@ async def async_get_production_status( Returns: Tuple of: - Partition global index tensor - - Production status tensor for the specified task. 1 for ready, 0 for not ready. + - Production status tensor for the specified fields. 1 for ready, 0 for not ready. Raises: RuntimeError: If communication fails or controller returns error response Example: >>> # Get production status - >>> global_index, production_status = asyncio.run(client.async_check_production_status( + >>> global_index, production_status = asyncio.run(client.async_get_production_status( ... data_fields=["input_ids", "attention_mask"], ... partition_id="train_0" ... )) @@ -672,7 +672,6 @@ async def async_check_consumption_status( Args: task_name: Name of the task to check consumption for partition_id: Partition id to check consumption status for - socket: ZMQ async socket for message transmission (injected by decorator) Returns: bool: True if all samples have been consumed by the task, False otherwise @@ -696,19 +695,19 @@ async def async_check_consumption_status( if consumption_status is None: return False - return torch.all(consumption_status).item() == 1 + return torch.all(consumption_status == 1).item() async def async_check_production_status( self, data_fields: list[str], partition_id: str, ) -> bool: - """Check if all samples for current partition are ready (produced) for consumption. + """Check if the all specific fields of samples for current partition are ready + (produced) for consumption. Args: data_fields: Data fields to check production status for partition_id: Partition id to check production status for - socket: ZMQ async socket for message transmission (injected by decorator) Returns: bool: True if all samples have been produced and ready, False otherwise @@ -731,7 +730,7 @@ async def async_check_production_status( if production_status is None: return False - return torch.all(production_status).item() == 1 + return torch.all(production_status == 1).item() @dynamic_socket(socket_name="request_handle_socket") async def async_get_partition_list( @@ -942,7 +941,7 @@ def get_production_status( Returns: Tuple of: - Partition global index tensor - - Production status tensor for the specified task. 1 for ready, 0 for not ready. + - Production status tensor for the specified fields. 1 for ready, 0 for not ready. Example: >>> global_index, production_status = client.get_production_status( diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 03b8530..2b6767f 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -1423,10 +1423,10 @@ def _process_request(self): global_index, consumption_status = self.get_consumption_status( params["partition_id"], params["task_name"] ) - sample_filter = params.get("sample_filter") # DEPRECATED in future + sample_filter = params.get("sample_filter") # TODO: DEPRECATED in future - if sample_filter: - # DEPRECATED in future + if sample_filter and consumption_status is not None: + # TODO: DEPRECATED in future consumption_status = consumption_status[sample_filter] response_msg = ZMQMessage.create(