1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Ops for 2D convolution. 17 18 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/core/framework/numeric_op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/tensor_slice.h" 27 #include "tensorflow/core/kernels/bounds_check.h" 28 #include "tensorflow/core/kernels/conv_grad_ops.h" 29 #include "tensorflow/core/kernels/ops_util.h" 30 #include "tensorflow/core/util/padding.h" 31 #include "tensorflow/core/util/tensor_format.h" 32 33 namespace tensorflow { 34 35 namespace { 36 37 // Returns the expanded size of a filter used for depthwise convolution. 38 // If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. 39 TensorShape ExpandedFilterShapeForDepthwiseConvolution( 40 const TensorShape& shape) { 41 int num_dims = shape.dims(); 42 CHECK_GE(num_dims, 2); 43 TensorShape expanded_shape = shape; 44 expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) * 45 shape.dim_size(num_dims - 1)); 46 return expanded_shape; 47 } 48 49 // Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. 50 xla::ComputationDataHandle CreateExpandedZero( 51 const TensorShape& filter_shape, DataType dtype, 52 xla::ComputationBuilder* builder) { 53 TensorShape expanded_filter_shape = 54 ExpandedFilterShapeForDepthwiseConvolution(filter_shape); 55 return builder->Broadcast(XlaHelpers::Zero(builder, dtype), 56 expanded_filter_shape.dim_sizes()); 57 } 58 59 // Create a mask for depthwise convolution that will make a normal convolution 60 // produce the same results as a depthwise convolution. For a [2, 2, 3, 2] 61 // depthwise filter this returns a [2, 2, 3, 6] tesnsor 62 // 1 1 0 0 0 0 1 1 0 0 0 0 63 // 0 0 1 1 0 0 0 0 1 1 0 0 64 // 0 0 0 0 1 1 0 0 0 0 1 1 65 // 66 // 1 1 0 0 0 0 1 1 0 0 0 0 67 // 0 0 1 1 0 0 0 0 1 1 0 0 68 // 0 0 0 0 1 1 0 0 0 0 1 1 69 // 70 // The first step is to create a one tensor, A, that is [3] 71 // 0 1 2 72 // 73 // and another tensor, B, that is [3 * 2] 74 // 0 1 2 3 4 5 75 // 76 // and divide B it by 2 to get 77 // 0 0 1 1 2 2 78 // 79 // then we broadcast the B to [2, 2, 3, 3 * 2] 80 // 0 0 1 1 2 2 0 0 1 1 2 2 81 // 0 0 1 1 2 2 0 0 1 1 2 2 82 // 0 0 1 1 2 2 0 0 1 1 2 2 83 // 84 // 0 0 1 1 2 2 0 0 1 1 2 2 85 // 0 0 1 1 2 2 0 0 1 1 2 2 86 // 0 0 1 1 2 2 0 0 1 1 2 2 87 // 88 // Finally compare A and broadcasted B in dimension 2 amd return the result at 89 // the beginning of the comment. 90 xla::ComputationDataHandle CreateExpandedFilterMask( 91 const TensorShape& filter_shape, xla::ComputationBuilder* builder) { 92 TensorShape expanded_filter_shape = 93 ExpandedFilterShapeForDepthwiseConvolution(filter_shape); 94 int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); 95 int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); 96 97 // Create a M sized linspace and an M*N sized linspace that will be 98 // broadcasted into perpendicular dimensions and compared. 99 xla::ComputationDataHandle input_feature_iota; 100 // DT_INT32 Iota will always return status::OK(). 101 TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, 102 &input_feature_iota)); 103 xla::ComputationDataHandle expanded_feature_iota; 104 TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, 105 input_feature * depthwise_multiplier, 106 &expanded_feature_iota)); 107 108 // Divide the M*N sized linspace by the depthwise_multiplier to create 109 // [0 0 1 1 2 2] in the example in the function comment. 110 expanded_feature_iota = 111 builder->Div(expanded_feature_iota, 112 XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, 113 depthwise_multiplier)); 114 115 // Broadcast the N*M linspace to [H, W, ..., M, M*N]. 116 auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); 117 expanded_feature_broadcast_dims.pop_back(); 118 auto broadcasted_expanded_feature_iota = builder->Broadcast( 119 expanded_feature_iota, expanded_feature_broadcast_dims); 120 121 // Compare the broadcasted linspace to the input feature linspace in the 122 // input feature dimension to create a diagonal predicate. 123 return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, 124 {expanded_filter_shape.dims() - 2}); 125 } 126 127 // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding 128 // zeros for the cross-depth filters. Used to build a depthwise convolution. 129 xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( 130 const TensorShape& filter_shape, DataType dtype, 131 const xla::ComputationDataHandle& filter, 132 xla::ComputationBuilder* builder) { 133 int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); 134 int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); 135 TensorShape expanded_filter_shape = 136 ExpandedFilterShapeForDepthwiseConvolution(filter_shape); 137 138 // Create a [H, W, ..., 1, N*M] reshape of the filter. 139 TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; 140 implicit_broadcast_filter_shape.set_dim( 141 implicit_broadcast_filter_shape.dims() - 2, 1); 142 implicit_broadcast_filter_shape.set_dim( 143 implicit_broadcast_filter_shape.dims() - 1, 144 depthwise_multiplier * input_feature); 145 auto implicit_broadcast_filter = 146 builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); 147 148 // Broadcast the filter to [H, W, ..., M, M*N]. 149 auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); 150 auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); 151 152 // If the filter mask is set, choose the broadcasted filter, othwerwise, 153 // choose zero. 154 return builder->Select(CreateExpandedFilterMask(filter_shape, builder), 155 expanded_filter, expanded_zero); 156 } 157 158 // Inverse of ExpandFilterForDepthwiseConvolution. 159 xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( 160 XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, 161 const xla::ComputationDataHandle& filter_backprop, 162 xla::ComputationBuilder* builder) { 163 TensorShape expanded_filter_shape = 164 ExpandedFilterShapeForDepthwiseConvolution(filter_shape); 165 auto masked_expanded_filter = builder->Select( 166 CreateExpandedFilterMask(filter_shape, builder), filter_backprop, 167 CreateExpandedZero(filter_shape, dtype, builder)); 168 return builder->Reshape( 169 builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), 170 *ctx->GetOrCreateAdd(dtype), 171 {expanded_filter_shape.dims() - 2}), 172 filter_shape.dim_sizes()); 173 } 174 175 class ConvOp : public XlaOpKernel { 176 public: 177 explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims, 178 bool depthwise) 179 : XlaOpKernel(ctx), 180 num_spatial_dims_(num_spatial_dims), 181 depthwise_(depthwise) { 182 OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); 183 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); 184 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 185 186 string data_format; 187 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); 188 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), 189 errors::InvalidArgument("Invalid data format")); 190 } 191 192 int num_dims() const { return num_spatial_dims_ + 2; } 193 194 void Compile(XlaOpKernelContext* ctx) override { 195 OP_REQUIRES(ctx, strides_.size() == num_dims(), 196 errors::InvalidArgument("Sliding window strides field must " 197 "specify ", 198 num_dims(), " dimensions")); 199 int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); 200 int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); 201 OP_REQUIRES( 202 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, 203 errors::Unimplemented("Current implementation does not yet support " 204 "strides in the batch and depth dimensions.")); 205 206 OP_REQUIRES(ctx, dilations_.size() == num_dims(), 207 errors::InvalidArgument("Dilations field must " 208 "specify ", 209 num_dims(), " dimensions")); 210 OP_REQUIRES( 211 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, 212 errors::Unimplemented("Current implementation does not support " 213 "dilations in the batch and depth dimensions.")); 214 for (int i = 0; i < num_spatial_dims_; ++i) { 215 int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 216 OP_REQUIRES(ctx, dilations_[input_dim] >= 1, 217 errors::Unimplemented("Dilation values must be positive; ", i, 218 "th spatial dimension had dilation ", 219 dilations_[input_dim])); 220 } 221 222 const TensorShape input_shape = ctx->InputShape(0); 223 // Input filter is of the following dimensions: 224 // [ filter_rows, filter_cols, ..., in_depth, out_depth] 225 const TensorShape filter_shape = ctx->InputShape(1); 226 227 // For 2D convolution, there should be 4 dimensions. 228 OP_REQUIRES( 229 ctx, input_shape.dims() == num_dims(), 230 errors::InvalidArgument("input must be ", num_dims(), "-dimensional", 231 input_shape.DebugString())); 232 OP_REQUIRES( 233 ctx, filter_shape.dims() == num_dims(), 234 errors::InvalidArgument("filter must be ", num_dims(), 235 "-dimensional: ", filter_shape.DebugString())); 236 237 // The last two dimension of the filter are the input and output shapes. 238 const int64 in_depth = filter_shape.dim_size(num_spatial_dims_); 239 240 // The 'C' dimension for input is in_depth. It must be the same as 241 // the filter's in_depth. 242 OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim), 243 errors::InvalidArgument( 244 "input and filter must have the same depth: ", in_depth, 245 " vs ", input_shape.dim_size(feature_dim))); 246 247 xla::ComputationBuilder* b = ctx->builder(); 248 249 xla::ComputationDataHandle filter = ctx->Input(1); 250 TensorShape expanded_filter_shape = filter_shape; 251 if (depthwise_) { 252 filter = ExpandFilterForDepthwiseConvolution( 253 filter_shape, ctx->input_type(0), filter, b); 254 expanded_filter_shape = 255 ExpandedFilterShapeForDepthwiseConvolution(filter_shape); 256 } 257 258 xla::ConvolutionDimensionNumbers dims; 259 std::vector<int64> window_strides(num_spatial_dims_); 260 std::vector<int64> lhs_dilation(num_spatial_dims_, 1); 261 std::vector<int64> rhs_dilation(num_spatial_dims_); 262 std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); 263 264 dims.set_input_batch_dimension(batch_dim); 265 dims.set_output_batch_dimension(batch_dim); 266 dims.set_input_feature_dimension(feature_dim); 267 dims.set_output_feature_dimension(feature_dim); 268 dims.set_kernel_input_feature_dimension(num_spatial_dims_); 269 dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); 270 271 for (int i = 0; i < num_spatial_dims_; ++i) { 272 const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 273 dims.add_input_spatial_dimensions(dim); 274 dims.add_kernel_spatial_dimensions(i); 275 dims.add_output_spatial_dimensions(dim); 276 window_strides[i] = strides_.at(dim); 277 rhs_dilation[i] = dilations_.at(dim); 278 279 int64 unused_output_size; 280 OP_REQUIRES_OK( 281 ctx, GetWindowedOutputSizeVerboseV2( 282 input_shape.dim_size(dim), expanded_filter_shape.dim_size(i), 283 rhs_dilation[i], window_strides[i], padding_, 284 &unused_output_size, &padding[i].first, &padding[i].second)); 285 } 286 287 xla::ComputationDataHandle conv = 288 b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, 289 lhs_dilation, rhs_dilation, dims); 290 ctx->SetOutput(0, conv); 291 } 292 293 protected: 294 const int num_spatial_dims_; 295 const bool depthwise_; 296 std::vector<int32> dilations_; 297 std::vector<int32> strides_; 298 Padding padding_; 299 TensorFormat data_format_ = FORMAT_NHWC; 300 301 private: 302 TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); 303 }; 304 305 class Conv2DOp : public ConvOp { 306 public: 307 explicit Conv2DOp(OpKernelConstruction* ctx) 308 : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} 309 }; 310 REGISTER_XLA_OP(Name("Conv2D"), Conv2DOp); 311 312 class Conv3DOp : public ConvOp { 313 public: 314 explicit Conv3DOp(OpKernelConstruction* ctx) 315 : ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} 316 }; 317 REGISTER_XLA_OP(Name("Conv3D"), Conv3DOp); 318 319 class DepthwiseConv2DOp : public ConvOp { 320 public: 321 explicit DepthwiseConv2DOp(OpKernelConstruction* ctx) 322 : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 323 }; 324 REGISTER_XLA_OP(Name("DepthwiseConv2dNative"), DepthwiseConv2DOp); 325 326 // Backprop for input. 327 class ConvBackpropInputOp : public XlaOpKernel { 328 public: 329 explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims, 330 bool depthwise) 331 : XlaOpKernel(ctx), 332 num_spatial_dims_(num_spatial_dims), 333 depthwise_(depthwise) { 334 OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); 335 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); 336 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 337 string data_format; 338 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); 339 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), 340 errors::InvalidArgument("Invalid data format")); 341 } 342 343 int num_dims() const { return num_spatial_dims_ + 2; } 344 345 void Compile(XlaOpKernelContext* ctx) override { 346 OP_REQUIRES(ctx, strides_.size() == num_dims(), 347 errors::InvalidArgument("Sliding window strides field must " 348 "specify ", 349 num_dims(), " dimensions")); 350 int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); 351 int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); 352 OP_REQUIRES( 353 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, 354 errors::Unimplemented("Current implementation does not yet support " 355 "strides in the batch and depth dimensions.")); 356 357 OP_REQUIRES(ctx, dilations_.size() == num_dims(), 358 errors::InvalidArgument("Dilations field must " 359 "specify ", 360 num_dims(), " dimensions")); 361 OP_REQUIRES( 362 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, 363 errors::Unimplemented("Current implementation does not support " 364 "dilations in the batch and depth dimensions.")); 365 for (int i = 0; i < num_spatial_dims_; ++i) { 366 int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 367 OP_REQUIRES(ctx, dilations_[input_dim] >= 1, 368 errors::Unimplemented("Dilation values must be positive; ", i, 369 "th spatial dimension had dilation ", 370 dilations_[input_dim])); 371 } 372 373 TensorShape input_shape; 374 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); 375 376 const TensorShape filter_shape = ctx->InputShape(1); 377 const TensorShape out_backprop_shape = ctx->InputShape(2); 378 379 const TensorShape expanded_filter_shape = 380 depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) 381 : filter_shape; 382 // Reuse dimension computation logic from conv_grad_ops.cc. 383 ConvBackpropDimensions dims; 384 OP_REQUIRES_OK(ctx, 385 ConvBackpropComputeDimensionsV2( 386 type_string(), num_spatial_dims_, input_shape, 387 expanded_filter_shape, out_backprop_shape, dilations_, 388 strides_, padding_, data_format_, &dims)); 389 390 xla::ComputationBuilder* b = ctx->builder(); 391 auto filter = ctx->Input(1); 392 auto out_backprop = ctx->Input(2); 393 394 // The input gradients are computed by a convolution of the output 395 // gradients and the filter, with some appropriate padding. See the 396 // comment at the top of conv_grad_ops.h for details. 397 398 xla::ConvolutionDimensionNumbers dnums; 399 dnums.set_input_batch_dimension(batch_dim); 400 dnums.set_output_batch_dimension(batch_dim); 401 dnums.set_input_feature_dimension(feature_dim); 402 dnums.set_output_feature_dimension(feature_dim); 403 404 // TF filter shape is [ H, W, ..., inC, outC ] 405 // Transpose the input and output features for computing the gradient. 406 dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1); 407 dnums.set_kernel_output_feature_dimension(num_spatial_dims_); 408 409 std::vector<int64> kernel_spatial_dims(num_spatial_dims_); 410 std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); 411 std::vector<int64> lhs_dilation(num_spatial_dims_); 412 std::vector<int64> rhs_dilation(num_spatial_dims_); 413 std::vector<int64> ones(num_spatial_dims_, 1); 414 for (int i = 0; i < num_spatial_dims_; ++i) { 415 int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 416 dnums.add_input_spatial_dimensions(dim); 417 dnums.add_kernel_spatial_dimensions(i); 418 dnums.add_output_spatial_dimensions(dim); 419 420 kernel_spatial_dims[i] = i; 421 padding[i] = {dims.spatial_dims[i].pad_before, 422 dims.spatial_dims[i].pad_after}; 423 lhs_dilation[i] = dims.spatial_dims[i].stride; 424 rhs_dilation[i] = dilations_[dim]; 425 } 426 427 // If this is a depthwise convolution, expand the filter. 428 if (depthwise_) { 429 filter = ExpandFilterForDepthwiseConvolution( 430 filter_shape, ctx->input_type(1), filter, b); 431 } 432 433 // Mirror the filter in the spatial dimensions. 434 xla::ComputationDataHandle mirrored_weights = 435 b->Rev(filter, kernel_spatial_dims); 436 437 // activation gradients 438 // = gradients (with padding and dilation) <conv> mirrored_weights 439 xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated( 440 out_backprop, mirrored_weights, /*window_strides=*/ones, padding, 441 lhs_dilation, rhs_dilation, dnums); 442 443 ctx->SetOutput(0, in_backprop); 444 } 445 446 protected: 447 const int num_spatial_dims_; 448 const bool depthwise_; 449 std::vector<int32> dilations_; 450 std::vector<int32> strides_; 451 Padding padding_; 452 TensorFormat data_format_ = FORMAT_NHWC; 453 454 private: 455 TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); 456 }; 457 458 class Conv2DBackpropInputOp : public ConvBackpropInputOp { 459 public: 460 explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) 461 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} 462 }; 463 REGISTER_XLA_OP( 464 Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"), 465 Conv2DBackpropInputOp); 466 467 class Conv3DBackpropInputOp : public ConvBackpropInputOp { 468 public: 469 explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) 470 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} 471 }; 472 REGISTER_XLA_OP( 473 Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"), 474 Conv3DBackpropInputOp); 475 476 class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { 477 public: 478 explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx) 479 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 480 }; 481 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") 482 .CompileTimeConstInput("input_sizes"), 483 DepthwiseConv2DBackpropInputOp); 484 485 class ConvBackpropFilterOp : public XlaOpKernel { 486 public: 487 explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims, 488 bool depthwise) 489 : XlaOpKernel(ctx), 490 num_spatial_dims_(num_spatial_dims), 491 depthwise_(depthwise) { 492 OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); 493 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); 494 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); 495 string data_format; 496 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); 497 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), 498 errors::InvalidArgument("Invalid data format")); 499 } 500 501 int num_dims() const { return num_spatial_dims_ + 2; } 502 503 void Compile(XlaOpKernelContext* ctx) override { 504 const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_); 505 const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); 506 507 OP_REQUIRES( 508 ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1), 509 errors::InvalidArgument("Current implementation does not yet support " 510 "strides in the batch and depth dimensions.")); 511 512 OP_REQUIRES(ctx, dilations_.size() == num_dims(), 513 errors::InvalidArgument("Dilations field must " 514 "specify ", 515 num_dims(), " dimensions")); 516 OP_REQUIRES( 517 ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, 518 errors::Unimplemented("Current implementation does not support " 519 "dilations in the batch and depth dimensions.")); 520 for (int i = 0; i < num_spatial_dims_; ++i) { 521 int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 522 OP_REQUIRES(ctx, dilations_[input_dim] >= 1, 523 errors::Unimplemented("Dilation values must be positive; ", i, 524 "th spatial dimension had dilation ", 525 dilations_[input_dim])); 526 } 527 528 const TensorShape activations_shape = ctx->InputShape(0); 529 TensorShape filter_shape; 530 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); 531 const TensorShape out_backprop_shape = ctx->InputShape(2); 532 533 const TensorShape expanded_filter_shape = 534 depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) 535 : filter_shape; 536 537 // Reuse dimension computation logic from conv_grad_ops.cc. 538 ConvBackpropDimensions dims; 539 OP_REQUIRES_OK(ctx, 540 ConvBackpropComputeDimensionsV2( 541 type_string(), num_spatial_dims_, activations_shape, 542 expanded_filter_shape, out_backprop_shape, dilations_, 543 strides_, padding_, data_format_, &dims)); 544 545 xla::ComputationBuilder* b = ctx->builder(); 546 xla::ComputationDataHandle activations = ctx->Input(0); 547 xla::ComputationDataHandle gradients = ctx->Input(2); 548 549 // The filter gradients are computed by a convolution of the input 550 // activations and the output gradients, with some appropriate padding. 551 // See the comment at the top of conv_grad_ops.h for details. 552 553 xla::ConvolutionDimensionNumbers dnums; 554 555 // The activations (inputs) form the LHS of the convolution. 556 // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] 557 // For the gradient computation, we flip the roles of the batch and 558 // feature dimensions. 559 // Each spatial entry has size in_depth * batch 560 561 // Swap n_dim and c_dim in the activations. 562 dnums.set_input_batch_dimension(c_dim); 563 dnums.set_input_feature_dimension(n_dim); 564 565 // The gradients become the RHS of the convolution. 566 // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] 567 // where the batch becomes the input feature for the convolution. 568 dnums.set_kernel_input_feature_dimension(n_dim); 569 dnums.set_kernel_output_feature_dimension(c_dim); 570 571 std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); 572 std::vector<int64> rhs_dilation(num_spatial_dims_); 573 std::vector<int64> window_strides(num_spatial_dims_); 574 std::vector<int64> ones(num_spatial_dims_, 1); 575 576 // Tensorflow filter shape is [ H, W, ..., inC, outC ]. 577 for (int i = 0; i < num_spatial_dims_; ++i) { 578 dnums.add_output_spatial_dimensions(i); 579 } 580 dnums.set_output_batch_dimension(num_spatial_dims_); 581 dnums.set_output_feature_dimension(num_spatial_dims_ + 1); 582 583 for (int i = 0; i < num_spatial_dims_; ++i) { 584 int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); 585 dnums.add_input_spatial_dimensions(dim); 586 dnums.add_kernel_spatial_dimensions(dim); 587 588 // We will also need to pad the input with zeros such that after the 589 // convolution, we get the right size for the filter. 590 // The padded_in_rows should be such that when we convolve this with the 591 // expanded_out_rows as a filter, we should get filter_rows back. 592 // 593 const int64 padded_in_size = 594 dims.spatial_dims[i].expanded_output_size + 595 (dims.spatial_dims[i].filter_size - 1) * dilations_[dim]; 596 597 // However it can be smaller than input_rows: in this 598 // case it means some of the inputs are not used. 599 // 600 // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: 601 // 602 // INPUT = [ A B C ] 603 // 604 // FILTER = [ x y ] 605 // 606 // and the output will only have one column: a = A * x + B * y 607 // 608 // and input "C" is not used at all. 609 // 610 // We apply negative padding in this case. 611 const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; 612 613 // + For the VALID padding, we don't pad anything on the top/left side 614 // and pad the bottom/right side with the remaining space. 615 // + For the SAME padding, we pad top/left side the same as bottom/right 616 // side. 617 // 618 // In addition, if the padded input size is smaller than the input size, 619 // we need to ignore some training elements of the input. We do this by 620 // applying negative padding on the right/bottom. 621 const int64 pad_before = 622 padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0; 623 624 padding[i] = {pad_before, pad_total - pad_before}; 625 rhs_dilation[i] = dims.spatial_dims[i].stride; 626 window_strides[i] = dilations_[dim]; 627 } 628 629 // Besides padding the input, we will also expand output_rows to 630 // expanded_out_rows = (output_rows - 1) * stride + 1 631 // with zeros in between: 632 // 633 // a . . . b . . . c . . . d . . . e 634 // 635 // This is done by specifying the window dilation factors in the 636 // convolution HLO below. 637 auto filter_backprop = 638 b->ConvGeneralDilated(activations, gradients, window_strides, padding, 639 /*lhs_dilation=*/ones, rhs_dilation, dnums); 640 641 if (depthwise_) { 642 filter_backprop = ContractFilterForDepthwiseBackprop( 643 ctx, filter_shape, ctx->input_type(0), filter_backprop, b); 644 } 645 ctx->SetOutput(0, filter_backprop); 646 } 647 648 protected: 649 const int num_spatial_dims_; 650 const bool depthwise_; 651 std::vector<int32> dilations_; 652 std::vector<int32> strides_; 653 Padding padding_; 654 TensorFormat data_format_ = FORMAT_NHWC; 655 656 private: 657 TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); 658 }; 659 660 class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { 661 public: 662 explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx) 663 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) { 664 } 665 }; 666 REGISTER_XLA_OP( 667 Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"), 668 Conv2DBackpropFilterOp); 669 670 class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { 671 public: 672 explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx) 673 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) { 674 } 675 }; 676 REGISTER_XLA_OP( 677 Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"), 678 Conv3DBackpropFilterOp); 679 680 class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { 681 public: 682 explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx) 683 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} 684 }; 685 REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") 686 .CompileTimeConstInput("filter_sizes"), 687 DepthwiseConv2DBackpropFilterOp); 688 689 } // namespace 690 } // namespace tensorflow 691