Home | History | Annotate | Download | only in optimized
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
     16 #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
     17 
     18 #include <assert.h>
     19 #include <stdint.h>
     20 #include <sys/types.h>
     21 #include <algorithm>
     22 #include <cmath>
     23 #include <limits>
     24 #include <memory>
     25 #include <tuple>
     26 #include <type_traits>
     27 
     28 #include "Eigen/Core"
     29 #include "unsupported/Eigen/CXX11/Tensor"
     30 #include "fixedpoint/fixedpoint.h"
     31 #include "public/gemmlowp.h"
     32 #include "tensorflow/contrib/lite/kernels/internal/common.h"
     33 #include "tensorflow/contrib/lite/kernels/internal/round.h"
     34 #include "tensorflow/contrib/lite/kernels/internal/types.h"
     35 
     36 namespace tflite {
     37 namespace optimized_ops {
     38 
     39 // Make a local VectorMap typedef allowing to map a float array
     40 // as a Eigen vector expression. The std::conditional here is to
     41 // construct the suitable Eigen type for the constness of the
     42 // data. Indeed, for const data, we need to produce
     43 //    Eigen::Map<const Eigen::Matrix<float, ...>>
     44 // and not the more straightforward
     45 //    Eigen::Map<Eigen::Matrix<const float, ...>>
     46 template <typename Scalar>
     47 using VectorMap = typename std::conditional<
     48     std::is_const<Scalar>::value,
     49     Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
     50                                    Eigen::Dynamic, 1>>,
     51     Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
     52 
     53 template <typename Scalar, int N>
     54 VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
     55   const int size = RequiredBufferSizeForDims(dims);
     56   return VectorMap<Scalar>(data, size, 1);
     57 }
     58 
     59 // Make a local VectorMap typedef allowing to map a float array
     60 // as a Eigen matrix expression. The same explanation as for VectorMap
     61 // above also applies here.
     62 template <typename Scalar>
     63 using MatrixMap = typename std::conditional<
     64     std::is_const<Scalar>::value,
     65     Eigen::Map<const Eigen::Matrix<typename std::remove_const<Scalar>::type,
     66                                    Eigen::Dynamic, Eigen::Dynamic>>,
     67     Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
     68 
     69 template <typename Scalar, int N>
     70 MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
     71                                                 const Dims<N>& dims) {
     72   const int rows = dims.sizes[0];
     73   int cols = 1;
     74   for (int d = 1; d < N; d++) {
     75     cols *= dims.sizes[d];
     76   }
     77   return MatrixMap<Scalar>(data, rows, cols);
     78 }
     79 
     80 template <typename Scalar, int N>
     81 MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
     82                                                const Dims<N>& dims) {
     83   const int cols = dims.sizes[N - 1];
     84   int rows = 1;
     85   for (int d = 0; d < N - 1; d++) {
     86     rows *= dims.sizes[d];
     87   }
     88   return MatrixMap<Scalar>(data, rows, cols);
     89 }
     90 
     91 template <typename Scalar>
     92 using ArrayMap = typename std::conditional<
     93     std::is_const<Scalar>::value,
     94     Eigen::Map<const Eigen::Array<typename std::remove_const<Scalar>::type,
     95                                   Eigen::Dynamic, Eigen::Dynamic>>,
     96     Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
     97 
     98 template <typename Scalar, int N>
     99 ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
    100                                               const Dims<N>& dims) {
    101   const int rows = dims.sizes[0];
    102   int cols = 1;
    103   for (int d = 1; d < N; d++) {
    104     cols *= dims.sizes[d];
    105   }
    106   return ArrayMap<Scalar>(data, rows, cols);
    107 }
    108 
    109 // TODO(b/62193649): this function is only needed as long
    110 // as we have the --variable_batch hack.
    111 template <typename Scalar, int N>
    112 MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
    113                                                    const Dims<N>& dims,
    114                                                    int rows) {
    115   int cols = 1;
    116   bool matched_rows = false;
    117   for (int d = 0; d < N; d++) {
    118     cols *= dims.sizes[d];
    119     if (cols == rows) {
    120       matched_rows = true;
    121       cols = 1;
    122     }
    123   }
    124   TFLITE_DCHECK(matched_rows);
    125   return MatrixMap<Scalar>(data, rows, cols);
    126 }
    127 
    128 // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE
    129 // BROADCASTING.
    130 //
    131 // NdArrayDesc<N> describes the shape and memory layout of an N-dimensional
    132 // rectangular array of numbers.
    133 //
    134 // NdArrayDesc<N> is basically identical to Dims<N> defined in types.h.
    135 // However, as Dims<N> is to be deprecated, this class exists as an adaptor
    136 // to enable simple unoptimized implementations of element-wise broadcasting
    137 // operations.
    138 template <int N>
    139 struct NdArrayDesc {
    140   // The "extent" of each dimension. Indices along dimension d must be in the
    141   // half-open interval [0, extents[d]).
    142   int extents[N];
    143 
    144   // The number of *elements* (not bytes) between consecutive indices of each
    145   // dimension.
    146   int strides[N];
    147 };
    148 
    149 // DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
    150 // ELEMENT-WISE BROADCASTING.
    151 //
    152 // Same as Offset(), except takes as NdArrayDesc<N> instead of Dims<N>.
    153 inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2,
    154                             int i3) {
    155   TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]);
    156   TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]);
    157   TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]);
    158   TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]);
    159   return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] +
    160          i3 * desc.strides[3];
    161 }
    162 
    163 // Given the dimensions of the operands for an element-wise binary broadcast,
    164 // adjusts them so that they can be directly iterated over with simple loops.
    165 // Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and
    166 // 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr.
    167 //
    168 // This function assumes that the two input shapes are compatible up to
    169 // broadcasting and the shorter one has already been prepended with 1s to be the
    170 // same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64),
    171 // shape1 must already have been prepended to be (1, 1, 1, 64). Recall that
    172 // Dims<N> refer to shapes in reverse order. In this case, input0_dims will be
    173 // (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1).
    174 //
    175 // When two shapes are compatible up to broadcasting, for each dimension d,
    176 // the input extents are either equal, or one of them is 1.
    177 //
    178 // This function performs the following for each dimension d:
    179 // - If the extents are equal, then do nothing since the loop that walks over
    180 //   both of the input arrays is correct.
    181 // - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1
    182 //   and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows
    183 //   array0 to be referenced *at any index* in dimension d and still access the
    184 //   same slice.
    185 template <int N>
    186 inline void NdArrayDescsForElementwiseBroadcast(const Dims<N>& input0_dims,
    187                                                 const Dims<N>& input1_dims,
    188                                                 NdArrayDesc<N>* desc0_out,
    189                                                 NdArrayDesc<N>* desc1_out) {
    190   TFLITE_DCHECK(desc0_out != nullptr);
    191   TFLITE_DCHECK(desc1_out != nullptr);
    192 
    193   // Copy dims to desc.
    194   for (int i = 0; i < N; ++i) {
    195     desc0_out->extents[i] = input0_dims.sizes[i];
    196     desc0_out->strides[i] = input0_dims.strides[i];
    197     desc1_out->extents[i] = input1_dims.sizes[i];
    198     desc1_out->strides[i] = input1_dims.strides[i];
    199   }
    200 
    201   // Walk over each dimension. If the extents are equal do nothing.
    202   // Otherwise, set the desc with extent 1 to have extent equal to the other and
    203   // stride 0.
    204   for (int i = 0; i < N; ++i) {
    205     const int extent0 = ArraySize(input0_dims, i);
    206     const int extent1 = ArraySize(input1_dims, i);
    207     if (extent0 != extent1) {
    208       if (extent0 == 1) {
    209         desc0_out->strides[i] = 0;
    210         desc0_out->extents[i] = extent1;
    211       } else {
    212         TFLITE_DCHECK_EQ(extent1, 1);
    213         desc1_out->strides[i] = 0;
    214         desc1_out->extents[i] = extent0;
    215       }
    216     }
    217   }
    218 }
    219 
    220 inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
    221   for (int i = 0; i < 4; i++) {
    222     if (dims1.sizes[i] != dims2.sizes[i]) {
    223       return false;
    224     }
    225   }
    226   return true;
    227 }
    228 
    229 inline void AddBiasAndEvalActivationFunction(const float* bias_data,
    230                                              const Dims<4>& bias_dims,
    231                                              float* array_data,
    232                                              const Dims<4>& array_dims,
    233                                              float output_activation_min,
    234                                              float output_activation_max) {
    235 #ifdef USE_NEON
    236   gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
    237   const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3];
    238   const int array_size = array_dims.sizes[3] * array_dims.strides[3];
    239   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
    240   float* array_ptr = array_data;
    241   float* array_end_ptr = array_ptr + array_size;
    242   const auto activation_min = vdupq_n_f32(output_activation_min);
    243   const auto activation_max = vdupq_n_f32(output_activation_max);
    244   for (; array_ptr != array_end_ptr; array_ptr += bias_size) {
    245     int i = 0;
    246     for (; i <= bias_size - 16; i += 16) {
    247       auto b0 = vld1q_f32(bias_data + i);
    248       auto b1 = vld1q_f32(bias_data + i + 4);
    249       auto b2 = vld1q_f32(bias_data + i + 8);
    250       auto b3 = vld1q_f32(bias_data + i + 12);
    251       auto a0 = vld1q_f32(array_ptr + i);
    252       auto a1 = vld1q_f32(array_ptr + i + 4);
    253       auto a2 = vld1q_f32(array_ptr + i + 8);
    254       auto a3 = vld1q_f32(array_ptr + i + 12);
    255       auto x0 = vaddq_f32(a0, b0);
    256       auto x1 = vaddq_f32(a1, b1);
    257       auto x2 = vaddq_f32(a2, b2);
    258       auto x3 = vaddq_f32(a3, b3);
    259       x0 = vmaxq_f32(activation_min, x0);
    260       x1 = vmaxq_f32(activation_min, x1);
    261       x2 = vmaxq_f32(activation_min, x2);
    262       x3 = vmaxq_f32(activation_min, x3);
    263       x0 = vminq_f32(activation_max, x0);
    264       x1 = vminq_f32(activation_max, x1);
    265       x2 = vminq_f32(activation_max, x2);
    266       x3 = vminq_f32(activation_max, x3);
    267       vst1q_f32(array_ptr + i, x0);
    268       vst1q_f32(array_ptr + i + 4, x1);
    269       vst1q_f32(array_ptr + i + 8, x2);
    270       vst1q_f32(array_ptr + i + 12, x3);
    271     }
    272     for (; i <= bias_size - 4; i += 4) {
    273       auto b = vld1q_f32(bias_data + i);
    274       auto a = vld1q_f32(array_ptr + i);
    275       auto x = vaddq_f32(a, b);
    276       x = vmaxq_f32(activation_min, x);
    277       x = vminq_f32(activation_max, x);
    278       vst1q_f32(array_ptr + i, x);
    279     }
    280     for (; i < bias_size; i++) {
    281       array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i],
    282                                                   output_activation_min,
    283                                                   output_activation_max);
    284     }
    285   }
    286 #else  // not NEON
    287   gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
    288   const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3];
    289   const int array_size = array_dims.sizes[3] * array_dims.strides[3];
    290   TFLITE_DCHECK_EQ((array_size % bias_size), 0);
    291   for (int array_offset = 0; array_offset < array_size;
    292        array_offset += bias_size) {
    293     for (int i = 0; i < bias_size; i++) {
    294       array_data[array_offset + i] = ActivationFunctionWithMinMax(
    295           array_data[array_offset + i] + bias_data[i], output_activation_min,
    296           output_activation_max);
    297     }
    298   }
    299 #endif
    300 }
    301 
    302 // legacy, for compatibility with old checked-in code
    303 template <FusedActivationFunctionType Ac>
    304 void AddBiasAndEvalActivationFunction(const float* bias_data,
    305                                       const Dims<4>& bias_dims,
    306                                       float* array_data,
    307                                       const Dims<4>& array_dims) {
    308   float output_activation_min, output_activation_max;
    309   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
    310   AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
    311                                    output_activation_min,
    312                                    output_activation_max);
    313 }
    314 
    315 template <typename Lhs, typename Rhs, typename Result>
    316 void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
    317           Eigen::MatrixBase<Result>* result) {
    318   if (rhs.cols() == 1) {
    319     gemmlowp::ScopedProfilingLabel label("GEMV");
    320     result->col(0).noalias() = lhs * rhs.col(0);
    321   } else {
    322     gemmlowp::ScopedProfilingLabel label("GEMM");
    323     result->noalias() = lhs * rhs;
    324   }
    325 }
    326 
    327 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
    328                            const float* weights_data,
    329                            const Dims<4>& weights_dims, const float* bias_data,
    330                            const Dims<4>& bias_dims,
    331                            float output_activation_min,
    332                            float output_activation_max, float* output_data,
    333                            const Dims<4>& output_dims) {
    334   gemmlowp::ScopedProfilingLabel label("FullyConnected");
    335   // TODO(b/62193649): this convoluted shape computation (determining
    336   // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
    337   // is because the current --variable_batch hack consists in overwriting the
    338   // 3rd dimension with the runtime batch size, as we don't keep track for each
    339   // array of which dimension is the batch dimension in it.
    340   // When that is fixed, this should become:
    341   // const auto input_matrix_map =
    342   //     MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
    343   const int input_rows = ArraySize(weights_dims, 0);
    344   const auto input_matrix_map =
    345       MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows);
    346   const auto filter_matrix_map =
    347       MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims);
    348   auto output_matrix_map =
    349       MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
    350 
    351   Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
    352   AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
    353                                    output_dims, output_activation_min,
    354                                    output_activation_max);
    355 }
    356 
    357 // legacy, for compatibility with old checked-in code
    358 template <FusedActivationFunctionType Ac>
    359 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
    360                     const float* weights_data, const Dims<4>& weights_dims,
    361                     const float* bias_data, const Dims<4>& bias_dims,
    362                     float* output_data, const Dims<4>& output_dims) {
    363   float output_activation_min, output_activation_max;
    364   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
    365   FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
    366                  bias_dims, output_activation_min, output_activation_max,
    367                  output_data, output_dims);
    368 }
    369 
    370 inline void preload_l1_stream(const uint8* ptr) {
    371 #ifdef GEMMLOWP_ARM_64
    372   asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :);
    373 #else
    374   gemmlowp::Prefetch(ptr);
    375 #endif
    376 }
    377 
    378 #ifdef USE_NEON
    379 inline void FullyConnectedAsGEMV(
    380     const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
    381     const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
    382     const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset,
    383     int32 output_multiplier, int output_shift, int32 output_activation_min,
    384     int32 output_activation_max, uint8* output_data,
    385     const Dims<4>& output_dims) {
    386   gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
    387   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
    388   TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
    389   TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
    390   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
    391   TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
    392                        ArraySize(output_dims, 3),
    393                    1);
    394   const int input_size = input_dims.strides[3];
    395   const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
    396   static constexpr int kPeel = 4;
    397   for (int k = 0; k < input_size; k += 64) {
    398     preload_l1_stream(input_data + k);
    399   }
    400   for (int k = 0; k < kPeel * input_size; k += 64) {
    401     preload_l1_stream(filter_data + k);
    402   }
    403   TFLITE_DCHECK(!(output_size % kPeel));
    404   const int32* bias_ptr = bias_data;
    405   uint8* output_ptr = output_data;
    406   for (int out = 0; out < output_size; out += kPeel) {
    407     int32x4_t acc[kPeel];
    408     for (int k = 0; k < kPeel; k++) {
    409       acc[k] = vdupq_n_s32(0);
    410     }
    411     const int16x8_t input_offset_vec = vdupq_n_s16(input_offset);
    412     const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset);
    413     int in = 0;
    414     for (; in <= input_size - 16; in += 16) {
    415       const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
    416       uint8x16_t filter_val_u8[kPeel];
    417       for (int k = 0; k < kPeel; k++) {
    418         const uint8* filter_ptr = filter_data + in + (out + k) * input_size;
    419         filter_val_u8[k] = vld1q_u8(filter_ptr);
    420         preload_l1_stream(filter_ptr + 64);
    421       }
    422       int16x8_t input_val[2];
    423       const uint8x8_t low = vget_low_u8(input_val_u8);
    424       const uint8x8_t high = vget_high_u8(input_val_u8);
    425       input_val[0] = vreinterpretq_s16_u16(vmovl_u8(low));
    426       input_val[1] = vreinterpretq_s16_u16(vmovl_u8(high));
    427       input_val[0] = vaddq_s16(input_val[0], input_offset_vec);
    428       input_val[1] = vaddq_s16(input_val[1], input_offset_vec);
    429       int16x8_t filter_val[kPeel][2];
    430       for (int k = 0; k < kPeel; k++) {
    431         const uint8x8_t low = vget_low_u8(filter_val_u8[k]);
    432         const uint8x8_t high = vget_high_u8(filter_val_u8[k]);
    433         filter_val[k][0] = vreinterpretq_s16_u16(vmovl_u8(low));
    434         filter_val[k][1] = vreinterpretq_s16_u16(vmovl_u8(high));
    435         filter_val[k][0] = vaddq_s16(filter_val[k][0], filter_offset_vec);
    436         filter_val[k][1] = vaddq_s16(filter_val[k][1], filter_offset_vec);
    437       }
    438       for (int p = 0; p < 2; p++) {
    439         for (int k = 0; k < kPeel; k++) {
    440           acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k][p]),
    441                              vget_low_s16(input_val[p]));
    442         }
    443         for (int k = 0; k < kPeel; k++) {
    444           acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k][p]),
    445                              vget_high_s16(input_val[p]));
    446         }
    447       }
    448     }
    449     for (; in <= input_size - 8; in += 8) {
    450       const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
    451       uint8x8_t filter_val_u8[kPeel];
    452       for (int k = 0; k < kPeel; k++) {
    453         const uint8* filter_ptr = filter_data + in + (out + k) * input_size;
    454         filter_val_u8[k] = vld1_u8(filter_ptr);
    455       }
    456       int16x8_t input_val;
    457       input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
    458       input_val = vaddq_s16(input_val, input_offset_vec);
    459       int16x8_t filter_val[kPeel];
    460       for (int k = 0; k < kPeel; k++) {
    461         filter_val[k] = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8[k]));
    462         filter_val[k] = vaddq_s16(filter_val[k], filter_offset_vec);
    463       }
    464       for (int k = 0; k < kPeel; k++) {
    465         acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k]),
    466                            vget_low_s16(input_val));
    467       }
    468       for (int k = 0; k < kPeel; k++) {
    469         acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k]),
    470                            vget_high_s16(input_val));
    471       }
    472     }
    473     if (in < input_size) {
    474       int32 buf[4 * kPeel];
    475       for (int k = 0; k < 4; k++) {
    476         vst1q_s32(buf + 4 * k, acc[k]);
    477       }
    478       for (; in < input_size; in++) {
    479         int lane = (in + 8 - input_size) % 4;
    480         const int32 input_val = input_data[in] + input_offset;
    481         for (int k = 0; k < kPeel; k++) {
    482           int32 filter_val =
    483               filter_data[in + (out + k) * input_size] + filter_offset;
    484           buf[lane + 4 * k] += filter_val * input_val;
    485         }
    486       }
    487       for (int k = 0; k < 4; k++) {
    488         acc[k] = vld1q_s32(buf + 4 * k);
    489       }
    490     }
    491 
    492     // Horizontally reduce accumulators
    493     int32x2_t pairwise_reduced_acc[kPeel];
    494     for (int k = 0; k < kPeel; k++) {
    495       pairwise_reduced_acc[k] =
    496           vpadd_s32(vget_low_s32(acc[k]), vget_high_s32(acc[k]));
    497     }
    498     static_assert(kPeel == 4, "the code below currently assumes kPeel = 4");
    499     const int32x2_t reduced_lo =
    500         vpadd_s32(pairwise_reduced_acc[0], pairwise_reduced_acc[1]);
    501     const int32x2_t reduced_hi =
    502         vpadd_s32(pairwise_reduced_acc[2], pairwise_reduced_acc[3]);
    503     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
    504     // Add bias values.
    505     int32x4_t bias_vec = vld1q_s32(bias_ptr);
    506     bias_ptr += 4;
    507     reduced = vaddq_s32(reduced, bias_vec);
    508     // Multiply by the fixed-point multiplier.
    509     reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
    510     // Rounding-shift-right.
    511     using gemmlowp::RoundingDivideByPOT;
    512     reduced = RoundingDivideByPOT(reduced, output_shift);
    513     // Add the output offset.
    514     const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
    515     reduced = vaddq_s32(reduced, output_offset_vec);
    516     // Narrow values down to 16 bit signed.
    517     const int16x4_t res16 = vqmovn_s32(reduced);
    518     // Narrow values down to 8 bit unsigned, saturating.
    519     uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16));
    520     // Apply the clamping from the activation function
    521     res8 = vmax_u8(res8, vdup_n_u8(output_activation_min));
    522     res8 = vmin_u8(res8, vdup_n_u8(output_activation_max));
    523     // Store results to destination. Assumes 32bit alignment.
    524     vst1_lane_u32(reinterpret_cast<uint32*>(output_ptr),
    525                   vreinterpret_u32_u8(res8), 0);
    526     output_ptr += kPeel;
    527   }
    528 }
    529 #endif  // USE_NEON
    530 
    531 struct GemmlowpOutputPipeline {
    532   typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
    533       ColVectorMap;
    534   typedef std::tuple<
    535       gemmlowp::OutputStageBiasAddition<ColVectorMap>,
    536       gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
    537       gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8>
    538       Pipeline;
    539   static Pipeline Make(const int32* bias_data, int output_rows,
    540                        int32 output_offset, int32 output_multiplier,
    541                        int output_shift, int32 output_activation_min,
    542                        int32 output_activation_max) {
    543     ColVectorMap bias_vector(bias_data, output_rows);
    544     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
    545     bias_addition_stage.bias_vector = bias_vector;
    546     gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
    547         quantize_down_stage;
    548     quantize_down_stage.result_offset_after_shift = output_offset;
    549     quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
    550     quantize_down_stage.result_shift = output_shift;
    551     gemmlowp::OutputStageClamp clamp_stage;
    552     clamp_stage.min = output_activation_min;
    553     clamp_stage.max = output_activation_max;
    554     gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
    555     return std::make_tuple(bias_addition_stage, quantize_down_stage,
    556                            clamp_stage, saturating_cast_stage);
    557   }
    558 };
    559 
    560 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
    561                            int32 input_offset, const uint8* filter_data,
    562                            const Dims<4>& filter_dims, int32 filter_offset,
    563                            const int32* bias_data, const Dims<4>& bias_dims,
    564                            int32 output_offset, int32 output_multiplier,
    565                            int output_shift, int32 output_activation_min,
    566                            int32 output_activation_max, uint8* output_data,
    567                            const Dims<4>& output_dims,
    568                            gemmlowp::GemmContext* gemm_context) {
    569   gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
    570   // TODO(benoitjacob): This really should be:
    571   //     const int batches = ArraySize(output_dims, 1);
    572   // but the current --variable_batch hack consists in overwriting the 3rd
    573   // dimension with the runtime batch size, as we don't keep track for each
    574   // array of which dimension is the batch dimension in it.
    575   const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
    576                       ArraySize(output_dims, 3);
    577 #ifdef USE_NEON
    578   const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
    579   if (batches == 1 && !(output_size % 4)) {
    580     return FullyConnectedAsGEMV(
    581         input_data, input_dims, input_offset, filter_data, filter_dims,
    582         filter_offset, bias_data, bias_dims, output_offset, output_multiplier,
    583         output_shift, output_activation_min, output_activation_max, output_data,
    584         output_dims);
    585   }
    586 #endif  // USE_NEON
    587   const int filter_rows = filter_dims.sizes[1];
    588   const int filter_cols = filter_dims.sizes[0];
    589   TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1);
    590   TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1);
    591   const int output_rows = output_dims.sizes[0];
    592   TFLITE_DCHECK_EQ(output_rows, filter_rows);
    593   TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
    594   TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
    595   TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
    596   TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
    597 
    598   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
    599       filter_data, output_rows, filter_cols, filter_cols);
    600   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
    601       input_data, filter_cols, batches, filter_cols);
    602   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
    603       output_data, output_rows, batches, output_rows);
    604   const auto& output_pipeline = GemmlowpOutputPipeline::Make(
    605       bias_data, output_rows, output_offset, output_multiplier, output_shift,
    606       output_activation_min, output_activation_max);
    607   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
    608                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
    609       gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
    610       input_offset, output_pipeline);
    611 }
    612 
    613 // legacy, for compatibility with old checked-in code
    614 template <FusedActivationFunctionType Ac>
    615 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
    616                     int32 input_offset, const uint8* filter_data,
    617                     const Dims<4>& filter_dims, int32 filter_offset,
    618                     const int32* bias_data, const Dims<4>& bias_dims,
    619                     int32 output_offset, int32 output_multiplier,
    620                     int output_shift, int32 output_activation_min,
    621                     int32 output_activation_max, uint8* output_data,
    622                     const Dims<4>& output_dims,
    623                     gemmlowp::GemmContext* gemm_context) {
    624   static_assert(Ac == FusedActivationFunctionType::kNone ||
    625                     Ac == FusedActivationFunctionType::kRelu ||
    626                     Ac == FusedActivationFunctionType::kRelu6 ||
    627                     Ac == FusedActivationFunctionType::kRelu1,
    628                 "");
    629   FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
    630                  filter_offset, bias_data, bias_dims, output_offset,
    631                  output_multiplier, output_shift, output_activation_min,
    632                  output_activation_max, output_data, output_dims, gemm_context);
    633 }
    634 
    635 template <typename T>
    636 inline void ExtractPatchIntoBufferColumn(
    637     const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
    638     int stride_width, int stride_height, int pad_width, int pad_height,
    639     int in_width, int in_height, int in_depth, int single_buffer_length,
    640     int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) {
    641   gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
    642   // This chunk of code reshapes all the inputs corresponding to
    643   // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
    644   const int kwidth_times_indepth = kwidth * in_depth;
    645   const int inwidth_times_indepth = in_width * in_depth;
    646   const int ih_ungated_start = h * stride_height - pad_height;
    647   const int ih_ungated_end = (ih_ungated_start + kheight);
    648   const int ih_end = std::min(ih_ungated_end, in_height);
    649   const int iw_ungated_start = w * stride_width - pad_width;
    650   const int iw_ungated_end = (iw_ungated_start + kwidth);
    651   const int iw_end = std::min(iw_ungated_end, in_width);
    652   // If the patch is off the edge of the input image, skip writing those rows
    653   // and columns from the patch into the output array.
    654   const int h_offset = std::max(0, -ih_ungated_start);
    655   const int w_offset = std::max(0, -iw_ungated_start);
    656   const int ih_start = std::max(0, ih_ungated_start);
    657   const int iw_start = std::max(0, iw_ungated_start);
    658   const int single_row_num =
    659       std::min(kwidth - w_offset, in_width - iw_start) * in_depth;
    660   const int output_row_offset = (buffer_id * single_buffer_length);
    661   int out_offset =
    662       output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
    663   int in_offset = Offset(input_dims, 0, iw_start, ih_start, b);
    664 
    665   // Express all of the calculations as padding around the input patch.
    666   const int top_padding = h_offset;
    667   const int bottom_padding = (ih_ungated_end - ih_end);
    668   const int left_padding = w_offset;
    669   const int right_padding = (iw_ungated_end - iw_end);
    670   assert(single_row_num ==
    671          ((kwidth - (left_padding + right_padding)) * in_depth));
    672 
    673   // Write out zeroes to the elements representing the top rows of the input
    674   // patch that are off the edge of the input image.
    675   if (top_padding > 0) {
    676     const int top_row_elements = (top_padding * kwidth * in_depth);
    677     memset(conv_buffer_data + output_row_offset, byte_zero,
    678            (top_row_elements * sizeof(T)));
    679   }
    680 
    681   // If the patch is on the interior of the input image horizontally, just copy
    682   // over the rows sequentially, otherwise add zero padding at the start or end.
    683   if ((left_padding == 0) && (right_padding == 0)) {
    684     for (int ih = ih_start; ih < ih_end; ++ih) {
    685       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
    686              single_row_num * sizeof(T));
    687       out_offset += kwidth_times_indepth;
    688       in_offset += inwidth_times_indepth;
    689     }
    690   } else {
    691     for (int ih = ih_start; ih < ih_end; ++ih) {
    692       if (left_padding > 0) {
    693         const int left_start = (out_offset - (left_padding * in_depth));
    694         memset(conv_buffer_data + left_start, byte_zero,
    695                (left_padding * in_depth * sizeof(T)));
    696       }
    697       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
    698              single_row_num * sizeof(T));
    699       if (right_padding > 0) {
    700         const int right_start = (out_offset + single_row_num);
    701         memset(conv_buffer_data + right_start, byte_zero,
    702                (right_padding * in_depth * sizeof(T)));
    703       }
    704       out_offset += kwidth_times_indepth;
    705       in_offset += inwidth_times_indepth;
    706     }
    707   }
    708 
    709   // If the bottom of the patch falls off the input image, pad the values
    710   // representing those input rows with zeroes.
    711   if (bottom_padding > 0) {
    712     const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
    713     const int bottom_start =
    714         output_row_offset +
    715         ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
    716     memset(conv_buffer_data + bottom_start, byte_zero,
    717            (bottom_row_elements * sizeof(T)));
    718   }
    719 }
    720 
    721 template <typename T>
    722 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
    723             int stride_height, int pad_width, int pad_height, int kheight,
    724             int kwidth, uint8 byte_zero, T* output_data,
    725             const Dims<4>& output_dims) {
    726   gemmlowp::ScopedProfilingLabel label("Im2col");
    727   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
    728   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
    729   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
    730   const int input_depth = ArraySize(input_dims, 0);
    731   const int input_width = ArraySize(input_dims, 1);
    732   const int input_height = ArraySize(input_dims, 2);
    733   const int output_depth = ArraySize(output_dims, 0);
    734   const int output_width = ArraySize(output_dims, 1);
    735   const int output_height = ArraySize(output_dims, 2);
    736 
    737   int buffer_id = 0;
    738   // Loop over the output nodes.
    739   for (int b = 0; b < batches; ++b) {
    740     for (int h = 0; h < output_height; ++h) {
    741       for (int w = 0; w < output_width; ++w) {
    742         ExtractPatchIntoBufferColumn(
    743             input_dims, w, h, b, kheight, kwidth, stride_width, stride_height,
    744             pad_width, pad_height, input_width, input_height, input_depth,
    745             output_depth, buffer_id, input_data, output_data, byte_zero);
    746         ++buffer_id;
    747       }
    748     }
    749   }
    750 }
    751 
    752 // legacy, for compatibility with old checked-in code
    753 template <typename T>
    754 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
    755             int pad_width, int pad_height, int kheight, int kwidth,
    756             uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
    757   Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
    758          kwidth, byte_zero, output_data, output_dims);
    759 }
    760 
    761 inline void Conv(const float* input_data, const Dims<4>& input_dims,
    762                  const float* filter_data, const Dims<4>& filter_dims,
    763                  const float* bias_data, const Dims<4>& bias_dims,
    764                  int stride_width, int stride_height, int pad_width,
    765                  int pad_height, float output_activation_min,
    766                  float output_activation_max, float* output_data,
    767                  const Dims<4>& output_dims, float* im2col_data,
    768                  const Dims<4>& im2col_dims) {
    769   (void)im2col_data;
    770   (void)im2col_dims;
    771   gemmlowp::ScopedProfilingLabel label("Conv");
    772 
    773   const float* gemm_input_data = nullptr;
    774   const Dims<4>* gemm_input_dims = nullptr;
    775   const int filter_width = ArraySize(filter_dims, 1);
    776   const int filter_height = ArraySize(filter_dims, 2);
    777   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
    778                            filter_width != 1 || filter_height != 1;
    779   if (need_im2col) {
    780     TFLITE_DCHECK(im2col_data);
    781     Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
    782            pad_height, filter_height, filter_width, 0, im2col_data,
    783            im2col_dims);
    784     gemm_input_data = im2col_data;
    785     gemm_input_dims = &im2col_dims;
    786   } else {
    787     // TODO(aselle): We need to make sure to not send im2col if it is not
    788     // needed.
    789     TFLITE_DCHECK(!im2col_data);
    790     gemm_input_data = input_data;
    791     gemm_input_dims = &input_dims;
    792   }
    793 
    794   const auto im2col_matrix_map =
    795       MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims);
    796   const auto filter_matrix_map =
    797       MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
    798   auto output_matrix_map =
    799       MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
    800 
    801   Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
    802 
    803   AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
    804                                    output_dims, output_activation_min,
    805                                    output_activation_max);
    806 }
    807 
    808 // legacy, for compatibility with old checked-in code
    809 template <FusedActivationFunctionType Ac>
    810 void Conv(const float* input_data, const Dims<4>& input_dims,
    811           const float* filter_data, const Dims<4>& filter_dims,
    812           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
    813           int stride_height, int pad_width, int pad_height, float* output_data,
    814           const Dims<4>& output_dims, float* im2col_data,
    815           const Dims<4>& im2col_dims) {
    816   float output_activation_min, output_activation_max;
    817   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
    818   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
    819        stride_width, stride_height, pad_width, pad_height,
    820        output_activation_min, output_activation_max, output_data, output_dims,
    821        im2col_data, im2col_dims);
    822 }
    823 
    824 // legacy, for compatibility with old checked-in code
    825 template <FusedActivationFunctionType Ac>
    826 void Conv(const float* input_data, const Dims<4>& input_dims,
    827           const float* filter_data, const Dims<4>& filter_dims,
    828           const float* bias_data, const Dims<4>& bias_dims, int stride,
    829           int pad_width, int pad_height, float* output_data,
    830           const Dims<4>& output_dims, float* im2col_data,
    831           const Dims<4>& im2col_dims) {
    832   Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
    833            bias_dims, stride, stride, pad_width, pad_height, output_data,
    834            output_dims, im2col_data, im2col_dims);
    835 }
    836 
    837 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
    838                  int32 input_offset, const uint8* filter_data,
    839                  const Dims<4>& filter_dims, int32 filter_offset,
    840                  const int32* bias_data, const Dims<4>& bias_dims,
    841                  int stride_width, int stride_height, int pad_width,
    842                  int pad_height, int32 output_offset, int32 output_multiplier,
    843                  int output_shift, int32 output_activation_min,
    844                  int32 output_activation_max, uint8* output_data,
    845                  const Dims<4>& output_dims, uint8* im2col_data,
    846                  const Dims<4>& im2col_dims,
    847                  gemmlowp::GemmContext* gemm_context) {
    848   gemmlowp::ScopedProfilingLabel label("Conv/8bit");
    849 
    850   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
    851   TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
    852   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
    853 
    854   const uint8* gemm_input_data = nullptr;
    855   const Dims<4>* gemm_input_dims = nullptr;
    856   const int filter_width = ArraySize(filter_dims, 1);
    857   const int filter_height = ArraySize(filter_dims, 2);
    858   const bool need_im2col = stride_width != 1 || stride_height != 1 ||
    859                            filter_width != 1 || filter_height != 1;
    860   if (need_im2col) {
    861     TFLITE_DCHECK(im2col_data);
    862     const int input_zero_point = -input_offset;
    863     TFLITE_DCHECK_GE(input_zero_point, 0);
    864     TFLITE_DCHECK_LE(input_zero_point, 255);
    865     Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
    866            pad_height, filter_height, filter_width, input_zero_point,
    867            im2col_data, im2col_dims);
    868     gemm_input_data = im2col_data;
    869     gemm_input_dims = &im2col_dims;
    870   } else {
    871     TFLITE_DCHECK(!im2col_data);
    872     gemm_input_data = input_data;
    873     gemm_input_dims = &input_dims;
    874   }
    875 
    876   const int gemm_input_rows = gemm_input_dims->sizes[0];
    877   const int gemm_input_cols = gemm_input_dims->sizes[1] *
    878                               gemm_input_dims->sizes[2] *
    879                               gemm_input_dims->sizes[3];
    880   const int filter_rows = filter_dims.sizes[3];
    881   const int filter_cols =
    882       filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
    883   const int output_rows = output_dims.sizes[0];
    884   const int output_cols =
    885       output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
    886   TFLITE_DCHECK_EQ(output_rows, filter_rows);
    887   TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
    888   TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
    889   TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
    890   TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
    891   TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
    892   TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
    893   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
    894       filter_data, filter_rows, filter_cols);
    895   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
    896       gemm_input_data, gemm_input_rows, gemm_input_cols);
    897   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
    898       output_data, output_rows, output_cols);
    899   const auto& output_pipeline = GemmlowpOutputPipeline::Make(
    900       bias_data, output_rows, output_offset, output_multiplier, output_shift,
    901       output_activation_min, output_activation_max);
    902   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
    903                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
    904       gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
    905       input_offset, output_pipeline);
    906 }
    907 
    908 // legacy, for compatibility with old checked-in code
    909 template <FusedActivationFunctionType Ac>
    910 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
    911                  int32 input_offset, const uint8* filter_data,
    912                  const Dims<4>& filter_dims, int32 filter_offset,
    913                  const int32* bias_data, const Dims<4>& bias_dims,
    914                  int stride_width, int stride_height, int pad_width,
    915                  int pad_height, int32 output_offset, int32 output_multiplier,
    916                  int output_shift, int32 output_activation_min,
    917                  int32 output_activation_max, uint8* output_data,
    918                  const Dims<4>& output_dims, uint8* im2col_data,
    919                  const Dims<4>& im2col_dims,
    920                  gemmlowp::GemmContext* gemm_context) {
    921   static_assert(Ac == FusedActivationFunctionType::kNone ||
    922                     Ac == FusedActivationFunctionType::kRelu ||
    923                     Ac == FusedActivationFunctionType::kRelu6 ||
    924                     Ac == FusedActivationFunctionType::kRelu1,
    925                 "");
    926   if (Ac == FusedActivationFunctionType::kNone) {
    927     TFLITE_DCHECK_EQ(output_activation_min, 0);
    928     TFLITE_DCHECK_EQ(output_activation_max, 255);
    929   }
    930   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
    931        filter_offset, bias_data, bias_dims, stride_width, stride_height,
    932        pad_width, pad_height, output_offset, output_multiplier, output_shift,
    933        output_activation_min, output_activation_max, output_data, output_dims,
    934        im2col_data, im2col_dims, gemm_context);
    935 }
    936 
    937 // legacy, for compatibility with old checked-in code
    938 template <FusedActivationFunctionType Ac>
    939 void Conv(const uint8* input_data, const Dims<4>& input_dims,
    940           int32 input_offset, const uint8* filter_data,
    941           const Dims<4>& filter_dims, int32 filter_offset,
    942           const int32* bias_data, const Dims<4>& bias_dims, int stride,
    943           int pad_width, int pad_height, int32 output_offset,
    944           int32 output_multiplier, int output_shift,
    945           int32 output_activation_min, int32 output_activation_max,
    946           uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
    947           const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
    948   static_assert(Ac == FusedActivationFunctionType::kNone ||
    949                     Ac == FusedActivationFunctionType::kRelu ||
    950                     Ac == FusedActivationFunctionType::kRelu6 ||
    951                     Ac == FusedActivationFunctionType::kRelu1,
    952                 "");
    953   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
    954        filter_offset, bias_data, bias_dims, stride, stride, pad_width,
    955        pad_height, output_offset, output_multiplier, output_shift,
    956        output_activation_min, output_activation_max, output_data, output_dims,
    957        im2col_data, im2col_dims, gemm_context);
    958 }
    959 
    960 template <typename T>
    961 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
    962                          int block_size, T* output_data,
    963                          const Dims<4>& output_dims) {
    964   gemmlowp::ScopedProfilingLabel label("DepthToSpace");
    965 
    966   const int input_depth = ArraySize(input_dims, 0);
    967   const int input_width = ArraySize(input_dims, 1);
    968   const int input_height = ArraySize(input_dims, 2);
    969 
    970   const int output_depth = ArraySize(output_dims, 0);
    971   const int batch_size = ArraySize(output_dims, 3);
    972 
    973   // Number of continuous values that we can copy in one interation.
    974   const int stride = block_size * output_depth;
    975 
    976   for (int batch = 0; batch < batch_size; ++batch) {
    977     for (int in_h = 0; in_h < input_height; ++in_h) {
    978       const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch);
    979       for (int offset_h = 0; offset_h < block_size; ++offset_h) {
    980         const T* src = input_ptr;
    981         for (int in_w = 0; in_w < input_width; ++in_w) {
    982           memcpy(output_data, src, stride * sizeof(T));
    983           output_data += stride;
    984           src += input_depth;
    985         }
    986         input_ptr += stride;
    987       }
    988     }
    989   }
    990 }
    991 
    992 // legacy, for compatibility with old checked-in code
    993 template <FusedActivationFunctionType Ac, typename T>
    994 void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
    995             int pad_width, int pad_height, int kheight, int kwidth,
    996             uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
    997   Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
    998          kwidth, byte_zero, output_data, output_dims);
    999 }
   1000 
   1001 // legacy, for compatibility with old checked-in code
   1002 template <FusedActivationFunctionType Ac>
   1003 void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
   1004                 const float* filter_data, const Dims<4>& filter_dims,
   1005                 const float* bias_data, const Dims<4>& bias_dims,
   1006                 float* output_data, const Dims<4>& output_dims) {
   1007   gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
   1008 
   1009   const auto input_matrix_map =
   1010       MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
   1011   const auto filter_matrix_map =
   1012       MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
   1013   auto output_matrix_map =
   1014       MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
   1015 
   1016   Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
   1017 
   1018   AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
   1019                                        output_dims);
   1020 }
   1021 
   1022 // legacy, for compatibility with old checked-in code
   1023 template <FusedActivationFunctionType Ac>
   1024 void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
   1025                 int32 input_offset, const uint8* filter_data,
   1026                 const Dims<4>& filter_dims, int32 filter_offset,
   1027                 const int32* bias_data, const Dims<4>& bias_dims,
   1028                 int32 output_offset, int32 output_multiplier, int output_shift,
   1029                 int32 output_activation_min, int32 output_activation_max,
   1030                 uint8* output_data, const Dims<4>& output_dims,
   1031                 gemmlowp::GemmContext* gemm_context) {
   1032   gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
   1033   static_assert(Ac == FusedActivationFunctionType::kNone ||
   1034                     Ac == FusedActivationFunctionType::kRelu ||
   1035                     Ac == FusedActivationFunctionType::kRelu6 ||
   1036                     Ac == FusedActivationFunctionType::kRelu1,
   1037                 "");
   1038   const int input_rows = input_dims.sizes[0];
   1039   const int input_cols =
   1040       input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3];
   1041   const int filter_rows = filter_dims.sizes[3];
   1042   const int filter_cols =
   1043       filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
   1044   const int output_rows = output_dims.sizes[0];
   1045   const int output_cols =
   1046       output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
   1047   TFLITE_DCHECK_EQ(output_rows, filter_rows);
   1048   TFLITE_DCHECK_EQ(output_cols, input_cols);
   1049   TFLITE_DCHECK_EQ(filter_cols, input_rows);
   1050   TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
   1051   TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
   1052   TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
   1053   TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
   1054   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
   1055       filter_data, output_rows, filter_cols, filter_cols);
   1056   gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
   1057       input_data, filter_cols, output_cols, filter_cols);
   1058   gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
   1059       output_data, output_rows, output_cols, output_rows);
   1060   const auto& output_pipeline = GemmlowpOutputPipeline::Make(
   1061       bias_data, output_rows, output_offset, output_multiplier, output_shift,
   1062       output_activation_min, output_activation_max);
   1063   gemmlowp::GemmWithOutputPipeline<uint8, uint8,
   1064                                    gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
   1065       gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
   1066       input_offset, output_pipeline);
   1067 }
   1068 
   1069 template <typename T>
   1070 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
   1071                          int block_size, T* output_data,
   1072                          const Dims<4>& output_dims) {
   1073   gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
   1074 
   1075   const int output_depth = ArraySize(output_dims, 0);
   1076   const int output_width = ArraySize(output_dims, 1);
   1077   const int output_height = ArraySize(output_dims, 2);
   1078 
   1079   const int input_depth = ArraySize(input_dims, 0);
   1080   const int batch_size = ArraySize(input_dims, 3);
   1081 
   1082   // Number of continuous values that we can copy in one interation.
   1083   const int stride = block_size * input_depth;
   1084 
   1085   for (int batch = 0; batch < batch_size; ++batch) {
   1086     for (int out_h = 0; out_h < output_height; ++out_h) {
   1087       T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch);
   1088       for (int offset_h = 0; offset_h < block_size; ++offset_h) {
   1089         T* dst = output_ptr;
   1090         for (int out_w = 0; out_w < output_width; ++out_w) {
   1091           memcpy(dst, input_data, stride * sizeof(T));
   1092           input_data += stride;
   1093           dst += output_depth;
   1094         }
   1095         output_ptr += stride;
   1096       }
   1097     }
   1098   }
   1099 }
   1100 
   1101 template <FusedActivationFunctionType Ac>
   1102 void NonGlobalBatchNormalization(
   1103     const float* input_data, const Dims<4>& input_dims, const float* mean_data,
   1104     const Dims<4>& mean_dims, const float* multiplier_data,
   1105     const Dims<4>& multiplier_dims, const float* offset_data,
   1106     const Dims<4>& offset_dims, float* output_data,
   1107     const Dims<4>& output_dims) {
   1108   gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization");
   1109   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   1110   const int height =
   1111       MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2,
   1112                         offset_dims, 2, output_dims, 2);
   1113   const int width =
   1114       MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1,
   1115                         offset_dims, 1, output_dims, 1);
   1116   const int depth =
   1117       MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
   1118                         offset_dims, 0, output_dims, 0);
   1119 
   1120   for (int b = 0; b < batches; ++b) {
   1121     for (int y = 0; y < height; ++y) {
   1122       for (int x = 0; x < width; ++x) {
   1123         for (int c = 0; c < depth; ++c) {
   1124           output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
   1125               (input_data[Offset(input_dims, c, x, y, b)] -
   1126                mean_data[Offset(mean_dims, c, x, y, 0)]) *
   1127                   multiplier_data[Offset(multiplier_dims, c, x, y, 0)] +
   1128               offset_data[Offset(offset_dims, c, x, y, 0)]);
   1129         }
   1130       }
   1131     }
   1132   }
   1133 }
   1134 
   1135 template <FusedActivationFunctionType Ac>
   1136 void GlobalBatchNormalization(const float* input_data,
   1137                               const Dims<4>& input_dims, const float* mean_data,
   1138                               const Dims<4>& mean_dims,
   1139                               const float* multiplier_data,
   1140                               const Dims<4>& multiplier_dims,
   1141                               const float* offset_data,
   1142                               const Dims<4>& offset_dims, float* output_data,
   1143                               const Dims<4>& output_dims) {
   1144   gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization");
   1145   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   1146   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   1147   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   1148   const int depth =
   1149       MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
   1150                         offset_dims, 0, output_dims, 0);
   1151 
   1152   for (int b = 0; b < batches; ++b) {
   1153     for (int y = 0; y < height; ++y) {
   1154       for (int x = 0; x < width; ++x) {
   1155         for (int c = 0; c < depth; ++c) {
   1156           output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction<Ac>(
   1157               (input_data[Offset(input_dims, c, x, y, b)] -
   1158                mean_data[Offset(mean_dims, c, 0, 0, 0)]) *
   1159                   multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] +
   1160               offset_data[Offset(offset_dims, c, 0, 0, 0)]);
   1161         }
   1162       }
   1163     }
   1164   }
   1165 }
   1166 
   1167 inline void Relu(const float* input_data, const Dims<4>& input_dims,
   1168                  float* output_data, const Dims<4>& output_dims) {
   1169   gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
   1170 
   1171   const auto input = MapAsVector(input_data, input_dims);
   1172   auto output = MapAsVector(output_data, output_dims);
   1173   output = input.cwiseMax(0.0f);
   1174 }
   1175 
   1176 inline void Relu1(const float* input_data, const Dims<4>& input_dims,
   1177                   float* output_data, const Dims<4>& output_dims) {
   1178   gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
   1179   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   1180   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   1181   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   1182   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   1183   for (int b = 0; b < batches; ++b) {
   1184     for (int y = 0; y < height; ++y) {
   1185       for (int x = 0; x < width; ++x) {
   1186         for (int c = 0; c < depth; ++c) {
   1187           float val = input_data[Offset(input_dims, c, x, y, b)];
   1188           const float upper = 1;
   1189           const float lower = -1;
   1190           float clamped = val > upper ? upper : val < lower ? lower : val;
   1191           output_data[Offset(output_dims, c, x, y, b)] = clamped;
   1192         }
   1193       }
   1194     }
   1195   }
   1196 }
   1197 
   1198 inline void Relu6(const float* input_data, const Dims<4>& input_dims,
   1199                   float* output_data, const Dims<4>& output_dims) {
   1200   gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
   1201   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   1202   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   1203   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   1204   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   1205   for (int b = 0; b < batches; ++b) {
   1206     for (int y = 0; y < height; ++y) {
   1207       for (int x = 0; x < width; ++x) {
   1208         for (int c = 0; c < depth; ++c) {
   1209           float val = input_data[Offset(input_dims, c, x, y, b)];
   1210           const float upper = 6;
   1211           const float lower = 0;
   1212           float clamped = val > upper ? upper : val < lower ? lower : val;
   1213           output_data[Offset(output_dims, c, x, y, b)] = clamped;
   1214         }
   1215       }
   1216     }
   1217   }
   1218 }
   1219 
   1220 template <FusedActivationFunctionType Ac>
   1221 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
   1222                      float* output_data, const Dims<4>& output_dims) {
   1223   gemmlowp::ScopedProfilingLabel label("L2Normalization");
   1224   static_assert(Ac == FusedActivationFunctionType::kNone, "");
   1225   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   1226   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   1227   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   1228   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   1229   for (int b = 0; b < batches; ++b) {
   1230     for (int y = 0; y < height; ++y) {
   1231       for (int x = 0; x < width; ++x) {
   1232         float squared_l2_norm = 0;
   1233         for (int c = 0; c < depth; ++c) {
   1234           float val = input_data[Offset(input_dims, c, x, y, b)];
   1235           squared_l2_norm += val * val;
   1236         }
   1237         float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm);
   1238         for (int c = 0; c < depth; ++c) {
   1239           output_data[Offset(output_dims, c, x, y, b)] =
   1240               input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm;
   1241         }
   1242       }
   1243     }
   1244   }
   1245 }
   1246 
   1247 inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
   1248                                           int* output_shift) {
   1249   *output_shift = 11;
   1250   while (input >= (1 << 29)) {
   1251     input /= 4;
   1252     ++*output_shift;
   1253   }
   1254   TFLITE_DCHECK_GT(input, 0);
   1255   const unsigned max_left_shift_bits = __builtin_clz(input) - 1;
   1256   const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
   1257   const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
   1258   *output_shift -= left_shift_bit_pairs;
   1259   input <<= 2 * left_shift_bit_pairs;
   1260   TFLITE_DCHECK_GE(input, (1 << 27));
   1261   TFLITE_DCHECK_LT(input, (1 << 29));
   1262   using gemmlowp::FixedPoint;
   1263   using gemmlowp::Rescale;
   1264   using gemmlowp::SaturatingRoundingMultiplyByPOT;
   1265   // Using 3 integer bits gives us enough room for the internal arithmetic in
   1266   // this Newton-Raphson iteration.
   1267   using F3 = FixedPoint<int32, 3>;
   1268   using F0 = FixedPoint<int32, 0>;
   1269   const F3 fixedpoint_input = F3::FromRaw(input >> 1);
   1270   const F3 fixedpoint_half_input =
   1271       SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
   1272   const F3 fixedpoint_half_three =
   1273       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
   1274   // Newton-Raphson iteration
   1275   // Naive unoptimized starting guess: x = 1
   1276   F3 x = F3::One();
   1277   // Naive unoptimized number of iterations: 5
   1278   for (int i = 0; i < 5; i++) {
   1279     const F3 x3 = Rescale<3>(x * x * x);
   1280     x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
   1281   }
   1282   const F0 fixedpoint_half_sqrt_2 =
   1283       GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
   1284   x = x * fixedpoint_half_sqrt_2;
   1285   *output_inv_sqrt = x.raw();
   1286   if (*output_shift < 0) {
   1287     *output_inv_sqrt <<= -*output_shift;
   1288     *output_shift = 0;
   1289   }
   1290 }
   1291 
   1292 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
   1293                             int32 input_zero_point, uint8* output_data,
   1294                             const Dims<4>& output_dims) {
   1295   gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
   1296   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   1297   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   1298   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   1299   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   1300   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
   1301   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
   1302   TFLITE_DCHECK_EQ(batches, 1);
   1303   TFLITE_DCHECK_EQ(height, 1);
   1304   TFLITE_DCHECK_EQ(width, 1);
   1305   int32 square_l2_norm = 0;
   1306   for (int i = 0; i < depth; i++) {
   1307     int32 diff = input_data[i] - input_zero_point;
   1308     square_l2_norm += diff * diff;
   1309   }
   1310   int32 inv_l2norm_multiplier;
   1311   int inv_l2norm_shift;
   1312   GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
   1313                                 &inv_l2norm_shift);
   1314 
   1315   for (int i = 0; i < depth; i++) {
   1316     int32 diff = input_data[i] - input_zero_point;
   1317     int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
   1318         128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
   1319     int32 unclamped_output_val = 128 + rescaled_diff;
   1320     int32 output_val = std::min(255, std::max(0, unclamped_output_val));
   1321     output_data[i] = static_cast<uint8>(output_val);
   1322   }
   1323 }
   1324 
   1325 inline void Add(const float* input1_data, const Dims<4>& input1_dims,
   1326                 const float* input2_data, const Dims<4>& input2_dims,
   1327                 float output_activation_min, float output_activation_max,
   1328                 float* output_data, const Dims<4>& output_dims) {
   1329   gemmlowp::ScopedProfilingLabel label("Add");
   1330   /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
   1331                                               output_dims, 3);
   1332   /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
   1333                                              output_dims, 2);
   1334   /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
   1335                                             output_dims, 1);
   1336   /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
   1337                                             output_dims, 0);
   1338   TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
   1339   TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
   1340   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
   1341 
   1342   int i = 0;
   1343   const int size = input1_dims.sizes[3] * input1_dims.strides[3];
   1344 #ifdef USE_NEON
   1345   const auto activation_min = vdupq_n_f32(output_activation_min);
   1346   const auto activation_max = vdupq_n_f32(output_activation_max);
   1347   for (; i <= size - 16; i += 16) {
   1348     auto a10 = vld1q_f32(input1_data + i);
   1349     auto a11 = vld1q_f32(input1_data + i + 4);
   1350     auto a12 = vld1q_f32(input1_data + i + 8);
   1351     auto a13 = vld1q_f32(input1_data + i + 12);
   1352     auto a20 = vld1q_f32(input2_data + i);
   1353     auto a21 = vld1q_f32(input2_data + i + 4);
   1354     auto a22 = vld1q_f32(input2_data + i + 8);
   1355     auto a23 = vld1q_f32(input2_data + i + 12);
   1356     auto x0 = vaddq_f32(a10, a20);
   1357     auto x1 = vaddq_f32(a11, a21);
   1358     auto x2 = vaddq_f32(a12, a22);
   1359     auto x3 = vaddq_f32(a13, a23);
   1360     x0 = vmaxq_f32(activation_min, x0);
   1361     x1 = vmaxq_f32(activation_min, x1);
   1362     x2 = vmaxq_f32(activation_min, x2);
   1363     x3 = vmaxq_f32(activation_min, x3);
   1364     x0 = vminq_f32(activation_max, x0);
   1365     x1 = vminq_f32(activation_max, x1);
   1366     x2 = vminq_f32(activation_max, x2);
   1367     x3 = vminq_f32(activation_max, x3);
   1368     vst1q_f32(output_data + i, x0);
   1369     vst1q_f32(output_data + i + 4, x1);
   1370     vst1q_f32(output_data + i + 8, x2);
   1371     vst1q_f32(output_data + i + 12, x3);
   1372   }
   1373   for (; i <= size - 4; i += 4) {
   1374     auto a1 = vld1q_f32(input1_data + i);
   1375     auto a2 = vld1q_f32(input2_data + i);
   1376     auto x = vaddq_f32(a1, a2);
   1377     x = vmaxq_f32(activation_min, x);
   1378     x = vminq_f32(activation_max, x);
   1379     vst1q_f32(output_data + i, x);
   1380   }
   1381 #endif  // NEON
   1382 
   1383   for (; i < size; i++) {
   1384     auto x = input1_data[i] + input2_data[i];
   1385     output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
   1386                                                   output_activation_max);
   1387   }
   1388 }
   1389 
   1390 // legacy, for compatibility with old checked-in code
   1391 template <FusedActivationFunctionType Ac>
   1392 void Add(const float* input1_data, const Dims<4>& input1_dims,
   1393          const float* input2_data, const Dims<4>& input2_dims,
   1394          float* output_data, const Dims<4>& output_dims) {
   1395   float output_activation_min, output_activation_max;
   1396   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
   1397 
   1398   Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
   1399       output_activation_max, output_data, output_dims);
   1400 }
   1401 
   1402 template <FusedActivationFunctionType Ac>
   1403 inline void Add(int left_shift, const uint8* input1_data,
   1404                 const Dims<4>& input1_dims, int32 input1_offset,
   1405                 int32 input1_multiplier, int input1_shift,
   1406                 const uint8* input2_data, const Dims<4>& input2_dims,
   1407                 int32 input2_offset, int32 input2_multiplier, int input2_shift,
   1408                 int32 output_offset, int32 output_multiplier, int output_shift,
   1409                 int32 output_activation_min, int32 output_activation_max,
   1410                 uint8* output_data, const Dims<4>& output_dims) {
   1411   static_assert(Ac == FusedActivationFunctionType::kNone ||
   1412                     Ac == FusedActivationFunctionType::kRelu ||
   1413                     Ac == FusedActivationFunctionType::kRelu6 ||
   1414                     Ac == FusedActivationFunctionType::kRelu1,
   1415                 "");
   1416   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
   1417   if (Ac == FusedActivationFunctionType::kNone) {
   1418     TFLITE_DCHECK_EQ(output_activation_min, 0);
   1419     TFLITE_DCHECK_EQ(output_activation_max, 255);
   1420   }
   1421   gemmlowp::ScopedProfilingLabel label("Add/8bit");
   1422   /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
   1423                                               output_dims, 3);
   1424   /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
   1425                                              output_dims, 2);
   1426   /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
   1427                                             output_dims, 1);
   1428   /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
   1429                                             output_dims, 0);
   1430   TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
   1431   TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
   1432   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
   1433 
   1434   int i = 0;
   1435   const int size = input1_dims.sizes[3] * input1_dims.strides[3];
   1436   TFLITE_DCHECK_GT(input1_offset, -256);
   1437   TFLITE_DCHECK_GT(input2_offset, -256);
   1438   TFLITE_DCHECK_LT(input1_offset, 256);
   1439   TFLITE_DCHECK_LT(input2_offset, 256);
   1440 #ifdef USE_NEON
   1441   for (; i <= size - 8; i += 8) {
   1442     const auto input1_val_original = vld1_u8(input1_data + i);
   1443     const auto input2_val_original = vld1_u8(input2_data + i);
   1444     const auto input1_val_s16 =
   1445         vreinterpretq_s16_u16(vmovl_u8(input1_val_original));
   1446     const auto input2_val_s16 =
   1447         vreinterpretq_s16_u16(vmovl_u8(input2_val_original));
   1448     const auto input1_val =
   1449         vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset));
   1450     const auto input2_val =
   1451         vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset));
   1452     const auto input1_val_high = vget_high_s16(input1_val);
   1453     const auto input1_val_low = vget_low_s16(input1_val);
   1454     const auto input2_val_high = vget_high_s16(input2_val);
   1455     const auto input2_val_low = vget_low_s16(input2_val);
   1456     auto x11 = vmovl_s16(input1_val_low);
   1457     auto x12 = vmovl_s16(input1_val_high);
   1458     auto x21 = vmovl_s16(input2_val_low);
   1459     auto x22 = vmovl_s16(input2_val_high);
   1460     const auto left_shift_dup = vdupq_n_s32(left_shift);
   1461     x11 = vshlq_s32(x11, left_shift_dup);
   1462     x12 = vshlq_s32(x12, left_shift_dup);
   1463     x21 = vshlq_s32(x21, left_shift_dup);
   1464     x22 = vshlq_s32(x22, left_shift_dup);
   1465     x11 = vqrdmulhq_n_s32(x11, input1_multiplier);
   1466     x12 = vqrdmulhq_n_s32(x12, input1_multiplier);
   1467     x21 = vqrdmulhq_n_s32(x21, input2_multiplier);
   1468     x22 = vqrdmulhq_n_s32(x22, input2_multiplier);
   1469     const auto input1_shift_dup = vdupq_n_s32(-input1_shift);
   1470     const auto input2_shift_dup = vdupq_n_s32(-input2_shift);
   1471     x11 = vshlq_s32(x11, input1_shift_dup);
   1472     x12 = vshlq_s32(x12, input1_shift_dup);
   1473     x21 = vshlq_s32(x21, input2_shift_dup);
   1474     x22 = vshlq_s32(x22, input2_shift_dup);
   1475     auto s1 = vaddq_s32(x11, x21);
   1476     auto s2 = vaddq_s32(x12, x22);
   1477     s1 = vqrdmulhq_n_s32(s1, output_multiplier);
   1478     s2 = vqrdmulhq_n_s32(s2, output_multiplier);
   1479     using gemmlowp::RoundingDivideByPOT;
   1480     s1 = RoundingDivideByPOT(s1, output_shift);
   1481     s2 = RoundingDivideByPOT(s2, output_shift);
   1482     const auto s1_narrowed = vmovn_s32(s1);
   1483     const auto s2_narrowed = vmovn_s32(s2);
   1484     const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed),
   1485                              vdupq_n_s16(output_offset));
   1486     vst1_u8(output_data + i, vqmovun_s16(s));
   1487   }
   1488 #endif  // NEON
   1489 
   1490   for (; i < size; i++) {
   1491     const int32 input1_val = input1_offset + input1_data[i];
   1492     const int32 input2_val = input2_offset + input2_data[i];
   1493     const int32 shifted_input1_val = input1_val * (1 << left_shift);
   1494     const int32 shifted_input2_val = input2_val * (1 << left_shift);
   1495     const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
   1496         shifted_input1_val, input1_multiplier, input1_shift);
   1497     const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
   1498         shifted_input2_val, input2_multiplier, input2_shift);
   1499     const int32 raw_sum = scaled_input1_val + scaled_input2_val;
   1500     const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne(
   1501                                  raw_sum, output_multiplier, output_shift) +
   1502                              output_offset;
   1503     const int32 clamped_output = std::min(
   1504         output_activation_max, std::max(output_activation_min, raw_output));
   1505     output_data[i] = static_cast<uint8>(clamped_output);
   1506   }
   1507 }
   1508 
   1509 template <FusedActivationFunctionType Ac>
   1510 void Add(const int32* input1_data, const Dims<4>& input1_dims,
   1511          const int32* input2_data, const Dims<4>& input2_dims,
   1512          int32* output_data, const Dims<4>& output_dims) {
   1513   gemmlowp::ScopedProfilingLabel label("Add/int32");
   1514   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
   1515 
   1516   auto input1_map = MapAsVector(input1_data, input1_dims);
   1517   auto input2_map = MapAsVector(input2_data, input2_dims);
   1518   auto output_map = MapAsVector(output_data, output_dims);
   1519   if (AreSameDims(input1_dims, input2_dims)) {
   1520     output_map.array() = input1_map.array() + input2_map.array();
   1521   } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
   1522     auto scalar = input2_data[0];
   1523     output_map.array() = input1_map.array() + scalar;
   1524   } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
   1525     auto scalar = input1_data[0];
   1526     output_map.array() = scalar + input2_map.array();
   1527   } else {
   1528     // Should not come here.
   1529     TFLITE_DCHECK(false);
   1530   }
   1531 }
   1532 
   1533 // TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
   1534 // dimensionality if the runtime code does a single loop over one dimension
   1535 // that handles broadcasting as the base case. The code generator would then
   1536 // generate max(D1, D2) nested for loops.
   1537 // TODO(benoitjacob): BroadcastAdd is intentionally duplicated from
   1538 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
   1539 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
   1540 // reference_ops.h.
   1541 template <typename T>
   1542 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
   1543                   const T* input2_data, const Dims<4>& input2_dims,
   1544                   T output_activation_min, T output_activation_max,
   1545                   T* output_data, const Dims<4>& output_dims) {
   1546   gemmlowp::ScopedProfilingLabel label("BroadcastAdd");
   1547 
   1548   NdArrayDesc<4> desc1;
   1549   NdArrayDesc<4> desc2;
   1550   NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
   1551 
   1552   // In Tensorflow, the dimensions are canonically named (batch_number, row,
   1553   // col, channel), with extents (batches, height, width, depth), with the
   1554   // trailing dimension changing most rapidly (channels has the smallest stride,
   1555   // typically 1 element).
   1556   //
   1557   // In generated C code, we store arrays with the dimensions reversed. The
   1558   // first dimension has smallest stride.
   1559   //
   1560   // We name our variables by their Tensorflow convention, but generate C code
   1561   // nesting loops such that the innermost loop has the smallest stride for the
   1562   // best cache behavior.
   1563   for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
   1564     for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
   1565       for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
   1566         for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
   1567           output_data[Offset(output_dims, c, x, y, b)] =
   1568               ActivationFunctionWithMinMax(
   1569                   input1_data[SubscriptToIndex(desc1, c, x, y, b)] +
   1570                       input2_data[SubscriptToIndex(desc2, c, x, y, b)],
   1571                   output_activation_min, output_activation_max);
   1572         }
   1573       }
   1574     }
   1575   }
   1576 }
   1577 
   1578 // legacy, for compatibility with old checked-in code
   1579 template <FusedActivationFunctionType Ac, typename T>
   1580 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
   1581                   const T* input2_data, const Dims<4>& input2_dims,
   1582                   T* output_data, const Dims<4>& output_dims) {
   1583   T output_activation_min, output_activation_max;
   1584   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
   1585 
   1586   BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
   1587                output_activation_min, output_activation_max, output_data,
   1588                output_dims);
   1589 }
   1590 
   1591 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
   1592                          const Dims<4>& input1_dims, int32 input1_offset,
   1593                          int32 input1_multiplier, int input1_shift,
   1594                          const uint8* input2_data, const Dims<4>& input2_dims,
   1595                          int32 input2_offset, int32 input2_multiplier,
   1596                          int input2_shift, int32 output_offset,
   1597                          int32 output_multiplier, int output_shift,
   1598                          int32 output_activation_min,
   1599                          int32 output_activation_max, uint8* output_data,
   1600                          const Dims<4>& output_dims) {
   1601   gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit");
   1602 
   1603   NdArrayDesc<4> desc1;
   1604   NdArrayDesc<4> desc2;
   1605   NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
   1606 
   1607   // In Tensorflow, the dimensions are canonically named (batch_number, row,
   1608   // col, channel), with extents (batches, height, width, depth), with the
   1609   // trailing dimension changing most rapidly (channels has the smallest stride,
   1610   // typically 1 element).
   1611   //
   1612   // In generated C code, we store arrays with the dimensions reversed. The
   1613   // first dimension has smallest stride.
   1614   //
   1615   // We name our variables by their Tensorflow convention, but generate C code
   1616   // nesting loops such that the innermost loop has the smallest stride for the
   1617   // best cache behavior.
   1618   for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
   1619     for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
   1620       for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
   1621         for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
   1622           const int32 input1_val =
   1623               input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
   1624           const int32 input2_val =
   1625               input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
   1626           const int32 shifted_input1_val = input1_val * (1 << left_shift);
   1627           const int32 shifted_input2_val = input2_val * (1 << left_shift);
   1628           const int32 scaled_input1_val =
   1629               MultiplyByQuantizedMultiplierSmallerThanOne(
   1630                   shifted_input1_val, input1_multiplier, input1_shift);
   1631           const int32 scaled_input2_val =
   1632               MultiplyByQuantizedMultiplierSmallerThanOne(
   1633                   shifted_input2_val, input2_multiplier, input2_shift);
   1634           const int32 raw_sum = scaled_input1_val + scaled_input2_val;
   1635           const int32 raw_output =
   1636               MultiplyByQuantizedMultiplierSmallerThanOne(
   1637                   raw_sum, output_multiplier, output_shift) +
   1638               output_offset;
   1639           const int32 clamped_output =
   1640               std::min(output_activation_max,
   1641                        std::max(output_activation_min, raw_output));
   1642           output_data[Offset(output_dims, c, x, y, b)] =
   1643               static_cast<uint8>(clamped_output);
   1644         }
   1645       }
   1646     }
   1647   }
   1648 }
   1649 
   1650 template <FusedActivationFunctionType Ac>
   1651 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
   1652                          const Dims<4>& input1_dims, int32 input1_offset,
   1653                          int32 input1_multiplier, int input1_shift,
   1654                          const uint8* input2_data, const Dims<4>& input2_dims,
   1655                          int32 input2_offset, int32 input2_multiplier,
   1656                          int input2_shift, int32 output_offset,
   1657                          int32 output_multiplier, int output_shift,
   1658                          int32 output_activation_min,
   1659                          int32 output_activation_max, uint8* output_data,
   1660                          const Dims<4>& output_dims) {
   1661   static_assert(Ac == FusedActivationFunctionType::kNone ||
   1662                     Ac == FusedActivationFunctionType::kRelu ||
   1663                     Ac == FusedActivationFunctionType::kRelu6 ||
   1664                     Ac == FusedActivationFunctionType::kRelu1,
   1665                 "");
   1666   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
   1667   if (Ac == FusedActivationFunctionType::kNone) {
   1668     TFLITE_DCHECK_EQ(output_activation_min, 0);
   1669     TFLITE_DCHECK_EQ(output_activation_max, 255);
   1670   }
   1671   BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset,
   1672                input1_multiplier, input1_shift, input2_data, input2_dims,
   1673                input2_offset, input2_multiplier, input2_shift, output_offset,
   1674                output_multiplier, output_shift, output_activation_min,
   1675                output_activation_max, output_data, output_dims);
   1676 }
   1677 
   1678 inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
   1679                 const float* input2_data, const Dims<4>& input2_dims,
   1680                 float output_activation_min, float output_activation_max,
   1681                 float* output_data, const Dims<4>& output_dims) {
   1682   gemmlowp::ScopedProfilingLabel label("Mul");
   1683   /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3,
   1684                                               output_dims, 3);
   1685   /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2,
   1686                                              output_dims, 2);
   1687   /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1,
   1688                                             output_dims, 1);
   1689   /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0,
   1690                                             output_dims, 0);
   1691   TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
   1692   TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
   1693   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
   1694 
   1695   int i = 0;
   1696   const int size = input1_dims.sizes[3] * input1_dims.strides[3];
   1697 #ifdef USE_NEON
   1698   const auto activation_min = vdupq_n_f32(output_activation_min);
   1699   const auto activation_max = vdupq_n_f32(output_activation_max);
   1700   for (; i <= size - 16; i += 16) {
   1701     auto a10 = vld1q_f32(input1_data + i);
   1702     auto a11 = vld1q_f32(input1_data + i + 4);
   1703     auto a12 = vld1q_f32(input1_data + i + 8);
   1704     auto a13 = vld1q_f32(input1_data + i + 12);
   1705     auto a20 = vld1q_f32(input2_data + i);
   1706     auto a21 = vld1q_f32(input2_data + i + 4);
   1707     auto a22 = vld1q_f32(input2_data + i + 8);
   1708     auto a23 = vld1q_f32(input2_data + i + 12);
   1709     auto x0 = vmulq_f32(a10, a20);
   1710     auto x1 = vmulq_f32(a11, a21);
   1711     auto x2 = vmulq_f32(a12, a22);
   1712     auto x3 = vmulq_f32(a13, a23);
   1713 
   1714     x0 = vmaxq_f32(activation_min, x0);
   1715     x1 = vmaxq_f32(activation_min, x1);
   1716     x2 = vmaxq_f32(activation_min, x2);
   1717     x3 = vmaxq_f32(activation_min, x3);
   1718     x0 = vminq_f32(activation_max, x0);
   1719     x1 = vminq_f32(activation_max, x1);
   1720     x2 = vminq_f32(activation_max, x2);
   1721     x3 = vminq_f32(activation_max, x3);
   1722 
   1723     vst1q_f32(output_data + i, x0);
   1724     vst1q_f32(output_data + i + 4, x1);
   1725     vst1q_f32(output_data + i + 8, x2);
   1726     vst1q_f32(output_data + i + 12, x3);
   1727   }
   1728   for (; i <= size - 4; i += 4) {
   1729     auto a1 = vld1q_f32(input1_data + i);
   1730     auto a2 = vld1q_f32(input2_data + i);
   1731     auto x = vmulq_f32(a1, a2);
   1732 
   1733     x = vmaxq_f32(activation_min, x);
   1734     x = vminq_f32(activation_max, x);
   1735 
   1736     vst1q_f32(output_data + i, x);
   1737   }
   1738 #endif  // NEON
   1739 
   1740   for (; i < size; i++) {
   1741     auto x = input1_data[i] * input2_data[i];
   1742     output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min,
   1743                                                   output_activation_max);
   1744   }
   1745 }
   1746 
   1747 // legacy, for compatibility with old checked-in code
   1748 template <FusedActivationFunctionType Ac>
   1749 void Mul(const float* input1_data, const Dims<4>& input1_dims,
   1750          const float* input2_data, const Dims<4>& input2_dims,
   1751          float* output_data, const Dims<4>& output_dims) {
   1752   float output_activation_min, output_activation_max;
   1753   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
   1754 
   1755   Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
   1756       output_activation_max, output_data, output_dims);
   1757 }
   1758 
   1759 template <FusedActivationFunctionType Ac>
   1760 void Mul(const int32* input1_data, const Dims<4>& input1_dims,
   1761          const int32* input2_data, const Dims<4>& input2_dims,
   1762          int32* output_data, const Dims<4>& output_dims) {
   1763   gemmlowp::ScopedProfilingLabel label("Mul/int32");
   1764   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
   1765 
   1766   auto input1_map = MapAsVector(input1_data, input1_dims);
   1767   auto input2_map = MapAsVector(input2_data, input2_dims);
   1768   auto output_map = MapAsVector(output_data, output_dims);
   1769   if (AreSameDims(input1_dims, input2_dims)) {
   1770     output_map.array() = input1_map.array() * input2_map.array();
   1771   } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
   1772     auto scalar = input2_data[0];
   1773     output_map.array() = input1_map.array() * scalar;
   1774   } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
   1775     auto scalar = input1_data[0];
   1776     output_map.array() = scalar * input2_map.array();
   1777   } else {
   1778     // Should not come here.
   1779     TFLITE_DCHECK(false);
   1780   }
   1781 }
   1782 
   1783 // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
   1784 // dimensionality if the runtime code does a single loop over one dimension
   1785 // that handles broadcasting as the base case. The code generator would then
   1786 // generate max(D1, D2) nested for loops.
   1787 // TODO(benoitjacob): BroadcastMul is intentionally duplicated from
   1788 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
   1789 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
   1790 // reference_ops.h.
   1791 template <typename T>
   1792 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
   1793                   const T* input2_data, const Dims<4>& input2_dims,
   1794                   T output_activation_min, T output_activation_max,
   1795                   T* output_data, const Dims<4>& output_dims) {
   1796   gemmlowp::ScopedProfilingLabel label("BroadcastMul");
   1797 
   1798   NdArrayDesc<4> desc1;
   1799   NdArrayDesc<4> desc2;
   1800   NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
   1801 
   1802   // In Tensorflow, the dimensions are canonically named (batch_number, row,
   1803   // col, channel), with extents (batches, height, width, depth), with the
   1804   // trailing dimension changing most rapidly (channels has the smallest stride,
   1805   // typically 1 element).
   1806   //
   1807   // In generated C code, we store arrays with the dimensions reversed. The
   1808   // first dimension has smallest stride.
   1809   //
   1810   // We name our variables by their Tensorflow convention, but generate C code
   1811   // nesting loops such that the innermost loop has the smallest stride for the
   1812   // best cache behavior.
   1813   for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
   1814     for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
   1815       for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
   1816         for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
   1817           output_data[Offset(output_dims, c, x, y, b)] =
   1818               ActivationFunctionWithMinMax(
   1819                   input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
   1820                       input2_data[SubscriptToIndex(desc2, c, x, y, b)],
   1821                   output_activation_min, output_activation_max);
   1822         }
   1823       }
   1824     }
   1825   }
   1826 }
   1827 
   1828 // legacy, for compatibility with old checked-in code
   1829 template <FusedActivationFunctionType Ac, typename T>
   1830 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
   1831                   const T* input2_data, const Dims<4>& input2_dims,
   1832                   T* output_data, const Dims<4>& output_dims) {
   1833   T output_activation_min, output_activation_max;
   1834   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
   1835 
   1836   BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
   1837                output_activation_min, output_activation_max, output_data,
   1838                output_dims);
   1839 }
   1840 
   1841 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
   1842                          int32 input1_offset, const uint8* input2_data,
   1843                          const Dims<4>& input2_dims, int32 input2_offset,
   1844                          int32 output_offset, int32 output_multiplier,
   1845                          int output_shift, int32 output_activation_min,
   1846                          int32 output_activation_max, uint8* output_data,
   1847                          const Dims<4>& output_dims) {
   1848   gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
   1849 
   1850   NdArrayDesc<4> desc1;
   1851   NdArrayDesc<4> desc2;
   1852   NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
   1853 
   1854   // In Tensorflow, the dimensions are canonically named (batch_number, row,
   1855   // col, channel), with extents (batches, height, width, depth), with the
   1856   // trailing dimension changing most rapidly (channels has the smallest stride,
   1857   // typically 1 element).
   1858   //
   1859   // In generated C code, we store arrays with the dimensions reversed. The
   1860   // first dimension has smallest stride.
   1861   //
   1862   // We name our variables by their Tensorflow convention, but generate C code
   1863   // nesting loops such that the innermost loop has the smallest stride for the
   1864   // best cache behavior.
   1865   for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
   1866     for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
   1867       for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
   1868         for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
   1869           const int32 input1_val =
   1870               input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
   1871           const int32 input2_val =
   1872               input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
   1873           const int32 unclamped_result =
   1874               output_offset +
   1875               MultiplyByQuantizedMultiplierSmallerThanOne(
   1876                   input1_val * input2_val, output_multiplier, output_shift);
   1877           const int32 clamped_output =
   1878               std::min(output_activation_max,
   1879                        std::max(output_activation_min, unclamped_result));
   1880           output_data[Offset(output_dims, c, x, y, b)] =
   1881               static_cast<uint8>(clamped_output);
   1882         }
   1883       }
   1884     }
   1885   }
   1886 }
   1887 
   1888 // legacy, for compatibility with old checked-in code
   1889 template <FusedActivationFunctionType Ac>
   1890 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
   1891                          int32 input1_offset, const uint8* input2_data,
   1892                          const Dims<4>& input2_dims, int32 input2_offset,
   1893                          int32 output_offset, int32 output_multiplier,
   1894                          int output_shift, int32 output_activation_min,
   1895                          int32 output_activation_max, uint8* output_data,
   1896                          const Dims<4>& output_dims) {
   1897   BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
   1898                input2_dims, input2_offset, output_offset, output_multiplier,
   1899                output_shift, output_activation_min, output_activation_max,
   1900                output_data, output_dims);
   1901 }
   1902 
   1903 // TODO(aselle): This is not actually optimized yet.
   1904 inline void Div(const float* input1_data, const Dims<4>& input1_dims,
   1905                 const float* input2_data, const Dims<4>& input2_dims,
   1906                 float output_activation_min, float output_activation_max,
   1907                 float* output_data, const Dims<4>& output_dims) {
   1908   const int batches =
   1909       MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
   1910   const int height =
   1911       MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
   1912   const int width =
   1913       MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
   1914   const int depth =
   1915       MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
   1916   for (int b = 0; b < batches; ++b) {
   1917     for (int y = 0; y < height; ++y) {
   1918       for (int x = 0; x < width; ++x) {
   1919         for (int c = 0; c < depth; ++c) {
   1920           output_data[Offset(output_dims, c, x, y, b)] =
   1921               ActivationFunctionWithMinMax(
   1922                   input1_data[Offset(input1_dims, c, x, y, b)] /
   1923                       input2_data[Offset(input2_dims, c, x, y, b)],
   1924                   output_activation_min, output_activation_max);
   1925         }
   1926       }
   1927     }
   1928   }
   1929 }
   1930 
   1931 // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
   1932 // dimensionality if the runtime code does a single loop over one dimension
   1933 // that handles broadcasting as the base case. The code generator would then
   1934 // generate max(D1, D2) nested for loops.
   1935 // TODO(benoitjacob): BroadcastDiv is intentionally duplicated from
   1936 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
   1937 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
   1938 // reference_ops.h.
   1939 template <typename T>
   1940 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
   1941                   const T* input2_data, const Dims<4>& input2_dims,
   1942                   T output_activation_min, T output_activation_max,
   1943                   T* output_data, const Dims<4>& output_dims) {
   1944   gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
   1945 
   1946   NdArrayDesc<4> desc1;
   1947   NdArrayDesc<4> desc2;
   1948   NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
   1949 
   1950   // In Tensorflow, the dimensions are canonically named (batch_number, row,
   1951   // col, channel), with extents (batches, height, width, depth), with the
   1952   // trailing dimension changing most rapidly (channels has the smallest stride,
   1953   // typically 1 element).
   1954   //
   1955   // In generated C code, we store arrays with the dimensions reversed. The
   1956   // first dimension has smallest stride.
   1957   //
   1958   // We name our variables by their Tensorflow convention, but generate C code
   1959   // nesting loops such that the innermost loop has the smallest stride for the
   1960   // best cache behavior.
   1961   for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
   1962     for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
   1963       for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
   1964         for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
   1965           output_data[Offset(output_dims, c, x, y, b)] =
   1966               ActivationFunctionWithMinMax(
   1967                   input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
   1968                       input2_data[SubscriptToIndex(desc2, c, x, y, b)],
   1969                   output_activation_min, output_activation_max);
   1970         }
   1971       }
   1972     }
   1973   }
   1974 }
   1975 
   1976 // TODO(aselle): This is not actually optimized yet.
   1977 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
   1978                 const float* input2_data, const Dims<4>& input2_dims,
   1979                 float output_activation_min, float output_activation_max,
   1980                 float* output_data, const Dims<4>& output_dims) {
   1981   const int batches =
   1982       MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
   1983   const int height =
   1984       MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
   1985   const int width =
   1986       MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
   1987   const int depth =
   1988       MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
   1989   for (int b = 0; b < batches; ++b) {
   1990     for (int y = 0; y < height; ++y) {
   1991       for (int x = 0; x < width; ++x) {
   1992         for (int c = 0; c < depth; ++c) {
   1993           output_data[Offset(output_dims, c, x, y, b)] =
   1994               ActivationFunctionWithMinMax(
   1995                   input1_data[Offset(input1_dims, c, x, y, b)] -
   1996                       input2_data[Offset(input2_dims, c, x, y, b)],
   1997                   output_activation_min, output_activation_max);
   1998         }
   1999       }
   2000     }
   2001   }
   2002 }
   2003 
   2004 // TODO(jiawen): We can implement BroadcastSub on buffers of arbitrary
   2005 // dimensionality if the runtime code does a single loop over one dimension
   2006 // that handles broadcasting as the base case. The code generator would then
   2007 // generate max(D1, D2) nested for loops.
   2008 // TODO(benoitjacob): BroadcastSub is intentionally duplicated from
   2009 // reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
   2010 // is no longer referenced in this file, move NdArrayDesc<T> from types.h to
   2011 // reference_ops.h.
   2012 template <typename T>
   2013 void BroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
   2014                   const T* input2_data, const Dims<4>& input2_dims,
   2015                   T output_activation_min, T output_activation_max,
   2016                   T* output_data, const Dims<4>& output_dims) {
   2017   gemmlowp::ScopedProfilingLabel label("BroadcastSub");
   2018 
   2019   NdArrayDesc<4> desc1;
   2020   NdArrayDesc<4> desc2;
   2021   NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
   2022 
   2023   // In Tensorflow, the dimensions are canonically named (batch_number, row,
   2024   // col, channel), with extents (batches, height, width, depth), with the
   2025   // trailing dimension changing most rapidly (channels has the smallest stride,
   2026   // typically 1 element).
   2027   //
   2028   // In generated C code, we store arrays with the dimensions reversed. The
   2029   // first dimension has smallest stride.
   2030   //
   2031   // We name our variables by their Tensorflow convention, but generate C code
   2032   // nesting loops such that the innermost loop has the smallest stride for the
   2033   // best cache behavior.
   2034   for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
   2035     for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
   2036       for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
   2037         for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
   2038           output_data[Offset(output_dims, c, x, y, b)] =
   2039               ActivationFunctionWithMinMax(
   2040                   input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
   2041                       input2_data[SubscriptToIndex(desc2, c, x, y, b)],
   2042                   output_activation_min, output_activation_max);
   2043         }
   2044       }
   2045     }
   2046   }
   2047 }
   2048 
   2049 template <FusedActivationFunctionType Ac, typename Scalar>
   2050 void Concatenation(int concat_dim, const Scalar* const* input_data,
   2051                    const Dims<4>* const* input_dims, int inputs_count,
   2052                    Scalar* output_data, const Dims<4>& output_dims) {
   2053   gemmlowp::ScopedProfilingLabel label("Concatenation");
   2054   int concat_size = 0;
   2055   for (int i = 0; i < inputs_count; i++) {
   2056     for (int j = 0; j < 4; j++) {
   2057       if (j != concat_dim) {
   2058         MatchingArraySize(*input_dims[i], j, output_dims, j);
   2059       }
   2060     }
   2061     concat_size += ArraySize(*input_dims[i], concat_dim);
   2062   }
   2063   TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
   2064   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
   2065   // for now we dont have a model with a Concatenation
   2066   // with fused activation function.
   2067   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
   2068   int outer_size = 1;
   2069   for (int i = concat_dim + 1; i < 4; i++) {
   2070     outer_size *= output_dims.sizes[i];
   2071   }
   2072   Scalar* output_ptr = output_data;
   2073   for (int k = 0; k < outer_size; k++) {
   2074     for (int i = 0; i < inputs_count; ++i) {
   2075       const int copy_size =
   2076           input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
   2077       memcpy(output_ptr, input_data[i] + k * copy_size,
   2078              copy_size * sizeof(Scalar));
   2079       output_ptr += copy_size;
   2080     }
   2081   }
   2082 }
   2083 
   2084 template <FusedActivationFunctionType Ac, typename Scalar>
   2085 void DepthConcatenation(const Scalar* const* input_data,
   2086                         const Dims<4>* const* input_dims, int inputs_count,
   2087                         Scalar* output_data, const Dims<4>& output_dims) {
   2088   Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
   2089                             output_data, output_dims);
   2090 }
   2091 
   2092 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
   2093                      const float* prev_activ_data,
   2094                      const Dims<4>& prev_activ_dims, const float* weights_data,
   2095                      const Dims<4>& weights_dims, const float* bias_data,
   2096                      const Dims<4>& bias_dims, const float* prev_state_data,
   2097                      const Dims<4>& prev_state_dims, float* output_state_data,
   2098                      const Dims<4>& output_state_dims, float* output_activ_data,
   2099                      const Dims<4>& output_activ_dims, float* concat_temp_data,
   2100                      const Dims<4>& concat_temp_dims, float* activ_temp_data,
   2101                      const Dims<4>& activ_temp_dims) {
   2102   gemmlowp::ScopedProfilingLabel label("LstmCell");
   2103   MatchingArraySize(  // batches
   2104       input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims,
   2105       3, output_activ_dims, 3);
   2106   MatchingArraySize(  // height
   2107       input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims,
   2108       2, output_activ_dims, 2);
   2109   MatchingArraySize(  // width
   2110       input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims,
   2111       1, output_activ_dims, 1);
   2112   TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
   2113   TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
   2114   const int input_depth = ArraySize(input_dims, 0);
   2115   const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
   2116   const int total_input_depth = prev_activ_depth + input_depth;
   2117   TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
   2118   TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
   2119                   1);
   2120   const int intern_activ_depth =
   2121       MatchingArraySize(weights_dims, 1, bias_dims, 0);
   2122   TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
   2123   const int output_depth =
   2124       MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
   2125                         output_state_dims, 0, output_activ_dims, 0);
   2126   TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
   2127 
   2128   // Concatenate prev_activ and input data together
   2129   std::vector<float const*> concat_input_arrays_data;
   2130   std::vector<Dims<4> const*> concat_input_arrays_dims;
   2131   concat_input_arrays_data.push_back(input_data);
   2132   concat_input_arrays_data.push_back(prev_activ_data);
   2133   concat_input_arrays_dims.push_back(&input_dims);
   2134   concat_input_arrays_dims.push_back(&prev_activ_dims);
   2135   Concatenation<FusedActivationFunctionType::kNone, float>(
   2136       0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
   2137       concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
   2138 
   2139   // Fully connected
   2140   FullyConnected<FusedActivationFunctionType::kNone>(
   2141       concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
   2142       bias_dims, activ_temp_data, activ_temp_dims);
   2143 
   2144   // Map raw arrays to Eigen arrays so we can use Eigen's optimized array
   2145   // operations.
   2146   ArrayMap<float> activ_temp_map =
   2147       MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims);
   2148   auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
   2149                                             activ_temp_map.cols());
   2150   auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
   2151                                            activ_temp_map.cols());
   2152   auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth,
   2153                                              activ_temp_map.cols());
   2154   auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
   2155                                              activ_temp_map.cols());
   2156   ArrayMap<const float> prev_state_map =
   2157       MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims);
   2158   ArrayMap<float> output_state_map =
   2159       MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims);
   2160   ArrayMap<float> output_activ_map =
   2161       MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims);
   2162 
   2163   // Combined memory state and final output calculation
   2164   gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
   2165   output_state_map =
   2166       input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
   2167           new_input_sm.tanh() +
   2168       forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
   2169           prev_state_map;
   2170   output_activ_map =
   2171       output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()) *
   2172       output_state_map.tanh();
   2173 }
   2174 
   2175 #ifdef GEMMLOWP_NEON
   2176 // In the common case of batch size 1, a fully-connected node degenerates
   2177 // to a matrix*vector product. LSTM cells contain a fully-connected node;
   2178 // when quantized, this becomes a special type of GEMV operation where
   2179 // the output is 16bit-quantized, thus needs its own special path.
   2180 inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
   2181                             const uint8* weights_data,
   2182                             const Dims<4>& weights_dims,
   2183                             uint8 weights_zero_point, const int32* bias_data,
   2184                             const Dims<4>& bias_dims, int32 accum_multiplier,
   2185                             int accum_shift, int16* output_data,
   2186                             const Dims<4>& output_dims) {
   2187   gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell");
   2188   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
   2189   TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
   2190   TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
   2191   TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
   2192   TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
   2193                        ArraySize(output_dims, 3),
   2194                    1);
   2195   const int input_size = input_dims.strides[3];
   2196   const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
   2197   // This special fast path for quantized LSTM cells does not try to support
   2198   // odd sizes that we haven't encountered in any LSTM cell, that would
   2199   // require special code (that would go untested until any LSTM cell
   2200   // exercises it). We just guard our assumptions about size evenness with
   2201   // the following assertions.
   2202   TFLITE_DCHECK(!(output_size % 4));
   2203   TFLITE_DCHECK(!(input_size % 8));
   2204   const int32* bias_ptr = bias_data;
   2205   int16* output_ptr = output_data;
   2206   for (int out = 0; out < output_size; out += 4) {
   2207     int32x4_t acc_0 = vdupq_n_s32(0);
   2208     int32x4_t acc_1 = vdupq_n_s32(0);
   2209     int32x4_t acc_2 = vdupq_n_s32(0);
   2210     int32x4_t acc_3 = vdupq_n_s32(0);
   2211     const int16x8_t input_offset_vec = vdupq_n_s16(-128);
   2212     const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point);
   2213     int in = 0;
   2214     // Handle 16 levels of depth at a time.
   2215     for (; in <= input_size - 16; in += 16) {
   2216       const uint8x16_t input_val_u8 = vld1q_u8(input_data + in);
   2217       const uint8* weights_ptr = weights_data + in + out * input_size;
   2218       uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size);
   2219       uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size);
   2220       uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size);
   2221       uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size);
   2222       int16x8_t input_val_0, input_val_1;
   2223       const uint8x8_t low = vget_low_u8(input_val_u8);
   2224       const uint8x8_t high = vget_high_u8(input_val_u8);
   2225       input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low));
   2226       input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high));
   2227       input_val_0 = vaddq_s16(input_val_0, input_offset_vec);
   2228       input_val_1 = vaddq_s16(input_val_1, input_offset_vec);
   2229       int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0,
   2230           weights_val_3_0;
   2231       int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1,
   2232           weights_val_3_1;
   2233       weights_val_0_0 = vaddq_s16(
   2234           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))),
   2235           weights_offset_vec);
   2236       weights_val_0_1 = vaddq_s16(
   2237           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))),
   2238           weights_offset_vec);
   2239       weights_val_1_0 = vaddq_s16(
   2240           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))),
   2241           weights_offset_vec);
   2242       weights_val_1_1 = vaddq_s16(
   2243           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))),
   2244           weights_offset_vec);
   2245       weights_val_2_0 = vaddq_s16(
   2246           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))),
   2247           weights_offset_vec);
   2248       weights_val_2_1 = vaddq_s16(
   2249           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))),
   2250           weights_offset_vec);
   2251       weights_val_3_0 = vaddq_s16(
   2252           vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))),
   2253           weights_offset_vec);
   2254       weights_val_3_1 = vaddq_s16(
   2255           vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))),
   2256           weights_offset_vec);
   2257       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0),
   2258                         vget_low_s16(input_val_0));
   2259       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0),
   2260                         vget_low_s16(input_val_0));
   2261       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0),
   2262                         vget_low_s16(input_val_0));
   2263       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0),
   2264                         vget_low_s16(input_val_0));
   2265       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0),
   2266                         vget_high_s16(input_val_0));
   2267       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0),
   2268                         vget_high_s16(input_val_0));
   2269       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0),
   2270                         vget_high_s16(input_val_0));
   2271       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0),
   2272                         vget_high_s16(input_val_0));
   2273       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1),
   2274                         vget_low_s16(input_val_1));
   2275       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1),
   2276                         vget_low_s16(input_val_1));
   2277       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1),
   2278                         vget_low_s16(input_val_1));
   2279       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1),
   2280                         vget_low_s16(input_val_1));
   2281       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1),
   2282                         vget_high_s16(input_val_1));
   2283       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1),
   2284                         vget_high_s16(input_val_1));
   2285       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1),
   2286                         vget_high_s16(input_val_1));
   2287       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1),
   2288                         vget_high_s16(input_val_1));
   2289     }
   2290     // Handle 8 levels of depth at a time.
   2291     for (; in < input_size; in += 8) {
   2292       const uint8x8_t input_val_u8 = vld1_u8(input_data + in);
   2293       const uint8* weights_ptr = weights_data + in + out * input_size;
   2294       uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size);
   2295       uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size);
   2296       uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size);
   2297       uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size);
   2298       int16x8_t input_val;
   2299       input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8));
   2300       input_val = vaddq_s16(input_val, input_offset_vec);
   2301       int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3;
   2302       weights_val_0 =
   2303           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)),
   2304                     weights_offset_vec);
   2305       weights_val_1 =
   2306           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)),
   2307                     weights_offset_vec);
   2308       weights_val_2 =
   2309           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)),
   2310                     weights_offset_vec);
   2311       weights_val_3 =
   2312           vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)),
   2313                     weights_offset_vec);
   2314       acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0),
   2315                         vget_low_s16(input_val));
   2316       acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1),
   2317                         vget_low_s16(input_val));
   2318       acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2),
   2319                         vget_low_s16(input_val));
   2320       acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3),
   2321                         vget_low_s16(input_val));
   2322       acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0),
   2323                         vget_high_s16(input_val));
   2324       acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1),
   2325                         vget_high_s16(input_val));
   2326       acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2),
   2327                         vget_high_s16(input_val));
   2328       acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3),
   2329                         vget_high_s16(input_val));
   2330     }
   2331     // Horizontally reduce accumulators
   2332     int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
   2333         pairwise_reduced_acc_2, pairwise_reduced_acc_3;
   2334     pairwise_reduced_acc_0 =
   2335         vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
   2336     pairwise_reduced_acc_1 =
   2337         vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
   2338     pairwise_reduced_acc_2 =
   2339         vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
   2340     pairwise_reduced_acc_3 =
   2341         vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
   2342     const int32x2_t reduced_lo =
   2343         vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
   2344     const int32x2_t reduced_hi =
   2345         vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
   2346     int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
   2347     // Add bias values.
   2348     int32x4_t bias_vec = vld1q_s32(bias_ptr);
   2349     bias_ptr += 4;
   2350     reduced = vaddq_s32(reduced, bias_vec);
   2351     int left_shift = accum_shift > 0 ? accum_shift : 0;
   2352     int right_shift = accum_shift > 0 ? 0 : -accum_shift;
   2353     reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
   2354     // Multiply by the fixed-point multiplier.
   2355     reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
   2356     // Rounding-shift-right.
   2357     using gemmlowp::RoundingDivideByPOT;
   2358     reduced = RoundingDivideByPOT(reduced, right_shift);
   2359     // Narrow values down to 16 bit signed.
   2360     const int16x4_t res16 = vqmovn_s32(reduced);
   2361     vst1_s16(output_ptr, res16);
   2362     output_ptr += 4;
   2363   }
   2364 }
   2365 #endif
   2366 
   2367 // Quantized LSTM cell. Currently just a copy of the reference impl in
   2368 // reference_ops.h. See the big function comment there, not replicating it
   2369 // here.
   2370 template <int StateIntegerBits>
   2371 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
   2372               const uint8* prev_activ_data_uint8,
   2373               const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
   2374               const Dims<4>& weights_dims, const int32* bias_data_int32,
   2375               const Dims<4>& bias_dims, const int16* prev_state_data_int16,
   2376               const Dims<4>& prev_state_dims, int16* output_state_data_int16,
   2377               const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
   2378               const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
   2379               const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
   2380               const Dims<4>& activ_temp_dims, int32 weights_zero_point,
   2381               int32 accum_multiplier, int accum_shift,
   2382               gemmlowp::GemmContext* gemm_context) {
   2383   gemmlowp::ScopedProfilingLabel label(
   2384       "LstmCell/quantized (8bit external, 16bit internal)");
   2385   // Gather dimensions information, and perform consistency checks.
   2386   const int batches =
   2387       MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
   2388                         output_state_dims, 3, output_activ_dims, 3);
   2389   const int height =
   2390       MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
   2391                         output_state_dims, 2, output_activ_dims, 2);
   2392   const int width =
   2393       MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
   2394                         output_state_dims, 1, output_activ_dims, 1);
   2395   TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
   2396   TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
   2397   const int input_depth = ArraySize(input_dims, 0);
   2398   const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
   2399   const int total_input_depth = prev_activ_depth + input_depth;
   2400   TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
   2401   TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
   2402                   1);
   2403   const int intern_activ_depth =
   2404       MatchingArraySize(weights_dims, 1, bias_dims, 0);
   2405   TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
   2406   const int output_depth =
   2407       MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
   2408                         output_state_dims, 0, output_activ_dims, 0);
   2409   TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
   2410   const int fc_batches = ArraySize(activ_temp_dims, 1) *
   2411                          ArraySize(activ_temp_dims, 2) *
   2412                          ArraySize(activ_temp_dims, 3);
   2413   const int fc_output_depth =
   2414       MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
   2415   const int fc_accum_depth = ArraySize(weights_dims, 0);
   2416   TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
   2417 
   2418   // Depth-concatenate prev_activ and input data together.
   2419   uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
   2420                                               prev_activ_data_uint8};
   2421   Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
   2422   Concatenation<FusedActivationFunctionType::kNone, uint8>(
   2423       0, concat_input_arrays_data, concat_input_arrays_dims, 2,
   2424       concat_temp_data_uint8, concat_temp_dims);
   2425 
   2426   // Implementation of the fully connected node inside the LSTM cell.
   2427   // The operands are 8-bit integers, the accumulators are internally 32bit
   2428   // integers, and the output is 16-bit fixed-point with 3 integer bits so
   2429   // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
   2430   // is explained in the function comment above.
   2431   bool gemm_already_performed = false;
   2432 #ifdef GEMMLOWP_NEON
   2433   if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
   2434     GEMVForLstmCell(concat_temp_data_uint8, concat_temp_dims,
   2435                     weights_data_uint8, weights_dims, weights_zero_point,
   2436                     bias_data_int32, bias_dims, accum_multiplier, accum_shift,
   2437                     activ_temp_data_int16, activ_temp_dims);
   2438     gemm_already_performed = true;
   2439   }
   2440 #endif
   2441   if (!gemm_already_performed) {
   2442     gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor>
   2443         weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth);
   2444     gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
   2445         concat_temp_data_uint8, fc_accum_depth, fc_batches);
   2446     gemmlowp::MatrixMap<int16, gemmlowp::MapOrder::ColMajor> output_matrix(
   2447         activ_temp_data_int16, fc_output_depth, fc_batches);
   2448     typedef gemmlowp::VectorMap<const int32, gemmlowp::VectorShape::Col>
   2449         ColVectorMap;
   2450     ColVectorMap bias_vector(bias_data_int32, fc_output_depth);
   2451     gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
   2452     bias_addition_stage.bias_vector = bias_vector;
   2453     gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage;
   2454     scale_stage.result_offset_after_shift = 0;
   2455     scale_stage.result_fixedpoint_multiplier = accum_multiplier;
   2456     scale_stage.result_exponent = accum_shift;
   2457     gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage;
   2458     auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage,
   2459                                            saturating_cast_int16_stage);
   2460     gemmlowp::GemmWithOutputPipeline<
   2461         uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
   2462         gemm_context, weights_matrix, input_matrix, &output_matrix,
   2463         -weights_zero_point, -128, output_pipeline);
   2464   }
   2465 
   2466   // Rest of the LSTM cell: tanh and logistic math functions, and some adds
   2467   // and muls, all done in 16-bit fixed-point.
   2468   const int outer_size = batches * width * height;
   2469   const int16* input_gate_input_ptr = activ_temp_data_int16;
   2470   const int16* input_modulation_gate_input_ptr =
   2471       activ_temp_data_int16 + output_depth;
   2472   const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth;
   2473   const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth;
   2474   const int16* prev_state_ptr = prev_state_data_int16;
   2475   int16* output_state_data_ptr = output_state_data_int16;
   2476   uint8* output_activ_data_ptr = output_activ_data_uint8;
   2477 
   2478   for (int b = 0; b < outer_size; ++b) {
   2479     int c = 0;
   2480 #ifdef GEMMLOWP_NEON
   2481     for (; c <= output_depth - 8; c += 8) {
   2482       // Define the fixed-point data types that we will use here. All use
   2483       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
   2484       // They only differ by the number of integral vs. fractional bits,
   2485       // determining the range of values that they can represent.
   2486       //
   2487       // F0 uses 0 integer bits, range [-1, 1].
   2488       // This is the return type of math functions such as tanh, logistic,
   2489       // whose range is in [-1, 1].
   2490       using F0 = gemmlowp::FixedPoint<int16x8_t, 0>;
   2491       // F3 uses 3 integer bits, range [-8, 8].
   2492       // This is the range of the previous fully-connected node's output,
   2493       // which is our input here.
   2494       using F3 = gemmlowp::FixedPoint<int16x8_t, 3>;
   2495       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
   2496       // 2^StateIntegerBits]. It's used to represent the internal state, whose
   2497       // number of integer bits is currently dictated by the model. See comment
   2498       // on the StateIntegerBits template parameter above.
   2499       using FS = gemmlowp::FixedPoint<int16x8_t, StateIntegerBits>;
   2500       // Implementation of input gate, using fixed-point logistic function.
   2501       F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr));
   2502       input_gate_input_ptr += 8;
   2503       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
   2504       // Implementation of input modulation gate, using fixed-point tanh
   2505       // function.
   2506       F3 input_modulation_gate_input =
   2507           F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr));
   2508       input_modulation_gate_input_ptr += 8;
   2509       F0 input_modulation_gate_output =
   2510           gemmlowp::tanh(input_modulation_gate_input);
   2511       // Implementation of forget gate, using fixed-point logistic function.
   2512       F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr));
   2513       forget_gate_input_ptr += 8;
   2514       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
   2515       // Implementation of output gate, using fixed-point logistic function.
   2516       F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr));
   2517       output_gate_input_ptr += 8;
   2518       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
   2519       // Implementation of internal multiplication nodes, still in fixed-point.
   2520       F0 input_times_input_modulation =
   2521           input_gate_output * input_modulation_gate_output;
   2522       FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr));
   2523       prev_state_ptr += 8;
   2524       FS prev_state_times_forget_state = forget_gate_output * prev_state;
   2525       // Implementation of internal addition node, saturating.
   2526       FS new_state = gemmlowp::SaturatingAdd(
   2527           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
   2528           prev_state_times_forget_state);
   2529       // Implementation of last internal tanh node, still in fixed-point.
   2530       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state);
   2531       // Store the new internal state back to memory, as 16-bit integers.
   2532       vst1q_s16(output_state_data_ptr, new_state.raw());
   2533       output_state_data_ptr += 8;
   2534       // Down-scale the output activations to 8-bit integers, saturating,
   2535       // and store back to memory.
   2536       int16x8_t rescaled_output_activ =
   2537           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
   2538       int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ);
   2539       uint8x8_t uint8_output_activ =
   2540           vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ));
   2541       vst1_u8(output_activ_data_ptr, uint8_output_activ);
   2542       output_activ_data_ptr += 8;
   2543     }
   2544 #endif
   2545     for (; c < output_depth; ++c) {
   2546       // Define the fixed-point data types that we will use here. All use
   2547       // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
   2548       // They only differ by the number of integral vs. fractional bits,
   2549       // determining the range of values that they can represent.
   2550       //
   2551       // F0 uses 0 integer bits, range [-1, 1].
   2552       // This is the return type of math functions such as tanh, logistic,
   2553       // whose range is in [-1, 1].
   2554       using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
   2555       // F3 uses 3 integer bits, range [-8, 8].
   2556       // This is the range of the previous fully-connected node's output,
   2557       // which is our input here.
   2558       using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
   2559       // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
   2560       // 2^StateIntegerBits]. It's used to represent the internal state, whose
   2561       // number of integer bits is currently dictated by the model. See comment
   2562       // on the StateIntegerBits template parameter above.
   2563       using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
   2564       // Implementation of input gate, using fixed-point logistic function.
   2565       F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++);
   2566       F0 input_gate_output = gemmlowp::logistic(input_gate_input);
   2567       // Implementation of input modulation gate, using fixed-point tanh
   2568       // function.
   2569       F3 input_modulation_gate_input =
   2570           F3::FromRaw(*input_modulation_gate_input_ptr++);
   2571       F0 input_modulation_gate_output =
   2572           gemmlowp::tanh(input_modulation_gate_input);
   2573       // Implementation of forget gate, using fixed-point logistic function.
   2574       F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++);
   2575       F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
   2576       // Implementation of output gate, using fixed-point logistic function.
   2577       F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++);
   2578       F0 output_gate_output = gemmlowp::logistic(output_gate_input);
   2579       // Implementation of internal multiplication nodes, still in fixed-point.
   2580       F0 input_times_input_modulation =
   2581           input_gate_output * input_modulation_gate_output;
   2582       FS prev_state = FS::FromRaw(*prev_state_ptr++);
   2583       FS prev_state_times_forget_state = forget_gate_output * prev_state;
   2584       // Implementation of internal addition node, saturating.
   2585       FS new_state = gemmlowp::SaturatingAdd(
   2586           gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
   2587           prev_state_times_forget_state);
   2588       // Implementation of last internal tanh node, still in fixed-point.
   2589       F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state);
   2590       // Store the new internal state back to memory, as 16-bit integers.
   2591       *output_state_data_ptr++ = new_state.raw();
   2592       // Down-scale the output activations to 8-bit integers, saturating,
   2593       // and store back to memory.
   2594       int16 rescaled_output_activ =
   2595           gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
   2596       int16 clamped_output_activ =
   2597           std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
   2598       *output_activ_data_ptr++ = 128 + clamped_output_activ;
   2599     }
   2600     input_gate_input_ptr += 3 * output_depth;
   2601     input_modulation_gate_input_ptr += 3 * output_depth;
   2602     forget_gate_input_ptr += 3 * output_depth;
   2603     output_gate_input_ptr += 3 * output_depth;
   2604   }
   2605 }
   2606 
   2607 template <FusedActivationFunctionType Ac, typename Scalar>
   2608 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
   2609                      int outputs_count, Scalar* const* output_data,
   2610                      const Dims<4>* const* output_dims) {
   2611   gemmlowp::ScopedProfilingLabel label("TensorFlowSplit");
   2612   TFLITE_DCHECK_GE(outputs_count, 1);
   2613   for (int i = 0; i < outputs_count; i++) {
   2614     /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
   2615     /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
   2616     /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
   2617   }
   2618   const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3);
   2619   const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2);
   2620   const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1);
   2621   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
   2622   // for now we dont have a model with a TensorFlowSplit
   2623   // with fused activation function.
   2624   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
   2625   const int whb = width * height * batches;
   2626   const Scalar* input_ptr = input_data;
   2627   for (int k = 0; k < whb; k++) {
   2628     for (int i = 0; i < outputs_count; ++i) {
   2629       memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr,
   2630              output_dims[i]->sizes[0] * sizeof(Scalar));
   2631       input_ptr += output_dims[i]->sizes[0];
   2632     }
   2633   }
   2634 }
   2635 
   2636 inline int NodeOffset(int b, int h, int w, int height, int width) {
   2637   return (b * height + h) * width + w;
   2638 }
   2639 
   2640 inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
   2641                         int stride_width, int stride_height, int pad_width,
   2642                         int pad_height, int kwidth, int kheight,
   2643                         float output_activation_min,
   2644                         float output_activation_max, float* output_data,
   2645                         const Dims<4>& output_dims) {
   2646   gemmlowp::ScopedProfilingLabel label("AveragePool");
   2647   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   2648   const int input_height = ArraySize(input_dims, 2);
   2649   const int input_width = ArraySize(input_dims, 1);
   2650   const int output_height = ArraySize(output_dims, 2);
   2651   const int output_width = ArraySize(output_dims, 1);
   2652   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   2653 
   2654   // TODO(benoitjacob) make this a proper reference impl without Eigen!
   2655   const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
   2656   auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
   2657   // TODO(benoitjacob) get rid of the dynamic memory allocation here!
   2658   Eigen::VectorXf out_count(out_mat.cols());
   2659   out_count.setZero();
   2660   // Prefill the output to 0.
   2661   out_mat.setZero();
   2662   for (int b = 0; b < batches; ++b) {
   2663     for (int h = 0; h < input_height; ++h) {
   2664       for (int w = 0; w < input_width; ++w) {
   2665         // (h_start, h_end) * (w_start, w_end) is the range that the input
   2666         // vector projects to.
   2667         int hpad = h + pad_height;
   2668         int wpad = w + pad_width;
   2669         int h_start =
   2670             (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
   2671         int h_end = std::min(hpad / stride_height + 1, output_height);
   2672         int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
   2673         int w_end = std::min(wpad / stride_width + 1, output_width);
   2674         // compute elementwise sum
   2675         for (int ph = h_start; ph < h_end; ++ph) {
   2676           for (int pw = w_start; pw < w_end; ++pw) {
   2677             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
   2678             out_mat.col(out_offset) +=
   2679                 in_mat.col(NodeOffset(b, h, w, input_height, input_width));
   2680             out_count(out_offset)++;
   2681           }
   2682         }
   2683       }
   2684     }
   2685   }
   2686   // Divide the output by the actual number of elements being averaged over
   2687   TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
   2688   out_mat.array().rowwise() /= out_count.transpose().array();
   2689 
   2690   for (int b = 0; b < batches; ++b) {
   2691     for (int y = 0; y < output_height; ++y) {
   2692       for (int x = 0; x < output_width; ++x) {
   2693         for (int c = 0; c < depth; ++c) {
   2694           output_data[Offset(output_dims, c, x, y, b)] =
   2695               ActivationFunctionWithMinMax(
   2696                   output_data[Offset(output_dims, c, x, y, b)],
   2697                   output_activation_min, output_activation_max);
   2698         }
   2699       }
   2700     }
   2701   }
   2702 }
   2703 
   2704 // legacy, for compatibility with old checked-in code
   2705 template <FusedActivationFunctionType Ac>
   2706 void AveragePool(const float* input_data, const Dims<4>& input_dims,
   2707                  int stride_width, int stride_height, int pad_width,
   2708                  int pad_height, int kwidth, int kheight, float* output_data,
   2709                  const Dims<4>& output_dims) {
   2710   float output_activation_min, output_activation_max;
   2711   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
   2712 
   2713   AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
   2714               pad_height, kwidth, kheight, output_activation_min,
   2715               output_activation_max, output_data, output_dims);
   2716 }
   2717 
   2718 // legacy, for compatibility with old checked-in code
   2719 template <FusedActivationFunctionType Ac>
   2720 void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
   2721                  int pad_width, int pad_height, int filter_width,
   2722                  int filter_height, float* output_data,
   2723                  const Dims<4>& output_dims) {
   2724   AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
   2725                   filter_width, filter_height, output_data, output_dims);
   2726 }
   2727 
   2728 inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
   2729                         int stride_width, int stride_height, int pad_width,
   2730                         int pad_height, int filter_width, int filter_height,
   2731                         int32 output_activation_min,
   2732                         int32 output_activation_max, uint8* output_data,
   2733                         const Dims<4>& output_dims) {
   2734   gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
   2735   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
   2736   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   2737   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   2738   const int input_height = ArraySize(input_dims, 2);
   2739   const int input_width = ArraySize(input_dims, 1);
   2740   const int output_height = ArraySize(output_dims, 2);
   2741   const int output_width = ArraySize(output_dims, 1);
   2742   for (int batch = 0; batch < batches; ++batch) {
   2743     for (int out_y = 0; out_y < output_height; ++out_y) {
   2744       for (int out_x = 0; out_x < output_width; ++out_x) {
   2745         const int in_x_origin = (out_x * stride_width) - pad_width;
   2746         const int in_y_origin = (out_y * stride_height) - pad_height;
   2747         const int filter_x_start = std::max(0, -in_x_origin);
   2748         const int filter_x_end =
   2749             std::min(filter_width, input_width - in_x_origin);
   2750         const int filter_y_start = std::max(0, -in_y_origin);
   2751         const int filter_y_end =
   2752             std::min(filter_height, input_height - in_y_origin);
   2753         const int filter_count =
   2754             (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
   2755         // 1280 required by Inception v3
   2756         static constexpr int kAccBufferMaxSize = 2048;
   2757         TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
   2758         uint16 acc[kAccBufferMaxSize];
   2759         memset(acc, 0, depth * sizeof(acc[0]));
   2760         const uint8* input_ptr =
   2761             input_data + input_dims.strides[1] * in_x_origin +
   2762             input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
   2763         for (int fy = filter_y_start; fy < filter_y_end; fy++) {
   2764           const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
   2765                                        filter_x_start * input_dims.strides[1];
   2766           for (int fx = filter_x_start; fx < filter_x_end; fx++) {
   2767             int channel = 0;
   2768 #ifdef USE_NEON
   2769             for (; channel <= depth - 16; channel += 16) {
   2770               uint16x8_t acc_reg[2];
   2771               for (int i = 0; i < 2; i++) {
   2772                 acc_reg[i] = vld1q_u16(acc + channel + 8 * i);
   2773               }
   2774               uint8x16_t input_reg = vld1q_u8(input_row_ptr);
   2775               input_row_ptr += 16;
   2776               acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg));
   2777               acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg));
   2778               for (int i = 0; i < 2; i++) {
   2779                 vst1q_u16(acc + channel + 8 * i, acc_reg[i]);
   2780               }
   2781             }
   2782             for (; channel <= depth - 8; channel += 8) {
   2783               uint16x8_t acc_reg = vld1q_u16(acc + channel);
   2784               uint8x8_t input_reg = vld1_u8(input_row_ptr);
   2785               input_row_ptr += 8;
   2786               acc_reg = vaddw_u8(acc_reg, input_reg);
   2787               vst1q_u16(acc + channel, acc_reg);
   2788             }
   2789 #endif
   2790             for (; channel < depth; ++channel) {
   2791               acc[channel] += *input_row_ptr++;
   2792             }
   2793           }
   2794         }
   2795         uint8* output_ptr =
   2796             output_data + Offset(output_dims, 0, out_x, out_y, batch);
   2797         int channel = 0;
   2798 #ifdef USE_NEON
   2799 #define AVGPOOL_DIVIDING_BY(FILTER_COUNT)                              \
   2800   if (filter_count == FILTER_COUNT) {                                  \
   2801     for (; channel <= depth - 8; channel += 8) {                       \
   2802       uint16 buf[8];                                                   \
   2803       for (int i = 0; i < 8; i++) {                                    \
   2804         buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
   2805       }                                                                \
   2806       uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));                     \
   2807       buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max));          \
   2808       buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min));          \
   2809       vst1_u8(output_ptr + channel, buf8);                             \
   2810     }                                                                  \
   2811   }
   2812         AVGPOOL_DIVIDING_BY(9)
   2813         AVGPOOL_DIVIDING_BY(15)
   2814 #undef AVGPOOL_DIVIDING_BY
   2815         for (; channel <= depth - 8; channel += 8) {
   2816           uint16 buf[8];
   2817           for (int i = 0; i < 8; i++) {
   2818             buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
   2819           }
   2820           uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
   2821           buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max));
   2822           buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min));
   2823           vst1_u8(output_ptr + channel, buf8);
   2824         }
   2825 #endif
   2826         for (; channel < depth; ++channel) {
   2827           uint16 a = (acc[channel] + filter_count / 2) / filter_count;
   2828           a = std::max<uint16>(a, output_activation_min);
   2829           a = std::min<uint16>(a, output_activation_max);
   2830           output_ptr[channel] = static_cast<uint8>(a);
   2831         }
   2832       }
   2833     }
   2834   }
   2835 }
   2836 
   2837 // legacy, for compatibility with old checked-in code
   2838 template <FusedActivationFunctionType Ac>
   2839 void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
   2840                  int stride_width, int stride_height, int pad_width,
   2841                  int pad_height, int filter_width, int filter_height,
   2842                  int32 output_activation_min, int32 output_activation_max,
   2843                  uint8* output_data, const Dims<4>& output_dims) {
   2844   static_assert(Ac == FusedActivationFunctionType::kNone ||
   2845                     Ac == FusedActivationFunctionType::kRelu ||
   2846                     Ac == FusedActivationFunctionType::kRelu6 ||
   2847                     Ac == FusedActivationFunctionType::kRelu1,
   2848                 "");
   2849   if (Ac == FusedActivationFunctionType::kNone) {
   2850     TFLITE_DCHECK_EQ(output_activation_min, 0);
   2851     TFLITE_DCHECK_EQ(output_activation_max, 255);
   2852   }
   2853   AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
   2854               pad_height, filter_width, filter_height, output_activation_min,
   2855               output_activation_max, output_data, output_dims);
   2856 }
   2857 
   2858 // legacy, for compatibility with old checked-in code
   2859 template <FusedActivationFunctionType Ac>
   2860 void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
   2861                  int pad_width, int pad_height, int filter_width,
   2862                  int filter_height, int32 output_activation_min,
   2863                  int32 output_activation_max, uint8* output_data,
   2864                  const Dims<4>& output_dims) {
   2865   AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
   2866                   filter_width, filter_height, output_activation_min,
   2867                   output_activation_max, output_data, output_dims);
   2868 }
   2869 
   2870 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
   2871                     int stride_width, int stride_height, int pad_width,
   2872                     int pad_height, int kwidth, int kheight,
   2873                     float output_activation_min, float output_activation_max,
   2874                     float* output_data, const Dims<4>& output_dims) {
   2875   gemmlowp::ScopedProfilingLabel label("MaxPool");
   2876   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   2877   const int input_height = ArraySize(input_dims, 2);
   2878   const int input_width = ArraySize(input_dims, 1);
   2879   const int output_height = ArraySize(output_dims, 2);
   2880   const int output_width = ArraySize(output_dims, 1);
   2881   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   2882 
   2883   const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
   2884   auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
   2885   // Prefill the output to minimum representable float value
   2886   out_mat.setConstant(std::numeric_limits<float>::lowest());
   2887   for (int b = 0; b < batches; ++b) {
   2888     for (int h = 0; h < input_height; ++h) {
   2889       for (int w = 0; w < input_width; ++w) {
   2890         // (h_start, h_end) * (w_start, w_end) is the range that the input
   2891         // vector projects to.
   2892         int hpad = h + pad_height;
   2893         int wpad = w + pad_width;
   2894         int h_start =
   2895             (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
   2896         int h_end = std::min(hpad / stride_height + 1, output_height);
   2897         int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
   2898         int w_end = std::min(wpad / stride_width + 1, output_width);
   2899         // compute elementwise sum
   2900         for (int ph = h_start; ph < h_end; ++ph) {
   2901           for (int pw = w_start; pw < w_end; ++pw) {
   2902             int out_offset = NodeOffset(b, ph, pw, output_height, output_width);
   2903             out_mat.col(out_offset) =
   2904                 out_mat.col(out_offset)
   2905                     .cwiseMax(in_mat.col(
   2906                         NodeOffset(b, h, w, input_height, input_width)));
   2907           }
   2908         }
   2909       }
   2910     }
   2911   }
   2912 
   2913   for (int b = 0; b < batches; ++b) {
   2914     for (int y = 0; y < output_height; ++y) {
   2915       for (int x = 0; x < output_width; ++x) {
   2916         for (int c = 0; c < depth; ++c) {
   2917           output_data[Offset(output_dims, c, x, y, b)] =
   2918               ActivationFunctionWithMinMax(
   2919                   output_data[Offset(output_dims, c, x, y, b)],
   2920                   output_activation_min, output_activation_max);
   2921         }
   2922       }
   2923     }
   2924   }
   2925 }
   2926 
   2927 // legacy, for compatibility with old checked-in code
   2928 template <FusedActivationFunctionType Ac>
   2929 void MaxPool(const float* input_data, const Dims<4>& input_dims,
   2930              int stride_width, int stride_height, int pad_width, int pad_height,
   2931              int kwidth, int kheight, float* output_data,
   2932              const Dims<4>& output_dims) {
   2933   float output_activation_min, output_activation_max;
   2934   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
   2935   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
   2936           pad_height, kwidth, kheight, output_activation_min,
   2937           output_activation_max, output_data, output_dims);
   2938 }
   2939 
   2940 // legacy, for compatibility with old checked-in code
   2941 template <FusedActivationFunctionType Ac>
   2942 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
   2943              int pad_width, int pad_height, int filter_width, int filter_height,
   2944              float* output_data, const Dims<4>& output_dims) {
   2945   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
   2946               filter_width, filter_height, output_data, output_dims);
   2947 }
   2948 
   2949 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
   2950                     int stride_width, int stride_height, int pad_width,
   2951                     int pad_height, int filter_width, int filter_height,
   2952                     int32 output_activation_min, int32 output_activation_max,
   2953                     uint8* output_data, const Dims<4>& output_dims) {
   2954   gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
   2955   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
   2956   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   2957   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   2958   const int input_height = ArraySize(input_dims, 2);
   2959   const int input_width = ArraySize(input_dims, 1);
   2960   const int output_height = ArraySize(output_dims, 2);
   2961   const int output_width = ArraySize(output_dims, 1);
   2962   for (int batch = 0; batch < batches; ++batch) {
   2963     for (int out_y = 0; out_y < output_height; ++out_y) {
   2964       for (int out_x = 0; out_x < output_width; ++out_x) {
   2965         const int in_x_origin = (out_x * stride_width) - pad_width;
   2966         const int in_y_origin = (out_y * stride_height) - pad_height;
   2967         const int filter_x_start = std::max(0, -in_x_origin);
   2968         const int filter_x_end =
   2969             std::min(filter_width, input_width - in_x_origin);
   2970         const int filter_y_start = std::max(0, -in_y_origin);
   2971         const int filter_y_end =
   2972             std::min(filter_height, input_height - in_y_origin);
   2973         // 2048 required by Inception v3
   2974         static constexpr int kAccBufferMaxSize = 2048;
   2975         TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
   2976         uint8 acc[kAccBufferMaxSize];
   2977         memset(acc, 0, depth * sizeof(acc[0]));
   2978         const uint8* input_ptr =
   2979             input_data + input_dims.strides[1] * in_x_origin +
   2980             input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
   2981         for (int fy = filter_y_start; fy < filter_y_end; fy++) {
   2982           const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
   2983                                        filter_x_start * input_dims.strides[1];
   2984           for (int fx = filter_x_start; fx < filter_x_end; fx++) {
   2985             int channel = 0;
   2986 #ifdef USE_NEON
   2987             for (; channel <= depth - 16; channel += 16) {
   2988               uint8x16_t acc_reg = vld1q_u8(acc + channel);
   2989               uint8x16_t input_reg = vld1q_u8(input_row_ptr);
   2990               input_row_ptr += 16;
   2991               acc_reg = vmaxq_u8(acc_reg, input_reg);
   2992               vst1q_u8(acc + channel, acc_reg);
   2993             }
   2994 
   2995             for (; channel <= depth - 8; channel += 8) {
   2996               uint8x8_t acc_reg = vld1_u8(acc + channel);
   2997               uint8x8_t input_reg = vld1_u8(input_row_ptr);
   2998               input_row_ptr += 8;
   2999               acc_reg = vmax_u8(acc_reg, input_reg);
   3000               vst1_u8(acc + channel, acc_reg);
   3001             }
   3002 #endif
   3003             for (; channel < depth; ++channel) {
   3004               acc[channel] = std::max(acc[channel], *input_row_ptr++);
   3005             }
   3006           }
   3007         }
   3008         uint8* output_ptr =
   3009             output_data + Offset(output_dims, 0, out_x, out_y, batch);
   3010         int channel = 0;
   3011 #ifdef USE_NEON
   3012         for (; channel <= depth - 16; channel += 16) {
   3013           uint8x16_t a = vld1q_u8(acc + channel);
   3014           a = vminq_u8(a, vdupq_n_u8(output_activation_max));
   3015           a = vmaxq_u8(a, vdupq_n_u8(output_activation_min));
   3016           vst1q_u8(output_ptr + channel, a);
   3017         }
   3018         for (; channel <= depth - 8; channel += 8) {
   3019           uint8x8_t a = vld1_u8(acc + channel);
   3020           a = vmin_u8(a, vdup_n_u8(output_activation_max));
   3021           a = vmax_u8(a, vdup_n_u8(output_activation_min));
   3022           vst1_u8(output_ptr + channel, a);
   3023         }
   3024 #endif
   3025         for (; channel < depth; ++channel) {
   3026           uint8 a = acc[channel];
   3027           a = std::max<uint8>(a, output_activation_min);
   3028           a = std::min<uint8>(a, output_activation_max);
   3029           output_ptr[channel] = static_cast<uint8>(a);
   3030         }
   3031       }
   3032     }
   3033   }
   3034 }
   3035 
   3036 // legacy, for compatibility with old checked-in code
   3037 template <FusedActivationFunctionType Ac>
   3038 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
   3039              int stride_width, int stride_height, int pad_width, int pad_height,
   3040              int filter_width, int filter_height, int32 output_activation_min,
   3041              int32 output_activation_max, uint8* output_data,
   3042              const Dims<4>& output_dims) {
   3043   static_assert(Ac == FusedActivationFunctionType::kNone ||
   3044                     Ac == FusedActivationFunctionType::kRelu ||
   3045                     Ac == FusedActivationFunctionType::kRelu6 ||
   3046                     Ac == FusedActivationFunctionType::kRelu1,
   3047                 "");
   3048   if (Ac == FusedActivationFunctionType::kNone) {
   3049     TFLITE_DCHECK_EQ(output_activation_min, 0);
   3050     TFLITE_DCHECK_EQ(output_activation_max, 255);
   3051   }
   3052   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
   3053           pad_height, filter_width, filter_height, output_activation_min,
   3054           output_activation_max, output_data, output_dims);
   3055 }
   3056 
   3057 // legacy, for compatibility with old checked-in code
   3058 template <FusedActivationFunctionType Ac>
   3059 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
   3060              int pad_width, int pad_height, int filter_width, int filter_height,
   3061              int32 output_activation_min, int32 output_activation_max,
   3062              uint8* output_data, const Dims<4>& output_dims) {
   3063   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
   3064               filter_width, filter_height, output_activation_min,
   3065               output_activation_max, output_data, output_dims);
   3066 }
   3067 
   3068 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
   3069                    int stride_width, int stride_height, int pad_width,
   3070                    int pad_height, int filter_width, int filter_height,
   3071                    float output_activation_min, float output_activation_max,
   3072                    float* output_data, const Dims<4>& output_dims) {
   3073   gemmlowp::ScopedProfilingLabel label("L2Pool");
   3074   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   3075   const int input_height = ArraySize(input_dims, 2);
   3076   const int input_width = ArraySize(input_dims, 1);
   3077   const int output_height = ArraySize(output_dims, 2);
   3078   const int output_width = ArraySize(output_dims, 1);
   3079   // Actually carry out L2 Pool. Code is written in forward mode: we go through
   3080   // the input values once, and write to all the pooled regions that it maps to.
   3081   const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
   3082   auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
   3083   Eigen::VectorXf in_square(in_mat.rows());
   3084   Eigen::VectorXf out_count(out_mat.cols());
   3085   out_count.setZero();
   3086   // Prefill the output to 0.
   3087   out_mat.setZero();
   3088   for (int b = 0; b < batches; ++b) {
   3089     for (int h = 0; h < input_height; ++h) {
   3090       for (int w = 0; w < input_width; ++w) {
   3091         // (h_start, h_end) * (w_start, w_end) is the range that the input
   3092         // vector projects to.
   3093         const int hpad = h + pad_height;
   3094         const int wpad = w + pad_width;
   3095         const int h_start = (hpad < filter_height)
   3096                                 ? 0
   3097                                 : (hpad - filter_height) / stride_height + 1;
   3098         const int h_end = std::min(hpad / stride_height + 1, output_height);
   3099         const int w_start = (wpad < filter_width)
   3100                                 ? 0
   3101                                 : (wpad - filter_width) / stride_width + 1;
   3102         const int w_end = std::min(wpad / stride_width + 1, output_width);
   3103         // pre-compute square
   3104         const int in_offset = w + input_width * (h + input_height * b);
   3105         in_square =
   3106             in_mat.col(in_offset).array() * in_mat.col(in_offset).array();
   3107         // compute elementwise sum of squares
   3108         for (int ph = h_start; ph < h_end; ++ph) {
   3109           for (int pw = w_start; pw < w_end; ++pw) {
   3110             const int out_offset = pw + output_width * (ph + output_height * b);
   3111             out_mat.col(out_offset) += in_square;
   3112             out_count(out_offset)++;
   3113           }
   3114         }
   3115       }
   3116     }
   3117   }
   3118 
   3119   out_count = out_count.array().inverse();
   3120   out_mat =
   3121       (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
   3122 }
   3123 
   3124 // legacy, for compatibility with old checked-in code
   3125 template <FusedActivationFunctionType Ac>
   3126 void L2Pool(const float* input_data, const Dims<4>& input_dims,
   3127             int stride_width, int stride_height, int pad_width, int pad_height,
   3128             int filter_width, int filter_height, float* output_data,
   3129             const Dims<4>& output_dims) {
   3130   float output_activation_min, output_activation_max;
   3131   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
   3132   L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
   3133          pad_height, filter_width, filter_height, output_activation_min,
   3134          output_activation_max, output_data, output_dims);
   3135 }
   3136 
   3137 // legacy, for compatibility with old checked-in code
   3138 template <FusedActivationFunctionType Ac>
   3139 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
   3140             int pad_width, int pad_height, int filter_width, int filter_height,
   3141             float* output_data, const Dims<4>& output_dims) {
   3142   L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
   3143              filter_width, filter_height, output_data, output_dims);
   3144 }
   3145 
   3146 inline void LocalResponseNormalization(const float* input_data,
   3147                                        const Dims<4>& input_dims, int range,
   3148                                        float bias, float alpha, float beta,
   3149                                        float* output_data,
   3150                                        const Dims<4>& output_dims) {
   3151   gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
   3152   /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3);
   3153   /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2);
   3154   /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1);
   3155   /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0);
   3156 
   3157   const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
   3158   auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
   3159 
   3160   // Carry out local response normalization, vector by vector.
   3161   // Since the data are stored column major, making row-wise operation
   3162   // probably not memory efficient anyway, we do an explicit for loop over
   3163   // the columns.
   3164   const int double_range = range * 2;
   3165   Eigen::VectorXf padded_square(data_in.rows() + double_range);
   3166   padded_square.setZero();
   3167   for (int r = 0; r < data_in.cols(); ++r) {
   3168     // Do local response normalization for data_in(:, r)
   3169     // first, compute the square and store them in buffer for repeated use
   3170     padded_square.block(range, 0, data_in.rows(), 1) =
   3171         data_in.col(r).cwiseProduct(data_in.col(r)) * alpha;
   3172     // Then, compute the scale and writes them to data_out
   3173     float accumulated_scale = 0;
   3174     for (int i = 0; i < double_range; ++i) {
   3175       accumulated_scale += padded_square(i);
   3176     }
   3177     for (int i = 0; i < data_in.rows(); ++i) {
   3178       accumulated_scale += padded_square(i + double_range);
   3179       data_out(i, r) = bias + accumulated_scale;
   3180       accumulated_scale -= padded_square(i);
   3181     }
   3182   }
   3183 
   3184   // In a few cases, the pow computation could benefit from speedups.
   3185   if (beta == 1) {
   3186     data_out.array() = data_in.array() * data_out.array().inverse();
   3187   } else if (beta == 0.5) {
   3188     data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
   3189   } else {
   3190     data_out.array() = data_in.array() * data_out.array().pow(-beta);
   3191   }
   3192 }
   3193 
   3194 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
   3195                     float beta, float* output_data,
   3196                     const Dims<4>& output_dims) {
   3197   gemmlowp::ScopedProfilingLabel label("Softmax");
   3198   /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3);
   3199   /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2);
   3200   /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1);
   3201   /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0);
   3202 
   3203   const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
   3204   auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
   3205   // Compute the exponential first, removing the max coefficient for numerical
   3206   // stability.
   3207   out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
   3208   // We are separating out the exp function so that exp can be vectorized.
   3209   out_mat = out_mat.array().exp();
   3210   // Normalize to get the activations.
   3211   Eigen::Array<float, 1, Eigen::Dynamic> scale =
   3212       out_mat.array().colwise().sum().inverse();
   3213   out_mat.array().rowwise() *= scale;
   3214 }
   3215 
   3216 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
   3217                     int32 input_beta_multiplier, int32 input_beta_left_shift,
   3218                     int diff_min, uint8* output_data,
   3219                     const Dims<4>& output_dims) {
   3220   // The representation chosen for the input to the exp() function is Q5.26.
   3221   // We need to leave extra space since values that we skip might be as large as
   3222   // -32 before multiplying by input_beta_multiplier, and therefore as large as
   3223   // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
   3224   // accumulation, but exp(-16) definitely is.
   3225   static const int kScaledDiffIntegerBits = 5;
   3226   static const int kAccumulationIntegerBits = 12;
   3227   using FixedPointScaledDiff =
   3228       gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
   3229   using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
   3230   using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
   3231 
   3232   gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
   3233   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   3234   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   3235   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   3236   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   3237 
   3238   const int outer_size = batches * height * width;
   3239 
   3240   for (int b = 0; b < outer_size; ++b) {
   3241     const uint8* input_data_ptr = input_data + b * depth;
   3242     uint8* output_data_ptr = output_data + b * depth;
   3243 
   3244     // Determine the largest entry in the current row
   3245     uint8 max_in_row = 0;
   3246     {
   3247       int c = 0;
   3248 #ifdef USE_NEON
   3249       uint8x16_t max16_0 = vdupq_n_u8(0);
   3250       uint8x16_t max16_1 = vdupq_n_u8(0);
   3251       for (; c <= depth - 32; c += 32) {
   3252         max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0));
   3253         max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16));
   3254       }
   3255       uint8x16_t max16 = vmaxq_u8(max16_0, max16_1);
   3256       if (c <= depth - 16) {
   3257         max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c));
   3258         c += 16;
   3259       }
   3260       uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16));
   3261       if (c <= depth - 8) {
   3262         max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c));
   3263         c += 8;
   3264       }
   3265       uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4));
   3266       uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2));
   3267       uint8x8_t max1 = vpmax_u8(max2, max2);
   3268       max_in_row = vget_lane_u8(max1, 0);
   3269 #endif
   3270       for (; c < depth; ++c) {
   3271         max_in_row = std::max(max_in_row, input_data_ptr[c]);
   3272       }
   3273     }
   3274 
   3275 #ifdef USE_NEON
   3276     using FixedPointAccumInt32x4 =
   3277         gemmlowp::FixedPoint<int32x4_t, kAccumulationIntegerBits>;
   3278     using FixedPointScaledDiffInt32x4 =
   3279         gemmlowp::FixedPoint<int32x4_t, kScaledDiffIntegerBits>;
   3280     using FixedPoint0Int32x4 = gemmlowp::FixedPoint<int32x4_t, 0>;
   3281     FixedPoint0Int32x4 input_beta_multiplier_f0 =
   3282         FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier);
   3283     int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row);
   3284 #endif
   3285 
   3286     // Compute the sum of exponentials of the differences of entries in the
   3287     // current row from the largest entry in the current row.
   3288     FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
   3289     {
   3290       int c = 0;
   3291 #ifdef USE_NEON
   3292       int32x4_t diff_min_s32 = vdupq_n_s32(diff_min);
   3293       FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero();
   3294       FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero();
   3295       FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero();
   3296       for (; c <= depth - 8; c += 8) {
   3297         uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
   3298         int16x8_t input_diff_s16 =
   3299             vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
   3300         int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
   3301         int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
   3302         int32x4_t mask_0 =
   3303             gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32);
   3304         int32x4_t mask_1 =
   3305             gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32);
   3306         FixedPointScaledDiffInt32x4 scaled_diff_0 =
   3307             input_beta_multiplier_f0 *
   3308             FixedPointScaledDiffInt32x4::FromRaw(
   3309                 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
   3310         FixedPointScaledDiffInt32x4 scaled_diff_1 =
   3311             input_beta_multiplier_f0 *
   3312             FixedPointScaledDiffInt32x4::FromRaw(
   3313                 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
   3314         FixedPointAccumInt32x4 exps_0 =
   3315             gemmlowp::Rescale<kAccumulationIntegerBits>(
   3316                 exp_on_negative_values(scaled_diff_0));
   3317         FixedPointAccumInt32x4 exps_1 =
   3318             gemmlowp::Rescale<kAccumulationIntegerBits>(
   3319                 exp_on_negative_values(scaled_diff_1));
   3320         FixedPointAccumInt32x4 masked_exps_0 =
   3321             SelectUsingMask(mask_0, exps_0, zeros);
   3322         FixedPointAccumInt32x4 masked_exps_1 =
   3323             SelectUsingMask(mask_1, exps_1, zeros);
   3324         sum_of_exps_0 = sum_of_exps_0 + masked_exps_0;
   3325         sum_of_exps_1 = sum_of_exps_1 + masked_exps_1;
   3326       }
   3327       int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw();
   3328       int32x2_t sum_of_exps_reduced_2 =
   3329           vadd_s32(vget_low_s32(sum_of_exps_reduced_4),
   3330                    vget_high_s32(sum_of_exps_reduced_4));
   3331       int32x2_t sum_of_exps_reduced_1 =
   3332           vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2);
   3333       sum_of_exps =
   3334           FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0));
   3335 #endif
   3336       for (; c < depth; ++c) {
   3337         int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
   3338         if (input_diff >= diff_min) {
   3339           const int32 input_diff_rescaled =
   3340               MultiplyByQuantizedMultiplierGreaterThanOne(
   3341                   input_diff, input_beta_multiplier, input_beta_left_shift);
   3342           const FixedPointScaledDiff scaled_diff_f8 =
   3343               FixedPointScaledDiff::FromRaw(input_diff_rescaled);
   3344           sum_of_exps =
   3345               sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
   3346                                 exp_on_negative_values(scaled_diff_f8));
   3347         }
   3348       }
   3349     }
   3350 
   3351     // Compute the fixed-point multiplier and shift that we need to apply to
   3352     // perform a division by the above-computed sum-of-exponentials.
   3353     int32 fixed_sum_of_exps = sum_of_exps.raw();
   3354     int headroom_plus_one =
   3355         __builtin_clz(static_cast<uint32>(fixed_sum_of_exps));
   3356     // This is the number of bits to the left of the binary point above 1.0.
   3357     // Consider fixed_sum_of_exps=1.25.  In that case shifted_scale=0.8 and
   3358     // no later adjustment will be needed.
   3359     int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
   3360     int32 shifted_sum_minus_one = static_cast<int32>(
   3361         (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
   3362         (static_cast<uint32>(1) << 31));
   3363     FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
   3364         FixedPoint0::FromRaw(shifted_sum_minus_one));
   3365 
   3366     // Compute the quotients of exponentials of differences of entries in the
   3367     // current row from the largest entry, over the previously-computed sum of
   3368     // exponentials.
   3369     {
   3370       int c = 0;
   3371 #ifdef USE_NEON
   3372       int16x8_t diff_min_s16 = vdupq_n_s16(diff_min);
   3373       for (; c <= depth - 8; c += 8) {
   3374         uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c));
   3375         int16x8_t input_diff_s16 =
   3376             vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16);
   3377         int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16));
   3378         int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16));
   3379         uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16));
   3380         FixedPointScaledDiffInt32x4 scaled_diff_0 =
   3381             input_beta_multiplier_f0 *
   3382             FixedPointScaledDiffInt32x4::FromRaw(
   3383                 gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift));
   3384         FixedPointScaledDiffInt32x4 scaled_diff_1 =
   3385             input_beta_multiplier_f0 *
   3386             FixedPointScaledDiffInt32x4::FromRaw(
   3387                 gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift));
   3388         FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0);
   3389         FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1);
   3390         int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT(
   3391             vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()),
   3392             num_bits_over_unit + 31 - 8);
   3393         int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT(
   3394             vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()),
   3395             num_bits_over_unit + 31 - 8);
   3396         int16x8_t output_s16 =
   3397             vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1));
   3398         uint8x8_t output_u8 = vqmovun_s16(output_s16);
   3399         uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0));
   3400         vst1_u8(output_data_ptr + c, masked_output);
   3401       }
   3402 #endif
   3403       for (; c < depth; ++c) {
   3404         int32 input_diff = static_cast<int32>(input_data_ptr[c]) - max_in_row;
   3405         if (input_diff >= diff_min) {
   3406           const int32 input_diff_rescaled =
   3407               MultiplyByQuantizedMultiplierGreaterThanOne(
   3408                   input_diff, input_beta_multiplier, input_beta_left_shift);
   3409           const FixedPointScaledDiff scaled_diff_f8 =
   3410               FixedPointScaledDiff::FromRaw(input_diff_rescaled);
   3411 
   3412           FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
   3413           int32 unsat_output = gemmlowp::RoundingDivideByPOT(
   3414               (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
   3415 
   3416           output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0);
   3417 
   3418         } else {
   3419           output_data_ptr[c] = 0;
   3420         }
   3421       }
   3422     }
   3423   }
   3424 }
   3425 
   3426 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
   3427                      float* output_data, const Dims<4>& output_dims) {
   3428   gemmlowp::ScopedProfilingLabel label("Logistic");
   3429   auto input_map = MapAsVector(input_data, input_dims);
   3430   auto output_map = MapAsVector(output_data, output_dims);
   3431   output_map.array() =
   3432       input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
   3433 }
   3434 
   3435 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
   3436                      int32 input_zero_point, int32 input_range_radius,
   3437                      int32 input_multiplier, int input_left_shift,
   3438                      uint8* output_data, const Dims<4>& output_dims) {
   3439   gemmlowp::ScopedProfilingLabel label("Logistic");
   3440   /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3);
   3441   /* height */ MatchingArraySize(input_dims, 2, output_dims, 2);
   3442   /* width */ MatchingArraySize(input_dims, 1, output_dims, 1);
   3443   /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0);
   3444   const int size = RequiredBufferSizeForDims(input_dims);
   3445 
   3446   int c = 0;
   3447 #ifdef USE_NEON
   3448   // Handle 16 values at a time
   3449   for (; c <= size - 16; c += 16) {
   3450     // Read input uint8 values, cast to int16 and subtract input_zero_point
   3451     uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
   3452     int16x8_t input_val_centered_0 =
   3453         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
   3454                   vdupq_n_s16(input_zero_point));
   3455     int16x8_t input_val_centered_1 =
   3456         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
   3457                   vdupq_n_s16(input_zero_point));
   3458 
   3459     // Prepare the bit masks that we will use at the end to implement the logic
   3460     // that was expressed in the scalar code with branching:
   3461     //   if (input_val_centered < -input_range_radius) {
   3462     //     output_val = 0;
   3463     //   } else if (input_val_centered > input_range_radius) {
   3464     //     output_val = 255;
   3465     //   } else {
   3466     //     ...
   3467     uint16x8_t mask_rightclamp_0 =
   3468         vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
   3469     uint16x8_t mask_rightclamp_1 =
   3470         vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
   3471     uint16x8_t mask_leftclamp_0 =
   3472         vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
   3473     uint16x8_t mask_leftclamp_1 =
   3474         vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
   3475     uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
   3476                                              vshrn_n_u16(mask_rightclamp_1, 8));
   3477     uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
   3478                                             vshrn_n_u16(mask_leftclamp_1, 8));
   3479 
   3480     // This performs what is expressed in the scalar code as
   3481     // const int32 input_val_rescaled =
   3482     //     MultiplyByQuantizedMultiplierGreaterThanOne(
   3483     //         input_val_centered, input_multiplier, input_left_shift);
   3484     int32x4_t input_val_rescaled_0 =
   3485         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
   3486                   vdupq_n_s32(input_left_shift));
   3487     int32x4_t input_val_rescaled_1 =
   3488         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
   3489                   vdupq_n_s32(input_left_shift));
   3490     int32x4_t input_val_rescaled_2 =
   3491         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
   3492                   vdupq_n_s32(input_left_shift));
   3493     int32x4_t input_val_rescaled_3 =
   3494         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
   3495                   vdupq_n_s32(input_left_shift));
   3496     input_val_rescaled_0 =
   3497         vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
   3498     input_val_rescaled_1 =
   3499         vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
   3500     input_val_rescaled_2 =
   3501         vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
   3502     input_val_rescaled_3 =
   3503         vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
   3504 
   3505     // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t
   3506     using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
   3507     using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
   3508     const FixedPoint4 input_val_f4_0 =
   3509         FixedPoint4::FromRaw(input_val_rescaled_0);
   3510     const FixedPoint4 input_val_f4_1 =
   3511         FixedPoint4::FromRaw(input_val_rescaled_1);
   3512     const FixedPoint4 input_val_f4_2 =
   3513         FixedPoint4::FromRaw(input_val_rescaled_2);
   3514     const FixedPoint4 input_val_f4_3 =
   3515         FixedPoint4::FromRaw(input_val_rescaled_3);
   3516     const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0);
   3517     const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1);
   3518     const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2);
   3519     const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3);
   3520 
   3521     // Divide by 2^23 as in the scalar code
   3522     using gemmlowp::RoundingDivideByPOT;
   3523     int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23);
   3524     int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23);
   3525     int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23);
   3526     int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23);
   3527 
   3528     // Cast output values to uint8, saturating
   3529     int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
   3530                                               vqmovn_s32(output_val_s32_1));
   3531     int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
   3532                                               vqmovn_s32(output_val_s32_3));
   3533     uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
   3534                                            vqmovun_s16(output_val_s16_1));
   3535 
   3536     // Perform the bit-masking with the bit masks computed at the beginning,
   3537     // see the comment there.
   3538     output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
   3539     output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
   3540 
   3541     // Store back to memory
   3542     vst1q_u8(output_data + c, output_val_u8);
   3543   }
   3544 #endif
   3545   // Leftover loop: handle one value at a time with scalar code.
   3546   for (; c < size; ++c) {
   3547     const uint8 input_val_u8 = input_data[c];
   3548     const int32 input_val_centered =
   3549         static_cast<int32>(input_val_u8) - input_zero_point;
   3550     uint8 output_val;
   3551     if (input_val_centered < -input_range_radius) {
   3552       output_val = 0;
   3553     } else if (input_val_centered > input_range_radius) {
   3554       output_val = 255;
   3555     } else {
   3556       const int32 input_val_rescaled =
   3557           MultiplyByQuantizedMultiplierGreaterThanOne(
   3558               input_val_centered, input_multiplier, input_left_shift);
   3559       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
   3560       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
   3561       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
   3562       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
   3563       using gemmlowp::RoundingDivideByPOT;
   3564       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
   3565       if (output_val_s32 == 256) {
   3566         output_val_s32 = 255;
   3567       }
   3568       TFLITE_DCHECK_GE(output_val_s32, 0);
   3569       TFLITE_DCHECK_LE(output_val_s32, 255);
   3570       output_val = static_cast<uint8>(output_val_s32);
   3571     }
   3572     output_data[c] = output_val;
   3573   }
   3574 }
   3575 
   3576 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
   3577                  float* output_data, const Dims<4>& output_dims) {
   3578   gemmlowp::ScopedProfilingLabel label("Tanh");
   3579   auto input_map = MapAsVector(input_data, input_dims);
   3580   auto output_map = MapAsVector(output_data, output_dims);
   3581   output_map.array() = input_map.array().tanh();
   3582 }
   3583 
   3584 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
   3585                  int32 input_zero_point, int32 input_range_radius,
   3586                  int32 input_multiplier, int input_left_shift,
   3587                  uint8* output_data, const Dims<4>& output_dims) {
   3588   // Note that this is almost the exact same code as in Logistic().
   3589   gemmlowp::ScopedProfilingLabel label("Tanh");
   3590   /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3);
   3591   /* height */ MatchingArraySize(input_dims, 2, output_dims, 2);
   3592   /* width */ MatchingArraySize(input_dims, 1, output_dims, 1);
   3593   /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0);
   3594   const int size = RequiredBufferSizeForDims(input_dims);
   3595 
   3596   int c = 0;
   3597   int32_t output_zero_point = 128;
   3598 #ifdef USE_NEON
   3599   // Handle 16 values at a time
   3600   for (; c <= size - 16; c += 16) {
   3601     // Read input uint8 values, cast to int16 and subtract input_zero_point
   3602     uint8x16_t input_val_u8 = vld1q_u8(input_data + c);
   3603     int16x8_t input_val_centered_0 =
   3604         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))),
   3605                   vdupq_n_s16(input_zero_point));
   3606     int16x8_t input_val_centered_1 =
   3607         vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))),
   3608                   vdupq_n_s16(input_zero_point));
   3609 
   3610     // Prepare the bit masks that we will use at the end to implement the logic
   3611     // that was expressed in the scalar code with branching:
   3612     //   if (input_val_centered < -input_range_radius) {
   3613     //     output_val = 0;
   3614     //   } else if (input_val_centered > input_range_radius) {
   3615     //     output_val = 255;
   3616     //   } else {
   3617     //     ...
   3618     uint16x8_t mask_rightclamp_0 =
   3619         vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius));
   3620     uint16x8_t mask_rightclamp_1 =
   3621         vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius));
   3622     uint16x8_t mask_leftclamp_0 =
   3623         vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius));
   3624     uint16x8_t mask_leftclamp_1 =
   3625         vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius));
   3626     uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8),
   3627                                              vshrn_n_u16(mask_rightclamp_1, 8));
   3628     uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8),
   3629                                             vshrn_n_u16(mask_leftclamp_1, 8));
   3630 
   3631     // This performs what is expressed in the scalar code as
   3632     // const int32 input_val_rescaled =
   3633     //     MultiplyByQuantizedMultiplierGreaterThanOne(
   3634     //         input_val_centered, input_multiplier, input_left_shift);
   3635     int32x4_t input_val_rescaled_0 =
   3636         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)),
   3637                   vdupq_n_s32(input_left_shift));
   3638     int32x4_t input_val_rescaled_1 =
   3639         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)),
   3640                   vdupq_n_s32(input_left_shift));
   3641     int32x4_t input_val_rescaled_2 =
   3642         vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)),
   3643                   vdupq_n_s32(input_left_shift));
   3644     int32x4_t input_val_rescaled_3 =
   3645         vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)),
   3646                   vdupq_n_s32(input_left_shift));
   3647     input_val_rescaled_0 =
   3648         vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier);
   3649     input_val_rescaled_1 =
   3650         vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier);
   3651     input_val_rescaled_2 =
   3652         vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier);
   3653     input_val_rescaled_3 =
   3654         vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier);
   3655 
   3656     // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t
   3657     using FixedPoint4 = gemmlowp::FixedPoint<int32x4_t, 4>;
   3658     using FixedPoint0 = gemmlowp::FixedPoint<int32x4_t, 0>;
   3659     const FixedPoint4 input_val_f4_0 =
   3660         FixedPoint4::FromRaw(input_val_rescaled_0);
   3661     const FixedPoint4 input_val_f4_1 =
   3662         FixedPoint4::FromRaw(input_val_rescaled_1);
   3663     const FixedPoint4 input_val_f4_2 =
   3664         FixedPoint4::FromRaw(input_val_rescaled_2);
   3665     const FixedPoint4 input_val_f4_3 =
   3666         FixedPoint4::FromRaw(input_val_rescaled_3);
   3667     const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0);
   3668     const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1);
   3669     const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2);
   3670     const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3);
   3671 
   3672     // Divide by 2^24 as in the scalar code
   3673     using gemmlowp::RoundingDivideByPOT;
   3674     int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24);
   3675     int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24);
   3676     int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24);
   3677     int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24);
   3678 
   3679     // Add the output zero point
   3680     int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point);
   3681     output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32);
   3682     output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32);
   3683     output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32);
   3684     output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32);
   3685 
   3686     // Cast output values to uint8, saturating
   3687     int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0),
   3688                                               vqmovn_s32(output_val_s32_1));
   3689     int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2),
   3690                                               vqmovn_s32(output_val_s32_3));
   3691     uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0),
   3692                                            vqmovun_s16(output_val_s16_1));
   3693 
   3694     // Perform the bit-masking with the bit masks computed at the beginning,
   3695     // see the comment there.
   3696     output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp);
   3697     output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp);
   3698 
   3699     // Store back to memory
   3700     vst1q_u8(output_data + c, output_val_u8);
   3701   }
   3702 #endif
   3703   // Leftover loop: handle one value at a time with scalar code.
   3704   for (; c < size; ++c) {
   3705     const uint8 input_val_u8 = input_data[c];
   3706     const int32 input_val_centered =
   3707         static_cast<int32>(input_val_u8) - input_zero_point;
   3708     uint8 output_val;
   3709     if (input_val_centered < -input_range_radius) {
   3710       output_val = 0;
   3711     } else if (input_val_centered > input_range_radius) {
   3712       output_val = 255;
   3713     } else {
   3714       const int32 input_val_rescaled =
   3715           MultiplyByQuantizedMultiplierGreaterThanOne(
   3716               input_val_centered, input_multiplier, input_left_shift);
   3717       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
   3718       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
   3719       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
   3720       const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4);
   3721       using gemmlowp::RoundingDivideByPOT;
   3722       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24);
   3723       output_val_s32 += output_zero_point;
   3724       if (output_val_s32 == 256) {
   3725         output_val_s32 = 255;
   3726       }
   3727       TFLITE_DCHECK_GE(output_val_s32, 0);
   3728       TFLITE_DCHECK_LE(output_val_s32, 255);
   3729       output_val = static_cast<uint8>(output_val_s32);
   3730     }
   3731     output_data[c] = output_val;
   3732   }
   3733 }
   3734 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
   3735                        int32 zero_point, double scale, float* output_data,
   3736                        const Dims<4>& output_dims) {
   3737   gemmlowp::ScopedProfilingLabel label("Dequantize");
   3738   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   3739   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   3740   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   3741   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   3742   for (int b = 0; b < batches; ++b) {
   3743     for (int y = 0; y < height; ++y) {
   3744       for (int x = 0; x < width; ++x) {
   3745         for (int c = 0; c < depth; ++c) {
   3746           int32 val = input_data[Offset(input_dims, c, x, y, b)];
   3747           float result = static_cast<float>(scale * (val - zero_point));
   3748           output_data[Offset(output_dims, c, x, y, b)] = result;
   3749         }
   3750       }
   3751     }
   3752   }
   3753 }
   3754 
   3755 inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
   3756                       float rmin, float rmax, float* output_data,
   3757                       const Dims<4>& output_dims) {
   3758   gemmlowp::ScopedProfilingLabel label("FakeQuant");
   3759 
   3760   // 0 should always be a representable value. Let's assume that the initial
   3761   // min,max range contains 0.
   3762   TFLITE_DCHECK_LE(rmin, 0.);
   3763   TFLITE_DCHECK_GE(rmax, 0.);
   3764 
   3765   // Determine quantization parameters: zero_point, scale.
   3766   using Integer = uint8;
   3767   const Integer qmin = std::numeric_limits<Integer>::min();
   3768   const Integer qmax = std::numeric_limits<Integer>::max();
   3769   const float qmin_float = qmin;
   3770   const float qmax_float = qmax;
   3771   int32 zero_point = 0;
   3772   float scale = 0.f;
   3773   // If rmin==rmax, both must be zero per the above assertion,
   3774   // so we are done.
   3775   if (rmin != rmax) {
   3776     // First determine the scale.
   3777     scale = (rmax - rmin) / (qmax_float - qmin_float);
   3778 
   3779     // Zero-point computation.
   3780     // First the initial floating-point computation. The zero-point can be
   3781     // determined from solving an affine equation for any known pair
   3782     // (real value, corresponding quantized value).
   3783     // We know two such pairs: (rmin, qmin) and (rmax, qmax).
   3784     // The arithmetic error on the zero point computed from either pair
   3785     // will be roughly machine_epsilon * (sum of absolute values of terms)
   3786     // so we want to use the variant that adds the smaller terms.
   3787     const float zero_point_from_min = qmin_float - rmin / scale;
   3788     const float zero_point_from_max = qmax_float - rmax / scale;
   3789     const float zero_point_from_min_error =
   3790         std::abs(qmin_float) + std::abs(rmin / scale);
   3791     const float zero_point_from_max_error =
   3792         std::abs(qmax_float) + std::abs(rmax / scale);
   3793 
   3794     const float zero_point_float =
   3795         zero_point_from_min_error < zero_point_from_max_error
   3796             ? zero_point_from_min
   3797             : zero_point_from_max;
   3798 
   3799     // Now we need to nudge the zero point to be an integer
   3800     // (our zero points are integer, and this is motivated by the requirement
   3801     // to be able to represent the real value "0" exactly as a quantized value,
   3802     // which is required in multiple places, for example in Im2col with SAME
   3803     // padding).
   3804     if (zero_point_float < qmin_float) {
   3805       zero_point = qmin;
   3806     } else if (zero_point_float > qmax_float) {
   3807       zero_point = qmax;
   3808     } else {
   3809       zero_point = static_cast<int32>(TfLiteRound(zero_point_float));
   3810     }
   3811     // The zero point should always be in the range of quantized value,
   3812     // [qmin, qmax].
   3813     TFLITE_DCHECK_GE(zero_point, qmin);
   3814     TFLITE_DCHECK_LE(zero_point, qmax);
   3815   }
   3816 
   3817   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   3818   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   3819   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   3820   const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   3821   for (int b = 0; b < batches; ++b) {
   3822     for (int y = 0; y < height; ++y) {
   3823       for (int x = 0; x < width; ++x) {
   3824         for (int c = 0; c < depth; ++c) {
   3825           const float src_val = input_data[Offset(input_dims, c, x, y, b)];
   3826           const float unclamped_quantized_val =
   3827               TfLiteRound(zero_point + src_val / scale);
   3828           const float quantized_val = std::min(
   3829               qmax_float, std::max(qmin_float, unclamped_quantized_val));
   3830           const float dst_val = scale * (quantized_val - zero_point);
   3831           output_data[Offset(output_dims, c, x, y, b)] = dst_val;
   3832         }
   3833       }
   3834     }
   3835   }
   3836 }
   3837 
   3838 template <typename SrcT, typename DstT>
   3839 inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
   3840                  DstT* output_data, const Dims<4>& output_dims) {
   3841   gemmlowp::ScopedProfilingLabel label("Cast");
   3842   auto input_map = MapAsVector(input_data, input_dims);
   3843   auto output_map = MapAsVector(output_data, output_dims);
   3844   output_map.array() = input_map.array().template cast<DstT>();
   3845 }
   3846 
   3847 inline void Floor(const float* input_data, const Dims<4>& input_dims,
   3848                   float* output_data, const Dims<4>& output_dims) {
   3849   gemmlowp::ScopedProfilingLabel label("Floor");
   3850   auto input_map = MapAsVector(input_data, input_dims);
   3851   auto output_map = MapAsVector(output_data, output_dims);
   3852   output_map.array() = Eigen::floor(input_map.array());
   3853 }
   3854 
   3855 template <typename T>
   3856 inline void Gather(const T* input_data, const Dims<4>& input_dims,
   3857                    int input_rank, const int32* coords_data,
   3858                    const Dims<4>& coords_dims, T* output_data,
   3859                    const Dims<4>& output_dims) {
   3860   gemmlowp::ScopedProfilingLabel label("Gather");
   3861 
   3862   TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
   3863   int stride = input_dims.strides[input_rank - 1];
   3864   T* out = output_data;
   3865 
   3866   for (int i = 0; i < coords_dims.sizes[0]; i++) {
   3867     TFLITE_DCHECK_GE(coords_data[i], 0);
   3868     TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
   3869     const T* in = input_data + coords_data[i] * stride;
   3870     memcpy(out, in, sizeof(T) * stride);
   3871     out += stride;
   3872   }
   3873 }
   3874 
   3875 #ifdef USE_NEON
   3876 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
   3877                                  float scale, float* output_ptr) {
   3878   int ic = 0;
   3879   // Handle 32 input channels at a time.
   3880   for (; ic <= depth - 32; ic += 32) {
   3881     float32x4x2_t input[4];
   3882     for (int i = 0; i < 4; i++) {
   3883       input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
   3884       input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
   3885     }
   3886     float32x4x2_t acc[4];
   3887     for (int i = 0; i < 4; i++) {
   3888       acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
   3889       acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
   3890     }
   3891     for (int i = 0; i < 4; i++) {
   3892       acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
   3893       acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
   3894     }
   3895     for (int i = 0; i < 4; i++) {
   3896       vst1q_f32(output_ptr, acc[i].val[0]);
   3897       vst1q_f32(output_ptr + 4, acc[i].val[1]);
   3898       output_ptr += 8;
   3899     }
   3900     input_ptr += 32;
   3901   }
   3902   // Handle 16 input channels at a time.
   3903   for (; ic <= depth - 16; ic += 16) {
   3904     float32x4x2_t input[2];
   3905     for (int i = 0; i < 2; i++) {
   3906       input[i].val[0] = vld1q_f32(input_ptr + 8 * i);
   3907       input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4);
   3908     }
   3909     float32x4x2_t acc[2];
   3910     for (int i = 0; i < 2; i++) {
   3911       acc[i].val[0] = vld1q_f32(output_ptr + 8 * i);
   3912       acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4);
   3913     }
   3914     for (int i = 0; i < 2; i++) {
   3915       acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale);
   3916       acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale);
   3917     }
   3918     for (int i = 0; i < 2; i++) {
   3919       vst1q_f32(output_ptr, acc[i].val[0]);
   3920       vst1q_f32(output_ptr + 4, acc[i].val[1]);
   3921       output_ptr += 8;
   3922     }
   3923     input_ptr += 16;
   3924   }
   3925   // Handle 8 input channels at a time.
   3926   for (; ic <= depth - 8; ic += 8) {
   3927     float32x4x2_t input;
   3928     input.val[0] = vld1q_f32(input_ptr);
   3929     input.val[1] = vld1q_f32(input_ptr + 4);
   3930 
   3931     float32x4x2_t acc;
   3932     acc.val[0] = vld1q_f32(output_ptr);
   3933     acc.val[1] = vld1q_f32(output_ptr + 4);
   3934     acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale);
   3935     acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale);
   3936 
   3937     vst1q_f32(output_ptr, acc.val[0]);
   3938     vst1q_f32(output_ptr + 4, acc.val[1]);
   3939 
   3940     input_ptr += 8;
   3941     output_ptr += 8;
   3942   }
   3943   // Handle 4 input channels at a time.
   3944   for (; ic <= depth - 4; ic += 4) {
   3945     float32x4_t input = vld1q_f32(input_ptr);
   3946     float32x4_t acc = vld1q_f32(output_ptr);
   3947 
   3948     acc = vmlaq_n_f32(acc, input, scale);
   3949     vst1q_f32(output_ptr, acc);
   3950 
   3951     input_ptr += 4;
   3952     output_ptr += 4;
   3953   }
   3954   // Handle 1 input channel at a time.
   3955   for (; ic < depth; ic++) {
   3956     *output_ptr += *input_ptr * scale;
   3957     output_ptr++;
   3958     input_ptr++;
   3959   }
   3960 }
   3961 #else
   3962 inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
   3963                                  float scale, float* output_ptr) {
   3964   for (int32 i = 0; i < depth; i++) {
   3965     *output_ptr += *input_ptr * scale;
   3966     output_ptr++;
   3967     input_ptr++;
   3968   }
   3969 }
   3970 #endif
   3971 
   3972 inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
   3973                                     int32 x, int32 y, int32 depth, int32 batch,
   3974                                     const float* input_data,
   3975                                     const Dims<4>& input_dims,
   3976                                     float* output_data,
   3977                                     const Dims<4>& output_dims) {
   3978   const int32 input_width = ArraySize(input_dims, 1);
   3979   const int32 output_width = ArraySize(output_dims, 1);
   3980 
   3981   const int32 input_x_offset = (x1 - x0) * depth;
   3982   const int32 input_y_offset = (y1 - y0) * depth * input_width;
   3983   const int32 output_x_offset = depth;
   3984   const int32 output_y_offset = depth * output_width;
   3985 
   3986 #ifdef USE_NEON
   3987   TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
   3988   TFLITE_DCHECK(x1 >= x0);
   3989   TFLITE_DCHECK(y1 >= y0);
   3990 
   3991   int ic = 0;
   3992   // Handle 8 input channels at a time.
   3993   for (; ic <= depth - 8; ic += 8) {
   3994     const float* input_ptr = nullptr;
   3995 
   3996     float32x4x2_t x0y0;
   3997     input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
   3998     x0y0.val[0] = vld1q_f32(input_ptr);
   3999     x0y0.val[1] = vld1q_f32(input_ptr + 4);
   4000 
   4001     float32x4x2_t x1y0;
   4002     input_ptr += input_x_offset;
   4003     x1y0.val[0] = vld1q_f32(input_ptr);
   4004     x1y0.val[1] = vld1q_f32(input_ptr + 4);
   4005 
   4006     float32x4x2_t x0y1;
   4007     input_ptr += -input_x_offset + input_y_offset;
   4008     x0y1.val[0] = vld1q_f32(input_ptr);
   4009     x0y1.val[1] = vld1q_f32(input_ptr + 4);
   4010 
   4011     float32x4x2_t x1y1;
   4012     input_ptr += input_x_offset;
   4013     x1y1.val[0] = vld1q_f32(input_ptr);
   4014     x1y1.val[1] = vld1q_f32(input_ptr + 4);
   4015 
   4016     // Top left corner.
   4017     float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
   4018     vst1q_f32(output_ptr, x0y0.val[0]);
   4019     vst1q_f32(output_ptr + 4, x0y0.val[1]);
   4020 
   4021     // Top right corner.
   4022     output_ptr += output_x_offset;
   4023     float32x4x2_t tr;
   4024     tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]);
   4025     tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]);
   4026     tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f);
   4027     tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f);
   4028 
   4029     vst1q_f32(output_ptr, tr.val[0]);
   4030     vst1q_f32(output_ptr + 4, tr.val[1]);
   4031 
   4032     // Bottom left corner.
   4033     output_ptr += -output_x_offset + output_y_offset;
   4034     float32x4x2_t bl;
   4035     bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]);
   4036     bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]);
   4037     bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f);
   4038     bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f);
   4039     vst1q_f32(output_ptr, bl.val[0]);
   4040     vst1q_f32(output_ptr + 4, bl.val[1]);
   4041 
   4042     // Bottom right corner.
   4043     output_ptr += output_x_offset;
   4044     float32x4x2_t br;
   4045     br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]);
   4046     br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]);
   4047     br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f);
   4048     br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f);
   4049     br.val[0] = vmulq_n_f32(br.val[0], 0.5f);
   4050     br.val[1] = vmulq_n_f32(br.val[1], 0.5f);
   4051     vst1q_f32(output_ptr, br.val[0]);
   4052     vst1q_f32(output_ptr + 4, br.val[1]);
   4053   }
   4054   // Handle 4 input channels at a time.
   4055   for (; ic <= depth - 4; ic += 4) {
   4056     const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
   4057     float32x4_t x0y0 = vld1q_f32(input_ptr);
   4058     float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
   4059     float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
   4060     float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
   4061 
   4062     // Top left corner.
   4063     float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
   4064     vst1q_f32(output_ptr, x0y0);
   4065 
   4066     // Top right corner.
   4067     output_ptr += output_x_offset;
   4068     float32x4_t tr = vaddq_f32(x0y0, x1y0);
   4069     tr = vmulq_n_f32(tr, 0.5f);
   4070     vst1q_f32(output_ptr, tr);
   4071 
   4072     // Bottom left corner.
   4073     output_ptr += -output_x_offset + output_y_offset;
   4074     float32x4_t bl = vaddq_f32(x0y0, x0y1);
   4075     bl = vmulq_n_f32(bl, 0.5f);
   4076     vst1q_f32(output_ptr, bl);
   4077 
   4078     // Bottom right corner.
   4079     output_ptr += output_x_offset;
   4080     float32x4_t br = vaddq_f32(x1y0, x1y1);
   4081     br = vmlaq_n_f32(bl, br, 0.5f);
   4082     br = vmulq_n_f32(br, 0.5f);
   4083     vst1q_f32(output_ptr, br);
   4084   }
   4085   // Handle one input channel at a time.
   4086   for (; ic < depth; ic++) {
   4087     const int32 input_offset = Offset(input_dims, ic, x0, y0, batch);
   4088 
   4089     float x0y0 = input_data[input_offset];
   4090     float x1y0 = input_data[input_offset + input_x_offset];
   4091     float x0y1 = input_data[input_offset + input_y_offset];
   4092     float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
   4093 
   4094     // Top left corner.
   4095     const int32 output_offset = Offset(output_dims, ic, x, y, batch);
   4096     output_data[output_offset] = x0y0;
   4097 
   4098     // Top right corner.
   4099     output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
   4100 
   4101     // Bottom left corner.
   4102     float output = (x0y0 + x0y1) / 2;
   4103     output_data[output_offset + output_y_offset] = output;
   4104 
   4105     // Bottom right corner.
   4106     output_data[output_offset + output_x_offset + output_y_offset] =
   4107         (output + ((x1y0 + x1y1) / 2)) / 2;
   4108   }
   4109 #else
   4110   for (int ch = 0; ch < depth; ch++) {
   4111     const int32 input_offset = Offset(input_dims, ch, x0, y0, batch);
   4112 
   4113     float x0y0 = input_data[input_offset];
   4114     float x1y0 = input_data[input_offset + input_x_offset];
   4115     float x0y1 = input_data[input_offset + input_y_offset];
   4116     float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
   4117 
   4118     // Top left corner.
   4119     const int32 output_offset = Offset(output_dims, ch, x, y, batch);
   4120     output_data[output_offset] = x0y0;
   4121 
   4122     // Top right corner.
   4123     output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
   4124 
   4125     // Bottom left corner.
   4126     float output = (x0y0 + x0y1) / 2;
   4127     output_data[output_offset + output_y_offset] = output;
   4128 
   4129     // Bottom right corner.
   4130     output_data[output_offset + output_x_offset + output_y_offset] =
   4131         (output + ((x1y0 + x1y1) / 2)) / 2;
   4132   }
   4133 #endif
   4134 }
   4135 
   4136 inline void ResizeBilinear2x2(const float* input_data,
   4137                               const Dims<4>& input_dims, float* output_data,
   4138                               const Dims<4>& output_dims, int32 batches,
   4139                               int32 input_height, int32 input_width,
   4140                               int32 depth, int32 output_height,
   4141                               int32 output_width) {
   4142   for (int b = 0; b < batches; b++) {
   4143     for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
   4144       for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
   4145         int32 x1 = std::min(x0 + 1, input_width - 1);
   4146         int32 y1 = std::min(y0 + 1, input_height - 1);
   4147         ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data,
   4148                                 input_dims, output_data, output_dims);
   4149       }
   4150     }
   4151   }
   4152 }
   4153 
   4154 inline void ResizeBilinearGeneric(const float* input_data,
   4155                                   const Dims<4>& input_dims, float* output_data,
   4156                                   const Dims<4>& output_dims, int32 batches,
   4157                                   int32 input_height, int32 input_width,
   4158                                   int32 depth, int32 output_height,
   4159                                   int32 output_width, float height_scale,
   4160                                   float width_scale) {
   4161   memset(output_data, 0,
   4162          batches * output_height * output_width * depth * sizeof(float));
   4163 
   4164   int32 output_offset = 0;
   4165   for (int b = 0; b < batches; ++b) {
   4166     for (int y = 0; y < output_height; ++y) {
   4167       float input_y = y * height_scale;
   4168       int32 y0 = static_cast<int32>(std::floor(input_y));
   4169       int32 y1 = std::min(y0 + 1, input_height - 1);
   4170       for (int x = 0; x < output_width; ++x) {
   4171         float input_x = x * width_scale;
   4172         int32 x0 = static_cast<int32>(input_x);
   4173         int32 x1 = std::min(x0 + 1, input_width - 1);
   4174         float* output_ptr = &output_data[output_offset];
   4175 
   4176         // Run kernel on the 4 corners of the bilinear resize algorithm.
   4177         int32 input_offset = Offset(input_dims, 0, x0, y0, b);
   4178         float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
   4179         const float* input_ptr = &input_data[input_offset];
   4180         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
   4181 
   4182         input_offset = Offset(input_dims, 0, x1, y0, b);
   4183         scale = (1 - (input_y - y0)) * (input_x - x0);
   4184         input_ptr = &input_data[input_offset];
   4185         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
   4186 
   4187         input_offset = Offset(input_dims, 0, x0, y1, b);
   4188         scale = (input_y - y0) * (1 - (input_x - x0));
   4189         input_ptr = &input_data[input_offset];
   4190         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
   4191 
   4192         input_offset = Offset(input_dims, 0, x1, y1, b);
   4193         scale = (input_y - y0) * (input_x - x0);
   4194         input_ptr = &input_data[input_offset];
   4195         ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
   4196 
   4197         output_offset += depth;
   4198       }
   4199     }
   4200   }
   4201 }
   4202 
   4203 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
   4204                            const int32* output_size_data,
   4205                            const Dims<4>& output_size_dims, float* output_data,
   4206                            const Dims<4>& output_dims, bool align_corners) {
   4207   gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
   4208   int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   4209   int32 input_height = ArraySize(input_dims, 2);
   4210   int32 input_width = ArraySize(input_dims, 1);
   4211   int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
   4212 
   4213   TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
   4214   TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
   4215   TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
   4216   TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
   4217   int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
   4218   int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
   4219 
   4220   // Specialize for 2x2 upsample.
   4221   if (!align_corners && output_height == 2 * input_height &&
   4222       output_width == 2 * input_width) {
   4223     ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches,
   4224                       input_height, input_width, depth, output_height,
   4225                       output_width);
   4226   } else {
   4227     float height_scale = static_cast<float>(input_height) / output_height;
   4228     float width_scale = static_cast<float>(input_width) / output_width;
   4229     if (align_corners && output_height > 1) {
   4230       height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
   4231     }
   4232     if (align_corners && output_width > 1) {
   4233       width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
   4234     }
   4235 
   4236     ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims,
   4237                           batches, input_height, input_width, depth,
   4238                           output_height, output_width, height_scale,
   4239                           width_scale);
   4240   }
   4241 }
   4242 
   4243 // legacy, for compatibility with old checked-in code
   4244 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
   4245                            const int32* output_size_data,
   4246                            const Dims<4>& output_size_dims, float* output_data,
   4247                            const Dims<4>& output_dims) {
   4248   ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
   4249                  output_data, output_dims, /*align_corners=*/false);
   4250 }
   4251 
   4252 template <typename T>
   4253 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
   4254                            const int32* block_shape_data,
   4255                            const Dims<4>& block_shape_dims,
   4256                            const int32* paddings_data,
   4257                            const Dims<4>& paddings_dims, T* output_data,
   4258                            const Dims<4>& output_dims) {
   4259   gemmlowp::ScopedProfilingLabel label("SpaceToBatchND");
   4260 
   4261   const int output_batch_size = ArraySize(output_dims, 3);
   4262   const int output_height = ArraySize(output_dims, 2);
   4263   const int output_width = ArraySize(output_dims, 1);
   4264   const int input_batch_size = ArraySize(input_dims, 3);
   4265   const int input_height = ArraySize(input_dims, 2);
   4266   const int input_width = ArraySize(input_dims, 1);
   4267   const int depth = ArraySize(input_dims, 0);
   4268   const int block_shape_height = block_shape_data[0];
   4269   const int block_shape_width = block_shape_data[1];
   4270   const int padding_top = paddings_data[0];
   4271   const int padding_left = paddings_data[2];
   4272 
   4273   for (int out_b = 0; out_b < output_batch_size; ++out_b) {
   4274     int input_batch = out_b % input_batch_size;
   4275     int shift_w = (out_b / input_batch_size) % block_shape_width;
   4276     int shift_h = (out_b / input_batch_size) / block_shape_width;
   4277     for (int out_h = 0; out_h < output_height; ++out_h) {
   4278       for (int out_w = 0; out_w < output_width; ++out_w) {
   4279         T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
   4280         if (out_h * block_shape_height + shift_h < padding_top ||
   4281             out_h * block_shape_height + shift_h >=
   4282                 padding_top + input_height ||
   4283             out_w * block_shape_width + shift_w < padding_left ||
   4284             out_w * block_shape_width + shift_w >= padding_left + input_width) {
   4285           memset(out, 0, depth * sizeof(T));
   4286         } else {
   4287           const T* in =
   4288               input_data +
   4289               Offset(input_dims, 0,
   4290                      (out_w * block_shape_width + shift_w) - padding_left,
   4291                      (out_h * block_shape_height + shift_h) - padding_top,
   4292                      input_batch);
   4293           memcpy(out, in, depth * sizeof(T));
   4294         }
   4295       }
   4296     }
   4297   }
   4298 }
   4299 
   4300 template <typename T>
   4301 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
   4302                            const int32* block_shape_data,
   4303                            const Dims<4>& block_shape_dims, T* output_data,
   4304                            const Dims<4>& output_dims) {
   4305   gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
   4306 
   4307   const int output_batch_size = ArraySize(output_dims, 3);
   4308   const int input_batch_size = ArraySize(input_dims, 3);
   4309   const int input_height = ArraySize(input_dims, 2);
   4310   const int input_width = ArraySize(input_dims, 1);
   4311   const int depth = ArraySize(input_dims, 0);
   4312   const int block_shape_width = block_shape_data[1];
   4313   const int block_shape_height = block_shape_data[0];
   4314 
   4315   for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) {
   4316     for (int in_h = 0; in_h < input_height; ++in_h) {
   4317       for (int in_w = 0; in_w < input_width; ++in_w) {
   4318         int out_batch = in_batch % output_batch_size;
   4319         int out_w = in_w * block_shape_width +
   4320                     (in_batch / output_batch_size) % block_shape_width;
   4321         int out_h = in_h * block_shape_height +
   4322                     (in_batch / output_batch_size) / block_shape_width;
   4323         T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
   4324         const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
   4325         memcpy(out, in, depth * sizeof(T));
   4326       }
   4327     }
   4328   }
   4329 }
   4330 
   4331 template <typename T>
   4332 inline void Pad(const T* input_data, const Dims<4>& input_dims,
   4333                 const std::vector<int>& left_paddings,
   4334                 const std::vector<int>& right_paddings, T* output_data,
   4335                 const Dims<4>& output_dims) {
   4336   gemmlowp::ScopedProfilingLabel label("Pad");
   4337   const int output_batch = ArraySize(output_dims, 3);
   4338   const int output_height = ArraySize(output_dims, 2);
   4339   const int output_width = ArraySize(output_dims, 1);
   4340   const int output_depth = ArraySize(output_dims, 0);
   4341 
   4342   const int left_b_padding = left_paddings[3];
   4343   const int left_h_padding = left_paddings[2];
   4344   const int left_w_padding = left_paddings[1];
   4345   const int left_d_padding = left_paddings[0];
   4346 
   4347   const int right_b_padding = right_paddings[3];
   4348   const int right_h_padding = right_paddings[2];
   4349   const int right_w_padding = right_paddings[1];
   4350   const int right_d_padding = right_paddings[0];
   4351 
   4352   const int input_depth = ArraySize(input_dims, 0);
   4353 
   4354   if (left_b_padding != 0) {
   4355     memset(output_data, 0,
   4356            left_b_padding * output_height * output_width * output_depth *
   4357                sizeof(T));
   4358   }
   4359   for (int out_b = left_b_padding; out_b < output_batch - right_b_padding;
   4360        ++out_b) {
   4361     if (left_h_padding != 0) {
   4362       memset(output_data + Offset(output_dims, 0, 0, 0, out_b), 0,
   4363              left_h_padding * output_width * output_depth * sizeof(T));
   4364     }
   4365     for (int out_h = left_h_padding; out_h < output_height - right_h_padding;
   4366          ++out_h) {
   4367       if (left_w_padding != 0) {
   4368         memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), 0,
   4369                left_w_padding * output_depth * sizeof(T));
   4370       }
   4371       for (int out_w = left_w_padding; out_w < output_width - right_w_padding;
   4372            ++out_w) {
   4373         if (left_d_padding != 0) {
   4374           memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b), 0,
   4375                  left_d_padding * sizeof(T));
   4376         }
   4377 
   4378         T* out = output_data +
   4379                  Offset(output_dims, left_d_padding, out_w, out_h, out_b);
   4380         const T* in =
   4381             input_data + Offset(input_dims, 0, out_w - left_w_padding,
   4382                                 out_h - left_h_padding, out_b - left_b_padding);
   4383         memcpy(out, in, input_depth * sizeof(T));
   4384 
   4385         if (right_d_padding != 0) {
   4386           memset(
   4387               output_data + Offset(output_dims, output_depth - right_d_padding,
   4388                                    out_w, out_h, out_b),
   4389               0, right_d_padding * sizeof(T));
   4390         }
   4391       }
   4392       if (right_w_padding != 0) {
   4393         memset(
   4394             output_data + Offset(output_dims, 0, output_width - right_w_padding,
   4395                                  out_h, out_b),
   4396             0, right_w_padding * output_depth * sizeof(T));
   4397       }
   4398     }
   4399     if (right_h_padding != 0) {
   4400       memset(output_data + Offset(output_dims, 0, 0,
   4401                                   output_height - right_h_padding, out_b),
   4402              0, right_h_padding * output_width * output_depth * sizeof(T));
   4403     }
   4404   }
   4405   if (right_b_padding != 0) {
   4406     memset(output_data +
   4407                Offset(output_dims, 0, 0, 0, output_batch - right_b_padding),
   4408            0,
   4409            right_b_padding * output_height * output_width * output_depth *
   4410                sizeof(T));
   4411   }
   4412 }
   4413 
   4414 template <typename T>
   4415 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
   4416                          int begin_mask, int end_mask,
   4417                          const std::vector<int>& starts,
   4418                          const std::vector<int>& stops,
   4419                          const std::vector<int>& strides, T* output_data,
   4420                          const Dims<4>& output_dims) {
   4421   gemmlowp::ScopedProfilingLabel label("StridedSlice");
   4422   const int start_b = (begin_mask & 8) ? 0 : starts[3];
   4423   const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3];
   4424   const int start_h = (begin_mask & 4) ? 0 : starts[2];
   4425   const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2];
   4426   const int start_w = (begin_mask & 2) ? 0 : starts[1];
   4427   const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1];
   4428   const int start_d = (begin_mask & 1) ? 0 : starts[0];
   4429   const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0];
   4430 
   4431   T* out_ptr = output_data;
   4432   if (strides[0] == 0) {
   4433     for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
   4434       for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
   4435         for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
   4436           const int len = stop_d - start_d;
   4437           memcpy(out_ptr,
   4438                  input_data + Offset(input_dims, start_d, in_w, in_h, in_b),
   4439                  len * sizeof(T));
   4440           out_ptr += len;
   4441         }
   4442       }
   4443     }
   4444   } else {
   4445     for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) {
   4446       for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) {
   4447         for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) {
   4448           for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) {
   4449             *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
   4450           }
   4451         }
   4452       }
   4453     }
   4454   }
   4455 }
   4456 
   4457 template <typename T>
   4458 inline void Slice(const T* input_data, const Dims<4>& input_dims,
   4459                   const std::vector<int>& begin, const std::vector<int>& size,
   4460                   T* output_data, const Dims<4>& output_dims) {
   4461   // TODO(dkalenichenko): This op only supports 4D tensors.
   4462   TFLITE_DCHECK_EQ(begin.size(), 4);
   4463   TFLITE_DCHECK_EQ(size.size(), 4);
   4464   const int start_b = begin[3];
   4465   const int stop_b =
   4466       size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3];
   4467   const int start_h = begin[2];
   4468   const int stop_h =
   4469       size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2];
   4470   const int start_w = begin[1];
   4471   const int stop_w =
   4472       size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1];
   4473   const int start_d = begin[0];
   4474   const int stop_d =
   4475       size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0];
   4476 
   4477   T* out_ptr = output_data;
   4478   for (int in_b = start_b; in_b < stop_b; ++in_b) {
   4479     for (int in_h = start_h; in_h < stop_h; ++in_h) {
   4480       for (int in_w = start_w; in_w < stop_w; ++in_w) {
   4481         const int len = stop_d - start_d;
   4482         memcpy(out_ptr,
   4483                input_data + Offset(input_dims, start_d, in_w, in_h, in_b),
   4484                len * sizeof(T));
   4485         out_ptr += len;
   4486       }
   4487     }
   4488   }
   4489 }
   4490 
   4491 template <typename T>
   4492 inline void Mean(const T* input_data, const Dims<4>& input_dims,
   4493                  const std::vector<int>& reduction_indices, T* output_data,
   4494                  const Dims<4>& output_dims) {
   4495   gemmlowp::ScopedProfilingLabel label("Mean");
   4496   const int output_batch = ArraySize(output_dims, 3);
   4497   const int output_height = ArraySize(output_dims, 2);
   4498   const int output_width = ArraySize(output_dims, 1);
   4499   const int output_depth = ArraySize(output_dims, 0);
   4500 
   4501   const int input_height = ArraySize(input_dims, 2);
   4502   const int input_width = ArraySize(input_dims, 1);
   4503 
   4504   // The current implementation only supports simultaneous reduction over
   4505   // width and height.
   4506   TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
   4507   TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
   4508                 (reduction_indices[0] == 2 && reduction_indices[1] == 1));
   4509   TFLITE_DCHECK_EQ(output_height, 1);
   4510   TFLITE_DCHECK_EQ(output_width, 1);
   4511 
   4512   for (int out_b = 0; out_b < output_batch; ++out_b) {
   4513     for (int out_d = 0; out_d < output_depth; ++out_d) {
   4514       float value = 0;
   4515       for (int in_h = 0; in_h < input_height; ++in_h) {
   4516         for (int in_w = 0; in_w < input_width; ++in_w) {
   4517           value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
   4518         }
   4519       }
   4520       output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
   4521           value / (input_width * input_height);
   4522     }
   4523   }
   4524 }
   4525 
   4526 template <typename T>
   4527 void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
   4528                          const T* input2_data, const Dims<4>& input2_dims,
   4529                          T* output_data, const Dims<4>& output_dims) {
   4530   gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub");
   4531 
   4532   NdArrayDesc<4> desc1;
   4533   NdArrayDesc<4> desc2;
   4534   NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
   4535 
   4536   // In Tensorflow, the dimensions are canonically named (batch_number, row,
   4537   // col, channel), with extents (batches, height, width, depth), with the
   4538   // trailing dimension changing most rapidly (channels has the smallest stride,
   4539   // typically 1 element).
   4540   //
   4541   // In generated C code, we store arrays with the dimensions reversed. The
   4542   // first dimension has smallest stride.
   4543   //
   4544   // We name our variables by their Tensorflow convention, but generate C code
   4545   // nesting loops such that the innermost loop has the smallest stride for the
   4546   // best cache behavior.
   4547   for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
   4548     for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
   4549       for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
   4550         for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
   4551           output_data[Offset(output_dims, c, x, y, b)] =
   4552               input1_data[SubscriptToIndex(desc1, c, x, y, b)] -
   4553               input2_data[SubscriptToIndex(desc2, c, x, y, b)];
   4554         }
   4555       }
   4556     }
   4557   }
   4558 }
   4559 
   4560 template <typename T>
   4561 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
   4562          const Dims<4>& input2_dims, T* output_data,
   4563          const Dims<4>& output_dims) {
   4564   gemmlowp::ScopedProfilingLabel label("Sub");
   4565 
   4566   auto input1_map = MapAsVector(input1_data, input1_dims);
   4567   auto input2_map = MapAsVector(input2_data, input2_dims);
   4568   auto output_map = MapAsVector(output_data, output_dims);
   4569   if (AreSameDims(input1_dims, input2_dims)) {
   4570     output_map.array() = input1_map.array() - input2_map.array();
   4571   } else if (RequiredBufferSizeForDims(input1_dims) == 1) {
   4572     auto scalar = input1_data[0];
   4573     output_map.array() = scalar - input2_map.array();
   4574   } else if (RequiredBufferSizeForDims(input2_dims) == 1) {
   4575     auto scalar = input2_data[0];
   4576     output_map.array() = input1_map.array() - scalar;
   4577   } else {
   4578     GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims,
   4579                         output_data, output_dims);
   4580   }
   4581 }
   4582 
   4583 template <typename T>
   4584 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
   4585                        const T* input2_data, T* output_data,
   4586                        const Dims<4>& output_dims) {
   4587   gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum");
   4588   auto input1_map = MapAsVector(input1_data, input1_dims);
   4589   auto output_map = MapAsVector(output_data, output_dims);
   4590   auto min_value = input2_data[0];
   4591   output_map.array() = input1_map.array().min(min_value);
   4592 }
   4593 
   4594 template <typename T>
   4595 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
   4596                        const T* input2_data, T* output_data,
   4597                        const Dims<4>& output_dims) {
   4598   gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum");
   4599   auto input1_map = MapAsVector(input1_data, input1_dims);
   4600   auto output_map = MapAsVector(output_data, output_dims);
   4601   auto max_value = input2_data[0];
   4602   output_map.array() = input1_map.array().max(max_value);
   4603 }
   4604 
   4605 template <typename T1, typename T2, typename T3>
   4606 void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
   4607             T2* output_data, const Dims<4>& output_dims) {
   4608   gemmlowp::ScopedProfilingLabel label("ArgMax");
   4609 
   4610   // The current ArgMax implemention can only determine the index of the maximum
   4611   // value in the last dimension. So the axis argument is ignored.
   4612   TFLITE_DCHECK_EQ(axis[0], 3);
   4613 
   4614   // For ArgMax, the number of output dimensions = (number of input dimensions -
   4615   // 1). For the sake of simplicity, the output dimensions are equal to the
   4616   // input dimensions here. We enforce the constraint that the last dimension
   4617   // must always be 1.
   4618   TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1);
   4619   const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
   4620   const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
   4621   const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
   4622   const int depth = ArraySize(input_dims, 0);
   4623   for (int b = 0; b < batches; ++b) {
   4624     for (int y = 0; y < height; ++y) {
   4625       for (int x = 0; x < width; ++x) {
   4626         auto max_value = input_data[Offset(input_dims, 0, x, y, b)];
   4627         int max_index = 0;
   4628         for (int d = 1; d < depth; ++d) {
   4629           const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)];
   4630           if (curr_value > max_value) {
   4631             max_value = curr_value;
   4632             max_index = d;
   4633           }
   4634         }
   4635         output_data[Offset(output_dims, 0, x, y, b)] = max_index;
   4636       }
   4637     }
   4638   }
   4639 }
   4640 
   4641 }  // namespace optimized_ops
   4642 }  // namespace tflite
   4643 
   4644 #if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
   4645 #undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS
   4646 #pragma GCC diagnostic pop
   4647 #endif
   4648 
   4649 #endif  // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_
   4650