Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <deque>
     17 #include <vector>
     18 
     19 #include "tensorflow/core/common_runtime/device.h"
     20 #include "tensorflow/core/common_runtime/executor.h"
     21 #include "tensorflow/core/common_runtime/function.h"
     22 #include "tensorflow/core/common_runtime/memory_types.h"
     23 #include "tensorflow/core/framework/function.h"
     24 #include "tensorflow/core/framework/op.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/graph/algorithm.h"
     28 #include "tensorflow/core/graph/gradients.h"
     29 #include "tensorflow/core/graph/graph_constructor.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 #include "tensorflow/core/util/device_name_utils.h"
     32 
     33 namespace tensorflow {
     34 
     35 static const char* const kGradientOp = "SymbolicGradient";
     36 
     37 class ArgOp : public OpKernel {
     38  public:
     39   explicit ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     40     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
     41     OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
     42   }
     43 
     44   void Compute(OpKernelContext* ctx) override {
     45     auto frame = ctx->call_frame();
     46     OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
     47     Tensor val;
     48     OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
     49     OP_REQUIRES(ctx, val.dtype() == dtype_,
     50                 errors::InvalidArgument(
     51                     "Type mismatch: actual ", DataTypeString(val.dtype()),
     52                     " vs. expect ", DataTypeString(dtype_)));
     53     ctx->set_output(0, val);
     54   }
     55 
     56   bool IsExpensive() override { return false; }
     57 
     58  private:
     59   int index_;
     60   DataType dtype_;
     61 
     62   TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
     63 };
     64 
     65 class RetvalOp : public OpKernel {
     66  public:
     67   explicit RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     68     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
     69     OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
     70   }
     71 
     72   void Compute(OpKernelContext* ctx) override {
     73     const Tensor& val = ctx->input(0);
     74     OP_REQUIRES(ctx, val.dtype() == dtype_,
     75                 errors::InvalidArgument(
     76                     "Type mismatch: actual ", DataTypeString(val.dtype()),
     77                     " vs. expect ", DataTypeString(dtype_)));
     78     auto frame = ctx->call_frame();
     79     OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
     80     OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
     81   }
     82 
     83   bool IsExpensive() override { return false; }
     84 
     85  private:
     86   int index_;
     87   DataType dtype_;
     88 
     89   TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
     90 };
     91 
     92 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
     93 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
     94 
     95 #if TENSORFLOW_USE_SYCL
     96 #define REGISTER(type)     \
     97   REGISTER_KERNEL_BUILDER( \
     98       Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
     99 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
    100 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
    101                                                    .Device(DEVICE_SYCL)
    102                                                    .HostMemory("output")
    103                                                    .TypeConstraint<int32>("T"),
    104                                                ArgOp);
    105 #undef REGISTER
    106 #define REGISTER(type)                                               \
    107   REGISTER_KERNEL_BUILDER(                                           \
    108       Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    109       RetvalOp);
    110 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
    111 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
    112                                                    .Device(DEVICE_SYCL)
    113                                                    .HostMemory("input")
    114                                                    .TypeConstraint<int32>("T"),
    115                                                RetvalOp);
    116 #undef REGISTER
    117 #endif
    118 
    119 #define REGISTER(type)     \
    120   REGISTER_KERNEL_BUILDER( \
    121       Name("_Arg").Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
    122 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
    123 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
    124                                                    .Device(DEVICE_GPU)
    125                                                    .HostMemory("output")
    126                                                    .TypeConstraint<int32>("T"),
    127                                                ArgOp);
    128 #undef REGISTER
    129 
    130 REGISTER_KERNEL_BUILDER(Name("_Arg")
    131                             .Device(DEVICE_GPU)
    132                             .HostMemory("output")
    133                             .TypeConstraint<ResourceHandle>("T"),
    134                         ArgOp);
    135 
    136 #define REGISTER(type)     \
    137   REGISTER_KERNEL_BUILDER( \
    138       Name("_Retval").Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
    139 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
    140 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
    141                                                    .Device(DEVICE_GPU)
    142                                                    .HostMemory("input")
    143                                                    .TypeConstraint<int32>("T"),
    144                                                RetvalOp);
    145 #undef REGISTER
    146 
    147 class PassOn : public OpKernel {
    148  public:
    149   explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
    150     OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
    151                 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
    152                                  " vs. ", ctx->num_outputs()));
    153     for (int i = 0; i < ctx->num_inputs(); ++i) {
    154       OP_REQUIRES(
    155           ctx, input_type(i) == output_type(i),
    156           errors::Internal("Input and output types for position ", i,
    157                            " do not match: ", DataTypeString(input_type(i)),
    158                            " vs. ", DataTypeString(output_type(i))));
    159     }
    160   }
    161 
    162   void Compute(OpKernelContext* ctx) override {
    163     for (int i = 0; i < ctx->num_inputs(); ++i) {
    164       ctx->set_output(i, ctx->input(i));
    165     }
    166   }
    167 };
    168 
    169 REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
    170 REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
    171 
    172 #define REGISTER_GPU_KERNELS(type)                                       \
    173   REGISTER_KERNEL_BUILDER(                                               \
    174       Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    175       PassOn);                                                           \
    176   REGISTER_KERNEL_BUILDER(                                               \
    177       Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    178       PassOn);
    179 
    180 REGISTER_GPU_KERNELS(Eigen::half);
    181 REGISTER_GPU_KERNELS(float);
    182 REGISTER_GPU_KERNELS(double);
    183 
    184 #undef REGISTER_GPU_KERNELS
    185 
    186 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
    187                             .Device(DEVICE_GPU)
    188                             .HostMemory("input")
    189                             .HostMemory("output")
    190                             .TypeConstraint<int32>("T"),
    191                         PassOn);
    192 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
    193                             .Device(DEVICE_GPU)
    194                             .HostMemory("input")
    195                             .HostMemory("output")
    196                             .TypeConstraint<int32>("T"),
    197                         PassOn);
    198 
    199 #ifdef TENSORFLOW_USE_SYCL
    200 #define REGISTER_SYCL_KERNELS(type)                                       \
    201   REGISTER_KERNEL_BUILDER(                                                \
    202       Name("_ListToArray").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    203       PassOn);                                                            \
    204   REGISTER_KERNEL_BUILDER(                                                \
    205       Name("_ArrayToList").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    206       PassOn);
    207 
    208 REGISTER_SYCL_KERNELS(float);
    209 REGISTER_SYCL_KERNELS(double);
    210 
    211 #undef REGISTER_SYCL_KERNELS
    212 
    213 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
    214                             .Device(DEVICE_SYCL)
    215                             .HostMemory("input")
    216                             .HostMemory("output")
    217                             .TypeConstraint<int32>("T"),
    218                         PassOn);
    219 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
    220                             .Device(DEVICE_SYCL)
    221                             .HostMemory("input")
    222                             .HostMemory("output")
    223                             .TypeConstraint<int32>("T"),
    224                         PassOn);
    225 #endif  // TENSORFLOW_USE_SYCL
    226 
    227 class SymbolicGradientOp : public AsyncOpKernel {
    228  public:
    229   explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
    230 
    231   ~SymbolicGradientOp() override {}
    232 
    233   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
    234     FunctionLibraryRuntime* lib = ctx->function_library();
    235     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
    236                       errors::Internal("No function library is provided."),
    237                       done);
    238 
    239     FunctionLibraryRuntime::Handle handle;
    240     OP_REQUIRES_OK_ASYNC(
    241         ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
    242 
    243     FunctionLibraryRuntime::Options opts;
    244     opts.step_id = ctx->step_id();
    245     opts.rendezvous = ctx->rendezvous();
    246     opts.cancellation_manager = ctx->cancellation_manager();
    247     opts.runner = ctx->runner();
    248     opts.stats_collector = ctx->stats_collector();
    249     opts.step_container = ctx->step_container();
    250     std::vector<Tensor> args;
    251     args.reserve(ctx->num_inputs());
    252     for (int i = 0; i < ctx->num_inputs(); ++i) {
    253       args.push_back(ctx->input(i));
    254     }
    255     std::vector<Tensor>* rets = new std::vector<Tensor>;
    256     lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
    257       if (!status.ok()) {
    258         ctx->SetStatus(status);
    259       } else if (rets->size() != ctx->num_outputs()) {
    260         ctx->SetStatus(errors::InvalidArgument(
    261             "SymGrad expects to return ", ctx->num_outputs(),
    262             " tensor(s), but get ", rets->size(), " tensor(s) instead."));
    263       } else {
    264         for (size_t i = 0; i < rets->size(); ++i) {
    265           ctx->set_output(i, (*rets)[i]);
    266         }
    267       }
    268       delete rets;
    269       done();
    270     });
    271   }
    272 
    273  private:
    274   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
    275 };
    276 
    277 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
    278                         SymbolicGradientOp);
    279 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
    280                         SymbolicGradientOp);
    281 #if TENSORFLOW_USE_SYCL
    282 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL),
    283                         SymbolicGradientOp);
    284 
    285 #endif  // TENSORFLOW_USE_SYCL
    286 
    287 class RemoteCallOp : public AsyncOpKernel {
    288  public:
    289   explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
    290     OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
    291   }
    292 
    293   ~RemoteCallOp() override {}
    294 
    295   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
    296     const Tensor* target;
    297     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
    298     const string& target_device =
    299         DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()());
    300 
    301     FunctionLibraryRuntime* lib = ctx->function_library();
    302     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
    303                       errors::Internal("No function library is provided."),
    304                       done);
    305     AttrValueMap attr_values = func_.attr();
    306     FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
    307     instantiate_opts.target = target_device;
    308     FunctionLibraryRuntime::Handle handle;
    309     OP_REQUIRES_OK_ASYNC(ctx,
    310                          lib->Instantiate(func_.name(), AttrSlice(&attr_values),
    311                                           instantiate_opts, &handle),
    312                          done);
    313 
    314     OpInputList arguments;
    315     OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
    316 
    317     FunctionLibraryRuntime::Options opts;
    318     opts.step_id = ctx->step_id();
    319     opts.runner = ctx->runner();
    320     opts.source_device = lib->device()->name();
    321     if (opts.source_device != target_device) {
    322       opts.remote_execution = true;
    323     }
    324     opts.create_rendezvous = true;
    325     std::vector<Tensor> args;
    326     args.reserve(arguments.size());
    327     for (const Tensor& argument : arguments) {
    328       args.push_back(argument);
    329     }
    330     auto* rets = new std::vector<Tensor>;
    331     lib->Run(opts, handle, args, rets, [rets, done, ctx](const Status& status) {
    332       if (!status.ok()) {
    333         ctx->SetStatus(status);
    334       } else {
    335         for (size_t i = 0; i < rets->size(); ++i) {
    336           ctx->set_output(i, (*rets)[i]);
    337         }
    338       }
    339       delete rets;
    340       done();
    341     });
    342   }
    343 
    344  private:
    345   string target_;
    346   NameAttrList func_;
    347   TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
    348 };
    349 
    350 REGISTER_KERNEL_BUILDER(
    351     Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
    352 REGISTER_KERNEL_BUILDER(
    353     Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
    354 #if TENSORFLOW_USE_SYCL
    355 REGISTER_KERNEL_BUILDER(
    356     Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
    357 
    358 #endif  // TENSORFLOW_USE_SYCL
    359 }  // namespace tensorflow
    360