Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Implements a quantized eight-bit version of the matmul operation.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
     21 #define USE_NEON
     22 #define QUANTIZED_ADD_USE_NEON
     23 #include <arm_neon.h>
     24 #endif
     25 
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/kernels/meta_support.h"
     29 #include "tensorflow/core/kernels/quantization_utils.h"
     30 #include "tensorflow/core/lib/core/casts.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/util/bcast.h"
     33 
     34 // There are implementations for three broadcast patterns for add:
     35 //  - Scalar * Array
     36 //  - Array * Array
     37 //  - Array * Shorter Array (repeated to match first)
     38 //
     39 // These handle a lot of common broadcast patterns, and we have NEON SIMD
     40 // versions to accelerate performance on ARM platforms.
     41 
     42 namespace tensorflow {
     43 namespace {
     44 
     45 template <class T, class Toutput>
     46 void ScalarAddition(OpKernelContext* context, const T* full_input,
     47                     float full_input_min, float full_input_max,
     48                     int64 num_elements, T scalar_input, float scalar_input_min,
     49                     float scalar_input_max, float output_min, float output_max,
     50                     Toutput* output) {
     51   const Toutput scalar_in_output_range = RequantizeInNewRange<T, Toutput>(
     52       scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
     53   for (int i = 0; i < num_elements; ++i) {
     54     const Toutput full_input_in_output_range = RequantizeInNewRange<T, Toutput>(
     55         full_input[i], full_input_min, full_input_max, output_min, output_max);
     56     output[i] = full_input_in_output_range + scalar_in_output_range;
     57   }
     58 }
     59 
     60 #ifdef QUANTIZED_ADD_USE_NEON
     61 
     62 template <>
     63 void ScalarAddition(OpKernelContext* context, const quint8* full_input,
     64                     float full_input_min, float full_input_max,
     65                     int64 num_elements, quint8 scalar_input,
     66                     float scalar_input_min, float scalar_input_max,
     67                     float output_min, float output_max, qint32* output) {
     68   const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>(
     69       scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
     70 
     71   const float input_0_float =
     72       QuantizedToFloat<quint8>(0, full_input_min, full_input_max);
     73   const float input_1_float =
     74       QuantizedToFloat<quint8>(1, full_input_min, full_input_max);
     75   const int64 input_0_int64 =
     76       FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max);
     77   const int64 input_1_int64 =
     78       FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max);
     79   const int32 input_mult_int32 = input_1_int64 - input_0_int64;
     80 
     81   const int64 lowest_quantized =
     82       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
     83   const int64 highest_quantized =
     84       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
     85 
     86   const int64x2_t input_0_64x2 = vmovq_n_s64(input_0_int64);
     87   const int32x2_t input_mult_32x2 = vmov_n_s32(input_mult_int32);
     88   const int32x4_t scalar_in_output_range_32x4 =
     89       vmovq_n_s32(scalar_in_output_range);
     90   int64 i = 0;
     91   for (; i < (num_elements - 7); i += 8) {
     92     const uint8* full_input_ptr = &(full_input->value) + i;
     93     const std::array<int32x4_t, 2> output_value =
     94         Requantize8x8To32Neon(full_input_ptr, input_0_64x2, input_mult_32x2);
     95     const int32x4_t result_low_32x4 =
     96         vaddq_s32(output_value[0], scalar_in_output_range_32x4);
     97     const int32x4_t result_high_32x4 =
     98         vaddq_s32(output_value[1], scalar_in_output_range_32x4);
     99     int32* output_ptr = &(output->value) + i;
    100     vst1q_s32(output_ptr + 0, result_low_32x4);
    101     vst1q_s32(output_ptr + 4, result_high_32x4);
    102   }
    103   for (; i < num_elements; ++i) {
    104     const int64 full_input_value = static_cast<int64>(full_input[i]);
    105     int64 full_input_in_output_range_64 =
    106         input_0_int64 + (full_input_value * input_mult_int32);
    107     full_input_in_output_range_64 =
    108         std::max(full_input_in_output_range_64, lowest_quantized);
    109     full_input_in_output_range_64 =
    110         std::min(full_input_in_output_range_64, highest_quantized);
    111     const int32 full_input_in_output_range =
    112         static_cast<int32>(full_input_in_output_range_64);
    113     output[i] = full_input_in_output_range + scalar_in_output_range;
    114   }
    115 }
    116 
    117 #else  // QUANTIZED_ADD_USE_NEON
    118 
    119 template <>
    120 void ScalarAddition(OpKernelContext* context, const quint8* full_input,
    121                     float full_input_min, float full_input_max,
    122                     int64 num_elements, quint8 scalar_input,
    123                     float scalar_input_min, float scalar_input_max,
    124                     float output_min, float output_max, qint32* output) {
    125   const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>(
    126       scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
    127 
    128   const float input_0_float =
    129       QuantizedToFloat<quint8>(0, full_input_min, full_input_max);
    130   const float input_1_float =
    131       QuantizedToFloat<quint8>(1, full_input_min, full_input_max);
    132   const int64 input_0_int64 =
    133       FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max);
    134   const int64 input_1_int64 =
    135       FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max);
    136   const int32 input_mult_int32 = input_1_int64 - input_0_int64;
    137 
    138   const int64 lowest_quantized =
    139       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
    140   const int64 highest_quantized =
    141       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
    142 
    143   for (int i = 0; i < num_elements; ++i) {
    144     const int64 full_input_value = static_cast<int64>(full_input[i]);
    145     int64 full_input_in_output_range_64 =
    146         input_0_int64 + (full_input_value * input_mult_int32);
    147     full_input_in_output_range_64 =
    148         std::max(full_input_in_output_range_64, lowest_quantized);
    149     full_input_in_output_range_64 =
    150         std::min(full_input_in_output_range_64, highest_quantized);
    151     const int32 full_input_in_output_range =
    152         static_cast<int32>(full_input_in_output_range_64);
    153     output[i] = full_input_in_output_range + scalar_in_output_range;
    154   }
    155 }
    156 
    157 #endif  // QUANTIZED_ADD_USE_NEON
    158 
    159 template <class T, class Toutput>
    160 void VectorAddition(OpKernelContext* context, const T* x_data, float min_x,
    161                     float max_x, const T* y_data, float min_y, float max_y,
    162                     int64 num_elements, float output_min, float output_max,
    163                     Toutput* output) {
    164   for (int i = 0; i < num_elements; ++i) {
    165     const Toutput x_in_output_range = RequantizeInNewRange<T, Toutput>(
    166         x_data[i], min_x, max_x, output_min, output_max);
    167     const Toutput y_in_output_range = RequantizeInNewRange<T, Toutput>(
    168         y_data[i], min_y, max_y, output_min, output_max);
    169     output[i] = x_in_output_range + y_in_output_range;
    170   }
    171 }
    172 
    173 #ifdef QUANTIZED_ADD_USE_NEON
    174 
    175 template <>
    176 void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x,
    177                     float max_x, const quint8* y_data, float min_y, float max_y,
    178                     int64 num_elements, float output_min, float output_max,
    179                     qint32* output) {
    180   const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x);
    181   const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x);
    182   const int64 x_0_int64 =
    183       FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max);
    184   const int64 x_1_int64 =
    185       FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max);
    186   const int32 x_mult_int32 = x_1_int64 - x_0_int64;
    187 
    188   const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y);
    189   const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y);
    190   const int64 y_0_int64 =
    191       FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max);
    192   const int64 y_1_int64 =
    193       FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max);
    194   const int32 y_mult_int32 = y_1_int64 - y_0_int64;
    195 
    196   const int64 lowest_quantized =
    197       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
    198   const int64 highest_quantized =
    199       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
    200 
    201   const int64x2_t x_0_64x2 = vmovq_n_s64(x_0_int64);
    202   const int32x2_t x_mult_32x2 = vmov_n_s32(x_mult_int32);
    203 
    204   const int64x2_t y_0_64x2 = vmovq_n_s64(y_0_int64);
    205   const int32x2_t y_mult_32x2 = vmov_n_s32(y_mult_int32);
    206 
    207   int64 i = 0;
    208   for (; i < (num_elements - 7); i += 8) {
    209     const uint8* x_ptr = &(x_data->value) + i;
    210     const std::array<int32x4_t, 2> x_output_value =
    211         Requantize8x8To32Neon(x_ptr, x_0_64x2, x_mult_32x2);
    212     const uint8* y_ptr = &(y_data->value) + i;
    213     const std::array<int32x4_t, 2> y_output_value =
    214         Requantize8x8To32Neon(y_ptr, y_0_64x2, y_mult_32x2);
    215 
    216     const int32x4_t result_low_32x4 =
    217         vaddq_s32(x_output_value[0], y_output_value[0]);
    218     const int32x4_t result_high_32x4 =
    219         vaddq_s32(x_output_value[1], y_output_value[1]);
    220     int32* output_ptr = &(output->value) + i;
    221     vst1q_s32(output_ptr + 0, result_low_32x4);
    222     vst1q_s32(output_ptr + 4, result_high_32x4);
    223   }
    224 
    225   for (; i < num_elements; ++i) {
    226     const int64 x_value = static_cast<int64>(x_data[i]);
    227     int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32);
    228     x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized);
    229     x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized);
    230     const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64);
    231 
    232     const int64 y_value = static_cast<int64>(y_data[i]);
    233     int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32);
    234     y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized);
    235     y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized);
    236     const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64);
    237 
    238     output[i] = x_in_output_range + y_in_output_range;
    239   }
    240 }
    241 
    242 #else  // QUANTIZED_ADD_USE_NEON
    243 
    244 template <>
    245 void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x,
    246                     float max_x, const quint8* y_data, float min_y, float max_y,
    247                     int64 num_elements, float output_min, float output_max,
    248                     qint32* output) {
    249   const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x);
    250   const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x);
    251   const int64 x_0_int64 =
    252       FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max);
    253   const int64 x_1_int64 =
    254       FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max);
    255   const int32 x_mult_int32 = x_1_int64 - x_0_int64;
    256 
    257   const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y);
    258   const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y);
    259   const int64 y_0_int64 =
    260       FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max);
    261   const int64 y_1_int64 =
    262       FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max);
    263   const int32 y_mult_int32 = y_1_int64 - y_0_int64;
    264 
    265   const int64 lowest_quantized =
    266       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
    267   const int64 highest_quantized =
    268       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
    269 
    270   for (int i = 0; i < num_elements; ++i) {
    271     const int64 x_value = static_cast<int64>(x_data[i]);
    272     int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32);
    273     x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized);
    274     x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized);
    275     const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64);
    276 
    277     const int64 y_value = static_cast<int64>(y_data[i]);
    278     int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32);
    279     y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized);
    280     y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized);
    281     const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64);
    282 
    283     output[i] = x_in_output_range + y_in_output_range;
    284   }
    285 }
    286 
    287 #endif  // QUANTIZED_ADD_USE_NEON
    288 
    289 template <class T, class Toutput>
    290 void VectorTensorAddition(const T* vector_data, float min_vector,
    291                           float max_vector, int64 vector_num_elements,
    292                           const T* tensor_data, float min_tensor,
    293                           float max_tensor, int64 tensor_num_elements,
    294                           float output_min, float output_max, Toutput* output) {
    295   for (int i = 0; i < tensor_num_elements; ++i) {
    296     const int64 vector_i = i % vector_num_elements;
    297     const Toutput vector_in_output_range = RequantizeInNewRange<T, Toutput>(
    298         vector_data[vector_i], min_vector, max_vector, output_min, output_max);
    299     const Toutput tensor_in_output_range = RequantizeInNewRange<T, Toutput>(
    300         tensor_data[i], min_tensor, max_tensor, output_min, output_max);
    301     output[i] = vector_in_output_range + tensor_in_output_range;
    302   }
    303 }
    304 
    305 #ifdef QUANTIZED_ADD_USE_NEON
    306 
    307 template <>
    308 void VectorTensorAddition(const quint8* vector_data, float min_vector,
    309                           float max_vector, int64 vector_num_elements,
    310                           const quint8* tensor_data, float min_tensor,
    311                           float max_tensor, int64 tensor_num_elements,
    312                           float output_min, float output_max, qint32* output) {
    313   const float vector_0_float =
    314       QuantizedToFloat<quint8>(0, min_vector, max_vector);
    315   const float vector_1_float =
    316       QuantizedToFloat<quint8>(1, min_vector, max_vector);
    317   const int64 vector_0_int64 =
    318       FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max);
    319   const int64 vector_1_int64 =
    320       FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max);
    321   const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64;
    322 
    323   const float tensor_0_float =
    324       QuantizedToFloat<quint8>(0, min_tensor, max_tensor);
    325   const float tensor_1_float =
    326       QuantizedToFloat<quint8>(1, min_tensor, max_tensor);
    327   const int64 tensor_0_int64 =
    328       FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max);
    329   const int64 tensor_1_int64 =
    330       FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max);
    331   const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64;
    332 
    333   const int64 lowest_quantized =
    334       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
    335   const int64 highest_quantized =
    336       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
    337 
    338   const int64x2_t vector_0_64x2 = vmovq_n_s64(vector_0_int64);
    339   const int32x2_t vector_mult_32x2 = vmov_n_s32(vector_mult_int32);
    340 
    341   const int64x2_t tensor_0_64x2 = vmovq_n_s64(tensor_0_int64);
    342   const int32x2_t tensor_mult_32x2 = vmov_n_s32(tensor_mult_int32);
    343 
    344   for (int64 base_i = 0; base_i < tensor_num_elements;
    345        base_i += vector_num_elements) {
    346     int64 i = base_i;
    347     int64 vector_i = 0;
    348     for (; vector_i < (vector_num_elements - 7); vector_i += 8, i += 8) {
    349       const uint8* vector_ptr = &(vector_data->value) + vector_i;
    350       const std::array<int32x4_t, 2> vector_output_value =
    351           Requantize8x8To32Neon(vector_ptr, vector_0_64x2, vector_mult_32x2);
    352       const uint8* tensor_ptr = &(tensor_data->value) + i;
    353       const std::array<int32x4_t, 2> tensor_output_value =
    354           Requantize8x8To32Neon(tensor_ptr, tensor_0_64x2, tensor_mult_32x2);
    355 
    356       const int32x4_t result_low_32x4 =
    357           vaddq_s32(vector_output_value[0], tensor_output_value[0]);
    358       const int32x4_t result_high_32x4 =
    359           vaddq_s32(vector_output_value[1], tensor_output_value[1]);
    360       int32* output_ptr = &(output->value) + i;
    361       vst1q_s32(output_ptr + 0, result_low_32x4);
    362       vst1q_s32(output_ptr + 4, result_high_32x4);
    363     }
    364     for (; vector_i < vector_num_elements; ++vector_i, ++i) {
    365       const int64 vector_value = static_cast<int64>(vector_data[vector_i]);
    366       int64 vector_in_output_range_64 =
    367           vector_0_int64 + (vector_value * vector_mult_int32);
    368       vector_in_output_range_64 =
    369           std::max(vector_in_output_range_64, lowest_quantized);
    370       vector_in_output_range_64 =
    371           std::min(vector_in_output_range_64, highest_quantized);
    372       const int32 vector_in_output_range =
    373           static_cast<int32>(vector_in_output_range_64);
    374 
    375       const int64 tensor_value = static_cast<int64>(tensor_data[i]);
    376       int64 tensor_in_output_range_64 =
    377           tensor_0_int64 + (tensor_value * tensor_mult_int32);
    378       tensor_in_output_range_64 =
    379           std::max(tensor_in_output_range_64, lowest_quantized);
    380       tensor_in_output_range_64 =
    381           std::min(tensor_in_output_range_64, highest_quantized);
    382       const int32 tensor_in_output_range =
    383           static_cast<int32>(tensor_in_output_range_64);
    384 
    385       output[i] = vector_in_output_range + tensor_in_output_range;
    386     }
    387   }
    388 }
    389 
    390 #else  // QUANTIZED_ADD_USE_NEON
    391 
    392 template <>
    393 void VectorTensorAddition(const quint8* vector_data, float min_vector,
    394                           float max_vector, int64 vector_num_elements,
    395                           const quint8* tensor_data, float min_tensor,
    396                           float max_tensor, int64 tensor_num_elements,
    397                           float output_min, float output_max, qint32* output) {
    398   const float vector_0_float =
    399       QuantizedToFloat<quint8>(0, min_vector, max_vector);
    400   const float vector_1_float =
    401       QuantizedToFloat<quint8>(1, min_vector, max_vector);
    402   const int64 vector_0_int64 =
    403       FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max);
    404   const int64 vector_1_int64 =
    405       FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max);
    406   const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64;
    407 
    408   const float tensor_0_float =
    409       QuantizedToFloat<quint8>(0, min_tensor, max_tensor);
    410   const float tensor_1_float =
    411       QuantizedToFloat<quint8>(1, min_tensor, max_tensor);
    412   const int64 tensor_0_int64 =
    413       FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max);
    414   const int64 tensor_1_int64 =
    415       FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max);
    416   const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64;
    417 
    418   const int64 lowest_quantized =
    419       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
    420   const int64 highest_quantized =
    421       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
    422 
    423   for (int i = 0; i < tensor_num_elements; ++i) {
    424     const int64 vector_i = i % vector_num_elements;
    425     const int64 vector_value = static_cast<int64>(vector_data[vector_i]);
    426     int64 vector_in_output_range_64 =
    427         vector_0_int64 + (vector_value * vector_mult_int32);
    428     vector_in_output_range_64 =
    429         std::max(vector_in_output_range_64, lowest_quantized);
    430     vector_in_output_range_64 =
    431         std::min(vector_in_output_range_64, highest_quantized);
    432     const int32 vector_in_output_range =
    433         static_cast<int32>(vector_in_output_range_64);
    434 
    435     const int64 tensor_value = static_cast<int64>(tensor_data[i]);
    436     int64 tensor_in_output_range_64 =
    437         tensor_0_int64 + (tensor_value * tensor_mult_int32);
    438     tensor_in_output_range_64 =
    439         std::max(tensor_in_output_range_64, lowest_quantized);
    440     tensor_in_output_range_64 =
    441         std::min(tensor_in_output_range_64, highest_quantized);
    442     const int32 tensor_in_output_range =
    443         static_cast<int32>(tensor_in_output_range_64);
    444 
    445     output[i] = vector_in_output_range + tensor_in_output_range;
    446   }
    447 }
    448 
    449 #endif  // QUANTIZED_ADD_USE_NEON
    450 
    451 }  // namespace
    452 
    453 template <class T, class Toutput>
    454 class QuantizedAddOp : public OpKernel {
    455  public:
    456   explicit QuantizedAddOp(OpKernelConstruction* context) : OpKernel(context) {}
    457 
    458   void Compute(OpKernelContext* context) override {
    459     const Tensor& x = context->input(0);
    460     const Tensor& y = context->input(1);
    461     const float min_x = context->input(2).flat<float>()(0);
    462     const float max_x = context->input(3).flat<float>()(0);
    463     const float min_y = context->input(4).flat<float>()(0);
    464     const float max_y = context->input(5).flat<float>()(0);
    465 
    466     BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape()));
    467     if (!bcast.IsValid()) {
    468       context->SetStatus(errors::InvalidArgument(
    469           "Incompatible shapes: ", x.shape().DebugString(), " vs. ",
    470           y.shape().DebugString()));
    471       return;
    472     }
    473     Tensor* z;
    474     OP_REQUIRES_OK(context, context->allocate_output(
    475                                 0, BCast::ToShape(bcast.output_shape()), &z));
    476 
    477     // Make sure that we have valid quantization ranges for the input buffers.
    478     // If the difference between the min and max is negative or zero, it makes
    479     // it hard to do meaningful intermediate operations on the values.
    480     OP_REQUIRES(context, (max_x > min_x),
    481                 errors::InvalidArgument("max_x must be larger than min_x."));
    482     OP_REQUIRES(context, (max_y > min_y),
    483                 errors::InvalidArgument("max_y must be larger than min_y."));
    484     const T* x_data = x.flat<T>().data();
    485     const T* y_data = y.flat<T>().data();
    486     Toutput* z_data = z->flat<Toutput>().data();
    487 
    488     // We want the range of the output to be symmetrical around zero so that
    489     // adding zero leaves the result unchanged, and to contain the largest of
    490     // the two input values with some room to spare.
    491     const float smallest_min = std::min(min_x, min_y);
    492     const float largest_max = std::max(max_x, max_y);
    493     const float biggest_range =
    494         std::max(std::abs(smallest_min), std::abs(largest_max));
    495     const float output_range = (biggest_range * (1 << 14));
    496     const float min_z_value = -output_range;
    497     const float max_z_value = output_range;
    498 
    499     const int ndims = bcast.x_reshape().size();
    500     if (ndims <= 1) {
    501       if (x.NumElements() == 1) {
    502         ScalarAddition<T, Toutput>(context, y_data, min_y, max_y,
    503                                    y.NumElements(), x_data[0], min_x, max_x,
    504                                    min_z_value, max_z_value, z_data);
    505       } else if (y.NumElements() == 1) {
    506         ScalarAddition<T, Toutput>(context, x_data, min_x, max_x,
    507                                    x.NumElements(), y_data[0], min_y, max_y,
    508                                    min_z_value, max_z_value, z_data);
    509       } else {
    510         VectorAddition<T, Toutput>(context, x_data, min_x, max_x, y_data, min_y,
    511                                    max_y, x.NumElements(), min_z_value,
    512                                    max_z_value, z_data);
    513       }
    514     } else if (ndims == 2) {
    515       const T* vector_data;
    516       int64 vector_num_elements;
    517       float vector_min;
    518       float vector_max;
    519       const T* tensor_data;
    520       int64 tensor_num_elements;
    521       float tensor_min;
    522       float tensor_max;
    523       if (x.NumElements() < y.NumElements()) {
    524         vector_data = x_data;
    525         vector_num_elements = x.NumElements();
    526         vector_min = min_x;
    527         vector_max = max_x;
    528         tensor_data = y_data;
    529         tensor_num_elements = y.NumElements();
    530         tensor_min = min_y;
    531         tensor_max = max_y;
    532       } else {
    533         vector_data = y_data;
    534         vector_num_elements = y.NumElements();
    535         vector_min = min_y;
    536         vector_max = max_y;
    537         tensor_data = x_data;
    538         tensor_num_elements = x.NumElements();
    539         tensor_min = min_x;
    540         tensor_max = max_x;
    541       }
    542       VectorTensorAddition<T, Toutput>(
    543           vector_data, vector_min, vector_max, vector_num_elements, tensor_data,
    544           tensor_min, tensor_max, tensor_num_elements, min_z_value, max_z_value,
    545           z_data);
    546     } else {
    547       LOG(INFO) << "ndims=" << ndims;
    548       LOG(INFO) << "bcast.x_reshape()="
    549                 << TensorShape(bcast.x_reshape()).DebugString();
    550       LOG(INFO) << "bcast.y_reshape()="
    551                 << TensorShape(bcast.y_reshape()).DebugString();
    552       LOG(INFO) << "bcast.x_bcast()="
    553                 << TensorShape(bcast.x_bcast()).DebugString();
    554       LOG(INFO) << "bcast.y_bcast()="
    555                 << TensorShape(bcast.y_bcast()).DebugString();
    556 
    557       context->SetStatus(errors::Unimplemented(
    558           "Broadcast between ", context->input(0).shape().DebugString(),
    559           " and ", context->input(1).shape().DebugString(),
    560           " is not supported yet."));
    561       return;
    562     }
    563 
    564     Tensor* z_min = nullptr;
    565     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &z_min));
    566     z_min->flat<float>()(0) = min_z_value;
    567 
    568     Tensor* z_max = nullptr;
    569     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &z_max));
    570     z_max->flat<float>()(0) = max_z_value;
    571   }
    572 };
    573 
    574 REGISTER_KERNEL_BUILDER(Name("QuantizedAdd")
    575                             .Device(DEVICE_CPU)
    576                             .TypeConstraint<quint8>("T1")
    577                             .TypeConstraint<quint8>("T2")
    578                             .TypeConstraint<qint32>("Toutput"),
    579                         QuantizedAddOp<quint8, qint32>);
    580 
    581 }  // namespace tensorflow
    582