1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import os 21 import warnings 22 23 import numpy as np 24 25 from tensorflow.core.protobuf import config_pb2 26 from tensorflow.python.client import session 27 from tensorflow.python.data.ops import dataset_ops 28 from tensorflow.python.data.ops import iterator_ops 29 from tensorflow.python.data.ops import readers 30 from tensorflow.python.framework import constant_op 31 from tensorflow.python.framework import dtypes 32 from tensorflow.python.framework import errors 33 from tensorflow.python.framework import function 34 from tensorflow.python.framework import ops 35 from tensorflow.python.framework import test_util 36 from tensorflow.python.ops import array_ops 37 from tensorflow.python.ops import functional_ops 38 from tensorflow.python.ops import gen_dataset_ops 39 from tensorflow.python.ops import gradients_impl 40 from tensorflow.python.ops import io_ops 41 from tensorflow.python.ops import math_ops 42 from tensorflow.python.ops import parsing_ops 43 from tensorflow.python.ops import script_ops 44 from tensorflow.python.ops import variables 45 from tensorflow.python.platform import test 46 from tensorflow.python.training import server_lib 47 48 49 class IteratorTest(test.TestCase): 50 51 def testAttemptingGradientsRaiseExceptions(self): 52 component = constant_op.constant([1]) 53 side = constant_op.constant(0) 54 add = lambda x: x + side 55 dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) 56 value = dataset.make_one_shot_iterator().get_next() 57 with self.assertRaisesRegexp(LookupError, "No gradient defined"): 58 gradients_impl.gradients(value, component) 59 with self.assertRaisesRegexp(LookupError, "No gradient defined"): 60 gradients_impl.gradients(value, side) 61 with self.assertRaisesRegexp(LookupError, "No gradient defined"): 62 gradients_impl.gradients(value, [component, side]) 63 64 def testCapturingStateInOneShotRaisesException(self): 65 var = variables.Variable(37.0, name="myvar") 66 dataset = (dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0]) 67 .map(lambda x: x + var)) 68 with self.assertRaisesRegexp( 69 ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " 70 "datasets that capture stateful objects.+myvar"): 71 dataset.make_one_shot_iterator() 72 73 def testOneShotIterator(self): 74 components = (np.arange(7), 75 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 76 np.array(37.0) * np.arange(7)) 77 78 def _map_fn(x, y, z): 79 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 80 81 iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) 82 .repeat(14).make_one_shot_iterator()) 83 get_next = iterator.get_next() 84 85 self.assertEqual([c.shape[1:] for c in components], 86 [t.shape for t in get_next]) 87 88 with self.test_session() as sess: 89 for _ in range(14): 90 for i in range(7): 91 result = sess.run(get_next) 92 for component, result_component in zip(components, result): 93 self.assertAllEqual(component[i]**2, result_component) 94 with self.assertRaises(errors.OutOfRangeError): 95 sess.run(get_next) 96 97 def testOneShotIteratorCaptureByValue(self): 98 components = (np.arange(7), 99 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 100 np.array(37.0) * np.arange(7)) 101 tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) 102 103 def _map_fn(x, y, z): 104 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 105 106 iterator = (dataset_ops.Dataset.from_tensor_slices(tensor_components) 107 .map(_map_fn).repeat(14).make_one_shot_iterator()) 108 get_next = iterator.get_next() 109 110 self.assertEqual([c.shape[1:] for c in components], 111 [t.shape for t in get_next]) 112 113 with self.test_session() as sess: 114 for _ in range(14): 115 for i in range(7): 116 result = sess.run(get_next) 117 for component, result_component in zip(components, result): 118 self.assertAllEqual(component[i]**2, result_component) 119 with self.assertRaises(errors.OutOfRangeError): 120 sess.run(get_next) 121 122 def testOneShotIteratorInsideContainer(self): 123 components = (np.arange(7), 124 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 125 np.array(37.0) * np.arange(7)) 126 127 def within_container(): 128 def _map_fn(x, y, z): 129 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 130 iterator = (dataset_ops.Dataset.from_tensor_slices(components) 131 .map(_map_fn).repeat(14).make_one_shot_iterator()) 132 return iterator.get_next() 133 134 server = server_lib.Server.create_local_server() 135 136 # Create two iterators within unique containers, and run them to 137 # make sure that the resources aren't shared. 138 # 139 # The test below would fail if cname were the same across both 140 # sessions. 141 for i in range(2): 142 with session.Session(server.target) as sess: 143 cname = "iteration%d" % i 144 with ops.container(cname): 145 get_next = within_container() 146 147 for _ in range(14): 148 for i in range(7): 149 result = sess.run(get_next) 150 for component, result_component in zip(components, result): 151 self.assertAllEqual(component[i]**2, result_component) 152 with self.assertRaises(errors.OutOfRangeError): 153 sess.run(get_next) 154 155 def testOneShotIteratorNonBlocking(self): 156 dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) 157 iterator = dataset.make_one_shot_iterator() 158 next_element = iterator.get_next() 159 160 # Create a session with a single thread to ensure that the 161 # one-shot iterator initializer does not deadlock. 162 config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, 163 use_per_session_threads=True) 164 with session.Session(config=config) as sess: 165 self.assertAllEqual([1, 4, 9], sess.run(next_element)) 166 with self.assertRaises(errors.OutOfRangeError): 167 sess.run(next_element) 168 169 # Test with multiple threads invoking the one-shot iterator concurrently. 170 with session.Session(config=config) as sess: 171 results = [] 172 def consumer_thread(): 173 try: 174 results.append(sess.run(next_element)) 175 except errors.OutOfRangeError: 176 results.append(None) 177 178 num_threads = 8 179 threads = [ 180 self.checkedThread(consumer_thread) for _ in range(num_threads)] 181 for t in threads: 182 t.start() 183 for t in threads: 184 t.join() 185 186 self.assertEqual(num_threads, len(results)) 187 self.assertEqual(num_threads - 1, 188 len([None for r in results if r is None])) 189 self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) 190 191 def testOneShotIteratorInitializerFails(self): 192 # Define a dataset whose initialization will always fail. 193 dataset = dataset_ops.Dataset.from_tensors( 194 array_ops.check_numerics( 195 constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) 196 iterator = dataset.make_one_shot_iterator() 197 next_element = iterator.get_next() 198 199 with self.test_session() as sess: 200 with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): 201 sess.run(next_element) 202 203 # Test that subsequent attempts to use the iterator also fail. 204 with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): 205 sess.run(next_element) 206 207 with self.test_session() as sess: 208 def consumer_thread(): 209 with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): 210 sess.run(next_element) 211 212 num_threads = 8 213 threads = [ 214 self.checkedThread(consumer_thread) for _ in range(num_threads)] 215 for t in threads: 216 t.start() 217 for t in threads: 218 t.join() 219 220 def testSimpleSharedResource(self): 221 components = ( 222 np.array(1, dtype=np.int64), 223 np.array([1, 2, 3], dtype=np.int64), 224 np.array(37.0, dtype=np.float64) 225 ) 226 227 server = server_lib.Server.create_local_server() 228 229 # Create two non-overlapping sessions that share the same iterator 230 # resource on the same server, and verify that an action of the 231 # first session (initializing the iterator) is visible in the 232 # second session. 233 with ops.Graph().as_default(): 234 iterator = (dataset_ops.Dataset.from_tensors(components) 235 .map(lambda x, y, z: (x, y, z)).make_initializable_iterator( 236 shared_name="shared_iterator")) 237 init_op = iterator.initializer 238 get_next = iterator.get_next() 239 240 with session.Session(server.target) as sess: 241 sess.run(init_op) 242 results = sess.run(get_next) 243 for component, result_component in zip(components, results): 244 self.assertAllEqual(component, result_component) 245 with self.assertRaises(errors.OutOfRangeError): 246 sess.run(get_next) 247 248 # Re-initialize the iterator in the first session. 249 sess.run(init_op) 250 251 with ops.Graph().as_default(): 252 # Re-define the iterator manually, without defining any of the 253 # functions in this graph, to ensure that we are not 254 # accidentally redefining functions with the same names in the 255 # new graph. 256 iterator = iterator_ops.Iterator.from_structure( 257 shared_name="shared_iterator", 258 output_types=(dtypes.int64, dtypes.int64, dtypes.float64), 259 output_shapes=([], [3], [])) 260 get_next = iterator.get_next() 261 262 with session.Session(server.target) as sess: 263 # Use the iterator without re-initializing in the second session. 264 results = sess.run(get_next) 265 for component, result_component in zip(components, results): 266 self.assertAllEqual(component, result_component) 267 with self.assertRaises(errors.OutOfRangeError): 268 sess.run(get_next) 269 270 def testNotInitializedError(self): 271 components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) 272 iterator = (dataset_ops.Dataset.from_tensors(components) 273 .make_initializable_iterator()) 274 get_next = iterator.get_next() 275 276 with self.test_session() as sess: 277 with self.assertRaisesRegexp(errors.FailedPreconditionError, 278 "iterator has not been initialized"): 279 sess.run(get_next) 280 281 def testReinitializableIterator(self): 282 dataset_3 = dataset_ops.Dataset.from_tensors( 283 constant_op.constant([1, 2, 3])) 284 dataset_4 = dataset_ops.Dataset.from_tensors( 285 constant_op.constant([4, 5, 6, 7])) 286 iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types, 287 [None]) 288 289 dataset_3_init_op = iterator.make_initializer(dataset_3) 290 dataset_4_init_op = iterator.make_initializer(dataset_4) 291 get_next = iterator.get_next() 292 293 self.assertEqual(dataset_3.output_types, iterator.output_types) 294 self.assertEqual(dataset_4.output_types, iterator.output_types) 295 self.assertEqual([None], iterator.output_shapes.as_list()) 296 297 with self.test_session() as sess: 298 # The iterator is initially uninitialized. 299 with self.assertRaises(errors.FailedPreconditionError): 300 sess.run(get_next) 301 302 # Initialize with one dataset. 303 sess.run(dataset_3_init_op) 304 self.assertAllEqual([1, 2, 3], sess.run(get_next)) 305 with self.assertRaises(errors.OutOfRangeError): 306 sess.run(get_next) 307 308 # Initialize with a different dataset. 309 sess.run(dataset_4_init_op) 310 self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) 311 with self.assertRaises(errors.OutOfRangeError): 312 sess.run(get_next) 313 314 # Reinitialize with the first dataset. 315 sess.run(dataset_3_init_op) 316 self.assertAllEqual([1, 2, 3], sess.run(get_next)) 317 with self.assertRaises(errors.OutOfRangeError): 318 sess.run(get_next) 319 320 def testReinitializableIteratorStaticErrors(self): 321 # Non-matching structure for types and shapes. 322 with self.assertRaises(TypeError): 323 iterator = iterator_ops.Iterator.from_structure((dtypes.int64, 324 dtypes.float64), [None]) 325 326 # Test validation of dataset argument. 327 iterator = iterator_ops.Iterator.from_structure((dtypes.int64, 328 dtypes.float64)) 329 330 # Incompatible structure. 331 with self.assertRaises(ValueError): 332 iterator.make_initializer( 333 dataset_ops.Dataset.from_tensors(((constant_op.constant( 334 [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant( 335 [4., 5., 6., 7.], dtype=dtypes.float64),)))) 336 337 # Incompatible types. 338 with self.assertRaises(TypeError): 339 iterator.make_initializer( 340 dataset_ops.Dataset.from_tensors((constant_op.constant( 341 [1, 2, 3], dtype=dtypes.int32), constant_op.constant( 342 [4., 5., 6., 7.], dtype=dtypes.float32)))) 343 344 # Incompatible shapes. 345 iterator = iterator_ops.Iterator.from_structure( 346 (dtypes.int64, dtypes.float64), ([None], [])) 347 with self.assertRaises(TypeError): 348 iterator.make_initializer( 349 dataset_ops.Dataset.from_tensors((constant_op.constant( 350 [1, 2, 3], dtype=dtypes.int64), constant_op.constant( 351 [4., 5., 6., 7.], dtype=dtypes.float64)))) 352 353 def testIteratorStringHandle(self): 354 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 355 dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) 356 357 iterator_3 = dataset_3.make_one_shot_iterator() 358 iterator_4 = dataset_4.make_one_shot_iterator() 359 360 handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 361 feedable_iterator = iterator_ops.Iterator.from_string_handle( 362 handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) 363 next_element = feedable_iterator.get_next() 364 365 self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) 366 self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) 367 self.assertEqual([], feedable_iterator.output_shapes) 368 369 with self.test_session() as sess: 370 iterator_3_handle = sess.run(iterator_3.string_handle()) 371 iterator_4_handle = sess.run(iterator_4.string_handle()) 372 373 self.assertEqual( 374 10, sess.run(next_element, 375 feed_dict={handle_placeholder: iterator_4_handle})) 376 self.assertEqual( 377 1, sess.run(next_element, 378 feed_dict={handle_placeholder: iterator_3_handle})) 379 self.assertEqual( 380 20, sess.run(next_element, 381 feed_dict={handle_placeholder: iterator_4_handle})) 382 self.assertEqual( 383 2, sess.run(next_element, 384 feed_dict={handle_placeholder: iterator_3_handle})) 385 self.assertEqual( 386 30, sess.run(next_element, 387 feed_dict={handle_placeholder: iterator_4_handle})) 388 self.assertEqual( 389 3, sess.run(next_element, 390 feed_dict={handle_placeholder: iterator_3_handle})) 391 self.assertEqual( 392 40, sess.run(next_element, 393 feed_dict={handle_placeholder: iterator_4_handle})) 394 with self.assertRaises(errors.OutOfRangeError): 395 sess.run(next_element, 396 feed_dict={handle_placeholder: iterator_3_handle}) 397 with self.assertRaises(errors.OutOfRangeError): 398 sess.run(next_element, 399 feed_dict={handle_placeholder: iterator_4_handle}) 400 401 def testIteratorStringHandleReuseTensorObject(self): 402 dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 403 one_shot_iterator = dataset.make_one_shot_iterator() 404 initializable_iterator = dataset.make_initializable_iterator() 405 structure_iterator = iterator_ops.Iterator.from_structure( 406 dataset.output_types) 407 408 created_ops = len(ops.get_default_graph().get_operations()) 409 410 self.assertIs(one_shot_iterator.string_handle(), 411 one_shot_iterator.string_handle()) 412 self.assertIs(initializable_iterator.string_handle(), 413 initializable_iterator.string_handle()) 414 self.assertIs(structure_iterator.string_handle(), 415 structure_iterator.string_handle()) 416 417 # Assert that getting the (default) string handle creates no ops. 418 self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) 419 420 # Specifying an explicit name will create a new op. 421 handle_with_name = one_shot_iterator.string_handle(name="foo") 422 self.assertEqual("foo", handle_with_name.op.name) 423 self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) 424 425 handle_with_same_name = one_shot_iterator.string_handle(name="foo") 426 self.assertEqual("foo_1", handle_with_same_name.op.name) 427 self.assertIsNot(handle_with_name, handle_with_same_name) 428 429 def testIteratorStringHandleError(self): 430 dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices([1, 2, 431 3]).repeat()) 432 dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) 433 434 handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 435 436 feedable_int_scalar = iterator_ops.Iterator.from_string_handle( 437 handle_placeholder, dtypes.int32, []) 438 feedable_int_vector = iterator_ops.Iterator.from_string_handle( 439 handle_placeholder, dtypes.int32, [None]) 440 feedable_int_any = iterator_ops.Iterator.from_string_handle( 441 handle_placeholder, dtypes.int32) 442 443 with self.test_session() as sess: 444 handle_int_scalar = sess.run( 445 dataset_int_scalar.make_one_shot_iterator().string_handle()) 446 handle_float_vector = sess.run( 447 dataset_float_vector.make_one_shot_iterator().string_handle()) 448 449 self.assertEqual(1, 450 sess.run( 451 feedable_int_scalar.get_next(), 452 feed_dict={handle_placeholder: handle_int_scalar})) 453 454 self.assertEqual(2, 455 sess.run( 456 feedable_int_any.get_next(), 457 feed_dict={handle_placeholder: handle_int_scalar})) 458 459 with self.assertRaises(errors.InvalidArgumentError): 460 print(sess.run( 461 feedable_int_vector.get_next(), 462 feed_dict={handle_placeholder: handle_int_scalar})) 463 464 with self.assertRaises(errors.InvalidArgumentError): 465 print(sess.run( 466 feedable_int_vector.get_next(), 467 feed_dict={handle_placeholder: handle_float_vector})) 468 469 def testRemoteIteratorUsingRemoteCallOpDirectSession(self): 470 worker_config = config_pb2.ConfigProto() 471 worker_config.device_count["CPU"] = 3 472 473 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): 474 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 475 iterator_3 = dataset_3.make_one_shot_iterator() 476 iterator_3_handle = iterator_3.string_handle() 477 478 @function.Defun(dtypes.string) 479 def _remote_fn(h): 480 remote_iterator = iterator_ops.Iterator.from_string_handle( 481 h, dataset_3.output_types, dataset_3.output_shapes) 482 return remote_iterator.get_next() 483 484 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 485 target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 486 remote_op = functional_ops.remote_call( 487 args=[iterator_3_handle], 488 Tout=[dtypes.int32], 489 f=_remote_fn, 490 target=target_placeholder) 491 492 with self.test_session(config=worker_config) as sess: 493 elem = sess.run( 494 remote_op, 495 feed_dict={ 496 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 497 }) 498 self.assertEqual(elem, [1]) 499 # Fails when target is cpu:2 where the resource is not located. 500 with self.assertRaises(errors.InvalidArgumentError): 501 sess.run( 502 remote_op, 503 feed_dict={ 504 target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" 505 }) 506 elem = sess.run( 507 remote_op, 508 feed_dict={ 509 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 510 }) 511 self.assertEqual(elem, [2]) 512 elem = sess.run( 513 remote_op, 514 feed_dict={ 515 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 516 }) 517 self.assertEqual(elem, [3]) 518 with self.assertRaises(errors.OutOfRangeError): 519 sess.run( 520 remote_op, 521 feed_dict={ 522 target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" 523 }) 524 525 def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): 526 if not test_util.is_gpu_available(): 527 self.skipTest("No GPU available") 528 529 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 530 dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) 531 iterator_3 = dataset_3.make_one_shot_iterator() 532 iterator_3_handle = iterator_3.string_handle() 533 534 def _encode_raw(byte_array): 535 return bytes(bytearray(byte_array)) 536 537 @function.Defun(dtypes.uint8) 538 def _remote_fn(h): 539 handle = script_ops.py_func(_encode_raw, [h], dtypes.string) 540 remote_iterator = iterator_ops.Iterator.from_string_handle( 541 handle, dataset_3.output_types, dataset_3.output_shapes) 542 return remote_iterator.get_next() 543 544 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 545 target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) 546 iterator_3_handle_uint8 = parsing_ops.decode_raw( 547 bytes=iterator_3_handle, out_type=dtypes.uint8) 548 remote_op = functional_ops.remote_call( 549 args=[iterator_3_handle_uint8], 550 Tout=[dtypes.int32], 551 f=_remote_fn, 552 target=target_placeholder) 553 554 with self.test_session() as sess: 555 elem = sess.run( 556 remote_op, 557 feed_dict={ 558 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 559 }) 560 self.assertEqual(elem, [1]) 561 elem = sess.run( 562 remote_op, 563 feed_dict={ 564 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 565 }) 566 self.assertEqual(elem, [2]) 567 elem = sess.run( 568 remote_op, 569 feed_dict={ 570 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 571 }) 572 self.assertEqual(elem, [3]) 573 with self.assertRaises(errors.OutOfRangeError): 574 sess.run( 575 remote_op, 576 feed_dict={ 577 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" 578 }) 579 580 def testIncorrectIteratorRestore(self): 581 582 def _path(): 583 return os.path.join(self.get_temp_dir(), "iterator") 584 585 def _save_op(iterator_resource): 586 iterator_state_variant = gen_dataset_ops.serialize_iterator( 587 iterator_resource) 588 save_op = io_ops.write_file( 589 _path(), parsing_ops.serialize_tensor(iterator_state_variant)) 590 return save_op 591 592 def _restore_op(iterator_resource): 593 iterator_state_variant = parsing_ops.parse_tensor( 594 io_ops.read_file(_path()), dtypes.variant) 595 restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, 596 iterator_state_variant) 597 return restore_op 598 599 def _build_range_dataset_graph(): 600 start = 1 601 stop = 10 602 iterator = dataset_ops.Dataset.range(start, 603 stop).make_initializable_iterator() 604 init_op = iterator.initializer 605 get_next = iterator.get_next() 606 save_op = _save_op(iterator._iterator_resource) 607 restore_op = _restore_op(iterator._iterator_resource) 608 return init_op, get_next, save_op, restore_op 609 610 def _build_reader_dataset_graph(): 611 filenames = ["test"] # Does not exist but we don't care in this test. 612 iterator = readers.FixedLengthRecordDataset( 613 filenames, 1, 0, 0).make_initializable_iterator() 614 init_op = iterator.initializer 615 get_next_op = iterator.get_next() 616 save_op = _save_op(iterator._iterator_resource) 617 restore_op = _restore_op(iterator._iterator_resource) 618 return init_op, get_next_op, save_op, restore_op 619 620 # Saving iterator for RangeDataset graph. 621 with ops.Graph().as_default() as g: 622 init_op, _, save_op, _ = _build_range_dataset_graph() 623 with self.test_session(graph=g) as sess: 624 sess.run(init_op) 625 sess.run(save_op) 626 627 # Attempt to restore the saved iterator into an IteratorResource of 628 # incompatible type. An iterator of RangeDataset has output type int64, 629 # while an iterator of FixedLengthRecordDataset has output type string. 630 # So an InvalidArgumentError should be raised by 631 # IteratorResource::set_iterator. 632 with ops.Graph().as_default() as g: 633 _, _, _, restore_op = _build_reader_dataset_graph() 634 with self.test_session(graph=g) as sess: 635 with self.assertRaises(errors.InvalidArgumentError): 636 sess.run(restore_op) 637 638 def testRepeatedGetNextWarning(self): 639 iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator() 640 warnings.simplefilter("always") 641 with warnings.catch_warnings(record=True) as w: 642 for _ in range(100): 643 iterator.get_next() 644 self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, 645 len(w)) 646 for warning in w: 647 self.assertTrue( 648 iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE in str(warning.message)) 649 650 651 if __name__ == "__main__": 652 test.main() 653