Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions agentlightning/llm_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ class LightningSpanExporter(SpanExporter):
internal loop, then waits for completion.
"""

# Maximum number of spans to hold in the buffer before dropping old entries.
# This prevents unbounded memory growth when spans cannot be flushed (e.g.,
# missing headers, store unavailable).
MAX_BUFFER_SIZE = 10000

def __init__(self, _store: Optional[LightningStore] = None):
self._store: Optional[LightningStore] = _store # this is only for testing purposes
self._buffer: List[ReadableSpan] = []
Expand Down Expand Up @@ -307,6 +312,19 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
with self._ensure_lock():
for span in spans:
self._buffer.append(span)

# Prevent unbounded buffer growth by dropping oldest entries when
# the buffer exceeds MAX_BUFFER_SIZE. This can happen when spans
# are missing required headers and can never be flushed.
if len(self._buffer) > self.MAX_BUFFER_SIZE:
n_drop = len(self._buffer) - self.MAX_BUFFER_SIZE
logger.warning(
"Span exporter buffer exceeded %d entries. Dropping %d oldest spans to prevent memory leak.",
self.MAX_BUFFER_SIZE,
n_drop,
)
self._buffer = self._buffer[n_drop:]

default_endpoint = self._otlp_exporter._endpoint # pyright: ignore[reportPrivateUsage]
try:
self._maybe_flush()
Expand Down
78 changes: 78 additions & 0 deletions agentlightning/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,81 @@ async def _evict_spans_for_rollout(self, collections: InMemoryLightningCollectio
# There is something removed for real
self._total_span_bytes = max(self._total_span_bytes - removed_bytes, 0)
self._evicted_rollout_span_sets.add(rollout_id)

@tracked("cleanup_finished_rollouts")
async def cleanup_finished_rollouts(self, rollout_ids=None):
"""Remove all data associated with finished rollouts to free memory.

This should be called after training data has been extracted from completed
rollouts (e.g., after get_train_data_batch or get_test_metrics). It removes
rollouts, their attempts, spans, and associated tracking metadata from all
in-memory data structures.

Args:
rollout_ids: Optional list of rollout IDs to clean up. If None, all
finished rollouts will be cleaned up.

Returns:
The number of rollouts cleaned up.
"""
cleaned_count = 0

async with self.collections.atomic(
mode="rw", snapshot=self._read_snapshot,
labels=["rollouts", "attempts", "spans", "span_sequence_ids"],
) as collections:
# Determine which rollouts to clean up
if rollout_ids is None:
all_rollouts = await collections.rollouts.query()
target_ids = [
r.rollout_id for r in all_rollouts.items if is_finished(r)
]
else:
target_ids = list(rollout_ids)

for rollout_id in target_ids:
rollout = await collections.rollouts.get(
{"rollout_id": {"exact": rollout_id}}
)
if rollout is None:
continue
if not is_finished(rollout):
continue

# Remove spans for this rollout
await collections.evict_spans_for_rollout(rollout_id)

# Remove attempts for this rollout
attempts_result = await collections.attempts.query(
filter={"rollout_id": {"exact": rollout_id}}
)
if attempts_result.items:
await collections.attempts.delete(attempts_result.items)

# Remove the rollout itself
await collections.rollouts.delete([rollout])

# Remove span sequence ID tracking
await collections.span_sequence_ids.pop(rollout_id)

cleaned_count += 1

# Clean up auxiliary tracking dicts outside the collection lock
for rollout_id in target_ids:
self._completion_events.pop(rollout_id, None)
self._start_time_by_rollout.pop(rollout_id, None)
self._span_bytes_by_rollout.pop(rollout_id, None)
self._running_rollout_ids.discard(rollout_id)
self._evicted_rollout_span_sets.discard(rollout_id)

if cleaned_count > 0:
logger.info(
"Cleaned up %d finished rollouts. Completion events: %d, "
"start time entries: %d, span byte entries: %d",
cleaned_count,
len(self._completion_events),
len(self._start_time_by_rollout),
len(self._span_bytes_by_rollout),
)

return cleaned_count
39 changes: 35 additions & 4 deletions agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,14 +1134,45 @@ def get_train_data_batch(
return data_proto, data_metrics

def clear_data_and_server(self):
"""Resets the internal state of the daemon for the next run."""
"""Resets the internal state of the daemon for the next run.

Also cleans up finished rollouts from the store to prevent memory leaks
during long training runs.
"""
# Collect rollout IDs before clearing so we can clean up the store.
# In v1 mode, _task_id_to_original_sample maps rollout_id -> sample.
finished_rollout_ids = list(self._task_id_to_original_sample.keys())

self.backend_llm_server_addresses = []
self._completed_rollouts_v0.clear()
self._task_id_to_original_sample.clear()
self._total_tasks_queued = 0
# For a true reset, the server's internal queues would also need clearing.
# This implementation assumes that `set_up_data_and_server` is called
# for each new run, effectively starting a fresh batch.

# Clean up finished rollouts from the store to free memory.
# This is critical for long training runs where rollout data accumulates
# in the in-memory store across training steps, causing OOM.
if self.mode == "v1" and finished_rollout_ids:
self._cleanup_store_rollouts(finished_rollout_ids)

def _cleanup_store_rollouts(self, rollout_ids: List[str]):
"""Clean up finished rollouts from the store to prevent memory buildup.

This runs the async cleanup synchronously on the internal event loop.
"""
from agentlightning.store.memory import InMemoryLightningStore

if not isinstance(self.store, InMemoryLightningStore):
return

if self._internal_loop is None:
return

try:
coro = self.store.cleanup_finished_rollouts(rollout_ids)
future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop)
future.result(timeout=60)
except Exception as e:
print(f"Warning: Failed to clean up store rollouts: {e}")

def _fillna_reward(self, rollout: RolloutLegacy):
if rollout.final_reward is None:
Expand Down
12 changes: 12 additions & 0 deletions agentlightning/verl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import gc
import random
from contextlib import contextmanager
from copy import deepcopy
Expand Down Expand Up @@ -272,6 +273,10 @@ def _train_step(self, batch_dict: dict) -> dict:
self.agent_mode_daemon.clear_data_and_server()
self.async_rollout_manager.sleep()

# Release the original input batch to free memory now that
# training data has been extracted from the daemon.
del gen_batch

if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with _timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
Expand Down Expand Up @@ -433,6 +438,13 @@ def _train_step(self, batch_dict: dict) -> dict:
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))

# Explicitly release batch tensors and trigger garbage collection to
# prevent memory accumulation across training steps.
del batch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

return metrics

def fit(self):
Expand Down