1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Test RangeDataset.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import os 21 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.data.ops import iterator_ops 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import errors 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import tensor_shape 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import gen_dataset_ops 30 from tensorflow.python.ops import io_ops 31 from tensorflow.python.ops import parsing_ops 32 from tensorflow.python.ops import variables 33 from tensorflow.python.platform import gfile 34 from tensorflow.python.platform import test 35 36 37 class RangeDatasetTest(test.TestCase): 38 39 def tearDown(self): 40 # Remove all checkpoint files. 41 prefix = self._iterator_checkpoint_prefix() 42 pattern = prefix + "*" 43 files = gfile.Glob(pattern) 44 map(gfile.Remove, files) 45 46 def testStop(self): 47 stop = array_ops.placeholder(dtypes.int64, shape=[]) 48 iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() 49 init_op = iterator.initializer 50 get_next = iterator.get_next() 51 52 with self.test_session() as sess: 53 sess.run(init_op, feed_dict={stop: 5}) 54 for i in range(5): 55 self.assertEqual(i, sess.run(get_next)) 56 with self.assertRaises(errors.OutOfRangeError): 57 sess.run(get_next) 58 59 def testStartStop(self): 60 start = array_ops.placeholder(dtypes.int64, shape=[]) 61 stop = array_ops.placeholder(dtypes.int64, shape=[]) 62 iterator = dataset_ops.Dataset.range(start, 63 stop).make_initializable_iterator() 64 init_op = iterator.initializer 65 get_next = iterator.get_next() 66 67 with self.test_session() as sess: 68 sess.run(init_op, feed_dict={start: 2, stop: 5}) 69 for i in range(2, 5): 70 self.assertEqual(i, sess.run(get_next)) 71 with self.assertRaises(errors.OutOfRangeError): 72 sess.run(get_next) 73 74 def testStartStopStep(self): 75 start = array_ops.placeholder(dtypes.int64, shape=[]) 76 stop = array_ops.placeholder(dtypes.int64, shape=[]) 77 step = array_ops.placeholder(dtypes.int64, shape=[]) 78 iterator = dataset_ops.Dataset.range(start, stop, 79 step).make_initializable_iterator() 80 init_op = iterator.initializer 81 get_next = iterator.get_next() 82 83 with self.test_session() as sess: 84 sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2}) 85 for i in range(2, 10, 2): 86 self.assertEqual(i, sess.run(get_next)) 87 with self.assertRaises(errors.OutOfRangeError): 88 sess.run(get_next) 89 90 def testZeroStep(self): 91 start = array_ops.placeholder(dtypes.int64, shape=[]) 92 stop = array_ops.placeholder(dtypes.int64, shape=[]) 93 step = array_ops.placeholder(dtypes.int64, shape=[]) 94 iterator = dataset_ops.Dataset.range(start, stop, 95 step).make_initializable_iterator() 96 init_op = iterator.initializer 97 98 with self.test_session() as sess: 99 with self.assertRaises(errors.InvalidArgumentError): 100 sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0}) 101 102 def testNegativeStep(self): 103 start = array_ops.placeholder(dtypes.int64, shape=[]) 104 stop = array_ops.placeholder(dtypes.int64, shape=[]) 105 step = array_ops.placeholder(dtypes.int64, shape=[]) 106 iterator = dataset_ops.Dataset.range(start, stop, 107 step).make_initializable_iterator() 108 init_op = iterator.initializer 109 get_next = iterator.get_next() 110 111 with self.test_session() as sess: 112 sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1}) 113 # This for loop is a no-op but will ensure that the implementation is 114 # consistent with range if it ever changes. 115 for i in range(2, 10, -1): 116 self.assertEqual(i, sess.run(get_next)) 117 with self.assertRaises(errors.OutOfRangeError): 118 sess.run(get_next) 119 120 def testStopLessThanStart(self): 121 start = array_ops.placeholder(dtypes.int64, shape=[]) 122 stop = array_ops.placeholder(dtypes.int64, shape=[]) 123 iterator = dataset_ops.Dataset.range(start, 124 stop).make_initializable_iterator() 125 init_op = iterator.initializer 126 get_next = iterator.get_next() 127 128 with self.test_session() as sess: 129 sess.run(init_op, feed_dict={start: 10, stop: 2}) 130 # This for loop is a no-op but will ensure that the implementation is 131 # consistent with range if it ever changes. 132 for i in range(10, 2): 133 self.assertEqual(i, sess.run(get_next)) 134 with self.assertRaises(errors.OutOfRangeError): 135 sess.run(get_next) 136 137 def testStopLessThanStartWithPositiveStep(self): 138 start = array_ops.placeholder(dtypes.int64, shape=[]) 139 stop = array_ops.placeholder(dtypes.int64, shape=[]) 140 step = array_ops.placeholder(dtypes.int64, shape=[]) 141 iterator = dataset_ops.Dataset.range(start, stop, 142 step).make_initializable_iterator() 143 init_op = iterator.initializer 144 get_next = iterator.get_next() 145 146 with self.test_session() as sess: 147 sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2}) 148 # This for loop is a no-op but will ensure that the implementation is 149 # consistent with range if it ever changes. 150 for i in range(10, 2, 2): 151 self.assertEqual(i, sess.run(get_next)) 152 with self.assertRaises(errors.OutOfRangeError): 153 sess.run(get_next) 154 155 def testStopLessThanStartWithNegativeStep(self): 156 start = array_ops.placeholder(dtypes.int64, shape=[]) 157 stop = array_ops.placeholder(dtypes.int64, shape=[]) 158 step = array_ops.placeholder(dtypes.int64, shape=[]) 159 iterator = dataset_ops.Dataset.range(start, stop, 160 step).make_initializable_iterator() 161 init_op = iterator.initializer 162 get_next = iterator.get_next() 163 164 with self.test_session() as sess: 165 sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1}) 166 for i in range(10, 2, -1): 167 self.assertEqual(i, sess.run(get_next)) 168 with self.assertRaises(errors.OutOfRangeError): 169 sess.run(get_next) 170 171 def _iterator_checkpoint_prefix(self): 172 return os.path.join(self.get_temp_dir(), "iterator") 173 174 def _save_op(self, iterator_resource): 175 iterator_state_variant = gen_dataset_ops.serialize_iterator( 176 iterator_resource) 177 save_op = io_ops.write_file( 178 self._iterator_checkpoint_prefix(), 179 parsing_ops.serialize_tensor(iterator_state_variant)) 180 return save_op 181 182 def _restore_op(self, iterator_resource): 183 iterator_state_variant = parsing_ops.parse_tensor( 184 io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) 185 restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, 186 iterator_state_variant) 187 return restore_op 188 189 def testSaveRestore(self): 190 191 def _build_graph(start, stop): 192 iterator = dataset_ops.Dataset.range(start, 193 stop).make_initializable_iterator() 194 init_op = iterator.initializer 195 get_next = iterator.get_next() 196 save_op = self._save_op(iterator._iterator_resource) 197 restore_op = self._restore_op(iterator._iterator_resource) 198 return init_op, get_next, save_op, restore_op 199 200 # Saving and restoring in different sessions. 201 start = 2 202 stop = 10 203 break_point = 5 204 with ops.Graph().as_default() as g: 205 init_op, get_next, save_op, _ = _build_graph(start, stop) 206 with self.test_session(graph=g) as sess: 207 sess.run(variables.global_variables_initializer()) 208 sess.run(init_op) 209 for i in range(start, break_point): 210 self.assertEqual(i, sess.run(get_next)) 211 sess.run(save_op) 212 213 with ops.Graph().as_default() as g: 214 init_op, get_next, _, restore_op = _build_graph(start, stop) 215 with self.test_session(graph=g) as sess: 216 sess.run(init_op) 217 sess.run(restore_op) 218 for i in range(break_point, stop): 219 self.assertEqual(i, sess.run(get_next)) 220 with self.assertRaises(errors.OutOfRangeError): 221 sess.run(get_next) 222 223 # Saving and restoring in same session. 224 with ops.Graph().as_default() as g: 225 init_op, get_next, save_op, restore_op = _build_graph(start, stop) 226 with self.test_session(graph=g) as sess: 227 sess.run(variables.global_variables_initializer()) 228 sess.run(init_op) 229 for i in range(start, break_point): 230 self.assertEqual(i, sess.run(get_next)) 231 sess.run(save_op) 232 sess.run(restore_op) 233 for i in range(break_point, stop): 234 self.assertEqual(i, sess.run(get_next)) 235 with self.assertRaises(errors.OutOfRangeError): 236 sess.run(get_next) 237 238 def testRestoreWithoutBuildingDatasetGraph(self): 239 240 def _build_graph(start, stop, num_epochs): 241 dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) 242 iterator = dataset.make_initializable_iterator() 243 init_op = iterator.initializer 244 get_next = iterator.get_next() 245 save_op = self._save_op(iterator._iterator_resource) 246 restore_op = self._restore_op(iterator._iterator_resource) 247 return init_op, get_next, save_op, restore_op 248 249 # Saving and restoring in different sessions. 250 start = 2 251 stop = 10 252 num_epochs = 5 253 break_point = 5 254 break_epoch = 3 255 with ops.Graph().as_default() as g: 256 init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) 257 with self.test_session(graph=g) as sess: 258 sess.run(variables.global_variables_initializer()) 259 sess.run(init_op) 260 for _ in range(break_epoch): 261 for i in range(start, stop): 262 self.assertEqual(i, sess.run(get_next)) 263 for i in range(start, break_point): 264 self.assertEqual(i, sess.run(get_next)) 265 sess.run(save_op) 266 267 with ops.Graph().as_default() as g: 268 # Create an empty IteratorResource and restore the Iterator into it. 269 output_types = dtypes.int64 270 output_shapes = tensor_shape.scalar() 271 iterator = iterator_ops.Iterator.from_structure(output_types, 272 output_shapes) 273 restore_op = self._restore_op(iterator._iterator_resource) 274 get_next = iterator.get_next() 275 with self.test_session(graph=g) as sess: 276 sess.run(restore_op) 277 for i in range(break_point, stop): 278 self.assertEqual(i, sess.run(get_next)) 279 for _ in range(break_epoch + 1, num_epochs): 280 for i in range(start, stop): 281 self.assertEqual(i, sess.run(get_next)) 282 with self.assertRaises(errors.OutOfRangeError): 283 sess.run(get_next) 284 285 def testRestoreInModifiedGraph(self): 286 287 def _build_graph(start, stop): 288 dataset = dataset_ops.Dataset.range(start, stop) 289 iterator = dataset.make_initializable_iterator() 290 init_op = iterator.initializer 291 get_next = iterator.get_next() 292 save_op = self._save_op(iterator._iterator_resource) 293 restore_op = self._restore_op(iterator._iterator_resource) 294 return init_op, get_next, save_op, restore_op 295 296 # Saving and restoring in different sessions. 297 start = 2 298 stop = 10 299 stop_1 = 8 300 break_point = 5 301 with ops.Graph().as_default() as g: 302 init_op, get_next, save_op, _ = _build_graph(start, stop) 303 with self.test_session(graph=g) as sess: 304 sess.run(variables.global_variables_initializer()) 305 sess.run(init_op) 306 for i in range(start, break_point): 307 self.assertEqual(i, sess.run(get_next)) 308 sess.run(save_op) 309 310 with ops.Graph().as_default() as g: 311 # Intentionally build a graph with a different value for stop to make sure 312 # the original dataset graph is actually getting loaded. 313 init_op, get_next, _, restore_op = _build_graph(start, stop_1) 314 with self.test_session(graph=g) as sess: 315 sess.run(restore_op) 316 for i in range(break_point, stop): 317 self.assertEqual(i, sess.run(get_next)) 318 with self.assertRaises(errors.OutOfRangeError): 319 sess.run(get_next) 320 321 def testInitThenRestore(self): 322 # Note: Calling init_op before restore_op is redundant. This test just makes 323 # sure we do not fail if restore is called on an already initialized 324 # iterator resource. 325 326 def _build_graph(start, stop): 327 dataset = dataset_ops.Dataset.range(start, stop) 328 iterator = dataset.make_initializable_iterator() 329 init_op = iterator.initializer 330 get_next = iterator.get_next() 331 save_op = self._save_op(iterator._iterator_resource) 332 restore_op = self._restore_op(iterator._iterator_resource) 333 return init_op, get_next, save_op, restore_op 334 335 # Saving and restoring in different sessions. 336 start = 2 337 stop = 10 338 break_point = 5 339 with ops.Graph().as_default() as g: 340 init_op, get_next, save_op, _ = _build_graph(start, stop) 341 with self.test_session(graph=g) as sess: 342 sess.run(variables.global_variables_initializer()) 343 sess.run(init_op) 344 for i in range(start, break_point): 345 self.assertEqual(i, sess.run(get_next)) 346 sess.run(save_op) 347 348 with ops.Graph().as_default() as g: 349 init_op, get_next, _, restore_op = _build_graph(start, stop) 350 with self.test_session(graph=g) as sess: 351 sess.run(init_op) 352 sess.run(restore_op) 353 for i in range(break_point, stop): 354 self.assertEqual(i, sess.run(get_next)) 355 with self.assertRaises(errors.OutOfRangeError): 356 sess.run(get_next) 357 358 def testMultipleSaves(self): 359 360 def _build_graph(start, stop): 361 iterator = dataset_ops.Dataset.range(start, 362 stop).make_initializable_iterator() 363 init_op = iterator.initializer 364 get_next = iterator.get_next() 365 save_op = self._save_op(iterator._iterator_resource) 366 restore_op = self._restore_op(iterator._iterator_resource) 367 return init_op, get_next, save_op, restore_op 368 369 start = 2 370 stop = 10 371 break_point1 = 5 372 break_point2 = 7 373 374 with ops.Graph().as_default() as g: 375 init_op, get_next, save_op, _ = _build_graph(start, stop) 376 with self.test_session(graph=g) as sess: 377 sess.run(variables.global_variables_initializer()) 378 sess.run(init_op) 379 for i in range(start, break_point1): 380 self.assertEqual(i, sess.run(get_next)) 381 sess.run(save_op) 382 383 with ops.Graph().as_default() as g: 384 init_op, get_next, save_op, restore_op = _build_graph(start, stop) 385 with self.test_session(graph=g) as sess: 386 sess.run(restore_op) 387 for i in range(break_point1, break_point2): 388 self.assertEqual(i, sess.run(get_next)) 389 sess.run(save_op) 390 391 break_point2 = 7 392 with ops.Graph().as_default() as g: 393 init_op, get_next, save_op, restore_op = _build_graph(start, stop) 394 with self.test_session(graph=g) as sess: 395 sess.run(restore_op) 396 for i in range(break_point2, stop): 397 self.assertEqual(i, sess.run(get_next)) 398 with self.assertRaises(errors.OutOfRangeError): 399 sess.run(get_next) 400 401 def testSaveRestoreWithRepeat(self): 402 403 def _build_graph(start, stop, num_epochs): 404 iterator = dataset_ops.Dataset.range( 405 start, stop).repeat(num_epochs).make_initializable_iterator() 406 init_op = iterator.initializer 407 get_next = iterator.get_next() 408 save_op = self._save_op(iterator._iterator_resource) 409 restore_op = self._restore_op(iterator._iterator_resource) 410 return init_op, get_next, save_op, restore_op 411 412 start = 2 413 stop = 10 414 num_epochs = 5 415 break_range = 5 416 break_epoch = 3 417 with ops.Graph().as_default() as g: 418 init_op, get_next, save_op, restore_op = _build_graph( 419 start, stop, num_epochs) 420 with self.test_session(graph=g) as sess: 421 sess.run(variables.global_variables_initializer()) 422 sess.run(init_op) 423 # Note: There is no checkpoint saved currently so a NotFoundError is 424 # raised. 425 with self.assertRaises(errors.NotFoundError): 426 sess.run(restore_op) 427 for _ in range(break_epoch - 1): 428 for i in range(start, stop): 429 self.assertEqual(i, sess.run(get_next)) 430 for i in range(start, break_range): 431 self.assertEqual(i, sess.run(get_next)) 432 sess.run(save_op) 433 434 with ops.Graph().as_default() as g: 435 init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) 436 with self.test_session(graph=g) as sess: 437 sess.run(restore_op) 438 for i in range(break_range, stop): 439 self.assertEqual(i, sess.run(get_next)) 440 for _ in range(break_epoch, num_epochs): 441 for i in range(start, stop): 442 self.assertEqual(i, sess.run(get_next)) 443 with self.assertRaises(errors.OutOfRangeError): 444 sess.run(get_next) 445 446 def testSaveRestoreExhaustedIterator(self): 447 448 def _build_graph(start, stop, num_epochs): 449 iterator = dataset_ops.Dataset.range( 450 start, stop).repeat(num_epochs).make_initializable_iterator() 451 init_op = iterator.initializer 452 get_next = iterator.get_next() 453 save_op = self._save_op(iterator._iterator_resource) 454 restore_op = self._restore_op(iterator._iterator_resource) 455 return init_op, get_next, save_op, restore_op 456 457 start = 2 458 stop = 10 459 num_epochs = 5 460 with ops.Graph().as_default() as g: 461 init_op, get_next, save_op, restore_op = _build_graph( 462 start, stop, num_epochs) 463 with self.test_session(graph=g) as sess: 464 sess.run(variables.global_variables_initializer()) 465 sess.run(init_op) 466 # Note: There is no checkpoint saved currently so a NotFoundError is 467 # raised. 468 with self.assertRaises(errors.NotFoundError): 469 sess.run(restore_op) 470 for _ in range(num_epochs): 471 for i in range(start, stop): 472 self.assertEqual(i, sess.run(get_next)) 473 with self.assertRaises(errors.OutOfRangeError): 474 sess.run(get_next) 475 sess.run(save_op) 476 477 with ops.Graph().as_default() as g: 478 init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) 479 with self.test_session(graph=g) as sess: 480 sess.run(restore_op) 481 with self.assertRaises(errors.OutOfRangeError): 482 sess.run(get_next) 483 484 485 if __name__ == "__main__": 486 test.main() 487