Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/batch_norm_op.h"
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/numeric_op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 
     27 namespace tensorflow {
     28 
     29 typedef Eigen::ThreadPoolDevice CPUDevice;
     30 typedef Eigen::GpuDevice GPUDevice;
     31 #ifdef TENSORFLOW_USE_SYCL
     32 typedef Eigen::SyclDevice SYCLDevice;
     33 #endif  // TENSORFLOW_USE_SYCL
     34 
     35 template <typename Device, typename T>
     36 class BatchNormOp : public OpKernel {
     37  public:
     38   explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
     39     float variance_epsilon;
     40     OP_REQUIRES_OK(context,
     41                    context->GetAttr("variance_epsilon", &variance_epsilon));
     42     variance_epsilon_ = T(variance_epsilon);
     43     OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
     44                                              &scale_after_normalization_));
     45   }
     46 
     47   void Compute(OpKernelContext* context) override {
     48     const Tensor& input = context->input(0);
     49     const Tensor& mean = context->input(1);
     50     const Tensor& var = context->input(2);
     51     const Tensor& beta = context->input(3);
     52     const Tensor& gamma = context->input(4);
     53 
     54     OP_REQUIRES(context, input.dims() == 4,
     55                 errors::InvalidArgument("input must be 4-dimensional",
     56                                         input.shape().DebugString()));
     57     OP_REQUIRES(context, mean.dims() == 1,
     58                 errors::InvalidArgument("mean must be 1-dimensional",
     59                                         mean.shape().DebugString()));
     60     OP_REQUIRES(context, var.dims() == 1,
     61                 errors::InvalidArgument("var must be 1-dimensional",
     62                                         var.shape().DebugString()));
     63     OP_REQUIRES(context, beta.dims() == 1,
     64                 errors::InvalidArgument("beta must be 1-dimensional",
     65                                         beta.shape().DebugString()));
     66     OP_REQUIRES(context, gamma.dims() == 1,
     67                 errors::InvalidArgument("gamma must be 1-dimensional",
     68                                         gamma.shape().DebugString()));
     69 
     70     Tensor* output = nullptr;
     71     OP_REQUIRES_OK(context,
     72                    context->allocate_output(0, input.shape(), &output));
     73 
     74     functor::BatchNorm<Device, T>()(
     75         context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
     76         var.vec<T>(), beta.vec<T>(), gamma.vec<T>(), variance_epsilon_,
     77         scale_after_normalization_, output->tensor<T, 4>());
     78   }
     79 
     80  private:
     81   T variance_epsilon_;
     82   bool scale_after_normalization_;
     83 };
     84 
     85 template <typename Device, typename T>
     86 class BatchNormGradOp : public OpKernel {
     87  public:
     88   explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
     89     float variance_epsilon;
     90     OP_REQUIRES_OK(context,
     91                    context->GetAttr("variance_epsilon", &variance_epsilon));
     92     variance_epsilon_ = T(variance_epsilon);
     93     OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
     94                                              &scale_after_normalization_));
     95   }
     96 
     97   void Compute(OpKernelContext* context) override {
     98     const Tensor& input = context->input(0);
     99     const Tensor& mean = context->input(1);
    100     const Tensor& var = context->input(2);
    101     const Tensor& gamma = context->input(3);
    102     const Tensor& out_backprop = context->input(4);
    103 
    104     OP_REQUIRES(context, input.dims() == 4,
    105                 errors::InvalidArgument("input must be 4-dimensional",
    106                                         input.shape().DebugString()));
    107     OP_REQUIRES(context, mean.dims() == 1,
    108                 errors::InvalidArgument("mean must be 1-dimensional",
    109                                         mean.shape().DebugString()));
    110     OP_REQUIRES(context, var.dims() == 1,
    111                 errors::InvalidArgument("var must be 1-dimensional",
    112                                         var.shape().DebugString()));
    113     OP_REQUIRES(context, gamma.dims() == 1,
    114                 errors::InvalidArgument("gamma must be 1-dimensional",
    115                                         gamma.shape().DebugString()));
    116     OP_REQUIRES(context, out_backprop.dims() == 4,
    117                 errors::InvalidArgument("out_backprop must be 4-dimensional",
    118                                         out_backprop.shape().DebugString()));
    119 
    120     Tensor* dx = nullptr;
    121     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    122                                 {0, 4}, 0, input.shape(), &dx));
    123     Tensor* dm = nullptr;
    124     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    125                                 {1}, 1, mean.shape(), &dm));
    126     Tensor* dv = nullptr;
    127     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    128                                 {2}, 2, var.shape(), &dv));
    129     Tensor* db = nullptr;
    130     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    131                                 {3}, 3, mean.shape(), &db));
    132     Tensor* dg = nullptr;
    133     OP_REQUIRES_OK(context, context->allocate_output(4, gamma.shape(), &dg));
    134 
    135     // Scratch buffer of [depth] dimension, aka the 4th dimension of input,
    136     // which is dim_size(3), for calculating various combinations of
    137     // (var + epsilon).
    138     Tensor scratch1;
    139     OP_REQUIRES_OK(context, context->allocate_temp(
    140                                 DataTypeToEnum<T>::value,
    141                                 TensorShape({input.dim_size(3)}), &scratch1));
    142 
    143     // Scratch buffer of [depth] dimension for saving intermediate calculation
    144     // values.
    145     Tensor scratch2;
    146     OP_REQUIRES_OK(context, context->allocate_temp(
    147                                 DataTypeToEnum<T>::value,
    148                                 TensorShape({input.dim_size(3)}), &scratch2));
    149 
    150     functor::BatchNormGrad<Device, T>()(
    151         context->eigen_device<Device>(), input.tensor<T, 4>(), mean.vec<T>(),
    152         var.vec<T>(), gamma.vec<T>(), out_backprop.tensor<T, 4>(),
    153         variance_epsilon_, scale_after_normalization_, dx->tensor<T, 4>(),
    154         dm->vec<T>(), dv->vec<T>(), db->vec<T>(), dg->vec<T>(),
    155         scratch1.vec<T>(), scratch2.vec<T>());
    156   }
    157 
    158  private:
    159   T variance_epsilon_;
    160   bool scale_after_normalization_;
    161 };
    162 
    163 #define REGISTER_KERNEL(T)                                         \
    164   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
    165                               .Device(DEVICE_CPU)                  \
    166                               .TypeConstraint<T>("T"),             \
    167                           BatchNormOp<CPUDevice, T>);
    168 
    169 TF_CALL_half(REGISTER_KERNEL);
    170 TF_CALL_float(REGISTER_KERNEL);
    171 TF_CALL_double(REGISTER_KERNEL);
    172 #undef REGISTER_KERNEL
    173 
    174 #if GOOGLE_CUDA
    175 // Forward declarations of the functor specializations for GPU.
    176 namespace functor {
    177 #define DECLARE_GPU_SPEC(T)                                                  \
    178   template <>                                                                \
    179   void BatchNorm<GPUDevice, T>::operator()(                                  \
    180       const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,          \
    181       typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var,   \
    182       typename TTypes<T>::ConstVec beta, typename TTypes<T>::ConstVec gamma, \
    183       T variance_epsilon, bool scale_after_normalization,                    \
    184       typename TTypes<T, 4>::Tensor output);                                 \
    185   extern template struct BatchNorm<GPUDevice, T>;
    186 
    187 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
    188 
    189 TF_CALL_half(DECLARE_GPU_SPECS);
    190 TF_CALL_float(DECLARE_GPU_SPECS);
    191 #undef DECLARE_GPU_SPEC
    192 }  // namespace functor
    193 
    194 // Registration of the GPU implementations.
    195 #define REGISTER_GPU_KERNEL(T)                                     \
    196   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
    197                               .Device(DEVICE_GPU)                  \
    198                               .TypeConstraint<T>("T"),             \
    199                           BatchNormOp<GPUDevice, T>);
    200 
    201 TF_CALL_half(REGISTER_GPU_KERNEL);
    202 TF_CALL_float(REGISTER_GPU_KERNEL);
    203 #undef REGISTER_GPU_KERNEL
    204 
    205 #endif  // GOOGLE_CUDA
    206 
    207 #if TENSORFLOW_USE_SYCL
    208 #define REGISTER_KERNEL(T)                                         \
    209   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
    210                               .Device(DEVICE_SYCL)                 \
    211                               .TypeConstraint<T>("T"),             \
    212                           BatchNormOp<SYCLDevice, T>);
    213 
    214 TF_CALL_float(REGISTER_KERNEL);
    215 TF_CALL_double(REGISTER_KERNEL);
    216 #undef REGISTER_KERNEL
    217 #endif  // TENSORFLOW_USE_SYCL
    218 
    219 #define REGISTER_KERNEL(T)                                             \
    220   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
    221                               .Device(DEVICE_CPU)                      \
    222                               .TypeConstraint<T>("T"),                 \
    223                           BatchNormGradOp<CPUDevice, T>);
    224 
    225 TF_CALL_half(REGISTER_KERNEL);
    226 TF_CALL_float(REGISTER_KERNEL);
    227 TF_CALL_double(REGISTER_KERNEL);
    228 #undef REGISTER_KERNEL
    229 
    230 #if GOOGLE_CUDA
    231 // Forward declarations of the functor specializations for GPU.
    232 namespace functor {
    233 #define DECLARE_GPU_SPEC(T)                                                \
    234   template <>                                                              \
    235   void BatchNormGrad<GPUDevice, T>::operator()(                            \
    236       const GPUDevice& d, typename TTypes<T, 4>::ConstTensor input,        \
    237       typename TTypes<T>::ConstVec mean, typename TTypes<T>::ConstVec var, \
    238       typename TTypes<T>::ConstVec gamma,                                  \
    239       typename TTypes<T, 4>::ConstTensor out_backprop, T variance_epsilon, \
    240       bool scale_after_normalization, typename TTypes<T, 4>::Tensor dx,    \
    241       typename TTypes<T>::Vec dm, typename TTypes<T>::Vec dv,              \
    242       typename TTypes<T>::Vec db, typename TTypes<T>::Vec dg,              \
    243       typename TTypes<T>::Vec scratch1, typename TTypes<T>::Vec scratch2); \
    244   extern template struct BatchNormGrad<GPUDevice, T>;
    245 
    246 #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
    247 
    248 TF_CALL_half(DECLARE_GPU_SPECS);
    249 TF_CALL_float(DECLARE_GPU_SPECS);
    250 #undef DECLARE_GPU_SPEC
    251 }  // namespace functor
    252 
    253 // Registration of the GPU implementations.
    254 #define REGISTER_GPU_KERNEL(T)                                         \
    255   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
    256                               .Device(DEVICE_GPU)                      \
    257                               .TypeConstraint<T>("T"),                 \
    258                           BatchNormGradOp<GPUDevice, T>);
    259 
    260 TF_CALL_half(REGISTER_GPU_KERNEL);
    261 TF_CALL_float(REGISTER_GPU_KERNEL);
    262 #undef REGISTER_GPU_KERNEL
    263 
    264 #endif  // GOOGLE_CUDA
    265 
    266 #if TENSORFLOW_USE_SYCL
    267 #define REGISTER_KERNEL(T)                                             \
    268   REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
    269                               .Device(DEVICE_SYCL)                     \
    270                               .TypeConstraint<T>("T"),                 \
    271                           BatchNormGradOp<SYCLDevice, T>);
    272 
    273 TF_CALL_float(REGISTER_KERNEL);
    274 TF_CALL_double(REGISTER_KERNEL);
    275 #undef REGISTER_KERNEL
    276 
    277 #endif  // TENSORFLOW_USE_SYCL
    278 
    279 }  // namespace tensorflow
    280