Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
     17 #define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
     18 
     19 #include "tensorflow/core/common_runtime/device.h"
     20 #include "tensorflow/core/framework/rendezvous.h"
     21 #include "tensorflow/core/framework/session_state.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/graph/graph.h"
     24 #include "tensorflow/core/lib/core/notification.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/platform/logging.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 
     29 namespace tensorflow {
     30 
     31 class StepStatsCollector;
     32 
     33 // Executor runs a graph computation.
     34 // Example:
     35 //   Graph* graph = ...;
     36 //      ... construct graph ...
     37 //   Executor* executor;
     38 //   TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
     39 //   Rendezvous* rendezvous = NewNaiveRendezvous();
     40 //   TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
     41 //   TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
     42 //   TF_CHECK_OK(rendezvous->Recv("output", &output_tensor));
     43 //   ... ...
     44 //
     45 // Multiple threads can call Executor::Run concurrently.
     46 class Executor {
     47  public:
     48   virtual ~Executor() {}
     49 
     50   // RunAsync() executes the graph computation. "done" is run when the
     51   // graph computation completes. If any error happens during the
     52   // computation, "done" is run and the error is passed to "done".
     53   //
     54   // RunAsync() is given a few arguments in Args. The caller must
     55   // ensure objects passed in Args (rendezvous, stats_collector, etc.)
     56   // are alive at least until done is invoked. All pointers to the
     57   // argument objects can be nullptr.
     58   //
     59   // "step_id" is a process-wide unique identifier for the step being
     60   // run. Executors on different devices may receive the same step_id
     61   // in the case that a step runs Ops on more than one device. The
     62   // step_id is used for tracking resource usage of a given step.
     63   //
     64   // RunAsync() uses the given "rendezvous", if not null, as the
     65   // mechanism to communicate inputs and outputs of the underlying
     66   // graph computation.
     67   //
     68   // RunAsync() calls "stats_collector", if not null, to keep track of
     69   // stats. This allows us to collect statistics and traces on demand.
     70   //
     71   // RunAsync() is provided a "call_frame", if the executor is used
     72   // for executing a function, is used to pass arguments and return
     73   // values between the caller and the callee.
     74   //
     75   // RunAsync() uses "cancellation_manager", if not nullptr, to
     76   // register callbacks that should be called if the graph computation
     77   // is canceled. Note that the callbacks merely unblock any
     78   // long-running computation, and a canceled step will terminate by
     79   // returning/calling the DoneCallback as usual.
     80   //
     81   // RunAsync() dispatches closures to "runner". Typically, "runner"
     82   // is backed up by a bounded threadpool.
     83   struct Args {
     84     int64 step_id = 0;
     85     Rendezvous* rendezvous = nullptr;
     86     StepStatsCollector* stats_collector = nullptr;
     87     CallFrameInterface* call_frame = nullptr;
     88     CancellationManager* cancellation_manager = nullptr;
     89     SessionState* session_state = nullptr;
     90     TensorStore* tensor_store = nullptr;
     91     ScopedStepContainer* step_container = nullptr;
     92 
     93     // If true, calls Sync() on the device.
     94     bool sync_on_finish = false;
     95 
     96     typedef std::function<void()> Closure;
     97     typedef std::function<void(Closure)> Runner;
     98     Runner runner = nullptr;
     99 
    100     // A callback that is invoked each time a node has finished executing.
    101     typedef std::function<Status(const string& node_name, const int output_slot,
    102                                  const Tensor* tensor, const bool is_ref,
    103                                  OpKernelContext* ctx)>
    104         NodeOutputsCallback;
    105     NodeOutputsCallback node_outputs_cb = nullptr;
    106   };
    107   typedef std::function<void(const Status&)> DoneCallback;
    108   virtual void RunAsync(const Args& args, DoneCallback done) = 0;
    109 
    110   // Synchronous wrapper for RunAsync().
    111   Status Run(const Args& args) {
    112     Status ret;
    113     Notification n;
    114     RunAsync(args, [&ret, &n](const Status& s) {
    115       ret = s;
    116       n.Notify();
    117     });
    118     n.WaitForNotification();
    119     return ret;
    120   }
    121 };
    122 
    123 // Creates an Executor that computes the given "graph".
    124 //
    125 // If successful, returns the constructed executor in "*executor". Otherwise,
    126 // returns an error status.
    127 //
    128 // "params" provides a set of context for the executor. We expect that
    129 // different context would provide different implementations.
    130 struct LocalExecutorParams {
    131   Device* device;
    132 
    133   // The library runtime support.
    134   FunctionLibraryRuntime* function_library = nullptr;
    135 
    136   // create_kernel returns an instance of op kernel based on NodeDef.
    137   // delete_kernel is called for every kernel used by the executor
    138   // when the executor is deleted.
    139   std::function<Status(const NodeDef&, OpKernel**)> create_kernel;
    140   std::function<void(OpKernel*)> delete_kernel;
    141 
    142   Executor::Args::NodeOutputsCallback node_outputs_cb;
    143 };
    144 ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
    145                                       std::unique_ptr<const Graph> graph,
    146                                       Executor** executor);
    147 
    148 // A class to help run multiple executors in parallel and wait until
    149 // all of them are complete.
    150 //
    151 // ExecutorBarrier deletes itself after the function returned by Get()
    152 // is called.
    153 class ExecutorBarrier {
    154  public:
    155   typedef std::function<void(const Status&)> StatusCallback;
    156 
    157   // Create an ExecutorBarrier for 'num' different executors.
    158   //
    159   // 'r' is the shared Rendezvous object that is used to communicate
    160   // state.  If any of the executors experiences an error, the
    161   // rendezvous object will be aborted exactly once.
    162   //
    163   // 'done' is called after the last executor completes, and
    164   // ExecutorBarrier is deleted.
    165   ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
    166       : rendez_(r), done_cb_(done), pending_(num) {}
    167 
    168   ~ExecutorBarrier() {}
    169 
    170   // Returns a closure that Executors must call when they are done
    171   // computing, passing the status of their execution as an argument.
    172   StatusCallback Get() {
    173     return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
    174   }
    175 
    176  private:
    177   Rendezvous* rendez_ = nullptr;
    178   StatusCallback done_cb_ = nullptr;
    179 
    180   mutable mutex mu_;
    181   int pending_ GUARDED_BY(mu_) = 0;
    182   Status status_ GUARDED_BY(mu_);
    183 
    184   void WhenDone(const Status& s) {
    185     bool error = false;
    186     Rendezvous* error_rendez = nullptr;
    187     StatusCallback done = nullptr;
    188     Status status;
    189     {
    190       mutex_lock l(mu_);
    191       // If we are the first error encountered, mark the status
    192       // appropriately and later trigger an abort of the Rendezvous
    193       // object by this thread only.
    194       if (status_.ok() && !s.ok()) {
    195         error = true;
    196         error_rendez = rendez_;
    197         error_rendez->Ref();
    198         status_ = s;
    199       }
    200 
    201       // If this is the last call to WhenDone, call the final callback
    202       // below.
    203       if (--pending_ == 0) {
    204         CHECK(done_cb_ != nullptr);
    205         std::swap(done, done_cb_);
    206       }
    207 
    208       if (!status_.ok()) {
    209         status = status_;
    210       }
    211     }
    212 
    213     if (error) {
    214       error_rendez->StartAbort(status);
    215       error_rendez->Unref();
    216     }
    217     if (done != nullptr) {
    218       delete this;
    219       done(status);
    220     }
    221   }
    222 
    223   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
    224 };
    225 
    226 // A few helpers to facilitate create/delete kernels.
    227 
    228 // Creates a kernel based on "ndef" on device "device". The kernel can
    229 // access the functions in the "flib". The caller takes ownership of
    230 // returned "*kernel".
    231 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
    232                              const NodeDef& ndef, int graph_def_version,
    233                              OpKernel** kernel);
    234 
    235 // Deletes "kernel" returned by CreateKernel.
    236 void DeleteNonCachedKernel(OpKernel* kernel);
    237 
    238 }  // end namespace tensorflow
    239 
    240 #endif  // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
    241