-
Notifications
You must be signed in to change notification settings - Fork 386
Open
Description
Hi, it seems like the documented returns shapes for the following functions might be off:
retrace_ops.retrace(...)retrace_ops.retrace_core(...)retrace_ops._general_off_policy_corrected_multistep_target(...)
The first two are documented to return shape [B] and third shape [T, B, num_actions], while they all appear to return [T, B].
Some test code to check.
import numpy as np
import tensorflow as tf
from trfl import retrace_ops, indexing_ops
### Example input data:
# https://github.com/deepmind/trfl/blob/08ccb293edb929d6002786f1c0c177ef291f2956/trfl/retrace_ops_test.py#L41
lambda_ = 0.9
qs = [
[[2.2, 3.2, 4.2],
[5.2, 6.2, 7.2]],
[[7.2, 6.2, 5.2],
[4.2, 3.2, 2.2]],
[[3.2, 5.2, 7.2],
[4.2, 6.2, 9.2]],
[[2.2, 8.2, 4.2],
[9.2, 1.2, 8.2]]
]
targnet_qs = [
[[2., 3., 4.],
[5., 6., 7.]],
[[7., 6., 5.],
[4., 3., 2.]],
[[3., 5., 7.],
[4., 6., 9.]],
[[2., 8., 4.],
[9., 1., 8.]]
]
actions = [
[2, 0],
[1, 2],
[0, 1],
[2, 0]
]
rewards = [
[1.9, 2.9],
[3.9, 4.9],
[5.9, 6.9],
[np.nan, np.nan] # nan marks entries we should never use.
]
pcontinues = [
[0.8, 0.9],
[0.7, 0.8],
[0.6, 0.5],
[np.nan, np.nan]
]
target_policy_probs = [
[[np.nan] * 3,
[np.nan] * 3],
[[0.41, 0.28, 0.31],
[0.19, 0.77, 0.04]],
[[0.22, 0.44, 0.34],
[0.14, 0.25, 0.61]],
[[0.16, 0.72, 0.12],
[0.33, 0.30, 0.37]]
]
behaviour_policy_probs = [
[np.nan, np.nan],
[0.85, 0.86],
[0.87, 0.88],
[0.89, 0.84]
]
### Retrace Test: ###
retrace = retrace_ops.retrace(
lambda_, qs, targnet_qs, actions, rewards,
pcontinues, target_policy_probs, behaviour_policy_probs)
# qs: shape [(T+1), B, num_actions]
# https://github.com/deepmind/trfl/blob/08ccb293edb929d6002786f1c0c177ef291f2956/trfl/retrace_ops.py#L85
T = len(qs) - 1 # sequence length
B = len(qs[0]) # batch dimension
N = len(qs[0][0]) # number of actions
# loss: documented shape [B]
# https://github.com/deepmind/trfl/blob/08ccb293edb929d6002786f1c0c177ef291f2956/trfl/retrace_ops.py#L121
tf.debugging.assert_equal(retrace.loss.shape, [T, B]) # succeeds
### Multi-step target Test: ###
timesteps = tf.shape(qs)[0] # Batch size is qs_shape[1].
timestep_indices_tm1 = tf.range(0, timesteps - 1)
timestep_indices_t = tf.range(1, timesteps)
target_policy_t = tf.gather(target_policy_probs, timestep_indices_t)
behaviour_policy_t = tf.gather(behaviour_policy_probs, timestep_indices_t)
a_t = tf.gather(actions, timestep_indices_t)
r_t = tf.gather(rewards, timestep_indices_tm1)
pcont_t = tf.gather(pcontinues, timestep_indices_tm1)
targnet_q_t = tf.gather(targnet_qs, timestep_indices_t)
c_t = retrace_ops._retrace_weights(
indexing_ops.batched_index(target_policy_t, a_t),
behaviour_policy_t) * lambda_
target = retrace_ops._general_off_policy_corrected_multistep_target(
r_t, pcont_t, target_policy_t, c_t, targnet_q_t, a_t
)
# target: documented shape [T, B, N]
# https://github.com/deepmind/trfl/blob/08ccb293edb929d6002786f1c0c177ef291f2956/trfl/retrace_ops.py#L241
tf.debugging.assert_equal(target.shape, [T, B]) # succeedsReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels