1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests of the tfdbg Stepper.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.core.protobuf import config_pb2 21 from tensorflow.core.protobuf import rewriter_config_pb2 22 from tensorflow.python.client import session 23 from tensorflow.python.debug.lib.stepper import NodeStepper 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import state_ops 31 from tensorflow.python.ops import variables 32 from tensorflow.python.platform import googletest 33 from tensorflow.python.training import gradient_descent 34 35 36 class StepperTest(test_util.TensorFlowTestCase): 37 38 def setUp(self): 39 self.a = variables.Variable(2.0, name="a") 40 self.b = variables.Variable(3.0, name="b") 41 42 self.c = math_ops.multiply(self.a, self.b, name="c") # Should be 6.0. 43 self.d = math_ops.multiply(self.a, self.a, name="d") # Should be 4.0. 44 45 self.e = math_ops.multiply(self.d, self.c, name="e") # Should be 24.0. 46 47 self.f_y = constant_op.constant(0.30, name="f_y") 48 self.f = math_ops.div(self.b, self.f_y, name="f") # Should be 10.0. 49 50 # The there nodes x, y and z form a graph with "cross-links" in. I.e., x 51 # and y are both direct inputs to z, but x is also a direct input to y. 52 self.x = variables.Variable(2.0, name="x") # Should be 2.0 53 self.y = math_ops.negative(self.x, name="y") # Should be -2.0. 54 55 self.z = math_ops.multiply(self.x, self.y, name="z") # Should be -4.0. 56 57 rewriter_config = rewriter_config_pb2.RewriterConfig( 58 disable_model_pruning=True, 59 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 60 constant_folding=rewriter_config_pb2.RewriterConfig.OFF) 61 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 62 config = config_pb2.ConfigProto(graph_options=graph_options) 63 self.sess = session.Session(config=config) 64 self.sess.run(variables.global_variables_initializer()) 65 66 def tearDown(self): 67 ops.reset_default_graph() 68 69 def testContToFetchNotInTransitiveClosureShouldError(self): 70 with NodeStepper(self.sess, "e:0") as stepper: 71 sorted_nodes = stepper.sorted_nodes() 72 self.assertEqual(7, len(sorted_nodes)) 73 self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) 74 self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) 75 self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) 76 self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) 77 self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) 78 self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) 79 self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) 80 81 self.assertSetEqual( 82 {"e:0", "d:0", "c:0", "a/read:0", "b/read:0", "b:0", "a:0"}, 83 set(stepper.closure_elements())) 84 85 with self.assertRaisesRegexp( 86 ValueError, 87 "Target \"f:0\" is not in the transitive closure for the fetch of " 88 "the stepper"): 89 stepper.cont("f:0") 90 91 def testContToNodeNameShouldReturnTensorValue(self): 92 with NodeStepper(self.sess, "e:0") as stepper: 93 self.assertAllClose(6.0, stepper.cont("c")) 94 95 def testUsingNamesNotUsingIntermediateTensors(self): 96 with NodeStepper(self.sess, "e:0") as stepper: 97 # The first cont() call should have used no feeds. 98 result = stepper.cont("c:0") 99 self.assertAllClose(6.0, result) 100 self.assertItemsEqual(["a/read:0", "b/read:0"], 101 stepper.intermediate_tensor_names()) 102 self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) 103 self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) 104 self.assertEqual({}, stepper.last_feed_types()) 105 106 # The second cont() call should have used the tensor handle from the 107 # previous cont() call. 108 result = stepper.cont("e:0") 109 self.assertAllClose(24.0, result) 110 self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"], 111 stepper.intermediate_tensor_names()) 112 self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) 113 self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) 114 self.assertAllClose(4.0, stepper.get_tensor_value("d:0")) 115 self.assertEqual({ 116 "c:0": NodeStepper.FEED_TYPE_HANDLE, 117 "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 118 }, stepper.last_feed_types()) 119 120 def testUsingNodesNotUsingIntermediateTensors(self): 121 with NodeStepper(self.sess, self.e) as stepper: 122 # There should be no handles before any cont() calls. 123 self.assertEqual([], stepper.handle_names()) 124 self.assertSetEqual(set(), stepper.handle_node_names()) 125 126 # Before the cont() call, the stepper should not have access to the value 127 # of c:0. 128 with self.assertRaisesRegexp( 129 ValueError, 130 "This stepper instance does not have access to the value of tensor " 131 "\"c:0\""): 132 stepper.get_tensor_value("c:0") 133 134 # Using the node/tensor itself, instead of the name str, should work on 135 # cont(). 136 result = stepper.cont(self.c) 137 self.assertItemsEqual(["a/read:0", "b/read:0"], 138 stepper.intermediate_tensor_names()) 139 self.assertAllClose(6.0, result) 140 self.assertEqual({}, stepper.last_feed_types()) 141 142 self.assertEqual(["c:0"], stepper.handle_names()) 143 self.assertEqual({"c"}, stepper.handle_node_names()) 144 145 # After the cont() call, the stepper should have access to the value of 146 # c:0 via a tensor handle. 147 self.assertAllClose(6.0, stepper.get_tensor_value("c:0")) 148 149 result = stepper.cont(self.e) 150 self.assertAllClose(24.0, result) 151 self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"], 152 stepper.intermediate_tensor_names()) 153 self.assertEqual({ 154 "c:0": NodeStepper.FEED_TYPE_HANDLE, 155 "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 156 }, stepper.last_feed_types()) 157 158 def testContToTensorWithIntermediateDumpShouldUseDump(self): 159 with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper: 160 stepper.cont("c:0") 161 self.assertItemsEqual(["a/read:0", "b/read:0"], 162 stepper.intermediate_tensor_names()) 163 self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) 164 self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) 165 166 self.assertAllClose(2.0, stepper.cont("a/read:0")) 167 self.assertEqual({ 168 "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE 169 }, stepper.last_feed_types()) 170 171 self.assertAllClose(10.0, stepper.cont("f:0")) 172 self.assertEqual({ 173 "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE 174 }, stepper.last_feed_types()) 175 176 def testDisablingUseDumpedIntermediatesWorks(self): 177 with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper: 178 stepper.cont("c:0") 179 self.assertItemsEqual(["a/read:0", "b/read:0"], 180 stepper.intermediate_tensor_names()) 181 self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0")) 182 self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0")) 183 184 self.assertAllClose(10.0, 185 stepper.cont("f:0", use_dumped_intermediates=False)) 186 self.assertEqual({}, stepper.last_feed_types()) 187 188 def testIsFeedableShouldGiveCorrectAnswers(self): 189 with NodeStepper(self.sess, self.e) as stepper: 190 self.assertTrue(stepper.is_feedable("a/read:0")) 191 self.assertTrue(stepper.is_feedable("b/read:0")) 192 self.assertTrue(stepper.is_feedable("c:0")) 193 self.assertTrue(stepper.is_feedable("d:0")) 194 195 def testOverrideValue(self): 196 with NodeStepper(self.sess, self.e) as stepper: 197 result = stepper.cont(self.c) 198 self.assertAllClose(6.0, result) 199 self.assertEqual({}, stepper.last_feed_types()) 200 201 # There should be no overrides before any cont() calls. 202 self.assertEqual([], stepper.override_names()) 203 204 # Calling cont() on c again should lead to use of the handle. 205 result = stepper.cont(self.c) 206 self.assertAllClose(6.0, result) 207 self.assertEqual({ 208 "c:0": NodeStepper.FEED_TYPE_HANDLE 209 }, stepper.last_feed_types()) 210 211 # Override c:0. 212 stepper.override_tensor("c:0", 7.0) 213 214 # After the overriding, calling get_tensor_value() on c:0 should yield the 215 # overriding value. 216 self.assertEqual(7.0, stepper.get_tensor_value("c:0")) 217 218 # Now c:0 should have only an override value, but no cached handle, 219 # because the handle should have been invalidated. 220 self.assertEqual([], stepper.handle_names()) 221 self.assertSetEqual(set(), stepper.handle_node_names()) 222 self.assertEqual(["c:0"], stepper.override_names()) 223 224 # Run a downstream tensor after the value override. 225 result = stepper.cont(self.e) 226 self.assertAllClose(28.0, result) # Should reflect the overriding value. 227 228 # Should use override, instead of the handle. 229 self.assertEqual({ 230 "c:0": NodeStepper.FEED_TYPE_OVERRIDE, 231 "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 232 }, stepper.last_feed_types()) 233 234 def testOverrideValueTwice(self): 235 with NodeStepper(self.sess, self.e) as stepper: 236 # Override once. 237 stepper.override_tensor("c:0", 7.0) 238 self.assertAllClose(28.0, stepper.cont(self.e)) 239 self.assertEqual({ 240 "c:0": NodeStepper.FEED_TYPE_OVERRIDE 241 }, stepper.last_feed_types()) 242 243 self.assertEqual(["e:0"], stepper.handle_names()) 244 self.assertSetEqual({"e"}, stepper.handle_node_names()) 245 self.assertEqual(["c:0"], stepper.override_names()) 246 247 # Calling cont(self.e) again. This time the cached tensor handle of e 248 # should be used. 249 self.assertEqual(28.0, stepper.cont(self.e)) 250 self.assertEqual({ 251 "e:0": NodeStepper.FEED_TYPE_HANDLE 252 }, stepper.last_feed_types()) 253 254 # Override c again. This should have invalidated the cache for e. 255 stepper.override_tensor("c:0", 8.0) 256 257 self.assertEqual([], stepper.handle_names()) 258 self.assertEqual(set(), stepper.handle_node_names()) 259 self.assertEqual(["c:0"], stepper.override_names()) 260 261 self.assertAllClose(32.0, stepper.cont(self.e)) 262 self.assertEqual({ 263 "c:0": NodeStepper.FEED_TYPE_OVERRIDE, 264 "d:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 265 }, stepper.last_feed_types()) 266 267 def testRemoveOverrideValue(self): 268 with NodeStepper(self.sess, self.e) as stepper: 269 result = stepper.cont(self.c) 270 self.assertAllClose(6.0, result) 271 self.assertEqual({}, stepper.last_feed_types()) 272 273 # The previous cont() step should have generated a cached tensor handle. 274 self.assertEqual(["c:0"], stepper.handle_names()) 275 self.assertSetEqual({"c"}, stepper.handle_node_names()) 276 277 # Override c:0. 278 stepper.override_tensor("c:0", 7.0) 279 280 # The overriding should have invalidated the tensor handle. 281 self.assertEqual([], stepper.handle_names()) 282 self.assertSetEqual(set(), stepper.handle_node_names()) 283 self.assertEqual(["c:0"], stepper.override_names()) 284 285 result = stepper.cont(self.e) 286 self.assertAllClose(28.0, result) # Should reflect the overriding value. 287 self.assertEqual({ 288 "c:0": NodeStepper.FEED_TYPE_OVERRIDE, 289 "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 290 }, stepper.last_feed_types()) 291 292 # The handle to tensor e:0 should have been cached, even though its 293 # transitive closure contains an override. 294 self.assertIn("e:0", stepper.handle_names()) 295 self.assertSetEqual({"e"}, stepper.handle_node_names()) 296 297 # Remove the override. 298 stepper.remove_override("c:0") 299 # c:0 should not be in the overrides anymore. 300 self.assertEqual([], stepper.override_names()) 301 302 # Removing the override should have invalidated the tensor handle for c. 303 self.assertNotIn("e:0", stepper.handle_names()) 304 self.assertNotIn("e", stepper.handle_node_names()) 305 306 # Should reflect the non-overriding value. 307 self.assertAllClose(24.0, stepper.cont(self.e)) 308 309 # This time, the handle to tensor e:0 should have been cached again, even 310 # thought its transitive closure contains an override. 311 self.assertIn("e:0", stepper.handle_names()) 312 self.assertIn("e", stepper.handle_node_names()) 313 314 # Calling cont(self.e) again should have used the tensor handle to e:0. 315 self.assertAllClose(24.0, stepper.cont(self.e)) 316 self.assertEqual({ 317 "e:0": NodeStepper.FEED_TYPE_HANDLE, 318 }, stepper.last_feed_types()) 319 320 def testOverrideAndContToSameTensor(self): 321 with NodeStepper(self.sess, self.e) as stepper: 322 result = stepper.cont(self.c) 323 self.assertAllClose(6.0, result) 324 self.assertEqual({}, stepper.last_feed_types()) 325 self.assertEqual(["c:0"], stepper.handle_names()) 326 self.assertSetEqual({"c"}, stepper.handle_node_names()) 327 328 self.assertAllClose(6.0, stepper.cont(self.c)) 329 330 # The last cont() call should use the tensor handle directly. 331 self.assertEqual({ 332 "c:0": NodeStepper.FEED_TYPE_HANDLE 333 }, stepper.last_feed_types()) 334 335 # Override c:0. 336 stepper.override_tensor("c:0", 7.0) 337 338 # As a result of the override, the tensor handle should have been 339 # invalidated. 340 self.assertEqual([], stepper.handle_names()) 341 self.assertSetEqual(set(), stepper.handle_node_names()) 342 343 result = stepper.cont(self.c) 344 self.assertAllClose(7.0, result) 345 346 self.assertEqual({ 347 "c:0": NodeStepper.FEED_TYPE_OVERRIDE 348 }, stepper.last_feed_types()) 349 350 def testFinalizeWithPreviousOverrides(self): 351 with NodeStepper(self.sess, self.e) as stepper: 352 stepper.override_tensor("a/read:0", 20.0) 353 self.assertEqual(["a/read:0"], stepper.override_names()) 354 355 # Should reflect the overriding value. 356 self.assertAllClose(24000.0, stepper.cont("e:0")) 357 self.assertEqual({ 358 "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE 359 }, stepper.last_feed_types()) 360 361 # Finalize call should have ignored the overriding value. 362 self.assertAllClose(24.0, stepper.finalize()) 363 364 def testRemoveNonexistentOverrideValue(self): 365 with NodeStepper(self.sess, self.e) as stepper: 366 self.assertEqual([], stepper.override_names()) 367 with self.assertRaisesRegexp( 368 ValueError, "No overriding value exists for tensor \"c:0\""): 369 stepper.remove_override("c:0") 370 371 def testAttemptToOverrideInvalidTensor(self): 372 stepper = NodeStepper(self.sess, self.e) 373 374 with self.assertRaisesRegexp(ValueError, "Cannot override tensor \"f:0\""): 375 stepper.override_tensor("f:0", 42.0) 376 377 def testInvalidOverrideArgumentType(self): 378 with NodeStepper(self.sess, self.e) as stepper: 379 with self.assertRaisesRegexp(TypeError, "Expected type str; got type"): 380 stepper.override_tensor(self.a, 42.0) 381 382 def testTransitiveClosureWithCrossLinksShouldHaveCorrectOrder(self): 383 with NodeStepper(self.sess, "z:0") as stepper: 384 sorted_nodes = stepper.sorted_nodes() 385 self.assertEqual(4, len(sorted_nodes)) 386 self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) 387 self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) 388 self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) 389 self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) 390 391 def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self): 392 for i in range(6): 393 if i == 0: 394 fetches = [self.e, [self.f, self.z]] 395 elif i == 1: 396 fetches = (self.e, (self.f, self.z)) 397 elif i == 2: 398 fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}} 399 elif i == 3: 400 fetches = ["e:0", ["f:0", "z:0"]] 401 elif i == 4: 402 fetches = ("e:0", ("f:0", "z:0")) 403 elif i == 5: 404 fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}} 405 406 with NodeStepper(self.sess, fetches) as stepper: 407 sorted_nodes = stepper.sorted_nodes() 408 self.assertEqual(13, len(sorted_nodes)) 409 410 # Check the topological order of the sorted nodes. 411 self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read")) 412 self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y")) 413 self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z")) 414 self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z")) 415 416 self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read")) 417 self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read")) 418 self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c")) 419 self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c")) 420 self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d")) 421 self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e")) 422 self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e")) 423 self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f")) 424 self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f")) 425 426 closure_elements = stepper.closure_elements() 427 self.assertIn("x/read:0", closure_elements) 428 self.assertIn("e:0", closure_elements) 429 self.assertIn("f:0", closure_elements) 430 431 self.assertEqual([0], stepper.output_slots_in_closure("x/read")) 432 self.assertEqual([0], stepper.output_slots_in_closure("e")) 433 self.assertEqual([0], stepper.output_slots_in_closure("f")) 434 435 result = stepper.finalize() 436 if i == 0 or i == 1 or i == 3 or i == 4: 437 self.assertAllClose(24.0, result[0]) 438 self.assertAllClose(10.0, result[1][0]) 439 self.assertAllClose(-4.0, result[1][1]) 440 elif i == 2 or i == 5: 441 self.assertAllClose(24.0, result["e"]) 442 self.assertAllClose(10.0, result["fz"]["f"]) 443 self.assertAllClose(-4.0, result["fz"]["z"]) 444 445 446 class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase): 447 448 def setUp(self): 449 self.ph0 = array_ops.placeholder(dtypes.float32, shape=(2, 2), name="ph0") 450 self.ph1 = array_ops.placeholder(dtypes.float32, shape=(2, 1), name="ph1") 451 452 self.x = math_ops.matmul(self.ph0, self.ph1, name="x") 453 self.y = math_ops.add(self.x, self.ph1, name="y") 454 455 self.sess = session.Session() 456 457 def tearDown(self): 458 ops.reset_default_graph() 459 460 def testGetTensorValueWorksOnPlaceholder(self): 461 with NodeStepper( 462 self.sess, 463 self.y, 464 feed_dict={ 465 self.ph0: [[1.0, 2.0], [-3.0, 5.0]], 466 self.ph1: [[-1.0], [0.5]] 467 }) as stepper: 468 self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], 469 stepper.get_tensor_value("ph0")) 470 self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]], 471 stepper.get_tensor_value("ph0:0")) 472 with self.assertRaisesRegexp( 473 KeyError, 474 r"The name 'ph0:1' refers to a Tensor which does not exist"): 475 stepper.get_tensor_value("ph0:1") 476 477 def testIsPlaceholdersShouldGiveCorrectAnswers(self): 478 with NodeStepper(self.sess, self.y) as stepper: 479 self.assertTrue(stepper.is_placeholder(self.ph0.name)) 480 self.assertTrue(stepper.is_placeholder(self.ph1.name)) 481 482 self.assertFalse(stepper.is_placeholder(self.x.name)) 483 self.assertFalse(stepper.is_placeholder(self.y.name)) 484 485 with self.assertRaisesRegexp(ValueError, 486 "A is not in the transitive closure"): 487 self.assertFalse(stepper.is_placeholder("A")) 488 489 def testPlaceholdersShouldGiveCorrectAnswers(self): 490 with NodeStepper(self.sess, self.y) as stepper: 491 self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders())) 492 493 def testContWithPlaceholders(self): 494 with NodeStepper( 495 self.sess, 496 self.y, 497 feed_dict={ 498 self.ph0: [[1.0, 2.0], [-3.0, 5.0]], 499 self.ph1: [[-1.0], [0.5]] 500 }) as stepper: 501 self.assertEqual(4, len(stepper.sorted_nodes())) 502 self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"}, 503 set(stepper.closure_elements())) 504 505 result = stepper.cont(self.x) 506 self.assertAllClose([[0.0], [5.5]], result) 507 self.assertEqual({ 508 "ph0:0": NodeStepper.FEED_TYPE_CLIENT, 509 "ph1:0": NodeStepper.FEED_TYPE_CLIENT, 510 }, stepper.last_feed_types()) 511 512 self.assertEqual(["x:0"], stepper.handle_names()) 513 self.assertSetEqual({"x"}, stepper.handle_node_names()) 514 515 result = stepper.cont(self.y) 516 self.assertAllClose([[-1.0], [6.0]], result) 517 self.assertEqual({ 518 "x:0": NodeStepper.FEED_TYPE_HANDLE, 519 "ph1:0": NodeStepper.FEED_TYPE_CLIENT, 520 }, stepper.last_feed_types()) 521 522 def testAttemptToContToPlaceholderWithTensorFeedKeysShouldWork(self): 523 """Continuing to a placeholder should be allowed, using client feed.""" 524 525 ph0_feed = [[1.0, 2.0], [-3.0, 5.0]] 526 ph1_feed = [[-1.0], [0.5]] 527 with NodeStepper( 528 self.sess, self.y, feed_dict={ 529 self.ph0: ph0_feed, 530 self.ph1: ph1_feed, 531 }) as stepper: 532 self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) 533 self.assertEqual({ 534 self.ph0.name: NodeStepper.FEED_TYPE_CLIENT 535 }, stepper.last_feed_types()) 536 537 self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) 538 self.assertEqual({ 539 self.ph1.name: NodeStepper.FEED_TYPE_CLIENT 540 }, stepper.last_feed_types()) 541 542 ph0_node = self.sess.graph.as_graph_element("ph0") 543 self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) 544 self.assertEqual({ 545 self.ph0.name: NodeStepper.FEED_TYPE_CLIENT 546 }, stepper.last_feed_types()) 547 548 self.assertAllClose([[-1.0], [6.0]], stepper.finalize()) 549 550 def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self): 551 552 ph0_feed = [[1.0, 2.0], [-3.0, 5.0]] 553 ph1_feed = [[-1.0], [0.5]] 554 with NodeStepper( 555 self.sess, 556 self.y, 557 feed_dict={ 558 self.ph0.name: ph0_feed, 559 self.ph1.name: ph1_feed, 560 }) as stepper: 561 self.assertAllClose(ph0_feed, stepper.cont(self.ph0)) 562 self.assertEqual({ 563 self.ph0.name: NodeStepper.FEED_TYPE_CLIENT 564 }, stepper.last_feed_types()) 565 566 self.assertAllClose(ph1_feed, stepper.cont(self.ph1)) 567 self.assertEqual({ 568 self.ph1.name: NodeStepper.FEED_TYPE_CLIENT 569 }, stepper.last_feed_types()) 570 571 ph0_node = self.sess.graph.as_graph_element("ph0") 572 self.assertAllClose(ph0_feed, stepper.cont(ph0_node)) 573 self.assertEqual({ 574 self.ph0.name: NodeStepper.FEED_TYPE_CLIENT 575 }, stepper.last_feed_types()) 576 577 self.assertAllClose([[-1.0], [6.0]], stepper.finalize()) 578 579 580 class StepperAssignAddTest(test_util.TensorFlowTestCase): 581 582 def setUp(self): 583 self.v = variables.Variable(10.0, name="v") 584 self.p = math_ops.add(self.v, self.v, name="p") 585 self.q = math_ops.multiply(self.p, self.p, name="q") 586 self.delta = constant_op.constant(2.0, name="delta") 587 self.v_add = state_ops.assign_add(self.v, self.delta, name="v_add") 588 self.v_add_plus_one = math_ops.add(self.v_add, 589 1.0, 590 name="v_add_plus_one") 591 592 rewriter_config = rewriter_config_pb2.RewriterConfig( 593 disable_model_pruning=True, 594 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 595 constant_folding=rewriter_config_pb2.RewriterConfig.OFF) 596 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 597 config = config_pb2.ConfigProto(graph_options=graph_options) 598 self.sess = session.Session(config=config) 599 self.sess.run(self.v.initializer) 600 601 def tearDown(self): 602 ops.reset_default_graph() 603 604 def testLastUpdatedVariablesReturnsNoneBeforeAnyContCalls(self): 605 with NodeStepper(self.sess, [self.q, self.v_add]) as stepper: 606 self.assertIsNone(stepper.last_updated()) 607 608 def testContToUpdateInvalidatesDumpedIntermediates(self): 609 with NodeStepper(self.sess, [self.q, self.v_add]) as stepper: 610 self.assertAllClose(400.0, stepper.cont("q:0")) 611 self.assertItemsEqual(["v/read:0", "p:0"], 612 stepper.intermediate_tensor_names()) 613 self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0")) 614 self.assertAllClose(20.0, stepper.get_tensor_value("p:0")) 615 616 self.assertAllClose( 617 12.0, stepper.cont( 618 self.v_add, invalidate_from_updated_variables=True)) 619 self.assertAllClose(12.0, self.sess.run(self.v)) 620 self.assertSetEqual({self.v.name}, stepper.last_updated()) 621 self.assertItemsEqual(["v:0"], stepper.dirty_variables()) 622 # Updating the value of v by calling v_add should have invalidated the 623 # dumped intermediate tensors for v/read:0 and p:0. 624 self.assertItemsEqual(["delta:0"], stepper.intermediate_tensor_names()) 625 with self.assertRaisesRegexp( 626 ValueError, 627 r"This stepper instance does not have access to the value of tensor " 628 r"\"p:0\""): 629 stepper.get_tensor_value("p:0") 630 631 # The next cont to q should not have used any dumped intermediate tensors 632 # and its result should reflect the updated value. 633 self.assertAllClose(576.0, stepper.cont("q:0")) 634 self.assertSetEqual(set(), stepper.last_updated()) 635 self.assertEqual({}, stepper.last_feed_types()) 636 637 def testOverridingUpstreamTensorInvalidatesDumpedIntermediates(self): 638 with NodeStepper(self.sess, self.q) as stepper: 639 self.assertAllClose(400.0, stepper.cont("q:0")) 640 self.assertItemsEqual(["v/read:0", "p:0"], 641 stepper.intermediate_tensor_names()) 642 self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0")) 643 self.assertAllClose(20.0, stepper.get_tensor_value("p:0")) 644 645 stepper.override_tensor("v/read:0", 11.0) 646 self.assertItemsEqual(["v/read:0"], stepper.override_names()) 647 # Overriding the upstream v/read:0 should have invalidated the dumped 648 # intermediate tensor for the downstream p:0. 649 self.assertItemsEqual([], stepper.intermediate_tensor_names()) 650 651 # The next cont to q should not have used any dumped intermediate tensors 652 # and its result should reflect the overriding value. 653 self.assertAllClose(484.0, stepper.cont("q:0")) 654 self.assertEqual({ 655 "v/read:0": NodeStepper.FEED_TYPE_OVERRIDE 656 }, stepper.last_feed_types()) 657 658 def testRemovingOverrideToUpstreamTensorInvalidatesDumpedIntermediates(self): 659 with NodeStepper(self.sess, self.q) as stepper: 660 stepper.override_tensor("v/read:0", 9.0) 661 self.assertItemsEqual(["v/read:0"], stepper.override_names()) 662 663 self.assertAllClose(324.0, stepper.cont(self.q)) 664 self.assertItemsEqual(["p:0"], stepper.intermediate_tensor_names()) 665 666 stepper.remove_override("v/read:0") 667 self.assertItemsEqual([], stepper.override_names()) 668 # Removing the pre-existing override to v/read:0 should have invalidated 669 # the dumped intermediate tensor. 670 self.assertItemsEqual([], stepper.intermediate_tensor_names()) 671 672 def testRepeatedCallsToAssignAddDoesNotUpdateVariableAgain(self): 673 with NodeStepper(self.sess, self.v_add) as stepper: 674 stepper.cont(self.v_add) 675 self.assertSetEqual({self.v.name}, stepper.last_updated()) 676 self.assertAllClose(12.0, stepper.cont(self.v)) 677 stepper.cont(self.v_add) 678 self.assertSetEqual(set(), stepper.last_updated()) 679 self.assertEqual({"v_add:0": NodeStepper.FEED_TYPE_HANDLE}, 680 stepper.last_feed_types()) 681 self.assertAllClose(12.0, stepper.cont(self.v)) 682 683 def testRepeatedCallsToAssignAddDownStreamDoesNotUpdateVariableAgain(self): 684 with NodeStepper(self.sess, self.v_add_plus_one) as stepper: 685 stepper.cont(self.v_add_plus_one) 686 self.assertSetEqual({self.v.name}, stepper.last_updated()) 687 self.assertAllClose(12.0, stepper.cont(self.v)) 688 stepper.cont(self.v_add_plus_one) 689 self.assertSetEqual(set(), stepper.last_updated()) 690 self.assertEqual({"v_add_plus_one:0": NodeStepper.FEED_TYPE_HANDLE}, 691 stepper.last_feed_types()) 692 self.assertAllClose(12.0, stepper.cont(self.v)) 693 694 695 class StepperBackwardRunTest(test_util.TensorFlowTestCase): 696 697 def setUp(self): 698 """Test setup. 699 700 Structure of the forward graph: 701 f 702 | | 703 ----- ----- 704 | | 705 d e 706 | | | | 707 --- --------- --- 708 | | | 709 a b c 710 711 Construct a backward graph using the GradientDescentOptimizer. 712 """ 713 714 self.a = variables.Variable(1.0, name="a") 715 self.b = variables.Variable(2.0, name="b") 716 self.c = variables.Variable(4.0, name="c") 717 self.d = math_ops.multiply(self.a, self.b, name="d") 718 self.e = math_ops.multiply(self.b, self.c, name="e") 719 self.f = math_ops.multiply(self.d, self.e, name="f") 720 721 # Gradient descent optimizer that minimizes g. 722 gradient_descent.GradientDescentOptimizer(0.01).minimize( 723 self.f, name="optim") 724 725 rewriter_config = rewriter_config_pb2.RewriterConfig( 726 disable_model_pruning=True, 727 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 728 constant_folding=rewriter_config_pb2.RewriterConfig.OFF) 729 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 730 config = config_pb2.ConfigProto(graph_options=graph_options) 731 self.sess = session.Session(config=config) 732 self.sess.run(variables.global_variables_initializer()) 733 734 def tearDown(self): 735 ops.reset_default_graph() 736 737 def testContToUpdateA(self): 738 with NodeStepper(self.sess, "optim") as stepper: 739 result = stepper.cont("a:0") 740 self.assertAllClose(1.0, result) 741 self.assertEqual({}, stepper.last_feed_types()) 742 743 result = stepper.cont("optim/learning_rate:0") 744 self.assertAllClose(0.01, result) 745 self.assertEqual({}, stepper.last_feed_types()) 746 747 # Before any cont calls on ApplyGradientDescent, there should be no 748 # "dirty" variables. 749 self.assertEqual(set(), stepper.dirty_variables()) 750 751 # First, all the two control inputs to optim. 752 result = stepper.cont("optim/update_a/ApplyGradientDescent", 753 invalidate_from_updated_variables=True) 754 755 # Now variable a should have been marked as dirty due to the update 756 # by optim/update_a/ApplyGradientDescent. 757 self.assertSetEqual({"a:0"}, stepper.last_updated()) 758 self.assertEqual({"a:0"}, stepper.dirty_variables()) 759 self.assertIsNone(result) 760 self.assertEqual({ 761 "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE 762 }, stepper.last_feed_types()) 763 764 # Check that Variable "a" has been updated properly, but "b", "c" and "d" 765 # remain the same. 766 # For backprop on Variable a: 767 # Because f = a * b * b * c, df / da = b * b * c. 768 # 1.0 - learning_rate * b * b * c 769 # = 1.0 - 0.01 * 2.0 * 2.0 * 4.0 = 0.84. 770 self.assertAllClose(0.84, self.sess.run(self.a)) 771 self.assertAllClose(2.0, self.sess.run(self.b)) 772 self.assertAllClose(4.0, self.sess.run(self.c)) 773 774 def testContToUpdateB(self): 775 with NodeStepper(self.sess, "optim") as stepper: 776 result = stepper.cont("optim/update_b/ApplyGradientDescent", 777 invalidate_from_updated_variables=True) 778 self.assertIsNone(result) 779 self.assertSetEqual({"b:0"}, stepper.last_updated()) 780 self.assertEqual(set(["b:0"]), stepper.dirty_variables()) 781 782 # For backprop on Variable b: 783 # Because f = a * b * b * c, df / da = 2 * a * b * c. 784 # 2.0 - learning_rate * 2 * a * b * c 785 # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84 786 self.assertAllClose(1.0, self.sess.run(self.a)) 787 self.assertAllClose(1.84, self.sess.run(self.b)) 788 self.assertAllClose(4.0, self.sess.run(self.c)) 789 790 def testContAfterUpdateWithoutRestoringVariableValue(self): 791 with NodeStepper(self.sess, "optim") as stepper: 792 # First, update Variable a from 1.0 to 0.84. 793 result = stepper.cont( 794 "optim/update_a/ApplyGradientDescent", 795 invalidate_from_updated_variables=True, 796 restore_variable_values=True) 797 self.assertIsNone(result) 798 self.assertSetEqual({"a:0"}, stepper.last_updated()) 799 self.assertEqual(set(["a:0"]), stepper.dirty_variables()) 800 self.assertAllClose(0.84, self.sess.run(self.a)) 801 self.assertAllClose(2.0, self.sess.run(self.b)) 802 self.assertAllClose(4.0, self.sess.run(self.c)) 803 # Tracking of the updated variables should have invalidated all 804 # intermediate tensors downstream to a:0. 805 self.assertNotIn("a/read:0", stepper.intermediate_tensor_names()) 806 self.assertNotIn("d:0", stepper.intermediate_tensor_names()) 807 808 # Second, update Variable b without the default restore_variable_values. 809 result = stepper.cont( 810 "optim/update_b/ApplyGradientDescent", restore_variable_values=False) 811 self.assertIsNone(result) 812 # For the backprop on Variable b under the updated value of a: 813 # 2.0 - learning_rate * 2 * a' * b * c 814 # = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656 815 self.assertAllClose(0.84, self.sess.run(self.a)) 816 self.assertAllClose(1.8656, self.sess.run(self.b)) 817 self.assertAllClose(4.0, self.sess.run(self.c)) 818 819 def testContNotInvalidatingFromVariableUpdatesWorksForNextUpdate(self): 820 with NodeStepper(self.sess, "optim") as stepper: 821 self.assertIsNone(stepper.cont( 822 "optim/update_a/ApplyGradientDescent", 823 invalidate_from_updated_variables=False)) 824 # Even though invalidate_from_updated_variables is set to False, dirty 825 # variables should still have been tracked. 826 self.assertSetEqual({"a:0"}, stepper.last_updated()) 827 self.assertEqual({"a:0"}, stepper.dirty_variables()) 828 self.assertIn("a/read:0", stepper.intermediate_tensor_names()) 829 self.assertIn("b/read:0", stepper.intermediate_tensor_names()) 830 self.assertIn("c/read:0", stepper.intermediate_tensor_names()) 831 self.assertIn("d:0", stepper.intermediate_tensor_names()) 832 self.assertIn("e:0", stepper.intermediate_tensor_names()) 833 self.assertIn("optim/learning_rate:0", 834 stepper.intermediate_tensor_names()) 835 self.assertNotIn("a:0", stepper.intermediate_tensor_names()) 836 self.assertNotIn("b:0", stepper.intermediate_tensor_names()) 837 self.assertNotIn("c:0", stepper.intermediate_tensor_names()) 838 839 self.assertAllClose(0.84, self.sess.run(self.a)) 840 self.assertAllClose(2.0, self.sess.run(self.b)) 841 self.assertAllClose(4.0, self.sess.run(self.c)) 842 843 # For the backprop on Variable b, the result should reflect the original 844 # value of Variable a, even though Variable a has actually been updated. 845 # 2.0 - learning_rate * 2 * a * b * c 846 # = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84 847 self.assertIsNone(stepper.cont( 848 "optim/update_b/ApplyGradientDescent", 849 invalidate_from_updated_variables=False, 850 restore_variable_values=False)) 851 self.assertAllClose(0.84, self.sess.run(self.a)) 852 self.assertAllClose(1.84, self.sess.run(self.b)) 853 self.assertAllClose(4.0, self.sess.run(self.c)) 854 855 def testUpdateTwiceRestoreVariable(self): 856 with NodeStepper(self.sess, "optim") as stepper: 857 result = stepper.cont( 858 "optim/update_a/ApplyGradientDescent", 859 invalidate_from_updated_variables=True, 860 restore_variable_values=True) 861 self.assertIsNone(result) 862 self.assertSetEqual({"a:0"}, stepper.last_updated()) 863 self.assertEqual({"a:0"}, stepper.dirty_variables()) 864 865 result = stepper.cont( 866 "optim/update_b/ApplyGradientDescent", 867 invalidate_from_updated_variables=True, 868 restore_variable_values=True) 869 self.assertIsNone(result) 870 # Variables a and c should have been restored and hence no longer dirty. 871 # Variable b should have been marked as dirty. 872 self.assertSetEqual({"b:0"}, stepper.last_updated()) 873 self.assertEqual({"b:0"}, stepper.dirty_variables()) 874 875 # The result of the update should be identitcal to as if only update_b is 876 # run. 877 self.assertAllClose(1.0, self.sess.run(self.a)) 878 self.assertAllClose(1.84, self.sess.run(self.b)) 879 self.assertAllClose(4.0, self.sess.run(self.c)) 880 881 def testSelectiveHandleUsageDependingOnTransitiveCleanliness(self): 882 """Test tensor handlers are using only during clean transitive closure. 883 884 "clean" means no Variables have been updated by preceding cont() calls. 885 """ 886 887 with NodeStepper(self.sess, "optim") as stepper: 888 # First, call cont() on the two tensors on the intermediate level: e and 889 # f. 890 result = stepper.cont("d:0") 891 self.assertAllClose(2.0, result) 892 self.assertEqual({}, stepper.last_feed_types()) 893 self.assertItemsEqual(["a/read:0", "b/read:0"], 894 stepper.intermediate_tensor_names()) 895 self.assertItemsEqual(["d:0"], stepper.handle_names()) 896 self.assertSetEqual(set(), stepper.last_updated()) 897 self.assertEqual(set(), stepper.dirty_variables()) 898 899 result = stepper.cont("e:0") 900 self.assertAllClose(8.0, result) 901 self.assertEqual({ 902 "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE 903 }, stepper.last_feed_types()) 904 self.assertItemsEqual(["d:0", "e:0"], stepper.handle_names()) 905 self.assertItemsEqual(["a/read:0", "b/read:0", "c/read:0"], 906 stepper.intermediate_tensor_names()) 907 self.assertSetEqual(set(), stepper.last_updated()) 908 self.assertEqual(set(), stepper.dirty_variables()) 909 910 # Now run update_a, so as to let Variable a be dirty. 911 result = stepper.cont( 912 "optim/update_a/ApplyGradientDescent", 913 invalidate_from_updated_variables=True, 914 restore_variable_values=True) 915 self.assertIsNone(result) 916 # Due to the update to the value of a:0, the dumped intermediate a/read:0 917 # should have been invalidated. 918 self.assertNotIn("a/read:0", stepper.intermediate_tensor_names()) 919 self.assertSetEqual({"a:0"}, stepper.last_updated()) 920 self.assertEqual({"a:0"}, stepper.dirty_variables()) 921 922 # Now, run update_b. 923 result = stepper.cont( 924 "optim/update_b/ApplyGradientDescent", restore_variable_values=True) 925 self.assertIsNone(result) 926 927 # The last cont() run should have use the handle of tensor e, but not the 928 # handle of tensor d, because the transitive closure of e is clean, 929 # whereas that of d is dirty due to the update to a in the previous cont() 930 # call. 931 last_feed_types = stepper.last_feed_types() 932 self.assertNotIn("d:0", last_feed_types) 933 self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 934 last_feed_types["b/read:0"]) 935 self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 936 last_feed_types["c/read:0"]) 937 938 # The result of the update_b should be identical to as if no other 939 # update_* cont() calls have occurred before. 940 self.assertAllClose(1.0, self.sess.run(self.a)) 941 self.assertAllClose(1.84, self.sess.run(self.b)) 942 self.assertAllClose(4.0, self.sess.run(self.c)) 943 944 def testRestoreVariableValues(self): 945 """Test restore_variable_values() restores the old values of variables.""" 946 947 with NodeStepper(self.sess, "optim") as stepper: 948 stepper.cont( 949 "optim/update_b/ApplyGradientDescent", 950 invalidate_from_updated_variables=True, 951 restore_variable_values=True) 952 self.assertAllClose(1.84, self.sess.run(self.b)) 953 954 stepper.restore_variable_values() 955 self.assertAllClose(2.0, self.sess.run(self.b)) 956 957 def testFinalize(self): 958 """Test finalize() to restore variables and run the original fetch.""" 959 960 with NodeStepper(self.sess, "optim") as stepper: 961 # Invoke update_b before calling finalize. 962 stepper.cont( 963 "optim/update_b/ApplyGradientDescent", 964 invalidate_from_updated_variables=True, 965 restore_variable_values=True) 966 967 result = stepper.finalize() 968 self.assertIsNone(result) 969 970 # The results of the Variable updates should be the same as if no cont() 971 # call has occurred on update_b. 972 self.assertAllClose(0.84, self.sess.run(self.a)) 973 self.assertAllClose(1.84, self.sess.run(self.b)) 974 self.assertAllClose(3.96, self.sess.run(self.c)) 975 976 def testOverrideThenContToUpdateThenRemoveOverrideThenUpdateAgain(self): 977 """Test cont() to update nodes after overriding tensor values.""" 978 979 with NodeStepper(self.sess, "optim") as stepper: 980 result = stepper.cont("d:0") 981 self.assertAllClose(2.0, result) 982 self.assertEqual({}, stepper.last_feed_types()) 983 self.assertSetEqual(set(), stepper.last_updated()) 984 self.assertEqual(set(), stepper.dirty_variables()) 985 self.assertEqual(["d:0"], stepper.handle_names()) 986 self.assertSetEqual({"d"}, stepper.handle_node_names()) 987 988 # Override the value from 1.0 to 10.0. 989 stepper.override_tensor("a/read:0", 10.0) 990 991 self.assertEqual(["a/read:0"], stepper.override_names()) 992 993 result = stepper.cont( 994 "optim/update_c/ApplyGradientDescent", 995 invalidate_from_updated_variables=True, 996 restore_variable_values=True) 997 self.assertIsNone(result) 998 999 # The last cont() call should have not used the tensor handle to d:0, 1000 # because the transitive closure of d:0 contains an override tensor. 1001 self.assertEqual({ 1002 "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE, 1003 "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 1004 }, stepper.last_feed_types()) 1005 1006 # The tensor handle to d:0 should have been removed due to the dirty 1007 # transitive closure. 1008 self.assertEqual([], stepper.handle_names()) 1009 self.assertSetEqual(set(), stepper.handle_node_names()) 1010 1011 # For this backprop on c, the overriding value of a/read:0 should have 1012 # been used: 1013 # 4.0 - learning_rate * a * b * b 1014 # = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6. 1015 self.assertAllClose(3.6, self.sess.run(self.c)) 1016 1017 # Now remove the overriding value of a/read:0. 1018 stepper.remove_override("a/read:0") 1019 self.assertEqual([], stepper.override_names()) 1020 1021 # Obtain the tensor handle to d:0 again. 1022 result = stepper.cont("d:0") 1023 self.assertAllClose(2.0, result) 1024 self.assertEqual(["d:0"], stepper.handle_names()) 1025 self.assertSetEqual({"d"}, stepper.handle_node_names()) 1026 self.assertNotIn("a/read:0", stepper.last_feed_types()) 1027 1028 # Then call update_c again, without restoring c. 1029 result = stepper.cont("optim/update_c/ApplyGradientDescent", 1030 restore_variable_values=False) 1031 self.assertIsNone(result) 1032 self.assertNotIn("a/read:0", stepper.last_feed_types()) 1033 1034 # This time, the d:0 tensor handle should have been used, because its 1035 # transitive closure is clean. 1036 self.assertEqual({ 1037 "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 1038 "d:0": NodeStepper.FEED_TYPE_HANDLE, 1039 "optim/learning_rate:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE, 1040 }, stepper.last_feed_types()) 1041 1042 # For this backprop on c, the overriding value of a/read:0 should have 1043 # been used: 1044 # 3.6 - learning_rate * a * b * b 1045 # = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56. 1046 self.assertAllClose(3.56, self.sess.run(self.c)) 1047 1048 def testContToNodeWithOutputTensors(self): 1049 """cont() to an op should cache its output tensors if appropriate.""" 1050 1051 with NodeStepper(self.sess, "optim") as stepper: 1052 # In the transitive closure of the stepper, look for an op of which the 1053 # output tensor also is in the transitive closure. 1054 # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1", 1055 # because it may vary between builds. 1056 closure_elements = stepper.closure_elements() 1057 op_with_output_in_closure = None 1058 for element_name in closure_elements: 1059 if element_name + ":0" in closure_elements: 1060 op_with_output_in_closure = str(element_name) 1061 break 1062 1063 self.assertEqual( 1064 [0], stepper.output_slots_in_closure(op_with_output_in_closure)) 1065 1066 self.assertIsNotNone(op_with_output_in_closure) 1067 output_tensor = op_with_output_in_closure + ":0" 1068 1069 # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the 1070 # stepper, because it is the control input to another o. However, its 1071 # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive 1072 # closure, because it is the (non-control) input of certain ops. Calling 1073 # cont() on the op should lead to the caching of the tensor handle for 1074 # the output tensor. 1075 stepper.cont(op_with_output_in_closure) 1076 1077 self.assertEqual([output_tensor], stepper.handle_names()) 1078 self.assertSetEqual({op_with_output_in_closure}, 1079 stepper.handle_node_names()) 1080 1081 # Do a cont() call that uses the cached tensor of 1082 # "gradients/?_grad/Reshape_1:0". 1083 stepper.cont(output_tensor) 1084 self.assertEqual({ 1085 output_tensor: NodeStepper.FEED_TYPE_HANDLE 1086 }, stepper.last_feed_types()) 1087 1088 1089 if __name__ == "__main__": 1090 googletest.main() 1091