Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions alf/algorithms/actor_critic_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self,
advantage_clip=None,
entropy_regularization=None,
td_loss_weight=1.0,
critic_warm_start_iters=0,
debug_summaries=False,
name="ActorCriticLoss"):
"""An actor-critic loss equals to
Expand All @@ -75,7 +76,7 @@ def __init__(self,
gamma (float|list[float]): A discount factor for future rewards. For
multi-dim reward, this can also be a list of discounts, each
discount applies to a reward dim.
td_errors_loss_fn (Callable): A function for computing the TD errors
td_error_loss_fn (Callable): A function for computing the TD errors
loss. This function takes as input the target and the estimated
Q values and returns the loss for each element of the batch.
use_gae (bool): If True, uses generalized advantage estimation for
Expand All @@ -99,7 +100,13 @@ def __init__(self,
advantage_clip (float): If set, clip advantages to :math:`[-x, x]`
entropy_regularization (float): Coefficient for entropy
regularization loss term.
td_loss_weight (float): the weigt for the loss of td error.
td_loss_weight (float): the weight for the loss of td error.
critic_warm_start_iters (int): Number of iterations for warm starting
the critic. Actor loss will be still be computed for summary but
not added to the loss sent to the optimizer. Note this is based
off of alf's global counter. Therefore, gradient-free iterations
(i.e., initial_collect_steps unrolling) will still be counted and
should be offset accordingly.
"""
super().__init__(name=name)

Expand All @@ -110,6 +117,7 @@ def __init__(self,
self._use_gae = use_gae
self._lambda = td_lambda
self._use_td_lambda_return = use_td_lambda_return
self._critic_warm_start_iters = critic_warm_start_iters
if normalize_scalar_advantages:
self._adv_norm = torch.nn.BatchNorm1d(
num_features=1,
Expand Down Expand Up @@ -212,21 +220,25 @@ def _summarize(v, r, adv, suffix):
if td_loss.ndim == 3:
td_loss = td_loss.mean(dim=2)

loss = pg_loss + self._td_loss_weight * td_loss

entropy_loss = ()
if self._entropy_regularization is not None:
# If entropy is explicitly provided, we'll use it.
# Otherwise, we will compute it from the provided action_distribution.
if info.entropy is not ():
if info.entropy != ():
entropy = info.entropy
entropy_for_gradient = info.entropy
else:
entropy, entropy_for_gradient = dist_utils.entropy_with_fallback(
info.action_distribution, return_sum=False)
entropy_loss = alf.nest.map_structure(lambda x: -x, entropy)
loss -= self._entropy_regularization * sum(
alf.nest.flatten(entropy_for_gradient))

loss = self._td_loss_weight * td_loss

if alf.summary.get_global_counter() > self._critic_warm_start_iters:
loss += pg_loss
if self._entropy_regularization is not None:
loss -= self._entropy_regularization * sum(
alf.nest.flatten(entropy_for_gradient))

return LossInfo(loss=loss,
extra=ActorCriticLossInfo(td_loss=td_loss,
Expand Down
Loading