-
Notifications
You must be signed in to change notification settings - Fork 435
[Performance] Lazy stack optimization for collector-to-buffer writes #3438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
55e61d9
fe89dd7
fce41c3
b54982c
cddbc62
8c5a3b1
8fc4e34
0f4280e
8c326f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -918,6 +918,115 @@ def test_extend_lazystack(self, storage_type): | |
| rb.sample(3) | ||
| assert len(rb) == 5 | ||
|
|
||
| @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) | ||
| def test_extend_lazystack_direct_write(self, storage_type): | ||
| """Test that lazy stacks can be extended to storage correctly. | ||
|
|
||
| This tests that lazy stacks from collectors are properly stored in | ||
| replay buffers and that the data integrity is preserved. Also verifies | ||
| that the update_() optimization is used for tensor indices. | ||
| """ | ||
| rb = ReplayBuffer( | ||
| storage=storage_type(100), | ||
| batch_size=10, | ||
| ) | ||
| # Create a list of tensordicts (like a collector would produce) | ||
| tensordicts = [ | ||
| TensorDict( | ||
| {"obs": torch.rand(4, 8), "action": torch.rand(2)}, batch_size=() | ||
| ) | ||
| for _ in range(10) | ||
| ] | ||
| # Create lazy stack with stack_dim=0 (the batch dimension) | ||
| lazy_td = LazyStackedTensorDict.lazy_stack(tensordicts, dim=0) | ||
| assert isinstance(lazy_td, LazyStackedTensorDict) | ||
|
|
||
| # Track calls to update_at_() - used for tensor indices | ||
| update_at_called = [] | ||
| original_update_at = TensorDictBase.update_at_ | ||
|
|
||
| def mock_update_at(self, *args, **kwargs): | ||
| update_at_called.append(True) | ||
| return original_update_at(self, *args, **kwargs) | ||
|
|
||
| # Extend with lazy stack and verify update_at_() is called | ||
| # (rb.extend uses tensor indices, so update_at_() path is taken) | ||
| with mock.patch.object(TensorDictBase, "update_at_", mock_update_at): | ||
| rb.extend(lazy_td) | ||
|
|
||
| # Verify update_at_() was called (optimization was used) | ||
| assert len(update_at_called) > 0, "update_at_() should have been called" | ||
|
|
||
| # Verify data integrity | ||
| assert len(rb) == 10 | ||
| sample = rb.sample(5) | ||
| assert sample["obs"].shape == (5, 4, 8) | ||
| assert sample["action"].shape == (5, 2) | ||
|
|
||
| # Verify all data is accessible by reading the entire storage | ||
| all_data = rb[:] | ||
| assert all_data["obs"].shape == (10, 4, 8) | ||
| assert all_data["action"].shape == (10, 2) | ||
|
|
||
| # Verify data values are preserved (check against original stacked data) | ||
| expected = lazy_td.to_tensordict() | ||
| assert torch.allclose(all_data["obs"], expected["obs"]) | ||
| assert torch.allclose(all_data["action"], expected["action"]) | ||
|
|
||
| @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) | ||
| def test_extend_lazystack_2d_storage(self, storage_type): | ||
| """Test lazy stack optimization for 2D storage (parallel envs). | ||
|
|
||
| When using parallel environments, the storage is 2D [max_size, n_steps] | ||
| and the lazy stack has stack_dim=1 (time dimension). This test verifies | ||
| the optimization handles this case correctly. | ||
| """ | ||
| n_envs = 4 | ||
| n_steps = 10 | ||
| img_shape = (3, 32, 32) | ||
|
|
||
| # Create 2D storage - capacity is 100 * n_steps when ndim=2 | ||
| storage = storage_type(100 * n_steps, ndim=2) | ||
|
|
||
| # Pre-initialize storage with correct shape by setting first element | ||
| init_td = TensorDict( | ||
| {"pixels": torch.zeros(n_steps, *img_shape)}, | ||
| batch_size=[n_steps], | ||
| ) | ||
| storage.set(0, init_td, set_cursor=False) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per se any write that uses a view (with int, some kinds of slice etc) should trigger something akin to stack_onto_ |
||
|
|
||
| # Expand storage to full size | ||
| full_init = TensorDict( | ||
| {"pixels": torch.zeros(100, n_steps, *img_shape)}, | ||
| batch_size=[100, n_steps], | ||
| ) | ||
| storage.set(slice(0, 100), full_init, set_cursor=False) | ||
|
|
||
| # Create lazy stack simulating parallel env output | ||
| # stack_dim=1 means stacked along time dimension | ||
| time_tds = [ | ||
| TensorDict( | ||
| {"pixels": torch.rand(n_envs, *img_shape)}, | ||
| batch_size=[n_envs], | ||
| ) | ||
| for _ in range(n_steps) | ||
| ] | ||
| lazy_td = LazyStackedTensorDict.lazy_stack(time_tds, dim=1) | ||
| assert lazy_td.stack_dim == 1 | ||
| assert lazy_td.batch_size == torch.Size([n_envs, n_steps]) | ||
|
|
||
| # Write using tensor indices (simulating circular buffer behavior) | ||
| cursor = torch.tensor([0, 1, 2, 3]) | ||
| storage.set(cursor, lazy_td) | ||
|
|
||
| # Verify data integrity | ||
| for i in range(n_envs): | ||
| stored = storage[i] | ||
| expected = lazy_td[i].to_tensordict() | ||
| assert torch.allclose( | ||
| stored["pixels"], expected["pixels"] | ||
| ), f"Data mismatch for env {i}" | ||
|
|
||
| @pytest.mark.parametrize("device_data", get_default_devices()) | ||
| @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) | ||
| @pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1057,7 +1057,22 @@ def set( | |
| if storage_keys is not None: | ||
| data = data.select(*storage_keys, strict=False) | ||
| try: | ||
| self._storage[cursor] = data | ||
| # Optimize lazy stack writes: write each tensordict directly to | ||
| # storage to avoid creating an intermediate contiguous copy. | ||
| if isinstance(data, LazyStackedTensorDict): | ||
| stack_dim = data.stack_dim | ||
| if isinstance(cursor, slice): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some slices don't return views. I think it depends both on the tensor layout and the slice |
||
| # For slices, storage[slice] typically returns a view. | ||
| # Use _stack_onto_ to write directly without intermediate copy. | ||
| self._storage[cursor]._stack_onto_( | ||
| list(data.unbind(stack_dim)), dim=stack_dim | ||
| ) | ||
| else: | ||
| # For tensor/sequence indices, use update_at_ which handles | ||
| # lazy stacks efficiently in a single call. | ||
| self._storage.update_at_(data, cursor) | ||
| else: | ||
| self._storage[cursor] = data | ||
| except RuntimeError as e: | ||
| if "locked" in str(e).lower(): | ||
| # Provide informative error about key differences | ||
|
|
@@ -1128,7 +1143,22 @@ def set( # noqa: F811 | |
| if storage_keys is not None: | ||
| data = data.select(*storage_keys, strict=False) | ||
| try: | ||
| self._storage[cursor] = data | ||
| # Optimize lazy stack writes: write each tensordict directly to | ||
| # storage to avoid creating an intermediate contiguous copy. | ||
| if is_tensor_collection(data) and isinstance(data, LazyStackedTensorDict): | ||
| stack_dim = data.stack_dim | ||
| if isinstance(cursor, slice): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
| # For slices, storage[slice] typically returns a view. | ||
| # Use _stack_onto_ to write directly without intermediate copy. | ||
| self._storage[cursor]._stack_onto_( | ||
| list(data.unbind(stack_dim)), dim=stack_dim | ||
| ) | ||
| else: | ||
| # For tensor/sequence indices, use update_at_ which handles | ||
| # lazy stacks efficiently in a single call. | ||
| self._storage.update_at_(data, cursor) | ||
| else: | ||
| self._storage[cursor] = data | ||
| except RuntimeError as e: | ||
| if "locked" in str(e).lower(): | ||
| # Provide informative error about key differences | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.