1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA specific pooling ops. 17 18 #include "tensorflow/compiler/tf2xla/type_util.h" 19 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 23 #include "tensorflow/compiler/xla/literal_util.h" 24 #include "tensorflow/compiler/xla/util.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/kernels/bounds_check.h" 29 #include "tensorflow/core/kernels/conv_grad_ops.h" 30 #include "tensorflow/core/kernels/pooling_ops_common.h" 31 32 namespace tensorflow { 33 namespace { 34 35 // Superclass of pooling ops. 36 class PoolingOp : public XlaOpKernel { 37 public: 38 PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) 39 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { 40 if (ctx->num_inputs() == 1) { 41 std::vector<int32> ksize_int; 42 std::vector<int32> stride_int; 43 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); 44 OP_REQUIRES(ctx, ksize_int.size() == num_dims(), 45 errors::InvalidArgument("Sliding window ksize field must " 46 "specify ", 47 num_dims(), " dimensions")); 48 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); 49 OP_REQUIRES(ctx, stride_int.size() == num_dims(), 50 errors::InvalidArgument("Sliding window stride field must " 51 "specify ", 52 num_dims(), " dimensions")); 53 for (int i = 0; i < num_dims(); ++i) { 54 ksize_.push_back(ksize_int[i]); 55 stride_.push_back(stride_int[i]); 56 } 57 } 58 Padding padding; 59 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); 60 padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame; 61 } 62 63 int num_dims() const { return num_spatial_dims_ + 2; } 64 65 // Method that builds an initial value to use in reductions. 66 virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, 67 DataType data_type) = 0; 68 69 // The reduction operation to apply to each window. 70 virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx, 71 DataType dtype) = 0; 72 73 // A post-processing operation to apply on the outputs of the ReduceWindow. 74 virtual xla::ComputationDataHandle PostProcessOutput( 75 XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, 76 DataType dtype, const TensorShape& input_shape) = 0; 77 78 void Compile(XlaOpKernelContext* ctx) override { 79 xla::ComputationDataHandle input = ctx->Input(0); 80 const TensorShape input_shape = ctx->InputShape(0); 81 82 std::vector<int64> ksize = ksize_; 83 std::vector<int64> stride = stride_; 84 if (ctx->num_inputs() != 1) { 85 const TensorShape ksize_shape = ctx->InputShape(1); 86 // Validate input sizes. 87 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), 88 errors::InvalidArgument("ksize must be a vector, not shape ", 89 ksize_shape.DebugString())); 90 OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), 91 errors::InvalidArgument("Sliding window ksize field must " 92 "specify ", 93 num_dims(), " dimensions")); 94 ksize.clear(); 95 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); 96 97 const TensorShape stride_shape = ctx->InputShape(2); 98 // Validate input sizes. 99 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), 100 errors::InvalidArgument("stride must be a vector, not shape ", 101 stride_shape.DebugString())); 102 OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), 103 errors::InvalidArgument("Sliding window stride field must " 104 "specify ", 105 num_dims(), " dimensions")); 106 stride.clear(); 107 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); 108 } 109 OP_REQUIRES(ctx, input_shape.dims() == num_dims(), 110 errors::InvalidArgument("Input to ", type_string(), 111 " operator must have ", num_dims(), 112 " dimensions")); 113 114 const DataType type = input_type(0); 115 xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( 116 input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize, 117 stride, padding_); 118 ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); 119 } 120 121 protected: 122 const int num_spatial_dims_; 123 std::vector<int64> ksize_; 124 std::vector<int64> stride_; 125 xla::Padding padding_; 126 TensorFormat data_format_ = FORMAT_NHWC; 127 }; 128 129 class MaxPoolOp : public PoolingOp { 130 public: 131 MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) 132 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {} 133 134 xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, 135 DataType data_type) override { 136 return XlaHelpers::MinValue(b, data_type); 137 } 138 139 const xla::Computation* Reduction(XlaOpKernelContext* ctx, 140 DataType dtype) override { 141 return ctx->GetOrCreateMax(dtype); 142 } 143 144 xla::ComputationDataHandle PostProcessOutput( 145 XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, 146 DataType dtype, const TensorShape& input_shape) override { 147 return output; 148 } 149 }; 150 151 class MaxPool2DOp : public MaxPoolOp { 152 public: 153 explicit MaxPool2DOp(OpKernelConstruction* ctx) 154 : MaxPoolOp(ctx, /*num_spatial_dims=*/2) { 155 string data_format_str; 156 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); 157 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), 158 errors::InvalidArgument("Invalid data format")); 159 } 160 }; 161 REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); 162 REGISTER_XLA_OP(Name("MaxPoolV2") 163 .CompileTimeConstInput("ksize") 164 .CompileTimeConstInput("strides"), 165 MaxPool2DOp); 166 167 class MaxPool3DOp : public MaxPoolOp { 168 public: 169 explicit MaxPool3DOp(OpKernelConstruction* ctx) 170 : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {} 171 }; 172 REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); 173 174 // Common computation shared between AvgPool and AvgPoolGrad. Divide each 175 // element of an image by the count of elements that contributed to that 176 // element during pooling. 177 static xla::ComputationDataHandle AvgPoolDivideByCount( 178 XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, 179 DataType dtype, const TensorShape& input_shape, xla::Padding padding, 180 const std::vector<int64>& ksize, const std::vector<int64>& stride, 181 int num_spatial_dims, TensorFormat data_format) { 182 if (padding == xla::Padding::kValid) { 183 // In VALID padding, all windows have the same number of elements 184 // contributing to each average. Divide by the window size everywhere to 185 // get the average. 186 int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1, 187 [](int64 a, int64 b) { return a * b; }); 188 189 auto divisor = 190 XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); 191 return ctx->builder()->Div(output, divisor); 192 } else { 193 // For SAME padding, the padding shouldn't be included in the 194 // counts. We use another ReduceWindow to find the right counts. 195 196 // TODO(phawkins): use a less brute-force way to compute this. Only 197 // the boundary regions will have interesting values here. 198 199 std::vector<int64> input_dim_sizes(num_spatial_dims); 200 std::vector<int64> window_dims(num_spatial_dims); 201 std::vector<int64> window_ksize(num_spatial_dims); 202 std::vector<int64> window_stride(num_spatial_dims); 203 for (int i = 0; i < num_spatial_dims; ++i) { 204 int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i); 205 input_dim_sizes[i] = input_shape.dim_size(dim); 206 window_dims[i] = dim; 207 window_ksize[i] = ksize[dim]; 208 window_stride[i] = stride[dim]; 209 } 210 211 // Build a matrix of all 1s, with the same width/height as the input. 212 auto ones = ctx->builder()->Broadcast( 213 XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes); 214 215 // Perform a ReduceWindow with the same window size, strides, and padding 216 // to count the number of contributions to each result element. 217 auto counts = ctx->builder()->ReduceWindow( 218 ones, XlaHelpers::Zero(ctx->builder(), dtype), 219 *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride, 220 xla::Padding::kSame); 221 222 return ctx->builder()->Div(output, counts, window_dims); 223 } 224 } 225 226 class AvgPoolOp : public PoolingOp { 227 public: 228 AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) 229 : PoolingOp(ctx, num_spatial_dims) {} 230 231 xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b, 232 DataType data_type) override { 233 return XlaHelpers::Zero(b, data_type); 234 } 235 236 const xla::Computation* Reduction(XlaOpKernelContext* ctx, 237 DataType dtype) override { 238 return ctx->GetOrCreateAdd(dtype); 239 } 240 241 xla::ComputationDataHandle PostProcessOutput( 242 XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, 243 DataType dtype, const TensorShape& input_shape) override { 244 return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, 245 ksize_, stride_, num_spatial_dims_, 246 data_format_); 247 } 248 }; 249 250 class AvgPool2DOp : public AvgPoolOp { 251 public: 252 explicit AvgPool2DOp(OpKernelConstruction* ctx) 253 : AvgPoolOp(ctx, /*num_spatial_dims=*/2) { 254 string data_format_str; 255 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); 256 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), 257 errors::InvalidArgument("Invalid data format")); 258 } 259 }; 260 REGISTER_XLA_OP(Name("AvgPool"), AvgPool2DOp); 261 262 class AvgPool3DOp : public AvgPoolOp { 263 public: 264 explicit AvgPool3DOp(OpKernelConstruction* ctx) 265 : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {} 266 }; 267 REGISTER_XLA_OP(Name("AvgPool3D"), AvgPool3DOp); 268 269 // The operation to compute MaxPool gradients. 270 // It takes three inputs: 271 // - The original input tensor 272 // - The original output tensor 273 // - Backprop tensor for output 274 // It produces one output: backprop tensor for input. 275 class MaxPoolGradOp : public XlaOpKernel { 276 public: 277 MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) 278 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { 279 if (ctx->num_inputs() == 3) { 280 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); 281 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); 282 } 283 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 284 } 285 286 int num_dims() const { return num_spatial_dims_ + 2; } 287 288 void Compile(XlaOpKernelContext* ctx) override { 289 if (ctx->num_inputs() != 3) { 290 OP_REQUIRES( 291 ctx, ctx->num_inputs() == 5, 292 errors::InvalidArgument("Must supply ksize and stride arguments.")); 293 const TensorShape ksize_shape = ctx->InputShape(3); 294 // Validate input sizes. 295 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), 296 errors::InvalidArgument("ksize must be a vector, not shape ", 297 ksize_shape.DebugString())); 298 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); 299 300 const TensorShape stride_shape = ctx->InputShape(4); 301 // Validate input sizes. 302 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), 303 errors::InvalidArgument("stride must be a vector, not shape ", 304 stride_shape.DebugString())); 305 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); 306 } 307 308 OP_REQUIRES(ctx, ksize_.size() == num_dims(), 309 errors::InvalidArgument("Sliding window ksize field must " 310 "specify ", 311 num_dims(), " dimensions")); 312 OP_REQUIRES(ctx, stride_.size() == num_dims(), 313 errors::InvalidArgument("Sliding window strides field must " 314 "specify ", 315 num_dims(), " dimensions")); 316 317 const TensorShape tensor_in_shape = ctx->InputShape(0); 318 const TensorShape tensor_out_shape = ctx->InputShape(1); 319 const TensorShape out_backprop_shape = ctx->InputShape(2); 320 321 // For maxpooling, tensor_in should have num_dims() dimensions. 322 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(), 323 errors::InvalidArgument("tensor_in must be ", num_dims(), 324 "-dimensional")); 325 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(), 326 errors::InvalidArgument("tensor_out must be ", num_dims(), 327 "-dimensional")); 328 // For maxpooling, out_backprop should have num_dims() dimensions. 329 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), 330 errors::InvalidArgument("out_backprop must be ", num_dims(), 331 "-dimensional")); 332 333 // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate 334 // whether this is a good time/space tradeoff. 335 auto input = ctx->Input(0); 336 auto out_backprop = ctx->Input(2); 337 338 xla::Padding xla_padding = 339 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; 340 341 xla::PrimitiveType element_type; 342 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); 343 xla::ComputationDataHandle init_value = 344 XlaHelpers::Zero(ctx->builder(), input_type(2)); 345 auto select = CreateScalarGeComputation(element_type, ctx->builder()); 346 auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); 347 xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter( 348 input, select, ksize_, stride_, xla_padding, out_backprop, init_value, 349 scatter); 350 351 ctx->SetOutput(0, gradients); 352 } 353 354 protected: 355 const int num_spatial_dims_; 356 std::vector<int64> ksize_; 357 std::vector<int64> stride_; 358 Padding padding_; 359 TensorFormat data_format_ = FORMAT_NHWC; 360 }; 361 362 class MaxPool2DGradOp : public MaxPoolGradOp { 363 public: 364 explicit MaxPool2DGradOp(OpKernelConstruction* ctx) 365 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) { 366 string data_format; 367 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); 368 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), 369 errors::InvalidArgument("Invalid data format")); 370 } 371 }; 372 REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp); 373 REGISTER_XLA_OP(Name("MaxPoolGradV2") 374 .CompileTimeConstInput("ksize") 375 .CompileTimeConstInput("strides"), 376 MaxPool2DGradOp); 377 378 class MaxPool3DGradOp : public MaxPoolGradOp { 379 public: 380 explicit MaxPool3DGradOp(OpKernelConstruction* ctx) 381 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {} 382 }; 383 REGISTER_XLA_OP(Name("MaxPool3DGrad"), MaxPool3DGradOp); 384 385 // Average-pooling gradient 386 class AvgPoolGradOp : public XlaOpKernel { 387 public: 388 AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) 389 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { 390 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); 391 OP_REQUIRES(ctx, ksize_.size() == num_dims(), 392 errors::InvalidArgument("Sliding window ksize field must " 393 "specify ", 394 num_dims(), " dimensions")); 395 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); 396 OP_REQUIRES(ctx, stride_.size() == num_dims(), 397 errors::InvalidArgument("Sliding window strides field must " 398 "specify ", 399 num_dims(), " dimensions")); 400 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 401 OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1, 402 errors::Unimplemented( 403 "Pooling is not yet supported on the batch dimension.")); 404 } 405 406 int num_dims() const { return num_spatial_dims_ + 2; } 407 408 void Compile(XlaOpKernelContext* ctx) override { 409 TensorShape gradients_shape; 410 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape)); 411 412 const TensorShape out_backprop_shape = ctx->InputShape(1); 413 414 // For avgpooling, tensor_in_shape should have num_dims() dimensions. 415 OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(), 416 errors::InvalidArgument("orig_input_shape must be ", num_dims(), 417 "-dimensional")); 418 419 // For avgpooling, out_backprop should have num_dims() dimensions. 420 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(), 421 errors::InvalidArgument("out_backprop must be ", num_dims(), 422 "-dimensional")); 423 424 int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); 425 int64 depth = out_backprop_shape.dim_size(depth_dim); 426 427 // We can think of average-pooling as: 428 // * a convolution with a kernel consisting entirely of 1s, where the 429 // input feature and output feature are equal, and 0s everywhere else. 430 // * followed by dividing by the counts. 431 // 432 // This then gives us an algorithm to build the gradient: 433 // * divide out_backprop by the counts, followed by 434 // * Conv2DBackpropInput specialized for that kernel, which simplifies to 435 // a Pad and a ReduceWindow. 436 // 437 // For an explanation of backpropagation for convolution, see the comments 438 // in third_party/tensorflow/core/kernels/conv_grad_ops.h 439 440 // TF filter shape is [ H, W, ..., inC, outC ] 441 std::vector<int64> filter_dims(num_dims()); 442 for (int i = 0; i < num_spatial_dims_; ++i) { 443 int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 444 filter_dims[i] = ksize_[dim]; 445 } 446 filter_dims[num_dims() - 2] = depth; 447 filter_dims[num_dims() - 1] = depth; 448 TensorShape filter_shape(filter_dims); 449 450 // Reuse the logic from Conv2DBackpropInput to compute padding. 451 ConvBackpropDimensions dims; 452 OP_REQUIRES_OK( 453 ctx, ConvBackpropComputeDimensions( 454 type_string(), /*num_spatial_dims=*/num_spatial_dims_, 455 gradients_shape, filter_shape, out_backprop_shape, stride_, 456 padding_, data_format_, &dims)); 457 458 auto out_backprop = ctx->Input(1); 459 460 // The input gradients are computed by a convolution of the output 461 // gradients 462 // and the filter, with some appropriate padding. See the comment at 463 // the top of conv_grad_ops.h for details. 464 DataType dtype = input_type(1); 465 466 xla::Padding xla_padding = 467 (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; 468 469 // Divide the out_backprop values by the counts for each spatial position. 470 std::vector<int64> stride_int64s(stride_.begin(), stride_.end()); 471 auto out_backprop_div = AvgPoolDivideByCount( 472 ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_, 473 stride_int64s, num_spatial_dims_, data_format_); 474 475 // Pad the gradients in the spatial dimensions. We use the same padding 476 // as Conv2DBackpropInput. 477 xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims()); 478 for (int i = 0; i < num_spatial_dims_; ++i) { 479 int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 480 auto* padding = padding_config.mutable_dimensions(dim); 481 padding->set_edge_padding_low(dims.spatial_dims[i].pad_before); 482 padding->set_edge_padding_high(dims.spatial_dims[i].pad_after); 483 padding->set_interior_padding(dims.spatial_dims[i].stride - 1); 484 } 485 486 auto zero = XlaHelpers::Zero(ctx->builder(), dtype); 487 auto padded_gradients = 488 ctx->builder()->Pad(out_backprop_div, zero, padding_config); 489 490 // in_backprop = padded_gradients <conv> ones 491 std::vector<int64> ones(num_dims(), 1LL); 492 xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow( 493 padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_, 494 /* window_strides=*/ones, xla::Padding::kValid); 495 496 ctx->SetOutput(0, in_backprop); 497 } 498 499 protected: 500 const int num_spatial_dims_; 501 std::vector<int64> ksize_; 502 std::vector<int32> stride_; 503 Padding padding_; 504 TensorFormat data_format_ = FORMAT_NHWC; 505 }; 506 507 class AvgPool2DGradOp : public AvgPoolGradOp { 508 public: 509 explicit AvgPool2DGradOp(OpKernelConstruction* ctx) 510 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) { 511 string data_format; 512 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); 513 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), 514 errors::InvalidArgument("Invalid data format")); 515 } 516 }; 517 REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"), 518 AvgPool2DGradOp); 519 520 class AvgPool3DGradOp : public AvgPoolGradOp { 521 public: 522 explicit AvgPool3DGradOp(OpKernelConstruction* ctx) 523 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} 524 }; 525 REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), 526 AvgPool3DGradOp); 527 528 } // anonymous namespace 529 } // namespace tensorflow 530