1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import itertools 21 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import errors 25 from tensorflow.python.framework import sparse_tensor 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import sparse_ops 28 from tensorflow.python.platform import test 29 30 31 class InterleaveDatasetTest(test.TestCase): 32 33 def _interleave(self, lists, cycle_length, block_length): 34 num_open = 0 35 36 # `all_iterators` acts as a queue of iterators over each element of `lists`. 37 all_iterators = [iter(l) for l in lists] 38 39 # `open_iterators` are the iterators whose elements are currently being 40 # interleaved. 41 open_iterators = [] 42 for i in range(cycle_length): 43 if all_iterators: 44 open_iterators.append(all_iterators.pop(0)) 45 num_open += 1 46 else: 47 open_iterators.append(None) 48 49 while num_open or all_iterators: 50 for i in range(cycle_length): 51 if open_iterators[i] is None: 52 if all_iterators: 53 open_iterators[i] = all_iterators.pop(0) 54 num_open += 1 55 else: 56 continue 57 for _ in range(block_length): 58 try: 59 yield next(open_iterators[i]) 60 except StopIteration: 61 open_iterators[i] = None 62 num_open -= 1 63 break 64 65 def testPythonImplementation(self): 66 input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], 67 [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] 68 69 # Cycle length 1 acts like `Dataset.flat_map()`. 70 expected_elements = itertools.chain(*input_lists) 71 for expected, produced in zip( 72 expected_elements, self._interleave(input_lists, 1, 1)): 73 self.assertEqual(expected, produced) 74 75 # Cycle length > 1. 76 expected_elements = [4, 5, 4, 5, 4, 5, 4, 77 5, 5, 6, 6, # NOTE(mrry): When we cycle back 78 # to a list and are already at 79 # the end of that list, we move 80 # on to the next element. 81 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] 82 for expected, produced in zip( 83 expected_elements, self._interleave(input_lists, 2, 1)): 84 self.assertEqual(expected, produced) 85 86 # Cycle length > 1 and block length > 1. 87 expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 88 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] 89 for expected, produced in zip( 90 expected_elements, self._interleave(input_lists, 2, 3)): 91 self.assertEqual(expected, produced) 92 93 # Cycle length > len(input_values). 94 expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 95 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] 96 for expected, produced in zip( 97 expected_elements, self._interleave(input_lists, 7, 2)): 98 self.assertEqual(expected, produced) 99 100 def testInterleaveDataset(self): 101 input_values = array_ops.placeholder(dtypes.int64, shape=[None]) 102 cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) 103 block_length = array_ops.placeholder(dtypes.int64, shape=[]) 104 105 repeat_count = 2 106 107 dataset = ( 108 dataset_ops.Dataset.from_tensor_slices(input_values) 109 .repeat(repeat_count) 110 .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), 111 cycle_length, block_length)) 112 iterator = dataset.make_initializable_iterator() 113 init_op = iterator.initializer 114 next_element = iterator.get_next() 115 116 with self.test_session() as sess: 117 # Cycle length 1 acts like `Dataset.flat_map()`. 118 sess.run(init_op, feed_dict={input_values: [4, 5, 6], 119 cycle_length: 1, block_length: 3}) 120 121 for expected_element in self._interleave( 122 [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): 123 self.assertEqual(expected_element, sess.run(next_element)) 124 125 # Cycle length > 1. 126 # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 127 # 6, 5, 6, 5, 6, 5, 6, 5] 128 sess.run(init_op, feed_dict={input_values: [4, 5, 6], 129 cycle_length: 2, block_length: 1}) 130 for expected_element in self._interleave( 131 [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): 132 self.assertEqual(expected_element, sess.run(next_element)) 133 with self.assertRaises(errors.OutOfRangeError): 134 sess.run(next_element) 135 136 # Cycle length > 1 and block length > 1. 137 # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, 138 # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] 139 sess.run(init_op, feed_dict={input_values: [4, 5, 6], 140 cycle_length: 2, block_length: 3}) 141 for expected_element in self._interleave( 142 [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): 143 self.assertEqual(expected_element, sess.run(next_element)) 144 with self.assertRaises(errors.OutOfRangeError): 145 sess.run(next_element) 146 147 # Cycle length > len(input_values) * repeat_count. 148 # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 149 # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] 150 sess.run(init_op, feed_dict={input_values: [4, 5, 6], 151 cycle_length: 7, block_length: 2}) 152 for expected_element in self._interleave( 153 [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): 154 self.assertEqual(expected_element, sess.run(next_element)) 155 with self.assertRaises(errors.OutOfRangeError): 156 sess.run(next_element) 157 158 # Empty input. 159 sess.run(init_op, feed_dict={input_values: [], 160 cycle_length: 2, block_length: 3}) 161 with self.assertRaises(errors.OutOfRangeError): 162 sess.run(next_element) 163 164 # Non-empty input leading to empty output. 165 sess.run(init_op, feed_dict={input_values: [0, 0, 0], 166 cycle_length: 2, block_length: 3}) 167 with self.assertRaises(errors.OutOfRangeError): 168 sess.run(next_element) 169 170 # Mixture of non-empty and empty interleaved datasets. 171 sess.run(init_op, feed_dict={input_values: [4, 0, 6], 172 cycle_length: 2, block_length: 3}) 173 for expected_element in self._interleave( 174 [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): 175 self.assertEqual(expected_element, sess.run(next_element)) 176 with self.assertRaises(errors.OutOfRangeError): 177 sess.run(next_element) 178 179 def testSparse(self): 180 181 def _map_fn(i): 182 return sparse_tensor.SparseTensorValue( 183 indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) 184 185 def _interleave_fn(x): 186 return dataset_ops.Dataset.from_tensor_slices( 187 sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) 188 189 iterator = ( 190 dataset_ops.Dataset.range(10).map(_map_fn).interleave( 191 _interleave_fn, cycle_length=1).make_initializable_iterator()) 192 init_op = iterator.initializer 193 get_next = iterator.get_next() 194 195 with self.test_session() as sess: 196 sess.run(init_op) 197 for i in range(10): 198 for j in range(2): 199 expected = [i, 0] if j % 2 == 0 else [0, -i] 200 self.assertAllEqual(expected, sess.run(get_next)) 201 with self.assertRaises(errors.OutOfRangeError): 202 sess.run(get_next) 203 204 def testEmptyInput(self): 205 iterator = ( 206 dataset_ops.Dataset.from_tensor_slices([]) 207 .repeat(None) 208 .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2) 209 .make_initializable_iterator()) 210 init_op = iterator.initializer 211 get_next = iterator.get_next() 212 213 with self.test_session() as sess: 214 sess.run(init_op) 215 with self.assertRaises(errors.OutOfRangeError): 216 sess.run(get_next) 217 218 219 if __name__ == "__main__": 220 test.main() 221