diff --git a/tests/test_controller_data_partitions.py b/tests/test_controller_data_partitions.py index e018fb3..ee1092a 100644 --- a/tests/test_controller_data_partitions.py +++ b/tests/test_controller_data_partitions.py @@ -19,6 +19,8 @@ import time from pathlib import Path +import pytest + parent_dir = Path(__file__).resolve().parent.parent sys.path.append(str(parent_dir)) @@ -63,6 +65,7 @@ def test_data_partition_status(): 1: {"input_ids": (512,), "attention_mask": (512,)}, 2: {"input_ids": (512,), "attention_mask": (512,)}, }, + custom_meta=None, ) assert success @@ -172,6 +175,7 @@ def test_dynamic_expansion_scenarios(): 5: {"field_1": (32,)}, 10: {"field_1": (32,)}, }, + custom_meta=None, ) assert partition.total_samples_num == 3 assert partition.allocated_samples_num >= 11 # Should accommodate index 10 @@ -180,7 +184,7 @@ def test_dynamic_expansion_scenarios(): # Scenario 2: Adding many fields dynamically for i in range(15): partition.update_production_status( - [0], [f"field_{i}"], {0: {f"field_{i}": "torch.bool"}}, {0: {f"field_{i}": (32,)}} + [0], [f"field_{i}"], {0: {f"field_{i}": "torch.bool"}}, {0: {f"field_{i}": (32,)}}, None ) assert partition.total_fields_num == 16 # Original + 15 new fields @@ -222,7 +226,7 @@ def test_data_partition_status_advanced(): # Add data to trigger expansion dtypes = {i: {f"dynamic_field_{s}": "torch.bool" for s in ["a", "b", "c"]} for i in range(5)} shapes = {i: {f"dynamic_field_{s}": (32,) for s in ["a", "b", "c"]} for i in range(5)} - partition.update_production_status([0, 1, 2, 3, 4], ["field_a", "field_b", "field_c"], dtypes, shapes) + partition.update_production_status([0, 1, 2, 3, 4], ["field_a", "field_b", "field_c"], dtypes, shapes, None) # Properties should reflect current state assert partition.total_samples_num >= 5 # At least 5 samples @@ -253,7 +257,7 @@ def test_data_partition_status_advanced(): 11: {"field_d": (32,)}, 12: {"field_d": (32,)}, } - partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes) # Triggers sample expansion + partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes, None) # Triggers sample expansion expanded_consumption = partition.get_consumption_status(task_name) assert expanded_consumption[0] == 1 # Preserved assert expanded_consumption[1] == 1 # Preserved @@ -265,13 +269,13 @@ def test_data_partition_status_advanced(): # Start with some fields dtypes = {0: {"initial_field": "torch.bool"}} shapes = {0: {"field_d": (32,)}} - partition.update_production_status([0], ["initial_field"], dtypes, shapes) + partition.update_production_status([0], ["initial_field"], dtypes, shapes, None) # Add many fields to trigger column expansion new_fields = [f"dynamic_field_{i}" for i in range(20)] dtypes = {1: {f"dynamic_field_{i}": "torch.bool" for i in range(20)}} shapes = {1: {f"dynamic_field_{i}": (32,) for i in range(20)}} - partition.update_production_status([1], new_fields, dtypes, shapes) + partition.update_production_status([1], new_fields, dtypes, shapes, None) # Verify all fields are registered and accessible assert "initial_field" in partition.field_name_mapping @@ -441,3 +445,84 @@ def test_performance_characteristics(): print("✓ Memory usage patterns reasonable") print("Performance characteristics tests passed!\n") + + +def test_custom_meta_in_data_partition_status(): + """Simplified tests for custom_meta functionality in DataPartitionStatus.""" + + print("Testing simplified custom_meta in DataPartitionStatus...") + + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="custom_meta_test") + + # Basic custom_meta storage via update_production_status + global_indices = [0, 1, 2] + field_names = ["input_ids", "attention_mask"] + dtypes = {i: {"input_ids": "torch.int32", "attention_mask": "torch.bool"} for i in global_indices} + shapes = {i: {"input_ids": (512,), "attention_mask": (512,)} for i in global_indices} + custom_meta = { + 0: {"input_ids": {"token_count": 100}}, + 1: {"attention_mask": {"mask_ratio": 0.2}}, + 2: {"input_ids": {"token_count": 300}}, + } + + success = partition.update_production_status( + global_indices=global_indices, + field_names=field_names, + dtypes=dtypes, + shapes=shapes, + custom_meta=custom_meta, + ) + + assert success + + # Verify some stored values + assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100 + assert partition.field_custom_metas[1]["attention_mask"]["mask_ratio"] == 0.2 + + # Retrieval via helper for a subset of fields + retrieved = partition.get_field_custom_meta([0, 1], ["input_ids", "attention_mask"]) + assert 0 in retrieved and "input_ids" in retrieved[0] + assert 1 in retrieved and "attention_mask" in retrieved[1] + + # Clearing a sample should remove its custom_meta + partition.clear_data([0], clear_consumption=True) + assert 0 not in partition.field_custom_metas + + print("✓ Custom_meta tests passed") + + +def test_update_field_metadata_variants(): + """Test _update_field_metadata handles dtypes/shapes/custom_meta being optional and merging.""" + from transfer_queue.controller import DataPartitionStatus + + partition = DataPartitionStatus(partition_id="update_meta_test") + + # Only dtypes provided + global_indices = [0, 1] + dtypes = {0: {"f1": "torch.int32"}, 1: {"f1": "torch.bool"}} + + partition._update_field_metadata(global_indices, dtypes, shapes=None, custom_meta=None) + assert partition.field_dtypes[0]["f1"] == "torch.int32" + assert partition.field_dtypes[1]["f1"] == "torch.bool" + assert partition.field_shapes == {} + assert partition.field_custom_metas == {} + + # Only shapes provided for a new index + partition._update_field_metadata([2], dtypes=None, shapes={2: {"f2": (16,)}}, custom_meta=None) + assert partition.field_shapes[2]["f2"] == (16,) + + # Only custom_meta provided and merged with existing entries + partition._update_field_metadata([2], dtypes=None, shapes=None, custom_meta={2: {"f2": {"meta": 1}}}) + assert 2 in partition.field_custom_metas + assert partition.field_custom_metas[2]["f2"]["meta"] == 1 + + # Merging dtypes on an existing index should preserve previous keys + partition._update_field_metadata([0], dtypes={0: {"f2": "torch.float32"}}, shapes=None, custom_meta=None) + assert partition.field_dtypes[0]["f1"] == "torch.int32" + assert partition.field_dtypes[0]["f2"] == "torch.float32" + + # Length mismatch should raise ValueError when provided mapping lengths differ from global_indices + with pytest.raises(ValueError): + partition._update_field_metadata([0, 1, 2], dtypes={0: {}}, shapes=None, custom_meta=None) diff --git a/tests/test_kv_storage_manager.py b/tests/test_kv_storage_manager.py index 3cfe168..c982227 100644 --- a/tests/test_kv_storage_manager.py +++ b/tests/test_kv_storage_manager.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import sys import unittest from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch import torch from tensordict import TensorDict @@ -97,6 +99,232 @@ def test_merge_kv_to_tensordict(self): self.assertEqual(reconstructed.batch_size, torch.Size([3])) + def test_get_shape_type_custom_meta_list_without_custom_meta(self): + """Test _get_shape_type_custom_meta_list returns correct shapes and dtypes without custom_meta.""" + shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata) + + # Expected order: sorted by field name (label, mask, text), then by global_index order + # 3 fields * 3 samples = 9 entries + self.assertEqual(len(shapes), 9) + self.assertEqual(len(dtypes), 9) + self.assertEqual(len(custom_meta_list), 9) + + # Check shapes - order is label, mask, text (sorted alphabetically) + # label shapes: [()]*3, mask shapes: [(1,)]*3, text shapes: [(2,)]*3 + expected_shapes = [ + torch.Size([]), # label[0] + torch.Size([]), # label[1] + torch.Size([]), # label[2] + torch.Size([1]), # mask[0] + torch.Size([1]), # mask[1] + torch.Size([1]), # mask[2] + torch.Size([2]), # text[0] + torch.Size([2]), # text[1] + torch.Size([2]), # text[2] + ] + self.assertEqual(shapes, expected_shapes) + + # All dtypes should be torch.int64 + for dtype in dtypes: + self.assertEqual(dtype, torch.int64) + + # No custom_meta provided, so all should be None + for meta in custom_meta_list: + self.assertIsNone(meta) + + def test_get_shape_type_custom_meta_list_with_custom_meta(self): + """Test _get_shape_type_custom_meta_list returns correct custom_meta when provided.""" + # Add custom_meta to metadata + custom_meta = { + 8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}}, + 9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}}, + 10: {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}}, + } + self.metadata.update_custom_meta(custom_meta) + + shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata) + + # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index + expected_custom_meta = [ + {"key2": "value2"}, # label, global_index=8 + {"key5": "value5"}, # label, global_index=9 + {"key8": "value8"}, # label, global_index=10 + {"key3": "value3"}, # mask, global_index=8 + {"key6": "value6"}, # mask, global_index=9 + {"key9": "value9"}, # mask, global_index=10 + {"key1": "value1"}, # text, global_index=8 + {"key4": "value4"}, # text, global_index=9 + {"key7": "value7"}, # text, global_index=10 + ] + self.assertEqual(custom_meta_list, expected_custom_meta) + + def test_get_shape_type_custom_meta_list_with_partial_custom_meta(self): + """Test _get_shape_type_custom_meta_list handles partial custom_meta correctly.""" + # Add custom_meta only for some global_indexes and fields + custom_meta = { + 8: {"text": {"key1": "value1"}}, # Only text field + # global_index 9 has no custom_meta + 10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only + } + self.metadata.update_custom_meta(custom_meta) + + shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata) + + # Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index + expected_custom_meta = [ + None, # label, global_index=8 (not in custom_meta) + None, # label, global_index=9 (not in custom_meta) + {"key2": "value2"}, # label, global_index=10 + None, # mask, global_index=8 (not in custom_meta) + None, # mask, global_index=9 (not in custom_meta) + {"key3": "value3"}, # mask, global_index=10 + {"key1": "value1"}, # text, global_index=8 + None, # text, global_index=9 (not in custom_meta) + None, # text, global_index=10 (not in custom_meta for text) + ] + self.assertEqual(custom_meta_list, expected_custom_meta) + + +class TestPutDataWithCustomMeta(unittest.TestCase): + """Test put_data with custom_meta functionality.""" + + def setUp(self): + """Set up test fixtures for put_data tests.""" + self.field_names = ["text", "label"] + self.global_indexes = [0, 1, 2] + + # Create test data + self.data = TensorDict( + { + "text": torch.tensor([[1, 2], [3, 4], [5, 6]]), + "label": torch.tensor([0, 1, 2]), + }, + batch_size=3, + ) + + # Create metadata without production status set (for insert mode) + samples = [] + for sample_id in range(self.data.batch_size[0]): + fields_dict = {} + for field_name in self.data.keys(): + tensor = self.data[field_name][sample_id] + field_meta = FieldMeta(name=field_name, dtype=tensor.dtype, shape=tensor.shape, production_status=0) + fields_dict[field_name] = field_meta + sample = SampleMeta( + partition_id="test_partition", + global_index=self.global_indexes[sample_id], + fields=fields_dict, + ) + samples.append(sample) + self.metadata = BatchMeta(samples=samples) + + @patch.object(KVStorageManager, "_connect_to_controller") + @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) + def test_put_data_with_custom_meta_from_storage_client(self, mock_notify, mock_connect): + """Test that put_data correctly processes custom_meta returned by storage client.""" + # Create a mock storage client + mock_storage_client = MagicMock() + # Simulate storage client returning custom_meta (one per key) + # Keys order: label[0,1,2], text[0,1,2] (sorted by field name) + mock_custom_meta = [ + {"storage_key": "0@label"}, + {"storage_key": "1@label"}, + {"storage_key": "2@label"}, + {"storage_key": "0@text"}, + {"storage_key": "1@text"}, + {"storage_key": "2@text"}, + ] + mock_storage_client.put.return_value = mock_custom_meta + + # Create manager with mocked dependencies + config = {"client_name": "MockClient"} + with patch( + "transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client + ): + manager = KVStorageManager(config) + + # Run put_data + asyncio.run(manager.put_data(self.data, self.metadata)) + + # Verify storage client was called with correct keys and values + mock_storage_client.put.assert_called_once() + call_args = mock_storage_client.put.call_args + keys = call_args[0][0] + values = call_args[0][1] + + # Verify keys are correct + expected_keys = ["0@label", "1@label", "2@label", "0@text", "1@text", "2@text"] + self.assertEqual(keys, expected_keys) + self.assertEqual(len(values), 6) + + # Verify notify_data_update was called with correct custom_meta structure + mock_notify.assert_called_once() + notify_call_args = mock_notify.call_args + per_field_custom_meta = notify_call_args[0][5] # 6th positional argument + + # Verify custom_meta is structured correctly: {global_index: {field: meta}} + self.assertIn(0, per_field_custom_meta) + self.assertIn(1, per_field_custom_meta) + self.assertIn(2, per_field_custom_meta) + + self.assertEqual(per_field_custom_meta[0]["label"], {"storage_key": "0@label"}) + self.assertEqual(per_field_custom_meta[0]["text"], {"storage_key": "0@text"}) + self.assertEqual(per_field_custom_meta[1]["label"], {"storage_key": "1@label"}) + self.assertEqual(per_field_custom_meta[1]["text"], {"storage_key": "1@text"}) + self.assertEqual(per_field_custom_meta[2]["label"], {"storage_key": "2@label"}) + self.assertEqual(per_field_custom_meta[2]["text"], {"storage_key": "2@text"}) + + # Verify metadata was updated with custom_meta + all_custom_meta = self.metadata.get_all_custom_meta() + self.assertEqual(all_custom_meta[0]["label"], {"storage_key": "0@label"}) + self.assertEqual(all_custom_meta[2]["text"], {"storage_key": "2@text"}) + + @patch.object(KVStorageManager, "_connect_to_controller") + @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) + def test_put_data_without_custom_meta(self, mock_notify, mock_connect): + """Test that put_data works correctly when storage client returns no custom_meta.""" + # Create a mock storage client that returns None for custom_meta + mock_storage_client = MagicMock() + mock_storage_client.put.return_value = None + + # Create manager with mocked dependencies + config = {"client_name": "MockClient"} + with patch( + "transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client + ): + manager = KVStorageManager(config) + + # Run put_data + asyncio.run(manager.put_data(self.data, self.metadata)) + + # Verify notify_data_update was called with empty dict for custom_meta + mock_notify.assert_called_once() + notify_call_args = mock_notify.call_args + per_field_custom_meta = notify_call_args[0][5] # 6th positional argument + self.assertEqual(per_field_custom_meta, {}) + + @patch.object(KVStorageManager, "_connect_to_controller") + @patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock) + def test_put_data_custom_meta_length_mismatch_raises_error(self, mock_notify, mock_connect): + """Test that put_data raises ValueError when custom_meta length doesn't match keys.""" + # Create a mock storage client that returns mismatched custom_meta length + mock_storage_client = MagicMock() + # Return only 3 custom_meta entries when 6 are expected + mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}] + + # Create manager with mocked dependencies + config = {"client_name": "MockClient"} + with patch( + "transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client + ): + manager = KVStorageManager(config) + + # Run put_data and expect ValueError + with self.assertRaises(ValueError) as context: + asyncio.run(manager.put_data(self.data, self.metadata)) + + self.assertIn("does not match", str(context.exception)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 23a6a72..2780f31 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -773,3 +773,170 @@ def test_batch_meta_concat_validation_error(self): with pytest.raises(ValueError) as exc_info: BatchMeta.concat([batch1, batch2], validate=True) assert "Field names do not match" in str(exc_info.value) + + +class TestCustomMeta: + """Unit tests for BatchMeta custom meta methods.""" + + def test_get_all_custom_meta_returns_deep_copy(self): + """Test get_all_custom_meta returns a deep copy of the custom meta dict.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + custom_meta = {0: {"field_a": {"nested": "value"}}} + batch.update_custom_meta(custom_meta) + + # Get all custom meta + result = batch.get_all_custom_meta() + + # Verify it's a deep copy - modifying result should not affect original + result[0]["field_a"]["nested"] = "modified" + + original = batch.get_all_custom_meta() + assert original[0]["field_a"]["nested"] == "value" + + def test_get_all_custom_meta_empty(self): + """Test get_all_custom_meta with no custom meta returns empty dict.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + result = batch.get_all_custom_meta() + + assert result == {} + + def test_update_custom_meta_basic(self): + """Test update_custom_meta adds new entries.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Update with custom meta + custom_meta = { + 0: {"field_a": "value_0"}, + 1: {"field_a": "value_1"}, + } + batch.update_custom_meta(custom_meta) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "value_0" + assert result[1]["field_a"] == "value_1" + + def test_update_custom_meta_overwrites_existing(self): + """Test update_custom_meta overwrites existing entries at the top level.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Initial custom meta + batch.update_custom_meta({0: {"field_a": "original"}}) + + # Update with new value - dict.update replaces the entire value for key 0 + batch.update_custom_meta({0: {"field_a": "updated"}}) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "updated" + + def test_update_custom_meta_merges_different_keys(self): + """Test update_custom_meta merges different top-level keys.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + SampleMeta(partition_id="partition_0", global_index=1, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # First update + batch.update_custom_meta({0: {"field_a": "value_0"}}) + + # Second update with different key + batch.update_custom_meta({1: {"field_a": "value_1"}}) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "value_0" + assert result[1]["field_a"] == "value_1" + + def test_update_custom_meta_with_none(self): + """Test update_custom_meta with None does nothing.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Set initial value + batch.update_custom_meta({0: {"field_a": "value"}}) + + # Update with None should not change anything + batch.update_custom_meta(None) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "value" + + def test_update_custom_meta_with_empty_dict(self): + """Test update_custom_meta with empty dict does nothing.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Set initial value + batch.update_custom_meta({0: {"field_a": "value"}}) + + # Update with empty dict should not change anything + batch.update_custom_meta({}) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"] == "value" + + def test_custom_meta_with_complex_values(self): + """Test custom meta can store complex values like dicts, lists, tensors.""" + fields = { + "field_a": FieldMeta(name="field_a", dtype=torch.float32, shape=(2,)), + } + samples = [ + SampleMeta(partition_id="partition_0", global_index=0, fields=fields), + ] + batch = BatchMeta(samples=samples) + + # Store complex values + custom_meta = { + 0: { + "field_a": { + "nested_dict": {"key": "value"}, + "list": [1, 2, 3], + "number": 42, + } + } + } + batch.update_custom_meta(custom_meta) + + result = batch.get_all_custom_meta() + assert result[0]["field_a"]["nested_dict"]["key"] == "value" + assert result[0]["field_a"]["list"] == [1, 2, 3] + assert result[0]["field_a"]["number"] == 42 diff --git a/transfer_queue/controller.py b/transfer_queue/controller.py index 3eab84b..fae9e96 100644 --- a/transfer_queue/controller.py +++ b/transfer_queue/controller.py @@ -230,6 +230,7 @@ class DataPartitionStatus: field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype} field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape} + field_custom_metas: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: custom_meta} # Threading lock for concurrency control; only for preventing mask operation error when expanding production_status. # No need to strictly lock for every read/write operation since freshness is not critical. @@ -326,6 +327,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> bool: """ Update production status for specific samples and fields. @@ -336,6 +338,7 @@ def update_production_status( field_names: List of field names to mark as produced dtypes: Optional per-sample field dtype information shapes: Optional per-sample field shape information + custom_meta: Optional per-sample field custom metadata Returns: True if update was successful, False on error @@ -366,7 +369,7 @@ def update_production_status( self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1 # Update field metadata - self._update_field_metadata(global_indices, dtypes, shapes) + self._update_field_metadata(global_indices, dtypes, shapes, custom_meta) # Save these global_indexes self.global_indexes.update(global_indices) @@ -382,33 +385,63 @@ def _update_field_metadata( global_indices: list[int], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]], ): """Update field dtype and shape metadata.""" if not global_indices: return - assert len(global_indices) == len(dtypes), "`global_indices` and `dtypes` length mismatch." - assert len(global_indices) == len(shapes), "`global_indices` and `shapes` length mismatch." + # Validate lengths only for provided mappings + if dtypes and len(global_indices) != len(dtypes): + raise ValueError(f"`global_indices` {len(global_indices)} and `dtypes` {len(dtypes)} length mismatch.") + if shapes and len(global_indices) != len(shapes): + raise ValueError(f"`global_indices` {len(global_indices)} and `shapes` {len(shapes)} length mismatch.") + if custom_meta and len(global_indices) != len(custom_meta): + raise ValueError( + f"`global_indices` {len(global_indices)} and `custom_meta` {len(custom_meta)} length mismatch." + ) - dtype_value = itemgetter(*global_indices)(dtypes) if dtypes else None - shape_value = itemgetter(*global_indices)(shapes) if shapes else None + # Extract values for each provided mapping; if a mapping is absent, use Nones + if dtypes: + dtype_value = itemgetter(*global_indices)(dtypes) + if not isinstance(dtype_value, tuple): + dtype_value = (dtype_value,) + else: + dtype_value = tuple([None] * len(global_indices)) - if not isinstance(dtype_value, tuple): - dtype_value = (dtype_value,) - if not isinstance(shape_value, tuple): - shape_value = (shape_value,) + if shapes: + shape_value = itemgetter(*global_indices)(shapes) + if not isinstance(shape_value, tuple): + shape_value = (shape_value,) + else: + shape_value = tuple([None] * len(global_indices)) - for i, global_idx in enumerate(global_indices): - if global_idx not in self.field_dtypes: - self.field_dtypes[global_idx] = {} - if global_idx not in self.field_shapes: - self.field_shapes[global_idx] = {} + if custom_meta: + custom_meta_value = itemgetter(*global_indices)(custom_meta) + if not isinstance(custom_meta_value, tuple): + custom_meta_value = (custom_meta_value,) + else: + custom_meta_value = tuple([None] * len(global_indices)) - if dtype_value is not None: + for i, global_idx in enumerate(global_indices): + # Only create and update dtype mapping if a dtype value was provided + if dtype_value[i] is not None: + if global_idx not in self.field_dtypes: + self.field_dtypes[global_idx] = {} self.field_dtypes[global_idx].update(dtype_value[i]) - if shape_value is not None: + + # Only create and update shape mapping if a shape value was provided + if shape_value[i] is not None: + if global_idx not in self.field_shapes: + self.field_shapes[global_idx] = {} self.field_shapes[global_idx].update(shape_value[i]) + # Only create and update custom_meta mapping if a custom_meta value was provided + if custom_meta_value[i] is not None: + if global_idx not in self.field_custom_metas: + self.field_custom_metas[global_idx] = {} + self.field_custom_metas[global_idx].update(custom_meta_value[i]) + # ==================== Consumption Status Interface ==================== def get_consumption_status(self, task_name: str) -> torch.Tensor: @@ -544,6 +577,14 @@ def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]: """Get shape for a specific sample and field.""" return self.field_shapes.get(global_index, {}).get(field_name) + def get_field_custom_meta(self, global_indices: list[int], field_names: list[str]) -> dict[int, dict[str, Any]]: + """Get custom_meta for multiple samples and fields.""" + return { + idx: {f: v for f, v in self.field_custom_metas[idx].items() if f in field_names} + for idx in global_indices + if idx in self.field_custom_metas + } + # ==================== Statistics and Monitoring ==================== def get_statistics(self) -> dict[str, Any]: @@ -571,7 +612,9 @@ def get_statistics(self) -> dict[str, Any]: field_produced = (self.production_status[:, field_idx] == 1).sum().item() field_stats[field_name] = { "produced_samples": field_produced, - "production_progress": field_produced / self.total_samples_num if self.total_samples_num > 0 else 0, + "production_progress": ( + field_produced / self.total_samples_num if self.total_samples_num > 0 else 0 + ), } stats["field_statistics"] = field_stats @@ -581,7 +624,9 @@ def get_statistics(self) -> dict[str, Any]: consumed_samples = (consumption_tensor == 1).sum().item() consumption_stats[task_name] = { "consumed_samples": consumed_samples, - "consumption_progress": consumed_samples / self.total_samples_num if self.total_samples_num > 0 else 0, + "consumption_progress": ( + consumed_samples / self.total_samples_num if self.total_samples_num > 0 else 0 + ), } stats["consumption_statistics"] = consumption_stats @@ -632,6 +677,10 @@ def clear_data(self, indexes_to_release: list[int], clear_consumption: bool = Tr consumption_tensor[indexes_to_release] = 0 self.global_indexes.difference_update(indexes_to_release) + for idx in indexes_to_release: + self.field_dtypes.pop(idx, None) + self.field_shapes.pop(idx, None) + self.field_custom_metas.pop(idx, None) except Exception as e: logger.error( @@ -658,7 +707,9 @@ class TransferQueueController: """ def __init__( - self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler, polling_mode: bool = False + self, + sampler: BaseSampler | type[BaseSampler] = SequentialSampler, + polling_mode: bool = False, ) -> None: """Initialize the TransferQueue Controller. @@ -791,6 +842,7 @@ def update_production_status( field_names: list[str], dtypes: Optional[dict[int, dict[str, Any]]], shapes: Optional[dict[int, dict[str, Any]]], + custom_meta: Optional[dict[int, dict[str, Any]]] = None, ) -> bool: """ Update production status for specific samples and fields in a partition. @@ -811,7 +863,7 @@ def update_production_status( logger.error(f"Partition {partition_id} not found") return False - success = partition.update_production_status(global_indexes, field_names, dtypes, shapes) + success = partition.update_production_status(global_indexes, field_names, dtypes, shapes, custom_meta) if success: logger.debug( f"[{self.controller_id}]: Updated production status for partition {partition_id}: " @@ -1070,7 +1122,11 @@ def generate_batch_meta( ) samples.append(sample) - return BatchMeta(samples=samples) + custom_meta = partition.get_field_custom_meta(batch_global_indexes, data_fields) + + batch_meta = BatchMeta(samples=samples) + batch_meta.update_custom_meta(custom_meta) + return batch_meta def clear_partition(self, partition_id: str, clear_consumption: bool = True): """ @@ -1092,7 +1148,12 @@ def clear_partition(self, partition_id: str, clear_consumption: bool = True): self.index_manager.release_partition(partition_id) self.partitions.pop(partition_id) - def clear_meta(self, global_indexes: list[int], partition_ids: list[str], clear_consumption: bool = True): + def clear_meta( + self, + global_indexes: list[int], + partition_ids: list[str], + clear_consumption: bool = True, + ): """ Clear meta for individual samples (preserving the partition). @@ -1230,7 +1291,9 @@ def _wait_connection(self): def _start_process_handshake(self): """Start the handshake process thread.""" self.wait_connection_thread = Thread( - target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True + target=self._wait_connection, + name="TransferQueueControllerWaitConnectionThread", + daemon=True, ) self.wait_connection_thread.start() @@ -1246,7 +1309,9 @@ def _start_process_update_data_status(self): def _start_process_request(self): """Start the request processing thread.""" self.process_request_thread = Thread( - target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True + target=self._process_request, + name="TransferQueueControllerProcessRequestThread", + daemon=True, ) self.process_request_thread.start() @@ -1408,6 +1473,7 @@ def _update_data_status(self): field_names=message_data.get("fields", []), dtypes=message_data.get("dtypes", {}), shapes=message_data.get("shapes", {}), + custom_meta=message_data.get("custom_meta", {}), ) if success: diff --git a/transfer_queue/metadata.py b/transfer_queue/metadata.py index c389157..a75e6cd 100644 --- a/transfer_queue/metadata.py +++ b/transfer_queue/metadata.py @@ -125,8 +125,11 @@ def select_fields(self, field_names: list[str]) -> "SampleMeta": selected_fields = {name: self.fields[name] for name in field_names if name in self.fields} # construct new SampleMeta instance + # TODO(tianyi): move custom_meta to FieldMeta level selected_sample_meta = SampleMeta( - fields=selected_fields, partition_id=self.partition_id, global_index=self.global_index + fields=selected_fields, + partition_id=self.partition_id, + global_index=self.global_index, ) return selected_sample_meta @@ -174,6 +177,8 @@ class BatchMeta: samples: list[SampleMeta] extra_info: dict[str, Any] = dataclasses.field(default_factory=dict) + # internal data for different storage backends: _custom_meta[global_index][field] + _custom_meta: dict[int, dict[str, Any]] = dataclasses.field(default_factory=dict) def __post_init__(self): """Initialize all computed properties during initialization""" @@ -230,6 +235,16 @@ def partition_ids(self) -> list[str]: """Get partition ids for all samples in this batch as a list (one per sample)""" return getattr(self, "_partition_ids", []) + # Custom meta methods for different storage backends + def get_all_custom_meta(self) -> dict[int, dict[str, Any]]: + """Get the entire custom meta dictionary""" + return copy.deepcopy(self._custom_meta) + + def update_custom_meta(self, new_custom_meta: dict[int, dict[str, Any]] = None): + """Update custom meta with a new dictionary""" + if new_custom_meta: + self._custom_meta.update(new_custom_meta) + # Extra info interface methods def get_extra_info(self, key: str, default: Any = None) -> Any: """Get extra info by key""" diff --git a/transfer_queue/storage/clients/base.py b/transfer_queue/storage/clients/base.py index fc5677d..db6239d 100644 --- a/transfer_queue/storage/clients/base.py +++ b/transfer_queue/storage/clients/base.py @@ -14,8 +14,7 @@ # limitations under the License. from abc import ABC, abstractmethod - -from torch import Tensor +from typing import Any, Optional class TransferQueueStorageKVClient(ABC): @@ -25,11 +24,36 @@ class TransferQueueStorageKVClient(ABC): """ @abstractmethod - def put(self, keys: list[str], values: list[Tensor]) -> None: + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: + """ + Store key-value pairs in the storage backend. + Args: + keys (list[str]): List of keys to store. + values (list[Any]): List of any type to store. + Returns: + Optional[list[Any]]: Optional list of custom metadata from each storage backend. + """ raise NotImplementedError("Subclasses must implement put") @abstractmethod - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: + """ + Retrieve values from the storage backend by key. + Args: + keys (list[str]): List of keys whose values should be retrieved. + shapes: Optional shape information for the expected values. The + structure and interpretation of this argument are determined + by the concrete storage backend implementation. + dtypes: Optional data type information for the expected values. + The structure and interpretation of this argument are + determined by the concrete storage backend implementation. + custom_meta: Optional backend-specific metadata used to control + or optimize the retrieval process. Its format is defined by + the concrete storage backend implementation. + Returns: + list[Any]: List of values retrieved from the storage backend, + in the same order as the provided keys. + """ raise NotImplementedError("Subclasses must implement get") @abstractmethod diff --git a/transfer_queue/storage/clients/mooncake_client.py b/transfer_queue/storage/clients/mooncake_client.py index a71262b..80efa09 100644 --- a/transfer_queue/storage/clients/mooncake_client.py +++ b/transfer_queue/storage/clients/mooncake_client.py @@ -1,7 +1,7 @@ import logging import os import pickle -from typing import Any +from typing import Any, Optional import torch from torch import Tensor @@ -53,7 +53,7 @@ def __init__(self, config: dict[str, Any]): if ret != 0: raise RuntimeError(f"Mooncake store setup failed with error code: {ret}") - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: if not isinstance(keys, list) or not isinstance(values, list): raise ValueError("keys and values must be lists") if len(keys) != len(values): @@ -82,6 +82,8 @@ def put(self, keys: list[str], values: list[Any]): if non_tensor_keys: self._batch_put_bytes(non_tensor_keys, non_tensor_values) + return None + def _batch_put_tensors(self, keys: list[str], tensors: list[Tensor]): for i in range(0, len(keys), BATCH_SIZE_LIMIT): batch_keys = keys[i : i + BATCH_SIZE_LIMIT] @@ -104,7 +106,7 @@ def _batch_put_bytes(self, keys: list[str], values: list[bytes]): if ret != 0: raise RuntimeError(f"put_batch failed with error code: {ret}") - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: if shapes is None or dtypes is None: raise ValueError("MooncakeStorageClient needs shapes and dtypes") if not (len(keys) == len(shapes) == len(dtypes)): diff --git a/transfer_queue/storage/clients/ray_storage_client.py b/transfer_queue/storage/clients/ray_storage_client.py index 5b85825..5ffd023 100644 --- a/transfer_queue/storage/clients/ray_storage_client.py +++ b/transfer_queue/storage/clients/ray_storage_client.py @@ -1,5 +1,5 @@ import itertools -from typing import Any +from typing import Any, Optional import ray import torch @@ -38,7 +38,7 @@ def __init__(self, config=None): except ValueError: self.storage_actor = RayObjectRefStorage.options(name="RayObjectRefStorage", get_if_exists=False).remote() - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: """ Store tensors to remote storage. Args: @@ -58,14 +58,16 @@ def put(self, keys: list[str], values: list[Any]): ) ) ray.get(self.storage_actor.put_obj_ref.remote(keys, obj_refs)) + return None - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: """ Retrieve objects from remote storage. Args: keys (list): List of string keys to fetch. shapes (list, optional): Ignored. For compatibility with KVStorageManager. dtypes (list, optional): Ignored. For compatibility with KVStorageManager. + custom_meta (list, optional): Ray object ref for each key Returns: list: List of retrieved objects """ diff --git a/transfer_queue/storage/clients/yuanrong_client.py b/transfer_queue/storage/clients/yuanrong_client.py index 4652314..c7da326 100644 --- a/transfer_queue/storage/clients/yuanrong_client.py +++ b/transfer_queue/storage/clients/yuanrong_client.py @@ -16,7 +16,7 @@ import logging import os import pickle -from typing import Any +from typing import Any, Optional import torch from torch import Tensor @@ -161,7 +161,7 @@ def _batch_put(self, keys: list[str], values: list[Any]): batch_vals = pickled_values[i : i + CPU_DS_CLIENT_KEYS_LIMIT] self._cpu_ds_client.mset(batch_keys, batch_vals) - def put(self, keys: list[str], values: list[Any]): + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: """Stores multiple key-value pairs to remote storage. Automatically routes NPU tensors to high-performance tensor storage, @@ -176,6 +176,7 @@ def put(self, keys: list[str], values: list[Any]): if len(keys) != len(values): raise ValueError("Number of keys must match number of values") self._batch_put(keys, values) + return None def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: """Retrieves a batch of values from remote storage using expected metadata. @@ -262,7 +263,7 @@ def _batch_get(self, keys: list[str], shapes: list, dtypes: list) -> list[Any]: idx += 1 return results - def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: """Retrieves multiple values from remote storage with expected metadata. Requires shape and dtype hints to reconstruct NPU tensors correctly. @@ -271,6 +272,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Any]: keys (List[str]): Keys to fetch. shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. + custom_meta (List[str], optional): Device type (npu/cpu) for each key Returns: List[Any]: Retrieved values in the same order as input keys. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 64927fd..e2d5f1b 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -185,6 +185,7 @@ async def notify_data_update( global_indexes: list[int], dtypes: dict[int, dict[str, Any]], shapes: dict[int, dict[str, Any]], + custom_meta: dict[int, dict[str, Any]] = None, ) -> None: """ Notify controller that new data is ready. @@ -195,6 +196,7 @@ async def notify_data_update( global_indexes: Data update related global_indexes. dtypes: Per-field dtypes for each field, in {global_index: {field: dtype}} format. shapes: Per-field shapes for each field, in {global_index: {field: shape}} format. + custom_meta: Per-field custom_meta for each field, in {global_index: {field: custom_meta}} format. """ # Create zmq poller for notifying data update information @@ -218,6 +220,7 @@ async def notify_data_update( "global_indexes": global_indexes, "dtypes": dtypes, "shapes": shapes, + "custom_meta": custom_meta, }, ).serialize() @@ -405,25 +408,29 @@ def _merge_tensors_to_tensordict(metadata: BatchMeta, values: list[Tensor]) -> T return TensorDict(merged_data, batch_size=len(global_indexes)) @staticmethod - def _get_shape_type_list(metadata: BatchMeta): + def _get_shape_type_custom_meta_list(metadata: BatchMeta): """ - Extract the expected shape and dtype for each field-sample pair in metadata. + Extract the expected shape, dtype, and custom meta for each field-sample pair in metadata. The order matches the key/value order: sorted by field name, then by global index. Args: metadata (BatchMeta): Metadata containing sample and field information. Returns: - tuple[list[torch.Size], list[torch.dtype]]: Two lists containing the shape and dtype - for each tensor to be retrieved. + tuple[list[torch.Size], list[torch.dtype], list[Any]]: the shape list, dtype list and + custom meta list for each tensor to be retrieved. """ shapes = [] dtypes = [] + custom_meta_list = [] + all_custom_meta = metadata.get_all_custom_meta() for field_name in sorted(metadata.field_names): for index in range(len(metadata)): field = metadata.samples[index].get_field_by_name(field_name) shapes.append(field.shape) dtypes.append(field.dtype) - return shapes, dtypes + global_index = metadata.global_indexes[index] + custom_meta_list.append(all_custom_meta.get(global_index, {}).get(field_name, None)) + return shapes, dtypes, custom_meta_list async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: """ @@ -445,7 +452,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: keys = self._generate_keys(data.keys(), metadata.global_indexes) values = self._generate_values(data) loop = asyncio.get_event_loop() - await loop.run_in_executor(None, self.storage_client.put, keys, values) + custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values) per_field_dtypes = {} per_field_shapes = {} @@ -466,13 +473,37 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: getattr(data_item, "shape", None) if isinstance(data_item, Tensor) else None ) + # Prepare per-field custom_meta if available + per_field_custom_meta = {} + if custom_meta: + if len(custom_meta) != len(keys): + raise ValueError(f"Length of custom_meta ({len(custom_meta)}) does not match expected ({len(keys)})") + # custom meta is a flat list aligned with keys/values + # Use itertools.product to eliminate nested loops + for global_idx in metadata.global_indexes: + per_field_custom_meta[global_idx] = {} + + # TODO(tianyi): the order of custom meta is coupled with keys/values + for (field_name, global_idx), meta_value in zip( + itertools.product(sorted(metadata.field_names), metadata.global_indexes), + custom_meta, + strict=True, + ): + per_field_custom_meta[global_idx][field_name] = meta_value + metadata.update_custom_meta(per_field_custom_meta) + # Get current data partition id # Note: Currently we only support putting to & getting data from a single data partition simultaneously, # but in the future we may support putting to & getting data from multiple data partitions concurrently. partition_id = metadata.samples[0].partition_id # notify controller that new data is ready await self.notify_data_update( - partition_id, list(data.keys()), metadata.global_indexes, per_field_dtypes, per_field_shapes + partition_id, + list(data.keys()), + metadata.global_indexes, + per_field_dtypes, + per_field_shapes, + per_field_custom_meta, ) async def get_data(self, metadata: BatchMeta) -> TensorDict: @@ -486,8 +517,8 @@ async def get_data(self, metadata: BatchMeta) -> TensorDict: logger.warning("Attempted to get data, but metadata contains no fields.") return TensorDict({}, batch_size=len(metadata)) keys = self._generate_keys(metadata.field_names, metadata.global_indexes) - shapes, dtypes = self._get_shape_type_list(metadata) - values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes) + shapes, dtypes, custom_meta = self._get_shape_type_custom_meta_list(metadata) + values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes, custom_meta=custom_meta) return self._merge_tensors_to_tensordict(metadata, values) async def clear_data(self, metadata: BatchMeta) -> None: