1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define USE_EIGEN_TENSOR 17 #define EIGEN_USE_THREADS 18 19 #include "tensorflow/core/kernels/conv_3d.h" 20 21 #include "tensorflow/core/framework/numeric_op.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/tensor_slice.h" 27 #include "tensorflow/core/kernels/conv_2d.h" 28 #include "tensorflow/core/kernels/conv_ops_gpu.h" 29 #include "tensorflow/core/kernels/ops_util.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/gtl/inlined_vector.h" 32 #include "tensorflow/core/util/padding.h" 33 #include "tensorflow/core/util/tensor_format.h" 34 #include "tensorflow/core/util/use_cudnn.h" 35 36 #if GOOGLE_CUDA 37 #include "tensorflow/core/platform/stream_executor.h" 38 using perftools::gputools::dnn::DimIndex; 39 #endif 40 41 namespace tensorflow { 42 43 typedef Eigen::ThreadPoolDevice CPUDevice; 44 typedef Eigen::GpuDevice GPUDevice; 45 46 // TODO(mjanusz): Get rid of the macro and return shapes directly. 47 #define EXTRACT_AND_VERIFY_DIMENSIONS(label) \ 48 const Tensor& out_backprop = context->input(2); \ 49 OP_REQUIRES( \ 50 context, input_shape.dims() == 5, \ 51 errors::InvalidArgument(label, ": input must be 5-dimensional")); \ 52 OP_REQUIRES( \ 53 context, filter_shape.dims() == 5, \ 54 errors::InvalidArgument(label, ": filter must be 5-dimensional")); \ 55 OP_REQUIRES( \ 56 context, out_backprop.dims() == 5, \ 57 errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \ 58 const int64 batch = input_shape.dim_size(0); \ 59 OP_REQUIRES( \ 60 context, batch == out_backprop.dim_size(0), \ 61 errors::InvalidArgument( \ 62 label, ": input and out_backprop must have the same batch size")); \ 63 const std::array<int64, 3> input_size = { \ 64 {GetTensorDim(input_shape, data_format_, '0'), \ 65 GetTensorDim(input_shape, data_format_, '1'), \ 66 GetTensorDim(input_shape, data_format_, '2')}}; \ 67 const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \ 68 const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \ 69 filter_shape.dim_size(1), \ 70 filter_shape.dim_size(2)}}; \ 71 const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \ 72 const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \ 73 const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \ 74 OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \ 75 errors::InvalidArgument( \ 76 label, ": input and filter must have the same depth")); \ 77 const int64 out_depth = filter_shape.dim_size(4); \ 78 OP_REQUIRES( \ 79 context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \ 80 errors::InvalidArgument( \ 81 label, ": filter and out_backprop must have the same out_depth")); \ 82 const std::array<int64, 3> strides = { \ 83 {GetTensorDim(stride_, data_format_, '0'), \ 84 GetTensorDim(stride_, data_format_, '1'), \ 85 GetTensorDim(stride_, data_format_, '2')}}; \ 86 std::array<int64, 3> out, padding; \ 87 OP_REQUIRES_OK(context, Get3dOutputSize(input_size, filter_size, strides, \ 88 padding_, &out, &padding)); \ 89 OP_REQUIRES(context, output_planes == out[0], \ 90 errors::InvalidArgument( \ 91 label, \ 92 ": Number of planes of out_backprop doesn't match " \ 93 "computed: actual = ", \ 94 output_planes, ", computed = ", out[0])); \ 95 OP_REQUIRES( \ 96 context, output_rows == out[1], \ 97 errors::InvalidArgument( \ 98 label, ": Number of rows of out_backprop doesn't match computed: ", \ 99 "actual = ", output_rows, ", computed = ", out[1])); \ 100 OP_REQUIRES( \ 101 context, output_cols == out[2], \ 102 errors::InvalidArgument( \ 103 label, ": Number of cols of out_backprop doesn't match computed: ", \ 104 "actual = ", output_cols, ", computed = ", out[2])); \ 105 const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \ 106 const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \ 107 const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \ 108 const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \ 109 const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \ 110 const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \ 111 const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \ 112 const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \ 113 const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \ 114 const auto bottom_pad_planes = \ 115 padded_out_planes - expanded_out_planes - top_pad_planes; \ 116 const auto bottom_pad_rows = \ 117 padded_out_rows - expanded_out_rows - top_pad_rows; \ 118 const auto right_pad_cols = \ 119 padded_out_cols - expanded_out_cols - left_pad_cols; \ 120 VLOG(2) << "Conv3d: " << label \ 121 << ": expanded_out_planes = " << expanded_out_planes \ 122 << ": expanded_out_rows = " << expanded_out_rows \ 123 << ", expanded_out_cols = " << expanded_out_cols \ 124 << ", padded_out_planes = " << padded_out_planes \ 125 << ", padded_out_rows = " << padded_out_rows \ 126 << ", padded_out_cols = " << padded_out_cols \ 127 << ", top_pad_planes = " << top_pad_planes \ 128 << ", top_pad_rows = " << top_pad_rows \ 129 << ", left_pad_cols = " << left_pad_cols \ 130 << ", bottom_pad_planes = " << bottom_pad_planes \ 131 << ", bottom_pad_rows = " << bottom_pad_rows \ 132 << ", right_pad_cols = " << right_pad_cols 133 134 // Backprop for input. 135 template <typename Device, class T> 136 class Conv3DBackpropInputOp : public OpKernel { 137 public: 138 explicit Conv3DBackpropInputOp(OpKernelConstruction* context) 139 : OpKernel(context), 140 data_format_(FORMAT_NHWC), 141 takes_shape_(type_string().find("V2") != std::string::npos) { 142 // data_format is only available in V2. 143 if (takes_shape_) { 144 string data_format; 145 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 146 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 147 errors::InvalidArgument("Invalid data format")); 148 OP_REQUIRES( 149 context, data_format_ == FORMAT_NHWC, 150 errors::InvalidArgument( 151 "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.")); 152 } 153 154 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 155 OP_REQUIRES(context, stride_.size() == 5, 156 errors::InvalidArgument("Sliding window strides field must " 157 "specify 5 dimensions")); 158 OP_REQUIRES( 159 context, 160 (GetTensorDim(stride_, data_format_, 'C') == 1 && 161 GetTensorDim(stride_, data_format_, 'N') == 1), 162 errors::InvalidArgument("Current implementation does not yet support " 163 "strides in the batch and depth dimensions.")); 164 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 165 } 166 167 void Compute(OpKernelContext* context) override { 168 const Tensor& filter = context->input(1); 169 const TensorShape& filter_shape = filter.shape(); 170 TensorShape input_shape; 171 if (takes_shape_) { 172 const Tensor& input_sizes = context->input(0); 173 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 174 input_sizes.vec<int32>(), &input_shape)); 175 } else { 176 input_shape = context->input(0).shape(); 177 } 178 EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); 179 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{ 180 {0, 0}, 181 {top_pad_planes, bottom_pad_planes}, 182 {top_pad_rows, bottom_pad_rows}, 183 {left_pad_cols, right_pad_cols}, 184 {0, 0}}; 185 Tensor* in_backprop; 186 OP_REQUIRES_OK(context, 187 context->allocate_output(0, input_shape, &in_backprop)); 188 189 // Fill out a padded out_backprop. 190 TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows, 191 padded_out_cols, out_depth}); 192 Tensor padded_output; 193 OP_REQUIRES_OK(context, 194 context->allocate_temp(DataTypeToEnum<T>::v(), 195 padded_out_shape, &padded_output)); 196 Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4}; 197 Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1], 198 strides[2], 1}; 199 functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()( 200 context->eigen_device<Device>(), out_backprop.tensor<T, 5>(), 201 eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>()); 202 const Tensor& padded_output_cref = padded_output; 203 204 // Fill a new "reverted" filter. We need to transpose the in_depth and 205 // out_depth for the filter and reverse the planes, rows and cols. 206 TensorShape r_filter_shape( 207 {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth}); 208 Tensor r_filter; 209 OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(), 210 r_filter_shape, &r_filter)); 211 Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3}; 212 Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false}; 213 functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( 214 context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order, 215 filter_rev_dims, r_filter.tensor<T, 5>()); 216 const Tensor& r_filter_cref = r_filter; 217 218 // Now we can call conv_3d directly. 219 functor::CuboidConvolution<Device, T>()( 220 context->eigen_device<Device>(), in_backprop->tensor<T, 5>(), 221 padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1, 222 1, BrainPadding2EigenPadding(VALID)); 223 } 224 225 private: 226 std::vector<int32> stride_; 227 Padding padding_; 228 TensorFormat data_format_; 229 bool takes_shape_; 230 }; 231 232 #define REGISTER_CPU_KERNEL(T) \ 233 REGISTER_KERNEL_BUILDER( \ 234 Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 235 Conv3DBackpropInputOp<CPUDevice, T>); \ 236 REGISTER_KERNEL_BUILDER( \ 237 Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 238 Conv3DBackpropInputOp<CPUDevice, T>); 239 TF_CALL_half(REGISTER_CPU_KERNEL); 240 TF_CALL_float(REGISTER_CPU_KERNEL); 241 TF_CALL_double(REGISTER_CPU_KERNEL); 242 #undef REGISTER_CPU_KERNEL 243 244 // Backprop for filter. 245 template <typename Device, class T> 246 class Conv3DBackpropFilterOp : public OpKernel { 247 public: 248 explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) 249 : OpKernel(context), 250 data_format_(FORMAT_NHWC), 251 takes_shape_(type_string().find("V2") != std::string::npos) { 252 // data_format is only available in V2. 253 if (takes_shape_) { 254 string data_format; 255 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 256 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 257 errors::InvalidArgument("Invalid data format")); 258 OP_REQUIRES( 259 context, data_format_ == FORMAT_NHWC, 260 errors::InvalidArgument( 261 "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU.")); 262 } 263 264 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 265 OP_REQUIRES(context, stride_.size() == 5, 266 errors::InvalidArgument("Sliding window strides field must " 267 "specify 5 dimensions")); 268 OP_REQUIRES( 269 context, 270 (GetTensorDim(stride_, data_format_, 'C') == 1 && 271 GetTensorDim(stride_, data_format_, 'N') == 1), 272 errors::InvalidArgument("Current implementation does not yet support " 273 "strides in the batch and depth dimensions.")); 274 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 275 } 276 277 void Compute(OpKernelContext* context) override { 278 const Tensor& input = context->input(0); 279 const TensorShape& input_shape = input.shape(); 280 TensorShape filter_shape; 281 282 if (takes_shape_) { 283 const Tensor& filter_sizes = context->input(1); 284 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 285 filter_sizes.vec<int32>(), &filter_shape)); 286 } else { 287 filter_shape = context->input(1).shape(); 288 } 289 290 EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); 291 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{ 292 {0, 0}, 293 {top_pad_planes, bottom_pad_planes}, 294 {top_pad_rows, bottom_pad_rows}, 295 {left_pad_cols, right_pad_cols}, 296 {0, 0}}; 297 Tensor* filter_backprop; 298 OP_REQUIRES_OK(context, 299 context->allocate_output(0, filter_shape, &filter_backprop)); 300 301 if (input_shape.num_elements() == 0) { 302 filter_backprop->template flat<T>().setZero(); 303 return; 304 } 305 306 // For the backprop of the filter, we need to also transpose the 307 // out_backprop. 308 // The shape of backprop is 309 // [batch, out_z, out_y, out_x, out_depth] 310 // And we need to change it to 311 // [out_depth, out_x, out_y, out_z, batch] 312 Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0}; 313 TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows, 314 padded_out_cols, batch}); 315 Tensor padded_output; 316 OP_REQUIRES_OK(context, 317 context->allocate_temp(DataTypeToEnum<T>::v(), 318 padded_out_shape, &padded_output)); 319 Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1], 320 strides[2], 1}; 321 functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()( 322 context->eigen_device<Device>(), out_backprop.tensor<T, 5>(), 323 eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>()); 324 const Tensor& padded_output_cref = padded_output; 325 326 // For the backprop of the filter, we need to transpose the input. 327 // The shape of input is 328 // [batch, in_z, in_y, in_x, in_depth] 329 // And we need to change it to 330 // [in_z, in_y, in_x, batch, in_depth] 331 Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4}; 332 TensorShape in_shuffle_shape( 333 {input_size[0], input_size[1], input_size[2], batch, in_depth}); 334 Tensor in_shuffle; 335 OP_REQUIRES_OK(context, 336 context->allocate_temp(DataTypeToEnum<T>::v(), 337 in_shuffle_shape, &in_shuffle)); 338 // No need for reversing this time. 339 Eigen::array<bool, 5> no_reverse{false, false, false, false, false}; 340 functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( 341 context->eigen_device<Device>(), input.tensor<T, 5>(), in_order, 342 no_reverse, in_shuffle.tensor<T, 5>()); 343 const Tensor& in_shuffle_cref = in_shuffle; 344 345 // The output of the conv_3d would be 346 // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth] 347 // and we need to shuffle it back to 348 // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth]; 349 // And we need to reverse the filter backprops. 350 // So we need to allocate (sigh) yet another piece of memory to hold the 351 // output. 352 TensorShape filter_shuffle_shape( 353 {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth}); 354 Tensor filter_shuffle; 355 OP_REQUIRES_OK( 356 context, context->allocate_temp(DataTypeToEnum<T>::v(), 357 filter_shuffle_shape, &filter_shuffle)); 358 functor::CuboidConvolution<Device, T>()( 359 context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(), 360 padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1, 361 1, BrainPadding2EigenPadding(VALID)); 362 363 // Now copy the filter_backprop back to the destination. 364 Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0}; 365 Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false}; 366 const Tensor& filter_shuffle_cref = filter_shuffle; 367 functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( 368 context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(), 369 filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>()); 370 } 371 372 private: 373 std::vector<int32> stride_; 374 Padding padding_; 375 TensorFormat data_format_; 376 bool takes_shape_; 377 }; 378 379 #define REGISTER_CPU_KERNEL(T) \ 380 REGISTER_KERNEL_BUILDER( \ 381 Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 382 Conv3DBackpropFilterOp<CPUDevice, T>); \ 383 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ 384 .Device(DEVICE_CPU) \ 385 .TypeConstraint<T>("T"), \ 386 Conv3DBackpropFilterOp<CPUDevice, T>); 387 TF_CALL_half(REGISTER_CPU_KERNEL); 388 TF_CALL_float(REGISTER_CPU_KERNEL); 389 TF_CALL_double(REGISTER_CPU_KERNEL); 390 #undef REGISTER_CPU_KERNEL 391 392 // GPU definitions of both ops. 393 #if GOOGLE_CUDA 394 // Forward declarations of the functor specializations for GPU. 395 // This ensures that the custom implementation is used instead of the default 396 // Eigen one (which is used for CPU). 397 namespace functor { 398 #define DECLARE_GPU_SPEC(T) \ 399 template <> \ 400 void TransformFilter<GPUDevice, T, int, 5>::operator()( \ 401 const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ 402 typename TTypes<T, 5, int>::Tensor out); \ 403 template <> \ 404 void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \ 405 const GPUDevice& d, typename TTypes<T, 5>::ConstTensor in, \ 406 typename TTypes<T, 5>::Tensor out); \ 407 template <> \ 408 void PadInput<GPUDevice, T, int, 5>::operator()( \ 409 const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ 410 const std::array<int, 3>& padding_left, \ 411 const std::array<int, 3>& padding_right, \ 412 typename TTypes<T, 5, int>::Tensor out, TensorFormat format); 413 414 DECLARE_GPU_SPEC(Eigen::half); 415 DECLARE_GPU_SPEC(float); 416 #undef DECLARE_GPU_SPEC 417 } // namespace functor 418 419 // A dummy type to group backward data autotune results together. 420 struct Conv3dBackwardDataAutoTuneGroup { 421 static string name() { return "Conv3dBwdData"; } 422 }; 423 typedef AutoTuneSingleton<Conv3dBackwardDataAutoTuneGroup, ConvParameters, 424 perftools::gputools::dnn::AlgorithmConfig> 425 426 AutoTuneConv3dBwdData; 427 template <typename T> 428 class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { 429 public: 430 explicit Conv3DBackpropInputOp(OpKernelConstruction* context) 431 : OpKernel(context), 432 data_format_(FORMAT_NHWC), 433 takes_shape_(type_string().find("V2") != std::string::npos) { 434 // data_format is only available in V2. 435 if (takes_shape_) { 436 string data_format; 437 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 438 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 439 errors::InvalidArgument("Invalid data format")); 440 } 441 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 442 OP_REQUIRES(context, stride_.size() == 5, 443 errors::InvalidArgument("Sliding window strides field must " 444 "specify 5 dimensions")); 445 OP_REQUIRES( 446 context, 447 (GetTensorDim(stride_, data_format_, 'C') == 1 && 448 GetTensorDim(stride_, data_format_, 'N') == 1), 449 errors::InvalidArgument("Current implementation does not yet support " 450 "strides in the batch and depth dimensions.")); 451 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 452 cudnn_use_autotune_ = CudnnUseAutotune(); 453 } 454 void Compute(OpKernelContext* context) override { 455 const Tensor& filter = context->input(1); 456 const TensorShape& filter_shape = filter.shape(); 457 TensorShape input_shape; 458 if (takes_shape_) { 459 const Tensor& input_sizes = context->input(0); 460 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 461 input_sizes.vec<int32>(), &input_shape)); 462 } else { 463 input_shape = context->input(0).shape(); 464 } 465 EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); 466 Tensor* in_backprop; 467 OP_REQUIRES_OK(context, 468 context->allocate_output(0, input_shape, &in_backprop)); 469 470 auto* stream = context->op_device_context()->stream(); 471 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); 472 473 if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 && 474 stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 && 475 data_format_ == FORMAT_NHWC) { 476 const uint64 m = batch * input_size[0] * input_size[1] * input_size[2]; 477 const uint64 k = out_depth; 478 const uint64 n = in_depth; 479 480 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 481 out_backprop.template flat<T>().size()); 482 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), 483 filter.template flat<T>().size()); 484 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), 485 in_backprop->template flat<T>().size()); 486 487 auto transpose = perftools::gputools::blas::Transpose::kTranspose; 488 auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; 489 490 bool blas_launch_status = 491 stream 492 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, 493 a_ptr, k, 0.0f, &c_ptr, n) 494 .ok(); 495 if (!blas_launch_status) { 496 context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 497 ", n=", n, ", k=", k)); 498 } 499 return; 500 } else if (filter_size[0] == input_size[0] && 501 filter_size[1] == input_size[1] && 502 filter_size[2] == input_size[2] && padding_ == Padding::VALID && 503 data_format_ == FORMAT_NHWC) { 504 const uint64 m = batch; 505 const uint64 k = out_depth; 506 const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth; 507 508 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 509 out_backprop.template flat<T>().size()); 510 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), 511 filter.template flat<T>().size()); 512 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), 513 in_backprop->template flat<T>().size()); 514 515 auto transpose = perftools::gputools::blas::Transpose::kTranspose; 516 auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; 517 518 bool blas_launch_status = 519 stream 520 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, 521 a_ptr, k, 0.0f, &c_ptr, n) 522 .ok(); 523 if (!blas_launch_status) { 524 context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 525 ", n=", n, ", k=", k)); 526 } 527 return; 528 } 529 530 int padding_rows = 0, padding_cols = 0, padding_planes = 0; 531 532 if (padding_ == Padding::SAME) { 533 padding_planes = std::max<int>( 534 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]); 535 padding_cols = std::max<int>( 536 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]); 537 padding_rows = std::max<int>( 538 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]); 539 } 540 const bool rows_odd = (padding_rows % 2 != 0); 541 const bool cols_odd = (padding_cols % 2 != 0); 542 const bool planes_odd = (padding_planes % 2 != 0); 543 544 TensorShape compatible_input_shape; 545 if (rows_odd || cols_odd || planes_odd) { 546 // cuDNN only supports the same amount of padding on both sides. 547 compatible_input_shape = { 548 batch, 549 in_depth, 550 input_size[0] + planes_odd, 551 input_size[1] + rows_odd, 552 input_size[2] + cols_odd, 553 }; 554 } else { 555 compatible_input_shape = {batch, in_depth, input_size[0], input_size[1], 556 input_size[2]}; 557 } 558 559 CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) 560 << "Negative paddings: (" << padding_rows << ", " << padding_cols 561 << ", " << padding_planes << ")"; 562 perftools::gputools::dnn::BatchDescriptor input_desc(3); 563 input_desc.set_count(batch) 564 .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4)) 565 .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3)) 566 .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2)) 567 .set_feature_map_count(in_depth) 568 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 569 perftools::gputools::dnn::BatchDescriptor output_desc(3); 570 output_desc.set_count(batch) 571 .set_spatial_dim(DimIndex::X, output_cols) 572 .set_spatial_dim(DimIndex::Y, output_rows) 573 .set_spatial_dim(DimIndex::Z, output_planes) 574 .set_feature_map_count(out_depth) 575 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 576 perftools::gputools::dnn::FilterDescriptor filter_desc(3); 577 filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) 578 .set_spatial_dim(DimIndex::Y, filter_size[1]) 579 .set_spatial_dim(DimIndex::Z, filter_size[0]) 580 .set_input_feature_map_count(in_depth) 581 .set_output_feature_map_count(out_depth); 582 perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3); 583 conv_desc.set_filter_stride(DimIndex::X, strides[2]) 584 .set_filter_stride(DimIndex::Y, strides[1]) 585 .set_filter_stride(DimIndex::Z, strides[0]) 586 .set_zero_padding(DimIndex::X, padding_cols / 2) 587 .set_zero_padding(DimIndex::Y, padding_rows / 2) 588 .set_zero_padding(DimIndex::Z, padding_planes / 2); 589 590 // Shape: out, in, z, y, x. 591 Tensor transformed_filter; 592 OP_REQUIRES_OK( 593 context, 594 context->allocate_temp(DataTypeToEnum<T>::value, 595 TensorShape({out_depth, in_depth, filter_size[0], 596 filter_size[1], filter_size[2]}), 597 &transformed_filter)); 598 functor::TransformFilter<GPUDevice, T, int, 5>()( 599 context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()), 600 To32Bit(transformed_filter.tensor<T, 5>())); 601 602 // Shape: batch, filters, z, y, x. 603 Tensor transformed_out_backprop; 604 if (data_format_ == FORMAT_NHWC) { 605 TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows, 606 output_cols}; 607 if (out_depth > 1) { 608 OP_REQUIRES_OK(context, context->allocate_temp( 609 DataTypeToEnum<T>::value, nchw_shape, 610 &transformed_out_backprop)); 611 functor::NHWCToNCHW<GPUDevice, T, 5>()( 612 context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), 613 transformed_out_backprop.tensor<T, 5>()); 614 } else { 615 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); 616 } 617 } else { 618 transformed_out_backprop = out_backprop; 619 } 620 // Shape: batch, filters, z, y, x. 621 Tensor pre_transformed_in_backprop; 622 OP_REQUIRES_OK( 623 context, 624 context->allocate_temp(DataTypeToEnum<T>::value, compatible_input_shape, 625 &pre_transformed_in_backprop)); 626 627 auto out_backprop_ptr = 628 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), 629 transformed_out_backprop.template flat<T>().size()); 630 auto filter_ptr = 631 AsDeviceMemory(transformed_filter.template flat<T>().data(), 632 transformed_filter.template flat<T>().size()); 633 auto in_backprop_ptr = 634 AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(), 635 pre_transformed_in_backprop.template flat<T>().size()); 636 637 static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit( 638 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default 639 640 const int device_id = stream->parent()->device_ordinal(); 641 DataType dtype = context->input(0).dtype(); 642 const ConvParameters conv_parameters = { 643 batch, 644 in_depth, 645 {{input_size[0], input_size[1], input_size[2]}}, 646 out_depth, 647 {{filter_size[0], filter_size[1], filter_size[2]}}, 648 // TODO(yangzihao): Send in arbitrary dilation rates after the dilated 649 // conv is supported. 650 /*dilation=*/{{1, 1, 1}}, 651 {{strides[0], strides[1], strides[2]}}, 652 {{padding_planes, padding_rows, padding_cols}}, 653 dtype, 654 device_id, 655 }; 656 657 using perftools::gputools::dnn::AlgorithmConfig; 658 using perftools::gputools::dnn::AlgorithmDesc; 659 using perftools::gputools::dnn::ProfileResult; 660 AlgorithmConfig algorithm_config; 661 if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find( 662 conv_parameters, &algorithm_config)) { 663 std::vector<AlgorithmDesc> algorithms; 664 CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( 665 conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); 666 ProfileResult best_result; 667 ProfileResult best_result_no_scratch; 668 for (auto profile_algorithm : algorithms) { 669 // TODO(zhengxq): profile each algorithm multiple times to better 670 // accuracy. 671 CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, 672 context); 673 ProfileResult profile_result; 674 bool cudnn_launch_status = 675 stream 676 ->ThenConvolveBackwardDataWithAlgorithm( 677 filter_desc, filter_ptr, output_desc, out_backprop_ptr, 678 conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, 679 AlgorithmConfig(profile_algorithm), &profile_result) 680 .ok(); 681 if (cudnn_launch_status) { 682 if (profile_result.is_valid()) { 683 if (profile_result.elapsed_time_in_ms() < 684 best_result.elapsed_time_in_ms()) { 685 best_result = profile_result; 686 } 687 if (scratch_allocator.TotalByteSize() == 0 && 688 profile_result.elapsed_time_in_ms() < 689 best_result_no_scratch.elapsed_time_in_ms()) { 690 best_result_no_scratch = profile_result; 691 } 692 } 693 } 694 } 695 OP_REQUIRES(context, 696 best_result.is_valid() || best_result_no_scratch.is_valid(), 697 errors::NotFound("No algorithm worked!")); 698 if (best_result.is_valid()) { 699 algorithm_config.set_algorithm(best_result.algorithm()); 700 } 701 if (best_result_no_scratch.is_valid()) { 702 algorithm_config.set_algorithm_no_scratch( 703 best_result_no_scratch.algorithm()); 704 } 705 AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters, 706 algorithm_config); 707 } 708 CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, 709 context); 710 bool cudnn_launch_status = 711 stream 712 ->ThenConvolveBackwardDataWithAlgorithm( 713 filter_desc, filter_ptr, output_desc, out_backprop_ptr, 714 conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, 715 algorithm_config, nullptr) 716 .ok(); 717 718 if (!cudnn_launch_status) { 719 context->SetStatus(errors::Internal( 720 "cuDNN Backward Data function launch failure : input shape(", 721 input_shape.DebugString(), ") filter shape(", 722 filter_shape.DebugString(), ")")); 723 } 724 725 if (rows_odd || cols_odd || planes_odd) { 726 Tensor in_backprop_remove_padding; 727 OP_REQUIRES_OK(context, 728 context->allocate_temp(DataTypeToEnum<T>::value, 729 {batch, in_depth, input_size[0], 730 input_size[1], input_size[2]}, 731 &in_backprop_remove_padding)); 732 733 // Remove the padding for odd spatial dimensions. 734 functor::PadInput<GPUDevice, T, int, 5>()( 735 context->eigen_device<GPUDevice>(), 736 To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop) 737 .tensor<T, 5>()), 738 {{0, 0, 0}}, {{-planes_odd, -rows_odd, -cols_odd}}, 739 To32Bit(in_backprop_remove_padding.tensor<T, 5>()), FORMAT_NCHW); 740 741 pre_transformed_in_backprop = in_backprop_remove_padding; 742 } 743 744 if (data_format_ == FORMAT_NHWC) { 745 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; 746 functor::NCHWToNHWC<GPUDevice, T, 5>()( 747 context->eigen_device<GPUDevice>(), 748 toConstTensor(pre_transformed_in_backprop).template tensor<T, 5>(), 749 in_backprop->tensor<T, 5>()); 750 } else { 751 *in_backprop = pre_transformed_in_backprop; 752 } 753 } 754 755 private: 756 std::vector<int32> stride_; 757 Padding padding_; 758 TensorFormat data_format_; 759 bool takes_shape_; 760 bool cudnn_use_autotune_; 761 }; 762 763 // A dummy type to group backward filter autotune results together. 764 struct Conv3dBackwardFilterAutoTuneGroup { 765 static string name() { return "Conv3dBwdFilter"; } 766 }; 767 typedef AutoTuneSingleton<Conv3dBackwardFilterAutoTuneGroup, ConvParameters, 768 perftools::gputools::dnn::AlgorithmConfig> 769 AutoTuneConv3dBwdFilter; 770 771 template <typename T> 772 class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { 773 public: 774 explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) 775 : OpKernel(context), 776 data_format_(FORMAT_NHWC), 777 takes_shape_(type_string().find("V2") != std::string::npos) { 778 // data_format is only available in V2. 779 if (takes_shape_) { 780 string data_format; 781 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 782 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 783 errors::InvalidArgument("Invalid data format")); 784 } 785 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 786 OP_REQUIRES(context, stride_.size() == 5, 787 errors::InvalidArgument("Sliding window strides field must " 788 "specify 5 dimensions")); 789 OP_REQUIRES( 790 context, 791 (GetTensorDim(stride_, data_format_, 'C') == 1 && 792 GetTensorDim(stride_, data_format_, 'N') == 1), 793 errors::InvalidArgument("Current implementation does not yet support " 794 "strides in the batch and depth dimensions.")); 795 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 796 cudnn_use_autotune_ = CudnnUseAutotune(); 797 } 798 799 void Compute(OpKernelContext* context) override { 800 const Tensor& input = context->input(0); 801 const TensorShape& input_shape = input.shape(); 802 TensorShape filter_shape; 803 if (takes_shape_) { 804 const Tensor& filter_sizes = context->input(1); 805 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 806 filter_sizes.vec<int32>(), &filter_shape)); 807 } else { 808 filter_shape = context->input(1).shape(); 809 } 810 811 EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); 812 813 Tensor* filter_backprop; 814 OP_REQUIRES_OK(context, 815 context->allocate_output(0, filter_shape, &filter_backprop)); 816 817 auto* stream = context->op_device_context()->stream(); 818 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); 819 820 if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 && 821 strides[2] == 1 && strides[1] == 1 && strides[0] == 1 && 822 data_format_ == FORMAT_NHWC) { 823 const uint64 m = in_depth; 824 const uint64 k = batch * input_size[1] * input_size[2] * input_size[0]; 825 const uint64 n = out_depth; 826 827 // The shape of output backprop is 828 // [batch, out_z, out_y, out_x, out_depth] 829 // From cublas's perspective, it is: n x k 830 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 831 out_backprop.template flat<T>().size()); 832 833 // The shape of input is: 834 // [batch, in_z, in_y, in_x, in_depth], 835 // From cublas's perspective, it is: m x k 836 auto b_ptr = AsDeviceMemory(input.template flat<T>().data(), 837 input.template flat<T>().size()); 838 839 // The shape of the filter backprop is: 840 // [1, 1, 1, in_depth, out_depth] 841 // From cublas's perspective, it is: n x m 842 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), 843 filter_backprop->template flat<T>().size()); 844 845 bool blas_launch_status = 846 stream 847 ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, 848 perftools::gputools::blas::Transpose::kTranspose, 849 n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) 850 .ok(); 851 if (!blas_launch_status) { 852 context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 853 ", n=", n, ", k=", k)); 854 } 855 return; 856 } else if (filter_size[0] == input_size[0] && 857 filter_size[1] == input_size[1] && 858 filter_size[2] == input_size[2] && padding_ == Padding::VALID && 859 data_format_ == FORMAT_NHWC) { 860 const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth; 861 const uint64 k = batch; 862 const uint64 n = out_depth; 863 864 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), 865 input.template flat<T>().size()); 866 auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 867 out_backprop.template flat<T>().size()); 868 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), 869 filter_backprop->template flat<T>().size()); 870 871 bool blas_launch_status = 872 stream 873 ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, 874 perftools::gputools::blas::Transpose::kTranspose, 875 n, m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n) 876 .ok(); 877 if (!blas_launch_status) { 878 context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 879 ", n=", n, ", k=", k)); 880 } 881 return; 882 } 883 884 int padding_rows = 0, padding_cols = 0, padding_planes = 0; 885 886 if (padding_ == Padding::SAME) { 887 padding_planes = std::max<int>( 888 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]); 889 padding_cols = std::max<int>( 890 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]); 891 padding_rows = std::max<int>( 892 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]); 893 } 894 bool rows_odd = (padding_rows % 2 != 0); 895 bool cols_odd = (padding_cols % 2 != 0); 896 bool planes_odd = (padding_planes % 2 != 0); 897 898 Tensor compatible_input; 899 if (rows_odd || cols_odd || planes_odd) { 900 OP_REQUIRES_OK(context, context->allocate_temp( 901 DataTypeToEnum<T>::value, 902 ShapeFromFormat(data_format_, batch, 903 {{input_size[0] + planes_odd, 904 input_size[1] + rows_odd, 905 input_size[2] + cols_odd}}, 906 in_depth), 907 &compatible_input)); 908 functor::PadInput<GPUDevice, T, int, 5>()( 909 context->template eigen_device<GPUDevice>(), 910 To32Bit(input.tensor<T, 5>()), {{0, 0, 0}}, 911 {{planes_odd, rows_odd, cols_odd}}, 912 To32Bit(compatible_input.tensor<T, 5>()), data_format_); 913 } else { 914 compatible_input = input; 915 } 916 917 CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) 918 << "Negative paddings: (" << padding_rows << ", " << padding_cols 919 << ", " << padding_planes << ")"; 920 perftools::gputools::dnn::BatchDescriptor input_desc(3); 921 input_desc.set_count(batch) 922 .set_spatial_dim(DimIndex::X, 923 GetTensorDim(compatible_input, data_format_, '2')) 924 .set_spatial_dim(DimIndex::Y, 925 GetTensorDim(compatible_input, data_format_, '1')) 926 .set_spatial_dim(DimIndex::Z, 927 GetTensorDim(compatible_input, data_format_, '0')) 928 .set_feature_map_count(in_depth) 929 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 930 perftools::gputools::dnn::BatchDescriptor output_desc(3); 931 output_desc.set_count(batch) 932 .set_spatial_dim(DimIndex::X, output_cols) 933 .set_spatial_dim(DimIndex::Y, output_rows) 934 .set_spatial_dim(DimIndex::Z, output_planes) 935 .set_feature_map_count(out_depth) 936 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 937 perftools::gputools::dnn::FilterDescriptor filter_desc(3); 938 filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) 939 .set_spatial_dim(DimIndex::Y, filter_size[1]) 940 .set_spatial_dim(DimIndex::Z, filter_size[0]) 941 .set_input_feature_map_count(in_depth) 942 .set_output_feature_map_count(out_depth); 943 perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3); 944 conv_desc.set_filter_stride(DimIndex::X, strides[2]) 945 .set_filter_stride(DimIndex::Y, strides[1]) 946 .set_filter_stride(DimIndex::Z, strides[0]) 947 .set_zero_padding(DimIndex::X, padding_cols / 2) 948 .set_zero_padding(DimIndex::Y, padding_rows / 2) 949 .set_zero_padding(DimIndex::Z, padding_planes / 2); 950 951 Tensor pre_transformed_filter_backprop; 952 OP_REQUIRES_OK( 953 context, 954 context->allocate_temp(DataTypeToEnum<T>::value, 955 TensorShape({out_depth, in_depth, filter_size[0], 956 filter_size[1], filter_size[2]}), 957 &pre_transformed_filter_backprop)); 958 959 Tensor transformed_out_backprop; 960 if (data_format_ == FORMAT_NHWC) { 961 TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows, 962 output_cols}; 963 OP_REQUIRES_OK( 964 context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, 965 &transformed_out_backprop)); 966 if (out_depth > 1) { 967 functor::NHWCToNCHW<GPUDevice, T, 5>()( 968 context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), 969 transformed_out_backprop.tensor<T, 5>()); 970 } else { 971 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); 972 } 973 } else { 974 transformed_out_backprop = out_backprop; 975 } 976 Tensor transformed_input; 977 if (data_format_ == FORMAT_NHWC) { 978 TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1), 979 compatible_input.dim_size(2), 980 compatible_input.dim_size(3)}; 981 if (in_depth > 1) { 982 OP_REQUIRES_OK(context, 983 context->allocate_temp(DataTypeToEnum<T>::value, 984 nchw_shape, &transformed_input)); 985 functor::NHWCToNCHW<GPUDevice, T, 5>()( 986 context->eigen_device<GPUDevice>(), 987 const_cast<const Tensor&>(compatible_input).tensor<T, 5>(), 988 transformed_input.tensor<T, 5>()); 989 } else { 990 CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); 991 } 992 } else { 993 transformed_input = compatible_input; 994 } 995 996 auto out_backprop_ptr = 997 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), 998 transformed_out_backprop.template flat<T>().size()); 999 auto filter_backprop_ptr = AsDeviceMemory( 1000 pre_transformed_filter_backprop.template flat<T>().data(), 1001 pre_transformed_filter_backprop.template flat<T>().size()); 1002 auto input_ptr = 1003 AsDeviceMemory(transformed_input.template flat<T>().data(), 1004 transformed_input.template flat<T>().size()); 1005 1006 static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit( 1007 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32); // 4GB by default 1008 1009 const int device_id = stream->parent()->device_ordinal(); 1010 DataType dtype = input.dtype(); 1011 const ConvParameters conv_parameters = { 1012 batch, 1013 in_depth, 1014 {{input_size[0], input_size[1], input_size[2]}}, 1015 out_depth, 1016 {{filter_size[0], filter_size[1], filter_size[2]}}, 1017 {{1, 1, 1}}, 1018 {{strides[0], strides[1], strides[2]}}, 1019 {{padding_planes, padding_rows, padding_cols}}, 1020 dtype, 1021 device_id, 1022 }; 1023 1024 using perftools::gputools::dnn::AlgorithmConfig; 1025 using perftools::gputools::dnn::AlgorithmDesc; 1026 using perftools::gputools::dnn::ProfileResult; 1027 AlgorithmConfig algorithm_config; 1028 if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find( 1029 conv_parameters, &algorithm_config)) { 1030 std::vector<AlgorithmDesc> algorithms; 1031 CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( 1032 conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); 1033 ProfileResult best_result; 1034 ProfileResult best_result_no_scratch; 1035 for (auto profile_algorithm : algorithms) { 1036 // TODO(zhengxq): profile each algorithm multiple times to better 1037 // accuracy. 1038 CudnnScratchAllocator scratch_allocator( 1039 ConvolveBackwardFilterScratchSize, context); 1040 ProfileResult profile_result; 1041 bool cudnn_launch_status = 1042 stream 1043 ->ThenConvolveBackwardFilterWithAlgorithm( 1044 input_desc, input_ptr, output_desc, out_backprop_ptr, 1045 conv_desc, filter_desc, &filter_backprop_ptr, 1046 &scratch_allocator, AlgorithmConfig(profile_algorithm), 1047 &profile_result) 1048 .ok(); 1049 if (cudnn_launch_status) { 1050 if (profile_result.is_valid()) { 1051 if (profile_result.elapsed_time_in_ms() < 1052 best_result.elapsed_time_in_ms()) { 1053 best_result = profile_result; 1054 } 1055 if (scratch_allocator.TotalByteSize() == 0 && 1056 profile_result.elapsed_time_in_ms() < 1057 best_result_no_scratch.elapsed_time_in_ms()) { 1058 best_result_no_scratch = profile_result; 1059 } 1060 } 1061 } 1062 } 1063 OP_REQUIRES(context, 1064 best_result.is_valid() || best_result_no_scratch.is_valid(), 1065 errors::NotFound("No algorithm worked!")); 1066 if (best_result.is_valid()) { 1067 algorithm_config.set_algorithm(best_result.algorithm()); 1068 } 1069 if (best_result_no_scratch.is_valid()) { 1070 algorithm_config.set_algorithm_no_scratch( 1071 best_result_no_scratch.algorithm()); 1072 } 1073 AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters, 1074 algorithm_config); 1075 } 1076 CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, 1077 context); 1078 bool cudnn_launch_status = 1079 stream 1080 ->ThenConvolveBackwardFilterWithAlgorithm( 1081 input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, 1082 filter_desc, &filter_backprop_ptr, &scratch_allocator, 1083 algorithm_config, nullptr) 1084 .ok(); 1085 1086 if (!cudnn_launch_status) { 1087 context->SetStatus(errors::Internal( 1088 "cuDNN Backward Filter function launch failure : input shape(", 1089 input_shape.DebugString(), ") filter shape(", 1090 filter_shape.DebugString(), ")")); 1091 } 1092 1093 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; 1094 functor::ReverseTransformFilter<GPUDevice, T, 5>()( 1095 context->eigen_device<GPUDevice>(), 1096 toConstTensor(pre_transformed_filter_backprop).template tensor<T, 5>(), 1097 filter_backprop->tensor<T, 5>()); 1098 } 1099 1100 private: 1101 std::vector<int32> stride_; 1102 Padding padding_; 1103 TensorFormat data_format_; 1104 bool takes_shape_; 1105 bool cudnn_use_autotune_; 1106 }; 1107 1108 #define REGISTER_GPU_KERNEL(T) \ 1109 REGISTER_KERNEL_BUILDER( \ 1110 Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 1111 Conv3DBackpropInputOp<GPUDevice, T>); \ 1112 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ 1113 .Device(DEVICE_GPU) \ 1114 .TypeConstraint<T>("T") \ 1115 .HostMemory("input_sizes"), \ 1116 Conv3DBackpropInputOp<GPUDevice, T>); \ 1117 REGISTER_KERNEL_BUILDER( \ 1118 Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 1119 Conv3DBackpropFilterOp<GPUDevice, T>); \ 1120 REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ 1121 .Device(DEVICE_GPU) \ 1122 .TypeConstraint<T>("T") \ 1123 .HostMemory("filter_sizes"), \ 1124 Conv3DBackpropFilterOp<GPUDevice, T>); 1125 TF_CALL_half(REGISTER_GPU_KERNEL); 1126 TF_CALL_float(REGISTER_GPU_KERNEL); 1127 #undef REGISTER_GPU_KERNEL 1128 1129 #endif // GOOGLE_CUDA 1130 1131 } // namespace tensorflow 1132