diff --git a/calm/learning/anyskill_agent.py b/calm/learning/anyskill_agent.py index 300dacc..6607ce4 100644 --- a/calm/learning/anyskill_agent.py +++ b/calm/learning/anyskill_agent.py @@ -276,6 +276,18 @@ def init_tensors(self): return + def _get_motion_encoding(self): + all_encoded_demo_amp_obs = [] + for motion_id in range(self.vec_env.env.task._motion_lib._motion_weights.shape[0]): + motion_amp_obs = self.vec_env.env.task.fetch_amp_obs_demo_per_id(32, motion_id)[-1].view(32, self.vec_env.env.task._num_amp_obs_enc_steps, self.vec_env.env.task._num_amp_obs_per_step) + + preproc_amp_obs = self._llc_agent._preproc_amp_obs(motion_amp_obs) + encoded_demo_amp_obs = self._llc_agent.model.a2c_network.eval_enc(preproc_amp_obs) + all_encoded_demo_amp_obs.append(encoded_demo_amp_obs) + all_encoded_demo_amp_obs = torch.stack(all_encoded_demo_amp_obs, dim=0) + + return all_encoded_demo_amp_obs + def _build_llc(self, config_params, checkpoint_file): network_params = config_params['network'] @@ -293,12 +305,13 @@ def _build_llc(self, config_params, checkpoint_file): print("Loaded LLC checkpoint from {:s}".format(checkpoint_file)) self._llc_agent.set_eval() - enc_amp_obs = self._llc_agent._fetch_amp_obs_demo(128) - if len(enc_amp_obs) == 2: - enc_amp_obs = enc_amp_obs[0] - - preproc_enc_amp_obs = self._llc_agent._preproc_amp_obs(enc_amp_obs) - self.encoded_motion = self._llc_agent.model.a2c_network.eval_enc(amp_obs=preproc_enc_amp_obs).unsqueeze(0) + # enc_amp_obs = self._llc_agent._fetch_amp_obs_demo(128) + # if len(enc_amp_obs) == 2: + # enc_amp_obs = enc_amp_obs[0] + # + # preproc_enc_amp_obs = self._llc_agent._preproc_amp_obs(enc_amp_obs) + # self.encoded_motion = self._llc_agent.model.a2c_network.eval_enc(amp_obs=preproc_enc_amp_obs).unsqueeze(0) + self.encoded_motion = self._get_motion_encoding() return