Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 34 additions & 40 deletions llmfoundry/loggers/composer_aim_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import atexit
from logging import getLogger
from typing import Dict, Optional, Any, Sequence, Union
from inspect import signature

import numpy as np
import torch
Expand Down Expand Up @@ -54,29 +55,19 @@ def __init__(
system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
log_system_params: bool = True,
capture_terminal_logs: Optional[bool] = True,
log_env_variables: Optional[bool] = False,
rank_zero_only: bool = True,
entity: Optional[str] = None,
project: Optional[str] = None,
upload_on_close: bool = True,
):
"""
Args:
repo (str, optional): Path or URI for Aim repo location.
experiment_name (str, optional): The Aim experiment name.
system_tracking_interval (int, optional): Controls how often system usage is logged.
log_system_params (bool): Whether to log system-level metrics (CPU usage, etc.).
capture_terminal_logs (bool, optional): Whether to capture stdout logs.
rank_zero_only (bool): Whether to log only on the global rank zero process.
entity (str, optional): For parity with WandB. Not strictly required by Aim.
project (str, optional): For parity with WandB. Not strictly required by Aim.
upload_on_close (bool): Whether to upload the Aim repo on trainer close.
"""
super().__init__()
self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self.capture_terminal_logs = capture_terminal_logs
self.log_env_variables = log_env_variables

self.rank_zero_only = rank_zero_only
# If this rank is not zero, we won't log anything
Expand All @@ -92,6 +83,7 @@ def __init__(
self._is_in_atexit = False
atexit.register(self._set_is_in_atexit)
self.upload_on_close = upload_on_close

def _set_is_in_atexit(self):
self._is_in_atexit = True

Expand All @@ -106,30 +98,32 @@ def _setup(self, state: Optional[State] = None):
"""Initialize the Aim Run if not already initialized."""
if self._run is not None or not self._enabled:
return

try:
# Get the signature for Run.__init__ to check for parameters
sig = signature(Run.__init__)

# Build common parameters
params = {
"repo": self.repo,
"system_tracking_interval": self.system_tracking_interval,
"log_system_params": self.log_system_params,
"capture_terminal_logs": self.capture_terminal_logs,
}
if "log_env_variables" in sig.parameters:
params["log_env_variables"] = self.log_env_variables
elif not self.log_env_variables and self.log_system_params:
sys_logger.warning("`log_env_variables` is not supported in this version of Aim. Environment variables will be logged with `log_system_params`.") # fmt: skip

if self._run_hash:
self._run = Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
capture_terminal_logs=self.capture_terminal_logs,
)
self._run = Run(self._run_hash, **params)
else:
# Provide the composer run_name if not explicitly given
# (Aim calls this "experiment", so we can unify them)
experiment = self.experiment_name
if experiment is None and state is not None and state.run_name is not None:
experiment = state.run_name

self._run = Run(
repo=self.repo,
experiment=experiment,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
capture_terminal_logs=self.capture_terminal_logs,
)
self._run = Run(**params)
self._run_hash = self._run.hash

# If available, store or conceive a notion of "run_dir" or "run_url"
Expand Down Expand Up @@ -218,22 +212,22 @@ def log_table(
'rows': rows,
}
# Typically you'd store as just a dictionary, or in custom namespace:
self._run.track(table_data, name=name, step=step, context={'type': 'table'})
self._run.track(table_data, name=name, step=step, context={"type": "table"})

def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
"""Log metrics to Aim."""
if not self._enabled or not self._run:
return

try:
# Batch metrics instead of logging individually
metric_dict = {}
array_metrics = {}

for name, value in metrics.items():
if value is None:
continue

if isinstance(value, (int, float)):
metric_dict[name] = value
elif isinstance(value, torch.Tensor):
Expand All @@ -242,11 +236,11 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> No
else:
# Store array metrics separately
array_metrics[name] = value.detach().cpu().numpy()

# Log scalar metrics in batch
if metric_dict:
self._run.track(metric_dict, step=step)

# Log array metrics
for name, value in array_metrics.items():
self._run.track(
Expand All @@ -255,7 +249,7 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> No
step=step,
context={"type": "array"}
)

except Exception as e:
sys_logger.warning(f"Failed to log metrics: {e}")

Expand All @@ -264,7 +258,7 @@ def log_images(
images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]],
name: str = 'Images',
step: Optional[int] = None,
**kwargs
**kwargs, # type: ignore
):
"""Log images, optionally with segmentation masks, to Aim.

Expand All @@ -279,18 +273,18 @@ def log_images(
"format": "CHW" if not kwargs.get('channels_last', False) else "HWC",
**kwargs
}

# Convert to sequence if single image
if not isinstance(images, Sequence):
images = [images]

for idx, img in enumerate(images):
img_data = _to_numpy_image(img, channels_last=kwargs.get('channels_last', False))
self._run.track(
img_data,
name=f"{name}/{idx}" if len(images) > 1 else name,
step=step,
context=context
context=context,
)

### THIS DOES NOT WORK CURRENTLY - AIM DOESN'T NATIVELY SUPPORT ARTIFACT STORAGE SO IT REQUIRES A CUSTOM APPROACH ###
Expand Down Expand Up @@ -386,4 +380,4 @@ def _to_numpy_image(
# This is not required for Aim to store it, but it emulates the W&B practice
if not channels_last and image.ndim == 3:
image = np.transpose(image, (1, 2, 0))
return image
return image