Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implements a quantized version of the resize bilinear op.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
     21 #define USE_NEON
     22 #define QUANTIZED_RESIZE_BILINEAR_USE_NEON
     23 #include <arm_neon.h>
     24 #endif
     25 
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/kernels/image_resizer_state.h"
     29 #include "tensorflow/core/kernels/quantization_utils.h"
     30 #include "tensorflow/core/platform/macros.h"
     31 
     32 namespace tensorflow {
     33 
     34 static constexpr bool USE_REFERENCE = false;
     35 
     36 namespace {
     37 // Compute the interpolation indices only once.
     38 template <typename T_SCALE>
     39 struct InterpolationCache {
     40   std::vector<int64> lower;  // Lower source index used in the interpolation
     41   std::vector<int64> upper;  // Upper source index used in the interpolation
     42   // 1-D linear iterpolation scale (see:
     43   // https://en.wikipedia.org/wiki/Bilinear_interpolation)
     44   std::vector<float> lerp;
     45   std::vector<T_SCALE> ilerp;
     46 };
     47 
     48 template <typename T_SCALE>
     49 inline void ComputeInterpolationWeights(
     50     const int64 out_size, const int64 in_size, const float scale,
     51     const int resolution, InterpolationCache<T_SCALE>* interpolation) {
     52   interpolation->lower.resize(out_size + 1);
     53   interpolation->upper.resize(out_size + 1);
     54   interpolation->lerp.resize(out_size + 1);
     55   interpolation->ilerp.resize(out_size + 1);
     56 
     57   interpolation->lower[out_size] = 0;
     58   interpolation->upper[out_size] = 0;
     59   for (int64 i = out_size - 1; i >= 0; --i) {
     60     const float in = i * scale;
     61     interpolation->lower[i] = static_cast<int64>(in);
     62     interpolation->upper[i] =
     63         std::min(interpolation->lower[i] + 1, in_size - 1);
     64     interpolation->lerp[i] = in - interpolation->lower[i];
     65     interpolation->ilerp[i] = static_cast<T_SCALE>(
     66         (in - interpolation->lower[i]) * (1 << resolution));
     67   }
     68 }
     69 
     70 template <typename T_SCALE>
     71 inline InterpolationCache<T_SCALE> BuildLerpCache(const int64 out_size,
     72                                                   const int64 in_size,
     73                                                   const float scale,
     74                                                   const int index_step,
     75                                                   const int resolution) {
     76   InterpolationCache<T_SCALE> cache;
     77   // Compute the cached interpolation weights on the x and y dimensions.
     78   ComputeInterpolationWeights<T_SCALE>(out_size, in_size, scale, resolution,
     79                                        &cache);
     80   CHECK(index_step > 0);
     81   if (index_step > 1) {
     82     for (int i = 0; i < cache.lower.size(); ++i) {
     83       cache.lower[i] *= index_step;
     84       cache.upper[i] *= index_step;
     85     }
     86   }
     87   return cache;
     88 }
     89 
     90 /**
     91  * Computes the bilinear interpolation from the appropriate 4 float points
     92  * and the linear interpolation weights.
     93  */
     94 template <typename T>
     95 inline T ComputeLerpReference(const T in_top_left, const T in_top_right,
     96                               const T in_bottom_left, const T in_bottom_right,
     97                               const float x_lerp, const float y_lerp,
     98                               const float min, const float max) {
     99   const float top_left = QuantizedToFloat<T>(in_top_left, min, max);
    100   const float top_right = QuantizedToFloat<T>(in_top_right, min, max);
    101   const float bottom_left = QuantizedToFloat<T>(in_bottom_left, min, max);
    102   const float bottom_right = QuantizedToFloat<T>(in_bottom_right, min, max);
    103   const float top = top_left + (top_right - top_left) * x_lerp;
    104   const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
    105   const float out = top + (bottom - top) * y_lerp;
    106   return FloatToQuantized<T>(out, min, max);
    107 }
    108 
    109 template <typename T, typename T_SCALE, typename T_CALC>
    110 inline T_CALC MulOffset(T a, T b, T_SCALE c) {
    111   return (static_cast<T_CALC>(a) - static_cast<T_CALC>(b)) *
    112          static_cast<T_CALC>(c);
    113 }
    114 
    115 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
    116 inline T ComputeLerp(const T top_left, const T top_right, const T bottom_left,
    117                      const T bottom_right, const T_SCALE x_lerp,
    118                      const T_SCALE y_lerp) {
    119   constexpr T_CALC RESOLUTION_MULT = (1 << RESOLUTION);
    120   const T_CALC top = static_cast<T_CALC>(top_left) * RESOLUTION_MULT +
    121                      MulOffset<T, T_SCALE, T_CALC>(top_right, top_left, x_lerp);
    122   const T_CALC bottom =
    123       static_cast<T_CALC>(bottom_left) * RESOLUTION_MULT +
    124       MulOffset<T, T_SCALE, T_CALC>(bottom_right, bottom_left, x_lerp);
    125   const T_CALC out = top + (bottom - top) / RESOLUTION_MULT * y_lerp;
    126   return static_cast<T>(
    127       static_cast<int32>((out + RESOLUTION_MULT / 2) / RESOLUTION_MULT));
    128 }
    129 
    130 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
    131 inline uint8x8_t ToUint8x8(const quint8* v0, const quint8* v1, const quint8* v2,
    132                            const quint8* v3, const quint8* v4, const quint8* v5,
    133                            const quint8* v6, const quint8* v7) {
    134   static const uint8x8_t ZERO_8x8 = vmov_n_u8(0);
    135   uint8x8_t ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v0), ZERO_8x8, 0);
    136   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v1), ret, 1);
    137   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v2), ret, 2);
    138   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v3), ret, 3);
    139   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v4), ret, 4);
    140   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v5), ret, 5);
    141   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v6), ret, 6);
    142   ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v7), ret, 7);
    143   return ret;
    144 }
    145 
    146 inline int16x8_t ToInt16x8(const int16* v0, const int16* v1, const int16* v2,
    147                            const int16* v3, const int16* v4, const int16* v5,
    148                            const int16* v6, const int16* v7) {
    149   static const int16x8_t ZERO_16x8 = vmovq_n_s16(0);
    150   int16x8_t ret = vld1q_lane_s16(v0, ZERO_16x8, 0);
    151   ret = vld1q_lane_s16(v1, ret, 1);
    152   ret = vld1q_lane_s16(v2, ret, 2);
    153   ret = vld1q_lane_s16(v3, ret, 3);
    154   ret = vld1q_lane_s16(v4, ret, 4);
    155   ret = vld1q_lane_s16(v5, ret, 5);
    156   ret = vld1q_lane_s16(v6, ret, 6);
    157   ret = vld1q_lane_s16(v7, ret, 7);
    158   return ret;
    159 }
    160 
    161 inline int32x2_t ToInt32x2(const qint32* v0, const qint32* v1) {
    162   static const int32x2_t ZERO_32x2 = vmov_n_s32(0);
    163   const int32x2_t ret0 =
    164       vld1_lane_s32(reinterpret_cast<const int32*>(v0), ZERO_32x2, 0);
    165   const int32x2_t ret1 =
    166       vld1_lane_s32(reinterpret_cast<const int32*>(v1), ret0, 1);
    167   return ret1;
    168 }
    169 
    170 template <int RESOLUTION, bool X_LERP_SAME>
    171 inline int32x2_t ComputeLerpx2(
    172     const qint32* top_left0, const qint32* top_right0,
    173     const qint32* bottom_left0, const qint32* bottom_right0,
    174     const qint32* top_left1, const qint32* top_right1,
    175     const qint32* bottom_left1, const qint32* bottom_right1,
    176     const int32* x_lerp, const int32x2_t y_lerpsx) {
    177   const int32x2_t x_lerpsx =
    178       X_LERP_SAME ? vld1_dup_s32(reinterpret_cast<const int32*>(x_lerp))
    179                   : vld1_s32(reinterpret_cast<const int32*>(x_lerp));
    180 
    181   const int32x2_t top_leftsx = ToInt32x2(top_left0, top_left1);
    182   const int32x2_t top_rightsx = ToInt32x2(top_right0, top_right1);
    183   const int32x2_t bottom_leftsx = ToInt32x2(bottom_left0, bottom_left1);
    184   const int32x2_t bottom_rightsx = ToInt32x2(bottom_right0, bottom_right1);
    185 
    186   const int32x2_t retval =
    187       ComputeLerp32x2<RESOLUTION>(top_leftsx, top_rightsx, bottom_leftsx,
    188                                   bottom_rightsx, x_lerpsx, y_lerpsx);
    189   return retval;
    190 }
    191 
    192 template <int RESOLUTION>
    193 inline uint8x8_t ComputeLerpx8(
    194     const quint8* tl0, const quint8* tr0, const quint8* bl0, const quint8* br0,
    195     const int16* xlp0, const quint8* tl1, const quint8* tr1, const quint8* bl1,
    196     const quint8* br1, const int16* xlp1, const quint8* tl2, const quint8* tr2,
    197     const quint8* bl2, const quint8* br2, const int16* xlp2, const quint8* tl3,
    198     const quint8* tr3, const quint8* bl3, const quint8* br3, const int16* xlp3,
    199     const quint8* tl4, const quint8* tr4, const quint8* bl4, const quint8* br4,
    200     const int16* xlp4, const quint8* tl5, const quint8* tr5, const quint8* bl5,
    201     const quint8* br5, const int16* xlp5, const quint8* tl6, const quint8* tr6,
    202     const quint8* bl6, const quint8* br6, const int16* xlp6, const quint8* tl7,
    203     const quint8* tr7, const quint8* bl7, const quint8* br7, const int16* xlp7,
    204     const int16x8_t ys_lerpsx) {
    205   const uint8x8_t tl8x8 = ToUint8x8(tl0, tl1, tl2, tl3, tl4, tl5, tl6, tl7);
    206   const uint8x8_t tr8x8 = ToUint8x8(tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7);
    207   const uint8x8_t bl8x8 = ToUint8x8(bl0, bl1, bl2, bl3, bl4, bl5, bl6, bl7);
    208   const uint8x8_t br8x8 = ToUint8x8(br0, br1, br2, br3, br4, br5, br6, br7);
    209   const int16x8_t xs_lerpsx =
    210       ToInt16x8(xlp0, xlp1, xlp2, xlp3, xlp4, xlp5, xlp6, xlp7);
    211   return ComputeLerp8x8<RESOLUTION>(tl8x8, tr8x8, bl8x8, br8x8, xs_lerpsx,
    212                                     ys_lerpsx);
    213 }
    214 
    215 // Expand address at compile time to improve performance
    216 template <int RESOLUTION, int ID0, int CH0, int ID1, int CH1, int ID2, int CH2,
    217           int ID3, int CH3, int ID4, int CH4, int ID5, int CH5, int ID6,
    218           int CH6, int ID7, int CH7>
    219 inline uint8x8_t ComputeLerpx8Tmpl(const quint8* const yl, const quint8* yu,
    220                                    const int64* xl, const int64* xu,
    221                                    const int16* xlp,
    222                                    const int16x8_t ys_lerpsx) {
    223   return ComputeLerpx8<RESOLUTION>(
    224       yl + xl[ID0] + CH0, yl + xu[ID0] + CH0, yu + xl[ID0] + CH0,
    225       yu + xu[ID0] + CH0, xlp + ID0, yl + xl[ID1] + CH1, yl + xu[ID1] + CH1,
    226       yu + xl[ID1] + CH1, yu + xu[ID1] + CH1, xlp + ID1, yl + xl[ID2] + CH2,
    227       yl + xu[ID2] + CH2, yu + xl[ID2] + CH2, yu + xu[ID2] + CH2, xlp + ID2,
    228       yl + xl[ID3] + CH3, yl + xu[ID3] + CH3, yu + xl[ID3] + CH3,
    229       yu + xu[ID3] + CH3, xlp + ID3, yl + xl[ID4] + CH4, yl + xu[ID4] + CH4,
    230       yu + xl[ID4] + CH4, yu + xu[ID4] + CH4, xlp + ID4, yl + xl[ID5] + CH5,
    231       yl + xu[ID5] + CH5, yu + xl[ID5] + CH5, yu + xu[ID5] + CH5, xlp + ID5,
    232       yl + xl[ID6] + CH6, yl + xu[ID6] + CH6, yu + xl[ID6] + CH6,
    233       yu + xu[ID6] + CH6, xlp + ID6, yl + xl[ID7] + CH7, yl + xu[ID7] + CH7,
    234       yu + xl[ID7] + CH7, yu + xu[ID7] + CH7, xlp + ID7, ys_lerpsx);
    235 }
    236 
    237 #endif
    238 
    239 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
    240 inline void OutputLerpForChannels(const InterpolationCache<T_SCALE>& xs,
    241                                   const int64 x, const T_SCALE ys_ilerp,
    242                                   const int channels, const float min,
    243                                   const float max, const T* ys_input_lower_ptr,
    244                                   const T* ys_input_upper_ptr,
    245                                   T* output_y_ptr) {
    246   const int64 xs_lower = xs.lower[x];
    247   const int64 xs_upper = xs.upper[x];
    248   const T_SCALE xs_ilerp = xs.ilerp[x];
    249   for (int c = 0; c < channels; ++c) {
    250     const T top_left = ys_input_lower_ptr[xs_lower + c];
    251     const T top_right = ys_input_lower_ptr[xs_upper + c];
    252     const T bottom_left = ys_input_upper_ptr[xs_lower + c];
    253     const T bottom_right = ys_input_upper_ptr[xs_upper + c];
    254     const T val = ComputeLerp<RESOLUTION, T, T_SCALE, T_CALC>(
    255         top_left, top_right, bottom_left, bottom_right, xs_ilerp, ys_ilerp);
    256     output_y_ptr[x * channels + c] = val;
    257   }
    258 }
    259 
    260 template <int RES>
    261 inline void OutputLerp8x8x1(const InterpolationCache<int16>& xs,
    262                             const int64 x_start, const int16 ys_ilerp,
    263                             const float min, const float max,
    264                             const quint8* const ys_input_lower_ptr,
    265                             const quint8* const ys_input_upper_ptr,
    266                             quint8* output_y_ptr) {
    267 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
    268   const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
    269 
    270   const uint8x8_t x0x7 =
    271       ComputeLerpx8Tmpl<RES, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0>(
    272           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
    273           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
    274 
    275   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start), x0x7);
    276 
    277 #else
    278   for (int x = x_start; x < x_start + 8; ++x) {
    279     OutputLerpForChannels<RES, quint8, int16, int16>(
    280         xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
    281         output_y_ptr);
    282   }
    283 #endif
    284 }
    285 
    286 template <int RES>
    287 inline void OutputLerp8x8x3(const InterpolationCache<int16>& xs,
    288                             const int64 x_start, const int16 ys_ilerp,
    289                             const float min, const float max,
    290                             const quint8* const ys_input_lower_ptr,
    291                             const quint8* const ys_input_upper_ptr,
    292                             quint8* output_y_ptr) {
    293 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
    294   const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
    295 
    296   const uint8x8_t x0c0x2c1 =
    297       ComputeLerpx8Tmpl<RES, 0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 0, 2, 1>(
    298           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
    299           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
    300 
    301   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3), x0c0x2c1);
    302 
    303   const uint8x8_t x2c2x5c0 =
    304       ComputeLerpx8Tmpl<RES, 2, 2, 3, 0, 3, 1, 3, 2, 4, 0, 4, 1, 4, 2, 5, 0>(
    305           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
    306           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
    307 
    308   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 8), x2c2x5c0);
    309 
    310   const uint8x8_t x5c1x7c2 =
    311       ComputeLerpx8Tmpl<RES, 5, 1, 5, 2, 6, 0, 6, 1, 6, 2, 7, 0, 7, 1, 7, 2>(
    312           ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
    313           &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
    314 
    315   vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 16),
    316           x5c1x7c2);
    317 
    318 #else
    319   for (int x = x_start; x < x_start + 8; ++x) {
    320     OutputLerpForChannels<RES, quint8, int16, int16>(
    321         xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
    322         output_y_ptr);
    323   }
    324 #endif
    325 }
    326 
    327 template <int RESOLUTION>
    328 inline void OutputLerp32x4x1(const InterpolationCache<int32>& xs,
    329                              const int64 x_start, const int32 ys_ilerp,
    330                              const float min, const float max,
    331                              const qint32* const ys_input_lower_ptr,
    332                              const qint32* const ys_input_upper_ptr,
    333                              qint32* output_y_ptr) {
    334 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
    335   const int64 xs_lower0 = xs.lower[x_start];
    336   const int64 xs_upper0 = xs.upper[x_start];
    337   const int32* const xs_ilerp0 = &xs.ilerp[x_start];
    338   const int64 xs_lower1 = xs.lower[x_start + 1];
    339   const int64 xs_upper1 = xs.upper[x_start + 1];
    340   const int64 xs_lower2 = xs.lower[x_start + 2];
    341   const int64 xs_upper2 = xs.upper[x_start + 2];
    342   const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
    343   const int64 xs_lower3 = xs.lower[x_start + 3];
    344   const int64 xs_upper3 = xs.upper[x_start + 3];
    345 
    346   const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
    347 
    348   const int32x2_t x0x1 = ComputeLerpx2<RESOLUTION, false>(
    349       ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
    350       ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
    351       ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
    352       ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
    353       y_lerpsx);
    354 
    355   const int32x2_t x1x2 = ComputeLerpx2<RESOLUTION, false>(
    356       ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
    357       ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
    358       ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
    359       ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
    360       y_lerpsx);
    361 
    362   const int32x4_t x0x1x2x3 = vcombine_s32(x0x1, x1x2);
    363 
    364   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start), x0x1x2x3);
    365 
    366 #else
    367   for (int x = x_start; x < x_start + 4; ++x) {
    368     OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
    369         xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
    370         output_y_ptr);
    371   }
    372 #endif
    373 }
    374 
    375 template <int RESOLUTION>
    376 inline void OutputLerp32x4x3(const InterpolationCache<int32>& xs,
    377                              const int64 x_start, const int32 ys_ilerp,
    378                              const float min, const float max,
    379                              const qint32* const ys_input_lower_ptr,
    380                              const qint32* const ys_input_upper_ptr,
    381                              qint32* output_y_ptr) {
    382 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
    383   const int64 xs_lower0 = xs.lower[x_start];
    384   const int64 xs_upper0 = xs.upper[x_start];
    385   const int32* const xs_ilerp0 = &xs.ilerp[x_start];
    386   const int64 xs_lower1 = xs.lower[x_start + 1];
    387   const int64 xs_upper1 = xs.upper[x_start + 1];
    388   const int32* const xs_ilerp1 = &xs.ilerp[x_start + 1];
    389   const int64 xs_lower2 = xs.lower[x_start + 2];
    390   const int64 xs_upper2 = xs.upper[x_start + 2];
    391   const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
    392   const int64 xs_lower3 = xs.lower[x_start + 3];
    393   const int64 xs_upper3 = xs.upper[x_start + 3];
    394   const int32* const xs_ilerp3 = &xs.ilerp[x_start + 3];
    395 
    396   const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
    397 
    398   const int32x2_t x0c0x0c1 = ComputeLerpx2<RESOLUTION, true>(
    399       ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
    400       ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
    401       ys_input_lower_ptr + xs_lower0 + 1, ys_input_lower_ptr + xs_upper0 + 1,
    402       ys_input_upper_ptr + xs_lower0 + 1, ys_input_upper_ptr + xs_upper0 + 1,
    403       xs_ilerp0, y_lerpsx);
    404 
    405   const int32x2_t x0c2x1c0 = ComputeLerpx2<RESOLUTION, false>(
    406       ys_input_lower_ptr + xs_lower0 + 2, ys_input_lower_ptr + xs_upper0 + 2,
    407       ys_input_upper_ptr + xs_lower0 + 2, ys_input_upper_ptr + xs_upper0 + 2,
    408       ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
    409       ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
    410       y_lerpsx);
    411 
    412   const int32x2_t x1c1x1c2 = ComputeLerpx2<RESOLUTION, true>(
    413       ys_input_lower_ptr + xs_lower1 + 1, ys_input_lower_ptr + xs_upper1 + 1,
    414       ys_input_upper_ptr + xs_lower1 + 1, ys_input_upper_ptr + xs_upper1 + 1,
    415       ys_input_lower_ptr + xs_lower1 + 2, ys_input_lower_ptr + xs_upper1 + 2,
    416       ys_input_upper_ptr + xs_lower1 + 2, ys_input_upper_ptr + xs_upper1 + 2,
    417       xs_ilerp1, y_lerpsx);
    418 
    419   const int32x2_t x2c0x2c1 = ComputeLerpx2<RESOLUTION, true>(
    420       ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
    421       ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
    422       ys_input_lower_ptr + xs_lower2 + 1, ys_input_lower_ptr + xs_upper2 + 1,
    423       ys_input_upper_ptr + xs_lower2 + 1, ys_input_upper_ptr + xs_upper2 + 1,
    424       xs_ilerp2, y_lerpsx);
    425 
    426   const int32x2_t x2c2x3c0 = ComputeLerpx2<RESOLUTION, false>(
    427       ys_input_lower_ptr + xs_lower2 + 2, ys_input_lower_ptr + xs_upper2 + 2,
    428       ys_input_upper_ptr + xs_lower2 + 2, ys_input_upper_ptr + xs_upper2 + 2,
    429       ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
    430       ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
    431       y_lerpsx);
    432 
    433   const int32x2_t x3c1x3c2 = ComputeLerpx2<RESOLUTION, true>(
    434       ys_input_lower_ptr + xs_lower3 + 1, ys_input_lower_ptr + xs_upper3 + 1,
    435       ys_input_upper_ptr + xs_lower3 + 1, ys_input_upper_ptr + xs_upper3 + 1,
    436       ys_input_lower_ptr + xs_lower3 + 2, ys_input_lower_ptr + xs_upper3 + 2,
    437       ys_input_upper_ptr + xs_lower3 + 2, ys_input_upper_ptr + xs_upper3 + 2,
    438       xs_ilerp3, y_lerpsx);
    439 
    440   const int32x4_t x0c0x0c1x0c2x1c0 = vcombine_s32(x0c0x0c1, x0c2x1c0);
    441   const int32x4_t x1c1x1c2x2c0x2c1 = vcombine_s32(x1c1x1c2, x2c0x2c1);
    442   const int32x4_t x2c2x3c0x3c1x3c2 = vcombine_s32(x2c2x3c0, x3c1x3c2);
    443 
    444   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3),
    445             x0c0x0c1x0c2x1c0);
    446   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 4),
    447             x1c1x1c2x2c0x2c1);
    448   vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 8),
    449             x2c2x3c0x3c1x3c2);
    450 
    451 #else
    452   for (int x = x_start; x < x_start + 4; ++x) {
    453     OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
    454         xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
    455         output_y_ptr);
    456   }
    457 #endif
    458 }
    459 
    460 template <typename T>
    461 void ResizeImageReference(typename TTypes<T, 4>::ConstTensor images,
    462                           const int batch_size, const int64 in_height,
    463                           const int64 in_width, const int64 out_height,
    464                           const int64 out_width, const int channels,
    465                           const float height_scale, const float width_scale,
    466                           const float in_min, const float in_max,
    467                           typename TTypes<T, 4>::Tensor* output) {
    468   CHECK_NOTNULL(output);
    469 
    470   const InterpolationCache<float> xs =
    471       BuildLerpCache<float>(out_width, in_width, width_scale, channels, 0);
    472   const InterpolationCache<float> ys =
    473       BuildLerpCache<float>(out_height, in_height, height_scale, 1, 0);
    474 
    475   const int64 in_row_size = in_width * channels;
    476   const int64 in_batch_num_values = in_height * in_row_size;
    477   const int64 out_row_size = out_width * channels;
    478 
    479   const T* input_b_ptr = images.data();
    480 
    481   T* output_y_ptr = output->data();
    482   for (int b = 0; b < batch_size; ++b) {
    483     for (int64 y = 0; y < out_height; ++y) {
    484       const T* ys_input_lower_ptr = input_b_ptr + ys.lower[y] * in_row_size;
    485       const T* ys_input_upper_ptr = input_b_ptr + ys.upper[y] * in_row_size;
    486       const float ys_lerp = ys.lerp[y];
    487       for (int64 x = 0; x < out_width; ++x) {
    488         const int64 xs_lower = xs.lower[x];
    489         const int64 xs_upper = xs.upper[x];
    490         const float xs_lerp = xs.lerp[x];
    491         for (int c = 0; c < channels; ++c) {
    492           const T top_left = ys_input_lower_ptr[xs_lower + c];
    493           const T top_right = ys_input_lower_ptr[xs_upper + c];
    494           const T bottom_left = ys_input_upper_ptr[xs_lower + c];
    495           const T bottom_right = ys_input_upper_ptr[xs_upper + c];
    496           const T val = ComputeLerpReference<T>(
    497               top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp,
    498               in_min, in_max);
    499           output_y_ptr[x * channels + c] = val;
    500         }
    501       }
    502       output_y_ptr += out_row_size;
    503     }
    504     input_b_ptr += in_batch_num_values;
    505   }
    506 }
    507 
    508 template <typename T>
    509 void ResizeImage(typename TTypes<T, 4>::ConstTensor images,
    510                  const int batch_size, const int64 in_height,
    511                  const int64 in_width, const int64 out_height,
    512                  const int64 out_width, const int channels,
    513                  const float height_scale, const float width_scale,
    514                  const float in_min, const float in_max,
    515                  typename TTypes<T, 4>::Tensor* output) {
    516   ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
    517                           out_width, channels, height_scale, width_scale,
    518                           in_min, in_max, output);
    519 }
    520 
    521 template <>
    522 void ResizeImage<qint32>(typename TTypes<qint32, 4>::ConstTensor images,
    523                          const int batch_size, const int64 in_height,
    524                          const int64 in_width, const int64 out_height,
    525                          const int64 out_width, const int channels,
    526                          const float height_scale, const float width_scale,
    527                          const float in_min, const float in_max,
    528                          typename TTypes<qint32, 4>::Tensor* output) {
    529   // 30 is maximum resolution for signed int.
    530   constexpr int RESOLUTION = 30;
    531   constexpr int SIMD_STEP = 4;
    532 
    533   CHECK_NOTNULL(output);
    534 
    535   const InterpolationCache<int32> xs = BuildLerpCache<int32>(
    536       out_width, in_width, width_scale, channels, RESOLUTION);
    537   const InterpolationCache<int32> ys =
    538       BuildLerpCache<int32>(out_height, in_height, height_scale, 1, RESOLUTION);
    539 
    540   const int64 in_row_size = in_width * channels;
    541   const int64 in_batch_num_values = in_height * in_row_size;
    542   const int64 out_row_size = out_width * channels;
    543 
    544   const qint32* input_b_ptr = images.data();
    545 
    546   qint32* output_y_ptr = output->data();
    547 
    548   for (int b = 0; b < batch_size; ++b) {
    549     for (int64 y = 0; y < out_height; ++y) {
    550       const qint32* ys_input_lower_ptr =
    551           input_b_ptr + ys.lower[y] * in_row_size;
    552       const qint32* ys_input_upper_ptr =
    553           input_b_ptr + ys.upper[y] * in_row_size;
    554       const int32 ys_ilerp = ys.ilerp[y];
    555       // Optimized for channels == 1 or channels == 3 as this
    556       // is typical channels.
    557       int64 x = 0;
    558       if (channels == 1) {
    559         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
    560           OutputLerp32x4x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
    561                                        ys_input_lower_ptr, ys_input_upper_ptr,
    562                                        output_y_ptr);
    563         }
    564       } else if (channels == 3) {
    565         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
    566           OutputLerp32x4x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
    567                                        ys_input_lower_ptr, ys_input_upper_ptr,
    568                                        output_y_ptr);
    569         }
    570       }
    571       for (; x < out_width; ++x) {
    572         OutputLerpForChannels<RESOLUTION, qint32, int32, int64>(
    573             xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
    574             ys_input_upper_ptr, output_y_ptr);
    575       }
    576       output_y_ptr += out_row_size;
    577     }
    578     input_b_ptr += in_batch_num_values;
    579   }
    580 }
    581 
    582 template <>
    583 void ResizeImage<quint8>(typename TTypes<quint8, 4>::ConstTensor images,
    584                          const int batch_size, const int64 in_height,
    585                          const int64 in_width, const int64 out_height,
    586                          const int64 out_width, const int channels,
    587                          const float height_scale, const float width_scale,
    588                          const float in_min, const float in_max,
    589                          typename TTypes<quint8, 4>::Tensor* output) {
    590   // 7 is maximum resolution for unsigned byte.
    591   constexpr int RESOLUTION = 7;
    592   constexpr int SIMD_STEP = 8;
    593 
    594   CHECK_NOTNULL(output);
    595 
    596   const InterpolationCache<int16> xs = BuildLerpCache<int16>(
    597       out_width, in_width, width_scale, channels, RESOLUTION);
    598   const InterpolationCache<int16> ys =
    599       BuildLerpCache<int16>(out_height, in_height, height_scale, 1, RESOLUTION);
    600 
    601   const int64 in_row_size = in_width * channels;
    602   const int64 in_batch_num_values = in_height * in_row_size;
    603   const int64 out_row_size = out_width * channels;
    604 
    605   const quint8* input_b_ptr = images.data();
    606 
    607   quint8* output_y_ptr = output->data();
    608 
    609   for (int b = 0; b < batch_size; ++b) {
    610     for (int64 y = 0; y < out_height; ++y) {
    611       const quint8* ys_input_lower_ptr =
    612           input_b_ptr + ys.lower[y] * in_row_size;
    613       const quint8* ys_input_upper_ptr =
    614           input_b_ptr + ys.upper[y] * in_row_size;
    615       const int32 ys_ilerp = ys.ilerp[y];
    616       // Optimized for channels == 1 or channels == 3 as this
    617       // is typical channels.
    618       // TODO(satok): Support more generic NEON optimized implementation
    619       // for different channels.
    620       int64 x = 0;
    621       if (channels == 1) {
    622         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
    623           OutputLerp8x8x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
    624                                       ys_input_lower_ptr, ys_input_upper_ptr,
    625                                       output_y_ptr);
    626         }
    627       } else if (channels == 3) {
    628         for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
    629           OutputLerp8x8x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
    630                                       ys_input_lower_ptr, ys_input_upper_ptr,
    631                                       output_y_ptr);
    632         }
    633       }
    634       for (; x < out_width; ++x) {
    635         OutputLerpForChannels<RESOLUTION, quint8, int16, int16>(
    636             xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
    637             ys_input_upper_ptr, output_y_ptr);
    638       }
    639       output_y_ptr += out_row_size;
    640     }
    641     input_b_ptr += in_batch_num_values;
    642   }
    643 }
    644 
    645 template <typename T>
    646 void ResizeBilinear(const typename TTypes<T, 4>::ConstTensor& images,
    647                     const float height_scale, const float width_scale,
    648                     const float in_min, const float in_max,
    649                     typename TTypes<T, 4>::Tensor* output) {
    650   CHECK_NOTNULL(output);
    651 
    652   const int batch_size = images.dimension(0);
    653   const int64 in_height = images.dimension(1);
    654   const int64 in_width = images.dimension(2);
    655   const int channels = images.dimension(3);
    656 
    657   const int64 out_height = output->dimension(1);
    658   const int64 out_width = output->dimension(2);
    659 
    660   // Handle no-op resizes efficiently.
    661   if (out_height == in_height && out_width == in_width) {
    662     *output = images.template cast<T>();
    663     return;
    664   }
    665 
    666   if (USE_REFERENCE) {
    667     ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
    668                             out_width, channels, height_scale, width_scale,
    669                             in_min, in_max, output);
    670   } else {
    671     ResizeImage<T>(images, batch_size, in_height, in_width, out_height,
    672                    out_width, channels, height_scale, width_scale, in_min,
    673                    in_max, output);
    674   }
    675 }
    676 
    677 }  // namespace
    678 
    679 template <class T>
    680 class QuantizedResizeBilinearOp : public OpKernel {
    681  public:
    682   explicit QuantizedResizeBilinearOp(OpKernelConstruction* context)
    683       : OpKernel(context) {
    684     OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
    685   }
    686 
    687   void Compute(OpKernelContext* context) override {
    688     const Tensor& input = context->input(0);
    689     const float in_min = context->input(2).flat<float>()(0);
    690     const float in_max = context->input(3).flat<float>()(0);
    691 
    692     ImageResizerState st(align_corners_);
    693     st.ValidateAndCreateOutput(context, input);
    694 
    695     if (!context->status().ok()) return;
    696 
    697     // Return if the output is empty.
    698     if (st.output->NumElements() == 0) return;
    699 
    700     typename TTypes<T, 4>::ConstTensor image_data(input.tensor<T, 4>());
    701     typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
    702 
    703     ResizeBilinear<T>(image_data, st.height_scale, st.width_scale, in_min,
    704                       in_max, &output_data);
    705     Tensor* out_min = nullptr;
    706     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &out_min));
    707     out_min->flat<float>()(0) = in_min;
    708 
    709     Tensor* out_max = nullptr;
    710     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &out_max));
    711     out_max->flat<float>()(0) = in_max;
    712   }
    713 
    714  private:
    715   bool align_corners_;
    716 
    717   TF_DISALLOW_COPY_AND_ASSIGN(QuantizedResizeBilinearOp<T>);
    718 };
    719 
    720 #define REGISTER_CPU_KERNEL(type)                         \
    721   REGISTER_KERNEL_BUILDER(Name("QuantizedResizeBilinear") \
    722                               .Device(DEVICE_CPU)         \
    723                               .HostMemory("size")         \
    724                               .TypeConstraint<type>("T"), \
    725                           QuantizedResizeBilinearOp<type>)
    726 
    727 REGISTER_CPU_KERNEL(::tensorflow::quint8);
    728 REGISTER_CPU_KERNEL(::tensorflow::qint32);
    729 REGISTER_CPU_KERNEL(float);
    730 
    731 }  // namespace tensorflow
    732