1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #define USE_EIGEN_TENSOR 19 #define EIGEN_USE_THREADS 20 21 #include "tensorflow/core/kernels/conv_grad_ops.h" 22 23 #include <algorithm> 24 #include <vector> 25 26 #include "tensorflow/core/framework/numeric_op.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/framework/tensor_slice.h" 32 #include "tensorflow/core/kernels/conv_2d.h" 33 #include "tensorflow/core/kernels/fill_functor.h" 34 #ifdef TENSORFLOW_USE_LIBXSMM 35 #include "tensorflow/core/kernels/xsmm_conv2d.h" 36 #endif 37 #include "tensorflow/core/kernels/ops_util.h" 38 #include "tensorflow/core/lib/core/errors.h" 39 #include "tensorflow/core/lib/gtl/array_slice.h" 40 #include "tensorflow/core/platform/logging.h" 41 #include "tensorflow/core/platform/macros.h" 42 #include "tensorflow/core/util/padding.h" 43 #include "tensorflow/core/util/tensor_format.h" 44 #include "tensorflow/core/util/use_cudnn.h" 45 #include "tensorflow/core/util/work_sharder.h" 46 47 #if GOOGLE_CUDA 48 #include "tensorflow/core/kernels/conv_ops_gpu.h" 49 #include "tensorflow/core/platform/stream_executor.h" 50 #endif // GOOGLE_CUDA 51 52 namespace { 53 54 // Returns in 'col_data', image patches in storage order (height, width, depth) 55 // extracted from image at 'input_data', which is required to be in storage 56 // order (batch, height, width, depth). 57 // Implementation written by Yangqing Jia (jiayq). 58 template <typename T> 59 void Im2col(const T* input_data, const int depth, const int height, 60 const int width, const int filter_h, const int filter_w, 61 const int pad_t, const int pad_l, const int pad_b, const int pad_r, 62 const int stride_h, const int stride_w, T* col_data) { 63 int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; 64 int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; 65 66 int h_pad = -pad_t; 67 for (int h = 0; h < height_col; ++h) { 68 int w_pad = -pad_l; 69 for (int w = 0; w < width_col; ++w) { 70 for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { 71 for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { 72 if (ih >= 0 && ih < height && iw >= 0 && iw < width) { 73 memcpy(col_data, input_data + (ih * width + iw) * depth, 74 sizeof(T) * depth); 75 } else { 76 // This should be simply padded with zero. 77 memset(col_data, 0, sizeof(T) * depth); 78 } 79 col_data += depth; 80 } 81 } 82 w_pad += stride_w; 83 } 84 h_pad += stride_h; 85 } 86 } 87 88 } // namespace 89 90 namespace tensorflow { 91 92 typedef Eigen::ThreadPoolDevice CPUDevice; 93 typedef Eigen::GpuDevice GPUDevice; 94 95 template <typename T> 96 struct LaunchConv2DBackpropFilterOp<CPUDevice, T> { 97 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, 98 const Tensor& out_backprop, const Tensor& input, 99 int row_stride, int col_stride, const Padding& padding, 100 Tensor* filter_backprop, TensorFormat data_format) { 101 const CPUDevice& d = ctx->eigen_device<CPUDevice>(); 102 functor::SpatialConvolutionBackwardFilter<CPUDevice, T>()( 103 d, filter_backprop->tensor<T, 4>(), input.tensor<T, 4>(), 104 out_backprop.tensor<T, 4>(), row_stride, col_stride, 105 /*row_dilation=*/1, /*col_dilation=*/1); 106 } 107 }; 108 109 #ifdef TENSORFLOW_USE_LIBXSMM 110 template <typename Device, class T> 111 struct LaunchXsmmBackwardFilter { 112 bool operator()(OpKernelContext* context, const Device& d, 113 typename TTypes<T, 4>::ConstTensor input_backward, 114 typename TTypes<T, 4>::Tensor kernel, 115 typename TTypes<T, 4>::ConstTensor output_backward, 116 int input_rows, int input_cols, int row_stride, 117 int col_stride, int pad_h, int pad_w, 118 TensorFormat data_format) const { 119 return false; 120 } 121 }; 122 123 template <> 124 struct LaunchXsmmBackwardFilter<CPUDevice, float> { 125 bool operator()(OpKernelContext* context, const CPUDevice& d, 126 typename TTypes<float, 4>::ConstTensor input, 127 typename TTypes<float, 4>::Tensor filter, 128 typename TTypes<float, 4>::ConstTensor output, int input_rows, 129 int input_cols, int row_stride, int col_stride, int pad_h, 130 int pad_w, TensorFormat data_format) const { 131 auto batch = input.dimension(0); 132 auto in_depth = input.dimension(3); 133 auto out_depth = output.dimension(3); 134 auto filter_rows = filter.dimension(0); 135 auto filter_cols = filter.dimension(1); 136 137 auto num_threads = 138 context->device()->tensorflow_cpu_worker_threads()->num_threads; 139 // See libxsmm_dnn.h for this struct definition. 140 libxsmm_dnn_conv_desc desc; 141 desc.N = batch; 142 desc.C = in_depth; 143 desc.H = input_rows; 144 desc.W = input_cols; 145 desc.K = out_depth; 146 desc.R = filter_rows; 147 desc.S = filter_cols; 148 desc.u = row_stride; 149 desc.v = col_stride; 150 desc.pad_h = pad_h; 151 desc.pad_w = pad_w; 152 desc.pad_h_in = 0; // pad_rows; // ignored by libxsmm for now. 153 desc.pad_w_in = 0; // pad_cols; // ignored by libxsmm for now. 154 desc.pad_h_out = 0; 155 desc.pad_w_out = 0; 156 desc.threads = num_threads; 157 desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT; 158 desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC; 159 desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_RSCK; 160 desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE; 161 desc.options = LIBXSMM_DNN_CONV_OPTION_NONE; 162 desc.datatype = LIBXSMM_DNN_DATATYPE_F32; 163 164 if (!CanUseXsmmConv2D(desc, data_format)) { 165 return false; 166 } 167 168 auto input_ptr = input.data(); 169 auto filter_ptr = filter.data(); 170 auto output_ptr = output.data(); 171 bool success = functor::XsmmBkwFilterConv2D<CPUDevice, float>()( 172 context, desc, input_ptr, filter_ptr, output_ptr); 173 return success; 174 } 175 }; 176 #endif 177 178 template <typename Device, class T> 179 class Conv2DFastBackpropFilterOp : public OpKernel { 180 public: 181 explicit Conv2DFastBackpropFilterOp(OpKernelConstruction* context) 182 : OpKernel(context) { 183 string data_format; 184 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 185 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 186 errors::InvalidArgument("Invalid data format")); 187 OP_REQUIRES(context, data_format_ == FORMAT_NHWC, 188 errors::InvalidArgument( 189 "Conv2DFastBackpropFilterOp only supports NHWC.")); 190 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 191 OP_REQUIRES(context, strides_.size() == 4, 192 errors::InvalidArgument("Sliding window strides field must " 193 "specify 4 dimensions")); 194 OP_REQUIRES( 195 context, (strides_[0] == 1 && strides_[3] == 1), 196 errors::InvalidArgument("Current implementation does not yet support " 197 "strides in the batch and depth dimensions.")); 198 OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, 199 errors::InvalidArgument( 200 "Row and column strides should be larger than 0.")); 201 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 202 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 203 OP_REQUIRES(context, dilations_.size() == 4, 204 errors::InvalidArgument("Sliding window dilations field must " 205 "specify 4 dimensions")); 206 OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1), 207 errors::InvalidArgument( 208 "Current implementation does not yet support " 209 "dilations in the batch and depth dimensions.")); 210 // TODO(yangzihao): Add a CPU implementation for dilated convolution. 211 OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), 212 errors::InvalidArgument( 213 "Current Eigen and libxsmm implementations do not " 214 "yet support dilation rates larger than 1.")); 215 } 216 217 void Compute(OpKernelContext* context) override { 218 const Tensor& input = context->input(0); 219 const Tensor& filter_sizes = context->input(1); 220 const Tensor& out_backprop = context->input(2); 221 OP_REQUIRES( 222 context, TensorShapeUtils::IsVector(filter_sizes.shape()), 223 errors::InvalidArgument( 224 "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", 225 filter_sizes.dims())); 226 TensorShape filter_shape; 227 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 228 filter_sizes.vec<int32>(), &filter_shape)); 229 230 ConvBackpropDimensions dims; 231 OP_REQUIRES_OK( 232 context, 233 ConvBackpropComputeDimensions( 234 type_string(), /*num_spatial_dims=*/2, input.shape(), filter_shape, 235 out_backprop.shape(), strides_, padding_, data_format_, &dims)); 236 237 Tensor* filter_backprop = nullptr; 238 OP_REQUIRES_OK(context, 239 context->allocate_output(0, filter_shape, &filter_backprop)); 240 241 // If there is nothing to compute, return. 242 if (filter_shape.num_elements() == 0) { 243 return; 244 } 245 246 #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD 247 int64 pad_top, pad_bottom; 248 int64 pad_left, pad_right; 249 OP_REQUIRES_OK( 250 context, 251 GetWindowedOutputSizeVerbose( 252 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, 253 dims.spatial_dims[0].stride, padding_, 254 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); 255 OP_REQUIRES_OK( 256 context, 257 GetWindowedOutputSizeVerbose( 258 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, 259 dims.spatial_dims[1].stride, padding_, 260 &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); 261 262 if (pad_left == pad_right && pad_top == pad_bottom) { 263 if (LaunchXsmmBackwardFilter<Device, T>()( 264 context, context->eigen_device<Device>(), input.tensor<T, 4>(), 265 filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(), 266 dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size, 267 static_cast<int>(dims.spatial_dims[0].stride), 268 static_cast<int>(dims.spatial_dims[1].stride), 269 static_cast<int>(pad_top), static_cast<int>(pad_left), 270 data_format_)) { 271 return; 272 } 273 } 274 #endif 275 276 LaunchConv2DBackpropFilterOp<Device, T>()( 277 context, false, false, out_backprop, input, dims.spatial_dims[0].stride, 278 dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_); 279 } 280 281 private: 282 std::vector<int32> dilations_; 283 std::vector<int32> strides_; 284 Padding padding_; 285 TensorFormat data_format_; 286 287 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DFastBackpropFilterOp); 288 }; 289 290 // Based on implementation written by Yangqing Jia (jiayq). 291 template <typename Device, class T> 292 class Conv2DCustomBackpropFilterOp : public OpKernel { 293 public: 294 explicit Conv2DCustomBackpropFilterOp(OpKernelConstruction* context) 295 : OpKernel(context) { 296 string data_format; 297 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 298 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 299 errors::InvalidArgument("Invalid data format")); 300 OP_REQUIRES(context, data_format_ == FORMAT_NHWC, 301 errors::InvalidArgument( 302 "Conv2DCustomBackpropFilterOp only supports NHWC.")); 303 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 304 OP_REQUIRES(context, strides_.size() == 4, 305 errors::InvalidArgument("Sliding window strides field must " 306 "specify 4 dimensions")); 307 OP_REQUIRES( 308 context, (strides_[0] == 1 && strides_[3] == 1), 309 errors::InvalidArgument("Current implementation does not yet support " 310 "strides in the batch and depth dimensions.")); 311 OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, 312 errors::InvalidArgument( 313 "Row and column strides should be larger than 0.")); 314 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 315 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 316 OP_REQUIRES(context, dilations_.size() == 4, 317 errors::InvalidArgument("Sliding window dilations field must " 318 "specify 4 dimensions")); 319 OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1), 320 errors::InvalidArgument( 321 "Current implementation does not yet support " 322 "dilations in the batch and depth dimensions.")); 323 // TODO(yangzihao): Add a CPU implementation for dilated convolution. 324 OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1), 325 errors::InvalidArgument( 326 "Current libxsmm and customized CPU implementations do " 327 "not yet support dilation rates larger than 1.")); 328 } 329 330 void Compute(OpKernelContext* context) override { 331 const Tensor& input = context->input(0); 332 const Tensor& filter_sizes = context->input(1); 333 const Tensor& out_backprop = context->input(2); 334 OP_REQUIRES( 335 context, TensorShapeUtils::IsVector(filter_sizes.shape()), 336 errors::InvalidArgument( 337 "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, " 338 "not ", 339 filter_sizes.dims())); 340 TensorShape filter_shape; 341 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 342 filter_sizes.vec<int32>(), &filter_shape)); 343 344 ConvBackpropDimensions dims; 345 OP_REQUIRES_OK(context, 346 ConvBackpropComputeDimensions( 347 "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, 348 input.shape(), filter_shape, out_backprop.shape(), 349 strides_, padding_, data_format_, &dims)); 350 351 Tensor* filter_backprop; 352 OP_REQUIRES_OK(context, 353 context->allocate_output(0, filter_shape, &filter_backprop)); 354 355 // If there is nothing to compute, return. 356 if (filter_shape.num_elements() == 0) { 357 return; 358 } 359 360 int64 pad_top, pad_bottom; 361 int64 pad_left, pad_right; 362 OP_REQUIRES_OK( 363 context, 364 GetWindowedOutputSizeVerbose( 365 dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, 366 dims.spatial_dims[0].stride, padding_, 367 &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); 368 OP_REQUIRES_OK( 369 context, 370 GetWindowedOutputSizeVerbose( 371 dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, 372 dims.spatial_dims[1].stride, padding_, 373 &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); 374 #if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD 375 if (pad_left == pad_right && pad_top == pad_bottom) { 376 if (LaunchXsmmBackwardFilter<Device, T>()( 377 context, context->eigen_device<Device>(), input.tensor<T, 4>(), 378 filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(), 379 dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size, 380 static_cast<int>(dims.spatial_dims[0].stride), 381 static_cast<int>(dims.spatial_dims[1].stride), 382 static_cast<int>(pad_top), static_cast<int>(pad_left), 383 data_format_)) { 384 return; 385 } 386 } 387 #endif 388 389 // The total dimension size of each kernel. 390 const int filter_total_size = dims.spatial_dims[0].filter_size * 391 dims.spatial_dims[1].filter_size * 392 dims.in_depth; 393 // The output image size is the spatial size of the output. 394 const int output_image_size = 395 dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size; 396 397 // Shard 'batch' images into 'shard_size' groups of images to be fed 398 // into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache 399 // size ('target_working_set_size') by the matmul size of an individual 400 // image ('work_unit_size'). 401 402 // TODO(andydavis) 403 // *) Get L3 cache size from device at runtime (30MB is from ivybridge). 404 // *) Consider reducing 'target_working_set_size' if L3 is shared by 405 // other concurrently running tensorflow ops. 406 const size_t target_working_set_size = (30LL << 20) / sizeof(T); 407 408 const size_t size_A = output_image_size * filter_total_size; 409 410 const size_t size_B = output_image_size * dims.out_depth; 411 412 const size_t size_C = filter_total_size * dims.out_depth; 413 414 const size_t work_unit_size = size_A + size_B + size_C; 415 416 const size_t shard_size = 417 (target_working_set_size + work_unit_size - 1) / work_unit_size; 418 419 Tensor col_buffer; 420 OP_REQUIRES_OK(context, 421 context->allocate_temp( 422 DataTypeToEnum<T>::value, 423 TensorShape({static_cast<int64>(shard_size), 424 static_cast<int64>(output_image_size), 425 static_cast<int64>(filter_total_size)}), 426 &col_buffer)); 427 428 // The input offset corresponding to a single input image. 429 const int input_offset = dims.spatial_dims[0].input_size * 430 dims.spatial_dims[1].input_size * dims.in_depth; 431 // The output offset corresponding to a single output image. 432 const int output_offset = dims.spatial_dims[0].output_size * 433 dims.spatial_dims[1].output_size * dims.out_depth; 434 435 const T* input_data = input.template flat<T>().data(); 436 T* col_buffer_data = col_buffer.template flat<T>().data(); 437 const T* out_backprop_data = out_backprop.template flat<T>().data(); 438 T* filter_backprop_data = filter_backprop->template flat<T>().data(); 439 440 typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, 441 Eigen::Unaligned> 442 TensorMap; 443 typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, 444 Eigen::Unaligned> 445 ConstTensorMap; 446 447 TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth); 448 C.setZero(); 449 450 // Initialize contraction dims (we need to transpose 'A' below). 451 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; 452 contract_dims[0].first = 0; 453 contract_dims[0].second = 0; 454 455 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 456 457 for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) { 458 const int shard_limit = 459 std::min(static_cast<int>(shard_size), 460 static_cast<int>(dims.batch_size) - image_id); 461 462 auto shard = [&input_data, &col_buffer_data, &dims, &pad_top, &pad_left, 463 &pad_bottom, &pad_right, &input_offset, 464 &size_A](int64 start, int64 limit) { 465 for (int shard_id = start; shard_id < limit; ++shard_id) { 466 const T* input_data_shard = input_data + shard_id * input_offset; 467 T* col_data_shard = col_buffer_data + shard_id * size_A; 468 469 // When we compute the gradient with respect to the filters, we need 470 // to do im2col to allow gemm-type computation. 471 Im2col<T>( 472 input_data_shard, dims.in_depth, dims.spatial_dims[0].input_size, 473 dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size, 474 dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom, 475 pad_right, dims.spatial_dims[0].stride, 476 dims.spatial_dims[1].stride, col_data_shard); 477 } 478 }; 479 Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, 480 size_A, shard); 481 482 ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, 483 filter_total_size); 484 ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, 485 dims.out_depth); 486 487 // Gradient with respect to filter. 488 C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); 489 490 input_data += input_offset * shard_limit; 491 out_backprop_data += output_offset * shard_limit; 492 } 493 } 494 495 private: 496 std::vector<int32> dilations_; 497 std::vector<int32> strides_; 498 Padding padding_; 499 TensorFormat data_format_; 500 501 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropFilterOp); 502 }; 503 504 #define REGISTER_CPU_KERNELS(T) \ 505 REGISTER_KERNEL_BUILDER( \ 506 Name("Conv2DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 507 Conv2DCustomBackpropFilterOp<CPUDevice, T>); \ 508 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \ 509 .Device(DEVICE_CPU) \ 510 .Label("custom") \ 511 .TypeConstraint<T>("T"), \ 512 Conv2DCustomBackpropFilterOp<CPUDevice, T>); \ 513 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \ 514 .Device(DEVICE_CPU) \ 515 .Label("eigen_tensor") \ 516 .TypeConstraint<T>("T"), \ 517 Conv2DFastBackpropFilterOp<CPUDevice, T>); 518 519 TF_CALL_half(REGISTER_CPU_KERNELS); 520 TF_CALL_float(REGISTER_CPU_KERNELS); 521 #undef REGISTER_CPU_KERNELS 522 523 // GPU definitions. 524 #if GOOGLE_CUDA 525 // The slow version (but compiles for GPU) 526 527 // A dummy type to group forward backward filter autotune results together. 528 struct ConvBackwardFilterAutoTuneGroup { 529 static string name() { return "ConvBwdFilter"; } 530 }; 531 typedef AutoTuneSingleton<ConvBackwardFilterAutoTuneGroup, ConvParameters, 532 perftools::gputools::dnn::AlgorithmConfig> 533 AutoTuneConvBwdFilter; 534 535 // Backprop for filter. 536 template <typename Device, class T> 537 class Conv2DSlowBackpropFilterOp : public OpKernel { 538 public: 539 explicit Conv2DSlowBackpropFilterOp(OpKernelConstruction* context) 540 : OpKernel(context) { 541 string data_format; 542 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 543 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 544 errors::InvalidArgument("Invalid data format")); 545 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 546 int stride_n = GetTensorDim(strides_, data_format_, 'N'); 547 int stride_c = GetTensorDim(strides_, data_format_, 'C'); 548 int stride_h = GetTensorDim(strides_, data_format_, 'H'); 549 int stride_w = GetTensorDim(strides_, data_format_, 'W'); 550 OP_REQUIRES( 551 context, (stride_n == 1 && stride_c == 1), 552 errors::InvalidArgument("Current implementation does not yet support " 553 "strides in the batch and depth dimensions.")); 554 OP_REQUIRES(context, stride_h > 0 && stride_w > 0, 555 errors::InvalidArgument( 556 "Row and column strides should be larger than 0.")); 557 OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); 558 OP_REQUIRES(context, dilations_.size() == 4, 559 errors::InvalidArgument("Sliding window dilations field must " 560 "specify 4 dimensions")); 561 int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); 562 int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); 563 int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); 564 int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); 565 OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, 566 errors::InvalidArgument( 567 "Current implementation does not yet support " 568 "dilations in the batch and depth dimensions.")); 569 OP_REQUIRES( 570 context, dilation_h > 0 && dilation_w > 0, 571 errors::InvalidArgument("Dilated rates should be larger than 0.")); 572 OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); 573 use_cudnn_ &= CanUseCudnn(); 574 cudnn_use_autotune_ = CudnnUseAutotune(); 575 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 576 } 577 578 void Compute(OpKernelContext* context) override { 579 const Tensor& input = context->input(0); 580 const Tensor& filter_sizes = context->input(1); 581 const Tensor& out_backprop = context->input(2); 582 OP_REQUIRES( 583 context, TensorShapeUtils::IsVector(filter_sizes.shape()), 584 errors::InvalidArgument( 585 "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", 586 filter_sizes.dims())); 587 TensorShape filter_shape; 588 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 589 filter_sizes.vec<int32>(), &filter_shape)); 590 591 Tensor* filter_backprop = nullptr; 592 OP_REQUIRES_OK(context, 593 context->allocate_output(0, filter_shape, &filter_backprop)); 594 595 // If there is nothing to compute, return. 596 if (filter_shape.num_elements() == 0) { 597 return; 598 } 599 // If input is empty, set gradients to zero. 600 if (input.shape().num_elements() == 0) { 601 functor::SetZeroFunctor<Device, T> f; 602 f(context->eigen_device<Device>(), filter_backprop->flat<T>()); 603 return; 604 } 605 606 // For now we take the stride from the second and third dimensions only (we 607 // do not support striding on the batch or depth dimension). 608 const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); 609 const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); 610 const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); 611 const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); 612 613 launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input, 614 dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, 615 filter_backprop, data_format_); 616 } 617 618 private: 619 std::vector<int32> dilations_; 620 std::vector<int32> strides_; 621 Padding padding_; 622 bool use_cudnn_; 623 TensorFormat data_format_; 624 LaunchConv2DBackpropFilterOp<Device, T> launcher_; 625 bool cudnn_use_autotune_; 626 627 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp); 628 }; 629 630 template <typename T> 631 void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()( 632 OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, 633 const Tensor& out_backprop, const Tensor& input, int row_dilation, 634 int col_dilation, int row_stride, int col_stride, const Padding& padding, 635 Tensor* filter_backprop, TensorFormat data_format) { 636 using perftools::gputools::dnn::AlgorithmConfig; 637 using perftools::gputools::dnn::AlgorithmDesc; 638 using perftools::gputools::dnn::ProfileResult; 639 640 std::vector<int32> dilations(4, 1); 641 dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation; 642 dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation; 643 644 std::vector<int32> strides(4, 1); 645 strides[GetTensorDimIndex(data_format, 'H')] = row_stride; 646 strides[GetTensorDimIndex(data_format, 'W')] = col_stride; 647 TensorShape filter_shape = filter_backprop->shape(); 648 649 ConvBackpropDimensions dims; 650 OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2( 651 "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2, 652 input.shape(), filter_shape, out_backprop.shape(), 653 dilations, strides, padding, data_format, &dims)); 654 655 // TODO(yangzihao): The padding computations should be done in 656 // GetWindowedOutputSize() functions. 657 const int padding_rows = 658 (padding == VALID) 659 ? 0 660 : std::max<int>(0, (dims.spatial_dims[0].output_size - 1) * 661 dims.spatial_dims[0].stride + 662 (dims.spatial_dims[0].filter_size - 1) * 663 dims.spatial_dims[0].dilation + 664 1 - dims.spatial_dims[0].input_size); 665 const int padding_cols = 666 (padding == VALID) 667 ? 0 668 : std::max<int>(0, (dims.spatial_dims[1].output_size - 1) * 669 dims.spatial_dims[1].stride + 670 (dims.spatial_dims[1].filter_size - 1) * 671 dims.spatial_dims[1].dilation + 672 1 - dims.spatial_dims[1].input_size); 673 674 // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only 675 // calling it when that is true. Remove this check when (if?) cuDNN starts 676 // supporting different padding. 677 bool rows_odd = (padding_rows % 2 != 0); 678 bool cols_odd = (padding_cols % 2 != 0); 679 680 auto* stream = ctx->op_device_context()->stream(); 681 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); 682 683 if (!use_cudnn) { 684 ctx->SetStatus(errors::Unimplemented( 685 "Conv2DBackprop for GPU is not currently supported " 686 "without cudnn")); 687 return; 688 } 689 690 bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization(); 691 if (!cudnn_disable_conv_1x1_optimization_ && 692 dims.spatial_dims[0].filter_size == 1 && 693 dims.spatial_dims[1].filter_size == 1 && 694 dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && 695 data_format == FORMAT_NHWC) { 696 const uint64 m = dims.in_depth; 697 const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size * 698 dims.spatial_dims[1].input_size; 699 const uint64 n = dims.out_depth; 700 701 // The shape of output backprop is 702 // [batch, out_rows, out_cols, out_depth] 703 // From cublas's perspective, it is: n x k 704 auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 705 out_backprop.template flat<T>().size()); 706 707 // The shape of input is 708 // [batch, in_rows, in_cols, in_depth], 709 // From cublas's perspective, it is: m x k 710 auto b_ptr = AsDeviceMemory(input.template flat<T>().data(), 711 input.template flat<T>().size()); 712 713 // the shape of the filter backprop from the conv_2d should be 714 // [1, 1, in_depth, out_depth] 715 // From cublas's perspective, it is: n x m 716 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), 717 filter_backprop->template flat<T>().size()); 718 719 bool blas_launch_status = 720 stream 721 ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, 722 perftools::gputools::blas::Transpose::kTranspose, n, 723 m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) 724 .ok(); 725 if (!blas_launch_status) { 726 ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 727 ", n=", n, ", k=", k)); 728 } 729 return; 730 } else if (dims.spatial_dims[0].filter_size == 731 dims.spatial_dims[0].input_size && 732 dims.spatial_dims[1].filter_size == 733 dims.spatial_dims[1].input_size && 734 padding == VALID && data_format == FORMAT_NHWC) { 735 // The input data and filter have the same height/width, so call cublas 736 // directly. 737 const uint64 m = dims.spatial_dims[0].input_size * 738 dims.spatial_dims[1].input_size * dims.in_depth; 739 const uint64 k = dims.batch_size; 740 const uint64 n = dims.out_depth; 741 742 auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), 743 input.template flat<T>().size()); 744 auto b_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), 745 out_backprop.template flat<T>().size()); 746 auto c_ptr = AsDeviceMemory(filter_backprop->template flat<T>().data(), 747 filter_backprop->template flat<T>().size()); 748 749 bool blas_launch_status = 750 stream 751 ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, 752 perftools::gputools::blas::Transpose::kTranspose, n, 753 m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n) 754 .ok(); 755 if (!blas_launch_status) { 756 ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, 757 ", n=", n, ", k=", k)); 758 } 759 return; 760 } 761 762 Tensor compatible_input; 763 if (rows_odd || cols_odd) { 764 // If a padding dimension is odd, we have one more element on the right 765 // side or the bottom side. This is unsupported in cudnn. Therefore, 766 // we pad that extra element and make it compatible. 767 OP_REQUIRES_OK( 768 ctx, ctx->allocate_temp( 769 DataTypeToEnum<T>::value, 770 ShapeFromFormat(data_format, dims.batch_size, 771 dims.spatial_dims[0].input_size + rows_odd, 772 dims.spatial_dims[1].input_size + cols_odd, 773 dims.in_depth), 774 &compatible_input)); 775 776 functor::PadInput<GPUDevice, T, int, 4>()( 777 ctx->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 4>()), 778 {{0, 0}}, {{rows_odd, cols_odd}}, 779 To32Bit(compatible_input.tensor<T, 4>()), data_format); 780 } else { 781 compatible_input = input; 782 } 783 784 CHECK(padding_rows >= 0 && padding_cols >= 0) 785 << "Negative row or col paddings: (" << padding_rows << ", " 786 << padding_cols << ")"; 787 perftools::gputools::dnn::BatchDescriptor input_desc; 788 input_desc.set_count(dims.batch_size) 789 .set_height(GetTensorDim(compatible_input, data_format, 'H')) 790 .set_width(GetTensorDim(compatible_input, data_format, 'W')) 791 .set_feature_map_count(dims.in_depth) 792 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 793 perftools::gputools::dnn::BatchDescriptor output_desc; 794 output_desc.set_count(dims.batch_size) 795 .set_height(dims.spatial_dims[0].output_size) 796 .set_width(dims.spatial_dims[1].output_size) 797 .set_feature_map_count(dims.out_depth) 798 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 799 perftools::gputools::dnn::FilterDescriptor filter_desc; 800 filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) 801 .set_input_filter_width(dims.spatial_dims[1].filter_size) 802 .set_input_feature_map_count(dims.in_depth) 803 .set_output_feature_map_count(dims.out_depth); 804 perftools::gputools::dnn::ConvolutionDescriptor conv_desc; 805 conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation) 806 .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation) 807 .set_vertical_filter_stride(dims.spatial_dims[0].stride) 808 .set_horizontal_filter_stride(dims.spatial_dims[1].stride) 809 .set_zero_padding_height(padding_rows / 2) 810 .set_zero_padding_width(padding_cols / 2); 811 812 // NOTE(zhengxq): 813 // cuDNN only supports the following layouts : 814 // Input : B x D x R x C 815 // Filter : OD x ID x R x C 816 // Whereas, we have 817 // Input : B x R x C x D 818 // Filter : R x C x ID x OD 819 // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) 820 // The first TransformDepth performs 821 // (B x R x C x D) => (B x D x R x C). 822 // Since the tensor returned from cuDNN is B x D x R x C also, 823 // the second TransformDepth performs 824 // (B x D x R x C) => (B x R x C x D). 825 826 Tensor pre_transformed_filter_backprop; 827 OP_REQUIRES_OK( 828 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 829 TensorShape({dims.out_depth, dims.in_depth, 830 dims.spatial_dims[0].filter_size, 831 dims.spatial_dims[1].filter_size}), 832 &pre_transformed_filter_backprop)); 833 834 Tensor transformed_out_backprop; 835 if (data_format == FORMAT_NHWC) { 836 TensorShape nchw_shape = ShapeFromFormat( 837 FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, 838 dims.spatial_dims[1].output_size, dims.out_depth); 839 if (dims.out_depth > 1) { 840 OP_REQUIRES_OK(ctx, 841 ctx->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, 842 &transformed_out_backprop)); 843 functor::NHWCToNCHW<GPUDevice, T, 4>()( 844 ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(), 845 transformed_out_backprop.tensor<T, 4>()); 846 } else { 847 // If depth <= 1, just reshape. 848 CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); 849 } 850 } else { 851 transformed_out_backprop = out_backprop; 852 } 853 854 Tensor transformed_input; 855 if (data_format == FORMAT_NHWC) { 856 TensorShape nchw_shape = ShapeFromFormat( 857 FORMAT_NCHW, GetTensorDim(compatible_input, data_format, 'N'), 858 GetTensorDim(compatible_input, data_format, 'H'), 859 GetTensorDim(compatible_input, data_format, 'W'), 860 GetTensorDim(compatible_input, data_format, 'C')); 861 if (nchw_shape.dim_size(1) > 1) { 862 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 863 nchw_shape, &transformed_input)); 864 functor::NHWCToNCHW<GPUDevice, T, 4>()( 865 ctx->eigen_device<GPUDevice>(), 866 const_cast<const Tensor&>(compatible_input).tensor<T, 4>(), 867 transformed_input.tensor<T, 4>()); 868 } else { 869 // If depth <= 1, just reshape. 870 CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); 871 } 872 } else { 873 transformed_input = compatible_input; 874 } 875 876 auto out_backprop_ptr = 877 AsDeviceMemory(transformed_out_backprop.template flat<T>().data(), 878 transformed_out_backprop.template flat<T>().size()); 879 auto filter_backprop_ptr = 880 AsDeviceMemory(pre_transformed_filter_backprop.template flat<T>().data(), 881 pre_transformed_filter_backprop.template flat<T>().size()); 882 auto input_ptr = AsDeviceMemory(transformed_input.template flat<T>().data(), 883 transformed_input.template flat<T>().size()); 884 885 static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit( 886 "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default 887 ); 888 int device_id = stream->parent()->device_ordinal(); 889 DataType dtype = input.dtype(); 890 ConvParameters conv_parameters = { 891 dims.batch_size, // batch 892 dims.in_depth, // in_depths 893 {{input_desc.height(), // in_rows 894 input_desc.width()}}, // in_cols 895 dims.out_depth, // out_depths 896 {{dims.spatial_dims[0].filter_size, // filter_rows 897 dims.spatial_dims[1].filter_size}}, // filter_cols 898 {{dims.spatial_dims[0].dilation, // dilation_rows 899 dims.spatial_dims[1].dilation}}, // dilation_cols 900 {{dims.spatial_dims[0].stride, // stride_rows 901 dims.spatial_dims[1].stride}}, // stride_cols 902 {{padding_rows, // padding_rows 903 padding_cols}}, // padding_cols 904 dtype, // tensor datatype 905 device_id, // device_id 906 }; 907 AlgorithmConfig algorithm_config; 908 if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find( 909 conv_parameters, &algorithm_config)) { 910 std::vector<AlgorithmDesc> algorithms; 911 CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( 912 conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms)); 913 ProfileResult best_result; 914 ProfileResult best_result_no_scratch; 915 for (auto profile_algorithm : algorithms) { 916 // TODO(zhengxq): profile each algorithm multiple times to better 917 // accuracy. 918 CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, 919 ctx); 920 ProfileResult profile_result; 921 bool cudnn_launch_status = 922 stream 923 ->ThenConvolveBackwardFilterWithAlgorithm( 924 input_desc, input_ptr, output_desc, out_backprop_ptr, 925 conv_desc, filter_desc, &filter_backprop_ptr, 926 &scratch_allocator, AlgorithmConfig(profile_algorithm), 927 &profile_result) 928 .ok(); 929 if (cudnn_launch_status) { 930 if (profile_result.is_valid()) { 931 if (profile_result.elapsed_time_in_ms() < 932 best_result.elapsed_time_in_ms()) { 933 best_result = profile_result; 934 } 935 if (scratch_allocator.TotalByteSize() == 0 && 936 profile_result.elapsed_time_in_ms() < 937 best_result_no_scratch.elapsed_time_in_ms()) { 938 best_result_no_scratch = profile_result; 939 } 940 } 941 } 942 } 943 OP_REQUIRES(ctx, 944 best_result.is_valid() || best_result_no_scratch.is_valid(), 945 errors::NotFound("No algorithm worked!")); 946 if (best_result.is_valid()) { 947 algorithm_config.set_algorithm(best_result.algorithm()); 948 } 949 if (best_result_no_scratch.is_valid()) { 950 algorithm_config.set_algorithm_no_scratch( 951 best_result_no_scratch.algorithm()); 952 } 953 AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters, 954 algorithm_config); 955 } 956 CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, 957 ctx); 958 bool cudnn_launch_status = 959 stream 960 ->ThenConvolveBackwardFilterWithAlgorithm( 961 input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, 962 filter_desc, &filter_backprop_ptr, &scratch_allocator, 963 algorithm_config, nullptr) 964 .ok(); 965 966 if (!cudnn_launch_status) { 967 ctx->SetStatus(errors::Internal( 968 "cuDNN Backward Filter function launch failure : input shape(", 969 input.shape().DebugString(), ") filter shape(", 970 filter_shape.DebugString(), ")")); 971 return; 972 } 973 974 auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; 975 functor::ReverseTransformFilter<GPUDevice, T, 4>()( 976 ctx->eigen_device<GPUDevice>(), 977 toConstTensor(pre_transformed_filter_backprop).template tensor<T, 4>(), 978 filter_backprop->tensor<T, 4>()); 979 } 980 981 // Forward declarations of the functor specializations for GPU. 982 namespace functor { 983 #define DECLARE_GPU_SPEC(T) \ 984 template <> \ 985 void ShuffleAndReverse<GPUDevice, T, 4, int>::operator()( \ 986 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \ 987 const Eigen::DSizes<int, 4>& order, \ 988 const Eigen::array<bool, 4>& reverse_dims, \ 989 typename TTypes<T, 4, int>::Tensor output); \ 990 extern template struct ShuffleAndReverse<GPUDevice, T, 4, int>; \ 991 template <> \ 992 void InflatePadAndShuffle<GPUDevice, T, 4, int>::operator()( \ 993 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor input, \ 994 const Eigen::DSizes<int, 4>& strides, \ 995 const Eigen::array<Eigen::IndexPair<int>, 4>& pad_dims, \ 996 const Eigen::DSizes<int, 4>& order, \ 997 typename TTypes<T, 4, int>::Tensor output); \ 998 extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \ 999 template <> \ 1000 void TransformFilter<GPUDevice, T, int, 4>::operator()( \ 1001 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ 1002 typename TTypes<T, 4, int>::Tensor out); \ 1003 extern template struct TransformFilter<GPUDevice, T, int, 4>; \ 1004 template <> \ 1005 void TransformDepth<GPUDevice, T, int>::operator()( \ 1006 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ 1007 const Eigen::DSizes<int, 4>& shuffle, \ 1008 typename TTypes<T, 4, int>::Tensor out); \ 1009 extern template struct TransformDepth<GPUDevice, T, int>; \ 1010 template <> \ 1011 void PadInput<GPUDevice, T, int, 4>::operator()( \ 1012 const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ 1013 const std::array<int, 2>& padding_left, \ 1014 const std::array<int, 2>& padding_right, \ 1015 typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format); \ 1016 extern template struct PadInput<GPUDevice, T, int, 4>; 1017 1018 DECLARE_GPU_SPEC(float); 1019 DECLARE_GPU_SPEC(Eigen::half); 1020 #undef DECLARE_GPU_SPEC 1021 } // namespace functor 1022 1023 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") 1024 .Device(DEVICE_GPU) 1025 .TypeConstraint<float>("T") 1026 .HostMemory("filter_sizes"), 1027 Conv2DSlowBackpropFilterOp<GPUDevice, float>); 1028 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") 1029 .Device(DEVICE_GPU) 1030 .TypeConstraint<Eigen::half>("T") 1031 .HostMemory("filter_sizes"), 1032 Conv2DSlowBackpropFilterOp<GPUDevice, Eigen::half>); 1033 #endif // GOOGLE_CUDA 1034 1035 } // namespace tensorflow 1036