Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for input_pipeline_ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.contrib.input_pipeline.python.ops import input_pipeline_ops
     21 from tensorflow.python.framework import constant_op
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import errors
     24 from tensorflow.python.ops import state_ops
     25 from tensorflow.python.ops import variables
     26 from tensorflow.python.platform import test
     27 
     28 
     29 class InputPipelineOpsTest(test.TestCase):
     30 
     31   def testObtainNext(self):
     32     with self.test_session():
     33       var = state_ops.variable_op([], dtypes.int64)
     34       state_ops.assign(var, -1).op.run()
     35       c = constant_op.constant(["a", "b"])
     36       sample1 = input_pipeline_ops.obtain_next(c, var)
     37       self.assertEqual(b"a", sample1.eval())
     38       self.assertEqual(0, var.eval())
     39       sample2 = input_pipeline_ops.obtain_next(c, var)
     40       self.assertEqual(b"b", sample2.eval())
     41       self.assertEqual(1, var.eval())
     42       sample3 = input_pipeline_ops.obtain_next(c, var)
     43       self.assertEqual(b"a", sample3.eval())
     44       self.assertEqual(0, var.eval())
     45 
     46   def testSeekNext(self):
     47     string_list = ["a", "b", "c"]
     48     with self.test_session() as session:
     49       elem = input_pipeline_ops.seek_next(string_list)
     50       session.run([variables.global_variables_initializer()])
     51       self.assertEqual(b"a", session.run(elem))
     52       self.assertEqual(b"b", session.run(elem))
     53       self.assertEqual(b"c", session.run(elem))
     54       # Make sure we loop.
     55       self.assertEqual(b"a", session.run(elem))
     56 
     57   # Helper method that runs the op len(expected_list) number of times, asserts
     58   # that the results are elements of the expected_list and then throws an
     59   # OutOfRangeError.
     60   def _assert_output(self, expected_list, session, op):
     61     for element in expected_list:
     62       self.assertEqual(element, session.run(op))
     63     with self.assertRaises(errors.OutOfRangeError):
     64       session.run(op)
     65 
     66   def testSeekNextLimitEpochs(self):
     67     string_list = ["a", "b", "c"]
     68     with self.test_session() as session:
     69       elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
     70       session.run([
     71           variables.local_variables_initializer(),
     72           variables.global_variables_initializer()
     73       ])
     74       self._assert_output([b"a", b"b", b"c"], session, elem)
     75 
     76   def testSeekNextLimitEpochsThree(self):
     77     string_list = ["a", "b", "c"]
     78     with self.test_session() as session:
     79       elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
     80       session.run([
     81           variables.local_variables_initializer(),
     82           variables.global_variables_initializer()
     83       ])
     84       # Expect to see [a, b, c] three times.
     85       self._assert_output([b"a", b"b", b"c"] * 3, session, elem)
     86 
     87 
     88 if __name__ == "__main__":
     89   test.main()
     90