Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/register_types.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/kernels/fill_functor.h"
     23 #include "tensorflow/core/kernels/inplace_ops_functor.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 
     26 namespace tensorflow {
     27 typedef Eigen::ThreadPoolDevice CPUDevice;
     28 #ifdef TENSORFLOW_USE_SYCL
     29 typedef Eigen::SyclDevice SyclDevice;
     30 #endif  // TENSORFLOW_USE_SYCL
     31 
     32 namespace functor {
     33 
     34 template <typename Device, typename T>
     35 Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32 loc,
     36                               Tensor* output) {
     37   auto Tvalue = value.shaped<T, 2>({1, value.NumElements()});
     38   auto Toutput = output->flat_outer_dims<T>();
     39   auto nrows = Toutput.dimension(0);
     40   auto r = (loc % nrows + nrows) % nrows;  // Guard index range.
     41   Toutput.template chip<0>(r).device(d) = Tvalue.template chip<0>(0);
     42   return Status::OK();
     43 }
     44 
     45 template <>
     46 Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc,
     47                         Tensor* output) {
     48   CHECK_EQ(value.dtype(), output->dtype());
     49   switch (value.dtype()) {
     50 #define CASE(type)                  \
     51   case DataTypeToEnum<type>::value: \
     52     return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output);
     53     TF_CALL_NUMBER_TYPES(CASE);
     54     TF_CALL_string(CASE);
     55     TF_CALL_variant(CASE);
     56 #undef CASE
     57     default:
     58       return errors::InvalidArgument("Unsupported data type: ", value.dtype());
     59   }
     60 }
     61 
     62 #ifdef TENSORFLOW_USE_SYCL
     63 template <>
     64 Status DoParallelConcat(const SyclDevice& d, const Tensor& value, int32 loc,
     65                         Tensor* output) {
     66   CHECK_EQ(value.dtype(), output->dtype());
     67   switch (value.dtype()) {
     68 #define CASE(type)                  \
     69   case DataTypeToEnum<type>::value: \
     70     return DoParallelConcatUpdate<SyclDevice, type>(d, value, loc, output);
     71     TF_CALL_GPU_NUMBER_TYPES_NO_HALF(CASE);
     72 #undef CASE
     73     default:
     74       return errors::InvalidArgument("Unsupported data type: ", value.dtype());
     75   }
     76 }
     77 #endif  // TENSORFLOW_USE_SYCL
     78 
     79 }  // end namespace functor
     80 
     81 namespace {
     82 
     83 template <typename Device>
     84 class ParallelConcatUpdate : public OpKernel {
     85  public:
     86   explicit ParallelConcatUpdate(OpKernelConstruction* ctx) : OpKernel(ctx) {
     87     OP_REQUIRES_OK(ctx, ctx->GetAttr("loc", &loc_));
     88   }
     89 
     90   void Compute(OpKernelContext* ctx) override {
     91     auto value = ctx->input(0);
     92     auto update = ctx->input(1);
     93 
     94     OP_REQUIRES(
     95         ctx, value.dims() == update.dims(),
     96         errors::InvalidArgument("value and update shape doesn't match: ",
     97                                 value.shape().DebugString(), " vs. ",
     98                                 update.shape().DebugString()));
     99     for (int i = 1; i < value.dims(); ++i) {
    100       OP_REQUIRES(
    101           ctx, value.dim_size(i) == update.dim_size(i),
    102           errors::InvalidArgument("value and update shape doesn't match ",
    103                                   value.shape().DebugString(), " vs. ",
    104                                   update.shape().DebugString()));
    105     }
    106     OP_REQUIRES(ctx, 1 == update.dim_size(0),
    107                 errors::InvalidArgument("update shape doesn't match: ",
    108                                         update.shape().DebugString()));
    109 
    110     Tensor output = value;  // This creates an alias intentionally.
    111     const auto& d = ctx->eigen_device<Device>();
    112     OP_REQUIRES_OK(
    113         ctx, ::tensorflow::functor::DoParallelConcat(d, update, loc_, &output));
    114     ctx->set_output(0, output);
    115   }
    116 
    117  private:
    118   int32 loc_;
    119 };
    120 
    121 template <typename Device, typename T>
    122 class ParallelConcatStart : public OpKernel {
    123  public:
    124   explicit ParallelConcatStart(OpKernelConstruction* ctx) : OpKernel(ctx) {
    125     OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
    126   }
    127 
    128   void Compute(OpKernelContext* ctx) override {
    129     Tensor* out = nullptr;
    130     // We do not know whether the output will be used on GPU. Setting it to be
    131     // gpu-compatible for now.
    132     AllocatorAttributes attr;
    133     attr.set_gpu_compatible(true);
    134     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape_, &out, attr));
    135   }
    136 
    137  private:
    138   TensorShape shape_;
    139 };
    140 
    141 class FailureKernel : public OpKernel {
    142  public:
    143   explicit FailureKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {
    144     OP_REQUIRES_OK(ctx,
    145                    errors::Internal("Found instance of parallel_stack which "
    146                                     "could not be properly replaced."));
    147   }
    148 
    149   void Compute(OpKernelContext*) override {}
    150 };
    151 
    152 #define REGISTER(type)                                    \
    153   REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")   \
    154                               .Device(DEVICE_CPU)         \
    155                               .TypeConstraint<type>("T"), \
    156                           ParallelConcatUpdate<CPUDevice>);
    157 TF_CALL_POD_STRING_TYPES(REGISTER)
    158 #undef REGISTER
    159 
    160 #define REGISTER_EMPTY(type)                                  \
    161   REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart")        \
    162                               .Device(DEVICE_CPU)             \
    163                               .TypeConstraint<type>("dtype"), \
    164                           ParallelConcatStart<CPUDevice, type>)
    165 
    166 TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY)
    167 #undef REGISTER_EMPTY
    168 
    169 #define REGISTER_PARALLEL_CONCAT(type)                                     \
    170   REGISTER_KERNEL_BUILDER(                                                 \
    171       Name("ParallelConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    172       FailureKernel);
    173 TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT);
    174 #undef REGISTER_PARALLEL_CONCAT
    175 
    176 #ifdef TENSORFLOW_USE_SYCL
    177 #define REGISTER_EMPTY(type)                                  \
    178   REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart")        \
    179                               .Device(DEVICE_SYCL)            \
    180                               .TypeConstraint<type>("dtype"), \
    181                           ParallelConcatStart<SyclDevice, type>);
    182 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_EMPTY)
    183 #undef REGISTER_EMPTY
    184 
    185 #define REGISTER_PARALLEL_CONCAT(type)                                      \
    186   REGISTER_KERNEL_BUILDER(                                                  \
    187       Name("ParallelConcat").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    188       FailureKernel);
    189 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_PARALLEL_CONCAT);
    190 #undef REGISTER_PARALLEL_CONCAT
    191 
    192 #define REGISTER(type)                                    \
    193   REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")   \
    194                               .Device(DEVICE_SYCL)        \
    195                               .TypeConstraint<type>("T"), \
    196                           ParallelConcatUpdate<SyclDevice>);
    197 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER)
    198 #undef REGISTER
    199 
    200 // Register versions that operate on int32 data on the CPU even though the op
    201 // has been placed on the SYCL
    202 
    203 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
    204                             .Device(DEVICE_SYCL)
    205                             .HostMemory("value")
    206                             .HostMemory("update")
    207                             .HostMemory("output")
    208                             .TypeConstraint<int32>("T"),
    209                         ParallelConcatUpdate<CPUDevice>);
    210 #endif  // TENSORFLOW_USE_SYCL
    211 
    212 #if GOOGLE_CUDA
    213 
    214 typedef Eigen::GpuDevice GPUDevice;
    215 
    216 #define REGISTER_EMPTY(type)                                  \
    217   REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart")        \
    218                               .Device(DEVICE_GPU)             \
    219                               .TypeConstraint<type>("dtype"), \
    220                           ParallelConcatStart<GPUDevice, type>);
    221 TF_CALL_GPU_NUMBER_TYPES(REGISTER_EMPTY)
    222 #undef REGISTER_EMPTY
    223 
    224 #define REGISTER_PARALLEL_CONCAT(type)                                     \
    225   REGISTER_KERNEL_BUILDER(                                                 \
    226       Name("ParallelConcat").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    227       FailureKernel);
    228 TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT);
    229 #undef REGISTER_PARALLEL_CONCAT
    230 
    231 #define REGISTER(type)                                    \
    232   REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")   \
    233                               .Device(DEVICE_GPU)         \
    234                               .TypeConstraint<type>("T"), \
    235                           ParallelConcatUpdate<GPUDevice>);
    236 TF_CALL_GPU_NUMBER_TYPES(REGISTER)
    237 #undef REGISTER
    238 
    239 // Register versions that operate on int32 data on the CPU even though the op
    240 // has been placed on the GPU
    241 
    242 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
    243                             .Device(DEVICE_GPU)
    244                             .HostMemory("value")
    245                             .HostMemory("update")
    246                             .HostMemory("output")
    247                             .TypeConstraint<int32>("T"),
    248                         ParallelConcatUpdate<CPUDevice>);
    249 #endif
    250 
    251 }  // end namespace
    252 }  // end namespace tensorflow
    253