1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define USE_EIGEN_TENSOR 17 #define EIGEN_USE_THREADS 18 19 #include "tensorflow/core/kernels/conv_3d.h" 20 21 #include "tensorflow/core/framework/numeric_op.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/tensor_slice.h" 27 #include "tensorflow/core/kernels/conv_2d.h" 28 #include "tensorflow/core/kernels/conv_grad_ops.h" 29 #include "tensorflow/core/kernels/conv_ops_gpu.h" 30 #include "tensorflow/core/kernels/ops_util.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/gtl/inlined_vector.h" 33 #include "tensorflow/core/util/padding.h" 34 #include "tensorflow/core/util/tensor_format.h" 35 #include "tensorflow/core/util/use_cudnn.h" 36 #include "tensorflow/core/util/work_sharder.h" 37 38 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 39 #include "tensorflow/core/kernels/eigen_contraction_kernel.h" 40 #endif 41 42 #if GOOGLE_CUDA 43 #include "tensorflow/core/platform/stream_executor.h" 44 using stream_executor::dnn::DimIndex; 45 #endif 46 47 namespace { 48 49 // TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and 50 // conv_grad_input_ops_3d.cc. 51 52 // TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels. 53 54 // "Depth" is already used for the channel dimension, so for the third spatial 55 // dimension in this file we use "plane", although in NDHWC layout it's 56 // indicated with a "D". 57 58 // Returns in 'im_data' (assumed to be zero-initialized) image patch in storage 59 // order (planes, height, width, depth), constructed from patches in 'col_data', 60 // which is required to be in storage order (out_planes * out_height * 61 // out_width, filter_planes, filter_height, filter_width, in_depth). 62 // 63 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq). 64 template <typename T> 65 void Col2im(const T* col_data, const int depth, const int planes, 66 const int height, const int width, const int filter_p, 67 const int filter_h, const int filter_w, const int pad_pt, 68 const int pad_t, const int pad_l, const int pad_pb, const int pad_b, 69 const int pad_r, const int stride_p, const int stride_h, 70 const int stride_w, T* im_data) { 71 const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; 72 const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; 73 const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; 74 int p_pad = -pad_pt; 75 for (int p = 0; p < planes_col; ++p) { 76 int h_pad = -pad_t; 77 for (int h = 0; h < height_col; ++h) { 78 int w_pad = -pad_l; 79 for (int w = 0; w < width_col; ++w) { 80 T* im_patch_data = 81 im_data + (p_pad * height * width + h_pad * width + w_pad) * depth; 82 for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { 83 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { 84 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { 85 if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && 86 iw < width) { 87 for (int i = 0; i < depth; ++i) { 88 im_patch_data[i] += col_data[i]; 89 } 90 } 91 im_patch_data += depth; 92 col_data += depth; 93 } 94 // Jump over remaining number of depth. 95 im_patch_data += depth * (width - filter_w); 96 } 97 // Jump over remaining number of (depth * width). 98 im_patch_data += (depth * width) * (height - filter_h); 99 } 100 w_pad += stride_w; 101 } 102 h_pad += stride_h; 103 } 104 p_pad += stride_p; 105 } 106 } 107 108 // Returns in 'col_data', image patches in storage order (planes, height, width, 109 // depth) extracted from image at 'input_data', which is required to be in 110 // storage order (batch, planes, height, width, depth). 111 // 112 // Based on 2-dimensional implementation written by Yangqing Jia (jiayq). 113 template <typename T> 114 void Im2col(const T* input_data, const int depth, const int planes, 115 const int height, const int width, const int filter_p, 116 const int filter_h, const int filter_w, const int pad_pt, 117 const int pad_t, const int pad_l, const int pad_pb, const int pad_b, 118 const int pad_r, const int stride_p, const int stride_h, 119 const int stride_w, T* col_data) { 120 const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; 121 const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; 122 const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; 123 124 int p_pad = -pad_pt; 125 for (int p = 0; p < planes_col; ++p) { 126 int h_pad = -pad_t; 127 for (int h = 0; h < height_col; ++h) { 128 int w_pad = -pad_l; 129 for (int w = 0; w < width_col; ++w) { 130 for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { 131 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { 132 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { 133 if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && 134 iw < width) { 135 memcpy(col_data, 136 input_data + 137 (ip * height * width + ih * width + iw) * depth, 138 sizeof(T) * depth); 139 } else { 140 // This should be simply padded with zero. 141 memset(col_data, 0, sizeof(T) * depth); 142 } 143 col_data += depth; 144 } 145 } 146 } 147 w_pad += stride_w; 148 } 149 h_pad += stride_h; 150 } 151 p_pad += stride_p; 152 } 153 } 154 155 } // namespace 156 157 namespace tensorflow { 158 159 typedef Eigen::ThreadPoolDevice CPUDevice; 160 typedef Eigen::GpuDevice GPUDevice; 161 162 // Backprop for input that offloads computation to 163 // Eigen::CuboidConvolutionBackwardInput. 164 template <typename Device, class T> 165 class Conv3DBackpropInputOp : public OpKernel { 166 public: 167 explicit Conv3DBackpropInputOp(OpKernelConstruction* context) 168 : OpKernel(context), 169 data_format_(FORMAT_NHWC), 170 takes_shape_(type_string().find("V2") != std::string::npos) { 171 // data_format is only available in V2. 172 if (takes_shape_) { 173 string data_format; 174 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 175 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 176 errors::InvalidArgument("Invalid data format")); 177 OP_REQUIRES( 178 context, data_format_ == FORMAT_NHWC, 179 errors::InvalidArgument( 180 "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.")); 181 } 182 183 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); 184 OP_REQUIRES(context, dilation_.size() == 5, 185 errors::InvalidArgument("Dilation rates field must " 186 "specify 5 dimensions")); 187 OP_REQUIRES(context, 188 (GetTensorDim(dilation_, data_format_, 'C') == 1 && 189 GetTensorDim(dilation_, data_format_, 'N') == 1), 190 errors::InvalidArgument( 191 "Current implementation does not yet support " 192 "dilation rates in the batch and depth dimensions.")); 193 194 // TODO(yangzihao): Add CPU version of dilated conv 3D. 195 OP_REQUIRES(context, 196 (GetTensorDim(dilation_, data_format_, '0') == 1 && 197 GetTensorDim(dilation_, data_format_, '1') == 1 && 198 GetTensorDim(dilation_, data_format_, '2') == 1), 199 errors::InvalidArgument( 200 "Current CPU implementation does not yet support " 201 "dilation rates larger than 1.")); 202 203 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 204 OP_REQUIRES(context, stride_.size() == 5, 205 errors::InvalidArgument("Sliding window strides field must " 206 "specify 5 dimensions")); 207 OP_REQUIRES( 208 context, 209 (GetTensorDim(stride_, data_format_, 'C') == 1 && 210 GetTensorDim(stride_, data_format_, 'N') == 1), 211 errors::InvalidArgument("Current implementation does not yet support " 212 "strides in the batch and depth dimensions.")); 213 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 214 } 215 216 void Compute(OpKernelContext* context) override { 217 const Tensor& filter = context->input(1); 218 const TensorShape& filter_shape = filter.shape(); 219 220 const Tensor& out_backprop = context->input(2); 221 const TensorShape& out_backprop_shape = out_backprop.shape(); 222 223 TensorShape input_shape; 224 if (takes_shape_) { 225 const Tensor& input_sizes = context->input(0); 226 // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes. 227 OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); 228 } else { 229 input_shape = context->input(0).shape(); 230 } 231 232 ConvBackpropDimensions dims; 233 OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( 234 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, 235 input_shape, filter_shape, out_backprop_shape, 236 stride_, padding_, data_format_, &dims)); 237 238 Tensor* in_backprop; 239 OP_REQUIRES_OK(context, 240 context->allocate_output(0, input_shape, &in_backprop)); 241 242 functor::CuboidConvolutionBackwardInput<Device, T>()( 243 context->eigen_device<Device>(), 244 in_backprop->tensor<T, 5>(), // input_backward 245 filter.tensor<T, 5>(), // filter 246 out_backprop.tensor<T, 5>(), // output_backward 247 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes 248 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows 249 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols 250 } 251 252 private: 253 std::vector<int32> dilation_; 254 std::vector<int32> stride_; 255 Padding padding_; 256 TensorFormat data_format_; 257 bool takes_shape_; 258 259 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp); 260 }; 261 262 // Custom backprop for input that explicitly does the work sharding and calls 263 // Eigen only to multiply matrices. 264 template <typename Device, class T> 265 class Conv3DCustomBackpropInputOp : public OpKernel { 266 // Limit the maximum size of allocated temporary buffer to 267 // kMaxTempAllocationOverhead times the size of the input tensors (input, 268 // filter, out_backprop). If the size of the temporary buffer exceeds this 269 // limit, fallback on Eigen implementation. 270 static constexpr int kMaxTempAllocationOverhead = 25; 271 272 public: 273 explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context) 274 : OpKernel(context), 275 data_format_(FORMAT_NHWC), 276 takes_shape_(type_string().find("V2") != std::string::npos) { 277 // data_format is only available in V2. 278 if (takes_shape_) { 279 string data_format; 280 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 281 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 282 errors::InvalidArgument("Invalid data format")); 283 OP_REQUIRES( 284 context, data_format_ == FORMAT_NHWC, 285 errors::InvalidArgument( 286 "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.")); 287 } 288 289 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); 290 OP_REQUIRES(context, dilation_.size() == 5, 291 errors::InvalidArgument("Dilation rates field must " 292 "specify 5 dimensions")); 293 OP_REQUIRES(context, 294 (GetTensorDim(dilation_, data_format_, 'C') == 1 && 295 GetTensorDim(dilation_, data_format_, 'N') == 1), 296 errors::InvalidArgument( 297 "Current implementation does not yet support " 298 "dilation rates in the batch and depth dimensions.")); 299 300 // TODO(yangzihao): Add CPU version of dilated conv 3D. 301 OP_REQUIRES(context, 302 (GetTensorDim(dilation_, data_format_, '0') == 1 && 303 GetTensorDim(dilation_, data_format_, '1') == 1 && 304 GetTensorDim(dilation_, data_format_, '2') == 1), 305 errors::InvalidArgument( 306 "Current CPU implementation does not yet support " 307 "dilation rates larger than 1.")); 308 309 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 310 OP_REQUIRES(context, stride_.size() == 5, 311 errors::InvalidArgument("Sliding window strides field must " 312 "specify 5 dimensions")); 313 OP_REQUIRES( 314 context, 315 (GetTensorDim(stride_, data_format_, 'C') == 1 && 316 GetTensorDim(stride_, data_format_, 'N') == 1), 317 errors::InvalidArgument("Current implementation does not yet support " 318 "strides in the batch and depth dimensions.")); 319 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 320 } 321 322 void Compute(OpKernelContext* context) override { 323 const Tensor& filter = context->input(1); 324 const TensorShape& filter_shape = filter.shape(); 325 326 const Tensor& out_backprop = context->input(2); 327 const TensorShape& out_backprop_shape = out_backprop.shape(); 328 329 TensorShape input_shape; 330 if (takes_shape_) { 331 const Tensor& input_sizes = context->input(0); 332 // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes. 333 OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); 334 } else { 335 input_shape = context->input(0).shape(); 336 } 337 338 ConvBackpropDimensions dims; 339 OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( 340 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, 341 input_shape, filter_shape, out_backprop_shape, 342 stride_, padding_, data_format_, &dims)); 343 344 Tensor* in_backprop; 345 OP_REQUIRES_OK(context, 346 context->allocate_output(0, input_shape, &in_backprop)); 347 348 int64 top_pad_planes, bottom_pad_planes; 349 int64 top_pad_rows, bottom_pad_rows; 350 int64 left_pad_cols, right_pad_cols; 351 352 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 353 dims.spatial_dims[0].input_size, 354 dims.spatial_dims[0].filter_size, 355 dims.spatial_dims[0].stride, padding_, 356 &dims.spatial_dims[0].output_size, 357 &top_pad_planes, &bottom_pad_planes)); 358 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 359 dims.spatial_dims[1].input_size, 360 dims.spatial_dims[1].filter_size, 361 dims.spatial_dims[1].stride, padding_, 362 &dims.spatial_dims[1].output_size, 363 &top_pad_rows, &bottom_pad_rows)); 364 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 365 dims.spatial_dims[2].input_size, 366 dims.spatial_dims[2].filter_size, 367 dims.spatial_dims[2].stride, padding_, 368 &dims.spatial_dims[2].output_size, 369 &left_pad_cols, &right_pad_cols)); 370 371 // TODO(ezhulenev): Extract work size and shard estimation to shared 372 // functions in conv_grad_ops, and update 2d convolution backprop. 373 374 // The total dimension size of each kernel. 375 const int64 filter_total_size = 376 dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * 377 dims.spatial_dims[2].filter_size * dims.in_depth; 378 379 // The output image size is the spatial size of the output. 380 const int64 output_image_size = dims.spatial_dims[0].output_size * 381 dims.spatial_dims[1].output_size * 382 dims.spatial_dims[2].output_size; 383 384 const auto cache_sizes = Eigen::internal::CacheSizes(); 385 const ptrdiff_t l3_cache_size = cache_sizes.m_l3; 386 387 // Use L3 cache size as target working set size. 388 const size_t target_working_set_size = l3_cache_size / sizeof(T); 389 390 // Calculate size of matrices involved in MatMul: C = A x B. 391 const int64 size_A = output_image_size * dims.out_depth; 392 393 const int64 size_B = filter_total_size * dims.out_depth; 394 395 const int64 size_C = output_image_size * filter_total_size; 396 397 const int64 work_unit_size = size_A + size_B + size_C; 398 399 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 400 401 // Use parallel tensor contractions if there is no batching. 402 // 403 // Compared to Conv2D code, this version is missing work size estimation. In 404 // benchmarks I didn't find a case when it's beneficial to run parallel 405 // contraction compared to sharding and matmuls. 406 const bool use_parallel_contraction = dims.batch_size == 1; 407 408 const size_t shard_size = 409 use_parallel_contraction 410 ? 1 411 : (target_working_set_size + work_unit_size - 1) / work_unit_size; 412 413 // Total number of elements in all the tensors used by this kernel. 414 int64 total_tensor_elements = input_shape.num_elements() + 415 filter_shape.num_elements() + 416 out_backprop_shape.num_elements(); 417 418 // Shape of the temporary workspace buffer. 419 TensorShape col_buffer_shape = {static_cast<int64>(shard_size), 420 static_cast<int64>(output_image_size), 421 static_cast<int64>(filter_total_size)}; 422 int64 col_buffer_elements = col_buffer_shape.num_elements(); 423 424 // If the temporary allocation overhead is too large, fallback on Eigen 425 // implementation which requires much less memory. 426 int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements; 427 if (col_buffer_overhead > kMaxTempAllocationOverhead) { 428 VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: " 429 "col_buffer_overhead=" 430 << col_buffer_overhead; 431 432 functor::CuboidConvolutionBackwardInput<Device, T>()( 433 context->eigen_device<Device>(), 434 in_backprop->tensor<T, 5>(), // input_backward 435 filter.tensor<T, 5>(), // filter 436 out_backprop.tensor<T, 5>(), // output_backward 437 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes 438 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows 439 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols 440 441 return; 442 } 443 444 Tensor col_buffer; 445 OP_REQUIRES_OK(context, 446 context->allocate_temp(DataTypeToEnum<T>::value, 447 col_buffer_shape, &col_buffer)); 448 449 // The input offset corresponding to a single input image. 450 const int64 input_offset = dims.spatial_dims[0].input_size * 451 dims.spatial_dims[1].input_size * 452 dims.spatial_dims[2].input_size * dims.in_depth; 453 454 // The output offset corresponding to a single output image. 455 const int64 output_offset = 456 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * 457 dims.spatial_dims[2].output_size * dims.out_depth; 458 459 const T* filter_data = filter.template flat<T>().data(); 460 T* col_buffer_data = col_buffer.template flat<T>().data(); 461 const T* out_backprop_data = out_backprop.template flat<T>().data(); 462 463 auto in_backprop_flat = in_backprop->template flat<T>(); 464 T* input_backprop_data = in_backprop_flat.data(); 465 in_backprop_flat.device(context->eigen_device<Device>()) = 466 in_backprop_flat.constant(T(0)); 467 468 if (use_parallel_contraction) { 469 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, 470 Eigen::Unaligned> 471 TensorMap; 472 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, 473 Eigen::Unaligned> 474 ConstTensorMap; 475 476 // Initialize contraction dims (we need to transpose 'B' below). 477 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; 478 contract_dims[0].first = 1; 479 contract_dims[0].second = 1; 480 481 for (int image_id = 0; image_id < dims.batch_size; ++image_id) { 482 // Compute gradient into col_buffer. 483 TensorMap C(col_buffer_data, output_image_size, filter_total_size); 484 485 ConstTensorMap A(out_backprop_data + output_offset * image_id, 486 output_image_size, dims.out_depth); 487 ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); 488 489 C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); 490 491 Col2im<T>(col_buffer_data, dims.in_depth, 492 // Input spatial dimensions. 493 dims.spatial_dims[0].input_size, // input planes 494 dims.spatial_dims[1].input_size, // input rows 495 dims.spatial_dims[2].input_size, // input cols 496 // Filter spatial dimensions. 497 dims.spatial_dims[0].filter_size, // filter planes 498 dims.spatial_dims[1].filter_size, // filter rows 499 dims.spatial_dims[2].filter_size, // filter cols 500 // Spatial padding. 501 top_pad_planes, top_pad_rows, left_pad_cols, 502 bottom_pad_planes, bottom_pad_rows, right_pad_cols, 503 // Spatial striding. 504 dims.spatial_dims[0].stride, // stride planes 505 dims.spatial_dims[1].stride, // stride rows 506 dims.spatial_dims[2].stride, // stride cols 507 input_backprop_data); 508 509 input_backprop_data += input_offset; 510 } 511 } else { 512 typedef Eigen::Map< 513 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 514 MatrixMap; 515 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, 516 Eigen::RowMajor>> 517 ConstMatrixMap; 518 519 for (int image_id = 0; image_id < dims.batch_size; 520 image_id += shard_size) { 521 const int shard_limit = 522 std::min(static_cast<int>(shard_size), 523 static_cast<int>(dims.batch_size) - image_id); 524 525 auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols, 526 &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols, 527 &output_image_size, &filter_total_size, 528 &input_backprop_data, &col_buffer_data, 529 &out_backprop_data, &filter_data, &input_offset, 530 &output_offset, &size_C](int64 start, int64 limit) { 531 for (int shard_id = start; shard_id < limit; ++shard_id) { 532 T* im2col_buf = col_buffer_data + shard_id * size_C; 533 T* input_data = input_backprop_data + shard_id * input_offset; 534 const T* out_data = out_backprop_data + shard_id * output_offset; 535 536 // Compute gradient into 'im2col_buf'. 537 MatrixMap C(im2col_buf, output_image_size, filter_total_size); 538 539 ConstMatrixMap A(out_data, output_image_size, dims.out_depth); 540 ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth); 541 542 C.noalias() = A * B.transpose(); 543 544 Col2im<T>(im2col_buf, dims.in_depth, 545 // Input spatial dimensions. 546 dims.spatial_dims[0].input_size, // input planes 547 dims.spatial_dims[1].input_size, // input rows 548 dims.spatial_dims[2].input_size, // input cols 549 // Filter spatial dimensions. 550 dims.spatial_dims[0].filter_size, // filter planes 551 dims.spatial_dims[1].filter_size, // filter rows 552 dims.spatial_dims[2].filter_size, // filter cols 553 // Spatial padding. 554 top_pad_planes, top_pad_rows, left_pad_cols, 555 bottom_pad_planes, bottom_pad_rows, right_pad_cols, 556 // Spatial striding. 557 dims.spatial_dims[0].stride, // stride planes 558 dims.spatial_dims[1].stride, // stride rows 559 dims.spatial_dims[2].stride, // stride cols 560 input_data); 561 } 562 }; 563 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, 564 work_unit_size, shard); 565 566 input_backprop_data += input_offset * shard_limit; 567 out_backprop_data += output_offset * shard_limit; 568 } 569 } 570 } 571 572 private: 573 std::vector<int32> dilation_; 574 std::vector<int32> stride_; 575 Padding padding_; 576 TensorFormat data_format_; 577 bool takes_shape_; 578 579 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp); 580 }; 581 582 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than 583 // default Eigen implementation (at the cost of ~2x-8x peak memory usage). 584 585 #define REGISTER_CPU_KERNEL(T) \ 586 REGISTER_KERNEL_BUILDER( \ 587 Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 588 Conv3DCustomBackpropInputOp<CPUDevice, T>); \ 589 REGISTER_KERNEL_BUILDER( \ 590 Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 591 Conv3DCustomBackpropInputOp<CPUDevice, T>); \ 592 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ 593 .Device(DEVICE_CPU) \ 594 .Label("custom") \ 595 .TypeConstraint<T>("T"), \ 596 Conv3DCustomBackpropInputOp<CPUDevice, T>); \ 597 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ 598 .Device(DEVICE_CPU) \ 599 .Label("custom") \ 600 .TypeConstraint<T>("T"), \ 601 Conv3DCustomBackpropInputOp<CPUDevice, T>); \ 602 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ 603 .Device(DEVICE_CPU) \ 604 .Label("eigen_tensor") \ 605 .TypeConstraint<T>("T"), \ 606 Conv3DBackpropInputOp<CPUDevice, T>); \ 607 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ 608 .Device(DEVICE_CPU) \ 609 .Label("eigen_tensor") \ 610 .TypeConstraint<T>("T"), \ 611 Conv3DBackpropInputOp<CPUDevice, T>); 612 613 TF_CALL_half(REGISTER_CPU_KERNEL); 614 TF_CALL_float(REGISTER_CPU_KERNEL); 615 TF_CALL_double(REGISTER_CPU_KERNEL); 616 #undef REGISTER_CPU_KERNEL 617 618 // Backprop for filter that offloads computation to 619 // Eigen::CuboidConvolutionBackwardFilter. 620 template <typename Device, class T> 621 class Conv3DBackpropFilterOp : public OpKernel { 622 public: 623 explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) 624 : OpKernel(context), 625 data_format_(FORMAT_NHWC), 626 takes_shape_(type_string().find("V2") != std::string::npos) { 627 // data_format is only available in V2. 628 if (takes_shape_) { 629 string data_format; 630 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 631 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 632 errors::InvalidArgument("Invalid data format")); 633 OP_REQUIRES( 634 context, data_format_ == FORMAT_NHWC, 635 errors::InvalidArgument( 636 "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU.")); 637 } 638 639 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); 640 OP_REQUIRES(context, dilation_.size() == 5, 641 errors::InvalidArgument("Dilation rates field must " 642 "specify 5 dimensions")); 643 OP_REQUIRES(context, 644 (GetTensorDim(dilation_, data_format_, 'C') == 1 && 645 GetTensorDim(dilation_, data_format_, 'N') == 1), 646 errors::InvalidArgument( 647 "Current implementation does not yet support " 648 "dilation rates in the batch and depth dimensions.")); 649 650 // TODO(yangzihao): Add CPU version of dilated conv 3D. 651 OP_REQUIRES(context, 652 (GetTensorDim(dilation_, data_format_, '0') == 1 && 653 GetTensorDim(dilation_, data_format_, '1') == 1 && 654 GetTensorDim(dilation_, data_format_, '2') == 1), 655 errors::InvalidArgument( 656 "Current CPU implementation does not yet support " 657 "dilation rates larger than 1.")); 658 659 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 660 OP_REQUIRES(context, stride_.size() == 5, 661 errors::InvalidArgument("Sliding window strides field must " 662 "specify 5 dimensions")); 663 OP_REQUIRES( 664 context, 665 (GetTensorDim(stride_, data_format_, 'C') == 1 && 666 GetTensorDim(stride_, data_format_, 'N') == 1), 667 errors::InvalidArgument("Current implementation does not yet support " 668 "strides in the batch and depth dimensions.")); 669 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 670 } 671 672 void Compute(OpKernelContext* context) override { 673 const Tensor& input = context->input(0); 674 const TensorShape& input_shape = input.shape(); 675 676 const Tensor& out_backprop = context->input(2); 677 const TensorShape& out_backprop_shape = out_backprop.shape(); 678 679 TensorShape filter_shape; 680 if (takes_shape_) { 681 const Tensor& filter_sizes = context->input(1); 682 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 683 filter_sizes.vec<int32>(), &filter_shape)); 684 } else { 685 filter_shape = context->input(1).shape(); 686 } 687 688 ConvBackpropDimensions dims; 689 OP_REQUIRES_OK(context, 690 ConvBackpropComputeDimensions( 691 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, 692 input_shape, filter_shape, out_backprop_shape, stride_, 693 padding_, data_format_, &dims)); 694 695 Tensor* filter_backprop; 696 OP_REQUIRES_OK(context, 697 context->allocate_output(0, filter_shape, &filter_backprop)); 698 699 if (input_shape.num_elements() == 0) { 700 filter_backprop->template flat<T>().setZero(); 701 return; 702 } 703 704 functor::CuboidConvolutionBackwardFilter<Device, T>()( 705 context->eigen_device<Device>(), 706 filter_backprop->tensor<T, 5>(), // filter_backward 707 input.tensor<T, 5>(), // input 708 out_backprop.tensor<T, 5>(), // output_backward 709 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes 710 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows 711 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols 712 } 713 714 private: 715 std::vector<int32> dilation_; 716 std::vector<int32> stride_; 717 Padding padding_; 718 TensorFormat data_format_; 719 bool takes_shape_; 720 721 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp); 722 }; 723 724 // Custom backprop for filter that explicitly does the work sharding and calls 725 // Eigen only to multiply matrices. 726 template <typename Device, class T> 727 class Conv3DCustomBackpropFilterOp : public OpKernel { 728 // Limit the maximum size of allocated temporary buffer to 729 // kMaxTempAllocationOverhead times the size of the input tensors (input, 730 // filter, out_backprop). If the size of the temporary buffer exceeds this 731 // limit, fallback on Eigen implementation. 732 static constexpr int kMaxTempAllocationOverhead = 25; 733 734 public: 735 explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context) 736 : OpKernel(context), 737 data_format_(FORMAT_NHWC), 738 takes_shape_(type_string().find("V2") != std::string::npos) { 739 // data_format is only available in V2. 740 if (takes_shape_) { 741 string data_format; 742 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 743 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 744 errors::InvalidArgument("Invalid data format")); 745 OP_REQUIRES( 746 context, data_format_ == FORMAT_NHWC, 747 errors::InvalidArgument( 748 "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU.")); 749 } 750 751 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); 752 OP_REQUIRES(context, dilation_.size() == 5, 753 errors::InvalidArgument("Dilation rates field must " 754 "specify 5 dimensions")); 755 OP_REQUIRES(context, 756 (GetTensorDim(dilation_, data_format_, 'C') == 1 && 757 GetTensorDim(dilation_, data_format_, 'N') == 1), 758 errors::InvalidArgument( 759 "Current implementation does not yet support " 760 "dilation rates in the batch and depth dimensions.")); 761 762 // TODO(yangzihao): Add CPU version of dilated conv 3D. 763 OP_REQUIRES(context, 764 (GetTensorDim(dilation_, data_format_, '0') == 1 && 765 GetTensorDim(dilation_, data_format_, '1') == 1 && 766 GetTensorDim(dilation_, data_format_, '2') == 1), 767 errors::InvalidArgument( 768 "Current CPU implementation does not yet support " 769 "dilation rates larger than 1.")); 770 771 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 772 OP_REQUIRES(context, stride_.size() == 5, 773 errors::InvalidArgument("Sliding window strides field must " 774 "specify 5 dimensions")); 775 OP_REQUIRES( 776 context, 777 (GetTensorDim(stride_, data_format_, 'C') == 1 && 778 GetTensorDim(stride_, data_format_, 'N') == 1), 779 errors::InvalidArgument("Current implementation does not yet support " 780 "strides in the batch and depth dimensions.")); 781 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 782 } 783 784 void Compute(OpKernelContext* context) override { 785 const Tensor& input = context->input(0); 786 const TensorShape& input_shape = input.shape(); 787 788 const Tensor& out_backprop = context->input(2); 789 const TensorShape& out_backprop_shape = out_backprop.shape(); 790 791 TensorShape filter_shape; 792 if (takes_shape_) { 793 const Tensor& filter_sizes = context->input(1); 794 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 795 filter_sizes.vec<int32>(), &filter_shape)); 796 } else { 797 filter_shape = context->input(1).shape(); 798 } 799 800 ConvBackpropDimensions dims; 801 OP_REQUIRES_OK(context, 802 ConvBackpropComputeDimensions( 803 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, 804 input_shape, filter_shape, out_backprop_shape, stride_, 805 padding_, data_format_, &dims)); 806 807 Tensor* filter_backprop; 808 OP_REQUIRES_OK(context, 809 context->allocate_output(0, filter_shape, &filter_backprop)); 810 811 if (input_shape.num_elements() == 0) { 812 filter_backprop->template flat<T>().setZero(); 813 return; 814 } 815 816 int64 top_pad_planes, bottom_pad_planes; 817 int64 top_pad_rows, bottom_pad_rows; 818 int64 left_pad_cols, right_pad_cols; 819 820 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 821 dims.spatial_dims[0].input_size, 822 dims.spatial_dims[0].filter_size, 823 dims.spatial_dims[0].stride, padding_, 824 &dims.spatial_dims[0].output_size, 825 &top_pad_planes, &bottom_pad_planes)); 826 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 827 dims.spatial_dims[1].input_size, 828 dims.spatial_dims[1].filter_size, 829 dims.spatial_dims[1].stride, padding_, 830 &dims.spatial_dims[1].output_size, 831 &top_pad_rows, &bottom_pad_rows)); 832 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 833 dims.spatial_dims[2].input_size, 834 dims.spatial_dims[2].filter_size, 835 dims.spatial_dims[2].stride, padding_, 836 &dims.spatial_dims[2].output_size, 837 &left_pad_cols, &right_pad_cols)); 838 839 // TODO(ezhulenev): Extract work size and shard estimation to shared 840 // functions in conv_grad_ops, and update 2d convolution backprop. 841 842 // The total dimension size of each kernel. 843 const int64 filter_total_size = 844 dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * 845 dims.spatial_dims[2].filter_size * dims.in_depth; 846 // The output image size is the spatial size of the output. 847 const int64 output_image_size = dims.spatial_dims[0].output_size * 848 dims.spatial_dims[1].output_size * 849 dims.spatial_dims[2].output_size; 850 851 // Shard 'batch' images (volumes) into 'shard_size' groups of images 852 // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by 853 // dividing the L3 cache size ('target_working_set_size') by the matmul size 854 // of an individual image ('work_unit_size'). 855 856 const auto cache_sizes = Eigen::internal::CacheSizes(); 857 const ptrdiff_t l3_cache_size = cache_sizes.m_l3; 858 859 // TODO(andydavis) 860 // *) Consider reducing 'target_working_set_size' if L3 is shared by 861 // other concurrently running tensorflow ops. 862 const size_t target_working_set_size = l3_cache_size / sizeof(T); 863 864 const int64 size_A = output_image_size * filter_total_size; 865 866 const int64 size_B = output_image_size * dims.out_depth; 867 868 const int64 size_C = filter_total_size * dims.out_depth; 869 870 const int64 work_unit_size = size_A + size_B + size_C; 871 872 const size_t shard_size = 873 (target_working_set_size + work_unit_size - 1) / work_unit_size; 874 875 // Total number of elements in all the tensors used by this kernel. 876 int64 total_tensor_elements = input_shape.num_elements() + 877 filter_shape.num_elements() + 878 out_backprop_shape.num_elements(); 879 880 // Shape of the temporary workspace buffer. 881 TensorShape col_buffer_shape = {static_cast<int64>(shard_size), 882 static_cast<int64>(output_image_size), 883 static_cast<int64>(filter_total_size)}; 884 int64 col_buffer_elements = col_buffer_shape.num_elements(); 885 886 // If the temporary allocation overhead is too large, fallback on Eigen 887 // implementation which requires much less memory. 888 int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements; 889 if (col_buffer_overhead > kMaxTempAllocationOverhead) { 890 VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: " 891 "col_buffer_overhead=" 892 << col_buffer_overhead; 893 894 functor::CuboidConvolutionBackwardFilter<Device, T>()( 895 context->eigen_device<Device>(), 896 filter_backprop->tensor<T, 5>(), // filter_backward 897 input.tensor<T, 5>(), // input 898 out_backprop.tensor<T, 5>(), // output_backward 899 static_cast<int>(dims.spatial_dims[0].stride), // stride_planes 900 static_cast<int>(dims.spatial_dims[1].stride), // stride_rows 901 static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols 902 903 return; 904 } 905 906 Tensor col_buffer; 907 OP_REQUIRES_OK(context, 908 context->allocate_temp(DataTypeToEnum<T>::value, 909 col_buffer_shape, &col_buffer)); 910 911 // The input offset corresponding to a single input image. 912 const int64 input_offset = dims.spatial_dims[0].input_size * 913 dims.spatial_dims[1].input_size * 914 dims.spatial_dims[2].input_size * dims.in_depth; 915 // The output offset corresponding to a single output image. 916 const int64 output_offset = 917 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * 918 dims.spatial_dims[2].output_size * dims.out_depth; 919 920 const T* input_data = input.template flat<T>().data(); 921 T* col_buffer_data = col_buffer.template flat<T>().data(); 922 const T* out_backprop_data = out_backprop.template flat<T>().data(); 923 T* filter_backprop_data = filter_backprop->template flat<T>().data(); 924 925 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, 926 Eigen::Unaligned> 927 TensorMap; 928 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, 929 Eigen::Unaligned> 930 ConstTensorMap; 931 932 TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth); 933 C.setZero(); 934 935 // Initialize contraction dims (we need to transpose 'A' below). 936 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; 937 contract_dims[0].first = 0; 938 contract_dims[0].second = 0; 939 940 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 941 942 for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) { 943 const int shard_limit = 944 std::min(static_cast<int>(shard_size), 945 static_cast<int>(dims.batch_size) - image_id); 946 947 auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes, 948 &top_pad_rows, &left_pad_cols, &bottom_pad_planes, 949 &bottom_pad_rows, &right_pad_cols, &input_offset, 950 &size_A](int64 start, int64 limit) { 951 for (int shard_id = start; shard_id < limit; ++shard_id) { 952 const T* input_data_shard = input_data + shard_id * input_offset; 953 T* col_data_shard = col_buffer_data + shard_id * size_A; 954 955 // When we compute the gradient with respect to the filters, we need 956 // to do im2col to allow gemm-type computation. 957 Im2col<T>(input_data_shard, dims.in_depth, 958 // Input spatial dimensions. 959 dims.spatial_dims[0].input_size, // input planes 960 dims.spatial_dims[1].input_size, // input rows 961 dims.spatial_dims[2].input_size, // input cols 962 // Filter spatial dimensions. 963 dims.spatial_dims[0].filter_size, // filter planes 964 dims.spatial_dims[1].filter_size, // filter rows 965 dims.spatial_dims[2].filter_size, // filter cols 966 // Spatial padding. 967 top_pad_planes, top_pad_rows, left_pad_cols, 968 bottom_pad_planes, bottom_pad_rows, right_pad_cols, 969 // Spatial striding. 970 dims.spatial_dims[0].stride, // stride planes 971 dims.spatial_dims[1].stride, // stride rows 972 dims.spatial_dims[2].stride, // stride cols 973 col_data_shard); 974 } 975 }; 976 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, 977 size_A, shard); 978 979 ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, 980 filter_total_size); 981 ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, 982 dims.out_depth); 983 984 // Gradient with respect to filter. 985 C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); 986 987 input_data += input_offset * shard_limit; 988 out_backprop_data += output_offset * shard_limit; 989 } 990 } 991 992 private: 993 std::vector<int32> dilation_; 994 std::vector<int32> stride_; 995 Padding padding_; 996 TensorFormat data_format_; 997 bool takes_shape_; 998 999 TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp); 1000 }; 1001 1002 // Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than 1003 // default Eigen implementation (at the cost of ~2x-8x peak memory usage). 1004 1005 #define REGISTER_CPU_KERNEL(T) \ 1006 REGISTER_KERNEL_BUILDER( \ 1007 Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 1008 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ 1009 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ 1010 .Device(DEVICE_CPU) \ 1011 .TypeConstraint<T>("T"), \ 1012 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ 1013 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ 1014 .Device(DEVICE_CPU) \ 1015 .Label("custom") \ 1016 .TypeConstraint<T>("T"), \ 1017 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ 1018 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ 1019 .Device(DEVICE_CPU) \ 1020 .Label("custom") \ 1021 .TypeConstraint<T>("T"), \ 1022 Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ 1023 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ 1024 .Device(DEVICE_CPU) \ 1025 .Label("eigen_tensor") \ 1026 .TypeConstraint<T>("T"), \ 1027 Conv3DBackpropFilterOp<CPUDevice, T>); \ 1028 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ 1029 .Device(DEVICE_CPU) \ 1030 .Label("eigen_tensor") \ 1031 .TypeConstraint<T>("T"), \ 1032 Conv3DBackpropFilterOp<CPUDevice, T>); 1033 1034 TF_CALL_float(REGISTER_CPU_KERNEL); 1035 TF_CALL_double(REGISTER_CPU_KERNEL); 1036 #undef REGISTER_CPU_KERNEL 1037 1038 // WARNING: Eigen::half is not trivially copyable and can't be used in 1039 // custom backprop filter kernel because of memcpy and memset in Im2col. 1040 #define REGISTER_CPU_KERNEL(T) \ 1041 REGISTER_KERNEL_BUILDER( \ 1042 Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 1043 Conv3DBackpropFilterOp<CPUDevice, T>); \ 1044 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ 1045 .Device(DEVICE_CPU) \ 1046 .TypeConstraint<T>("T"), \ 1047 Conv3DBackpropFilterOp<CPUDevice, T>); 1048 1049 TF_CALL_half(REGISTER_CPU_KERNEL); 1050 #undef REGISTER_CPU_KERNEL 1051 1052 // GPU definitions of both ops. 1053 #if GOOGLE_CUDA 1054 // Forward declarations of the functor specializations for GPU. 1055 // This ensures that the custom implementation is used instead of the default 1056 // Eigen one (which is used for CPU). 1057 namespace functor { 1058 #define DECLARE_GPU_SPEC(T) \ 1059 template <> \ 1060 void TransformFilter<GPUDevice, T, int, 5>::operator()( \ 1061 const GPUDevice& d, FilterTensorFormat dst_filter_format, \ 1062 typename TTypes<T, 5, int>::ConstTensor in, \ 1063 typename TTypes<T, 5, int>::Tensor out); \ 1064 template <> \ 1065 void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \ 1066 const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \ 1067 typename TTypes<T, 5>::Tensor out); \ 1068 template <> \ 1069 void PadInput<GPUDevice, T, int, 5>::operator()( \ 1070 const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ 1071 const std::array<int, 3>& padding_left, \ 1072 const std::array<int, 3>& padding_right, \ 1073 typename TTypes<T, 5, int>::Tensor out, TensorFormat format); 1074 1075 DECLARE_GPU_SPEC(Eigen::half); 1076 DECLARE_GPU_SPEC(float); 1077 DECLARE_GPU_SPEC(double); 1078 #undef DECLARE_GPU_SPEC 1079 } // namespace functor 1080 1081 // A dummy type to group backward data autotune results together. 1082 struct Conv3dBackwardDataAutoTuneGroup { 1083 static string name() { return "Conv3dBwdData"; } 1084 }; 1085 typedef AutoTuneSingleton<Conv3dBackwardDataAutoTuneGroup, ConvParameters, 1086 se::dnn::AlgorithmConfig> 1087 1088 AutoTuneConv3dBwdData; 1089 template <typename T> 1090 class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { 1091 public: 1092 explicit Conv3DBackpropInputOp(OpKernelConstruction* context) 1093 : OpKernel(context), 1094 data_format_(FORMAT_NHWC), 1095 takes_shape_(type_string().find("V2") != std::string::npos) { 1096 // data_format is only available in V2. 1097 if (takes_shape_) { 1098 string data_format; 1099 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 1100 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 1101 errors::InvalidArgument("Invalid data format")); 1102 } 1103 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); 1104 OP_REQUIRES(context, dilation_.size() == 5, 1105 errors::InvalidArgument("Dilation rates field must " 1106 "specify 5 dimensions")); 1107 OP_REQUIRES(context, 1108 (GetTensorDim(dilation_, data_format_, 'C') == 1 && 1109 GetTensorDim(dilation_, data_format_, 'N') == 1), 1110 errors::InvalidArgument( 1111 "Current implementation does not yet support " 1112 "dilation rates in the batch and depth dimensions.")); 1113 OP_REQUIRES( 1114 context, 1115 (GetTensorDim(dilation_, data_format_, '0') > 0 && 1116 GetTensorDim(dilation_, data_format_, '1') > 0 && 1117 GetTensorDim(dilation_, data_format_, '2') > 0), 1118 errors::InvalidArgument("Dilated rates should be larger than 0.")); 1119 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 1120 OP_REQUIRES(context, stride_.size() == 5, 1121 errors::InvalidArgument("Sliding window strides field must " 1122 "specify 5 dimensions")); 1123 OP_REQUIRES( 1124 context, 1125 (GetTensorDim(stride_, data_format_, 'C') == 1 && 1126 GetTensorDim(stride_, data_format_, 'N') == 1), 1127 errors::InvalidArgument("Current implementation does not yet support " 1128 "strides in the batch and depth dimensions.")); 1129 OP_REQUIRES( 1130 context, 1131 (GetTensorDim(stride_, data_format_, '0') > 0 && 1132 GetTensorDim(stride_, data_format_, '1') > 0 && 1133 GetTensorDim(stride_, data_format_, '2') > 0), 1134 errors::InvalidArgument("Spatial strides should be larger than 0.")); 1135 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 1136 cudnn_use_autotune_ = CudnnUseAutotune(); 1137 } 1138 void Compute(OpKernelContext* context) override { 1139 const Tensor& filter = context->input(1); 1140 const TensorShape& filter_shape = filter.shape(); 1141 1142 const Tensor& out_backprop = context->input(2); 1143 const TensorShape& out_backprop_shape = out_backprop.shape(); 1144 1145 TensorShape input_shape; 1146 if (takes_shape_) { 1147 const Tensor& input_sizes = context->input(0); 1148 OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); 1149 } else { 1150 input_shape = context->input(0).shape(); 1151 } 1152 1153 ConvBackpropDimensions dims; 1154 OP_REQUIRES_OK(context, ConvBackpropComputeDimensionsV2( 1155 "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, 1156 input_shape, filter_shape, out_backprop_shape, 1157 dilation_, stride_, padding_, 1158 /*explicit_paddings=*/{}, data_format_, &dims)); 1159 1160 Tensor* in_backprop; 1161 OP_REQUIRES_OK(context, 1162 context->allocate_output(0, input_shape, &in_backprop)); 1163 1164 auto* stream = context->op_device_context()->stream(); 1165 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); 1166 1167 if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 && 1168 dims.filter_size(2) == 1 && dims.dilation(0) == 1 && 1169 dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 && 1170 dims.stride(1) == 1 && dims.stride(2) == 1 && 1171 data_format_ == FORMAT_NHWC) { 1172 const uint64 m = dims.batch_size * dims.input_size(0) * 1173 dims.input_size(1) * dims.input_size(2); 1174 const uint64 k = dims.out_depth; 1175 const uint64 n = dims.in_depth; 1176 1177 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 1178 out_backprop.template flat<T>().size()); 1179 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), 1180 filter.template flat<T>().size()); 1181 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), 1182 in_backprop->template flat<T>().size()); 1183 1184 auto transpose = se::blas::Transpose::kTranspose; 1185 auto no_transpose = se::blas::Transpose::kNoTranspose; 1186 1187 bool blas_launch_status = 1188 stream 1189 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, 1190 a_ptr, k, 0.0f, &c_ptr, n) 1191 .ok(); 1192 if (!blas_launch_status) { 1193 context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 1194 ", n=", n, ", k=", k)); 1195 } 1196 return; 1197 } else if (dims.filter_size(0) == dims.input_size(0) && 1198 dims.filter_size(1) == dims.input_size(1) && 1199 dims.filter_size(2) == dims.input_size(2) && 1200 padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { 1201 const uint64 m = dims.batch_size; 1202 const uint64 k = dims.out_depth; 1203 const uint64 n = dims.input_size(0) * dims.input_size(1) * 1204 dims.input_size(2) * dims.in_depth; 1205 1206 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 1207 out_backprop.template flat<T>().size()); 1208 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), 1209 filter.template flat<T>().size()); 1210 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), 1211 in_backprop->template flat<T>().size()); 1212 1213 auto transpose = se::blas::Transpose::kTranspose; 1214 auto no_transpose = se::blas::Transpose::kNoTranspose; 1215 1216 bool blas_launch_status = 1217 stream 1218 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, 1219 a_ptr, k, 0.0f, &c_ptr, n) 1220 .ok(); 1221 if (!blas_launch_status) { 1222 context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 1223 ", n=", n, ", k=", k)); 1224 } 1225 return; 1226 } 1227 1228 int padding_planes = dims.SpatialPadding(padding_, 0); 1229 int padding_rows = dims.SpatialPadding(padding_, 1); 1230 int padding_cols = dims.SpatialPadding(padding_, 2); 1231 const bool planes_odd = (padding_planes % 2 != 0); 1232 const bool rows_odd = (padding_rows % 2 != 0); 1233 const bool cols_odd = (padding_cols % 2 != 0); 1234 1235 TensorShape compatible_input_shape; 1236 if (rows_odd || cols_odd || planes_odd) { 1237 // cuDNN only supports the same amount of padding on both sides. 1238 compatible_input_shape = { 1239 dims.batch_size, 1240 dims.in_depth, 1241 dims.input_size(0) + planes_odd, 1242 dims.input_size(1) + rows_odd, 1243 dims.input_size(2) + cols_odd, 1244 }; 1245 } else { 1246 compatible_input_shape = {dims.batch_size, dims.in_depth, 1247 dims.input_size(0), dims.input_size(1), 1248 dims.input_size(2)}; 1249 } 1250 1251 CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) 1252 << "Negative paddings: (" << padding_rows << ", " << padding_cols 1253 << ", " << padding_planes << ")"; 1254 se::dnn::BatchDescriptor input_desc(3); 1255 input_desc.set_count(dims.batch_size) 1256 .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4)) 1257 .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3)) 1258 .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2)) 1259 .set_feature_map_count(dims.in_depth) 1260 .set_layout(se::dnn::DataLayout::kBatchDepthYX); 1261 se::dnn::BatchDescriptor output_desc(3); 1262 output_desc.set_count(dims.batch_size) 1263 .set_spatial_dim(DimIndex::X, dims.output_size(2)) 1264 .set_spatial_dim(DimIndex::Y, dims.output_size(1)) 1265 .set_spatial_dim(DimIndex::Z, dims.output_size(0)) 1266 .set_feature_map_count(dims.out_depth) 1267 .set_layout(se::dnn::DataLayout::kBatchDepthYX); 1268 se::dnn::FilterDescriptor filter_desc(3); 1269 filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) 1270 .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) 1271 .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) 1272 .set_input_feature_map_count(dims.in_depth) 1273 .set_output_feature_map_count(dims.out_depth); 1274 se::dnn::ConvolutionDescriptor conv_desc(3); 1275 conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) 1276 .set_dilation_rate(DimIndex::Y, dims.dilation(1)) 1277 .set_dilation_rate(DimIndex::Z, dims.dilation(0)) 1278 .set_filter_stride(DimIndex::X, dims.stride(2)) 1279 .set_filter_stride(DimIndex::Y, dims.stride(1)) 1280 .set_filter_stride(DimIndex::Z, dims.stride(0)) 1281 .set_zero_padding(DimIndex::X, padding_cols / 2) 1282 .set_zero_padding(DimIndex::Y, padding_rows / 2) 1283 .set_zero_padding(DimIndex::Z, padding_planes / 2); 1284 1285 // Shape: out, in, z, y, x. 1286 Tensor transformed_filter; 1287 OP_REQUIRES_OK( 1288 context, 1289 context->allocate_temp( 1290 DataTypeToEnum<T>::value, 1291 TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), 1292 dims.filter_size(1), dims.filter_size(2)}), 1293 &transformed_filter)); 1294 functor::TransformFilter<GPUDevice, T, int, 5>()( 1295 context->eigen_device<GPUDevice>(), FORMAT_OIHW, 1296 To32Bit(filter.tensor<T, 5>()), 1297 To32Bit(transformed_filter.tensor<T, 5>())); 1298 1299 // Shape: batch, filters, z, y, x. 1300 Tensor transformed_out_backprop; 1301 if (data_format_ == FORMAT_NHWC) { 1302 TensorShape nchw_shape = {dims.batch_size, dims.out_depth, 1303 dims.output_size(0), dims.output_size(1), 1304 dims.output_size(2)}; 1305 if (dims.out_depth > 1) { 1306 OP_REQUIRES_OK(context, context->allocate_temp( 1307 DataTypeToEnum<T>::value, nchw_shape, 1308 &transformed_out_backprop)); 1309 functor::NHWCToNCHW<GPUDevice, T, 5>()( 1310 context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), 1311 transformed_out_backprop.tensor<T, 5>()); 1312 } else { 1313 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); 1314 } 1315 } else { 1316 transformed_out_backprop = out_backprop; 1317 } 1318 // Shape: batch, filters, z, y, x. 1319 Tensor pre_transformed_in_backprop; 1320 OP_REQUIRES_OK( 1321 context, 1322 context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape, 1323 &pre_transformed_in_backprop)); 1324 1325 auto out_backprop_ptr = 1326 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), 1327 transformed_out_backprop.template flat<T>().size()); 1328 auto filter_ptr = 1329 AsDeviceMemory(transformed_filter.template flat<T>().data(), 1330 transformed_filter.template flat<T>().size()); 1331 auto in_backprop_ptr = 1332 AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(), 1333 pre_transformed_in_backprop.template flat<T>().size()); 1334 1335 static int64 ConvolveBackwardDataScratchSize = GetDnnWorkspaceLimit( 1336 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default 1337 1338 const int device_id = stream->parent()->device_ordinal(); 1339 DataType dtype = context->input(0).dtype(); 1340 const ConvParameters conv_parameters = { 1341 dims.batch_size, 1342 dims.in_depth, 1343 {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, 1344 FORMAT_NCHW, 1345 dims.out_depth, 1346 {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, 1347 {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, 1348 {{dims.stride(0), dims.stride(1), dims.stride(2)}}, 1349 {{padding_planes, padding_rows, padding_cols}}, 1350 dtype, 1351 device_id, 1352 }; 1353 1354 using se::dnn::AlgorithmConfig; 1355 using se::dnn::AlgorithmDesc; 1356 using se::dnn::ProfileResult; 1357 AlgorithmConfig algorithm_config; 1358 if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find( 1359 conv_parameters, &algorithm_config)) { 1360 std::vector<AlgorithmDesc> algorithms; 1361 CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( 1362 conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>( 1363 stream->parent()), 1364 &algorithms)); 1365 ProfileResult best_result; 1366 ProfileResult best_result_no_scratch; 1367 for (auto profile_algorithm : algorithms) { 1368 // TODO(zhengxq): profile each algorithm multiple times to better 1369 // accuracy. 1370 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, 1371 context); 1372 ProfileResult profile_result; 1373 bool cudnn_launch_status = 1374 stream 1375 ->ThenConvolveBackwardDataWithAlgorithm( 1376 filter_desc, filter_ptr, output_desc, out_backprop_ptr, 1377 conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, 1378 AlgorithmConfig(profile_algorithm), &profile_result) 1379 .ok(); 1380 if (cudnn_launch_status) { 1381 if (profile_result.is_valid()) { 1382 if (profile_result.elapsed_time_in_ms() < 1383 best_result.elapsed_time_in_ms()) { 1384 best_result = profile_result; 1385 } 1386 if (scratch_allocator.TotalByteSize() == 0 && 1387 profile_result.elapsed_time_in_ms() < 1388 best_result_no_scratch.elapsed_time_in_ms()) { 1389 best_result_no_scratch = profile_result; 1390 } 1391 } 1392 } 1393 } 1394 OP_REQUIRES(context, 1395 best_result.is_valid() || best_result_no_scratch.is_valid(), 1396 errors::NotFound("No algorithm worked!")); 1397 if (best_result.is_valid()) { 1398 algorithm_config.set_algorithm(best_result.algorithm()); 1399 } 1400 if (best_result_no_scratch.is_valid()) { 1401 algorithm_config.set_algorithm_no_scratch( 1402 best_result_no_scratch.algorithm()); 1403 } 1404 AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters, 1405 algorithm_config); 1406 } 1407 DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, 1408 context); 1409 bool cudnn_launch_status = 1410 stream 1411 ->ThenConvolveBackwardDataWithAlgorithm( 1412 filter_desc, filter_ptr, output_desc, out_backprop_ptr, 1413 conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, 1414 algorithm_config, nullptr) 1415 .ok(); 1416 1417 if (!cudnn_launch_status) { 1418 context->SetStatus(errors::Internal( 1419 "cuDNN Backward Data function launch failure : input shape(", 1420 input_shape.DebugString(), ") filter shape(", 1421 filter_shape.DebugString(), ")")); 1422 } 1423 1424 if (rows_odd || cols_odd || planes_odd) { 1425 Tensor in_backprop_remove_padding; 1426 OP_REQUIRES_OK(context, 1427 context->allocate_temp( 1428 DataTypeToEnum<T>::value, 1429 {dims.batch_size, dims.in_depth, dims.input_size(0), 1430 dims.input_size(1), dims.input_size(2)}, 1431 &in_backprop_remove_padding)); 1432 1433 // Remove the padding for odd spatial dimensions. 1434 functor::PadInput<GPUDevice, T, int, 5>()( 1435 context->eigen_device<GPUDevice>(), 1436 To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop) 1437 .tensor<T, 5>()), 1438 {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}}, 1439 To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW); 1440 1441 pre_transformed_in_backprop = in_backprop_remove_padding; 1442 } 1443 1444 if (data_format_ == FORMAT_NHWC) { 1445 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; 1446 functor::NCHWToNHWC<GPUDevice, T, 5>()( 1447 context->eigen_device<GPUDevice>(), 1448 toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(), 1449 in_backprop->tensor<T, 5>()); 1450 } else { 1451 *in_backprop = pre_transformed_in_backprop; 1452 } 1453 } 1454 1455 private: 1456 std::vector<int32> dilation_; 1457 std::vector<int32> stride_; 1458 Padding padding_; 1459 TensorFormat data_format_; 1460 bool takes_shape_; 1461 bool cudnn_use_autotune_; 1462 }; 1463 1464 // A dummy type to group backward filter autotune results together. 1465 struct Conv3dBackwardFilterAutoTuneGroup { 1466 static string name() { return "Conv3dBwdFilter"; } 1467 }; 1468 typedef AutoTuneSingleton<Conv3dBackwardFilterAutoTuneGroup, ConvParameters, 1469 se::dnn::AlgorithmConfig> 1470 AutoTuneConv3dBwdFilter; 1471 1472 template <typename T> 1473 class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { 1474 public: 1475 explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) 1476 : OpKernel(context), 1477 data_format_(FORMAT_NHWC), 1478 takes_shape_(type_string().find("V2") != std::string::npos) { 1479 // data_format is only available in V2. 1480 if (takes_shape_) { 1481 string data_format; 1482 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 1483 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 1484 errors::InvalidArgument("Invalid data format")); 1485 } 1486 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); 1487 OP_REQUIRES(context, dilation_.size() == 5, 1488 errors::InvalidArgument("Dilation rates field must " 1489 "specify 5 dimensions")); 1490 OP_REQUIRES(context, 1491 (GetTensorDim(dilation_, data_format_, 'C') == 1 && 1492 GetTensorDim(dilation_, data_format_, 'N') == 1), 1493 errors::InvalidArgument( 1494 "Current implementation does not yet support " 1495 "dilation rates in the batch and depth dimensions.")); 1496 OP_REQUIRES( 1497 context, 1498 (GetTensorDim(dilation_, data_format_, '0') > 0 && 1499 GetTensorDim(dilation_, data_format_, '1') > 0 && 1500 GetTensorDim(dilation_, data_format_, '2') > 0), 1501 errors::InvalidArgument("Dilated rates should be larger than 0.")); 1502 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 1503 OP_REQUIRES(context, stride_.size() == 5, 1504 errors::InvalidArgument("Sliding window strides field must " 1505 "specify 5 dimensions")); 1506 OP_REQUIRES( 1507 context, 1508 (GetTensorDim(stride_, data_format_, 'C') == 1 && 1509 GetTensorDim(stride_, data_format_, 'N') == 1), 1510 errors::InvalidArgument("Current implementation does not yet support " 1511 "strides in the batch and depth dimensions.")); 1512 OP_REQUIRES( 1513 context, 1514 (GetTensorDim(stride_, data_format_, '0') > 0 && 1515 GetTensorDim(stride_, data_format_, '1') > 0 && 1516 GetTensorDim(stride_, data_format_, '2') > 0), 1517 errors::InvalidArgument("Spatial strides should be larger than 0.")); 1518 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 1519 cudnn_use_autotune_ = CudnnUseAutotune(); 1520 } 1521 1522 void Compute(OpKernelContext* context) override { 1523 const Tensor& input = context->input(0); 1524 const TensorShape& input_shape = input.shape(); 1525 1526 const Tensor& out_backprop = context->input(2); 1527 const TensorShape& out_backprop_shape = out_backprop.shape(); 1528 1529 TensorShape filter_shape; 1530 if (takes_shape_) { 1531 const Tensor& filter_sizes = context->input(1); 1532 OP_REQUIRES_OK(context, MakeShape(filter_sizes, &filter_shape)); 1533 } else { 1534 filter_shape = context->input(1).shape(); 1535 } 1536 1537 ConvBackpropDimensions dims; 1538 OP_REQUIRES_OK( 1539 context, 1540 ConvBackpropComputeDimensionsV2( 1541 "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, input_shape, 1542 filter_shape, out_backprop_shape, dilation_, stride_, padding_, 1543 /*explicit_paddings=*/{}, data_format_, &dims)); 1544 1545 Tensor* filter_backprop; 1546 OP_REQUIRES_OK(context, 1547 context->allocate_output(0, filter_shape, &filter_backprop)); 1548 1549 auto* stream = context->op_device_context()->stream(); 1550 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); 1551 1552 if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 && 1553 dims.filter_size(0) == 1 && dims.dilation(2) == 1 && 1554 dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 && 1555 dims.stride(1) == 1 && dims.stride(0) == 1 && 1556 data_format_ == FORMAT_NHWC) { 1557 const uint64 m = dims.in_depth; 1558 const uint64 k = dims.batch_size * dims.input_size(1) * 1559 dims.input_size(2) * dims.input_size(0); 1560 const uint64 n = dims.out_depth; 1561 1562 // The shape of output backprop is 1563 // [batch, out_z, out_y, out_x, out_depth] 1564 // From cublas's perspective, it is: n x k 1565 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 1566 out_backprop.template flat<T>().size()); 1567 1568 // The shape of input is: 1569 // [batch, in_z, in_y, in_x, in_depth], 1570 // From cublas's perspective, it is: m x k 1571 auto b_ptr = AsDeviceMemory(input.template flat<T>().data(), 1572 input.template flat<T>().size()); 1573 1574 // The shape of the filter backprop is: 1575 // [1, 1, 1, in_depth, out_depth] 1576 // From cublas's perspective, it is: n x m 1577 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), 1578 filter_backprop->template flat<T>().size()); 1579 1580 bool blas_launch_status = 1581 stream 1582 ->ThenBlasGemm(se::blas::Transpose::kNoTranspose, 1583 se::blas::Transpose::kTranspose, n, m, k, 1.0f, 1584 a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) 1585 .ok(); 1586 if (!blas_launch_status) { 1587 context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 1588 ", n=", n, ", k=", k)); 1589 } 1590 return; 1591 } else if (dims.filter_size(0) == dims.input_size(0) && 1592 dims.filter_size(1) == dims.input_size(1) && 1593 dims.filter_size(2) == dims.input_size(2) && 1594 padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { 1595 const uint64 m = dims.input_size(0) * dims.input_size(1) * 1596 dims.input_size(2) * dims.in_depth; 1597 const uint64 k = dims.batch_size; 1598 const uint64 n = dims.out_depth; 1599 1600 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), 1601 input.template flat<T>().size()); 1602 auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 1603 out_backprop.template flat<T>().size()); 1604 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), 1605 filter_backprop->template flat<T>().size()); 1606 1607 bool blas_launch_status = 1608 stream 1609 ->ThenBlasGemm(se::blas::Transpose::kNoTranspose, 1610 se::blas::Transpose::kTranspose, n, m, k, 1.0f, 1611 b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n) 1612 .ok(); 1613 if (!blas_launch_status) { 1614 context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 1615 ", n=", n, ", k=", k)); 1616 } 1617 return; 1618 } 1619 1620 int padding_planes = dims.SpatialPadding(padding_, 0); 1621 int padding_rows = dims.SpatialPadding(padding_, 1); 1622 int padding_cols = dims.SpatialPadding(padding_, 2); 1623 const bool planes_odd = (padding_planes % 2 != 0); 1624 const bool rows_odd = (padding_rows % 2 != 0); 1625 const bool cols_odd = (padding_cols % 2 != 0); 1626 1627 Tensor compatible_input; 1628 if (rows_odd || cols_odd || planes_odd) { 1629 OP_REQUIRES_OK(context, 1630 context->allocate_temp( 1631 DataTypeToEnum<T>::value, 1632 ShapeFromFormat(data_format_, dims.batch_size, 1633 {{dims.input_size(0) + planes_odd, 1634 dims.input_size(1) + rows_odd, 1635 dims.input_size(2) + cols_odd}}, 1636 dims.in_depth), 1637 &compatible_input)); 1638 functor::PadInput<GPUDevice, T, int, 5>()( 1639 context->template eigen_device<GPUDevice>(), 1640 To32Bit(input.tensor<T, 5>()), {{0, 0, 0}}, 1641 {{planes_odd, rows_odd, cols_odd}}, 1642 To32Bit(compatible_input.tensor<T, 5>()), data_format_); 1643 } else { 1644 compatible_input = input; 1645 } 1646 1647 CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) 1648 << "Negative paddings: (" << padding_rows << ", " << padding_cols 1649 << ", " << padding_planes << ")"; 1650 se::dnn::BatchDescriptor input_desc(3); 1651 input_desc.set_count(dims.batch_size) 1652 .set_spatial_dim(DimIndex::X, 1653 GetTensorDim(compatible_input, data_format_, '2')) 1654 .set_spatial_dim(DimIndex::Y, 1655 GetTensorDim(compatible_input, data_format_, '1')) 1656 .set_spatial_dim(DimIndex::Z, 1657 GetTensorDim(compatible_input, data_format_, '0')) 1658 .set_feature_map_count(dims.in_depth) 1659 .set_layout(se::dnn::DataLayout::kBatchDepthYX); 1660 se::dnn::BatchDescriptor output_desc(3); 1661 output_desc.set_count(dims.batch_size) 1662 .set_spatial_dim(DimIndex::X, dims.output_size(2)) 1663 .set_spatial_dim(DimIndex::Y, dims.output_size(1)) 1664 .set_spatial_dim(DimIndex::Z, dims.output_size(0)) 1665 .set_feature_map_count(dims.out_depth) 1666 .set_layout(se::dnn::DataLayout::kBatchDepthYX); 1667 se::dnn::FilterDescriptor filter_desc(3); 1668 filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) 1669 .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) 1670 .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) 1671 .set_input_feature_map_count(dims.in_depth) 1672 .set_output_feature_map_count(dims.out_depth); 1673 se::dnn::ConvolutionDescriptor conv_desc(3); 1674 conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) 1675 .set_dilation_rate(DimIndex::Y, dims.dilation(1)) 1676 .set_dilation_rate(DimIndex::Z, dims.dilation(0)) 1677 .set_filter_stride(DimIndex::X, dims.stride(2)) 1678 .set_filter_stride(DimIndex::Y, dims.stride(1)) 1679 .set_filter_stride(DimIndex::Z, dims.stride(0)) 1680 .set_zero_padding(DimIndex::X, padding_cols / 2) 1681 .set_zero_padding(DimIndex::Y, padding_rows / 2) 1682 .set_zero_padding(DimIndex::Z, padding_planes / 2); 1683 1684 Tensor pre_transformed_filter_backprop; 1685 OP_REQUIRES_OK( 1686 context, 1687 context->allocate_temp( 1688 DataTypeToEnum<T>::value, 1689 TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), 1690 dims.filter_size(1), dims.filter_size(2)}), 1691 &pre_transformed_filter_backprop)); 1692 1693 Tensor transformed_out_backprop; 1694 if (data_format_ == FORMAT_NHWC) { 1695 TensorShape nchw_shape = {dims.batch_size, dims.out_depth, 1696 dims.output_size(0), dims.output_size(1), 1697 dims.output_size(2)}; 1698 OP_REQUIRES_OK( 1699 context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, 1700 &transformed_out_backprop)); 1701 if (dims.out_depth > 1) { 1702 functor::NHWCToNCHW<GPUDevice, T, 5>()( 1703 context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), 1704 transformed_out_backprop.tensor<T, 5>()); 1705 } else { 1706 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); 1707 } 1708 } else { 1709 transformed_out_backprop = out_backprop; 1710 } 1711 Tensor transformed_input; 1712 if (data_format_ == FORMAT_NHWC) { 1713 TensorShape nchw_shape = { 1714 dims.batch_size, dims.in_depth, compatible_input.dim_size(1), 1715 compatible_input.dim_size(2), compatible_input.dim_size(3)}; 1716 if (dims.in_depth > 1) { 1717 OP_REQUIRES_OK(context, 1718 context->allocate_temp(DataTypeToEnum<T>::value, 1719 nchw_shape, &transformed_input)); 1720 functor::NHWCToNCHW<GPUDevice, T, 5>()( 1721 context->eigen_device<GPUDevice>(), 1722 const_cast<const Tensor&>(compatible_input).tensor<T, 5>(), 1723 transformed_input.tensor<T, 5>()); 1724 } else { 1725 CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); 1726 } 1727 } else { 1728 transformed_input = compatible_input; 1729 } 1730 1731 auto out_backprop_ptr = 1732 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), 1733 transformed_out_backprop.template flat<T>().size()); 1734 auto filter_backprop_ptr = AsDeviceMemory( 1735 pre_transformed_filter_backprop.template flat<T>().data(), 1736 pre_transformed_filter_backprop.template flat<T>().size()); 1737 auto input_ptr = 1738 AsDeviceMemory(transformed_input.template flat<T>().data(), 1739 transformed_input.template flat<T>().size()); 1740 1741 static int64 ConvolveBackwardFilterScratchSize = GetDnnWorkspaceLimit( 1742 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default 1743 1744 const int device_id = stream->parent()->device_ordinal(); 1745 DataType dtype = input.dtype(); 1746 const ConvParameters conv_parameters = { 1747 dims.batch_size, 1748 dims.in_depth, 1749 {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, 1750 FORMAT_NCHW, 1751 dims.out_depth, 1752 {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, 1753 {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, 1754 {{dims.stride(0), dims.stride(1), dims.stride(2)}}, 1755 {{padding_planes, padding_rows, padding_cols}}, 1756 dtype, 1757 device_id, 1758 }; 1759 1760 using se::dnn::AlgorithmConfig; 1761 using se::dnn::AlgorithmDesc; 1762 using se::dnn::ProfileResult; 1763 AlgorithmConfig algorithm_config; 1764 if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find( 1765 conv_parameters, &algorithm_config)) { 1766 std::vector<AlgorithmDesc> algorithms; 1767 CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( 1768 conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>( 1769 stream->parent()), 1770 &algorithms)); 1771 ProfileResult best_result; 1772 ProfileResult best_result_no_scratch; 1773 for (auto profile_algorithm : algorithms) { 1774 // TODO(zhengxq): profile each algorithm multiple times to better 1775 // accuracy. 1776 DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, 1777 context); 1778 ProfileResult profile_result; 1779 bool cudnn_launch_status = 1780 stream 1781 ->ThenConvolveBackwardFilterWithAlgorithm( 1782 input_desc, input_ptr, output_desc, out_backprop_ptr, 1783 conv_desc, filter_desc, &filter_backprop_ptr, 1784 &scratch_allocator, AlgorithmConfig(profile_algorithm), 1785 &profile_result) 1786 .ok(); 1787 if (cudnn_launch_status) { 1788 if (profile_result.is_valid()) { 1789 if (profile_result.elapsed_time_in_ms() < 1790 best_result.elapsed_time_in_ms()) { 1791 best_result = profile_result; 1792 } 1793 if (scratch_allocator.TotalByteSize() == 0 && 1794 profile_result.elapsed_time_in_ms() < 1795 best_result_no_scratch.elapsed_time_in_ms()) { 1796 best_result_no_scratch = profile_result; 1797 } 1798 } 1799 } 1800 } 1801 OP_REQUIRES(context, 1802 best_result.is_valid() || best_result_no_scratch.is_valid(), 1803 errors::NotFound("No algorithm worked!")); 1804 if (best_result.is_valid()) { 1805 algorithm_config.set_algorithm(best_result.algorithm()); 1806 } 1807 if (best_result_no_scratch.is_valid()) { 1808 algorithm_config.set_algorithm_no_scratch( 1809 best_result_no_scratch.algorithm()); 1810 } 1811 AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters, 1812 algorithm_config); 1813 } 1814 DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, 1815 context); 1816 bool cudnn_launch_status = 1817 stream 1818 ->ThenConvolveBackwardFilterWithAlgorithm( 1819 input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, 1820 filter_desc, &filter_backprop_ptr, &scratch_allocator, 1821 algorithm_config, nullptr) 1822 .ok(); 1823 1824 if (!cudnn_launch_status) { 1825 context->SetStatus(errors::Internal( 1826 "cuDNN Backward Filter function launch failure : input shape(", 1827 input_shape.DebugString(), ") filter shape(", 1828 filter_shape.DebugString(), ")")); 1829 } 1830 1831 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; 1832 functor::ReverseTransformFilter<GPUDevice, T, 5>()( 1833 context->eigen_device<GPUDevice>(), 1834 toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(), 1835 filter_backprop->tensor<T, 5>()); 1836 } 1837 1838 private: 1839 std::vector<int32> dilation_; 1840 std::vector<int32> stride_; 1841 Padding padding_; 1842 TensorFormat data_format_; 1843 bool takes_shape_; 1844 bool cudnn_use_autotune_; 1845 }; 1846 1847 #define REGISTER_GPU_KERNEL(T) \ 1848 REGISTER_KERNEL_BUILDER( \ 1849 Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 1850 Conv3DBackpropInputOp<GPUDevice, T>); \ 1851 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ 1852 .Device(DEVICE_GPU) \ 1853 .TypeConstraint<T>("T") \ 1854 .HostMemory("input_sizes"), \ 1855 Conv3DBackpropInputOp<GPUDevice, T>); \ 1856 REGISTER_KERNEL_BUILDER( \ 1857 Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 1858 Conv3DBackpropFilterOp<GPUDevice, T>); \ 1859 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ 1860 .Device(DEVICE_GPU) \ 1861 .TypeConstraint<T>("T") \ 1862 .HostMemory("filter_sizes"), \ 1863 Conv3DBackpropFilterOp<GPUDevice, T>); 1864 TF_CALL_half(REGISTER_GPU_KERNEL); 1865 TF_CALL_float(REGISTER_GPU_KERNEL); 1866 TF_CALL_double(REGISTER_GPU_KERNEL); 1867 #undef REGISTER_GPU_KERNEL 1868 1869 #endif // GOOGLE_CUDA 1870 1871 } // namespace tensorflow 1872