1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import errors 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.platform import test 27 28 29 class SequenceDatasetTest(test.TestCase): 30 31 def testRepeatTensorDataset(self): 32 """Test a dataset that repeats its input multiple times.""" 33 components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) 34 # This placeholder can be fed when dataset-definition subgraph 35 # runs (i.e. `init_op` below) to configure the number of 36 # repetitions used in a particular iterator. 37 count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) 38 39 iterator = (dataset_ops.Dataset.from_tensors(components) 40 .repeat(count_placeholder).make_initializable_iterator()) 41 init_op = iterator.initializer 42 get_next = iterator.get_next() 43 44 self.assertEqual([c.shape for c in components], 45 [t.shape for t in get_next]) 46 47 with self.test_session() as sess: 48 # Test a finite repetition. 49 sess.run(init_op, feed_dict={count_placeholder: 3}) 50 for _ in range(3): 51 results = sess.run(get_next) 52 for component, result_component in zip(components, results): 53 self.assertAllEqual(component, result_component) 54 55 with self.assertRaises(errors.OutOfRangeError): 56 sess.run(get_next) 57 58 # Test a different finite repetition. 59 sess.run(init_op, feed_dict={count_placeholder: 7}) 60 for _ in range(7): 61 results = sess.run(get_next) 62 for component, result_component in zip(components, results): 63 self.assertAllEqual(component, result_component) 64 with self.assertRaises(errors.OutOfRangeError): 65 sess.run(get_next) 66 67 # Test an empty repetition. 68 sess.run(init_op, feed_dict={count_placeholder: 0}) 69 with self.assertRaises(errors.OutOfRangeError): 70 sess.run(get_next) 71 72 # Test an infinite repetition. 73 # NOTE(mrry): There's not a good way to test that the sequence 74 # actually is infinite. 75 sess.run(init_op, feed_dict={count_placeholder: -1}) 76 for _ in range(17): 77 results = sess.run(get_next) 78 for component, result_component in zip(components, results): 79 self.assertAllEqual(component, result_component) 80 81 def testTakeTensorDataset(self): 82 components = (np.arange(10),) 83 count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) 84 85 iterator = (dataset_ops.Dataset.from_tensor_slices(components) 86 .take(count_placeholder).make_initializable_iterator()) 87 init_op = iterator.initializer 88 get_next = iterator.get_next() 89 90 self.assertEqual([c.shape[1:] for c in components], 91 [t.shape for t in get_next]) 92 93 with self.test_session() as sess: 94 # Take fewer than input size 95 sess.run(init_op, feed_dict={count_placeholder: 4}) 96 for i in range(4): 97 results = sess.run(get_next) 98 self.assertAllEqual(results, components[0][i:i+1]) 99 100 with self.assertRaises(errors.OutOfRangeError): 101 sess.run(get_next) 102 103 # Take more than input size 104 sess.run(init_op, feed_dict={count_placeholder: 25}) 105 for i in range(10): 106 results = sess.run(get_next) 107 self.assertAllEqual(results, components[0][i:i+1]) 108 109 with self.assertRaises(errors.OutOfRangeError): 110 sess.run(get_next) 111 112 # Take all of input 113 sess.run(init_op, feed_dict={count_placeholder: -1}) 114 for i in range(10): 115 results = sess.run(get_next) 116 self.assertAllEqual(results, components[0][i:i+1]) 117 118 with self.assertRaises(errors.OutOfRangeError): 119 sess.run(get_next) 120 121 # Take nothing 122 sess.run(init_op, feed_dict={count_placeholder: 0}) 123 124 with self.assertRaises(errors.OutOfRangeError): 125 sess.run(get_next) 126 127 def testSkipTensorDataset(self): 128 components = (np.arange(10),) 129 count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) 130 131 iterator = (dataset_ops.Dataset.from_tensor_slices(components) 132 .skip(count_placeholder).make_initializable_iterator()) 133 init_op = iterator.initializer 134 get_next = iterator.get_next() 135 136 self.assertEqual([c.shape[1:] for c in components], 137 [t.shape for t in get_next]) 138 139 with self.test_session() as sess: 140 # Skip fewer than input size, we should skip 141 # the first 4 elements and then read the rest. 142 sess.run(init_op, feed_dict={count_placeholder: 4}) 143 for i in range(4, 10): 144 results = sess.run(get_next) 145 self.assertAllEqual(results, components[0][i:i+1]) 146 with self.assertRaises(errors.OutOfRangeError): 147 sess.run(get_next) 148 149 # Skip more than input size: get nothing. 150 sess.run(init_op, feed_dict={count_placeholder: 25}) 151 with self.assertRaises(errors.OutOfRangeError): 152 sess.run(get_next) 153 154 # Skip exactly input size. 155 sess.run(init_op, feed_dict={count_placeholder: 10}) 156 with self.assertRaises(errors.OutOfRangeError): 157 sess.run(get_next) 158 159 # Set -1 for 'count': skip the entire dataset. 160 sess.run(init_op, feed_dict={count_placeholder: -1}) 161 with self.assertRaises(errors.OutOfRangeError): 162 sess.run(get_next) 163 164 # Skip nothing 165 sess.run(init_op, feed_dict={count_placeholder: 0}) 166 for i in range(0, 10): 167 results = sess.run(get_next) 168 self.assertAllEqual(results, components[0][i:i+1]) 169 with self.assertRaises(errors.OutOfRangeError): 170 sess.run(get_next) 171 172 def testRepeatRepeatTensorDataset(self): 173 """Test the composition of repeat datasets.""" 174 components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) 175 inner_count = array_ops.placeholder(dtypes.int64, shape=[]) 176 outer_count = array_ops.placeholder(dtypes.int64, shape=[]) 177 178 iterator = (dataset_ops.Dataset.from_tensors(components).repeat(inner_count) 179 .repeat(outer_count).make_initializable_iterator()) 180 init_op = iterator.initializer 181 get_next = iterator.get_next() 182 183 self.assertEqual([c.shape for c in components], 184 [t.shape for t in get_next]) 185 186 with self.test_session() as sess: 187 sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14}) 188 for _ in range(7 * 14): 189 results = sess.run(get_next) 190 for component, result_component in zip(components, results): 191 self.assertAllEqual(component, result_component) 192 with self.assertRaises(errors.OutOfRangeError): 193 sess.run(get_next) 194 195 def testRepeatEmptyDataset(self): 196 """Test that repeating an empty dataset does not hang.""" 197 iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10) 198 .repeat(-1).make_initializable_iterator()) 199 init_op = iterator.initializer 200 get_next = iterator.get_next() 201 202 with self.test_session() as sess: 203 sess.run(init_op) 204 with self.assertRaises(errors.OutOfRangeError): 205 sess.run(get_next) 206 207 208 if __name__ == "__main__": 209 test.main() 210