Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion alf/summary/summary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import functools
import numpy as np
import torch
import contextlib
from torch.utils.tensorboard import SummaryWriter
from typing import Callable, Union
from typing import Callable, Union, List
from alf.utils.schedulers import update_progress
import alf

# These will be used by orig_tf_gfile_context() in alf.utils.common
TF_IO_GFILE = None
Expand Down Expand Up @@ -99,6 +101,54 @@ def __exit__(self, type, value, traceback):
_scope_stack.pop()


@contextlib.contextmanager
def average_all_summaries(cond: Callable, target_names: List[str] = None):
"""
Context manager that sets all nested scalar summaries to average.
It also disables any nested recording interval logic.
Scalar summaries with an explicit average_over_summary_interval=False
will not be overridden.

This is useful when training with small mini-batches, where per-step scalar
summaries can be noisy.

Args:
cond (Callable): a function which returns whether the summary recordings
should be averaged and recorded.
target_names: An optional list of substring summary names to record. If None,
will average all summaries.
"""
orig_scalar = alf.summary.scalar
orig_record_if = alf.summary.record_if

def _wrap(fn):

def wrapped(name, data, *args, **kwargs):
matched = True
if target_names is not None:
matched = any(t in name for t in target_names)

if matched:
kwargs.setdefault("average_over_summary_interval", True)

return fn(name, data, *args, **kwargs)

return wrapped

@contextlib.contextmanager
def _disabled_record_if(*args, **kwargs):
yield

alf.summary.scalar = _wrap(orig_scalar)
alf.summary.record_if = _disabled_record_if
try:
with orig_record_if(cond):
yield
finally:
alf.summary.scalar = orig_scalar
alf.summary.record_if = orig_record_if


_SUMMARY_DATA_BUFFER = {}


Expand Down
78 changes: 78 additions & 0 deletions alf/summary/summary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,84 @@ def test_summary(self):
self.assertEqual(tag2val['root/b/histogram'].max(), 99)
self.assertEqual(len(tag2val['root/b/histogram']), 30)

def test_average_all_summaries(self):
with tempfile.TemporaryDirectory() as root_dir:
writer = alf.summary.create_summary_writer(root_dir,
flush_secs=10,
max_queue=10)
alf.summary.set_default_writer(writer)
alf.summary.enable_summary()
event_file = _find_event_file(root_dir)
self.assertIsNotNone(event_file)

tag2val = {
'scalar1': None,
'scalar2': None,
}

def load_summaries():
writer.flush()
for event_str in event_file_loader.EventFileLoader(
event_file).Load():
if event_str.summary.value:
for item in event_str.summary.value:
self.assertTrue(item.tag in tag2val)
tag2val[item.tag] = tensor_util.make_ndarray(
item.tensor)

load_summaries()

with alf.summary.record_if(lambda: True):
# average_over_summary_interval is False by default
alf.summary.scalar("scalar1", 101)
alf.summary.scalar("scalar1", 102)
alf.summary.scalar("scalar2", 103)

load_summaries()
self.assertEqual(tag2val['scalar1'], 102)
self.assertEqual(tag2val['scalar2'], 103)

# Test that average_all_summaries uses its own record boundary and
# ignores nested record_if settings.
num_iters = 4
counter = 1
with alf.summary.average_all_summaries(lambda: counter ==
num_iters):
for i in range(num_iters):
# This record_if should be overwritten
with alf.summary.record_if(lambda: True):
# This scalar should be averaged
alf.summary.scalar("scalar1", 100 + i)
# This scalar should not be averaged given the explicit kwarg
alf.summary.scalar("scalar2",
100 + i,
average_over_summary_interval=False)
counter += 1

load_summaries()
self.assertEqual(tag2val['scalar1'], 101.5)
self.assertEqual(tag2val['scalar2'], 103)

# Test that average_all_summaries filters according to target_names.
num_iters = 4
counter = 1
with alf.summary.average_all_summaries(
lambda: counter == num_iters, target_names=["scalar2"]):
for i in range(num_iters):
# This record_if should be overwritten
with alf.summary.record_if(lambda: True):
# This scalar should not be averaged because it is not in target_names
alf.summary.scalar("scalar1", 100 + i)
# This scalar should be averaged
alf.summary.scalar("scalar2", 100 + i)
counter += 1

load_summaries()
self.assertEqual(tag2val['scalar1'], 103)
self.assertEqual(tag2val['scalar2'], 101.5)

writer.close()


if __name__ == "__main__":
alf.test.main()
Loading