Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import os
     21 import warnings
     22 
     23 import numpy as np
     24 
     25 from tensorflow.core.protobuf import config_pb2
     26 from tensorflow.python.client import session
     27 from tensorflow.python.data.ops import dataset_ops
     28 from tensorflow.python.data.ops import iterator_ops
     29 from tensorflow.python.data.ops import readers
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import errors
     33 from tensorflow.python.framework import function
     34 from tensorflow.python.framework import ops
     35 from tensorflow.python.framework import test_util
     36 from tensorflow.python.ops import array_ops
     37 from tensorflow.python.ops import functional_ops
     38 from tensorflow.python.ops import gen_dataset_ops
     39 from tensorflow.python.ops import gradients_impl
     40 from tensorflow.python.ops import io_ops
     41 from tensorflow.python.ops import math_ops
     42 from tensorflow.python.ops import parsing_ops
     43 from tensorflow.python.ops import script_ops
     44 from tensorflow.python.ops import variables
     45 from tensorflow.python.platform import test
     46 from tensorflow.python.training import server_lib
     47 
     48 
     49 class IteratorTest(test.TestCase):
     50 
     51   def testAttemptingGradientsRaiseExceptions(self):
     52     component = constant_op.constant([1])
     53     side = constant_op.constant(0)
     54     add = lambda x: x + side
     55     dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
     56     value = dataset.make_one_shot_iterator().get_next()
     57     with self.assertRaisesRegexp(LookupError, "No gradient defined"):
     58       gradients_impl.gradients(value, component)
     59     with self.assertRaisesRegexp(LookupError, "No gradient defined"):
     60       gradients_impl.gradients(value, side)
     61     with self.assertRaisesRegexp(LookupError, "No gradient defined"):
     62       gradients_impl.gradients(value, [component, side])
     63 
     64   def testCapturingStateInOneShotRaisesException(self):
     65     var = variables.Variable(37.0, name="myvar")
     66     dataset = (dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
     67                .map(lambda x: x + var))
     68     with self.assertRaisesRegexp(
     69         ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support "
     70         "datasets that capture stateful objects.+myvar"):
     71       dataset.make_one_shot_iterator()
     72 
     73   def testOneShotIterator(self):
     74     components = (np.arange(7),
     75                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
     76                   np.array(37.0) * np.arange(7))
     77 
     78     def _map_fn(x, y, z):
     79       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
     80 
     81     iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
     82                 .repeat(14).make_one_shot_iterator())
     83     get_next = iterator.get_next()
     84 
     85     self.assertEqual([c.shape[1:] for c in components],
     86                      [t.shape for t in get_next])
     87 
     88     with self.test_session() as sess:
     89       for _ in range(14):
     90         for i in range(7):
     91           result = sess.run(get_next)
     92           for component, result_component in zip(components, result):
     93             self.assertAllEqual(component[i]**2, result_component)
     94       with self.assertRaises(errors.OutOfRangeError):
     95         sess.run(get_next)
     96 
     97   def testOneShotIteratorCaptureByValue(self):
     98     components = (np.arange(7),
     99                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
    100                   np.array(37.0) * np.arange(7))
    101     tensor_components = tuple([ops.convert_to_tensor(c) for c in components])
    102 
    103     def _map_fn(x, y, z):
    104       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
    105 
    106     iterator = (dataset_ops.Dataset.from_tensor_slices(tensor_components)
    107                 .map(_map_fn).repeat(14).make_one_shot_iterator())
    108     get_next = iterator.get_next()
    109 
    110     self.assertEqual([c.shape[1:] for c in components],
    111                      [t.shape for t in get_next])
    112 
    113     with self.test_session() as sess:
    114       for _ in range(14):
    115         for i in range(7):
    116           result = sess.run(get_next)
    117           for component, result_component in zip(components, result):
    118             self.assertAllEqual(component[i]**2, result_component)
    119       with self.assertRaises(errors.OutOfRangeError):
    120         sess.run(get_next)
    121 
    122   def testOneShotIteratorInsideContainer(self):
    123     components = (np.arange(7),
    124                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
    125                   np.array(37.0) * np.arange(7))
    126 
    127     def within_container():
    128       def _map_fn(x, y, z):
    129         return math_ops.square(x), math_ops.square(y), math_ops.square(z)
    130       iterator = (dataset_ops.Dataset.from_tensor_slices(components)
    131                   .map(_map_fn).repeat(14).make_one_shot_iterator())
    132       return iterator.get_next()
    133 
    134     server = server_lib.Server.create_local_server()
    135 
    136     # Create two iterators within unique containers, and run them to
    137     # make sure that the resources aren't shared.
    138     #
    139     # The test below would fail if cname were the same across both
    140     # sessions.
    141     for i in range(2):
    142       with session.Session(server.target) as sess:
    143         cname = "iteration%d" % i
    144         with ops.container(cname):
    145           get_next = within_container()
    146 
    147         for _ in range(14):
    148           for i in range(7):
    149             result = sess.run(get_next)
    150             for component, result_component in zip(components, result):
    151               self.assertAllEqual(component[i]**2, result_component)
    152         with self.assertRaises(errors.OutOfRangeError):
    153           sess.run(get_next)
    154 
    155   def testOneShotIteratorNonBlocking(self):
    156     dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
    157     iterator = dataset.make_one_shot_iterator()
    158     next_element = iterator.get_next()
    159 
    160     # Create a session with a single thread to ensure that the
    161     # one-shot iterator initializer does not deadlock.
    162     config = config_pb2.ConfigProto(inter_op_parallelism_threads=1,
    163                                     use_per_session_threads=True)
    164     with session.Session(config=config) as sess:
    165       self.assertAllEqual([1, 4, 9], sess.run(next_element))
    166       with self.assertRaises(errors.OutOfRangeError):
    167         sess.run(next_element)
    168 
    169     # Test with multiple threads invoking the one-shot iterator concurrently.
    170     with session.Session(config=config) as sess:
    171       results = []
    172       def consumer_thread():
    173         try:
    174           results.append(sess.run(next_element))
    175         except errors.OutOfRangeError:
    176           results.append(None)
    177 
    178       num_threads = 8
    179       threads = [
    180           self.checkedThread(consumer_thread) for _ in range(num_threads)]
    181       for t in threads:
    182         t.start()
    183       for t in threads:
    184         t.join()
    185 
    186       self.assertEqual(num_threads, len(results))
    187       self.assertEqual(num_threads - 1,
    188                        len([None for r in results if r is None]))
    189       self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])
    190 
    191   def testOneShotIteratorInitializerFails(self):
    192     # Define a dataset whose initialization will always fail.
    193     dataset = dataset_ops.Dataset.from_tensors(
    194         array_ops.check_numerics(
    195             constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
    196     iterator = dataset.make_one_shot_iterator()
    197     next_element = iterator.get_next()
    198 
    199     with self.test_session() as sess:
    200       with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
    201         sess.run(next_element)
    202 
    203       # Test that subsequent attempts to use the iterator also fail.
    204       with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
    205         sess.run(next_element)
    206 
    207     with self.test_session() as sess:
    208       def consumer_thread():
    209         with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
    210           sess.run(next_element)
    211 
    212       num_threads = 8
    213       threads = [
    214           self.checkedThread(consumer_thread) for _ in range(num_threads)]
    215       for t in threads:
    216         t.start()
    217       for t in threads:
    218         t.join()
    219 
    220   def testSimpleSharedResource(self):
    221     components = (
    222         np.array(1, dtype=np.int64),
    223         np.array([1, 2, 3], dtype=np.int64),
    224         np.array(37.0, dtype=np.float64)
    225     )
    226 
    227     server = server_lib.Server.create_local_server()
    228 
    229     # Create two non-overlapping sessions that share the same iterator
    230     # resource on the same server, and verify that an action of the
    231     # first session (initializing the iterator) is visible in the
    232     # second session.
    233     with ops.Graph().as_default():
    234       iterator = (dataset_ops.Dataset.from_tensors(components)
    235                   .map(lambda x, y, z: (x, y, z)).make_initializable_iterator(
    236                       shared_name="shared_iterator"))
    237       init_op = iterator.initializer
    238       get_next = iterator.get_next()
    239 
    240       with session.Session(server.target) as sess:
    241         sess.run(init_op)
    242         results = sess.run(get_next)
    243         for component, result_component in zip(components, results):
    244           self.assertAllEqual(component, result_component)
    245         with self.assertRaises(errors.OutOfRangeError):
    246           sess.run(get_next)
    247 
    248         # Re-initialize the iterator in the first session.
    249         sess.run(init_op)
    250 
    251     with ops.Graph().as_default():
    252       # Re-define the iterator manually, without defining any of the
    253       # functions in this graph, to ensure that we are not
    254       # accidentally redefining functions with the same names in the
    255       # new graph.
    256       iterator = iterator_ops.Iterator.from_structure(
    257           shared_name="shared_iterator",
    258           output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
    259           output_shapes=([], [3], []))
    260       get_next = iterator.get_next()
    261 
    262       with session.Session(server.target) as sess:
    263         # Use the iterator without re-initializing in the second session.
    264         results = sess.run(get_next)
    265         for component, result_component in zip(components, results):
    266           self.assertAllEqual(component, result_component)
    267         with self.assertRaises(errors.OutOfRangeError):
    268           sess.run(get_next)
    269 
    270   def testNotInitializedError(self):
    271     components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
    272     iterator = (dataset_ops.Dataset.from_tensors(components)
    273                 .make_initializable_iterator())
    274     get_next = iterator.get_next()
    275 
    276     with self.test_session() as sess:
    277       with self.assertRaisesRegexp(errors.FailedPreconditionError,
    278                                    "iterator has not been initialized"):
    279         sess.run(get_next)
    280 
    281   def testReinitializableIterator(self):
    282     dataset_3 = dataset_ops.Dataset.from_tensors(
    283         constant_op.constant([1, 2, 3]))
    284     dataset_4 = dataset_ops.Dataset.from_tensors(
    285         constant_op.constant([4, 5, 6, 7]))
    286     iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types,
    287                                                     [None])
    288 
    289     dataset_3_init_op = iterator.make_initializer(dataset_3)
    290     dataset_4_init_op = iterator.make_initializer(dataset_4)
    291     get_next = iterator.get_next()
    292 
    293     self.assertEqual(dataset_3.output_types, iterator.output_types)
    294     self.assertEqual(dataset_4.output_types, iterator.output_types)
    295     self.assertEqual([None], iterator.output_shapes.as_list())
    296 
    297     with self.test_session() as sess:
    298       # The iterator is initially uninitialized.
    299       with self.assertRaises(errors.FailedPreconditionError):
    300         sess.run(get_next)
    301 
    302       # Initialize with one dataset.
    303       sess.run(dataset_3_init_op)
    304       self.assertAllEqual([1, 2, 3], sess.run(get_next))
    305       with self.assertRaises(errors.OutOfRangeError):
    306         sess.run(get_next)
    307 
    308       # Initialize with a different dataset.
    309       sess.run(dataset_4_init_op)
    310       self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
    311       with self.assertRaises(errors.OutOfRangeError):
    312         sess.run(get_next)
    313 
    314       # Reinitialize with the first dataset.
    315       sess.run(dataset_3_init_op)
    316       self.assertAllEqual([1, 2, 3], sess.run(get_next))
    317       with self.assertRaises(errors.OutOfRangeError):
    318         sess.run(get_next)
    319 
    320   def testReinitializableIteratorStaticErrors(self):
    321     # Non-matching structure for types and shapes.
    322     with self.assertRaises(TypeError):
    323       iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
    324                                                        dtypes.float64), [None])
    325 
    326     # Test validation of dataset argument.
    327     iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
    328                                                      dtypes.float64))
    329 
    330     # Incompatible structure.
    331     with self.assertRaises(ValueError):
    332       iterator.make_initializer(
    333           dataset_ops.Dataset.from_tensors(((constant_op.constant(
    334               [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant(
    335                   [4., 5., 6., 7.], dtype=dtypes.float64),))))
    336 
    337     # Incompatible types.
    338     with self.assertRaises(TypeError):
    339       iterator.make_initializer(
    340           dataset_ops.Dataset.from_tensors((constant_op.constant(
    341               [1, 2, 3], dtype=dtypes.int32), constant_op.constant(
    342                   [4., 5., 6., 7.], dtype=dtypes.float32))))
    343 
    344     # Incompatible shapes.
    345     iterator = iterator_ops.Iterator.from_structure(
    346         (dtypes.int64, dtypes.float64), ([None], []))
    347     with self.assertRaises(TypeError):
    348       iterator.make_initializer(
    349           dataset_ops.Dataset.from_tensors((constant_op.constant(
    350               [1, 2, 3], dtype=dtypes.int64), constant_op.constant(
    351                   [4., 5., 6., 7.], dtype=dtypes.float64))))
    352 
    353   def testIteratorStringHandle(self):
    354     dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    355     dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
    356 
    357     iterator_3 = dataset_3.make_one_shot_iterator()
    358     iterator_4 = dataset_4.make_one_shot_iterator()
    359 
    360     handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    361     feedable_iterator = iterator_ops.Iterator.from_string_handle(
    362         handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
    363     next_element = feedable_iterator.get_next()
    364 
    365     self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
    366     self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
    367     self.assertEqual([], feedable_iterator.output_shapes)
    368 
    369     with self.test_session() as sess:
    370       iterator_3_handle = sess.run(iterator_3.string_handle())
    371       iterator_4_handle = sess.run(iterator_4.string_handle())
    372 
    373       self.assertEqual(
    374           10, sess.run(next_element,
    375                        feed_dict={handle_placeholder: iterator_4_handle}))
    376       self.assertEqual(
    377           1, sess.run(next_element,
    378                       feed_dict={handle_placeholder: iterator_3_handle}))
    379       self.assertEqual(
    380           20, sess.run(next_element,
    381                        feed_dict={handle_placeholder: iterator_4_handle}))
    382       self.assertEqual(
    383           2, sess.run(next_element,
    384                       feed_dict={handle_placeholder: iterator_3_handle}))
    385       self.assertEqual(
    386           30, sess.run(next_element,
    387                        feed_dict={handle_placeholder: iterator_4_handle}))
    388       self.assertEqual(
    389           3, sess.run(next_element,
    390                       feed_dict={handle_placeholder: iterator_3_handle}))
    391       self.assertEqual(
    392           40, sess.run(next_element,
    393                        feed_dict={handle_placeholder: iterator_4_handle}))
    394       with self.assertRaises(errors.OutOfRangeError):
    395         sess.run(next_element,
    396                  feed_dict={handle_placeholder: iterator_3_handle})
    397       with self.assertRaises(errors.OutOfRangeError):
    398         sess.run(next_element,
    399                  feed_dict={handle_placeholder: iterator_4_handle})
    400 
    401   def testIteratorStringHandleReuseTensorObject(self):
    402     dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    403     one_shot_iterator = dataset.make_one_shot_iterator()
    404     initializable_iterator = dataset.make_initializable_iterator()
    405     structure_iterator = iterator_ops.Iterator.from_structure(
    406         dataset.output_types)
    407 
    408     created_ops = len(ops.get_default_graph().get_operations())
    409 
    410     self.assertIs(one_shot_iterator.string_handle(),
    411                   one_shot_iterator.string_handle())
    412     self.assertIs(initializable_iterator.string_handle(),
    413                   initializable_iterator.string_handle())
    414     self.assertIs(structure_iterator.string_handle(),
    415                   structure_iterator.string_handle())
    416 
    417     # Assert that getting the (default) string handle creates no ops.
    418     self.assertEqual(created_ops, len(ops.get_default_graph().get_operations()))
    419 
    420     # Specifying an explicit name will create a new op.
    421     handle_with_name = one_shot_iterator.string_handle(name="foo")
    422     self.assertEqual("foo", handle_with_name.op.name)
    423     self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)
    424 
    425     handle_with_same_name = one_shot_iterator.string_handle(name="foo")
    426     self.assertEqual("foo_1", handle_with_same_name.op.name)
    427     self.assertIsNot(handle_with_name, handle_with_same_name)
    428 
    429   def testIteratorStringHandleError(self):
    430     dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices([1, 2,
    431                                                                   3]).repeat())
    432     dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]))
    433 
    434     handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    435 
    436     feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
    437         handle_placeholder, dtypes.int32, [])
    438     feedable_int_vector = iterator_ops.Iterator.from_string_handle(
    439         handle_placeholder, dtypes.int32, [None])
    440     feedable_int_any = iterator_ops.Iterator.from_string_handle(
    441         handle_placeholder, dtypes.int32)
    442 
    443     with self.test_session() as sess:
    444       handle_int_scalar = sess.run(
    445           dataset_int_scalar.make_one_shot_iterator().string_handle())
    446       handle_float_vector = sess.run(
    447           dataset_float_vector.make_one_shot_iterator().string_handle())
    448 
    449       self.assertEqual(1,
    450                        sess.run(
    451                            feedable_int_scalar.get_next(),
    452                            feed_dict={handle_placeholder: handle_int_scalar}))
    453 
    454       self.assertEqual(2,
    455                        sess.run(
    456                            feedable_int_any.get_next(),
    457                            feed_dict={handle_placeholder: handle_int_scalar}))
    458 
    459       with self.assertRaises(errors.InvalidArgumentError):
    460         print(sess.run(
    461             feedable_int_vector.get_next(),
    462             feed_dict={handle_placeholder: handle_int_scalar}))
    463 
    464       with self.assertRaises(errors.InvalidArgumentError):
    465         print(sess.run(
    466             feedable_int_vector.get_next(),
    467             feed_dict={handle_placeholder: handle_float_vector}))
    468 
    469   def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
    470     worker_config = config_pb2.ConfigProto()
    471     worker_config.device_count["CPU"] = 3
    472 
    473     with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
    474       dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    475       iterator_3 = dataset_3.make_one_shot_iterator()
    476       iterator_3_handle = iterator_3.string_handle()
    477 
    478     @function.Defun(dtypes.string)
    479     def _remote_fn(h):
    480       remote_iterator = iterator_ops.Iterator.from_string_handle(
    481           h, dataset_3.output_types, dataset_3.output_shapes)
    482       return remote_iterator.get_next()
    483 
    484     with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
    485       target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    486       remote_op = functional_ops.remote_call(
    487           args=[iterator_3_handle],
    488           Tout=[dtypes.int32],
    489           f=_remote_fn,
    490           target=target_placeholder)
    491 
    492     with self.test_session(config=worker_config) as sess:
    493       elem = sess.run(
    494           remote_op,
    495           feed_dict={
    496               target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
    497           })
    498       self.assertEqual(elem, [1])
    499       # Fails when target is cpu:2 where the resource is not located.
    500       with self.assertRaises(errors.InvalidArgumentError):
    501         sess.run(
    502             remote_op,
    503             feed_dict={
    504                 target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
    505             })
    506       elem = sess.run(
    507           remote_op,
    508           feed_dict={
    509               target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
    510           })
    511       self.assertEqual(elem, [2])
    512       elem = sess.run(
    513           remote_op,
    514           feed_dict={
    515               target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
    516           })
    517       self.assertEqual(elem, [3])
    518       with self.assertRaises(errors.OutOfRangeError):
    519         sess.run(
    520             remote_op,
    521             feed_dict={
    522                 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
    523             })
    524 
    525   def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
    526     if not test_util.is_gpu_available():
    527       self.skipTest("No GPU available")
    528 
    529     with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
    530       dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    531       iterator_3 = dataset_3.make_one_shot_iterator()
    532       iterator_3_handle = iterator_3.string_handle()
    533 
    534     def _encode_raw(byte_array):
    535       return bytes(bytearray(byte_array))
    536 
    537     @function.Defun(dtypes.uint8)
    538     def _remote_fn(h):
    539       handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
    540       remote_iterator = iterator_ops.Iterator.from_string_handle(
    541           handle, dataset_3.output_types, dataset_3.output_shapes)
    542       return remote_iterator.get_next()
    543 
    544     with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
    545       target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    546       iterator_3_handle_uint8 = parsing_ops.decode_raw(
    547           bytes=iterator_3_handle, out_type=dtypes.uint8)
    548       remote_op = functional_ops.remote_call(
    549           args=[iterator_3_handle_uint8],
    550           Tout=[dtypes.int32],
    551           f=_remote_fn,
    552           target=target_placeholder)
    553 
    554     with self.test_session() as sess:
    555       elem = sess.run(
    556           remote_op,
    557           feed_dict={
    558               target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
    559           })
    560       self.assertEqual(elem, [1])
    561       elem = sess.run(
    562           remote_op,
    563           feed_dict={
    564               target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
    565           })
    566       self.assertEqual(elem, [2])
    567       elem = sess.run(
    568           remote_op,
    569           feed_dict={
    570               target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
    571           })
    572       self.assertEqual(elem, [3])
    573       with self.assertRaises(errors.OutOfRangeError):
    574         sess.run(
    575             remote_op,
    576             feed_dict={
    577                 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
    578             })
    579 
    580   def testIncorrectIteratorRestore(self):
    581 
    582     def _path():
    583       return os.path.join(self.get_temp_dir(), "iterator")
    584 
    585     def _save_op(iterator_resource):
    586       iterator_state_variant = gen_dataset_ops.serialize_iterator(
    587           iterator_resource)
    588       save_op = io_ops.write_file(
    589           _path(), parsing_ops.serialize_tensor(iterator_state_variant))
    590       return save_op
    591 
    592     def _restore_op(iterator_resource):
    593       iterator_state_variant = parsing_ops.parse_tensor(
    594           io_ops.read_file(_path()), dtypes.variant)
    595       restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
    596                                                         iterator_state_variant)
    597       return restore_op
    598 
    599     def _build_range_dataset_graph():
    600       start = 1
    601       stop = 10
    602       iterator = dataset_ops.Dataset.range(start,
    603                                            stop).make_initializable_iterator()
    604       init_op = iterator.initializer
    605       get_next = iterator.get_next()
    606       save_op = _save_op(iterator._iterator_resource)
    607       restore_op = _restore_op(iterator._iterator_resource)
    608       return init_op, get_next, save_op, restore_op
    609 
    610     def _build_reader_dataset_graph():
    611       filenames = ["test"]  # Does not exist but we don't care in this test.
    612       iterator = readers.FixedLengthRecordDataset(
    613           filenames, 1, 0, 0).make_initializable_iterator()
    614       init_op = iterator.initializer
    615       get_next_op = iterator.get_next()
    616       save_op = _save_op(iterator._iterator_resource)
    617       restore_op = _restore_op(iterator._iterator_resource)
    618       return init_op, get_next_op, save_op, restore_op
    619 
    620     # Saving iterator for RangeDataset graph.
    621     with ops.Graph().as_default() as g:
    622       init_op, _, save_op, _ = _build_range_dataset_graph()
    623       with self.test_session(graph=g) as sess:
    624         sess.run(init_op)
    625         sess.run(save_op)
    626 
    627     # Attempt to restore the saved iterator into an IteratorResource of
    628     # incompatible type. An iterator of RangeDataset has output type int64,
    629     # while an iterator of FixedLengthRecordDataset has output type string.
    630     # So an InvalidArgumentError should be raised by
    631     # IteratorResource::set_iterator.
    632     with ops.Graph().as_default() as g:
    633       _, _, _, restore_op = _build_reader_dataset_graph()
    634       with self.test_session(graph=g) as sess:
    635         with self.assertRaises(errors.InvalidArgumentError):
    636           sess.run(restore_op)
    637 
    638   def testRepeatedGetNextWarning(self):
    639     iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
    640     warnings.simplefilter("always")
    641     with warnings.catch_warnings(record=True) as w:
    642       for _ in range(100):
    643         iterator.get_next()
    644     self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD,
    645                      len(w))
    646     for warning in w:
    647       self.assertTrue(
    648           iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE in str(warning.message))
    649 
    650 
    651 if __name__ == "__main__":
    652   test.main()
    653