Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests of the tfdbg Stepper."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.core.protobuf import config_pb2
     21 from tensorflow.core.protobuf import rewriter_config_pb2
     22 from tensorflow.python.client import session
     23 from tensorflow.python.debug.lib.stepper import NodeStepper
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import state_ops
     31 from tensorflow.python.ops import variables
     32 from tensorflow.python.platform import googletest
     33 from tensorflow.python.training import gradient_descent
     34 
     35 
     36 class StepperTest(test_util.TensorFlowTestCase):
     37 
     38   def setUp(self):
     39     self.a = variables.Variable(2.0, name="a")
     40     self.b = variables.Variable(3.0, name="b")
     41 
     42     self.c = math_ops.multiply(self.a, self.b, name="c")  # Should be 6.0.
     43     self.d = math_ops.multiply(self.a, self.a, name="d")  # Should be 4.0.
     44 
     45     self.e = math_ops.multiply(self.d, self.c, name="e")  # Should be 24.0.
     46 
     47     self.f_y = constant_op.constant(0.30, name="f_y")
     48     self.f = math_ops.div(self.b, self.f_y, name="f")  # Should be 10.0.
     49 
     50     # The there nodes x, y and z form a graph with "cross-links" in. I.e., x
     51     # and y are both direct inputs to z, but x is also a direct input to y.
     52     self.x = variables.Variable(2.0, name="x")  # Should be 2.0
     53     self.y = math_ops.negative(self.x, name="y")  # Should be -2.0.
     54 
     55     self.z = math_ops.multiply(self.x, self.y, name="z")  # Should be -4.0.
     56 
     57     rewriter_config = rewriter_config_pb2.RewriterConfig(
     58         disable_model_pruning=True,
     59         arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
     60         constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
     61     graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
     62     config = config_pb2.ConfigProto(graph_options=graph_options)
     63     self.sess = session.Session(config=config)
     64     self.sess.run(variables.global_variables_initializer())
     65 
     66   def tearDown(self):
     67     ops.reset_default_graph()
     68 
     69   def testContToFetchNotInTransitiveClosureShouldError(self):
     70     with NodeStepper(self.sess, "e:0") as stepper:
     71       sorted_nodes = stepper.sorted_nodes()
     72       self.assertEqual(7, len(sorted_nodes))
     73       self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
     74       self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
     75       self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
     76       self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
     77       self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
     78       self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
     79       self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
     80 
     81       self.assertSetEqual(
     82           {"e:0", "d:0", "c:0", "a/read:0", "b/read:0", "b:0", "a:0"},
     83           set(stepper.closure_elements()))
     84 
     85       with self.assertRaisesRegexp(
     86           ValueError,
     87           "Target \"f:0\" is not in the transitive closure for the fetch of "
     88           "the stepper"):
     89         stepper.cont("f:0")
     90 
     91   def testContToNodeNameShouldReturnTensorValue(self):
     92     with NodeStepper(self.sess, "e:0") as stepper:
     93       self.assertAllClose(6.0, stepper.cont("c"))
     94 
     95   def testUsingNamesNotUsingIntermediateTensors(self):
     96     with NodeStepper(self.sess, "e:0") as stepper:
     97       # The first cont() call should have used no feeds.
     98       result = stepper.cont("c:0")
     99       self.assertAllClose(6.0, result)
    100       self.assertItemsEqual(["a/read:0", "b/read:0"],
    101                             stepper.intermediate_tensor_names())
    102       self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
    103       self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
    104       self.assertEqual({}, stepper.last_feed_types())
    105 
    106       # The second cont() call should have used the tensor handle from the
    107       # previous cont() call.
    108       result = stepper.cont("e:0")
    109       self.assertAllClose(24.0, result)
    110       self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"],
    111                             stepper.intermediate_tensor_names())
    112       self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
    113       self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
    114       self.assertAllClose(4.0, stepper.get_tensor_value("d:0"))
    115       self.assertEqual({
    116           "c:0": NodeStepper.FEED_TYPE_HANDLE,
    117           "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
    118       }, stepper.last_feed_types())
    119 
    120   def testUsingNodesNotUsingIntermediateTensors(self):
    121     with NodeStepper(self.sess, self.e) as stepper:
    122       # There should be no handles before any cont() calls.
    123       self.assertEqual([], stepper.handle_names())
    124       self.assertSetEqual(set(), stepper.handle_node_names())
    125 
    126       # Before the cont() call, the stepper should not have access to the value
    127       # of c:0.
    128       with self.assertRaisesRegexp(
    129           ValueError,
    130           "This stepper instance does not have access to the value of tensor "
    131           "\"c:0\""):
    132         stepper.get_tensor_value("c:0")
    133 
    134       # Using the node/tensor itself, instead of the name str, should work on
    135       # cont().
    136       result = stepper.cont(self.c)
    137       self.assertItemsEqual(["a/read:0", "b/read:0"],
    138                             stepper.intermediate_tensor_names())
    139       self.assertAllClose(6.0, result)
    140       self.assertEqual({}, stepper.last_feed_types())
    141 
    142       self.assertEqual(["c:0"], stepper.handle_names())
    143       self.assertEqual({"c"}, stepper.handle_node_names())
    144 
    145       # After the cont() call, the stepper should have access to the value of
    146       # c:0 via a tensor handle.
    147       self.assertAllClose(6.0, stepper.get_tensor_value("c:0"))
    148 
    149       result = stepper.cont(self.e)
    150       self.assertAllClose(24.0, result)
    151       self.assertItemsEqual(["a/read:0", "b/read:0", "d:0"],
    152                             stepper.intermediate_tensor_names())
    153       self.assertEqual({
    154           "c:0": NodeStepper.FEED_TYPE_HANDLE,
    155           "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
    156       }, stepper.last_feed_types())
    157 
    158   def testContToTensorWithIntermediateDumpShouldUseDump(self):
    159     with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper:
    160       stepper.cont("c:0")
    161       self.assertItemsEqual(["a/read:0", "b/read:0"],
    162                             stepper.intermediate_tensor_names())
    163       self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
    164       self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
    165 
    166       self.assertAllClose(2.0, stepper.cont("a/read:0"))
    167       self.assertEqual({
    168           "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
    169       }, stepper.last_feed_types())
    170 
    171       self.assertAllClose(10.0, stepper.cont("f:0"))
    172       self.assertEqual({
    173           "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
    174       }, stepper.last_feed_types())
    175 
    176   def testDisablingUseDumpedIntermediatesWorks(self):
    177     with NodeStepper(self.sess, ["e:0", "f:0"]) as stepper:
    178       stepper.cont("c:0")
    179       self.assertItemsEqual(["a/read:0", "b/read:0"],
    180                             stepper.intermediate_tensor_names())
    181       self.assertAllClose(2.0, stepper.get_tensor_value("a/read:0"))
    182       self.assertAllClose(3.0, stepper.get_tensor_value("b/read:0"))
    183 
    184       self.assertAllClose(10.0,
    185                           stepper.cont("f:0", use_dumped_intermediates=False))
    186       self.assertEqual({}, stepper.last_feed_types())
    187 
    188   def testIsFeedableShouldGiveCorrectAnswers(self):
    189     with NodeStepper(self.sess, self.e) as stepper:
    190       self.assertTrue(stepper.is_feedable("a/read:0"))
    191       self.assertTrue(stepper.is_feedable("b/read:0"))
    192       self.assertTrue(stepper.is_feedable("c:0"))
    193       self.assertTrue(stepper.is_feedable("d:0"))
    194 
    195   def testOverrideValue(self):
    196     with NodeStepper(self.sess, self.e) as stepper:
    197       result = stepper.cont(self.c)
    198       self.assertAllClose(6.0, result)
    199       self.assertEqual({}, stepper.last_feed_types())
    200 
    201       # There should be no overrides before any cont() calls.
    202       self.assertEqual([], stepper.override_names())
    203 
    204       # Calling cont() on c again should lead to use of the handle.
    205       result = stepper.cont(self.c)
    206       self.assertAllClose(6.0, result)
    207       self.assertEqual({
    208           "c:0": NodeStepper.FEED_TYPE_HANDLE
    209       }, stepper.last_feed_types())
    210 
    211       # Override c:0.
    212       stepper.override_tensor("c:0", 7.0)
    213 
    214       # After the overriding, calling get_tensor_value() on c:0 should yield the
    215       # overriding value.
    216       self.assertEqual(7.0, stepper.get_tensor_value("c:0"))
    217 
    218       # Now c:0 should have only an override value, but no cached handle,
    219       # because the handle should have been invalidated.
    220       self.assertEqual([], stepper.handle_names())
    221       self.assertSetEqual(set(), stepper.handle_node_names())
    222       self.assertEqual(["c:0"], stepper.override_names())
    223 
    224       # Run a downstream tensor after the value override.
    225       result = stepper.cont(self.e)
    226       self.assertAllClose(28.0, result)  # Should reflect the overriding value.
    227 
    228       # Should use override, instead of the handle.
    229       self.assertEqual({
    230           "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
    231           "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
    232       }, stepper.last_feed_types())
    233 
    234   def testOverrideValueTwice(self):
    235     with NodeStepper(self.sess, self.e) as stepper:
    236       # Override once.
    237       stepper.override_tensor("c:0", 7.0)
    238       self.assertAllClose(28.0, stepper.cont(self.e))
    239       self.assertEqual({
    240           "c:0": NodeStepper.FEED_TYPE_OVERRIDE
    241       }, stepper.last_feed_types())
    242 
    243       self.assertEqual(["e:0"], stepper.handle_names())
    244       self.assertSetEqual({"e"}, stepper.handle_node_names())
    245       self.assertEqual(["c:0"], stepper.override_names())
    246 
    247       # Calling cont(self.e) again. This time the cached tensor handle of e
    248       # should be used.
    249       self.assertEqual(28.0, stepper.cont(self.e))
    250       self.assertEqual({
    251           "e:0": NodeStepper.FEED_TYPE_HANDLE
    252       }, stepper.last_feed_types())
    253 
    254       # Override c again. This should have invalidated the cache for e.
    255       stepper.override_tensor("c:0", 8.0)
    256 
    257       self.assertEqual([], stepper.handle_names())
    258       self.assertEqual(set(), stepper.handle_node_names())
    259       self.assertEqual(["c:0"], stepper.override_names())
    260 
    261       self.assertAllClose(32.0, stepper.cont(self.e))
    262       self.assertEqual({
    263           "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
    264           "d:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
    265       }, stepper.last_feed_types())
    266 
    267   def testRemoveOverrideValue(self):
    268     with NodeStepper(self.sess, self.e) as stepper:
    269       result = stepper.cont(self.c)
    270       self.assertAllClose(6.0, result)
    271       self.assertEqual({}, stepper.last_feed_types())
    272 
    273       # The previous cont() step should have generated a cached tensor handle.
    274       self.assertEqual(["c:0"], stepper.handle_names())
    275       self.assertSetEqual({"c"}, stepper.handle_node_names())
    276 
    277       # Override c:0.
    278       stepper.override_tensor("c:0", 7.0)
    279 
    280       # The overriding should have invalidated the tensor handle.
    281       self.assertEqual([], stepper.handle_names())
    282       self.assertSetEqual(set(), stepper.handle_node_names())
    283       self.assertEqual(["c:0"], stepper.override_names())
    284 
    285       result = stepper.cont(self.e)
    286       self.assertAllClose(28.0, result)  # Should reflect the overriding value.
    287       self.assertEqual({
    288           "c:0": NodeStepper.FEED_TYPE_OVERRIDE,
    289           "a/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
    290       }, stepper.last_feed_types())
    291 
    292       # The handle to tensor e:0 should have been cached, even though its
    293       # transitive closure contains an override.
    294       self.assertIn("e:0", stepper.handle_names())
    295       self.assertSetEqual({"e"}, stepper.handle_node_names())
    296 
    297       # Remove the override.
    298       stepper.remove_override("c:0")
    299       # c:0 should not be in the overrides anymore.
    300       self.assertEqual([], stepper.override_names())
    301 
    302       # Removing the override should have invalidated the tensor handle for c.
    303       self.assertNotIn("e:0", stepper.handle_names())
    304       self.assertNotIn("e", stepper.handle_node_names())
    305 
    306       # Should reflect the non-overriding value.
    307       self.assertAllClose(24.0, stepper.cont(self.e))
    308 
    309       # This time, the handle to tensor e:0 should have been cached again, even
    310       # thought its transitive closure contains an override.
    311       self.assertIn("e:0", stepper.handle_names())
    312       self.assertIn("e", stepper.handle_node_names())
    313 
    314       # Calling cont(self.e) again should have used the tensor handle to e:0.
    315       self.assertAllClose(24.0, stepper.cont(self.e))
    316       self.assertEqual({
    317           "e:0": NodeStepper.FEED_TYPE_HANDLE,
    318       }, stepper.last_feed_types())
    319 
    320   def testOverrideAndContToSameTensor(self):
    321     with NodeStepper(self.sess, self.e) as stepper:
    322       result = stepper.cont(self.c)
    323       self.assertAllClose(6.0, result)
    324       self.assertEqual({}, stepper.last_feed_types())
    325       self.assertEqual(["c:0"], stepper.handle_names())
    326       self.assertSetEqual({"c"}, stepper.handle_node_names())
    327 
    328       self.assertAllClose(6.0, stepper.cont(self.c))
    329 
    330       # The last cont() call should use the tensor handle directly.
    331       self.assertEqual({
    332           "c:0": NodeStepper.FEED_TYPE_HANDLE
    333       }, stepper.last_feed_types())
    334 
    335       # Override c:0.
    336       stepper.override_tensor("c:0", 7.0)
    337 
    338       # As a result of the override, the tensor handle should have been
    339       # invalidated.
    340       self.assertEqual([], stepper.handle_names())
    341       self.assertSetEqual(set(), stepper.handle_node_names())
    342 
    343       result = stepper.cont(self.c)
    344       self.assertAllClose(7.0, result)
    345 
    346       self.assertEqual({
    347           "c:0": NodeStepper.FEED_TYPE_OVERRIDE
    348       }, stepper.last_feed_types())
    349 
    350   def testFinalizeWithPreviousOverrides(self):
    351     with NodeStepper(self.sess, self.e) as stepper:
    352       stepper.override_tensor("a/read:0", 20.0)
    353       self.assertEqual(["a/read:0"], stepper.override_names())
    354 
    355       # Should reflect the overriding value.
    356       self.assertAllClose(24000.0, stepper.cont("e:0"))
    357       self.assertEqual({
    358           "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE
    359       }, stepper.last_feed_types())
    360 
    361       # Finalize call should have ignored the overriding value.
    362       self.assertAllClose(24.0, stepper.finalize())
    363 
    364   def testRemoveNonexistentOverrideValue(self):
    365     with NodeStepper(self.sess, self.e) as stepper:
    366       self.assertEqual([], stepper.override_names())
    367       with self.assertRaisesRegexp(
    368           ValueError, "No overriding value exists for tensor \"c:0\""):
    369         stepper.remove_override("c:0")
    370 
    371   def testAttemptToOverrideInvalidTensor(self):
    372     stepper = NodeStepper(self.sess, self.e)
    373 
    374     with self.assertRaisesRegexp(ValueError, "Cannot override tensor \"f:0\""):
    375       stepper.override_tensor("f:0", 42.0)
    376 
    377   def testInvalidOverrideArgumentType(self):
    378     with NodeStepper(self.sess, self.e) as stepper:
    379       with self.assertRaisesRegexp(TypeError, "Expected type str; got type"):
    380         stepper.override_tensor(self.a, 42.0)
    381 
    382   def testTransitiveClosureWithCrossLinksShouldHaveCorrectOrder(self):
    383     with NodeStepper(self.sess, "z:0") as stepper:
    384       sorted_nodes = stepper.sorted_nodes()
    385       self.assertEqual(4, len(sorted_nodes))
    386       self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
    387       self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
    388       self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
    389       self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
    390 
    391   def testNodeStepperConstructorShouldAllowListOrTupleOrDictOfFetches(self):
    392     for i in range(6):
    393       if i == 0:
    394         fetches = [self.e, [self.f, self.z]]
    395       elif i == 1:
    396         fetches = (self.e, (self.f, self.z))
    397       elif i == 2:
    398         fetches = {"e": self.e, "fz": {"f": self.f, "z": self.z}}
    399       elif i == 3:
    400         fetches = ["e:0", ["f:0", "z:0"]]
    401       elif i == 4:
    402         fetches = ("e:0", ("f:0", "z:0"))
    403       elif i == 5:
    404         fetches = {"e": "e:0", "fz": {"f": "f:0", "z": "z:0"}}
    405 
    406       with NodeStepper(self.sess, fetches) as stepper:
    407         sorted_nodes = stepper.sorted_nodes()
    408         self.assertEqual(13, len(sorted_nodes))
    409 
    410         # Check the topological order of the sorted nodes.
    411         self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("x/read"))
    412         self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("y"))
    413         self.assertLess(sorted_nodes.index("x"), sorted_nodes.index("z"))
    414         self.assertLess(sorted_nodes.index("y"), sorted_nodes.index("z"))
    415 
    416         self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("a/read"))
    417         self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("b/read"))
    418         self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("c"))
    419         self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("c"))
    420         self.assertLess(sorted_nodes.index("a"), sorted_nodes.index("d"))
    421         self.assertLess(sorted_nodes.index("d"), sorted_nodes.index("e"))
    422         self.assertLess(sorted_nodes.index("c"), sorted_nodes.index("e"))
    423         self.assertLess(sorted_nodes.index("b"), sorted_nodes.index("f"))
    424         self.assertLess(sorted_nodes.index("f_y"), sorted_nodes.index("f"))
    425 
    426         closure_elements = stepper.closure_elements()
    427         self.assertIn("x/read:0", closure_elements)
    428         self.assertIn("e:0", closure_elements)
    429         self.assertIn("f:0", closure_elements)
    430 
    431         self.assertEqual([0], stepper.output_slots_in_closure("x/read"))
    432         self.assertEqual([0], stepper.output_slots_in_closure("e"))
    433         self.assertEqual([0], stepper.output_slots_in_closure("f"))
    434 
    435         result = stepper.finalize()
    436         if i == 0 or i == 1 or i == 3 or i == 4:
    437           self.assertAllClose(24.0, result[0])
    438           self.assertAllClose(10.0, result[1][0])
    439           self.assertAllClose(-4.0, result[1][1])
    440         elif i == 2 or i == 5:
    441           self.assertAllClose(24.0, result["e"])
    442           self.assertAllClose(10.0, result["fz"]["f"])
    443           self.assertAllClose(-4.0, result["fz"]["z"])
    444 
    445 
    446 class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
    447 
    448   def setUp(self):
    449     self.ph0 = array_ops.placeholder(dtypes.float32, shape=(2, 2), name="ph0")
    450     self.ph1 = array_ops.placeholder(dtypes.float32, shape=(2, 1), name="ph1")
    451 
    452     self.x = math_ops.matmul(self.ph0, self.ph1, name="x")
    453     self.y = math_ops.add(self.x, self.ph1, name="y")
    454 
    455     self.sess = session.Session()
    456 
    457   def tearDown(self):
    458     ops.reset_default_graph()
    459 
    460   def testGetTensorValueWorksOnPlaceholder(self):
    461     with NodeStepper(
    462         self.sess,
    463         self.y,
    464         feed_dict={
    465             self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
    466             self.ph1: [[-1.0], [0.5]]
    467         }) as stepper:
    468       self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
    469                           stepper.get_tensor_value("ph0"))
    470       self.assertAllClose([[1.0, 2.0], [-3.0, 5.0]],
    471                           stepper.get_tensor_value("ph0:0"))
    472       with self.assertRaisesRegexp(
    473           KeyError,
    474           r"The name 'ph0:1' refers to a Tensor which does not exist"):
    475         stepper.get_tensor_value("ph0:1")
    476 
    477   def testIsPlaceholdersShouldGiveCorrectAnswers(self):
    478     with NodeStepper(self.sess, self.y) as stepper:
    479       self.assertTrue(stepper.is_placeholder(self.ph0.name))
    480       self.assertTrue(stepper.is_placeholder(self.ph1.name))
    481 
    482       self.assertFalse(stepper.is_placeholder(self.x.name))
    483       self.assertFalse(stepper.is_placeholder(self.y.name))
    484 
    485       with self.assertRaisesRegexp(ValueError,
    486                                    "A is not in the transitive closure"):
    487         self.assertFalse(stepper.is_placeholder("A"))
    488 
    489   def testPlaceholdersShouldGiveCorrectAnswers(self):
    490     with NodeStepper(self.sess, self.y) as stepper:
    491       self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders()))
    492 
    493   def testContWithPlaceholders(self):
    494     with NodeStepper(
    495         self.sess,
    496         self.y,
    497         feed_dict={
    498             self.ph0: [[1.0, 2.0], [-3.0, 5.0]],
    499             self.ph1: [[-1.0], [0.5]]
    500         }) as stepper:
    501       self.assertEqual(4, len(stepper.sorted_nodes()))
    502       self.assertSetEqual({"ph0:0", "ph1:0", "x:0", "y:0"},
    503                           set(stepper.closure_elements()))
    504 
    505       result = stepper.cont(self.x)
    506       self.assertAllClose([[0.0], [5.5]], result)
    507       self.assertEqual({
    508           "ph0:0": NodeStepper.FEED_TYPE_CLIENT,
    509           "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
    510       }, stepper.last_feed_types())
    511 
    512       self.assertEqual(["x:0"], stepper.handle_names())
    513       self.assertSetEqual({"x"}, stepper.handle_node_names())
    514 
    515       result = stepper.cont(self.y)
    516       self.assertAllClose([[-1.0], [6.0]], result)
    517       self.assertEqual({
    518           "x:0": NodeStepper.FEED_TYPE_HANDLE,
    519           "ph1:0": NodeStepper.FEED_TYPE_CLIENT,
    520       }, stepper.last_feed_types())
    521 
    522   def testAttemptToContToPlaceholderWithTensorFeedKeysShouldWork(self):
    523     """Continuing to a placeholder should be allowed, using client feed."""
    524 
    525     ph0_feed = [[1.0, 2.0], [-3.0, 5.0]]
    526     ph1_feed = [[-1.0], [0.5]]
    527     with NodeStepper(
    528         self.sess, self.y, feed_dict={
    529             self.ph0: ph0_feed,
    530             self.ph1: ph1_feed,
    531         }) as stepper:
    532       self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
    533       self.assertEqual({
    534           self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
    535       }, stepper.last_feed_types())
    536 
    537       self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
    538       self.assertEqual({
    539           self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
    540       }, stepper.last_feed_types())
    541 
    542       ph0_node = self.sess.graph.as_graph_element("ph0")
    543       self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
    544       self.assertEqual({
    545           self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
    546       }, stepper.last_feed_types())
    547 
    548       self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
    549 
    550   def testAttemptToContToPlaceholderWithTensorNameFeedKeysShouldWork(self):
    551 
    552     ph0_feed = [[1.0, 2.0], [-3.0, 5.0]]
    553     ph1_feed = [[-1.0], [0.5]]
    554     with NodeStepper(
    555         self.sess,
    556         self.y,
    557         feed_dict={
    558             self.ph0.name: ph0_feed,
    559             self.ph1.name: ph1_feed,
    560         }) as stepper:
    561       self.assertAllClose(ph0_feed, stepper.cont(self.ph0))
    562       self.assertEqual({
    563           self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
    564       }, stepper.last_feed_types())
    565 
    566       self.assertAllClose(ph1_feed, stepper.cont(self.ph1))
    567       self.assertEqual({
    568           self.ph1.name: NodeStepper.FEED_TYPE_CLIENT
    569       }, stepper.last_feed_types())
    570 
    571       ph0_node = self.sess.graph.as_graph_element("ph0")
    572       self.assertAllClose(ph0_feed, stepper.cont(ph0_node))
    573       self.assertEqual({
    574           self.ph0.name: NodeStepper.FEED_TYPE_CLIENT
    575       }, stepper.last_feed_types())
    576 
    577       self.assertAllClose([[-1.0], [6.0]], stepper.finalize())
    578 
    579 
    580 class StepperAssignAddTest(test_util.TensorFlowTestCase):
    581 
    582   def setUp(self):
    583     self.v = variables.Variable(10.0, name="v")
    584     self.p = math_ops.add(self.v, self.v, name="p")
    585     self.q = math_ops.multiply(self.p, self.p, name="q")
    586     self.delta = constant_op.constant(2.0, name="delta")
    587     self.v_add = state_ops.assign_add(self.v, self.delta, name="v_add")
    588     self.v_add_plus_one = math_ops.add(self.v_add,
    589                                        1.0,
    590                                        name="v_add_plus_one")
    591 
    592     rewriter_config = rewriter_config_pb2.RewriterConfig(
    593         disable_model_pruning=True,
    594         arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
    595         constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
    596     graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
    597     config = config_pb2.ConfigProto(graph_options=graph_options)
    598     self.sess = session.Session(config=config)
    599     self.sess.run(self.v.initializer)
    600 
    601   def tearDown(self):
    602     ops.reset_default_graph()
    603 
    604   def testLastUpdatedVariablesReturnsNoneBeforeAnyContCalls(self):
    605     with NodeStepper(self.sess, [self.q, self.v_add]) as stepper:
    606       self.assertIsNone(stepper.last_updated())
    607 
    608   def testContToUpdateInvalidatesDumpedIntermediates(self):
    609     with NodeStepper(self.sess, [self.q, self.v_add]) as stepper:
    610       self.assertAllClose(400.0, stepper.cont("q:0"))
    611       self.assertItemsEqual(["v/read:0", "p:0"],
    612                             stepper.intermediate_tensor_names())
    613       self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0"))
    614       self.assertAllClose(20.0, stepper.get_tensor_value("p:0"))
    615 
    616       self.assertAllClose(
    617           12.0, stepper.cont(
    618               self.v_add, invalidate_from_updated_variables=True))
    619       self.assertAllClose(12.0, self.sess.run(self.v))
    620       self.assertSetEqual({self.v.name}, stepper.last_updated())
    621       self.assertItemsEqual(["v:0"], stepper.dirty_variables())
    622       # Updating the value of v by calling v_add should have invalidated the
    623       # dumped intermediate tensors for v/read:0 and p:0.
    624       self.assertItemsEqual(["delta:0"], stepper.intermediate_tensor_names())
    625       with self.assertRaisesRegexp(
    626           ValueError,
    627           r"This stepper instance does not have access to the value of tensor "
    628           r"\"p:0\""):
    629         stepper.get_tensor_value("p:0")
    630 
    631       # The next cont to q should not have used any dumped intermediate tensors
    632       # and its result should reflect the updated value.
    633       self.assertAllClose(576.0, stepper.cont("q:0"))
    634       self.assertSetEqual(set(), stepper.last_updated())
    635       self.assertEqual({}, stepper.last_feed_types())
    636 
    637   def testOverridingUpstreamTensorInvalidatesDumpedIntermediates(self):
    638     with NodeStepper(self.sess, self.q) as stepper:
    639       self.assertAllClose(400.0, stepper.cont("q:0"))
    640       self.assertItemsEqual(["v/read:0", "p:0"],
    641                             stepper.intermediate_tensor_names())
    642       self.assertAllClose(10.0, stepper.get_tensor_value("v/read:0"))
    643       self.assertAllClose(20.0, stepper.get_tensor_value("p:0"))
    644 
    645       stepper.override_tensor("v/read:0", 11.0)
    646       self.assertItemsEqual(["v/read:0"], stepper.override_names())
    647       # Overriding the upstream v/read:0 should have invalidated the dumped
    648       # intermediate tensor for the downstream p:0.
    649       self.assertItemsEqual([], stepper.intermediate_tensor_names())
    650 
    651       # The next cont to q should not have used any dumped intermediate tensors
    652       # and its result should reflect the overriding value.
    653       self.assertAllClose(484.0, stepper.cont("q:0"))
    654       self.assertEqual({
    655           "v/read:0": NodeStepper.FEED_TYPE_OVERRIDE
    656       }, stepper.last_feed_types())
    657 
    658   def testRemovingOverrideToUpstreamTensorInvalidatesDumpedIntermediates(self):
    659     with NodeStepper(self.sess, self.q) as stepper:
    660       stepper.override_tensor("v/read:0", 9.0)
    661       self.assertItemsEqual(["v/read:0"], stepper.override_names())
    662 
    663       self.assertAllClose(324.0, stepper.cont(self.q))
    664       self.assertItemsEqual(["p:0"], stepper.intermediate_tensor_names())
    665 
    666       stepper.remove_override("v/read:0")
    667       self.assertItemsEqual([], stepper.override_names())
    668       # Removing the pre-existing override to v/read:0 should have invalidated
    669       # the dumped intermediate tensor.
    670       self.assertItemsEqual([], stepper.intermediate_tensor_names())
    671 
    672   def testRepeatedCallsToAssignAddDoesNotUpdateVariableAgain(self):
    673     with NodeStepper(self.sess, self.v_add) as stepper:
    674       stepper.cont(self.v_add)
    675       self.assertSetEqual({self.v.name}, stepper.last_updated())
    676       self.assertAllClose(12.0, stepper.cont(self.v))
    677       stepper.cont(self.v_add)
    678       self.assertSetEqual(set(), stepper.last_updated())
    679       self.assertEqual({"v_add:0": NodeStepper.FEED_TYPE_HANDLE},
    680                        stepper.last_feed_types())
    681       self.assertAllClose(12.0, stepper.cont(self.v))
    682 
    683   def testRepeatedCallsToAssignAddDownStreamDoesNotUpdateVariableAgain(self):
    684     with NodeStepper(self.sess, self.v_add_plus_one) as stepper:
    685       stepper.cont(self.v_add_plus_one)
    686       self.assertSetEqual({self.v.name}, stepper.last_updated())
    687       self.assertAllClose(12.0, stepper.cont(self.v))
    688       stepper.cont(self.v_add_plus_one)
    689       self.assertSetEqual(set(), stepper.last_updated())
    690       self.assertEqual({"v_add_plus_one:0": NodeStepper.FEED_TYPE_HANDLE},
    691                        stepper.last_feed_types())
    692       self.assertAllClose(12.0, stepper.cont(self.v))
    693 
    694 
    695 class StepperBackwardRunTest(test_util.TensorFlowTestCase):
    696 
    697   def setUp(self):
    698     """Test setup.
    699 
    700     Structure of the forward graph:
    701               f
    702              | |
    703         -----   -----
    704         |           |
    705         d           e
    706        | |         | |
    707     ---   ---------  ---
    708     |         |        |
    709     a         b        c
    710 
    711     Construct a backward graph using the GradientDescentOptimizer.
    712     """
    713 
    714     self.a = variables.Variable(1.0, name="a")
    715     self.b = variables.Variable(2.0, name="b")
    716     self.c = variables.Variable(4.0, name="c")
    717     self.d = math_ops.multiply(self.a, self.b, name="d")
    718     self.e = math_ops.multiply(self.b, self.c, name="e")
    719     self.f = math_ops.multiply(self.d, self.e, name="f")
    720 
    721     # Gradient descent optimizer that minimizes g.
    722     gradient_descent.GradientDescentOptimizer(0.01).minimize(
    723         self.f, name="optim")
    724 
    725     rewriter_config = rewriter_config_pb2.RewriterConfig(
    726         disable_model_pruning=True,
    727         arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
    728         constant_folding=rewriter_config_pb2.RewriterConfig.OFF)
    729     graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
    730     config = config_pb2.ConfigProto(graph_options=graph_options)
    731     self.sess = session.Session(config=config)
    732     self.sess.run(variables.global_variables_initializer())
    733 
    734   def tearDown(self):
    735     ops.reset_default_graph()
    736 
    737   def testContToUpdateA(self):
    738     with NodeStepper(self.sess, "optim") as stepper:
    739       result = stepper.cont("a:0")
    740       self.assertAllClose(1.0, result)
    741       self.assertEqual({}, stepper.last_feed_types())
    742 
    743       result = stepper.cont("optim/learning_rate:0")
    744       self.assertAllClose(0.01, result)
    745       self.assertEqual({}, stepper.last_feed_types())
    746 
    747       # Before any cont calls on ApplyGradientDescent, there should be no
    748       # "dirty" variables.
    749       self.assertEqual(set(), stepper.dirty_variables())
    750 
    751       # First, all the two control inputs to optim.
    752       result = stepper.cont("optim/update_a/ApplyGradientDescent",
    753                             invalidate_from_updated_variables=True)
    754 
    755       # Now variable a should have been marked as dirty due to the update
    756       # by optim/update_a/ApplyGradientDescent.
    757       self.assertSetEqual({"a:0"}, stepper.last_updated())
    758       self.assertEqual({"a:0"}, stepper.dirty_variables())
    759       self.assertIsNone(result)
    760       self.assertEqual({
    761           "optim/learning_rate:0": NodeStepper.FEED_TYPE_HANDLE
    762       }, stepper.last_feed_types())
    763 
    764       # Check that Variable "a" has been updated properly, but "b", "c" and "d"
    765       # remain the same.
    766       # For backprop on Variable a:
    767       #   Because f = a * b * b * c, df / da = b * b * c.
    768       #   1.0 - learning_rate * b * b * c
    769       #     = 1.0 -  0.01 * 2.0 * 2.0 * 4.0 = 0.84.
    770       self.assertAllClose(0.84, self.sess.run(self.a))
    771       self.assertAllClose(2.0, self.sess.run(self.b))
    772       self.assertAllClose(4.0, self.sess.run(self.c))
    773 
    774   def testContToUpdateB(self):
    775     with NodeStepper(self.sess, "optim") as stepper:
    776       result = stepper.cont("optim/update_b/ApplyGradientDescent",
    777                             invalidate_from_updated_variables=True)
    778       self.assertIsNone(result)
    779       self.assertSetEqual({"b:0"}, stepper.last_updated())
    780       self.assertEqual(set(["b:0"]), stepper.dirty_variables())
    781 
    782       # For backprop on Variable b:
    783       #   Because f = a * b * b * c, df / da = 2 * a * b * c.
    784       #   2.0 - learning_rate * 2 * a * b * c
    785       #     = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84
    786       self.assertAllClose(1.0, self.sess.run(self.a))
    787       self.assertAllClose(1.84, self.sess.run(self.b))
    788       self.assertAllClose(4.0, self.sess.run(self.c))
    789 
    790   def testContAfterUpdateWithoutRestoringVariableValue(self):
    791     with NodeStepper(self.sess, "optim") as stepper:
    792       # First, update Variable a from 1.0 to 0.84.
    793       result = stepper.cont(
    794           "optim/update_a/ApplyGradientDescent",
    795           invalidate_from_updated_variables=True,
    796           restore_variable_values=True)
    797       self.assertIsNone(result)
    798       self.assertSetEqual({"a:0"}, stepper.last_updated())
    799       self.assertEqual(set(["a:0"]), stepper.dirty_variables())
    800       self.assertAllClose(0.84, self.sess.run(self.a))
    801       self.assertAllClose(2.0, self.sess.run(self.b))
    802       self.assertAllClose(4.0, self.sess.run(self.c))
    803       # Tracking of the updated variables should have invalidated all
    804       # intermediate tensors downstream to a:0.
    805       self.assertNotIn("a/read:0", stepper.intermediate_tensor_names())
    806       self.assertNotIn("d:0", stepper.intermediate_tensor_names())
    807 
    808       # Second, update Variable b without the default restore_variable_values.
    809       result = stepper.cont(
    810           "optim/update_b/ApplyGradientDescent", restore_variable_values=False)
    811       self.assertIsNone(result)
    812       # For the backprop on Variable b under the updated value of a:
    813       #   2.0 - learning_rate * 2 * a' * b * c
    814       #     = 2.0 - 0.01 * 2 * 0.84 * 2.0 * 4.0 = 1.8656
    815       self.assertAllClose(0.84, self.sess.run(self.a))
    816       self.assertAllClose(1.8656, self.sess.run(self.b))
    817       self.assertAllClose(4.0, self.sess.run(self.c))
    818 
    819   def testContNotInvalidatingFromVariableUpdatesWorksForNextUpdate(self):
    820     with NodeStepper(self.sess, "optim") as stepper:
    821       self.assertIsNone(stepper.cont(
    822           "optim/update_a/ApplyGradientDescent",
    823           invalidate_from_updated_variables=False))
    824       # Even though invalidate_from_updated_variables is set to False, dirty
    825       # variables should still have been tracked.
    826       self.assertSetEqual({"a:0"}, stepper.last_updated())
    827       self.assertEqual({"a:0"}, stepper.dirty_variables())
    828       self.assertIn("a/read:0", stepper.intermediate_tensor_names())
    829       self.assertIn("b/read:0", stepper.intermediate_tensor_names())
    830       self.assertIn("c/read:0", stepper.intermediate_tensor_names())
    831       self.assertIn("d:0", stepper.intermediate_tensor_names())
    832       self.assertIn("e:0", stepper.intermediate_tensor_names())
    833       self.assertIn("optim/learning_rate:0",
    834                     stepper.intermediate_tensor_names())
    835       self.assertNotIn("a:0", stepper.intermediate_tensor_names())
    836       self.assertNotIn("b:0", stepper.intermediate_tensor_names())
    837       self.assertNotIn("c:0", stepper.intermediate_tensor_names())
    838 
    839       self.assertAllClose(0.84, self.sess.run(self.a))
    840       self.assertAllClose(2.0, self.sess.run(self.b))
    841       self.assertAllClose(4.0, self.sess.run(self.c))
    842 
    843       # For the backprop on Variable b, the result should reflect the original
    844       # value of Variable a, even though Variable a has actually been updated.
    845       #   2.0 - learning_rate * 2 * a * b * c
    846       #     = 2.0 - 0.01 * 2 * 1.0 * 2.0 * 4.0 = 1.84
    847       self.assertIsNone(stepper.cont(
    848           "optim/update_b/ApplyGradientDescent",
    849           invalidate_from_updated_variables=False,
    850           restore_variable_values=False))
    851       self.assertAllClose(0.84, self.sess.run(self.a))
    852       self.assertAllClose(1.84, self.sess.run(self.b))
    853       self.assertAllClose(4.0, self.sess.run(self.c))
    854 
    855   def testUpdateTwiceRestoreVariable(self):
    856     with NodeStepper(self.sess, "optim") as stepper:
    857       result = stepper.cont(
    858           "optim/update_a/ApplyGradientDescent",
    859           invalidate_from_updated_variables=True,
    860           restore_variable_values=True)
    861       self.assertIsNone(result)
    862       self.assertSetEqual({"a:0"}, stepper.last_updated())
    863       self.assertEqual({"a:0"}, stepper.dirty_variables())
    864 
    865       result = stepper.cont(
    866           "optim/update_b/ApplyGradientDescent",
    867           invalidate_from_updated_variables=True,
    868           restore_variable_values=True)
    869       self.assertIsNone(result)
    870       # Variables a and c should have been restored and hence no longer dirty.
    871       # Variable b should have been marked as dirty.
    872       self.assertSetEqual({"b:0"}, stepper.last_updated())
    873       self.assertEqual({"b:0"}, stepper.dirty_variables())
    874 
    875     # The result of the update should be identitcal to as if only update_b is
    876     # run.
    877     self.assertAllClose(1.0, self.sess.run(self.a))
    878     self.assertAllClose(1.84, self.sess.run(self.b))
    879     self.assertAllClose(4.0, self.sess.run(self.c))
    880 
    881   def testSelectiveHandleUsageDependingOnTransitiveCleanliness(self):
    882     """Test tensor handlers are using only during clean transitive closure.
    883 
    884     "clean" means no Variables have been updated by preceding cont() calls.
    885     """
    886 
    887     with NodeStepper(self.sess, "optim") as stepper:
    888       # First, call cont() on the two tensors on the intermediate level: e and
    889       # f.
    890       result = stepper.cont("d:0")
    891       self.assertAllClose(2.0, result)
    892       self.assertEqual({}, stepper.last_feed_types())
    893       self.assertItemsEqual(["a/read:0", "b/read:0"],
    894                             stepper.intermediate_tensor_names())
    895       self.assertItemsEqual(["d:0"], stepper.handle_names())
    896       self.assertSetEqual(set(), stepper.last_updated())
    897       self.assertEqual(set(), stepper.dirty_variables())
    898 
    899       result = stepper.cont("e:0")
    900       self.assertAllClose(8.0, result)
    901       self.assertEqual({
    902           "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE
    903       }, stepper.last_feed_types())
    904       self.assertItemsEqual(["d:0", "e:0"], stepper.handle_names())
    905       self.assertItemsEqual(["a/read:0", "b/read:0", "c/read:0"],
    906                             stepper.intermediate_tensor_names())
    907       self.assertSetEqual(set(), stepper.last_updated())
    908       self.assertEqual(set(), stepper.dirty_variables())
    909 
    910       # Now run update_a, so as to let Variable a be dirty.
    911       result = stepper.cont(
    912           "optim/update_a/ApplyGradientDescent",
    913           invalidate_from_updated_variables=True,
    914           restore_variable_values=True)
    915       self.assertIsNone(result)
    916       # Due to the update to the value of a:0, the dumped intermediate a/read:0
    917       # should have been invalidated.
    918       self.assertNotIn("a/read:0", stepper.intermediate_tensor_names())
    919       self.assertSetEqual({"a:0"}, stepper.last_updated())
    920       self.assertEqual({"a:0"}, stepper.dirty_variables())
    921 
    922       # Now, run update_b.
    923       result = stepper.cont(
    924           "optim/update_b/ApplyGradientDescent", restore_variable_values=True)
    925       self.assertIsNone(result)
    926 
    927       # The last cont() run should have use the handle of tensor e, but not the
    928       # handle of tensor d, because the transitive closure of e is clean,
    929       # whereas that of d is dirty due to the update to a in the previous cont()
    930       # call.
    931       last_feed_types = stepper.last_feed_types()
    932       self.assertNotIn("d:0", last_feed_types)
    933       self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
    934                        last_feed_types["b/read:0"])
    935       self.assertEqual(NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
    936                        last_feed_types["c/read:0"])
    937 
    938       # The result of the update_b should be identical to as if no other
    939       # update_* cont() calls have occurred before.
    940       self.assertAllClose(1.0, self.sess.run(self.a))
    941       self.assertAllClose(1.84, self.sess.run(self.b))
    942       self.assertAllClose(4.0, self.sess.run(self.c))
    943 
    944   def testRestoreVariableValues(self):
    945     """Test restore_variable_values() restores the old values of variables."""
    946 
    947     with NodeStepper(self.sess, "optim") as stepper:
    948       stepper.cont(
    949           "optim/update_b/ApplyGradientDescent",
    950           invalidate_from_updated_variables=True,
    951           restore_variable_values=True)
    952       self.assertAllClose(1.84, self.sess.run(self.b))
    953 
    954       stepper.restore_variable_values()
    955       self.assertAllClose(2.0, self.sess.run(self.b))
    956 
    957   def testFinalize(self):
    958     """Test finalize() to restore variables and run the original fetch."""
    959 
    960     with NodeStepper(self.sess, "optim") as stepper:
    961       # Invoke update_b before calling finalize.
    962       stepper.cont(
    963           "optim/update_b/ApplyGradientDescent",
    964           invalidate_from_updated_variables=True,
    965           restore_variable_values=True)
    966 
    967       result = stepper.finalize()
    968       self.assertIsNone(result)
    969 
    970       # The results of the Variable updates should be the same as if no cont()
    971       # call has occurred on update_b.
    972       self.assertAllClose(0.84, self.sess.run(self.a))
    973       self.assertAllClose(1.84, self.sess.run(self.b))
    974       self.assertAllClose(3.96, self.sess.run(self.c))
    975 
    976   def testOverrideThenContToUpdateThenRemoveOverrideThenUpdateAgain(self):
    977     """Test cont() to update nodes after overriding tensor values."""
    978 
    979     with NodeStepper(self.sess, "optim") as stepper:
    980       result = stepper.cont("d:0")
    981       self.assertAllClose(2.0, result)
    982       self.assertEqual({}, stepper.last_feed_types())
    983       self.assertSetEqual(set(), stepper.last_updated())
    984       self.assertEqual(set(), stepper.dirty_variables())
    985       self.assertEqual(["d:0"], stepper.handle_names())
    986       self.assertSetEqual({"d"}, stepper.handle_node_names())
    987 
    988       # Override the value from 1.0 to 10.0.
    989       stepper.override_tensor("a/read:0", 10.0)
    990 
    991       self.assertEqual(["a/read:0"], stepper.override_names())
    992 
    993       result = stepper.cont(
    994           "optim/update_c/ApplyGradientDescent",
    995           invalidate_from_updated_variables=True,
    996           restore_variable_values=True)
    997       self.assertIsNone(result)
    998 
    999       # The last cont() call should have not used the tensor handle to d:0,
   1000       # because the transitive closure of d:0 contains an override tensor.
   1001       self.assertEqual({
   1002           "a/read:0": NodeStepper.FEED_TYPE_OVERRIDE,
   1003           "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
   1004       }, stepper.last_feed_types())
   1005 
   1006       # The tensor handle to d:0 should have been removed due to the dirty
   1007       # transitive closure.
   1008       self.assertEqual([], stepper.handle_names())
   1009       self.assertSetEqual(set(), stepper.handle_node_names())
   1010 
   1011       # For this backprop on c, the overriding value of a/read:0 should have
   1012       # been used:
   1013       #   4.0 - learning_rate * a * b * b
   1014       #     = 4.0 - 0.01 * 10.0 * 2.0 * 2.0 = 3.6.
   1015       self.assertAllClose(3.6, self.sess.run(self.c))
   1016 
   1017       # Now remove the overriding value of a/read:0.
   1018       stepper.remove_override("a/read:0")
   1019       self.assertEqual([], stepper.override_names())
   1020 
   1021       # Obtain the tensor handle to d:0 again.
   1022       result = stepper.cont("d:0")
   1023       self.assertAllClose(2.0, result)
   1024       self.assertEqual(["d:0"], stepper.handle_names())
   1025       self.assertSetEqual({"d"}, stepper.handle_node_names())
   1026       self.assertNotIn("a/read:0", stepper.last_feed_types())
   1027 
   1028       # Then call update_c again, without restoring c.
   1029       result = stepper.cont("optim/update_c/ApplyGradientDescent",
   1030                             restore_variable_values=False)
   1031       self.assertIsNone(result)
   1032       self.assertNotIn("a/read:0", stepper.last_feed_types())
   1033 
   1034       # This time, the d:0 tensor handle should have been used, because its
   1035       # transitive closure is clean.
   1036       self.assertEqual({
   1037           "b/read:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
   1038           "d:0": NodeStepper.FEED_TYPE_HANDLE,
   1039           "optim/learning_rate:0": NodeStepper.FEED_TYPE_DUMPED_INTERMEDIATE,
   1040       }, stepper.last_feed_types())
   1041 
   1042       # For this backprop on c, the overriding value of a/read:0 should have
   1043       # been used:
   1044       #   3.6 - learning_rate * a * b * b
   1045       #     = 3.6 - 0.01 * 1.0 * 2.0 * 2.0 = 3.56.
   1046       self.assertAllClose(3.56, self.sess.run(self.c))
   1047 
   1048   def testContToNodeWithOutputTensors(self):
   1049     """cont() to an op should cache its output tensors if appropriate."""
   1050 
   1051     with NodeStepper(self.sess, "optim") as stepper:
   1052       # In the transitive closure of the stepper, look for an op of which the
   1053       # output tensor also is in the transitive closure.
   1054       # Do not assume a specific op, e.g., ""gradients/e_grad/Reshape_1",
   1055       # because it may vary between builds.
   1056       closure_elements = stepper.closure_elements()
   1057       op_with_output_in_closure = None
   1058       for element_name in closure_elements:
   1059         if element_name + ":0" in closure_elements:
   1060           op_with_output_in_closure = str(element_name)
   1061           break
   1062 
   1063       self.assertEqual(
   1064           [0], stepper.output_slots_in_closure(op_with_output_in_closure))
   1065 
   1066       self.assertIsNotNone(op_with_output_in_closure)
   1067       output_tensor = op_with_output_in_closure + ":0"
   1068 
   1069       # The op "gradients/?_grad/Reshape_1" is in the transitive closure of the
   1070       # stepper, because it is the control input to another o. However, its
   1071       # output tensor "gradients/?_grad/Reshape_1:0" is also in the transitive
   1072       # closure, because it is the (non-control) input of certain ops. Calling
   1073       # cont() on the op should lead to the caching of the tensor handle for
   1074       # the output tensor.
   1075       stepper.cont(op_with_output_in_closure)
   1076 
   1077       self.assertEqual([output_tensor], stepper.handle_names())
   1078       self.assertSetEqual({op_with_output_in_closure},
   1079                           stepper.handle_node_names())
   1080 
   1081       # Do a cont() call that uses the cached tensor of
   1082       # "gradients/?_grad/Reshape_1:0".
   1083       stepper.cont(output_tensor)
   1084       self.assertEqual({
   1085           output_tensor: NodeStepper.FEED_TYPE_HANDLE
   1086       }, stepper.last_feed_types())
   1087 
   1088 
   1089 if __name__ == "__main__":
   1090   googletest.main()
   1091