Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
fd809ae
Add OpenSpiel interface.
jhtschultz Nov 17, 2020
a02237e
OpenSpiel interface v2.
jhtschultz Nov 21, 2020
c8a625b
Implement review feedback.
jhtschultz Nov 25, 2020
4ea4d78
Fix run_dqn.py example so each agent gets its own network.
jhtschultz Nov 25, 2020
55e0ce7
Disable pytest import errors.
jhtschultz Jan 7, 2021
58f3ba0
Merge branch 'master' into open_spiel
jhtschultz Jan 7, 2021
8d0f41e
Remove duplicate policy_network argument.
jhtschultz Jan 8, 2021
1f9251d
Set prev_actions default to INVALID_ACTION instead of None.
jhtschultz Jan 8, 2021
5f0191f
Remove verbose printouts from env loop.
jhtschultz Jan 8, 2021
5c20c23
Comment out OpenSpielWrapper from __init__.py so pytest passes.
jhtschultz Jan 8, 2021
35d434f
Comment out files in __init__.py with OpenSpiel imports so pytest pas…
jhtschultz Jan 8, 2021
4a44c64
Merge branch 'open_spiel' of https://github.com/jhtschultz/acme into …
jhtschultz Jan 8, 2021
2ccd84b
Add OpenSpiel wrapper test.
jhtschultz Jan 8, 2021
785fb86
Add OpenSpiel env loop test.
jhtschultz Jan 9, 2021
99b57d0
Remove unnecessary TODO.
jhtschultz Jan 9, 2021
592002a
Remove unnecessary pytype disable.
jhtschultz Jan 13, 2021
e7b77a5
Refactor loop for constructing agents.
jhtschultz Jan 13, 2021
74dfddf
Change TestCase base class of OpenSpielWrapperTest.
jhtschultz Jan 13, 2021
0328e63
Re-enable pytype import-error.
jhtschultz Jan 13, 2021
0e538c7
Minor cosmetic changes.
jhtschultz Jan 13, 2021
381692c
Narrow down pytype disable import-error and re-enable.
jhtschultz Jan 13, 2021
fdbf4c0
Access OpenSpiel env properties instead of private attributes.
jhtschultz Jan 13, 2021
429b1ec
Remove unused imports.
jhtschultz Jan 13, 2021
2592165
Use consistent quote delimiter formatting.
jhtschultz Jan 13, 2021
4de0f90
Access OpenSpiel env properties instead of private attributes.
jhtschultz Jan 13, 2021
17cb075
Access OpenSpiel env properties instead of private attributes.
jhtschultz Jan 13, 2021
d0811ef
Minor cosmetic changes.
jhtschultz Jan 13, 2021
d57c516
Specify np.array dtypes.
jhtschultz Jan 13, 2021
b95a28c
Wrap __init__.py imports in try-except block.
jhtschultz Jan 14, 2021
4a3677c
Remove unnecessary pylint disable.
jhtschultz Jan 14, 2021
7853707
Narrow down pytype disable import-error and re-enable.
jhtschultz Jan 14, 2021
fa986b2
Refactor handling of terminal states.
jhtschultz Jan 14, 2021
920d282
Clarify note regarding use of -np.inf.
jhtschultz Jan 14, 2021
c5b0070
Remove unused import.
jhtschultz Jan 14, 2021
6141f1a
Remove default threshold argument.
jhtschultz Jan 14, 2021
f676b8a
Pylint formatting changes.
jhtschultz Jan 14, 2021
ff9b7ad
Clarify TODO.
jhtschultz Jan 14, 2021
5caa0ac
Remove unnecessary pylint disable.
jhtschultz Jan 14, 2021
b46960d
Remove unnecessary pylint disable.
jhtschultz Jan 14, 2021
6af4844
Merge branch 'open_spiel' of https://github.com/jhtschultz/acme into …
jhtschultz Jan 14, 2021
092aaba
Access environment._use_observation attribute via property.
jhtschultz Jan 15, 2021
e00393a
Fix type hints.
jhtschultz Jan 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions acme/environment_loops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Specialized environment loops."""

try:
from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop
except ImportError:
pass
221 changes: 221 additions & 0 deletions acme/environment_loops/open_spiel_environment_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# python3
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An OpenSpiel multi-agent/environment training loop."""

import operator
import time
from typing import Optional, Sequence

from acme import core
from acme.utils import counting
from acme.utils import loggers
from acme.wrappers import open_spiel_wrapper
import dm_env
from dm_env import specs
import numpy as np
# pytype: disable=import-error
import pyspiel
# pytype: enable=import-error
import tree


class OpenSpielEnvironmentLoop(core.Worker):
"""An OpenSpiel RL environment loop.

This takes `Environment` and list of `Actor` instances and coordinates their
interaction. Agents are updated if `should_update=True`. This can be used as:

loop = EnvironmentLoop(environment, actors)
loop.run(num_episodes)

A `Counter` instance can optionally be given in order to maintain counts
between different Acme components. If not given a local Counter will be
created to maintain counts between calls to the `run` method.

A `Logger` instance can also be passed in order to control the output of the
loop. If not given a platform-specific default logger will be used as defined
by utils.loggers.make_default_logger. A string `label` can be passed to easily
change the label associated with the default logger; this is ignored if a
`Logger` instance is given.
"""

def __init__(
self,
environment: open_spiel_wrapper.OpenSpielWrapper,
actors: Sequence[core.Actor],
counter: counting.Counter = None,
logger: loggers.Logger = None,
should_update: bool = True,
label: str = 'open_spiel_environment_loop',
):
# Internalize agent and environment.
self._environment = environment
self._actors = actors
self._counter = counter or counting.Counter()
self._logger = logger or loggers.make_default_logger(label)
self._should_update = should_update

# Track information necessary to coordinate updates among multiple actors.
self._observed_first = [False] * len(self._actors)
self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors)

def _send_observation(self, timestep: dm_env.TimeStep, player: int):
# If terminal all actors must update
if player == pyspiel.PlayerId.TERMINAL:
for player_id in range(len(self._actors)):
# Note: we must account for situations where the first observation
# is a terminal state, e.g. if an opponent folds in poker before we get
# to act.
if self._observed_first[player_id]:
player_timestep = self._get_player_timestep(timestep, player_id)
self._actors[player_id].observe(self._prev_actions[player_id],
player_timestep)
if self._should_update:
self._actors[player_id].update()
self._observed_first = [False] * len(self._actors)
self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe initialize these two once at the beginning of run_episode? That would eliminate code duplication.

else:
if not self._observed_first[player]:
player_timestep = dm_env.TimeStep(
observation=timestep.observation[player],
reward=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't OpenSpiel return reward in this case? In generic case I think we should use reward provided by the environment. Can't we call _get_player_timestep instead (setting dm_env.StepType.FIRST)?

discount=None,
step_type=dm_env.StepType.FIRST)
self._actors[player].observe_first(player_timestep)
self._observed_first[player] = True
else:
player_timestep = self._get_player_timestep(timestep, player)
self._actors[player].observe(self._prev_actions[player],
player_timestep)
if self._should_update:
self._actors[player].update()

def _get_action(self, timestep: dm_env.TimeStep, player: int) -> int:
self._prev_actions[player] = self._actors[player].select_action(
timestep.observation[player])
return self._prev_actions[player]

def _get_player_timestep(self, timestep: dm_env.TimeStep,
player: int) -> dm_env.TimeStep:
return dm_env.TimeStep(observation=timestep.observation[player],
reward=timestep.reward[player],
discount=timestep.discount[player],
step_type=timestep.step_type)

def run_episode(self) -> loggers.LoggingData:
"""Run one episode.

Each episode is a loop which interacts first with the environment to get an
observation and then give that observation to the agent in order to retrieve
an action.

Returns:
An instance of `loggers.LoggingData`.
"""
# Reset any counts and start the environment.
start_time = time.time()
episode_steps = 0

# For evaluation, this keeps track of the total undiscounted reward
# for each player accumulated during the episode.
multiplayer_reward_spec = specs.BoundedArray(
(self._environment.game.num_players(),),
np.float32,
minimum=self._environment.game.min_utility(),
maximum=self._environment.game.max_utility())
episode_return = tree.map_structure(_generate_zeros_from_spec,
multiplayer_reward_spec)

timestep = self._environment.reset()

# Make the first observation.
self._send_observation(timestep, self._environment.current_player)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't _send_observation be called as the first thing in the loop (which would eliminate duplicated code)?

# Run an episode.
while not timestep.last():
# Generate an action from the agent's policy and step the environment.
if self._environment.is_turn_based:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I was thinking about possibility of making this environment loop more generic by covering cases of multi-agent/multi-team environments that are not necessarily turn based or fully simultaneous.
Lets consider multi-agent, turn-based environment (for example soccer environment like https://github.com/google-research/football/) where each team consists of 11 players, but teams execute actions in turns. To cover OpenSpiel and this use case with one implementation we could:

  • eliminate _environment.is_turn_based
  • replace _environment.get_state.current_player() with a set of current players
    That would cover "# TODO Support simultaneous move games." in a generic way.
    What do you think about that?

Copy link

@lanctot lanctot Feb 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @qstanczyk,

Can you elaborate on your motivation for this? I worry because if we make this too general, it gets harder to interface with the specific API it was designed to be used with and could introduce some awkward bugs for use cases we didn't anticipate by supporting multiple APIs.

Also I don't understand the football env example. If that environment is doing turn-based already, then it should fit exactly the OpenSpiel API as well already (which as you say is already turn-based and you just have 11 consecutive actions from the same player before switching the current player to the next one)?

WDYT about a wrapper of the football env as an OpenSpiel python game (living in this repos or in OpenSpiel)? Sounds like they are compatible from your description. Let me know if I'm missing something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If turn-based Soccer environment would be integrated the way you suggest (11 actions for each team) then what is the use case you anticipate for "!is_turn_based"? Any multi-agent environment can be implemented by passing N actions, so maybe in such case there is no need to have "is_turn_based" check around?
My motivation was to try and generalize the API by supporting any is_turn_based / multi-agent mix. I'm fine with not taking this thought into account, but I just wonder at this point if "is_turn_based" is needed.

Copy link

@lanctot lanctot Feb 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If turn-based Soccer environment would be integrated the way you suggest (11 actions for each team) then what is the use case you anticipate for "!is_turn_based"?

Wait, is that not how it's implemented? I only assumed that based on what you said. The !is_turn_based is necessary here because of the way OpenSpiel categorizes its dynamics. Some games require actions from all agents (simultaneous game) and apply them all at once, and others are turn-based so you only get one for the current player and apply that action only.

Any multi-agent environment can be implemented by passing N actions, so maybe in such case there is no need to have "is_turn_based" check around?

Yes, I agree, but that's not how OpenSpiel implements turn-based games.

My motivation was to try and generalize the API by supporting any is_turn_based / multi-agent mix.

Ok yes I think I understand. I think this would make more sense if we had a use case for this generalization, and I'm not seeing it (yet). So my hesitatation is due mainly to this wrapper being designed for OpenSpiel specifically.

OpenSpiel treats these two cases separately because the algorithms that are built on top of the different type of games depend on these choices. Also, in OpenSpiel a simultaneous game doesn't need to be fully simultaneous at all states and also it's fine for a subset of agents to only have 1 legal pass-like move, hence my suggestion of wrapping the ones you mention as OpenSpiel simultaneous-move games, because then this interface would work on them. (Would that work, e.g. for the football env?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack, I was just thinking out-loud about generalization of the API. We don't look at hooking up GRF at this point, but was wondering if this OpenSpiel environment loop could be more generic.

action_list = [
self._get_action(timestep, self._environment.current_player)
]
else:
# TODO Support simultaneous move games.
raise ValueError('Currently only supports sequential games.')

timestep = self._environment.step(action_list)

# Have the agent observe the timestep and let the actor update itself.
self._send_observation(timestep, self._environment.current_player)

# Book-keeping.
episode_steps += 1

# Equivalent to: episode_return += timestep.reward
tree.map_structure(operator.iadd, episode_return, timestep.reward)

# Record counts.
counts = self._counter.increment(episodes=1, steps=episode_steps)

# Collect the results and combine with counts.
steps_per_second = episode_steps / (time.time() - start_time)
result = {
'episode_length': episode_steps,
'episode_return': episode_return,
'steps_per_second': steps_per_second,
}
result.update(counts)
return result

def run(self,
num_episodes: Optional[int] = None,
num_steps: Optional[int] = None):
"""Perform the run loop.

Run the environment loop either for `num_episodes` episodes or for at
least `num_steps` steps (the last episode is always run until completion,
so the total number of steps may be slightly more than `num_steps`).
At least one of these two arguments has to be None.

Upon termination of an episode a new episode will be started. If the number
of episodes and the number of steps are not given then this will interact
with the environment infinitely.

Args:
num_episodes: number of episodes to run the loop for.
num_steps: minimal number of steps to run the loop for.

Raises:
ValueError: If both 'num_episodes' and 'num_steps' are not None.
"""

if not (num_episodes is None or num_steps is None):
raise ValueError('Either "num_episodes" or "num_steps" should be None.')

def should_terminate(episode_count: int, step_count: int) -> bool:
return ((num_episodes is not None and episode_count >= num_episodes) or
(num_steps is not None and step_count >= num_steps))

episode_count, step_count = 0, 0
while not should_terminate(episode_count, step_count):
result = self.run_episode()
episode_count += 1
step_count += result['episode_length']
# Log the given results.
self._logger.write(result)


def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray:
return np.zeros(spec.shape, spec.dtype)
101 changes: 101 additions & 0 deletions acme/environment_loops/open_spiel_environment_loop_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# python3
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for OpenSpiel environment loop."""

import unittest
from absl.testing import absltest
from absl.testing import parameterized

import acme
from acme import core
from acme import specs
from acme import types
from acme import wrappers

import dm_env
import numpy as np
import tree

SKIP_OPEN_SPIEL_TESTS = False
SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.'

try:
# pytype: disable=import-error
from acme.environment_loops import open_spiel_environment_loop
from acme.wrappers import open_spiel_wrapper
from open_spiel.python import rl_environment
# pytype: disable=import-error

class RandomActor(core.Actor):
"""Fake actor which generates random actions and validates specs."""

def __init__(self, spec: specs.EnvironmentSpec):
self._spec = spec
self.num_updates = 0

def select_action(self, observation: open_spiel_wrapper.OLT) -> int:
_validate_spec(self._spec.observations, observation)
legals = np.array(np.nonzero(observation.legal_actions), dtype=np.int32)
return np.random.choice(legals[0])

def observe_first(self, timestep: dm_env.TimeStep):
_validate_spec(self._spec.observations, timestep.observation)

def observe(self, action: types.NestedArray,
next_timestep: dm_env.TimeStep):
_validate_spec(self._spec.actions, action)
_validate_spec(self._spec.rewards, next_timestep.reward)
_validate_spec(self._spec.discounts, next_timestep.discount)
_validate_spec(self._spec.observations, next_timestep.observation)

def update(self, wait: bool = False):
self.num_updates += 1

except ModuleNotFoundError:
SKIP_OPEN_SPIEL_TESTS = True


def _validate_spec(spec: types.NestedSpec, value: types.NestedArray):
"""Validate a value from a potentially nested spec."""
tree.assert_same_structure(value, spec)
tree.map_structure(lambda s, v: s.validate(v), spec, value)


@unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE)
class OpenSpielEnvironmentLoopTest(parameterized.TestCase):

def test_loop_run(self):
raw_env = rl_environment.Environment('tic_tac_toe')
env = open_spiel_wrapper.OpenSpielWrapper(raw_env)
env = wrappers.SinglePrecisionWrapper(env)
environment_spec = acme.make_environment_spec(env)

actors = []
for _ in range(env.num_players):
actors.append(RandomActor(environment_spec))

loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(env, actors)
result = loop.run_episode()
self.assertIn('episode_length', result)
self.assertIn('episode_return', result)
self.assertIn('steps_per_second', result)

loop.run(num_episodes=10)
loop.run(num_steps=100)


if __name__ == '__main__':
absltest.main()
5 changes: 5 additions & 0 deletions acme/tf/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
from acme.tf.networks.distributional import UnivariateGaussianMixture
from acme.tf.networks.distributions import DiscreteValuedDistribution
from acme.tf.networks.duelling import DuellingMLP
try:
from acme.tf.networks.legal_actions import MaskedSequential
from acme.tf.networks.legal_actions import EpsilonGreedy
except ImportError:
pass
from acme.tf.networks.masked_epsilon_greedy import NetworkWithMaskedEpsilonGreedy
from acme.tf.networks.multihead import Multihead
from acme.tf.networks.multiplexers import CriticMultiplexer
Expand Down
Loading