diff --git a/acme/environment_loops/__init__.py b/acme/environment_loops/__init__.py new file mode 100644 index 0000000000..b018025c9b --- /dev/null +++ b/acme/environment_loops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Specialized environment loops.""" + +try: + from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop +except ImportError: + pass diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py new file mode 100644 index 0000000000..9584be4bab --- /dev/null +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -0,0 +1,221 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An OpenSpiel multi-agent/environment training loop.""" + +import operator +import time +from typing import Optional, Sequence + +from acme import core +from acme.utils import counting +from acme.utils import loggers +from acme.wrappers import open_spiel_wrapper +import dm_env +from dm_env import specs +import numpy as np +# pytype: disable=import-error +import pyspiel +# pytype: enable=import-error +import tree + + +class OpenSpielEnvironmentLoop(core.Worker): + """An OpenSpiel RL environment loop. + + This takes `Environment` and list of `Actor` instances and coordinates their + interaction. Agents are updated if `should_update=True`. This can be used as: + + loop = EnvironmentLoop(environment, actors) + loop.run(num_episodes) + + A `Counter` instance can optionally be given in order to maintain counts + between different Acme components. If not given a local Counter will be + created to maintain counts between calls to the `run` method. + + A `Logger` instance can also be passed in order to control the output of the + loop. If not given a platform-specific default logger will be used as defined + by utils.loggers.make_default_logger. A string `label` can be passed to easily + change the label associated with the default logger; this is ignored if a + `Logger` instance is given. + """ + + def __init__( + self, + environment: open_spiel_wrapper.OpenSpielWrapper, + actors: Sequence[core.Actor], + counter: counting.Counter = None, + logger: loggers.Logger = None, + should_update: bool = True, + label: str = 'open_spiel_environment_loop', + ): + # Internalize agent and environment. + self._environment = environment + self._actors = actors + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger(label) + self._should_update = should_update + + # Track information necessary to coordinate updates among multiple actors. + self._observed_first = [False] * len(self._actors) + self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) + + def _send_observation(self, timestep: dm_env.TimeStep, player: int): + # If terminal all actors must update + if player == pyspiel.PlayerId.TERMINAL: + for player_id in range(len(self._actors)): + # Note: we must account for situations where the first observation + # is a terminal state, e.g. if an opponent folds in poker before we get + # to act. + if self._observed_first[player_id]: + player_timestep = self._get_player_timestep(timestep, player_id) + self._actors[player_id].observe(self._prev_actions[player_id], + player_timestep) + if self._should_update: + self._actors[player_id].update() + self._observed_first = [False] * len(self._actors) + self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) + else: + if not self._observed_first[player]: + player_timestep = dm_env.TimeStep( + observation=timestep.observation[player], + reward=None, + discount=None, + step_type=dm_env.StepType.FIRST) + self._actors[player].observe_first(player_timestep) + self._observed_first[player] = True + else: + player_timestep = self._get_player_timestep(timestep, player) + self._actors[player].observe(self._prev_actions[player], + player_timestep) + if self._should_update: + self._actors[player].update() + + def _get_action(self, timestep: dm_env.TimeStep, player: int) -> int: + self._prev_actions[player] = self._actors[player].select_action( + timestep.observation[player]) + return self._prev_actions[player] + + def _get_player_timestep(self, timestep: dm_env.TimeStep, + player: int) -> dm_env.TimeStep: + return dm_env.TimeStep(observation=timestep.observation[player], + reward=timestep.reward[player], + discount=timestep.discount[player], + step_type=timestep.step_type) + + def run_episode(self) -> loggers.LoggingData: + """Run one episode. + + Each episode is a loop which interacts first with the environment to get an + observation and then give that observation to the agent in order to retrieve + an action. + + Returns: + An instance of `loggers.LoggingData`. + """ + # Reset any counts and start the environment. + start_time = time.time() + episode_steps = 0 + + # For evaluation, this keeps track of the total undiscounted reward + # for each player accumulated during the episode. + multiplayer_reward_spec = specs.BoundedArray( + (self._environment.game.num_players(),), + np.float32, + minimum=self._environment.game.min_utility(), + maximum=self._environment.game.max_utility()) + episode_return = tree.map_structure(_generate_zeros_from_spec, + multiplayer_reward_spec) + + timestep = self._environment.reset() + + # Make the first observation. + self._send_observation(timestep, self._environment.current_player) + + # Run an episode. + while not timestep.last(): + # Generate an action from the agent's policy and step the environment. + if self._environment.is_turn_based: + action_list = [ + self._get_action(timestep, self._environment.current_player) + ] + else: + # TODO Support simultaneous move games. + raise ValueError('Currently only supports sequential games.') + + timestep = self._environment.step(action_list) + + # Have the agent observe the timestep and let the actor update itself. + self._send_observation(timestep, self._environment.current_player) + + # Book-keeping. + episode_steps += 1 + + # Equivalent to: episode_return += timestep.reward + tree.map_structure(operator.iadd, episode_return, timestep.reward) + + # Record counts. + counts = self._counter.increment(episodes=1, steps=episode_steps) + + # Collect the results and combine with counts. + steps_per_second = episode_steps / (time.time() - start_time) + result = { + 'episode_length': episode_steps, + 'episode_return': episode_return, + 'steps_per_second': steps_per_second, + } + result.update(counts) + return result + + def run(self, + num_episodes: Optional[int] = None, + num_steps: Optional[int] = None): + """Perform the run loop. + + Run the environment loop either for `num_episodes` episodes or for at + least `num_steps` steps (the last episode is always run until completion, + so the total number of steps may be slightly more than `num_steps`). + At least one of these two arguments has to be None. + + Upon termination of an episode a new episode will be started. If the number + of episodes and the number of steps are not given then this will interact + with the environment infinitely. + + Args: + num_episodes: number of episodes to run the loop for. + num_steps: minimal number of steps to run the loop for. + + Raises: + ValueError: If both 'num_episodes' and 'num_steps' are not None. + """ + + if not (num_episodes is None or num_steps is None): + raise ValueError('Either "num_episodes" or "num_steps" should be None.') + + def should_terminate(episode_count: int, step_count: int) -> bool: + return ((num_episodes is not None and episode_count >= num_episodes) or + (num_steps is not None and step_count >= num_steps)) + + episode_count, step_count = 0, 0 + while not should_terminate(episode_count, step_count): + result = self.run_episode() + episode_count += 1 + step_count += result['episode_length'] + # Log the given results. + self._logger.write(result) + + +def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray: + return np.zeros(spec.shape, spec.dtype) diff --git a/acme/environment_loops/open_spiel_environment_loop_test.py b/acme/environment_loops/open_spiel_environment_loop_test.py new file mode 100644 index 0000000000..adecdbd426 --- /dev/null +++ b/acme/environment_loops/open_spiel_environment_loop_test.py @@ -0,0 +1,101 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OpenSpiel environment loop.""" + +import unittest +from absl.testing import absltest +from absl.testing import parameterized + +import acme +from acme import core +from acme import specs +from acme import types +from acme import wrappers + +import dm_env +import numpy as np +import tree + +SKIP_OPEN_SPIEL_TESTS = False +SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' + +try: + # pytype: disable=import-error + from acme.environment_loops import open_spiel_environment_loop + from acme.wrappers import open_spiel_wrapper + from open_spiel.python import rl_environment + # pytype: disable=import-error + + class RandomActor(core.Actor): + """Fake actor which generates random actions and validates specs.""" + + def __init__(self, spec: specs.EnvironmentSpec): + self._spec = spec + self.num_updates = 0 + + def select_action(self, observation: open_spiel_wrapper.OLT) -> int: + _validate_spec(self._spec.observations, observation) + legals = np.array(np.nonzero(observation.legal_actions), dtype=np.int32) + return np.random.choice(legals[0]) + + def observe_first(self, timestep: dm_env.TimeStep): + _validate_spec(self._spec.observations, timestep.observation) + + def observe(self, action: types.NestedArray, + next_timestep: dm_env.TimeStep): + _validate_spec(self._spec.actions, action) + _validate_spec(self._spec.rewards, next_timestep.reward) + _validate_spec(self._spec.discounts, next_timestep.discount) + _validate_spec(self._spec.observations, next_timestep.observation) + + def update(self, wait: bool = False): + self.num_updates += 1 + +except ModuleNotFoundError: + SKIP_OPEN_SPIEL_TESTS = True + + +def _validate_spec(spec: types.NestedSpec, value: types.NestedArray): + """Validate a value from a potentially nested spec.""" + tree.assert_same_structure(value, spec) + tree.map_structure(lambda s, v: s.validate(v), spec, value) + + +@unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) +class OpenSpielEnvironmentLoopTest(parameterized.TestCase): + + def test_loop_run(self): + raw_env = rl_environment.Environment('tic_tac_toe') + env = open_spiel_wrapper.OpenSpielWrapper(raw_env) + env = wrappers.SinglePrecisionWrapper(env) + environment_spec = acme.make_environment_spec(env) + + actors = [] + for _ in range(env.num_players): + actors.append(RandomActor(environment_spec)) + + loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(env, actors) + result = loop.run_episode() + self.assertIn('episode_length', result) + self.assertIn('episode_return', result) + self.assertIn('steps_per_second', result) + + loop.run(num_episodes=10) + loop.run(num_steps=100) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/tf/networks/__init__.py b/acme/tf/networks/__init__.py index 9427b025c9..992cbee845 100644 --- a/acme/tf/networks/__init__.py +++ b/acme/tf/networks/__init__.py @@ -34,6 +34,11 @@ from acme.tf.networks.distributional import UnivariateGaussianMixture from acme.tf.networks.distributions import DiscreteValuedDistribution from acme.tf.networks.duelling import DuellingMLP +try: + from acme.tf.networks.legal_actions import MaskedSequential + from acme.tf.networks.legal_actions import EpsilonGreedy +except ImportError: + pass from acme.tf.networks.masked_epsilon_greedy import NetworkWithMaskedEpsilonGreedy from acme.tf.networks.multihead import Multihead from acme.tf.networks.multiplexers import CriticMultiplexer diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py new file mode 100644 index 0000000000..56ca232291 --- /dev/null +++ b/acme/tf/networks/legal_actions.py @@ -0,0 +1,120 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks used for handling illegal actions.""" + +from typing import Any, Callable, Iterable, Union + +# pytype: disable=import-error +from acme.wrappers import open_spiel_wrapper +# pytype: enable=import-error + +import numpy as np +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +class MaskedSequential(snt.Module): + """Applies a legal actions mask to a linear chain of modules / callables. + + It is assumed the trailing dimension of the final layer (representing + action values) is the same as the trailing dimension of legal_actions. + """ + + def __init__(self, + layers: Iterable[Callable[..., Any]] = None, + name: str = 'MaskedSequential'): + super().__init__(name=name) + self._layers = list(layers) if layers is not None else [] + self._illegal_action_penalty = -1e9 + # Note: illegal_action_penalty cannot be -np.inf because trfl's qlearning + # ops utilize a batched_index function that returns NaN whenever -np.inf + # is present among action values. + + def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: + # Extract observation, legal actions, and terminal + outputs = inputs.observation + legal_actions = inputs.legal_actions + terminal = inputs.terminal + + for mod in self._layers: + outputs = mod(outputs) + + # Apply legal actions mask + outputs = tf.where(tf.equal(legal_actions, 1), outputs, + tf.fill(tf.shape(outputs), self._illegal_action_penalty)) + + # When computing the Q-learning target (r_t + d_t * max q_t) we need to + # ensure max q_t = 0 in terminal states. + outputs = tf.where(tf.equal(terminal, 1), tf.zeros_like(outputs), outputs) + + return outputs + + +# TODO Add functionality to support decaying epsilon parameter. +# TODO This is a modified version of trfl's epsilon_greedy() which incorporates +# code from the bug fix described here https://github.com/deepmind/trfl/pull/28 +class EpsilonGreedy(snt.Module): + """Computes an epsilon-greedy distribution over actions. + This policy does the following: + - With probability 1 - epsilon, take the action corresponding to the highest + action value, breaking ties uniformly at random. + - With probability epsilon, take an action uniformly at random. + Args: + epsilon: Exploratory param with value between 0 and 1. + threshold: Action values must exceed this value to qualify as a legal action + and possibly be selected by the policy. + Returns: + policy: tfp.distributions.Categorical distribution representing the policy. + """ + + def __init__(self, + epsilon: Union[tf.Tensor, float], + threshold: float, + name: str = 'EpsilonGreedy'): + super().__init__(name=name) + self._epsilon = tf.Variable(epsilon, trainable=False) + self._threshold = threshold + + def __call__(self, action_values: tf.Tensor) -> tfd.Categorical: + legal_actions_mask = tf.where( + tf.math.less_equal(action_values, self._threshold), + tf.fill(tf.shape(action_values), 0.), + tf.fill(tf.shape(action_values), 1.)) + + # Dithering action distribution. + dither_probs = 1 / tf.reduce_sum(legal_actions_mask, axis=-1, + keepdims=True) * legal_actions_mask + masked_action_values = tf.where(tf.equal(legal_actions_mask, 1), + action_values, + tf.fill(tf.shape(action_values), -np.inf)) + # Greedy action distribution, breaking ties uniformly at random. + max_value = tf.reduce_max(masked_action_values, axis=-1, keepdims=True) + greedy_probs = tf.cast( + tf.equal(action_values * legal_actions_mask, max_value), + action_values.dtype) + + greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) + + # Epsilon-greedy action distribution. + probs = self._epsilon * dither_probs + (1 - self._epsilon) * greedy_probs + + # Make the policy object. + policy = tfd.Categorical(probs=probs) + + return policy diff --git a/acme/wrappers/__init__.py b/acme/wrappers/__init__.py index 9930564e12..24434eb05a 100644 --- a/acme/wrappers/__init__.py +++ b/acme/wrappers/__init__.py @@ -23,6 +23,10 @@ from acme.wrappers.gym_wrapper import GymAtariAdapter from acme.wrappers.gym_wrapper import GymWrapper from acme.wrappers.observation_action_reward import ObservationActionRewardWrapper +try: + from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper +except ImportError: + pass from acme.wrappers.single_precision import SinglePrecisionWrapper from acme.wrappers.step_limit import StepLimitWrapper from acme.wrappers.video import MujocoVideoWrapper diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py new file mode 100644 index 0000000000..8323c4ded4 --- /dev/null +++ b/acme/wrappers/open_spiel_wrapper.py @@ -0,0 +1,147 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wraps an OpenSpiel RL environment to be used as a dm_env environment.""" + +from typing import List, NamedTuple + +from acme import specs +from acme import types +import dm_env +import numpy as np +# pytype: disable=import-error +from open_spiel.python import rl_environment +import pyspiel +# pytype: enable=import-error + + +class OLT(NamedTuple): + """Container for (observation, legal_actions, terminal) tuples.""" + observation: types.Nest + legal_actions: types.Nest + terminal: types.Nest + + +class OpenSpielWrapper(dm_env.Environment): + """Environment wrapper for OpenSpiel RL environments.""" + + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + + def __init__(self, environment: rl_environment.Environment): + self._environment = environment + self._reset_next_step = True + if environment.game.get_type( + ).dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL: + raise ValueError("Currently only supports sequential games.") + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + open_spiel_timestep = self._environment.reset() + observations = self._convert_observation(open_spiel_timestep) + return dm_env.restart(observations) + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + open_spiel_timestep = self._environment.step(action) + + if open_spiel_timestep.step_type == rl_environment.StepType.LAST: + self._reset_next_step = True + + observations = self._convert_observation(open_spiel_timestep) + rewards = np.asarray(open_spiel_timestep.rewards) + discounts = np.asarray(open_spiel_timestep.discounts) + step_type = open_spiel_timestep.step_type + + if step_type == rl_environment.StepType.FIRST: + step_type = dm_env.StepType.FIRST + elif step_type == rl_environment.StepType.MID: + step_type = dm_env.StepType.MID + elif step_type == rl_environment.StepType.LAST: + step_type = dm_env.StepType.LAST + else: + raise ValueError( + "Did not recognize OpenSpiel StepType: {}".format(step_type)) + + return dm_env.TimeStep(observation=observations, + reward=rewards, + discount=discounts, + step_type=step_type) + + # Convert OpenSpiel observation so it's dm_env compatible. Also, the list + # of legal actions must be converted to a legal actions mask. + def _convert_observation( + self, open_spiel_timestep: rl_environment.TimeStep) -> List[OLT]: + observations = [] + for pid in range(self._environment.num_players): + legals = np.zeros(self._environment.game.num_distinct_actions(), + dtype=np.float32) + legals[open_spiel_timestep.observations["legal_actions"][pid]] = 1.0 + player_observation = OLT(observation=np.asarray( + open_spiel_timestep.observations["info_state"][pid], + dtype=np.float32), + legal_actions=legals, + terminal=np.asarray([open_spiel_timestep.last()], + dtype=np.float32)) + observations.append(player_observation) + return observations + + def observation_spec(self) -> OLT: + # Observation spec depends on whether the OpenSpiel environment is using + # observation/information_state tensors. + if self._environment.use_observation: + return OLT(observation=specs.Array( + (self._environment.game.observation_tensor_size(),), np.float32), + legal_actions=specs.Array( + (self._environment.game.num_distinct_actions(),), + np.float32), + terminal=specs.Array((1,), np.float32)) + else: + return OLT(observation=specs.Array( + (self._environment.game.information_state_tensor_size(),), + np.float32), + legal_actions=specs.Array( + (self._environment.game.num_distinct_actions(),), + np.float32), + terminal=specs.Array((1,), np.float32)) + + def action_spec(self) -> specs.DiscreteArray: + return specs.DiscreteArray(self._environment.game.num_distinct_actions()) + + def reward_spec(self) -> specs.BoundedArray: + return specs.BoundedArray((), + np.float32, + minimum=self._environment.game.min_utility(), + maximum=self._environment.game.max_utility()) + + def discount_spec(self) -> specs.BoundedArray: + return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) + + @property + def environment(self) -> rl_environment.Environment: + """Returns the wrapped environment.""" + return self._environment + + @property + def current_player(self) -> int: + return self._environment.get_state.current_player() + + def __getattr__(self, name: str): + """Expose any other attributes of the underlying environment.""" + return getattr(self._environment, name) diff --git a/acme/wrappers/open_spiel_wrapper_test.py b/acme/wrappers/open_spiel_wrapper_test.py new file mode 100644 index 0000000000..a6cc2e24a6 --- /dev/null +++ b/acme/wrappers/open_spiel_wrapper_test.py @@ -0,0 +1,66 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for open_spiel_wrapper.""" + +import unittest +from absl.testing import absltest +from dm_env import specs +import numpy as np + +SKIP_OPEN_SPIEL_TESTS = False +SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' + +try: + # pytype: disable=import-error + from acme.wrappers import open_spiel_wrapper + from open_spiel.python import rl_environment + # pytype: enable=import-error +except ModuleNotFoundError: + SKIP_OPEN_SPIEL_TESTS = True + + +@unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) +class OpenSpielWrapperTest(absltest.TestCase): + + def test_tic_tac_toe(self): + raw_env = rl_environment.Environment('tic_tac_toe') + env = open_spiel_wrapper.OpenSpielWrapper(raw_env) + + # Test converted observation spec. + observation_spec = env.observation_spec() + self.assertEqual(type(observation_spec), open_spiel_wrapper.OLT) + self.assertEqual(type(observation_spec.observation), specs.Array) + self.assertEqual(type(observation_spec.legal_actions), specs.Array) + self.assertEqual(type(observation_spec.terminal), specs.Array) + + # Test converted action spec. + action_spec: specs.DiscreteArray = env.action_spec() + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 8) + self.assertEqual(action_spec.num_values, 9) + self.assertEqual(action_spec.dtype, np.dtype('int32')) + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + _ = env.step([0]) + env.close() + + +if __name__ == '__main__': + absltest.main() diff --git a/examples/open_spiel/run_dqn.py b/examples/open_spiel/run_dqn.py new file mode 100644 index 0000000000..d9756ae513 --- /dev/null +++ b/examples/open_spiel/run_dqn.py @@ -0,0 +1,75 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running DQN on OpenSpiel game in a single process.""" + +from absl import app +from absl import flags + +import acme +from acme import wrappers +from acme.agents.tf import dqn +from acme.environment_loops import open_spiel_environment_loop +from acme.tf.networks import legal_actions +from acme.wrappers import open_spiel_wrapper +from open_spiel.python import rl_environment +import sonnet as snt + +flags.DEFINE_string('game', 'tic_tac_toe', 'Name of the game') +flags.DEFINE_integer('num_players', None, 'Number of players') + +FLAGS = flags.FLAGS + + +def main(_): + # Create an environment and grab the spec. + env_configs = {'players': FLAGS.num_players} if FLAGS.num_players else {} + raw_environment = rl_environment.Environment(FLAGS.game, **env_configs) + + environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment) + environment = wrappers.SinglePrecisionWrapper(environment) + environment_spec = acme.make_environment_spec(environment) + + # Build the networks. + networks = [] + policy_networks = [] + for _ in range(environment.num_players): + network = legal_actions.MaskedSequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, environment_spec.actions.num_values]) + ]) + policy_network = snt.Sequential( + [network, + legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8)]) + networks.append(network) + policy_networks.append(policy_network) + + # Construct the agents. + agents = [] + + for network, policy_network in zip(networks, policy_networks): + agents.append( + dqn.DQN(environment_spec=environment_spec, + network=network, + policy_network=policy_network)) + + # Run the environment loop. + loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( + environment, agents) + loop.run(num_episodes=100000) + + +if __name__ == '__main__': + app.run(main)