Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     19 #include "tensorflow/core/framework/function.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/tensor.h"
     22 #include "tensorflow/core/framework/tensor_shape.h"
     23 #include "tensorflow/core/lib/core/threadpool.h"
     24 #include "tensorflow/core/platform/mutex.h"
     25 
     26 namespace tensorflow {
     27 
     28 typedef Eigen::GpuDevice GPUDevice;
     29 typedef Eigen::ThreadPoolDevice CPUDevice;
     30 typedef FunctionLibraryRuntime::Handle FHandle;
     31 typedef std::vector<Tensor> TensorVec;
     32 
     33 namespace {
     34 
     35 // Helper to instantiate function "func" in the library "lib".
     36 Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
     37                    FunctionLibraryRuntime::Handle* handle) {
     38   return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
     39 }
     40 
     41 // If "t" is a scalar of a supported type, returns t != 0 in "*v".
     42 Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
     43   if (t.size() != 1) {
     44     return errors::InvalidArgument(
     45         "Expected a single scalar which can be converted to a boolean, got ",
     46         t.size(), " tensors.");
     47   }
     48   if (TensorShapeUtils::IsScalar(t[0].shape())) {
     49     switch (t[0].dtype()) {
     50 #define CASE(T)                   \
     51   case DataTypeToEnum<T>::value:  \
     52     *v = t[0].scalar<T>()() != 0; \
     53     break;
     54 
     55       CASE(float);
     56       CASE(double);
     57       CASE(int32);
     58       CASE(uint8);
     59       CASE(int16);
     60       CASE(int8);
     61       CASE(int64);
     62 #undef CASE
     63       case DT_BOOL:
     64         *v = t[0].scalar<bool>()();
     65         break;
     66       case DT_STRING:
     67         *v = !t[0].scalar<string>()().empty();
     68         break;
     69       default:
     70         return errors::InvalidArgument(DataTypeString(t[0].dtype()),
     71                                        " cannot be converted to a boolean");
     72     }
     73   } else {
     74     *v = t[0].NumElements() > 0;
     75   }
     76   return Status::OK();
     77 }
     78 
     79 // Sets "rets" to be the output of "ctx". Validates rets' types based
     80 // on "kernel".
     81 Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx,
     82                   gtl::ArraySlice<Tensor> rets) {
     83   if (rets.size() != ctx->num_outputs()) {
     84     return errors::Internal("Expect to produce ", ctx->num_outputs(),
     85                             " tensors, but only get ", rets.size());
     86   }
     87   for (int i = 0; i < rets.size(); ++i) {
     88     if (rets[i].dtype() != kernel->output_type(i)) {
     89       return errors::Internal("Expect ", i, "-th output is of type ",
     90                               DataTypeString(kernel->output_type(i)),
     91                               " but get ", DataTypeString(rets[i].dtype()));
     92     }
     93     ctx->set_output(i, rets[i]);
     94   }
     95   return Status::OK();
     96 }
     97 
     98 void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
     99                    bool always_collect_stats) {
    100   opts->step_id = ctx->step_id();
    101   opts->rendezvous = ctx->rendezvous();
    102   opts->cancellation_manager = ctx->cancellation_manager();
    103   if (always_collect_stats) {
    104     opts->stats_collector = ctx->stats_collector();
    105   }
    106   opts->runner = ctx->runner();
    107 }
    108 
    109 }  // end namespace
    110 
    111 class FunctionalIf : public AsyncOpKernel {
    112  public:
    113   explicit FunctionalIf(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
    114     auto lib = ctx->function_library();
    115     OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
    116     const NameAttrList* func;
    117     OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &func));
    118     OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &then_handle_));
    119     OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &func));
    120     OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_));
    121   }
    122 
    123   ~FunctionalIf() override {}
    124 
    125   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
    126     bool cond;
    127     OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
    128     (new State(this, ctx, cond, done))->Start();
    129   }
    130 
    131  private:
    132   FHandle then_handle_;
    133   FHandle else_handle_;
    134 
    135   class State {
    136    public:
    137     State(FunctionalIf* kernel, OpKernelContext* ctx, bool cond,
    138           DoneCallback done)
    139         : kernel_(kernel),
    140           ctx_(ctx),
    141           cond_(cond),
    142           done_(done),
    143           lib_(CHECK_NOTNULL(ctx_->function_library())) {
    144       SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
    145       for (int i = 1; i < ctx_->num_inputs(); ++i) {
    146         args_.push_back(ctx_->input(i));
    147       }
    148     }
    149 
    150     ~State() {}
    151 
    152     void Start() {
    153       FHandle handle = cond_ ? kernel_->then_handle_ : kernel_->else_handle_;
    154       rets_.clear();
    155       lib_->Run(
    156           // Evaluate one of the branch.
    157           opts_, handle, args_, &rets_,
    158           // Done callback
    159           [this](Status s) {
    160             if (s.ok()) {
    161               s = SetOutputs(kernel_, ctx_, rets_);
    162             }
    163             ctx_->SetStatus(s);
    164             auto done = done_;
    165             delete this;
    166             done();
    167           });
    168     }
    169 
    170    private:
    171     FunctionalIf* const kernel_;
    172     OpKernelContext* const ctx_;
    173     const bool cond_;
    174     const DoneCallback done_;
    175     FunctionLibraryRuntime* const lib_;
    176     FunctionLibraryRuntime::Options opts_;
    177     TensorVec args_;
    178     TensorVec rets_;
    179   };
    180 };
    181 
    182 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_CPU), FunctionalIf);
    183 REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
    184                         FunctionalIf);
    185 
    186 class FunctionalWhile : public AsyncOpKernel {
    187  public:
    188   explicit FunctionalWhile(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
    189     OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &cond_func_));
    190     OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &body_func_));
    191   }
    192 
    193   ~FunctionalWhile() override {}
    194 
    195   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
    196     auto lib = ctx->function_library();
    197     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
    198                       errors::Internal("No function library"), done);
    199 
    200     // TODO(b/37549631): Because this op has `SetIsStateful()` in its
    201     // op registration, this kernel may be shared by multiple
    202     // subgraphs, which have different associated
    203     // `FunctionLibraryRuntime` objects and hence different `FHandle`
    204     // namespaces. We currently work around this by caching the map
    205     // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
    206     // functions this op uses.
    207     FHandle cond_handle;
    208     FHandle body_handle;
    209     {
    210       mutex_lock l(mu_);
    211       const auto iter = handles_.find(lib);
    212       if (iter == handles_.end()) {
    213         OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle),
    214                              done);
    215         OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle),
    216                              done);
    217         handles_[lib] = {cond_handle, body_handle};
    218       } else {
    219         cond_handle = iter->second.first;
    220         body_handle = iter->second.second;
    221       }
    222     }
    223 
    224     (new State(this, ctx, cond_handle, body_handle, done))->Start();
    225   }
    226 
    227  private:
    228   NameAttrList cond_func_;
    229   NameAttrList body_func_;
    230 
    231   mutex mu_;
    232   std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
    233       handles_ GUARDED_BY(mu_);
    234 
    235   class State {
    236    public:
    237     State(FunctionalWhile* kernel, OpKernelContext* ctx, FHandle cond_handle,
    238           FHandle body_handle, DoneCallback done)
    239         : kernel_(kernel),
    240           ctx_(ctx),
    241           cond_handle_(cond_handle),
    242           body_handle_(body_handle),
    243           done_(done),
    244           lib_(CHECK_NOTNULL(ctx_->function_library())) {
    245       SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
    246       for (int i = 0; i < ctx_->num_inputs(); ++i) {
    247         args_.push_back(ctx_->input(i));
    248       }
    249     }
    250 
    251     ~State() {}
    252 
    253     void Start() { EvalCond(); }
    254 
    255    private:
    256     FunctionalWhile* const kernel_;
    257     OpKernelContext* const ctx_;
    258     const FHandle cond_handle_;
    259     const FHandle body_handle_;
    260     const DoneCallback done_;
    261     FunctionLibraryRuntime* const lib_;
    262     FunctionLibraryRuntime::Options opts_;
    263     TensorVec args_;
    264     TensorVec rets_;
    265 
    266     void EvalCond() {
    267       lib_->Run(
    268           // Evaluate the condition.
    269           opts_, cond_handle_, args_, &rets_,
    270           // Done cb.
    271           [this](const Status& s) {
    272             if (!s.ok()) {
    273               return Finish(s);
    274             }
    275             StartBody();
    276           });
    277     }
    278 
    279     void StartBody() {
    280       bool cond;
    281       Status s = ToBool(rets_, &cond);
    282       if (!s.ok()) {
    283         return Finish(s);
    284       }
    285       if (!cond) {
    286         return Finish(Status::OK());
    287       }
    288       rets_.clear();
    289       lib_->Run(
    290           // Evaluate the body.
    291           opts_, body_handle_, args_, &rets_,
    292           // Done callback
    293           [this](const Status& s) {
    294             if (!s.ok()) {
    295               return Finish(s);
    296             }
    297             if (args_.size() != rets_.size()) {
    298               return Finish(errors::InvalidArgument(
    299                   "While loop body returned ", rets_.size(),
    300                   " arguments. Expected: ", args_.size()));
    301             }
    302             args_.clear();
    303             using std::swap;
    304             swap(args_, rets_);
    305             EvalCond();
    306           });
    307     }
    308 
    309     void Finish(Status s) {
    310       if (s.ok()) {
    311         s = SetOutputs(kernel_, ctx_, args_);
    312       }
    313       ctx_->SetStatus(s);
    314       done_();
    315       delete this;
    316     }
    317   };
    318 };
    319 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_CPU), FunctionalWhile);
    320 REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), FunctionalWhile);
    321 
    322 }  // namespace tensorflow
    323