Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     19 #include "tensorflow/core/framework/kernel_def_builder.h"
     20 
     21 namespace tensorflow {
     22 namespace {
     23 
     24 // Local response normalization
     25 class LRNOp : public XlaOpKernel {
     26  public:
     27   explicit LRNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     28     OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_));
     29 
     30     // TODO(phawkins): handle non-float types for attributes.
     31     OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_));
     32     OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
     33     OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_));
     34   }
     35 
     36   void Compile(XlaOpKernelContext* ctx) override {
     37     const TensorShape in_shape = ctx->InputShape(0);
     38     OP_REQUIRES(ctx, in_shape.dims() == 4,
     39                 errors::InvalidArgument("in must be 4-dimensional"));
     40 
     41     xla::ComputationBuilder* builder = ctx->builder();
     42     xla::ComputationDataHandle input = ctx->Input(0);
     43 
     44     // sqr_sum[a, b, c, d] =
     45     //    sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
     46     // output = input / (bias + alpha * sqr_sum) ** beta
     47 
     48     // We use a window of depth_radius_ * 2 + 1, to account for the current
     49     // element and a depth_radius_ on either side.
     50     auto squared = builder->Mul(input, input);
     51     auto sqr_sum = builder->ReduceWindow(
     52         squared, XlaHelpers::Zero(builder, input_type(0)),
     53         *ctx->GetOrCreateAdd(input_type(0)),
     54         /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
     55         /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
     56 
     57     auto scale = builder->Pow(
     58         builder->Add(builder->ConstantR0<float>(bias_),
     59                      builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum)),
     60         builder->ConstantR0<float>(-beta_));
     61 
     62     ctx->SetOutput(0, builder->Mul(input, scale));
     63   }
     64 
     65  private:
     66   int64 depth_radius_;
     67   float bias_;
     68   float alpha_;
     69   float beta_;
     70 };
     71 
     72 REGISTER_XLA_OP(Name("LRN"), LRNOp);
     73 
     74 class LRNGradOp : public XlaOpKernel {
     75  public:
     76   explicit LRNGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     77     OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_));
     78 
     79     // TODO(phawkins): handle non-float types for attributes.
     80     OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_));
     81     OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
     82     OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_));
     83   }
     84 
     85   void Compile(XlaOpKernelContext* ctx) override {
     86     const TensorShape in_grads_shape = ctx->InputShape(0);
     87     const TensorShape in_image_shape = ctx->InputShape(1);
     88     const TensorShape out_image_shape = ctx->InputShape(2);
     89 
     90     OP_REQUIRES(ctx, in_grads_shape.dims() == 4 && in_image_shape.dims() == 4,
     91                 errors::InvalidArgument("inputs must be 4-dimensional"));
     92     const int64 batch = in_grads_shape.dim_size(0);
     93     const int64 rows = in_grads_shape.dim_size(1);
     94     const int64 cols = in_grads_shape.dim_size(2);
     95     const int64 depth = in_grads_shape.dim_size(3);
     96     OP_REQUIRES(
     97         ctx, in_image_shape.dim_size(0) == batch &&
     98                  in_image_shape.dim_size(1) == rows &&
     99                  in_image_shape.dim_size(2) == cols &&
    100                  in_image_shape.dim_size(3) == depth &&
    101                  out_image_shape.dim_size(0) == batch &&
    102                  out_image_shape.dim_size(1) == rows &&
    103                  out_image_shape.dim_size(2) == cols &&
    104                  out_image_shape.dim_size(3) == depth,
    105         errors::InvalidArgument(
    106             "input_grads, input_image, and out_image should have the same "
    107             "shape"));
    108 
    109     xla::ComputationBuilder* builder = ctx->builder();
    110     xla::ComputationDataHandle in_grads = ctx->Input(0);
    111     xla::ComputationDataHandle in_image = ctx->Input(1);
    112     xla::ComputationDataHandle out_image = ctx->Input(2);
    113 
    114     // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python
    115     // pseudo-code, the Eigen code does this for each spatial position:
    116     // grads = [0.0] * depth
    117     // for j in range(depth):
    118     //   depth_begin = max(0, j - depth_radius)
    119     //   depth_end = min(depth, j + depth_radius + 1)
    120     //
    121     //   norm = 0
    122     //   for k in range(depth_begin, depth_end):
    123     //     norm += in_image[k] * in_image[k]
    124     //   norm = alpha * norm + bias
    125     //
    126     //   for k in range(depth_begin, depth_end):
    127     //     dyi = -2.0 * alpha * beta * in_image[k] * out_image[j] / norm
    128     //     if k == j:
    129     //       dyi += norm ** (-beta)
    130     //     dyi *= out_grads[j]
    131     //     grads[k] += dyi
    132 
    133     auto squared = builder->Mul(in_image, in_image);
    134     auto sqr_sum = builder->ReduceWindow(
    135         squared, XlaHelpers::Zero(builder, input_type(0)),
    136         *ctx->GetOrCreateAdd(input_type(0)),
    137         /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
    138         /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
    139 
    140     auto norm =
    141         builder->Add(builder->ConstantR0<float>(bias_),
    142                      builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum));
    143 
    144     auto dy = builder->Mul(
    145         builder->Mul(builder->ConstantR0<float>(-2.0f * alpha_ * beta_),
    146                      builder->Div(out_image, norm)),
    147         in_grads);
    148 
    149     auto dy_reduced = builder->ReduceWindow(
    150         dy, XlaHelpers::Zero(builder, input_type(0)),
    151         *ctx->GetOrCreateAdd(input_type(0)),
    152         /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
    153         /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
    154 
    155     xla::ComputationDataHandle gradients = builder->Add(
    156         builder->Mul(in_image, dy_reduced),
    157         builder->Mul(in_grads,
    158                      builder->Pow(norm, builder->ConstantR0<float>(-beta_))));
    159 
    160     ctx->SetOutput(0, gradients);
    161   }
    162 
    163  private:
    164   int64 depth_radius_;
    165   float bias_;
    166   float alpha_;
    167   float beta_;
    168 };
    169 
    170 REGISTER_XLA_OP(Name("LRNGrad"), LRNGradOp);
    171 
    172 }  // anonymous namespace
    173 }  // namespace tensorflow
    174