1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements a quantized version of the resize bilinear op. 17 18 #define EIGEN_USE_THREADS 19 20 #if defined(__ARM_NEON__) || defined(__ARM_NEON) 21 #define USE_NEON 22 #define QUANTIZED_RESIZE_BILINEAR_USE_NEON 23 #include <arm_neon.h> 24 #endif 25 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/kernels/image_resizer_state.h" 29 #include "tensorflow/core/kernels/quantization_utils.h" 30 #include "tensorflow/core/platform/macros.h" 31 32 namespace tensorflow { 33 34 static constexpr bool USE_REFERENCE = false; 35 36 namespace { 37 // Compute the interpolation indices only once. 38 template <typename T_SCALE> 39 struct InterpolationCache { 40 std::vector<int64> lower; // Lower source index used in the interpolation 41 std::vector<int64> upper; // Upper source index used in the interpolation 42 // 1-D linear iterpolation scale (see: 43 // https://en.wikipedia.org/wiki/Bilinear_interpolation) 44 std::vector<float> lerp; 45 std::vector<T_SCALE> ilerp; 46 }; 47 48 template <typename T_SCALE> 49 inline void ComputeInterpolationWeights( 50 const int64 out_size, const int64 in_size, const float scale, 51 const int resolution, InterpolationCache<T_SCALE>* interpolation) { 52 interpolation->lower.resize(out_size + 1); 53 interpolation->upper.resize(out_size + 1); 54 interpolation->lerp.resize(out_size + 1); 55 interpolation->ilerp.resize(out_size + 1); 56 57 interpolation->lower[out_size] = 0; 58 interpolation->upper[out_size] = 0; 59 for (int64 i = out_size - 1; i >= 0; --i) { 60 const float in = i * scale; 61 interpolation->lower[i] = static_cast<int64>(in); 62 interpolation->upper[i] = 63 std::min(interpolation->lower[i] + 1, in_size - 1); 64 interpolation->lerp[i] = in - interpolation->lower[i]; 65 interpolation->ilerp[i] = static_cast<T_SCALE>( 66 (in - interpolation->lower[i]) * (1 << resolution)); 67 } 68 } 69 70 template <typename T_SCALE> 71 inline InterpolationCache<T_SCALE> BuildLerpCache(const int64 out_size, 72 const int64 in_size, 73 const float scale, 74 const int index_step, 75 const int resolution) { 76 InterpolationCache<T_SCALE> cache; 77 // Compute the cached interpolation weights on the x and y dimensions. 78 ComputeInterpolationWeights<T_SCALE>(out_size, in_size, scale, resolution, 79 &cache); 80 CHECK(index_step > 0); 81 if (index_step > 1) { 82 for (int i = 0; i < cache.lower.size(); ++i) { 83 cache.lower[i] *= index_step; 84 cache.upper[i] *= index_step; 85 } 86 } 87 return cache; 88 } 89 90 /** 91 * Computes the bilinear interpolation from the appropriate 4 float points 92 * and the linear interpolation weights. 93 */ 94 template <typename T> 95 inline T ComputeLerpReference(const T in_top_left, const T in_top_right, 96 const T in_bottom_left, const T in_bottom_right, 97 const float x_lerp, const float y_lerp, 98 const float min, const float max) { 99 const float top_left = QuantizedToFloat<T>(in_top_left, min, max); 100 const float top_right = QuantizedToFloat<T>(in_top_right, min, max); 101 const float bottom_left = QuantizedToFloat<T>(in_bottom_left, min, max); 102 const float bottom_right = QuantizedToFloat<T>(in_bottom_right, min, max); 103 const float top = top_left + (top_right - top_left) * x_lerp; 104 const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; 105 const float out = top + (bottom - top) * y_lerp; 106 return FloatToQuantized<T>(out, min, max); 107 } 108 109 template <typename T, typename T_SCALE, typename T_CALC> 110 inline T_CALC MulOffset(T a, T b, T_SCALE c) { 111 return (static_cast<T_CALC>(a) - static_cast<T_CALC>(b)) * 112 static_cast<T_CALC>(c); 113 } 114 115 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC> 116 inline T ComputeLerp(const T top_left, const T top_right, const T bottom_left, 117 const T bottom_right, const T_SCALE x_lerp, 118 const T_SCALE y_lerp) { 119 constexpr T_CALC RESOLUTION_MULT = (1 << RESOLUTION); 120 const T_CALC top = static_cast<T_CALC>(top_left) * RESOLUTION_MULT + 121 MulOffset<T, T_SCALE, T_CALC>(top_right, top_left, x_lerp); 122 const T_CALC bottom = 123 static_cast<T_CALC>(bottom_left) * RESOLUTION_MULT + 124 MulOffset<T, T_SCALE, T_CALC>(bottom_right, bottom_left, x_lerp); 125 const T_CALC out = top + (bottom - top) / RESOLUTION_MULT * y_lerp; 126 return static_cast<T>( 127 static_cast<int32>((out + RESOLUTION_MULT / 2) / RESOLUTION_MULT)); 128 } 129 130 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON 131 inline uint8x8_t ToUint8x8(const quint8* v0, const quint8* v1, const quint8* v2, 132 const quint8* v3, const quint8* v4, const quint8* v5, 133 const quint8* v6, const quint8* v7) { 134 static const uint8x8_t ZERO_8x8 = vmov_n_u8(0); 135 uint8x8_t ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v0), ZERO_8x8, 0); 136 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v1), ret, 1); 137 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v2), ret, 2); 138 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v3), ret, 3); 139 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v4), ret, 4); 140 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v5), ret, 5); 141 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v6), ret, 6); 142 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v7), ret, 7); 143 return ret; 144 } 145 146 inline int16x8_t ToInt16x8(const int16* v0, const int16* v1, const int16* v2, 147 const int16* v3, const int16* v4, const int16* v5, 148 const int16* v6, const int16* v7) { 149 static const int16x8_t ZERO_16x8 = vmovq_n_s16(0); 150 int16x8_t ret = vld1q_lane_s16(v0, ZERO_16x8, 0); 151 ret = vld1q_lane_s16(v1, ret, 1); 152 ret = vld1q_lane_s16(v2, ret, 2); 153 ret = vld1q_lane_s16(v3, ret, 3); 154 ret = vld1q_lane_s16(v4, ret, 4); 155 ret = vld1q_lane_s16(v5, ret, 5); 156 ret = vld1q_lane_s16(v6, ret, 6); 157 ret = vld1q_lane_s16(v7, ret, 7); 158 return ret; 159 } 160 161 inline int32x2_t ToInt32x2(const qint32* v0, const qint32* v1) { 162 static const int32x2_t ZERO_32x2 = vmov_n_s32(0); 163 const int32x2_t ret0 = 164 vld1_lane_s32(reinterpret_cast<const int32*>(v0), ZERO_32x2, 0); 165 const int32x2_t ret1 = 166 vld1_lane_s32(reinterpret_cast<const int32*>(v1), ret0, 1); 167 return ret1; 168 } 169 170 template <int RESOLUTION, bool X_LERP_SAME> 171 inline int32x2_t ComputeLerpx2( 172 const qint32* top_left0, const qint32* top_right0, 173 const qint32* bottom_left0, const qint32* bottom_right0, 174 const qint32* top_left1, const qint32* top_right1, 175 const qint32* bottom_left1, const qint32* bottom_right1, 176 const int32* x_lerp, const int32x2_t y_lerpsx) { 177 const int32x2_t x_lerpsx = 178 X_LERP_SAME ? vld1_dup_s32(reinterpret_cast<const int32*>(x_lerp)) 179 : vld1_s32(reinterpret_cast<const int32*>(x_lerp)); 180 181 const int32x2_t top_leftsx = ToInt32x2(top_left0, top_left1); 182 const int32x2_t top_rightsx = ToInt32x2(top_right0, top_right1); 183 const int32x2_t bottom_leftsx = ToInt32x2(bottom_left0, bottom_left1); 184 const int32x2_t bottom_rightsx = ToInt32x2(bottom_right0, bottom_right1); 185 186 const int32x2_t retval = 187 ComputeLerp32x2<RESOLUTION>(top_leftsx, top_rightsx, bottom_leftsx, 188 bottom_rightsx, x_lerpsx, y_lerpsx); 189 return retval; 190 } 191 192 template <int RESOLUTION> 193 inline uint8x8_t ComputeLerpx8( 194 const quint8* tl0, const quint8* tr0, const quint8* bl0, const quint8* br0, 195 const int16* xlp0, const quint8* tl1, const quint8* tr1, const quint8* bl1, 196 const quint8* br1, const int16* xlp1, const quint8* tl2, const quint8* tr2, 197 const quint8* bl2, const quint8* br2, const int16* xlp2, const quint8* tl3, 198 const quint8* tr3, const quint8* bl3, const quint8* br3, const int16* xlp3, 199 const quint8* tl4, const quint8* tr4, const quint8* bl4, const quint8* br4, 200 const int16* xlp4, const quint8* tl5, const quint8* tr5, const quint8* bl5, 201 const quint8* br5, const int16* xlp5, const quint8* tl6, const quint8* tr6, 202 const quint8* bl6, const quint8* br6, const int16* xlp6, const quint8* tl7, 203 const quint8* tr7, const quint8* bl7, const quint8* br7, const int16* xlp7, 204 const int16x8_t ys_lerpsx) { 205 const uint8x8_t tl8x8 = ToUint8x8(tl0, tl1, tl2, tl3, tl4, tl5, tl6, tl7); 206 const uint8x8_t tr8x8 = ToUint8x8(tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7); 207 const uint8x8_t bl8x8 = ToUint8x8(bl0, bl1, bl2, bl3, bl4, bl5, bl6, bl7); 208 const uint8x8_t br8x8 = ToUint8x8(br0, br1, br2, br3, br4, br5, br6, br7); 209 const int16x8_t xs_lerpsx = 210 ToInt16x8(xlp0, xlp1, xlp2, xlp3, xlp4, xlp5, xlp6, xlp7); 211 return ComputeLerp8x8<RESOLUTION>(tl8x8, tr8x8, bl8x8, br8x8, xs_lerpsx, 212 ys_lerpsx); 213 } 214 215 // Expand address at compile time to improve performance 216 template <int RESOLUTION, int ID0, int CH0, int ID1, int CH1, int ID2, int CH2, 217 int ID3, int CH3, int ID4, int CH4, int ID5, int CH5, int ID6, 218 int CH6, int ID7, int CH7> 219 inline uint8x8_t ComputeLerpx8Tmpl(const quint8* const yl, const quint8* yu, 220 const int64* xl, const int64* xu, 221 const int16* xlp, 222 const int16x8_t ys_lerpsx) { 223 return ComputeLerpx8<RESOLUTION>( 224 yl + xl[ID0] + CH0, yl + xu[ID0] + CH0, yu + xl[ID0] + CH0, 225 yu + xu[ID0] + CH0, xlp + ID0, yl + xl[ID1] + CH1, yl + xu[ID1] + CH1, 226 yu + xl[ID1] + CH1, yu + xu[ID1] + CH1, xlp + ID1, yl + xl[ID2] + CH2, 227 yl + xu[ID2] + CH2, yu + xl[ID2] + CH2, yu + xu[ID2] + CH2, xlp + ID2, 228 yl + xl[ID3] + CH3, yl + xu[ID3] + CH3, yu + xl[ID3] + CH3, 229 yu + xu[ID3] + CH3, xlp + ID3, yl + xl[ID4] + CH4, yl + xu[ID4] + CH4, 230 yu + xl[ID4] + CH4, yu + xu[ID4] + CH4, xlp + ID4, yl + xl[ID5] + CH5, 231 yl + xu[ID5] + CH5, yu + xl[ID5] + CH5, yu + xu[ID5] + CH5, xlp + ID5, 232 yl + xl[ID6] + CH6, yl + xu[ID6] + CH6, yu + xl[ID6] + CH6, 233 yu + xu[ID6] + CH6, xlp + ID6, yl + xl[ID7] + CH7, yl + xu[ID7] + CH7, 234 yu + xl[ID7] + CH7, yu + xu[ID7] + CH7, xlp + ID7, ys_lerpsx); 235 } 236 237 #endif 238 239 template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC> 240 inline void OutputLerpForChannels(const InterpolationCache<T_SCALE>& xs, 241 const int64 x, const T_SCALE ys_ilerp, 242 const int channels, const float min, 243 const float max, const T* ys_input_lower_ptr, 244 const T* ys_input_upper_ptr, 245 T* output_y_ptr) { 246 const int64 xs_lower = xs.lower[x]; 247 const int64 xs_upper = xs.upper[x]; 248 const T_SCALE xs_ilerp = xs.ilerp[x]; 249 for (int c = 0; c < channels; ++c) { 250 const T top_left = ys_input_lower_ptr[xs_lower + c]; 251 const T top_right = ys_input_lower_ptr[xs_upper + c]; 252 const T bottom_left = ys_input_upper_ptr[xs_lower + c]; 253 const T bottom_right = ys_input_upper_ptr[xs_upper + c]; 254 const T val = ComputeLerp<RESOLUTION, T, T_SCALE, T_CALC>( 255 top_left, top_right, bottom_left, bottom_right, xs_ilerp, ys_ilerp); 256 output_y_ptr[x * channels + c] = val; 257 } 258 } 259 260 template <int RES> 261 inline void OutputLerp8x8x1(const InterpolationCache<int16>& xs, 262 const int64 x_start, const int16 ys_ilerp, 263 const float min, const float max, 264 const quint8* const ys_input_lower_ptr, 265 const quint8* const ys_input_upper_ptr, 266 quint8* output_y_ptr) { 267 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON 268 const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp); 269 270 const uint8x8_t x0x7 = 271 ComputeLerpx8Tmpl<RES, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0>( 272 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start], 273 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx); 274 275 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start), x0x7); 276 277 #else 278 for (int x = x_start; x < x_start + 8; ++x) { 279 OutputLerpForChannels<RES, quint8, int16, int16>( 280 xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr, 281 output_y_ptr); 282 } 283 #endif 284 } 285 286 template <int RES> 287 inline void OutputLerp8x8x3(const InterpolationCache<int16>& xs, 288 const int64 x_start, const int16 ys_ilerp, 289 const float min, const float max, 290 const quint8* const ys_input_lower_ptr, 291 const quint8* const ys_input_upper_ptr, 292 quint8* output_y_ptr) { 293 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON 294 const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp); 295 296 const uint8x8_t x0c0x2c1 = 297 ComputeLerpx8Tmpl<RES, 0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 0, 2, 1>( 298 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start], 299 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx); 300 301 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3), x0c0x2c1); 302 303 const uint8x8_t x2c2x5c0 = 304 ComputeLerpx8Tmpl<RES, 2, 2, 3, 0, 3, 1, 3, 2, 4, 0, 4, 1, 4, 2, 5, 0>( 305 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start], 306 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx); 307 308 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 8), x2c2x5c0); 309 310 const uint8x8_t x5c1x7c2 = 311 ComputeLerpx8Tmpl<RES, 5, 1, 5, 2, 6, 0, 6, 1, 6, 2, 7, 0, 7, 1, 7, 2>( 312 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start], 313 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx); 314 315 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 16), 316 x5c1x7c2); 317 318 #else 319 for (int x = x_start; x < x_start + 8; ++x) { 320 OutputLerpForChannels<RES, quint8, int16, int16>( 321 xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr, 322 output_y_ptr); 323 } 324 #endif 325 } 326 327 template <int RESOLUTION> 328 inline void OutputLerp32x4x1(const InterpolationCache<int32>& xs, 329 const int64 x_start, const int32 ys_ilerp, 330 const float min, const float max, 331 const qint32* const ys_input_lower_ptr, 332 const qint32* const ys_input_upper_ptr, 333 qint32* output_y_ptr) { 334 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON 335 const int64 xs_lower0 = xs.lower[x_start]; 336 const int64 xs_upper0 = xs.upper[x_start]; 337 const int32* const xs_ilerp0 = &xs.ilerp[x_start]; 338 const int64 xs_lower1 = xs.lower[x_start + 1]; 339 const int64 xs_upper1 = xs.upper[x_start + 1]; 340 const int64 xs_lower2 = xs.lower[x_start + 2]; 341 const int64 xs_upper2 = xs.upper[x_start + 2]; 342 const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2]; 343 const int64 xs_lower3 = xs.lower[x_start + 3]; 344 const int64 xs_upper3 = xs.upper[x_start + 3]; 345 346 const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp); 347 348 const int32x2_t x0x1 = ComputeLerpx2<RESOLUTION, false>( 349 ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0, 350 ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0, 351 ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1, 352 ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0, 353 y_lerpsx); 354 355 const int32x2_t x1x2 = ComputeLerpx2<RESOLUTION, false>( 356 ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2, 357 ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2, 358 ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3, 359 ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2, 360 y_lerpsx); 361 362 const int32x4_t x0x1x2x3 = vcombine_s32(x0x1, x1x2); 363 364 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start), x0x1x2x3); 365 366 #else 367 for (int x = x_start; x < x_start + 4; ++x) { 368 OutputLerpForChannels<RESOLUTION, qint32, int32, int64>( 369 xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr, 370 output_y_ptr); 371 } 372 #endif 373 } 374 375 template <int RESOLUTION> 376 inline void OutputLerp32x4x3(const InterpolationCache<int32>& xs, 377 const int64 x_start, const int32 ys_ilerp, 378 const float min, const float max, 379 const qint32* const ys_input_lower_ptr, 380 const qint32* const ys_input_upper_ptr, 381 qint32* output_y_ptr) { 382 #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON 383 const int64 xs_lower0 = xs.lower[x_start]; 384 const int64 xs_upper0 = xs.upper[x_start]; 385 const int32* const xs_ilerp0 = &xs.ilerp[x_start]; 386 const int64 xs_lower1 = xs.lower[x_start + 1]; 387 const int64 xs_upper1 = xs.upper[x_start + 1]; 388 const int32* const xs_ilerp1 = &xs.ilerp[x_start + 1]; 389 const int64 xs_lower2 = xs.lower[x_start + 2]; 390 const int64 xs_upper2 = xs.upper[x_start + 2]; 391 const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2]; 392 const int64 xs_lower3 = xs.lower[x_start + 3]; 393 const int64 xs_upper3 = xs.upper[x_start + 3]; 394 const int32* const xs_ilerp3 = &xs.ilerp[x_start + 3]; 395 396 const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp); 397 398 const int32x2_t x0c0x0c1 = ComputeLerpx2<RESOLUTION, true>( 399 ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0, 400 ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0, 401 ys_input_lower_ptr + xs_lower0 + 1, ys_input_lower_ptr + xs_upper0 + 1, 402 ys_input_upper_ptr + xs_lower0 + 1, ys_input_upper_ptr + xs_upper0 + 1, 403 xs_ilerp0, y_lerpsx); 404 405 const int32x2_t x0c2x1c0 = ComputeLerpx2<RESOLUTION, false>( 406 ys_input_lower_ptr + xs_lower0 + 2, ys_input_lower_ptr + xs_upper0 + 2, 407 ys_input_upper_ptr + xs_lower0 + 2, ys_input_upper_ptr + xs_upper0 + 2, 408 ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1, 409 ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0, 410 y_lerpsx); 411 412 const int32x2_t x1c1x1c2 = ComputeLerpx2<RESOLUTION, true>( 413 ys_input_lower_ptr + xs_lower1 + 1, ys_input_lower_ptr + xs_upper1 + 1, 414 ys_input_upper_ptr + xs_lower1 + 1, ys_input_upper_ptr + xs_upper1 + 1, 415 ys_input_lower_ptr + xs_lower1 + 2, ys_input_lower_ptr + xs_upper1 + 2, 416 ys_input_upper_ptr + xs_lower1 + 2, ys_input_upper_ptr + xs_upper1 + 2, 417 xs_ilerp1, y_lerpsx); 418 419 const int32x2_t x2c0x2c1 = ComputeLerpx2<RESOLUTION, true>( 420 ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2, 421 ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2, 422 ys_input_lower_ptr + xs_lower2 + 1, ys_input_lower_ptr + xs_upper2 + 1, 423 ys_input_upper_ptr + xs_lower2 + 1, ys_input_upper_ptr + xs_upper2 + 1, 424 xs_ilerp2, y_lerpsx); 425 426 const int32x2_t x2c2x3c0 = ComputeLerpx2<RESOLUTION, false>( 427 ys_input_lower_ptr + xs_lower2 + 2, ys_input_lower_ptr + xs_upper2 + 2, 428 ys_input_upper_ptr + xs_lower2 + 2, ys_input_upper_ptr + xs_upper2 + 2, 429 ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3, 430 ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2, 431 y_lerpsx); 432 433 const int32x2_t x3c1x3c2 = ComputeLerpx2<RESOLUTION, true>( 434 ys_input_lower_ptr + xs_lower3 + 1, ys_input_lower_ptr + xs_upper3 + 1, 435 ys_input_upper_ptr + xs_lower3 + 1, ys_input_upper_ptr + xs_upper3 + 1, 436 ys_input_lower_ptr + xs_lower3 + 2, ys_input_lower_ptr + xs_upper3 + 2, 437 ys_input_upper_ptr + xs_lower3 + 2, ys_input_upper_ptr + xs_upper3 + 2, 438 xs_ilerp3, y_lerpsx); 439 440 const int32x4_t x0c0x0c1x0c2x1c0 = vcombine_s32(x0c0x0c1, x0c2x1c0); 441 const int32x4_t x1c1x1c2x2c0x2c1 = vcombine_s32(x1c1x1c2, x2c0x2c1); 442 const int32x4_t x2c2x3c0x3c1x3c2 = vcombine_s32(x2c2x3c0, x3c1x3c2); 443 444 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3), 445 x0c0x0c1x0c2x1c0); 446 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 4), 447 x1c1x1c2x2c0x2c1); 448 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 8), 449 x2c2x3c0x3c1x3c2); 450 451 #else 452 for (int x = x_start; x < x_start + 4; ++x) { 453 OutputLerpForChannels<RESOLUTION, qint32, int32, int64>( 454 xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr, 455 output_y_ptr); 456 } 457 #endif 458 } 459 460 template <typename T> 461 void ResizeImageReference(typename TTypes<T, 4>::ConstTensor images, 462 const int batch_size, const int64 in_height, 463 const int64 in_width, const int64 out_height, 464 const int64 out_width, const int channels, 465 const float height_scale, const float width_scale, 466 const float in_min, const float in_max, 467 typename TTypes<T, 4>::Tensor* output) { 468 CHECK_NOTNULL(output); 469 470 const InterpolationCache<float> xs = 471 BuildLerpCache<float>(out_width, in_width, width_scale, channels, 0); 472 const InterpolationCache<float> ys = 473 BuildLerpCache<float>(out_height, in_height, height_scale, 1, 0); 474 475 const int64 in_row_size = in_width * channels; 476 const int64 in_batch_num_values = in_height * in_row_size; 477 const int64 out_row_size = out_width * channels; 478 479 const T* input_b_ptr = images.data(); 480 481 T* output_y_ptr = output->data(); 482 for (int b = 0; b < batch_size; ++b) { 483 for (int64 y = 0; y < out_height; ++y) { 484 const T* ys_input_lower_ptr = input_b_ptr + ys.lower[y] * in_row_size; 485 const T* ys_input_upper_ptr = input_b_ptr + ys.upper[y] * in_row_size; 486 const float ys_lerp = ys.lerp[y]; 487 for (int64 x = 0; x < out_width; ++x) { 488 const int64 xs_lower = xs.lower[x]; 489 const int64 xs_upper = xs.upper[x]; 490 const float xs_lerp = xs.lerp[x]; 491 for (int c = 0; c < channels; ++c) { 492 const T top_left = ys_input_lower_ptr[xs_lower + c]; 493 const T top_right = ys_input_lower_ptr[xs_upper + c]; 494 const T bottom_left = ys_input_upper_ptr[xs_lower + c]; 495 const T bottom_right = ys_input_upper_ptr[xs_upper + c]; 496 const T val = ComputeLerpReference<T>( 497 top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp, 498 in_min, in_max); 499 output_y_ptr[x * channels + c] = val; 500 } 501 } 502 output_y_ptr += out_row_size; 503 } 504 input_b_ptr += in_batch_num_values; 505 } 506 } 507 508 template <typename T> 509 void ResizeImage(typename TTypes<T, 4>::ConstTensor images, 510 const int batch_size, const int64 in_height, 511 const int64 in_width, const int64 out_height, 512 const int64 out_width, const int channels, 513 const float height_scale, const float width_scale, 514 const float in_min, const float in_max, 515 typename TTypes<T, 4>::Tensor* output) { 516 ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height, 517 out_width, channels, height_scale, width_scale, 518 in_min, in_max, output); 519 } 520 521 template <> 522 void ResizeImage<qint32>(typename TTypes<qint32, 4>::ConstTensor images, 523 const int batch_size, const int64 in_height, 524 const int64 in_width, const int64 out_height, 525 const int64 out_width, const int channels, 526 const float height_scale, const float width_scale, 527 const float in_min, const float in_max, 528 typename TTypes<qint32, 4>::Tensor* output) { 529 // 30 is maximum resolution for signed int. 530 constexpr int RESOLUTION = 30; 531 constexpr int SIMD_STEP = 4; 532 533 CHECK_NOTNULL(output); 534 535 const InterpolationCache<int32> xs = BuildLerpCache<int32>( 536 out_width, in_width, width_scale, channels, RESOLUTION); 537 const InterpolationCache<int32> ys = 538 BuildLerpCache<int32>(out_height, in_height, height_scale, 1, RESOLUTION); 539 540 const int64 in_row_size = in_width * channels; 541 const int64 in_batch_num_values = in_height * in_row_size; 542 const int64 out_row_size = out_width * channels; 543 544 const qint32* input_b_ptr = images.data(); 545 546 qint32* output_y_ptr = output->data(); 547 548 for (int b = 0; b < batch_size; ++b) { 549 for (int64 y = 0; y < out_height; ++y) { 550 const qint32* ys_input_lower_ptr = 551 input_b_ptr + ys.lower[y] * in_row_size; 552 const qint32* ys_input_upper_ptr = 553 input_b_ptr + ys.upper[y] * in_row_size; 554 const int32 ys_ilerp = ys.ilerp[y]; 555 // Optimized for channels == 1 or channels == 3 as this 556 // is typical channels. 557 int64 x = 0; 558 if (channels == 1) { 559 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) { 560 OutputLerp32x4x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max, 561 ys_input_lower_ptr, ys_input_upper_ptr, 562 output_y_ptr); 563 } 564 } else if (channels == 3) { 565 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) { 566 OutputLerp32x4x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max, 567 ys_input_lower_ptr, ys_input_upper_ptr, 568 output_y_ptr); 569 } 570 } 571 for (; x < out_width; ++x) { 572 OutputLerpForChannels<RESOLUTION, qint32, int32, int64>( 573 xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr, 574 ys_input_upper_ptr, output_y_ptr); 575 } 576 output_y_ptr += out_row_size; 577 } 578 input_b_ptr += in_batch_num_values; 579 } 580 } 581 582 template <> 583 void ResizeImage<quint8>(typename TTypes<quint8, 4>::ConstTensor images, 584 const int batch_size, const int64 in_height, 585 const int64 in_width, const int64 out_height, 586 const int64 out_width, const int channels, 587 const float height_scale, const float width_scale, 588 const float in_min, const float in_max, 589 typename TTypes<quint8, 4>::Tensor* output) { 590 // 7 is maximum resolution for unsigned byte. 591 constexpr int RESOLUTION = 7; 592 constexpr int SIMD_STEP = 8; 593 594 CHECK_NOTNULL(output); 595 596 const InterpolationCache<int16> xs = BuildLerpCache<int16>( 597 out_width, in_width, width_scale, channels, RESOLUTION); 598 const InterpolationCache<int16> ys = 599 BuildLerpCache<int16>(out_height, in_height, height_scale, 1, RESOLUTION); 600 601 const int64 in_row_size = in_width * channels; 602 const int64 in_batch_num_values = in_height * in_row_size; 603 const int64 out_row_size = out_width * channels; 604 605 const quint8* input_b_ptr = images.data(); 606 607 quint8* output_y_ptr = output->data(); 608 609 for (int b = 0; b < batch_size; ++b) { 610 for (int64 y = 0; y < out_height; ++y) { 611 const quint8* ys_input_lower_ptr = 612 input_b_ptr + ys.lower[y] * in_row_size; 613 const quint8* ys_input_upper_ptr = 614 input_b_ptr + ys.upper[y] * in_row_size; 615 const int32 ys_ilerp = ys.ilerp[y]; 616 // Optimized for channels == 1 or channels == 3 as this 617 // is typical channels. 618 // TODO(satok): Support more generic NEON optimized implementation 619 // for different channels. 620 int64 x = 0; 621 if (channels == 1) { 622 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) { 623 OutputLerp8x8x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max, 624 ys_input_lower_ptr, ys_input_upper_ptr, 625 output_y_ptr); 626 } 627 } else if (channels == 3) { 628 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) { 629 OutputLerp8x8x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max, 630 ys_input_lower_ptr, ys_input_upper_ptr, 631 output_y_ptr); 632 } 633 } 634 for (; x < out_width; ++x) { 635 OutputLerpForChannels<RESOLUTION, quint8, int16, int16>( 636 xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr, 637 ys_input_upper_ptr, output_y_ptr); 638 } 639 output_y_ptr += out_row_size; 640 } 641 input_b_ptr += in_batch_num_values; 642 } 643 } 644 645 template <typename T> 646 void ResizeBilinear(const typename TTypes<T, 4>::ConstTensor& images, 647 const float height_scale, const float width_scale, 648 const float in_min, const float in_max, 649 typename TTypes<T, 4>::Tensor* output) { 650 CHECK_NOTNULL(output); 651 652 const int batch_size = images.dimension(0); 653 const int64 in_height = images.dimension(1); 654 const int64 in_width = images.dimension(2); 655 const int channels = images.dimension(3); 656 657 const int64 out_height = output->dimension(1); 658 const int64 out_width = output->dimension(2); 659 660 // Handle no-op resizes efficiently. 661 if (out_height == in_height && out_width == in_width) { 662 *output = images.template cast<T>(); 663 return; 664 } 665 666 if (USE_REFERENCE) { 667 ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height, 668 out_width, channels, height_scale, width_scale, 669 in_min, in_max, output); 670 } else { 671 ResizeImage<T>(images, batch_size, in_height, in_width, out_height, 672 out_width, channels, height_scale, width_scale, in_min, 673 in_max, output); 674 } 675 } 676 677 } // namespace 678 679 template <class T> 680 class QuantizedResizeBilinearOp : public OpKernel { 681 public: 682 explicit QuantizedResizeBilinearOp(OpKernelConstruction* context) 683 : OpKernel(context) { 684 OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_)); 685 } 686 687 void Compute(OpKernelContext* context) override { 688 const Tensor& input = context->input(0); 689 const float in_min = context->input(2).flat<float>()(0); 690 const float in_max = context->input(3).flat<float>()(0); 691 692 ImageResizerState st(align_corners_); 693 st.ValidateAndCreateOutput(context, input); 694 695 if (!context->status().ok()) return; 696 697 // Return if the output is empty. 698 if (st.output->NumElements() == 0) return; 699 700 typename TTypes<T, 4>::ConstTensor image_data(input.tensor<T, 4>()); 701 typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>()); 702 703 ResizeBilinear<T>(image_data, st.height_scale, st.width_scale, in_min, 704 in_max, &output_data); 705 Tensor* out_min = nullptr; 706 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &out_min)); 707 out_min->flat<float>()(0) = in_min; 708 709 Tensor* out_max = nullptr; 710 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &out_max)); 711 out_max->flat<float>()(0) = in_max; 712 } 713 714 private: 715 bool align_corners_; 716 717 TF_DISALLOW_COPY_AND_ASSIGN(QuantizedResizeBilinearOp<T>); 718 }; 719 720 #define REGISTER_CPU_KERNEL(type) \ 721 REGISTER_KERNEL_BUILDER(Name("QuantizedResizeBilinear") \ 722 .Device(DEVICE_CPU) \ 723 .HostMemory("size") \ 724 .TypeConstraint<type>("T"), \ 725 QuantizedResizeBilinearOp<type>) 726 727 REGISTER_CPU_KERNEL(::tensorflow::quint8); 728 REGISTER_CPU_KERNEL(::tensorflow::qint32); 729 REGISTER_CPU_KERNEL(float); 730 731 } // namespace tensorflow 732