Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implements convolution operations with other kernels baked into the
     17 // processing, to optimize latency and memory usage.
     18 
     19 #define EIGEN_USE_THREADS
     20 
     21 #include <string.h>
     22 #include <map>
     23 #include <vector>
     24 #include "tensorflow/core/framework/common_shape_fns.h"
     25 #include "tensorflow/core/framework/numeric_op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/resource_mgr.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/framework/tensor_shape.h"
     31 #include "tensorflow/core/framework/tensor_slice.h"
     32 #include "tensorflow/core/kernels/bounds_check.h"
     33 #include "tensorflow/core/kernels/conv_ops.h"
     34 #include "tensorflow/core/kernels/gemm_functors.h"
     35 #include "tensorflow/core/kernels/image_resizer_state.h"
     36 #include "tensorflow/core/lib/core/threadpool.h"
     37 #include "tensorflow/core/util/mirror_pad_mode.h"
     38 #include "tensorflow/core/util/padding.h"
     39 #include "tensorflow/core/util/tensor_format.h"
     40 
     41 namespace tensorflow {
     42 
     43 namespace {
     44 
     45 // We don't want to allocate a buffer to hold all the patches if the size is
     46 // going to be extremely large, so break it into chunks if it's bigger than
     47 // a limit. Each chunk will be processed serially, so we can refill the
     48 // buffer for the next chunk and reuse it, keeping maximum memory size down.
     49 // In this case, we've picked 16 megabytes as a reasonable limit for Android and
     50 // other platforms using Eigen, and 1MB for iOS devices, from experimentation.
     51 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
     52 const size_t kMaxChunkSize = (1 * 1024 * 1024);
     53 #else
     54 const size_t kMaxChunkSize = (16 * 1024 * 1024);
     55 #endif
     56 const size_t kResizeCacheSize = (8 * 1024 * 1024);
     57 
     58 // Lookup method used when resizing.
     59 enum SamplingMode {
     60   BILINEAR = 0,
     61   NEAREST = 1,
     62 };
     63 
     64 // Simple utility function used by FusedConv to multithread basic workloads. To
     65 // use it, pass begin and end values for the full workload and a std::function
     66 // that receives a subset of that through the begin and end values for each
     67 // worker's task. The division of the full workload into worker tasks is handled
     68 // by the multithreading logic. Here's an example of how to use it:
     69 // std::vector<float> my_vector(100);
     70 // ...
     71 // FusedConvParallelFor(context, 0, 100,
     72 //   [&my_vector](int64 task_begin, int64 task_end) {
     73 //     for (int64 current = task_begin; current != task_end; ++current) {
     74 //       my_vector[current] *= 10.0f;
     75 //     }
     76 // });
     77 void FusedConvParallelFor(
     78     OpKernelContext* context, int64 begin, int64 end,
     79     const std::function<void(int64, int64)>& task_function) {
     80 // On iOS, the thread management imposes a very big performance penalty, so
     81 // just call the function directly with no multithreading.
     82 #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
     83   task_function(begin, end);
     84 #else
     85   auto& worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
     86   thread::ThreadPool* thread_pool = worker_threads.workers;
     87   const int64 total_elements = end - begin;
     88   // This is a bit of an arbitrary number, but was found to work well for
     89   // typical models we've been profiling on various devices.
     90   const int64 element_cost = 10000000;
     91   thread_pool->ParallelFor(
     92       total_elements, element_cost,
     93       [begin, task_function](int64 begin_offset, int64 end_offset) {
     94         const int64 task_begin = begin + begin_offset;
     95         const int64 task_end = begin + end_offset;
     96         task_function(task_begin, task_end);
     97       });
     98 #endif
     99 }
    100 
    101 // Holds the state needed for the resizing subtasks.
    102 template <class T1>
    103 struct ResizeTaskParameters {
    104   ResizeTaskParameters() : st(false) {}
    105 
    106   int cache_height;
    107   T1* resize_cache;
    108   int cache_line_width;
    109   int input_width;
    110   int input_depth;
    111   int top_padding;
    112   int pad_offset;
    113   int64 resized_height;
    114   ImageResizerState st;
    115   const T1* input_batch_start;
    116   int64 cache_start_x;
    117   int64 cache_end_x;
    118   int left_padding;
    119   int64 resized_width;
    120   int64 padded_width;
    121   int64 padded_height;
    122 };
    123 
    124 template <class T1>
    125 struct PerCacheLineParameters {
    126   PerCacheLineParameters() {}
    127   PerCacheLineParameters(const PerCacheLineParameters<T1>& other)
    128       : cache_line_start(other.cache_line_start),
    129         input_top_row_start(other.input_top_row_start),
    130         input_bottom_row_start(other.input_bottom_row_start),
    131         y_lerp(other.y_lerp) {}
    132 
    133   T1* cache_line_start;
    134   const T1* input_top_row_start;
    135   const T1* input_bottom_row_start;
    136   T1 y_lerp;
    137 };
    138 
    139 // Helper class to simplify bilinear filtering
    140 template <class T1>
    141 struct SampleRect {
    142   EIGEN_ALWAYS_INLINE SampleRect(const T1* in_top_left, const T1* in_top_right,
    143                                  const T1* in_bottom_left,
    144                                  const T1* in_bottom_right)
    145       : top_left(in_top_left),
    146         top_right(in_top_right),
    147         bottom_left(in_bottom_left),
    148         bottom_right(in_bottom_right) {}
    149 
    150   EIGEN_ALWAYS_INLINE T1 BilinearSample(int channel, T1 x_lerp,
    151                                         T1 y_lerp) const {
    152     const T1 top =
    153         top_left[channel] + (top_right[channel] - top_left[channel]) * x_lerp;
    154     const T1 bottom = bottom_left[channel] +
    155                       (bottom_right[channel] - bottom_left[channel]) * x_lerp;
    156     return top + (bottom - top) * y_lerp;
    157   }
    158 
    159   const T1* top_left;
    160   const T1* top_right;
    161   const T1* bottom_left;
    162   const T1* bottom_right;
    163 };
    164 
    165 // Calculates parameters which remain constant through a resize cache row.
    166 template <class T1>
    167 EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
    168     int64 cache_height, int64 cache_y, T1* resize_cache, int64 cache_line_width,
    169     int64 input_width, int64 input_depth, int64 top_padding, int64 pad_offset,
    170     int64 resized_height, const ImageResizerState& st,
    171     const T1* input_batch_start) {
    172   PerCacheLineParameters<T1> result;
    173   // The cache is organized so that the real y values of the resized image map
    174   // onto the actual cache values through a modulo scheme. This means that as we
    175   // progress downwards through the image, we keep reusing a small cache and so
    176   // keep memory usage down.
    177   int64 cache_index_y;
    178   if (cache_y < 0) {
    179     cache_index_y = cache_height + (cache_y % cache_height);
    180   } else {
    181     cache_index_y = cache_y % cache_height;
    182   }
    183   result.cache_line_start =
    184       resize_cache + (cache_index_y * cache_line_width * input_depth);
    185   // This part is implementing the mirror padding that happens before resizing.
    186   float in_y = (cache_y - top_padding);
    187   if (in_y < 0) {
    188     in_y = -(in_y + 1.0f - pad_offset);
    189   } else if (in_y >= resized_height) {
    190     in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
    191   }
    192   // Here's where do do the actual resize.
    193   in_y *= st.height_scale;
    194   const int64 top_y_index = static_cast<int64>(std::floor(in_y));
    195   const int64 bottom_y_index =
    196       std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
    197   // Lerp is used for bilinear filtering when that's needed.
    198   result.y_lerp = in_y - top_y_index;
    199   // Which rows of the original input image to pull the values from.
    200   result.input_top_row_start =
    201       input_batch_start + (top_y_index * input_width * input_depth);
    202   result.input_bottom_row_start =
    203       input_batch_start + (bottom_y_index * input_width * input_depth);
    204   return result;
    205 }
    206 
    207 template <class T1>
    208 struct PerCachePixelParameters {
    209   PerCachePixelParameters() {}
    210   PerCachePixelParameters(const PerCachePixelParameters<T1>& other)
    211       : cache_line_pixel(other.cache_line_pixel),
    212         left_x_index(other.left_x_index),
    213         right_x_index(other.right_x_index),
    214         x_lerp(other.x_lerp) {}
    215 
    216   T1* cache_line_pixel;
    217   int64 left_x_index;
    218   int64 right_x_index;
    219   T1 x_lerp;
    220 };
    221 
    222 // Pulls out common parameters used for every resized pixel.
    223 template <class T1>
    224 EIGEN_ALWAYS_INLINE PerCachePixelParameters<T1>
    225 CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x,
    226                                  T1* cache_line_start, int64 input_depth,
    227                                  int64 left_padding, int64 pad_offset,
    228                                  int64 resized_width,
    229                                  const ImageResizerState& st) {
    230   PerCachePixelParameters<T1> result;
    231   // Figure out where we're going to store the results of our transform.
    232   const int cache_index_x = cache_x - cache_start_x;
    233   result.cache_line_pixel = cache_line_start + (cache_index_x * input_depth);
    234   // Implement mirror padding by flipping in_x if it's off the edge.
    235   float in_x = (cache_x - left_padding);
    236   if (in_x < 0) {
    237     in_x = -(in_x + 1.0f - pad_offset);
    238   } else if (in_x >= resized_width) {
    239     in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
    240   }
    241   // Resize the x parameters.
    242   in_x *= st.width_scale;
    243   // Get the x coordinates for the left and right pixels to pull from.
    244   result.left_x_index = static_cast<int64>(std::floor(in_x));
    245   result.right_x_index =
    246       std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
    247   // This x_lerp is used to blend pixels in bilinear filtering.
    248   result.x_lerp = in_x - result.left_x_index;
    249   return result;
    250 }
    251 
    252 // Combines bilinear resizing and mirror padding into the im2col transformation
    253 // stage of convolution.
    254 template <class T1, class T2, class T3, class TGemmFunctor,
    255           SamplingMode SampleMode>
    256 class FusedResizeAndPadConvFunctor {
    257  public:
    258   void operator()(OpKernelContext* context, const Tensor& input,
    259                   int input_batches, int resized_height, int resized_width,
    260                   int padded_height, int padded_width, int input_depth,
    261                   const T2* filter_data, int filter_height, int filter_width,
    262                   int filter_count, int stride_rows, int stride_cols,
    263                   Padding padding, T3* output_data, int output_height,
    264                   int output_width, const ImageResizerState& st,
    265                   int top_padding, int bottom_padding, int left_padding,
    266                   int right_padding, int pad_offset) {
    267     if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
    268         (input_depth <= 0)) {
    269       LOG(WARNING) << "Conv2D was called with bad input dimensions: "
    270                    << input_batches << ", " << padded_height << ", "
    271                    << padded_width << ", " << input_depth;
    272       return;
    273     }
    274     if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
    275       LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
    276                    << filter_width << ", " << filter_height << ", "
    277                    << filter_count;
    278       return;
    279     }
    280     if ((output_width <= 0) || (output_height <= 0)) {
    281       LOG(WARNING) << "Conv2D was called with bad output width or height: "
    282                    << output_width << ", " << output_height;
    283       return;
    284     }
    285     OP_REQUIRES(
    286         context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)),
    287         errors::InvalidArgument("Bad sample mode passed in", SampleMode));
    288 
    289     // These calculations define how the patches will be positioned within the
    290     // input image. The actual definitions are quite complex, and rely on the
    291     // previously-calculated output size.
    292     int filter_left_offset;
    293     int filter_top_offset;
    294     if (padding == VALID) {
    295       filter_left_offset =
    296           ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
    297           2;
    298       filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
    299                            padded_height + 1) /
    300                           2;
    301     } else {
    302       filter_left_offset =
    303           ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
    304       filter_top_offset =
    305           ((output_height - 1) * stride_rows + filter_height - padded_height) /
    306           2;
    307     }
    308 
    309     ResizeTaskParameters<T1> task_params;
    310     task_params.input_depth = input_depth;
    311     task_params.top_padding = top_padding;
    312     task_params.pad_offset = pad_offset;
    313     task_params.resized_height = resized_height;
    314     task_params.st = st;
    315     task_params.left_padding = left_padding;
    316     task_params.resized_width = resized_width;
    317     task_params.padded_width = padded_width;
    318     task_params.padded_height = padded_height;
    319 
    320     // The im2col buffer has # of patches rows, and # of filters cols.
    321     // It's laid out like this, in row major order in memory:
    322     //        < filter value count >
    323     //   ^   +---------------------+
    324     // patch |                     |
    325     // count |                     |
    326     //   v   +---------------------+
    327     // Each patch row contains a filter_width x filter_height patch of the
    328     // input, with the depth channel as the most contiguous in memory, followed
    329     // by the width, then the height. This is the standard memory order in the
    330     // image world if it helps to visualize it.
    331     const int filter_value_count = filter_width * filter_height * input_depth;
    332 
    333     OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= kMaxChunkSize,
    334                 errors::InvalidArgument("Im2Col patch too large for buffer"));
    335     const size_t patches_per_chunk =
    336         kMaxChunkSize / (filter_value_count * sizeof(T1));
    337     // Because memory allocation is very expensive on mobile platforms, try to
    338     // allocate a persistent buffer that will be kept around between calls. We
    339     // use TensorFlow's resource management to ensure that the memory will be
    340     // released when the session is over.
    341     Im2ColBufferResource<T1, kMaxChunkSize>* im2col_buffer_resource;
    342     std::function<Status(Im2ColBufferResource<T1, kMaxChunkSize>**)> creator =
    343         [](Im2ColBufferResource<T1, kMaxChunkSize>** resource) {
    344           *resource = new Im2ColBufferResource<T1, kMaxChunkSize>();
    345           return Status::OK();
    346         };
    347     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
    348                                 "Conv2d", "im2col_buffer",
    349                                 &im2col_buffer_resource, creator));
    350 
    351     // Create a resize cache memory buffer that will hold the rows of
    352     // transformed and mirror padded input pixels, ready to be copied
    353     // into filter patches by im2col.
    354     // It's laid out like this, in row major order in memory:
    355     //         < cache line width >
    356     //   ^    +--------------------+
    357     // cache  |                    |
    358     // height |                    |
    359     //   v    +--------------------+
    360     // Each cache row contains a cache_line_width number of resized pixels,
    361     // each with input_depth channels. The cache height is typically less than
    362     // the full height the resized image would be, so it's filled up
    363     // incrementally as we progress downwards through the input creating im2col
    364     // patches.
    365     task_params.cache_start_x = -filter_left_offset;
    366     task_params.cache_end_x =
    367         (((output_width - 1) * stride_cols) - filter_left_offset) +
    368         filter_width;
    369     task_params.cache_line_width =
    370         task_params.cache_end_x - task_params.cache_start_x;
    371     task_params.cache_height =
    372         kResizeCacheSize / (task_params.cache_line_width * input_depth);
    373     const int needed_resize_cache_count =
    374         filter_height * task_params.cache_line_width * input_depth;
    375     OP_REQUIRES(context,
    376                 (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize,
    377                 errors::InvalidArgument("Input too large for resize cache"));
    378     Im2ColBufferResource<T1, kResizeCacheSize>* resize_cache_resource;
    379     std::function<Status(Im2ColBufferResource<T1, kResizeCacheSize>**)>
    380         resize_creator =
    381             [](Im2ColBufferResource<T1, kResizeCacheSize>** resource) {
    382               *resource = new Im2ColBufferResource<T1, kResizeCacheSize>();
    383               return Status::OK();
    384             };
    385     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
    386                                 "Conv2d", "resize_cache",
    387                                 &resize_cache_resource, resize_creator));
    388 
    389     // This means that multiple ops can't be run simultaneously on different
    390     // threads, because we have a single shared resource. The platforms this is
    391     // aimed at have intra-op parallelism as their focus though, so it shouldn't
    392     // be an issue.
    393     mutex_lock lock_buffer(im2col_buffer_resource->mu);
    394     core::ScopedUnref unref_buffer(im2col_buffer_resource);
    395     T1* im2col_buffer = im2col_buffer_resource->data;
    396 
    397     // This buffer is used as a fairly heavy-weight cache for the resized and
    398     // mirrored inputs to the im2col operation. The problem is that we want to
    399     // keep the memory usage down by not rendering the fully resized and padded
    400     // input tensor to the convolution into an entire buffer. The first approach
    401     // to avoid this was to fold the bilinear filtering and padding spatial
    402     // transformations into the im2col lookup itself. This successfully reduced
    403     // memory usage, but because im2col can access an individual pixel for many
    404     // different patches, the extra overhead of doing the same bilinear lookups
    405     // repeatedly became too expensive.
    406     // The resize cache is designed to avoid this problem by keeping a
    407     // horizontal slice of the resized and padded input to the im2col
    408     // precalculated, so that repeated accesses to the same pixel from different
    409     // filter patches can just be copied from this cache. It's organized as a
    410     // horizontal slice stretching across the whole virtual image, and as high
    411     // as the filter window, so that as the patch processing moves across all
    412     // the pixels are present, and before a new row of patches is started any
    413     // previously calculated rows that are needed are maintained, with new rows
    414     // calculated as required.
    415     mutex_lock resize_lock_buffer(resize_cache_resource->mu);
    416     core::ScopedUnref unref_resized_cache(resize_cache_resource);
    417     task_params.resize_cache = resize_cache_resource->data;
    418 
    419     const T1* input_data = input.flat<T1>().data();
    420     const int64 input_height = input.shape().dim_sizes()[1];
    421     task_params.input_width = input.shape().dim_sizes()[2];
    422 
    423     int end_cached_lines = std::numeric_limits<int>::min();
    424 
    425     for (int batch = 0; batch < input_batches; ++batch) {
    426       task_params.input_batch_start =
    427           input_data +
    428           (batch * input_height * task_params.input_width * input_depth);
    429       const int in_y_end =
    430           ((output_height * stride_rows) - filter_top_offset) + filter_height;
    431       for (int out_y = 0; out_y < output_height; ++out_y) {
    432         const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
    433         const int cache_start_y = std::max(in_y_origin, end_cached_lines);
    434         const int cache_end_y = std::min(
    435             in_y_end, std::max((in_y_origin + task_params.cache_height),
    436                                end_cached_lines));
    437         if (end_cached_lines < (in_y_origin + filter_height)) {
    438           // This call breaks up the work required for calculating the mirror
    439           // padding and resizing across multiple threads.
    440           FusedConvParallelFor(
    441               context, cache_start_y, cache_end_y,
    442               [task_params](int64 task_cache_start_y, int64 task_cache_end_y) {
    443                 // This is a long and confusing function, but it's been laid out
    444                 // this way to help with performance on some intensive models.
    445                 // What it's doing is populating a cache of the original input
    446                 // image, after it's been bilinear resized and had its edges
    447                 // mirrored. This allows the following im2col code to access the
    448                 // transformed pixels from this cache, without having to
    449                 // repeatedly apply the expensive bilinear calculations as the
    450                 // same pixels are accessed by different patches.
    451                 // This is most effective when the stride is small and the
    452                 // filter size is large, since that's when pixels are reused
    453                 // most frequently as patches overlap.
    454                 for (int cache_y = task_cache_start_y;
    455                      cache_y < task_cache_end_y; ++cache_y) {
    456                   // We organize the cache as a series of rows, each containing
    457                   // all the transformed pixels for a given line in the image.
    458                   // This cache is big enough to hold at least a filter's height
    459                   // worth of rows, but typically more, limited by the size of
    460                   // the cache buffer.
    461                   // We don't allocate an entire image's worth of rows though,
    462                   // because we're trying to keep memory usage down, so as we
    463                   // progress downwards through the im2col we periodically
    464                   // refresh the cache so that the next lines that are needed
    465                   // for that operation are always present.
    466                   // Work out the parameters that remain constant across the
    467                   // row we're calculating.
    468                   PerCacheLineParameters<float> line_params(
    469                       CalculatePerCacheLineParameters<float>(
    470                           task_params.cache_height, cache_y,
    471                           task_params.resize_cache,
    472                           task_params.cache_line_width, task_params.input_width,
    473                           task_params.input_depth, task_params.top_padding,
    474                           task_params.pad_offset, task_params.resized_height,
    475                           task_params.st, task_params.input_batch_start));
    476                   // Iterate through the resize cache row we're filling in.
    477                   for (int cache_x = task_params.cache_start_x;
    478                        cache_x < task_params.cache_end_x; ++cache_x) {
    479                     // Figure out what we need for the cache pixel we're
    480                     // populating.
    481                     PerCachePixelParameters<T1> pixel_params(
    482                         CalculatePerCachePixelParameters<T1>(
    483                             cache_x, task_params.cache_start_x,
    484                             line_params.cache_line_start,
    485                             task_params.input_depth, task_params.left_padding,
    486                             task_params.pad_offset, task_params.resized_width,
    487                             task_params.st));
    488                     // If the access is off the left, right, top, or bottom of
    489                     // the resized image, the conv padding means we should set
    490                     // it to zero.
    491                     if ((cache_x < 0) ||
    492                         (cache_x >= task_params.padded_width) ||
    493                         (cache_y < 0) ||
    494                         (cache_y >= task_params.padded_height)) {
    495                       std::fill_n(pixel_params.cache_line_pixel,
    496                                   task_params.input_depth, T1(0));
    497                     } else {
    498                       // There are two different sampling strategies for
    499                       // resizing. When using nearest, we can just do a
    500                       // straight copy of the pixel closest to our sample point,
    501                       // but bilinear requires a more complex calculation.
    502                       if (SampleMode == NEAREST) {
    503                         const T1* input_top_left_pixel =
    504                             line_params.input_top_row_start +
    505                             (pixel_params.left_x_index *
    506                              task_params.input_depth);
    507 
    508                         std::copy_n(input_top_left_pixel,
    509                                     task_params.input_depth,
    510                                     pixel_params.cache_line_pixel);
    511                       } else {
    512                         const SampleRect<T1> rect(
    513                             line_params.input_top_row_start +
    514                                 (pixel_params.left_x_index *
    515                                  task_params.input_depth),
    516                             line_params.input_top_row_start +
    517                                 (pixel_params.right_x_index *
    518                                  task_params.input_depth),
    519                             line_params.input_bottom_row_start +
    520                                 (pixel_params.left_x_index *
    521                                  task_params.input_depth),
    522                             line_params.input_bottom_row_start +
    523                                 (pixel_params.right_x_index *
    524                                  task_params.input_depth));
    525                         for (int in_channel = 0;
    526                              in_channel < task_params.input_depth;
    527                              ++in_channel) {
    528                           pixel_params.cache_line_pixel[in_channel] =
    529                               rect.BilinearSample(in_channel,
    530                                                   pixel_params.x_lerp,
    531                                                   line_params.y_lerp);
    532                         }
    533                       }
    534                     }
    535                   }
    536                 }
    537               });
    538           end_cached_lines = cache_end_y;
    539         }
    540         for (int out_x = 0; out_x < output_width; ++out_x) {
    541           const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
    542           const int patch_index = (batch * output_width * output_height) +
    543                                   (out_y * output_width) + out_x;
    544           const int patch_index_within_chunk = patch_index % patches_per_chunk;
    545           T1* im2col_patch_start =
    546               im2col_buffer + (patch_index_within_chunk * filter_value_count);
    547           for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
    548             T1* im2col_row_start =
    549                 im2col_patch_start +
    550                 (filter_y * filter_width * task_params.input_depth);
    551             const int conv_in_y = in_y_origin + filter_y;
    552             int cache_index_y;
    553             if (conv_in_y < 0) {
    554               cache_index_y = task_params.cache_height +
    555                               (conv_in_y % task_params.cache_height);
    556             } else {
    557               cache_index_y = conv_in_y % task_params.cache_height;
    558             }
    559             T1* cache_line_start =
    560                 task_params.resize_cache +
    561                 (cache_index_y * task_params.cache_line_width *
    562                  task_params.input_depth);
    563             T1* cache_filter_row_start =
    564                 cache_line_start + ((in_x_origin - task_params.cache_start_x) *
    565                                     task_params.input_depth);
    566             std::copy_n(cache_filter_row_start,
    567                         (filter_width * task_params.input_depth),
    568                         im2col_row_start);
    569           }
    570           const bool is_last_in_chunk =
    571               (patch_index_within_chunk == (patches_per_chunk - 1));
    572           const bool is_last_overall =
    573               ((batch == (input_batches - 1)) &&
    574                (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
    575           if (is_last_in_chunk || is_last_overall) {
    576             // Now we've assembled a set of image patches into a matrix, apply
    577             // a GEMM matrix multiply of the patches as rows, times the filter
    578             // weights in columns, to get partial results in the output
    579             // matrix.
    580             const int how_many_patches = patch_index_within_chunk + 1;
    581             const int m = how_many_patches;
    582             const int n = filter_count;
    583             const int k = filter_value_count;
    584             const int lda = filter_value_count;
    585             const int ldb = filter_count;
    586             const int ldc = filter_count;
    587             const size_t start_patch_index =
    588                 patch_index - (how_many_patches - 1);
    589             T3* chunk_output_data =
    590                 output_data + (start_patch_index * filter_count);
    591             TGemmFunctor gemm_functor;
    592             gemm_functor(context, m, n, k, im2col_buffer, lda, filter_data, ldb,
    593                          chunk_output_data, ldc);
    594           }
    595         }
    596       }
    597     }
    598   }
    599 };
    600 
    601 }  // namespace
    602 
    603 // Implements a version of convolution with bilinear resizing and mirror padding
    604 // included.
    605 template <class T, class TConvFunctor, bool DoResize>
    606 class FusedResizeConv2DUsingGemmOp : public OpKernel {
    607  public:
    608   explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
    609       : OpKernel(context) {
    610     if (DoResize) {
    611       OP_REQUIRES_OK(context,
    612                      context->GetAttr("resize_align_corners", &align_corners_));
    613     }
    614     MirrorPadMode mode;
    615     OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
    616 
    617     switch (mode) {
    618       case MirrorPadMode::SYMMETRIC: {
    619         offset_ = 0;
    620         break;
    621       }
    622       case MirrorPadMode::REFLECT: {
    623         offset_ = 1;
    624         break;
    625       }
    626       default:
    627         OP_REQUIRES(context, false,
    628                     errors::InvalidArgument(
    629                         "mode must be either REFLECT or SYMMETRIC."));
    630     }
    631     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    632     OP_REQUIRES(context, strides_.size() == 4,
    633                 errors::InvalidArgument("Sliding window strides field must "
    634                                         "specify 4 dimensions"));
    635     const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
    636     const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
    637     OP_REQUIRES(
    638         context, stride_n == 1 && stride_c == 1,
    639         errors::InvalidArgument("Current implementation does not yet support "
    640                                 "strides in the batch and depth dimensions."));
    641     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    642   }
    643 
    644   void Compute(OpKernelContext* context) override {
    645     // Input tensor is of the following dimensions:
    646     // [ batch, in_rows, in_cols, in_depth ]
    647     const Tensor& input = context->input(0);
    648     OP_REQUIRES(context, (input.shape().num_elements() > 0),
    649                 errors::InvalidArgument("Input tensor can't be empty"));
    650 
    651     ImageResizerState st(false);
    652     if (DoResize) {
    653       st = ImageResizerState(align_corners_);
    654       st.ValidateAndCalculateOutputSize(context, input);
    655       if (!context->status().ok()) return;
    656     } else {
    657       // Set up the resize parameters to do no scaling at all.
    658       st.batch_size = input.dim_size(0);
    659       st.out_height = input.dim_size(1);
    660       st.out_width = input.dim_size(2);
    661       st.in_height = input.dim_size(1);
    662       st.in_width = input.dim_size(2);
    663       st.channels = input.dim_size(3);
    664       st.height_scale = 1.0f;
    665       st.width_scale = 1.0f;
    666     }
    667     TensorShape resized_shape(
    668         {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
    669     int paddings_index;
    670     int filter_index;
    671     if (DoResize) {
    672       paddings_index = 2;
    673       filter_index = 3;
    674     } else {
    675       paddings_index = 1;
    676       filter_index = 2;
    677     }
    678     const Tensor& paddings = context->input(paddings_index);
    679 
    680     const int dims = resized_shape.dims();
    681     OP_REQUIRES(
    682         context,
    683         TensorShapeUtils::IsMatrix(paddings.shape()) &&
    684             paddings.dim_size(1) == 2,
    685         errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
    686                                 paddings.shape().DebugString()));
    687     const int fixed_dims =
    688         (allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
    689             ? 1
    690             : dims;
    691     OP_REQUIRES(
    692         context, fixed_dims == paddings.dim_size(0),
    693         errors::InvalidArgument(
    694             "The first dimension of paddings must be the rank of inputs: ",
    695             fixed_dims, " ", paddings.shape().DebugString(), " ",
    696             resized_shape.DebugString()));
    697     OP_REQUIRES(
    698         context, dims == paddings.dim_size(0),
    699         errors::InvalidArgument(
    700             "The first dimension of paddings must be the rank of inputs: ",
    701             dims, " ", paddings.shape().DebugString(), " ",
    702             resized_shape.DebugString()));
    703 
    704     OP_REQUIRES(
    705         context, dims == 4,
    706         errors::InvalidArgument(
    707             "Fused mirror padding only supports four-dimensional inputs, but ",
    708             dims, " requested"));
    709 
    710     // Compute the shape of the output tensor, and allocate it.
    711     TensorShape padded_shape;
    712     TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
    713     for (int d = 0; d < dims; ++d) {
    714       const int32 before =
    715           paddings_matrix(d, 0);  // Pad before existing elements.
    716       const int32 after =
    717           paddings_matrix(d, 1);  // Pad after existing elements.
    718       OP_REQUIRES(context, before >= 0 && after >= 0,
    719                   errors::InvalidArgument(
    720                       "paddings must be non-negative: ", before, " ", after));
    721       if (offset_ == 0) {  // SYMMETRIC mode.
    722         OP_REQUIRES(
    723             context,
    724             before <= resized_shape.dim_size(d) &&
    725                 after <= resized_shape.dim_size(d),
    726             errors::InvalidArgument("paddings must be no greater "
    727                                     "than the dimension size: ",
    728                                     before, ", ", after, " greater than ",
    729                                     resized_shape.dim_size(d)));
    730       } else if (offset_ == 1) {  // REFLECT mode.
    731         OP_REQUIRES(
    732             context,
    733             before < resized_shape.dim_size(d) &&
    734                 after < resized_shape.dim_size(d),
    735             errors::InvalidArgument("paddings must be less than"
    736                                     " the dimension size: ",
    737                                     before, ", ", after, " not less than ",
    738                                     resized_shape.dim_size(d)));
    739       }
    740       padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
    741     }
    742 
    743     OP_REQUIRES(
    744         context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
    745         errors::InvalidArgument(
    746             "Fused mirror padding only support spatial padding, not batches: ",
    747             paddings.DebugString()));
    748     OP_REQUIRES(
    749         context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
    750         errors::InvalidArgument(
    751             "Fused mirror padding only support spatial padding, not channels: ",
    752             paddings.DebugString()));
    753     const int32 top_padding = paddings_matrix(1, 0);
    754     const int32 bottom_padding = paddings_matrix(1, 1);
    755     const int32 left_padding = paddings_matrix(2, 0);
    756     const int32 right_padding = paddings_matrix(2, 1);
    757 
    758     // Input filter is of the following dimensions:
    759     // [ filter_rows, filter_cols, in_depth, out_depth]
    760     const Tensor& filter = context->input(filter_index);
    761 
    762     // For 2D convolution, there should be 4 dimensions.
    763     OP_REQUIRES(context, padded_shape.dims() == 4,
    764                 errors::InvalidArgument("input must be 4-dimensional",
    765                                         padded_shape.DebugString()));
    766     OP_REQUIRES(context, filter.dims() == 4,
    767                 errors::InvalidArgument("filter must be 4-dimensional: ",
    768                                         filter.shape().DebugString()));
    769 
    770     // We only check the first three dims, since the depth is accessed as an
    771     // int64 below.
    772     for (int i = 0; i < 3; i++) {
    773       OP_REQUIRES(
    774           context,
    775           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
    776           errors::InvalidArgument("filter too large"));
    777     }
    778 
    779     // The last dimension for input is in_depth. It must be the same as the
    780     // filter's in_depth.
    781     const int64 in_depth = padded_shape.dim_size(3);
    782     OP_REQUIRES(context, in_depth == filter.dim_size(2),
    783                 errors::InvalidArgument(
    784                     "input and filter must have the same depth: ", in_depth,
    785                     " vs ", filter.dim_size(2)));
    786 
    787     // The last dimension for filter is out_depth.
    788     const int out_depth = static_cast<int>(filter.dim_size(3));
    789 
    790     // The second dimension for input is rows/height.
    791     // The first dimension for filter is rows/height.
    792     const int64 padded_rows_raw = padded_shape.dim_size(1);
    793     OP_REQUIRES(
    794         context,
    795         FastBoundsCheck(padded_rows_raw, std::numeric_limits<int>::max()),
    796         errors::InvalidArgument("Input rows too large"));
    797     const int padded_rows = static_cast<int>(padded_rows_raw);
    798     const int filter_rows = static_cast<int>(filter.dim_size(0));
    799     const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
    800 
    801     // The third dimension for input is columns/width.
    802     // The second dimension for filter is columns/width.
    803     const int64 padded_cols_raw = padded_shape.dim_size(2);
    804     OP_REQUIRES(
    805         context,
    806         FastBoundsCheck(padded_cols_raw, std::numeric_limits<int>::max()),
    807         errors::InvalidArgument("Input cols too large"));
    808     const int padded_cols = static_cast<int>(padded_cols_raw);
    809     const int filter_cols = static_cast<int>(filter.dim_size(1));
    810     const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
    811 
    812     // The first dimension for input is batch.
    813     const int64 batch_raw = padded_shape.dim_size(0);
    814     OP_REQUIRES(context,
    815                 FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
    816                 errors::InvalidArgument("batch is too large"));
    817     const int batch = static_cast<int>(batch_raw);
    818 
    819     // For now we take the stride from the second and third dimensions only (we
    820     // do not support striding on the batch or depth dimension).
    821     const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
    822     const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
    823 
    824     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
    825     OP_REQUIRES_OK(context,
    826                    GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
    827                                          padding_, &out_rows, &pad_rows));
    828     OP_REQUIRES_OK(context,
    829                    GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
    830                                          padding_, &out_cols, &pad_cols));
    831     TensorShape out_shape =
    832         ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
    833     OP_REQUIRES(context, (out_shape.num_elements() > 0),
    834                 errors::InvalidArgument("Output tensor can't be empty"));
    835 
    836     // Output tensor is of the following dimensions:
    837     // [ in_batch, out_rows, out_cols, out_depth ]
    838     Tensor* output = nullptr;
    839     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    840 
    841     VLOG(2) << "FusedConv2D: " << name() << ", in_depth = " << in_depth
    842             << ", padded_cols = " << padded_cols
    843             << ", resized_cols = " << resized_cols
    844             << ", filter_cols = " << filter_cols
    845             << ", padded_rows = " << padded_rows
    846             << ", resized_rows = " << resized_rows
    847             << ", filter_rows = " << filter_rows
    848             << ", stride_rows = " << stride_rows
    849             << ", stride_cols = " << stride_cols
    850             << ", out_depth = " << out_depth << ", DoResize=" << DoResize;
    851 
    852     // If there is nothing to compute, return.
    853     if (out_shape.num_elements() == 0) {
    854       return;
    855     }
    856     TConvFunctor conv_functor;
    857     conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
    858                  padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
    859                  filter_cols, out_depth, stride_rows, stride_cols, padding_,
    860                  output->flat<T>().data(), out_rows, out_cols, st, top_padding,
    861                  bottom_padding, left_padding, right_padding, offset_);
    862   }
    863 
    864  private:
    865   std::vector<int32> strides_;
    866   Padding padding_;
    867   bool align_corners_;
    868   int offset_;
    869 
    870   TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
    871 };
    872 
    873 #define REGISTER_FUSED(T)                                                 \
    874   REGISTER_KERNEL_BUILDER(                                                \
    875       Name("FusedResizeAndPadConv2D")                                     \
    876           .Device(DEVICE_CPU)                                             \
    877           .TypeConstraint<T>("T"),                                        \
    878       FusedResizeConv2DUsingGemmOp<                                       \
    879           T,                                                              \
    880           FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
    881                                        BILINEAR>,                         \
    882           true>);
    883 
    884 TF_CALL_float(REGISTER_FUSED);
    885 
    886 #define REGISTER_PAD_ONLY_FUSED(T)                                        \
    887   REGISTER_KERNEL_BUILDER(                                                \
    888       Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
    889       FusedResizeConv2DUsingGemmOp<                                       \
    890           T,                                                              \
    891           FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
    892                                        NEAREST>,                          \
    893           false>);
    894 
    895 TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
    896 
    897 }  // namespace tensorflow
    898