1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.data_flow_ops.{,parallel_}dynamic_stitch.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import data_flow_ops 27 from tensorflow.python.ops import gradients_impl 28 import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import 29 from tensorflow.python.platform import test 30 from tensorflow.python.framework import dtypes 31 32 33 class DynamicStitchTestBase(object): 34 35 def __init__(self, stitch_op): 36 self.stitch_op = stitch_op 37 38 def testScalar(self): 39 with self.test_session(use_gpu=True): 40 indices = [constant_op.constant(0), constant_op.constant(1)] 41 data = [constant_op.constant(40), constant_op.constant(60)] 42 for step in -1, 1: 43 stitched_t = self.stitch_op(indices[::step], data) 44 stitched_val = stitched_t.eval() 45 self.assertAllEqual([40, 60][::step], stitched_val) 46 # Dimension 0 is max(flatten(indices))+1. 47 self.assertEqual([2], stitched_t.get_shape().as_list()) 48 49 def testShapeInferenceForScalarWithNonConstantIndices(self): 50 with self.test_session(use_gpu=True): 51 indices = [ 52 array_ops.placeholder(dtype=dtypes.int32), 53 constant_op.constant(1) 54 ] 55 data = [constant_op.constant(40), constant_op.constant(60)] 56 for step in -1, 1: 57 stitched_t = self.stitch_op(indices[::step], data) 58 # Dimension 0 is max(flatten(indices))+1, but the first indices input is 59 # not a constant tensor, so we can only infer it as a vector of unknown 60 # length. 61 self.assertEqual([None], stitched_t.get_shape().as_list()) 62 63 def testSimpleOneDimensional(self): 64 with self.test_session(use_gpu=True): 65 indices = [ 66 constant_op.constant([0, 4, 7]), 67 constant_op.constant([1, 6, 2, 3, 5]) 68 ] 69 data = [ 70 constant_op.constant([0, 40, 70]), 71 constant_op.constant([10, 60, 20, 30, 50]) 72 ] 73 stitched_t = self.stitch_op(indices, data) 74 stitched_val = stitched_t.eval() 75 self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val) 76 # Dimension 0 is max(flatten(indices))+1. 77 self.assertEqual([8], stitched_t.get_shape().as_list()) 78 79 def testOneListOneDimensional(self): 80 with self.test_session(use_gpu=True): 81 indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])] 82 data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])] 83 stitched_t = self.stitch_op(indices, data) 84 stitched_val = stitched_t.eval() 85 self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val) 86 # Dimension 0 is max(flatten(indices))+1. 87 self.assertEqual([8], stitched_t.get_shape().as_list()) 88 89 def testSimpleTwoDimensional(self): 90 with self.test_session(use_gpu=True): 91 indices = [ 92 constant_op.constant([0, 4, 7]), 93 constant_op.constant([1, 6]), 94 constant_op.constant([2, 3, 5]) 95 ] 96 data = [ 97 constant_op.constant([[0, 1], [40, 41], [70, 71]]), 98 constant_op.constant([[10, 11], [60, 61]]), 99 constant_op.constant([[20, 21], [30, 31], [50, 51]]) 100 ] 101 stitched_t = self.stitch_op(indices, data) 102 stitched_val = stitched_t.eval() 103 self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], 104 [50, 51], [60, 61], [70, 71]], stitched_val) 105 # Dimension 0 is max(flatten(indices))+1. 106 self.assertEqual([8, 2], stitched_t.get_shape().as_list()) 107 108 def testHigherRank(self): 109 with self.test_session(use_gpu=True) as sess: 110 indices = [ 111 constant_op.constant(6), 112 constant_op.constant([4, 1]), 113 constant_op.constant([[5, 2], [0, 3]]) 114 ] 115 data = [ 116 constant_op.constant([61, 62]), 117 constant_op.constant([[41, 42], [11, 12]]), 118 constant_op.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]]) 119 ] 120 stitched_t = self.stitch_op(indices, data) 121 stitched_val = stitched_t.eval() 122 correct = 10 * np.arange(7)[:, None] + [1, 2] 123 self.assertAllEqual(correct, stitched_val) 124 self.assertEqual([7, 2], stitched_t.get_shape().as_list()) 125 # Test gradients 126 stitched_grad = 7 * stitched_val 127 grads = gradients_impl.gradients(stitched_t, indices + data, 128 stitched_grad) 129 self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients 130 for datum, grad in zip(data, sess.run(grads[3:])): 131 self.assertAllEqual(7 * datum.eval(), grad) 132 133 def testErrorIndicesMultiDimensional(self): 134 indices = [ 135 constant_op.constant([0, 4, 7]), 136 constant_op.constant([[1, 6, 2, 3, 5]]) 137 ] 138 data = [ 139 constant_op.constant([[0, 40, 70]]), 140 constant_op.constant([10, 60, 20, 30, 50]) 141 ] 142 with self.assertRaises(ValueError): 143 self.stitch_op(indices, data) 144 145 def testErrorDataNumDimsMismatch(self): 146 indices = [ 147 constant_op.constant([0, 4, 7]), 148 constant_op.constant([1, 6, 2, 3, 5]) 149 ] 150 data = [ 151 constant_op.constant([0, 40, 70]), 152 constant_op.constant([[10, 60, 20, 30, 50]]) 153 ] 154 with self.assertRaises(ValueError): 155 self.stitch_op(indices, data) 156 157 def testErrorDataDimSizeMismatch(self): 158 indices = [ 159 constant_op.constant([0, 4, 5]), 160 constant_op.constant([1, 6, 2, 3]) 161 ] 162 data = [ 163 constant_op.constant([[0], [40], [70]]), 164 constant_op.constant([[10, 11], [60, 61], [20, 21], [30, 31]]) 165 ] 166 with self.assertRaises(ValueError): 167 self.stitch_op(indices, data) 168 169 def testErrorDataAndIndicesSizeMismatch(self): 170 indices = [ 171 constant_op.constant([0, 4, 7]), 172 constant_op.constant([1, 6, 2, 3, 5]) 173 ] 174 data = [ 175 constant_op.constant([0, 40, 70]), 176 constant_op.constant([10, 60, 20, 30]) 177 ] 178 with self.assertRaises(ValueError): 179 self.stitch_op(indices, data) 180 181 182 class DynamicStitchTest(DynamicStitchTestBase, test.TestCase): 183 184 def __init__(self, *test_case_args): 185 test.TestCase.__init__(self, *test_case_args) 186 DynamicStitchTestBase.__init__(self, data_flow_ops.dynamic_stitch) 187 188 189 class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase): 190 191 def __init__(self, *test_case_args): 192 test.TestCase.__init__(self, *test_case_args) 193 DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch) 194 195 def testScalar(self): 196 with self.test_session(use_gpu=True): 197 indices = [constant_op.constant(0), constant_op.constant(1)] 198 data = [constant_op.constant(40.0), constant_op.constant(60.0)] 199 for step in -1, 1: 200 stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data) 201 stitched_val = stitched_t.eval() 202 self.assertAllEqual([40.0, 60.0][::step], stitched_val) 203 # Dimension 0 is max(flatten(indices))+1. 204 self.assertEqual([2], stitched_t.get_shape().as_list()) 205 206 def testHigherRank(self): 207 with self.test_session(use_gpu=True) as sess: 208 indices = [ 209 constant_op.constant(6), 210 constant_op.constant([4, 1]), 211 constant_op.constant([[5, 2], [0, 3]]) 212 ] 213 data = [ 214 constant_op.constant([61, 62], dtype=dtypes.float32), 215 constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32), 216 constant_op.constant( 217 [[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32) 218 ] 219 stitched_t = data_flow_ops.dynamic_stitch(indices, data) 220 stitched_val = stitched_t.eval() 221 correct = 10 * np.arange(7)[:, None] + [1.0, 2.0] 222 self.assertAllEqual(correct, stitched_val) 223 self.assertEqual([7, 2], stitched_t.get_shape().as_list()) 224 # Test gradients 225 stitched_grad = 7 * stitched_val 226 grads = gradients_impl.gradients(stitched_t, indices + data, 227 stitched_grad) 228 self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients 229 for datum, grad in zip(data, sess.run(grads[3:])): 230 self.assertAllEqual(7.0 * datum.eval(), grad) 231 232 # GPU version unit tests 233 def testScalarGPU(self): 234 with self.test_session(): 235 indices = [constant_op.constant(0), constant_op.constant(1)] 236 data = [constant_op.constant(40.0), constant_op.constant(60.0)] 237 for step in -1, 1: 238 stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data) 239 stitched_val = stitched_t.eval() 240 self.assertAllEqual([40.0, 60.0][::step], stitched_val) 241 # Dimension 0 is max(flatten(indices))+1. 242 self.assertEqual([2], stitched_t.get_shape().as_list()) 243 244 def testHigherRankGPU(self): 245 with self.test_session() as sess: 246 indices = [ 247 constant_op.constant(6), 248 constant_op.constant([4, 1]), 249 constant_op.constant([[5, 2], [0, 3]]) 250 ] 251 data = [ 252 constant_op.constant([61, 62], dtype=dtypes.float32), 253 constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32), 254 constant_op.constant( 255 [[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32) 256 ] 257 stitched_t = data_flow_ops.dynamic_stitch(indices, data) 258 stitched_val = stitched_t.eval() 259 correct = 10 * np.arange(7)[:, None] + [1.0, 2.0] 260 self.assertAllEqual(correct, stitched_val) 261 self.assertEqual([7, 2], stitched_t.get_shape().as_list()) 262 # Test gradients 263 stitched_grad = 7 * stitched_val 264 grads = gradients_impl.gradients(stitched_t, indices + data, 265 stitched_grad) 266 self.assertEqual(grads[:3], [None] * 3) # Indices have no gradients 267 for datum, grad in zip(data, sess.run(grads[3:])): 268 self.assertAllEqual(7.0 * datum.eval(), grad) 269 270 271 if __name__ == "__main__": 272 test.main() 273