From 4740782a5f0654bbf1d770e5323cff12988d1f28 Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Wed, 21 Jan 2026 18:41:34 +0800 Subject: [PATCH 1/8] 1. add custom_meta to controller 2. allow kv storage manager put returns their custom metadata per index per field Signed-off-by: tianyi-ge --- transfer_queue/controller.py | 73 ++++++++++++++++--- transfer_queue/metadata.py | 33 ++++++++- transfer_queue/storage/clients/base.py | 13 +++- .../storage/clients/mooncake_client.py | 8 +- .../storage/clients/yuanrong_client.py | 6 +- transfer_queue/storage/managers/base.py | 69 +++++++++++++++--- 6 files changed, 169 insertions(+), 33 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 3eab84b..98bfefe 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -230,6 +230,7 @@ class DataPartitionStatus: field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype} field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape} + field_custom_metas: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: custom_meta} # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. # No need to strictly lock for every read/write operation since freshness is not critical. @@ -326,6 +327,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]], ) -> bool: """ Update production status for specific samples and fields. @@ -336,6 +338,7 @@ def update_production_status( field_names: List of field names to mark as produced dtypes: Optional per-sample field dtype information shapes: Optional per-sample field shape information + custom_meta: Optional per-sample field custom metadata Returns: True if update was successful, False on error @@ -366,7 +369,7 @@ def update_production_status( self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 # Update field metadata - self._update_field_metadata(global_indices, dtypes, shapes) + self._update_field_metadata(global_indices, dtypes, shapes, custom_meta) # Save these global_indexes self.global_indexes.update(global_indices) @@ -380,8 +383,9 @@ def update_production_status( def _update_field_metadata( self, global_indices: list[int], - dtypes: Optional[dict[int, dict[str, Any]]], - shapes: Optional[dict[int, dict[str, Any]]], + dtypes: dict[int, dict[str, Any]], + shapes: dict[int, dict[str, Any]], + custom_meta: Optional[dict[int, dict[str, Any]]], ): """Update field dtype and shape metadata.""" if not global_indices: @@ -409,6 +413,21 @@ def _update_field_metadata( if shape_value is not None: self.field_shapes[global_idx].update(shape_value[i]) + if custom_meta: + if len(global_indices) != len(custom_meta): + raise ValueError( + f"Length of global_indices ({len(global_indices)}) does not match " + f"length of custom_meta ({len(custom_meta)})" + ) + custom_meta_value = itemgetter(*global_indices)(custom_meta) if custom_meta else None + if not isinstance(custom_meta_value, tuple): + custom_meta_value = (custom_meta_value,) + for i, global_idx in enumerate(global_indices): + if global_idx not in self.field_custom_metas: + self.field_custom_metas[global_idx] = {} + if custom_meta_value is not None: + self.field_custom_metas[global_idx].update(custom_meta_value[i]) + # ==================== Consumption Status Interface ==================== def get_consumption_status(self, task_name: str) -> torch.Tensor: @@ -544,6 +563,14 @@ def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: """Get shape for a specific sample and field.""" return self.field_shapes.get(global_index, {}).get(field_name) + def get_field_custom_meta(self, global_indices: list[int], field_names: list[str]) -> Optional[Any]: + """Get custom_meta for a specific sample and field.""" + return { + idx: {f: v for f, v in self.field_custom_metas[idx].items() if f in field_names} + for idx in global_indices + if idx in self.field_custom_metas + } + # ==================== Statistics and Monitoring ==================== def get_statistics(self) -> dict[str, Any]: @@ -571,7 +598,9 @@ def get_statistics(self) -> dict[str, Any]: field_produced = (self.production_status[:, field_idx] == 1).sum().item() field_stats[field_name] = { "produced_samples": field_produced, - "production_progress": field_produced / self.total_samples_num if self.total_samples_num > 0 else 0, + "production_progress": ( + field_produced / self.total_samples_num if self.total_samples_num > 0 else 0 + ), } stats["field_statistics"] = field_stats @@ -581,7 +610,9 @@ def get_statistics(self) -> dict[str, Any]: consumed_samples = (consumption_tensor == 1).sum().item() consumption_stats[task_name] = { "consumed_samples": consumed_samples, - "consumption_progress": consumed_samples / self.total_samples_num if self.total_samples_num > 0 else 0, + "consumption_progress": ( + consumed_samples / self.total_samples_num if self.total_samples_num > 0 else 0 + ), } stats["consumption_statistics"] = consumption_stats @@ -632,6 +663,9 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr consumption_tensor[indexes_to_release] = 0 self.global_indexes.difference_update(indexes_to_release) + self.field_dtypes.difference_update(indexes_to_release) + self.field_shapes.difference_update(indexes_to_release) + self.field_custom_metas.difference_update(indexes_to_release) except Exception as e: logger.error( @@ -658,7 +692,9 @@ class TransferQueueController: """ def __init__( - self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler, polling_mode: bool = False + self, + sampler: BaseSampler | type[BaseSampler] = SequentialSampler, + polling_mode: bool = False, ) -> None: """Initialize the TransferQueue Controller. @@ -791,6 +827,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]], ) -> bool: """ Update production status for specific samples and fields in a partition. @@ -811,7 +848,7 @@ def update_production_status( logger.error(f"Partition {partition_id} not found") return False - success = partition.update_production_status(global_indexes, field_names, dtypes, shapes) + success = partition.update_production_status(global_indexes, field_names, dtypes, shapes, custom_meta) if success: logger.debug( f"[{self.controller_id}]: Updated production status for partition {partition_id}: " @@ -1070,7 +1107,11 @@ def generate_batch_meta( ) samples.append(sample) - return BatchMeta(samples=samples) + custom_meta = partition.get_field_custom_meta(batch_global_indexes, data_fields) + + batch_meta = BatchMeta(samples=samples) + batch_meta.update_custom_meta(custom_meta) + return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): """ @@ -1092,7 +1133,12 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True): self.index_manager.release_partition(partition_id) self.partitions.pop(partition_id) - def clear_meta(self, global_indexes: list[int], partition_ids: list[str], clear_consumption: bool = True): + def clear_meta( + self, + global_indexes: list[int], + partition_ids: list[str], + clear_consumption: bool = True, + ): """ Clear meta for individual samples (preserving the partition). @@ -1230,7 +1276,9 @@ def _wait_connection(self): def _start_process_handshake(self): """Start the handshake process thread.""" self.wait_connection_thread = Thread( - target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True + target=self._wait_connection, + name="TransferQueueControllerWaitConnectionThread", + daemon=True, ) self.wait_connection_thread.start() @@ -1246,7 +1294,9 @@ def _start_process_update_data_status(self): def _start_process_request(self): """Start the request processing thread.""" self.process_request_thread = Thread( - target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True + target=self._process_request, + name="TransferQueueControllerProcessRequestThread", + daemon=True, ) self.process_request_thread.start() @@ -1408,6 +1458,7 @@ def _update_data_status(self): field_names=message_data.get("fields", []), dtypes=message_data.get("dtypes", {}), shapes=message_data.get("shapes", {}), + custom_meta=message_data.get("custom_meta", {}), ) if success: diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index c389157..4ac7d94 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -126,7 +126,9 @@ def select_fields(self, field_names: list[str]) -> "SampleMeta": # construct new SampleMeta instance selected_sample_meta = SampleMeta( - fields=selected_fields, partition_id=self.partition_id, global_index=self.global_index + fields=selected_fields, + partition_id=self.partition_id, + global_index=self.global_index, ) return selected_sample_meta @@ -174,6 +176,8 @@ class BatchMeta: samples: list[SampleMeta] extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) + # internal data for different storage backends: _custom_meta[index][field] + _custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) def __post_init__(self): """Initialize all computed properties during initialization""" @@ -189,7 +193,11 @@ def __post_init__(self): for idx, sample in enumerate(self.samples): object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly - object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) + object.__setattr__( + self, + "_global_indexes", + [sample.global_index for sample in self.samples], + ) # check if all samples have the same field names first_sample_field_names = sorted(self.samples[0].field_names) @@ -230,6 +238,23 @@ def partition_ids(self) -> list[str]: """Get partition ids for all samples in this batch as a list (one per sample)""" return getattr(self, "_partition_ids", []) + # Custom meta methods for different storage backends + def get_custom_meta_list(self) -> list[Any]: + """Get required custom meta as a list""" + return [ + self._custom_meta.get(index, {}).get(field_name, None) + for field_name, index in itertools.product(sorted(self.field_names), range(self.size)) + ] + + def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: + """Get the entire custom meta dictionary""" + return copy.deepcopy(self._custom_meta) + + def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None): + """Update custom meta with a new dictionary""" + if new_custom_meta: + self._custom_meta.update(new_custom_meta) + # Extra info interface methods def get_extra_info(self, key: str, default: Any = None) -> Any: """Get extra info by key""" @@ -529,7 +554,9 @@ def _update_after_reorder(self) -> None: @classmethod def from_samples( - cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None + cls, + samples: SampleMeta | list[SampleMeta], + extra_info: Optional[dict[str, Any]] = None, ) -> "BatchMeta": """ Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects. diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index fc5677d..fed031d 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -14,6 +14,7 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Any, Optional from torch import Tensor @@ -25,11 +26,19 @@ class TransferQueueStorageKVClient(ABC): """ @abstractmethod - def put(self, keys: list[str], values: list[Tensor]) -> None: + def put(self, keys: list[str], values: list[Tensor]) -> Optional[list[Any]]: + """ + Store key-value pairs in the storage backend. + Args: + keys (list[str]): List of keys to store. + values (list[Tensor]): List of tensor values to store. + Returns: + Optional[list[Any]]: Optional list of custom metadata from each storage backend. + """ raise NotImplementedError("Subclasses must implement put") @abstractmethod - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Tensor]: raise NotImplementedError("Subclasses must implement get") @abstractmethod diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index a71262b..80efa09 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -1,7 +1,7 @@ import logging import os import pickle -from typing import Any +from typing import Any, Optional import torch from torch import Tensor @@ -53,7 +53,7 @@ def __init__(self, config: dict[str, Any]): if ret != 0: raise RuntimeError(f"Mooncake store setup failed with error code: {ret}") - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: if not isinstance(keys, list) or not isinstance(values, list): raise ValueError("keys and values must be lists") if len(keys) != len(values): @@ -82,6 +82,8 @@ def put(self, keys: list[str], values: list[Any]): if non_tensor_keys: self._batch_put_bytes(non_tensor_keys, non_tensor_values) + return None + def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]): for i in range(0, len(keys), BATCH_SIZE_LIMIT): batch_keys = keys[i : i + BATCH_SIZE_LIMIT] @@ -104,7 +106,7 @@ def _batch_put_bytes(self, keys: list[str], values: list[bytes]): if ret != 0: raise RuntimeError(f"put_batch failed with error code: {ret}") - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: if shapes is None or dtypes is None: raise ValueError("MooncakeStorageClient needs shapes and dtypes") if not (len(keys) == len(shapes) == len(dtypes)): diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 4652314..6045abc 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -16,7 +16,7 @@ import logging import os import pickle -from typing import Any +from typing import Any, Optional import torch from torch import Tensor @@ -177,7 +177,7 @@ def put(self, keys: list[str], values: list[Any]): raise ValueError("Number of keys must match number of values") self._batch_put(keys, values) - def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: + def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> Optional[list[Any]]: """Retrieves a batch of values from remote storage using expected metadata. NPU tensors are fetched via DsTensorClient using pre-allocated buffers. @@ -262,7 +262,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: idx += 1 return results - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. Requires shape and dtype hints to reconstruct NPU tensors correctly. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 64927fd..44252d7 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -28,7 +28,12 @@ from transfer_queue.metadata import BatchMeta from transfer_queue.storage.clients.factory import StorageClientFactory -from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket +from transfer_queue.utils.zmq_utils import ( + ZMQMessage, + ZMQRequestType, + ZMQServerInfo, + create_zmq_socket, +) logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -185,6 +190,7 @@ async def notify_data_update( global_indexes: list[int], dtypes: dict[int, dict[str, Any]], shapes: dict[int, dict[str, Any]], + custom_meta: dict[int, dict[str, Any]] = None, ) -> None: """ Notify controller that new data is ready. @@ -195,6 +201,7 @@ async def notify_data_update( global_indexes: Data update related global_indexes. dtypes: Per-field dtypes for each field, in {global_index: {field: dtype}} format. shapes: Per-field shapes for each field, in {global_index: {field: shape}} format. + custom_meta: Per-field custom_meta for each field, in {global_index: {field: custom_meta}} format. """ # Create zmq poller for notifying data update information @@ -218,6 +225,7 @@ async def notify_data_update( "global_indexes": global_indexes, "dtypes": dtypes, "shapes": shapes, + "custom_meta": custom_meta, }, ).serialize() @@ -322,6 +330,19 @@ def __init__(self, config: dict[str, Any]): super().__init__(config) self.storage_client = StorageClientFactory.create(client_name, config) + @staticmethod + def _generate_key(field_name: str, global_index: int) -> str: + """ + Generate a KV key in the format 'global_index@field_name'. + + Args: + field_name : Name of the field. + global_index : Global index of the sample. + Returns: + str: Generated key, e.g., '0@field_a' + """ + return f"{global_index}@{field_name}" + @staticmethod def _generate_keys(field_names: list[str], global_indexes: list[int]) -> list[str]: """ @@ -335,7 +356,10 @@ def _generate_keys(field_names: list[str], global_indexes: list[int]) -> list[st Returns: list[str]: List of keys, e.g., ['0@field_a', '1@field_a', '0@field_b', ...] """ - return [f"{index}@{field}" for field, index in itertools.product(sorted(field_names), global_indexes)] + return [ + KVStorageManager._generate_key(field, index) + for field, index in itertools.product(sorted(field_names), global_indexes) + ] @staticmethod def _generate_values(data: TensorDict) -> list[Tensor]: @@ -405,16 +429,16 @@ def _merge_tensors_to_tensordict(metadata: BatchMeta, values: list[Tensor]) -> T return TensorDict(merged_data, batch_size=len(global_indexes)) @staticmethod - def _get_shape_type_list(metadata: BatchMeta): + def _get_shape_type_custom_meta_list(metadata: BatchMeta): """ - Extract the expected shape and dtype for each field-sample pair in metadata. + Extract the expected shape, dtype, and custom meta for each field-sample pair in metadata. The order matches the key/value order: sorted by field name, then by global index. Args: metadata (BatchMeta): Metadata containing sample and field information. Returns: - tuple[list[torch.Size], list[torch.dtype]]: Two lists containing the shape and dtype - for each tensor to be retrieved. + tuple[list[torch.Size], list[torch.dtype], list[Any]]: the shape list, dtype list and + custom meta list for each tensor to be retrieved. """ shapes = [] dtypes = [] @@ -423,7 +447,8 @@ def _get_shape_type_list(metadata: BatchMeta): field = metadata.samples[index].get_field_by_name(field_name) shapes.append(field.shape) dtypes.append(field.dtype) - return shapes, dtypes + custom_meta_list = metadata.get_custom_meta_list() + return shapes, dtypes, custom_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ @@ -445,7 +470,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: keys = self._generate_keys(data.keys(), metadata.global_indexes) values = self._generate_values(data) loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self.storage_client.put, keys, values) + custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) per_field_dtypes = {} per_field_shapes = {} @@ -466,13 +491,35 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None ) + # Prepare per-field custom_meta if available + per_field_custom_meta = {} + if custom_meta: + if len(custom_meta) != len(keys): + raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})") + # custom meta is a flat list aligned with keys/values + # Use itertools.product to eliminate nested loops + for (global_idx, field_name), meta_value in zip( + itertools.product(metadata.global_indexes, metadata.field_names), + custom_meta, + strict=False, + ): + if global_idx not in per_field_custom_meta: + per_field_custom_meta[global_idx] = {} + per_field_custom_meta[global_idx][field_name] = meta_value + metadata.update_custom_meta(per_field_custom_meta) + # Get current data partition id # Note: Currently we only support putting to & getting data from a single data partition simultaneously, # but in the future we may support putting to & getting data from multiple data partitions concurrently. partition_id = metadata.samples[0].partition_id # notify controller that new data is ready await self.notify_data_update( - partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes + partition_id, + list(data.keys()), + metadata.global_indexes, + per_field_dtypes, + per_field_shapes, + per_field_custom_meta, ) async def get_data(self, metadata: BatchMeta) -> TensorDict: @@ -486,8 +533,8 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: logger.warning("Attempted to get data, but metadata contains no fields.") return TensorDict({}, batch_size=len(metadata)) keys = self._generate_keys(metadata.field_names, metadata.global_indexes) - shapes, dtypes = self._get_shape_type_list(metadata) - values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes) + shapes, dtypes, custom_meta = self._get_shape_type_custom_meta_list(metadata) + values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes, custom_meta=custom_meta) return self._merge_tensors_to_tensordict(metadata, values) async def clear_data(self, metadata: BatchMeta) -> None: From c267892b945868674f1d99efeec48a1eb1d687d6 Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Wed, 21 Jan 2026 18:41:34 +0800 Subject: [PATCH 2/8] 1. add custom_meta to controller 2. allow kv storage manager put returns their custom metadata per index per field Signed-off-by: tianyi-ge --- transfer_queue/controller.py | 50 ++++++++++++++++--- transfer_queue/metadata.py | 19 +++++++ transfer_queue/storage/clients/base.py | 13 ++++- .../storage/clients/mooncake_client.py | 8 +-- .../storage/clients/yuanrong_client.py | 9 ++-- transfer_queue/storage/managers/base.py | 44 ++++++++++++---- 6 files changed, 119 insertions(+), 24 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 3eab84b..eda521c 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -230,6 +230,7 @@ class DataPartitionStatus: field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype} field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape} + field_custom_metas: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: custom_meta} # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. # No need to strictly lock for every read/write operation since freshness is not critical. @@ -326,6 +327,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]], ) -> bool: """ Update production status for specific samples and fields. @@ -336,6 +338,7 @@ def update_production_status( field_names: List of field names to mark as produced dtypes: Optional per-sample field dtype information shapes: Optional per-sample field shape information + custom_meta: Optional per-sample field custom metadata Returns: True if update was successful, False on error @@ -366,7 +369,7 @@ def update_production_status( self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 # Update field metadata - self._update_field_metadata(global_indices, dtypes, shapes) + self._update_field_metadata(global_indices, dtypes, shapes, custom_meta) # Save these global_indexes self.global_indexes.update(global_indices) @@ -380,8 +383,9 @@ def update_production_status( def _update_field_metadata( self, global_indices: list[int], - dtypes: Optional[dict[int, dict[str, Any]]], - shapes: Optional[dict[int, dict[str, Any]]], + dtypes: dict[int, dict[str, Any]], + shapes: dict[int, dict[str, Any]], + custom_meta: Optional[dict[int, dict[str, Any]]], ): """Update field dtype and shape metadata.""" if not global_indices: @@ -409,6 +413,21 @@ def _update_field_metadata( if shape_value is not None: self.field_shapes[global_idx].update(shape_value[i]) + if custom_meta: + if len(global_indices) != len(custom_meta): + raise ValueError( + f"Length of global_indices ({len(global_indices)}) does not match " + f"length of custom_meta ({len(custom_meta)})" + ) + custom_meta_value = itemgetter(*global_indices)(custom_meta) if custom_meta else None + if not isinstance(custom_meta_value, tuple): + custom_meta_value = (custom_meta_value,) + for i, global_idx in enumerate(global_indices): + if global_idx not in self.field_custom_metas: + self.field_custom_metas[global_idx] = {} + if custom_meta_value is not None: + self.field_custom_metas[global_idx].update(custom_meta_value[i]) + # ==================== Consumption Status Interface ==================== def get_consumption_status(self, task_name: str) -> torch.Tensor: @@ -544,6 +563,14 @@ def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: """Get shape for a specific sample and field.""" return self.field_shapes.get(global_index, {}).get(field_name) + def get_field_custom_meta(self, global_indices: list[int], field_names: list[str]) -> Optional[Any]: + """Get custom_meta for a specific sample and field.""" + return { + idx: {f: v for f, v in self.field_custom_metas[idx].items() if f in field_names} + for idx in global_indices + if idx in self.field_custom_metas + } + # ==================== Statistics and Monitoring ==================== def get_statistics(self) -> dict[str, Any]: @@ -632,6 +659,9 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr consumption_tensor[indexes_to_release] = 0 self.global_indexes.difference_update(indexes_to_release) + self.field_dtypes.difference_update(indexes_to_release) + self.field_shapes.difference_update(indexes_to_release) + self.field_custom_metas.difference_update(indexes_to_release) except Exception as e: logger.error( @@ -658,7 +688,9 @@ class TransferQueueController: """ def __init__( - self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler, polling_mode: bool = False + self, + sampler: BaseSampler | type[BaseSampler] = SequentialSampler, + polling_mode: bool = False, ) -> None: """Initialize the TransferQueue Controller. @@ -791,6 +823,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]], ) -> bool: """ Update production status for specific samples and fields in a partition. @@ -811,7 +844,7 @@ def update_production_status( logger.error(f"Partition {partition_id} not found") return False - success = partition.update_production_status(global_indexes, field_names, dtypes, shapes) + success = partition.update_production_status(global_indexes, field_names, dtypes, shapes, custom_meta) if success: logger.debug( f"[{self.controller_id}]: Updated production status for partition {partition_id}: " @@ -1070,7 +1103,11 @@ def generate_batch_meta( ) samples.append(sample) - return BatchMeta(samples=samples) + custom_meta = partition.get_field_custom_meta(batch_global_indexes, data_fields) + + batch_meta = BatchMeta(samples=samples) + batch_meta.update_custom_meta(custom_meta) + return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): """ @@ -1408,6 +1445,7 @@ def _update_data_status(self): field_names=message_data.get("fields", []), dtypes=message_data.get("dtypes", {}), shapes=message_data.get("shapes", {}), + custom_meta=message_data.get("custom_meta", {}), ) if success: diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index c389157..417fe58 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -174,6 +174,8 @@ class BatchMeta: samples: list[SampleMeta] extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) + # internal data for different storage backends: _custom_meta[index][field] + _custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) def __post_init__(self): """Initialize all computed properties during initialization""" @@ -230,6 +232,23 @@ def partition_ids(self) -> list[str]: """Get partition ids for all samples in this batch as a list (one per sample)""" return getattr(self, "_partition_ids", []) + # Custom meta methods for different storage backends + def get_custom_meta_list(self) -> list[Any]: + """Get required custom meta as a list""" + return [ + self._custom_meta.get(index, {}).get(field_name, None) + for field_name, index in itertools.product(sorted(self.field_names), range(self.size)) + ] + + def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: + """Get the entire custom meta dictionary""" + return copy.deepcopy(self._custom_meta) + + def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None): + """Update custom meta with a new dictionary""" + if new_custom_meta: + self._custom_meta.update(new_custom_meta) + # Extra info interface methods def get_extra_info(self, key: str, default: Any = None) -> Any: """Get extra info by key""" diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index fc5677d..fed031d 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -14,6 +14,7 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Any, Optional from torch import Tensor @@ -25,11 +26,19 @@ class TransferQueueStorageKVClient(ABC): """ @abstractmethod - def put(self, keys: list[str], values: list[Tensor]) -> None: + def put(self, keys: list[str], values: list[Tensor]) -> Optional[list[Any]]: + """ + Store key-value pairs in the storage backend. + Args: + keys (list[str]): List of keys to store. + values (list[Tensor]): List of tensor values to store. + Returns: + Optional[list[Any]]: Optional list of custom metadata from each storage backend. + """ raise NotImplementedError("Subclasses must implement put") @abstractmethod - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Tensor]: raise NotImplementedError("Subclasses must implement get") @abstractmethod diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index a71262b..80efa09 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -1,7 +1,7 @@ import logging import os import pickle -from typing import Any +from typing import Any, Optional import torch from torch import Tensor @@ -53,7 +53,7 @@ def __init__(self, config: dict[str, Any]): if ret != 0: raise RuntimeError(f"Mooncake store setup failed with error code: {ret}") - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: if not isinstance(keys, list) or not isinstance(values, list): raise ValueError("keys and values must be lists") if len(keys) != len(values): @@ -82,6 +82,8 @@ def put(self, keys: list[str], values: list[Any]): if non_tensor_keys: self._batch_put_bytes(non_tensor_keys, non_tensor_values) + return None + def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]): for i in range(0, len(keys), BATCH_SIZE_LIMIT): batch_keys = keys[i : i + BATCH_SIZE_LIMIT] @@ -104,7 +106,7 @@ def _batch_put_bytes(self, keys: list[str], values: list[bytes]): if ret != 0: raise RuntimeError(f"put_batch failed with error code: {ret}") - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: if shapes is None or dtypes is None: raise ValueError("MooncakeStorageClient needs shapes and dtypes") if not (len(keys) == len(shapes) == len(dtypes)): diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 4652314..0956a56 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -16,7 +16,7 @@ import logging import os import pickle -from typing import Any +from typing import Any, Optional import torch from torch import Tensor @@ -161,7 +161,7 @@ def _batch_put(self, keys: list[str], values: list[Any]): batch_vals = pickled_values[i : i + CPU_DS_CLIENT_KEYS_LIMIT] self._cpu_ds_client.mset(batch_keys, batch_vals) - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: """Stores multiple key-value pairs to remote storage. Automatically routes NPU tensors to high-performance tensor storage, @@ -176,8 +176,9 @@ def put(self, keys: list[str], values: list[Any]): if len(keys) != len(values): raise ValueError("Number of keys must match number of values") self._batch_put(keys, values) + return None - def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: + def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> Optional[list[Any]]: """Retrieves a batch of values from remote storage using expected metadata. NPU tensors are fetched via DsTensorClient using pre-allocated buffers. @@ -262,7 +263,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: idx += 1 return results - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. Requires shape and dtype hints to reconstruct NPU tensors correctly. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 64927fd..3cf5bf9 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -185,6 +185,7 @@ async def notify_data_update( global_indexes: list[int], dtypes: dict[int, dict[str, Any]], shapes: dict[int, dict[str, Any]], + custom_meta: dict[int, dict[str, Any]] = None, ) -> None: """ Notify controller that new data is ready. @@ -195,6 +196,7 @@ async def notify_data_update( global_indexes: Data update related global_indexes. dtypes: Per-field dtypes for each field, in {global_index: {field: dtype}} format. shapes: Per-field shapes for each field, in {global_index: {field: shape}} format. + custom_meta: Per-field custom_meta for each field, in {global_index: {field: custom_meta}} format. """ # Create zmq poller for notifying data update information @@ -218,6 +220,7 @@ async def notify_data_update( "global_indexes": global_indexes, "dtypes": dtypes, "shapes": shapes, + "custom_meta": custom_meta, }, ).serialize() @@ -405,16 +408,16 @@ def _merge_tensors_to_tensordict(metadata: BatchMeta, values: list[Tensor]) -> T return TensorDict(merged_data, batch_size=len(global_indexes)) @staticmethod - def _get_shape_type_list(metadata: BatchMeta): + def _get_shape_type_custom_meta_list(metadata: BatchMeta): """ - Extract the expected shape and dtype for each field-sample pair in metadata. + Extract the expected shape, dtype, and custom meta for each field-sample pair in metadata. The order matches the key/value order: sorted by field name, then by global index. Args: metadata (BatchMeta): Metadata containing sample and field information. Returns: - tuple[list[torch.Size], list[torch.dtype]]: Two lists containing the shape and dtype - for each tensor to be retrieved. + tuple[list[torch.Size], list[torch.dtype], list[Any]]: the shape list, dtype list and + custom meta list for each tensor to be retrieved. """ shapes = [] dtypes = [] @@ -423,7 +426,8 @@ def _get_shape_type_list(metadata: BatchMeta): field = metadata.samples[index].get_field_by_name(field_name) shapes.append(field.shape) dtypes.append(field.dtype) - return shapes, dtypes + custom_meta_list = metadata.get_custom_meta_list() + return shapes, dtypes, custom_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ @@ -445,7 +449,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: keys = self._generate_keys(data.keys(), metadata.global_indexes) values = self._generate_values(data) loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self.storage_client.put, keys, values) + custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) per_field_dtypes = {} per_field_shapes = {} @@ -466,13 +470,35 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None ) + # Prepare per-field custom_meta if available + per_field_custom_meta = {} + if custom_meta: + if len(custom_meta) != len(keys): + raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})") + # custom meta is a flat list aligned with keys/values + # Use itertools.product to eliminate nested loops + for (global_idx, field_name), meta_value in zip( + itertools.product(metadata.global_indexes, metadata.field_names), + custom_meta, + strict=False, + ): + if global_idx not in per_field_custom_meta: + per_field_custom_meta[global_idx] = {} + per_field_custom_meta[global_idx][field_name] = meta_value + metadata.update_custom_meta(per_field_custom_meta) + # Get current data partition id # Note: Currently we only support putting to & getting data from a single data partition simultaneously, # but in the future we may support putting to & getting data from multiple data partitions concurrently. partition_id = metadata.samples[0].partition_id # notify controller that new data is ready await self.notify_data_update( - partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes + partition_id, + list(data.keys()), + metadata.global_indexes, + per_field_dtypes, + per_field_shapes, + per_field_custom_meta, ) async def get_data(self, metadata: BatchMeta) -> TensorDict: @@ -486,8 +512,8 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: logger.warning("Attempted to get data, but metadata contains no fields.") return TensorDict({}, batch_size=len(metadata)) keys = self._generate_keys(metadata.field_names, metadata.global_indexes) - shapes, dtypes = self._get_shape_type_list(metadata) - values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes) + shapes, dtypes, custom_meta = self._get_shape_type_custom_meta_list(metadata) + values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes, custom_meta=custom_meta) return self._merge_tensors_to_tensordict(metadata, values) async def clear_data(self, metadata: BatchMeta) -> None: From 9a553b936a7c52c68c4895a566515a52ad4fd125 Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Wed, 21 Jan 2026 23:47:08 +0800 Subject: [PATCH 3/8] 1. add tests for controller with custom meta 2. add tests for kv storage with custom meta 3. update nits Signed-off-by: tianyi-ge --- tests/test_controller.py | 462 ++++++++++++++++++ tests/test_controller_data_partitions.py | 185 ++++++- tests/test_kv_storage_manager.py | 228 +++++++++ tests/test_metadata.py | 167 +++++++ transfer_queue/controller.py | 39 +- transfer_queue/metadata.py | 9 +- transfer_queue/storage/clients/base.py | 25 +- .../storage/clients/ray_storage_client.py | 7 +- .../storage/clients/yuanrong_client.py | 1 + transfer_queue/storage/managers/base.py | 13 +- 10 files changed, 1090 insertions(+), 46 deletions(-) diff --git a/tests/test_controller.py b/tests/test_controller.py index 360b7f0..35a95f1 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -382,3 +382,465 @@ def test_controller_clear_meta(self, ray_setup): assert set(partition_after.global_indexes) == set([4, 5, 7]) print("✓ Clear meta correct") + + +class TestCustomMeta: + """Test suite for custom_meta functionality in TransferQueueController""" + + def test_custom_meta_basic_storage_and_retrieval(self, ray_setup): + """Test basic custom_meta storage via update_production_status and retrieval via get_metadata""" + gbs = 4 + partition_id = "test_custom_meta_basic" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode + data_fields = ["prompt_ids", "attention_mask"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs, + partition_id=partition_id, + mode="insert", + ) + ) + + assert metadata.global_indexes == list(range(gbs)) + + # Update production status with custom_meta + dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes} + custom_meta = { + k: {"prompt_ids": {"token_count": 100 + k}, "attention_mask": {"mask_ratio": 0.1 * k}} + for k in metadata.global_indexes + } + + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_meta, + ) + ) + assert success + + # Verify custom_meta is stored in partition + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert partition is not None + assert len(partition.field_custom_metas) == gbs + + for idx in metadata.global_indexes: + assert idx in partition.field_custom_metas + assert "prompt_ids" in partition.field_custom_metas[idx] + assert "attention_mask" in partition.field_custom_metas[idx] + assert partition.field_custom_metas[idx]["prompt_ids"]["token_count"] == 100 + idx + assert partition.field_custom_metas[idx]["attention_mask"]["mask_ratio"] == 0.1 * idx + + print("✓ Basic custom_meta storage correct") + + # Retrieve via get_metadata in fetch mode and verify custom_meta is in batch_meta + fetch_meta = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs, + partition_id=partition_id, + mode="fetch", + task_name="test_task", + ) + ) + + assert fetch_meta is not None + custom_meta_retrieved = fetch_meta.get_all_custom_meta() + assert custom_meta_retrieved is not None + + for idx in metadata.global_indexes: + assert idx in custom_meta_retrieved + assert "prompt_ids" in custom_meta_retrieved[idx] + assert "attention_mask" in custom_meta_retrieved[idx] + + print("✓ Basic custom_meta retrieval via get_metadata correct") + + def test_custom_meta_with_partial_fields(self, ray_setup): + """Test custom_meta retrieval when only requesting subset of fields""" + gbs = 4 + partition_id = "test_custom_meta_partial" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode with multiple fields + data_fields = ["prompt_ids", "attention_mask", "labels"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs, + partition_id=partition_id, + mode="insert", + ) + ) + + # Update production status with custom_meta for all fields + dtypes = { + k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool", "labels": "torch.int64"} + for k in metadata.global_indexes + } + shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,), "labels": (32,)} for k in metadata.global_indexes} + custom_meta = { + k: { + "prompt_ids": {"meta_prompt": f"prompt_{k}"}, + "attention_mask": {"meta_mask": f"mask_{k}"}, + "labels": {"meta_label": f"label_{k}"}, + } + for k in metadata.global_indexes + } + + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_meta, + ) + ) + assert success + + # Fetch with only a subset of fields + subset_fields = ["prompt_ids", "labels"] + fetch_meta = ray.get( + tq_controller.get_metadata.remote( + data_fields=subset_fields, + batch_size=gbs, + partition_id=partition_id, + mode="fetch", + task_name="test_task", + ) + ) + + assert fetch_meta is not None + custom_meta_retrieved = fetch_meta.get_all_custom_meta() + assert custom_meta_retrieved is not None + + # Verify only requested fields are in custom_meta + for idx in metadata.global_indexes: + assert idx in custom_meta_retrieved + assert "prompt_ids" in custom_meta_retrieved[idx] + assert "labels" in custom_meta_retrieved[idx] + # attention_mask should not be in the custom_meta since it wasn't requested + assert "attention_mask" not in custom_meta_retrieved[idx] + + print("✓ Custom_meta with partial fields correct") + + def test_custom_meta_length_mismatch_returns_false(self, ray_setup): + """Test that custom_meta length mismatch with global_indices returns False""" + gbs = 4 + partition_id = "test_custom_meta_mismatch" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode + data_fields = ["prompt_ids"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs, + partition_id=partition_id, + mode="insert", + ) + ) + + # Prepare mismatched custom_meta (fewer entries than global_indexes) + dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} + # Only provide custom_meta for half the samples + custom_meta = {k: {"prompt_ids": {"meta": k}} for k in metadata.global_indexes[:2]} + + # The method should return False when there's a length mismatch + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_meta, + ) + ) + assert success is False, "Expected update_production_status to return False for length mismatch" + + print("✓ Custom_meta length mismatch error handling correct") + + def test_custom_meta_none_does_not_store(self, ray_setup): + """Test that passing None for custom_meta doesn't create custom_meta entries""" + gbs = 4 + partition_id = "test_custom_meta_none" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode + data_fields = ["prompt_ids"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs, + partition_id=partition_id, + mode="insert", + ) + ) + + # Update production status without custom_meta (None) + dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} + + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=None, + ) + ) + assert success + + # Verify no custom_meta is stored + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert partition is not None + assert len(partition.field_custom_metas) == 0 + + print("✓ Custom_meta None handling correct") + + def test_custom_meta_preserved_after_partial_clear(self, ray_setup): + """Test that custom_meta for non-cleared samples is preserved after clear_meta""" + gbs = 4 + partition_id = "test_custom_meta_clear" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode + data_fields = ["prompt_ids"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs, + partition_id=partition_id, + mode="insert", + ) + ) + + # Update production status with custom_meta + dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} + custom_meta = {k: {"prompt_ids": {"sample_id": k * 10}} for k in metadata.global_indexes} + + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_meta, + ) + ) + assert success + + # Clear only first 2 samples + global_indexes_to_clear = [0, 1] + partition_ids_to_clear = [partition_id] * len(global_indexes_to_clear) + + ray.get( + tq_controller.clear_meta.remote( + global_indexes=global_indexes_to_clear, + partition_ids=partition_ids_to_clear, + ) + ) + + # Verify custom_meta is cleared for cleared samples and preserved for others + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert partition is not None + + # Cleared samples should not have custom_meta + assert 0 not in partition.field_custom_metas + assert 1 not in partition.field_custom_metas + + # Non-cleared samples should still have custom_meta + assert 2 in partition.field_custom_metas + assert 3 in partition.field_custom_metas + assert partition.field_custom_metas[2]["prompt_ids"]["sample_id"] == 20 + assert partition.field_custom_metas[3]["prompt_ids"]["sample_id"] == 30 + + print("✓ Custom_meta preserved after partial clear correct") + + def test_custom_meta_update_merges_values(self, ray_setup): + """Test that updating custom_meta for the same sample merges values""" + gbs = 2 + partition_id = "test_custom_meta_merge" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode with first field + data_fields_1 = ["prompt_ids"] + metadata_1 = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields_1, + batch_size=gbs, + partition_id=partition_id, + mode="insert", + ) + ) + + # First update with custom_meta for prompt_ids + dtypes_1 = {k: {"prompt_ids": "torch.int64"} for k in metadata_1.global_indexes} + shapes_1 = {k: {"prompt_ids": (32,)} for k in metadata_1.global_indexes} + custom_meta_1 = {k: {"prompt_ids": {"first_update": True}} for k in metadata_1.global_indexes} + + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata_1.global_indexes, + field_names=metadata_1.field_names, + dtypes=dtypes_1, + shapes=shapes_1, + custom_meta=custom_meta_1, + ) + ) + assert success + + # Second update with new field and its custom_meta + data_fields_2 = ["attention_mask"] + dtypes_2 = {k: {"attention_mask": "torch.bool"} for k in metadata_1.global_indexes} + shapes_2 = {k: {"attention_mask": (32,)} for k in metadata_1.global_indexes} + custom_meta_2 = {k: {"attention_mask": {"second_update": True}} for k in metadata_1.global_indexes} + + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata_1.global_indexes, + field_names=data_fields_2, + dtypes=dtypes_2, + shapes=shapes_2, + custom_meta=custom_meta_2, + ) + ) + assert success + + # Verify both custom_meta entries are present (merged) + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert partition is not None + + for idx in metadata_1.global_indexes: + assert idx in partition.field_custom_metas + assert "prompt_ids" in partition.field_custom_metas[idx] + assert "attention_mask" in partition.field_custom_metas[idx] + assert partition.field_custom_metas[idx]["prompt_ids"]["first_update"] is True + assert partition.field_custom_metas[idx]["attention_mask"]["second_update"] is True + + print("✓ Custom_meta merge on update correct") + + def test_custom_meta_with_complex_nested_data(self, ray_setup): + """Test custom_meta with complex nested data structures""" + gbs = 2 + partition_id = "test_custom_meta_complex" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode + data_fields = ["prompt_ids"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs, + partition_id=partition_id, + mode="insert", + ) + ) + + # Create complex nested custom_meta + dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} + custom_meta = { + k: { + "prompt_ids": { + "nested_dict": {"level1": {"level2": {"value": k}}}, + "list_data": [1, 2, 3, k], + "mixed_types": {"string": "test", "number": 42, "float": 3.14, "bool": True}, + } + } + for k in metadata.global_indexes + } + + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_meta, + ) + ) + assert success + + # Verify complex nested data is preserved + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert partition is not None + + for idx in metadata.global_indexes: + stored_meta = partition.field_custom_metas[idx]["prompt_ids"] + assert stored_meta["nested_dict"]["level1"]["level2"]["value"] == idx + assert stored_meta["list_data"] == [1, 2, 3, idx] + assert stored_meta["mixed_types"]["string"] == "test" + assert stored_meta["mixed_types"]["number"] == 42 + assert stored_meta["mixed_types"]["float"] == 3.14 + assert stored_meta["mixed_types"]["bool"] is True + + print("✓ Complex nested custom_meta correct") + + def test_custom_meta_cleared_on_partition_clear(self, ray_setup): + """Test that custom_meta is fully cleared when partition is cleared""" + gbs = 4 + partition_id = "test_custom_meta_partition_clear" + + tq_controller = TransferQueueController.remote() + + # Create metadata in insert mode + data_fields = ["prompt_ids"] + metadata = ray.get( + tq_controller.get_metadata.remote( + data_fields=data_fields, + batch_size=gbs, + partition_id=partition_id, + mode="insert", + ) + ) + + # Update production status with custom_meta + dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} + shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} + custom_meta = {k: {"prompt_ids": {"data": k}} for k in metadata.global_indexes} + + success = ray.get( + tq_controller.update_production_status.remote( + partition_id=partition_id, + global_indexes=metadata.global_indexes, + field_names=metadata.field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_meta, + ) + ) + assert success + + # Clear the entire partition + ray.get(tq_controller.clear_partition.remote(partition_id)) + + # Verify partition is gone + partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) + assert partition is None + + print("✓ Custom_meta cleared on partition clear correct") diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index e018fb3..49cb715 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -63,6 +63,7 @@ def test_data_partition_status(): 1: {"input_ids": (512,), "attention_mask": (512,)}, 2: {"input_ids": (512,), "attention_mask": (512,)}, }, + custom_meta=None, ) assert success @@ -172,6 +173,7 @@ def test_dynamic_expansion_scenarios(): 5: {"field_1": (32,)}, 10: {"field_1": (32,)}, }, + custom_meta=None, ) assert partition.total_samples_num == 3 assert partition.allocated_samples_num >= 11 # Should accommodate index 10 @@ -180,7 +182,7 @@ def test_dynamic_expansion_scenarios(): # Scenario 2: Adding many fields dynamically for i in range(15): partition.update_production_status( - [0], [f"field_{i}"], {0: {f"field_{i}": "torch.bool"}}, {0: {f"field_{i}": (32,)}} + [0], [f"field_{i}"], {0: {f"field_{i}": "torch.bool"}}, {0: {f"field_{i}": (32,)}}, None ) assert partition.total_fields_num == 16 # Original + 15 new fields @@ -222,7 +224,7 @@ def test_data_partition_status_advanced(): # Add data to trigger expansion dtypes = {i: {f"dynamic_field_{s}": "torch.bool" for s in ["a", "b", "c"]} for i in range(5)} shapes = {i: {f"dynamic_field_{s}": (32,) for s in ["a", "b", "c"]} for i in range(5)} - partition.update_production_status([0, 1, 2, 3, 4], ["field_a", "field_b", "field_c"], dtypes, shapes) + partition.update_production_status([0, 1, 2, 3, 4], ["field_a", "field_b", "field_c"], dtypes, shapes, None) # Properties should reflect current state assert partition.total_samples_num >= 5 # At least 5 samples @@ -253,7 +255,7 @@ def test_data_partition_status_advanced(): 11: {"field_d": (32,)}, 12: {"field_d": (32,)}, } - partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes) # Triggers sample expansion + partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes, None) # Triggers sample expansion expanded_consumption = partition.get_consumption_status(task_name) assert expanded_consumption[0] == 1 # Preserved assert expanded_consumption[1] == 1 # Preserved @@ -265,13 +267,13 @@ def test_data_partition_status_advanced(): # Start with some fields dtypes = {0: {"initial_field": "torch.bool"}} shapes = {0: {"field_d": (32,)}} - partition.update_production_status([0], ["initial_field"], dtypes, shapes) + partition.update_production_status([0], ["initial_field"], dtypes, shapes, None) # Add many fields to trigger column expansion new_fields = [f"dynamic_field_{i}" for i in range(20)] dtypes = {1: {f"dynamic_field_{i}": "torch.bool" for i in range(20)}} shapes = {1: {f"dynamic_field_{i}": (32,) for i in range(20)}} - partition.update_production_status([1], new_fields, dtypes, shapes) + partition.update_production_status([1], new_fields, dtypes, shapes, None) # Verify all fields are registered and accessible assert "initial_field" in partition.field_name_mapping @@ -441,3 +443,176 @@ def test_performance_characteristics(): print("✓ Memory usage patterns reasonable") print("Performance characteristics tests passed!\n") + + +def test_custom_meta_in_data_partition_status(): + """Test custom_meta functionality in DataPartitionStatus.""" + print("Testing custom_meta in DataPartitionStatus...") + + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="custom_meta_test") + + # Test 1: Basic custom_meta storage via update_production_status + global_indices = [0, 1, 2] + field_names = ["input_ids", "attention_mask"] + dtypes = { + 0: {"input_ids": "torch.int32", "attention_mask": "torch.bool"}, + 1: {"input_ids": "torch.int32", "attention_mask": "torch.bool"}, + 2: {"input_ids": "torch.int32", "attention_mask": "torch.bool"}, + } + shapes = { + 0: {"input_ids": (512,), "attention_mask": (512,)}, + 1: {"input_ids": (512,), "attention_mask": (512,)}, + 2: {"input_ids": (512,), "attention_mask": (512,)}, + } + custom_meta = { + 0: {"input_ids": {"token_count": 100}, "attention_mask": {"mask_ratio": 0.1}}, + 1: {"input_ids": {"token_count": 200}, "attention_mask": {"mask_ratio": 0.2}}, + 2: {"input_ids": {"token_count": 300}, "attention_mask": {"mask_ratio": 0.3}}, + } + + success = partition.update_production_status( + global_indices=global_indices, + field_names=field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_meta, + ) + + assert success + assert len(partition.field_custom_metas) == 3 + + # Verify custom_meta is stored correctly + assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100 + assert partition.field_custom_metas[1]["attention_mask"]["mask_ratio"] == 0.2 + assert partition.field_custom_metas[2]["input_ids"]["token_count"] == 300 + + print("✓ Basic custom_meta storage works") + + # Test 2: get_field_custom_meta retrieval + retrieved_meta = partition.get_field_custom_meta([0, 1, 2], ["input_ids", "attention_mask"]) + + assert 0 in retrieved_meta + assert 1 in retrieved_meta + assert 2 in retrieved_meta + assert retrieved_meta[0]["input_ids"]["token_count"] == 100 + assert retrieved_meta[1]["attention_mask"]["mask_ratio"] == 0.2 + + print("✓ get_field_custom_meta retrieval works") + + # Test 3: get_field_custom_meta with partial field filter + partial_meta = partition.get_field_custom_meta([0, 1], ["input_ids"]) + + assert 0 in partial_meta + assert 1 in partial_meta + assert "input_ids" in partial_meta[0] + assert "attention_mask" not in partial_meta[0] # Should not include non-requested fields + + print("✓ get_field_custom_meta with partial fields works") + + # Test 4: get_field_custom_meta with non-existent global_index + empty_meta = partition.get_field_custom_meta([999], ["input_ids"]) + assert 999 not in empty_meta # Should not include non-existent indices + + print("✓ get_field_custom_meta handles non-existent indices correctly") + + # Test 5: custom_meta update (merge) on same global_index + additional_custom_meta = { + 0: {"new_field": {"new_key": "new_value"}}, + } + success = partition.update_production_status( + global_indices=[0], + field_names=["new_field"], + dtypes={0: {"new_field": "torch.float32"}}, + shapes={0: {"new_field": (64,)}}, + custom_meta=additional_custom_meta, + ) + + assert success + # Original custom_meta should be preserved + assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100 + # New custom_meta should be merged + assert partition.field_custom_metas[0]["new_field"]["new_key"] == "new_value" + + print("✓ Custom_meta merge on update works") + + # Test 6: custom_meta cleared on clear_data + partition.clear_data([0], clear_consumption=True) + + assert 0 not in partition.field_custom_metas + assert 1 in partition.field_custom_metas # Other samples should remain + assert 2 in partition.field_custom_metas + + print("✓ Custom_meta cleared on clear_data works") + + # Test 7: custom_meta None does not create entries + partition2 = DataPartitionStatus(partition_id="custom_meta_test_2") + success = partition2.update_production_status( + global_indices=[0, 1], + field_names=["field1"], + dtypes={0: {"field1": "torch.int32"}, 1: {"field1": "torch.int32"}}, + shapes={0: {"field1": (32,)}, 1: {"field1": (32,)}}, + custom_meta=None, + ) + + assert success + assert len(partition2.field_custom_metas) == 0 + + print("✓ Custom_meta None handling works") + + # Test 8: custom_meta length mismatch raises ValueError + partition3 = DataPartitionStatus(partition_id="custom_meta_test_3") + mismatched_custom_meta = { + 0: {"field1": {"key": "value"}}, + # Missing entries for 1 and 2 + } + success = partition3.update_production_status( + global_indices=[0, 1, 2], + field_names=["field1"], + dtypes={0: {"field1": "torch.int32"}, 1: {"field1": "torch.int32"}, 2: {"field1": "torch.int32"}}, + shapes={0: {"field1": (32,)}, 1: {"field1": (32,)}, 2: {"field1": (32,)}}, + custom_meta=mismatched_custom_meta, + ) + + # Should return False due to length mismatch (caught by exception handler) + assert success is False + + print("✓ Custom_meta length mismatch error handling works") + + # Test 9: Complex nested custom_meta + partition4 = DataPartitionStatus(partition_id="custom_meta_test_4") + complex_custom_meta = { + 0: { + "field1": { + "nested": {"level1": {"level2": {"value": 42}}}, + "list_data": [1, 2, 3], + "mixed": {"str": "test", "int": 100, "float": 3.14, "bool": True}, + } + }, + } + success = partition4.update_production_status( + global_indices=[0], + field_names=["field1"], + dtypes={0: {"field1": "torch.int32"}}, + shapes={0: {"field1": (32,)}}, + custom_meta=complex_custom_meta, + ) + + assert success + stored_meta = partition4.field_custom_metas[0]["field1"] + assert stored_meta["nested"]["level1"]["level2"]["value"] == 42 + assert stored_meta["list_data"] == [1, 2, 3] + assert stored_meta["mixed"]["str"] == "test" + assert stored_meta["mixed"]["bool"] is True + + print("✓ Complex nested custom_meta storage works") + + # Test 10: custom_meta preserved in snapshot + snapshot = partition4.to_snapshot() + assert 0 in snapshot.field_custom_metas + assert snapshot.field_custom_metas[0]["field1"]["nested"]["level1"]["level2"]["value"] == 42 + + print("✓ Custom_meta preserved in snapshot") + + print("Custom_meta in DataPartitionStatus tests passed!\n") diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 3cfe168..c982227 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import sys import unittest from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch import torch from tensordict import TensorDict @@ -97,6 +99,232 @@ def test_merge_kv_to_tensordict(self): self.assertEqual(reconstructed.batch_size, torch.Size([3])) + def test_get_shape_type_custom_meta_list_without_custom_meta(self): + """Test _get_shape_type_custom_meta_list returns correct shapes and dtypes without custom_meta.""" + shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata) + + # Expected order: sorted by field name (label, mask, text), then by global_index order + # 3 fields * 3 samples = 9 entries + self.assertEqual(len(shapes), 9) + self.assertEqual(len(dtypes), 9) + self.assertEqual(len(custom_meta_list), 9) + + # Check shapes - order is label, mask, text (sorted alphabetically) + # label shapes: [()]*3, mask shapes: [(1,)]*3, text shapes: [(2,)]*3 + expected_shapes = [ + torch.Size([]), # label[0] + torch.Size([]), # label[1] + torch.Size([]), # label[2] + torch.Size([1]), # mask[0] + torch.Size([1]), # mask[1] + torch.Size([1]), # mask[2] + torch.Size([2]), # text[0] + torch.Size([2]), # text[1] + torch.Size([2]), # text[2] + ] + self.assertEqual(shapes, expected_shapes) + + # All dtypes should be torch.int64 + for dtype in dtypes: + self.assertEqual(dtype, torch.int64) + + # No custom_meta provided, so all should be None + for meta in custom_meta_list: + self.assertIsNone(meta) + + def test_get_shape_type_custom_meta_list_with_custom_meta(self): + """Test _get_shape_type_custom_meta_list returns correct custom_meta when provided.""" + # Add custom_meta to metadata + custom_meta = { + 8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, + 9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, + 10: {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, + } + self.metadata.update_custom_meta(custom_meta) + + shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata) + + # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index + expected_custom_meta = [ + {"key2": "value2"}, # label, global_index=8 + {"key5": "value5"}, # label, global_index=9 + {"key8": "value8"}, # label, global_index=10 + {"key3": "value3"}, # mask, global_index=8 + {"key6": "value6"}, # mask, global_index=9 + {"key9": "value9"}, # mask, global_index=10 + {"key1": "value1"}, # text, global_index=8 + {"key4": "value4"}, # text, global_index=9 + {"key7": "value7"}, # text, global_index=10 + ] + self.assertEqual(custom_meta_list, expected_custom_meta) + + def test_get_shape_type_custom_meta_list_with_partial_custom_meta(self): + """Test _get_shape_type_custom_meta_list handles partial custom_meta correctly.""" + # Add custom_meta only for some global_indexes and fields + custom_meta = { + 8: {"text": {"key1": "value1"}}, # Only text field + # global_index 9 has no custom_meta + 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only + } + self.metadata.update_custom_meta(custom_meta) + + shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata) + + # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index + expected_custom_meta = [ + None, # label, global_index=8 (not in custom_meta) + None, # label, global_index=9 (not in custom_meta) + {"key2": "value2"}, # label, global_index=10 + None, # mask, global_index=8 (not in custom_meta) + None, # mask, global_index=9 (not in custom_meta) + {"key3": "value3"}, # mask, global_index=10 + {"key1": "value1"}, # text, global_index=8 + None, # text, global_index=9 (not in custom_meta) + None, # text, global_index=10 (not in custom_meta for text) + ] + self.assertEqual(custom_meta_list, expected_custom_meta) + + +class TestPutDataWithCustomMeta(unittest.TestCase): + """Test put_data with custom_meta functionality.""" + + def setUp(self): + """Set up test fixtures for put_data tests.""" + self.field_names = ["text", "label"] + self.global_indexes = [0, 1, 2] + + # Create test data + self.data = TensorDict( + { + "text": torch.tensor([[1, 2], [3, 4], [5, 6]]), + "label": torch.tensor([0, 1, 2]), + }, + batch_size=3, + ) + + # Create metadata without production status set (for insert mode) + samples = [] + for sample_id in range(self.data.batch_size[0]): + fields_dict = {} + for field_name in self.data.keys(): + tensor = self.data[field_name][sample_id] + field_meta = FieldMeta(name=field_name, dtype=tensor.dtype, shape=tensor.shape, production_status=0) + fields_dict[field_name] = field_meta + sample = SampleMeta( + partition_id="test_partition", + global_index=self.global_indexes[sample_id], + fields=fields_dict, + ) + samples.append(sample) + self.metadata = BatchMeta(samples=samples) + + @patch.object(KVStorageManager, "_connect_to_controller") + @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) + def test_put_data_with_custom_meta_from_storage_client(self, mock_notify, mock_connect): + """Test that put_data correctly processes custom_meta returned by storage client.""" + # Create a mock storage client + mock_storage_client = MagicMock() + # Simulate storage client returning custom_meta (one per key) + # Keys order: label[0,1,2], text[0,1,2] (sorted by field name) + mock_custom_meta = [ + {"storage_key": "0@label"}, + {"storage_key": "1@label"}, + {"storage_key": "2@label"}, + {"storage_key": "0@text"}, + {"storage_key": "1@text"}, + {"storage_key": "2@text"}, + ] + mock_storage_client.put.return_value = mock_custom_meta + + # Create manager with mocked dependencies + config = {"client_name": "MockClient"} + with patch( + "transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client + ): + manager = KVStorageManager(config) + + # Run put_data + asyncio.run(manager.put_data(self.data, self.metadata)) + + # Verify storage client was called with correct keys and values + mock_storage_client.put.assert_called_once() + call_args = mock_storage_client.put.call_args + keys = call_args[0][0] + values = call_args[0][1] + + # Verify keys are correct + expected_keys = ["0@label", "1@label", "2@label", "0@text", "1@text", "2@text"] + self.assertEqual(keys, expected_keys) + self.assertEqual(len(values), 6) + + # Verify notify_data_update was called with correct custom_meta structure + mock_notify.assert_called_once() + notify_call_args = mock_notify.call_args + per_field_custom_meta = notify_call_args[0][5] # 6th positional argument + + # Verify custom_meta is structured correctly: {global_index: {field: meta}} + self.assertIn(0, per_field_custom_meta) + self.assertIn(1, per_field_custom_meta) + self.assertIn(2, per_field_custom_meta) + + self.assertEqual(per_field_custom_meta[0]["label"], {"storage_key": "0@label"}) + self.assertEqual(per_field_custom_meta[0]["text"], {"storage_key": "0@text"}) + self.assertEqual(per_field_custom_meta[1]["label"], {"storage_key": "1@label"}) + self.assertEqual(per_field_custom_meta[1]["text"], {"storage_key": "1@text"}) + self.assertEqual(per_field_custom_meta[2]["label"], {"storage_key": "2@label"}) + self.assertEqual(per_field_custom_meta[2]["text"], {"storage_key": "2@text"}) + + # Verify metadata was updated with custom_meta + all_custom_meta = self.metadata.get_all_custom_meta() + self.assertEqual(all_custom_meta[0]["label"], {"storage_key": "0@label"}) + self.assertEqual(all_custom_meta[2]["text"], {"storage_key": "2@text"}) + + @patch.object(KVStorageManager, "_connect_to_controller") + @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) + def test_put_data_without_custom_meta(self, mock_notify, mock_connect): + """Test that put_data works correctly when storage client returns no custom_meta.""" + # Create a mock storage client that returns None for custom_meta + mock_storage_client = MagicMock() + mock_storage_client.put.return_value = None + + # Create manager with mocked dependencies + config = {"client_name": "MockClient"} + with patch( + "transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client + ): + manager = KVStorageManager(config) + + # Run put_data + asyncio.run(manager.put_data(self.data, self.metadata)) + + # Verify notify_data_update was called with empty dict for custom_meta + mock_notify.assert_called_once() + notify_call_args = mock_notify.call_args + per_field_custom_meta = notify_call_args[0][5] # 6th positional argument + self.assertEqual(per_field_custom_meta, {}) + + @patch.object(KVStorageManager, "_connect_to_controller") + @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) + def test_put_data_custom_meta_length_mismatch_raises_error(self, mock_notify, mock_connect): + """Test that put_data raises ValueError when custom_meta length doesn't match keys.""" + # Create a mock storage client that returns mismatched custom_meta length + mock_storage_client = MagicMock() + # Return only 3 custom_meta entries when 6 are expected + mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}] + + # Create manager with mocked dependencies + config = {"client_name": "MockClient"} + with patch( + "transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client + ): + manager = KVStorageManager(config) + + # Run put_data and expect ValueError + with self.assertRaises(ValueError) as context: + asyncio.run(manager.put_data(self.data, self.metadata)) + + self.assertIn("does not match", str(context.exception)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 23a6a72..2780f31 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -773,3 +773,170 @@ def test_batch_meta_concat_validation_error(self): with pytest.raises(ValueError) as exc_info: BatchMeta.concat([batch1, batch2], validate=True) assert "Field names do not match" in str(exc_info.value) + + +class TestCustomMeta: + """Unit tests for BatchMeta custom meta methods.""" + + def test_get_all_custom_meta_returns_deep_copy(self): + """Test get_all_custom_meta returns a deep copy of the custom meta dict.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + custom_meta = {0: {"field_a": {"nested": "value"}}} + batch.update_custom_meta(custom_meta) + + # Get all custom meta + result = batch.get_all_custom_meta() + + # Verify it's a deep copy - modifying result should not affect original + result[0]["field_a"]["nested"] = "modified" + + original = batch.get_all_custom_meta() + assert original[0]["field_a"]["nested"] == "value" + + def test_get_all_custom_meta_empty(self): + """Test get_all_custom_meta with no custom meta returns empty dict.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + result = batch.get_all_custom_meta() + + assert result == {} + + def test_update_custom_meta_basic(self): + """Test update_custom_meta adds new entries.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Update with custom meta + custom_meta = { + 0: {"field_a": "value_0"}, + 1: {"field_a": "value_1"}, + } + batch.update_custom_meta(custom_meta) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "value_0" + assert result[1]["field_a"] == "value_1" + + def test_update_custom_meta_overwrites_existing(self): + """Test update_custom_meta overwrites existing entries at the top level.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Initial custom meta + batch.update_custom_meta({0: {"field_a": "original"}}) + + # Update with new value - dict.update replaces the entire value for key 0 + batch.update_custom_meta({0: {"field_a": "updated"}}) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "updated" + + def test_update_custom_meta_merges_different_keys(self): + """Test update_custom_meta merges different top-level keys.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # First update + batch.update_custom_meta({0: {"field_a": "value_0"}}) + + # Second update with different key + batch.update_custom_meta({1: {"field_a": "value_1"}}) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "value_0" + assert result[1]["field_a"] == "value_1" + + def test_update_custom_meta_with_none(self): + """Test update_custom_meta with None does nothing.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Set initial value + batch.update_custom_meta({0: {"field_a": "value"}}) + + # Update with None should not change anything + batch.update_custom_meta(None) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "value" + + def test_update_custom_meta_with_empty_dict(self): + """Test update_custom_meta with empty dict does nothing.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Set initial value + batch.update_custom_meta({0: {"field_a": "value"}}) + + # Update with empty dict should not change anything + batch.update_custom_meta({}) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "value" + + def test_custom_meta_with_complex_values(self): + """Test custom meta can store complex values like dicts, lists, tensors.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Store complex values + custom_meta = { + 0: { + "field_a": { + "nested_dict": {"key": "value"}, + "list": [1, 2, 3], + "number": 42, + } + } + } + batch.update_custom_meta(custom_meta) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"]["nested_dict"]["key"] == "value" + assert result[0]["field_a"]["list"] == [1, 2, 3] + assert result[0]["field_a"]["number"] == 42 diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index eda521c..4469632 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -327,7 +327,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], - custom_meta: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> bool: """ Update production status for specific samples and fields. @@ -383,19 +383,21 @@ def update_production_status( def _update_field_metadata( self, global_indices: list[int], - dtypes: dict[int, dict[str, Any]], - shapes: dict[int, dict[str, Any]], + dtypes: Optional[dict[int, dict[str, Any]]], + shapes: Optional[dict[int, dict[str, Any]]], custom_meta: Optional[dict[int, dict[str, Any]]], ): """Update field dtype and shape metadata.""" - if not global_indices: + if not global_indices or not dtypes or not shapes: return - assert len(global_indices) == len(dtypes), "`global_indices` and `dtypes` length mismatch." - assert len(global_indices) == len(shapes), "`global_indices` and `shapes` length mismatch." + if len(global_indices) != len(dtypes): + raise ValueError(f"`global_indices` {len(global_indices)} and `dtypes` {len(dtypes)} length mismatch.") + if len(global_indices) != len(shapes): + raise ValueError(f"`global_indices` {len(global_indices)} and `shapes` {len(shapes)} length mismatch.") - dtype_value = itemgetter(*global_indices)(dtypes) if dtypes else None - shape_value = itemgetter(*global_indices)(shapes) if shapes else None + dtype_value = itemgetter(*global_indices)(dtypes) + shape_value = itemgetter(*global_indices)(shapes) if not isinstance(dtype_value, tuple): dtype_value = (dtype_value,) @@ -408,16 +410,13 @@ def _update_field_metadata( if global_idx not in self.field_shapes: self.field_shapes[global_idx] = {} - if dtype_value is not None: - self.field_dtypes[global_idx].update(dtype_value[i]) - if shape_value is not None: - self.field_shapes[global_idx].update(shape_value[i]) + self.field_dtypes[global_idx].update(dtype_value[i]) + self.field_shapes[global_idx].update(shape_value[i]) if custom_meta: if len(global_indices) != len(custom_meta): raise ValueError( - f"Length of global_indices ({len(global_indices)}) does not match " - f"length of custom_meta ({len(custom_meta)})" + f"`global_indices` {len(global_indices)} and `custom_meta` {len(custom_meta)} length mismatch." ) custom_meta_value = itemgetter(*global_indices)(custom_meta) if custom_meta else None if not isinstance(custom_meta_value, tuple): @@ -425,8 +424,7 @@ def _update_field_metadata( for i, global_idx in enumerate(global_indices): if global_idx not in self.field_custom_metas: self.field_custom_metas[global_idx] = {} - if custom_meta_value is not None: - self.field_custom_metas[global_idx].update(custom_meta_value[i]) + self.field_custom_metas[global_idx].update(custom_meta_value[i]) # ==================== Consumption Status Interface ==================== @@ -659,9 +657,10 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr consumption_tensor[indexes_to_release] = 0 self.global_indexes.difference_update(indexes_to_release) - self.field_dtypes.difference_update(indexes_to_release) - self.field_shapes.difference_update(indexes_to_release) - self.field_custom_metas.difference_update(indexes_to_release) + for idx in indexes_to_release: + self.field_dtypes.pop(idx, None) + self.field_shapes.pop(idx, None) + self.field_custom_metas.pop(idx, None) except Exception as e: logger.error( @@ -823,7 +822,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], - custom_meta: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> bool: """ Update production status for specific samples and fields in a partition. diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 417fe58..a310634 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -174,7 +174,7 @@ class BatchMeta: samples: list[SampleMeta] extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) - # internal data for different storage backends: _custom_meta[index][field] + # internal data for different storage backends: _custom_meta[global_index][field] _custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -233,13 +233,6 @@ def partition_ids(self) -> list[str]: return getattr(self, "_partition_ids", []) # Custom meta methods for different storage backends - def get_custom_meta_list(self) -> list[Any]: - """Get required custom meta as a list""" - return [ - self._custom_meta.get(index, {}).get(field_name, None) - for field_name, index in itertools.product(sorted(self.field_names), range(self.size)) - ] - def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: """Get the entire custom meta dictionary""" return copy.deepcopy(self._custom_meta) diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index fed031d..651d86e 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -16,8 +16,6 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from torch import Tensor - class TransferQueueStorageKVClient(ABC): """ @@ -26,19 +24,36 @@ class TransferQueueStorageKVClient(ABC): """ @abstractmethod - def put(self, keys: list[str], values: list[Tensor]) -> Optional[list[Any]]: + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: """ Store key-value pairs in the storage backend. Args: keys (list[str]): List of keys to store. - values (list[Tensor]): List of tensor values to store. + values (list[Any]): List of any type to store. Returns: Optional[list[Any]]: Optional list of custom metadata from each storage backend. """ raise NotImplementedError("Subclasses must implement put") @abstractmethod - def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Tensor]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: + """ + Retrieve values from the storage backend by key. + Args: + keys (list[str]): List of keys whose values should be retrieved. + shapes: Optional shape information for the expected tensors. The + structure and interpretation of this argument are determined + by the concrete storage backend implementation. + dtypes: Optional data type information for the expected tensors. + The structure and interpretation of this argument are + determined by the concrete storage backend implementation. + custom_meta: Optional backend-specific metadata used to control + or optimize the retrieval process. Its format is defined by + the concrete storage backend implementation. + Returns: + list[Tensor]: List of tensors retrieved from the storage backend, + in the same order as the provided keys. + """ raise NotImplementedError("Subclasses must implement get") @abstractmethod diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index 5b85825..c7b1e2f 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -1,5 +1,5 @@ import itertools -from typing import Any +from typing import Any, Optional import ray import torch @@ -38,7 +38,7 @@ def __init__(self, config=None): except ValueError: self.storage_actor = RayObjectRefStorage.options(name="RayObjectRefStorage", get_if_exists=False).remote() - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: """ Store tensors to remote storage. Args: @@ -59,13 +59,14 @@ def put(self, keys: list[str], values: list[Any]): ) ray.get(self.storage_actor.put_obj_ref.remote(keys, obj_refs)) - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: """ Retrieve objects from remote storage. Args: keys (list): List of string keys to fetch. shapes (list, optional): Ignored. For compatibility with KVStorageManager. dtypes (list, optional): Ignored. For compatibility with KVStorageManager. + custom_meta (list, optional): Ray object ref for each key Returns: list: List of retrieved objects """ diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 0956a56..280f4c9 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -272,6 +272,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li keys (List[str]): Keys to fetch. shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. + custom_meta (List[str], optional): Device type (npu/cpu) for each key Returns: List[Any]: Retrieved values in the same order as input keys. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 3cf5bf9..3dfd047 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -421,12 +421,15 @@ def _get_shape_type_custom_meta_list(metadata: BatchMeta): """ shapes = [] dtypes = [] + custom_meta_list = [] + all_custom_meta = metadata.get_all_custom_meta() for field_name in sorted(metadata.field_names): for index in range(len(metadata)): field = metadata.samples[index].get_field_by_name(field_name) shapes.append(field.shape) dtypes.append(field.dtype) - custom_meta_list = metadata.get_custom_meta_list() + global_index = metadata.global_indexes[index] + custom_meta_list.append(all_custom_meta.get(global_index, {}).get(field_name, None)) return shapes, dtypes, custom_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: @@ -477,13 +480,13 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})") # custom meta is a flat list aligned with keys/values # Use itertools.product to eliminate nested loops - for (global_idx, field_name), meta_value in zip( - itertools.product(metadata.global_indexes, metadata.field_names), + for global_idx in metadata.global_indexes: + per_field_custom_meta[global_idx] = {} + for (field_name, global_idx), meta_value in zip( + itertools.product(sorted(metadata.field_names), metadata.global_indexes), custom_meta, strict=False, ): - if global_idx not in per_field_custom_meta: - per_field_custom_meta[global_idx] = {} per_field_custom_meta[global_idx][field_name] = meta_value metadata.update_custom_meta(per_field_custom_meta) From 3dee20c403cd3849e441ebb4e2a128e590144b9e Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Thu, 22 Jan 2026 10:47:10 +0800 Subject: [PATCH 4/8] revert unnecessary changes Signed-off-by: tianyi-ge --- transfer_queue/metadata.py | 10 ++------- transfer_queue/storage/managers/base.py | 28 +++++-------------------- 2 files changed, 7 insertions(+), 31 deletions(-) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 3cfdce5..0e9fa61 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -193,11 +193,7 @@ def __post_init__(self): for idx, sample in enumerate(self.samples): object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly - object.__setattr__( - self, - "_global_indexes", - [sample.global_index for sample in self.samples], - ) + object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples]) # check if all samples have the same field names first_sample_field_names = sorted(self.samples[0].field_names) @@ -547,9 +543,7 @@ def _update_after_reorder(self) -> None: @classmethod def from_samples( - cls, - samples: SampleMeta | list[SampleMeta], - extra_info: Optional[dict[str, Any]] = None, + cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None ) -> "BatchMeta": """ Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 1127e37..fb98f91 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -28,12 +28,7 @@ from transfer_queue.metadata import BatchMeta from transfer_queue.storage.clients.factory import StorageClientFactory -from transfer_queue.utils.zmq_utils import ( - ZMQMessage, - ZMQRequestType, - ZMQServerInfo, - create_zmq_socket, -) +from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -330,19 +325,6 @@ def __init__(self, config: dict[str, Any]): super().__init__(config) self.storage_client = StorageClientFactory.create(client_name, config) - @staticmethod - def _generate_key(field_name: str, global_index: int) -> str: - """ - Generate a KV key in the format 'global_index@field_name'. - - Args: - field_name : Name of the field. - global_index : Global index of the sample. - Returns: - str: Generated key, e.g., '0@field_a' - """ - return f"{global_index}@{field_name}" - @staticmethod def _generate_keys(field_names: list[str], global_indexes: list[int]) -> list[str]: """ @@ -356,10 +338,7 @@ def _generate_keys(field_names: list[str], global_indexes: list[int]) -> list[st Returns: list[str]: List of keys, e.g., ['0@field_a', '1@field_a', '0@field_b', ...] """ - return [ - KVStorageManager._generate_key(field, index) - for field, index in itertools.product(sorted(field_names), global_indexes) - ] + return [f"{index}@{field}" for field, index in itertools.product(sorted(field_names), global_indexes)] @staticmethod def _generate_values(data: TensorDict) -> list[Tensor]: @@ -503,6 +482,9 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: # Use itertools.product to eliminate nested loops for global_idx in metadata.global_indexes: per_field_custom_meta[global_idx] = {} + + # TODO(tianyi): the order of custom meta is coupled with keys/values + # if _generate_keys or _generate_values changes, this will break for (field_name, global_idx), meta_value in zip( itertools.product(sorted(metadata.field_names), metadata.global_indexes), custom_meta, From 91297657c9bc7a94dff4792304f98a35cccf0597 Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Thu, 22 Jan 2026 11:39:36 +0800 Subject: [PATCH 5/8] fix minor code reviews Signed-off-by: tianyi-ge --- transfer_queue/controller.py | 21 ++++--------------- transfer_queue/storage/clients/base.py | 6 +++--- .../storage/clients/ray_storage_client.py | 1 + .../storage/clients/yuanrong_client.py | 2 +- 4 files changed, 9 insertions(+), 21 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index b8040d0..52fc5d5 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -410,8 +410,10 @@ def _update_field_metadata( if global_idx not in self.field_shapes: self.field_shapes[global_idx] = {} - self.field_dtypes[global_idx].update(dtype_value[i]) - self.field_shapes[global_idx].update(shape_value[i]) + if dtype_value[i] is not None: + self.field_dtypes[global_idx].update(dtype_value[i]) + if shape_value[i] is not None: + self.field_shapes[global_idx].update(shape_value[i]) if custom_meta: if len(global_indices) != len(custom_meta): @@ -426,21 +428,6 @@ def _update_field_metadata( self.field_custom_metas[global_idx] = {} self.field_custom_metas[global_idx].update(custom_meta_value[i]) - if custom_meta: - if len(global_indices) != len(custom_meta): - raise ValueError( - f"Length of global_indices ({len(global_indices)}) does not match " - f"length of custom_meta ({len(custom_meta)})" - ) - custom_meta_value = itemgetter(*global_indices)(custom_meta) if custom_meta else None - if not isinstance(custom_meta_value, tuple): - custom_meta_value = (custom_meta_value,) - for i, global_idx in enumerate(global_indices): - if global_idx not in self.field_custom_metas: - self.field_custom_metas[global_idx] = {} - if custom_meta_value is not None: - self.field_custom_metas[global_idx].update(custom_meta_value[i]) - # ==================== Consumption Status Interface ==================== def get_consumption_status(self, task_name: str) -> torch.Tensor: diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index 651d86e..db6239d 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -41,17 +41,17 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li Retrieve values from the storage backend by key. Args: keys (list[str]): List of keys whose values should be retrieved. - shapes: Optional shape information for the expected tensors. The + shapes: Optional shape information for the expected values. The structure and interpretation of this argument are determined by the concrete storage backend implementation. - dtypes: Optional data type information for the expected tensors. + dtypes: Optional data type information for the expected values. The structure and interpretation of this argument are determined by the concrete storage backend implementation. custom_meta: Optional backend-specific metadata used to control or optimize the retrieval process. Its format is defined by the concrete storage backend implementation. Returns: - list[Tensor]: List of tensors retrieved from the storage backend, + list[Any]: List of values retrieved from the storage backend, in the same order as the provided keys. """ raise NotImplementedError("Subclasses must implement get") diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index c7b1e2f..5ffd023 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -58,6 +58,7 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: ) ) ray.get(self.storage_actor.put_obj_ref.remote(keys, obj_refs)) + return None def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: """ diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 280f4c9..c7da326 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -178,7 +178,7 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: self._batch_put(keys, values) return None - def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> Optional[list[Any]]: + def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: """Retrieves a batch of values from remote storage using expected metadata. NPU tensors are fetched via DsTensorClient using pre-allocated buffers. From a21ef480501cfd08e71e794e95872fa8d34d762c Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Thu, 22 Jan 2026 15:59:30 +0800 Subject: [PATCH 6/8] fix unit test reviews Signed-off-by: tianyi-ge --- tests/test_controller.py | 462 ----------------------- tests/test_controller_data_partitions.py | 178 +++------ transfer_queue/controller.py | 60 +-- transfer_queue/metadata.py | 1 + 4 files changed, 82 insertions(+), 619 deletions(-) diff --git a/tests/test_controller.py b/tests/test_controller.py index 35a95f1..360b7f0 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -382,465 +382,3 @@ def test_controller_clear_meta(self, ray_setup): assert set(partition_after.global_indexes) == set([4, 5, 7]) print("✓ Clear meta correct") - - -class TestCustomMeta: - """Test suite for custom_meta functionality in TransferQueueController""" - - def test_custom_meta_basic_storage_and_retrieval(self, ray_setup): - """Test basic custom_meta storage via update_production_status and retrieval via get_metadata""" - gbs = 4 - partition_id = "test_custom_meta_basic" - - tq_controller = TransferQueueController.remote() - - # Create metadata in insert mode - data_fields = ["prompt_ids", "attention_mask"] - metadata = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields, - batch_size=gbs, - partition_id=partition_id, - mode="insert", - ) - ) - - assert metadata.global_indexes == list(range(gbs)) - - # Update production status with custom_meta - dtypes = {k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool"} for k in metadata.global_indexes} - shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,)} for k in metadata.global_indexes} - custom_meta = { - k: {"prompt_ids": {"token_count": 100 + k}, "attention_mask": {"mask_ratio": 0.1 * k}} - for k in metadata.global_indexes - } - - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata.global_indexes, - field_names=metadata.field_names, - dtypes=dtypes, - shapes=shapes, - custom_meta=custom_meta, - ) - ) - assert success - - # Verify custom_meta is stored in partition - partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) - assert partition is not None - assert len(partition.field_custom_metas) == gbs - - for idx in metadata.global_indexes: - assert idx in partition.field_custom_metas - assert "prompt_ids" in partition.field_custom_metas[idx] - assert "attention_mask" in partition.field_custom_metas[idx] - assert partition.field_custom_metas[idx]["prompt_ids"]["token_count"] == 100 + idx - assert partition.field_custom_metas[idx]["attention_mask"]["mask_ratio"] == 0.1 * idx - - print("✓ Basic custom_meta storage correct") - - # Retrieve via get_metadata in fetch mode and verify custom_meta is in batch_meta - fetch_meta = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields, - batch_size=gbs, - partition_id=partition_id, - mode="fetch", - task_name="test_task", - ) - ) - - assert fetch_meta is not None - custom_meta_retrieved = fetch_meta.get_all_custom_meta() - assert custom_meta_retrieved is not None - - for idx in metadata.global_indexes: - assert idx in custom_meta_retrieved - assert "prompt_ids" in custom_meta_retrieved[idx] - assert "attention_mask" in custom_meta_retrieved[idx] - - print("✓ Basic custom_meta retrieval via get_metadata correct") - - def test_custom_meta_with_partial_fields(self, ray_setup): - """Test custom_meta retrieval when only requesting subset of fields""" - gbs = 4 - partition_id = "test_custom_meta_partial" - - tq_controller = TransferQueueController.remote() - - # Create metadata in insert mode with multiple fields - data_fields = ["prompt_ids", "attention_mask", "labels"] - metadata = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields, - batch_size=gbs, - partition_id=partition_id, - mode="insert", - ) - ) - - # Update production status with custom_meta for all fields - dtypes = { - k: {"prompt_ids": "torch.int64", "attention_mask": "torch.bool", "labels": "torch.int64"} - for k in metadata.global_indexes - } - shapes = {k: {"prompt_ids": (32,), "attention_mask": (32,), "labels": (32,)} for k in metadata.global_indexes} - custom_meta = { - k: { - "prompt_ids": {"meta_prompt": f"prompt_{k}"}, - "attention_mask": {"meta_mask": f"mask_{k}"}, - "labels": {"meta_label": f"label_{k}"}, - } - for k in metadata.global_indexes - } - - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata.global_indexes, - field_names=metadata.field_names, - dtypes=dtypes, - shapes=shapes, - custom_meta=custom_meta, - ) - ) - assert success - - # Fetch with only a subset of fields - subset_fields = ["prompt_ids", "labels"] - fetch_meta = ray.get( - tq_controller.get_metadata.remote( - data_fields=subset_fields, - batch_size=gbs, - partition_id=partition_id, - mode="fetch", - task_name="test_task", - ) - ) - - assert fetch_meta is not None - custom_meta_retrieved = fetch_meta.get_all_custom_meta() - assert custom_meta_retrieved is not None - - # Verify only requested fields are in custom_meta - for idx in metadata.global_indexes: - assert idx in custom_meta_retrieved - assert "prompt_ids" in custom_meta_retrieved[idx] - assert "labels" in custom_meta_retrieved[idx] - # attention_mask should not be in the custom_meta since it wasn't requested - assert "attention_mask" not in custom_meta_retrieved[idx] - - print("✓ Custom_meta with partial fields correct") - - def test_custom_meta_length_mismatch_returns_false(self, ray_setup): - """Test that custom_meta length mismatch with global_indices returns False""" - gbs = 4 - partition_id = "test_custom_meta_mismatch" - - tq_controller = TransferQueueController.remote() - - # Create metadata in insert mode - data_fields = ["prompt_ids"] - metadata = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields, - batch_size=gbs, - partition_id=partition_id, - mode="insert", - ) - ) - - # Prepare mismatched custom_meta (fewer entries than global_indexes) - dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} - shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} - # Only provide custom_meta for half the samples - custom_meta = {k: {"prompt_ids": {"meta": k}} for k in metadata.global_indexes[:2]} - - # The method should return False when there's a length mismatch - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata.global_indexes, - field_names=metadata.field_names, - dtypes=dtypes, - shapes=shapes, - custom_meta=custom_meta, - ) - ) - assert success is False, "Expected update_production_status to return False for length mismatch" - - print("✓ Custom_meta length mismatch error handling correct") - - def test_custom_meta_none_does_not_store(self, ray_setup): - """Test that passing None for custom_meta doesn't create custom_meta entries""" - gbs = 4 - partition_id = "test_custom_meta_none" - - tq_controller = TransferQueueController.remote() - - # Create metadata in insert mode - data_fields = ["prompt_ids"] - metadata = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields, - batch_size=gbs, - partition_id=partition_id, - mode="insert", - ) - ) - - # Update production status without custom_meta (None) - dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} - shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} - - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata.global_indexes, - field_names=metadata.field_names, - dtypes=dtypes, - shapes=shapes, - custom_meta=None, - ) - ) - assert success - - # Verify no custom_meta is stored - partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) - assert partition is not None - assert len(partition.field_custom_metas) == 0 - - print("✓ Custom_meta None handling correct") - - def test_custom_meta_preserved_after_partial_clear(self, ray_setup): - """Test that custom_meta for non-cleared samples is preserved after clear_meta""" - gbs = 4 - partition_id = "test_custom_meta_clear" - - tq_controller = TransferQueueController.remote() - - # Create metadata in insert mode - data_fields = ["prompt_ids"] - metadata = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields, - batch_size=gbs, - partition_id=partition_id, - mode="insert", - ) - ) - - # Update production status with custom_meta - dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} - shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} - custom_meta = {k: {"prompt_ids": {"sample_id": k * 10}} for k in metadata.global_indexes} - - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata.global_indexes, - field_names=metadata.field_names, - dtypes=dtypes, - shapes=shapes, - custom_meta=custom_meta, - ) - ) - assert success - - # Clear only first 2 samples - global_indexes_to_clear = [0, 1] - partition_ids_to_clear = [partition_id] * len(global_indexes_to_clear) - - ray.get( - tq_controller.clear_meta.remote( - global_indexes=global_indexes_to_clear, - partition_ids=partition_ids_to_clear, - ) - ) - - # Verify custom_meta is cleared for cleared samples and preserved for others - partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) - assert partition is not None - - # Cleared samples should not have custom_meta - assert 0 not in partition.field_custom_metas - assert 1 not in partition.field_custom_metas - - # Non-cleared samples should still have custom_meta - assert 2 in partition.field_custom_metas - assert 3 in partition.field_custom_metas - assert partition.field_custom_metas[2]["prompt_ids"]["sample_id"] == 20 - assert partition.field_custom_metas[3]["prompt_ids"]["sample_id"] == 30 - - print("✓ Custom_meta preserved after partial clear correct") - - def test_custom_meta_update_merges_values(self, ray_setup): - """Test that updating custom_meta for the same sample merges values""" - gbs = 2 - partition_id = "test_custom_meta_merge" - - tq_controller = TransferQueueController.remote() - - # Create metadata in insert mode with first field - data_fields_1 = ["prompt_ids"] - metadata_1 = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields_1, - batch_size=gbs, - partition_id=partition_id, - mode="insert", - ) - ) - - # First update with custom_meta for prompt_ids - dtypes_1 = {k: {"prompt_ids": "torch.int64"} for k in metadata_1.global_indexes} - shapes_1 = {k: {"prompt_ids": (32,)} for k in metadata_1.global_indexes} - custom_meta_1 = {k: {"prompt_ids": {"first_update": True}} for k in metadata_1.global_indexes} - - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata_1.global_indexes, - field_names=metadata_1.field_names, - dtypes=dtypes_1, - shapes=shapes_1, - custom_meta=custom_meta_1, - ) - ) - assert success - - # Second update with new field and its custom_meta - data_fields_2 = ["attention_mask"] - dtypes_2 = {k: {"attention_mask": "torch.bool"} for k in metadata_1.global_indexes} - shapes_2 = {k: {"attention_mask": (32,)} for k in metadata_1.global_indexes} - custom_meta_2 = {k: {"attention_mask": {"second_update": True}} for k in metadata_1.global_indexes} - - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata_1.global_indexes, - field_names=data_fields_2, - dtypes=dtypes_2, - shapes=shapes_2, - custom_meta=custom_meta_2, - ) - ) - assert success - - # Verify both custom_meta entries are present (merged) - partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) - assert partition is not None - - for idx in metadata_1.global_indexes: - assert idx in partition.field_custom_metas - assert "prompt_ids" in partition.field_custom_metas[idx] - assert "attention_mask" in partition.field_custom_metas[idx] - assert partition.field_custom_metas[idx]["prompt_ids"]["first_update"] is True - assert partition.field_custom_metas[idx]["attention_mask"]["second_update"] is True - - print("✓ Custom_meta merge on update correct") - - def test_custom_meta_with_complex_nested_data(self, ray_setup): - """Test custom_meta with complex nested data structures""" - gbs = 2 - partition_id = "test_custom_meta_complex" - - tq_controller = TransferQueueController.remote() - - # Create metadata in insert mode - data_fields = ["prompt_ids"] - metadata = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields, - batch_size=gbs, - partition_id=partition_id, - mode="insert", - ) - ) - - # Create complex nested custom_meta - dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} - shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} - custom_meta = { - k: { - "prompt_ids": { - "nested_dict": {"level1": {"level2": {"value": k}}}, - "list_data": [1, 2, 3, k], - "mixed_types": {"string": "test", "number": 42, "float": 3.14, "bool": True}, - } - } - for k in metadata.global_indexes - } - - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata.global_indexes, - field_names=metadata.field_names, - dtypes=dtypes, - shapes=shapes, - custom_meta=custom_meta, - ) - ) - assert success - - # Verify complex nested data is preserved - partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) - assert partition is not None - - for idx in metadata.global_indexes: - stored_meta = partition.field_custom_metas[idx]["prompt_ids"] - assert stored_meta["nested_dict"]["level1"]["level2"]["value"] == idx - assert stored_meta["list_data"] == [1, 2, 3, idx] - assert stored_meta["mixed_types"]["string"] == "test" - assert stored_meta["mixed_types"]["number"] == 42 - assert stored_meta["mixed_types"]["float"] == 3.14 - assert stored_meta["mixed_types"]["bool"] is True - - print("✓ Complex nested custom_meta correct") - - def test_custom_meta_cleared_on_partition_clear(self, ray_setup): - """Test that custom_meta is fully cleared when partition is cleared""" - gbs = 4 - partition_id = "test_custom_meta_partition_clear" - - tq_controller = TransferQueueController.remote() - - # Create metadata in insert mode - data_fields = ["prompt_ids"] - metadata = ray.get( - tq_controller.get_metadata.remote( - data_fields=data_fields, - batch_size=gbs, - partition_id=partition_id, - mode="insert", - ) - ) - - # Update production status with custom_meta - dtypes = {k: {"prompt_ids": "torch.int64"} for k in metadata.global_indexes} - shapes = {k: {"prompt_ids": (32,)} for k in metadata.global_indexes} - custom_meta = {k: {"prompt_ids": {"data": k}} for k in metadata.global_indexes} - - success = ray.get( - tq_controller.update_production_status.remote( - partition_id=partition_id, - global_indexes=metadata.global_indexes, - field_names=metadata.field_names, - dtypes=dtypes, - shapes=shapes, - custom_meta=custom_meta, - ) - ) - assert success - - # Clear the entire partition - ray.get(tq_controller.clear_partition.remote(partition_id)) - - # Verify partition is gone - partition = ray.get(tq_controller.get_partition_snapshot.remote(partition_id)) - assert partition is None - - print("✓ Custom_meta cleared on partition clear correct") diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index 49cb715..ee1092a 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -19,6 +19,8 @@ import time from pathlib import Path +import pytest + parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) @@ -446,30 +448,23 @@ def test_performance_characteristics(): def test_custom_meta_in_data_partition_status(): - """Test custom_meta functionality in DataPartitionStatus.""" - print("Testing custom_meta in DataPartitionStatus...") + """Simplified tests for custom_meta functionality in DataPartitionStatus.""" + + print("Testing simplified custom_meta in DataPartitionStatus...") from transfer_queue.controller import DataPartitionStatus partition = DataPartitionStatus(partition_id="custom_meta_test") - # Test 1: Basic custom_meta storage via update_production_status + # Basic custom_meta storage via update_production_status global_indices = [0, 1, 2] field_names = ["input_ids", "attention_mask"] - dtypes = { - 0: {"input_ids": "torch.int32", "attention_mask": "torch.bool"}, - 1: {"input_ids": "torch.int32", "attention_mask": "torch.bool"}, - 2: {"input_ids": "torch.int32", "attention_mask": "torch.bool"}, - } - shapes = { - 0: {"input_ids": (512,), "attention_mask": (512,)}, - 1: {"input_ids": (512,), "attention_mask": (512,)}, - 2: {"input_ids": (512,), "attention_mask": (512,)}, - } + dtypes = {i: {"input_ids": "torch.int32", "attention_mask": "torch.bool"} for i in global_indices} + shapes = {i: {"input_ids": (512,), "attention_mask": (512,)} for i in global_indices} custom_meta = { - 0: {"input_ids": {"token_count": 100}, "attention_mask": {"mask_ratio": 0.1}}, - 1: {"input_ids": {"token_count": 200}, "attention_mask": {"mask_ratio": 0.2}}, - 2: {"input_ids": {"token_count": 300}, "attention_mask": {"mask_ratio": 0.3}}, + 0: {"input_ids": {"token_count": 100}}, + 1: {"attention_mask": {"mask_ratio": 0.2}}, + 2: {"input_ids": {"token_count": 300}}, } success = partition.update_production_status( @@ -481,138 +476,53 @@ def test_custom_meta_in_data_partition_status(): ) assert success - assert len(partition.field_custom_metas) == 3 - # Verify custom_meta is stored correctly + # Verify some stored values assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100 assert partition.field_custom_metas[1]["attention_mask"]["mask_ratio"] == 0.2 - assert partition.field_custom_metas[2]["input_ids"]["token_count"] == 300 - - print("✓ Basic custom_meta storage works") - - # Test 2: get_field_custom_meta retrieval - retrieved_meta = partition.get_field_custom_meta([0, 1, 2], ["input_ids", "attention_mask"]) - - assert 0 in retrieved_meta - assert 1 in retrieved_meta - assert 2 in retrieved_meta - assert retrieved_meta[0]["input_ids"]["token_count"] == 100 - assert retrieved_meta[1]["attention_mask"]["mask_ratio"] == 0.2 - - print("✓ get_field_custom_meta retrieval works") - - # Test 3: get_field_custom_meta with partial field filter - partial_meta = partition.get_field_custom_meta([0, 1], ["input_ids"]) - - assert 0 in partial_meta - assert 1 in partial_meta - assert "input_ids" in partial_meta[0] - assert "attention_mask" not in partial_meta[0] # Should not include non-requested fields - - print("✓ get_field_custom_meta with partial fields works") - - # Test 4: get_field_custom_meta with non-existent global_index - empty_meta = partition.get_field_custom_meta([999], ["input_ids"]) - assert 999 not in empty_meta # Should not include non-existent indices - - print("✓ get_field_custom_meta handles non-existent indices correctly") - - # Test 5: custom_meta update (merge) on same global_index - additional_custom_meta = { - 0: {"new_field": {"new_key": "new_value"}}, - } - success = partition.update_production_status( - global_indices=[0], - field_names=["new_field"], - dtypes={0: {"new_field": "torch.float32"}}, - shapes={0: {"new_field": (64,)}}, - custom_meta=additional_custom_meta, - ) - - assert success - # Original custom_meta should be preserved - assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100 - # New custom_meta should be merged - assert partition.field_custom_metas[0]["new_field"]["new_key"] == "new_value" - print("✓ Custom_meta merge on update works") + # Retrieval via helper for a subset of fields + retrieved = partition.get_field_custom_meta([0, 1], ["input_ids", "attention_mask"]) + assert 0 in retrieved and "input_ids" in retrieved[0] + assert 1 in retrieved and "attention_mask" in retrieved[1] - # Test 6: custom_meta cleared on clear_data + # Clearing a sample should remove its custom_meta partition.clear_data([0], clear_consumption=True) - assert 0 not in partition.field_custom_metas - assert 1 in partition.field_custom_metas # Other samples should remain - assert 2 in partition.field_custom_metas - - print("✓ Custom_meta cleared on clear_data works") - - # Test 7: custom_meta None does not create entries - partition2 = DataPartitionStatus(partition_id="custom_meta_test_2") - success = partition2.update_production_status( - global_indices=[0, 1], - field_names=["field1"], - dtypes={0: {"field1": "torch.int32"}, 1: {"field1": "torch.int32"}}, - shapes={0: {"field1": (32,)}, 1: {"field1": (32,)}}, - custom_meta=None, - ) - assert success - assert len(partition2.field_custom_metas) == 0 + print("✓ Custom_meta tests passed") - print("✓ Custom_meta None handling works") - - # Test 8: custom_meta length mismatch raises ValueError - partition3 = DataPartitionStatus(partition_id="custom_meta_test_3") - mismatched_custom_meta = { - 0: {"field1": {"key": "value"}}, - # Missing entries for 1 and 2 - } - success = partition3.update_production_status( - global_indices=[0, 1, 2], - field_names=["field1"], - dtypes={0: {"field1": "torch.int32"}, 1: {"field1": "torch.int32"}, 2: {"field1": "torch.int32"}}, - shapes={0: {"field1": (32,)}, 1: {"field1": (32,)}, 2: {"field1": (32,)}}, - custom_meta=mismatched_custom_meta, - ) - # Should return False due to length mismatch (caught by exception handler) - assert success is False +def test_update_field_metadata_variants(): + """Test _update_field_metadata handles dtypes/shapes/custom_meta being optional and merging.""" + from transfer_queue.controller import DataPartitionStatus - print("✓ Custom_meta length mismatch error handling works") + partition = DataPartitionStatus(partition_id="update_meta_test") - # Test 9: Complex nested custom_meta - partition4 = DataPartitionStatus(partition_id="custom_meta_test_4") - complex_custom_meta = { - 0: { - "field1": { - "nested": {"level1": {"level2": {"value": 42}}}, - "list_data": [1, 2, 3], - "mixed": {"str": "test", "int": 100, "float": 3.14, "bool": True}, - } - }, - } - success = partition4.update_production_status( - global_indices=[0], - field_names=["field1"], - dtypes={0: {"field1": "torch.int32"}}, - shapes={0: {"field1": (32,)}}, - custom_meta=complex_custom_meta, - ) + # Only dtypes provided + global_indices = [0, 1] + dtypes = {0: {"f1": "torch.int32"}, 1: {"f1": "torch.bool"}} - assert success - stored_meta = partition4.field_custom_metas[0]["field1"] - assert stored_meta["nested"]["level1"]["level2"]["value"] == 42 - assert stored_meta["list_data"] == [1, 2, 3] - assert stored_meta["mixed"]["str"] == "test" - assert stored_meta["mixed"]["bool"] is True + partition._update_field_metadata(global_indices, dtypes, shapes=None, custom_meta=None) + assert partition.field_dtypes[0]["f1"] == "torch.int32" + assert partition.field_dtypes[1]["f1"] == "torch.bool" + assert partition.field_shapes == {} + assert partition.field_custom_metas == {} - print("✓ Complex nested custom_meta storage works") + # Only shapes provided for a new index + partition._update_field_metadata([2], dtypes=None, shapes={2: {"f2": (16,)}}, custom_meta=None) + assert partition.field_shapes[2]["f2"] == (16,) - # Test 10: custom_meta preserved in snapshot - snapshot = partition4.to_snapshot() - assert 0 in snapshot.field_custom_metas - assert snapshot.field_custom_metas[0]["field1"]["nested"]["level1"]["level2"]["value"] == 42 + # Only custom_meta provided and merged with existing entries + partition._update_field_metadata([2], dtypes=None, shapes=None, custom_meta={2: {"f2": {"meta": 1}}}) + assert 2 in partition.field_custom_metas + assert partition.field_custom_metas[2]["f2"]["meta"] == 1 - print("✓ Custom_meta preserved in snapshot") + # Merging dtypes on an existing index should preserve previous keys + partition._update_field_metadata([0], dtypes={0: {"f2": "torch.float32"}}, shapes=None, custom_meta=None) + assert partition.field_dtypes[0]["f1"] == "torch.int32" + assert partition.field_dtypes[0]["f2"] == "torch.float32" - print("Custom_meta in DataPartitionStatus tests passed!\n") + # Length mismatch should raise ValueError when provided mapping lengths differ from global_indices + with pytest.raises(ValueError): + partition._update_field_metadata([0, 1, 2], dtypes={0: {}}, shapes=None, custom_meta=None) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 52fc5d5..42970d3 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -388,42 +388,56 @@ def _update_field_metadata( custom_meta: Optional[dict[int, dict[str, Any]]], ): """Update field dtype and shape metadata.""" - if not global_indices or not dtypes or not shapes: + if not global_indices: return - if len(global_indices) != len(dtypes): + # Validate lengths only for provided mappings + if dtypes and len(global_indices) != len(dtypes): raise ValueError(f"`global_indices` {len(global_indices)} and `dtypes` {len(dtypes)} length mismatch.") - if len(global_indices) != len(shapes): + if shapes and len(global_indices) != len(shapes): raise ValueError(f"`global_indices` {len(global_indices)} and `shapes` {len(shapes)} length mismatch.") + if custom_meta and len(global_indices) != len(custom_meta): + raise ValueError( + f"`global_indices` {len(global_indices)} and `custom_meta` {len(custom_meta)} length mismatch." + ) - dtype_value = itemgetter(*global_indices)(dtypes) - shape_value = itemgetter(*global_indices)(shapes) + # Extract values for each provided mapping; if a mapping is absent, use Nones + if dtypes: + dtype_value = itemgetter(*global_indices)(dtypes) + if not isinstance(dtype_value, tuple): + dtype_value = (dtype_value,) + else: + dtype_value = tuple([None] * len(global_indices)) - if not isinstance(dtype_value, tuple): - dtype_value = (dtype_value,) - if not isinstance(shape_value, tuple): - shape_value = (shape_value,) + if shapes: + shape_value = itemgetter(*global_indices)(shapes) + if not isinstance(shape_value, tuple): + shape_value = (shape_value,) + else: + shape_value = tuple([None] * len(global_indices)) - for i, global_idx in enumerate(global_indices): - if global_idx not in self.field_dtypes: - self.field_dtypes[global_idx] = {} - if global_idx not in self.field_shapes: - self.field_shapes[global_idx] = {} + if custom_meta: + custom_meta_value = itemgetter(*global_indices)(custom_meta) + if not isinstance(custom_meta_value, tuple): + custom_meta_value = (custom_meta_value,) + else: + custom_meta_value = tuple([None] * len(global_indices)) + for i, global_idx in enumerate(global_indices): + # Only create and update dtype mapping if a dtype value was provided if dtype_value[i] is not None: + if global_idx not in self.field_dtypes: + self.field_dtypes[global_idx] = {} self.field_dtypes[global_idx].update(dtype_value[i]) + + # Only create and update shape mapping if a shape value was provided if shape_value[i] is not None: + if global_idx not in self.field_shapes: + self.field_shapes[global_idx] = {} self.field_shapes[global_idx].update(shape_value[i]) - if custom_meta: - if len(global_indices) != len(custom_meta): - raise ValueError( - f"`global_indices` {len(global_indices)} and `custom_meta` {len(custom_meta)} length mismatch." - ) - custom_meta_value = itemgetter(*global_indices)(custom_meta) if custom_meta else None - if not isinstance(custom_meta_value, tuple): - custom_meta_value = (custom_meta_value,) - for i, global_idx in enumerate(global_indices): + # Only create and update custom_meta mapping if a custom_meta value was provided + if custom_meta_value[i] is not None: if global_idx not in self.field_custom_metas: self.field_custom_metas[global_idx] = {} self.field_custom_metas[global_idx].update(custom_meta_value[i]) diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index 0e9fa61..a75e6cd 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -125,6 +125,7 @@ def select_fields(self, field_names: list[str]) -> "SampleMeta": selected_fields = {name: self.fields[name] for name in field_names if name in self.fields} # construct new SampleMeta instance + # TODO(tianyi): move custom_meta to FieldMeta level selected_sample_meta = SampleMeta( fields=selected_fields, partition_id=self.partition_id, From 92c0305889ca5cd2ba17905a16e4b6bcbf4b92be Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Thu, 22 Jan 2026 16:37:41 +0800 Subject: [PATCH 7/8] fix minor reviews Signed-off-by: tianyi-ge --- transfer_queue/storage/managers/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index fb98f91..e2d5f1b 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -484,11 +484,10 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: per_field_custom_meta[global_idx] = {} # TODO(tianyi): the order of custom meta is coupled with keys/values - # if _generate_keys or _generate_values changes, this will break for (field_name, global_idx), meta_value in zip( itertools.product(sorted(metadata.field_names), metadata.global_indexes), custom_meta, - strict=False, + strict=True, ): per_field_custom_meta[global_idx][field_name] = meta_value metadata.update_custom_meta(per_field_custom_meta) From 3a16244cdd1152393a3fb42efb5558e90419271e Mon Sep 17 00:00:00 2001 From: tianyi-ge Date: Thu, 22 Jan 2026 20:18:38 +0800 Subject: [PATCH 8/8] fix type check Signed-off-by: tianyi-ge --- transfer_queue/controller.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 42970d3..fae9e96 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -577,8 +577,8 @@ def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: """Get shape for a specific sample and field.""" return self.field_shapes.get(global_index, {}).get(field_name) - def get_field_custom_meta(self, global_indices: list[int], field_names: list[str]) -> Optional[Any]: - """Get custom_meta for a specific sample and field.""" + def get_field_custom_meta(self, global_indices: list[int], field_names: list[str]) -> dict[int, dict[str, Any]]: + """Get custom_meta for multiple samples and fields.""" return { idx: {f: v for f, v in self.field_custom_metas[idx].items() if f in field_names} for idx in global_indices