Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.platform import test
     27 
     28 
     29 class ZipDatasetTest(test.TestCase):
     30 
     31   def testZipDataset(self):
     32     component_placeholders = [
     33         array_ops.placeholder(dtypes.int64),
     34         array_ops.placeholder(dtypes.int64),
     35         array_ops.placeholder(dtypes.float64)
     36     ]
     37 
     38     datasets = tuple([
     39         dataset_ops.Dataset.from_tensor_slices(component_placeholder)
     40         for component_placeholder in component_placeholders
     41     ])
     42     zipped = dataset_ops.Dataset.zip(datasets)
     43 
     44     iterator = zipped.make_initializable_iterator()
     45     init_op = iterator.initializer
     46     get_next = iterator.get_next()
     47 
     48     with self.test_session() as sess:
     49       equal_length_components = [
     50           np.tile(np.array([[1], [2], [3], [4]]), 20),
     51           np.tile(np.array([[12], [13], [14], [15]]), 22),
     52           np.array([37.0, 38.0, 39.0, 40.0])
     53       ]
     54       sess.run(init_op, feed_dict={ph: value for ph, value in zip(
     55           component_placeholders, equal_length_components)})
     56       for i in range(4):
     57         results = sess.run(get_next)
     58         for component, result_component in zip(
     59             equal_length_components, results):
     60           self.assertAllEqual(component[i], result_component)
     61       with self.assertRaises(errors.OutOfRangeError):
     62         sess.run(get_next)
     63 
     64       variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
     65       sess.run(init_op, feed_dict={ph: value for ph, value in zip(
     66           component_placeholders, variable_length_components)})
     67       for i in range(2):
     68         results = sess.run(get_next)
     69         for component, result_component in zip(
     70             variable_length_components, results):
     71           self.assertAllEqual(component[i], result_component)
     72       with self.assertRaises(errors.OutOfRangeError):
     73         sess.run(get_next)
     74 
     75   def testNestedZipDataset(self):
     76     component_placeholders = [
     77         array_ops.placeholder(dtypes.int64, shape=[4, 20]),
     78         array_ops.placeholder(dtypes.int64, shape=[4, 22]),
     79         array_ops.placeholder(dtypes.float64, shape=[4])
     80     ]
     81 
     82     datasets = [
     83         dataset_ops.Dataset.from_tensor_slices(component_placeholder)
     84         for component_placeholder in component_placeholders
     85     ]
     86     zipped = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
     87 
     88     iterator = zipped.make_initializable_iterator()
     89     init_op = iterator.initializer
     90     get_next = iterator.get_next()
     91 
     92     self.assertEqual([20], get_next[0].shape)
     93     self.assertEqual([22], get_next[1][0].shape)
     94     self.assertEqual([], get_next[1][1].shape)
     95 
     96     with self.test_session() as sess:
     97       equal_length_components = [
     98           np.tile(np.array([[1], [2], [3], [4]]), 20),
     99           np.tile(np.array([[12], [13], [14], [15]]), 22),
    100           np.array([37.0, 38.0, 39.0, 40.0])
    101       ]
    102       sess.run(init_op, feed_dict={ph: value for ph, value in zip(
    103           component_placeholders, equal_length_components)})
    104       for i in range(4):
    105         result1, (result2, result3) = sess.run(get_next)
    106         self.assertAllEqual(equal_length_components[0][i], result1)
    107         self.assertAllEqual(equal_length_components[1][i], result2)
    108         self.assertAllEqual(equal_length_components[2][i], result3)
    109       with self.assertRaises(errors.OutOfRangeError):
    110         sess.run(get_next)
    111 
    112 
    113 if __name__ == "__main__":
    114   test.main()
    115