Conversation
ding/policy/__init__.py
Outdated
| from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy | ||
| from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy | ||
| from .cql import CQLPolicy, DiscreteCQLPolicy | ||
| from .qtransformer import QtransformerPolicy |
| from ding.entry import serial_pipeline_offline | ||
| from ding.config import read_config | ||
| from pathlib import Path | ||
| from ding.model.template.qtransformer import QTransformer |
There was a problem hiding this comment.
import from the secondary directory, such as:
from ding.model import QTransformer| alpha=0.2, | ||
| discount_factor_gamma=0.9, | ||
| min_reward = 0.1, | ||
| auto_alpha=False, |
There was a problem hiding this comment.
remove unused fields like this
ding/policy/qtransformer.py
Outdated
| update_type='momentum', | ||
| update_kwargs={'theta': self._cfg.learn.target_theta} | ||
| ) | ||
| self._low = np.array(self._cfg.other["low"]) |
There was a problem hiding this comment.
we don't need low and high here, We always think that the action value range in the policy is [-1,1]
| cuda=True, | ||
| model=dict( | ||
| num_actions = 3, | ||
| action_bins = 256, |
There was a problem hiding this comment.
this action_bins field is not used in policy
ding/policy/qtransformer.py
Outdated
| selected = t.gather(-1, indices) | ||
| return rearrange(selected, '... 1 -> ...') | ||
|
|
||
| def _discretize_action(self, actions): |
There was a problem hiding this comment.
we can optimize this for loop:
action_values = np.linspace(-1, 1, 8)[np.newaxis, ...].repeat(4, 0)
action_values = torch.as_tensor(action_values).to(self._device)
diff = (actions.unsqueeze(-1) - action_values.unsqueeze(0)) ** 2
indices = diff.argmin(-1)
ding/policy/qtransformer.py
Outdated
| actions = data['action'] | ||
|
|
||
| #get q | ||
| num_timesteps, device = states.shape[1], states.device |
There was a problem hiding this comment.
use self._device, which is the default member variable of Policy
ding/policy/qtransformer.py
Outdated
| import torch | ||
| import torch.nn.functional as F | ||
| from torch.distributions import Normal, Independent | ||
| from ema_pytorch import EMA |
There was a problem hiding this comment.
remove unused third party libraries
ding/policy/qtransformer.py
Outdated
|
|
||
| from pathlib import Path | ||
| from functools import partial | ||
| from contextlib import nullcontext |
ding/policy/qtransformer.py
Outdated
|
|
||
| from torchtyping import TensorType | ||
|
|
||
| from einops import rearrange, repeat, pack, unpack |
ding/policy/qtransformer.py
Outdated
| from einops import rearrange, repeat, pack, unpack | ||
| from einops.layers.torch import Rearrange | ||
|
|
||
| from beartype import beartype |
There was a problem hiding this comment.
we will not use beartype to validate runtime types in the current version, thus remove it in this PR
ding/model/template/qtransformer.py
Outdated
| @@ -0,0 +1,753 @@ | |||
| from random import random | |||
| from functools import partial, cache | |||
There was a problem hiding this comment.
cache is the new feature in python3.9, for compatibility, you should implement it as follows:
try:
from functools import cache # only in Python >= 3.9
except ImportError:
from functools import lru_cache
cache = lru_cache(maxsize=None)…tput; more pannel to see
Description
Related Issue
TODO
Check List