Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // This file contains a set of different implementations of the two-dimensional
     17 // convolution operation. The standard TensorFlow Conv2d kernel uses EigenTensor
     18 // to implement the computation, but this module has a variety of different ways
     19 // of producing the same result. These methods are designed to be easier to
     20 // understand and connect to other libraries, so that we can take advantage of
     21 // platforms that have specialized implementations of GEMM for example.
     22 //
     23 // The basic interface is a Conv functor object that's templated by the types
     24 // of the data it will be operating on, and is passed in the arguments needed to
     25 // calculate the convolution. The simplest implementation of this functor is
     26 // ReferenceConvFunctor, which is a readable but slow reference version.
     27 //
     28 // A faster version uses the approach of packing image patches into a matrix
     29 // before calling a matrix multiply, the Im2ColConvFunctor. In turn, this can
     30 // use a variety of different methods to calculate the matrix multiplication,
     31 // or GEMM. The simplest but slowest is the ReferenceGemmFunctor, but the
     32 // FastGemmFunctor will use whatever optimized libraries are available. By
     33 // default it uses Eigen, but on Apple platforms it will take advantage of the
     34 // system's Accelerate BLAS library to get better performance than the standard
     35 // TensorFlow convolution kernel.
     36 //
     37 // The version actually used is defined at the bottom of this file using the
     38 // REGISTER_KERNEL_BUILDER() macro. To try out different implementations (for
     39 // example to switch to a reference one for easier debugging) you can swap out
     40 // the default functors in that call.
     41 //
     42 // The registration itself is guarded with the USE_GEMM_FOR_CONV macro. The iOS
     43 // makefile build defines this, but if you want to enable this implementation
     44 // and disable the standard EigenTensor one in other build setups, you'll need
     45 // to define it there too.
     46 
     47 #define EIGEN_USE_THREADS
     48 
     49 #include <string.h>
     50 #include <map>
     51 #include <vector>
     52 #include "tensorflow/core/framework/common_shape_fns.h"
     53 #include "tensorflow/core/framework/numeric_op.h"
     54 #include "tensorflow/core/framework/op_kernel.h"
     55 #include "tensorflow/core/framework/register_types.h"
     56 #include "tensorflow/core/framework/resource_mgr.h"
     57 #include "tensorflow/core/framework/tensor.h"
     58 #include "tensorflow/core/framework/tensor_shape.h"
     59 #include "tensorflow/core/framework/tensor_slice.h"
     60 #include "tensorflow/core/kernels/bounds_check.h"
     61 #include "tensorflow/core/kernels/conv_ops.h"
     62 #include "tensorflow/core/kernels/gemm_functors.h"
     63 #include "tensorflow/core/kernels/image_resizer_state.h"
     64 #include "tensorflow/core/util/mirror_pad_mode.h"
     65 #include "tensorflow/core/util/padding.h"
     66 #include "tensorflow/core/util/tensor_format.h"
     67 
     68 namespace tensorflow {
     69 
     70 namespace {
     71 // This function implements the convolution operation in as simple a form as
     72 // possible. It won't give great performance, but it is very useful for
     73 // stepping through and instrumenting for debugging, creating minimal benchmarks
     74 // to prototype with, and sharing with teams that want to run this outside of
     75 // our environment.
     76 // With that in mind, I've avoided using anything except pretty standard C++
     77 // types. This is especially noticeable in the data access through raw array
     78 // indexing. It's deliberate in this case though, since it makes the underlying
     79 // memory order very explicit, which is important for both inspecting memory
     80 // contents during debugging and for specifying what we expect to others.
     81 // The memory layout of the data is, from biggest stride to smallest:
     82 // input_data = [input_batches, input_height, input_width, input_depth]
     83 // filter_data = [filter_height, filter_width, input_depth, filter_count]
     84 // output_data = [input_batches, output_height, output_width, filter_count]
     85 template <class T1, class T2, class T3>
     86 class ReferenceConvFunctor {
     87  public:
     88   void operator()(OpKernelContext* context, const T1* input_data,
     89                   int input_batches, int input_height, int input_width,
     90                   int input_depth, const T2* filter_data, int filter_height,
     91                   int filter_width, int filter_count, int stride_rows,
     92                   int stride_cols, Padding padding, T3* output_data,
     93                   int output_height, int output_width) {
     94     // The two different padding modes we support can be a bit confusing. SAME
     95     // means we're trying to produce an output image that's the same size as the
     96     // input. It's complicated by stride, which shrinks the output image by a
     97     // a factor, but it means we end up sampling from outside the borders of the
     98     // input. These out-of-bounds values are read as zeroes. VALID means only
     99     // produce output values where the filters can read all their values from
    100     // within the input image. It effectively removes the margins of the output
    101     // image compared to the one produced by SAME. Stride complicates this
    102     // definition though, because it can result in the right and bottom filter
    103     // patches sampling from outside the borders if it's greater than 1.
    104     // Most of the logic for sorting this all out is done before this function,
    105     // when we calculate the output size, but the positioning of the origin of
    106     // the filters is different between the two modes, since SAME positions the
    107     // first filter off the edge of the input.
    108     int filter_left_offset;
    109     int filter_top_offset;
    110     if (padding == VALID) {
    111       filter_left_offset =
    112           ((output_width - 1) * stride_cols + filter_width - input_width + 1) /
    113           2;
    114       filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
    115                            input_height + 1) /
    116                           2;
    117     } else {
    118       filter_left_offset =
    119           ((output_width - 1) * stride_cols + filter_width - input_width) / 2;
    120       filter_top_offset =
    121           ((output_height - 1) * stride_rows + filter_height - input_height) /
    122           2;
    123     }
    124 
    125     // If we've got multiple images in our input, work through each of them.
    126     for (int batch = 0; batch < input_batches; ++batch) {
    127       // Walk through all the output image values, sliding the filter to
    128       // different positions in the input.
    129       for (int out_y = 0; out_y < output_height; ++out_y) {
    130         for (int out_x = 0; out_x < output_width; ++out_x) {
    131           // Each filter kernel produces one output channel.
    132           for (int out_channel = 0; out_channel < filter_count; ++out_channel) {
    133             // We're going to calculate a single output value, which means we
    134             // need to multiply a three dimensional kernel of weights against
    135             // the current location within the input image.
    136             /*
    137              *-------------------------------...
    138              |\ ^
    139              | \in_depth
    140              |  \ v
    141              |   *-------------------------------...
    142              |   |            ^
    143              |   |       in_y_origin
    144              |   |            v   \
    145              |   |<in_x_origin>*---*^
    146              |   |            \|   |filter_height
    147              .   |             *---*v
    148              .   |             <--->
    149              .         filter_width
    150              .
    151             */
    152             const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
    153             const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
    154             T3 total(0);
    155             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
    156               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
    157                 for (int in_channel = 0; in_channel < input_depth;
    158                      ++in_channel) {
    159                   const int in_x = in_x_origin + filter_x;
    160                   const int in_y = in_y_origin + filter_y;
    161                   T1 input_value;
    162                   // If the location is outside the bounds of the input image,
    163                   // use zero as a default value.
    164                   if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
    165                       (in_y < input_height)) {
    166                     input_value =
    167                         input_data[(batch * input_height * input_width *
    168                                     input_depth) +
    169                                    (in_y * input_width * input_depth) +
    170                                    (in_x * input_depth) + in_channel];
    171                   } else {
    172                     input_value = T1(0);
    173                   }
    174                   const T2 filter_value =
    175                       filter_data[(filter_y * filter_width * input_depth *
    176                                    filter_count) +
    177                                   (filter_x * input_depth * filter_count) +
    178                                   (in_channel * filter_count) + out_channel];
    179                   total += (input_value * filter_value);
    180                 }
    181               }
    182             }
    183             output_data[(batch * output_height * output_width * filter_count) +
    184                         (out_y * output_width * filter_count) +
    185                         (out_x * filter_count) + out_channel] = total;
    186           }
    187         }
    188       }
    189     }
    190   }
    191 };
    192 
    193 // We don't want to allocate a buffer to hold all the patches if the size is
    194 // going to be extremely large, so break it into chunks if it's bigger than
    195 // a limit. Each chunk will be processed serially, so we can refill the
    196 // buffer for the next chunk and reuse it, keeping maximum memory size down.
    197 // In this case, we've picked 16 megabytes as a reasonable limit for Android and
    198 // other platforms using Eigen, and 1MB for Apple devices, from experimentation.
    199 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
    200 const size_t kMaxChunkSize = (1 * 1024 * 1024);
    201 #else
    202 const size_t kMaxChunkSize = (16 * 1024 * 1024);
    203 #endif
    204 
    205 // Implements convolution as a two stage process, first packing the patches of
    206 // the input image into columns (im2col) and then running GEMM to produce the
    207 // final result.
    208 template <class T1, class T2, class T3, class TGemmFunctor>
    209 class Im2ColConvFunctor {
    210  public:
    211   void operator()(OpKernelContext* context, const T1* input_data,
    212                   int input_batches, int input_height, int input_width,
    213                   int input_depth, const T2* filter_data, int filter_height,
    214                   int filter_width, int filter_count, int stride_rows,
    215                   int stride_cols, Padding padding, T3* output_data,
    216                   int output_height, int output_width) {
    217     if ((input_batches <= 0) || (input_width <= 0) || (input_height <= 0) ||
    218         (input_depth <= 0)) {
    219       LOG(WARNING) << "Conv2D was called with bad input dimensions: "
    220                    << input_batches << ", " << input_height << ", "
    221                    << input_width << ", " << input_depth;
    222       return;
    223     }
    224     if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
    225       LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
    226                    << filter_width << ", " << filter_height << ", "
    227                    << filter_count;
    228       return;
    229     }
    230     if ((output_width <= 0) || (output_height <= 0)) {
    231       LOG(WARNING) << "Conv2D was called with bad output width or height: "
    232                    << output_width << ", " << output_height;
    233       return;
    234     }
    235 
    236     // We can just use a GEMM if the im2col is the identity operator, e.g., if
    237     // the kernel is 1x1 or the input data and filter have same height/width.
    238     if (filter_height == 1 && filter_width == 1 && stride_rows == 1 &&
    239         stride_cols == 1) {
    240       // The kernel is 1x1.
    241       const int m = input_batches * input_height * input_width;
    242       const int n = filter_count;
    243       const int k = input_depth;
    244       const int lda = k;
    245       const int ldb = filter_count;
    246       const int ldc = filter_count;
    247       TGemmFunctor gemm_functor;
    248       gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb,
    249                    output_data, ldc);
    250       return;
    251     } else if (filter_height == input_height && filter_width == input_width &&
    252                padding == VALID) {
    253       // The input data and filter have the same height/width.
    254       const int m = input_batches;
    255       const int n = filter_count;
    256       const int k = input_height * input_width * input_depth;
    257       const int lda = k;
    258       const int ldb = filter_count;
    259       const int ldc = filter_count;
    260       TGemmFunctor gemm_functor;
    261       gemm_functor(context, m, n, k, input_data, lda, filter_data, ldb,
    262                    output_data, ldc);
    263       return;
    264     }
    265 
    266     // These calculations define how the patches will be positioned within the
    267     // input image. The actual definitions are quite complex, and rely on the
    268     // previously-calculated output size.
    269     int filter_left_offset;
    270     int filter_top_offset;
    271     if (padding == VALID) {
    272       filter_left_offset =
    273           ((output_width - 1) * stride_cols + filter_width - input_width + 1) /
    274           2;
    275       filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
    276                            input_height + 1) /
    277                           2;
    278     } else {
    279       filter_left_offset =
    280           ((output_width - 1) * stride_cols + filter_width - input_width) / 2;
    281       filter_top_offset =
    282           ((output_height - 1) * stride_rows + filter_height - input_height) /
    283           2;
    284     }
    285 
    286     // The im2col buffer has # of patches rows, and # of filters cols.
    287     // It's laid out like this, in row major order in memory:
    288     //        < filter value count >
    289     //   ^   +---------------------+
    290     // patch |                     |
    291     // count |                     |
    292     //   v   +---------------------+
    293     // Each patch row contains a filter_width x filter_height patch of the
    294     // input, with the depth channel as the most contiguous in memory, followed
    295     // by the width, then the height. This is the standard memory order in the
    296     // image world if it helps to visualize it.
    297     const int filter_value_count = filter_width * filter_height * input_depth;
    298     OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
    299                 errors::InvalidArgument("Im2Col patch too large for buffer"));
    300     const int64 patches_per_chunk =
    301         kMaxChunkSize / (filter_value_count * sizeof(T1));
    302     const int64 chunk_value_count =
    303         (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
    304     // Because memory allocation is very expensive on mobile platforms, try to
    305     // allocate a persistent buffer that will be kept around between calls. We
    306     // use TensorFlow's resource management to ensure that the memory will be
    307     // released when the session is over.
    308     Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource;
    309     std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)>
    310         creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) {
    311           *resource = new Im2ColBufferResource<T1, chunk_value_count>();
    312           return Status::OK();
    313         };
    314     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
    315                                 "Conv2d", "im2col_buffer",
    316                                 &im2col_buffer_resource, creator));
    317     // This means that multiple ops can't be run simultaneously on different
    318     // threads, because we have a single shared resource. The platforms this is
    319     // aimed at have intra-op parallelism as their focus though, so it shouldn't
    320     // be an issue.
    321     mutex_lock lock_buffer(im2col_buffer_resource->mu);
    322     core::ScopedUnref unref_buffer(im2col_buffer_resource);
    323     T1* im2col_buffer = im2col_buffer_resource->data;
    324 
    325     const int64 patch_count = (input_batches * output_height * output_width);
    326     const int64 chunk_count =
    327         (patch_count + (patches_per_chunk - 1)) / patches_per_chunk;
    328     for (int64 chunk_index = 0; chunk_index < chunk_count; ++chunk_index) {
    329       const int64 patch_index_start = chunk_index * patches_per_chunk;
    330       const int64 patch_index_end =
    331           std::min(patch_index_start + patches_per_chunk, patch_count);
    332       for (int64 patch_index = patch_index_start; patch_index < patch_index_end;
    333            ++patch_index) {
    334         const int64 batch = patch_index / (output_height * output_width);
    335         const int64 out_y = (patch_index / output_width) % output_height;
    336         const int64 out_x = patch_index % output_width;
    337         const T1* input_batch_start =
    338             input_data + (batch * input_height * input_width * input_depth);
    339         const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
    340         const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
    341         const int patch_index_within_chunk = patch_index % patches_per_chunk;
    342         T1* im2col_patch_start =
    343             im2col_buffer + (patch_index_within_chunk * filter_value_count);
    344         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
    345           const int in_y = in_y_origin + filter_y;
    346           T1* im2col_row_start =
    347               im2col_patch_start + (filter_y * filter_width * input_depth);
    348           // If we're off the top or the bottom of the input, fill the
    349           // whole row with zeroes.
    350           if ((in_y < 0) || (in_y >= input_height)) {
    351             T1* im2col_row_end =
    352                 im2col_row_start + (filter_width * input_depth);
    353             std::fill(im2col_row_start, im2col_row_end, T1(0));
    354           } else {
    355             // What we're doing here is trying to copy and fill the im2col
    356             // buffer as efficiently as possible, using functions to set or
    357             // duplicate values en masse. We know we don't have to worry about
    358             // vertical edges because we dealt with that case above, so we
    359             // just need to handle filters that overlap the left or right
    360             // edges. Here's what that looks like:
    361             //
    362             // < left_zero_count > < center_copy_count > < right_zero_count >
    363             // +------------------+---------------------+--------------------+
    364             // |     (filter)     |       (image)       |      (filter)      |
    365             // +------------------+---------------------+--------------------+
    366             // in_x_origin        0                 input_width       in_x_end
    367             //
    368             // In reality it's unlikely that a filter patch will be wider
    369             // than an input, but this shows all the edge cases.
    370             // We use std::fill() to set the left and right sections to zeroes
    371             // and std::copy() to copy over the input data for the center.
    372             const int in_x_end = in_x_origin + filter_width;
    373             const int left_zero_count = std::max(0, 0 - in_x_origin);
    374             const int right_zero_count = std::max(0, in_x_end - input_width);
    375             const int center_copy_count =
    376                 filter_width - (left_zero_count + right_zero_count);
    377             if (left_zero_count > 0) {
    378               T1* im2col_left_start = im2col_row_start;
    379               T1* im2col_left_end =
    380                   im2col_left_start + (left_zero_count * input_depth);
    381               std::fill(im2col_left_start, im2col_left_end, T1(0));
    382             }
    383             if (center_copy_count > 0) {
    384               const T1* input_row_start =
    385                   input_batch_start + (in_y * input_width * input_depth) +
    386                   (std::max(0, in_x_origin) * input_depth);
    387               const T1* input_row_end =
    388                   input_row_start + (center_copy_count * input_depth);
    389               T1* im2col_center_start =
    390                   im2col_row_start + (left_zero_count * input_depth);
    391               std::copy(input_row_start, input_row_end, im2col_center_start);
    392             }
    393             if (right_zero_count > 0) {
    394               T1* im2col_right_start =
    395                   im2col_row_start +
    396                   ((left_zero_count + center_copy_count) * input_depth);
    397               T1* im2col_right_end =
    398                   im2col_right_start + (right_zero_count * input_depth);
    399               std::fill(im2col_right_start, im2col_right_end, T1(0));
    400             }
    401           }
    402         }
    403       }
    404       // Now we've assembled a set of image patches into a matrix, apply a
    405       // GEMM matrix multiply of the patches as rows, times the filter
    406       // weights in columns, to get partial results in the output matrix.
    407       const int how_many_patches = patch_index_end - patch_index_start;
    408       const int m = how_many_patches;
    409       const int n = filter_count;
    410       const int k = filter_value_count;
    411       const int lda = filter_value_count;
    412       const int ldb = filter_count;
    413       const int ldc = filter_count;
    414       T3* chunk_output_data = output_data + (patch_index_start * filter_count);
    415       TGemmFunctor gemm_functor;
    416       gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
    417                    chunk_output_data, ldc);
    418     }
    419   }
    420 };
    421 
    422 }  // namespace
    423 
    424 // This TensorFlow kernel class handles all of the IO and housekeeping for the
    425 // functors that actually implement the underlying algorithm. To swap in
    426 // different implementations of the main calculations, use a different
    427 // TConvFunctor parameter when instantiating the template.
    428 template <class T, class TConvFunctor>
    429 class Conv2DUsingGemmOp : public BinaryOp<T> {
    430  public:
    431   explicit Conv2DUsingGemmOp(OpKernelConstruction* context)
    432       : BinaryOp<T>(context) {
    433     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    434     string data_format;
    435     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    436     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    437                 errors::InvalidArgument("Invalid data format"));
    438     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
    439                 errors::InvalidArgument(
    440                     "Data format not supported by this kernel", data_format));
    441     OP_REQUIRES(context, strides_.size() == 4,
    442                 errors::InvalidArgument("Sliding window strides field must "
    443                                         "specify 4 dimensions"));
    444     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
    445     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
    446     OP_REQUIRES(
    447         context, stride_n == 1 && stride_c == 1,
    448         errors::InvalidArgument("Current implementation does not yet support "
    449                                 "strides in the batch and depth dimensions."));
    450     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    451   }
    452 
    453   void Compute(OpKernelContext* context) override {
    454     // Input tensor is of the following dimensions:
    455     // [ batch, in_rows, in_cols, in_depth ]
    456     const Tensor& input = context->input(0);
    457 
    458     // Input filter is of the following dimensions:
    459     // [ filter_rows, filter_cols, in_depth, out_depth]
    460     const Tensor& filter = context->input(1);
    461 
    462     // For 2D convolution, there should be 4 dimensions.
    463     OP_REQUIRES(context, input.dims() == 4,
    464                 errors::InvalidArgument("input must be 4-dimensional",
    465                                         input.shape().DebugString()));
    466     OP_REQUIRES(context, filter.dims() == 4,
    467                 errors::InvalidArgument("filter must be 4-dimensional: ",
    468                                         filter.shape().DebugString()));
    469 
    470     for (int i = 0; i < 3; i++) {
    471       OP_REQUIRES(
    472           context,
    473           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
    474           errors::InvalidArgument("filter too large"));
    475     }
    476 
    477     // The last dimension for input is in_depth. It must be the same as the
    478     // filter's in_depth.
    479     const int64 in_depth = GetTensorDim(input, data_format_, 'C');
    480     OP_REQUIRES(context, in_depth == filter.dim_size(2),
    481                 errors::InvalidArgument(
    482                     "input and filter must have the same depth: ", in_depth,
    483                     " vs ", filter.dim_size(2)));
    484 
    485     // The last dimension for filter is out_depth.
    486     const int out_depth = static_cast<int>(filter.dim_size(3));
    487 
    488     // The second dimension for input is rows/height.
    489     // The first dimension for filter is rows/height.
    490     const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
    491     OP_REQUIRES(
    492         context,
    493         FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
    494         errors::InvalidArgument("Input rows too large"));
    495     const int input_rows = static_cast<int>(input_rows_raw);
    496     const int filter_rows = static_cast<int>(filter.dim_size(0));
    497 
    498     // The third dimension for input is columns/width.
    499     // The second dimension for filter is columns/width.
    500     const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
    501     OP_REQUIRES(
    502         context,
    503         FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
    504         errors::InvalidArgument("Input cols too large"));
    505     const int input_cols = static_cast<int>(input_cols_raw);
    506     const int filter_cols = static_cast<int>(filter.dim_size(1));
    507 
    508     // The first dimension for input is batch.
    509     const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
    510     OP_REQUIRES(context,
    511                 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
    512                 errors::InvalidArgument("batch is too large"));
    513     const int batch = static_cast<int>(batch_raw);
    514 
    515     // For now we take the stride from the second and third dimensions only (we
    516     // do not support striding on the batch or depth dimension).
    517     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
    518     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
    519 
    520     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
    521     OP_REQUIRES_OK(context,
    522                    GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
    523                                          padding_, &out_rows, &pad_rows));
    524     OP_REQUIRES_OK(context,
    525                    GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
    526                                          padding_, &out_cols, &pad_cols));
    527     TensorShape out_shape =
    528         ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
    529 
    530     // Output tensor is of the following dimensions:
    531     // [ in_batch, out_rows, out_cols, out_depth ]
    532     Tensor* output = nullptr;
    533     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    534 
    535     VLOG(2) << "Conv2D: in_depth = " << in_depth
    536             << ", input_cols = " << input_cols
    537             << ", filter_cols = " << filter_cols
    538             << ", input_rows = " << input_rows
    539             << ", filter_rows = " << filter_rows
    540             << ", stride_rows = " << stride_rows
    541             << ", stride_cols = " << stride_cols
    542             << ", out_depth = " << out_depth;
    543 
    544     // If there is nothing to compute, return.
    545     if (out_shape.num_elements() == 0) {
    546       return;
    547     }
    548     TConvFunctor conv_functor;
    549     conv_functor(context, input.flat<T>().data(), batch, input_rows, input_cols,
    550                  in_depth, filter.flat<T>().data(), filter_rows, filter_cols,
    551                  out_depth, stride_rows, stride_cols, padding_,
    552                  output->flat<T>().data(), out_rows, out_cols);
    553   }
    554 
    555  private:
    556   std::vector<int32> strides_;
    557   Padding padding_;
    558   TensorFormat data_format_;
    559 
    560   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DUsingGemmOp);
    561 };
    562 
    563 #define REGISTER_CPU(T)                                         \
    564   REGISTER_KERNEL_BUILDER(                                      \
    565       Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    566       Conv2DUsingGemmOp<                                        \
    567           T, Im2ColConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
    568 
    569 // Only register this GEMM-based implementation of Conv2d if the compiler flags
    570 // request the implementation explicitly, since otherwise it will clash with the
    571 // default EigenTensor-based kernel.
    572 #if defined(USE_GEMM_FOR_CONV)
    573 TF_CALL_half(REGISTER_CPU);
    574 TF_CALL_float(REGISTER_CPU);
    575 #endif  // USE_GEMM_FOR_CONV
    576 
    577 }  // namespace tensorflow
    578