From c69720daaff589b5d2462b8ae649186dfa9b9f6f Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Fri, 25 Oct 2024 12:08:10 +0100 Subject: [PATCH] fix: fix the gym box shape for grid observation --- lbforaging/foraging/environment.py | 17 ++++++----------- lbforaging/foraging/test.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) create mode 100644 lbforaging/foraging/test.py diff --git a/lbforaging/foraging/environment.py b/lbforaging/foraging/environment.py index b3c8636..c0dcccf 100644 --- a/lbforaging/foraging/environment.py +++ b/lbforaging/foraging/environment.py @@ -245,7 +245,7 @@ def _get_observation_space(self): high_obs = np.array(max_obs) assert low_obs.shape == high_obs.shape return gym.spaces.Box( - low=low_obs, high=high_obs, shape=[len(low_obs)], dtype=np.float32 + low=low_obs, high=high_obs, shape=low_obs.shape, dtype=np.float32 ) @classmethod @@ -436,7 +436,7 @@ def _is_valid_action(self, player, action): elif action == Action.LOAD: return self.adjacent_food(*player.position) > 0 - self.logger.error("Undefined action {} from {}".format(action, player.name)) + self.logger.error(f"Undefined action {action} from {player.name}") raise ValueError("Undefined action") def _transform_to_neighborhood(self, center, sight, position): @@ -574,13 +574,11 @@ def get_agent_grid_bounds(agent_x, agent_y): get_agent_grid_bounds(*player.position) for player in self.players ] nobs = tuple( - [ - layers[:, start_x:end_x, start_y:end_y] - for start_x, end_x, start_y, end_y in agents_bounds - ] + layers[:, start_x:end_x, start_y:end_y] + for start_x, end_x, start_y, end_y in agents_bounds ) else: - nobs = tuple([make_obs_array(obs) for obs in observations]) + nobs = tuple(make_obs_array(obs) for obs in observations) # check the space of obs for i, obs in enumerate(nobs): @@ -631,10 +629,7 @@ def step(self, actions): for i, (player, action) in enumerate(zip(self.players, actions)): if action not in self._valid_actions[player]: self.logger.info( - "{}{} attempted invalid action {}.".format( - player.name, player.position, action - ) - ) + f"{player.name}{player.position} attempted invalid action {action}.") actions[i] = Action.NONE loading_players = set() diff --git a/lbforaging/foraging/test.py b/lbforaging/foraging/test.py new file mode 100644 index 0000000..015942e --- /dev/null +++ b/lbforaging/foraging/test.py @@ -0,0 +1,19 @@ +from environment import ForagingEnv + +env = ForagingEnv( + grid_observation=True, + players=2, + max_player_level=2, + field_size=(8,8), + max_num_food=2, + sight=8, + force_coop=True, + min_player_level=1, + min_food_level=1, + max_food_level=10, + max_episode_steps=100) + +nobs, infos = env.reset() +actions = [5,5,0] +nobs, rewards, done, truncated, info = env.step(actions=actions) +print(f"nobs: \n {nobs[0].shape}") \ No newline at end of file