Home | History | Annotate | Download | only in lib
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the reconstruction of non-debugger-decorated GraphDefs."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import shutil
     21 import tempfile
     22 
     23 from tensorflow.core.framework import graph_pb2
     24 from tensorflow.core.protobuf import config_pb2
     25 from tensorflow.core.protobuf import rewriter_config_pb2
     26 from tensorflow.python.client import session
     27 from tensorflow.python.debug.lib import debug_data
     28 from tensorflow.python.debug.lib import debug_graphs
     29 from tensorflow.python.debug.lib import debug_utils
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.framework import test_util
     33 from tensorflow.python.ops import control_flow_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.ops import variables
     36 from tensorflow.python.platform import test
     37 from tensorflow.python.training import gradient_descent
     38 
     39 
     40 class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
     41 
     42   _OP_TYPE_BLACKLIST = (
     43       "_Send", "_Recv", "_HostSend", "_HostRecv", "_Retval")
     44 
     45   def _no_rewrite_session_config(self):
     46     rewriter_config = rewriter_config_pb2.RewriterConfig(
     47         dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
     48     graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
     49     return config_pb2.ConfigProto(graph_options=graph_options)
     50 
     51   def setUp(self):
     52     super(ReconstructNonDebugGraphTest, self).setUp()
     53     self._dump_dir = tempfile.mkdtemp()
     54     self._debug_url = "file://" + self._dump_dir
     55     ops.reset_default_graph()
     56 
     57   def tearDown(self):
     58     shutil.rmtree(self._dump_dir)
     59     super(ReconstructNonDebugGraphTest, self).tearDown()
     60 
     61   def _graphDefWithoutBlacklistedNodes(self, graph_def):
     62     output_graph_def = graph_pb2.GraphDef()
     63     for node in graph_def.node:
     64       if node.op not in self._OP_TYPE_BLACKLIST:
     65         new_node = output_graph_def.node.add()
     66         new_node.CopyFrom(node)
     67 
     68         if new_node.op == "Enter":
     69           # The debugger sets parallel_iterations attribute of while-loop Enter
     70           # nodes to 1 for debugging.
     71           for attr_key in new_node.attr:
     72             if attr_key == "parallel_iterations":
     73               new_node.attr[attr_key].i = 1
     74         elif new_node.op == "Switch":
     75           # We don't check the inputs to Switch ops as their inputs may be
     76           # Send/Recv nodes.
     77           del new_node.input[:]
     78 
     79     return output_graph_def
     80 
     81   def _compareOriginalAndReconstructedGraphDefs(self,
     82                                                 sess,
     83                                                 fetches,
     84                                                 feed_dict=None,
     85                                                 expected_output=None):
     86     run_options = config_pb2.RunOptions(output_partition_graphs=True)
     87     run_metadata = config_pb2.RunMetadata()
     88     output = sess.run(fetches, feed_dict=feed_dict, options=run_options,
     89                       run_metadata=run_metadata)
     90     if expected_output is not None:
     91       self.assertAllClose(expected_output, output)
     92     non_debug_graph_defs = run_metadata.partition_graphs
     93 
     94     debug_utils.watch_graph(
     95         run_options, sess.graph, debug_urls=self._debug_url)
     96     run_metadata = config_pb2.RunMetadata()
     97     output = sess.run(fetches, feed_dict=feed_dict, options=run_options,
     98                       run_metadata=run_metadata)
     99     if expected_output is not None:
    100       self.assertAllClose(expected_output, output)
    101 
    102     dump = debug_data.DebugDumpDir(
    103         self._dump_dir, partition_graphs=run_metadata.partition_graphs,
    104         validate=True)
    105     reconstructed = dump.reconstructed_non_debug_partition_graphs()
    106 
    107     self.assertEqual(len(non_debug_graph_defs), len(reconstructed))
    108     for i, non_debug_graph_def in enumerate(non_debug_graph_defs):
    109       device_name = debug_graphs._infer_device_name(non_debug_graph_def)
    110       test_util.assert_equal_graph_def(
    111           self._graphDefWithoutBlacklistedNodes(reconstructed[device_name]),
    112           self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))
    113 
    114       # Test debug_graphs.reconstruct_non_debug_graph_def.
    115       reconstructed_again = (
    116           debug_graphs.reconstruct_non_debug_graph_def(
    117               run_metadata.partition_graphs[i]))
    118       test_util.assert_equal_graph_def(
    119           self._graphDefWithoutBlacklistedNodes(reconstructed_again),
    120           self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))
    121 
    122   def testReconstructSimpleGraph(self):
    123     with session.Session() as sess:
    124       u = variables.Variable([12.0], name="u")
    125       v = variables.Variable([30.0], name="v")
    126       w = math_ops.add(u, v, name="w")
    127       sess.run(u.initializer)
    128       sess.run(v.initializer)
    129 
    130       self._compareOriginalAndReconstructedGraphDefs(
    131           sess, w, expected_output=[42.0])
    132 
    133   def testReconstructGraphWithControlEdge(self):
    134     with session.Session() as sess:
    135       a = variables.Variable(10.0, name="a")
    136       with ops.control_dependencies([a]):
    137         b = math_ops.add(a, a, name="b")
    138       with ops.control_dependencies([a, b]):
    139         c = math_ops.multiply(b, b, name="c")
    140       sess.run(a.initializer)
    141 
    142       self._compareOriginalAndReconstructedGraphDefs(
    143           sess, c, expected_output=400.0)
    144 
    145   def testReonstructGraphWithCond(self):
    146     with session.Session(config=self._no_rewrite_session_config()) as sess:
    147       x = variables.Variable(10.0, name="x")
    148       y = variables.Variable(20.0, name="y")
    149       cond = control_flow_ops.cond(
    150           x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
    151       sess.run(x.initializer)
    152       sess.run(y.initializer)
    153 
    154       self._compareOriginalAndReconstructedGraphDefs(
    155           sess, cond, expected_output=21.0)
    156 
    157   def testReconstructGraphWithWhileLoop(self):
    158     with session.Session() as sess:
    159       loop_body = lambda i: math_ops.add(i, 2)
    160       loop_cond = lambda i: math_ops.less(i, 16)
    161       i = constant_op.constant(10, name="i")
    162       loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
    163 
    164       self._compareOriginalAndReconstructedGraphDefs(sess, loop)
    165 
    166   def testReconstructGraphWithGradients(self):
    167     with session.Session(config=self._no_rewrite_session_config()) as sess:
    168       u = variables.Variable(12.0, name="u")
    169       v = variables.Variable(30.0, name="v")
    170       x = constant_op.constant(1.1, name="x")
    171       toy_loss = x * (u - v)
    172       train_op = gradient_descent.GradientDescentOptimizer(
    173           learning_rate=0.1).minimize(toy_loss, name="train_op")
    174       sess.run(u.initializer)
    175       sess.run(v.initializer)
    176 
    177       self._compareOriginalAndReconstructedGraphDefs(sess, train_op)
    178 
    179 
    180 if __name__ == "__main__":
    181   test.main()
    182