Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define USE_EIGEN_TENSOR
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/kernels/deep_conv2d.h"
     20 
     21 #include <stdlib.h>
     22 
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/kernels/winograd_transform.h"
     25 #include "tensorflow/core/util/work_sharder.h"
     26 
     27 namespace tensorflow {
     28 
     29 // DeepConv2D is a Conv2D implementation specialized for deep convolutions (i.e
     30 // large 'in_depth' and 'out_depth' product. See cost models below for details).
     31 //
     32 // DeepConv2D is implemented by computing the following equation:
     33 //
     34 //   y = C[Ad * Bg]
     35 //
     36 //   C: output transform matrix
     37 //   A: input data transform matrix
     38 //   B: filter transform matrix
     39 //   d: vectorized data tile
     40 //   g: vectorized filter tile
     41 //   y: vectorized output tile
     42 //
     43 // The transform matrices and input, filter and output tile sizes are all
     44 // specified by the DeepConv2DTransform implementation selected at the
     45 // start of the DeepConv2D call, based on convolution parameters.
     46 
     47 // Approximate cost models for direct and deep convolutions.
     48 static int64 GetDeepConvCost(int input_tile_rows, int input_tile_cols,
     49                              int out_tile_rows, int out_tile_cols, int in_depth,
     50                              int out_depth, int out_rows, int out_cols) {
     51   // Input transform cost.
     52   const int64 input_tile_spatial_size = input_tile_rows * input_tile_cols;
     53   const int64 input_transform_cost =
     54       input_tile_spatial_size * input_tile_spatial_size * in_depth;
     55 
     56   // Element-wise products (each product is a MatMul across depth).
     57   const int64 product_cost = input_tile_spatial_size * in_depth * out_depth;
     58 
     59   // Output transform cost.
     60   const int64 output_tile_spatial_size = out_tile_rows * out_tile_cols;
     61   const int64 output_transform_cost =
     62       output_tile_spatial_size * input_tile_spatial_size * out_depth;
     63 
     64   // Calculate number of input tiles to process.
     65   const int64 row_tiles = (out_rows + out_tile_rows - 1) / out_tile_rows;
     66   const int64 col_tiles = (out_cols + out_tile_cols - 1) / out_tile_cols;
     67   const int64 num_tiles = row_tiles * col_tiles;
     68 
     69   // Return total cost.
     70   return num_tiles *
     71          (input_transform_cost + product_cost + output_transform_cost);
     72 }
     73 
     74 static int64 GetDirectConvCost(int filter_rows, int filter_cols, int in_depth,
     75                                int out_depth, int out_rows, int out_cols) {
     76   return filter_rows * filter_cols * in_depth * out_depth * out_rows * out_cols;
     77 }
     78 
     79 // Reads environment variable 'env_var_name'.
     80 // Returns 'true' if environment variable is enabled, false otherwise.
     81 static bool ReadBoolFromEnvVar(const char* env_var_name, bool default_val) {
     82   const char* tf_env_var_val = getenv(env_var_name);
     83   if (tf_env_var_val != nullptr) {
     84     StringPiece tf_env_var_val_str(tf_env_var_val);
     85     if (tf_env_var_val_str == "0") {
     86       return false;
     87     }
     88     return true;
     89   }
     90   return default_val;
     91 }
     92 
     93 // Returns true if convolution can be computed efficiently by DeepConv2D,
     94 // returns false otherwise.
     95 // TODO(andydavis) Add support for other filter sizes and strides.
     96 // TODO(andydavis) Add support for autotuning.
     97 bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows,
     98                       int filter_cols, int in_depth, int out_depth,
     99                       int out_rows, int out_cols) {
    100   // Check if convolution parameters are supported.
    101   // TODO(andydavis) Add support for multiple filter sizes and strides.
    102   if (stride_rows > 1 || stride_cols > 1 || filter_rows != 3 ||
    103       filter_cols != 3) {
    104     return false;
    105   }
    106 
    107   // Check if deep convolution is enabled by environment variable.
    108   // NOTE: IF this environment variable name changes, update conv_ops_test.py.
    109   if (!ReadBoolFromEnvVar("TF_USE_DEEP_CONV2D", false)) {
    110     return false;
    111   }
    112 
    113   // Check if flop cost of deep convolution is less than direct convolution.
    114   WinogradTransform<float> t;
    115   const int64 deep_conv_cost = GetDeepConvCost(
    116       t.input_shape().rows, t.input_shape().cols, t.output_shape().rows,
    117       t.output_shape().cols, in_depth, out_depth, out_rows, out_cols);
    118   const int64 direct_conv_cost = GetDirectConvCost(
    119       filter_rows, filter_cols, in_depth, out_depth, out_rows, out_cols);
    120 
    121   VLOG(2) << "CanUseDeepConv2D"
    122           << " deep_conv_cost: " << deep_conv_cost
    123           << " direct_conv_cost: " << direct_conv_cost << " deep_direct_ratio: "
    124           << (static_cast<float>(deep_conv_cost) /
    125               static_cast<float>(direct_conv_cost))
    126           << " use_deep_conv: " << (deep_conv_cost < direct_conv_cost);
    127   return deep_conv_cost < direct_conv_cost;
    128 }
    129 
    130 typedef Eigen::ThreadPoolDevice CPUDevice;
    131 
    132 // Copies data from 'filter_in' to 'filter_buf' along 'in_depth' dimension.
    133 //
    134 // filter_in:
    135 //   [filter_rows, filter_cols, in_depth, out_depth]
    136 //
    137 // filter_buf:
    138 //   [base_filter_rows, base_filter_cols, in_depth]
    139 //
    140 template <typename T>
    141 struct CopyFilterDepth {
    142   void operator()(const Conv2DArgs& args, const T* filter_in, T* filter_buf) {
    143     typedef typename Eigen::internal::packet_traits<T>::type Packet;
    144     static constexpr int64 kPacketSize = (sizeof(Packet) / sizeof(T));
    145 
    146     const int64 vectorized_size = args.in_depth / kPacketSize;
    147     const int64 scalar_size = args.in_depth % kPacketSize;
    148     const int64 input_stride = args.out_depth * kPacketSize;
    149 
    150     // Copy vectorized portion of depth dimension.
    151     for (int64 d = 0; d < vectorized_size; ++d) {
    152       auto v = Eigen::internal::pgather<T, Packet>(filter_in + d * input_stride,
    153                                                    args.out_depth);
    154       Eigen::internal::pstoreu<T>(filter_buf + d * kPacketSize, v);
    155     }
    156     // Copy scalar portion of inner dimension.
    157     const int64 in_scalar_base = vectorized_size * input_stride;
    158     const int64 buf_scalar_base = vectorized_size * kPacketSize;
    159     for (int64 d = 0; d < scalar_size; ++d) {
    160       filter_buf[buf_scalar_base + d] =
    161           filter_in[in_scalar_base + d * args.out_depth];
    162     }
    163   }
    164 };
    165 
    166 // Computes transform of 'num_filters' from 'filter_in' starting at 'od_start'.
    167 // Intermediate results (i.e. output of MatMul('transform_matrix', 'filter_in'))
    168 // are stored in 'out_buffer'. The final result is copied from 'out_buffer' to
    169 // 'filter_out' at the coordinate stride required by the transformed filter
    170 // data layout.
    171 //
    172 // filter_in:
    173 //   [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
    174 //    in_depth]
    175 //
    176 // filter_out:
    177 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
    178 //
    179 // transform_matrix:
    180 //   [tile_spatial_size, base_filter_spatial_size]
    181 //
    182 // out_buffer:
    183 //   [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
    184 
    185 template <typename T>
    186 struct ComputeFilterRangeTransform {
    187   typedef typename Eigen::internal::packet_traits<T>::type Packet;
    188   static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
    189 
    190   typedef Eigen::Map<
    191       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
    192       MatrixMap;
    193   typedef Eigen::Map<
    194       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
    195       ConstMatrixMap;
    196 
    197   void operator()(const Conv2DArgs& args,
    198                   const DeepConv2DTransform<T>* transform, const int64 od_start,
    199                   const int64 num_filters, const int64 shard_rows,
    200                   const int64 shard_cols, const T* filter_in,
    201                   const int64 in_stride, const int64 out_stride,
    202                   const T* transform_matrix, T* out_buffer, T* filter_out) {
    203     namespace ei = Eigen::internal;
    204 
    205     const int64 in_depth = args.in_depth;
    206     const int64 base_filter_rows = transform->filter_shape().rows;
    207     const int64 base_filter_cols = transform->filter_shape().cols;
    208     const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
    209     const int64 tile_rows = transform->input_shape().rows;
    210     const int64 tile_cols = transform->input_shape().cols;
    211     const int64 tile_spatial_size = tile_rows * tile_cols;
    212 
    213     // Compute transform of 'num_filters' by 'transform_matrix'.
    214     ConstMatrixMap A(transform_matrix, tile_spatial_size,
    215                      base_filter_spatial_size);
    216     ConstMatrixMap B(filter_in, base_filter_spatial_size, in_stride);
    217     MatrixMap C(out_buffer, tile_spatial_size, in_stride);
    218 
    219     C.noalias() = A * B;
    220 
    221     // Copy 'out_buffer' to 'filter_out' at required filter output stride.
    222     const int64 scalar_size = in_depth % kPacketSize;
    223     const int64 vectorized_size = in_depth / kPacketSize;
    224 
    225     const int64 shard_stride = args.in_depth;
    226     const int64 out_depth_stride = shard_rows * shard_cols * shard_stride;
    227 
    228     for (int64 od = 0; od < num_filters; ++od) {
    229       const int64 out_depth_buf_base = od * out_depth_stride;
    230       const int64 out_depth_base = (od_start + od) * out_depth_stride;
    231 
    232       // TODO(andydavis) Shard filters that are multiples of base filter sizes.
    233       for (int64 s_r = 0; s_r < shard_rows; ++s_r) {
    234         for (int64 s_c = 0; s_c < shard_cols; ++s_c) {
    235           const int64 shard_base = shard_stride * (s_r * shard_cols + s_c);
    236 
    237           for (int64 i = 0; i < tile_spatial_size; ++i) {
    238             const int64 in_base =
    239                 i * in_stride + out_depth_buf_base + shard_base;
    240             const int64 out_base = i * out_stride + out_depth_base + shard_base;
    241             // Copy vectorized portion of 'in_depth'.
    242             for (int64 d = 0; d < vectorized_size; ++d) {
    243               auto v =
    244                   ei::ploadu<Packet>(out_buffer + in_base + d * kPacketSize);
    245               ei::pstoreu<T>(filter_out + out_base + d * kPacketSize, v);
    246             }
    247             // Transform scalar portion of 'in_depth'.
    248             const int64 scalar_base = vectorized_size * kPacketSize;
    249             for (int64 d = 0; d < scalar_size; ++d) {
    250               filter_out[out_base + scalar_base + d] =
    251                   out_buffer[in_base + scalar_base + d];
    252             }
    253           }
    254         }
    255       }
    256     }
    257   }
    258 };
    259 
    260 // Transforms 'num_filters' from 'filter_in', starting at 'od_start'.
    261 // For each filter in 'num_filters', copies data for all filter shards from
    262 // 'filter_in' into 'filter_buf', adding zero-padding as needed.
    263 // Calls ComputeFilterRangeTransform to compute filter transform of data
    264 // in 'filter_buf' by 'transform_matrix', storing the result in 'filter_out'.
    265 //
    266 // filter_in:
    267 //   [filter_rows, filter_cols, in_depth, out_depth]
    268 //
    269 // filter_out:
    270 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
    271 //
    272 // filter_buffer:
    273 //   [base_filter_rows, base_filter_cols, num_filters, shard_rows, shard_cols,
    274 //    in_depth]
    275 //
    276 // transform_matrix:
    277 //   [tile_spatial_size, base_filter_spatial_size]
    278 //
    279 // out_buffer:
    280 //   [tile_spatial_size, num_filters, shard_rows, shard_cols, in_depth]
    281 //
    282 
    283 template <typename T>
    284 struct TransformFilterRange {
    285   void operator()(const Conv2DArgs& args,
    286                   const DeepConv2DTransform<T>* transform, const int64 od_start,
    287                   const int64 od_limit, const T* filter_in,
    288                   const T* transform_matrix, T* out_buffer, T* filter_buf,
    289                   T* filter_out) {
    290     const int64 num_filters = od_limit - od_start;
    291     const int64 base_filter_rows = transform->filter_shape().rows;
    292     const int64 base_filter_cols = transform->filter_shape().cols;
    293     const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
    294 
    295     // Compute number of filter shards.
    296     const int64 residual_row =
    297         std::max(0LL, args.filter_rows - base_filter_rows);
    298     const int64 shard_rows = 1 + (residual_row + 2 - 1) / 2;
    299 
    300     const int64 residual_col =
    301         std::max(0LL, args.filter_cols - base_filter_cols);
    302     const int64 shard_cols = 1 + (residual_col + 2 - 1) / 2;
    303 
    304     // Compute strides to be used for input and output IO.
    305     const int64 shard_stride = args.in_depth;
    306     const int64 out_depth_stride = shard_rows * shard_cols * shard_stride;
    307     const int64 coord_stride = out_depth_stride * args.out_depth;
    308     const int64 filter_buf_stride =
    309         num_filters * shard_rows * shard_cols * args.in_depth;
    310     const int64 tile_stride_rows = transform->output_shape().rows;
    311     const int64 tile_stride_cols = transform->output_shape().cols;
    312 
    313     const int64 filter_buf_size = base_filter_spatial_size * num_filters *
    314                                   shard_rows * shard_cols * args.in_depth;
    315     memset(filter_buf, 0, sizeof(T) * filter_buf_size);
    316 
    317     // Copy filter range into 'filter_buf'.
    318     for (int64 od = 0; od < num_filters; ++od) {
    319       const int64 out_depth_base = od * out_depth_stride;
    320 
    321       // TODO(andydavis) Shard filters that are multiples of base filter sizes.
    322       for (int64 s_r = 0; s_r < shard_rows; ++s_r) {
    323         const int64 row_offset = s_r == 0 ? 0 : 1;
    324 
    325         for (int64 s_c = 0; s_c < shard_cols; ++s_c) {
    326           const int64 col_offset = s_c == 0 ? 0 : 1;
    327           const int64 f_r_start = s_r * tile_stride_rows;
    328           const int64 f_c_start = s_c * tile_stride_cols;
    329 
    330           const int64 shard_base = shard_stride * (s_r * shard_cols + s_c);
    331 
    332           for (int64 b_r = row_offset; b_r < base_filter_rows; ++b_r) {
    333             const int64 f_r = f_r_start + b_r;
    334             if (f_r >= args.filter_rows) continue;
    335 
    336             for (int64 b_c = col_offset; b_c < base_filter_cols; ++b_c) {
    337               const int64 f_c = f_c_start + b_c;
    338               if (f_c >= args.filter_cols) continue;
    339 
    340               const int64 in_index =
    341                   args.out_depth *
    342                       (args.in_depth * (f_r * args.filter_cols + f_c)) +
    343                   (od_start + od);
    344 
    345               const int64 buf_index =
    346                   filter_buf_stride * (b_r * base_filter_cols + b_c) +
    347                   out_depth_base + shard_base;
    348 
    349               CopyFilterDepth<T>()(args, filter_in + in_index,
    350                                    filter_buf + buf_index);
    351             }
    352           }
    353         }
    354       }
    355     }
    356 
    357     // Compute filter transform of data in 'filter_buf' by 'transform_matrix'.
    358     // Intermediate results are stored in 'out_buffer'.
    359     // Final results are stored in 'filter_out'.
    360     ComputeFilterRangeTransform<T>()(args, transform, od_start, num_filters,
    361                                      shard_rows, shard_cols, filter_buf,
    362                                      filter_buf_stride, coord_stride,
    363                                      transform_matrix, out_buffer, filter_out);
    364   }
    365 };
    366 
    367 // Transforms all filters from 'filter_in', storing result in 'filter_out'.
    368 //
    369 // filter_in:
    370 //   [filter_rows, filter_cols, in_depth, out_depth]
    371 //
    372 // filter_out:
    373 //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
    374 //
    375 template <typename T>
    376 struct TransformFilters {
    377   void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
    378                   const DeepConv2DTransform<T>* transform,
    379                   const int64 filter_shards_row, const int64 filter_shards_col,
    380                   const T* filter_in, T* filter_out) {
    381     const int64 in_depth = args.in_depth;
    382     const int64 out_depth = args.out_depth;
    383 
    384     const int64 tile_rows = transform->input_shape().rows;
    385     const int64 tile_cols = transform->input_shape().cols;
    386     const int64 tile_spatial_size = tile_rows * tile_cols;
    387 
    388     const int64 base_filter_rows = transform->filter_shape().rows;
    389     const int64 base_filter_cols = transform->filter_shape().cols;
    390     const int64 base_filter_spatial_size = base_filter_rows * base_filter_cols;
    391 
    392     const int64 filter_shards_total = filter_shards_row * filter_shards_col;
    393 
    394     // Calculate filter transform batch based on cache/filter sizes.
    395 
    396     // Cache budget (based on L2 cache size = 256KB).
    397     // TODO(andydavis) Read cache size from system.
    398     const int64 cache_size = (256LL << 10) / sizeof(T);
    399 
    400     // Fixed cost.
    401     const int64 filter_transform_matrix_size =
    402         tile_spatial_size * base_filter_spatial_size;
    403 
    404     // Per-filter costs.
    405     const int64 filter_total_size =
    406         base_filter_spatial_size * in_depth * filter_shards_total;
    407 
    408     const int64 filter_transform_buffer_size =
    409         base_filter_spatial_size * filter_shards_total * in_depth;
    410 
    411     const int64 filter_out_buf_size =
    412         tile_spatial_size * filter_shards_total * in_depth;
    413 
    414     // Total per-filter costs.
    415     const int64 per_filter_cost =
    416         filter_total_size + filter_transform_buffer_size + filter_out_buf_size;
    417 
    418     // Remove fixed cost and divide by per-filter cost.
    419     const int64 num_filters_cache = std::max(
    420         1LL, (cache_size - filter_transform_matrix_size) / per_filter_cost);
    421     const int64 num_filters_transform = std::min(out_depth, num_filters_cache);
    422 
    423     // Allocate buffer for filter transform matrix:
    424     //   [tile_spatial_size, base_filter_spatial_size]
    425     Tensor filter_transform_matrix;
    426     OP_REQUIRES_OK(
    427         ctx, ctx->allocate_temp(
    428                  DataTypeToEnum<T>::value,
    429                  TensorShape({tile_spatial_size, base_filter_spatial_size}),
    430                  &filter_transform_matrix));
    431     T* transform_matrix = filter_transform_matrix.template flat<T>().data();
    432     transform->GetFilterTransformMatrix(
    433         tile_spatial_size, base_filter_spatial_size, transform_matrix);
    434 
    435     auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols,
    436                   &num_filters_transform, &in_depth, &out_depth,
    437                   &filter_shards_row, &filter_shards_col, &tile_spatial_size,
    438                   &filter_in, &transform_matrix,
    439                   &filter_out](int64 start, int64 limit) {
    440       // Allocate buffer for pre-processed filter:
    441       //   [base_filter_rows, base_filter_cols, num_filters_transform, in_depth]
    442       //
    443       Tensor filter_transform_buffer;
    444       OP_REQUIRES_OK(ctx,
    445                      ctx->allocate_temp(
    446                          DataTypeToEnum<T>::value,
    447                          TensorShape({base_filter_rows, base_filter_cols,
    448                                       num_filters_transform, filter_shards_row,
    449                                       filter_shards_col, in_depth}),
    450                          &filter_transform_buffer));
    451       T* filter_buf = filter_transform_buffer.template flat<T>().data();
    452 
    453       // Allocate buffer for output filter transform matrix:
    454       //   [tile_rows, tile_cols, out_depth, shard_rows, shard_cols, in_depth]
    455       Tensor filter_output_buffer;
    456       OP_REQUIRES_OK(
    457           ctx,
    458           ctx->allocate_temp(
    459               DataTypeToEnum<T>::value,
    460               TensorShape({tile_spatial_size, num_filters_transform,
    461                            filter_shards_row, filter_shards_col, in_depth}),
    462               &filter_output_buffer));
    463       T* out_buffer = filter_output_buffer.template flat<T>().data();
    464 
    465       const int64 num_filters = limit - start;
    466       const int64 od_unroll = num_filters_transform;
    467       const int64 od_unroll_limit = (num_filters / od_unroll) * od_unroll;
    468 
    469       for (int64 od = start; od < od_unroll_limit; od += od_unroll) {
    470         TransformFilterRange<T>()(args, transform, od, od + od_unroll,
    471                                   filter_in, transform_matrix, out_buffer,
    472                                   filter_buf, filter_out);
    473       }
    474 
    475       if (od_unroll_limit < limit) {
    476         TransformFilterRange<T>()(args, transform, od_unroll_limit, limit,
    477                                   filter_in, transform_matrix, out_buffer,
    478                                   filter_buf, filter_out);
    479       }
    480     };
    481     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
    482 
    483     const int64 shard_cost = args.filter_rows * args.filter_cols * in_depth *
    484                              filter_shards_total * tile_spatial_size;
    485     // TODO(andydavis) Resolve performance of multi-threaded filter transforms.
    486     Shard(1, worker_threads.workers, out_depth, shard_cost, shard);
    487   }
    488 };
    489 
    490 // Packs transformed filters stored in 'lhs_input' into 'lhs_block' in a
    491 // gemm-kernel friendly data layout.
    492 //
    493 // Data layout for 'lhs_block':
    494 //   [out_depth, shard_rows, shard_cols, in_depth].
    495 
    496 template <typename T>
    497 class GemmFilterPacker {
    498  public:
    499   typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::RowMajor>
    500       LhsMapper;
    501   typedef Eigen::internal::gebp_traits<T, T> Traits;
    502   Eigen::internal::gemm_pack_lhs<T, int64, LhsMapper, Traits::mr,
    503                                  Traits::LhsProgress, Eigen::RowMajor>
    504       pack_lhs;
    505 
    506   GemmFilterPacker(const int64 rows, const int64 depth, const T* lhs_input,
    507                    T* lhs_block)
    508       : rows_(rows),
    509         depth_(depth),
    510         lhs_block_(lhs_block),
    511         lhs_mapper_(lhs_input, depth_) {}
    512 
    513   void Run() { pack_lhs(lhs_block_, lhs_mapper_, depth_, rows_); }
    514 
    515  private:
    516   const int64 rows_;
    517   const int64 depth_;
    518   T* lhs_block_;
    519   LhsMapper lhs_mapper_;
    520 };
    521 
    522 // Packs transformed filter stored in 'filter_transform_data' into
    523 // 'packed_filters' to be used by GemmState.
    524 template <typename T>
    525 struct PackFilters {
    526   void operator()(OpKernelContext* ctx, const Conv2DArgs& args,
    527                   const int64 tile_spatial_size, const int64 filter_shards_row,
    528                   const int64 filter_shards_col, const T* filter_transform_data,
    529                   std::vector<Tensor>* packed_filters) {
    530     const int64 in_depth = args.in_depth;
    531     const int64 out_depth = args.out_depth;
    532     const int64 num_filters = filter_shards_row * filter_shards_col * out_depth;
    533 
    534     auto shard = [&ctx, &packed_filters, &filter_transform_data,
    535                   &tile_spatial_size, &in_depth, &out_depth, &filter_shards_row,
    536                   &filter_shards_col, &num_filters](int64 start, int64 limit) {
    537       const int64 filter_coord_stride = num_filters * in_depth;
    538       for (int64 i = start; i < limit; ++i) {
    539         // Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth].
    540         OP_REQUIRES_OK(
    541             ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    542                                     TensorShape({out_depth, filter_shards_row,
    543                                                  filter_shards_col, in_depth}),
    544                                     &(*packed_filters)[i]));
    545         T* packed_filter = (*packed_filters)[i].template flat<T>().data();
    546         // Pack filters.
    547         GemmFilterPacker<T> packer(
    548             num_filters, in_depth,
    549             filter_transform_data + i * filter_coord_stride, packed_filter);
    550         packer.Run();
    551       }
    552     };
    553     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
    554     Shard(worker_threads.num_threads, worker_threads.workers, tile_spatial_size,
    555           num_filters * in_depth, shard);
    556   }
    557 };
    558 
    559 // Computes the product of filters stored in 'lhs_block' and input tiles
    560 // stored in 'rhs_block', storing output in 'out_buffer'.
    561 //
    562 // Data layout for 'lhs_block':
    563 //   [out_depth, shard_rows, shard_cols, in_depth].
    564 //
    565 // Data layout for 'rhs_block':
    566 //   [num_tiles, in_depth]
    567 //
    568 // Data layout for 'out_buffer':
    569 //   [num_tiles, out_depth, shard_rows, shard_cols]
    570 
    571 template <typename T>
    572 class GemmState {
    573  public:
    574   typedef Eigen::internal::const_blas_data_mapper<T, int64, Eigen::ColMajor>
    575       RhsMapper;
    576   typedef Eigen::internal::blas_data_mapper<T, int64, Eigen::ColMajor>
    577       OutputMapper;
    578   typedef Eigen::internal::gebp_traits<T, T> Traits;
    579 
    580   Eigen::internal::gemm_pack_rhs<T, int64, RhsMapper, Traits::nr,
    581                                  Eigen::ColMajor>
    582       pack_rhs;
    583   Eigen::internal::gebp_kernel<T, T, int64, OutputMapper, Traits::mr,
    584                                Traits::nr, false, false>
    585       gebp;
    586 
    587   GemmState(const int64 rows, const int64 cols, const int64 depth,
    588             const int64 out_buffer_size, const T* lhs_block, const T* rhs_input,
    589             T* rhs_block, T* out_buffer)
    590       : rows_(rows),
    591         cols_(cols),
    592         depth_(depth),
    593         out_buffer_size_(out_buffer_size),
    594         lhs_block_(lhs_block),
    595         rhs_block_(rhs_block),
    596         out_buffer_(out_buffer),
    597         rhs_mapper_(rhs_input, depth_),
    598         out_mapper_(out_buffer, rows_) {}
    599 
    600   void PackRhs() { pack_rhs(rhs_block_, rhs_mapper_, depth_, cols_); }
    601 
    602   void Compute() {
    603     memset(out_buffer_, 0, sizeof(T) * out_buffer_size_);
    604     gebp(out_mapper_, lhs_block_, rhs_block_, rows_, depth_, cols_, 1.0);
    605   }
    606 
    607  private:
    608   const int64 rows_;
    609   const int64 cols_;
    610   const int64 depth_;
    611   const int64 out_buffer_size_;
    612   const T* lhs_block_;
    613   T* rhs_block_;
    614   T* out_buffer_;
    615   RhsMapper rhs_mapper_;
    616   OutputMapper out_mapper_;
    617 };
    618 
    619 // Copies an input tile from 'input' into 'tile_buffer'.
    620 //
    621 // input:
    622 //   [in_rows, in_cols, in_depth]
    623 //
    624 // tile_buffer:
    625 //   [tile_rows, tile_cols, num_tiles, in_depth]
    626 
    627 template <typename T>
    628 struct CopyInputTile {
    629   void operator()(const Conv2DArgs& args,
    630                   const DeepConv2DTransform<T>* transform,
    631                   const int64 num_tiles, const int64 in_r_start,
    632                   const int64 in_c_start, const T* input, T* tile_buffer) {
    633     typedef typename Eigen::internal::packet_traits<T>::type Packet;
    634     static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
    635 
    636     const int64 tile_rows = transform->input_shape().rows;
    637     const int64 tile_cols = transform->input_shape().cols;
    638     const int64 coord_stride = num_tiles * args.in_depth;
    639 
    640     // Calculate vectorized and scalar (residual) lengths for 'in_depth'.
    641     const int64 input_vectorized_size =
    642         (args.in_depth / kPacketSize) * kPacketSize;
    643     const int64 input_scalar_size = args.in_depth % kPacketSize;
    644 
    645     for (int64 r = 0; r < tile_rows; ++r) {
    646       const int64 in_r = in_r_start + r;
    647       if (in_r < 0 || in_r >= args.in_rows) continue;
    648 
    649       for (int64 c = 0; c < tile_cols; ++c) {
    650         const int64 in_c = in_c_start + c;
    651         if (in_c < 0 || in_c >= args.in_cols) continue;
    652 
    653         auto* in = input + (in_r * args.in_cols + in_c) * args.in_depth;
    654         auto* tile = tile_buffer + coord_stride * (r * tile_rows + c);
    655         // Copy vectorized portion of depth dimension.
    656         for (int64 d = 0; d < input_vectorized_size; d += kPacketSize) {
    657           auto v = Eigen::internal::ploadu<Packet>(in + d);
    658           Eigen::internal::pstoreu<T>(tile, v);
    659           tile += kPacketSize;
    660         }
    661         // Copy scalar portion of inner dimension.
    662         for (int64 d = 0; d < input_scalar_size; ++d) {
    663           tile[d] = in[input_vectorized_size + d];
    664         }
    665       }
    666     }
    667   }
    668 };
    669 
    670 // Transforms 'num_tiles' tiles from 'input' by 'transform_matrix', storing the
    671 // final result in 'tile_transform'.
    672 // Intermediate results are stored in 'tile_buffer'.
    673 //
    674 // input:
    675 //   [in_rows, in_cols, in_depth]
    676 // tile_buffer:
    677 //   [tile_rows, tile_cols, num_tiles, in_depth]
    678 // tile_transform_matrix:
    679 //   [tile_spatial_size, tile_spatial_size]
    680 // tile_transform:
    681 //   [tile_rows, tile_cols, num_tiles, in_depth]
    682 
    683 template <typename T>
    684 struct TransformInputTiles {
    685   typedef Eigen::Map<
    686       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
    687       MatrixMap;
    688   typedef Eigen::Map<
    689       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
    690       ConstMatrixMap;
    691 
    692   void operator()(const Conv2DArgs& args,
    693                   const DeepConv2DTransform<T>* transform,
    694                   const int64 num_tiles, const int64 in_r_start,
    695                   const int64 in_c_start, const T* input,
    696                   const T* transform_matrix, T* tile_buffer,
    697                   T* tile_transform) {
    698     const int64 tile_rows = transform->input_shape().rows;
    699     const int64 tile_cols = transform->input_shape().cols;
    700     const int64 tile_spatial_size = tile_rows * tile_cols;
    701     const int64 tile_stride_cols = transform->output_shape().cols;
    702     const int64 coord_stride = num_tiles * args.in_depth;
    703     const int64 num_tiles_stride = args.in_depth;
    704 
    705     memset(tile_buffer, 0, sizeof(T) * tile_spatial_size * coord_stride);
    706     const int64 in_r = in_r_start;
    707     for (int64 t = 0; t < num_tiles; ++t) {
    708       const int64 num_tiles_base = t * num_tiles_stride;
    709       const int64 in_c = in_c_start + t * tile_stride_cols;
    710       CopyInputTile<T>()(args, transform, num_tiles, in_r, in_c, input,
    711                          tile_buffer + num_tiles_base);
    712     }
    713 
    714     ConstMatrixMap A(transform_matrix, tile_spatial_size, tile_spatial_size);
    715     ConstMatrixMap B(tile_buffer, tile_spatial_size, coord_stride);
    716     MatrixMap C(tile_transform, tile_spatial_size, coord_stride);
    717 
    718     C.noalias() = A * B;
    719   }
    720 };
    721 
    722 // Transforms output tiles from buffer by 'out_transform_matrix', storing
    723 // final result in 'output' (intermediate results stored in 'out_buffer').
    724 //
    725 // out_buffer:
    726 //   [tile_rows, tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
    727 //
    728 // output transform buffer:
    729 //  [out_tile_rows, out_tile_cols, num_tiles, out_depth, shard_rows, shard_cols]
    730 //
    731 // output:
    732 //   [out_rows, out_cols, out_depth]
    733 //
    734 
    735 template <typename T>
    736 struct TransformOutputTile {
    737   typedef Eigen::Map<
    738       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
    739       MatrixMap;
    740   typedef Eigen::Map<
    741       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
    742       ConstMatrixMap;
    743 
    744   void operator()(const Conv2DArgs& args,
    745                   const DeepConv2DTransform<T>* transform,
    746                   const int64 num_tiles, const int64 in_r, const int64 in_c,
    747                   const int64 filter_shards_row, const int64 filter_shards_col,
    748                   const T* out_transform_matrix, const T* out_buffer,
    749                   T* out_transform_buffer, T* output) {
    750     const int64 tile_rows = transform->input_shape().rows;
    751     const int64 tile_cols = transform->input_shape().cols;
    752     const int64 tile_spatial_size = tile_rows * tile_cols;
    753 
    754     const int64 out_buf_stride =
    755         num_tiles * args.out_depth * filter_shards_row * filter_shards_col;
    756 
    757     const int64 out_tile_rows = transform->output_shape().rows;
    758     const int64 out_tile_cols = transform->output_shape().cols;
    759     const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
    760 
    761     // Compute output transform.
    762     ConstMatrixMap A(out_transform_matrix, out_tile_spatial_size,
    763                      tile_spatial_size);
    764     ConstMatrixMap B(out_buffer, tile_spatial_size, out_buf_stride);
    765     MatrixMap C(out_transform_buffer, out_tile_spatial_size, out_buf_stride);
    766 
    767     C.noalias() = A * B;
    768 
    769     const int64 tile_stride_rows = transform->output_shape().rows;
    770     const int64 tile_stride_cols = transform->output_shape().cols;
    771 
    772     const int64 out_depth_stride = filter_shards_row * filter_shards_col;
    773     const int64 num_tiles_stride = args.out_depth * out_depth_stride;
    774 
    775     // Copy transformed output from 'out_transform_buffer' to proper index
    776     // in 'output'. Note that some outputs at boundaries can be discarded.
    777     for (int64 t = 0; t < num_tiles; ++t) {
    778       const int64 tile_base = t * num_tiles_stride;
    779 
    780       for (int64 od = 0; od < args.out_depth; ++od) {
    781         const int64 out_depth_base = od * out_depth_stride;
    782 
    783         // TODO(andydavis) Update filter sharding scheme in the next CL.
    784         for (int64 sr = 0; sr < filter_shards_row; ++sr) {
    785           for (int64 sc = 0; sc < filter_shards_col; ++sc) {
    786             const int64 shard_base = sr * filter_shards_col + sc;
    787             const int64 out_buf_base = tile_base + out_depth_base + shard_base;
    788 
    789             // Calcuate output indices and outputs to drop (if needed).
    790             const int64 out_r_start =
    791                 in_r + args.pad_rows - sr * tile_stride_rows;
    792             // NOTE: The index 't' for 'num_tiles is used in index calculation
    793             // for 'out_c_start' because we 'num_tiles' progresses along the
    794             // column dimension.
    795             const int64 out_c_start = (in_c + t * tile_stride_cols) +
    796                                       args.pad_cols - sc * tile_stride_cols;
    797 
    798             if (out_r_start < 0 || out_r_start >= args.out_rows ||
    799                 out_c_start < 0 || out_c_start >= args.out_cols) {
    800               continue;  // Skip un-needed outputs.
    801             }
    802 
    803             // Increment output if not first filter shard.
    804             const bool inc_output = (sr == 0 && sc == 0) ? false : true;
    805 
    806             for (int64 ot_row = 0; ot_row < out_tile_rows; ++ot_row) {
    807               const int64 out_r = out_r_start + ot_row;
    808               if (out_r >= args.out_rows) continue;
    809 
    810               for (int64 ot_col = 0; ot_col < out_tile_cols; ++ot_col) {
    811                 const int64 out_c = out_c_start + ot_col;
    812                 if (out_c >= args.out_cols) continue;
    813 
    814                 // Calculate out tile indexl
    815                 const int64 out_buf_index = ot_row * out_tile_cols + ot_col;
    816                 // Read output value from buffer.
    817                 const T out_val =
    818                     out_transform_buffer[out_buf_base +
    819                                          out_buf_index * out_buf_stride];
    820                 // Calculate output index.
    821                 const int64 output_index =
    822                     args.out_depth * (out_r * args.out_cols + out_c) + od;
    823                 // Update output.
    824                 if (inc_output) {
    825                   output[output_index] += out_val;
    826                 } else {
    827                   output[output_index] = out_val;
    828                 }
    829               }
    830             }
    831           }
    832         }
    833       }
    834     }
    835   }
    836 };
    837 
    838 template <typename T>
    839 struct Conv2DState {
    840   Conv2DState(const int64 tile_spatial_size, const int64 filter_shards_row,
    841               const int64 filter_shards_col, const T* input,
    842               const T* tile_transform_matrix, const T* output_transform_matrix,
    843               T* buffer1, T* buffer2, T* packed_tile_buffer,
    844               T* gemm_output_buffer)
    845       : tile_spatial_size(tile_spatial_size),
    846         filter_shards_row(filter_shards_row),
    847         filter_shards_col(filter_shards_col),
    848         input(input),
    849         tile_transform_matrix(tile_transform_matrix),
    850         output_transform_matrix(output_transform_matrix),
    851         buffer1(buffer1),
    852         buffer2(buffer2),
    853         packed_tile_buffer(packed_tile_buffer),
    854         gemm_output_buffer(gemm_output_buffer) {}
    855 
    856   const int64 tile_spatial_size;
    857   const int64 filter_shards_row;
    858   const int64 filter_shards_col;
    859   const T* input;
    860   const T* tile_transform_matrix;
    861   const T* output_transform_matrix;
    862   T* buffer1;
    863   T* buffer2;
    864   T* packed_tile_buffer;
    865   T* gemm_output_buffer;
    866 };
    867 
    868 // Computes Conv2D for 'num_tiles' input tiles from 'input' starting at
    869 // (in_r, in_c), storing the results of the computation in 'output'.
    870 // Details:
    871 // *) Transforms 'num_tiles' input tiles into 'tile_transform_buffer'.
    872 // *) Computes point-wise MatMuls of 'num_tiles' input tiles with all filters.
    873 // *) Transforms output tiles, and stores result to 'output'.
    874 
    875 // TODO(andydavis) Maybe pass Conv2DState into TransformInput/Output functions.
    876 template <typename T>
    877 struct ComputeConv2D {
    878   void operator()(const Conv2DArgs& args,
    879                   const DeepConv2DTransform<T>* transform,
    880                   const Conv2DState<T>& cs, const int64 in_r, const int64 in_c,
    881                   const int64 num_tiles,
    882                   const std::vector<Tensor>& packed_filters, const T* input,
    883                   T* output) {
    884     // Transform input tiles.
    885     TransformInputTiles<T>()(args, transform, num_tiles, in_r, in_c, input,
    886                              cs.tile_transform_matrix, cs.buffer1, cs.buffer2);
    887 
    888     // Compute element-wise product (each a MatMul): input tiles X filters.
    889     const int64 in_depth = args.in_depth;
    890     const int64 out_depth = args.out_depth;
    891     const int64 num_filters =
    892         cs.filter_shards_row * cs.filter_shards_col * out_depth;
    893     const int64 tile_coord_stride = num_tiles * in_depth;
    894     const int64 gemm_out_buf_size = num_tiles * num_filters;
    895     const int64 gemm_out_buf_bytes = gemm_out_buf_size * sizeof(T);
    896 
    897     for (int64 i = 0; i < cs.tile_spatial_size; ++i) {
    898       GemmState<T> gemm(num_filters, num_tiles, in_depth, gemm_out_buf_size,
    899                         packed_filters[i].template flat<T>().data(),
    900                         cs.buffer2 + i * tile_coord_stride,
    901                         cs.packed_tile_buffer, cs.gemm_output_buffer);
    902       // Pack tile buffer.
    903       gemm.PackRhs();
    904       // Compute product.
    905       gemm.Compute();
    906       // Copy to larger output buffer without alignment requirements.
    907       memcpy(cs.buffer1 + i * gemm_out_buf_size, cs.gemm_output_buffer,
    908              gemm_out_buf_bytes);
    909     }
    910 
    911     // Transform output.
    912     TransformOutputTile<T>()(args, transform, num_tiles, in_r, in_c,
    913                              cs.filter_shards_row, cs.filter_shards_col,
    914                              cs.output_transform_matrix, cs.buffer1, cs.buffer2,
    915                              output);
    916   }
    917 };
    918 
    919 namespace functor {
    920 
    921 // Conv2D operation specialized for deep convolutions (i.e. large
    922 // in_depth * out_depth).
    923 // Details:
    924 // *) Transforms and packs filters from 'filter' in parallel.
    925 // *) Computes Conv2D parallelized across 'batch' dimension.
    926 //   *) Each thread loops over images in its batch shard, copying 'num_tiles'
    927 //      input tiles into a local buffer, and computing the Conv2D output of
    928 //      these tiles by all filters.
    929 
    930 // TODO(andydavis) Improve the performance of boundary cases where the input
    931 // tile extends past the limit, and wasted outputs are computed. This overhead
    932 // is at most 2/n, where 'n' is the max(out_rows, out_cols), and so is worse
    933 // for smaller spatial sizes.
    934 // TODO(andydavis) Improve the performance of sharded filters.
    935 template <typename T>
    936 struct DeepConv2D<CPUDevice, T> {
    937   void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input,
    938                   const T* filter, T* output) {
    939     // TODO(andydavis) Add function to select transform based on conv params.
    940     std::unique_ptr<DeepConv2DTransform<T>> transform(new WinogradTransform<T>);
    941 
    942     const int64 in_depth = args.in_depth;
    943     const int64 out_depth = args.out_depth;
    944 
    945     const int64 tile_rows = transform->input_shape().rows;
    946     const int64 tile_cols = transform->input_shape().cols;
    947     const int64 tile_spatial_size = tile_rows * tile_cols;
    948 
    949     const int64 out_tile_rows = transform->output_shape().rows;
    950     const int64 out_tile_cols = transform->output_shape().cols;
    951     const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
    952 
    953     const int64 base_filter_rows = transform->filter_shape().rows;
    954 
    955     const int64 filter_residual_row =
    956         std::max(0LL, args.filter_rows - base_filter_rows);
    957     const int64 filter_shards_row = 1 + (filter_residual_row + 2 - 1) / 2;
    958 
    959     const int64 filter_residual_col =
    960         std::max(0LL, args.filter_cols - base_filter_rows);
    961     const int64 filter_shards_col = 1 + (filter_residual_col + 2 - 1) / 2;
    962 
    963     // Allocate buffer for transformed filters.
    964     Tensor filter_transform;
    965     OP_REQUIRES_OK(
    966         ctx, ctx->allocate_temp(
    967                  DataTypeToEnum<T>::value,
    968                  TensorShape({tile_rows, tile_cols, out_depth,
    969                               filter_shards_row, filter_shards_col, in_depth}),
    970                  &filter_transform));
    971     T* filter_transform_data = filter_transform.template flat<T>().data();
    972 
    973     // Transform filters.
    974     TransformFilters<T>()(ctx, args, transform.get(), filter_shards_row,
    975                           filter_shards_col, filter, filter_transform_data);
    976 
    977     // Pack filters.
    978     std::vector<Tensor> packed_filters(tile_spatial_size);
    979     PackFilters<T>()(ctx, args, tile_spatial_size, filter_shards_row,
    980                      filter_shards_col, filter_transform_data, &packed_filters);
    981 
    982     // Allocate buffer for tile transform matrix.
    983     Tensor tile_transform_matrix_tensor;
    984     OP_REQUIRES_OK(ctx, ctx->allocate_temp(
    985                             DataTypeToEnum<T>::value,
    986                             TensorShape({tile_spatial_size, tile_spatial_size}),
    987                             &tile_transform_matrix_tensor));
    988     T* tile_transform_matrix =
    989         tile_transform_matrix_tensor.template flat<T>().data();
    990     transform->GetInputTransformMatrix(tile_spatial_size, tile_spatial_size,
    991                                        tile_transform_matrix);
    992 
    993     // Allocate buffer for output transform matrix.
    994     Tensor output_transform_matrix_tensor;
    995     OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    996                                            TensorShape({out_tile_spatial_size,
    997                                                         tile_spatial_size}),
    998                                            &output_transform_matrix_tensor));
    999     T* output_transform_matrix =
   1000         output_transform_matrix_tensor.template flat<T>().data();
   1001     transform->GetOutputTransformMatrix(
   1002         out_tile_spatial_size, tile_spatial_size, output_transform_matrix);
   1003 
   1004     auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth,
   1005                   out_depth, tile_rows, tile_cols, out_tile_rows, out_tile_cols,
   1006                   filter_shards_row, filter_shards_col, tile_spatial_size,
   1007                   &input, &tile_transform_matrix, &output_transform_matrix,
   1008                   &output](int64 batch_start, int64 batch_limit) {
   1009       const int64 row_tiles =
   1010           (args.out_rows + out_tile_rows - 1) / out_tile_rows +
   1011           filter_shards_row - 1;
   1012       const int64 col_tiles =
   1013           (args.out_cols + out_tile_cols - 1) / out_tile_cols +
   1014           filter_shards_col - 1;
   1015 
   1016       // Calculate number of tiles to process together.
   1017       const int64 filter_shard_size = filter_shards_row * filter_shards_col;
   1018       const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
   1019 
   1020       // Cache budget (based on L2 cache size = 256KB).
   1021       // TODO(andydavis) Read cache size from the system.
   1022       const int64 cache_size = (256LL << 10) / sizeof(T);
   1023 
   1024       // Fixed costs.
   1025       const int64 tile_transform_matrix_size =
   1026           tile_spatial_size * tile_spatial_size;
   1027       const int64 output_transform_matrix_size =
   1028           out_tile_spatial_size * tile_spatial_size;
   1029       // Calculate cache reserve size.
   1030       const int64 filter_depth_size = in_depth * out_depth * filter_shard_size;
   1031       const bool small_filter = ((filter_depth_size * 100) / cache_size) <= 25;
   1032       const int64 cache_reserve_size = small_filter ? filter_depth_size : 1024;
   1033       // Calculate total fixed cost.
   1034       const int64 total_fixed_cost = tile_transform_matrix_size +
   1035                                      output_transform_matrix_size +
   1036                                      cache_reserve_size;
   1037 
   1038       // Per-tile costs.
   1039       const int64 buffer1_per_tile_size =
   1040           tile_spatial_size * std::max(in_depth, out_depth * filter_shard_size);
   1041       const int64 buffer2_per_tile_size =
   1042           std::max(tile_spatial_size * in_depth,
   1043                    out_tile_spatial_size * out_depth * filter_shard_size);
   1044       const int64 packed_tile_per_tile_size = in_depth;
   1045       const int64 gemm_out_per_tile_size = out_depth * filter_shard_size;
   1046       const int64 total_per_tile_cost =
   1047           buffer1_per_tile_size + buffer2_per_tile_size +
   1048           packed_tile_per_tile_size + gemm_out_per_tile_size;
   1049 
   1050       const int64 num_tiles_cache =
   1051           std::max(4LL, (cache_size - total_fixed_cost) / total_per_tile_cost);
   1052       const int64 num_tiles = std::min(num_tiles_cache, col_tiles);
   1053 
   1054       // Allocate temporary buffer 'buffer1', which is first used for copying
   1055       // input tiles, then re-used to buffer gemm output. Calculate the
   1056       // required buffer size for 'buffer1', based on max buffer size required
   1057       // between copying input tiles and buffering gemm product output.
   1058       //   buffer1: [max(buf1_tile_size, buf1_out_size)]
   1059       const int64 buffer1_tile_size = tile_spatial_size * num_tiles * in_depth;
   1060       const int64 buffer1_out_size =
   1061           tile_spatial_size * num_tiles * out_depth * filter_shard_size;
   1062       const int64 buffer1_size = std::max(buffer1_tile_size, buffer1_out_size);
   1063       Tensor buffer1_tensor;
   1064       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
   1065                                              TensorShape({buffer1_size}),
   1066                                              &buffer1_tensor));
   1067       T* buffer1 = buffer1_tensor.template flat<T>().data();
   1068 
   1069       // Allocate temporary buffer 'buffer2', which is first used for
   1070       // transformed input tiles, then re-used for transformed output tiles.
   1071       // Calculate required buffer size for 'buffer2' as max required buffer
   1072       // between input and output transform buffer sizes.
   1073       const int64 buffer2_tile_transform_size =
   1074           tile_spatial_size * num_tiles * in_depth;
   1075       const int64 buffer2_out_transform_size =
   1076           out_tile_spatial_size * num_tiles * out_depth * filter_shard_size;
   1077       const int64 buffer2_size =
   1078           std::max(buffer2_tile_transform_size, buffer2_out_transform_size);
   1079       Tensor buffer2_tensor;
   1080       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
   1081                                              TensorShape({buffer2_size}),
   1082                                              &buffer2_tensor));
   1083       T* buffer2 = buffer2_tensor.template flat<T>().data();
   1084 
   1085       // Allocate temporary buffer to store packed tiles for one coordinate.
   1086       // packed tile buffer: [num_tiles, in_depth].
   1087       Tensor packed_tile_tensor;
   1088       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
   1089                                              TensorShape({num_tiles, in_depth}),
   1090                                              &packed_tile_tensor));
   1091       T* packed_tile_buffer = packed_tile_tensor.template flat<T>().data();
   1092 
   1093       // Allocate temporary buffer for gemm output.
   1094       // gemm output buffer [num_tiles, out_depth, shard_rows, shard_cols].
   1095       Tensor gemm_output_tensor;
   1096       OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
   1097                                              TensorShape({num_tiles, out_depth,
   1098                                                           filter_shards_row,
   1099                                                           filter_shards_col}),
   1100                                              &gemm_output_tensor));
   1101       T* gemm_output_buffer = gemm_output_tensor.template flat<T>().data();
   1102 
   1103       // Capture state needed for ComputeConv2D inner loop.
   1104       Conv2DState<T> conv_state(tile_spatial_size, filter_shards_row,
   1105                                 filter_shards_col, input, tile_transform_matrix,
   1106                                 output_transform_matrix, buffer1, buffer2,
   1107                                 packed_tile_buffer, gemm_output_buffer);
   1108 
   1109       const int64 row_pad = args.pad_rows;
   1110       const int64 col_pad = args.pad_cols;
   1111       const int64 unroll_col_limit = (col_tiles / num_tiles) * num_tiles;
   1112 
   1113       const int64 input_image_size = args.in_rows * args.in_cols * in_depth;
   1114       const int64 output_image_size = args.out_rows * args.out_cols * out_depth;
   1115 
   1116       const int64 tile_stride_rows = transform->output_shape().rows;
   1117       const int64 tile_stride_cols = transform->output_shape().cols;
   1118 
   1119       for (int64 b = batch_start; b < batch_limit; ++b) {
   1120         const int64 in_base = b * input_image_size;
   1121         const int64 out_base = b * output_image_size;
   1122 
   1123         for (int64 tile_r = 0; tile_r < row_tiles; ++tile_r) {
   1124           const int64 in_r = tile_r * tile_stride_rows - row_pad;
   1125 
   1126           // Process unrolled tiles.
   1127           for (int64 tile_c = 0; tile_c < unroll_col_limit;
   1128                tile_c += num_tiles) {
   1129             const int64 in_c = tile_c * tile_stride_cols - col_pad;
   1130             ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
   1131                                num_tiles, packed_filters, input + in_base,
   1132                                output + out_base);
   1133           }
   1134           // Process remaining tiles.
   1135           if (unroll_col_limit < col_tiles) {
   1136             const int64 rem_tiles = col_tiles - unroll_col_limit;
   1137             const int64 in_c = unroll_col_limit * tile_stride_cols - col_pad;
   1138             ComputeConv2D<T>()(args, transform.get(), conv_state, in_r, in_c,
   1139                                rem_tiles, packed_filters, input + in_base,
   1140                                output + out_base);
   1141           }
   1142         }
   1143       }
   1144     };
   1145     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
   1146     const int64 shard_cost = args.out_rows * args.out_cols * args.out_depth *
   1147                              tile_spatial_size * args.in_depth;
   1148     Shard(worker_threads.num_threads, worker_threads.workers, args.batch,
   1149           shard_cost, shard);
   1150   }
   1151 };
   1152 
   1153 }  // namespace functor
   1154 
   1155 template struct functor::DeepConv2D<CPUDevice, float>;
   1156 
   1157 }  // namespace tensorflow
   1158