1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 19 #include "tensorflow/core/framework/kernel_def_builder.h" 20 21 namespace tensorflow { 22 namespace { 23 24 // Local response normalization 25 class LRNOp : public XlaOpKernel { 26 public: 27 explicit LRNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 28 OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_)); 29 30 // TODO(phawkins): handle non-float types for attributes. 31 OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_)); 32 OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); 33 OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_)); 34 } 35 36 void Compile(XlaOpKernelContext* ctx) override { 37 const TensorShape in_shape = ctx->InputShape(0); 38 OP_REQUIRES(ctx, in_shape.dims() == 4, 39 errors::InvalidArgument("in must be 4-dimensional")); 40 41 xla::ComputationBuilder* builder = ctx->builder(); 42 xla::ComputationDataHandle input = ctx->Input(0); 43 44 // sqr_sum[a, b, c, d] = 45 // sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) 46 // output = input / (bias + alpha * sqr_sum) ** beta 47 48 // We use a window of depth_radius_ * 2 + 1, to account for the current 49 // element and a depth_radius_ on either side. 50 auto squared = builder->Mul(input, input); 51 auto sqr_sum = builder->ReduceWindow( 52 squared, XlaHelpers::Zero(builder, input_type(0)), 53 *ctx->GetOrCreateAdd(input_type(0)), 54 /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, 55 /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); 56 57 auto scale = builder->Pow( 58 builder->Add(builder->ConstantR0<float>(bias_), 59 builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum)), 60 builder->ConstantR0<float>(-beta_)); 61 62 ctx->SetOutput(0, builder->Mul(input, scale)); 63 } 64 65 private: 66 int64 depth_radius_; 67 float bias_; 68 float alpha_; 69 float beta_; 70 }; 71 72 REGISTER_XLA_OP(Name("LRN"), LRNOp); 73 74 class LRNGradOp : public XlaOpKernel { 75 public: 76 explicit LRNGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 77 OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_)); 78 79 // TODO(phawkins): handle non-float types for attributes. 80 OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_)); 81 OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_)); 82 OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_)); 83 } 84 85 void Compile(XlaOpKernelContext* ctx) override { 86 const TensorShape in_grads_shape = ctx->InputShape(0); 87 const TensorShape in_image_shape = ctx->InputShape(1); 88 const TensorShape out_image_shape = ctx->InputShape(2); 89 90 OP_REQUIRES(ctx, in_grads_shape.dims() == 4 && in_image_shape.dims() == 4, 91 errors::InvalidArgument("inputs must be 4-dimensional")); 92 const int64 batch = in_grads_shape.dim_size(0); 93 const int64 rows = in_grads_shape.dim_size(1); 94 const int64 cols = in_grads_shape.dim_size(2); 95 const int64 depth = in_grads_shape.dim_size(3); 96 OP_REQUIRES( 97 ctx, in_image_shape.dim_size(0) == batch && 98 in_image_shape.dim_size(1) == rows && 99 in_image_shape.dim_size(2) == cols && 100 in_image_shape.dim_size(3) == depth && 101 out_image_shape.dim_size(0) == batch && 102 out_image_shape.dim_size(1) == rows && 103 out_image_shape.dim_size(2) == cols && 104 out_image_shape.dim_size(3) == depth, 105 errors::InvalidArgument( 106 "input_grads, input_image, and out_image should have the same " 107 "shape")); 108 109 xla::ComputationBuilder* builder = ctx->builder(); 110 xla::ComputationDataHandle in_grads = ctx->Input(0); 111 xla::ComputationDataHandle in_image = ctx->Input(1); 112 xla::ComputationDataHandle out_image = ctx->Input(2); 113 114 // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python 115 // pseudo-code, the Eigen code does this for each spatial position: 116 // grads = [0.0] * depth 117 // for j in range(depth): 118 // depth_begin = max(0, j - depth_radius) 119 // depth_end = min(depth, j + depth_radius + 1) 120 // 121 // norm = 0 122 // for k in range(depth_begin, depth_end): 123 // norm += in_image[k] * in_image[k] 124 // norm = alpha * norm + bias 125 // 126 // for k in range(depth_begin, depth_end): 127 // dyi = -2.0 * alpha * beta * in_image[k] * out_image[j] / norm 128 // if k == j: 129 // dyi += norm ** (-beta) 130 // dyi *= out_grads[j] 131 // grads[k] += dyi 132 133 auto squared = builder->Mul(in_image, in_image); 134 auto sqr_sum = builder->ReduceWindow( 135 squared, XlaHelpers::Zero(builder, input_type(0)), 136 *ctx->GetOrCreateAdd(input_type(0)), 137 /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, 138 /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); 139 140 auto norm = 141 builder->Add(builder->ConstantR0<float>(bias_), 142 builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum)); 143 144 auto dy = builder->Mul( 145 builder->Mul(builder->ConstantR0<float>(-2.0f * alpha_ * beta_), 146 builder->Div(out_image, norm)), 147 in_grads); 148 149 auto dy_reduced = builder->ReduceWindow( 150 dy, XlaHelpers::Zero(builder, input_type(0)), 151 *ctx->GetOrCreateAdd(input_type(0)), 152 /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1}, 153 /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame); 154 155 xla::ComputationDataHandle gradients = builder->Add( 156 builder->Mul(in_image, dy_reduced), 157 builder->Mul(in_grads, 158 builder->Pow(norm, builder->ConstantR0<float>(-beta_)))); 159 160 ctx->SetOutput(0, gradients); 161 } 162 163 private: 164 int64 depth_radius_; 165 float bias_; 166 float alpha_; 167 float beta_; 168 }; 169 170 REGISTER_XLA_OP(Name("LRNGrad"), LRNGradOp); 171 172 } // anonymous namespace 173 } // namespace tensorflow 174