Home | History | Annotate | Download | only in debug
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_DEBUG_DEBUG_SESSION_H_
     17 #define TENSORFLOW_DEBUG_DEBUG_SESSION_H_
     18 
     19 #include <unordered_map>
     20 
     21 #include "tensorflow/core/common_runtime/direct_session.h"
     22 #include "tensorflow/core/common_runtime/executor.h"
     23 
     24 namespace tensorflow {
     25 
     26 // Experimental. tfdb (TensorFlow Debugger): Gateway to intermediate node
     27 // outputs during Session Run calls. Currently limited to DirectSession.
     28 class DebugGateway {
     29  public:
     30   DebugGateway(DirectSession* session);
     31   virtual ~DebugGateway();
     32 
     33   // Callback for node completion. This callback is invoked only once for
     34   // a node regardless of whether it has one or more outputs. The value(s) of
     35   // the output tensor(s) are not necessarily available when this callback is
     36   // invoked. They may need to be asynchronously copied from device (e.g.,
     37   // GPU) to host, hence the need for the NodeValueCallback below.
     38   //
     39   // Args:
     40   //   node_name: Name of the node that has just completed execution
     41   //   any_output: Whether the node has any output(s)
     42   typedef std::function<void(const string& node_name, const bool any_output)>
     43       NodeCompletionCallback;
     44   void SetNodeCompletionCallback(NodeCompletionCallback callback);
     45 
     46   // Callback for node value. This is invoked when the value of a node's
     47   // output tensor is available on the host, possibly after copying from
     48   // a device (e.g., GPU).
     49   //
     50   // Args:
     51   //   node_name: Name of the node of which the output has become available
     52   //   output_slot: Output slot number of the output Tensor
     53   //   tensor_value: Reference to the tensor value
     54   //   is_ref: Whether the output of the reference type
     55   typedef std::function<void(const string& node_name, const int output_slot,
     56                              const Tensor& tensor_value, const bool is_ref)>
     57       NodeValueCallback;
     58   void SetNodeValueCallback(NodeValueCallback callback);
     59 
     60   // TODO(cais): Add whitelists for ops/tensors (e.g., {"A:0", "B:0"})
     61   // for node completion callback (whitelist_comp_) and node value callback
     62   // (whitelist_val_). If whitelist_comp_ is non-empty, the gateway will
     63   // invoke the NodeCompletionCallback only for the nodes specified in the
     64   // whitelist. And so forth for whitelist_val_.
     65 
     66  private:
     67   DirectSession* session_;
     68   // TODO(cais): DebugGateway currently supports only DirectSession. Add
     69   // support for GrpcSession.
     70 
     71   NodeCompletionCallback comp_cb_ = nullptr;
     72   NodeValueCallback val_cb_ = nullptr;
     73 
     74   typedef std::function<void(const Tensor* dst_tensor)> CopyDoneCallback;
     75 
     76   void CopyTensor(const string& node_name, const int output_slot,
     77                   const Tensor* src_tensor, OpKernelContext* ctx,
     78                   CopyDoneCallback copy_done_cb);
     79 };
     80 
     81 }  // end namespace tensorflow
     82 
     83 #endif  // TENSORFLOW_DEBUG_DEBUG_SESSION_H_
     84