From d1e427ed2941d6e3eaa23aca5120ecb295dd8b8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 22 Oct 2024 17:07:30 +0800 Subject: [PATCH 1/9] feature(pu): add pistonball_env, its unittest and qmix config --- .../config/ptz_pistonball_qmix_config.py | 76 ++++++ .../envs/petting_zoo_pistonball_env.py | 247 ++++++++++++++++++ .../envs/petting_zoo_simple_spread_env.py | 10 +- .../envs/test_petting_zoo_pistonball_env.py | 106 ++++++++ 4 files changed, 434 insertions(+), 5 deletions(-) create mode 100644 dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py create mode 100644 dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py create mode 100644 dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py diff --git a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py new file mode 100644 index 0000000000..f1b2da682a --- /dev/null +++ b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py @@ -0,0 +1,76 @@ +from easydict import EasyDict + +n_pistons = 20 +collector_env_num = 8 +evaluator_env_num = 8 + +main_config = dict( + exp_name='ptz_pistonball_qmix_seed0', + env=dict( + env_family='butterfly', + env_id='pistonball_v6', + n_pistons=n_pistons, + max_cycles=125, + agent_obs_only=False, + continuous_actions=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + stop_value=1e6, + manager=dict( + shared_memory=False, + reset_timeout=6000, + ), + ), + policy=dict( + cuda=True, + model=dict( + agent_num=n_pistons, + obs_shape=(3, 457, 120), # RGB image observation shape for each piston agent + global_obs_shape=(3, 560, 880), # Global state shape + action_shape=3, # Discrete actions (0, 1, 2) + hidden_size_list=[128, 128, 64], + # mixer=True, # TODO: mixer is not supported image observation now + mixer=False, + ), + learn=dict( + update_per_collect=100, + batch_size=32, + learning_rate=0.0005, + target_update_theta=0.001, + discount_factor=0.99, + double_q=True, + ), + collect=dict( + n_sample=600, + unroll_len=16, + env_num=collector_env_num, + ), + eval=dict(env_num=evaluator_env_num), + other=dict(eps=dict( + type='exp', + start=1.0, + end=0.05, + decay=100000, + )), + ), +) +main_config = EasyDict(main_config) + +create_config = dict( + env=dict( + import_names=['dizoo.petting_zoo.envs.petting_zoo_pistonball_env'], + type='petting_zoo_pistonball', + ), + env_manager=dict(type='subprocess'), + policy=dict(type='qmix'), +) +create_config = EasyDict(create_config) + +ptz_pistonball_qmix_config = main_config +ptz_pistonball_qmix_create_config = create_config + +if __name__ == '__main__': + # or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0` + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file diff --git a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py new file mode 100644 index 0000000000..4e456db710 --- /dev/null +++ b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py @@ -0,0 +1,247 @@ +from typing import Any, List, Union, Optional, Dict +import gymnasium as gym +import numpy as np +from functools import reduce + +from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper +from ding.torch_utils import to_ndarray, to_list +from ding.envs.common.common_function import affine_transform +from ding.utils import ENV_REGISTRY +from pettingzoo.utils.conversions import parallel_wrapper_fn +from pettingzoo.butterfly import pistonball_v6 + + +# Custom wrapper for recording videos in PettingZoo environments +class PTZRecordVideo(gym.wrappers.RecordVideo): + def step(self, action): + """ + Custom step function for handling PettingZoo environments + with gymnasium's RecordVideo wrapper. + """ + observations, rewards, terminateds, truncateds, infos = self.env.step(action) + + # Check if any agent has terminated or truncated + if not (self.terminated is True or self.truncated is True): + self.step_id += 1 + if not self.is_vector_env: + if terminateds or truncateds: + self.episode_id += 1 + self.terminated = terminateds + self.truncated = truncateds + elif terminateds[0] or truncateds[0]: + self.episode_id += 1 + self.terminated = terminateds[0] + self.truncated = truncateds[0] + + # Capture the video frame if recording + if self.recording: + assert self.video_recorder is not None + self.video_recorder.capture_frame() + self.recorded_frames += 1 + if self.video_length > 0 and self.recorded_frames > self.video_length: + self.close_video_recorder() + elif not self.is_vector_env: + if terminateds is True or truncateds is True: + self.close_video_recorder() + elif terminateds[0] or truncateds[0]: + self.close_video_recorder() + + elif self._video_enabled(): + self.start_video_recorder() + + return observations, rewards, terminateds, truncateds, infos + + +@ENV_REGISTRY.register('petting_zoo_pistonball') +class PettingZooPistonballEnv(BaseEnv): + """ + DI-engine PettingZoo environment adapter for the Pistonball environment. + This class integrates the `pistonball_v6` environment into the DI-engine + framework, supporting both continuous and discrete actions. + """ + + def __init__(self, cfg: dict) -> None: + self._cfg = cfg + self._init_flag = False + self._replay_path = None + self._num_pistons = self._cfg.get('n_pistons', 20) + self._continuous_actions = self._cfg.get('continuous_actions', False) + self._max_cycles = self._cfg.get('max_cycles', 125) + self._act_scale = self._cfg.get('act_scale', False) + self._agent_specific_global_state = self._cfg.get('agent_specific_global_state', False) + if self._act_scale: + assert self._continuous_actions, 'Action scaling only applies to continuous action spaces.' + self._channel_first = self._cfg.get('channel_first', True) + + def reset(self) -> np.ndarray: + """ + Resets the environment and returns the initial observations. + """ + if not self._init_flag: + # Initialize the pistonball environment + parallel_env = pistonball_v6.parallel_env + self._env = parallel_env( + n_pistons=self._num_pistons, + continuous=self._continuous_actions, + max_cycles=self._max_cycles + ) + self._env.reset() + self._agents = self._env.agents + + # Define action and observation spaces + self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents}) + single_agent_obs_space = self._env.observation_space(self._agents[0]) + single_agent_action_space = self._env.action_space(self._agents[0]) + + if isinstance(single_agent_action_space, gym.spaces.Box): + self._action_dim = single_agent_action_space.shape + elif isinstance(single_agent_action_space, gym.spaces.Discrete): + self._action_dim = (single_agent_action_space.n, ) + else: + raise Exception('Only support `Box` or `Discrete` obs space for single agent.') + + if isinstance(single_agent_obs_space, gym.spaces.Box): + self._obs_shape = single_agent_obs_space.shape + else: + raise ValueError("Only support `Box` observation space for each agent.") + + self._observation_space = gym.spaces.Box( + low=0, high=255, shape=(self._num_pistons, *self._obs_shape), dtype=np.uint8 + ) + + self._reward_space = gym.spaces.Dict( + { + agent: gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(1,), dtype=np.float32) + for agent in self._agents + } + ) + + if self._replay_path is not None: + self._env.render_mode = 'rgb_array' + self._env = PTZRecordVideo(self._env, self._replay_path, name_prefix=f'rl-video-{id(self)}', disable_logger=True) + self._init_flag = True + + if hasattr(self, '_seed'): + obs = self._env.reset(seed=self._seed) + else: + obs = self._env.reset() + + self._eval_episode_return = 0.0 + self._step_count = 0 + obs_n = self._process_obs(obs) + return obs_n + + def close(self) -> None: + """ + Closes the environment. + """ + if self._init_flag: + self._env.close() + self._init_flag = False + + def render(self) -> None: + """ + Renders the environment. + """ + self._env.render() + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Sets the seed for the environment. + """ + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def step(self, action: np.ndarray) -> BaseEnvTimestep: + """ + Steps through the environment using the provided action. + """ + self._step_count += 1 + assert isinstance(action, np.ndarray), type(action) + action = self._process_action(action) + if self._act_scale: + for agent in self._agents: + action[agent] = affine_transform(action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high) + + obs, rew, done, trunc, info = self._env.step(action) + obs_n = self._process_obs(obs) + rew_n = np.array([sum([rew[agent] for agent in self._agents])]) + rew_n = rew_n.astype(np.float32) + self._eval_episode_return += rew_n.item() + + done_n = reduce(lambda x, y: x and y, done.values()) or self._step_count >= self._max_cycles + if done_n: + info['eval_episode_return'] = self._eval_episode_return + + return BaseEnvTimestep(obs_n, rew_n, done_n, info) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + """ + Enables video recording during the episode. + """ + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + + def _process_obs(self, obs: Dict[str, np.ndarray]) -> np.ndarray: + """ + Processes the observations into the required format. + """ + if self._channel_first: + obs = np.array([np.transpose(obs[agent], (2, 0, 1)) for agent in self._agents]).astype(np.uint8) + else: + obs = np.array([obs[agent] for agent in self._agents]).astype(np.uint8) + if self._cfg.get('agent_obs_only', False): + return obs + ret = {} + ret['agent_state'] = obs + if self._channel_first: + ret['global_state'] = self._env.state().transpose(2, 0, 1) + else: + ret['global_state'] = self._env.state() + if self._agent_specific_global_state: # TODO: more elegant way to handle this + ret['global_state'] = np.repeat(np.expand_dims(ret['global_state'], axis=0), self._num_pistons, axis=0) + ret['action_mask'] = np.ones((self._num_pistons, *self._action_dim)).astype(np.float32) + + return ret + + def _process_action(self, action: np.ndarray) -> Dict[str, np.ndarray]: + """ + Processes the action array into a dictionary format for each agent. + """ + dict_action = {} + for i, agent in enumerate(self._agents): + dict_action[agent] = action[i] + return dict_action + + def random_action(self) -> np.ndarray: + """ + Generates a random action for each agent. + """ + random_action = self.action_space.sample() + for k in random_action: + if isinstance(random_action[k], np.ndarray): + pass + elif isinstance(random_action[k], int): + random_action[k] = to_ndarray([random_action[k]], dtype=np.int64) + return random_action + + def __repr__(self) -> str: + return "DI-engine PettingZoo Pistonball Env" + + @property + def agents(self) -> List[str]: + return self._agents + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space \ No newline at end of file diff --git a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index bde84685f0..10c642026d 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -101,11 +101,11 @@ def reset(self) -> np.ndarray: self._agents = self._env.agents self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents}) - single_agent_obs_space = self._env.action_space(self._agents[0]) - if isinstance(single_agent_obs_space, gym.spaces.Box): - self._action_dim = single_agent_obs_space.shape - elif isinstance(single_agent_obs_space, gym.spaces.Discrete): - self._action_dim = (single_agent_obs_space.n, ) + single_agent_action_space = self._env.action_space(self._agents[0]) + if isinstance(single_agent_action_space, gym.spaces.Box): + self._action_dim = single_agent_action_space.shape + elif isinstance(single_agent_action_space, gym.spaces.Discrete): + self._action_dim = (single_agent_action_space.n, ) else: raise Exception('Only support `Box` or `Discrete` obs space for single agent.') diff --git a/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py new file mode 100644 index 0000000000..ea5ac988d7 --- /dev/null +++ b/dizoo/petting_zoo/envs/test_petting_zoo_pistonball_env.py @@ -0,0 +1,106 @@ +from easydict import EasyDict +import pytest +import numpy as np +from dizoo.petting_zoo.envs.petting_zoo_pistonball_env import PettingZooPistonballEnv + + +@pytest.mark.envtest +class TestPettingZooPistonballEnv: + + def test_agent_obs_only(self): + n_pistons = 20 + env = PettingZooPistonballEnv( + EasyDict( + dict( + n_pistons=n_pistons, + max_cycles=125, + agent_obs_only=True, + continuous_actions=True, + act_scale=False, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + assert obs.shape == (n_pistons, 3, 457, 120) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + # print(timestep) + assert isinstance(timestep.obs, np.ndarray), timestep.obs + assert timestep.obs.shape == (n_pistons, 3, 457, 120) + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + assert timestep.reward.dtype == np.float32 + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_dict_obs(self): + n_pistons = 20 + env = PettingZooPistonballEnv( + EasyDict( + dict( + n_pistons=n_pistons, + max_cycles=125, + agent_obs_only=False, + agent_specific_global_state=False, + continuous_actions=True, + act_scale=False, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + for k, v in obs.items(): + print(k, v.shape) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + # print(timestep) + assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs['agent_state'] + assert isinstance(timestep.obs['global_state'], np.ndarray), timestep.obs['global_state'] + assert timestep.obs['agent_state'].shape == (n_pistons, 3, 457, 120) + assert timestep.obs['global_state'].shape == (3, 560, 880) + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_agent_specific_global_state(self): + n_pistons = 20 + env = PettingZooPistonballEnv( + EasyDict( + dict( + n_pistons=n_pistons, + max_cycles=125, + agent_obs_only=False, + continuous_actions=True, + agent_specific_global_state=True, + act_scale=False, + ) + ) + ) + env.seed(123) + assert env._seed == 123 + obs = env.reset() + for k, v in obs.items(): + print(k, v.shape) + for i in range(10): + random_action = env.random_action() + random_action = np.array([random_action[agent] for agent in random_action]) + timestep = env.step(random_action) + # print(timestep) + assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs['agent_state'] + assert isinstance(timestep.obs['global_state'], np.ndarray), timestep.obs['global_state'] + assert timestep.obs['agent_state'].shape == (n_pistons, 3, 457, 120) + assert timestep.obs['global_state'].shape == (n_pistons, 3, 560, 880) + assert timestep.obs['global_state'].shape == (n_pistons, 3, 560, 880) + + assert isinstance(timestep.done, bool), timestep.done + assert isinstance(timestep.reward, np.ndarray), timestep.reward + print(env.observation_space, env.action_space, env.reward_space) + env.close() \ No newline at end of file From e91684197cfcd6732427150277f74de95cbf3cd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 22 Oct 2024 17:15:05 +0800 Subject: [PATCH 2/9] polish(pu): pistonball reuse PTZRecordVideo --- .../envs/petting_zoo_pistonball_env.py | 53 +++---------------- .../envs/petting_zoo_simple_spread_env.py | 10 ++-- 2 files changed, 9 insertions(+), 54 deletions(-) diff --git a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py index 4e456db710..2238b85c53 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py @@ -1,57 +1,16 @@ -from typing import Any, List, Union, Optional, Dict -import gymnasium as gym -import numpy as np from functools import reduce +from typing import List, Optional, Dict -from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper -from ding.torch_utils import to_ndarray, to_list +import gymnasium as gym +import numpy as np +from ding.envs import BaseEnv, BaseEnvTimestep from ding.envs.common.common_function import affine_transform +from ding.torch_utils import to_ndarray from ding.utils import ENV_REGISTRY -from pettingzoo.utils.conversions import parallel_wrapper_fn +from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PTZRecordVideo from pettingzoo.butterfly import pistonball_v6 -# Custom wrapper for recording videos in PettingZoo environments -class PTZRecordVideo(gym.wrappers.RecordVideo): - def step(self, action): - """ - Custom step function for handling PettingZoo environments - with gymnasium's RecordVideo wrapper. - """ - observations, rewards, terminateds, truncateds, infos = self.env.step(action) - - # Check if any agent has terminated or truncated - if not (self.terminated is True or self.truncated is True): - self.step_id += 1 - if not self.is_vector_env: - if terminateds or truncateds: - self.episode_id += 1 - self.terminated = terminateds - self.truncated = truncateds - elif terminateds[0] or truncateds[0]: - self.episode_id += 1 - self.terminated = terminateds[0] - self.truncated = truncateds[0] - - # Capture the video frame if recording - if self.recording: - assert self.video_recorder is not None - self.video_recorder.capture_frame() - self.recorded_frames += 1 - if self.video_length > 0 and self.recorded_frames > self.video_length: - self.close_video_recorder() - elif not self.is_vector_env: - if terminateds is True or truncateds is True: - self.close_video_recorder() - elif terminateds[0] or truncateds[0]: - self.close_video_recorder() - - elif self._video_enabled(): - self.start_video_recorder() - - return observations, rewards, terminateds, truncateds, infos - - @ENV_REGISTRY.register('petting_zoo_pistonball') class PettingZooPistonballEnv(BaseEnv): """ diff --git a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py index 10c642026d..186f41e9e2 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py @@ -13,17 +13,12 @@ from pettingzoo.mpe.simple_spread.simple_spread import Scenario +# Custom wrapper for recording videos in PettingZoo environments class PTZRecordVideo(gym.wrappers.RecordVideo): def step(self, action): """Steps through the environment using action, recording observations if :attr:`self.recording`.""" # gymnasium==0.27.1 - ( - observations, - rewards, - terminateds, - truncateds, - infos, - ) = self.env.step(action) + observations, rewards, terminateds, truncateds, infos = self.env.step(action) # Because pettingzoo returns a dict of terminated and truncated, we need to check if any of the values are True if not (self.terminated is True or self.truncated is True): # the first location for modifications @@ -39,6 +34,7 @@ def step(self, action): self.terminated = terminateds[0] self.truncated = truncateds[0] + # Capture the video frame if recording if self.recording: assert self.video_recorder is not None self.video_recorder.capture_frame() From 55dc25472238f8f057ff181417e390cc9849ca73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 28 Oct 2024 15:06:45 +0800 Subject: [PATCH 3/9] polish(pu): adapt qmix's mixer to support image obs --- ding/model/template/qmix.py | 17 +++++++++++++---- .../config/ptz_pistonball_qmix_config.py | 6 +++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index 68354e0cf7..a4af8e0ba8 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -5,6 +5,7 @@ from functools import reduce from ding.utils import list_split, MODEL_REGISTRY from ding.torch_utils import fc_block, MLP +from ..common import ConvEncoder from .q_learning import DRQN @@ -111,7 +112,7 @@ def __init__( self, agent_num: int, obs_shape: int, - global_obs_shape: int, + global_obs_shape: Union[int, List[int]], action_shape: int, hidden_size_list: list, mixer: bool = True, @@ -146,8 +147,14 @@ def __init__( embedding_size = hidden_size_list[-1] self.mixer = mixer if self.mixer: - self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) - self._global_state_encoder = nn.Identity() + if len(global_obs_shape) == 1: + self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) + self._global_state_encoder = nn.Identity() + elif len(global_obs_shape) == 3: + self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) + self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN') + else: + raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape)) def forward(self, data: dict, single_step: bool = True) -> dict: """ @@ -183,7 +190,9 @@ def forward(self, data: dict, single_step: bool = True) -> dict: 'prev_state'] action = data.get('action', None) if single_step: - agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) + agent_state = agent_state.unsqueeze(0) + if single_step and len(global_state.shape) == 2: + global_state = global_state.unsqueeze(0) T, B, A = agent_state.shape[:3] assert len(prev_state) == B and all( [len(p) == A for p in prev_state] diff --git a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py index f1b2da682a..7af0289d89 100644 --- a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py +++ b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py @@ -21,6 +21,7 @@ shared_memory=False, reset_timeout=6000, ), + max_env_step=3e6, ), policy=dict( cuda=True, @@ -30,8 +31,7 @@ global_obs_shape=(3, 560, 880), # Global state shape action_shape=3, # Discrete actions (0, 1, 2) hidden_size_list=[128, 128, 64], - # mixer=True, # TODO: mixer is not supported image observation now - mixer=False, + mixer=True, ), learn=dict( update_per_collect=100, @@ -73,4 +73,4 @@ if __name__ == '__main__': # or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0` from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) \ No newline at end of file + serial_pipeline((main_config, create_config), seed=0, max_env_step=main_config.env.max_env_step) \ No newline at end of file From e6a18baa66e2f7ac3de1db1908e8186a17417bbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cpuyuan1996=E2=80=9D?= <2402552459@qq.com> Date: Tue, 12 Nov 2024 18:01:34 +0800 Subject: [PATCH 4/9] tmp commit: unizero_mt_ddp_v2 --- ding/model/template/qmix.py | 5 ++- ding/policy/base_policy.py | 14 ++++++- ding/worker/learner/base_learner.py | 4 ++ .../config/ptz_pistonball_qmix_config.py | 17 ++++---- .../envs/petting_zoo_pistonball_env.py | 40 +++++++++++++------ 5 files changed, 57 insertions(+), 23 deletions(-) diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index a4af8e0ba8..6d28bac6bc 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -214,7 +214,10 @@ def forward(self, data: dict, single_step: bool = True) -> dict: agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1)) agent_q_act = agent_q_act.squeeze(-1) # T, B, A if self.mixer: - global_state_embedding = self._global_state_encoder(global_state) + if len(global_state.shape) == 5: + global_state_embedding = self._global_state_encoder(global_state.reshape(-1, *global_state.shape[-3:])).reshape(global_state.shape[0], global_state.shape[1], -1) + else: + global_state_embedding = self._global_state_encoder(global_state) total_q = self._mixer(agent_q_act, global_state_embedding) else: total_q = agent_q_act.sum(-1) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 7e843f8429..1f0b827ce8 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -421,10 +421,22 @@ def sync_gradients(self, model: torch.nn.Module) -> None: gradients allreduce and optimizer updates. """ + # if self._bp_update_sync: + # for name, param in model.named_parameters(): + # if param.requires_grad: + # if param.grad is not None: + # allreduce(param.grad.data) + # else: + # synchronize() if self._bp_update_sync: for name, param in model.named_parameters(): if param.requires_grad: - allreduce(param.grad.data) + if param.grad is not None: + allreduce(param.grad.data) + else: + # 如果梯度为 None,则创建一个与 param.grad_size 相同的零张量,并执行 allreduce + zero_grad = torch.zeros_like(param.data) + allreduce(zero_grad) else: synchronize() diff --git a/ding/worker/learner/base_learner.py b/ding/worker/learner/base_learner.py index 1144a412cd..e888945763 100644 --- a/ding/worker/learner/base_learner.py +++ b/ding/worker/learner/base_learner.py @@ -108,6 +108,10 @@ def __init__( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False ) self._tb_logger = None + + self._tb_logger = None + + self._log_buffer = { 'scalar': build_log_buffer(), 'scalars': build_log_buffer(), diff --git a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py index 7af0289d89..3ef0b91313 100644 --- a/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py +++ b/dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py @@ -1,11 +1,12 @@ from easydict import EasyDict n_pistons = 20 +# n_pistons = 2 collector_env_num = 8 evaluator_env_num = 8 main_config = dict( - exp_name='ptz_pistonball_qmix_seed0', + exp_name=f'data_pistonball/ptz_pistonball_n{n_pistons}_qmix_seed0', env=dict( env_family='butterfly', env_id='pistonball_v6', @@ -17,10 +18,7 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, stop_value=1e6, - manager=dict( - shared_memory=False, - reset_timeout=6000, - ), + manager=dict(shared_memory=False,), max_env_step=3e6, ), policy=dict( @@ -34,15 +32,16 @@ mixer=True, ), learn=dict( - update_per_collect=100, - batch_size=32, - learning_rate=0.0005, + update_per_collect=20, + batch_size=16, + learning_rate=0.0001, target_update_theta=0.001, discount_factor=0.99, double_q=True, + clip_value=10, ), collect=dict( - n_sample=600, + n_sample=32, unroll_len=16, env_num=collector_env_num, ), diff --git a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py index 2238b85c53..5483651e3d 100644 --- a/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py +++ b/dizoo/petting_zoo/envs/petting_zoo_pistonball_env.py @@ -147,21 +147,37 @@ def _process_obs(self, obs: Dict[str, np.ndarray]) -> np.ndarray: """ Processes the observations into the required format. """ - if self._channel_first: - obs = np.array([np.transpose(obs[agent], (2, 0, 1)) for agent in self._agents]).astype(np.uint8) - else: - obs = np.array([obs[agent] for agent in self._agents]).astype(np.uint8) + # Process agent observations, transpose if channel_first is True + obs = np.array( + [np.transpose(obs[agent], (2, 0, 1)) if self._channel_first else obs[agent] + for agent in self._agents], + dtype=np.uint8 + ) + + # Return only agent observations if configured to do so if self._cfg.get('agent_obs_only', False): return obs - ret = {} - ret['agent_state'] = obs + + # Initialize return dictionary + ret = { + 'agent_state': (obs / 255.0).astype(np.float32) + } + + # Obtain global state, transpose if channel_first is True + global_state = self._env.state() if self._channel_first: - ret['global_state'] = self._env.state().transpose(2, 0, 1) - else: - ret['global_state'] = self._env.state() - if self._agent_specific_global_state: # TODO: more elegant way to handle this - ret['global_state'] = np.repeat(np.expand_dims(ret['global_state'], axis=0), self._num_pistons, axis=0) - ret['action_mask'] = np.ones((self._num_pistons, *self._action_dim)).astype(np.float32) + global_state = global_state.transpose(2, 0, 1) + ret['global_state'] = (global_state / 255.0).astype(np.float32) + + # Handle agent-specific global states by repeating the global state for each agent + if self._agent_specific_global_state: + ret['global_state'] = np.tile( + np.expand_dims(ret['global_state'], axis=0), + (self._num_pistons, 1, 1, 1) + ) + + # Set action mask for each agent + ret['action_mask'] = np.ones((self._num_pistons, *self._action_dim), dtype=np.float32) return ret From a42c85be67a9e2e03b4933a7b143125555801ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cpuyuan1996=E2=80=9D?= <2402552459@qq.com> Date: Tue, 12 Nov 2024 18:16:26 +0800 Subject: [PATCH 5/9] polish(pu): adapt learner to unizero_multitask_ddp_v2 --- ding/worker/learner/base_learner.py | 6 +++--- ding/worker/learner/learner_hook.py | 16 +++++++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/ding/worker/learner/base_learner.py b/ding/worker/learner/base_learner.py index e888945763..9862216f6c 100644 --- a/ding/worker/learner/base_learner.py +++ b/ding/worker/learner/base_learner.py @@ -107,9 +107,9 @@ def __init__( self._logger, _ = build_logger( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False ) - self._tb_logger = None - - self._tb_logger = None + # self._tb_logger = None + # ========== TODO: unizero_multitask ddp_v2 ======== + self._tb_logger = tb_logger self._log_buffer = { diff --git a/ding/worker/learner/learner_hook.py b/ding/worker/learner/learner_hook.py index 250a8f1950..3b564812e4 100644 --- a/ding/worker/learner/learner_hook.py +++ b/ding/worker/learner/learner_hook.py @@ -273,8 +273,22 @@ def aggregate(data): Returns: - new_data (:obj:`dict`): data after reduce """ + # if isinstance(data, dict): + # new_data = {k: aggregate(v) for k, v in data.items()} + + def should_reduce(key): + # 检查 key 是否以 "noreduce_" 前缀开头 + return not key.startswith("noreduce_") + if isinstance(data, dict): - new_data = {k: aggregate(v) for k, v in data.items()} + new_data = {} + for k, v in data.items(): + if should_reduce(k): + new_data[k] = aggregate(v) # 对需要 reduce 的数据执行 allreduce + else: + new_data[k] = v # 不需要 reduce 的数据直接保留 + + elif isinstance(data, list) or isinstance(data, tuple): new_data = [aggregate(t) for t in data] elif isinstance(data, torch.Tensor): From 7a66b76ea94e5114f155d151accf1170754bb7bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cpuyuan1996=E2=80=9D?= <2402552459@qq.com> Date: Tue, 12 Nov 2024 18:47:28 +0800 Subject: [PATCH 6/9] polish(pu): adapt learner to unizero_multitask_ddp_v2 --- ding/worker/learner/base_learner.py | 7 +++++-- ding/worker/learner/learner_hook.py | 10 +++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ding/worker/learner/base_learner.py b/ding/worker/learner/base_learner.py index 9862216f6c..4ffef2e85f 100644 --- a/ding/worker/learner/base_learner.py +++ b/ding/worker/learner/base_learner.py @@ -436,8 +436,11 @@ def policy(self, _policy: 'Policy') -> None: # noqa Policy variable monitor is set alongside with policy, because variables are determined by specific policy. """ self._policy = _policy - if self._rank == 0: - self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) + # if self._rank == 0: + # self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) + + self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10) + if self._cfg.log_policy: self.info(self._policy.info()) diff --git a/ding/worker/learner/learner_hook.py b/ding/worker/learner/learner_hook.py index 3b564812e4..ae03d8c8b6 100644 --- a/ding/worker/learner/learner_hook.py +++ b/ding/worker/learner/learner_hook.py @@ -202,10 +202,11 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa - engine (:obj:`BaseLearner`): the BaseLearner """ # Only show log for rank 0 learner - if engine.rank != 0: - for k in engine.log_buffer: - engine.log_buffer[k].clear() - return + # if engine.rank != 0: + # for k in engine.log_buffer: + # engine.log_buffer[k].clear() + # return + # For 'scalar' type variables: log_buffer -> tick_monitor -> monitor_time.step for k, v in engine.log_buffer['scalar'].items(): setattr(engine.monitor, k, v) @@ -288,7 +289,6 @@ def should_reduce(key): else: new_data[k] = v # 不需要 reduce 的数据直接保留 - elif isinstance(data, list) or isinstance(data, tuple): new_data = [aggregate(t) for t in data] elif isinstance(data, torch.Tensor): From 0c4b33874329746cc22568a26c92255ff2245fe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cpuyuan1996=E2=80=9D?= <2402552459@qq.com> Date: Thu, 28 Nov 2024 15:38:43 +0800 Subject: [PATCH 7/9] test(pu): add timeout in dist.init_process_group --- ding/utils/pytorch_ddp_dist_helper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index 13d9e1e299..09c7901955 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -171,7 +171,10 @@ def dist_init(backend: str = 'nccl', else: world_size = int(ntasks) - dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + # dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + # TODO: + import datetime + dist.init_process_group(backend=backend, rank=rank, world_size=world_size, timeout=datetime.timedelta(seconds=60000)) num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) From 141cf512a9794f8f22ac9e49b2c97cb05a94ef54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cpuyuan1996=E2=80=9D?= <2402552459@qq.com> Date: Tue, 21 Jan 2025 18:58:33 +0800 Subject: [PATCH 8/9] feature(pu): adapt ppo vac to env that return obs dict (include action mask), e.g. detective env --- ding/model/template/vac.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ding/model/template/vac.py b/ding/model/template/vac.py index 00ef4162b9..254a670c78 100644 --- a/ding/model/template/vac.py +++ b/ding/model/template/vac.py @@ -265,12 +265,17 @@ def compute_actor(self, x: torch.Tensor) -> Dict: >>> assert actor_outputs['logit'].shape == torch.Size([4, 64]) """ if self.share_encoder: - x = self.encoder(x) + # import ipdb;ipdb.set_trace() + # x = self.encoder(x) + action_mask = x['action_mask'] + x = self.encoder(x['observation']) else: x = self.actor_encoder(x) if self.action_space == 'discrete': - return self.actor_head(x) + # return self.actor_head(x) + # import ipdb;ipdb.set_trace() + return {'logit': self.actor_head(x)['logit'], 'action_mask': action_mask} elif self.action_space == 'continuous': x = self.actor_head(x) # mu, sigma return {'logit': x} @@ -299,7 +304,9 @@ def compute_critic(self, x: torch.Tensor) -> Dict: >>> assert critic_outputs['value'].shape == torch.Size([4]) """ if self.share_encoder: - x = self.encoder(x) + # x = self.encoder(x) + # action_mask = x['action_mask'] + x = self.encoder(x['observation']) else: x = self.critic_encoder(x) x = self.critic_head(x) @@ -339,7 +346,9 @@ def compute_actor_critic(self, x: torch.Tensor) -> Dict: dict output. """ if self.share_encoder: - actor_embedding = critic_embedding = self.encoder(x) + action_mask = x['action_mask'] + actor_embedding = critic_embedding = self.encoder(x['observation']) + # actor_embedding = critic_embedding = self.encoder(x) else: actor_embedding = self.actor_encoder(x) critic_embedding = self.critic_encoder(x) @@ -348,7 +357,8 @@ def compute_actor_critic(self, x: torch.Tensor) -> Dict: if self.action_space == 'discrete': logit = self.actor_head(actor_embedding)['logit'] - return {'logit': logit, 'value': value} + # return {'logit': logit, 'value': value} + return {'logit': logit, 'action_mask': action_mask, 'value': value} elif self.action_space == 'continuous': x = self.actor_head(actor_embedding) return {'logit': x, 'value': value} From d6f5fcbb7f381102cd727c64fbfabf4d53ac39a1 Mon Sep 17 00:00:00 2001 From: puyuan Date: Thu, 10 Apr 2025 13:50:26 +0000 Subject: [PATCH 9/9] polish(pu): polish sync gradients --- ding/policy/base_policy.py | 56 +++++++++++++++++++++------ ding/utils/__init__.py | 2 +- ding/utils/pytorch_ddp_dist_helper.py | 19 +++++++++ 3 files changed, 64 insertions(+), 13 deletions(-) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 1f0b827ce8..39ac51eb75 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -7,7 +7,7 @@ import torch from ding.model import create_model -from ding.utils import import_module, allreduce, broadcast, get_rank, allreduce_async, synchronize, deep_merge_dicts, \ +from ding.utils import import_module, allreduce, allreduce_with_indicator, broadcast, get_rank, allreduce_async, synchronize, deep_merge_dicts, \ POLICY_REGISTRY @@ -421,25 +421,57 @@ def sync_gradients(self, model: torch.nn.Module) -> None: gradients allreduce and optimizer updates. """ - # if self._bp_update_sync: - # for name, param in model.named_parameters(): - # if param.requires_grad: - # if param.grad is not None: - # allreduce(param.grad.data) - # else: - # synchronize() if self._bp_update_sync: for name, param in model.named_parameters(): if param.requires_grad: + # Create an indicator tensor on the same device as the parameter (or its gradient) if param.grad is not None: - allreduce(param.grad.data) + # If the gradient exists, extract its data and set indicator to 1. + grad_tensor = param.grad.data + indicator = torch.tensor(1.0, device=grad_tensor.device) else: - # 如果梯度为 None,则创建一个与 param.grad_size 相同的零张量,并执行 allreduce - zero_grad = torch.zeros_like(param.data) - allreduce(zero_grad) + # If the parameter did not participate in the computation (grad is None), + # create a zero tensor for the gradient and set the indicator to 0. + grad_tensor = torch.zeros_like(param.data) + indicator = torch.tensor(0.0, device=grad_tensor.device) + + # Assign the zero gradient to param.grad to ensure that all GPUs + # participate in the subsequent allreduce call (avoiding deadlock). + param.grad = grad_tensor + + # Use the custom allreduce function to reduce the gradient using the indicator. + allreduce_with_indicator(param.grad, indicator) + # else: + # # 对于不需要梯度的参数,也需要参与集体通信以确保所有进程的调用顺序一致 + # dummy_tensor = torch.tensor(0.0, device=param.data.device) + # dummy_indicator = torch.tensor(0.0, device=param.data.device) + # allreduce_with_indicator(dummy_tensor, dummy_indicator) else: synchronize() + + + # if self._bp_update_sync: + # for name, param in model.named_parameters(): + # if param.requires_grad: + # if param.grad is not None: + # allreduce(param.grad.data) + # else: + # # 如果梯度为 None,则创建一个与 param.grad_size 相同的零张量,并执行 allreduce + # zero_grad = torch.zeros_like(param.data) + # allreduce(zero_grad) + # else: + # synchronize() + + + # if self._bp_update_sync: + # for name, param in model.named_parameters(): + # if param.requires_grad: + # if param.grad is not None: + # allreduce(param.grad.data) + # else: + # synchronize() + # don't need to implement default_model method by force def default_model(self) -> Tuple[str, List[str]]: """ diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index 68a92efcb5..216cea468f 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -39,5 +39,5 @@ allreduce, broadcast, DistContext, allreduce_async, synchronize else: from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ - allreduce, broadcast, DDPContext, allreduce_async, synchronize, reduce_data, broadcast_object_list, \ + allreduce, allreduce_with_indicator, broadcast, DDPContext, allreduce_async, synchronize, reduce_data, broadcast_object_list, \ to_ddp_config, allreduce_data diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index 09c7901955..5dafa83906 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -45,6 +45,25 @@ def allreduce(x: torch.Tensor) -> None: dist.all_reduce(x) x.div_(get_world_size()) +def allreduce_with_indicator(grad: torch.Tensor, indicator: torch.Tensor) -> None: + """ + Overview: + Custom allreduce: Sum both the gradient and indicator tensors across all processes. + Then, if at least one process contributed (i.e., the summation of indicator > 0), + divide the gradient by the summed indicator. This ensures that if only a subset of + GPUs contributed a gradient, the averaging is performed based on the actual number + of contributors rather than the total number of GPUs. + Arguments: + - grad (torch.Tensor): Local gradient tensor to be reduced. + - indicator (torch.Tensor): A tensor flag (1 if the gradient is computed, 0 otherwise). + """ + # Allreduce (sum) the gradient and indicator + dist.all_reduce(grad) + dist.all_reduce(indicator) + + # Avoid division by zero. If indicator is close to 0 (extreme case), grad remains zeros. + if not torch.isclose(indicator, torch.tensor(0.0)): + grad.div_(indicator.item()) def allreduce_async(name: str, x: torch.Tensor) -> None: """