1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file contains a set of different implementations of the two-dimensional 17 // convolution operation. The standard TensorFlow Conv2d kernel uses EigenTensor 18 // to implement the computation, but this module has a variety of different ways 19 // of producing the same result. These methods are designed to be easier to 20 // understand and connect to other libraries, so that we can take advantage of 21 // platforms that have specialized implementations of GEMM for example. 22 // 23 // The basic interface is a Conv functor object that's templated by the types 24 // of the data it will be operating on, and is passed in the arguments needed to 25 // calculate the convolution. The simplest implementation of this functor is 26 // ReferenceConvFunctor, which is a readable but slow reference version. 27 // 28 // A faster version uses the approach of packing image patches into a matrix 29 // before calling a matrix multiply, the Im2ColConvFunctor. In turn, this can 30 // use a variety of different methods to calculate the matrix multiplication, 31 // or GEMM. The simplest but slowest is the ReferenceGemmFunctor, but the 32 // FastGemmFunctor will use whatever optimized libraries are available. By 33 // default it uses Eigen, but on Apple platforms it will take advantage of the 34 // system's Accelerate BLAS library to get better performance than the standard 35 // TensorFlow convolution kernel. 36 // 37 // The version actually used is defined at the bottom of this file using the 38 // REGISTER_KERNEL_BUILDER() macro. To try out different implementations (for 39 // example to switch to a reference one for easier debugging) you can swap out 40 // the default functors in that call. 41 // 42 // The registration itself is guarded with the USE_GEMM_FOR_CONV macro. The iOS 43 // makefile build defines this, but if you want to enable this implementation 44 // and disable the standard EigenTensor one in other build setups, you'll need 45 // to define it there too. 46 47 #define EIGEN_USE_THREADS 48 49 #include <string.h> 50 #include <map> 51 #include <vector> 52 #include "tensorflow/core/framework/common_shape_fns.h" 53 #include "tensorflow/core/framework/numeric_op.h" 54 #include "tensorflow/core/framework/op_kernel.h" 55 #include "tensorflow/core/framework/register_types.h" 56 #include "tensorflow/core/framework/resource_mgr.h" 57 #include "tensorflow/core/framework/tensor.h" 58 #include "tensorflow/core/framework/tensor_shape.h" 59 #include "tensorflow/core/framework/tensor_slice.h" 60 #include "tensorflow/core/kernels/bounds_check.h" 61 #include "tensorflow/core/kernels/conv_ops.h" 62 #include "tensorflow/core/kernels/gemm_functors.h" 63 #include "tensorflow/core/kernels/image_resizer_state.h" 64 #include "tensorflow/core/util/mirror_pad_mode.h" 65 #include "tensorflow/core/util/padding.h" 66 #include "tensorflow/core/util/tensor_format.h" 67 68 namespace tensorflow { 69 70 namespace { 71 // This function implements the convolution operation in as simple a form as 72 // possible. It won't give great performance, but it is very useful for 73 // stepping through and instrumenting for debugging, creating minimal benchmarks 74 // to prototype with, and sharing with teams that want to run this outside of 75 // our environment. 76 // With that in mind, I've avoided using anything except pretty standard C++ 77 // types. This is especially noticeable in the data access through raw array 78 // indexing. It's deliberate in this case though, since it makes the underlying 79 // memory order very explicit, which is important for both inspecting memory 80 // contents during debugging and for specifying what we expect to others. 81 // The memory layout of the data is, from biggest stride to smallest: 82 // input_data = [input_batches, input_height, input_width, input_depth] 83 // filter_data = [filter_height, filter_width, input_depth, filter_count] 84 // output_data = [input_batches, output_height, output_width, filter_count] 85 template <class T1, class T2, class T3> 86 class ReferenceConvFunctor { 87 public: 88 void operator()(OpKernelContext* context, const T1* input_data, 89 int input_batches, int input_height, int input_width, 90 int input_depth, const T2* filter_data, int filter_height, 91 int filter_width, int filter_count, int stride_rows, 92 int stride_cols, Padding padding, T3* output_data, 93 int output_height, int output_width) { 94 // The two different padding modes we support can be a bit confusing. SAME 95 // means we're trying to produce an output image that's the same size as the 96 // input. It's complicated by stride, which shrinks the output image by a 97 // a factor, but it means we end up sampling from outside the borders of the 98 // input. These out-of-bounds values are read as zeroes. VALID means only 99 // produce output values where the filters can read all their values from 100 // within the input image. It effectively removes the margins of the output 101 // image compared to the one produced by SAME. Stride complicates this 102 // definition though, because it can result in the right and bottom filter 103 // patches sampling from outside the borders if it's greater than 1. 104 // Most of the logic for sorting this all out is done before this function, 105 // when we calculate the output size, but the positioning of the origin of 106 // the filters is different between the two modes, since SAME positions the 107 // first filter off the edge of the input. 108 int filter_left_offset; 109 int filter_top_offset; 110 if (padding == VALID) { 111 filter_left_offset = 112 ((output_width - 1) * stride_cols + filter_width - input_width + 1) / 113 2; 114 filter_top_offset = ((output_height - 1) * stride_rows + filter_height - 115 input_height + 1) / 116 2; 117 } else { 118 filter_left_offset = 119 ((output_width - 1) * stride_cols + filter_width - input_width) / 2; 120 filter_top_offset = 121 ((output_height - 1) * stride_rows + filter_height - input_height) / 122 2; 123 } 124 125 // If we've got multiple images in our input, work through each of them. 126 for (int batch = 0; batch < input_batches; ++batch) { 127 // Walk through all the output image values, sliding the filter to 128 // different positions in the input. 129 for (int out_y = 0; out_y < output_height; ++out_y) { 130 for (int out_x = 0; out_x < output_width; ++out_x) { 131 // Each filter kernel produces one output channel. 132 for (int out_channel = 0; out_channel < filter_count; ++out_channel) { 133 // We're going to calculate a single output value, which means we 134 // need to multiply a three dimensional kernel of weights against 135 // the current location within the input image. 136 /* 137 *-------------------------------... 138 |\ ^ 139 | \in_depth 140 | \ v 141 | *-------------------------------... 142 | | ^ 143 | | in_y_origin 144 | | v \ 145 | |<in_x_origin>*---*^ 146 | | \| |filter_height 147 . | *---*v 148 . | <---> 149 . filter_width 150 . 151 */ 152 const int in_x_origin = (out_x * stride_cols) - filter_left_offset; 153 const int in_y_origin = (out_y * stride_rows) - filter_top_offset; 154 T3 total(0); 155 for (int filter_y = 0; filter_y < filter_height; ++filter_y) { 156 for (int filter_x = 0; filter_x < filter_width; ++filter_x) { 157 for (int in_channel = 0; in_channel < input_depth; 158 ++in_channel) { 159 const int in_x = in_x_origin + filter_x; 160 const int in_y = in_y_origin + filter_y; 161 T1 input_value; 162 // If the location is outside the bounds of the input image, 163 // use zero as a default value. 164 if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && 165 (in_y < input_height)) { 166 input_value = 167 input_data[(batch * input_height * input_width * 168 input_depth) + 169 (in_y * input_width * input_depth) + 170 (in_x * input_depth) + in_channel]; 171 } else { 172 input_value = T1(0); 173 } 174 const T2 filter_value = 175 filter_data[(filter_y * filter_width * input_depth * 176 filter_count) + 177 (filter_x * input_depth * filter_count) + 178 (in_channel * filter_count) + out_channel]; 179 total += (input_value * filter_value); 180 } 181 } 182 } 183 output_data[(batch * output_height * output_width * filter_count) + 184 (out_y * output_width * filter_count) + 185 (out_x * filter_count) + out_channel] = total; 186 } 187 } 188 } 189 } 190 } 191 }; 192 193 // We don't want to allocate a buffer to hold all the patches if the size is 194 // going to be extremely large, so break it into chunks if it's bigger than 195 // a limit. Each chunk will be processed serially, so we can refill the 196 // buffer for the next chunk and reuse it, keeping maximum memory size down. 197 // In this case, we've picked 16 megabytes as a reasonable limit for Android and 198 // other platforms using Eigen, and 1MB for Apple devices, from experimentation. 199 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM) 200 const size_t kMaxChunkSize = (1 * 1024 * 1024); 201 #else 202 const size_t kMaxChunkSize = (16 * 1024 * 1024); 203 #endif 204 205 // Implements convolution as a two stage process, first packing the patches of 206 // the input image into columns (im2col) and then running GEMM to produce the 207 // final result. 208 template <class T1, class T2, class T3, class TGemmFunctor> 209 class Im2ColConvFunctor { 210 public: 211 void operator()(OpKernelContext* context, const T1* input_data, 212 int input_batches, int input_height, int input_width, 213 int input_depth, const T2* filter_data, int filter_height, 214 int filter_width, int filter_count, int stride_rows, 215 int stride_cols, Padding padding, T3* output_data, 216 int output_height, int output_width) { 217 if ((input_batches <= 0) || (input_width <= 0) || (input_height <= 0) || 218 (input_depth <= 0)) { 219 LOG(WARNING) << "Conv2D was called with bad input dimensions: " 220 << input_batches << ", " << input_height << ", " 221 << input_width << ", " << input_depth; 222 return; 223 } 224 if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) { 225 LOG(WARNING) << "Conv2D was called with bad filter dimensions: " 226 << filter_width << ", " << filter_height << ", " 227 << filter_count; 228 return; 229 } 230 if ((output_width <= 0) || (output_height <= 0)) { 231 LOG(WARNING) << "Conv2D was called with bad output width or height: " 232 << output_width << ", " << output_height; 233 return; 234 } 235 236 // We can just use a GEMM if the im2col is the identity operator, e.g., if 237 // the kernel is 1x1 or the input data and filter have same height/width. 238 if (filter_height == 1 && filter_width == 1 && stride_rows == 1 && 239 stride_cols == 1) { 240 // The kernel is 1x1. 241 const int m = input_batches * input_height * input_width; 242 const int n = filter_count; 243 const int k = input_depth; 244 const int lda = k; 245 const int ldb = filter_count; 246 const int ldc = filter_count; 247 TGemmFunctor gemm_functor; 248 gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb, 249 output_data, ldc); 250 return; 251 } else if (filter_height == input_height && filter_width == input_width && 252 padding == VALID) { 253 // The input data and filter have the same height/width. 254 const int m = input_batches; 255 const int n = filter_count; 256 const int k = input_height * input_width * input_depth; 257 const int lda = k; 258 const int ldb = filter_count; 259 const int ldc = filter_count; 260 TGemmFunctor gemm_functor; 261 gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb, 262 output_data, ldc); 263 return; 264 } 265 266 // These calculations define how the patches will be positioned within the 267 // input image. The actual definitions are quite complex, and rely on the 268 // previously-calculated output size. 269 int filter_left_offset; 270 int filter_top_offset; 271 if (padding == VALID) { 272 filter_left_offset = 273 ((output_width - 1) * stride_cols + filter_width - input_width + 1) / 274 2; 275 filter_top_offset = ((output_height - 1) * stride_rows + filter_height - 276 input_height + 1) / 277 2; 278 } else { 279 filter_left_offset = 280 ((output_width - 1) * stride_cols + filter_width - input_width) / 2; 281 filter_top_offset = 282 ((output_height - 1) * stride_rows + filter_height - input_height) / 283 2; 284 } 285 286 // The im2col buffer has # of patches rows, and # of filters cols. 287 // It's laid out like this, in row major order in memory: 288 // < filter value count > 289 // ^ +---------------------+ 290 // patch | | 291 // count | | 292 // v +---------------------+ 293 // Each patch row contains a filter_width x filter_height patch of the 294 // input, with the depth channel as the most contiguous in memory, followed 295 // by the width, then the height. This is the standard memory order in the 296 // image world if it helps to visualize it. 297 const int filter_value_count = filter_width * filter_height * input_depth; 298 OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize, 299 errors::InvalidArgument("Im2Col patch too large for buffer")); 300 const int64 patches_per_chunk = 301 kMaxChunkSize / (filter_value_count * sizeof(T1)); 302 const int64 chunk_value_count = 303 (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); 304 // Because memory allocation is very expensive on mobile platforms, try to 305 // allocate a persistent buffer that will be kept around between calls. We 306 // use TensorFlow's resource management to ensure that the memory will be 307 // released when the session is over. 308 Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource; 309 std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)> 310 creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) { 311 *resource = new Im2ColBufferResource<T1, chunk_value_count>(); 312 return Status::OK(); 313 }; 314 OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( 315 "Conv2d", "im2col_buffer", 316 &im2col_buffer_resource, creator)); 317 // This means that multiple ops can't be run simultaneously on different 318 // threads, because we have a single shared resource. The platforms this is 319 // aimed at have intra-op parallelism as their focus though, so it shouldn't 320 // be an issue. 321 mutex_lock lock_buffer(im2col_buffer_resource->mu); 322 core::ScopedUnref unref_buffer(im2col_buffer_resource); 323 T1* im2col_buffer = im2col_buffer_resource->data; 324 325 const int64 patch_count = (input_batches * output_height * output_width); 326 const int64 chunk_count = 327 (patch_count + (patches_per_chunk - 1)) / patches_per_chunk; 328 for (int64 chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { 329 const int64 patch_index_start = chunk_index * patches_per_chunk; 330 const int64 patch_index_end = 331 std::min(patch_index_start + patches_per_chunk, patch_count); 332 for (int64 patch_index = patch_index_start; patch_index < patch_index_end; 333 ++patch_index) { 334 const int64 batch = patch_index / (output_height * output_width); 335 const int64 out_y = (patch_index / output_width) % output_height; 336 const int64 out_x = patch_index % output_width; 337 const T1* input_batch_start = 338 input_data + (batch * input_height * input_width * input_depth); 339 const int in_y_origin = (out_y * stride_rows) - filter_top_offset; 340 const int in_x_origin = (out_x * stride_cols) - filter_left_offset; 341 const int patch_index_within_chunk = patch_index % patches_per_chunk; 342 T1* im2col_patch_start = 343 im2col_buffer + (patch_index_within_chunk * filter_value_count); 344 for (int filter_y = 0; filter_y < filter_height; ++filter_y) { 345 const int in_y = in_y_origin + filter_y; 346 T1* im2col_row_start = 347 im2col_patch_start + (filter_y * filter_width * input_depth); 348 // If we're off the top or the bottom of the input, fill the 349 // whole row with zeroes. 350 if ((in_y < 0) || (in_y >= input_height)) { 351 T1* im2col_row_end = 352 im2col_row_start + (filter_width * input_depth); 353 std::fill(im2col_row_start, im2col_row_end, T1(0)); 354 } else { 355 // What we're doing here is trying to copy and fill the im2col 356 // buffer as efficiently as possible, using functions to set or 357 // duplicate values en masse. We know we don't have to worry about 358 // vertical edges because we dealt with that case above, so we 359 // just need to handle filters that overlap the left or right 360 // edges. Here's what that looks like: 361 // 362 // < left_zero_count > < center_copy_count > < right_zero_count > 363 // +------------------+---------------------+--------------------+ 364 // | (filter) | (image) | (filter) | 365 // +------------------+---------------------+--------------------+ 366 // in_x_origin 0 input_width in_x_end 367 // 368 // In reality it's unlikely that a filter patch will be wider 369 // than an input, but this shows all the edge cases. 370 // We use std::fill() to set the left and right sections to zeroes 371 // and std::copy() to copy over the input data for the center. 372 const int in_x_end = in_x_origin + filter_width; 373 const int left_zero_count = std::max(0, 0 - in_x_origin); 374 const int right_zero_count = std::max(0, in_x_end - input_width); 375 const int center_copy_count = 376 filter_width - (left_zero_count + right_zero_count); 377 if (left_zero_count > 0) { 378 T1* im2col_left_start = im2col_row_start; 379 T1* im2col_left_end = 380 im2col_left_start + (left_zero_count * input_depth); 381 std::fill(im2col_left_start, im2col_left_end, T1(0)); 382 } 383 if (center_copy_count > 0) { 384 const T1* input_row_start = 385 input_batch_start + (in_y * input_width * input_depth) + 386 (std::max(0, in_x_origin) * input_depth); 387 const T1* input_row_end = 388 input_row_start + (center_copy_count * input_depth); 389 T1* im2col_center_start = 390 im2col_row_start + (left_zero_count * input_depth); 391 std::copy(input_row_start, input_row_end, im2col_center_start); 392 } 393 if (right_zero_count > 0) { 394 T1* im2col_right_start = 395 im2col_row_start + 396 ((left_zero_count + center_copy_count) * input_depth); 397 T1* im2col_right_end = 398 im2col_right_start + (right_zero_count * input_depth); 399 std::fill(im2col_right_start, im2col_right_end, T1(0)); 400 } 401 } 402 } 403 } 404 // Now we've assembled a set of image patches into a matrix, apply a 405 // GEMM matrix multiply of the patches as rows, times the filter 406 // weights in columns, to get partial results in the output matrix. 407 const int how_many_patches = patch_index_end - patch_index_start; 408 const int m = how_many_patches; 409 const int n = filter_count; 410 const int k = filter_value_count; 411 const int lda = filter_value_count; 412 const int ldb = filter_count; 413 const int ldc = filter_count; 414 T3* chunk_output_data = output_data + (patch_index_start * filter_count); 415 TGemmFunctor gemm_functor; 416 gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb, 417 chunk_output_data, ldc); 418 } 419 } 420 }; 421 422 } // namespace 423 424 // This TensorFlow kernel class handles all of the IO and housekeeping for the 425 // functors that actually implement the underlying algorithm. To swap in 426 // different implementations of the main calculations, use a different 427 // TConvFunctor parameter when instantiating the template. 428 template <class T, class TConvFunctor> 429 class Conv2DUsingGemmOp : public BinaryOp<T> { 430 public: 431 explicit Conv2DUsingGemmOp(OpKernelConstruction* context) 432 : BinaryOp<T>(context) { 433 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 434 string data_format; 435 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 436 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 437 errors::InvalidArgument("Invalid data format")); 438 OP_REQUIRES(context, data_format_ == FORMAT_NHWC, 439 errors::InvalidArgument( 440 "Data format not supported by this kernel", data_format)); 441 OP_REQUIRES(context, strides_.size() == 4, 442 errors::InvalidArgument("Sliding window strides field must " 443 "specify 4 dimensions")); 444 const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); 445 const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); 446 OP_REQUIRES( 447 context, stride_n == 1 && stride_c == 1, 448 errors::InvalidArgument("Current implementation does not yet support " 449 "strides in the batch and depth dimensions.")); 450 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 451 } 452 453 void Compute(OpKernelContext* context) override { 454 // Input tensor is of the following dimensions: 455 // [ batch, in_rows, in_cols, in_depth ] 456 const Tensor& input = context->input(0); 457 458 // Input filter is of the following dimensions: 459 // [ filter_rows, filter_cols, in_depth, out_depth] 460 const Tensor& filter = context->input(1); 461 462 // For 2D convolution, there should be 4 dimensions. 463 OP_REQUIRES(context, input.dims() == 4, 464 errors::InvalidArgument("input must be 4-dimensional", 465 input.shape().DebugString())); 466 OP_REQUIRES(context, filter.dims() == 4, 467 errors::InvalidArgument("filter must be 4-dimensional: ", 468 filter.shape().DebugString())); 469 470 for (int i = 0; i < 3; i++) { 471 OP_REQUIRES( 472 context, 473 FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), 474 errors::InvalidArgument("filter too large")); 475 } 476 477 // The last dimension for input is in_depth. It must be the same as the 478 // filter's in_depth. 479 const int64 in_depth = GetTensorDim(input, data_format_, 'C'); 480 OP_REQUIRES(context, in_depth == filter.dim_size(2), 481 errors::InvalidArgument( 482 "input and filter must have the same depth: ", in_depth, 483 " vs ", filter.dim_size(2))); 484 485 // The last dimension for filter is out_depth. 486 const int out_depth = static_cast<int>(filter.dim_size(3)); 487 488 // The second dimension for input is rows/height. 489 // The first dimension for filter is rows/height. 490 const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); 491 OP_REQUIRES( 492 context, 493 FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), 494 errors::InvalidArgument("Input rows too large")); 495 const int input_rows = static_cast<int>(input_rows_raw); 496 const int filter_rows = static_cast<int>(filter.dim_size(0)); 497 498 // The third dimension for input is columns/width. 499 // The second dimension for filter is columns/width. 500 const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); 501 OP_REQUIRES( 502 context, 503 FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), 504 errors::InvalidArgument("Input cols too large")); 505 const int input_cols = static_cast<int>(input_cols_raw); 506 const int filter_cols = static_cast<int>(filter.dim_size(1)); 507 508 // The first dimension for input is batch. 509 const int64 batch_raw = GetTensorDim(input, data_format_, 'N'); 510 OP_REQUIRES(context, 511 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), 512 errors::InvalidArgument("batch is too large")); 513 const int batch = static_cast<int>(batch_raw); 514 515 // For now we take the stride from the second and third dimensions only (we 516 // do not support striding on the batch or depth dimension). 517 const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); 518 const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); 519 520 int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; 521 OP_REQUIRES_OK(context, 522 GetWindowedOutputSize(input_rows, filter_rows, stride_rows, 523 padding_, &out_rows, &pad_rows)); 524 OP_REQUIRES_OK(context, 525 GetWindowedOutputSize(input_cols, filter_cols, stride_cols, 526 padding_, &out_cols, &pad_cols)); 527 TensorShape out_shape = 528 ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); 529 530 // Output tensor is of the following dimensions: 531 // [ in_batch, out_rows, out_cols, out_depth ] 532 Tensor* output = nullptr; 533 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); 534 535 VLOG(2) << "Conv2D: in_depth = " << in_depth 536 << ", input_cols = " << input_cols 537 << ", filter_cols = " << filter_cols 538 << ", input_rows = " << input_rows 539 << ", filter_rows = " << filter_rows 540 << ", stride_rows = " << stride_rows 541 << ", stride_cols = " << stride_cols 542 << ", out_depth = " << out_depth; 543 544 // If there is nothing to compute, return. 545 if (out_shape.num_elements() == 0) { 546 return; 547 } 548 TConvFunctor conv_functor; 549 conv_functor(context, input.flat<T>().data(), batch, input_rows, input_cols, 550 in_depth, filter.flat<T>().data(), filter_rows, filter_cols, 551 out_depth, stride_rows, stride_cols, padding_, 552 output->flat<T>().data(), out_rows, out_cols); 553 } 554 555 private: 556 std::vector<int32> strides_; 557 Padding padding_; 558 TensorFormat data_format_; 559 560 TF_DISALLOW_COPY_AND_ASSIGN(Conv2DUsingGemmOp); 561 }; 562 563 #define REGISTER_CPU(T) \ 564 REGISTER_KERNEL_BUILDER( \ 565 Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 566 Conv2DUsingGemmOp< \ 567 T, Im2ColConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>); 568 569 // Only register this GEMM-based implementation of Conv2d if the compiler flags 570 // request the implementation explicitly, since otherwise it will clash with the 571 // default EigenTensor-based kernel. 572 #if defined(USE_GEMM_FOR_CONV) 573 TF_CALL_half(REGISTER_CPU); 574 TF_CALL_float(REGISTER_CPU); 575 #endif // USE_GEMM_FOR_CONV 576 577 } // namespace tensorflow 578