From 07906642b8ca7de06c39a30901dcc585a39cf96e Mon Sep 17 00:00:00 2001 From: John Schultz Date: Thu, 5 Nov 2020 23:14:10 -0500 Subject: [PATCH] Fix legal_actions_mask bug in epsilon_greedy(). --- trfl/policy_ops.py | 24 ++++++++++++++++++------ trfl/policy_ops_test.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/trfl/policy_ops.py b/trfl/policy_ops.py index 2cfcba6..97bdde3 100644 --- a/trfl/policy_ops.py +++ b/trfl/policy_ops.py @@ -20,6 +20,7 @@ # Dependency imports +import numpy as np import tensorflow.compat.v1 as tf import tensorflow_probability as tfp @@ -51,7 +52,8 @@ def epsilon_greedy(action_values, epsilon, legal_actions_mask=None): Returns: policy: tfp.distributions.Categorical distribution representing the policy. """ - with tf.name_scope("epsilon_greedy", values=[action_values, epsilon]): + with tf.name_scope("epsilon_greedy", + values=[action_values, epsilon, legal_actions_mask]): # Convert inputs to Tensors if they aren't already. action_values = tf.convert_to_tensor(action_values) @@ -60,17 +62,27 @@ def epsilon_greedy(action_values, epsilon, legal_actions_mask=None): # We compute the action space dynamically. num_actions = tf.cast(tf.shape(action_values)[-1], action_values.dtype) - # Dithering action distribution. if legal_actions_mask is None: + # Dithering action distribution. dither_probs = 1 / num_actions * tf.ones_like(action_values) + # Greedy action distribution, breaking ties uniformly at random. + max_value = tf.reduce_max(action_values, axis=-1, keepdims=True) + greedy_probs = tf.cast(tf.equal(action_values, max_value), + action_values.dtype) else: + legal_actions_mask = tf.convert_to_tensor(legal_actions_mask) + # Dithering action distribution. dither_probs = 1 / tf.reduce_sum( legal_actions_mask, axis=-1, keepdims=True) * legal_actions_mask + masked_action_values = tf.where(tf.equal(legal_actions_mask, 1), + action_values, + tf.fill(tf.shape(action_values), -np.inf)) + # Greedy action distribution, breaking ties uniformly at random. + max_value = tf.reduce_max(masked_action_values, axis=-1, keepdims=True) + greedy_probs = tf.cast( + tf.equal(action_values * legal_actions_mask, max_value), + action_values.dtype) - # Greedy action distribution, breaking ties uniformly at random. - max_value = tf.reduce_max(action_values, axis=-1, keepdims=True) - greedy_probs = tf.cast(tf.equal(action_values, max_value), - action_values.dtype) greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) # Epsilon-greedy action distribution. diff --git a/trfl/policy_ops_test.py b/trfl/policy_ops_test.py index d9dddfa..2a9f2d3 100644 --- a/trfl/policy_ops_test.py +++ b/trfl/policy_ops_test.py @@ -105,6 +105,18 @@ def testLegalActionsMask(self): with self.test_session() as sess: self.assertAllClose(sess.run(result), expected) + def testLegalActionsMask2(self): + action_values = [-0.8, 1., -0.8, -2.0] + legal_actions_mask = [0., 0., 1., 1.] + epsilon = 0.1 + + expected = [0.00, 0.00, 0.95, 0.05] + + result = policy_ops.epsilon_greedy(action_values, epsilon, + legal_actions_mask).probs + with self.test_session() as sess: + self.assertAllClose(sess.run(result), expected) + if __name__ == "__main__": tf.test.main()