1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DEBUG_DEBUG_SESSION_H_ 17 #define TENSORFLOW_DEBUG_DEBUG_SESSION_H_ 18 19 #include <unordered_map> 20 21 #include "tensorflow/core/common_runtime/direct_session.h" 22 #include "tensorflow/core/common_runtime/executor.h" 23 24 namespace tensorflow { 25 26 // Experimental. tfdb (TensorFlow Debugger): Gateway to intermediate node 27 // outputs during Session Run calls. Currently limited to DirectSession. 28 class DebugGateway { 29 public: 30 DebugGateway(DirectSession* session); 31 virtual ~DebugGateway(); 32 33 // Callback for node completion. This callback is invoked only once for 34 // a node regardless of whether it has one or more outputs. The value(s) of 35 // the output tensor(s) are not necessarily available when this callback is 36 // invoked. They may need to be asynchronously copied from device (e.g., 37 // GPU) to host, hence the need for the NodeValueCallback below. 38 // 39 // Args: 40 // node_name: Name of the node that has just completed execution 41 // any_output: Whether the node has any output(s) 42 typedef std::function<void(const string& node_name, const bool any_output)> 43 NodeCompletionCallback; 44 void SetNodeCompletionCallback(NodeCompletionCallback callback); 45 46 // Callback for node value. This is invoked when the value of a node's 47 // output tensor is available on the host, possibly after copying from 48 // a device (e.g., GPU). 49 // 50 // Args: 51 // node_name: Name of the node of which the output has become available 52 // output_slot: Output slot number of the output Tensor 53 // tensor_value: Reference to the tensor value 54 // is_ref: Whether the output of the reference type 55 typedef std::function<void(const string& node_name, const int output_slot, 56 const Tensor& tensor_value, const bool is_ref)> 57 NodeValueCallback; 58 void SetNodeValueCallback(NodeValueCallback callback); 59 60 // TODO(cais): Add whitelists for ops/tensors (e.g., {"A:0", "B:0"}) 61 // for node completion callback (whitelist_comp_) and node value callback 62 // (whitelist_val_). If whitelist_comp_ is non-empty, the gateway will 63 // invoke the NodeCompletionCallback only for the nodes specified in the 64 // whitelist. And so forth for whitelist_val_. 65 66 private: 67 DirectSession* session_; 68 // TODO(cais): DebugGateway currently supports only DirectSession. Add 69 // support for GrpcSession. 70 71 NodeCompletionCallback comp_cb_ = nullptr; 72 NodeValueCallback val_cb_ = nullptr; 73 74 typedef std::function<void(const Tensor* dst_tensor)> CopyDoneCallback; 75 76 void CopyTensor(const string& node_name, const int output_slot, 77 const Tensor* src_tensor, OpKernelContext* ctx, 78 CopyDoneCallback copy_done_cb); 79 }; 80 81 } // end namespace tensorflow 82 83 #endif // TENSORFLOW_DEBUG_DEBUG_SESSION_H_ 84