From 28f1cef0fcff31cb6ea23e7d8943d139b8a91e8c Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Tue, 8 Mar 2022 16:01:42 -0800 Subject: [PATCH 1/2] Munchausen RL --- README.md | 1 + alf/algorithms/sac_algorithm.py | 35 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/README.md b/README.md index 6bf2e49f1..66d38c633 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Read the ALF documentation [here](https://alf.readthedocs.io/). |[OAC](alf/algorithms/oac_algorithm.py)|Off-policy RL|Ciosek et al. "Better Exploration with Optimistic Actor-Critic" [arXiv:1910.12807](https://arxiv.org/abs/1910.12807)| |[HER](https://github.com/HorizonRobotics/alf/blob/911d9573866df41e9e3adf6cdd94ee03016bf5a8/alf/algorithms/data_transformer.py#L672)|Off-policy RL|Andrychowicz et al. "Hindsight Experience Replay" [arXiv:1707.01495](https://arxiv.org/abs/1707.01495)| |[TAAC](alf/algorithms/taac_algorithm.py)|Off-policy RL|Yu et al. "TAAC: Temporally Abstract Actor-Critic for Continuous Control" [arXiv:2104.06521](https://arxiv.org/abs/2104.06521)| +|[Munchausen RL](alf/algorithms/sac_algorithm.py)|Off-policy RL|Nino et al. "Munchausen Reinforcement Learning" [arXiv:2007.14430](https://arxiv.org/abs/2007.14430)| |[DIAYN](alf/algorithms/diayn_algorithm.py)|Intrinsic motivation/Exploration|Eysenbach et al. "Diversity is All You Need: Learning Diverse Skills without a Reward Function" [arXiv:1802.06070](https://arxiv.org/abs/1802.06070)| |[ICM](alf/algorithms/icm_algorithm.py)|Intrinsic motivation/Exploration|Pathak et al. "Curiosity-driven Exploration by Self-supervised Prediction" [arXiv:1705.05363](https://arxiv.org/abs/1705.05363)| |[RND](alf/algorithms/rnd_algorithm.py)|Intrinsic motivation/Exploration|Burda et al. "Exploration by Random Network Distillation" [arXiv:1810.12894](https://arxiv.org/abs/1810.12894)| diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index ea1cfad13..dd29f126e 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -153,6 +153,7 @@ def __init__(self, reward_weights=None, epsilon_greedy=None, use_entropy_reward=True, + munchausen_reward_weight=0, normalize_entropy_reward=False, calculate_priority=False, num_critic_replicas=2, @@ -204,6 +205,11 @@ def __init__(self, from ``config.epsilon_greedy`` and then ``alf.get_config_value(TrainerConfig.epsilon_greedy)``. use_entropy_reward (bool): whether to include entropy as reward + munchausen_reward_weight (float): the weight of augmenting the task + reward with munchausen reward, as introduced in ``Munchausen + Reinforcement Learning``, which is essentially the log_pi of + the given action. A non-positive value means the munchausen + reward is not used. normalize_entropy_reward (bool): if True, normalize entropy reward to reduce bias in episodic cases. Only used if ``use_entropy_reward==True``. @@ -267,6 +273,11 @@ def __init__(self, critic_network_cls, q_network_cls) self._use_entropy_reward = use_entropy_reward + self._munchausen_reward_weight = min(0, munchausen_reward_weight) + if munchausen_reward_weight > 0: + assert not normalize_entropy_reward, ( + "should not normalize entropy " + "reward when using munchausen reward") if reward_spec.numel > 1: assert self._act_type != ActionType.Mixed, ( @@ -846,6 +857,30 @@ def _calc_critic_loss(self, info: SacInfo): When the reward is multi-dim, the entropy reward will be added to *all* dims. """ + if self._munchausen_reward_weight > 0: + with torch.no_grad(): + # calculate the log probability of the rollout action + log_pi_rollout_a = nest.map_structure( + lambda dist, a: dist.log_prob(a), info.action_distribution, + info.action) + + if self._act_type == ActionType.Mixed: + # For mixed type, add log_pi separately + log_pi_rollout_a = type(self._action_spec)( + (sum(nest.flatten(log_pi_rollout_a[0])), + sum(nest.flatten(log_pi_rollout_a[1])))) + else: + log_pi_rollout_a = sum(nest.flatten(log_pi_rollout_a)) + + munchausen_reward = nest.map_structure( + lambda la, lp: torch.exp(la) * lp, self._log_alpha, + log_pi_rollout_a) + munchausen_reward = sum(nest.flatten(munchausen_reward)) + info = info._replace( + reward=( + info.reward + self._munchausen_reward_weight * + common.expand_dims_as(munchausen_reward, info.reward))) + if self._use_entropy_reward: with torch.no_grad(): log_pi = info.log_pi From b3306303b38ab6359504dcb5e13bf9ea2da8725c Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Tue, 8 Mar 2022 17:24:58 -0800 Subject: [PATCH 2/2] Fix alignment --- alf/algorithms/sac_algorithm.py | 45 ++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py index dd29f126e..451dab241 100644 --- a/alf/algorithms/sac_algorithm.py +++ b/alf/algorithms/sac_algorithm.py @@ -273,7 +273,7 @@ def __init__(self, critic_network_cls, q_network_cls) self._use_entropy_reward = use_entropy_reward - self._munchausen_reward_weight = min(0, munchausen_reward_weight) + self._munchausen_reward_weight = max(0, munchausen_reward_weight) if munchausen_reward_weight > 0: assert not normalize_entropy_reward, ( "should not normalize entropy " @@ -853,10 +853,23 @@ def _calc_critic_loss(self, info: SacInfo): (There is an issue in their implementation: their "terminals" can't differentiate between discount=0 (NormalEnd) and discount=1 (TimeOut). In the latter case, masking should not be performed.) - - When the reward is multi-dim, the entropy reward will be added to *all* - dims. """ + if self._use_entropy_reward: + with torch.no_grad(): + log_pi = info.log_pi + if self._entropy_normalizer is not None: + log_pi = self._entropy_normalizer.normalize(log_pi) + entropy_reward = nest.map_structure( + lambda la, lp: -torch.exp(la) * lp, self._log_alpha, + log_pi) + entropy_reward = sum(nest.flatten(entropy_reward)) + discount = self._critic_losses[0].gamma * info.discount + # When the reward is multi-dim, the entropy reward will be + # added to *all* dims. + info = info._replace( + reward=(info.reward + common.expand_dims_as( + entropy_reward * discount, info.reward))) + if self._munchausen_reward_weight > 0: with torch.no_grad(): # calculate the log probability of the rollout action @@ -875,26 +888,22 @@ def _calc_critic_loss(self, info: SacInfo): munchausen_reward = nest.map_structure( lambda la, lp: torch.exp(la) * lp, self._log_alpha, log_pi_rollout_a) + # [T, B] munchausen_reward = sum(nest.flatten(munchausen_reward)) + # forward shift the munchausen reward one-step temporally, + # with zero-padding for the first step. This dummy reward + # for the first step does not impact training as it is not + # used in TD-learning. + munchausen_reward = torch.cat((torch.zeros_like( + munchausen_reward[0:1]), munchausen_reward[:-1]), + dim=0) + # When the reward is multi-dim, the munchausen reward will be + # added to *all* dims. info = info._replace( reward=( info.reward + self._munchausen_reward_weight * common.expand_dims_as(munchausen_reward, info.reward))) - if self._use_entropy_reward: - with torch.no_grad(): - log_pi = info.log_pi - if self._entropy_normalizer is not None: - log_pi = self._entropy_normalizer.normalize(log_pi) - entropy_reward = nest.map_structure( - lambda la, lp: -torch.exp(la) * lp, self._log_alpha, - log_pi) - entropy_reward = sum(nest.flatten(entropy_reward)) - discount = self._critic_losses[0].gamma * info.discount - info = info._replace( - reward=(info.reward + common.expand_dims_as( - entropy_reward * discount, info.reward))) - critic_info = info.critic critic_losses = [] for i, l in enumerate(self._critic_losses):