From fd809ae38f734a62100ce5d5bfaaff9280d254dc Mon Sep 17 00:00:00 2001 From: John Schultz Date: Tue, 17 Nov 2020 02:22:36 +0000 Subject: [PATCH 01/39] Add OpenSpiel interface. --- acme/open_spiel/__init__.py | 22 ++ acme/open_spiel/agents/agent.py | 132 +++++++++ acme/open_spiel/agents/tf/__init__.py | 13 + acme/open_spiel/agents/tf/actors.py | 101 +++++++ acme/open_spiel/agents/tf/dqn/README.md | 21 ++ acme/open_spiel/agents/tf/dqn/__init__.py | 19 ++ acme/open_spiel/agents/tf/dqn/agent.py | 252 ++++++++++++++++++ acme/open_spiel/agents/tf/dqn/learning.py | 211 +++++++++++++++ acme/open_spiel/examples/run_dqn.py | 75 ++++++ .../open_spiel/open_spiel_environment_loop.py | 197 ++++++++++++++ acme/open_spiel/open_spiel_specs.py | 40 +++ acme/open_spiel/open_spiel_wrapper.py | 140 ++++++++++ 12 files changed, 1223 insertions(+) create mode 100644 acme/open_spiel/__init__.py create mode 100644 acme/open_spiel/agents/agent.py create mode 100644 acme/open_spiel/agents/tf/__init__.py create mode 100644 acme/open_spiel/agents/tf/actors.py create mode 100644 acme/open_spiel/agents/tf/dqn/README.md create mode 100644 acme/open_spiel/agents/tf/dqn/__init__.py create mode 100644 acme/open_spiel/agents/tf/dqn/agent.py create mode 100644 acme/open_spiel/agents/tf/dqn/learning.py create mode 100644 acme/open_spiel/examples/run_dqn.py create mode 100644 acme/open_spiel/open_spiel_environment_loop.py create mode 100644 acme/open_spiel/open_spiel_specs.py create mode 100644 acme/open_spiel/open_spiel_wrapper.py diff --git a/acme/open_spiel/__init__.py b/acme/open_spiel/__init__.py new file mode 100644 index 0000000000..ac6dc86ca2 --- /dev/null +++ b/acme/open_spiel/__init__.py @@ -0,0 +1,22 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Expose specs and types modules. +from acme.open_spiel import open_spiel_specs + +# Expose the environment loop. +from acme.open_spiel import open_spiel_environment_loop + +# Acme loves OpenSpiel. diff --git a/acme/open_spiel/agents/agent.py b/acme/open_spiel/agents/agent.py new file mode 100644 index 0000000000..078ebf38b9 --- /dev/null +++ b/acme/open_spiel/agents/agent.py @@ -0,0 +1,132 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenSpiel agent interface.""" + +from typing import List, Tuple + +from acme import core +from acme import types +from acme import specs +from acme.agents import agent +from acme.open_spiel import open_spiel_specs +from acme.tf import utils as tf2_utils +# Internal imports. + +import dm_env +import numpy as np +from open_spiel.python import rl_environment +import pyspiel +import tensorflow as tf + + +class OpenSpielAgent(agent.Agent): + """Agent class which combines acting and learning.""" + + def __init__(self, + actor: core.Actor, + learner: core.Learner, + min_observations: int, + observations_per_step: float, + player_id: int, + should_update: bool = True): + self._player_id = player_id + self._should_update = should_update + self._observed_first = False + self._prev_action = None + super().__init__(actor=actor, + learner=learner, + min_observations=min_observations, + observations_per_step=observations_per_step) + + def set_update(self, should_update: bool): + self._should_update = should_update + + def select_action(self, observation: types.NestedArray) -> types.NestedArray: + current_player = observation["current_player"] + assert current_player == self._player_id + player_observation = observation["info_state"][current_player] + legal_actions = observation["legal_actions"][current_player] + self._prev_action = self._actor.select_action(player_observation, + legal_actions) + return self._prev_action + + # TODO Eventually remove? Currently used for debugging. + def print_policy(self, observation: types.NestedArray) -> types.NestedArray: + current_player = observation["current_player"] + assert current_player == self._player_id + player_observation = observation["info_state"][current_player] + legal_actions = observation["legal_actions"][current_player] + + batched_observation = tf2_utils.add_batch_dim(player_observation) + policy = self._actor._policy_network(batched_observation, legal_actions) + tf.print("Policy: ", policy.probs, summarize=-1) + + def observe_first(self, timestep: dm_env.TimeStep): + current_player = timestep.observation["current_player"] + assert current_player == self._player_id + timestep = dm_env.TimeStep( + observation=timestep.observation["info_state"][current_player], + reward=None, + discount=None, + step_type=dm_env.StepType.FIRST) + self._actor.observe_first(timestep) + self._observed_first = True + + def observe( + self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + ): + current_player = next_timestep.observation["current_player"] + if current_player == self._player_id: + if not self._observed_first: + self.observe_first(next_timestep) + else: + next_timestep, extras = self._convert_timestep(next_timestep) + self._num_observations += 1 + self._actor.observe(self._prev_action, next_timestep, extras) + if self._should_update: + super().update() + # TODO Note: we must account for situations where the first obs is a + # terminal state, e.g. if an opponent folds in poker before we get to act. + elif current_player == pyspiel.PlayerId.TERMINAL and self._observed_first: + next_timestep, extras = self._convert_timestep(next_timestep) + self._num_observations += 1 + self._actor.observe(self._prev_action, next_timestep, extras) + self._observed_first = False + self._prev_action = None + if self._should_update: + super().update() + else: + # TODO We ignore observations not relevant to this agent. + pass + + # TODO In order to avoid bookkeeping in the environment loop, OpenSpiel agents + # receive full timesteps that contain information for all agents. Here we + # extract the information specific to this agent. + def _convert_timestep( + self, timestep: dm_env.TimeStep + ) -> Tuple[dm_env.TimeStep, open_spiel_specs.Extras]: + legal_actions = timestep.observation["legal_actions"][self._player_id] + terminal = np.array(timestep.last(), dtype=np.float32) + extras = open_spiel_specs.Extras(legal_actions=legal_actions, + terminals=terminal) + converted_timestep = dm_env.TimeStep( + observation=timestep.observation["info_state"][self._player_id], + reward=timestep.reward[self._player_id], + discount=timestep.discount[self._player_id], + step_type=timestep.step_type) + return converted_timestep, extras diff --git a/acme/open_spiel/agents/tf/__init__.py b/acme/open_spiel/agents/tf/__init__.py new file mode 100644 index 0000000000..de867df849 --- /dev/null +++ b/acme/open_spiel/agents/tf/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/acme/open_spiel/agents/tf/actors.py b/acme/open_spiel/agents/tf/actors.py new file mode 100644 index 0000000000..783ded3605 --- /dev/null +++ b/acme/open_spiel/agents/tf/actors.py @@ -0,0 +1,101 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic actor implementation, using TensorFlow and Sonnet.""" + +from typing import Optional, Tuple + +from acme import adders +from acme import core +from acme import types +# Internal imports. +from acme.tf import utils as tf2_utils +from acme.tf import variable_utils as tf2_variable_utils + +import dm_env +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +class FeedForwardActor(core.Actor): + """A feed-forward actor. + + An actor based on a feed-forward policy which takes non-batched observations + and outputs non-batched actions. It also allows adding experiences to replay + and updating the weights from the policy on the learner. + """ + + def __init__( + self, + policy_network: snt.Module, + adder: Optional[adders.Adder] = None, + variable_client: Optional[tf2_variable_utils.VariableClient] = None, + ): + """Initializes the actor. + + Args: + policy_network: the policy to run. + adder: the adder object to which allows to add experiences to a + dataset/replay buffer. + variable_client: object which allows to copy weights from the learner copy + of the policy to the actor copy (in case they are separate). + """ + + # Store these for later use. + self._adder = adder + self._variable_client = variable_client + self._policy_network = policy_network + + @tf.function + def _policy(self, observation: types.NestedTensor, + legal_actions: types.NestedTensor) -> types.NestedTensor: + # Add a dummy batch dimension and as a side effect convert numpy to TF. + batched_observation = tf2_utils.add_batch_dim(observation) + + # Compute the policy, conditioned on the observation. + policy = self._policy_network(batched_observation, legal_actions) + + # Sample from the policy if it is stochastic. + action = policy.sample() if isinstance(policy, tfd.Distribution) else policy + + return action + + def select_action( + self, + observation: types.NestedArray, + legal_actions: Optional[types.NestedTensor] = None) -> types.NestedArray: + # Pass the observation through the policy network. + action = self._policy(observation, legal_actions) + + # Return a numpy array with squeezed out batch dimension. + return tf2_utils.to_numpy_squeeze(action) + + def observe_first(self, timestep: dm_env.TimeStep): + if self._adder: + self._adder.add_first(timestep) + + def observe(self, + action: types.NestedArray, + next_timestep: dm_env.TimeStep, + extras: types.NestedArray = ()): + if self._adder: + self._adder.add(action, next_timestep, extras) + + def update(self, wait: bool = False): + if self._variable_client: + self._variable_client.update(wait) diff --git a/acme/open_spiel/agents/tf/dqn/README.md b/acme/open_spiel/agents/tf/dqn/README.md new file mode 100644 index 0000000000..8303a3d51c --- /dev/null +++ b/acme/open_spiel/agents/tf/dqn/README.md @@ -0,0 +1,21 @@ +# Deep Q-Networks (DQN) + +This folder contains an implementation of the DQN algorithm +([Mnih et al., 2013], [Mnih et al., 2015]), with extras bells & whistles, +similar to Rainbow DQN ([Hessel et al., 2017]). + +- Q-learning with neural network function approximation. The loss is + given by the Huber loss applied to the temporal difference error. +- Target Q' network updated periodically ([Mnih et al., 2015]). +- N-step bootstrapping ([Sutton & Barto, 2018]). +- Double Q-learning ([van Hasselt et al., 2015]). +- Prioritized experience replay ([Schaul et al., 2015]). + + +[Mnih et al., 2013]: https://arxiv.org/abs/1312.5602 +[Mnih et al., 2015]: https://www.nature.com/articles/nature14236 +[van Hasselt et al., 2015]: https://arxiv.org/abs/1509.06461 +[Schaul et al., 2015]: https://arxiv.org/abs/1511.05952 +[Hessel et al., 2017]: https://arxiv.org/abs/1710.02298 +[Horgan et al., 2018]: https://arxiv.org/abs/1803.00933 +[Sutton & Barto, 2018]: http://incompleteideas.net/book/the-book.html diff --git a/acme/open_spiel/agents/tf/dqn/__init__.py b/acme/open_spiel/agents/tf/dqn/__init__.py new file mode 100644 index 0000000000..e995bc506c --- /dev/null +++ b/acme/open_spiel/agents/tf/dqn/__init__.py @@ -0,0 +1,19 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of a deep Q-networks (DQN) agent.""" + +from acme.open_spiel.agents.tf.dqn.agent import DQN +from acme.open_spiel.agents.tf.dqn.learning import DQNLearner diff --git a/acme/open_spiel/agents/tf/dqn/agent.py b/acme/open_spiel/agents/tf/dqn/agent.py new file mode 100644 index 0000000000..fe2151fcee --- /dev/null +++ b/acme/open_spiel/agents/tf/dqn/agent.py @@ -0,0 +1,252 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN agent implementation.""" + +import copy +from typing import Any, Callable, Iterable, Optional, Text + +from acme import datasets +from acme import specs +from acme.adders import reverb as adders +from acme.open_spiel import open_spiel_specs +from acme.open_spiel.agents import agent +from acme.open_spiel.agents.tf import actors +from acme.open_spiel.agents.tf.dqn import learning +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp +import trfl + + +# TODO Import this from trfl once legal_actions_mask bug is fixed. +# See https://github.com/deepmind/trfl/pull/28 +def epsilon_greedy(action_values, epsilon, legal_actions_mask=None): + """Computes an epsilon-greedy distribution over actions. + This returns a categorical distribution over a discrete action space. It is + assumed that the trailing dimension of `action_values` is of length A, i.e. + the number of actions. It is also assumed that actions are 0-indexed. + This policy does the following: + - With probability 1 - epsilon, take the action corresponding to the highest + action value, breaking ties uniformly at random. + - With probability epsilon, take an action uniformly at random. + Args: + action_values: A Tensor of action values with any rank >= 1 and dtype float. + Shape can be flat ([A]), batched ([B, A]), a batch of sequences + ([T, B, A]), and so on. + epsilon: A scalar Tensor (or Python float) with value between 0 and 1. + legal_actions_mask: An optional one-hot tensor having the shame shape and + dtypes as `action_values`, defining the legal actions: + legal_actions_mask[..., a] = 1 if a is legal, 0 otherwise. + If not provided, all actions will be considered legal and + `tf.ones_like(action_values)`. + Returns: + policy: tfp.distributions.Categorical distribution representing the policy. + """ + # Convert inputs to Tensors if they aren't already. + action_values = tf.convert_to_tensor(action_values) + epsilon = tf.convert_to_tensor(epsilon, dtype=action_values.dtype) + + # We compute the action space dynamically. + num_actions = tf.cast(tf.shape(action_values)[-1], action_values.dtype) + + if legal_actions_mask is None: + # Dithering action distribution. + dither_probs = 1 / num_actions * tf.ones_like(action_values) + # Greedy action distribution, breaking ties uniformly at random. + max_value = tf.reduce_max(action_values, axis=-1, keepdims=True) + greedy_probs = tf.cast(tf.equal(action_values, max_value), + action_values.dtype) + else: + legal_actions_mask = tf.convert_to_tensor(legal_actions_mask) + # Dithering action distribution. + dither_probs = 1 / tf.reduce_sum(legal_actions_mask, axis=-1, + keepdims=True) * legal_actions_mask + masked_action_values = tf.where(tf.equal(legal_actions_mask, 1), + action_values, + tf.fill(tf.shape(action_values), -np.inf)) + # Greedy action distribution, breaking ties uniformly at random. + max_value = tf.reduce_max(masked_action_values, axis=-1, keepdims=True) + greedy_probs = tf.cast( + tf.equal(action_values * legal_actions_mask, max_value), + action_values.dtype) + + greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) + + # Epsilon-greedy action distribution. + probs = epsilon * dither_probs + (1 - epsilon) * greedy_probs + + # Make the policy object. + policy = tfp.distributions.Categorical(probs=probs) + + return policy + + +# TODO Move to separate file. +class MaskedSequential(snt.Module): + """Sonnet Module similar to Sequential but masks illegal actions.""" + + def __init__(self, + layers: Iterable[Callable[..., Any]] = None, + epsilon: tf.Tensor = 0.0, + name: Optional[Text] = None): + super(MaskedSequential, self).__init__(name=name) + self._layers = list(layers) if layers is not None else [] + self._epsilon = epsilon + + def __call__(self, inputs, legal_actions_mask): + outputs = inputs + for mod in self._layers: + outputs = mod(outputs) + outputs = epsilon_greedy(outputs, self._epsilon, legal_actions_mask) + return outputs + +class DQN(agent.OpenSpielAgent): + """DQN agent. + + This implements a single-process DQN agent. This is a simple Q-learning + algorithm that inserts N-step transitions into a replay buffer, and + periodically updates its policy by sampling these transitions using + prioritization. + """ + + def __init__( + self, + environment_spec: specs.EnvironmentSpec, + extras_spec: open_spiel_specs.ExtrasSpec, + network: snt.Module, + player_id: int, + batch_size: int = 256, + prefetch_size: int = 4, + target_update_period: int = 100, + samples_per_insert: float = 32.0, + min_replay_size: int = 1000, + max_replay_size: int = 1000000, + importance_sampling_exponent: float = 0.2, + priority_exponent: float = 0.6, + n_step: int = 5, + epsilon: Optional[tf.Tensor] = None, + learning_rate: float = 1e-3, + discount: float = 0.99, + logger: loggers.Logger = None, + checkpoint: bool = True, + checkpoint_subpath: str = '~/acme/', + ): + """Initialize the agent. + + Args: + environment_spec: description of the actions, observations, etc. + network: the online Q network (the one being optimized) + batch_size: batch size for updates. + prefetch_size: size to prefetch from replay. + target_update_period: number of learner steps to perform before updating + the target networks. + samples_per_insert: number of samples to take from replay for every insert + that is made. + min_replay_size: minimum replay size before updating. This and all + following arguments are related to dataset construction and will be + ignored if a dataset argument is passed. + max_replay_size: maximum replay size. + importance_sampling_exponent: power to which importance weights are raised + before normalizing. + priority_exponent: exponent used in prioritized sampling. + n_step: number of steps to squash into a single transition. + epsilon: probability of taking a random action; ignored if a policy + network is given. + learning_rate: learning rate for the q-network update. + discount: discount to use for TD updates. + logger: logger object to be used by learner. + checkpoint: boolean indicating whether to checkpoint the learner. + checkpoint_subpath: directory for the checkpoint. + """ + + # Create a replay server to add data to. This uses no limiter behavior in + # order to allow the Agent interface to handle it. + replay_table = reverb.Table( + name=adders.DEFAULT_PRIORITY_TABLE, + sampler=reverb.selectors.Prioritized(priority_exponent), + remover=reverb.selectors.Fifo(), + max_size=max_replay_size, + rate_limiter=reverb.rate_limiters.MinSize(1), + signature=adders.NStepTransitionAdder.signature(environment_spec, + extras_spec)) + self._server = reverb.Server([replay_table], port=None) + + # The adder is used to insert observations into replay. + address = f'localhost:{self._server.port}' + adder = adders.NStepTransitionAdder( + client=reverb.Client(address), + n_step=n_step, + discount=discount) + + # The dataset provides an interface to sample from replay. + replay_client = reverb.TFClient(address) + dataset = datasets.make_reverb_dataset( + server_address=address, + batch_size=batch_size, + prefetch_size=prefetch_size) + + # Use constant 0.05 epsilon greedy policy by default. + if epsilon is None: + epsilon = tf.Variable(0.05, trainable=False) + policy_network = MaskedSequential([network], epsilon) + + # Create a target network. + target_network = copy.deepcopy(network) + + # Ensure that we create the variables before proceeding (maybe not needed). + tf2_utils.create_variables(network, [environment_spec.observations]) + tf2_utils.create_variables(target_network, [environment_spec.observations]) + + # Create the actor which defines how we take actions. + actor = actors.FeedForwardActor(policy_network, adder) + + # The learner updates the parameters (and initializes them). + learner = learning.DQNLearner( + network=network, + target_network=target_network, + discount=discount, + importance_sampling_exponent=importance_sampling_exponent, + learning_rate=learning_rate, + target_update_period=target_update_period, + dataset=dataset, + replay_client=replay_client, + logger=logger, + checkpoint=checkpoint) + + if checkpoint: + self._checkpointer = tf2_savers.Checkpointer( + directory=checkpoint_subpath, + objects_to_save=learner.state, + subdirectory='dqn_learner', + time_delta_minutes=60.) + else: + self._checkpointer = None + + super().__init__( + actor=actor, + learner=learner, + min_observations=max(batch_size, min_replay_size), + observations_per_step=float(batch_size) / samples_per_insert, + player_id=player_id) + + def update(self): + super().update() + if self._checkpointer is not None: + self._checkpointer.save() diff --git a/acme/open_spiel/agents/tf/dqn/learning.py b/acme/open_spiel/agents/tf/dqn/learning.py new file mode 100644 index 0000000000..d14c825904 --- /dev/null +++ b/acme/open_spiel/agents/tf/dqn/learning.py @@ -0,0 +1,211 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DQN learner implementation.""" + +import time +from typing import Dict, List + +import acme +from acme.adders import reverb as adders +from acme.tf import losses +from acme.tf import savers as tf2_savers +from acme.tf import utils as tf2_utils +from acme.utils import counting +from acme.utils import loggers +import numpy as np +import reverb +import sonnet as snt +import tensorflow as tf +import trfl + +ILLEGAL_ACTION_LOGITS_PENALTY = -1e9 + + +class DQNLearner(acme.Learner, tf2_savers.TFSaveable): + """DQN learner. + + This is the learning component of a DQN agent. It takes a dataset as input + and implements update functionality to learn from this dataset. Optionally + it takes a replay client as well to allow for updating of priorities. + """ + + def __init__( + self, + network: snt.Module, + target_network: snt.Module, + discount: float, + importance_sampling_exponent: float, + learning_rate: float, + target_update_period: int, + dataset: tf.data.Dataset, + huber_loss_parameter: float = 1., + replay_client: reverb.TFClient = None, + counter: counting.Counter = None, + logger: loggers.Logger = None, + checkpoint: bool = True, + ): + """Initializes the learner. + + Args: + network: the online Q network (the one being optimized) + target_network: the target Q critic (which lags behind the online net). + discount: discount to use for TD updates. + importance_sampling_exponent: power to which importance weights are raised + before normalizing. + learning_rate: learning rate for the q-network update. + target_update_period: number of learner steps to perform before updating + the target networks. + dataset: dataset to learn from, whether fixed or from a replay buffer (see + `acme.datasets.reverb.make_dataset` documentation). + huber_loss_parameter: Quadratic-linear boundary for Huber loss. + replay_client: client to replay to allow for updating priorities. + counter: Counter object for (potentially distributed) counting. + logger: Logger object for writing logs to. + checkpoint: boolean indicating whether to checkpoint the learner. + """ + + # Internalise agent components (replay buffer, networks, optimizer). + # TODO(b/155086959): Fix type stubs and remove. + self._iterator = iter(dataset) # pytype: disable=wrong-arg-types + self._network = network + self._target_network = target_network + self._optimizer = snt.optimizers.Adam(learning_rate) + self._replay_client = replay_client + + # Internalise the hyperparameters. + self._discount = discount + self._target_update_period = target_update_period + self._importance_sampling_exponent = importance_sampling_exponent + self._huber_loss_parameter = huber_loss_parameter + + # Learner state. + self._variables: List[List[tf.Tensor]] = [network.trainable_variables] + self._num_steps = tf.Variable(0, dtype=tf.int32) + + # Internalise logging/counting objects. + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + + # Create a snapshotter object. + if checkpoint: + self._snapshotter = tf2_savers.Snapshotter( + objects_to_save={'network': network}, time_delta_minutes=60.) + else: + self._snapshotter = None + + # Do not record timestamps until after the first learning step is done. + # This is to avoid including the time it takes for actors to come online and + # fill the replay buffer. + self._timestamp = None + + @tf.function + def _step(self) -> Dict[str, tf.Tensor]: + """Do a step of SGD and update the priorities.""" + + # Pull out the data needed for updates/priorities. + inputs = next(self._iterator) + o_tm1, a_tm1, r_t, d_t, o_t, e_t = inputs.data + keys, probs = inputs.info[:2] + + legal_actions_t = e_t.legal_actions + terminals_t = e_t.terminals + + with tf.GradientTape() as tape: + # Evaluate our networks. + q_tm1 = self._network(o_tm1) + q_t_value = self._target_network(o_t) + q_t_selector = self._network(o_t) + + # TODO Here we apply the legal actions mask. + illegal_actions = 1 - legal_actions_t + illegal_logits = illegal_actions * ILLEGAL_ACTION_LOGITS_PENALTY + q_t_value = tf.math.add(tf.stop_gradient(q_t_value), illegal_logits) + q_t_selector = tf.math.add(tf.stop_gradient(q_t_selector), illegal_logits) + + # The rewards and discounts have to have the same type as network values. + r_t = tf.cast(r_t, q_tm1.dtype) + # TODO Remove reward clipping in OpenSpiel DQN? + #r_t = tf.clip_by_value(r_t, -1., 1.) + d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount, q_tm1.dtype) + + # Compute the loss. + _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value, + q_t_selector) + loss = losses.huber(extra.td_error, self._huber_loss_parameter) + + # Get the importance weights. + importance_weights = 1. / probs # [B] + importance_weights **= self._importance_sampling_exponent + importance_weights /= tf.reduce_max(importance_weights) + + # Reweight. + loss *= tf.cast(importance_weights, loss.dtype) # [B] + loss = tf.reduce_mean(loss, axis=[0]) # [] + + # Do a step of SGD. + gradients = tape.gradient(loss, self._network.trainable_variables) + self._optimizer.apply(gradients, self._network.trainable_variables) + + # Update the priorities in the replay buffer. + if self._replay_client: + priorities = tf.cast(tf.abs(extra.td_error), tf.float64) + self._replay_client.update_priorities( + table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) + + # Periodically update the target network. + if tf.math.mod(self._num_steps, self._target_update_period) == 0: + for src, dest in zip(self._network.variables, + self._target_network.variables): + dest.assign(src) + self._num_steps.assign_add(1) + + # Report loss & statistics for logging. + fetches = { + 'loss': loss, + } + + return fetches + + def step(self): + # Do a batch of SGD. + result = self._step() + + # Compute elapsed time. + timestamp = time.time() + elapsed_time = timestamp - self._timestamp if self._timestamp else 0 + self._timestamp = timestamp + + # Update our counts and record it. + counts = self._counter.increment(steps=1, walltime=elapsed_time) + result.update(counts) + + # Snapshot and attempt to write logs. + if self._snapshotter is not None: + self._snapshotter.save() + self._logger.write(result) + + def get_variables(self, names: List[str]) -> List[np.ndarray]: + return tf2_utils.to_numpy(self._variables) + + @property + def state(self): + """Returns the stateful parts of the learner for checkpointing.""" + return { + 'network': self._network, + 'target_network': self._target_network, + 'optimizer': self._optimizer, + 'num_steps': self._num_steps + } diff --git a/acme/open_spiel/examples/run_dqn.py b/acme/open_spiel/examples/run_dqn.py new file mode 100644 index 0000000000..15387875c0 --- /dev/null +++ b/acme/open_spiel/examples/run_dqn.py @@ -0,0 +1,75 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example running DQN on OpenSpiel in a single process.""" + +from absl import app +from absl import flags + +import acme +from acme import specs +from acme import wrappers +from acme.open_spiel import open_spiel_environment_loop +from acme.open_spiel import open_spiel_specs +from acme.open_spiel import open_spiel_wrapper +from acme.open_spiel.agents.tf import dqn +from acme.tf import networks +from open_spiel.python import rl_environment +import sonnet as snt + +flags.DEFINE_string("game", "tic_tac_toe", "Name of the game") +flags.DEFINE_integer("num_players", None, "Number of players") + +FLAGS = flags.FLAGS + + +def main(_): + # Create an environment and grab the spec. + env_configs = {"players": FLAGS.num_players} if FLAGS.num_players else {} + raw_environment = rl_environment.Environment(FLAGS.game, **env_configs) + + environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment) + environment = wrappers.SinglePrecisionWrapper(environment) + environment_spec = specs.make_environment_spec(environment) + extras_spec = open_spiel_specs.make_extras_spec(environment) + + network = snt.Sequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, environment_spec.actions.num_values]) + ]) + + # Construct the agent. + agents = [] + + for i in range(environment.num_players): + agents.append( + dqn.DQN( + environment_spec=environment_spec, + extras_spec=extras_spec, + priority_exponent=0.0, # TODO Test priority_exponent. + discount=1.0, + n_step=1, # TODO Appear to be convergence issues when n > 1. + epsilon=0.1, + network=network, + player_id=i)) + + # Run the environment loop. + loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( + environment, agents) + loop.run(num_episodes=100000) # pytype: disable=attribute-error + + +if __name__ == '__main__': + app.run(main) diff --git a/acme/open_spiel/open_spiel_environment_loop.py b/acme/open_spiel/open_spiel_environment_loop.py new file mode 100644 index 0000000000..6aae60b1a8 --- /dev/null +++ b/acme/open_spiel/open_spiel_environment_loop.py @@ -0,0 +1,197 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An OpenSpiel agent-environment training loop.""" + +import operator +import time +from typing import List, Optional + +from acme import core +# Internal imports. +from acme.utils import counting +from acme.utils import loggers + +import dm_env +from dm_env import specs +import numpy as np +import tree + + +class OpenSpielEnvironmentLoop(core.Worker): + """A simple RL environment loop. + + This takes `Environment` and `Actor` instances and coordinates their + interaction. This can be used as: + + loop = EnvironmentLoop(environment, actor) + loop.run(num_episodes) + + A `Counter` instance can optionally be given in order to maintain counts + between different Acme components. If not given a local Counter will be + created to maintain counts between calls to the `run` method. + + A `Logger` instance can also be passed in order to control the output of the + loop. If not given a platform-specific default logger will be used as defined + by utils.loggers.make_default_logger. A string `label` can be passed to easily + change the label associated with the default logger; this is ignored if a + `Logger` instance is given. + """ + + def __init__( + self, + environment: dm_env.Environment, + actors: List[core.Actor], + counter: counting.Counter = None, + logger: loggers.Logger = None, + label: str = 'open_spiel_environment_loop', + ): + # Internalize agent and environment. + self._environment = environment + self._actors = actors + self._counter = counter or counting.Counter() + self._logger = logger or loggers.make_default_logger(label) + + def run_episode(self, verbose=False) -> loggers.LoggingData: + """Run one episode. + + Each episode is a loop which interacts first with the environment to get an + observation and then give that observation to the agent in order to retrieve + an action. + + Returns: + An instance of `loggers.LoggingData`. + """ + # Reset any counts and start the environment. + start_time = time.time() + episode_steps = 0 + + # For evaluation, this keeps track of the total undiscounted reward + # for each player accumulated during the episode. + multiplayer_reward_spec = specs.BoundedArray( + (self._environment._game.num_players(),), + np.float32, + minimum=self._environment._game.min_utility(), + maximum=self._environment._game.max_utility()) + episode_return = tree.map_structure(_generate_zeros_from_spec, + multiplayer_reward_spec) + + timestep = self._environment.reset() + + # Make the first observation. + # TODO Note: OpenSpiel agents handle observe_first() internally. + for actor in self._actors: + actor.observe(None, next_timestep=timestep) + + # Run an episode. + while not timestep.last(): + # Generate an action from the agent's policy and step the environment. + pid = timestep.observation["current_player"] + + if self._environment.is_turn_based: + action_list = [self._actors[pid].select_action(timestep.observation)] + else: + # TODO Test this on simultaneous move games. + agents_output = [agent.step(time_step) for agent in agents] + action_list = [ + actor.select_action(timestep.observation) for actor in self._actors + ] + + # TODO Delete or move to logger? + if verbose: + self._actors[pid].print_policy(timestep.observation) + print("Action: ", action_list[0]) + + timestep = self._environment.step(action_list) + + # TODO Delete or move to logger? + if verbose: + print("State:") + print(str(self._environment._state)) + + # Have the agent observe the timestep and let the actor update itself. + for actor in self._actors: + actor.observe(action_list, next_timestep=timestep) + + # Book-keeping. + episode_steps += 1 + + # Equivalent to: episode_return += timestep.reward + tree.map_structure(operator.iadd, episode_return, timestep.reward) + + # TODO Delete or move to logger? + if verbose: + print("Reward: ", timestep.reward) + + # Record counts. + counts = self._counter.increment(episodes=1, steps=episode_steps) + + # Collect the results and combine with counts. + steps_per_second = episode_steps / (time.time() - start_time) + result = { + 'episode_length': episode_steps, + 'episode_return': episode_return, + 'steps_per_second': steps_per_second, + } + result.update(counts) + return result + + def run(self, + num_episodes: Optional[int] = None, + num_steps: Optional[int] = None): + """Perform the run loop. + + Run the environment loop either for `num_episodes` episodes or for at + least `num_steps` steps (the last episode is always run until completion, + so the total number of steps may be slightly more than `num_steps`). + At least one of these two arguments has to be None. + + Upon termination of an episode a new episode will be started. If the number + of episodes and the number of steps are not given then this will interact + with the environment infinitely. + + Args: + num_episodes: number of episodes to run the loop for. + num_steps: minimal number of steps to run the loop for. + + Raises: + ValueError: If both 'num_episodes' and 'num_steps' are not None. + """ + + if not (num_episodes is None or num_steps is None): + raise ValueError('Either "num_episodes" or "num_steps" should be None.') + + def should_terminate(episode_count: int, step_count: int) -> bool: + return ((num_episodes is not None and episode_count >= num_episodes) or + (num_steps is not None and step_count >= num_steps)) + + episode_count, step_count = 0, 0 + while not should_terminate(episode_count, step_count): + # TODO Remove verbose? + if episode_count % 1000 == 0: + result = self.run_episode(verbose=True) + else: + result = self.run_episode(verbose=False) + episode_count += 1 + step_count += result['episode_length'] + # Log the given results. + self._logger.write(result) + + +def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray: + return np.zeros(spec.shape, spec.dtype) + + +# Internal class. diff --git a/acme/open_spiel/open_spiel_specs.py b/acme/open_spiel/open_spiel_specs.py new file mode 100644 index 0000000000..f73ee92919 --- /dev/null +++ b/acme/open_spiel/open_spiel_specs.py @@ -0,0 +1,40 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Objects which specify the extra information used by an OpenSpiel environment.""" + +from typing import Any, NamedTuple + +from acme.open_spiel import open_spiel_wrapper + + +# TODO Move elsewhere? Not actually a spec. +class Extras(NamedTuple): + """Extras used by a given environment.""" + legal_actions: Any + terminals: Any + + +class ExtrasSpec(NamedTuple): + """Full specification of the extras used by a given environment.""" + legal_actions: Any + terminals: Any + + +def make_extras_spec( + environment: open_spiel_wrapper.OpenSpielWrapper) -> ExtrasSpec: + """Returns an `ExtrasSpec` describing additional values used by OpenSpiel.""" + return ExtrasSpec(legal_actions=environment.legal_actions_spec(), + terminals=environment.terminals_spec()) diff --git a/acme/open_spiel/open_spiel_wrapper.py b/acme/open_spiel/open_spiel_wrapper.py new file mode 100644 index 0000000000..10281aa89e --- /dev/null +++ b/acme/open_spiel/open_spiel_wrapper.py @@ -0,0 +1,140 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wraps an OpenSpiel RL environment to be used as a dm_env environment.""" + +from typing import List + +from acme import specs +from acme import types +import dm_env +import numpy as np +from open_spiel.python import rl_environment +import pyspiel + + +# TODO Wrap the underlying OpenSpiel game directly instead of OpenSpiel's +# rl_environment? +class OpenSpielWrapper(dm_env.Environment): + """Environment wrapper for OpenSpiel RL environments.""" + + # Note: we don't inherit from base.EnvironmentWrapper because that class + # assumes that the wrapped environment is a dm_env.Environment. + + def __init__(self, environment: rl_environment.Environment): + self._environment = environment + self._reset_next_step = True + assert environment._game.get_type( + ).dynamics == pyspiel.GameType.Dynamics.SEQUENTIAL, ( + "Currently only supports sequential games.") + + def reset(self) -> dm_env.TimeStep: + """Resets the episode.""" + self._reset_next_step = False + open_spiel_timestep = self._environment.reset() + observation = self._convert_observation(open_spiel_timestep.observations) + assert open_spiel_timestep.step_type == rl_environment.StepType.FIRST + return dm_env.restart(observation) + + def step(self, action: types.NestedArray) -> dm_env.TimeStep: + """Steps the environment.""" + if self._reset_next_step: + return self.reset() + + open_spiel_timestep = self._environment.step(action) + + if open_spiel_timestep.step_type == rl_environment.StepType.LAST: + self._reset_next_step = True + + observation = self._convert_observation(open_spiel_timestep.observations) + reward = np.asarray(open_spiel_timestep.rewards) + discount = np.asarray(open_spiel_timestep.discounts) + step_type = open_spiel_timestep.step_type + + if step_type == rl_environment.StepType.FIRST: + step_type = dm_env.StepType.FIRST + elif step_type == rl_environment.StepType.MID: + step_type = dm_env.StepType.MID + elif step_type == rl_environment.StepType.LAST: + step_type = dm_env.StepType.LAST + else: + raise ValueError("Did not recognize OpenSpiel StepType") + + return dm_env.TimeStep(observation=observation, + reward=reward, + discount=discount, + step_type=step_type) + + # TODO Convert OpenSpiel observation so it's dm_env compatible. Dm_env + # timesteps allow for dicts and nesting, but they require the leaf elements + # be numpy arrays, whereas OpenSpiel timestep leaf elements are python lists. + # Also, the list of legal actions must be converted to a legal actions mask. + def _convert_observation( + self, open_spiel_observation: types.NestedArray) -> types.NestedArray: + observation = {"info_state": [], "legal_actions": [], "current_player": []} + info_state = [] + for player_info_state in open_spiel_observation["info_state"]: + info_state.append(np.asarray(player_info_state)) + observation["info_state"] = info_state + legal_actions = [] + for indicies in open_spiel_observation["legal_actions"]: + legals = np.zeros(self._environment._game.num_distinct_actions()) + legals[indicies] = 1 + legal_actions.append(legals) + observation["legal_actions"] = legal_actions + observation["current_player"] = self._environment._state.current_player() + return observation + + # TODO These specs describe the timestep that the actor and learner ultimately + # receive, not the timestep that gets passed to the OpenSpiel agent. See + # acme/open_spiel/agents/agent.py for more details. + def observation_spec(self) -> types.NestedSpec: + if self._environment._use_observation: + return specs.Array((self._environment._game.observation_tensor_size(),), + np.float32) + else: + return specs.Array( + (self._environment._game.information_state_tensor_size(),), + np.float32) + + def action_spec(self) -> types.NestedSpec: + return specs.DiscreteArray(self._environment._game.num_distinct_actions()) + + def reward_spec(self) -> types.NestedSpec: + return specs.BoundedArray((), + np.float32, + minimum=self._game.min_utility(), + maximum=self._game.max_utility()) + + def discount_spec(self) -> types.NestedSpec: + return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) + + def legal_actions_spec(self) -> types.NestedSpec: + return specs.BoundedArray((self._environment._game.num_distinct_actions(),), + np.float32, + minimum=0, + maximum=1.0) + + def terminals_spec(self) -> types.NestedSpec: + return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) + + @property + def environment(self): + """Returns the wrapped environment.""" + return self._environment + + def __getattr__(self, name: str): + # Expose any other attributes of the underlying environment. + return getattr(self._environment, name) From a02237ec9a203c9a7e24a690ae53500d0d5aa18c Mon Sep 17 00:00:00 2001 From: John Schultz Date: Sat, 21 Nov 2020 18:04:43 +0000 Subject: [PATCH 02/39] OpenSpiel interface v2. --- acme/agents/tf/dqn/agent.py | 10 +- .../tf => environment_loops}/__init__.py | 4 + .../open_spiel_environment_loop.py | 102 +++++-- acme/open_spiel/__init__.py | 22 -- acme/open_spiel/agents/agent.py | 132 --------- acme/open_spiel/agents/tf/actors.py | 101 ------- acme/open_spiel/agents/tf/dqn/README.md | 21 -- acme/open_spiel/agents/tf/dqn/__init__.py | 19 -- acme/open_spiel/agents/tf/dqn/agent.py | 252 ------------------ acme/open_spiel/agents/tf/dqn/learning.py | 211 --------------- acme/open_spiel/open_spiel_specs.py | 40 --- acme/tf/networks/__init__.py | 2 + acme/tf/networks/legal_actions.py | 117 ++++++++ acme/wrappers/__init__.py | 1 + .../open_spiel_wrapper.py | 95 +++---- .../open_spiel}/run_dqn.py | 29 +- 16 files changed, 269 insertions(+), 889 deletions(-) rename acme/{open_spiel/agents/tf => environment_loops}/__init__.py (82%) rename acme/{open_spiel => environment_loops}/open_spiel_environment_loop.py (61%) delete mode 100644 acme/open_spiel/__init__.py delete mode 100644 acme/open_spiel/agents/agent.py delete mode 100644 acme/open_spiel/agents/tf/actors.py delete mode 100644 acme/open_spiel/agents/tf/dqn/README.md delete mode 100644 acme/open_spiel/agents/tf/dqn/__init__.py delete mode 100644 acme/open_spiel/agents/tf/dqn/agent.py delete mode 100644 acme/open_spiel/agents/tf/dqn/learning.py delete mode 100644 acme/open_spiel/open_spiel_specs.py create mode 100644 acme/tf/networks/legal_actions.py rename acme/{open_spiel => wrappers}/open_spiel_wrapper.py (59%) rename {acme/open_spiel/examples => examples/open_spiel}/run_dqn.py (71%) diff --git a/acme/agents/tf/dqn/agent.py b/acme/agents/tf/dqn/agent.py index c8a00d74e6..f5bed82e7c 100644 --- a/acme/agents/tf/dqn/agent.py +++ b/acme/agents/tf/dqn/agent.py @@ -46,6 +46,7 @@ def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, + policy_network: Optional[snt.Module] = None, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, @@ -118,10 +119,11 @@ def __init__( # Use constant 0.05 epsilon greedy policy by default. if epsilon is None: epsilon = tf.Variable(0.05, trainable=False) - policy_network = snt.Sequential([ - network, - lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), - ]) + if policy_network is None: + policy_network = snt.Sequential([ + network, + lambda q: trfl.epsilon_greedy(q, epsilon=epsilon).sample(), + ]) # Create a target network. target_network = copy.deepcopy(network) diff --git a/acme/open_spiel/agents/tf/__init__.py b/acme/environment_loops/__init__.py similarity index 82% rename from acme/open_spiel/agents/tf/__init__.py rename to acme/environment_loops/__init__.py index de867df849..08747b703a 100644 --- a/acme/open_spiel/agents/tf/__init__.py +++ b/acme/environment_loops/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Specialized environment loops.""" + +from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop diff --git a/acme/open_spiel/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py similarity index 61% rename from acme/open_spiel/open_spiel_environment_loop.py rename to acme/environment_loops/open_spiel_environment_loop.py index 6aae60b1a8..08e225401a 100644 --- a/acme/open_spiel/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An OpenSpiel agent-environment training loop.""" +"""An OpenSpiel multi-agent/environment training loop.""" import operator import time @@ -21,22 +21,24 @@ from acme import core # Internal imports. +from acme.tf import utils as tf2_utils from acme.utils import counting from acme.utils import loggers - import dm_env from dm_env import specs import numpy as np +import pyspiel +import tensorflow as tf import tree class OpenSpielEnvironmentLoop(core.Worker): - """A simple RL environment loop. + """An OpenSpiel RL environment loop. - This takes `Environment` and `Actor` instances and coordinates their - interaction. This can be used as: + This takes `Environment` and list of `Actor` instances and coordinates their + interaction. Agents are updated if `should_update=True`. This can be used as: - loop = EnvironmentLoop(environment, actor) + loop = EnvironmentLoop(environment, actors) loop.run(num_episodes) A `Counter` instance can optionally be given in order to maintain counts @@ -56,6 +58,7 @@ def __init__( actors: List[core.Actor], counter: counting.Counter = None, logger: loggers.Logger = None, + should_update: bool = True, label: str = 'open_spiel_environment_loop', ): # Internalize agent and environment. @@ -63,8 +66,65 @@ def __init__( self._actors = actors self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger(label) - - def run_episode(self, verbose=False) -> loggers.LoggingData: + self._should_update = should_update + + # Track information necessary to coordinate updates among multiple actors. + self._observed_first = [False] * len(self._actors) + self._prev_actions = [None] * len(self._actors) + + def _send_observation(self, timestep: dm_env.TimeStep, player: int): + # If terminal all actors must update + if player == pyspiel.PlayerId.TERMINAL: + for player_id in range(len(self._actors)): + # Note: we must account for situations where the first observation + # is a terminal state, e.g. if an opponent folds in poker before we get + # to act. + if self._observed_first[player_id]: + player_timestep = self._get_player_timestep(timestep, player_id) + self._actors[player_id].observe(self._prev_actions[player_id], + player_timestep) + if self._should_update: + self._actors[player_id].update() + self._observed_first = [False] * len(self._actors) + self._prev_actions = [None] * len(self._actors) + else: + if not self._observed_first[player]: + player_timestep = dm_env.TimeStep( + observation=timestep.observation[player], + reward=None, + discount=None, + step_type=dm_env.StepType.FIRST) + self._actors[player].observe_first(player_timestep) + self._observed_first[player] = True + else: + player_timestep = self._get_player_timestep(timestep, player) + self._actors[player].observe(self._prev_actions[player], + player_timestep) + if self._should_update: + self._actors[player].update() + + def _get_action(self, timestep: dm_env.TimeStep, player: int) -> int: + self._prev_actions[player] = self._actors[player].select_action( + timestep.observation[player]) + return self._prev_actions[player] + + def _get_player_timestep(self, timestep: dm_env.TimeStep, + player: int) -> dm_env.TimeStep: + return dm_env.TimeStep(observation=timestep.observation[player], + reward=timestep.reward[player], + discount=timestep.discount[player], + step_type=timestep.step_type) + + # TODO Remove? Currently used for debugging. + def _print_policy(self, timestep: dm_env.TimeStep, player: int): + batched_observation = tf2_utils.add_batch_dim(timestep.observation[player]) + policy = tf.squeeze( + self._actors[player]._learner._network(batched_observation)) + tf.print(policy, summarize=-1) + tf.print("Greedy action: ", tf.math.argmax(policy)) + + # TODO Remove verbose or add to logger? + def run_episode(self, verbose: bool = False) -> loggers.LoggingData: """Run one episode. Each episode is a loop which interacts first with the environment to get an @@ -91,39 +151,31 @@ def run_episode(self, verbose=False) -> loggers.LoggingData: timestep = self._environment.reset() # Make the first observation. - # TODO Note: OpenSpiel agents handle observe_first() internally. - for actor in self._actors: - actor.observe(None, next_timestep=timestep) + self._send_observation(timestep, self._environment.current_player) # Run an episode. while not timestep.last(): # Generate an action from the agent's policy and step the environment. - pid = timestep.observation["current_player"] - if self._environment.is_turn_based: - action_list = [self._actors[pid].select_action(timestep.observation)] - else: - # TODO Test this on simultaneous move games. - agents_output = [agent.step(time_step) for agent in agents] action_list = [ - actor.select_action(timestep.observation) for actor in self._actors + self._get_action(timestep, self._environment.current_player) ] + else: + # TODO Support simultaneous move games. + raise ValueError("Currently only supports sequential games.") - # TODO Delete or move to logger? if verbose: - self._actors[pid].print_policy(timestep.observation) - print("Action: ", action_list[0]) + self._print_policy(timestep, self._environment.current_player) + print("Selected action: ", action_list[0]) timestep = self._environment.step(action_list) - # TODO Delete or move to logger? if verbose: print("State:") print(str(self._environment._state)) # Have the agent observe the timestep and let the actor update itself. - for actor in self._actors: - actor.observe(action_list, next_timestep=timestep) + self._send_observation(timestep, self._environment.current_player) # Book-keeping. episode_steps += 1 @@ -131,7 +183,6 @@ def run_episode(self, verbose=False) -> loggers.LoggingData: # Equivalent to: episode_return += timestep.reward tree.map_structure(operator.iadd, episode_return, timestep.reward) - # TODO Delete or move to logger? if verbose: print("Reward: ", timestep.reward) @@ -179,7 +230,6 @@ def should_terminate(episode_count: int, step_count: int) -> bool: episode_count, step_count = 0, 0 while not should_terminate(episode_count, step_count): - # TODO Remove verbose? if episode_count % 1000 == 0: result = self.run_episode(verbose=True) else: diff --git a/acme/open_spiel/__init__.py b/acme/open_spiel/__init__.py deleted file mode 100644 index ac6dc86ca2..0000000000 --- a/acme/open_spiel/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# python3 -# Copyright 2018 DeepMind Technologies Limited. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Expose specs and types modules. -from acme.open_spiel import open_spiel_specs - -# Expose the environment loop. -from acme.open_spiel import open_spiel_environment_loop - -# Acme loves OpenSpiel. diff --git a/acme/open_spiel/agents/agent.py b/acme/open_spiel/agents/agent.py deleted file mode 100644 index 078ebf38b9..0000000000 --- a/acme/open_spiel/agents/agent.py +++ /dev/null @@ -1,132 +0,0 @@ -# python3 -# Copyright 2018 DeepMind Technologies Limited. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""OpenSpiel agent interface.""" - -from typing import List, Tuple - -from acme import core -from acme import types -from acme import specs -from acme.agents import agent -from acme.open_spiel import open_spiel_specs -from acme.tf import utils as tf2_utils -# Internal imports. - -import dm_env -import numpy as np -from open_spiel.python import rl_environment -import pyspiel -import tensorflow as tf - - -class OpenSpielAgent(agent.Agent): - """Agent class which combines acting and learning.""" - - def __init__(self, - actor: core.Actor, - learner: core.Learner, - min_observations: int, - observations_per_step: float, - player_id: int, - should_update: bool = True): - self._player_id = player_id - self._should_update = should_update - self._observed_first = False - self._prev_action = None - super().__init__(actor=actor, - learner=learner, - min_observations=min_observations, - observations_per_step=observations_per_step) - - def set_update(self, should_update: bool): - self._should_update = should_update - - def select_action(self, observation: types.NestedArray) -> types.NestedArray: - current_player = observation["current_player"] - assert current_player == self._player_id - player_observation = observation["info_state"][current_player] - legal_actions = observation["legal_actions"][current_player] - self._prev_action = self._actor.select_action(player_observation, - legal_actions) - return self._prev_action - - # TODO Eventually remove? Currently used for debugging. - def print_policy(self, observation: types.NestedArray) -> types.NestedArray: - current_player = observation["current_player"] - assert current_player == self._player_id - player_observation = observation["info_state"][current_player] - legal_actions = observation["legal_actions"][current_player] - - batched_observation = tf2_utils.add_batch_dim(player_observation) - policy = self._actor._policy_network(batched_observation, legal_actions) - tf.print("Policy: ", policy.probs, summarize=-1) - - def observe_first(self, timestep: dm_env.TimeStep): - current_player = timestep.observation["current_player"] - assert current_player == self._player_id - timestep = dm_env.TimeStep( - observation=timestep.observation["info_state"][current_player], - reward=None, - discount=None, - step_type=dm_env.StepType.FIRST) - self._actor.observe_first(timestep) - self._observed_first = True - - def observe( - self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - ): - current_player = next_timestep.observation["current_player"] - if current_player == self._player_id: - if not self._observed_first: - self.observe_first(next_timestep) - else: - next_timestep, extras = self._convert_timestep(next_timestep) - self._num_observations += 1 - self._actor.observe(self._prev_action, next_timestep, extras) - if self._should_update: - super().update() - # TODO Note: we must account for situations where the first obs is a - # terminal state, e.g. if an opponent folds in poker before we get to act. - elif current_player == pyspiel.PlayerId.TERMINAL and self._observed_first: - next_timestep, extras = self._convert_timestep(next_timestep) - self._num_observations += 1 - self._actor.observe(self._prev_action, next_timestep, extras) - self._observed_first = False - self._prev_action = None - if self._should_update: - super().update() - else: - # TODO We ignore observations not relevant to this agent. - pass - - # TODO In order to avoid bookkeeping in the environment loop, OpenSpiel agents - # receive full timesteps that contain information for all agents. Here we - # extract the information specific to this agent. - def _convert_timestep( - self, timestep: dm_env.TimeStep - ) -> Tuple[dm_env.TimeStep, open_spiel_specs.Extras]: - legal_actions = timestep.observation["legal_actions"][self._player_id] - terminal = np.array(timestep.last(), dtype=np.float32) - extras = open_spiel_specs.Extras(legal_actions=legal_actions, - terminals=terminal) - converted_timestep = dm_env.TimeStep( - observation=timestep.observation["info_state"][self._player_id], - reward=timestep.reward[self._player_id], - discount=timestep.discount[self._player_id], - step_type=timestep.step_type) - return converted_timestep, extras diff --git a/acme/open_spiel/agents/tf/actors.py b/acme/open_spiel/agents/tf/actors.py deleted file mode 100644 index 783ded3605..0000000000 --- a/acme/open_spiel/agents/tf/actors.py +++ /dev/null @@ -1,101 +0,0 @@ -# python3 -# Copyright 2018 DeepMind Technologies Limited. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Generic actor implementation, using TensorFlow and Sonnet.""" - -from typing import Optional, Tuple - -from acme import adders -from acme import core -from acme import types -# Internal imports. -from acme.tf import utils as tf2_utils -from acme.tf import variable_utils as tf2_variable_utils - -import dm_env -import sonnet as snt -import tensorflow as tf -import tensorflow_probability as tfp - -tfd = tfp.distributions - - -class FeedForwardActor(core.Actor): - """A feed-forward actor. - - An actor based on a feed-forward policy which takes non-batched observations - and outputs non-batched actions. It also allows adding experiences to replay - and updating the weights from the policy on the learner. - """ - - def __init__( - self, - policy_network: snt.Module, - adder: Optional[adders.Adder] = None, - variable_client: Optional[tf2_variable_utils.VariableClient] = None, - ): - """Initializes the actor. - - Args: - policy_network: the policy to run. - adder: the adder object to which allows to add experiences to a - dataset/replay buffer. - variable_client: object which allows to copy weights from the learner copy - of the policy to the actor copy (in case they are separate). - """ - - # Store these for later use. - self._adder = adder - self._variable_client = variable_client - self._policy_network = policy_network - - @tf.function - def _policy(self, observation: types.NestedTensor, - legal_actions: types.NestedTensor) -> types.NestedTensor: - # Add a dummy batch dimension and as a side effect convert numpy to TF. - batched_observation = tf2_utils.add_batch_dim(observation) - - # Compute the policy, conditioned on the observation. - policy = self._policy_network(batched_observation, legal_actions) - - # Sample from the policy if it is stochastic. - action = policy.sample() if isinstance(policy, tfd.Distribution) else policy - - return action - - def select_action( - self, - observation: types.NestedArray, - legal_actions: Optional[types.NestedTensor] = None) -> types.NestedArray: - # Pass the observation through the policy network. - action = self._policy(observation, legal_actions) - - # Return a numpy array with squeezed out batch dimension. - return tf2_utils.to_numpy_squeeze(action) - - def observe_first(self, timestep: dm_env.TimeStep): - if self._adder: - self._adder.add_first(timestep) - - def observe(self, - action: types.NestedArray, - next_timestep: dm_env.TimeStep, - extras: types.NestedArray = ()): - if self._adder: - self._adder.add(action, next_timestep, extras) - - def update(self, wait: bool = False): - if self._variable_client: - self._variable_client.update(wait) diff --git a/acme/open_spiel/agents/tf/dqn/README.md b/acme/open_spiel/agents/tf/dqn/README.md deleted file mode 100644 index 8303a3d51c..0000000000 --- a/acme/open_spiel/agents/tf/dqn/README.md +++ /dev/null @@ -1,21 +0,0 @@ -# Deep Q-Networks (DQN) - -This folder contains an implementation of the DQN algorithm -([Mnih et al., 2013], [Mnih et al., 2015]), with extras bells & whistles, -similar to Rainbow DQN ([Hessel et al., 2017]). - -- Q-learning with neural network function approximation. The loss is - given by the Huber loss applied to the temporal difference error. -- Target Q' network updated periodically ([Mnih et al., 2015]). -- N-step bootstrapping ([Sutton & Barto, 2018]). -- Double Q-learning ([van Hasselt et al., 2015]). -- Prioritized experience replay ([Schaul et al., 2015]). - - -[Mnih et al., 2013]: https://arxiv.org/abs/1312.5602 -[Mnih et al., 2015]: https://www.nature.com/articles/nature14236 -[van Hasselt et al., 2015]: https://arxiv.org/abs/1509.06461 -[Schaul et al., 2015]: https://arxiv.org/abs/1511.05952 -[Hessel et al., 2017]: https://arxiv.org/abs/1710.02298 -[Horgan et al., 2018]: https://arxiv.org/abs/1803.00933 -[Sutton & Barto, 2018]: http://incompleteideas.net/book/the-book.html diff --git a/acme/open_spiel/agents/tf/dqn/__init__.py b/acme/open_spiel/agents/tf/dqn/__init__.py deleted file mode 100644 index e995bc506c..0000000000 --- a/acme/open_spiel/agents/tf/dqn/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# python3 -# Copyright 2018 DeepMind Technologies Limited. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Implementation of a deep Q-networks (DQN) agent.""" - -from acme.open_spiel.agents.tf.dqn.agent import DQN -from acme.open_spiel.agents.tf.dqn.learning import DQNLearner diff --git a/acme/open_spiel/agents/tf/dqn/agent.py b/acme/open_spiel/agents/tf/dqn/agent.py deleted file mode 100644 index fe2151fcee..0000000000 --- a/acme/open_spiel/agents/tf/dqn/agent.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2018 DeepMind Technologies Limited. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DQN agent implementation.""" - -import copy -from typing import Any, Callable, Iterable, Optional, Text - -from acme import datasets -from acme import specs -from acme.adders import reverb as adders -from acme.open_spiel import open_spiel_specs -from acme.open_spiel.agents import agent -from acme.open_spiel.agents.tf import actors -from acme.open_spiel.agents.tf.dqn import learning -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import loggers -import numpy as np -import reverb -import sonnet as snt -import tensorflow as tf -import tensorflow_probability as tfp -import trfl - - -# TODO Import this from trfl once legal_actions_mask bug is fixed. -# See https://github.com/deepmind/trfl/pull/28 -def epsilon_greedy(action_values, epsilon, legal_actions_mask=None): - """Computes an epsilon-greedy distribution over actions. - This returns a categorical distribution over a discrete action space. It is - assumed that the trailing dimension of `action_values` is of length A, i.e. - the number of actions. It is also assumed that actions are 0-indexed. - This policy does the following: - - With probability 1 - epsilon, take the action corresponding to the highest - action value, breaking ties uniformly at random. - - With probability epsilon, take an action uniformly at random. - Args: - action_values: A Tensor of action values with any rank >= 1 and dtype float. - Shape can be flat ([A]), batched ([B, A]), a batch of sequences - ([T, B, A]), and so on. - epsilon: A scalar Tensor (or Python float) with value between 0 and 1. - legal_actions_mask: An optional one-hot tensor having the shame shape and - dtypes as `action_values`, defining the legal actions: - legal_actions_mask[..., a] = 1 if a is legal, 0 otherwise. - If not provided, all actions will be considered legal and - `tf.ones_like(action_values)`. - Returns: - policy: tfp.distributions.Categorical distribution representing the policy. - """ - # Convert inputs to Tensors if they aren't already. - action_values = tf.convert_to_tensor(action_values) - epsilon = tf.convert_to_tensor(epsilon, dtype=action_values.dtype) - - # We compute the action space dynamically. - num_actions = tf.cast(tf.shape(action_values)[-1], action_values.dtype) - - if legal_actions_mask is None: - # Dithering action distribution. - dither_probs = 1 / num_actions * tf.ones_like(action_values) - # Greedy action distribution, breaking ties uniformly at random. - max_value = tf.reduce_max(action_values, axis=-1, keepdims=True) - greedy_probs = tf.cast(tf.equal(action_values, max_value), - action_values.dtype) - else: - legal_actions_mask = tf.convert_to_tensor(legal_actions_mask) - # Dithering action distribution. - dither_probs = 1 / tf.reduce_sum(legal_actions_mask, axis=-1, - keepdims=True) * legal_actions_mask - masked_action_values = tf.where(tf.equal(legal_actions_mask, 1), - action_values, - tf.fill(tf.shape(action_values), -np.inf)) - # Greedy action distribution, breaking ties uniformly at random. - max_value = tf.reduce_max(masked_action_values, axis=-1, keepdims=True) - greedy_probs = tf.cast( - tf.equal(action_values * legal_actions_mask, max_value), - action_values.dtype) - - greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) - - # Epsilon-greedy action distribution. - probs = epsilon * dither_probs + (1 - epsilon) * greedy_probs - - # Make the policy object. - policy = tfp.distributions.Categorical(probs=probs) - - return policy - - -# TODO Move to separate file. -class MaskedSequential(snt.Module): - """Sonnet Module similar to Sequential but masks illegal actions.""" - - def __init__(self, - layers: Iterable[Callable[..., Any]] = None, - epsilon: tf.Tensor = 0.0, - name: Optional[Text] = None): - super(MaskedSequential, self).__init__(name=name) - self._layers = list(layers) if layers is not None else [] - self._epsilon = epsilon - - def __call__(self, inputs, legal_actions_mask): - outputs = inputs - for mod in self._layers: - outputs = mod(outputs) - outputs = epsilon_greedy(outputs, self._epsilon, legal_actions_mask) - return outputs - -class DQN(agent.OpenSpielAgent): - """DQN agent. - - This implements a single-process DQN agent. This is a simple Q-learning - algorithm that inserts N-step transitions into a replay buffer, and - periodically updates its policy by sampling these transitions using - prioritization. - """ - - def __init__( - self, - environment_spec: specs.EnvironmentSpec, - extras_spec: open_spiel_specs.ExtrasSpec, - network: snt.Module, - player_id: int, - batch_size: int = 256, - prefetch_size: int = 4, - target_update_period: int = 100, - samples_per_insert: float = 32.0, - min_replay_size: int = 1000, - max_replay_size: int = 1000000, - importance_sampling_exponent: float = 0.2, - priority_exponent: float = 0.6, - n_step: int = 5, - epsilon: Optional[tf.Tensor] = None, - learning_rate: float = 1e-3, - discount: float = 0.99, - logger: loggers.Logger = None, - checkpoint: bool = True, - checkpoint_subpath: str = '~/acme/', - ): - """Initialize the agent. - - Args: - environment_spec: description of the actions, observations, etc. - network: the online Q network (the one being optimized) - batch_size: batch size for updates. - prefetch_size: size to prefetch from replay. - target_update_period: number of learner steps to perform before updating - the target networks. - samples_per_insert: number of samples to take from replay for every insert - that is made. - min_replay_size: minimum replay size before updating. This and all - following arguments are related to dataset construction and will be - ignored if a dataset argument is passed. - max_replay_size: maximum replay size. - importance_sampling_exponent: power to which importance weights are raised - before normalizing. - priority_exponent: exponent used in prioritized sampling. - n_step: number of steps to squash into a single transition. - epsilon: probability of taking a random action; ignored if a policy - network is given. - learning_rate: learning rate for the q-network update. - discount: discount to use for TD updates. - logger: logger object to be used by learner. - checkpoint: boolean indicating whether to checkpoint the learner. - checkpoint_subpath: directory for the checkpoint. - """ - - # Create a replay server to add data to. This uses no limiter behavior in - # order to allow the Agent interface to handle it. - replay_table = reverb.Table( - name=adders.DEFAULT_PRIORITY_TABLE, - sampler=reverb.selectors.Prioritized(priority_exponent), - remover=reverb.selectors.Fifo(), - max_size=max_replay_size, - rate_limiter=reverb.rate_limiters.MinSize(1), - signature=adders.NStepTransitionAdder.signature(environment_spec, - extras_spec)) - self._server = reverb.Server([replay_table], port=None) - - # The adder is used to insert observations into replay. - address = f'localhost:{self._server.port}' - adder = adders.NStepTransitionAdder( - client=reverb.Client(address), - n_step=n_step, - discount=discount) - - # The dataset provides an interface to sample from replay. - replay_client = reverb.TFClient(address) - dataset = datasets.make_reverb_dataset( - server_address=address, - batch_size=batch_size, - prefetch_size=prefetch_size) - - # Use constant 0.05 epsilon greedy policy by default. - if epsilon is None: - epsilon = tf.Variable(0.05, trainable=False) - policy_network = MaskedSequential([network], epsilon) - - # Create a target network. - target_network = copy.deepcopy(network) - - # Ensure that we create the variables before proceeding (maybe not needed). - tf2_utils.create_variables(network, [environment_spec.observations]) - tf2_utils.create_variables(target_network, [environment_spec.observations]) - - # Create the actor which defines how we take actions. - actor = actors.FeedForwardActor(policy_network, adder) - - # The learner updates the parameters (and initializes them). - learner = learning.DQNLearner( - network=network, - target_network=target_network, - discount=discount, - importance_sampling_exponent=importance_sampling_exponent, - learning_rate=learning_rate, - target_update_period=target_update_period, - dataset=dataset, - replay_client=replay_client, - logger=logger, - checkpoint=checkpoint) - - if checkpoint: - self._checkpointer = tf2_savers.Checkpointer( - directory=checkpoint_subpath, - objects_to_save=learner.state, - subdirectory='dqn_learner', - time_delta_minutes=60.) - else: - self._checkpointer = None - - super().__init__( - actor=actor, - learner=learner, - min_observations=max(batch_size, min_replay_size), - observations_per_step=float(batch_size) / samples_per_insert, - player_id=player_id) - - def update(self): - super().update() - if self._checkpointer is not None: - self._checkpointer.save() diff --git a/acme/open_spiel/agents/tf/dqn/learning.py b/acme/open_spiel/agents/tf/dqn/learning.py deleted file mode 100644 index d14c825904..0000000000 --- a/acme/open_spiel/agents/tf/dqn/learning.py +++ /dev/null @@ -1,211 +0,0 @@ -# python3 -# Copyright 2018 DeepMind Technologies Limited. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DQN learner implementation.""" - -import time -from typing import Dict, List - -import acme -from acme.adders import reverb as adders -from acme.tf import losses -from acme.tf import savers as tf2_savers -from acme.tf import utils as tf2_utils -from acme.utils import counting -from acme.utils import loggers -import numpy as np -import reverb -import sonnet as snt -import tensorflow as tf -import trfl - -ILLEGAL_ACTION_LOGITS_PENALTY = -1e9 - - -class DQNLearner(acme.Learner, tf2_savers.TFSaveable): - """DQN learner. - - This is the learning component of a DQN agent. It takes a dataset as input - and implements update functionality to learn from this dataset. Optionally - it takes a replay client as well to allow for updating of priorities. - """ - - def __init__( - self, - network: snt.Module, - target_network: snt.Module, - discount: float, - importance_sampling_exponent: float, - learning_rate: float, - target_update_period: int, - dataset: tf.data.Dataset, - huber_loss_parameter: float = 1., - replay_client: reverb.TFClient = None, - counter: counting.Counter = None, - logger: loggers.Logger = None, - checkpoint: bool = True, - ): - """Initializes the learner. - - Args: - network: the online Q network (the one being optimized) - target_network: the target Q critic (which lags behind the online net). - discount: discount to use for TD updates. - importance_sampling_exponent: power to which importance weights are raised - before normalizing. - learning_rate: learning rate for the q-network update. - target_update_period: number of learner steps to perform before updating - the target networks. - dataset: dataset to learn from, whether fixed or from a replay buffer (see - `acme.datasets.reverb.make_dataset` documentation). - huber_loss_parameter: Quadratic-linear boundary for Huber loss. - replay_client: client to replay to allow for updating priorities. - counter: Counter object for (potentially distributed) counting. - logger: Logger object for writing logs to. - checkpoint: boolean indicating whether to checkpoint the learner. - """ - - # Internalise agent components (replay buffer, networks, optimizer). - # TODO(b/155086959): Fix type stubs and remove. - self._iterator = iter(dataset) # pytype: disable=wrong-arg-types - self._network = network - self._target_network = target_network - self._optimizer = snt.optimizers.Adam(learning_rate) - self._replay_client = replay_client - - # Internalise the hyperparameters. - self._discount = discount - self._target_update_period = target_update_period - self._importance_sampling_exponent = importance_sampling_exponent - self._huber_loss_parameter = huber_loss_parameter - - # Learner state. - self._variables: List[List[tf.Tensor]] = [network.trainable_variables] - self._num_steps = tf.Variable(0, dtype=tf.int32) - - # Internalise logging/counting objects. - self._counter = counter or counting.Counter() - self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) - - # Create a snapshotter object. - if checkpoint: - self._snapshotter = tf2_savers.Snapshotter( - objects_to_save={'network': network}, time_delta_minutes=60.) - else: - self._snapshotter = None - - # Do not record timestamps until after the first learning step is done. - # This is to avoid including the time it takes for actors to come online and - # fill the replay buffer. - self._timestamp = None - - @tf.function - def _step(self) -> Dict[str, tf.Tensor]: - """Do a step of SGD and update the priorities.""" - - # Pull out the data needed for updates/priorities. - inputs = next(self._iterator) - o_tm1, a_tm1, r_t, d_t, o_t, e_t = inputs.data - keys, probs = inputs.info[:2] - - legal_actions_t = e_t.legal_actions - terminals_t = e_t.terminals - - with tf.GradientTape() as tape: - # Evaluate our networks. - q_tm1 = self._network(o_tm1) - q_t_value = self._target_network(o_t) - q_t_selector = self._network(o_t) - - # TODO Here we apply the legal actions mask. - illegal_actions = 1 - legal_actions_t - illegal_logits = illegal_actions * ILLEGAL_ACTION_LOGITS_PENALTY - q_t_value = tf.math.add(tf.stop_gradient(q_t_value), illegal_logits) - q_t_selector = tf.math.add(tf.stop_gradient(q_t_selector), illegal_logits) - - # The rewards and discounts have to have the same type as network values. - r_t = tf.cast(r_t, q_tm1.dtype) - # TODO Remove reward clipping in OpenSpiel DQN? - #r_t = tf.clip_by_value(r_t, -1., 1.) - d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount, q_tm1.dtype) - - # Compute the loss. - _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value, - q_t_selector) - loss = losses.huber(extra.td_error, self._huber_loss_parameter) - - # Get the importance weights. - importance_weights = 1. / probs # [B] - importance_weights **= self._importance_sampling_exponent - importance_weights /= tf.reduce_max(importance_weights) - - # Reweight. - loss *= tf.cast(importance_weights, loss.dtype) # [B] - loss = tf.reduce_mean(loss, axis=[0]) # [] - - # Do a step of SGD. - gradients = tape.gradient(loss, self._network.trainable_variables) - self._optimizer.apply(gradients, self._network.trainable_variables) - - # Update the priorities in the replay buffer. - if self._replay_client: - priorities = tf.cast(tf.abs(extra.td_error), tf.float64) - self._replay_client.update_priorities( - table=adders.DEFAULT_PRIORITY_TABLE, keys=keys, priorities=priorities) - - # Periodically update the target network. - if tf.math.mod(self._num_steps, self._target_update_period) == 0: - for src, dest in zip(self._network.variables, - self._target_network.variables): - dest.assign(src) - self._num_steps.assign_add(1) - - # Report loss & statistics for logging. - fetches = { - 'loss': loss, - } - - return fetches - - def step(self): - # Do a batch of SGD. - result = self._step() - - # Compute elapsed time. - timestamp = time.time() - elapsed_time = timestamp - self._timestamp if self._timestamp else 0 - self._timestamp = timestamp - - # Update our counts and record it. - counts = self._counter.increment(steps=1, walltime=elapsed_time) - result.update(counts) - - # Snapshot and attempt to write logs. - if self._snapshotter is not None: - self._snapshotter.save() - self._logger.write(result) - - def get_variables(self, names: List[str]) -> List[np.ndarray]: - return tf2_utils.to_numpy(self._variables) - - @property - def state(self): - """Returns the stateful parts of the learner for checkpointing.""" - return { - 'network': self._network, - 'target_network': self._target_network, - 'optimizer': self._optimizer, - 'num_steps': self._num_steps - } diff --git a/acme/open_spiel/open_spiel_specs.py b/acme/open_spiel/open_spiel_specs.py deleted file mode 100644 index f73ee92919..0000000000 --- a/acme/open_spiel/open_spiel_specs.py +++ /dev/null @@ -1,40 +0,0 @@ -# python3 -# Copyright 2018 DeepMind Technologies Limited. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Objects which specify the extra information used by an OpenSpiel environment.""" - -from typing import Any, NamedTuple - -from acme.open_spiel import open_spiel_wrapper - - -# TODO Move elsewhere? Not actually a spec. -class Extras(NamedTuple): - """Extras used by a given environment.""" - legal_actions: Any - terminals: Any - - -class ExtrasSpec(NamedTuple): - """Full specification of the extras used by a given environment.""" - legal_actions: Any - terminals: Any - - -def make_extras_spec( - environment: open_spiel_wrapper.OpenSpielWrapper) -> ExtrasSpec: - """Returns an `ExtrasSpec` describing additional values used by OpenSpiel.""" - return ExtrasSpec(legal_actions=environment.legal_actions_spec(), - terminals=environment.terminals_spec()) diff --git a/acme/tf/networks/__init__.py b/acme/tf/networks/__init__.py index f099bc2559..5b783d78f3 100644 --- a/acme/tf/networks/__init__.py +++ b/acme/tf/networks/__init__.py @@ -34,6 +34,8 @@ from acme.tf.networks.distributional import UnivariateGaussianMixture from acme.tf.networks.distributions import DiscreteValuedDistribution from acme.tf.networks.duelling import DuellingMLP +from acme.tf.networks.legal_actions import MaskedSequential +from acme.tf.networks.legal_actions import EpsilonGreedy from acme.tf.networks.multihead import Multihead from acme.tf.networks.multiplexers import CriticMultiplexer from acme.tf.networks.noise import ClippedGaussian diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py new file mode 100644 index 0000000000..f8d8b2749c --- /dev/null +++ b/acme/tf/networks/legal_actions.py @@ -0,0 +1,117 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Networks used for handling illegal actions.""" + +from typing import Any, Callable, Iterable, Optional, Text, Union + +import acme +from acme.wrappers import open_spiel_wrapper +import numpy as np +import sonnet as snt +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + + +class MaskedSequential(snt.Module): + """Applies a legal actions mask to a linear chain of modules / callables. + + It is assumed the trailing dimension of the final layer (representing + action values) is the same as the trailing dimension of legal_actions. + """ + + def __init__(self, + layers: Iterable[Callable[..., Any]] = None, + name: Optional[Text] = None): + super(MaskedSequential, self).__init__(name=name) + self._layers = list(layers) if layers is not None else [] + self._illegal_action_penalty = -1e9 + # TODO Note: illegal_action_penalty cannot be -np.inf, throws + # error "Priority must not be NaN" + + def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: + # Extract observation, legal actions, and terminal + outputs = inputs.observation + legal_actions = inputs.legal_actions + terminal = inputs.terminal + + for mod in self._layers: + outputs = mod(outputs) + + # Apply legal actions mask + outputs = tf.where(tf.equal(legal_actions, 1), outputs, + tf.fill(tf.shape(outputs), self._illegal_action_penalty)) + + # When computing the Q-learning target (r_t + d_t * max q_t) we need to + # ensure max q_t = 0 in terminal states. + outputs = outputs * (1 - terminal) + + return outputs + + +# TODO Function to update epsilon +# TODO This is a modified version of trfl's epsilon_greedy() which incorporates +# code from the bug fix described here https://github.com/deepmind/trfl/pull/28 +class EpsilonGreedy(snt.Module): + """Computes an epsilon-greedy distribution over actions. + This policy does the following: + - With probability 1 - epsilon, take the action corresponding to the highest + action value, breaking ties uniformly at random. + - With probability epsilon, take an action uniformly at random. + Args: + epsilon: Exploratory param with value between 0 and 1. + threshold: Action values must exceed this value to qualify as a legal action + and possibly be selected by the policy. + Returns: + policy: tfp.distributions.Categorical distribution representing the policy. + """ + + def __init__(self, + epsilon: Union[tf.Tensor, float], + threshold: float = -np.inf, + name: Optional[Text] = 'EpsilonGreedy'): + super(EpsilonGreedy, self).__init__(name=name) + self._epsilon = tf.Variable(epsilon, trainable=False) + self._threshold = threshold + + def __call__(self, action_values: tf.Tensor) -> tfd.Categorical: + legal_actions_mask = tf.where( + tf.math.less_equal(action_values, self._threshold), + tf.fill(tf.shape(action_values), 0.), + tf.fill(tf.shape(action_values), 1.)) + + # Dithering action distribution. + dither_probs = 1 / tf.reduce_sum(legal_actions_mask, axis=-1, + keepdims=True) * legal_actions_mask + masked_action_values = tf.where(tf.equal(legal_actions_mask, 1), + action_values, + tf.fill(tf.shape(action_values), -np.inf)) + # Greedy action distribution, breaking ties uniformly at random. + max_value = tf.reduce_max(masked_action_values, axis=-1, keepdims=True) + greedy_probs = tf.cast( + tf.equal(action_values * legal_actions_mask, max_value), + action_values.dtype) + + greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) + + # Epsilon-greedy action distribution. + probs = self._epsilon * dither_probs + (1 - self._epsilon) * greedy_probs + + # Make the policy object. + policy = tfd.Categorical(probs=probs) + + return policy diff --git a/acme/wrappers/__init__.py b/acme/wrappers/__init__.py index 9930564e12..9978d29158 100644 --- a/acme/wrappers/__init__.py +++ b/acme/wrappers/__init__.py @@ -23,6 +23,7 @@ from acme.wrappers.gym_wrapper import GymAtariAdapter from acme.wrappers.gym_wrapper import GymWrapper from acme.wrappers.observation_action_reward import ObservationActionRewardWrapper +from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper from acme.wrappers.single_precision import SinglePrecisionWrapper from acme.wrappers.step_limit import StepLimitWrapper from acme.wrappers.video import MujocoVideoWrapper diff --git a/acme/open_spiel/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py similarity index 59% rename from acme/open_spiel/open_spiel_wrapper.py rename to acme/wrappers/open_spiel_wrapper.py index 10281aa89e..27b07bcdb5 100644 --- a/acme/open_spiel/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -15,7 +15,7 @@ """Wraps an OpenSpiel RL environment to be used as a dm_env environment.""" -from typing import List +from typing import List, NamedTuple from acme import specs from acme import types @@ -25,6 +25,13 @@ import pyspiel +class OLT(NamedTuple): + """Container for (observation, legal_actions, terminal) tuples.""" + observation: types.Nest + legal_actions: types.Nest + terminal: types.Nest + + # TODO Wrap the underlying OpenSpiel game directly instead of OpenSpiel's # rl_environment? class OpenSpielWrapper(dm_env.Environment): @@ -44,9 +51,9 @@ def reset(self) -> dm_env.TimeStep: """Resets the episode.""" self._reset_next_step = False open_spiel_timestep = self._environment.reset() - observation = self._convert_observation(open_spiel_timestep.observations) assert open_spiel_timestep.step_type == rl_environment.StepType.FIRST - return dm_env.restart(observation) + observations = self._convert_observation(open_spiel_timestep) + return dm_env.restart(observations) def step(self, action: types.NestedArray) -> dm_env.TimeStep: """Steps the environment.""" @@ -58,9 +65,9 @@ def step(self, action: types.NestedArray) -> dm_env.TimeStep: if open_spiel_timestep.step_type == rl_environment.StepType.LAST: self._reset_next_step = True - observation = self._convert_observation(open_spiel_timestep.observations) - reward = np.asarray(open_spiel_timestep.rewards) - discount = np.asarray(open_spiel_timestep.discounts) + observations = self._convert_observation(open_spiel_timestep) + rewards = np.asarray(open_spiel_timestep.rewards) + discounts = np.asarray(open_spiel_timestep.discounts) step_type = open_spiel_timestep.step_type if step_type == rl_environment.StepType.FIRST: @@ -72,42 +79,45 @@ def step(self, action: types.NestedArray) -> dm_env.TimeStep: else: raise ValueError("Did not recognize OpenSpiel StepType") - return dm_env.TimeStep(observation=observation, - reward=reward, - discount=discount, + return dm_env.TimeStep(observation=observations, + reward=rewards, + discount=discounts, step_type=step_type) - # TODO Convert OpenSpiel observation so it's dm_env compatible. Dm_env - # timesteps allow for dicts and nesting, but they require the leaf elements - # be numpy arrays, whereas OpenSpiel timestep leaf elements are python lists. - # Also, the list of legal actions must be converted to a legal actions mask. + # Convert OpenSpiel observation so it's dm_env compatible. Also, the list + # of legal actions must be converted to a legal actions mask. def _convert_observation( - self, open_spiel_observation: types.NestedArray) -> types.NestedArray: - observation = {"info_state": [], "legal_actions": [], "current_player": []} - info_state = [] - for player_info_state in open_spiel_observation["info_state"]: - info_state.append(np.asarray(player_info_state)) - observation["info_state"] = info_state - legal_actions = [] - for indicies in open_spiel_observation["legal_actions"]: + self, open_spiel_timestep: NamedTuple) -> types.NestedArray: + observations = [] + for pid in range(self._environment.num_players): legals = np.zeros(self._environment._game.num_distinct_actions()) - legals[indicies] = 1 - legal_actions.append(legals) - observation["legal_actions"] = legal_actions - observation["current_player"] = self._environment._state.current_player() - return observation - - # TODO These specs describe the timestep that the actor and learner ultimately - # receive, not the timestep that gets passed to the OpenSpiel agent. See - # acme/open_spiel/agents/agent.py for more details. + legals[open_spiel_timestep.observations["legal_actions"][pid]] = 1 + player_observation = OLT(observation=np.asarray( + open_spiel_timestep.observations["info_state"][pid]), + legal_actions=legals, + terminal=np.asarray( + [float(open_spiel_timestep.last())])) + observations.append(player_observation) + return observations + def observation_spec(self) -> types.NestedSpec: + # Observation spec depends on whether the OpenSpiel environment is using + # observation/information_state tensors if self._environment._use_observation: - return specs.Array((self._environment._game.observation_tensor_size(),), - np.float32) + return OLT(observation=specs.Array( + (self._environment._game.observation_tensor_size(),), np.float32), + legal_actions=specs.Array( + (self._environment._game.num_distinct_actions(),), + np.float32), + terminal=specs.Array((1,), np.float32)) else: - return specs.Array( + return OLT(observation=specs.Array( (self._environment._game.information_state_tensor_size(),), - np.float32) + np.float32), + legal_actions=specs.Array( + (self._environment._game.num_distinct_actions(),), + np.float32), + terminal=specs.Array((1,), np.float32)) def action_spec(self) -> types.NestedSpec: return specs.DiscreteArray(self._environment._game.num_distinct_actions()) @@ -121,20 +131,15 @@ def reward_spec(self) -> types.NestedSpec: def discount_spec(self) -> types.NestedSpec: return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) - def legal_actions_spec(self) -> types.NestedSpec: - return specs.BoundedArray((self._environment._game.num_distinct_actions(),), - np.float32, - minimum=0, - maximum=1.0) - - def terminals_spec(self) -> types.NestedSpec: - return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) - @property - def environment(self): + def environment(self) -> rl_environment.Environment: """Returns the wrapped environment.""" return self._environment + @property + def current_player(self) -> int: + return self._environment._state.current_player() + def __getattr__(self, name: str): - # Expose any other attributes of the underlying environment. + """Expose any other attributes of the underlying environment.""" return getattr(self._environment, name) diff --git a/acme/open_spiel/examples/run_dqn.py b/examples/open_spiel/run_dqn.py similarity index 71% rename from acme/open_spiel/examples/run_dqn.py rename to examples/open_spiel/run_dqn.py index 15387875c0..c2c8911784 100644 --- a/acme/open_spiel/examples/run_dqn.py +++ b/examples/open_spiel/run_dqn.py @@ -13,19 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example running DQN on OpenSpiel in a single process.""" +"""Example running DQN on OpenSpiel game in a single process.""" from absl import app from absl import flags import acme -from acme import specs from acme import wrappers -from acme.open_spiel import open_spiel_environment_loop -from acme.open_spiel import open_spiel_specs -from acme.open_spiel import open_spiel_wrapper -from acme.open_spiel.agents.tf import dqn -from acme.tf import networks +from acme.agents.tf import dqn +from acme.environment_loops import open_spiel_environment_loop +from acme.tf.networks import legal_actions +from acme.wrappers import open_spiel_wrapper from open_spiel.python import rl_environment import sonnet as snt @@ -42,28 +40,27 @@ def main(_): environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment) environment = wrappers.SinglePrecisionWrapper(environment) - environment_spec = specs.make_environment_spec(environment) - extras_spec = open_spiel_specs.make_extras_spec(environment) + environment_spec = acme.make_environment_spec(environment) - network = snt.Sequential([ + network = legal_actions.MaskedSequential([ snt.Flatten(), snt.nets.MLP([50, 50, environment_spec.actions.num_values]) ]) - # Construct the agent. + policy_network = snt.Sequential( + [network, legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8)]) + + # Construct the agents. agents = [] for i in range(environment.num_players): agents.append( dqn.DQN( environment_spec=environment_spec, - extras_spec=extras_spec, - priority_exponent=0.0, # TODO Test priority_exponent. discount=1.0, - n_step=1, # TODO Appear to be convergence issues when n > 1. - epsilon=0.1, + n_step=1, # Note: does indeed converge for n > 1 network=network, - player_id=i)) + policy_network=policy_network)) # Run the environment loop. loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( From c8a625b04f726b72709f8a9f82db1fb6cdfd70d9 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 25 Nov 2020 23:14:34 +0000 Subject: [PATCH 03/39] Implement review feedback. --- .../open_spiel_environment_loop.py | 13 +++++-------- acme/tf/networks/legal_actions.py | 4 ++-- acme/wrappers/open_spiel_wrapper.py | 14 +++++++------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index 08e225401a..3991e805cf 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -17,13 +17,13 @@ import operator import time -from typing import List, Optional +from typing import Optional, Sequence from acme import core -# Internal imports. from acme.tf import utils as tf2_utils from acme.utils import counting from acme.utils import loggers +from acme.wrappers import open_spiel_wrapper import dm_env from dm_env import specs import numpy as np @@ -54,8 +54,8 @@ class OpenSpielEnvironmentLoop(core.Worker): def __init__( self, - environment: dm_env.Environment, - actors: List[core.Actor], + environment: open_spiel_wrapper.OpenSpielWrapper, + actors: Sequence[core.Actor], counter: counting.Counter = None, logger: loggers.Logger = None, should_update: bool = True, @@ -120,7 +120,7 @@ def _print_policy(self, timestep: dm_env.TimeStep, player: int): batched_observation = tf2_utils.add_batch_dim(timestep.observation[player]) policy = tf.squeeze( self._actors[player]._learner._network(batched_observation)) - tf.print(policy, summarize=-1) + tf.print("Policy: ", policy, summarize=-1) tf.print("Greedy action: ", tf.math.argmax(policy)) # TODO Remove verbose or add to logger? @@ -242,6 +242,3 @@ def should_terminate(episode_count: int, step_count: int) -> bool: def _generate_zeros_from_spec(spec: specs.Array) -> np.ndarray: return np.zeros(spec.shape, spec.dtype) - - -# Internal class. diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index f8d8b2749c..0f673ff1d3 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -40,8 +40,8 @@ def __init__(self, super(MaskedSequential, self).__init__(name=name) self._layers = list(layers) if layers is not None else [] self._illegal_action_penalty = -1e9 - # TODO Note: illegal_action_penalty cannot be -np.inf, throws - # error "Priority must not be NaN" + # Note: illegal_action_penalty cannot be -np.inf, throws error "Priority + # must not be NaN" def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: # Extract observation, legal actions, and terminal diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 27b07bcdb5..c2779dafc1 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -15,7 +15,7 @@ """Wraps an OpenSpiel RL environment to be used as a dm_env environment.""" -from typing import List, NamedTuple +from typing import NamedTuple from acme import specs from acme import types @@ -43,15 +43,14 @@ class OpenSpielWrapper(dm_env.Environment): def __init__(self, environment: rl_environment.Environment): self._environment = environment self._reset_next_step = True - assert environment._game.get_type( - ).dynamics == pyspiel.GameType.Dynamics.SEQUENTIAL, ( - "Currently only supports sequential games.") + if environment._game.get_type( + ).dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL: + raise ValueError("Currently only supports sequential games.") def reset(self) -> dm_env.TimeStep: """Resets the episode.""" self._reset_next_step = False open_spiel_timestep = self._environment.reset() - assert open_spiel_timestep.step_type == rl_environment.StepType.FIRST observations = self._convert_observation(open_spiel_timestep) return dm_env.restart(observations) @@ -77,7 +76,8 @@ def step(self, action: types.NestedArray) -> dm_env.TimeStep: elif step_type == rl_environment.StepType.LAST: step_type = dm_env.StepType.LAST else: - raise ValueError("Did not recognize OpenSpiel StepType") + raise ValueError( + "Did not recognize OpenSpiel StepType: {}".format(step_type)) return dm_env.TimeStep(observation=observations, reward=rewards, @@ -87,7 +87,7 @@ def step(self, action: types.NestedArray) -> dm_env.TimeStep: # Convert OpenSpiel observation so it's dm_env compatible. Also, the list # of legal actions must be converted to a legal actions mask. def _convert_observation( - self, open_spiel_timestep: NamedTuple) -> types.NestedArray: + self, open_spiel_timestep: rl_environment.TimeStep) -> types.NestedArray: observations = [] for pid in range(self._environment.num_players): legals = np.zeros(self._environment._game.num_distinct_actions()) From 4ea4d785331f0b1961881855ebcb5f4d069ff980 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 25 Nov 2020 23:18:39 +0000 Subject: [PATCH 04/39] Fix run_dqn.py example so each agent gets its own network. --- examples/open_spiel/run_dqn.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/open_spiel/run_dqn.py b/examples/open_spiel/run_dqn.py index c2c8911784..1c8bbaab12 100644 --- a/examples/open_spiel/run_dqn.py +++ b/examples/open_spiel/run_dqn.py @@ -42,25 +42,28 @@ def main(_): environment = wrappers.SinglePrecisionWrapper(environment) environment_spec = acme.make_environment_spec(environment) - network = legal_actions.MaskedSequential([ - snt.Flatten(), - snt.nets.MLP([50, 50, environment_spec.actions.num_values]) - ]) - - policy_network = snt.Sequential( - [network, legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8)]) + # Build the networks. + networks = [] + policy_networks = [] + for i in range(environment.num_players): + network = legal_actions.MaskedSequential([ + snt.Flatten(), + snt.nets.MLP([50, 50, environment_spec.actions.num_values]) + ]) + policy_network = snt.Sequential( + [network, + legal_actions.EpsilonGreedy(epsilon=0.1, threshold=-1e8)]) + networks.append(network) + policy_networks.append(policy_network) # Construct the agents. agents = [] for i in range(environment.num_players): agents.append( - dqn.DQN( - environment_spec=environment_spec, - discount=1.0, - n_step=1, # Note: does indeed converge for n > 1 - network=network, - policy_network=policy_network)) + dqn.DQN(environment_spec=environment_spec, + network=networks[i], + policy_network=policy_networks[i])) # Run the environment loop. loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( From 55e0ce775e6e2575c4ffae60dd56cf5b6a8216ec Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 7 Jan 2021 22:58:43 +0000 Subject: [PATCH 05/39] Disable pytest import errors. --- acme/environment_loops/open_spiel_environment_loop.py | 1 + acme/tf/networks/legal_actions.py | 2 ++ acme/wrappers/open_spiel_wrapper.py | 1 + 3 files changed, 4 insertions(+) diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index 3991e805cf..99fd6b110d 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -19,6 +19,7 @@ import time from typing import Optional, Sequence +# pytype: disable=import-error from acme import core from acme.tf import utils as tf2_utils from acme.utils import counting diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 0f673ff1d3..4486acb945 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -18,7 +18,9 @@ from typing import Any, Callable, Iterable, Optional, Text, Union import acme +# pytype: disable=import-error from acme.wrappers import open_spiel_wrapper +# pytype: enable=import-error import numpy as np import sonnet as snt import tensorflow as tf diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index c2779dafc1..508d8fed91 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -17,6 +17,7 @@ from typing import NamedTuple +# pytype: disable=import-error from acme import specs from acme import types import dm_env From 8d0f41e0c05066fa6a839abf1f78ede5104dd724 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Fri, 8 Jan 2021 01:08:58 +0000 Subject: [PATCH 06/39] Remove duplicate policy_network argument. --- acme/agents/tf/dqn/agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/acme/agents/tf/dqn/agent.py b/acme/agents/tf/dqn/agent.py index 13f7783602..1437a02699 100644 --- a/acme/agents/tf/dqn/agent.py +++ b/acme/agents/tf/dqn/agent.py @@ -46,7 +46,6 @@ def __init__( self, environment_spec: specs.EnvironmentSpec, network: snt.Module, - policy_network: Optional[snt.Module] = None, batch_size: int = 256, prefetch_size: int = 4, target_update_period: int = 100, From 1f9251d6c549ffeb97f13a9354730705bd568a60 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Fri, 8 Jan 2021 01:10:26 +0000 Subject: [PATCH 07/39] Set prev_actions default to INVALID_ACTION instead of None. --- acme/environment_loops/open_spiel_environment_loop.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index 99fd6b110d..8bcb349ffd 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -71,7 +71,7 @@ def __init__( # Track information necessary to coordinate updates among multiple actors. self._observed_first = [False] * len(self._actors) - self._prev_actions = [None] * len(self._actors) + self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) def _send_observation(self, timestep: dm_env.TimeStep, player: int): # If terminal all actors must update @@ -87,7 +87,7 @@ def _send_observation(self, timestep: dm_env.TimeStep, player: int): if self._should_update: self._actors[player_id].update() self._observed_first = [False] * len(self._actors) - self._prev_actions = [None] * len(self._actors) + self._prev_actions = [pyspiel.INVALID_ACTION] * len(self._actors) else: if not self._observed_first[player]: player_timestep = dm_env.TimeStep( From 5f0191fd13b3432f1c09ebc4eca1321d8d70bab3 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Fri, 8 Jan 2021 01:23:16 +0000 Subject: [PATCH 08/39] Remove verbose printouts from env loop. --- .../open_spiel_environment_loop.py | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index 8bcb349ffd..81d480440f 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -116,16 +116,7 @@ def _get_player_timestep(self, timestep: dm_env.TimeStep, discount=timestep.discount[player], step_type=timestep.step_type) - # TODO Remove? Currently used for debugging. - def _print_policy(self, timestep: dm_env.TimeStep, player: int): - batched_observation = tf2_utils.add_batch_dim(timestep.observation[player]) - policy = tf.squeeze( - self._actors[player]._learner._network(batched_observation)) - tf.print("Policy: ", policy, summarize=-1) - tf.print("Greedy action: ", tf.math.argmax(policy)) - - # TODO Remove verbose or add to logger? - def run_episode(self, verbose: bool = False) -> loggers.LoggingData: + def run_episode(self) -> loggers.LoggingData: """Run one episode. Each episode is a loop which interacts first with the environment to get an @@ -165,16 +156,8 @@ def run_episode(self, verbose: bool = False) -> loggers.LoggingData: # TODO Support simultaneous move games. raise ValueError("Currently only supports sequential games.") - if verbose: - self._print_policy(timestep, self._environment.current_player) - print("Selected action: ", action_list[0]) - timestep = self._environment.step(action_list) - if verbose: - print("State:") - print(str(self._environment._state)) - # Have the agent observe the timestep and let the actor update itself. self._send_observation(timestep, self._environment.current_player) @@ -184,9 +167,6 @@ def run_episode(self, verbose: bool = False) -> loggers.LoggingData: # Equivalent to: episode_return += timestep.reward tree.map_structure(operator.iadd, episode_return, timestep.reward) - if verbose: - print("Reward: ", timestep.reward) - # Record counts. counts = self._counter.increment(episodes=1, steps=episode_steps) @@ -231,10 +211,7 @@ def should_terminate(episode_count: int, step_count: int) -> bool: episode_count, step_count = 0, 0 while not should_terminate(episode_count, step_count): - if episode_count % 1000 == 0: - result = self.run_episode(verbose=True) - else: - result = self.run_episode(verbose=False) + result = self.run_episode() episode_count += 1 step_count += result['episode_length'] # Log the given results. From 5c20c2399bf7e72e2f8c51fb22e5fe1bfc7b2e64 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Fri, 8 Jan 2021 02:49:44 +0000 Subject: [PATCH 09/39] Comment out OpenSpielWrapper from __init__.py so pytest passes. --- acme/wrappers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/wrappers/__init__.py b/acme/wrappers/__init__.py index 9978d29158..4697285b8b 100644 --- a/acme/wrappers/__init__.py +++ b/acme/wrappers/__init__.py @@ -23,7 +23,7 @@ from acme.wrappers.gym_wrapper import GymAtariAdapter from acme.wrappers.gym_wrapper import GymWrapper from acme.wrappers.observation_action_reward import ObservationActionRewardWrapper -from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper +#from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper from acme.wrappers.single_precision import SinglePrecisionWrapper from acme.wrappers.step_limit import StepLimitWrapper from acme.wrappers.video import MujocoVideoWrapper From 35d434f45b6ebf82b7b1475cb65153b1d0e23a49 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Fri, 8 Jan 2021 02:49:44 +0000 Subject: [PATCH 10/39] Comment out files in __init__.py with OpenSpiel imports so pytest passes. --- acme/tf/networks/__init__.py | 4 ++-- acme/wrappers/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/acme/tf/networks/__init__.py b/acme/tf/networks/__init__.py index 2fc342f56f..dc02d089a6 100644 --- a/acme/tf/networks/__init__.py +++ b/acme/tf/networks/__init__.py @@ -34,8 +34,8 @@ from acme.tf.networks.distributional import UnivariateGaussianMixture from acme.tf.networks.distributions import DiscreteValuedDistribution from acme.tf.networks.duelling import DuellingMLP -from acme.tf.networks.legal_actions import MaskedSequential -from acme.tf.networks.legal_actions import EpsilonGreedy +#from acme.tf.networks.legal_actions import MaskedSequential +#from acme.tf.networks.legal_actions import EpsilonGreedy from acme.tf.networks.masked_epsilon_greedy import NetworkWithMaskedEpsilonGreedy from acme.tf.networks.multihead import Multihead from acme.tf.networks.multiplexers import CriticMultiplexer diff --git a/acme/wrappers/__init__.py b/acme/wrappers/__init__.py index 9978d29158..4697285b8b 100644 --- a/acme/wrappers/__init__.py +++ b/acme/wrappers/__init__.py @@ -23,7 +23,7 @@ from acme.wrappers.gym_wrapper import GymAtariAdapter from acme.wrappers.gym_wrapper import GymWrapper from acme.wrappers.observation_action_reward import ObservationActionRewardWrapper -from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper +#from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper from acme.wrappers.single_precision import SinglePrecisionWrapper from acme.wrappers.step_limit import StepLimitWrapper from acme.wrappers.video import MujocoVideoWrapper From 2ccd84bd3ab0ced7c6dca3f60a8c55860e49d754 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Fri, 8 Jan 2021 19:17:07 +0000 Subject: [PATCH 11/39] Add OpenSpiel wrapper test. --- acme/wrappers/open_spiel_wrapper_test.py | 67 ++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 acme/wrappers/open_spiel_wrapper_test.py diff --git a/acme/wrappers/open_spiel_wrapper_test.py b/acme/wrappers/open_spiel_wrapper_test.py new file mode 100644 index 0000000000..fb16c1a64d --- /dev/null +++ b/acme/wrappers/open_spiel_wrapper_test.py @@ -0,0 +1,67 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for open_spiel_wrapper.""" + +import unittest +from absl.testing import absltest +from absl.testing import parameterized +from dm_env import specs +import numpy as np + +SKIP_OPEN_SPIEL_TESTS = False +SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' + +try: + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from acme.wrappers import open_spiel_wrapper + from open_spiel.python import rl_environment +except ModuleNotFoundError: + SKIP_OPEN_SPIEL_TESTS = True + + +@unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) +class OpenSpielWrapperTest(parameterized.TestCase): + + def test_tic_tac_toe(self): + raw_env = rl_environment.Environment('tic_tac_toe') + env = open_spiel_wrapper.OpenSpielWrapper(raw_env) + + # Test converted observation spec. + observation_spec = env.observation_spec() + self.assertEqual(type(observation_spec), open_spiel_wrapper.OLT) + self.assertEqual(type(observation_spec.observation), specs.Array) + self.assertEqual(type(observation_spec.legal_actions), specs.Array) + self.assertEqual(type(observation_spec.terminal), specs.Array) + + # Test converted action spec. + action_spec: specs.DiscreteArray = env.action_spec() + self.assertEqual(type(action_spec), specs.DiscreteArray) + self.assertEqual(action_spec.shape, ()) + self.assertEqual(action_spec.minimum, 0) + self.assertEqual(action_spec.maximum, 8) + self.assertEqual(action_spec.num_values, 9) + self.assertEqual(action_spec.dtype, np.dtype('int32')) + + # Test step. + timestep = env.reset() + self.assertTrue(timestep.first()) + _ = env.step([0]) + env.close() + + +if __name__ == '__main__': + absltest.main() From 785fb868289307251bdf96fad8eabb366c23f16b Mon Sep 17 00:00:00 2001 From: John Schultz Date: Sat, 9 Jan 2021 01:52:14 +0000 Subject: [PATCH 12/39] Add OpenSpiel env loop test. --- acme/environment_loops/__init__.py | 2 +- .../open_spiel_environment_loop_test.py | 102 ++++++++++++++++++ 2 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 acme/environment_loops/open_spiel_environment_loop_test.py diff --git a/acme/environment_loops/__init__.py b/acme/environment_loops/__init__.py index 08747b703a..c1fb195e95 100644 --- a/acme/environment_loops/__init__.py +++ b/acme/environment_loops/__init__.py @@ -14,4 +14,4 @@ """Specialized environment loops.""" -from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop +#from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop diff --git a/acme/environment_loops/open_spiel_environment_loop_test.py b/acme/environment_loops/open_spiel_environment_loop_test.py new file mode 100644 index 0000000000..fd8fa58256 --- /dev/null +++ b/acme/environment_loops/open_spiel_environment_loop_test.py @@ -0,0 +1,102 @@ +# python3 +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OpenSpiel environment loop.""" + +import unittest +from absl.testing import absltest +from absl.testing import parameterized + +import acme +from acme import core +from acme.testing import fakes +from acme import specs +from acme import types +from acme import wrappers + +import dm_env +import numpy as np +import tree + +SKIP_OPEN_SPIEL_TESTS = False +SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' + +try: + # pylint: disable=g-import-not-at-top + # pytype: disable=import-error + from acme.environment_loops import open_spiel_environment_loop + from acme.wrappers import open_spiel_wrapper + from open_spiel.python import rl_environment + + class RandomActor(core.Actor): + """Fake actor which generates random actions and validates specs.""" + + def __init__(self, spec: specs.EnvironmentSpec): + self._spec = spec + self.num_updates = 0 + + def select_action(self, observation: open_spiel_wrapper.OLT) -> int: + _validate_spec(self._spec.observations, observation) + legals = np.array(np.nonzero(observation.legal_actions), dtype=np.int32) + return np.random.choice(legals[0]) + + def observe_first(self, timestep: dm_env.TimeStep): + _validate_spec(self._spec.observations, timestep.observation) + + def observe(self, action: types.NestedArray, + next_timestep: dm_env.TimeStep): + _validate_spec(self._spec.actions, action) + _validate_spec(self._spec.rewards, next_timestep.reward) + _validate_spec(self._spec.discounts, next_timestep.discount) + _validate_spec(self._spec.observations, next_timestep.observation) + + def update(self, wait: bool = False): + self.num_updates += 1 + +except ModuleNotFoundError: + SKIP_OPEN_SPIEL_TESTS = True + + +def _validate_spec(spec: types.NestedSpec, value: types.NestedArray): + """Validate a value from a potentially nested spec.""" + tree.assert_same_structure(value, spec) + tree.map_structure(lambda s, v: s.validate(v), spec, value) + + +@unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) +class OpenSpielEnvironmentLoopTest(parameterized.TestCase): + + def test_loop_run(self): + raw_env = rl_environment.Environment('tic_tac_toe') + env = open_spiel_wrapper.OpenSpielWrapper(raw_env) + env = wrappers.SinglePrecisionWrapper(env) + environment_spec = acme.make_environment_spec(env) + + actors = [] + for i in range(env.num_players): + actors.append(RandomActor(environment_spec)) + + loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(env, actors) + result = loop.run_episode() + self.assertIn('episode_length', result) + self.assertIn('episode_return', result) + self.assertIn('steps_per_second', result) + + loop.run(num_episodes=10) + loop.run(num_steps=100) + + +if __name__ == '__main__': + absltest.main() From 99b57d0d35371c0bdb4df2c8901f367ae0f0797c Mon Sep 17 00:00:00 2001 From: John Schultz Date: Sat, 9 Jan 2021 03:46:23 +0000 Subject: [PATCH 13/39] Remove unnecessary TODO. --- acme/wrappers/open_spiel_wrapper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 508d8fed91..23ec03ca01 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -33,8 +33,6 @@ class OLT(NamedTuple): terminal: types.Nest -# TODO Wrap the underlying OpenSpiel game directly instead of OpenSpiel's -# rl_environment? class OpenSpielWrapper(dm_env.Environment): """Environment wrapper for OpenSpiel RL environments.""" From 592002a8910ffd5340a1b57c9cf9aa87d794fc51 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 20:55:07 +0000 Subject: [PATCH 14/39] Remove unnecessary pytype disable. --- examples/open_spiel/run_dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/open_spiel/run_dqn.py b/examples/open_spiel/run_dqn.py index 1c8bbaab12..a7aed3bcfc 100644 --- a/examples/open_spiel/run_dqn.py +++ b/examples/open_spiel/run_dqn.py @@ -68,7 +68,7 @@ def main(_): # Run the environment loop. loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( environment, agents) - loop.run(num_episodes=100000) # pytype: disable=attribute-error + loop.run(num_episodes=100000) if __name__ == '__main__': From e7b77a508e645c4d53075cf4685b76fe80974d5b Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 20:59:46 +0000 Subject: [PATCH 15/39] Refactor loop for constructing agents. --- examples/open_spiel/run_dqn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/open_spiel/run_dqn.py b/examples/open_spiel/run_dqn.py index a7aed3bcfc..ff952c2162 100644 --- a/examples/open_spiel/run_dqn.py +++ b/examples/open_spiel/run_dqn.py @@ -59,11 +59,11 @@ def main(_): # Construct the agents. agents = [] - for i in range(environment.num_players): + for network, policy_network in zip(networks, policy_networks): agents.append( dqn.DQN(environment_spec=environment_spec, - network=networks[i], - policy_network=policy_networks[i])) + network=network, + policy_network=policy_network)) # Run the environment loop. loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop( From 74dfddfc4ec8e179639db906344cacca68dcff8d Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 21:17:20 +0000 Subject: [PATCH 16/39] Change TestCase base class of OpenSpielWrapperTest. --- acme/wrappers/open_spiel_wrapper_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/acme/wrappers/open_spiel_wrapper_test.py b/acme/wrappers/open_spiel_wrapper_test.py index fb16c1a64d..16bec7b8a2 100644 --- a/acme/wrappers/open_spiel_wrapper_test.py +++ b/acme/wrappers/open_spiel_wrapper_test.py @@ -17,7 +17,6 @@ import unittest from absl.testing import absltest -from absl.testing import parameterized from dm_env import specs import numpy as np @@ -34,7 +33,7 @@ @unittest.skipIf(SKIP_OPEN_SPIEL_TESTS, SKIP_OPEN_SPIEL_MESSAGE) -class OpenSpielWrapperTest(parameterized.TestCase): +class OpenSpielWrapperTest(absltest.TestCase): def test_tic_tac_toe(self): raw_env = rl_environment.Environment('tic_tac_toe') From 0328e630252da579d80b33439c7df18bcf81cd0e Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 21:37:12 +0000 Subject: [PATCH 17/39] Re-enable pytype import-error. --- acme/wrappers/open_spiel_wrapper_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/acme/wrappers/open_spiel_wrapper_test.py b/acme/wrappers/open_spiel_wrapper_test.py index 16bec7b8a2..cb6f75e36b 100644 --- a/acme/wrappers/open_spiel_wrapper_test.py +++ b/acme/wrappers/open_spiel_wrapper_test.py @@ -28,6 +28,7 @@ # pytype: disable=import-error from acme.wrappers import open_spiel_wrapper from open_spiel.python import rl_environment + # pytype: enable=import-error except ModuleNotFoundError: SKIP_OPEN_SPIEL_TESTS = True From 0e538c78fbbb2aa55c1c8fdb53ad7a32fc99929b Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 21:41:07 +0000 Subject: [PATCH 18/39] Minor cosmetic changes. --- examples/open_spiel/run_dqn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/open_spiel/run_dqn.py b/examples/open_spiel/run_dqn.py index ff952c2162..d9756ae513 100644 --- a/examples/open_spiel/run_dqn.py +++ b/examples/open_spiel/run_dqn.py @@ -27,15 +27,15 @@ from open_spiel.python import rl_environment import sonnet as snt -flags.DEFINE_string("game", "tic_tac_toe", "Name of the game") -flags.DEFINE_integer("num_players", None, "Number of players") +flags.DEFINE_string('game', 'tic_tac_toe', 'Name of the game') +flags.DEFINE_integer('num_players', None, 'Number of players') FLAGS = flags.FLAGS def main(_): # Create an environment and grab the spec. - env_configs = {"players": FLAGS.num_players} if FLAGS.num_players else {} + env_configs = {'players': FLAGS.num_players} if FLAGS.num_players else {} raw_environment = rl_environment.Environment(FLAGS.game, **env_configs) environment = open_spiel_wrapper.OpenSpielWrapper(raw_environment) @@ -45,7 +45,7 @@ def main(_): # Build the networks. networks = [] policy_networks = [] - for i in range(environment.num_players): + for _ in range(environment.num_players): network = legal_actions.MaskedSequential([ snt.Flatten(), snt.nets.MLP([50, 50, environment_spec.actions.num_values]) From 381692c41a179312d75cf913dacaad5a303b0fee Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 21:47:22 +0000 Subject: [PATCH 19/39] Narrow down pytype disable import-error and re-enable. --- acme/wrappers/open_spiel_wrapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 23ec03ca01..318f284721 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -17,13 +17,14 @@ from typing import NamedTuple -# pytype: disable=import-error from acme import specs from acme import types import dm_env import numpy as np +# pytype: disable=import-error from open_spiel.python import rl_environment import pyspiel +# pytype: enable=import-error class OLT(NamedTuple): From fdbf4c01cf058d4191f35d6a8103ccf5062cc33a Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 22:08:16 +0000 Subject: [PATCH 20/39] Access OpenSpiel env properties instead of private attributes. --- acme/wrappers/open_spiel_wrapper.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 318f284721..06b6541023 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -90,7 +90,7 @@ def _convert_observation( self, open_spiel_timestep: rl_environment.TimeStep) -> types.NestedArray: observations = [] for pid in range(self._environment.num_players): - legals = np.zeros(self._environment._game.num_distinct_actions()) + legals = np.zeros(self._environment.game.num_distinct_actions()) legals[open_spiel_timestep.observations["legal_actions"][pid]] = 1 player_observation = OLT(observation=np.asarray( open_spiel_timestep.observations["info_state"][pid]), @@ -102,25 +102,25 @@ def _convert_observation( def observation_spec(self) -> types.NestedSpec: # Observation spec depends on whether the OpenSpiel environment is using - # observation/information_state tensors + # observation/information_state tensors. if self._environment._use_observation: return OLT(observation=specs.Array( - (self._environment._game.observation_tensor_size(),), np.float32), + (self._environment.game.observation_tensor_size(),), np.float32), legal_actions=specs.Array( - (self._environment._game.num_distinct_actions(),), + (self._environment.game.num_distinct_actions(),), np.float32), terminal=specs.Array((1,), np.float32)) else: return OLT(observation=specs.Array( - (self._environment._game.information_state_tensor_size(),), + (self._environment.game.information_state_tensor_size(),), np.float32), legal_actions=specs.Array( - (self._environment._game.num_distinct_actions(),), + (self._environment.game.num_distinct_actions(),), np.float32), terminal=specs.Array((1,), np.float32)) def action_spec(self) -> types.NestedSpec: - return specs.DiscreteArray(self._environment._game.num_distinct_actions()) + return specs.DiscreteArray(self._environment.game.num_distinct_actions()) def reward_spec(self) -> types.NestedSpec: return specs.BoundedArray((), @@ -138,7 +138,7 @@ def environment(self) -> rl_environment.Environment: @property def current_player(self) -> int: - return self._environment._state.current_player() + return self._environment.get_state.current_player() def __getattr__(self, name: str): """Expose any other attributes of the underlying environment.""" From 429b1ec2039728cb6cee97b4e09524526f54c9a5 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 22:30:59 +0000 Subject: [PATCH 21/39] Remove unused imports. --- acme/environment_loops/open_spiel_environment_loop.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index 81d480440f..16e1c49b27 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -21,7 +21,6 @@ # pytype: disable=import-error from acme import core -from acme.tf import utils as tf2_utils from acme.utils import counting from acme.utils import loggers from acme.wrappers import open_spiel_wrapper @@ -29,7 +28,6 @@ from dm_env import specs import numpy as np import pyspiel -import tensorflow as tf import tree From 2592165405f66fc3fc7d561185fb36b60e93db8e Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 22:31:52 +0000 Subject: [PATCH 22/39] Use consistent quote delimiter formatting. --- acme/environment_loops/open_spiel_environment_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index 16e1c49b27..2e6d0d1c83 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -152,7 +152,7 @@ def run_episode(self) -> loggers.LoggingData: ] else: # TODO Support simultaneous move games. - raise ValueError("Currently only supports sequential games.") + raise ValueError('Currently only supports sequential games.') timestep = self._environment.step(action_list) From 4de0f90769373471bff5bc5af6276b1bbeca2bc3 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 22:40:18 +0000 Subject: [PATCH 23/39] Access OpenSpiel env properties instead of private attributes. --- acme/wrappers/open_spiel_wrapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 06b6541023..7cc61a469d 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -43,7 +43,7 @@ class OpenSpielWrapper(dm_env.Environment): def __init__(self, environment: rl_environment.Environment): self._environment = environment self._reset_next_step = True - if environment._game.get_type( + if environment.game.get_type( ).dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL: raise ValueError("Currently only supports sequential games.") @@ -125,8 +125,8 @@ def action_spec(self) -> types.NestedSpec: def reward_spec(self) -> types.NestedSpec: return specs.BoundedArray((), np.float32, - minimum=self._game.min_utility(), - maximum=self._game.max_utility()) + minimum=self._environment.game.min_utility(), + maximum=self._environment.game.max_utility()) def discount_spec(self) -> types.NestedSpec: return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) From 17cb075a630b1617eb45c0ecee4f302d7a63fa57 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 22:41:32 +0000 Subject: [PATCH 24/39] Access OpenSpiel env properties instead of private attributes. --- acme/environment_loops/open_spiel_environment_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index 2e6d0d1c83..fb50f866d1 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -131,10 +131,10 @@ def run_episode(self) -> loggers.LoggingData: # For evaluation, this keeps track of the total undiscounted reward # for each player accumulated during the episode. multiplayer_reward_spec = specs.BoundedArray( - (self._environment._game.num_players(),), + (self._environment.game.num_players(),), np.float32, - minimum=self._environment._game.min_utility(), - maximum=self._environment._game.max_utility()) + minimum=self._environment.game.min_utility(), + maximum=self._environment.game.max_utility()) episode_return = tree.map_structure(_generate_zeros_from_spec, multiplayer_reward_spec) From d0811ef1e08f90aceb1ab4ea8b9da41ad64b3cc6 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 22:51:51 +0000 Subject: [PATCH 25/39] Minor cosmetic changes. --- acme/environment_loops/open_spiel_environment_loop_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/acme/environment_loops/open_spiel_environment_loop_test.py b/acme/environment_loops/open_spiel_environment_loop_test.py index fd8fa58256..092e7316ff 100644 --- a/acme/environment_loops/open_spiel_environment_loop_test.py +++ b/acme/environment_loops/open_spiel_environment_loop_test.py @@ -21,7 +21,6 @@ import acme from acme import core -from acme.testing import fakes from acme import specs from acme import types from acme import wrappers @@ -39,6 +38,7 @@ from acme.environment_loops import open_spiel_environment_loop from acme.wrappers import open_spiel_wrapper from open_spiel.python import rl_environment + # pytype: disable=import-error class RandomActor(core.Actor): """Fake actor which generates random actions and validates specs.""" @@ -85,7 +85,7 @@ def test_loop_run(self): environment_spec = acme.make_environment_spec(env) actors = [] - for i in range(env.num_players): + for _ in range(env.num_players): actors.append(RandomActor(environment_spec)) loop = open_spiel_environment_loop.OpenSpielEnvironmentLoop(env, actors) From d57c5161377b9d4a871d6d9b091b97dabfe8726e Mon Sep 17 00:00:00 2001 From: John Schultz Date: Wed, 13 Jan 2021 23:25:17 +0000 Subject: [PATCH 26/39] Specify np.array dtypes. --- acme/wrappers/open_spiel_wrapper.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 7cc61a469d..3b3a0af4dc 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -90,13 +90,15 @@ def _convert_observation( self, open_spiel_timestep: rl_environment.TimeStep) -> types.NestedArray: observations = [] for pid in range(self._environment.num_players): - legals = np.zeros(self._environment.game.num_distinct_actions()) - legals[open_spiel_timestep.observations["legal_actions"][pid]] = 1 + legals = np.zeros(self._environment.game.num_distinct_actions(), + dtype=np.float32) + legals[open_spiel_timestep.observations["legal_actions"][pid]] = 1.0 player_observation = OLT(observation=np.asarray( - open_spiel_timestep.observations["info_state"][pid]), + open_spiel_timestep.observations["info_state"][pid], + dtype=np.float32), legal_actions=legals, - terminal=np.asarray( - [float(open_spiel_timestep.last())])) + terminal=np.asarray([open_spiel_timestep.last()], + dtype=np.float32)) observations.append(player_observation) return observations From b95a28c23c79efb13f277c8aec89e7a007cd4f11 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 00:02:12 +0000 Subject: [PATCH 27/39] Wrap __init__.py imports in try-except block. --- acme/environment_loops/__init__.py | 5 ++++- acme/tf/networks/__init__.py | 7 +++++-- acme/wrappers/__init__.py | 5 ++++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/acme/environment_loops/__init__.py b/acme/environment_loops/__init__.py index c1fb195e95..b018025c9b 100644 --- a/acme/environment_loops/__init__.py +++ b/acme/environment_loops/__init__.py @@ -14,4 +14,7 @@ """Specialized environment loops.""" -#from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop +try: + from acme.environment_loops.open_spiel_environment_loop import OpenSpielEnvironmentLoop +except ImportError: + pass diff --git a/acme/tf/networks/__init__.py b/acme/tf/networks/__init__.py index dc02d089a6..992cbee845 100644 --- a/acme/tf/networks/__init__.py +++ b/acme/tf/networks/__init__.py @@ -34,8 +34,11 @@ from acme.tf.networks.distributional import UnivariateGaussianMixture from acme.tf.networks.distributions import DiscreteValuedDistribution from acme.tf.networks.duelling import DuellingMLP -#from acme.tf.networks.legal_actions import MaskedSequential -#from acme.tf.networks.legal_actions import EpsilonGreedy +try: + from acme.tf.networks.legal_actions import MaskedSequential + from acme.tf.networks.legal_actions import EpsilonGreedy +except ImportError: + pass from acme.tf.networks.masked_epsilon_greedy import NetworkWithMaskedEpsilonGreedy from acme.tf.networks.multihead import Multihead from acme.tf.networks.multiplexers import CriticMultiplexer diff --git a/acme/wrappers/__init__.py b/acme/wrappers/__init__.py index 4697285b8b..24434eb05a 100644 --- a/acme/wrappers/__init__.py +++ b/acme/wrappers/__init__.py @@ -23,7 +23,10 @@ from acme.wrappers.gym_wrapper import GymAtariAdapter from acme.wrappers.gym_wrapper import GymWrapper from acme.wrappers.observation_action_reward import ObservationActionRewardWrapper -#from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper +try: + from acme.wrappers.open_spiel_wrapper import OpenSpielWrapper +except ImportError: + pass from acme.wrappers.single_precision import SinglePrecisionWrapper from acme.wrappers.step_limit import StepLimitWrapper from acme.wrappers.video import MujocoVideoWrapper From 4a3677c6828c366dd34b8823da0abe277881056b Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 00:21:05 +0000 Subject: [PATCH 28/39] Remove unnecessary pylint disable. --- acme/environment_loops/open_spiel_environment_loop_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/acme/environment_loops/open_spiel_environment_loop_test.py b/acme/environment_loops/open_spiel_environment_loop_test.py index 092e7316ff..adecdbd426 100644 --- a/acme/environment_loops/open_spiel_environment_loop_test.py +++ b/acme/environment_loops/open_spiel_environment_loop_test.py @@ -33,7 +33,6 @@ SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' try: - # pylint: disable=g-import-not-at-top # pytype: disable=import-error from acme.environment_loops import open_spiel_environment_loop from acme.wrappers import open_spiel_wrapper From 7853707b8d5180d898f0a3a1ea7f43e50c75a16a Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 00:25:45 +0000 Subject: [PATCH 29/39] Narrow down pytype disable import-error and re-enable. --- acme/environment_loops/open_spiel_environment_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/acme/environment_loops/open_spiel_environment_loop.py b/acme/environment_loops/open_spiel_environment_loop.py index fb50f866d1..9584be4bab 100644 --- a/acme/environment_loops/open_spiel_environment_loop.py +++ b/acme/environment_loops/open_spiel_environment_loop.py @@ -19,7 +19,6 @@ import time from typing import Optional, Sequence -# pytype: disable=import-error from acme import core from acme.utils import counting from acme.utils import loggers @@ -27,7 +26,9 @@ import dm_env from dm_env import specs import numpy as np +# pytype: disable=import-error import pyspiel +# pytype: enable=import-error import tree From fa986b20be5244dcc72b49bcbda50dbe4adf0cf3 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 16:09:33 +0000 Subject: [PATCH 30/39] Refactor handling of terminal states. --- acme/tf/networks/legal_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 4486acb945..36c78f7231 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -60,7 +60,7 @@ def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: # When computing the Q-learning target (r_t + d_t * max q_t) we need to # ensure max q_t = 0 in terminal states. - outputs = outputs * (1 - terminal) + outputs = tf.where(tf.equal(terminal, 1), tf.zeros_like(outputs), outputs) return outputs From 920d28233a03ff642989a9859d9da8be87a6a540 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 16:13:51 +0000 Subject: [PATCH 31/39] Clarify note regarding use of -np.inf. --- acme/tf/networks/legal_actions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 36c78f7231..87fa4e509b 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -42,8 +42,9 @@ def __init__(self, super(MaskedSequential, self).__init__(name=name) self._layers = list(layers) if layers is not None else [] self._illegal_action_penalty = -1e9 - # Note: illegal_action_penalty cannot be -np.inf, throws error "Priority - # must not be NaN" + # Note: illegal_action_penalty cannot be -np.inf because trfl's qlearning + # ops utilize a batched_index function that returns NaN whenever -np.inf + # is present among action values. def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: # Extract observation, legal actions, and terminal From c5b0070c1ddf25b6494918bc45a5f979a7b1fbb5 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 16:30:42 +0000 Subject: [PATCH 32/39] Remove unused import. --- acme/tf/networks/legal_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 87fa4e509b..866488ff8f 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -17,10 +17,10 @@ from typing import Any, Callable, Iterable, Optional, Text, Union -import acme # pytype: disable=import-error from acme.wrappers import open_spiel_wrapper # pytype: enable=import-error + import numpy as np import sonnet as snt import tensorflow as tf From 6141f1ad13ef39062913d9b7e26b09f326711288 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 16:31:55 +0000 Subject: [PATCH 33/39] Remove default threshold argument. --- acme/tf/networks/legal_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 866488ff8f..95cdd37f39 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -85,7 +85,7 @@ class EpsilonGreedy(snt.Module): def __init__(self, epsilon: Union[tf.Tensor, float], - threshold: float = -np.inf, + threshold: float, name: Optional[Text] = 'EpsilonGreedy'): super(EpsilonGreedy, self).__init__(name=name) self._epsilon = tf.Variable(epsilon, trainable=False) From f676b8a631956e0d7880f6a05eb984c50e2d7f32 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 16:38:46 +0000 Subject: [PATCH 34/39] Pylint formatting changes. --- acme/tf/networks/legal_actions.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 95cdd37f39..8fa4be61ed 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -15,7 +15,7 @@ """Networks used for handling illegal actions.""" -from typing import Any, Callable, Iterable, Optional, Text, Union +from typing import Any, Callable, Iterable, Union # pytype: disable=import-error from acme.wrappers import open_spiel_wrapper @@ -38,8 +38,8 @@ class MaskedSequential(snt.Module): def __init__(self, layers: Iterable[Callable[..., Any]] = None, - name: Optional[Text] = None): - super(MaskedSequential, self).__init__(name=name) + name: str = 'MaskedSequential'): + super().__init__(name=name) self._layers = list(layers) if layers is not None else [] self._illegal_action_penalty = -1e9 # Note: illegal_action_penalty cannot be -np.inf because trfl's qlearning @@ -86,8 +86,8 @@ class EpsilonGreedy(snt.Module): def __init__(self, epsilon: Union[tf.Tensor, float], threshold: float, - name: Optional[Text] = 'EpsilonGreedy'): - super(EpsilonGreedy, self).__init__(name=name) + name: str = 'EpsilonGreedy'): + super().__init__(name=name) self._epsilon = tf.Variable(epsilon, trainable=False) self._threshold = threshold From ff9b7addd5ee83c333f5c0725672673446f01477 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 16:48:34 +0000 Subject: [PATCH 35/39] Clarify TODO. --- acme/tf/networks/legal_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/tf/networks/legal_actions.py b/acme/tf/networks/legal_actions.py index 8fa4be61ed..56ca232291 100644 --- a/acme/tf/networks/legal_actions.py +++ b/acme/tf/networks/legal_actions.py @@ -66,7 +66,7 @@ def __call__(self, inputs: open_spiel_wrapper.OLT) -> tf.Tensor: return outputs -# TODO Function to update epsilon +# TODO Add functionality to support decaying epsilon parameter. # TODO This is a modified version of trfl's epsilon_greedy() which incorporates # code from the bug fix described here https://github.com/deepmind/trfl/pull/28 class EpsilonGreedy(snt.Module): From 5caa0aca0d0e968d1b0c93e544473849d8ae0f4b Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 17:16:08 +0000 Subject: [PATCH 36/39] Remove unnecessary pylint disable. --- acme/wrappers/open_spiel_wrapper_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/acme/wrappers/open_spiel_wrapper_test.py b/acme/wrappers/open_spiel_wrapper_test.py index cb6f75e36b..a6cc2e24a6 100644 --- a/acme/wrappers/open_spiel_wrapper_test.py +++ b/acme/wrappers/open_spiel_wrapper_test.py @@ -24,7 +24,6 @@ SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' try: - # pylint: disable=g-import-not-at-top # pytype: disable=import-error from acme.wrappers import open_spiel_wrapper from open_spiel.python import rl_environment From b46960d4daa944473ab733e61c3c737b02470617 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 14 Jan 2021 17:16:08 +0000 Subject: [PATCH 37/39] Remove unnecessary pylint disable. --- acme/wrappers/open_spiel_wrapper_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/acme/wrappers/open_spiel_wrapper_test.py b/acme/wrappers/open_spiel_wrapper_test.py index cb6f75e36b..a6cc2e24a6 100644 --- a/acme/wrappers/open_spiel_wrapper_test.py +++ b/acme/wrappers/open_spiel_wrapper_test.py @@ -24,7 +24,6 @@ SKIP_OPEN_SPIEL_MESSAGE = 'open_spiel not installed.' try: - # pylint: disable=g-import-not-at-top # pytype: disable=import-error from acme.wrappers import open_spiel_wrapper from open_spiel.python import rl_environment From 092aabafc150d28c7bdd7824f271217b6462924f Mon Sep 17 00:00:00 2001 From: John Schultz Date: Fri, 15 Jan 2021 14:09:49 +0000 Subject: [PATCH 38/39] Access environment._use_observation attribute via property. --- acme/wrappers/open_spiel_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 3b3a0af4dc..760df05bf9 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -105,7 +105,7 @@ def _convert_observation( def observation_spec(self) -> types.NestedSpec: # Observation spec depends on whether the OpenSpiel environment is using # observation/information_state tensors. - if self._environment._use_observation: + if self._environment.use_observation: return OLT(observation=specs.Array( (self._environment.game.observation_tensor_size(),), np.float32), legal_actions=specs.Array( From e00393a7a88e9f3684775bfa16f75d61efc0a417 Mon Sep 17 00:00:00 2001 From: John Schultz Date: Fri, 15 Jan 2021 16:14:18 +0000 Subject: [PATCH 39/39] Fix type hints. --- acme/wrappers/open_spiel_wrapper.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/acme/wrappers/open_spiel_wrapper.py b/acme/wrappers/open_spiel_wrapper.py index 760df05bf9..8323c4ded4 100644 --- a/acme/wrappers/open_spiel_wrapper.py +++ b/acme/wrappers/open_spiel_wrapper.py @@ -15,7 +15,7 @@ """Wraps an OpenSpiel RL environment to be used as a dm_env environment.""" -from typing import NamedTuple +from typing import List, NamedTuple from acme import specs from acme import types @@ -87,7 +87,7 @@ def step(self, action: types.NestedArray) -> dm_env.TimeStep: # Convert OpenSpiel observation so it's dm_env compatible. Also, the list # of legal actions must be converted to a legal actions mask. def _convert_observation( - self, open_spiel_timestep: rl_environment.TimeStep) -> types.NestedArray: + self, open_spiel_timestep: rl_environment.TimeStep) -> List[OLT]: observations = [] for pid in range(self._environment.num_players): legals = np.zeros(self._environment.game.num_distinct_actions(), @@ -102,7 +102,7 @@ def _convert_observation( observations.append(player_observation) return observations - def observation_spec(self) -> types.NestedSpec: + def observation_spec(self) -> OLT: # Observation spec depends on whether the OpenSpiel environment is using # observation/information_state tensors. if self._environment.use_observation: @@ -121,16 +121,16 @@ def observation_spec(self) -> types.NestedSpec: np.float32), terminal=specs.Array((1,), np.float32)) - def action_spec(self) -> types.NestedSpec: + def action_spec(self) -> specs.DiscreteArray: return specs.DiscreteArray(self._environment.game.num_distinct_actions()) - def reward_spec(self) -> types.NestedSpec: + def reward_spec(self) -> specs.BoundedArray: return specs.BoundedArray((), np.float32, minimum=self._environment.game.min_utility(), maximum=self._environment.game.max_utility()) - def discount_spec(self) -> types.NestedSpec: + def discount_spec(self) -> specs.BoundedArray: return specs.BoundedArray((), np.float32, minimum=0, maximum=1.0) @property