Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for debugger functionalities in tf.Session with grpc:// URLs.
     16 
     17 This test focus on grpc:// debugging of distributed (gRPC) sessions.
     18 """
     19 from __future__ import absolute_import
     20 from __future__ import division
     21 from __future__ import print_function
     22 
     23 import json
     24 import subprocess
     25 import sys
     26 import time
     27 
     28 import portpicker
     29 from six.moves import xrange  # pylint: disable=redefined-builtin
     30 
     31 from tensorflow.core.protobuf import config_pb2
     32 from tensorflow.python.client import session
     33 from tensorflow.python.debug.lib import debug_utils
     34 from tensorflow.python.debug.lib import grpc_debug_test_server
     35 from tensorflow.python.debug.wrappers import framework
     36 from tensorflow.python.debug.wrappers import grpc_wrapper
     37 from tensorflow.python.framework import ops
     38 from tensorflow.python.framework import test_util
     39 from tensorflow.python.ops import math_ops
     40 from tensorflow.python.ops import state_ops
     41 from tensorflow.python.ops import variables
     42 from tensorflow.python.platform import googletest
     43 from tensorflow.python.platform import test
     44 from tensorflow.python.platform import tf_logging
     45 
     46 
     47 class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
     48   """Test the debugging of distributed sessions."""
     49 
     50   PER_PROC_GPU_MEMORY_FRACTION = 0.1
     51   POLLING_INTERVAL_SEC = 0.025
     52 
     53   @classmethod
     54   def setUpClass(cls):
     55     gpu_memory_fraction_opt = (
     56         "--gpu_memory_fraction=%f" % cls.PER_PROC_GPU_MEMORY_FRACTION)
     57 
     58     worker_port = portpicker.pick_unused_port()
     59     cluster_spec = "worker|localhost:%d" % worker_port
     60     tf_logging.info("cluster_spec: %s", cluster_spec)
     61 
     62     server_bin = test.test_src_dir_path(
     63         "tools/dist_test/server/grpc_tensorflow_server")
     64 
     65     cls.server_target = "grpc://localhost:%d" % worker_port
     66 
     67     cls.server_procs = {}
     68     cls.server_procs["worker"] = subprocess.Popen(
     69         [
     70             server_bin,
     71             "--cluster_spec=%s" % cluster_spec,
     72             "--job_name=worker",
     73             "--task_id=0",
     74             gpu_memory_fraction_opt,
     75         ],
     76         stdout=sys.stdout,
     77         stderr=sys.stderr)
     78 
     79     # Start debug server in-process, on separate thread.
     80     (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread,
     81      cls.debug_server
     82     ) = grpc_debug_test_server.start_server_on_separate_thread(
     83         dump_to_filesystem=False)
     84     tf_logging.info("debug server url: %s", cls.debug_server_url)
     85 
     86     cls.session_config = config_pb2.ConfigProto(
     87         gpu_options=config_pb2.GPUOptions(
     88             per_process_gpu_memory_fraction=cls.PER_PROC_GPU_MEMORY_FRACTION))
     89 
     90   @classmethod
     91   def tearDownClass(cls):
     92     for key in cls.server_procs:
     93       cls.server_procs[key].terminate()
     94     cls.debug_server.stop_server().wait()
     95     cls.debug_server_thread.join()
     96 
     97   def setUp(self):
     98     pass
     99 
    100   def tearDown(self):
    101     self.debug_server.clear_data()
    102 
    103   def _pollingAssertDebugTensorValuesAllClose(self, expected_values,
    104                                               debug_tensor_name):
    105     """Poll debug_server till tensor appears and matches expected values."""
    106     while (debug_tensor_name not in self.debug_server.debug_tensor_values or
    107            len(self.debug_server.debug_tensor_values) < len(expected_values)):
    108       time.sleep(self.POLLING_INTERVAL_SEC)
    109     self.assertAllClose(
    110         expected_values,
    111         self.debug_server.debug_tensor_values[debug_tensor_name])
    112 
    113   def _createGraph(self):
    114     """Create graph for testing.
    115 
    116     Returns:
    117       Python Graph object.
    118     """
    119     with ops.Graph().as_default() as graph:
    120       with ops.device("/job:worker/task:0/cpu:0"):
    121         self.a = variables.Variable(10.0, name="a")
    122         self.b = variables.Variable(100.0, name="b")
    123         self.inc_a = state_ops.assign_add(self.a, 2.0, name="inc_a")
    124         self.dec_b = state_ops.assign_add(self.b, -5.0, name="dec_b")
    125         self.p = math_ops.multiply(self.inc_a, self.dec_b, name="p")
    126         self.q = math_ops.negative(self.p, name="q")
    127     return graph
    128 
    129   def testDistributedRunWithGatedGrpcCommunicatesWithDebugServerCorrectly(self):
    130     graph = self._createGraph()
    131     with session.Session(
    132         config=self.session_config, graph=graph,
    133         target=self.server_target) as sess:
    134       sess.run(self.a.initializer)
    135       sess.run(self.b.initializer)
    136 
    137       run_options = config_pb2.RunOptions()
    138       debug_utils.watch_graph(
    139           run_options,
    140           sess.graph,
    141           node_name_regex_whitelist=r"a",
    142           debug_ops=["DebugIdentity"],
    143           debug_urls=[self.debug_server_url])
    144 
    145       # Test gated_grpc for an op located on the worker, i.e., on the same
    146       # host as where MasterSession is.
    147       # TODO(cais): gRPC gating of debug ops does not work on partition graphs
    148       # not located on MasterSession hosts (e.g., parameter servers) yet. Make
    149       # it work.
    150       debug_utils.watch_graph(
    151           run_options,
    152           sess.graph,
    153           node_name_regex_whitelist=r"p",
    154           debug_ops=["DebugIdentity(gated_grpc=True)"],
    155           debug_urls=[self.debug_server_url])
    156 
    157       for i in xrange(4):
    158         if i % 2 == 0:
    159           self.debug_server.request_watch("p", 0, "DebugIdentity")
    160         else:
    161           self.debug_server.request_unwatch("p", 0, "DebugIdentity")
    162 
    163         expected_p = (10.0 + 2.0 * (i + 1)) * (100.0 - 5.0 * (i + 1))
    164         self.assertAllClose(-expected_p, sess.run(self.q, options=run_options))
    165 
    166         self.assertEqual(1, len(self.debug_server.core_metadata_json_strings))
    167         core_metadata = json.loads(
    168             self.debug_server.core_metadata_json_strings[0])
    169         self.assertEqual([], core_metadata["input_names"])
    170         self.assertEqual(["q:0"], core_metadata["output_names"])
    171         self.assertEqual(i, core_metadata["executor_step_index"])
    172 
    173         if i == 0:
    174           self.assertEqual(1, len(self.debug_server.partition_graph_defs))
    175 
    176         # Tensor "a" is from a PS. It may take longer to arrive due to the fact
    177         # that the stream connection between the PS and the debug server is
    178         # persistent and not torn down at the end of each Session.run()
    179         self._pollingAssertDebugTensorValuesAllClose([10.0 + 2.0 * i],
    180                                                      "a:0:DebugIdentity")
    181 
    182         # Due to the gRPC gating of the debug op for "p", the debug tensor
    183         # should be available on odd-indexed runs.
    184         if i % 2 == 0:
    185           self.assertAllClose(
    186               [expected_p],
    187               self.debug_server.debug_tensor_values["p:0:DebugIdentity"])
    188         else:
    189           self.assertNotIn("p:0:DebugIdentity",
    190                            self.debug_server.debug_tensor_values)
    191 
    192         self.assertNotIn("b:0:DebugIdentity",
    193                          self.debug_server.debug_tensor_values)
    194         self.debug_server.clear_data()
    195 
    196   def testDistributedRunWithGrpcDebugWrapperWorks(self):
    197     graph = self._createGraph()
    198     with session.Session(
    199         config=self.session_config, graph=graph,
    200         target=self.server_target) as sess:
    201       sess.run(self.a.initializer)
    202       sess.run(self.b.initializer)
    203 
    204       def watch_fn(feeds, fetch_keys):
    205         del feeds, fetch_keys
    206         return framework.WatchOptions(
    207             debug_ops=["DebugIdentity"],
    208             node_name_regex_whitelist=r"p")
    209       sess = grpc_wrapper.GrpcDebugWrapperSession(
    210           sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
    211 
    212       for i in xrange(4):
    213         expected_p = (10.0 + 2.0 * (i + 1)) * (100.0 - 5.0 * (i + 1))
    214         self.assertAllClose(-expected_p, sess.run(self.q))
    215 
    216         if i == 0:
    217           self.assertEqual(1, len(self.debug_server.partition_graph_defs))
    218 
    219         self.assertAllClose(
    220             [expected_p],
    221             self.debug_server.debug_tensor_values["p:0:DebugIdentity"])
    222         self.assertNotIn("b:0:DebugIdentity",
    223                          self.debug_server.debug_tensor_values)
    224         self.debug_server.clear_data()
    225 
    226 
    227 if __name__ == "__main__":
    228   googletest.main()
    229