Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implements quantized eight-bit versions of the convolution operations.
     17 
     18 #include <algorithm>
     19 #include <vector>
     20 
     21 #define EIGEN_USE_THREADS
     22 
     23 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
     24 #include "public/gemmlowp.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/kernels/conv_ops.h"
     28 #include "tensorflow/core/kernels/meta_support.h"
     29 #include "tensorflow/core/kernels/ops_util.h"
     30 #include "tensorflow/core/kernels/quantization_utils.h"
     31 #include "tensorflow/core/kernels/reference_gemm.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/util/padding.h"
     34 
     35 namespace tensorflow {
     36 
     37 // This functor implements the convolution operation in as simple a form as
     38 // possible. It won't give great performance, but it is very useful for
     39 // stepping through and instrumenting for debugging, creating minimal benchmarks
     40 // to prototype with, and sharing with teams that want to run this outside of
     41 // our environment.
     42 // With that in mind, I've avoided using anything except pretty standard C++
     43 // types. This is especially noticeable in the data access through raw array
     44 // indexing. It's deliberate in this case though, since it makes the underlying
     45 // memory order very explicit, which is important for both inspecting memory
     46 // contents during debugging and for specifying what we expect to others.
     47 // The memory layout of the data is, from biggest stride to smallest:
     48 // input_data = [input_batches, input_height, input_width, input_depth]
     49 // filter_data = [filter_height, filter_width, input_depth, filter_count]
     50 // output_data = [input_batches, output_height, output_width, filter_count]
     51 template <class T1, class T2, class T3>
     52 class ReferenceConvFunctor {
     53  public:
     54   void operator()(OpKernelContext* context, const T1* input_data,
     55                   int input_batches, int input_height, int input_width,
     56                   int input_depth, int input_offset, const T2* filter_data,
     57                   int filter_height, int filter_width, int filter_count,
     58                   int filter_offset, int stride, Padding padding,
     59                   T3* output_data, int output_height, int output_width,
     60                   int output_shift, int output_offset, int output_mult) {
     61     // Set up some constants we need for the output down-shifting and
     62     // saturation.
     63     const int32 highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
     64     const int32 lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest());
     65 
     66     // When we're converting the 32 bit accumulator to a lower bit depth, we
     67     // need to add on 0.5 in fixed-point terms to make the operation round half
     68     // up towards positive infinity, rather than a floor.
     69     // We also need to watch out for the case when there's no down shift,
     70     // because a left shift by a negative number gives undefined results.
     71     const int32 rounding = (output_shift < 1) ? 0 : (1 << (output_shift - 1));
     72 
     73     // The two different padding modes we support can be a bit confusing. SAME
     74     // means we're trying to produce an output image that's the same size as the
     75     // input. It's complicated by stride, which shrinks the output image by a
     76     // a factor, but it means we end up sampling from outside the borders of the
     77     // input. These out-of-bounds values are read as zeroes. VALID means only
     78     // produce output values where the filters can read all their values from
     79     // within the input image. It effectively removes the margins of the output
     80     // image compared to the one produced by SAME. Stride complicates this
     81     // definition though, because it can result in the right and bottom filter
     82     // patches sampling from outside the borders if it's greater than 1.
     83     // Most of the logic for sorting this all out is done before this function,
     84     // when we calculate the output size, but the positioning of the origin of
     85     // the filters is different between the two modes, since SAME positions the
     86     // first filter off the edge of the input.
     87     int filter_left_offset;
     88     int filter_top_offset;
     89     if (padding == VALID) {
     90       filter_left_offset =
     91           ((output_width - 1) * stride + filter_width - input_width + 1) / 2;
     92       filter_top_offset =
     93           ((output_height - 1) * stride + filter_height - input_height + 1) / 2;
     94     } else {
     95       filter_left_offset =
     96           ((output_width - 1) * stride + filter_width - input_width) / 2;
     97       filter_top_offset =
     98           ((output_height - 1) * stride + filter_height - input_height) / 2;
     99     }
    100 
    101     // If we've got multiple images in our input, work through each of them.
    102     for (int batch = 0; batch < input_batches; ++batch) {
    103       // Walk through all the output image values, sliding the filter to
    104       // different
    105       // positions in the input.
    106       for (int out_y = 0; out_y < output_height; ++out_y) {
    107         for (int out_x = 0; out_x < output_width; ++out_x) {
    108           // Each filter kernel produces one output channel.
    109           for (int out_channel = 0; out_channel < filter_count; ++out_channel) {
    110             // We're going to calculate a single output value, which means we
    111             // need to multiply a three dimensional kernel of weights against
    112             // the current location within the input image.
    113             /*
    114               *-------------------------------...
    115               |\ ^
    116               | \in_depth
    117               |  \ v
    118               |   *-------------------------------...
    119               |   |            ^
    120               |   |       in_y_origin
    121               |   |            v   \
    122               |   |<in_x_origin>*---*^
    123               |   |            \|   |filter_height
    124               .   |             *---*v
    125               .   |             <--->
    126                   .         filter_width
    127                   .
    128             */
    129             const int in_x_origin = (out_x * stride) - filter_left_offset;
    130             const int in_y_origin = (out_y * stride) - filter_top_offset;
    131             int32 total = 0;
    132             for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
    133               for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
    134                 for (int in_channel = 0; in_channel < input_depth;
    135                      ++in_channel) {
    136                   const int in_x = in_x_origin + filter_x;
    137                   const int in_y = in_y_origin + filter_y;
    138                   int32 input_value;
    139                   // If the location is outside the bounds of the input image,
    140                   // use zero as a default value.
    141                   if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
    142                       (in_y < input_height)) {
    143                     const T1 input_source_value =
    144                         input_data[(batch * input_height * input_width *
    145                                     input_depth) +
    146                                    (in_y * input_width * input_depth) +
    147                                    (in_x * input_depth) + in_channel];
    148                     // We're promoting the T1 type to a higher bit depth here as
    149                     // we do the subtraction.
    150                     input_value =
    151                         static_cast<int32>(input_source_value) - input_offset;
    152                   } else {
    153                     input_value = 0;
    154                   }
    155                   const T2 filter_source_value =
    156                       filter_data[(filter_y * filter_width * input_depth *
    157                                    filter_count) +
    158                                   (filter_x * input_depth * filter_count) +
    159                                   (in_channel * filter_count) + out_channel];
    160                   // Another promotion to 32 bit, as above.
    161                   const int32 filter_value =
    162                       static_cast<int32>(filter_source_value) - filter_offset;
    163                   total += (input_value * filter_value);
    164                 }
    165               }
    166             }
    167             // Here we're applying scale factors to compress the 32 bit
    168             // accumulated total to a potentially lower bit depth.
    169             const int32_t output =
    170                 ((((total + output_offset) * output_mult) + rounding) >>
    171                  output_shift);
    172             // We need to saturate the results against the largest and smallest
    173             // values that can be represented in this type.
    174             const int32 top_clamped_output = std::min(output, highest);
    175             const int32 clamped_output = std::max(top_clamped_output, lowest);
    176             output_data[(batch * output_height * output_width * filter_count) +
    177                         (out_y * output_width * filter_count) +
    178                         (out_x * filter_count) + out_channel] = clamped_output;
    179           }
    180         }
    181       }
    182     }
    183   }
    184 };
    185 
    186 // We don't want to allocate a buffer to hold all the patches if the size is
    187 // going to be extremely large, so break it into chunks if it's bigger than
    188 // a limit. Each chunk will be processed serially, so we can refill the
    189 // buffer for the next chunk and reuse it, keeping maximum memory size down.
    190 // In this case, we've picked 1 megabyte as a reasonable limit, from
    191 // experimentation.
    192 const size_t kMaxChunkSize = (1 * 1024 * 1024);
    193 
    194 // Implements convolution as a two stage process, first packing the patches of
    195 // the input image into columns (im2col) and then running GEMM to produce the
    196 // final result.
    197 template <class T1, class T2, class T3>
    198 class Im2ColConvFunctor {
    199  public:
    200   void operator()(OpKernelContext* context, const T1* input_data,
    201                   int input_batches, int input_height, int input_width,
    202                   int input_depth, int input_offset, const T2* filter_data,
    203                   int filter_height, int filter_width, int filter_count,
    204                   int filter_offset, int stride, Padding padding,
    205                   T3* output_data, int output_height, int output_width,
    206                   int output_shift, int output_offset, int output_mult) {
    207     if (input_offset < 0) {
    208       // Only log the first few occurrences of this warning.
    209       static int warning_count = 0;
    210       if (warning_count < 10) {
    211         ++warning_count;
    212         LOG(WARNING)
    213             << "For kernel '" << context->op_kernel().name() << "' from input '"
    214             << context->op_kernel().requested_input(0)
    215             << "': Zero is not representable in the quantized range used by the"
    216             << " input. This means QuantizedConv2d has to fall back to a slow"
    217             << " implementation, since the border of zero values can't be"
    218             << " represented easily. You should try to construct graphs that"
    219             << " avoid this situation.";
    220       }
    221       ReferenceConvFunctor<T1, T2, T3> conv_functor;
    222       conv_functor(context, input_data, input_batches, input_height,
    223                    input_width, input_depth, input_offset, filter_data,
    224                    filter_height, filter_width, filter_count, filter_offset,
    225                    stride, padding, output_data, output_height, output_width,
    226                    output_shift, output_offset, output_mult);
    227       return;
    228     }
    229 
    230     CHECK_GT(output_width, 0);
    231     CHECK_GT(output_height, 0);
    232     int filter_left_offset;
    233     int filter_top_offset;
    234     if (padding == VALID) {
    235       filter_left_offset =
    236           ((output_width - 1) * stride + filter_width - input_width + 1) / 2;
    237       filter_top_offset =
    238           ((output_height - 1) * stride + filter_height - input_height + 1) / 2;
    239     } else {
    240       filter_left_offset =
    241           ((output_width - 1) * stride + filter_width - input_width) / 2;
    242       filter_top_offset =
    243           ((output_height - 1) * stride + filter_height - input_height) / 2;
    244     }
    245 
    246     // The im2col buffer has # of patches rows, and # of filters cols.
    247     // It's laid out like this, in row major order in memory:
    248     //        < filter value count >
    249     //   ^   +---------------------+
    250     // patch |                     |
    251     // count |                     |
    252     //   v   +---------------------+
    253     // Each patch row contains a filter_width x filter_height patch of the
    254     // input, with the depth channel as the most contiguous in memory, followed
    255     // by the width, then the height. This is the standard memory order in the
    256     // image world if it helps to visualize it.
    257     const int filter_value_count = filter_width * filter_height * input_depth;
    258     const int64 patches_per_chunk =
    259         kMaxChunkSize / (filter_value_count * sizeof(T1));
    260     const int64 chunk_value_count =
    261         (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
    262     // TODO(petewarden) - Memory allocation can be very slow on Android. Can we
    263     // optimize this by keeping the scratch buffer around?
    264     // Because memory allocation is very expensive on mobile platforms, try to
    265     // allocate a persistent buffer that will be kept around between calls. We
    266     // use TensorFlow's resource management to ensure that the memory will be
    267     // released when the session is over.
    268     Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource;
    269     std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)>
    270         creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) {
    271 #ifdef _MSC_VER
    272           // MSVC complains about the capture of chunk_value_count which oddly
    273           // works fine in conv_ops_using_gemm.cc for example.
    274           // Define chunk_value_count inside the lambda for now.
    275           const int64 chunk_value_count =
    276               (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
    277 #endif
    278           *resource = new Im2ColBufferResource<T1, chunk_value_count>();
    279           return Status::OK();
    280         };
    281     OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
    282                                 "Conv2d", "im2col_buffer",
    283                                 &im2col_buffer_resource, creator));
    284     // This means that multiple ops can't be run simultaneously on different
    285     // threads, because we have a single shared resource. The platforms this is
    286     // aimed at have intra-op parallelism as their focus though, so it shouldn't
    287     // be an issue.
    288     mutex_lock lock_buffer(im2col_buffer_resource->mu);
    289     core::ScopedUnref unref_buffer(im2col_buffer_resource);
    290     T1* im2col_buffer = im2col_buffer_resource->data;
    291 
    292     const int64 patch_count = (input_batches * output_height * output_width);
    293     const int64 chunk_count =
    294         (patch_count + (patches_per_chunk - 1)) / patches_per_chunk;
    295 
    296     for (int64 chunk_index = 0; chunk_index < chunk_count; ++chunk_index) {
    297       const int64 patch_index_start = chunk_index * patches_per_chunk;
    298       const int64 patch_index_end =
    299           std::min(patch_index_start + patches_per_chunk, patch_count);
    300       for (int64 patch_index = patch_index_start; patch_index < patch_index_end;
    301            ++patch_index) {
    302         const int64 batch = patch_index / (output_height * output_width);
    303         const int64 out_y = (patch_index / output_width) % output_height;
    304         const int64 out_x = patch_index % output_width;
    305         const T1* input_batch_start =
    306             input_data + (batch * input_height * input_width * input_depth);
    307         const int in_y_origin = (out_y * stride) - filter_top_offset;
    308         const int in_x_origin = (out_x * stride) - filter_left_offset;
    309         const int patch_index_within_chunk = patch_index % patches_per_chunk;
    310         T1* im2col_patch_start =
    311             im2col_buffer + (patch_index_within_chunk * filter_value_count);
    312         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
    313           const int in_y = in_y_origin + filter_y;
    314           T1* im2col_row_start =
    315               im2col_patch_start + (filter_y * filter_width * input_depth);
    316           // If we're off the top or the bottom of the input, fill the
    317           // whole row with zeroes.
    318           if ((in_y < 0) || (in_y >= input_height)) {
    319             // On Android, memset and memcpy are significantly faster than the
    320             // more modern std::set and std::copy equivalents.
    321             memset(im2col_row_start, input_offset,
    322                    (filter_width * input_depth));
    323           } else {
    324             // What we're doing here is trying to copy and fill the im2col
    325             // buffer as efficiently as possible, using functions to set or
    326             // duplicate values en masse. We know we don't have to worry about
    327             // vertical edges because we dealt with that case above, so we
    328             // just need to handle filters that overlap the left or right
    329             // edges. Here's what that looks like:
    330             //
    331             // < left_zero_count > < center_copy_count > < right_zero_count >
    332             // +------------------+---------------------+--------------------+
    333             // |     (filter)     |       (image)       |      (filter)      |
    334             // +------------------+---------------------+--------------------+
    335             // in_x_origin        0                 input_width       in_x_end
    336             //
    337             // In reality it's unlikely that a filter patch will be wider
    338             // than an input, but this shows all the edge cases.
    339             // We use memset() to set the left and right sections to zeroes
    340             // and memcpy() to copy over the input data for the center. These
    341             // are preferred to std::fill and std::copy because they're much
    342             // faster on Android.
    343             const int in_x_end = in_x_origin + filter_width;
    344             const int left_zero_count = std::max(0, 0 - in_x_origin);
    345             const int right_zero_count = std::max(0, in_x_end - input_width);
    346             const int center_copy_count =
    347                 filter_width - (left_zero_count + right_zero_count);
    348             if (left_zero_count > 0) {
    349               T1* im2col_left_start = im2col_row_start;
    350               memset(im2col_left_start, input_offset,
    351                      (left_zero_count * input_depth));
    352             }
    353             if (center_copy_count > 0) {
    354               const T1* input_row_start =
    355                   input_batch_start + (in_y * input_width * input_depth) +
    356                   (std::max(0, in_x_origin) * input_depth);
    357               T1* im2col_center_start =
    358                   im2col_row_start + (left_zero_count * input_depth);
    359               memcpy(im2col_center_start, input_row_start,
    360                      (center_copy_count * input_depth));
    361             }
    362             if (right_zero_count > 0) {
    363               T1* im2col_right_start =
    364                   im2col_row_start +
    365                   ((left_zero_count + center_copy_count) * input_depth);
    366               memset(im2col_right_start, input_offset,
    367                      (right_zero_count * input_depth));
    368             }
    369           }
    370         }
    371       }
    372       // Now we've assembled a set of image patches into a matrix, apply a
    373       // GEMM matrix multiply of the patches as rows, times the filter
    374       // weights in columns, to get partial results in the output matrix.
    375       const int how_many_patches = patch_index_end - patch_index_start;
    376       const bool transpose_a = false;
    377       const bool transpose_b = false;
    378       const bool transpose_c = false;
    379       const int m = how_many_patches;
    380       const int n = filter_count;
    381       const int k = filter_value_count;
    382       const int lda = filter_value_count;
    383       const int ldb = filter_count;
    384       const int ldc = filter_count;
    385       T3* chunk_output_data = output_data + (patch_index_start * filter_count);
    386 
    387       if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
    388           std::is_same<T2, quint8>() && std::is_same<T3, qint32>() &&
    389           (output_offset == 0) && (output_mult == 1) && (output_shift == 0) &&
    390           (transpose_c == false) && (k <= 2048)) {
    391         meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer,
    392                             filter_data, chunk_output_data, m, n, k,
    393                             -input_offset, -filter_offset, lda, ldb, ldc);
    394       } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
    395                  std::is_same<T3, qint32>() && (output_offset == 0) &&
    396                  (output_mult == 1) && (output_shift == 0)) {
    397         // The gemmlowp optimized library only works for a particular set of
    398         // data types, so check if we meet those requirements and fall back to a
    399         // slower reference implementation if not.
    400         const uint8* im2col_data_as_uint8 = &(im2col_buffer->value);
    401         const uint8* filter_data_as_uint8 = &(filter_data->value);
    402         int32* output_data_as_int32 = &(chunk_output_data->value);
    403         // All of the transpose_* variables are currently compile-time consts,
    404         // so we could just hard-code these values too, but that would break if
    405         // anybody changed those values in the future (e.g. to match the ability
    406         // of MatMul to specify them as attributes). We're using a verbose
    407         // approach of deriving the order values from the transpose variables to
    408         // be able to catch any changes like that.
    409         static const gemmlowp::MapOrder ResultOrder =
    410             !transpose_c ? gemmlowp::MapOrder::RowMajor
    411                          : gemmlowp::MapOrder::ColMajor;
    412         static const gemmlowp::MapOrder LhsOrder =
    413             !transpose_a ? gemmlowp::MapOrder::RowMajor
    414                          : gemmlowp::MapOrder::ColMajor;
    415         static const gemmlowp::MapOrder RhsOrder =
    416             !transpose_b ? gemmlowp::MapOrder::RowMajor
    417                          : gemmlowp::MapOrder::ColMajor;
    418         gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(
    419             im2col_data_as_uint8, m, k, lda);
    420         gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(
    421             filter_data_as_uint8, k, n, ldb);
    422         gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(
    423             output_data_as_int32, m, n, ldc);
    424         const std::tuple<> empty_pipeline = {};
    425 
    426         auto& worker_threads =
    427             *(context->device()->tensorflow_cpu_worker_threads());
    428         TensorflowGemmContext context(worker_threads.num_threads,
    429                                       worker_threads.workers);
    430         gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
    431                                          gemmlowp::DefaultL8R8BitDepthParams>(
    432             &context, lhs, rhs, &result, -input_offset, -filter_offset,
    433             empty_pipeline);
    434         // Since gemmlowp uses assembly to write to the output, msan won't
    435         // detect the output buffer as written to, so we mark it manually.
    436         TF_ANNOTATE_MEMORY_IS_INITIALIZED(output_data_as_int32,
    437                                           m * n * sizeof(int32));
    438       } else {
    439         ReferenceGemm<T1, T2, T3>(
    440             transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer,
    441             input_offset, lda, filter_data, filter_offset, ldb,
    442             chunk_output_data, output_shift, output_offset, output_mult, ldc);
    443       }
    444     }
    445   }
    446 };
    447 
    448 template <class T1, class T2, class T3,
    449           template <class TF1, class TF2, class TF3> class ConvFunctor>
    450 class QuantizedConv2DOp : public OpKernel {
    451  public:
    452   explicit QuantizedConv2DOp(OpKernelConstruction* context)
    453       : OpKernel(context) {
    454     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    455     OP_REQUIRES(context, strides_.size() == 4,
    456                 errors::InvalidArgument("Sliding window strides field must "
    457                                         "specify 4 dimensions"));
    458     OP_REQUIRES(context, strides_[1] == strides_[2],
    459                 errors::InvalidArgument(
    460                     "Current implementation only supports equal length "
    461                     "strides in the row and column dimensions."));
    462     OP_REQUIRES(
    463         context, (strides_[0] == 1 && strides_[3] == 1),
    464         errors::InvalidArgument("Current implementation does not yet support "
    465                                 "strides in the batch and depth dimensions."));
    466     std::vector<int32> dilations;
    467     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations));
    468     OP_REQUIRES(context, dilations.size() == 4,
    469                 errors::InvalidArgument("Dilations field must "
    470                                         "specify 4 dimensions"));
    471     OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1,
    472                 errors::InvalidArgument(
    473                     "Current implementation only supports dilated rate as 1 "
    474                     "in the row and column dimensions."));
    475     OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1),
    476                 errors::InvalidArgument(
    477                     "Current implementation does not yet support "
    478                     "dilations in the batch and depth dimensions."));
    479     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    480   }
    481 
    482   void Compute(OpKernelContext* context) override {
    483     // Input tensor is of the following dimensions:
    484     // [ batch, in_rows, in_cols, in_depth ]
    485     const Tensor& input = context->input(0);
    486 
    487     // Input filter is of the following dimensions:
    488     // [ filter_rows, filter_cols, in_depth, out_depth]
    489     const Tensor& filter = context->input(1);
    490 
    491     // For 2D convolution, there should be 4 dimensions.
    492     OP_REQUIRES(context, input.dims() == 4,
    493                 errors::InvalidArgument("input must be 4-dimensional",
    494                                         input.shape().DebugString()));
    495     OP_REQUIRES(context, filter.dims() == 4,
    496                 errors::InvalidArgument("filter must be 4-dimensional: ",
    497                                         filter.shape().DebugString()));
    498 
    499     const float min_input = context->input(2).flat<float>()(0);
    500     const float max_input = context->input(3).flat<float>()(0);
    501     const float min_filter = context->input(4).flat<float>()(0);
    502     const float max_filter = context->input(5).flat<float>()(0);
    503     const int32 offset_input =
    504         FloatToQuantizedUnclamped<T1>(0.0f, min_input, max_input);
    505     const int32 offset_filter =
    506         FloatToQuantizedUnclamped<T2>(0.0f, min_filter, max_filter);
    507     const int32 offset_output = 0;
    508     const int32 mult_output = 1;
    509     const int32 shift_output = 0;
    510 
    511     // The last dimension for input is in_depth. It must be the same as the
    512     // filter's in_depth.
    513     const int64 in_depth = input.dim_size(3);
    514     OP_REQUIRES(context, in_depth == filter.dim_size(2),
    515                 errors::InvalidArgument(
    516                     "input and filter must have the same depth: ", in_depth,
    517                     " vs ", filter.dim_size(2)));
    518 
    519     // The last dimension for filter is out_depth.
    520     const int64 out_depth = filter.dim_size(3);
    521 
    522     // The second dimension for input is rows/height.
    523     // The first dimension for filter is rows/height.
    524     const int64 input_rows = input.dim_size(1);
    525     const int64 filter_rows = filter.dim_size(0);
    526 
    527     // The third dimension for input is columns/width.
    528     // The second dimension for filter is columns/width.
    529     const int64 input_cols = input.dim_size(2);
    530     const int64 filter_cols = filter.dim_size(1);
    531 
    532     // The first dimension for input is batch.
    533     const int64 batch = input.dim_size(0);
    534 
    535     // For now we take the stride from the second dimension only (we
    536     // assume row = col stride, and do not support striding on the
    537     // batch or depth dimension).
    538     const int stride = strides_[1];
    539 
    540     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
    541     OP_REQUIRES_OK(context,
    542                    GetWindowedOutputSize(input_rows, filter_rows, stride,
    543                                          padding_, &out_rows, &pad_rows));
    544     OP_REQUIRES_OK(context,
    545                    GetWindowedOutputSize(input_cols, filter_cols, stride,
    546                                          padding_, &out_cols, &pad_cols));
    547     CHECK_GT(batch, 0);
    548     CHECK_GT(out_rows, 0);
    549     CHECK_GT(out_cols, 0);
    550     CHECK_GT(out_depth, 0);
    551     TensorShape out_shape({batch, out_rows, out_cols, out_depth});
    552 
    553     // Output tensor is of the following dimensions:
    554     // [ in_batch, out_rows, out_cols, out_depth ]
    555     Tensor* output = nullptr;
    556     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    557 
    558     // This will call different implementations (e.g. reference or optimized)
    559     // depending on the template parameter.
    560     ConvFunctor<T1, T2, T3> conv_functor;
    561     conv_functor(context, input.flat<T1>().data(), batch, input_rows,
    562                  input_cols, in_depth, offset_input, filter.flat<T2>().data(),
    563                  filter_rows, filter_cols, out_depth, offset_filter, stride,
    564                  padding_, output->flat<T3>().data(), out_rows, out_cols,
    565                  shift_output, offset_output, mult_output);
    566 
    567     float min_output_value;
    568     float max_output_value;
    569     QuantizationRangeForMultiplication<T1, T2, T3>(
    570         min_input, max_input, min_filter, max_filter, &min_output_value,
    571         &max_output_value);
    572 
    573     Tensor* output_min = nullptr;
    574     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
    575     output_min->flat<float>()(0) = min_output_value;
    576 
    577     Tensor* output_max = nullptr;
    578     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
    579     output_max->flat<float>()(0) = max_output_value;
    580   }
    581 
    582  private:
    583   std::vector<int32> strides_;
    584   Padding padding_;
    585 };
    586 
    587 // Right now we only support taking two eight bit inputs, and returning the
    588 // results as signed 32-bit integers.
    589 REGISTER_KERNEL_BUILDER(
    590     Name("QuantizedConv2D")
    591         .Device(DEVICE_CPU)
    592         .TypeConstraint<quint8>("Tinput")
    593         .TypeConstraint<quint8>("Tfilter")
    594         .TypeConstraint<qint32>("out_type"),
    595     QuantizedConv2DOp<quint8, quint8, qint32, Im2ColConvFunctor>);
    596 
    597 }  // namespace tensorflow
    598