diff --git a/llmfoundry/loggers/composer_aim_logger.py b/llmfoundry/loggers/composer_aim_logger.py index 2d74fca..682dbcd 100644 --- a/llmfoundry/loggers/composer_aim_logger.py +++ b/llmfoundry/loggers/composer_aim_logger.py @@ -3,6 +3,7 @@ import atexit from logging import getLogger from typing import Dict, Optional, Any, Sequence, Union +from inspect import signature import numpy as np import torch @@ -54,29 +55,19 @@ def __init__( system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, log_system_params: bool = True, capture_terminal_logs: Optional[bool] = True, + log_env_variables: Optional[bool] = False, rank_zero_only: bool = True, entity: Optional[str] = None, project: Optional[str] = None, upload_on_close: bool = True, ): - """ - Args: - repo (str, optional): Path or URI for Aim repo location. - experiment_name (str, optional): The Aim experiment name. - system_tracking_interval (int, optional): Controls how often system usage is logged. - log_system_params (bool): Whether to log system-level metrics (CPU usage, etc.). - capture_terminal_logs (bool, optional): Whether to capture stdout logs. - rank_zero_only (bool): Whether to log only on the global rank zero process. - entity (str, optional): For parity with WandB. Not strictly required by Aim. - project (str, optional): For parity with WandB. Not strictly required by Aim. - upload_on_close (bool): Whether to upload the Aim repo on trainer close. - """ super().__init__() self.repo = repo self.experiment_name = experiment_name self.system_tracking_interval = system_tracking_interval self.log_system_params = log_system_params self.capture_terminal_logs = capture_terminal_logs + self.log_env_variables = log_env_variables self.rank_zero_only = rank_zero_only # If this rank is not zero, we won't log anything @@ -92,6 +83,7 @@ def __init__( self._is_in_atexit = False atexit.register(self._set_is_in_atexit) self.upload_on_close = upload_on_close + def _set_is_in_atexit(self): self._is_in_atexit = True @@ -106,30 +98,32 @@ def _setup(self, state: Optional[State] = None): """Initialize the Aim Run if not already initialized.""" if self._run is not None or not self._enabled: return - + try: + # Get the signature for Run.__init__ to check for parameters + sig = signature(Run.__init__) + + # Build common parameters + params = { + "repo": self.repo, + "system_tracking_interval": self.system_tracking_interval, + "log_system_params": self.log_system_params, + "capture_terminal_logs": self.capture_terminal_logs, + } + if "log_env_variables" in sig.parameters: + params["log_env_variables"] = self.log_env_variables + elif not self.log_env_variables and self.log_system_params: + sys_logger.warning("`log_env_variables` is not supported in this version of Aim. Environment variables will be logged with `log_system_params`.") # fmt: skip + if self._run_hash: - self._run = Run( - self._run_hash, - repo=self.repo, - system_tracking_interval=self.system_tracking_interval, - log_system_params=self.log_system_params, - capture_terminal_logs=self.capture_terminal_logs, - ) + self._run = Run(self._run_hash, **params) else: # Provide the composer run_name if not explicitly given # (Aim calls this "experiment", so we can unify them) experiment = self.experiment_name if experiment is None and state is not None and state.run_name is not None: experiment = state.run_name - - self._run = Run( - repo=self.repo, - experiment=experiment, - system_tracking_interval=self.system_tracking_interval, - log_system_params=self.log_system_params, - capture_terminal_logs=self.capture_terminal_logs, - ) + self._run = Run(**params) self._run_hash = self._run.hash # If available, store or conceive a notion of "run_dir" or "run_url" @@ -218,22 +212,22 @@ def log_table( 'rows': rows, } # Typically you'd store as just a dictionary, or in custom namespace: - self._run.track(table_data, name=name, step=step, context={'type': 'table'}) + self._run.track(table_data, name=name, step=step, context={"type": "table"}) def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: """Log metrics to Aim.""" if not self._enabled or not self._run: return - + try: # Batch metrics instead of logging individually metric_dict = {} array_metrics = {} - + for name, value in metrics.items(): if value is None: continue - + if isinstance(value, (int, float)): metric_dict[name] = value elif isinstance(value, torch.Tensor): @@ -242,11 +236,11 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> No else: # Store array metrics separately array_metrics[name] = value.detach().cpu().numpy() - + # Log scalar metrics in batch if metric_dict: self._run.track(metric_dict, step=step) - + # Log array metrics for name, value in array_metrics.items(): self._run.track( @@ -255,7 +249,7 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> No step=step, context={"type": "array"} ) - + except Exception as e: sys_logger.warning(f"Failed to log metrics: {e}") @@ -264,7 +258,7 @@ def log_images( images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]], name: str = 'Images', step: Optional[int] = None, - **kwargs + **kwargs, # type: ignore ): """Log images, optionally with segmentation masks, to Aim. @@ -279,18 +273,18 @@ def log_images( "format": "CHW" if not kwargs.get('channels_last', False) else "HWC", **kwargs } - + # Convert to sequence if single image if not isinstance(images, Sequence): images = [images] - + for idx, img in enumerate(images): img_data = _to_numpy_image(img, channels_last=kwargs.get('channels_last', False)) self._run.track( img_data, name=f"{name}/{idx}" if len(images) > 1 else name, step=step, - context=context + context=context, ) ### THIS DOES NOT WORK CURRENTLY - AIM DOESN'T NATIVELY SUPPORT ARTIFACT STORAGE SO IT REQUIRES A CUSTOM APPROACH ### @@ -386,4 +380,4 @@ def _to_numpy_image( # This is not required for Aim to store it, but it emulates the W&B practice if not channels_last and image.ndim == 3: image = np.transpose(image, (1, 2, 0)) - return image \ No newline at end of file + return image