Home | History | Annotate | Download | only in debug
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7   http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/debug/debug_grpc_testlib.h"
     17 
     18 #include "tensorflow/core/debug/debug_graph_utils.h"
     19 #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
     20 #include "tensorflow/core/framework/summary.pb.h"
     21 #include "tensorflow/core/lib/io/path.h"
     22 #include "tensorflow/core/lib/strings/str_util.h"
     23 #include "tensorflow/core/platform/env.h"
     24 #include "tensorflow/core/platform/protobuf.h"
     25 #include "tensorflow/core/platform/tracing.h"
     26 
     27 namespace tensorflow {
     28 
     29 namespace test {
     30 
     31 ::grpc::Status TestEventListenerImpl::SendEvents(
     32     ::grpc::ServerContext* context,
     33     ::grpc::ServerReaderWriter<::tensorflow::EventReply, ::tensorflow::Event>*
     34         stream) {
     35   Event event;
     36 
     37   while (stream->Read(&event)) {
     38     if (event.has_log_message()) {
     39       debug_metadata_strings.push_back(event.log_message().message());
     40       stream->Write(EventReply());
     41     } else if (!event.graph_def().empty()) {
     42       encoded_graph_defs.push_back(event.graph_def());
     43       stream->Write(EventReply());
     44     } else if (event.has_summary()) {
     45       const Summary::Value& val = event.summary().value(0);
     46 
     47       std::vector<string> name_items =
     48           tensorflow::str_util::Split(val.node_name(), ':');
     49 
     50       const string node_name = name_items[0];
     51       const string debug_op = name_items[2];
     52 
     53       const TensorProto& tensor_proto = val.tensor();
     54       Tensor tensor(tensor_proto.dtype());
     55       if (!tensor.FromProto(tensor_proto)) {
     56         return ::grpc::Status::CANCELLED;
     57       }
     58 
     59       // Obtain the device name, which is encoded in JSON.
     60       third_party::tensorflow::core::debug::DebuggerEventMetadata metadata;
     61       if (val.metadata().plugin_data().plugin_name() != "debugger") {
     62         // This plugin data was meant for another plugin.
     63         continue;
     64       }
     65       auto status = tensorflow::protobuf::util::JsonStringToMessage(
     66           val.metadata().plugin_data().content(), &metadata);
     67       if (!status.ok()) {
     68         // The device name could not be determined.
     69         continue;
     70       }
     71 
     72       device_names.push_back(metadata.device());
     73       node_names.push_back(node_name);
     74       output_slots.push_back(metadata.output_slot());
     75       debug_ops.push_back(debug_op);
     76       debug_tensors.push_back(tensor);
     77 
     78       // If the debug node is currently in the READ_WRITE mode, send an
     79       // EventReply to 1) unblock the execution and 2) optionally modify the
     80       // value.
     81       const DebugNodeKey debug_node_key(metadata.device(), node_name,
     82                                         metadata.output_slot(), debug_op);
     83       if (write_enabled_debug_node_keys_.find(debug_node_key) !=
     84           write_enabled_debug_node_keys_.end()) {
     85         stream->Write(EventReply());
     86       }
     87     }
     88   }
     89 
     90   {
     91     mutex_lock l(states_mu_);
     92     for (size_t i = 0; i < new_states_.size(); ++i) {
     93       EventReply event_reply;
     94       EventReply::DebugOpStateChange* change =
     95           event_reply.add_debug_op_state_changes();
     96 
     97       // State changes will take effect in the next stream, i.e., next debugged
     98       // Session.run() call.
     99       change->set_state(new_states_[i]);
    100       const DebugNodeKey& debug_node_key = debug_node_keys_[i];
    101       change->set_node_name(debug_node_key.node_name);
    102       change->set_output_slot(debug_node_key.output_slot);
    103       change->set_debug_op(debug_node_key.debug_op);
    104       stream->Write(event_reply);
    105 
    106       if (new_states_[i] == EventReply::DebugOpStateChange::READ_WRITE) {
    107         write_enabled_debug_node_keys_.insert(debug_node_key);
    108       } else {
    109         write_enabled_debug_node_keys_.erase(debug_node_key);
    110       }
    111     }
    112 
    113     debug_node_keys_.clear();
    114     new_states_.clear();
    115   }
    116 
    117   return ::grpc::Status::OK;
    118 }
    119 
    120 void TestEventListenerImpl::ClearReceivedDebugData() {
    121   debug_metadata_strings.clear();
    122   encoded_graph_defs.clear();
    123   device_names.clear();
    124   node_names.clear();
    125   output_slots.clear();
    126   debug_ops.clear();
    127   debug_tensors.clear();
    128 }
    129 
    130 void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream(
    131     const EventReply::DebugOpStateChange::State new_state,
    132     const DebugNodeKey& debug_node_key) {
    133   mutex_lock l(states_mu_);
    134 
    135   debug_node_keys_.push_back(debug_node_key);
    136   new_states_.push_back(new_state);
    137 }
    138 
    139 void TestEventListenerImpl::RunServer(const int server_port) {
    140   ::grpc::ServerBuilder builder;
    141   builder.AddListeningPort(strings::StrCat("localhost:", server_port),
    142                            ::grpc::InsecureServerCredentials());
    143   builder.RegisterService(this);
    144   std::unique_ptr<::grpc::Server> server = builder.BuildAndStart();
    145 
    146   while (!stop_requested_.load()) {
    147     Env::Default()->SleepForMicroseconds(200 * 1000);
    148   }
    149   server->Shutdown();
    150   stopped_.store(true);
    151 }
    152 
    153 void TestEventListenerImpl::StopServer() {
    154   stop_requested_.store(true);
    155   while (!stopped_.load()) {
    156   }
    157 }
    158 
    159 bool PollTillFirstRequestSucceeds(const string& server_url,
    160                                   const size_t max_attempts) {
    161   const int kSleepDurationMicros = 100 * 1000;
    162   size_t n_attempts = 0;
    163   bool success = false;
    164 
    165   // Try a number of times to send the Event proto to the server, as it may
    166   // take the server a few seconds to start up and become responsive.
    167   Tensor prep_tensor(DT_FLOAT, TensorShape({1, 1}));
    168   prep_tensor.flat<float>()(0) = 42.0f;
    169 
    170   while (n_attempts++ < max_attempts) {
    171     const uint64 wall_time = Env::Default()->NowMicros();
    172     Status publish_s = DebugIO::PublishDebugTensor(
    173         DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "prep_node", 0,
    174                      "DebugIdentity"),
    175         prep_tensor, wall_time, {server_url});
    176     Status close_s = DebugIO::CloseDebugURL(server_url);
    177 
    178     if (publish_s.ok() && close_s.ok()) {
    179       success = true;
    180       break;
    181     } else {
    182       Env::Default()->SleepForMicroseconds(kSleepDurationMicros);
    183     }
    184   }
    185 
    186   return success;
    187 }
    188 
    189 }  // namespace test
    190 
    191 }  // namespace tensorflow
    192