Skip to content
63 changes: 63 additions & 0 deletions benchmarks/test_collectors_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,69 @@ def execute_collector(c):
next(c)


def execute_collector_with_rb(c, rb):
"""Execute collector iteration and verify data was stored in replay buffer."""
next(c)


# --- Benchmarks for collector with replay buffer (lazy stack optimization) ---


def single_collector_with_rb_setup():
"""Setup single collector with replay buffer - tests lazy stack optimization."""
device = "cuda:0" if torch.cuda.device_count() else "cpu"
env = TransformedEnv(DMControlEnv("cheetah", "run", device=device), StepCounter(50))
rb = ReplayBuffer(storage=LazyTensorStorage(10000))
c = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
total_frames=-1,
frames_per_batch=100,
device=device,
replay_buffer=rb,
)
c = iter(c)
# Warmup
for i, _ in enumerate(c):
if i == 10:
break
return ((c, rb), {})


def single_collector_with_rb_setup_pixels():
"""Setup single collector with replay buffer for pixel observations."""
device = "cuda:0" if torch.cuda.device_count() else "cpu"
env = TransformedEnv(GymEnv("ALE/Pong-v5"), StepCounter(50))
rb = ReplayBuffer(storage=LazyTensorStorage(10000))
c = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
total_frames=-1,
frames_per_batch=100,
device=device,
replay_buffer=rb,
)
c = iter(c)
# Warmup
for i, _ in enumerate(c):
if i == 10:
break
return ((c, rb), {})


def test_single_with_rb(benchmark):
"""Benchmark single collector with replay buffer (lazy stack path)."""
(c, rb), _ = single_collector_with_rb_setup()
benchmark(execute_collector_with_rb, c, rb)


@pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda")
def test_single_with_rb_pixels(benchmark):
"""Benchmark single collector with replay buffer for pixel observations."""
(c, rb), _ = single_collector_with_rb_setup_pixels()
benchmark(execute_collector_with_rb, c, rb)


def test_single(benchmark):
(c,), _ = single_collector_setup()
benchmark(execute_collector, c)
Expand Down
117 changes: 117 additions & 0 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4051,6 +4051,123 @@ def test_collector_rb_sync(self):
del collector, env
assert assert_allclose_td(rbdata0, rbdata1)

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
@pytest.mark.parametrize("storage_type", [LazyTensorStorage, LazyMemmapStorage])
def test_collector_with_rb_uses_lazy_stack(self, storage_type, tmpdir):
"""Test that collector uses lazy stack path when replay buffer is provided.

This tests the optimization where collectors create lazy stacks instead of
materializing data into a contiguous buffer, allowing the storage to write
directly to its buffer without intermediate copies.
"""
if storage_type is LazyMemmapStorage:
storage = storage_type(1000, scratch_dir=tmpdir)
else:
storage = storage_type(1000)

env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)
rb = ReplayBuffer(storage=storage, batch_size=10)
collector = Collector(
env,
RandomPolicy(env.action_spec),
frames_per_batch=50,
total_frames=200,
replay_buffer=rb,
)
torch.manual_seed(0)

try:
# Track calls to update_at_() - used for tensor indices
update_at_called = []
original_update_at = TensorDictBase.update_at_

def mock_update_at(self, *args, **kwargs):
update_at_called.append(True)
return original_update_at(self, *args, **kwargs)

with patch.object(TensorDictBase, "update_at_", mock_update_at):
collected_frames = 0
for data in collector:
# When replay buffer is used, collector yields None
assert data is None
collected_frames += 50

# Verify update_at_() was called (optimization was used)
assert len(update_at_called) > 0, "update_at_() should have been called"

# Verify data was properly stored in the replay buffer
assert len(rb) == 200, f"Expected 200 frames in buffer, got {len(rb)}"

# Sample and verify data integrity
sample = rb.sample(10)
assert "observation" in sample.keys()
assert "action" in sample.keys()
assert "next" in sample.keys()
assert sample["observation"].shape[0] == 10

# Verify we can sample multiple times without issues
for _ in range(5):
sample = rb.sample(20)
assert sample["observation"].shape[0] == 20
finally:
collector.shutdown()

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
@pytest.mark.parametrize("storage_type", [LazyTensorStorage, LazyMemmapStorage])
def test_collector_with_rb_parallel_env(self, storage_type, tmpdir):
"""Test collector with replay buffer using parallel envs (2D storage).

With parallel environments, the storage is 2D [max_size, n_steps] and the
lazy stack has stack_dim=1. This tests that data is correctly stored and
can be sampled from the replay buffer.
"""
n_envs = 4

def make_env():
return GymEnv(CARTPOLE_VERSIONED())

env = SerialEnv(n_envs, make_env)
env.set_seed(0)

if storage_type is LazyMemmapStorage:
storage = storage_type(1000, scratch_dir=tmpdir, ndim=2)
else:
storage = storage_type(1000, ndim=2)

rb = ReplayBuffer(storage=storage, batch_size=10)
collector = Collector(
env,
RandomPolicy(env.action_spec),
frames_per_batch=100, # 100 frames = 25 steps per env
total_frames=200,
replay_buffer=rb,
)
torch.manual_seed(0)

try:
collected_frames = 0
for data in collector:
# When replay buffer is used, collector yields None
assert data is None
collected_frames += 100

# With 2D storage [n_rows, n_steps], len(rb) returns n_rows
# Each batch adds n_envs rows, so 2 batches = 8 rows
assert len(rb) >= 8, f"Expected >= 8 rows in buffer, got {len(rb)}"

# Sample and verify data integrity
sample = rb.sample(4)
assert "observation" in sample.keys()
assert "action" in sample.keys()
assert "next" in sample.keys()

# Verify we can sample multiple times without issues
for _ in range(5):
sample = rb.sample(4)
finally:
collector.shutdown()

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
@pytest.mark.parametrize("extend_buffer", [False, True])
@pytest.mark.parametrize("env_creator", [False, True])
Expand Down
109 changes: 109 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,115 @@ def test_extend_lazystack(self, storage_type):
rb.sample(3)
assert len(rb) == 5

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
def test_extend_lazystack_direct_write(self, storage_type):
"""Test that lazy stacks can be extended to storage correctly.

This tests that lazy stacks from collectors are properly stored in
replay buffers and that the data integrity is preserved. Also verifies
that the update_() optimization is used for tensor indices.
"""
rb = ReplayBuffer(
storage=storage_type(100),
batch_size=10,
)
# Create a list of tensordicts (like a collector would produce)
tensordicts = [
TensorDict(
{"obs": torch.rand(4, 8), "action": torch.rand(2)}, batch_size=()
)
for _ in range(10)
]
# Create lazy stack with stack_dim=0 (the batch dimension)
lazy_td = LazyStackedTensorDict.lazy_stack(tensordicts, dim=0)
assert isinstance(lazy_td, LazyStackedTensorDict)

# Track calls to update_at_() - used for tensor indices
update_at_called = []
original_update_at = TensorDictBase.update_at_

def mock_update_at(self, *args, **kwargs):
update_at_called.append(True)
return original_update_at(self, *args, **kwargs)

# Extend with lazy stack and verify update_at_() is called
# (rb.extend uses tensor indices, so update_at_() path is taken)
with mock.patch.object(TensorDictBase, "update_at_", mock_update_at):
rb.extend(lazy_td)

# Verify update_at_() was called (optimization was used)
assert len(update_at_called) > 0, "update_at_() should have been called"

# Verify data integrity
assert len(rb) == 10
sample = rb.sample(5)
assert sample["obs"].shape == (5, 4, 8)
assert sample["action"].shape == (5, 2)

# Verify all data is accessible by reading the entire storage
all_data = rb[:]
assert all_data["obs"].shape == (10, 4, 8)
assert all_data["action"].shape == (10, 2)

# Verify data values are preserved (check against original stacked data)
expected = lazy_td.to_tensordict()
assert torch.allclose(all_data["obs"], expected["obs"])
assert torch.allclose(all_data["action"], expected["action"])

@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
def test_extend_lazystack_2d_storage(self, storage_type):
"""Test lazy stack optimization for 2D storage (parallel envs).

When using parallel environments, the storage is 2D [max_size, n_steps]
and the lazy stack has stack_dim=1 (time dimension). This test verifies
the optimization handles this case correctly.
"""
n_envs = 4
n_steps = 10
img_shape = (3, 32, 32)

# Create 2D storage - capacity is 100 * n_steps when ndim=2
storage = storage_type(100 * n_steps, ndim=2)

# Pre-initialize storage with correct shape by setting first element
init_td = TensorDict(
{"pixels": torch.zeros(n_steps, *img_shape)},
batch_size=[n_steps],
)
storage.set(0, init_td, set_cursor=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per se any write that uses a view (with int, some kinds of slice etc) should trigger something akin to stack_onto_
Im not sure if PyTorch has a utility that given a tensor tells you if an index will return a view or a copy. If that does not exist we should invent it.


# Expand storage to full size
full_init = TensorDict(
{"pixels": torch.zeros(100, n_steps, *img_shape)},
batch_size=[100, n_steps],
)
storage.set(slice(0, 100), full_init, set_cursor=False)

# Create lazy stack simulating parallel env output
# stack_dim=1 means stacked along time dimension
time_tds = [
TensorDict(
{"pixels": torch.rand(n_envs, *img_shape)},
batch_size=[n_envs],
)
for _ in range(n_steps)
]
lazy_td = LazyStackedTensorDict.lazy_stack(time_tds, dim=1)
assert lazy_td.stack_dim == 1
assert lazy_td.batch_size == torch.Size([n_envs, n_steps])

# Write using tensor indices (simulating circular buffer behavior)
cursor = torch.tensor([0, 1, 2, 3])
storage.set(cursor, lazy_td)

# Verify data integrity
for i in range(n_envs):
stored = storage[i]
expected = lazy_td[i].to_tensordict()
assert torch.allclose(
stored["pixels"], expected["pixels"]
), f"Data mismatch for env {i}"

@pytest.mark.parametrize("device_data", get_default_devices())
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
@pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"])
Expand Down
16 changes: 16 additions & 0 deletions torchrl/collectors/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,13 @@ def rollout(self) -> TensorDictBase:
self._final_rollout.ndim - 1,
out=self._final_rollout[..., : t + 1],
)
elif (
self.replay_buffer is not None
and not self._ignore_rb
and self.extend_buffer
):
# Use lazy stack for direct storage write optimization
result = LazyStackedTensorDict.lazy_stack(tensordicts, dim=-1)
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
break
Expand All @@ -1775,6 +1782,15 @@ def rollout(self) -> TensorDictBase:
and not self.extend_buffer
):
return
elif (
self.replay_buffer is not None
and not self._ignore_rb
and self.extend_buffer
):
# Use lazy stack for direct storage write optimization.
# This avoids creating an intermediate contiguous copy -
# the storage will stack directly into its buffer.
result = LazyStackedTensorDict.lazy_stack(tensordicts, dim=-1)
else:
result = TensorDict.maybe_dense_stack(tensordicts, dim=-1)
result.refine_names(..., "time")
Expand Down
34 changes: 32 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,22 @@ def set(
if storage_keys is not None:
data = data.select(*storage_keys, strict=False)
try:
self._storage[cursor] = data
# Optimize lazy stack writes: write each tensordict directly to
# storage to avoid creating an intermediate contiguous copy.
if isinstance(data, LazyStackedTensorDict):
stack_dim = data.stack_dim
if isinstance(cursor, slice):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some slices don't return views. I think it depends both on the tensor layout and the slice

# For slices, storage[slice] typically returns a view.
# Use _stack_onto_ to write directly without intermediate copy.
self._storage[cursor]._stack_onto_(
list(data.unbind(stack_dim)), dim=stack_dim
)
else:
# For tensor/sequence indices, use update_at_ which handles
# lazy stacks efficiently in a single call.
self._storage.update_at_(data, cursor)
else:
self._storage[cursor] = data
except RuntimeError as e:
if "locked" in str(e).lower():
# Provide informative error about key differences
Expand Down Expand Up @@ -1128,7 +1143,22 @@ def set( # noqa: F811
if storage_keys is not None:
data = data.select(*storage_keys, strict=False)
try:
self._storage[cursor] = data
# Optimize lazy stack writes: write each tensordict directly to
# storage to avoid creating an intermediate contiguous copy.
if is_tensor_collection(data) and isinstance(data, LazyStackedTensorDict):
stack_dim = data.stack_dim
if isinstance(cursor, slice):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

# For slices, storage[slice] typically returns a view.
# Use _stack_onto_ to write directly without intermediate copy.
self._storage[cursor]._stack_onto_(
list(data.unbind(stack_dim)), dim=stack_dim
)
else:
# For tensor/sequence indices, use update_at_ which handles
# lazy stacks efficiently in a single call.
self._storage.update_at_(data, cursor)
else:
self._storage[cursor] = data
except RuntimeError as e:
if "locked" in str(e).lower():
# Provide informative error about key differences
Expand Down
Loading