1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define USE_EIGEN_TENSOR 19 #define EIGEN_USE_THREADS 20 21 #include "tensorflow/core/kernels/conv_grad_ops.h" 22 23 #include <algorithm> 24 #include <vector> 25 26 #include "tensorflow/core/framework/numeric_op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/framework/tensor_slice.h" 32 #include "tensorflow/core/kernels/conv_2d.h" 33 #ifdef TENSORFLOW_USE_LIBXSMM 34 #include "tensorflow/core/kernels/xsmm_conv2d.h" 35 #endif 36 #include "tensorflow/core/kernels/ops_util.h" 37 #include "tensorflow/core/lib/core/errors.h" 38 #include "tensorflow/core/lib/gtl/array_slice.h" 39 #include "tensorflow/core/platform/logging.h" 40 #include "tensorflow/core/platform/macros.h" 41 #include "tensorflow/core/util/padding.h" 42 #include "tensorflow/core/util/tensor_format.h" 43 #include "tensorflow/core/util/use_cudnn.h" 44 #include "tensorflow/core/util/work_sharder.h" 45 46 #if GOOGLE_CUDA 47 #include "tensorflow/core/kernels/conv_ops_gpu.h" 48 #include "tensorflow/core/platform/stream_executor.h" 49 #endif // GOOGLE_CUDA 50 51 namespace { 52 53 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage 54 // order (height, width, depth), constructed from patches in 'col_data', which 55 // is required to be in storage order (out_height * out_width, filter_height, 56 // filter_width, in_depth). Implementation by Yangqing Jia (jiayq). 57 template <typename T> 58 void Col2im(const T* col_data, const int depth, const int height, 59 const int width, const int filter_h, const int filter_w, 60 const int pad_t, const int pad_l, const int pad_b, const int pad_r, 61 const int stride_h, const int stride_w, T* im_data) { 62 int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; 63 int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; 64 int h_pad = -pad_t; 65 for (int h = 0; h < height_col; ++h) { 66 int w_pad = -pad_l; 67 for (int w = 0; w < width_col; ++w) { 68 T* im_patch_data = im_data + (h_pad * width + w_pad) * depth; 69 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { 70 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { 71 if (ih >= 0 && ih < height && iw >= 0 && iw < width) { 72 // TODO(andydavis) Vectorize this loop (if compiler does not). 73 for (int i = 0; i < depth; ++i) { 74 im_patch_data[i] += col_data[i]; 75 } 76 } 77 im_patch_data += depth; 78 col_data += depth; 79 } 80 // Jump over remaining number of depth. 81 im_patch_data += depth * (width - filter_w); 82 } 83 w_pad += stride_w; 84 } 85 h_pad += stride_h; 86 } 87 } 88 89 } // namespace 90 91 namespace tensorflow { 92 93 typedef Eigen::ThreadPoolDevice CPUDevice; 94 typedef Eigen::GpuDevice GPUDevice; 95 96 // The fast versions using eigen computations directly. They are only enabled 97 // for CPU for now since nvcc times out when trying to compile them. 98 // TODO(yangke): enable them for GPUs when we have a faster compiler. 99 100 template <typename T> 101 struct LaunchConv2DBackpropInputOp<CPUDevice, T> { 102 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, 103 const Tensor& out_backprop, const Tensor& filter, 104 int row_stride, int col_stride, const Padding& padding, 105 Tensor* in_backprop, TensorFormat data_format) { 106 const CPUDevice& d = ctx->eigen_device<CPUDevice>(); 107 functor::SpatialConvolutionBackwardInput<CPUDevice, T>()( 108 d, in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(), 109 out_backprop.tensor<T, 4>(), row_stride, col_stride, 110 /*row_dilation=*/1, /*col_dilation=*/1); 111 } 112 }; 113 114 #ifdef TENSORFLOW_USE_LIBXSMM 115 template <typename Device, class T> 116 struct LaunchXsmmBackwardInputConvolution { 117 bool operator()(OpKernelContext* context, const Device& d, 118 typename TTypes<T, 4>::Tensor input_backward, 119 typename TTypes<T, 4>::ConstTensor kernel, 120 typename TTypes<T, 4>::ConstTensor output_backward, 121 int input_rows, int input_cols, int row_stride, 122 int col_stride, int pad_h, int pad_w, 123 TensorFormat data_format) const { 124 return false; 125 } 126 }; 127 128 template <> 129 struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> { 130 bool operator()(OpKernelContext* context, const CPUDevice& d, 131 typename TTypes<float, 4>::Tensor input_backward, 132 typename TTypes<float, 4>::ConstTensor kernel, 133 typename TTypes<float, 4>::ConstTensor output_backward, 134 int input_rows, int input_cols, int row_stride, 135 int col_stride, int pad_h, int pad_w, 136 TensorFormat data_format) const { 137 auto batch = input_backward.dimension(0); 138 auto in_depth = input_backward.dimension(3); 139 auto out_depth = output_backward.dimension(3); 140 auto filter_rows = kernel.dimension(0); 141 auto filter_cols = kernel.dimension(1); 142 auto num_threads = 143 context->device()->tensorflow_cpu_worker_threads()->num_threads; 144 // See libxsmm_dnn.h for this struct definition. 145 libxsmm_dnn_conv_desc desc; 146 desc.N = batch; 147 desc.C = in_depth; 148 desc.H = input_rows; 149 desc.W = input_cols; 150 desc.K = out_depth; 151 desc.R = filter_rows; 152 desc.S = filter_cols; 153 desc.u = row_stride; 154 desc.v = col_stride; 155 desc.pad_h = pad_h; 156 desc.pad_w = pad_w; 157 desc.pad_h_in = 0; 158 desc.pad_w_in = 0; 159 desc.pad_h_out = 0; 160 desc.pad_w_out = 0; 161 desc.threads = num_threads; 162 desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; 163 desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; 164 desc.filter_format = 165 LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM; // LIBXSMM_DNN_TENSOR_FORMAT_RSCK; 166 desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; 167 desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE; 168 desc.datatype = LIBXSMM_DNN_DATATYPE_F32; 169 170 auto input_ptr = input_backward.data(); 171 auto filter_ptr = kernel.data(); 172 auto output_ptr = output_backward.data(); 173 174 bool success = functor::XsmmBkwInputConv2D<CPUDevice, float>()( 175 context, desc, input_ptr, filter_ptr, output_ptr); 176 return success; 177 } 178 }; 179 #endif 180 181 template <typename Device, class T> 182 class Conv2DFastBackpropInputOp : public OpKernel { 183 public: 184 explicit Conv2DFastBackpropInputOp(OpKernelConstruction* context) 185 : OpKernel(context) { 186 string data_format; 187 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 188 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 189 errors::InvalidArgument("Invalid data format")); 190 OP_REQUIRES(context, data_format_ == FORMAT_NHWC, 191 errors::InvalidArgument( 192 "Eigen Conv2DFastBackpropInputOp only supports NHWC.")); 193 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 194 OP_REQUIRES(context, strides_.size() == 4, 195 errors::InvalidArgument("Sliding window strides field must " 196 "specify 4 dimensions")); 197 OP_REQUIRES( 198 context, (strides_[0] == 1 && strides_[3] == 1), 199 errors::InvalidArgument("Current implementation does not yet support " 200 "strides in the batch and depth dimensions.")); 201 OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, 202 errors::InvalidArgument( 203 "Row and column strides should be larger than 0.")); 204 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 205 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 206 OP_REQUIRES(context, dilations_.size() == 4, 207 errors::InvalidArgument("Sliding window dilations field must " 208 "specify 4 dimensions")); 209 OP_REQUIRES(context, (dilations_[0] && dilations_[3]), 210 errors::InvalidArgument( 211 "Current implementation does not yet support " 212 "dilations in the batch and depth dimensions.")); 213 // TODO(yangzihao): Add a CPU implementation for dilated convolution. 214 OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), 215 errors::InvalidArgument( 216 "Current Eigen and libxsmm implementations do not " 217 "yet support dilation rates larger than 1.")); 218 } 219 220 void Compute(OpKernelContext* context) override { 221 const Tensor& input_sizes = context->input(0); 222 const Tensor& filter = context->input(1); 223 const Tensor& out_backprop = context->input(2); 224 OP_REQUIRES( 225 context, TensorShapeUtils::IsVector(input_sizes.shape()), 226 errors::InvalidArgument( 227 "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", 228 input_sizes.dims())); 229 TensorShape input_shape; 230 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 231 input_sizes.vec<int32>(), &input_shape)); 232 233 ConvBackpropDimensions dims; 234 OP_REQUIRES_OK(context, 235 ConvBackpropComputeDimensions( 236 "Conv2DFastBackpropInput", /*num_spatial_dims=*/2, 237 input_shape, filter.shape(), out_backprop.shape(), 238 strides_, padding_, data_format_, &dims)); 239 240 Tensor* in_backprop = nullptr; 241 OP_REQUIRES_OK(context, 242 context->allocate_output(0, input_shape, &in_backprop)); 243 244 // If there is nothing to compute, return. 245 if (input_shape.num_elements() == 0) { 246 return; 247 } 248 249 #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD 250 int64 pad_top, pad_bottom; 251 int64 pad_left, pad_right; 252 OP_REQUIRES_OK( 253 context, 254 GetWindowedOutputSizeVerbose( 255 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, 256 dims.spatial_dims[0].stride, padding_, 257 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); 258 OP_REQUIRES_OK( 259 context, 260 GetWindowedOutputSizeVerbose( 261 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, 262 dims.spatial_dims[1].stride, padding_, 263 &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); 264 265 if (pad_left == pad_right && pad_top == pad_bottom) { 266 if (LaunchXsmmBackwardInputConvolution<Device, T>()( 267 context, context->eigen_device<Device>(), 268 in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(), 269 out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size, 270 dims.spatial_dims[1].input_size, 271 static_cast<int>(dims.spatial_dims[0].stride), 272 static_cast<int>(dims.spatial_dims[1].stride), 273 static_cast<int>(pad_top), static_cast<int>(pad_left), 274 data_format_)) { 275 return; 276 } 277 } 278 #endif 279 280 LaunchConv2DBackpropInputOp<Device, T>()( 281 context, false, false, out_backprop, filter, 282 dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_, 283 in_backprop, data_format_); 284 } 285 286 private: 287 std::vector<int32> dilations_; 288 std::vector<int32> strides_; 289 Padding padding_; 290 TensorFormat data_format_; 291 292 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DFastBackpropInputOp); 293 }; 294 295 // Based on implementation written by Yangqing Jia (jiayq). 296 template <typename Device, class T> 297 class Conv2DCustomBackpropInputOp : public OpKernel { 298 public: 299 explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context) 300 : OpKernel(context) { 301 string data_format; 302 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 303 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 304 errors::InvalidArgument("Invalid data format")); 305 OP_REQUIRES(context, data_format_ == FORMAT_NHWC, 306 errors::InvalidArgument( 307 "Conv2DCustomBackpropInputOp only supports NHWC.")); 308 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 309 OP_REQUIRES(context, strides_.size() == 4, 310 errors::InvalidArgument("Sliding window strides field must " 311 "specify 4 dimensions")); 312 OP_REQUIRES( 313 context, (strides_[0] == 1 && strides_[3] == 1), 314 errors::InvalidArgument("Current implementation does not yet support " 315 "strides in the batch and depth dimensions.")); 316 OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, 317 errors::InvalidArgument( 318 "Row and column strides should be larger than 0.")); 319 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 320 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 321 OP_REQUIRES(context, dilations_.size() == 4, 322 errors::InvalidArgument("Sliding window dilations field must " 323 "specify 4 dimensions")); 324 OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1), 325 errors::InvalidArgument( 326 "Current implementation does not yet support " 327 "dilations in the batch and depth dimensions.")); 328 // TODO(yangzihao): Add a CPU implementation for dilated convolution. 329 OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), 330 errors::InvalidArgument( 331 "Current libxsmm and customized CPU implementations do " 332 "not yet support dilation rates larger than 1.")); 333 } 334 335 void Compute(OpKernelContext* context) override { 336 const Tensor& input_sizes = context->input(0); 337 const Tensor& filter = context->input(1); 338 const Tensor& out_backprop = context->input(2); 339 OP_REQUIRES( 340 context, TensorShapeUtils::IsVector(input_sizes.shape()), 341 errors::InvalidArgument( 342 "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", 343 input_sizes.dims())); 344 TensorShape input_shape; 345 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 346 input_sizes.vec<int32>(), &input_shape)); 347 348 ConvBackpropDimensions dims; 349 OP_REQUIRES_OK(context, 350 ConvBackpropComputeDimensions( 351 "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2, 352 input_shape, filter.shape(), out_backprop.shape(), 353 strides_, padding_, data_format_, &dims)); 354 355 Tensor* in_backprop = nullptr; 356 OP_REQUIRES_OK(context, 357 context->allocate_output(0, input_shape, &in_backprop)); 358 359 // If there is nothing to compute, return. 360 if (input_shape.num_elements() == 0) { 361 return; 362 } 363 364 // TODO(andydavis) Consider moving code shared with 365 // Conv2DCustomBackpropFilterOp into a shared helper function. 366 #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD 367 int64 pad_top, pad_bottom; 368 int64 pad_left, pad_right; 369 OP_REQUIRES_OK( 370 context, 371 GetWindowedOutputSizeVerbose( 372 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, 373 dims.spatial_dims[0].stride, padding_, 374 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); 375 OP_REQUIRES_OK( 376 context, 377 GetWindowedOutputSizeVerbose( 378 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, 379 dims.spatial_dims[1].stride, padding_, 380 &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); 381 382 if (pad_left == pad_right && pad_top == pad_bottom) { 383 if (LaunchXsmmBackwardInputConvolution<Device, T>()( 384 context, context->eigen_device<Device>(), 385 in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(), 386 out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size, 387 dims.spatial_dims[1].input_size, 388 static_cast<int>(dims.spatial_dims[0].stride), 389 static_cast<int>(dims.spatial_dims[1].stride), 390 static_cast<int>(pad_top), static_cast<int>(pad_left), 391 data_format_)) { 392 return; 393 } 394 } 395 #else 396 int64 pad_top, pad_bottom; 397 int64 pad_left, pad_right; 398 #endif 399 OP_REQUIRES_OK( 400 context, 401 GetWindowedOutputSizeVerbose( 402 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, 403 dims.spatial_dims[0].stride, padding_, 404 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); 405 OP_REQUIRES_OK( 406 context, 407 GetWindowedOutputSizeVerbose( 408 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, 409 dims.spatial_dims[1].stride, padding_, 410 &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); 411 412 // The total dimension size of each kernel. 413 const int filter_total_size = dims.spatial_dims[0].filter_size * 414 dims.spatial_dims[1].filter_size * 415 dims.in_depth; 416 // The output image size is the spatial size of the output. 417 const int output_image_size = 418 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size; 419 420 // TODO(andydavis) Get L2/L3 cache sizes from device. 421 const size_t l2_cache_size = 256LL << 10; 422 const size_t l3_cache_size = 30LL << 20; 423 424 // Use L3 cache size as target working set size. 425 const size_t target_working_set_size = l3_cache_size / sizeof(T); 426 427 // Calculate size of matrices involved in MatMul: C = A x B. 428 const size_t size_A = output_image_size * dims.out_depth; 429 430 const size_t size_B = filter_total_size * dims.out_depth; 431 432 const size_t size_C = output_image_size * filter_total_size; 433 434 const size_t work_unit_size = size_A + size_B + size_C; 435 436 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 437 438 // Calculate per-thread work unit size. 439 const size_t thread_work_unit_size = 440 work_unit_size / worker_threads.num_threads; 441 442 // Set minimum per-thread work unit size to size of L2 cache. 443 const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T); 444 445 // Use parallel tensor contractions if there is no batching, or if the 446 // minimum per-thread work unit size threshold has been exceeded. 447 // Otherwise, revert to multiple single-threaded matmul ops running in 448 // parallel to keep all threads busy. 449 // TODO(andydavis) Explore alternatives to branching the code in this way 450 // (i.e. run multiple, parallel tensor contractions in another thread pool). 451 const bool use_parallel_contraction = 452 dims.batch_size == 1 || 453 thread_work_unit_size >= min_thread_work_unit_size; 454 455 const size_t shard_size = 456 use_parallel_contraction 457 ? 1 458 : (target_working_set_size + work_unit_size - 1) / work_unit_size; 459 460 Tensor col_buffer; 461 OP_REQUIRES_OK(context, 462 context->allocate_temp( 463 DataTypeToEnum<T>::value, 464 TensorShape({static_cast<int64>(shard_size), 465 static_cast<int64>(output_image_size), 466 static_cast<int64>(filter_total_size)}), 467 &col_buffer)); 468 469 // The input offset corresponding to a single input image. 470 const int input_offset = dims.spatial_dims[0].input_size * 471 dims.spatial_dims[1].input_size * dims.in_depth; 472 // The output offset corresponding to a single output image. 473 const int output_offset = dims.spatial_dims[0].output_size * 474 dims.spatial_dims[1].output_size * dims.out_depth; 475 476 const T* filter_data = filter.template flat<T>().data(); 477 T* col_buffer_data = col_buffer.template flat<T>().data(); 478 const T* out_backprop_data = out_backprop.template flat<T>().data(); 479 480 auto in_backprop_flat = in_backprop->template flat<T>(); 481 T* input_backprop_data = in_backprop_flat.data(); 482 in_backprop_flat.device(context->eigen_device<Device>()) = 483 in_backprop_flat.constant(T(0)); 484 485 if (use_parallel_contraction) { 486 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, 487 Eigen::Unaligned> 488 TensorMap; 489 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, 490 Eigen::Unaligned> 491 ConstTensorMap; 492 493 // Initialize contraction dims (we need to transpose 'B' below). 494 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; 495 contract_dims[0].first = 1; 496 contract_dims[0].second = 1; 497 498 for (int image_id = 0; image_id < dims.batch_size; ++image_id) { 499 // Compute gradient into col_buffer. 500 TensorMap C(col_buffer_data, output_image_size, filter_total_size); 501 502 ConstTensorMap A(out_backprop_data + output_offset * image_id, 503 output_image_size, dims.out_depth); 504 ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); 505 506 C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); 507 508 Col2im<T>( 509 col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size, 510 dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size, 511 dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom, 512 pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, 513 input_backprop_data); 514 515 input_backprop_data += input_offset; 516 } 517 } else { 518 typedef Eigen::Map< 519 Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 520 MatrixMap; 521 typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, 522 Eigen::RowMajor>> 523 ConstMatrixMap; 524 525 for (int image_id = 0; image_id < dims.batch_size; 526 image_id += shard_size) { 527 const int shard_limit = 528 std::min(static_cast<int>(shard_size), 529 static_cast<int>(dims.batch_size) - image_id); 530 531 auto shard = [&dims, &pad_top, &pad_left, &pad_bottom, &pad_right, 532 &output_image_size, &filter_total_size, 533 &input_backprop_data, &col_buffer_data, 534 &out_backprop_data, &filter_data, &input_offset, 535 &output_offset, &size_C](int64 start, int64 limit) { 536 for (int shard_id = start; shard_id < limit; ++shard_id) { 537 T* im2col_buf = col_buffer_data + shard_id * size_C; 538 T* input_data = input_backprop_data + shard_id * input_offset; 539 const T* out_data = out_backprop_data + shard_id * output_offset; 540 541 // Compute gradient into 'im2col_buf'. 542 MatrixMap C(im2col_buf, output_image_size, filter_total_size); 543 544 ConstMatrixMap A(out_data, output_image_size, dims.out_depth); 545 ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth); 546 547 C.noalias() = A * B.transpose(); 548 549 Col2im<T>(im2col_buf, dims.in_depth, 550 dims.spatial_dims[0].input_size, 551 dims.spatial_dims[1].input_size, 552 dims.spatial_dims[0].filter_size, 553 dims.spatial_dims[1].filter_size, pad_top, pad_left, 554 pad_bottom, pad_right, dims.spatial_dims[0].stride, 555 dims.spatial_dims[1].stride, input_data); 556 } 557 }; 558 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, 559 work_unit_size, shard); 560 561 input_backprop_data += input_offset * shard_limit; 562 out_backprop_data += output_offset * shard_limit; 563 } 564 } 565 } 566 567 private: 568 std::vector<int32> dilations_; 569 std::vector<int32> strides_; 570 Padding padding_; 571 TensorFormat data_format_; 572 573 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp); 574 }; 575 576 #define REGISTER_CPU_KERNELS(T) \ 577 REGISTER_KERNEL_BUILDER( \ 578 Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 579 Conv2DCustomBackpropInputOp<CPUDevice, T>); \ 580 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \ 581 .Device(DEVICE_CPU) \ 582 .Label("custom") \ 583 .TypeConstraint<T>("T"), \ 584 Conv2DCustomBackpropInputOp<CPUDevice, T>); \ 585 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \ 586 .Device(DEVICE_CPU) \ 587 .Label("eigen_tensor") \ 588 .TypeConstraint<T>("T"), \ 589 Conv2DFastBackpropInputOp<CPUDevice, T>); 590 591 TF_CALL_half(REGISTER_CPU_KERNELS); 592 TF_CALL_float(REGISTER_CPU_KERNELS); 593 #undef REGISTER_CPU_KERNELS 594 595 // GPU definitions. 596 #if GOOGLE_CUDA 597 // The slow version (but compiles for GPU) 598 599 // A dummy type to group forward backward data autotune results together. 600 struct ConvBackwardDataAutoTuneGroup { 601 static string name() { return "ConvBwdData"; } 602 }; 603 typedef AutoTuneSingleton<ConvBackwardDataAutoTuneGroup, ConvParameters, 604 perftools::gputools::dnn::AlgorithmConfig> 605 AutoTuneConvBwdData; 606 607 // Backprop for input. 608 template <typename Device, class T> 609 class Conv2DSlowBackpropInputOp : public OpKernel { 610 public: 611 explicit Conv2DSlowBackpropInputOp(OpKernelConstruction* context) 612 : OpKernel(context) { 613 string data_format; 614 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 615 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 616 errors::InvalidArgument("Invalid data format")); 617 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 618 OP_REQUIRES(context, strides_.size() == 4, 619 errors::InvalidArgument("Sliding window strides field must " 620 "specify 4 dimensions")); 621 int stride_n = GetTensorDim(strides_, data_format_, 'N'); 622 int stride_c = GetTensorDim(strides_, data_format_, 'C'); 623 int stride_h = GetTensorDim(strides_, data_format_, 'H'); 624 int stride_w = GetTensorDim(strides_, data_format_, 'W'); 625 OP_REQUIRES( 626 context, (stride_n == 1 && stride_c == 1), 627 errors::InvalidArgument("Current implementation does not yet support " 628 "strides in the batch and depth dimensions.")); 629 OP_REQUIRES(context, stride_h > 0 && stride_w > 0, 630 errors::InvalidArgument( 631 "Row and column strides should be larger than 0.")); 632 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 633 OP_REQUIRES(context, dilations_.size() == 4, 634 errors::InvalidArgument("Sliding window dilations field must " 635 "specify 4 dimensions")); 636 int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); 637 int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); 638 int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); 639 int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); 640 OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), 641 errors::InvalidArgument( 642 "Current implementation does not yet support " 643 "dilations in the batch and depth dimensions.")); 644 OP_REQUIRES( 645 context, dilation_h > 0 && dilation_w > 0, 646 errors::InvalidArgument("Dilated rates should be larger than 0.")); 647 OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); 648 use_cudnn_ &= CanUseCudnn(); 649 cudnn_use_autotune_ = CudnnUseAutotune(); 650 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 651 } 652 653 void Compute(OpKernelContext* context) override { 654 const Tensor& input_sizes = context->input(0); 655 const Tensor& filter = context->input(1); 656 const Tensor& out_backprop = context->input(2); 657 OP_REQUIRES( 658 context, TensorShapeUtils::IsVector(input_sizes.shape()), 659 errors::InvalidArgument( 660 "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", 661 input_sizes.dims())); 662 TensorShape input_shape; 663 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 664 input_sizes.vec<int32>(), &input_shape)); 665 666 Tensor* in_backprop = nullptr; 667 OP_REQUIRES_OK(context, 668 context->allocate_output(0, input_shape, &in_backprop)); 669 670 // If there is nothing to compute, return. 671 if (input_shape.num_elements() == 0) { 672 return; 673 } 674 675 // For now we take the stride from the second and third dimensions only (we 676 // do not support striding on the batch or depth dimension). 677 const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); 678 const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); 679 const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); 680 const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); 681 682 launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter, 683 dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, 684 in_backprop, data_format_); 685 } 686 687 private: 688 std::vector<int32> dilations_; 689 std::vector<int32> strides_; 690 Padding padding_; 691 bool use_cudnn_; 692 TensorFormat data_format_; 693 LaunchConv2DBackpropInputOp<Device, T> launcher_; 694 bool cudnn_use_autotune_; 695 696 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp); 697 }; 698 699 template <typename T> 700 void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( 701 OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, 702 const Tensor& out_backprop, const Tensor& filter, int row_dilation, 703 int col_dilation, int row_stride, int col_stride, const Padding& padding, 704 Tensor* in_backprop, TensorFormat data_format) { 705 using perftools::gputools::dnn::AlgorithmConfig; 706 using perftools::gputools::dnn::AlgorithmDesc; 707 using perftools::gputools::dnn::ProfileResult; 708 709 std::vector<int32> strides(4, 1); 710 std::vector<int32> dilations(4, 1); 711 auto input_h = GetTensorDimIndex(data_format, 'H'); 712 auto input_w = GetTensorDimIndex(data_format, 'W'); 713 strides[input_h] = row_stride; 714 strides[input_w] = col_stride; 715 dilations[input_h] = row_dilation; 716 dilations[input_w] = col_dilation; 717 TensorShape input_shape = in_backprop->shape(); 718 719 const TensorShape& filter_shape = filter.shape(); 720 ConvBackpropDimensions dims; 721 OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2( 722 "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, 723 input_shape, filter_shape, out_backprop.shape(), 724 dilations, strides, padding, data_format, &dims)); 725 726 // TODO(yangzihao): The padding computations should be done in 727 // GetWindowedOutputSize() functions. 728 const int padding_rows = 729 (padding == VALID) 730 ? 0 731 : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) * 732 dims.spatial_dims[0].stride + 733 (dims.spatial_dims[0].filter_size - 1) * 734 dims.spatial_dims[0].dilation + 735 1 - dims.spatial_dims[0].input_size); 736 const int padding_cols = 737 (padding == VALID) 738 ? 0 739 : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) * 740 dims.spatial_dims[1].stride + 741 (dims.spatial_dims[1].filter_size - 1) * 742 dims.spatial_dims[1].dilation + 743 1 - dims.spatial_dims[1].input_size); 744 745 // TODO(keveman): cuDNN only supports equal padding on both sides, so only 746 // calling it when that is true. Remove this check when (if?) cuDNN starts 747 // supporting different padding. 748 bool rows_odd = (padding_rows % 2 != 0); 749 bool cols_odd = (padding_cols % 2 != 0); 750 751 auto* stream = ctx->op_device_context()->stream(); 752 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); 753 754 if (!use_cudnn) { 755 ctx->SetStatus(errors::Unimplemented( 756 "Conv2DBackpropInput for GPU is not currently supported " 757 "without cudnn")); 758 return; 759 } 760 761 if (dims.spatial_dims[0].filter_size == 1 && 762 dims.spatial_dims[1].filter_size == 1 && 763 dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && 764 data_format == FORMAT_NHWC) { 765 // 1x1 filter, so call cublas directly. 766 const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size * 767 dims.spatial_dims[1].input_size; 768 const uint64 k = dims.out_depth; 769 const uint64 n = dims.in_depth; 770 771 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 772 out_backprop.template flat<T>().size()); 773 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), 774 filter.template flat<T>().size()); 775 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), 776 in_backprop->template flat<T>().size()); 777 778 auto transpose = perftools::gputools::blas::Transpose::kTranspose; 779 auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; 780 781 bool blas_launch_status = 782 stream 783 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, 784 a_ptr, k, 0.0f, &c_ptr, n) 785 .ok(); 786 if (!blas_launch_status) { 787 ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 788 ", n=", n, ", k=", k)); 789 } 790 return; 791 } else if (dims.spatial_dims[0].filter_size == 792 dims.spatial_dims[0].input_size && 793 dims.spatial_dims[1].filter_size == 794 dims.spatial_dims[1].input_size && 795 padding == VALID && data_format == FORMAT_NHWC) { 796 // The input data and filter have the same height/width, so call cublas 797 // directly. 798 const uint64 m = dims.batch_size; 799 const uint64 k = dims.out_depth; 800 const uint64 n = dims.spatial_dims[0].input_size * 801 dims.spatial_dims[1].input_size * dims.in_depth; 802 803 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 804 out_backprop.template flat<T>().size()); 805 auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(), 806 filter.template flat<T>().size()); 807 auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(), 808 in_backprop->template flat<T>().size()); 809 810 auto transpose = perftools::gputools::blas::Transpose::kTranspose; 811 auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; 812 813 bool blas_launch_status = 814 stream 815 ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, 816 a_ptr, k, 0.0f, &c_ptr, n) 817 .ok(); 818 if (!blas_launch_status) { 819 ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 820 ", n=", n, ", k=", k)); 821 } 822 return; 823 } 824 825 TensorShape compatible_input_shape; 826 if (rows_odd || cols_odd) { 827 // If a padding dimension is odd, we have one more element on the right 828 // side or the bottom side. This is unsupported in cudnn. Therefore, 829 // we pad that extra element and make it compatible. 830 compatible_input_shape = ShapeFromFormat( 831 data_format, dims.batch_size, 832 dims.spatial_dims[0].input_size + rows_odd, 833 dims.spatial_dims[1].input_size + cols_odd, dims.in_depth); 834 } else { 835 compatible_input_shape = input_shape; 836 } 837 838 CHECK(padding_rows >= 0 && padding_cols >= 0) 839 << "Negative row or col paddings: (" << padding_rows << ", " 840 << padding_cols << ")"; 841 perftools::gputools::dnn::BatchDescriptor input_desc; 842 input_desc.set_count(dims.batch_size) 843 .set_height(GetTensorDim(compatible_input_shape, data_format, 'H')) 844 .set_width(GetTensorDim(compatible_input_shape, data_format, 'W')) 845 .set_feature_map_count(dims.in_depth) 846 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 847 perftools::gputools::dnn::BatchDescriptor output_desc; 848 output_desc.set_count(dims.batch_size) 849 .set_height(dims.spatial_dims[0].output_size) 850 .set_width(dims.spatial_dims[1].output_size) 851 .set_feature_map_count(dims.out_depth) 852 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 853 perftools::gputools::dnn::FilterDescriptor filter_desc; 854 filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) 855 .set_input_filter_width(dims.spatial_dims[1].filter_size) 856 .set_input_feature_map_count(dims.in_depth) 857 .set_output_feature_map_count(dims.out_depth); 858 perftools::gputools::dnn::ConvolutionDescriptor conv_desc; 859 conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation) 860 .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation) 861 .set_vertical_filter_stride(dims.spatial_dims[0].stride) 862 .set_horizontal_filter_stride(dims.spatial_dims[1].stride) 863 .set_zero_padding_height(padding_rows / 2) 864 .set_zero_padding_width(padding_cols / 2); 865 866 // NOTE(keveman): 867 // cuDNN only supports the following layouts : 868 // Input : B x D x R x C 869 // Filter : OD x ID x R x C 870 // Whereas, we have 871 // Input : B x R x C x D 872 // Filter : R x C x ID x OD 873 // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) 874 // The first TransformDepth performs 875 // (B x R x C x D) => (B x D x R x C). 876 // Since the tensor returned from cuDNN is B x D x R x C also, 877 // the second TransformDepth performs 878 // (B x D x R x C) => (B x R x C x D). 879 Tensor transformed_filter; 880 OP_REQUIRES_OK( 881 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 882 TensorShape({dims.out_depth, dims.in_depth, 883 dims.spatial_dims[0].filter_size, 884 dims.spatial_dims[1].filter_size}), 885 &transformed_filter)); 886 887 functor::TransformFilter<GPUDevice, T, int, 4>()( 888 ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()), 889 To32Bit(transformed_filter.tensor<T, 4>())); 890 891 Tensor transformed_out_backprop; 892 if (data_format == FORMAT_NHWC) { 893 TensorShape nchw_shape = ShapeFromFormat( 894 FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, 895 dims.spatial_dims[1].output_size, dims.out_depth); 896 if (dims.out_depth > 1) { 897 OP_REQUIRES_OK(ctx, 898 ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, 899 &transformed_out_backprop)); 900 functor::NHWCToNCHW<GPUDevice, T, 4>()( 901 ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(), 902 transformed_out_backprop.tensor<T, 4>()); 903 } else { 904 // If depth <= 1, then just reshape. 905 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); 906 } 907 } else { 908 transformed_out_backprop = out_backprop; 909 } 910 911 Tensor pre_transformed_in_backprop; 912 OP_REQUIRES_OK( 913 ctx, ctx->allocate_temp( 914 DataTypeToEnum<T>::value, 915 ShapeFromFormat( 916 FORMAT_NCHW, 917 GetTensorDim(compatible_input_shape, data_format, 'N'), 918 GetTensorDim(compatible_input_shape, data_format, 'H'), 919 GetTensorDim(compatible_input_shape, data_format, 'W'), 920 GetTensorDim(compatible_input_shape, data_format, 'C')), 921 &pre_transformed_in_backprop)); 922 923 auto out_backprop_ptr = 924 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), 925 transformed_out_backprop.template flat<T>().size()); 926 auto filter_ptr = 927 AsDeviceMemory(transformed_filter.template flat<T>().data(), 928 transformed_filter.template flat<T>().size()); 929 auto in_backprop_ptr = 930 AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(), 931 pre_transformed_in_backprop.template flat<T>().size()); 932 933 static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit( 934 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default 935 ); 936 CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx); 937 int device_id = stream->parent()->device_ordinal(); 938 DataType dtype = out_backprop.dtype(); 939 ConvParameters conv_parameters = { 940 dims.batch_size, // batch 941 dims.in_depth, // in_depths 942 {{input_desc.height(), // in_rows 943 input_desc.width()}}, // in_cols 944 dims.out_depth, // out_depths 945 {{dims.spatial_dims[0].filter_size, // filter_rows 946 dims.spatial_dims[1].filter_size}}, // filter_cols 947 {{dims.spatial_dims[0].dilation, // dilation_rows 948 dims.spatial_dims[1].dilation}}, // dilation_cols 949 {{dims.spatial_dims[0].stride, // stride_rows 950 dims.spatial_dims[1].stride}}, // stride_cols 951 {{padding_rows, // padding_rows 952 padding_cols}}, // padding_cols 953 dtype, // tensor data type 954 device_id, // device_id 955 }; 956 AlgorithmConfig algorithm_config; 957 if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find( 958 conv_parameters, &algorithm_config)) { 959 std::vector<AlgorithmDesc> algorithms; 960 CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( 961 conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); 962 ProfileResult best_result; 963 ProfileResult best_result_no_scratch; 964 for (auto profile_algorithm : algorithms) { 965 // TODO(zhengxq): profile each algorithm multiple times to better 966 // accuracy. 967 CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, 968 ctx); 969 ProfileResult profile_result; 970 bool cudnn_launch_status = 971 stream 972 ->ThenConvolveBackwardDataWithAlgorithm( 973 filter_desc, filter_ptr, output_desc, out_backprop_ptr, 974 conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, 975 AlgorithmConfig(profile_algorithm), &profile_result) 976 .ok(); 977 if (cudnn_launch_status) { 978 if (profile_result.is_valid()) { 979 if (profile_result.elapsed_time_in_ms() < 980 best_result.elapsed_time_in_ms()) { 981 best_result = profile_result; 982 } 983 if (scratch_allocator.TotalByteSize() == 0 && 984 profile_result.elapsed_time_in_ms() < 985 best_result_no_scratch.elapsed_time_in_ms()) { 986 best_result_no_scratch = profile_result; 987 } 988 } 989 } 990 } 991 OP_REQUIRES(ctx, 992 best_result.is_valid() || best_result_no_scratch.is_valid(), 993 errors::NotFound("No algorithm worked!")); 994 if (best_result.is_valid()) { 995 algorithm_config.set_algorithm(best_result.algorithm()); 996 } 997 if (best_result_no_scratch.is_valid()) { 998 algorithm_config.set_algorithm_no_scratch( 999 best_result_no_scratch.algorithm()); 1000 } 1001 AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters, 1002 algorithm_config); 1003 } 1004 bool cudnn_launch_status = 1005 stream 1006 ->ThenConvolveBackwardDataWithAlgorithm( 1007 filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc, 1008 input_desc, &in_backprop_ptr, &scratch_allocator, 1009 algorithm_config, nullptr) 1010 .ok(); 1011 1012 if (!cudnn_launch_status) { 1013 ctx->SetStatus(errors::Internal( 1014 "cuDNN Backward Data function launch failure : input shape(", 1015 input_shape.DebugString(), ") filter shape(", 1016 filter_shape.DebugString(), ")")); 1017 return; 1018 } 1019 1020 if (rows_odd || cols_odd) { 1021 Tensor in_backprop_remove_padding; 1022 OP_REQUIRES_OK( 1023 ctx, ctx->allocate_temp( 1024 DataTypeToEnum<T>::value, 1025 ShapeFromFormat(FORMAT_NCHW, 1026 GetTensorDim(input_shape, data_format, 'N'), 1027 GetTensorDim(input_shape, data_format, 'H'), 1028 GetTensorDim(input_shape, data_format, 'W'), 1029 GetTensorDim(input_shape, data_format, 'C')), 1030 &in_backprop_remove_padding)); 1031 1032 // Remove the padding for odd rows or cols. 1033 functor::PadInput<GPUDevice, T, int, 4>()( 1034 ctx->template eigen_device<GPUDevice>(), 1035 To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop) 1036 .tensor<T, 4>()), 1037 {{0, 0}}, {{-rows_odd, -cols_odd}}, 1038 To32Bit(in_backprop_remove_padding.tensor<T, 4>()), FORMAT_NCHW); 1039 1040 pre_transformed_in_backprop = in_backprop_remove_padding; 1041 } 1042 1043 if (data_format == FORMAT_NHWC) { 1044 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; 1045 functor::NCHWToNHWC<GPUDevice, T, 4>()( 1046 ctx->eigen_device<GPUDevice>(), 1047 toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(), 1048 in_backprop->tensor<T, 4>()); 1049 } else { 1050 *in_backprop = pre_transformed_in_backprop; 1051 } 1052 } 1053 1054 // Forward declarations of the functor specializations for GPU. 1055 namespace functor { 1056 #define DECLARE_GPU_SPEC(T) \ 1057 template <> \ 1058 void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \ 1059 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \ 1060 const Eigen::DSizes<int, 4>& order, \ 1061 const Eigen::array<bool, 4>& reverse_dims, \ 1062 typename TTypes<T, 4, int>::Tensor output); \ 1063 extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>; \ 1064 template <> \ 1065 void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()( \ 1066 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \ 1067 const Eigen::DSizes<int, 4>& strides, \ 1068 const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims, \ 1069 const Eigen::DSizes<int, 4>& order, \ 1070 typename TTypes<T, 4, int>::Tensor output); \ 1071 extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \ 1072 template <> \ 1073 void TransformFilter<GPUDevice, T, int, 4>::operator()( \ 1074 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ 1075 typename TTypes<T, 4, int>::Tensor out); \ 1076 extern template struct TransformFilter<GPUDevice, T, int, 4>; \ 1077 template <> \ 1078 void TransformDepth<GPUDevice, T, int>::operator()( \ 1079 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ 1080 const Eigen::DSizes<int, 4>& shuffle, \ 1081 typename TTypes<T, 4, int>::Tensor out); \ 1082 extern template struct TransformDepth<GPUDevice, T, int>; \ 1083 template <> \ 1084 void PadInput<GPUDevice, T, int, 4>::operator()( \ 1085 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ 1086 const std::array<int, 2>& padding_left, \ 1087 const std::array<int, 2>& padding_right, \ 1088 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \ 1089 extern template struct PadInput<GPUDevice, T, int, 4>; 1090 1091 DECLARE_GPU_SPEC(float); 1092 DECLARE_GPU_SPEC(Eigen::half); 1093 #undef DECLARE_GPU_SPEC 1094 } // namespace functor 1095 1096 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") 1097 .Device(DEVICE_GPU) 1098 .TypeConstraint<float>("T") 1099 .HostMemory("input_sizes"), 1100 Conv2DSlowBackpropInputOp<GPUDevice, float>); 1101 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") 1102 .Device(DEVICE_GPU) 1103 .TypeConstraint<Eigen::half>("T") 1104 .HostMemory("input_sizes"), 1105 Conv2DSlowBackpropInputOp<GPUDevice, Eigen::half>); 1106 #endif // GOOGLE_CUDA 1107 1108 } // namespace tensorflow 1109