1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.dynamic_stitch.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.compiler.tests.xla_test import XLATestCase 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import data_flow_ops 27 from tensorflow.python.platform import googletest 28 29 30 class DynamicStitchTest(XLATestCase): 31 32 def _AssertDynamicStitchResultIs(self, indices, data, expected): 33 with self.test_session() as session: 34 index_placeholders = [ 35 array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices 36 ] 37 data_placeholders = [ 38 array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in data 39 ] 40 with self.test_scope(): 41 output = data_flow_ops.dynamic_stitch(index_placeholders, 42 data_placeholders) 43 44 feed_dict = {} 45 for placeholder, value in zip(index_placeholders, indices): 46 feed_dict[placeholder] = value 47 for placeholder, value in zip(data_placeholders, data): 48 feed_dict[placeholder] = value 49 result = session.run(output, feed_dict=feed_dict) 50 self.assertAllClose(expected, result, rtol=1e-3) 51 52 def testSimpleEmpty(self): 53 idx1 = np.array([0, 2], dtype=np.int32) 54 idx2 = np.array([[1], [3]], dtype=np.int32) 55 val1 = np.array([[], []], dtype=np.int32) 56 val2 = np.array([[[]], [[]]], dtype=np.int32) 57 self._AssertDynamicStitchResultIs( 58 [idx1, idx2], [val1, val2], 59 expected=np.array([[], [], [], []], np.int32)) 60 61 def testSimple1D(self): 62 val1 = np.array([0, 4, 7], dtype=np.int32) 63 val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32) 64 val3 = np.array([0, 40, 70], dtype=np.float32) 65 val4 = np.array([10, 60, 20, 30, 50], dtype=np.float32) 66 expected = np.array([0, 10, 20, 30, 40, 50, 60, 70], dtype=np.float32) 67 self._AssertDynamicStitchResultIs( 68 [val1, val2], [val3, val4], expected=expected) 69 70 def testSimple2D(self): 71 val1 = np.array([0, 4, 7], dtype=np.int32) 72 val2 = np.array([1, 6], dtype=np.int32) 73 val3 = np.array([2, 3, 5], dtype=np.int32) 74 val4 = np.array([[0, 1], [40, 41], [70, 71]], dtype=np.float32) 75 val5 = np.array([[10, 11], [60, 61]], dtype=np.float32) 76 val6 = np.array([[20, 21], [30, 31], [50, 51]], dtype=np.float32) 77 expected = np.array( 78 [[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], [50, 51], [60, 61], 79 [70, 71]], 80 dtype=np.float32) 81 self._AssertDynamicStitchResultIs( 82 [val1, val2, val3], [val4, val5, val6], expected=expected) 83 84 85 if __name__ == "__main__": 86 googletest.main() 87