From 3ca3702b826c60c2e08adf35708a5b272e87e167 Mon Sep 17 00:00:00 2001 From: Andrew Choi Date: Fri, 13 Mar 2026 12:09:56 -0700 Subject: [PATCH 1/2] Add average_all_summaries context manager --- alf/summary/summary_ops.py | 55 ++++++++++++++++++++++- alf/summary/summary_ops_test.py | 78 +++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/alf/summary/summary_ops.py b/alf/summary/summary_ops.py index 76cb3107f..db6fce629 100644 --- a/alf/summary/summary_ops.py +++ b/alf/summary/summary_ops.py @@ -16,9 +16,11 @@ import functools import numpy as np import torch +import contextlib from torch.utils.tensorboard import SummaryWriter -from typing import Callable, Union +from typing import Callable, Union, List from alf.utils.schedulers import update_progress +import alf # These will be used by orig_tf_gfile_context() in alf.utils.common TF_IO_GFILE = None @@ -99,6 +101,57 @@ def __exit__(self, type, value, traceback): _scope_stack.pop() +@contextlib.contextmanager +def average_all_summaries(cond: Callable, target_names: List[str] = None): + """ + Context manager that sets all nested scalar summaries to average. + It also disables any nested recording interval logic. + Scalar summaries with an explicit average_over_summary_interval=False + will not be overridden. + + This is useful when training with small mini-batches, where per-step scalar + summaries can be noisy. + + Args: + cond (Callable): a function which returns whether the summary recordings + should be averaged and recorded. + target_names: An optional list of substring summary names to record. If None, + will average all summaries. + """ + orig_scalar = alf.summary.scalar + orig_record_if = alf.summary.record_if + orig_should_record_summaries = alf.summary.should_record_summaries + + def _wrap(fn): + + def wrapped(name, data, *args, **kwargs): + matched = True + if target_names is not None: + matched = any(t in name for t in target_names) + + if matched: + kwargs.setdefault("average_over_summary_interval", True) + + return fn(name, data, *args, **kwargs) + + return wrapped + + @contextlib.contextmanager + def _disabled_record_if(*args, **kwargs): + yield + + alf.summary.scalar = _wrap(orig_scalar) + alf.summary.record_if = _disabled_record_if + alf.summary.should_record_summaries = lambda: True + try: + with orig_record_if(cond): + yield + finally: + alf.summary.scalar = orig_scalar + alf.summary.record_if = orig_record_if + alf.summary.should_record_summaries = orig_should_record_summaries + + _SUMMARY_DATA_BUFFER = {} diff --git a/alf/summary/summary_ops_test.py b/alf/summary/summary_ops_test.py index 8ea27da30..401662f9f 100644 --- a/alf/summary/summary_ops_test.py +++ b/alf/summary/summary_ops_test.py @@ -81,6 +81,84 @@ def test_summary(self): self.assertEqual(tag2val['root/b/histogram'].max(), 99) self.assertEqual(len(tag2val['root/b/histogram']), 30) + def test_average_all_summaries(self): + with tempfile.TemporaryDirectory() as root_dir: + writer = alf.summary.create_summary_writer(root_dir, + flush_secs=10, + max_queue=10) + alf.summary.set_default_writer(writer) + alf.summary.enable_summary() + event_file = _find_event_file(root_dir) + self.assertIsNotNone(event_file) + + tag2val = { + 'scalar1': None, + 'scalar2': None, + } + + def load_summaries(): + writer.flush() + for event_str in event_file_loader.EventFileLoader( + event_file).Load(): + if event_str.summary.value: + for item in event_str.summary.value: + self.assertTrue(item.tag in tag2val) + tag2val[item.tag] = tensor_util.make_ndarray( + item.tensor) + + load_summaries() + + with alf.summary.record_if(lambda: True): + # average_over_summary_interval is False by default + alf.summary.scalar("scalar1", 101) + alf.summary.scalar("scalar1", 102) + alf.summary.scalar("scalar2", 103) + + load_summaries() + self.assertEqual(tag2val['scalar1'], 102) + self.assertEqual(tag2val['scalar2'], 103) + + # Test that average_all_summaries uses its own record boundary and + # ignores nested record_if settings. + num_iters = 4 + counter = 1 + with alf.summary.average_all_summaries(lambda: counter == + num_iters): + for i in range(num_iters): + # This record_if should be overwritten + with alf.summary.record_if(lambda: True): + # This scalar should be averaged + alf.summary.scalar("scalar1", 100 + i) + # This scalar should not be averaged given the explicit kwarg + alf.summary.scalar("scalar2", + 100 + i, + average_over_summary_interval=False) + counter += 1 + + load_summaries() + self.assertEqual(tag2val['scalar1'], 101.5) + self.assertEqual(tag2val['scalar2'], 103) + + # Test that average_all_summaries filters according to target_names. + num_iters = 4 + counter = 1 + with alf.summary.average_all_summaries( + lambda: counter == num_iters, target_names=["scalar2"]): + for i in range(num_iters): + # This record_if should be overwritten + with alf.summary.record_if(lambda: True): + # This scalar should not be averaged because it is not in target_names + alf.summary.scalar("scalar1", 100 + i) + # This scalar should be averaged + alf.summary.scalar("scalar2", 100 + i) + counter += 1 + + load_summaries() + self.assertEqual(tag2val['scalar1'], 103) + self.assertEqual(tag2val['scalar2'], 101.5) + + writer.close() + if __name__ == "__main__": alf.test.main() From 7e60b632ea4f41044e34ad46570afbaa33437f9e Mon Sep 17 00:00:00 2001 From: Andrew Choi Date: Mon, 16 Mar 2026 09:49:38 -0700 Subject: [PATCH 2/2] remove redundant logic --- alf/summary/summary_ops.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/alf/summary/summary_ops.py b/alf/summary/summary_ops.py index db6fce629..fdac7193f 100644 --- a/alf/summary/summary_ops.py +++ b/alf/summary/summary_ops.py @@ -120,7 +120,6 @@ def average_all_summaries(cond: Callable, target_names: List[str] = None): """ orig_scalar = alf.summary.scalar orig_record_if = alf.summary.record_if - orig_should_record_summaries = alf.summary.should_record_summaries def _wrap(fn): @@ -142,14 +141,12 @@ def _disabled_record_if(*args, **kwargs): alf.summary.scalar = _wrap(orig_scalar) alf.summary.record_if = _disabled_record_if - alf.summary.should_record_summaries = lambda: True try: with orig_record_if(cond): yield finally: alf.summary.scalar = orig_scalar alf.summary.record_if = orig_record_if - alf.summary.should_record_summaries = orig_should_record_summaries _SUMMARY_DATA_BUFFER = {}