Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 90 additions & 5 deletions tests/test_controller_data_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import time
from pathlib import Path

import pytest

parent_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(parent_dir))

Expand Down Expand Up @@ -63,6 +65,7 @@ def test_data_partition_status():
1: {"input_ids": (512,), "attention_mask": (512,)},
2: {"input_ids": (512,), "attention_mask": (512,)},
},
custom_meta=None,
)

assert success
Expand Down Expand Up @@ -172,6 +175,7 @@ def test_dynamic_expansion_scenarios():
5: {"field_1": (32,)},
10: {"field_1": (32,)},
},
custom_meta=None,
)
assert partition.total_samples_num == 3
assert partition.allocated_samples_num >= 11 # Should accommodate index 10
Expand All @@ -180,7 +184,7 @@ def test_dynamic_expansion_scenarios():
# Scenario 2: Adding many fields dynamically
for i in range(15):
partition.update_production_status(
[0], [f"field_{i}"], {0: {f"field_{i}": "torch.bool"}}, {0: {f"field_{i}": (32,)}}
[0], [f"field_{i}"], {0: {f"field_{i}": "torch.bool"}}, {0: {f"field_{i}": (32,)}}, None
)

assert partition.total_fields_num == 16 # Original + 15 new fields
Expand Down Expand Up @@ -222,7 +226,7 @@ def test_data_partition_status_advanced():
# Add data to trigger expansion
dtypes = {i: {f"dynamic_field_{s}": "torch.bool" for s in ["a", "b", "c"]} for i in range(5)}
shapes = {i: {f"dynamic_field_{s}": (32,) for s in ["a", "b", "c"]} for i in range(5)}
partition.update_production_status([0, 1, 2, 3, 4], ["field_a", "field_b", "field_c"], dtypes, shapes)
partition.update_production_status([0, 1, 2, 3, 4], ["field_a", "field_b", "field_c"], dtypes, shapes, None)

# Properties should reflect current state
assert partition.total_samples_num >= 5 # At least 5 samples
Expand Down Expand Up @@ -253,7 +257,7 @@ def test_data_partition_status_advanced():
11: {"field_d": (32,)},
12: {"field_d": (32,)},
}
partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes) # Triggers sample expansion
partition.update_production_status([10, 11, 12], ["field_d"], dtypes, shapes, None) # Triggers sample expansion
expanded_consumption = partition.get_consumption_status(task_name)
assert expanded_consumption[0] == 1 # Preserved
assert expanded_consumption[1] == 1 # Preserved
Expand All @@ -265,13 +269,13 @@ def test_data_partition_status_advanced():
# Start with some fields
dtypes = {0: {"initial_field": "torch.bool"}}
shapes = {0: {"field_d": (32,)}}
partition.update_production_status([0], ["initial_field"], dtypes, shapes)
partition.update_production_status([0], ["initial_field"], dtypes, shapes, None)

# Add many fields to trigger column expansion
new_fields = [f"dynamic_field_{i}" for i in range(20)]
dtypes = {1: {f"dynamic_field_{i}": "torch.bool" for i in range(20)}}
shapes = {1: {f"dynamic_field_{i}": (32,) for i in range(20)}}
partition.update_production_status([1], new_fields, dtypes, shapes)
partition.update_production_status([1], new_fields, dtypes, shapes, None)

# Verify all fields are registered and accessible
assert "initial_field" in partition.field_name_mapping
Expand Down Expand Up @@ -441,3 +445,84 @@ def test_performance_characteristics():
print("✓ Memory usage patterns reasonable")

print("Performance characteristics tests passed!\n")


def test_custom_meta_in_data_partition_status():
"""Simplified tests for custom_meta functionality in DataPartitionStatus."""

print("Testing simplified custom_meta in DataPartitionStatus...")

from transfer_queue.controller import DataPartitionStatus

partition = DataPartitionStatus(partition_id="custom_meta_test")

# Basic custom_meta storage via update_production_status
global_indices = [0, 1, 2]
field_names = ["input_ids", "attention_mask"]
dtypes = {i: {"input_ids": "torch.int32", "attention_mask": "torch.bool"} for i in global_indices}
shapes = {i: {"input_ids": (512,), "attention_mask": (512,)} for i in global_indices}
custom_meta = {
0: {"input_ids": {"token_count": 100}},
1: {"attention_mask": {"mask_ratio": 0.2}},
2: {"input_ids": {"token_count": 300}},
}

success = partition.update_production_status(
global_indices=global_indices,
field_names=field_names,
dtypes=dtypes,
shapes=shapes,
custom_meta=custom_meta,
)

assert success

# Verify some stored values
assert partition.field_custom_metas[0]["input_ids"]["token_count"] == 100
assert partition.field_custom_metas[1]["attention_mask"]["mask_ratio"] == 0.2

# Retrieval via helper for a subset of fields
retrieved = partition.get_field_custom_meta([0, 1], ["input_ids", "attention_mask"])
assert 0 in retrieved and "input_ids" in retrieved[0]
assert 1 in retrieved and "attention_mask" in retrieved[1]

# Clearing a sample should remove its custom_meta
partition.clear_data([0], clear_consumption=True)
assert 0 not in partition.field_custom_metas

print("✓ Custom_meta tests passed")


def test_update_field_metadata_variants():
"""Test _update_field_metadata handles dtypes/shapes/custom_meta being optional and merging."""
from transfer_queue.controller import DataPartitionStatus

partition = DataPartitionStatus(partition_id="update_meta_test")

# Only dtypes provided
global_indices = [0, 1]
dtypes = {0: {"f1": "torch.int32"}, 1: {"f1": "torch.bool"}}

partition._update_field_metadata(global_indices, dtypes, shapes=None, custom_meta=None)
assert partition.field_dtypes[0]["f1"] == "torch.int32"
assert partition.field_dtypes[1]["f1"] == "torch.bool"
assert partition.field_shapes == {}
assert partition.field_custom_metas == {}

# Only shapes provided for a new index
partition._update_field_metadata([2], dtypes=None, shapes={2: {"f2": (16,)}}, custom_meta=None)
assert partition.field_shapes[2]["f2"] == (16,)

# Only custom_meta provided and merged with existing entries
partition._update_field_metadata([2], dtypes=None, shapes=None, custom_meta={2: {"f2": {"meta": 1}}})
assert 2 in partition.field_custom_metas
assert partition.field_custom_metas[2]["f2"]["meta"] == 1

# Merging dtypes on an existing index should preserve previous keys
partition._update_field_metadata([0], dtypes={0: {"f2": "torch.float32"}}, shapes=None, custom_meta=None)
assert partition.field_dtypes[0]["f1"] == "torch.int32"
assert partition.field_dtypes[0]["f2"] == "torch.float32"

# Length mismatch should raise ValueError when provided mapping lengths differ from global_indices
with pytest.raises(ValueError):
partition._update_field_metadata([0, 1, 2], dtypes={0: {}}, shapes=None, custom_meta=None)
228 changes: 228 additions & 0 deletions tests/test_kv_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import sys
import unittest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import torch
from tensordict import TensorDict
Expand Down Expand Up @@ -97,6 +99,232 @@ def test_merge_kv_to_tensordict(self):

self.assertEqual(reconstructed.batch_size, torch.Size([3]))

def test_get_shape_type_custom_meta_list_without_custom_meta(self):
"""Test _get_shape_type_custom_meta_list returns correct shapes and dtypes without custom_meta."""
shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata)

# Expected order: sorted by field name (label, mask, text), then by global_index order
# 3 fields * 3 samples = 9 entries
self.assertEqual(len(shapes), 9)
self.assertEqual(len(dtypes), 9)
self.assertEqual(len(custom_meta_list), 9)

# Check shapes - order is label, mask, text (sorted alphabetically)
# label shapes: [()]*3, mask shapes: [(1,)]*3, text shapes: [(2,)]*3
expected_shapes = [
torch.Size([]), # label[0]
torch.Size([]), # label[1]
torch.Size([]), # label[2]
torch.Size([1]), # mask[0]
torch.Size([1]), # mask[1]
torch.Size([1]), # mask[2]
torch.Size([2]), # text[0]
torch.Size([2]), # text[1]
torch.Size([2]), # text[2]
]
self.assertEqual(shapes, expected_shapes)

# All dtypes should be torch.int64
for dtype in dtypes:
self.assertEqual(dtype, torch.int64)

# No custom_meta provided, so all should be None
for meta in custom_meta_list:
self.assertIsNone(meta)

def test_get_shape_type_custom_meta_list_with_custom_meta(self):
"""Test _get_shape_type_custom_meta_list returns correct custom_meta when provided."""
# Add custom_meta to metadata
custom_meta = {
8: {"text": {"key1": "value1"}, "label": {"key2": "value2"}, "mask": {"key3": "value3"}},
9: {"text": {"key4": "value4"}, "label": {"key5": "value5"}, "mask": {"key6": "value6"}},
10: {"text": {"key7": "value7"}, "label": {"key8": "value8"}, "mask": {"key9": "value9"}},
}
self.metadata.update_custom_meta(custom_meta)

shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata)

# Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index
expected_custom_meta = [
{"key2": "value2"}, # label, global_index=8
{"key5": "value5"}, # label, global_index=9
{"key8": "value8"}, # label, global_index=10
{"key3": "value3"}, # mask, global_index=8
{"key6": "value6"}, # mask, global_index=9
{"key9": "value9"}, # mask, global_index=10
{"key1": "value1"}, # text, global_index=8
{"key4": "value4"}, # text, global_index=9
{"key7": "value7"}, # text, global_index=10
]
self.assertEqual(custom_meta_list, expected_custom_meta)

def test_get_shape_type_custom_meta_list_with_partial_custom_meta(self):
"""Test _get_shape_type_custom_meta_list handles partial custom_meta correctly."""
# Add custom_meta only for some global_indexes and fields
custom_meta = {
8: {"text": {"key1": "value1"}}, # Only text field
# global_index 9 has no custom_meta
10: {"label": {"key2": "value2"}, "mask": {"key3": "value3"}}, # label and mask only
}
self.metadata.update_custom_meta(custom_meta)

shapes, dtypes, custom_meta_list = KVStorageManager._get_shape_type_custom_meta_list(self.metadata)

# Check custom_meta - order is label, mask, text (sorted alphabetically) by global_index
expected_custom_meta = [
None, # label, global_index=8 (not in custom_meta)
None, # label, global_index=9 (not in custom_meta)
{"key2": "value2"}, # label, global_index=10
None, # mask, global_index=8 (not in custom_meta)
None, # mask, global_index=9 (not in custom_meta)
{"key3": "value3"}, # mask, global_index=10
{"key1": "value1"}, # text, global_index=8
None, # text, global_index=9 (not in custom_meta)
None, # text, global_index=10 (not in custom_meta for text)
]
self.assertEqual(custom_meta_list, expected_custom_meta)


class TestPutDataWithCustomMeta(unittest.TestCase):
"""Test put_data with custom_meta functionality."""

def setUp(self):
"""Set up test fixtures for put_data tests."""
self.field_names = ["text", "label"]
self.global_indexes = [0, 1, 2]

# Create test data
self.data = TensorDict(
{
"text": torch.tensor([[1, 2], [3, 4], [5, 6]]),
"label": torch.tensor([0, 1, 2]),
},
batch_size=3,
)

# Create metadata without production status set (for insert mode)
samples = []
for sample_id in range(self.data.batch_size[0]):
fields_dict = {}
for field_name in self.data.keys():
tensor = self.data[field_name][sample_id]
field_meta = FieldMeta(name=field_name, dtype=tensor.dtype, shape=tensor.shape, production_status=0)
fields_dict[field_name] = field_meta
sample = SampleMeta(
partition_id="test_partition",
global_index=self.global_indexes[sample_id],
fields=fields_dict,
)
samples.append(sample)
self.metadata = BatchMeta(samples=samples)

@patch.object(KVStorageManager, "_connect_to_controller")
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
def test_put_data_with_custom_meta_from_storage_client(self, mock_notify, mock_connect):
"""Test that put_data correctly processes custom_meta returned by storage client."""
# Create a mock storage client
mock_storage_client = MagicMock()
# Simulate storage client returning custom_meta (one per key)
# Keys order: label[0,1,2], text[0,1,2] (sorted by field name)
mock_custom_meta = [
{"storage_key": "0@label"},
{"storage_key": "1@label"},
{"storage_key": "2@label"},
{"storage_key": "0@text"},
{"storage_key": "1@text"},
{"storage_key": "2@text"},
]
mock_storage_client.put.return_value = mock_custom_meta

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
with patch(
"transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client
):
manager = KVStorageManager(config)

# Run put_data
asyncio.run(manager.put_data(self.data, self.metadata))

# Verify storage client was called with correct keys and values
mock_storage_client.put.assert_called_once()
call_args = mock_storage_client.put.call_args
keys = call_args[0][0]
values = call_args[0][1]

# Verify keys are correct
expected_keys = ["0@label", "1@label", "2@label", "0@text", "1@text", "2@text"]
self.assertEqual(keys, expected_keys)
self.assertEqual(len(values), 6)

# Verify notify_data_update was called with correct custom_meta structure
mock_notify.assert_called_once()
notify_call_args = mock_notify.call_args
per_field_custom_meta = notify_call_args[0][5] # 6th positional argument

# Verify custom_meta is structured correctly: {global_index: {field: meta}}
self.assertIn(0, per_field_custom_meta)
self.assertIn(1, per_field_custom_meta)
self.assertIn(2, per_field_custom_meta)

self.assertEqual(per_field_custom_meta[0]["label"], {"storage_key": "0@label"})
self.assertEqual(per_field_custom_meta[0]["text"], {"storage_key": "0@text"})
self.assertEqual(per_field_custom_meta[1]["label"], {"storage_key": "1@label"})
self.assertEqual(per_field_custom_meta[1]["text"], {"storage_key": "1@text"})
self.assertEqual(per_field_custom_meta[2]["label"], {"storage_key": "2@label"})
self.assertEqual(per_field_custom_meta[2]["text"], {"storage_key": "2@text"})

# Verify metadata was updated with custom_meta
all_custom_meta = self.metadata.get_all_custom_meta()
self.assertEqual(all_custom_meta[0]["label"], {"storage_key": "0@label"})
self.assertEqual(all_custom_meta[2]["text"], {"storage_key": "2@text"})

@patch.object(KVStorageManager, "_connect_to_controller")
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
def test_put_data_without_custom_meta(self, mock_notify, mock_connect):
"""Test that put_data works correctly when storage client returns no custom_meta."""
# Create a mock storage client that returns None for custom_meta
mock_storage_client = MagicMock()
mock_storage_client.put.return_value = None

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
with patch(
"transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client
):
manager = KVStorageManager(config)

# Run put_data
asyncio.run(manager.put_data(self.data, self.metadata))

# Verify notify_data_update was called with empty dict for custom_meta
mock_notify.assert_called_once()
notify_call_args = mock_notify.call_args
per_field_custom_meta = notify_call_args[0][5] # 6th positional argument
self.assertEqual(per_field_custom_meta, {})

@patch.object(KVStorageManager, "_connect_to_controller")
@patch.object(KVStorageManager, "notify_data_update", new_callable=AsyncMock)
def test_put_data_custom_meta_length_mismatch_raises_error(self, mock_notify, mock_connect):
"""Test that put_data raises ValueError when custom_meta length doesn't match keys."""
# Create a mock storage client that returns mismatched custom_meta length
mock_storage_client = MagicMock()
# Return only 3 custom_meta entries when 6 are expected
mock_storage_client.put.return_value = [{"key": "1"}, {"key": "2"}, {"key": "3"}]

# Create manager with mocked dependencies
config = {"client_name": "MockClient"}
with patch(
"transfer_queue.storage.managers.base.StorageClientFactory.create", return_value=mock_storage_client
):
manager = KVStorageManager(config)

# Run put_data and expect ValueError
with self.assertRaises(ValueError) as context:
asyncio.run(manager.put_data(self.data, self.metadata))

self.assertIn("does not match", str(context.exception))


if __name__ == "__main__":
unittest.main()
Loading