Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
     19 #define USE_NEON
     20 #include <arm_neon.h>
     21 #endif
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/framework/numeric_op.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 
     29 #include "tensorflow/core/kernels/quantization_utils.h"
     30 
     31 #ifdef USE_NEON
     32 namespace {
     33 
     34 // Single pass mean and variance.
     35 // Shape of `input` is [rows x cols], shape of both `mean` and `variance`
     36 // is [cols].
     37 // Note, `mean` and `variance` are of 'i' (not scaled).
     38 // The following is a straightforward implementation of the parallel algorithm
     39 // described in
     40 // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
     41 void ColMeanAndVariance(const uint8_t* input, const uint32_t rows,
     42                         const uint32_t cols, float* mean, float* variance) {
     43   // The implementation operates on for 16 columns at a time.
     44   // Assumes cols % 16 == 0
     45   for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
     46     // Vector registers to track the running sum across the rows. Since there
     47     // are 16 columns, we have 4 32x4 registers.
     48     uint32x4_t sum[4] = {0};
     49 
     50     float nA = 0.0f;
     51     // Running average and the second moment.
     52     float32x4_t xA[4] = {0.0f};
     53     float32x4_t M2A[4] = {0.0f};
     54 
     55     const uint8_t* inp_ptr = input + col_offset;
     56     // Go over the rows in chunks of 256. This is so that we can use 16 bit adds
     57     // to do the accumulation.
     58     for (uint32_t row = 0; row < rows; row += 256) {
     59       // Running sum and sum of squares for the 256 rows.
     60       uint32x4_t sub_sum[4] = {0};
     61       uint32x4_t sub_sq_sum[4] = {0};
     62       const uint32_t limit = std::min(rows, row + 256);
     63       const float nB = limit - row;
     64       for (uint32_t subrow = row; subrow < limit; ++subrow) {
     65         const uint8x16_t v = vld1q_u8(inp_ptr);
     66         inp_ptr += cols;
     67 
     68         const uint8x8_t v_high = vget_high_u8(v);
     69         const uint8x8_t v_low = vget_low_u8(v);
     70 
     71         const uint16x8_t v_high_u16 = vmovl_u8(v_high);
     72         const uint16x8_t v_low_u16 = vmovl_u8(v_low);
     73 
     74         const uint16x4_t v_high_high = vget_high_u16(v_high_u16);
     75         const uint16x4_t v_high_low = vget_low_u16(v_high_u16);
     76         const uint16x4_t v_low_high = vget_high_u16(v_low_u16);
     77         const uint16x4_t v_low_low = vget_low_u16(v_low_u16);
     78 
     79         sub_sum[0] = vaddw_u16(sub_sum[0], v_high_high);
     80         sub_sum[1] = vaddw_u16(sub_sum[1], v_high_low);
     81         sub_sum[2] = vaddw_u16(sub_sum[2], v_low_high);
     82         sub_sum[3] = vaddw_u16(sub_sum[3], v_low_low);
     83 
     84         sub_sq_sum[0] = vmlal_u16(sub_sq_sum[0], v_high_high, v_high_high);
     85         sub_sq_sum[1] = vmlal_u16(sub_sq_sum[1], v_high_low, v_high_low);
     86         sub_sq_sum[2] = vmlal_u16(sub_sq_sum[2], v_low_high, v_low_high);
     87         sub_sq_sum[3] = vmlal_u16(sub_sq_sum[3], v_low_low, v_low_low);
     88       }
     89 
     90       // Update the full running sum and moment from the ones for 256 rows.
     91       for (int i = 0; i < 4; ++i) {
     92         sum[i] = vaddq_u32(sum[i], sub_sum[i]);
     93         const float nX = nA + nB;
     94         // xB is the average of up to 256 elements.
     95         const float32x4_t xB =
     96             vmulq_n_f32(vcvtq_f32_u32(sub_sum[i]), 1.0f / nB);
     97 
     98         // delta = xB - xA
     99         const float32x4_t delta = vsubq_f32(xB, xA[i]);
    100         // xA = (nA * xA + nB * xB) / (nA + nB)
    101         xA[i] = vmulq_n_f32(
    102             vaddq_f32(vmulq_n_f32(xA[i], nA), vmulq_n_f32(xB, nB)), 1.0f / nX);
    103 
    104         const float32x4_t sub_sum_f32 = vcvtq_f32_u32(sub_sum[i]);
    105         const float32x4_t sub_sum_sq = vmulq_f32(sub_sum_f32, sub_sum_f32);
    106 
    107         // M2B = sum(xB^2) - sum(xB)^2/nB
    108         const float32x4_t M2B = vsubq_f32(vcvtq_f32_u32(sub_sq_sum[i]),
    109                                           vmulq_n_f32(sub_sum_sq, 1.0f / nB));
    110         const float32x4_t last_term =
    111             vmulq_n_f32(vmulq_f32(delta, delta), nA * nB / nX);
    112         // M2A = oldM2A + M2B + delta^2 * nA*nB/nX
    113         M2A[i] = vaddq_f32(vaddq_f32(M2A[i], M2B), last_term);
    114       }
    115       nA += limit;
    116     }
    117 
    118     // Write the final mean and variance for the 16 columns.
    119     const float inv_rows = 1.0f / static_cast<float>(rows);
    120     vst1q_f32(mean + col_offset, vmulq_n_f32(vcvtq_f32_u32(sum[3]), inv_rows));
    121     vst1q_f32(mean + col_offset + 4,
    122               vmulq_n_f32(vcvtq_f32_u32(sum[2]), inv_rows));
    123     vst1q_f32(mean + col_offset + 8,
    124               vmulq_n_f32(vcvtq_f32_u32(sum[1]), inv_rows));
    125     vst1q_f32(mean + col_offset + 12,
    126               vmulq_n_f32(vcvtq_f32_u32(sum[0]), inv_rows));
    127 
    128     vst1q_f32(variance + col_offset, vmulq_n_f32(M2A[3], inv_rows));
    129     vst1q_f32(variance + col_offset + 4, vmulq_n_f32(M2A[2], inv_rows));
    130     vst1q_f32(variance + col_offset + 8, vmulq_n_f32(M2A[1], inv_rows));
    131     vst1q_f32(variance + col_offset + 12, vmulq_n_f32(M2A[0], inv_rows));
    132   }
    133 }
    134 
    135 // Compute min and max of (input - mean) / sqrt(variance + epsilon).
    136 // This is done in a separate pass so that the normalized value can be
    137 // temporarily computed in floating point precision and not stored anywhere.
    138 void MinAndMax(const uint8_t* input, const uint32_t rows, const uint32_t cols,
    139                const float* mean_ptr, const float* variance_ptr,
    140                float variance_epsilon, float* minimum, float* maximum) {
    141   float v_maximum = std::numeric_limits<float>::min();
    142   float v_minimum = std::numeric_limits<float>::max();
    143   const float32x4_t eps = vdupq_n_f32(variance_epsilon);
    144 
    145   for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
    146     const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset),
    147                                  vld1q_f32(mean_ptr + col_offset + 4),
    148                                  vld1q_f32(mean_ptr + col_offset + 8),
    149                                  vld1q_f32(mean_ptr + col_offset + 12)};
    150     const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset),
    151                                      vld1q_f32(variance_ptr + col_offset + 4),
    152                                      vld1q_f32(variance_ptr + col_offset + 8),
    153                                      vld1q_f32(variance_ptr + col_offset + 12)};
    154     const float32x4_t inv_stddev[4] = {
    155         vrsqrteq_f32(vaddq_f32(variance[0], eps)),
    156         vrsqrteq_f32(vaddq_f32(variance[1], eps)),
    157         vrsqrteq_f32(vaddq_f32(variance[2], eps)),
    158         vrsqrteq_f32(vaddq_f32(variance[3], eps))};
    159 
    160     const uint8_t* inp_ptr = input + col_offset;
    161     for (uint32_t row = 0; row < rows; ++row) {
    162       const uint8x16_t v = vld1q_u8(inp_ptr);
    163       inp_ptr += cols;
    164 
    165       const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
    166       const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
    167 
    168       const float32x4_t v_float[4] = {
    169           vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
    170           vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
    171           vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
    172           vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
    173 
    174       for (int i = 0; i < 4; ++i) {
    175         const float32x4_t normed =
    176             vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
    177         const float32x2_t high = vget_high_f32(normed);
    178         const float32x2_t low = vget_low_f32(normed);
    179         float32x2_t tmp_max = vpmax_f32(low, high);
    180         tmp_max = vpmax_f32(tmp_max, tmp_max);
    181         v_maximum = std::max(v_maximum, vget_lane_f32(tmp_max, 0));
    182         float32x2_t tmp_min = vpmin_f32(low, high);
    183         tmp_min = vpmin_f32(tmp_min, tmp_min);
    184         v_minimum = std::min(v_minimum, vget_lane_f32(tmp_min, 0));
    185       }
    186     }
    187   }
    188   *minimum = v_minimum;
    189   *maximum = v_maximum;
    190 }
    191 
    192 // Compute (input - mean) / sqrt(variance + epsilon) in floating point, quantize
    193 // it in the range (minimum, maximum) and store the result as quint8.
    194 void InstanceNorm(const uint8_t* input, const uint32_t rows,
    195                   const uint32_t cols, const float* mean_ptr,
    196                   const float* variance_ptr, float variance_epsilon,
    197                   float minimum, float maximum, uint8_t* output) {
    198   const float32x4_t eps = vdupq_n_f32(variance_epsilon);
    199   const float32x4_t out_min = vdupq_n_f32(minimum);
    200   const float out_scale = 255.0f / (maximum - minimum);
    201 
    202   for (uint32_t col_offset = 0; col_offset < cols; col_offset += 16) {
    203     const float32x4_t mean[4] = {vld1q_f32(mean_ptr + col_offset + 12),
    204                                  vld1q_f32(mean_ptr + col_offset + 8),
    205                                  vld1q_f32(mean_ptr + col_offset + 4),
    206                                  vld1q_f32(mean_ptr + col_offset)};
    207     const float32x4_t variance[4] = {vld1q_f32(variance_ptr + col_offset + 12),
    208                                      vld1q_f32(variance_ptr + col_offset + 8),
    209                                      vld1q_f32(variance_ptr + col_offset + 4),
    210                                      vld1q_f32(variance_ptr + col_offset)};
    211     const float32x4_t inv_stddev[4] = {
    212         vrsqrteq_f32(vaddq_f32(variance[0], eps)),
    213         vrsqrteq_f32(vaddq_f32(variance[1], eps)),
    214         vrsqrteq_f32(vaddq_f32(variance[2], eps)),
    215         vrsqrteq_f32(vaddq_f32(variance[3], eps))};
    216     const uint8_t* inp_ptr = input + col_offset;
    217     uint8_t* out_ptr = output + col_offset;
    218     for (uint32_t row = 0; row < rows; ++row) {
    219       const uint8x16_t v = vld1q_u8(inp_ptr);
    220       inp_ptr += cols;
    221       const uint16x8_t v_high = vmovl_u8(vget_high_u8(v));
    222       const uint16x8_t v_low = vmovl_u8(vget_low_u8(v));
    223 
    224       const float32x4_t v_float[4] = {
    225           vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_high))),
    226           vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_high))),
    227           vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_low))),
    228           vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_low)))};
    229 
    230       uint16x4_t normed_uint16[4];
    231       for (int i = 0; i < 4; ++i) {
    232         const float32x4_t normed =
    233             vmulq_f32(vsubq_f32(v_float[i], mean[i]), inv_stddev[i]);
    234         const int32x4_t normed_int32 =
    235             vcvtq_s32_f32(vmulq_n_f32(vsubq_f32(normed, out_min), out_scale));
    236         normed_uint16[i] = vqmovun_s32(normed_int32);
    237       }
    238       vst1_u8(out_ptr,
    239               vqmovn_u16(vcombine_u16(normed_uint16[3], normed_uint16[2])));
    240       vst1_u8(out_ptr + 8,
    241               vqmovn_u16(vcombine_u16(normed_uint16[1], normed_uint16[0])));
    242       out_ptr += cols;
    243     }
    244   }
    245 }
    246 
    247 }  // end namespace
    248 #endif  // USE_NEON
    249 
    250 namespace tensorflow {
    251 
    252 typedef Eigen::ThreadPoolDevice CPUDevice;
    253 
    254 class QuantizedInstanceNorm : public OpKernel {
    255  public:
    256   explicit QuantizedInstanceNorm(OpKernelConstruction* context)
    257       : OpKernel(context) {
    258     OP_REQUIRES_OK(context,
    259                    context->GetAttr("variance_epsilon", &variance_epsilon_));
    260     OP_REQUIRES_OK(context,
    261                    context->GetAttr("min_separation", &min_separation_));
    262     OP_REQUIRES_OK(
    263         context, context->GetAttr("output_range_given", &output_range_given_));
    264     if (output_range_given_) {
    265       OP_REQUIRES_OK(context, context->GetAttr("given_y_min", &given_y_min_));
    266       OP_REQUIRES_OK(context, context->GetAttr("given_y_max", &given_y_max_));
    267       OP_REQUIRES(context, given_y_min_ < given_y_max_,
    268                   errors::InvalidArgument(
    269                       "given_y_min must be less than given_y_max : ",
    270                       given_y_min_, " >= ", given_y_max_));
    271     }
    272   }
    273 
    274   void Compute(OpKernelContext* context) override {
    275     const Tensor& input = context->input(0);
    276 
    277     float input_min = context->input(1).flat<float>()(0);
    278     float input_max = context->input(2).flat<float>()(0);
    279     float input_scale = (input_max - input_min) / 255.0f;
    280 
    281     OP_REQUIRES(context, input_min < input_max,
    282                 errors::InvalidArgument(
    283                     "input_min must be less than input_max : ", input_min,
    284                     " >= ", input_max));
    285 
    286     auto input_tensor = input.tensor<quint8, 4>();
    287     auto N = input_tensor.dimension(0);
    288     auto H = input_tensor.dimension(1);
    289     auto W = input_tensor.dimension(2);
    290     auto C = input_tensor.dimension(3);
    291 
    292     Tensor* output = nullptr;
    293     OP_REQUIRES_OK(context,
    294                    context->allocate_output(0, input.shape(), &output));
    295 
    296     Tensor* output_min = nullptr;
    297     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min));
    298     Tensor* output_max = nullptr;
    299     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max));
    300 
    301     typedef TTypes<float>::Tensor::Index Index;
    302 
    303 #if defined(EIGEN_HAS_INDEX_LIST)
    304     const Eigen::IndexList<Eigen::type2index<1>, Eigen::type2index<2>>
    305         reduction_indices;
    306     Eigen::IndexList<Eigen::type2index<1>, Index, Index, Eigen::type2index<1>>
    307         broadcast_spec;
    308     broadcast_spec.set(1, H);
    309     broadcast_spec.set(2, W);
    310     Eigen::IndexList<Index, Eigen::type2index<1>, Eigen::type2index<1>, Index>
    311         expand_spec;
    312     expand_spec.set(0, N);
    313     expand_spec.set(3, C);
    314 #else
    315     const Eigen::array<Index, 2> reduction_indices{1, 2};
    316     const Eigen::array<Index, 4> broadcast_spec{1, H, W, 1};
    317     const Eigen::array<Index, 4> expand_spec{N, 1, 1, C};
    318 #endif
    319 
    320     Eigen::Tensor<float, 2, Eigen::RowMajor> float_mean(N, C);
    321     Eigen::Tensor<float, 2, Eigen::RowMajor> float_variance(N, C);
    322 
    323 #ifdef USE_NEON
    324     if (N == 1 && (C % 16 == 0)) {
    325       VLOG(2) << "Calling optimized";
    326       ColMeanAndVariance(reinterpret_cast<const uint8_t*>(input_tensor.data()),
    327                          H * W, C, float_mean.data(), float_variance.data());
    328 
    329       float minimum = given_y_min_, maximum = given_y_max_;
    330       if (!output_range_given_) {
    331         MinAndMax(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
    332                   C, float_mean.data(), float_variance.data(),
    333                   variance_epsilon_, &minimum, &maximum);
    334       }
    335 
    336       if (maximum - minimum < min_separation_) {
    337         maximum = minimum + min_separation_;
    338       }
    339 
    340       InstanceNorm(reinterpret_cast<const uint8_t*>(input_tensor.data()), H * W,
    341                    C, float_mean.data(), float_variance.data(),
    342                    variance_epsilon_, minimum, maximum,
    343                    reinterpret_cast<uint8_t*>(output->flat<quint8>().data()));
    344       output_min->scalar<float>()() = minimum;
    345       output_max->scalar<float>()() = maximum;
    346     } else  // NOLINT(readability/braces)
    347 #endif
    348     {
    349       VLOG(2) << "Calling unoptimized";
    350       float_mean = input_tensor.cast<float>().reduce(
    351           reduction_indices, Eigen::internal::MeanReducer<float>());
    352 
    353       float_variance =
    354           (input_scale *
    355            ((input_tensor.cast<float>() -
    356              float_mean.reshape(expand_spec).broadcast(broadcast_spec))))
    357               .square()
    358               .reduce(reduction_indices, Eigen::internal::MeanReducer<float>());
    359 
    360       Eigen::Tensor<float, 4, Eigen::RowMajor> instance_normed =
    361           input_scale *
    362           (input_tensor.cast<float>() -
    363            float_mean.reshape(expand_spec).broadcast(broadcast_spec)) *
    364           (float_variance + variance_epsilon_)
    365               .rsqrt()
    366               .reshape(expand_spec)
    367               .broadcast(broadcast_spec);
    368 
    369       Eigen::Tensor<float, 0, Eigen::RowMajor> normed_min;
    370       Eigen::Tensor<float, 0, Eigen::RowMajor> normed_max;
    371 
    372       if (!output_range_given_) {
    373         normed_min = instance_normed.minimum();
    374         normed_max = instance_normed.maximum();
    375       } else {
    376         normed_min() = given_y_min_;
    377         normed_max() = given_y_max_;
    378       }
    379 
    380       if (normed_max() - normed_min() < min_separation_) {
    381         normed_max() = normed_min() + min_separation_;
    382       }
    383 
    384       FloatToQuantizedStruct<quint8> output_f2q(normed_min(), normed_max());
    385       auto instance_normed_quantized =
    386           QUANTIZE_WITH_EIGEN(instance_normed, output_f2q, quint8);
    387 
    388       output->tensor<quint8, 4>().device(
    389           context->template eigen_device<CPUDevice>()) =
    390           instance_normed_quantized;
    391       output_min->flat<float>()(0) = normed_min();
    392       output_max->flat<float>()(0) = normed_max();
    393     }
    394   }
    395 
    396  private:
    397   float variance_epsilon_;
    398   float min_separation_;
    399   bool output_range_given_;
    400   float given_y_min_;
    401   float given_y_max_;
    402 };
    403 
    404 REGISTER_KERNEL_BUILDER(Name("QuantizedInstanceNorm")
    405                             .Device(DEVICE_CPU)
    406                             .TypeConstraint<quint8>("T"),
    407                         QuantizedInstanceNorm);
    408 
    409 }  // namespace tensorflow
    410