1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import errors 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.platform import test 27 28 29 class ZipDatasetTest(test.TestCase): 30 31 def testZipDataset(self): 32 component_placeholders = [ 33 array_ops.placeholder(dtypes.int64), 34 array_ops.placeholder(dtypes.int64), 35 array_ops.placeholder(dtypes.float64) 36 ] 37 38 datasets = tuple([ 39 dataset_ops.Dataset.from_tensor_slices(component_placeholder) 40 for component_placeholder in component_placeholders 41 ]) 42 zipped = dataset_ops.Dataset.zip(datasets) 43 44 iterator = zipped.make_initializable_iterator() 45 init_op = iterator.initializer 46 get_next = iterator.get_next() 47 48 with self.test_session() as sess: 49 equal_length_components = [ 50 np.tile(np.array([[1], [2], [3], [4]]), 20), 51 np.tile(np.array([[12], [13], [14], [15]]), 22), 52 np.array([37.0, 38.0, 39.0, 40.0]) 53 ] 54 sess.run(init_op, feed_dict={ph: value for ph, value in zip( 55 component_placeholders, equal_length_components)}) 56 for i in range(4): 57 results = sess.run(get_next) 58 for component, result_component in zip( 59 equal_length_components, results): 60 self.assertAllEqual(component[i], result_component) 61 with self.assertRaises(errors.OutOfRangeError): 62 sess.run(get_next) 63 64 variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]] 65 sess.run(init_op, feed_dict={ph: value for ph, value in zip( 66 component_placeholders, variable_length_components)}) 67 for i in range(2): 68 results = sess.run(get_next) 69 for component, result_component in zip( 70 variable_length_components, results): 71 self.assertAllEqual(component[i], result_component) 72 with self.assertRaises(errors.OutOfRangeError): 73 sess.run(get_next) 74 75 def testNestedZipDataset(self): 76 component_placeholders = [ 77 array_ops.placeholder(dtypes.int64, shape=[4, 20]), 78 array_ops.placeholder(dtypes.int64, shape=[4, 22]), 79 array_ops.placeholder(dtypes.float64, shape=[4]) 80 ] 81 82 datasets = [ 83 dataset_ops.Dataset.from_tensor_slices(component_placeholder) 84 for component_placeholder in component_placeholders 85 ] 86 zipped = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2]))) 87 88 iterator = zipped.make_initializable_iterator() 89 init_op = iterator.initializer 90 get_next = iterator.get_next() 91 92 self.assertEqual([20], get_next[0].shape) 93 self.assertEqual([22], get_next[1][0].shape) 94 self.assertEqual([], get_next[1][1].shape) 95 96 with self.test_session() as sess: 97 equal_length_components = [ 98 np.tile(np.array([[1], [2], [3], [4]]), 20), 99 np.tile(np.array([[12], [13], [14], [15]]), 22), 100 np.array([37.0, 38.0, 39.0, 40.0]) 101 ] 102 sess.run(init_op, feed_dict={ph: value for ph, value in zip( 103 component_placeholders, equal_length_components)}) 104 for i in range(4): 105 result1, (result2, result3) = sess.run(get_next) 106 self.assertAllEqual(equal_length_components[0][i], result1) 107 self.assertAllEqual(equal_length_components[1][i], result2) 108 self.assertAllEqual(equal_length_components[2][i], result3) 109 with self.assertRaises(errors.OutOfRangeError): 110 sess.run(get_next) 111 112 113 if __name__ == "__main__": 114 test.main() 115