1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Test PrefetchDataset.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import errors 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.platform import test 25 26 27 class PrefetchDatasetTest(test.TestCase): 28 29 def testBufferSize(self): 30 buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) 31 iterator = dataset_ops.Dataset.range(10).prefetch( 32 buffer_size=buffer_size).make_initializable_iterator() 33 init_op = iterator.initializer 34 get_next = iterator.get_next() 35 36 with self.test_session() as sess: 37 sess.run(init_op, feed_dict={buffer_size: 5}) 38 for m in range(10): 39 self.assertEqual(m, sess.run(get_next)) 40 with self.assertRaises(errors.OutOfRangeError): 41 sess.run(get_next) 42 43 def testInvalidBufferSize(self): 44 buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) 45 iterator = dataset_ops.Dataset.range(10).prefetch( 46 buffer_size=buffer_size).make_initializable_iterator() 47 init_op = iterator.initializer 48 49 with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): 50 with self.test_session() as sess: 51 sess.run(init_op, feed_dict={buffer_size: 0}) 52 53 with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"): 54 with self.test_session() as sess: 55 sess.run(init_op, feed_dict={buffer_size: -5}) 56 57 58 if __name__ == "__main__": 59 test.main() 60