diff --git a/agentlightning/llm_proxy.py b/agentlightning/llm_proxy.py index 42368edef..89ae4a0ee 100644 --- a/agentlightning/llm_proxy.py +++ b/agentlightning/llm_proxy.py @@ -210,6 +210,11 @@ class LightningSpanExporter(SpanExporter): internal loop, then waits for completion. """ + # Maximum number of spans to hold in the buffer before dropping old entries. + # This prevents unbounded memory growth when spans cannot be flushed (e.g., + # missing headers, store unavailable). + MAX_BUFFER_SIZE = 10000 + def __init__(self, _store: Optional[LightningStore] = None): self._store: Optional[LightningStore] = _store # this is only for testing purposes self._buffer: List[ReadableSpan] = [] @@ -307,6 +312,19 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: with self._ensure_lock(): for span in spans: self._buffer.append(span) + + # Prevent unbounded buffer growth by dropping oldest entries when + # the buffer exceeds MAX_BUFFER_SIZE. This can happen when spans + # are missing required headers and can never be flushed. + if len(self._buffer) > self.MAX_BUFFER_SIZE: + n_drop = len(self._buffer) - self.MAX_BUFFER_SIZE + logger.warning( + "Span exporter buffer exceeded %d entries. Dropping %d oldest spans to prevent memory leak.", + self.MAX_BUFFER_SIZE, + n_drop, + ) + self._buffer = self._buffer[n_drop:] + default_endpoint = self._otlp_exporter._endpoint # pyright: ignore[reportPrivateUsage] try: self._maybe_flush() diff --git a/agentlightning/store/memory.py b/agentlightning/store/memory.py index ec97963b7..582e8866f 100644 --- a/agentlightning/store/memory.py +++ b/agentlightning/store/memory.py @@ -378,3 +378,81 @@ async def _evict_spans_for_rollout(self, collections: InMemoryLightningCollectio # There is something removed for real self._total_span_bytes = max(self._total_span_bytes - removed_bytes, 0) self._evicted_rollout_span_sets.add(rollout_id) + + @tracked("cleanup_finished_rollouts") + async def cleanup_finished_rollouts(self, rollout_ids=None): + """Remove all data associated with finished rollouts to free memory. + + This should be called after training data has been extracted from completed + rollouts (e.g., after get_train_data_batch or get_test_metrics). It removes + rollouts, their attempts, spans, and associated tracking metadata from all + in-memory data structures. + + Args: + rollout_ids: Optional list of rollout IDs to clean up. If None, all + finished rollouts will be cleaned up. + + Returns: + The number of rollouts cleaned up. + """ + cleaned_count = 0 + + async with self.collections.atomic( + mode="rw", snapshot=self._read_snapshot, + labels=["rollouts", "attempts", "spans", "span_sequence_ids"], + ) as collections: + # Determine which rollouts to clean up + if rollout_ids is None: + all_rollouts = await collections.rollouts.query() + target_ids = [ + r.rollout_id for r in all_rollouts.items if is_finished(r) + ] + else: + target_ids = list(rollout_ids) + + for rollout_id in target_ids: + rollout = await collections.rollouts.get( + {"rollout_id": {"exact": rollout_id}} + ) + if rollout is None: + continue + if not is_finished(rollout): + continue + + # Remove spans for this rollout + await collections.evict_spans_for_rollout(rollout_id) + + # Remove attempts for this rollout + attempts_result = await collections.attempts.query( + filter={"rollout_id": {"exact": rollout_id}} + ) + if attempts_result.items: + await collections.attempts.delete(attempts_result.items) + + # Remove the rollout itself + await collections.rollouts.delete([rollout]) + + # Remove span sequence ID tracking + await collections.span_sequence_ids.pop(rollout_id) + + cleaned_count += 1 + + # Clean up auxiliary tracking dicts outside the collection lock + for rollout_id in target_ids: + self._completion_events.pop(rollout_id, None) + self._start_time_by_rollout.pop(rollout_id, None) + self._span_bytes_by_rollout.pop(rollout_id, None) + self._running_rollout_ids.discard(rollout_id) + self._evicted_rollout_span_sets.discard(rollout_id) + + if cleaned_count > 0: + logger.info( + "Cleaned up %d finished rollouts. Completion events: %d, " + "start time entries: %d, span byte entries: %d", + cleaned_count, + len(self._completion_events), + len(self._start_time_by_rollout), + len(self._span_bytes_by_rollout), + ) + + return cleaned_count diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 98c58f330..3b37129d8 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -1134,14 +1134,45 @@ def get_train_data_batch( return data_proto, data_metrics def clear_data_and_server(self): - """Resets the internal state of the daemon for the next run.""" + """Resets the internal state of the daemon for the next run. + + Also cleans up finished rollouts from the store to prevent memory leaks + during long training runs. + """ + # Collect rollout IDs before clearing so we can clean up the store. + # In v1 mode, _task_id_to_original_sample maps rollout_id -> sample. + finished_rollout_ids = list(self._task_id_to_original_sample.keys()) + self.backend_llm_server_addresses = [] self._completed_rollouts_v0.clear() self._task_id_to_original_sample.clear() self._total_tasks_queued = 0 - # For a true reset, the server's internal queues would also need clearing. - # This implementation assumes that `set_up_data_and_server` is called - # for each new run, effectively starting a fresh batch. + + # Clean up finished rollouts from the store to free memory. + # This is critical for long training runs where rollout data accumulates + # in the in-memory store across training steps, causing OOM. + if self.mode == "v1" and finished_rollout_ids: + self._cleanup_store_rollouts(finished_rollout_ids) + + def _cleanup_store_rollouts(self, rollout_ids: List[str]): + """Clean up finished rollouts from the store to prevent memory buildup. + + This runs the async cleanup synchronously on the internal event loop. + """ + from agentlightning.store.memory import InMemoryLightningStore + + if not isinstance(self.store, InMemoryLightningStore): + return + + if self._internal_loop is None: + return + + try: + coro = self.store.cleanup_finished_rollouts(rollout_ids) + future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop) + future.result(timeout=60) + except Exception as e: + print(f"Warning: Failed to clean up store rollouts: {e}") def _fillna_reward(self, rollout: RolloutLegacy): if rollout.final_reward is None: diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 413a0a1cf..7e8fb6013 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -4,6 +4,7 @@ from __future__ import annotations +import gc import random from contextlib import contextmanager from copy import deepcopy @@ -272,6 +273,10 @@ def _train_step(self, batch_dict: dict) -> dict: self.agent_mode_daemon.clear_data_and_server() self.async_rollout_manager.sleep() + # Release the original input batch to free memory now that + # training data has been extracted from the daemon. + del gen_batch + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: with _timer("gen_max", timing_raw): gen_baseline_batch = deepcopy(gen_batch) @@ -433,6 +438,13 @@ def _train_step(self, batch_dict: dict) -> dict: n_gpus = self.resource_pool_manager.get_n_gpus() metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # Explicitly release batch tensors and trigger garbage collection to + # prevent memory accumulation across training steps. + del batch + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return metrics def fit(self):