Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Test RangeDataset."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import os
     21 
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.data.ops import iterator_ops
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import gen_dataset_ops
     30 from tensorflow.python.ops import io_ops
     31 from tensorflow.python.ops import parsing_ops
     32 from tensorflow.python.ops import variables
     33 from tensorflow.python.platform import gfile
     34 from tensorflow.python.platform import test
     35 
     36 
     37 class RangeDatasetTest(test.TestCase):
     38 
     39   def tearDown(self):
     40     # Remove all checkpoint files.
     41     prefix = self._iterator_checkpoint_prefix()
     42     pattern = prefix + "*"
     43     files = gfile.Glob(pattern)
     44     map(gfile.Remove, files)
     45 
     46   def testStop(self):
     47     stop = array_ops.placeholder(dtypes.int64, shape=[])
     48     iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator()
     49     init_op = iterator.initializer
     50     get_next = iterator.get_next()
     51 
     52     with self.test_session() as sess:
     53       sess.run(init_op, feed_dict={stop: 5})
     54       for i in range(5):
     55         self.assertEqual(i, sess.run(get_next))
     56       with self.assertRaises(errors.OutOfRangeError):
     57         sess.run(get_next)
     58 
     59   def testStartStop(self):
     60     start = array_ops.placeholder(dtypes.int64, shape=[])
     61     stop = array_ops.placeholder(dtypes.int64, shape=[])
     62     iterator = dataset_ops.Dataset.range(start,
     63                                          stop).make_initializable_iterator()
     64     init_op = iterator.initializer
     65     get_next = iterator.get_next()
     66 
     67     with self.test_session() as sess:
     68       sess.run(init_op, feed_dict={start: 2, stop: 5})
     69       for i in range(2, 5):
     70         self.assertEqual(i, sess.run(get_next))
     71       with self.assertRaises(errors.OutOfRangeError):
     72         sess.run(get_next)
     73 
     74   def testStartStopStep(self):
     75     start = array_ops.placeholder(dtypes.int64, shape=[])
     76     stop = array_ops.placeholder(dtypes.int64, shape=[])
     77     step = array_ops.placeholder(dtypes.int64, shape=[])
     78     iterator = dataset_ops.Dataset.range(start, stop,
     79                                          step).make_initializable_iterator()
     80     init_op = iterator.initializer
     81     get_next = iterator.get_next()
     82 
     83     with self.test_session() as sess:
     84       sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2})
     85       for i in range(2, 10, 2):
     86         self.assertEqual(i, sess.run(get_next))
     87       with self.assertRaises(errors.OutOfRangeError):
     88         sess.run(get_next)
     89 
     90   def testZeroStep(self):
     91     start = array_ops.placeholder(dtypes.int64, shape=[])
     92     stop = array_ops.placeholder(dtypes.int64, shape=[])
     93     step = array_ops.placeholder(dtypes.int64, shape=[])
     94     iterator = dataset_ops.Dataset.range(start, stop,
     95                                          step).make_initializable_iterator()
     96     init_op = iterator.initializer
     97 
     98     with self.test_session() as sess:
     99       with self.assertRaises(errors.InvalidArgumentError):
    100         sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0})
    101 
    102   def testNegativeStep(self):
    103     start = array_ops.placeholder(dtypes.int64, shape=[])
    104     stop = array_ops.placeholder(dtypes.int64, shape=[])
    105     step = array_ops.placeholder(dtypes.int64, shape=[])
    106     iterator = dataset_ops.Dataset.range(start, stop,
    107                                          step).make_initializable_iterator()
    108     init_op = iterator.initializer
    109     get_next = iterator.get_next()
    110 
    111     with self.test_session() as sess:
    112       sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1})
    113       # This for loop is a no-op but will ensure that the implementation is
    114       # consistent with range if it ever changes.
    115       for i in range(2, 10, -1):
    116         self.assertEqual(i, sess.run(get_next))
    117       with self.assertRaises(errors.OutOfRangeError):
    118         sess.run(get_next)
    119 
    120   def testStopLessThanStart(self):
    121     start = array_ops.placeholder(dtypes.int64, shape=[])
    122     stop = array_ops.placeholder(dtypes.int64, shape=[])
    123     iterator = dataset_ops.Dataset.range(start,
    124                                          stop).make_initializable_iterator()
    125     init_op = iterator.initializer
    126     get_next = iterator.get_next()
    127 
    128     with self.test_session() as sess:
    129       sess.run(init_op, feed_dict={start: 10, stop: 2})
    130       # This for loop is a no-op but will ensure that the implementation is
    131       # consistent with range if it ever changes.
    132       for i in range(10, 2):
    133         self.assertEqual(i, sess.run(get_next))
    134       with self.assertRaises(errors.OutOfRangeError):
    135         sess.run(get_next)
    136 
    137   def testStopLessThanStartWithPositiveStep(self):
    138     start = array_ops.placeholder(dtypes.int64, shape=[])
    139     stop = array_ops.placeholder(dtypes.int64, shape=[])
    140     step = array_ops.placeholder(dtypes.int64, shape=[])
    141     iterator = dataset_ops.Dataset.range(start, stop,
    142                                          step).make_initializable_iterator()
    143     init_op = iterator.initializer
    144     get_next = iterator.get_next()
    145 
    146     with self.test_session() as sess:
    147       sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2})
    148       # This for loop is a no-op but will ensure that the implementation is
    149       # consistent with range if it ever changes.
    150       for i in range(10, 2, 2):
    151         self.assertEqual(i, sess.run(get_next))
    152       with self.assertRaises(errors.OutOfRangeError):
    153         sess.run(get_next)
    154 
    155   def testStopLessThanStartWithNegativeStep(self):
    156     start = array_ops.placeholder(dtypes.int64, shape=[])
    157     stop = array_ops.placeholder(dtypes.int64, shape=[])
    158     step = array_ops.placeholder(dtypes.int64, shape=[])
    159     iterator = dataset_ops.Dataset.range(start, stop,
    160                                          step).make_initializable_iterator()
    161     init_op = iterator.initializer
    162     get_next = iterator.get_next()
    163 
    164     with self.test_session() as sess:
    165       sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1})
    166       for i in range(10, 2, -1):
    167         self.assertEqual(i, sess.run(get_next))
    168       with self.assertRaises(errors.OutOfRangeError):
    169         sess.run(get_next)
    170 
    171   def _iterator_checkpoint_prefix(self):
    172     return os.path.join(self.get_temp_dir(), "iterator")
    173 
    174   def _save_op(self, iterator_resource):
    175     iterator_state_variant = gen_dataset_ops.serialize_iterator(
    176         iterator_resource)
    177     save_op = io_ops.write_file(
    178         self._iterator_checkpoint_prefix(),
    179         parsing_ops.serialize_tensor(iterator_state_variant))
    180     return save_op
    181 
    182   def _restore_op(self, iterator_resource):
    183     iterator_state_variant = parsing_ops.parse_tensor(
    184         io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
    185     restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
    186                                                       iterator_state_variant)
    187     return restore_op
    188 
    189   def testSaveRestore(self):
    190 
    191     def _build_graph(start, stop):
    192       iterator = dataset_ops.Dataset.range(start,
    193                                            stop).make_initializable_iterator()
    194       init_op = iterator.initializer
    195       get_next = iterator.get_next()
    196       save_op = self._save_op(iterator._iterator_resource)
    197       restore_op = self._restore_op(iterator._iterator_resource)
    198       return init_op, get_next, save_op, restore_op
    199 
    200     # Saving and restoring in different sessions.
    201     start = 2
    202     stop = 10
    203     break_point = 5
    204     with ops.Graph().as_default() as g:
    205       init_op, get_next, save_op, _ = _build_graph(start, stop)
    206       with self.test_session(graph=g) as sess:
    207         sess.run(variables.global_variables_initializer())
    208         sess.run(init_op)
    209         for i in range(start, break_point):
    210           self.assertEqual(i, sess.run(get_next))
    211         sess.run(save_op)
    212 
    213     with ops.Graph().as_default() as g:
    214       init_op, get_next, _, restore_op = _build_graph(start, stop)
    215       with self.test_session(graph=g) as sess:
    216         sess.run(init_op)
    217         sess.run(restore_op)
    218         for i in range(break_point, stop):
    219           self.assertEqual(i, sess.run(get_next))
    220         with self.assertRaises(errors.OutOfRangeError):
    221           sess.run(get_next)
    222 
    223     # Saving and restoring in same session.
    224     with ops.Graph().as_default() as g:
    225       init_op, get_next, save_op, restore_op = _build_graph(start, stop)
    226       with self.test_session(graph=g) as sess:
    227         sess.run(variables.global_variables_initializer())
    228         sess.run(init_op)
    229         for i in range(start, break_point):
    230           self.assertEqual(i, sess.run(get_next))
    231         sess.run(save_op)
    232         sess.run(restore_op)
    233         for i in range(break_point, stop):
    234           self.assertEqual(i, sess.run(get_next))
    235         with self.assertRaises(errors.OutOfRangeError):
    236           sess.run(get_next)
    237 
    238   def testRestoreWithoutBuildingDatasetGraph(self):
    239 
    240     def _build_graph(start, stop, num_epochs):
    241       dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
    242       iterator = dataset.make_initializable_iterator()
    243       init_op = iterator.initializer
    244       get_next = iterator.get_next()
    245       save_op = self._save_op(iterator._iterator_resource)
    246       restore_op = self._restore_op(iterator._iterator_resource)
    247       return init_op, get_next, save_op, restore_op
    248 
    249     # Saving and restoring in different sessions.
    250     start = 2
    251     stop = 10
    252     num_epochs = 5
    253     break_point = 5
    254     break_epoch = 3
    255     with ops.Graph().as_default() as g:
    256       init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
    257       with self.test_session(graph=g) as sess:
    258         sess.run(variables.global_variables_initializer())
    259         sess.run(init_op)
    260         for _ in range(break_epoch):
    261           for i in range(start, stop):
    262             self.assertEqual(i, sess.run(get_next))
    263         for i in range(start, break_point):
    264           self.assertEqual(i, sess.run(get_next))
    265         sess.run(save_op)
    266 
    267     with ops.Graph().as_default() as g:
    268       # Create an empty IteratorResource and restore the Iterator into it.
    269       output_types = dtypes.int64
    270       output_shapes = tensor_shape.scalar()
    271       iterator = iterator_ops.Iterator.from_structure(output_types,
    272                                                       output_shapes)
    273       restore_op = self._restore_op(iterator._iterator_resource)
    274       get_next = iterator.get_next()
    275       with self.test_session(graph=g) as sess:
    276         sess.run(restore_op)
    277         for i in range(break_point, stop):
    278           self.assertEqual(i, sess.run(get_next))
    279         for _ in range(break_epoch + 1, num_epochs):
    280           for i in range(start, stop):
    281             self.assertEqual(i, sess.run(get_next))
    282         with self.assertRaises(errors.OutOfRangeError):
    283           sess.run(get_next)
    284 
    285   def testRestoreInModifiedGraph(self):
    286 
    287     def _build_graph(start, stop):
    288       dataset = dataset_ops.Dataset.range(start, stop)
    289       iterator = dataset.make_initializable_iterator()
    290       init_op = iterator.initializer
    291       get_next = iterator.get_next()
    292       save_op = self._save_op(iterator._iterator_resource)
    293       restore_op = self._restore_op(iterator._iterator_resource)
    294       return init_op, get_next, save_op, restore_op
    295 
    296     # Saving and restoring in different sessions.
    297     start = 2
    298     stop = 10
    299     stop_1 = 8
    300     break_point = 5
    301     with ops.Graph().as_default() as g:
    302       init_op, get_next, save_op, _ = _build_graph(start, stop)
    303       with self.test_session(graph=g) as sess:
    304         sess.run(variables.global_variables_initializer())
    305         sess.run(init_op)
    306         for i in range(start, break_point):
    307           self.assertEqual(i, sess.run(get_next))
    308         sess.run(save_op)
    309 
    310     with ops.Graph().as_default() as g:
    311       # Intentionally build a graph with a different value for stop to make sure
    312       # the original dataset graph is actually getting loaded.
    313       init_op, get_next, _, restore_op = _build_graph(start, stop_1)
    314       with self.test_session(graph=g) as sess:
    315         sess.run(restore_op)
    316         for i in range(break_point, stop):
    317           self.assertEqual(i, sess.run(get_next))
    318         with self.assertRaises(errors.OutOfRangeError):
    319           sess.run(get_next)
    320 
    321   def testInitThenRestore(self):
    322     # Note: Calling init_op before restore_op is redundant. This test just makes
    323     # sure we do not fail if restore is called on an already initialized
    324     # iterator resource.
    325 
    326     def _build_graph(start, stop):
    327       dataset = dataset_ops.Dataset.range(start, stop)
    328       iterator = dataset.make_initializable_iterator()
    329       init_op = iterator.initializer
    330       get_next = iterator.get_next()
    331       save_op = self._save_op(iterator._iterator_resource)
    332       restore_op = self._restore_op(iterator._iterator_resource)
    333       return init_op, get_next, save_op, restore_op
    334 
    335     # Saving and restoring in different sessions.
    336     start = 2
    337     stop = 10
    338     break_point = 5
    339     with ops.Graph().as_default() as g:
    340       init_op, get_next, save_op, _ = _build_graph(start, stop)
    341       with self.test_session(graph=g) as sess:
    342         sess.run(variables.global_variables_initializer())
    343         sess.run(init_op)
    344         for i in range(start, break_point):
    345           self.assertEqual(i, sess.run(get_next))
    346         sess.run(save_op)
    347 
    348     with ops.Graph().as_default() as g:
    349       init_op, get_next, _, restore_op = _build_graph(start, stop)
    350       with self.test_session(graph=g) as sess:
    351         sess.run(init_op)
    352         sess.run(restore_op)
    353         for i in range(break_point, stop):
    354           self.assertEqual(i, sess.run(get_next))
    355         with self.assertRaises(errors.OutOfRangeError):
    356           sess.run(get_next)
    357 
    358   def testMultipleSaves(self):
    359 
    360     def _build_graph(start, stop):
    361       iterator = dataset_ops.Dataset.range(start,
    362                                            stop).make_initializable_iterator()
    363       init_op = iterator.initializer
    364       get_next = iterator.get_next()
    365       save_op = self._save_op(iterator._iterator_resource)
    366       restore_op = self._restore_op(iterator._iterator_resource)
    367       return init_op, get_next, save_op, restore_op
    368 
    369     start = 2
    370     stop = 10
    371     break_point1 = 5
    372     break_point2 = 7
    373 
    374     with ops.Graph().as_default() as g:
    375       init_op, get_next, save_op, _ = _build_graph(start, stop)
    376       with self.test_session(graph=g) as sess:
    377         sess.run(variables.global_variables_initializer())
    378         sess.run(init_op)
    379         for i in range(start, break_point1):
    380           self.assertEqual(i, sess.run(get_next))
    381         sess.run(save_op)
    382 
    383     with ops.Graph().as_default() as g:
    384       init_op, get_next, save_op, restore_op = _build_graph(start, stop)
    385       with self.test_session(graph=g) as sess:
    386         sess.run(restore_op)
    387         for i in range(break_point1, break_point2):
    388           self.assertEqual(i, sess.run(get_next))
    389         sess.run(save_op)
    390 
    391     break_point2 = 7
    392     with ops.Graph().as_default() as g:
    393       init_op, get_next, save_op, restore_op = _build_graph(start, stop)
    394       with self.test_session(graph=g) as sess:
    395         sess.run(restore_op)
    396         for i in range(break_point2, stop):
    397           self.assertEqual(i, sess.run(get_next))
    398         with self.assertRaises(errors.OutOfRangeError):
    399           sess.run(get_next)
    400 
    401   def testSaveRestoreWithRepeat(self):
    402 
    403     def _build_graph(start, stop, num_epochs):
    404       iterator = dataset_ops.Dataset.range(
    405           start, stop).repeat(num_epochs).make_initializable_iterator()
    406       init_op = iterator.initializer
    407       get_next = iterator.get_next()
    408       save_op = self._save_op(iterator._iterator_resource)
    409       restore_op = self._restore_op(iterator._iterator_resource)
    410       return init_op, get_next, save_op, restore_op
    411 
    412     start = 2
    413     stop = 10
    414     num_epochs = 5
    415     break_range = 5
    416     break_epoch = 3
    417     with ops.Graph().as_default() as g:
    418       init_op, get_next, save_op, restore_op = _build_graph(
    419           start, stop, num_epochs)
    420       with self.test_session(graph=g) as sess:
    421         sess.run(variables.global_variables_initializer())
    422         sess.run(init_op)
    423         # Note: There is no checkpoint saved currently so a NotFoundError is
    424         # raised.
    425         with self.assertRaises(errors.NotFoundError):
    426           sess.run(restore_op)
    427         for _ in range(break_epoch - 1):
    428           for i in range(start, stop):
    429             self.assertEqual(i, sess.run(get_next))
    430         for i in range(start, break_range):
    431           self.assertEqual(i, sess.run(get_next))
    432         sess.run(save_op)
    433 
    434     with ops.Graph().as_default() as g:
    435       init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
    436       with self.test_session(graph=g) as sess:
    437         sess.run(restore_op)
    438         for i in range(break_range, stop):
    439           self.assertEqual(i, sess.run(get_next))
    440         for _ in range(break_epoch, num_epochs):
    441           for i in range(start, stop):
    442             self.assertEqual(i, sess.run(get_next))
    443         with self.assertRaises(errors.OutOfRangeError):
    444           sess.run(get_next)
    445 
    446   def testSaveRestoreExhaustedIterator(self):
    447 
    448     def _build_graph(start, stop, num_epochs):
    449       iterator = dataset_ops.Dataset.range(
    450           start, stop).repeat(num_epochs).make_initializable_iterator()
    451       init_op = iterator.initializer
    452       get_next = iterator.get_next()
    453       save_op = self._save_op(iterator._iterator_resource)
    454       restore_op = self._restore_op(iterator._iterator_resource)
    455       return init_op, get_next, save_op, restore_op
    456 
    457     start = 2
    458     stop = 10
    459     num_epochs = 5
    460     with ops.Graph().as_default() as g:
    461       init_op, get_next, save_op, restore_op = _build_graph(
    462           start, stop, num_epochs)
    463       with self.test_session(graph=g) as sess:
    464         sess.run(variables.global_variables_initializer())
    465         sess.run(init_op)
    466         # Note: There is no checkpoint saved currently so a NotFoundError is
    467         # raised.
    468         with self.assertRaises(errors.NotFoundError):
    469           sess.run(restore_op)
    470         for _ in range(num_epochs):
    471           for i in range(start, stop):
    472             self.assertEqual(i, sess.run(get_next))
    473         with self.assertRaises(errors.OutOfRangeError):
    474           sess.run(get_next)
    475         sess.run(save_op)
    476 
    477     with ops.Graph().as_default() as g:
    478       init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
    479       with self.test_session(graph=g) as sess:
    480         sess.run(restore_op)
    481         with self.assertRaises(errors.OutOfRangeError):
    482           sess.run(get_next)
    483 
    484 
    485 if __name__ == "__main__":
    486   test.main()
    487