1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.data.util import nest 24 from tensorflow.python.framework import errors 25 from tensorflow.python.framework import tensor_shape 26 from tensorflow.python.platform import test 27 28 29 class ConcatenateDatasetTest(test.TestCase): 30 31 def testConcatenateDataset(self): 32 input_components = ( 33 np.tile(np.array([[1], [2], [3], [4]]), 20), 34 np.tile(np.array([[12], [13], [14], [15]]), 15), 35 np.array([37.0, 38.0, 39.0, 40.0])) 36 to_concatenate_components = ( 37 np.tile(np.array([[1], [2], [3], [4], [5]]), 20), 38 np.tile(np.array([[12], [13], [14], [15], [16]]), 15), 39 np.array([37.0, 38.0, 39.0, 40.0, 41.0])) 40 41 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 42 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 43 to_concatenate_components) 44 concatenated = input_dataset.concatenate(dataset_to_concatenate) 45 self.assertEqual(concatenated.output_shapes, (tensor_shape.TensorShape( 46 [20]), tensor_shape.TensorShape([15]), tensor_shape.TensorShape([]))) 47 48 iterator = concatenated.make_initializable_iterator() 49 init_op = iterator.initializer 50 get_next = iterator.get_next() 51 52 with self.test_session() as sess: 53 sess.run(init_op) 54 for i in range(9): 55 result = sess.run(get_next) 56 if i < 4: 57 for component, result_component in zip(input_components, result): 58 self.assertAllEqual(component[i], result_component) 59 else: 60 for component, result_component in zip(to_concatenate_components, 61 result): 62 self.assertAllEqual(component[i - 4], result_component) 63 with self.assertRaises(errors.OutOfRangeError): 64 sess.run(get_next) 65 66 def testConcatenateDatasetDifferentShape(self): 67 input_components = ( 68 np.tile(np.array([[1], [2], [3], [4]]), 20), 69 np.tile(np.array([[12], [13], [14], [15]]), 4)) 70 to_concatenate_components = ( 71 np.tile(np.array([[1], [2], [3], [4], [5]]), 20), 72 np.tile(np.array([[12], [13], [14], [15], [16]]), 15)) 73 74 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 75 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 76 to_concatenate_components) 77 concatenated = input_dataset.concatenate(dataset_to_concatenate) 78 self.assertEqual( 79 [ts.as_list() 80 for ts in nest.flatten(concatenated.output_shapes)], [[20], [None]]) 81 82 iterator = concatenated.make_initializable_iterator() 83 init_op = iterator.initializer 84 get_next = iterator.get_next() 85 86 with self.test_session() as sess: 87 sess.run(init_op) 88 for i in range(9): 89 result = sess.run(get_next) 90 if i < 4: 91 for component, result_component in zip(input_components, result): 92 self.assertAllEqual(component[i], result_component) 93 else: 94 for component, result_component in zip(to_concatenate_components, 95 result): 96 self.assertAllEqual(component[i - 4], result_component) 97 with self.assertRaises(errors.OutOfRangeError): 98 sess.run(get_next) 99 100 def testConcatenateDatasetDifferentStructure(self): 101 input_components = ( 102 np.tile(np.array([[1], [2], [3], [4]]), 5), 103 np.tile(np.array([[12], [13], [14], [15]]), 4)) 104 to_concatenate_components = ( 105 np.tile(np.array([[1], [2], [3], [4], [5]]), 20), 106 np.tile(np.array([[12], [13], [14], [15], [16]]), 15), 107 np.array([37.0, 38.0, 39.0, 40.0, 41.0])) 108 109 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 110 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 111 to_concatenate_components) 112 113 with self.assertRaisesRegexp(ValueError, 114 "don't have the same number of elements"): 115 input_dataset.concatenate(dataset_to_concatenate) 116 117 def testConcatenateDatasetDifferentType(self): 118 input_components = ( 119 np.tile(np.array([[1], [2], [3], [4]]), 5), 120 np.tile(np.array([[12], [13], [14], [15]]), 4)) 121 to_concatenate_components = ( 122 np.tile(np.array([[1.0], [2.0], [3.0], [4.0]]), 5), 123 np.tile(np.array([[12], [13], [14], [15]]), 15)) 124 125 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 126 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 127 to_concatenate_components) 128 129 with self.assertRaisesRegexp(TypeError, "have different types"): 130 input_dataset.concatenate(dataset_to_concatenate) 131 132 133 if __name__ == "__main__": 134 test.main() 135