feature(zc): add MetaDiffuser and prompt-dt#771
feature(zc): add MetaDiffuser and prompt-dt#771Super1ce wants to merge 17 commits intoopendilab:mainfrom
Conversation
| ) -> 'Policy': # noqa | ||
| """ | ||
| Overview: | ||
| Serial pipeline entry. |
| # use the original batch size per gpu and increase learning rate | ||
| # correspondingly. | ||
| cfg.policy.learn.batch_size // get_world_size(), | ||
| # cfg.policy.learn.batch_size |
| for epoch in range(cfg.policy.learn.train_epoch): | ||
| if get_world_size() > 1: | ||
| dataloader.sampler.set_epoch(epoch) | ||
| for i in range(cfg.policy.train_num): |
There was a problem hiding this comment.
"train_num"->"batch_size"?
| (prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1 | ||
| ).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim) | ||
|
|
||
| # prompt_stacked_attention_mask = torch.stack( |
There was a problem hiding this comment.
Remove these unused lines?
ding/model/template/diffusion.py
Outdated
| self.returns_condition = returns_condition | ||
| self.condition_guidance_w = condition_guidance_w | ||
|
|
||
| # def get_loss_weights(self, discount: int): |
There was a problem hiding this comment.
Remove these unused lines?
|
|
||
| return model_mean + model_std * noise, y | ||
|
|
||
| def free_guidance_sample( |
There was a problem hiding this comment.
Add class hints for all arguments, add Overview for functions and classes.
ding/model/template/diffusion.py
Outdated
|
|
||
| self.embed = nn.Sequential( | ||
| nn.Linear((obs_dim * 2 + action_dim + 1) * encoder_horizon, dim * 4), | ||
| Mish(),#nn.Mish(), |
| self._learn_model = model_wrap(self._model, wrapper_name='base') | ||
| self._learn_model.reset() | ||
|
|
||
| def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: |
There was a problem hiding this comment.
data should be collated into batchsize before entering policy._forward_learn.
data type shoule be Dict[str, torch.Tensor].
| if self.have_train: | ||
| if self.task_id is None: | ||
| self.task_id = [0] * self.eval_batch_size | ||
| # if data_id is None: |
| if self._cuda: | ||
| data = to_device(data, self._device) | ||
|
|
||
| p_s, p_a, p_rtg, p_t, p_mask, timesteps, states, actions, rewards, returns_to_go, \ |
There was a problem hiding this comment.
data should be collated into batchsize before entering policy._forward_learn.
data type shoule be Dict[str, torch.Tensor], so that it can be assigned confirmly.
| self.returns_mlp = nn.Sequential( | ||
| SinusoidalPosEmb(dim), | ||
| nn.Linear(dim, dim * 4), | ||
| #nn.Mish(), |
|
|
||
| @DATASET_REGISTRY.register('meta_traj') | ||
| class MetaTraj(Dataset): | ||
| def __init__(self, cfg): |
There was a problem hiding this comment.
Add notation for this class and config items.
| Interaction serial evaluator class, policy interacts with env. This class evaluator algorithm | ||
| with test environment list. | ||
| Interfaces: | ||
| __init__, reset, reset_policy, reset_env, close, should_eval, eval |
Add MetaDIffusion and prompt-dt algorithm