From 0d0a13a5c639a18c6e68a252a7ae5f42e8a955c5 Mon Sep 17 00:00:00 2001 From: Andrew Choi Date: Fri, 13 Mar 2026 11:01:55 -0700 Subject: [PATCH] Add critic_warm_start_iters to ActorCriticLoss --- alf/algorithms/actor_critic_loss.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/alf/algorithms/actor_critic_loss.py b/alf/algorithms/actor_critic_loss.py index 756ca3a3d..c4e65d638 100644 --- a/alf/algorithms/actor_critic_loss.py +++ b/alf/algorithms/actor_critic_loss.py @@ -60,6 +60,7 @@ def __init__(self, advantage_clip=None, entropy_regularization=None, td_loss_weight=1.0, + critic_warm_start_iters=0, debug_summaries=False, name="ActorCriticLoss"): """An actor-critic loss equals to @@ -75,7 +76,7 @@ def __init__(self, gamma (float|list[float]): A discount factor for future rewards. For multi-dim reward, this can also be a list of discounts, each discount applies to a reward dim. - td_errors_loss_fn (Callable): A function for computing the TD errors + td_error_loss_fn (Callable): A function for computing the TD errors loss. This function takes as input the target and the estimated Q values and returns the loss for each element of the batch. use_gae (bool): If True, uses generalized advantage estimation for @@ -99,7 +100,13 @@ def __init__(self, advantage_clip (float): If set, clip advantages to :math:`[-x, x]` entropy_regularization (float): Coefficient for entropy regularization loss term. - td_loss_weight (float): the weigt for the loss of td error. + td_loss_weight (float): the weight for the loss of td error. + critic_warm_start_iters (int): Number of iterations for warm starting + the critic. Actor loss will be still be computed for summary but + not added to the loss sent to the optimizer. Note this is based + off of alf's global counter. Therefore, gradient-free iterations + (i.e., initial_collect_steps unrolling) will still be counted and + should be offset accordingly. """ super().__init__(name=name) @@ -110,6 +117,7 @@ def __init__(self, self._use_gae = use_gae self._lambda = td_lambda self._use_td_lambda_return = use_td_lambda_return + self._critic_warm_start_iters = critic_warm_start_iters if normalize_scalar_advantages: self._adv_norm = torch.nn.BatchNorm1d( num_features=1, @@ -212,21 +220,25 @@ def _summarize(v, r, adv, suffix): if td_loss.ndim == 3: td_loss = td_loss.mean(dim=2) - loss = pg_loss + self._td_loss_weight * td_loss - entropy_loss = () if self._entropy_regularization is not None: # If entropy is explicitly provided, we'll use it. # Otherwise, we will compute it from the provided action_distribution. - if info.entropy is not (): + if info.entropy != (): entropy = info.entropy entropy_for_gradient = info.entropy else: entropy, entropy_for_gradient = dist_utils.entropy_with_fallback( info.action_distribution, return_sum=False) entropy_loss = alf.nest.map_structure(lambda x: -x, entropy) - loss -= self._entropy_regularization * sum( - alf.nest.flatten(entropy_for_gradient)) + + loss = self._td_loss_weight * td_loss + + if alf.summary.get_global_counter() > self._critic_warm_start_iters: + loss += pg_loss + if self._entropy_regularization is not None: + loss -= self._entropy_regularization * sum( + alf.nest.flatten(entropy_for_gradient)) return LossInfo(loss=loss, extra=ActorCriticLossInfo(td_loss=td_loss,