Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/bias_op.h"
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/numeric_op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/kernels/bounds_check.h"
     27 #include "tensorflow/core/util/tensor_format.h"
     28 
     29 #if GOOGLE_CUDA
     30 #include "tensorflow/core/kernels/bias_op_gpu.h"
     31 #include "tensorflow/core/platform/stream_executor.h"
     32 #endif  // GOOGLE_CUDA
     33 
     34 namespace tensorflow {
     35 
     36 typedef Eigen::ThreadPoolDevice CPUDevice;
     37 typedef Eigen::GpuDevice GPUDevice;
     38 #ifdef TENSORFLOW_USE_SYCL
     39 typedef Eigen::SyclDevice SYCLDevice;
     40 #endif  // TENSORFLOW_USE_SYCL
     41 
     42 namespace {
     43 
     44 void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
     45                       int32* batch, int32* height, int32* width,
     46                       int32* channel) {
     47   *batch = 1;
     48   *width = 1;
     49   *height = 1;
     50   *channel = 1;
     51   if (data_format == FORMAT_NHWC) {
     52     int32 channel_dim = value_tensor.dims() - 1;
     53     *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
     54     for (int32 i = 0; i < channel_dim; i++) {
     55       *batch *= static_cast<int32>(value_tensor.dim_size(i));
     56     }
     57   } else if (data_format == FORMAT_NCHW) {
     58     int32 channel_dim = value_tensor.dims() - 3;
     59     int32 height_dim = value_tensor.dims() - 2;
     60     int32 width_dim = value_tensor.dims() - 1;
     61     *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
     62     *height = static_cast<int32>(value_tensor.dim_size(height_dim));
     63     *width = static_cast<int32>(value_tensor.dim_size(width_dim));
     64     for (int32 i = 0; i < channel_dim; i++) {
     65       *batch *= static_cast<int32>(value_tensor.dim_size(i));
     66     }
     67   }
     68 }
     69 
     70 template <class T>
     71 struct AccumulatorType {
     72   typedef T type;
     73 };
     74 
     75 // float is faster on the CPU than half, and also more precise,
     76 // so use float for the temporary accumulators.
     77 template <>
     78 struct AccumulatorType<Eigen::half> {
     79   typedef float type;
     80 };
     81 
     82 }  // namespace
     83 
     84 template <typename Device, typename T>
     85 class BiasOp : public BinaryOp<T> {
     86  public:
     87   explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
     88     string data_format;
     89     if (context->GetAttr("data_format", &data_format).ok()) {
     90       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     91                   errors::InvalidArgument("Invalid data format"));
     92     } else {
     93       data_format_ = FORMAT_NHWC;
     94     }
     95   }
     96 
     97   void Compute(OpKernelContext* context) override {
     98     const Tensor& input = context->input(0);
     99     const Tensor& bias = context->input(1);
    100 
    101     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
    102                 errors::InvalidArgument("Input tensor must be at least 2D: ",
    103                                         input.shape().DebugString()));
    104     OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
    105                 errors::InvalidArgument("Biases must be 1D: ",
    106                                         bias.shape().DebugString()));
    107 
    108     // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
    109     size_t channel_dim;
    110     if (data_format_ == FORMAT_NCHW) {
    111       OP_REQUIRES(context, input.dims() == 4,
    112                   errors::InvalidArgument(
    113                       "NCHW format supports only 4D input tensor."));
    114       channel_dim = 1;
    115     } else {
    116       channel_dim = input.shape().dims() - 1;  // End of code by intel_tf.
    117     }
    118 
    119     OP_REQUIRES(
    120         context,
    121         bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
    122         errors::InvalidArgument(
    123             "Must provide as many biases as the last dimension "
    124             "of the input tensor: ",
    125             bias.shape().DebugString(), " vs. ", input.shape().DebugString()));
    126 
    127     Tensor* output = nullptr;
    128     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    129                                 {0}, 0, input.shape(), &output));
    130     if (input.NumElements() == 0) return;
    131 
    132     // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
    133     if (data_format_ == FORMAT_NCHW) {
    134       int32 batch, height, width, channel;
    135       GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
    136       Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
    137       Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
    138       const Device& d = context->eigen_device<Device>();
    139       output->tensor<T, 4>().device(d) =
    140           input.tensor<T, 4>() +
    141           bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
    142       return;
    143     }  // End of code by intel_tf.
    144 
    145     switch (input.shape().dims()) {
    146       case 2:
    147         Compute<2>(context, input, bias, output);
    148         break;
    149       case 3:
    150         Compute<3>(context, input, bias, output);
    151         break;
    152       case 4:
    153         Compute<4>(context, input, bias, output);
    154         break;
    155       case 5:
    156         Compute<5>(context, input, bias, output);
    157         break;
    158       default:
    159         OP_REQUIRES(context, false,
    160                     errors::InvalidArgument("Only ranks up to 5 supported: ",
    161                                             input.shape().DebugString()));
    162     }
    163   }
    164 
    165   // Add biases for an input matrix of rank Dims, by using the Bias.
    166   template <int Dims>
    167   void Compute(OpKernelContext* ctx, const Tensor& input, const Tensor& bias,
    168                Tensor* output) {
    169     functor::Bias<Device, T, Dims> functor;
    170     functor(ctx->eigen_device<Device>(), input.tensor<T, Dims>(), bias.vec<T>(),
    171             output->tensor<T, Dims>());
    172   }
    173 
    174  private:
    175   TensorFormat data_format_;
    176 };
    177 
    178 #define REGISTER_KERNEL(type)                                         \
    179   REGISTER_KERNEL_BUILDER(                                            \
    180       Name("BiasAdd").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
    181       BiasOp<CPUDevice, type>);                                       \
    182   REGISTER_KERNEL_BUILDER(                                            \
    183       Name("BiasAddV1").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    184       BiasOp<CPUDevice, type>);
    185 
    186 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
    187 #undef REGISTER_KERNEL
    188 
    189 #ifdef TENSORFLOW_USE_SYCL
    190 #define REGISTER_KERNEL(type)                                          \
    191   REGISTER_KERNEL_BUILDER(                                             \
    192       Name("BiasAdd").Device(DEVICE_SYCL).TypeConstraint<type>("T"),   \
    193       BiasOp<SYCLDevice, type>);                                       \
    194   REGISTER_KERNEL_BUILDER(                                             \
    195       Name("BiasAddV1").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    196       BiasOp<SYCLDevice, type>);
    197 
    198 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL);
    199 REGISTER_KERNEL(float);
    200 REGISTER_KERNEL(double);
    201 #undef REGISTER_KERNEL
    202 #endif  // TENSORFLOW_USE_SYCL
    203 
    204 template <typename Device, typename T>
    205 class BiasGradOp : public OpKernel {
    206  public:
    207   explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
    208     string data_format;
    209     if (context->GetAttr("data_format", &data_format).ok()) {
    210       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    211                   errors::InvalidArgument("Invalid data format"));
    212     } else {
    213       data_format_ = FORMAT_NHWC;
    214     }
    215   }
    216 
    217   void Compute(OpKernelContext* context) override {
    218     const Tensor& output_backprop = context->input(0);
    219 
    220     OP_REQUIRES(context,
    221                 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
    222                 errors::InvalidArgument("Input tensor must be at least 2D: ",
    223                                         output_backprop.shape().DebugString()));
    224 
    225     OP_REQUIRES(
    226         context,
    227         FastBoundsCheck(output_backprop.NumElements(),
    228                         std::numeric_limits<int32>::max()),
    229         errors::InvalidArgument("BiasGrad requires tensor size <= int32 max"));
    230 
    231     int32 batch, height, width, channel;
    232     GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
    233                      &channel);
    234     Tensor* output = nullptr;
    235     TensorShape output_shape{channel};
    236     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    237 
    238     if (channel == 0) {
    239       return;  // Nothing to do
    240     } else if (output_backprop.NumElements() == 0) {
    241       // Eigen often crashes by design on empty tensors, but setZero is safe
    242       output->template flat<T>().setZero();
    243     } else {
    244       // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
    245       if (data_format_ == FORMAT_NCHW) {
    246         OP_REQUIRES(context, output_backprop.dims() == 4,
    247                     errors::InvalidArgument(
    248                         "NCHW format supports only 4D input/output tensor."));
    249         Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
    250 #ifdef EIGEN_HAS_INDEX_LIST
    251         using idx0 = Eigen::type2index<0>;
    252         using idx2 = Eigen::type2index<2>;
    253         using idx3 = Eigen::type2index<3>;
    254         Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
    255 #else
    256         Eigen::array<int, 3> reduction_axes = {0, 2, 3};
    257 #endif
    258         output->template flat<T>().device(context->eigen_device<Device>()) =
    259             output_backprop.flat<T>()
    260                 .template cast<typename AccumulatorType<T>::type>()
    261                 .reshape(four_dims)
    262                 .sum(reduction_axes)
    263                 .template cast<T>();  // End of code by intel_tf.
    264       } else {
    265         Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
    266 #ifdef EIGEN_HAS_INDEX_LIST
    267         Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
    268 #else
    269         Eigen::array<int, 1> reduction_axis = {0};
    270 #endif
    271         output->template flat<T>().device(context->eigen_device<Device>()) =
    272             output_backprop.flat<T>()
    273                 .template cast<typename AccumulatorType<T>::type>()
    274                 .reshape(two_dims)
    275                 .sum(reduction_axis)
    276                 .template cast<T>();
    277       }
    278     }
    279   }
    280 
    281  private:
    282   TensorFormat data_format_;
    283 };
    284 
    285 // Registration of the GPU implementations.
    286 #define REGISTER_KERNEL(type)                                           \
    287   REGISTER_KERNEL_BUILDER(                                              \
    288       Name("BiasAddGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    289       BiasGradOp<CPUDevice, type>);
    290 
    291 TF_CALL_NUMBER_TYPES(REGISTER_KERNEL);
    292 #undef REGISTER_KERNEL
    293 
    294 #ifdef TENSORFLOW_USE_SYCL
    295 #define REGISTER_KERNEL(type)                                            \
    296   REGISTER_KERNEL_BUILDER(                                               \
    297       Name("BiasAddGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
    298       BiasGradOp<SYCLDevice, type>);
    299 
    300 TF_CALL_INTEGRAL_TYPES(REGISTER_KERNEL);
    301 REGISTER_KERNEL(float);
    302 REGISTER_KERNEL(double);
    303 #undef REGISTER_KERNEL
    304 #endif  // TENSORFLOW_USE_SYCL
    305 
    306 #if GOOGLE_CUDA
    307 template <typename T>
    308 class BiasOp<GPUDevice, T> : public BinaryOp<T> {
    309  public:
    310   typedef GPUDevice Device;
    311   explicit BiasOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
    312     string data_format;
    313     if (context->GetAttr("data_format", &data_format).ok()) {
    314       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    315                   errors::InvalidArgument("Invalid data format"));
    316     } else {
    317       data_format_ = FORMAT_NHWC;
    318     }
    319   }
    320 
    321   void Compute(OpKernelContext* context) override {
    322     const Tensor& input = context->input(0);
    323     const Tensor& bias = context->input(1);
    324 
    325     OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input.shape()),
    326                 errors::InvalidArgument("Input tensor must be at least 2D: ",
    327                                         input.shape().DebugString()));
    328     OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
    329                 errors::InvalidArgument("Biases must be 1D: ",
    330                                         bias.shape().DebugString()));
    331     int32 batch, height, width, channel;
    332     GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
    333     OP_REQUIRES(context, bias.shape().dim_size(0) == channel,
    334                 errors::InvalidArgument(
    335                     "Must provide as many biases as the channel dimension "
    336                     "of the input tensor: ",
    337                     bias.shape().DebugString(), " vs. ", channel, " in ",
    338                     input.shape().DebugString()));
    339     Tensor* output = nullptr;
    340     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    341                                 {0}, 0, input.shape(), &output));
    342     if (input.NumElements() > 0) {
    343       BiasGPU<T>::compute(context->template eigen_device<Device>(),
    344                           input.flat<T>().data(), bias.flat<T>().data(),
    345                           output->flat<T>().data(), batch, width, height,
    346                           channel, data_format_);
    347     }
    348   }
    349 
    350  private:
    351   TensorFormat data_format_;
    352 };
    353 
    354 // Registration of the GPU implementations.
    355 #define REGISTER_GPU_KERNEL(type)                                     \
    356   REGISTER_KERNEL_BUILDER(                                            \
    357       Name("BiasAdd").Device(DEVICE_GPU).TypeConstraint<type>("T"),   \
    358       BiasOp<GPUDevice, type>);                                       \
    359   REGISTER_KERNEL_BUILDER(                                            \
    360       Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    361       BiasOp<GPUDevice, type>);
    362 
    363 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
    364 #undef REGISTER_GPU_KERNEL
    365 
    366 template <typename T>
    367 class BiasGradOp<GPUDevice, T> : public OpKernel {
    368  public:
    369   typedef GPUDevice Device;
    370   explicit BiasGradOp(OpKernelConstruction* context) : OpKernel(context) {
    371     string data_format;
    372     if (context->GetAttr("data_format", &data_format).ok()) {
    373       OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    374                   errors::InvalidArgument("Invalid data format"));
    375     } else {
    376       data_format_ = FORMAT_NCHW;
    377     }
    378   }
    379 
    380   void Compute(OpKernelContext* context) override {
    381     const Tensor& output_backprop = context->input(0);
    382 
    383     OP_REQUIRES(context,
    384                 TensorShapeUtils::IsMatrixOrHigher(output_backprop.shape()),
    385                 errors::InvalidArgument("Input tensor must be at least 2D: ",
    386                                         output_backprop.shape().DebugString()));
    387     int32 batch, height, width, channel;
    388     GetBiasValueDims(output_backprop, data_format_, &batch, &height, &width,
    389                      &channel);
    390     Tensor* output = nullptr;
    391     TensorShape output_shape{channel};
    392     OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
    393     if (channel == 0) return;
    394     auto* stream = context->op_device_context()->stream();
    395     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    396     perftools::gputools::DeviceMemoryBase output_ptr(
    397         output->flat<T>().data(), output->NumElements() * sizeof(T));
    398     stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
    399     if (output_backprop.NumElements() > 0) {
    400       BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
    401                               output_backprop.template flat<T>().data(),
    402                               output->flat<T>().data(), batch, width, height,
    403                               channel, data_format_);
    404     }
    405   }
    406 
    407  private:
    408   TensorFormat data_format_;
    409 };
    410 
    411 // Registration of the GPU implementations.
    412 #define REGISTER_GPU_KERNEL(type)                                       \
    413   REGISTER_KERNEL_BUILDER(                                              \
    414       Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
    415       BiasGradOp<GPUDevice, type>);
    416 
    417 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
    418 #undef REGISTER_GPU_KERNEL
    419 
    420 #endif  // GOOGLE_CUDA
    421 
    422 }  // namespace tensorflow
    423