From 00efea8975dcf605e81817317476227baeedbb25 Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Tue, 8 Apr 2025 14:52:20 -0700 Subject: [PATCH 1/8] Episodic annotation and synced training --- alf/algorithms/algorithm.py | 13 +++-- alf/algorithms/rl_algorithm.py | 95 +++++++++++++++++++++++++++------- 2 files changed, 85 insertions(+), 23 deletions(-) diff --git a/alf/algorithms/algorithm.py b/alf/algorithms/algorithm.py index 19a6db97b..d1eb68377 100644 --- a/alf/algorithms/algorithm.py +++ b/alf/algorithms/algorithm.py @@ -1426,7 +1426,9 @@ def train_from_unroll(self, experience, train_info): return shape[0] * shape[1] @common.mark_replay - def train_from_replay_buffer(self, update_global_counter=False): + def train_from_replay_buffer(self, + effective_unroll_steps, + update_global_counter=False): """This function can be called by any algorithm that has its own replay buffer configured. There are several parameters specified in ``self._config`` that will affect how the training is performed: @@ -1469,6 +1471,7 @@ def train_from_replay_buffer(self, update_global_counter=False): ``True``, it will affect the counter only if ``config.update_counter_every_mini_batch=True``. """ + config: TrainerConfig = self._config # returns 0 if haven't started training yet, when ``_replay_buffer`` is @@ -1479,7 +1482,8 @@ def train_from_replay_buffer(self, update_global_counter=False): # training is not started yet, ``_replay_buffer`` will be None since it # is only lazily created later when online RL training started. if (self._replay_buffer and self._replay_buffer.total_size - < config.initial_collect_steps): + < config.initial_collect_steps) or (effective_unroll_steps + == 0): assert ( self._replay_buffer.num_environments * self._replay_buffer.max_length >= config.initial_collect_steps @@ -1493,6 +1497,7 @@ def _replay(): # ``_replay_buffer`` for training. # TODO: If this function can be called asynchronously, and using # prioritized replay, then make sure replay and train below is atomic. + effective_num_updates_per_train_iter = config.num_updates_per_train_iter with record_time("time/replay"): mini_batch_size = config.mini_batch_size if mini_batch_size is None: @@ -1500,14 +1505,14 @@ def _replay(): if config.whole_replay_buffer_training: experience, batch_info = self._replay_buffer.gather_all( ignore_earliest_frames=True) - num_updates = config.num_updates_per_train_iter + num_updates = effective_num_updates_per_train_iter else: assert config.mini_batch_length is not None, ( "No mini_batch_length is specified for off-policy training" ) experience, batch_info = self._replay_buffer.get_batch( batch_size=(mini_batch_size * - config.num_updates_per_train_iter), + effective_num_updates_per_train_iter), batch_length=config.mini_batch_length) num_updates = 1 return experience, batch_info, num_updates, mini_batch_size diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index 773647233..88abfd52f 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -19,7 +19,7 @@ import os import time import torch -from typing import Callable, Optional +from typing import Callable, List, Optional from absl import logging import alf @@ -147,6 +147,7 @@ def __init__(self, optimizer=None, checkpoint=None, is_eval: bool = False, + episodic_annotation: bool = False, overwrite_policy_output=False, debug_summaries=False, name="RLAlgorithm"): @@ -186,6 +187,8 @@ def __init__(self, during deployment. In this case, the algorithm do not need to create certain components such as value_network for ActorCriticAlgorithm, critic_networks for SacAlgorithm. + episodic_annotation: if True, annotate the episode before being observed by the + replay buffer. overwrite_policy_output (bool): if True, overwrite the policy output with next_step.prev_action. This option can be used in some cases such as data collection. @@ -203,6 +206,7 @@ def __init__(self, debug_summaries=debug_summaries, name=name) self._is_eval = is_eval + self._episodic_annotation = episodic_annotation self._env = env self._observation_spec = observation_spec @@ -235,11 +239,14 @@ def __init__(self, self._current_time_step = None self._current_policy_state = None self._current_transform_state = None - + self._cached_exp = [] # for lazy observation if self._env is not None and not self.on_policy: replay_buffer_length = adjust_replay_buffer_length( config, self._num_earliest_frames_ignored) + if self._episodic_annotation: + assert self._env.batch_size == 1, "only support non-batched environment" + if config.whole_replay_buffer_training and config.clear_replay_buffer: # For whole replay buffer training, we would like to be sure # that the replay buffer have enough samples in it to perform @@ -598,6 +605,25 @@ def _async_unroll(self, unroll_length: int): return experience + def should_post_process_episode(self, rollout_info, step_type: StepType): + """A function that determines whether the ``post_process_episode`` function should + be applied to the current list of experiences. + """ + return False + + def post_process_episode(self, experiences: List[Experience]): + """A function for postprocessing a list of experience. It is called when + ``should_post_process_episode`` is True. + It can be used to create a number of useful features such as 'hindsight relabeling' + of a trajectory etc. + + Args: + experiences: a list of experience, containing the experience starting from the + initial time when ``should_post_process_episode`` is False to the step where + ``should_post_process_episode`` is True. + """ + return None + def _process_unroll_step(self, policy_step, action, time_step, transformed_time_step, policy_state, experience_list, original_reward_list): @@ -605,12 +631,36 @@ def _process_unroll_step(self, policy_step, action, time_step, exp = make_experience(time_step.cpu(), alf.layers.to_float32(policy_step), alf.layers.to_float32(policy_state)) - - store_exp_time = 0 - if not self.on_policy: - t0 = time.time() - self.observe_for_replay(exp) - store_exp_time = time.time() - t0 + effective_number_of_unroll_steps = 1 + if self._episodic_annotation: + store_exp_time = 0 + # if last step, annotate + rollout_info = policy_step.info + self._cached_exp.append(exp) + if self.should_post_process_episode(rollout_info, + time_step.step_type): + + # 1) process + annotated_exp_list = self.post_process_episode( + self._cached_exp) + effective_number_of_unroll_steps = len(annotated_exp_list) + # 2) observe + if not self.on_policy: + t0 = time.time() + for exp in annotated_exp_list: + self.observe_for_replay(exp) + store_exp_time = time.time() - t0 + # clean up the exp cache + self._cached_exp = [] + else: + # effective unroll steps as 0 if not post_process_episode timepoint yet + effective_number_of_unroll_steps = 0 + else: + store_exp_time = 0 + if not self.on_policy: + t0 = time.time() + self.observe_for_replay(exp) + store_exp_time = time.time() - t0 exp_for_training = Experience( time_step=transformed_time_step, @@ -620,7 +670,7 @@ def _process_unroll_step(self, policy_step, action, time_step, experience_list.append(exp_for_training) original_reward_list.append(time_step.reward) - return store_exp_time + return store_exp_time, effective_number_of_unroll_steps def reset_state(self): """Reset the state of the algorithm. @@ -665,6 +715,7 @@ def _sync_unroll(self, unroll_length: int): policy_step_time = 0. env_step_time = 0. store_exp_time = 0. + effective_unroll_steps = 0 for _ in range(unroll_length): policy_state = common.reset_state_if_necessary( policy_state, initial_state, time_step.is_first()) @@ -693,9 +744,10 @@ def _sync_unroll(self, unroll_length: int): if self._overwrite_policy_output: policy_step = policy_step._replace( output=next_time_step.prev_action) - store_exp_time += self._process_unroll_step( + store_exp_time_i, effective_unroll_steps = self._process_unroll_step( policy_step, action, time_step, transformed_time_step, policy_state, experience_list, original_reward_list) + store_exp_time += store_exp_time_i time_step = next_time_step policy_state = policy_step.state @@ -723,7 +775,7 @@ def _sync_unroll(self, unroll_length: int): self._current_policy_state = common.detach(policy_state) self._current_transform_state = common.detach(trans_state) - return experience + return experience, effective_unroll_steps def train_iter(self): """Perform one iteration of training. @@ -804,6 +856,7 @@ def _unroll_iter_off_policy(self): unrolled = False root_inputs = None rollout_info = None + effective_unroll_steps = 0 if (alf.summary.get_global_counter() >= self._rl_train_after_update_steps and (unroll_length > 0 or config.unroll_length == 0) and @@ -822,7 +875,8 @@ def _unroll_iter_off_policy(self): # need to remember whether summary has been written between # two unrolls. with self._ensure_rollout_summary: - experience = self.unroll(unroll_length) + experience, effective_unroll_steps = self.unroll( + unroll_length) if experience: self.summarize_rollout(experience) self.summarize_metrics() @@ -830,11 +884,12 @@ def _unroll_iter_off_policy(self): if config.use_root_inputs_for_after_train_iter: root_inputs = experience.time_step del experience - return unrolled, root_inputs, rollout_info + return unrolled, root_inputs, rollout_info, effective_unroll_steps def _train_iter_off_policy(self): """User may override this for their own training procedure.""" - unrolled, root_inputs, rollout_info = self._unroll_iter_off_policy() + unrolled, root_inputs, rollout_info, effective_unroll_steps = self._unroll_iter_off_policy( + ) # replay buffer may not have been created for two different reasons: # 1. in online RL training (``has_offline`` is False), unroll is not @@ -846,11 +901,13 @@ def _train_iter_off_policy(self): return 0 self.train() - steps = self.train_from_replay_buffer(update_global_counter=True) - - if unrolled: - with record_time("time/after_train_iter"): - self.after_train_iter(root_inputs, rollout_info) + steps = 0 + for i in range(effective_unroll_steps): + steps += self.train_from_replay_buffer(effective_unroll_steps=1, + update_global_counter=True) + if unrolled: + with record_time("time/after_train_iter"): + self.after_train_iter(root_inputs, rollout_info) # For now, we only return the steps of the primary algorithm's training return steps From e4cdb811e45ee30575643d2f4c30b1988c8b8255 Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Fri, 9 May 2025 15:59:21 -0700 Subject: [PATCH 2/8] Address comments --- alf/algorithms/algorithm.py | 7 ++----- alf/algorithms/rl_algorithm.py | 36 +++++++++++++++++++--------------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/alf/algorithms/algorithm.py b/alf/algorithms/algorithm.py index d1eb68377..cabe2c5c9 100644 --- a/alf/algorithms/algorithm.py +++ b/alf/algorithms/algorithm.py @@ -1426,9 +1426,7 @@ def train_from_unroll(self, experience, train_info): return shape[0] * shape[1] @common.mark_replay - def train_from_replay_buffer(self, - effective_unroll_steps, - update_global_counter=False): + def train_from_replay_buffer(self, update_global_counter=False): """This function can be called by any algorithm that has its own replay buffer configured. There are several parameters specified in ``self._config`` that will affect how the training is performed: @@ -1482,8 +1480,7 @@ def train_from_replay_buffer(self, # training is not started yet, ``_replay_buffer`` will be None since it # is only lazily created later when online RL training started. if (self._replay_buffer and self._replay_buffer.total_size - < config.initial_collect_steps) or (effective_unroll_steps - == 0): + < config.initial_collect_steps): assert ( self._replay_buffer.num_environments * self._replay_buffer.max_length >= config.initial_collect_steps diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index 88abfd52f..19cacd3a8 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -187,8 +187,10 @@ def __init__(self, during deployment. In this case, the algorithm do not need to create certain components such as value_network for ActorCriticAlgorithm, critic_networks for SacAlgorithm. - episodic_annotation: if True, annotate the episode before being observed by the - replay buffer. + episodic_annotation: episodic annotation is an operation that annotates the + episode after it being collected, and then the annotated episode will be + observed by the replay buffer. If True, annotate the episode before being + observed by the replay buffer. Otherwise, episodic annotation is not applied. overwrite_policy_output (bool): if True, overwrite the policy output with next_step.prev_action. This option can be used in some cases such as data collection. @@ -244,9 +246,6 @@ def __init__(self, replay_buffer_length = adjust_replay_buffer_length( config, self._num_earliest_frames_ignored) - if self._episodic_annotation: - assert self._env.batch_size == 1, "only support non-batched environment" - if config.whole_replay_buffer_training and config.clear_replay_buffer: # For whole replay buffer training, we would like to be sure # that the replay buffer have enough samples in it to perform @@ -608,21 +607,27 @@ def _async_unroll(self, unroll_length: int): def should_post_process_episode(self, rollout_info, step_type: StepType): """A function that determines whether the ``post_process_episode`` function should be applied to the current list of experiences. + Users can customize this function in the derived class. + Bu default, it returns True all the time steps. When this is combined with + ``post_process_episode`` which simply return the input unmodified (as the default + implementation in this class), it is a dummy version of eposodic annotation with + logic equivalent to the case of episodic_annotation=False. """ - return False + return True def post_process_episode(self, experiences: List[Experience]): """A function for postprocessing a list of experience. It is called when ``should_post_process_episode`` is True. - It can be used to create a number of useful features such as 'hindsight relabeling' - of a trajectory etc. + By default, it returns the input unmodified. + Users can customize this function in the derived class, to create a number of + useful features such as 'hindsight relabeling' of a trajectory etc. Args: experiences: a list of experience, containing the experience starting from the initial time when ``should_post_process_episode`` is False to the step where ``should_post_process_episode`` is True. """ - return None + return experiences def _process_unroll_step(self, policy_step, action, time_step, transformed_time_step, policy_state, @@ -633,6 +638,7 @@ def _process_unroll_step(self, policy_step, action, time_step, alf.layers.to_float32(policy_state)) effective_number_of_unroll_steps = 1 if self._episodic_annotation: + assert not self.on_policy, "only support episodic annotation for off policy training" store_exp_time = 0 # if last step, annotate rollout_info = policy_step.info @@ -645,11 +651,10 @@ def _process_unroll_step(self, policy_step, action, time_step, self._cached_exp) effective_number_of_unroll_steps = len(annotated_exp_list) # 2) observe - if not self.on_policy: - t0 = time.time() - for exp in annotated_exp_list: - self.observe_for_replay(exp) - store_exp_time = time.time() - t0 + t0 = time.time() + for exp in annotated_exp_list: + self.observe_for_replay(exp) + store_exp_time = time.time() - t0 # clean up the exp cache self._cached_exp = [] else: @@ -903,8 +908,7 @@ def _train_iter_off_policy(self): self.train() steps = 0 for i in range(effective_unroll_steps): - steps += self.train_from_replay_buffer(effective_unroll_steps=1, - update_global_counter=True) + steps += self.train_from_replay_buffer(update_global_counter=True) if unrolled: with record_time("time/after_train_iter"): self.after_train_iter(root_inputs, rollout_info) From a05e8daedea530b0ededce13a2324dd08533a710 Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Wed, 21 May 2025 12:26:09 -0700 Subject: [PATCH 3/8] Update async unroll --- alf/algorithms/rl_algorithm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index 19cacd3a8..16e5d4a49 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -572,10 +572,11 @@ def _async_unroll(self, unroll_length: int): step_time += unroll_result.step_time max_step_time = max(max_step_time, unroll_result.step_time) - store_exp_time += self._process_unroll_step( + store_exp_time_i, effective_unroll_steps = self._process_unroll_step( policy_step, policy_step.output, time_step, transformed_time_step, policy_state, experience_list, original_reward_list) + store_exp_time += store_exp_time_i alf.summary.scalar("time/unroll_env_step", env_step_time, @@ -602,7 +603,7 @@ def _async_unroll(self, unroll_length: int): self._current_transform_state = common.detach(trans_state) - return experience + return experience, effective_unroll_steps def should_post_process_episode(self, rollout_info, step_type: StepType): """A function that determines whether the ``post_process_episode`` function should @@ -804,7 +805,7 @@ def _compute_train_info_and_loss_info_on_policy(self, unroll_length): with record_time("time/unroll"): with torch.cuda.amp.autocast(self._config.enable_amp, dtype=self._config.amp_dtype): - experience = self.unroll(self._config.unroll_length) + experience, _ = self.unroll(self._config.unroll_length) self.summarize_metrics() train_info = experience.rollout_info From 9cfe6a5f104b4baba4a3630cfe3b079302848ee3 Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Fri, 23 May 2025 11:19:13 -0700 Subject: [PATCH 4/8] Update --- alf/algorithms/rl_algorithm.py | 122 ++++++++++++++++++--------------- 1 file changed, 67 insertions(+), 55 deletions(-) diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index 16e5d4a49..f7b2bef21 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -147,7 +147,6 @@ def __init__(self, optimizer=None, checkpoint=None, is_eval: bool = False, - episodic_annotation: bool = False, overwrite_policy_output=False, debug_summaries=False, name="RLAlgorithm"): @@ -187,10 +186,6 @@ def __init__(self, during deployment. In this case, the algorithm do not need to create certain components such as value_network for ActorCriticAlgorithm, critic_networks for SacAlgorithm. - episodic_annotation: episodic annotation is an operation that annotates the - episode after it being collected, and then the annotated episode will be - observed by the replay buffer. If True, annotate the episode before being - observed by the replay buffer. Otherwise, episodic annotation is not applied. overwrite_policy_output (bool): if True, overwrite the policy output with next_step.prev_action. This option can be used in some cases such as data collection. @@ -208,7 +203,6 @@ def __init__(self, debug_summaries=debug_summaries, name=name) self._is_eval = is_eval - self._episodic_annotation = episodic_annotation self._env = env self._observation_spec = observation_spec @@ -241,7 +235,6 @@ def __init__(self, self._current_time_step = None self._current_policy_state = None self._current_transform_state = None - self._cached_exp = [] # for lazy observation if self._env is not None and not self.on_policy: replay_buffer_length = adjust_replay_buffer_length( config, self._num_earliest_frames_ignored) @@ -550,6 +543,7 @@ def _async_unroll(self, unroll_length: int): store_exp_time = 0. step_time = 0. max_step_time = 0. + effective_unroll_steps = 0 qsize = self._async_unroller.get_queue_size() unroll_results = self._async_unroller.gather_unroll_results( unroll_length, self._config.max_unroll_length) @@ -572,11 +566,12 @@ def _async_unroll(self, unroll_length: int): step_time += unroll_result.step_time max_step_time = max(max_step_time, unroll_result.step_time) - store_exp_time_i, effective_unroll_steps = self._process_unroll_step( + store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step( policy_step, policy_step.output, time_step, transformed_time_step, policy_state, experience_list, original_reward_list) store_exp_time += store_exp_time_i + effective_unroll_steps += effective_unroll_steps_i alf.summary.scalar("time/unroll_env_step", env_step_time, @@ -603,70 +598,80 @@ def _async_unroll(self, unroll_length: int): self._current_transform_state = common.detach(trans_state) - return experience, effective_unroll_steps - - def should_post_process_episode(self, rollout_info, step_type: StepType): - """A function that determines whether the ``post_process_episode`` function should - be applied to the current list of experiences. - Users can customize this function in the derived class. - Bu default, it returns True all the time steps. When this is combined with - ``post_process_episode`` which simply return the input unmodified (as the default - implementation in this class), it is a dummy version of eposodic annotation with - logic equivalent to the case of episodic_annotation=False. + effective_unroll_iters = effective_unroll_steps // unroll_length + return experience, effective_unroll_iters + + def should_post_process_experience(self, rollout_info, + step_type: StepType): + """A function that determines whether the ``post_process_experience`` function should + be called. Users can customize this pair of functions in the derived class to achieve + different effects. For example: + - per-step processing: ``should_post_process_experience`` + returns True for all the steps (by default), and ``post_process_experience`` + returns the current step of experience unmodified (by default) or a modified version + according to their customized ``post_process_experience`` function. + As another example, task filtering can be simply achieved by returning ``[]`` + in ``post_process_experience`` for that particular task. + - per-episode processing: ``should_post_process_experience`` returns True on episode + end and ``post_process_experience`` can return a list of cached and processed + experiences. For example, this can be used for success episode labeling. """ return True - def post_process_episode(self, experiences: List[Experience]): + def post_process_experience(self, experiences: Experience): """A function for postprocessing a list of experience. It is called when - ``should_post_process_episode`` is True. + ``should_post_process_experience`` is True. By default, it returns the input unmodified. Users can customize this function in the derived class, to create a number of useful features such as 'hindsight relabeling' of a trajectory etc. Args: - experiences: a list of experience, containing the experience starting from the - initial time when ``should_post_process_episode`` is False to the step where - ``should_post_process_episode`` is True. + experiences: one step of experience. + + Returns: + A list of experiences. Users can customize this pair of functions in the + derived class to achieve different effects. For example: + - return a list that contains only the input experience (default behavior). + - return a list that contains a number of experiences. This can be useful + for episode processing such as success episode labeling. """ - return experiences + return [experiences] def _process_unroll_step(self, policy_step, action, time_step, transformed_time_step, policy_state, experience_list, original_reward_list): + """A function for processing the unroll steps. + By default, it returns the input unmodified. + Users can customize this function in the derived class, to create a number of + useful features such as 'hindsight relabeling' of a trajectory etc. + + Args: + experiences: a list of experience, containing the experience starting from the + initial time when ``should_post_process_experience`` is False to the step where + ``should_post_process_experience`` is True. + """ + self.observe_for_metrics(time_step.cpu()) exp = make_experience(time_step.cpu(), alf.layers.to_float32(policy_step), alf.layers.to_float32(policy_state)) - effective_number_of_unroll_steps = 1 - if self._episodic_annotation: - assert not self.on_policy, "only support episodic annotation for off policy training" - store_exp_time = 0 - # if last step, annotate + effective_unroll_steps = 1 + store_exp_time = 0 + if not self.on_policy: rollout_info = policy_step.info - self._cached_exp.append(exp) - if self.should_post_process_episode(rollout_info, - time_step.step_type): - + if self.should_post_process_experience(rollout_info, + time_step.step_type): # 1) process - annotated_exp_list = self.post_process_episode( - self._cached_exp) - effective_number_of_unroll_steps = len(annotated_exp_list) + post_processed_exp_list = self.post_process_experience(exp) + effective_unroll_steps = len(post_processed_exp_list) # 2) observe t0 = time.time() - for exp in annotated_exp_list: + for exp in post_processed_exp_list: self.observe_for_replay(exp) store_exp_time = time.time() - t0 - # clean up the exp cache - self._cached_exp = [] else: - # effective unroll steps as 0 if not post_process_episode timepoint yet - effective_number_of_unroll_steps = 0 - else: - store_exp_time = 0 - if not self.on_policy: - t0 = time.time() - self.observe_for_replay(exp) - store_exp_time = time.time() - t0 + # effective unroll steps as 0 if ``should_post_process_experience condition`` is False + effective_unroll_steps = 0 exp_for_training = Experience( time_step=transformed_time_step, @@ -676,7 +681,7 @@ def _process_unroll_step(self, policy_step, action, time_step, experience_list.append(exp_for_training) original_reward_list.append(time_step.reward) - return store_exp_time, effective_number_of_unroll_steps + return store_exp_time, effective_unroll_steps def reset_state(self): """Reset the state of the algorithm. @@ -700,6 +705,8 @@ def _sync_unroll(self, unroll_length: int): Returns: Experience: The stacked experience with shape :math:`[T, B, \ldots]` for each of its members. + effective_unroll_iters: the effective number of unroll iterations. + Each unroll iteration contains ``unroll_length`` unroll steps. """ if self._current_time_step is None: self._current_time_step = common.get_initial_time_step(self._env) @@ -750,10 +757,11 @@ def _sync_unroll(self, unroll_length: int): if self._overwrite_policy_output: policy_step = policy_step._replace( output=next_time_step.prev_action) - store_exp_time_i, effective_unroll_steps = self._process_unroll_step( + store_exp_time_i, effective_unroll_steps_i = self._process_unroll_step( policy_step, action, time_step, transformed_time_step, policy_state, experience_list, original_reward_list) store_exp_time += store_exp_time_i + effective_unroll_steps += effective_unroll_steps_i time_step = next_time_step policy_state = policy_step.state @@ -781,7 +789,8 @@ def _sync_unroll(self, unroll_length: int): self._current_policy_state = common.detach(policy_state) self._current_transform_state = common.detach(trans_state) - return experience, effective_unroll_steps + effective_unroll_iters = effective_unroll_steps // unroll_length + return experience, effective_unroll_iters def train_iter(self): """Perform one iteration of training. @@ -846,6 +855,9 @@ def _unroll_iter_off_policy(self): unroll length, it may not have been called. - root_inputs: root-level time step returned by the unroll - rollout_info: rollout info returned by the unroll + - effective_unroll_iters: the effective number of unroll iterations. + ``train_from_replay_buffer`` will be run ``effective_unroll_iters`` times + during ``_train_iter_off_policy``. """ config: TrainerConfig = self._config @@ -862,7 +874,7 @@ def _unroll_iter_off_policy(self): unrolled = False root_inputs = None rollout_info = None - effective_unroll_steps = 0 + effective_unroll_iters = 0 if (alf.summary.get_global_counter() >= self._rl_train_after_update_steps and (unroll_length > 0 or config.unroll_length == 0) and @@ -881,7 +893,7 @@ def _unroll_iter_off_policy(self): # need to remember whether summary has been written between # two unrolls. with self._ensure_rollout_summary: - experience, effective_unroll_steps = self.unroll( + experience, effective_unroll_iters = self.unroll( unroll_length) if experience: self.summarize_rollout(experience) @@ -890,11 +902,11 @@ def _unroll_iter_off_policy(self): if config.use_root_inputs_for_after_train_iter: root_inputs = experience.time_step del experience - return unrolled, root_inputs, rollout_info, effective_unroll_steps + return unrolled, root_inputs, rollout_info, effective_unroll_iters def _train_iter_off_policy(self): """User may override this for their own training procedure.""" - unrolled, root_inputs, rollout_info, effective_unroll_steps = self._unroll_iter_off_policy( + unrolled, root_inputs, rollout_info, effective_unroll_iters = self._unroll_iter_off_policy( ) # replay buffer may not have been created for two different reasons: @@ -908,7 +920,7 @@ def _train_iter_off_policy(self): self.train() steps = 0 - for i in range(effective_unroll_steps): + for i in range(effective_unroll_iters): steps += self.train_from_replay_buffer(update_global_counter=True) if unrolled: with record_time("time/after_train_iter"): From 26ab09a8ea0488f9c73b42b7829cdb8883e898e1 Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Fri, 23 May 2025 12:15:48 -0700 Subject: [PATCH 5/8] Address more comments --- alf/algorithms/algorithm.py | 6 +-- alf/algorithms/rl_algorithm.py | 69 +++++++++++----------------------- 2 files changed, 24 insertions(+), 51 deletions(-) diff --git a/alf/algorithms/algorithm.py b/alf/algorithms/algorithm.py index cabe2c5c9..19a6db97b 100644 --- a/alf/algorithms/algorithm.py +++ b/alf/algorithms/algorithm.py @@ -1469,7 +1469,6 @@ def train_from_replay_buffer(self, update_global_counter=False): ``True``, it will affect the counter only if ``config.update_counter_every_mini_batch=True``. """ - config: TrainerConfig = self._config # returns 0 if haven't started training yet, when ``_replay_buffer`` is @@ -1494,7 +1493,6 @@ def _replay(): # ``_replay_buffer`` for training. # TODO: If this function can be called asynchronously, and using # prioritized replay, then make sure replay and train below is atomic. - effective_num_updates_per_train_iter = config.num_updates_per_train_iter with record_time("time/replay"): mini_batch_size = config.mini_batch_size if mini_batch_size is None: @@ -1502,14 +1500,14 @@ def _replay(): if config.whole_replay_buffer_training: experience, batch_info = self._replay_buffer.gather_all( ignore_earliest_frames=True) - num_updates = effective_num_updates_per_train_iter + num_updates = config.num_updates_per_train_iter else: assert config.mini_batch_length is not None, ( "No mini_batch_length is specified for off-policy training" ) experience, batch_info = self._replay_buffer.get_batch( batch_size=(mini_batch_size * - effective_num_updates_per_train_iter), + config.num_updates_per_train_iter), batch_length=config.mini_batch_length) num_updates = 1 return experience, batch_info, num_updates, mini_batch_size diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index f7b2bef21..0c2327dc7 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -19,7 +19,7 @@ import os import time import torch -from typing import Callable, List, Optional +from typing import Callable, Optional from absl import logging import alf @@ -601,35 +601,25 @@ def _async_unroll(self, unroll_length: int): effective_unroll_iters = effective_unroll_steps // unroll_length return experience, effective_unroll_iters - def should_post_process_experience(self, rollout_info, - step_type: StepType): - """A function that determines whether the ``post_process_experience`` function should - be called. Users can customize this pair of functions in the derived class to achieve - different effects. For example: - - per-step processing: ``should_post_process_experience`` - returns True for all the steps (by default), and ``post_process_experience`` - returns the current step of experience unmodified (by default) or a modified version - according to their customized ``post_process_experience`` function. + def post_process_experience(self, rollout_info, step_type: StepType, + experiences: Experience): + """A function for postprocessing experience. By default, it returns the input + experience unmodified. Users can customize this function in the derived + class to achieve different effects. For example: + - per-step processing: return the current step of experience unmodified (by default) + or a modified version according to the customized ``post_process_experience``. As another example, task filtering can be simply achieved by returning ``[]`` - in ``post_process_experience`` for that particular task. - - per-episode processing: ``should_post_process_experience`` returns True on episode - end and ``post_process_experience`` can return a list of cached and processed + for that particular task. + - per-episode processing: this can be achieved by returning a list of processed experiences. For example, this can be used for success episode labeling. - """ - return True - - def post_process_experience(self, experiences: Experience): - """A function for postprocessing a list of experience. It is called when - ``should_post_process_experience`` is True. - By default, it returns the input unmodified. - Users can customize this function in the derived class, to create a number of - useful features such as 'hindsight relabeling' of a trajectory etc. Args: + rollout_info: the rollout info. + step_type: the step type of the current experience. experiences: one step of experience. Returns: - A list of experiences. Users can customize this pair of functions in the + A list of experiences. Users can customize this functions in the derived class to achieve different effects. For example: - return a list that contains only the input experience (default behavior). - return a list that contains a number of experiences. This can be useful @@ -640,17 +630,6 @@ def post_process_experience(self, experiences: Experience): def _process_unroll_step(self, policy_step, action, time_step, transformed_time_step, policy_state, experience_list, original_reward_list): - """A function for processing the unroll steps. - By default, it returns the input unmodified. - Users can customize this function in the derived class, to create a number of - useful features such as 'hindsight relabeling' of a trajectory etc. - - Args: - experiences: a list of experience, containing the experience starting from the - initial time when ``should_post_process_experience`` is False to the step where - ``should_post_process_experience`` is True. - """ - self.observe_for_metrics(time_step.cpu()) exp = make_experience(time_step.cpu(), alf.layers.to_float32(policy_step), @@ -659,19 +638,15 @@ def _process_unroll_step(self, policy_step, action, time_step, store_exp_time = 0 if not self.on_policy: rollout_info = policy_step.info - if self.should_post_process_experience(rollout_info, - time_step.step_type): - # 1) process - post_processed_exp_list = self.post_process_experience(exp) - effective_unroll_steps = len(post_processed_exp_list) - # 2) observe - t0 = time.time() - for exp in post_processed_exp_list: - self.observe_for_replay(exp) - store_exp_time = time.time() - t0 - else: - # effective unroll steps as 0 if ``should_post_process_experience condition`` is False - effective_unroll_steps = 0 + # 1) process + post_processed_exp_list = self.post_process_experience( + rollout_info, time_step.step_type, exp) + effective_unroll_steps = len(post_processed_exp_list) + # 2) observe + t0 = time.time() + for exp in post_processed_exp_list: + self.observe_for_replay(exp) + store_exp_time = time.time() - t0 exp_for_training = Experience( time_step=transformed_time_step, From 734dae80a6f40794d34ad105101797fdb2d3ae2d Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Fri, 23 May 2025 14:38:46 -0700 Subject: [PATCH 6/8] Handle fractional unroll --- alf/algorithms/rl_algorithm.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index 0c2327dc7..e2dbb4d1d 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -235,6 +235,7 @@ def __init__(self, self._current_time_step = None self._current_policy_state = None self._current_transform_state = None + if self._env is not None and not self.on_policy: replay_buffer_length = adjust_replay_buffer_length( config, self._num_earliest_frames_ignored) @@ -598,7 +599,9 @@ def _async_unroll(self, unroll_length: int): self._current_transform_state = common.detach(trans_state) - effective_unroll_iters = effective_unroll_steps // unroll_length + # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as + # an effective unroll iter + effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length return experience, effective_unroll_iters def post_process_experience(self, rollout_info, step_type: StepType, @@ -637,11 +640,12 @@ def _process_unroll_step(self, policy_step, action, time_step, effective_unroll_steps = 1 store_exp_time = 0 if not self.on_policy: - rollout_info = policy_step.info - # 1) process + # 1) post process post_processed_exp_list = self.post_process_experience( - rollout_info, time_step.step_type, exp) - effective_unroll_steps = len(post_processed_exp_list) + policy_step.info, time_step.step_type, exp) + effective_unroll_steps = sum( + exp.step_type.shape[0] + for exp in post_processed_exp_list) / exp.step_type.shape[0] # 2) observe t0 = time.time() for exp in post_processed_exp_list: @@ -764,7 +768,9 @@ def _sync_unroll(self, unroll_length: int): self._current_policy_state = common.detach(policy_state) self._current_transform_state = common.detach(trans_state) - effective_unroll_iters = effective_unroll_steps // unroll_length + # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as + # an effective unroll iter + effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length return experience, effective_unroll_iters def train_iter(self): From 94a50bfc52071f4d7df3f7f0e4e7dff499904fe5 Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Thu, 5 Jun 2025 10:59:45 -0700 Subject: [PATCH 7/8] Let user set effective_unroll_steps --- alf/algorithms/rl_algorithm.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index e2dbb4d1d..2ec85419d 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -19,7 +19,7 @@ import os import time import torch -from typing import Callable, Optional +from typing import Callable, List, Optional, Tuple from absl import logging import alf @@ -605,7 +605,7 @@ def _async_unroll(self, unroll_length: int): return experience, effective_unroll_iters def post_process_experience(self, rollout_info, step_type: StepType, - experiences: Experience): + experiences: Experience) -> Tuple[List, int]: """A function for postprocessing experience. By default, it returns the input experience unmodified. Users can customize this function in the derived class to achieve different effects. For example: @@ -622,17 +622,22 @@ class to achieve different effects. For example: experiences: one step of experience. Returns: - A list of experiences. Users can customize this functions in the - derived class to achieve different effects. For example: - - return a list that contains only the input experience (default behavior). - - return a list that contains a number of experiences. This can be useful - for episode processing such as success episode labeling. + - a list of experiences. Users can customize this functions in the + derived class to achieve different effects. For example: + * return a list that contains only the input experience (default behavior). + * return a list that contains a number of experiences. This can be useful + for episode processing such as success episode labeling. + - an integer representing the effective number of unroll steps per env. The + default value of 1, meaning the length of effective experience is 1 + after calling ``post_process_experience``, the same as the input length + of experience. """ - return [experiences] + return [experiences], 1 def _process_unroll_step(self, policy_step, action, time_step, transformed_time_step, policy_state, - experience_list, original_reward_list): + experience_list, + original_reward_list) -> Tuple[int, int]: self.observe_for_metrics(time_step.cpu()) exp = make_experience(time_step.cpu(), alf.layers.to_float32(policy_step), @@ -641,11 +646,8 @@ def _process_unroll_step(self, policy_step, action, time_step, store_exp_time = 0 if not self.on_policy: # 1) post process - post_processed_exp_list = self.post_process_experience( + post_processed_exp_list, effective_unroll_steps = self.post_process_experience( policy_step.info, time_step.step_type, exp) - effective_unroll_steps = sum( - exp.step_type.shape[0] - for exp in post_processed_exp_list) / exp.step_type.shape[0] # 2) observe t0 = time.time() for exp in post_processed_exp_list: From 8fc3ff294e4c470803f98bc8e8af14ea5592df5b Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Fri, 6 Jun 2025 11:48:30 -0700 Subject: [PATCH 8/8] Address comments --- alf/algorithms/rl_algorithm.py | 80 +++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index 2ec85419d..81ae4a27b 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -604,13 +604,15 @@ def _async_unroll(self, unroll_length: int): effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length return experience, effective_unroll_iters - def post_process_experience(self, rollout_info, step_type: StepType, - experiences: Experience) -> Tuple[List, int]: - """A function for postprocessing experience. By default, it returns the input + def preprocess_unroll_experience( + self, rollout_info, step_type: StepType, + experiences: Experience) -> Tuple[List, float]: + """A function for processing the experience obtained from an unroll step before + being saved into the replay buffer. By default, it returns the input experience unmodified. Users can customize this function in the derived class to achieve different effects. For example: - per-step processing: return the current step of experience unmodified (by default) - or a modified version according to the customized ``post_process_experience``. + or a modified version according to the customized ``preprocess_unroll_experience``. As another example, task filtering can be simply achieved by returning ``[]`` for that particular task. - per-episode processing: this can be achieved by returning a list of processed @@ -622,22 +624,66 @@ class to achieve different effects. For example: experiences: one step of experience. Returns: - - a list of experiences. Users can customize this functions in the - derived class to achieve different effects. For example: + - ``effective_experiences``: a list of experiences. Users can customize this + functions in the derived class to achieve different effects. For example: * return a list that contains only the input experience (default behavior). * return a list that contains a number of experiences. This can be useful for episode processing such as success episode labeling. - - an integer representing the effective number of unroll steps per env. The - default value of 1, meaning the length of effective experience is 1 - after calling ``post_process_experience``, the same as the input length - of experience. + - ``effective_unroll_steps`` : a value representing the effective number of + unroll steps per env. The default value of 1, meaning the length of + effective experience is 1 after calling ``preprocess_unroll_experience``, + the same as the input length of experience. + The value of ``effective_unroll_steps`` can be set differently according + to different scenarios, e.g.: + (1) per-step saving without delay: saving each step of unroll experience + into the replay buffer as we get it. Set ``effective_unroll_steps`` + as 1 so that each step will be counted as valid and there will be no + impact on the train/unroll ratio. + (2) all-step saving with delay: saving all the steps of unroll experience into + the replay buffer with delay. This can happen in the case where we want to + annotate an trajectory based on some quantities that are not immediately + available in the current step (e.g. task success/failure). In this case, + we can simply caching the experiences and set ``effective_experiences=[]`` + before obtaining the quantities required for annotation. + After obtaining the quantities required for annotation, we can + set ``effective_experiences`` as the cached and annotated experience. + To maintain the original unroll/train iter ratio, we can set + ``effective_unroll_steps=1``, meaning each unroll step is regarded as + effective in terms of the unroll/train iter ratio, even though the + pace of saving the unroll steps into replay buffer has been altered. + (3) selective saving: exclude some of the unroll experiences and only save + the rest. This could be useful in the case where there are transitions + that are irrelevant to the training (e.g. in the multi-task case, where + we want to exclude data from certain subtasks). + This can be achieved by setting ``effective_experiences=[]``for the + steps to be excluded, while ``effective_experiences = [experiences]`` + otherwise. If we do not want to trigger a train iter for the unroll + step that will be excluded, we can simply set ``effective_unroll_steps=0``. + Otherwise, we can simply set ``effective_unroll_steps=1``. + (4) parallel environments: in the case of parallel environments, the value + of ``effective_unroll_steps`` can be set according to the modes described + above and the status of each environment (e.g. ``effective_unroll_steps`` + can be set to an average value across environments). Note that this could + resulf to a floating number. """ - return [experiences], 1 + effective_experiences = [experiences] + effective_unroll_steps = 1 + return effective_experiences, effective_unroll_steps def _process_unroll_step(self, policy_step, action, time_step, transformed_time_step, policy_state, experience_list, - original_reward_list) -> Tuple[int, int]: + original_reward_list) -> Tuple[int, float]: + """ + + Returns: + - ``store_exp_time``: the time spent on storing the experience + - ``effective_unroll_steps``: a value representing the effective number + of unroll steps per env. The default value of 1, meaning the length of + effective experience is 1 after calling ``preprocess_unroll_experience``, + the same as the input length of experience. For more details on it, + please refer to the docstr of ``preprocess_unroll_experience``. + """ self.observe_for_metrics(time_step.cpu()) exp = make_experience(time_step.cpu(), alf.layers.to_float32(policy_step), @@ -645,12 +691,12 @@ def _process_unroll_step(self, policy_step, action, time_step, effective_unroll_steps = 1 store_exp_time = 0 if not self.on_policy: - # 1) post process - post_processed_exp_list, effective_unroll_steps = self.post_process_experience( + # 1) pre-process unroll experience + pre_processed_exp_list, effective_unroll_steps = self.preprocess_unroll_experience( policy_step.info, time_step.step_type, exp) # 2) observe t0 = time.time() - for exp in post_processed_exp_list: + for exp in pre_processed_exp_list: self.observe_for_replay(exp) store_exp_time = time.time() - t0 @@ -771,7 +817,9 @@ def _sync_unroll(self, unroll_length: int): self._current_transform_state = common.detach(trans_state) # if the input unroll_length is 0 (e.g. fractional unroll), then this it treated as - # an effective unroll iter + # an effective unroll iter. + # one ``effective_unroll_iter`` refers to the ``unroll_length`` times of calling + # of ``rollout_step`` in the unroll phase. effective_unroll_iters = 1 if unroll_length == 0 else effective_unroll_steps // unroll_length return experience, effective_unroll_iters