diff --git a/tangent/tf_extensions.py b/tangent/tf_extensions.py index 6eec9c4..87ce1e1 100644 --- a/tangent/tf_extensions.py +++ b/tangent/tf_extensions.py @@ -57,7 +57,7 @@ def tensor_shapes_match(a, b): non_differentiable.register_non_differentiable_functions( - tf.shape, tf.to_float, tf.equal, tf.constant, + tf.shape, tf.cast, tf.float32, tf.equal, tf.constant, tf.zeros, tf.ones, tf.zeros_like, tf.ones_like, size, shape_as_list, dtype) @@ -238,7 +238,7 @@ def dtfreduce_mean(y, x, axis=None, keep_dims=False): @adjoint(tf.reduce_max) def dtfreduce_max(y, x, axis=None, keep_dims=False): - mask = tf.to_float( + mask = tf.cast( tf.equal( tangent.unreduce(y, tangent.shape_as_list(x), axis, keep_dims), x)) d[x] = tf.multiply( @@ -272,8 +272,8 @@ def dtfdivide(z, x, y): @adjoint(tf.maximum) def dtfmaximum(z, x, y): - d[x] = tf.multiply(d[z], tf.to_float(tf.equal(z, x))) - d[y] = tf.multiply(d[z], tf.to_float(tf.equal(z, y))) + d[x] = tf.multiply(d[z], tf.cast(tf.equal(z, x), tf.float32)) + d[y] = tf.multiply(d[z], tf.cast(tf.equal(z, y), tf.float32)) @adjoint(tf.squared_difference) @@ -385,10 +385,10 @@ def ttfreduce_mean(y, x, axis=None, keep_dims=False): @tangent_(tf.reduce_max) def ttfreduce_max(y, x, axis=None, keep_dims=False): - mask = tf.to_float( + mask = tf.cast( tf.equal( tangent.unreduce( - tf.ones_like(y), tangent.shape_as_list(x), axis, keep_dims), x)) + tf.ones_like(y), tangent.shape_as_list(x), axis, keep_dims), x), tf.float32) d[y] = tf.multiply(d[x], mask) @@ -421,7 +421,7 @@ def ttfdivide(z, x, y): @tangent_(tf.maximum) def ttfmaximum(z, x, y): - d[z] = d[x] * tf.to_float(tf.equal(z, x)) + d[y] * tf.to_float(tf.equal(z, y)) + d[z] = d[x] * tf.cast(tf.equal(z, x), tf.float32) + d[y] * tf.cast(tf.equal(z, y), tf.float32) @tangent_(tf.nn.avg_pool)