1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for input_pipeline_ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.contrib.input_pipeline.python.ops import input_pipeline_ops 21 from tensorflow.python.framework import constant_op 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import errors 24 from tensorflow.python.ops import state_ops 25 from tensorflow.python.ops import variables 26 from tensorflow.python.platform import test 27 28 29 class InputPipelineOpsTest(test.TestCase): 30 31 def testObtainNext(self): 32 with self.test_session(): 33 var = state_ops.variable_op([], dtypes.int64) 34 state_ops.assign(var, -1).op.run() 35 c = constant_op.constant(["a", "b"]) 36 sample1 = input_pipeline_ops.obtain_next(c, var) 37 self.assertEqual(b"a", sample1.eval()) 38 self.assertEqual(0, var.eval()) 39 sample2 = input_pipeline_ops.obtain_next(c, var) 40 self.assertEqual(b"b", sample2.eval()) 41 self.assertEqual(1, var.eval()) 42 sample3 = input_pipeline_ops.obtain_next(c, var) 43 self.assertEqual(b"a", sample3.eval()) 44 self.assertEqual(0, var.eval()) 45 46 def testSeekNext(self): 47 string_list = ["a", "b", "c"] 48 with self.test_session() as session: 49 elem = input_pipeline_ops.seek_next(string_list) 50 session.run([variables.global_variables_initializer()]) 51 self.assertEqual(b"a", session.run(elem)) 52 self.assertEqual(b"b", session.run(elem)) 53 self.assertEqual(b"c", session.run(elem)) 54 # Make sure we loop. 55 self.assertEqual(b"a", session.run(elem)) 56 57 # Helper method that runs the op len(expected_list) number of times, asserts 58 # that the results are elements of the expected_list and then throws an 59 # OutOfRangeError. 60 def _assert_output(self, expected_list, session, op): 61 for element in expected_list: 62 self.assertEqual(element, session.run(op)) 63 with self.assertRaises(errors.OutOfRangeError): 64 session.run(op) 65 66 def testSeekNextLimitEpochs(self): 67 string_list = ["a", "b", "c"] 68 with self.test_session() as session: 69 elem = input_pipeline_ops.seek_next(string_list, num_epochs=1) 70 session.run([ 71 variables.local_variables_initializer(), 72 variables.global_variables_initializer() 73 ]) 74 self._assert_output([b"a", b"b", b"c"], session, elem) 75 76 def testSeekNextLimitEpochsThree(self): 77 string_list = ["a", "b", "c"] 78 with self.test_session() as session: 79 elem = input_pipeline_ops.seek_next(string_list, num_epochs=3) 80 session.run([ 81 variables.local_variables_initializer(), 82 variables.global_variables_initializer() 83 ]) 84 # Expect to see [a, b, c] three times. 85 self._assert_output([b"a", b"b", b"c"] * 3, session, elem) 86 87 88 if __name__ == "__main__": 89 test.main() 90