Home | History | Annotate | Download | only in tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.dynamic_stitch."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.compiler.tests.xla_test import XLATestCase
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import data_flow_ops
     27 from tensorflow.python.platform import googletest
     28 
     29 
     30 class DynamicStitchTest(XLATestCase):
     31 
     32   def _AssertDynamicStitchResultIs(self, indices, data, expected):
     33     with self.test_session() as session:
     34       index_placeholders = [
     35           array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
     36       ]
     37       data_placeholders = [
     38           array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in data
     39       ]
     40       with self.test_scope():
     41         output = data_flow_ops.dynamic_stitch(index_placeholders,
     42                                               data_placeholders)
     43 
     44       feed_dict = {}
     45       for placeholder, value in zip(index_placeholders, indices):
     46         feed_dict[placeholder] = value
     47       for placeholder, value in zip(data_placeholders, data):
     48         feed_dict[placeholder] = value
     49       result = session.run(output, feed_dict=feed_dict)
     50       self.assertAllClose(expected, result, rtol=1e-3)
     51 
     52   def testSimpleEmpty(self):
     53     idx1 = np.array([0, 2], dtype=np.int32)
     54     idx2 = np.array([[1], [3]], dtype=np.int32)
     55     val1 = np.array([[], []], dtype=np.int32)
     56     val2 = np.array([[[]], [[]]], dtype=np.int32)
     57     self._AssertDynamicStitchResultIs(
     58         [idx1, idx2], [val1, val2],
     59         expected=np.array([[], [], [], []], np.int32))
     60 
     61   def testSimple1D(self):
     62     val1 = np.array([0, 4, 7], dtype=np.int32)
     63     val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)
     64     val3 = np.array([0, 40, 70], dtype=np.float32)
     65     val4 = np.array([10, 60, 20, 30, 50], dtype=np.float32)
     66     expected = np.array([0, 10, 20, 30, 40, 50, 60, 70], dtype=np.float32)
     67     self._AssertDynamicStitchResultIs(
     68         [val1, val2], [val3, val4], expected=expected)
     69 
     70   def testSimple2D(self):
     71     val1 = np.array([0, 4, 7], dtype=np.int32)
     72     val2 = np.array([1, 6], dtype=np.int32)
     73     val3 = np.array([2, 3, 5], dtype=np.int32)
     74     val4 = np.array([[0, 1], [40, 41], [70, 71]], dtype=np.float32)
     75     val5 = np.array([[10, 11], [60, 61]], dtype=np.float32)
     76     val6 = np.array([[20, 21], [30, 31], [50, 51]], dtype=np.float32)
     77     expected = np.array(
     78         [[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], [50, 51], [60, 61],
     79          [70, 71]],
     80         dtype=np.float32)
     81     self._AssertDynamicStitchResultIs(
     82         [val1, val2, val3], [val4, val5, val6], expected=expected)
     83 
     84 
     85 if __name__ == "__main__":
     86   googletest.main()
     87