From 3c81edb7a3313090154d3aaf357bd4bf8b19748d Mon Sep 17 00:00:00 2001 From: ln23415 Date: Wed, 27 Dec 2023 19:47:07 +0800 Subject: [PATCH] fix bugs to fetch all motions features after encoder --- calm/learning/anyskill_agent.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/calm/learning/anyskill_agent.py b/calm/learning/anyskill_agent.py index 300dacc..6607ce4 100644 --- a/calm/learning/anyskill_agent.py +++ b/calm/learning/anyskill_agent.py @@ -276,6 +276,18 @@ def init_tensors(self): return + def _get_motion_encoding(self): + all_encoded_demo_amp_obs = [] + for motion_id in range(self.vec_env.env.task._motion_lib._motion_weights.shape[0]): + motion_amp_obs = self.vec_env.env.task.fetch_amp_obs_demo_per_id(32, motion_id)[-1].view(32, self.vec_env.env.task._num_amp_obs_enc_steps, self.vec_env.env.task._num_amp_obs_per_step) + + preproc_amp_obs = self._llc_agent._preproc_amp_obs(motion_amp_obs) + encoded_demo_amp_obs = self._llc_agent.model.a2c_network.eval_enc(preproc_amp_obs) + all_encoded_demo_amp_obs.append(encoded_demo_amp_obs) + all_encoded_demo_amp_obs = torch.stack(all_encoded_demo_amp_obs, dim=0) + + return all_encoded_demo_amp_obs + def _build_llc(self, config_params, checkpoint_file): network_params = config_params['network'] @@ -293,12 +305,13 @@ def _build_llc(self, config_params, checkpoint_file): print("Loaded LLC checkpoint from {:s}".format(checkpoint_file)) self._llc_agent.set_eval() - enc_amp_obs = self._llc_agent._fetch_amp_obs_demo(128) - if len(enc_amp_obs) == 2: - enc_amp_obs = enc_amp_obs[0] - - preproc_enc_amp_obs = self._llc_agent._preproc_amp_obs(enc_amp_obs) - self.encoded_motion = self._llc_agent.model.a2c_network.eval_enc(amp_obs=preproc_enc_amp_obs).unsqueeze(0) + # enc_amp_obs = self._llc_agent._fetch_amp_obs_demo(128) + # if len(enc_amp_obs) == 2: + # enc_amp_obs = enc_amp_obs[0] + # + # preproc_enc_amp_obs = self._llc_agent._preproc_amp_obs(enc_amp_obs) + # self.encoded_motion = self._llc_agent.model.a2c_network.eval_enc(amp_obs=preproc_enc_amp_obs).unsqueeze(0) + self.encoded_motion = self._get_motion_encoding() return