diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py index c3887352b7d..789e4134edf 100644 --- a/benchmarks/test_collectors_benchmark.py +++ b/benchmarks/test_collectors_benchmark.py @@ -151,6 +151,69 @@ def execute_collector(c): next(c) +def execute_collector_with_rb(c, rb): + """Execute collector iteration and verify data was stored in replay buffer.""" + next(c) + + +# --- Benchmarks for collector with replay buffer (lazy stack optimization) --- + + +def single_collector_with_rb_setup(): + """Setup single collector with replay buffer - tests lazy stack optimization.""" + device = "cuda:0" if torch.cuda.device_count() else "cpu" + env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50)) + rb = ReplayBuffer(storage=LazyTensorStorage(10000)) + c = SyncDataCollector( + env, + RandomPolicy(env.action_spec), + total_frames=-1, + frames_per_batch=100, + device=device, + replay_buffer=rb, + ) + c = iter(c) + # Warmup + for i, _ in enumerate(c): + if i == 10: + break + return ((c, rb), {}) + + +def single_collector_with_rb_setup_pixels(): + """Setup single collector with replay buffer for pixel observations.""" + device = "cuda:0" if torch.cuda.device_count() else "cpu" + env = TransformedEnv(GymEnv("ALE/Pong-v5"), StepCounter(50)) + rb = ReplayBuffer(storage=LazyTensorStorage(10000)) + c = SyncDataCollector( + env, + RandomPolicy(env.action_spec), + total_frames=-1, + frames_per_batch=100, + device=device, + replay_buffer=rb, + ) + c = iter(c) + # Warmup + for i, _ in enumerate(c): + if i == 10: + break + return ((c, rb), {}) + + +def test_single_with_rb(benchmark): + """Benchmark single collector with replay buffer (lazy stack path).""" + (c, rb), _ = single_collector_with_rb_setup() + benchmark(execute_collector_with_rb, c, rb) + + +@pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda") +def test_single_with_rb_pixels(benchmark): + """Benchmark single collector with replay buffer for pixel observations.""" + (c, rb), _ = single_collector_with_rb_setup_pixels() + benchmark(execute_collector_with_rb, c, rb) + + def test_single(benchmark): (c,), _ = single_collector_setup() benchmark(execute_collector, c) diff --git a/test/test_collectors.py b/test/test_collectors.py index dd60eb431a4..8c246ea68c9 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -4051,6 +4051,123 @@ def test_collector_rb_sync(self): del collector, env assert assert_allclose_td(rbdata0, rbdata1) + @pytest.mark.skipif(not _has_gym, reason="requires gym.") + @pytest.mark.parametrize("storage_type", [LazyTensorStorage, LazyMemmapStorage]) + def test_collector_with_rb_uses_lazy_stack(self, storage_type, tmpdir): + """Test that collector uses lazy stack path when replay buffer is provided. + + This tests the optimization where collectors create lazy stacks instead of + materializing data into a contiguous buffer, allowing the storage to write + directly to its buffer without intermediate copies. + """ + if storage_type is LazyMemmapStorage: + storage = storage_type(1000, scratch_dir=tmpdir) + else: + storage = storage_type(1000) + + env = GymEnv(CARTPOLE_VERSIONED()) + env.set_seed(0) + rb = ReplayBuffer(storage=storage, batch_size=10) + collector = Collector( + env, + RandomPolicy(env.action_spec), + frames_per_batch=50, + total_frames=200, + replay_buffer=rb, + ) + torch.manual_seed(0) + + try: + # Track calls to update_at_() - used for tensor indices + update_at_called = [] + original_update_at = TensorDictBase.update_at_ + + def mock_update_at(self, *args, **kwargs): + update_at_called.append(True) + return original_update_at(self, *args, **kwargs) + + with patch.object(TensorDictBase, "update_at_", mock_update_at): + collected_frames = 0 + for data in collector: + # When replay buffer is used, collector yields None + assert data is None + collected_frames += 50 + + # Verify update_at_() was called (optimization was used) + assert len(update_at_called) > 0, "update_at_() should have been called" + + # Verify data was properly stored in the replay buffer + assert len(rb) == 200, f"Expected 200 frames in buffer, got {len(rb)}" + + # Sample and verify data integrity + sample = rb.sample(10) + assert "observation" in sample.keys() + assert "action" in sample.keys() + assert "next" in sample.keys() + assert sample["observation"].shape[0] == 10 + + # Verify we can sample multiple times without issues + for _ in range(5): + sample = rb.sample(20) + assert sample["observation"].shape[0] == 20 + finally: + collector.shutdown() + + @pytest.mark.skipif(not _has_gym, reason="requires gym.") + @pytest.mark.parametrize("storage_type", [LazyTensorStorage, LazyMemmapStorage]) + def test_collector_with_rb_parallel_env(self, storage_type, tmpdir): + """Test collector with replay buffer using parallel envs (2D storage). + + With parallel environments, the storage is 2D [max_size, n_steps] and the + lazy stack has stack_dim=1. This tests that data is correctly stored and + can be sampled from the replay buffer. + """ + n_envs = 4 + + def make_env(): + return GymEnv(CARTPOLE_VERSIONED()) + + env = SerialEnv(n_envs, make_env) + env.set_seed(0) + + if storage_type is LazyMemmapStorage: + storage = storage_type(1000, scratch_dir=tmpdir, ndim=2) + else: + storage = storage_type(1000, ndim=2) + + rb = ReplayBuffer(storage=storage, batch_size=10) + collector = Collector( + env, + RandomPolicy(env.action_spec), + frames_per_batch=100, # 100 frames = 25 steps per env + total_frames=200, + replay_buffer=rb, + ) + torch.manual_seed(0) + + try: + collected_frames = 0 + for data in collector: + # When replay buffer is used, collector yields None + assert data is None + collected_frames += 100 + + # With 2D storage [n_rows, n_steps], len(rb) returns n_rows + # Each batch adds n_envs rows, so 2 batches = 8 rows + assert len(rb) >= 8, f"Expected >= 8 rows in buffer, got {len(rb)}" + + # Sample and verify data integrity + sample = rb.sample(4) + assert "observation" in sample.keys() + assert "action" in sample.keys() + assert "next" in sample.keys() + + # Verify we can sample multiple times without issues + for _ in range(5): + sample = rb.sample(4) + finally: + collector.shutdown() + @pytest.mark.skipif(not _has_gym, reason="requires gym.") @pytest.mark.parametrize("extend_buffer", [False, True]) @pytest.mark.parametrize("env_creator", [False, True]) diff --git a/test/test_rb.py b/test/test_rb.py index 95da8810606..3490938dfbe 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -918,6 +918,115 @@ def test_extend_lazystack(self, storage_type): rb.sample(3) assert len(rb) == 5 + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + def test_extend_lazystack_direct_write(self, storage_type): + """Test that lazy stacks can be extended to storage correctly. + + This tests that lazy stacks from collectors are properly stored in + replay buffers and that the data integrity is preserved. Also verifies + that the update_() optimization is used for tensor indices. + """ + rb = ReplayBuffer( + storage=storage_type(100), + batch_size=10, + ) + # Create a list of tensordicts (like a collector would produce) + tensordicts = [ + TensorDict( + {"obs": torch.rand(4, 8), "action": torch.rand(2)}, batch_size=() + ) + for _ in range(10) + ] + # Create lazy stack with stack_dim=0 (the batch dimension) + lazy_td = LazyStackedTensorDict.lazy_stack(tensordicts, dim=0) + assert isinstance(lazy_td, LazyStackedTensorDict) + + # Track calls to update_at_() - used for tensor indices + update_at_called = [] + original_update_at = TensorDictBase.update_at_ + + def mock_update_at(self, *args, **kwargs): + update_at_called.append(True) + return original_update_at(self, *args, **kwargs) + + # Extend with lazy stack and verify update_at_() is called + # (rb.extend uses tensor indices, so update_at_() path is taken) + with mock.patch.object(TensorDictBase, "update_at_", mock_update_at): + rb.extend(lazy_td) + + # Verify update_at_() was called (optimization was used) + assert len(update_at_called) > 0, "update_at_() should have been called" + + # Verify data integrity + assert len(rb) == 10 + sample = rb.sample(5) + assert sample["obs"].shape == (5, 4, 8) + assert sample["action"].shape == (5, 2) + + # Verify all data is accessible by reading the entire storage + all_data = rb[:] + assert all_data["obs"].shape == (10, 4, 8) + assert all_data["action"].shape == (10, 2) + + # Verify data values are preserved (check against original stacked data) + expected = lazy_td.to_tensordict() + assert torch.allclose(all_data["obs"], expected["obs"]) + assert torch.allclose(all_data["action"], expected["action"]) + + @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) + def test_extend_lazystack_2d_storage(self, storage_type): + """Test lazy stack optimization for 2D storage (parallel envs). + + When using parallel environments, the storage is 2D [max_size, n_steps] + and the lazy stack has stack_dim=1 (time dimension). This test verifies + the optimization handles this case correctly. + """ + n_envs = 4 + n_steps = 10 + img_shape = (3, 32, 32) + + # Create 2D storage - capacity is 100 * n_steps when ndim=2 + storage = storage_type(100 * n_steps, ndim=2) + + # Pre-initialize storage with correct shape by setting first element + init_td = TensorDict( + {"pixels": torch.zeros(n_steps, *img_shape)}, + batch_size=[n_steps], + ) + storage.set(0, init_td, set_cursor=False) + + # Expand storage to full size + full_init = TensorDict( + {"pixels": torch.zeros(100, n_steps, *img_shape)}, + batch_size=[100, n_steps], + ) + storage.set(slice(0, 100), full_init, set_cursor=False) + + # Create lazy stack simulating parallel env output + # stack_dim=1 means stacked along time dimension + time_tds = [ + TensorDict( + {"pixels": torch.rand(n_envs, *img_shape)}, + batch_size=[n_envs], + ) + for _ in range(n_steps) + ] + lazy_td = LazyStackedTensorDict.lazy_stack(time_tds, dim=1) + assert lazy_td.stack_dim == 1 + assert lazy_td.batch_size == torch.Size([n_envs, n_steps]) + + # Write using tensor indices (simulating circular buffer behavior) + cursor = torch.tensor([0, 1, 2, 3]) + storage.set(cursor, lazy_td) + + # Verify data integrity + for i in range(n_envs): + stored = storage[i] + expected = lazy_td[i].to_tensordict() + assert torch.allclose( + stored["pixels"], expected["pixels"] + ), f"Data mismatch for env {i}" + @pytest.mark.parametrize("device_data", get_default_devices()) @pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage]) @pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"]) diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index 3afd9d0deb5..cdafacada44 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -1749,6 +1749,13 @@ def rollout(self) -> TensorDictBase: self._final_rollout.ndim - 1, out=self._final_rollout[..., : t + 1], ) + elif ( + self.replay_buffer is not None + and not self._ignore_rb + and self.extend_buffer + ): + # Use lazy stack for direct storage write optimization + result = LazyStackedTensorDict.lazy_stack(tensordicts, dim=-1) else: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) break @@ -1775,6 +1782,15 @@ def rollout(self) -> TensorDictBase: and not self.extend_buffer ): return + elif ( + self.replay_buffer is not None + and not self._ignore_rb + and self.extend_buffer + ): + # Use lazy stack for direct storage write optimization. + # This avoids creating an intermediate contiguous copy - + # the storage will stack directly into its buffer. + result = LazyStackedTensorDict.lazy_stack(tensordicts, dim=-1) else: result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) result.refine_names(..., "time") diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index e6cdd64d583..3ce5de75793 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1057,7 +1057,22 @@ def set( if storage_keys is not None: data = data.select(*storage_keys, strict=False) try: - self._storage[cursor] = data + # Optimize lazy stack writes: write each tensordict directly to + # storage to avoid creating an intermediate contiguous copy. + if isinstance(data, LazyStackedTensorDict): + stack_dim = data.stack_dim + if isinstance(cursor, slice): + # For slices, storage[slice] typically returns a view. + # Use _stack_onto_ to write directly without intermediate copy. + self._storage[cursor]._stack_onto_( + list(data.unbind(stack_dim)), dim=stack_dim + ) + else: + # For tensor/sequence indices, use update_at_ which handles + # lazy stacks efficiently in a single call. + self._storage.update_at_(data, cursor) + else: + self._storage[cursor] = data except RuntimeError as e: if "locked" in str(e).lower(): # Provide informative error about key differences @@ -1128,7 +1143,22 @@ def set( # noqa: F811 if storage_keys is not None: data = data.select(*storage_keys, strict=False) try: - self._storage[cursor] = data + # Optimize lazy stack writes: write each tensordict directly to + # storage to avoid creating an intermediate contiguous copy. + if is_tensor_collection(data) and isinstance(data, LazyStackedTensorDict): + stack_dim = data.stack_dim + if isinstance(cursor, slice): + # For slices, storage[slice] typically returns a view. + # Use _stack_onto_ to write directly without intermediate copy. + self._storage[cursor]._stack_onto_( + list(data.unbind(stack_dim)), dim=stack_dim + ) + else: + # For tensor/sequence indices, use update_at_ which handles + # lazy stacks efficiently in a single call. + self._storage.update_at_(data, cursor) + else: + self._storage[cursor] = data except RuntimeError as e: if "locked" in str(e).lower(): # Provide informative error about key differences