Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // LRN = Local Response Normalization
     17 // See docs in ../ops/nn_ops.cc.
     18 
     19 #define EIGEN_USE_THREADS
     20 
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/kernels/bounds_check.h"
     26 #include "tensorflow/core/kernels/ops_util.h"
     27 #include "tensorflow/core/lib/core/errors.h"
     28 
     29 #if !defined(IS_MOBILE_PLATFORM)
     30 #include "tensorflow/core/util/work_sharder.h"
     31 #endif
     32 
     33 #if GOOGLE_CUDA
     34 #include "cuda/include/cuda.h"
     35 #include "tensorflow/core/platform/stream_executor.h"
     36 #include "tensorflow/core/util/stream_executor_util.h"
     37 #endif  // GOOGLE_CUDA
     38 
     39 namespace tensorflow {
     40 
     41 namespace {
     42 
     43 // When the depth is large and beta_ is 0.5 or 1.0, Single-threaded
     44 // LRN is faster than the main band matrix approach used
     45 // below. Benchmarks suggest switching to SingleThreadedLRN when depth > 384.
     46 const int kSingleThreadedLRNDepthCutoff = 384;
     47 
     48 // Create a depth-by-depth band matrix with 1s along a swath of size (2 *
     49 // depth_radius + 1) around the diagonal.
     50 template <typename T>
     51 void GetBandMatrix(int depth, int depth_radius,
     52                    Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
     53   result->setZero();
     54   for (int row = 0; row < depth; ++row) {
     55     const int begin = std::max<int>(0, row - depth_radius);
     56     const int end = std::min<int>(depth, row + depth_radius + 1);
     57     Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
     58     Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
     59     result->slice(start, sizes).setConstant(T(1));
     60   }
     61 }
     62 
     63 }  // namespace
     64 
     65 typedef Eigen::ThreadPoolDevice CPUDevice;
     66 typedef Eigen::GpuDevice GPUDevice;
     67 
     68 template <typename Device, typename T>
     69 struct LaunchLRN;
     70 
     71 template <typename T>
     72 struct LaunchLRN<CPUDevice, T> {
     73   LaunchLRN(int depth_radius, T bias, T alpha, T beta)
     74       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
     75 
     76   void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
     77               Tensor* output) {
     78     const int batch = static_cast<int>(in.dim_size(0));
     79     const int rows = static_cast<int>(in.dim_size(1));
     80     const int cols = static_cast<int>(in.dim_size(2));
     81     const int depth = static_cast<int>(in.dim_size(3));
     82 
     83 #if defined(IS_MOBILE_PLATFORM)
     84     SingleThreadedLRN(in, batch, rows, cols, depth, output);
     85 #else
     86     const int nodes = cols * rows;
     87     if (depth > kSingleThreadedLRNDepthCutoff &&
     88         (beta_ == T(0.5) || beta_ == T(1))) {
     89       SingleThreadedLRN(in, batch, rows, cols, depth, output);
     90       return;
     91     }
     92 
     93     auto in_shaped = in.shaped<T, 2>({nodes * batch, depth});
     94 
     95     // Multiplying the input with the band matrix has the effect of reducing the
     96     // correct patch along the depth.
     97     Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
     98     GetBandMatrix<T>(depth, depth_radius_, &multiplier);
     99 
    100     auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
    101     Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
    102     auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
    103     if (beta_ == T(1)) {
    104       out_shaped.device(context->eigen_cpu_device()) =
    105           in_shaped * tmp.inverse();
    106     } else if (beta_ == T(0.5)) {
    107       out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
    108     } else {
    109       out_shaped.device(context->eigen_cpu_device()) =
    110           in_shaped * (tmp.log() * -beta_).exp();
    111     }
    112 #endif
    113   }
    114 
    115  private:
    116   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
    117 
    118   void SingleThreadedLRN(const Tensor& in, const int batch, const int rows,
    119                          const int cols, const int depth, Tensor* out) {
    120     Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> data_in(
    121         in.flat<T>().data(), depth, batch * rows * cols);
    122 
    123     Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> data_out(
    124         out->flat<T>().data(), depth, batch * rows * cols);
    125 
    126     const int double_depth_radius = depth_radius_ * 2;
    127     Eigen::Matrix<T, Eigen::Dynamic, 1> padded_square(data_in.rows() +
    128                                                       double_depth_radius);
    129     padded_square.setZero();
    130     for (int r = 0; r < data_in.cols(); ++r) {
    131       // Do local response normalization for data_in(:, r). First, compute the
    132       // square and store them in buffer for repeated use.
    133       padded_square.block(depth_radius_, 0, data_out.rows(), 1) =
    134           data_in.col(r).cwiseProduct(data_in.col(r)) * alpha_;
    135       // Then, compute the scale and write it to data_out.
    136       T accumulated_scale(0);
    137       for (int i = 0; i < double_depth_radius; ++i) {
    138         accumulated_scale += padded_square(i);
    139       }
    140       for (int i = 0; i < data_in.rows(); ++i) {
    141         accumulated_scale += padded_square(i + double_depth_radius);
    142         data_out(i, r) = bias_ + accumulated_scale;
    143         accumulated_scale -= padded_square(i);
    144       }
    145     }
    146 
    147     if (beta_ == T(1)) {
    148       data_out.array() = data_in.array() * data_out.array().inverse();
    149     } else if (beta_ == T(0.5)) {
    150       data_out.array() = data_in.array() * data_out.array().rsqrt();
    151     } else {
    152       data_out.array() =
    153           data_in.array() * (data_out.array().log() * -beta_).exp();
    154     }
    155   }
    156 
    157   int depth_radius_;
    158   T bias_;
    159   T alpha_;
    160   T beta_;
    161 };
    162 
    163 #if GOOGLE_CUDA
    164 
    165 template <typename T>
    166 struct LaunchLRN<GPUDevice, T> {
    167   LaunchLRN(int depth_radius, T bias, T alpha, T beta)
    168       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
    169 
    170   void launch(OpKernelContext* context, OpKernel* kernel, const Tensor& in,
    171               Tensor* output) {
    172     OP_REQUIRES(
    173         context, beta_ >= 0.01,
    174         errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
    175 
    176     OP_REQUIRES(
    177         context, depth_radius_ > 0 && depth_radius_ <= 7,
    178         errors::InvalidArgument("cuDNN requires depth_radius in [1, 7], got: ",
    179                                 depth_radius_));
    180     OP_REQUIRES(
    181         context, bias_ >= 1e-5,
    182         errors::InvalidArgument("cuDNN requires bias >= 1e-5, got: ", bias_));
    183 
    184     // Cast to platform-specific int to avoid conversion warnings.
    185     const int batch = static_cast<int>(in.dim_size(0));
    186     const int rows = static_cast<int>(in.dim_size(1));
    187     const int cols = static_cast<int>(in.dim_size(2));
    188     const int depth = static_cast<int>(in.dim_size(3));
    189 
    190     perftools::gputools::dnn::BatchDescriptor dimensions_desc;
    191     dimensions_desc.set_count(batch)
    192         .set_height(rows)
    193         .set_width(cols)
    194         .set_feature_map_count(depth)
    195         .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
    196 
    197     perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
    198     normalize_desc.set_bias(bias_)
    199         .set_range(depth_radius_)
    200         .set_alpha(alpha_)
    201         .set_beta(beta_);
    202 
    203     auto input_data = StreamExecutorUtil::AsDeviceMemory<T>(in);
    204     auto output_data = StreamExecutorUtil::AsDeviceMemory<T>(*output);
    205 
    206     auto* stream = context->op_device_context()->stream();
    207     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    208 
    209     bool status =
    210         stream
    211             ->ThenNormalizeWithDimensions(normalize_desc, dimensions_desc,
    212                                           input_data, &output_data)
    213             .ok();
    214     OP_REQUIRES(context, status,
    215                 errors::Internal("NormalizeWithDimensions launch failed"));
    216   }
    217 
    218   int depth_radius_;
    219   T bias_;
    220   T alpha_;
    221   T beta_;
    222 };
    223 
    224 #endif  // GOOGLE_CUDA
    225 
    226 template <typename Device, typename T>
    227 class LRNOp : public OpKernel {
    228  public:
    229   explicit LRNOp(OpKernelConstruction* context) : OpKernel(context) {
    230     int64 depth_radius64;
    231     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
    232     OP_REQUIRES(
    233         context,
    234         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
    235         errors::InvalidArgument("depth_radius = ", depth_radius64,
    236                                 " larger than int max"));
    237     depth_radius_ = static_cast<int>(depth_radius64);
    238     float tmp;
    239     OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp));
    240     bias_ = T(tmp);
    241     OP_REQUIRES_OK(context, context->GetAttr("alpha", &tmp));
    242     alpha_ = T(tmp);
    243     OP_REQUIRES_OK(context, context->GetAttr("beta", &tmp));
    244     beta_ = T(tmp);
    245   }
    246 
    247   void Compute(OpKernelContext* context) override {
    248     const Tensor& in = context->input(0);
    249     OP_REQUIRES(context, in.dims() == 4,
    250                 errors::InvalidArgument("in must be 4-dimensional"));
    251     OP_REQUIRES(
    252         context,
    253         FastBoundsCheck(in.NumElements(), std::numeric_limits<int>::max()),
    254         errors::InvalidArgument("argument to LRN too large"));
    255     // Cast to platform-specific int to avoid conversion warnings.
    256     const int batch = static_cast<int>(in.dim_size(0));
    257     const int rows = static_cast<int>(in.dim_size(1));
    258     const int cols = static_cast<int>(in.dim_size(2));
    259     const int depth = static_cast<int>(in.dim_size(3));
    260 
    261     OP_REQUIRES(context,
    262                 (depth + depth_radius_) <= std::numeric_limits<int>::max(),
    263                 errors::InvalidArgument("depth ", depth, " + depth_radius ",
    264                                         depth_radius_, " exceeds int max."));
    265 
    266     Tensor* output = nullptr;
    267     OP_REQUIRES_OK(context,
    268                    context->allocate_output(
    269                        0, TensorShape({batch, rows, cols, depth}), &output));
    270 
    271     LaunchLRN<Device, T> launcher(depth_radius_, bias_, alpha_, beta_);
    272     launcher.launch(context, this, in, output);
    273   }
    274 
    275  private:
    276   int depth_radius_;
    277   T bias_;
    278   T alpha_;
    279   T beta_;
    280 };
    281 
    282 #define REGISTER_CPU(T)                                      \
    283   REGISTER_KERNEL_BUILDER(                                   \
    284       Name("LRN").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    285       LRNOp<CPUDevice, T>);
    286 TF_CALL_float(REGISTER_CPU);
    287 TF_CALL_half(REGISTER_CPU);
    288 
    289 #undef REGISTER_CPU
    290 
    291 #if GOOGLE_CUDA
    292 
    293 #define REGISTER_GPU(T)                                      \
    294   REGISTER_KERNEL_BUILDER(                                   \
    295       Name("LRN").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    296       LRNOp<GPUDevice, T>);
    297 TF_CALL_float(REGISTER_GPU);
    298 
    299 #undef REGISTER_GPU
    300 
    301 #endif  // GOOGLE_CUDA
    302 
    303 #if !defined(IS_MOBILE_PLATFORM)
    304 
    305 template <typename Device, typename T>
    306 struct LaunchLRNGrad;
    307 
    308 template <typename T>
    309 struct LaunchLRNGrad<CPUDevice, T> {
    310   LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta)
    311       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
    312 
    313   void launch(OpKernelContext* context, OpKernel* kernel,
    314               const Tensor& in_grads, const Tensor& in_image,
    315               const Tensor& out_image, Tensor* output) {
    316     const int64 batch = in_grads.dim_size(0);
    317     const int64 rows = in_grads.dim_size(1);
    318     const int64 cols = in_grads.dim_size(2);
    319     const int64 depth = in_grads.dim_size(3);
    320     const auto nodes = cols * rows;
    321     auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
    322     auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
    323     auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
    324 
    325     auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
    326     out_shaped.setZero();
    327 
    328     auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
    329                   depth](int64 begin, int64 end) {
    330       for (int64 i = begin; i < end; ++i) {
    331         for (int64 j = 0; j < depth; ++j) {
    332           // Let y be the LRN activations and x be the inputs along the depth
    333           // dimension. (LRN operates independently along rows, cols, and
    334           // batch).
    335           // We have
    336           // yi = xi / (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
    337           //      x_j^2))^beta
    338           //
    339           // Let N = (bias + alpha(sum_j_{i - depth_radius}^{i + depth_radius}
    340           //           x_j^2))
    341           // dy_i/dx_i = (N^beta - xi. beta*N^(beta-1)*2*alpha*xi)/N^(2*beta)
    342           // dy_i/dx_j = (       - xi. beta*N^(beta-1)*2*alpha*xj)/N^(2*beta)
    343           //
    344           // NOTE(keveman) : We can compute N by doing (yi/xi) ^ (1/beta).
    345           // However, this is numerically unstable for small values of xi. We
    346           // compute N explicitly here to avoid that.
    347 
    348           int64 depth_begin = std::max<int64>(0, j - depth_radius_);
    349           int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
    350 
    351           T norm(0);
    352           for (int64 k = depth_begin; k < depth_end; ++k) {
    353             norm += in_shaped(i, k) * in_shaped(i, k);
    354           }
    355           norm = alpha_ * norm + bias_;
    356           DCHECK_GT(norm, T(1e-6));
    357           for (int64 k = depth_begin; k < depth_end; ++k) {
    358             T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
    359                     activations(i, j) / norm;
    360             if (k == j) {
    361               dyi += Eigen::numext::pow(norm, -beta_);
    362             }
    363             dyi *= grads_shaped(i, j);
    364             const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi;
    365           }
    366         }
    367       }
    368     };
    369     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    370     Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
    371           depth * depth, shard);
    372   }
    373 
    374   int depth_radius_;
    375   T bias_;
    376   T alpha_;
    377   T beta_;
    378 };
    379 
    380 #if GOOGLE_CUDA
    381 
    382 template <typename T>
    383 struct LaunchLRNGrad<GPUDevice, T> {
    384   LaunchLRNGrad(int depth_radius, T bias, T alpha, T beta)
    385       : depth_radius_(depth_radius), bias_(bias), alpha_(alpha), beta_(beta) {}
    386 
    387   void launch(OpKernelContext* context, OpKernel* kernel,
    388               const Tensor& in_grads, const Tensor& in_image,
    389               const Tensor& out_image, Tensor* output) {
    390     OP_REQUIRES(
    391         context, beta_ >= 0.01,
    392         errors::InvalidArgument("cuDNN requires beta >= 0.01, got: ", beta_));
    393 
    394     OP_REQUIRES(
    395         context, depth_radius_ > 0 && depth_radius_ <= 7,
    396         errors::InvalidArgument("cuDNN requires depth_radius in [1, 7], got: ",
    397                                 depth_radius_));
    398     OP_REQUIRES(
    399         context, bias_ >= 1e-5,
    400         errors::InvalidArgument("cuDNN requires bias >= 1e-5, got: ", bias_));
    401 
    402     const int64 batch = in_grads.dim_size(0);
    403     const int64 rows = in_grads.dim_size(1);
    404     const int64 cols = in_grads.dim_size(2);
    405     const int64 depth = in_grads.dim_size(3);
    406 
    407     perftools::gputools::dnn::BatchDescriptor dimensions_desc;
    408     dimensions_desc.set_count(batch)
    409         .set_height(rows)
    410         .set_width(cols)
    411         .set_feature_map_count(depth)
    412         .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
    413 
    414     perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
    415     normalize_desc.set_bias(bias_)
    416         .set_range(depth_radius_)
    417         .set_alpha(alpha_)
    418         .set_beta(beta_);
    419 
    420     auto input_grads_data = StreamExecutorUtil::AsDeviceMemory<T>(in_grads);
    421     auto input_image_data = StreamExecutorUtil::AsDeviceMemory<T>(in_image);
    422     auto output_image_data = StreamExecutorUtil::AsDeviceMemory<T>(out_image);
    423     auto output_grads_data = StreamExecutorUtil::AsDeviceMemory<T>(*output);
    424 
    425     auto* stream = context->op_device_context()->stream();
    426     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    427 
    428     bool status =
    429         stream
    430             ->ThenNormalizeBackwardWithDimensions(
    431                 normalize_desc, dimensions_desc, input_image_data,
    432                 output_image_data, input_grads_data, &output_grads_data)
    433             .ok();
    434     OP_REQUIRES(
    435         context, status,
    436         errors::Internal("NormalizeBackwardWithDimensions launch failed"));
    437   }
    438 
    439   int depth_radius_;
    440   T bias_;
    441   T alpha_;
    442   T beta_;
    443 };
    444 
    445 #endif  // GOOGLE_CUDA
    446 
    447 template <typename Device, typename T>
    448 class LRNGradOp : public OpKernel {
    449  public:
    450   explicit LRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
    451     int64 depth_radius64;
    452     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
    453     OP_REQUIRES(
    454         context,
    455         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
    456         errors::InvalidArgument("depth_radius = ", depth_radius64,
    457                                 " larger than int max"));
    458     depth_radius_ = static_cast<int>(depth_radius64);
    459     float tmp;
    460     OP_REQUIRES_OK(context, context->GetAttr("bias", &tmp));
    461     bias_ = T(tmp);
    462     OP_REQUIRES_OK(context, context->GetAttr("alpha", &tmp));
    463     alpha_ = T(tmp);
    464     OP_REQUIRES_OK(context, context->GetAttr("beta", &tmp));
    465     beta_ = T(tmp);
    466   }
    467 
    468   void Compute(OpKernelContext* context) override {
    469     const Tensor& in_grads = context->input(0);
    470     const Tensor& in_image = context->input(1);
    471     const Tensor& out_image = context->input(2);
    472 
    473     OP_REQUIRES(context, in_grads.dims() == 4 && in_image.dims() == 4,
    474                 errors::InvalidArgument("inputs must be 4-dimensional"));
    475     const int64 batch = in_grads.dim_size(0);
    476     const int64 rows = in_grads.dim_size(1);
    477     const int64 cols = in_grads.dim_size(2);
    478     const int64 depth = in_grads.dim_size(3);
    479     OP_REQUIRES(
    480         context,
    481         in_image.dim_size(0) == batch && in_image.dim_size(1) == rows &&
    482             in_image.dim_size(2) == cols && in_image.dim_size(3) == depth &&
    483             out_image.dim_size(0) == batch && out_image.dim_size(1) == rows &&
    484             out_image.dim_size(2) == cols && out_image.dim_size(3) == depth,
    485         errors::InvalidArgument(
    486             "input_grads, input_image, and out_image should have the same "
    487             "shape"));
    488 
    489     Tensor* output = nullptr;
    490     OP_REQUIRES_OK(context,
    491                    context->allocate_output(
    492                        0, TensorShape({batch, rows, cols, depth}), &output));
    493 
    494     LaunchLRNGrad<Device, T> launcher(depth_radius_, bias_, alpha_, beta_);
    495     launcher.launch(context, this, in_grads, in_image, out_image, output);
    496   }
    497 
    498  private:
    499   int depth_radius_;
    500   T bias_;
    501   T alpha_;
    502   T beta_;
    503 };
    504 
    505 #define REGISTER_CPU(T)                                          \
    506   REGISTER_KERNEL_BUILDER(                                       \
    507       Name("LRNGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    508       LRNGradOp<CPUDevice, T>);
    509 TF_CALL_float(REGISTER_CPU);
    510 TF_CALL_half(REGISTER_CPU);
    511 
    512 #undef REGISTER_CPU
    513 
    514 #if GOOGLE_CUDA
    515 
    516 #define REGISTER_GPU(T)                                          \
    517   REGISTER_KERNEL_BUILDER(                                       \
    518       Name("LRNGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
    519       LRNGradOp<GPUDevice, T>);
    520 TF_CALL_float(REGISTER_GPU);
    521 
    522 #undef REGISTER_GPU
    523 
    524 #endif  // GOOGLE_CUDA
    525 
    526 #endif  // !defined(IS_MOBILE_PLATFORM)
    527 
    528 }  // namespace tensorflow
    529