1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the reconstruction of non-debugger-decorated GraphDefs.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import shutil 21 import tempfile 22 23 from tensorflow.core.framework import graph_pb2 24 from tensorflow.core.protobuf import config_pb2 25 from tensorflow.core.protobuf import rewriter_config_pb2 26 from tensorflow.python.client import session 27 from tensorflow.python.debug.lib import debug_data 28 from tensorflow.python.debug.lib import debug_graphs 29 from tensorflow.python.debug.lib import debug_utils 30 from tensorflow.python.framework import constant_op 31 from tensorflow.python.framework import ops 32 from tensorflow.python.framework import test_util 33 from tensorflow.python.ops import control_flow_ops 34 from tensorflow.python.ops import math_ops 35 from tensorflow.python.ops import variables 36 from tensorflow.python.platform import test 37 from tensorflow.python.training import gradient_descent 38 39 40 class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase): 41 42 _OP_TYPE_BLACKLIST = ( 43 "_Send", "_Recv", "_HostSend", "_HostRecv", "_Retval") 44 45 def _no_rewrite_session_config(self): 46 rewriter_config = rewriter_config_pb2.RewriterConfig( 47 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) 48 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 49 return config_pb2.ConfigProto(graph_options=graph_options) 50 51 def setUp(self): 52 super(ReconstructNonDebugGraphTest, self).setUp() 53 self._dump_dir = tempfile.mkdtemp() 54 self._debug_url = "file://" + self._dump_dir 55 ops.reset_default_graph() 56 57 def tearDown(self): 58 shutil.rmtree(self._dump_dir) 59 super(ReconstructNonDebugGraphTest, self).tearDown() 60 61 def _graphDefWithoutBlacklistedNodes(self, graph_def): 62 output_graph_def = graph_pb2.GraphDef() 63 for node in graph_def.node: 64 if node.op not in self._OP_TYPE_BLACKLIST: 65 new_node = output_graph_def.node.add() 66 new_node.CopyFrom(node) 67 68 if new_node.op == "Enter": 69 # The debugger sets parallel_iterations attribute of while-loop Enter 70 # nodes to 1 for debugging. 71 for attr_key in new_node.attr: 72 if attr_key == "parallel_iterations": 73 new_node.attr[attr_key].i = 1 74 elif new_node.op == "Switch": 75 # We don't check the inputs to Switch ops as their inputs may be 76 # Send/Recv nodes. 77 del new_node.input[:] 78 79 return output_graph_def 80 81 def _compareOriginalAndReconstructedGraphDefs(self, 82 sess, 83 fetches, 84 feed_dict=None, 85 expected_output=None): 86 run_options = config_pb2.RunOptions(output_partition_graphs=True) 87 run_metadata = config_pb2.RunMetadata() 88 output = sess.run(fetches, feed_dict=feed_dict, options=run_options, 89 run_metadata=run_metadata) 90 if expected_output is not None: 91 self.assertAllClose(expected_output, output) 92 non_debug_graph_defs = run_metadata.partition_graphs 93 94 debug_utils.watch_graph( 95 run_options, sess.graph, debug_urls=self._debug_url) 96 run_metadata = config_pb2.RunMetadata() 97 output = sess.run(fetches, feed_dict=feed_dict, options=run_options, 98 run_metadata=run_metadata) 99 if expected_output is not None: 100 self.assertAllClose(expected_output, output) 101 102 dump = debug_data.DebugDumpDir( 103 self._dump_dir, partition_graphs=run_metadata.partition_graphs, 104 validate=True) 105 reconstructed = dump.reconstructed_non_debug_partition_graphs() 106 107 self.assertEqual(len(non_debug_graph_defs), len(reconstructed)) 108 for i, non_debug_graph_def in enumerate(non_debug_graph_defs): 109 device_name = debug_graphs._infer_device_name(non_debug_graph_def) 110 test_util.assert_equal_graph_def( 111 self._graphDefWithoutBlacklistedNodes(reconstructed[device_name]), 112 self._graphDefWithoutBlacklistedNodes(non_debug_graph_def)) 113 114 # Test debug_graphs.reconstruct_non_debug_graph_def. 115 reconstructed_again = ( 116 debug_graphs.reconstruct_non_debug_graph_def( 117 run_metadata.partition_graphs[i])) 118 test_util.assert_equal_graph_def( 119 self._graphDefWithoutBlacklistedNodes(reconstructed_again), 120 self._graphDefWithoutBlacklistedNodes(non_debug_graph_def)) 121 122 def testReconstructSimpleGraph(self): 123 with session.Session() as sess: 124 u = variables.Variable([12.0], name="u") 125 v = variables.Variable([30.0], name="v") 126 w = math_ops.add(u, v, name="w") 127 sess.run(u.initializer) 128 sess.run(v.initializer) 129 130 self._compareOriginalAndReconstructedGraphDefs( 131 sess, w, expected_output=[42.0]) 132 133 def testReconstructGraphWithControlEdge(self): 134 with session.Session() as sess: 135 a = variables.Variable(10.0, name="a") 136 with ops.control_dependencies([a]): 137 b = math_ops.add(a, a, name="b") 138 with ops.control_dependencies([a, b]): 139 c = math_ops.multiply(b, b, name="c") 140 sess.run(a.initializer) 141 142 self._compareOriginalAndReconstructedGraphDefs( 143 sess, c, expected_output=400.0) 144 145 def testReonstructGraphWithCond(self): 146 with session.Session(config=self._no_rewrite_session_config()) as sess: 147 x = variables.Variable(10.0, name="x") 148 y = variables.Variable(20.0, name="y") 149 cond = control_flow_ops.cond( 150 x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1)) 151 sess.run(x.initializer) 152 sess.run(y.initializer) 153 154 self._compareOriginalAndReconstructedGraphDefs( 155 sess, cond, expected_output=21.0) 156 157 def testReconstructGraphWithWhileLoop(self): 158 with session.Session() as sess: 159 loop_body = lambda i: math_ops.add(i, 2) 160 loop_cond = lambda i: math_ops.less(i, 16) 161 i = constant_op.constant(10, name="i") 162 loop = control_flow_ops.while_loop(loop_cond, loop_body, [i]) 163 164 self._compareOriginalAndReconstructedGraphDefs(sess, loop) 165 166 def testReconstructGraphWithGradients(self): 167 with session.Session(config=self._no_rewrite_session_config()) as sess: 168 u = variables.Variable(12.0, name="u") 169 v = variables.Variable(30.0, name="v") 170 x = constant_op.constant(1.1, name="x") 171 toy_loss = x * (u - v) 172 train_op = gradient_descent.GradientDescentOptimizer( 173 learning_rate=0.1).minimize(toy_loss, name="train_op") 174 sess.run(u.initializer) 175 sess.run(v.initializer) 176 177 self._compareOriginalAndReconstructedGraphDefs(sess, train_op) 178 179 180 if __name__ == "__main__": 181 test.main() 182