diff --git a/lbforaging/foraging/environment.py b/lbforaging/foraging/environment.py index 35497f7..a14c1a5 100644 --- a/lbforaging/foraging/environment.py +++ b/lbforaging/foraging/environment.py @@ -82,6 +82,7 @@ def __init__( sight, max_episode_steps, force_coop, + randomise_spawn=True, normalize_reward=True, grid_observation=False, penalty=0.0, @@ -99,6 +100,7 @@ def __init__( self.max_player_level = max_player_level self.sight = sight self.force_coop = force_coop + self.randomise_spawn = randomise_spawn self._game_over = None self._rendering_initialized = False @@ -253,8 +255,16 @@ def spawn_food(self, max_food, max_level): while food_count < max_food and attempts < 1000: attempts += 1 - row = self.np_random.randint(1, self.rows - 1) - col = self.np_random.randint(1, self.cols - 1) + if self.randomise_spawn: + # spawn food randomly + row = self.np_random.randint(1, self.rows - 1) + col = self.np_random.randint(1, self.cols - 1) + else: + # spawn food in the center of the map + # randomise around center to be able to spawn multiple food + rand = attempts // 10 + row = max(1, min(self.rows - 2, self.rows // 2 + self.np_random.randint(-rand, rand+1))) + col = max(1, min(self.cols - 2, self.cols // 2 + self.np_random.randint(-rand, rand+1))) # check if it has neighbors: if ( @@ -264,6 +274,7 @@ def spawn_food(self, max_food, max_level): ): continue + self.field[row, col] = ( min_level if min_level == max_level @@ -290,8 +301,27 @@ def spawn_players(self, max_player_level): player.reward = 0 while attempts < 1000: - row = self.np_random.randint(0, self.rows) - col = self.np_random.randint(0, self.cols) + if self.randomise_spawn: + # spawn player randomly + row = self.np_random.randint(0, self.rows) + col = self.np_random.randint(0, self.cols) + else: + # spawn players in corners + rand = attempts // 10 + corners = [ + (0, 0), + (0, self.cols - 1), + (self.rows - 1, 0), + (self.rows - 1, self.cols - 1), + ] + for row, col in corners: + # add increasing small randomisation after 10+ attempts + row = max(0, min(self.rows - 1, row + self.np_random.randint(-rand, rand+1))) + col = max(0, min(self.cols - 1, col + self.np_random.randint(-rand, rand+1))) + if self._is_empty_location(row, col): + # found empty corner + break + if self._is_empty_location(row, col): player.setup( (row, col), @@ -475,6 +505,7 @@ def reset(self): self.spawn_food( self.max_food, max_level=sum(player_levels[:3]) ) + self.current_step = 0 self._game_over = False self._gen_valid_moves()