Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     19 #include "tensorflow/core/framework/numeric_op.h"
     20 #include "tensorflow/core/framework/op_kernel.h"
     21 #include "tensorflow/core/framework/register_types.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/kernels/quantization_utils.h"
     24 
     25 namespace tensorflow {
     26 
     27 namespace {
     28 
     29 // A slow but straightforward implementation of batch normalization.
     30 template <typename T1, typename T2>
     31 void ReferenceBatchNorm(const Tensor& input, const float input_min,
     32                         const float input_max, const Tensor& mean,
     33                         float mean_min, float mean_max, const Tensor& var,
     34                         float var_min, float var_max, const Tensor& beta,
     35                         float beta_min, float beta_max, const Tensor& gamma,
     36                         float gamma_min, float gamma_max,
     37                         float variance_epsilon, bool scale_after_normalization,
     38                         Tensor* output, float* output_min, float* output_max) {
     39   auto input_flat = input.flat<T1>();
     40   auto mean_flat = mean.flat<T1>();
     41   auto var_flat = var.flat<T1>();
     42   auto beta_flat = beta.flat<T1>();
     43   auto gamma_flat = gamma.flat<T1>();
     44   auto output_flat = output->flat<T2>();
     45 
     46   const int depth = mean.dim_size(0);
     47   const int row_count = input_flat.size() / depth;
     48 
     49   *output_min = std::numeric_limits<float>::max();
     50   *output_max = std::numeric_limits<float>::lowest();
     51   for (int pass = 0; pass < 2; ++pass) {
     52     const bool is_range_pass = (pass == 0);
     53     for (int row_index = 0; row_index < row_count; ++row_index) {
     54       for (int channel = 0; channel < depth; ++channel) {
     55         const int input_index = (row_index * depth) + channel;
     56         const float input_value =
     57             QuantizedToFloat(input_flat(input_index), input_min, input_max);
     58         const float mean_value =
     59             QuantizedToFloat(mean_flat(channel), mean_min, mean_max);
     60         const float var_value =
     61             QuantizedToFloat(var_flat(channel), var_min, var_max);
     62         const float beta_value =
     63             QuantizedToFloat(beta_flat(channel), beta_min, beta_max);
     64         const float gamma_value =
     65             QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max);
     66         float output_value;
     67         if (scale_after_normalization) {
     68           output_value = (((input_value - mean_value) /
     69                            sqrtf(var_value + variance_epsilon)) *
     70                           gamma_value) +
     71                          beta_value;
     72         } else {
     73           output_value = ((input_value - mean_value) /
     74                           sqrtf(var_value + variance_epsilon)) +
     75                          beta_value;
     76         }
     77         if (is_range_pass) {
     78           *output_min = std::min(output_value, *output_min);
     79           *output_max = std::max(output_value, *output_max);
     80         } else {
     81           output_flat(input_index) =
     82               FloatToQuantized<T2>(output_value, *output_min, *output_max);
     83         }
     84       }
     85     }
     86   }
     87 }
     88 
     89 // An implementation of batch normalization that does the main calculations
     90 // using only fixed-point arithmetic. There's a prologue with some floating
     91 // calculations, but assuming the weights are constant these could be hoisted to
     92 // an offline process, or baked into the weights.
     93 template <typename T1, typename T2>
     94 void FixedPointBatchNorm(const Tensor& input, const float input_min,
     95                          const float input_max, const Tensor& mean,
     96                          float mean_min, float mean_max, const Tensor& var,
     97                          float var_min, float var_max, const Tensor& beta,
     98                          float beta_min, float beta_max, const Tensor& gamma,
     99                          float gamma_min, float gamma_max,
    100                          float variance_epsilon, bool scale_after_normalization,
    101                          Tensor* output, float* output_min, float* output_max) {
    102   auto input_flat = input.flat<T1>();
    103   auto mean_flat = mean.flat<T1>();
    104   auto var_flat = var.flat<T1>();
    105   auto beta_flat = beta.flat<T1>();
    106   auto gamma_flat = gamma.flat<T1>();
    107   auto output_flat = output->flat<T2>();
    108 
    109   const int depth = mean.dim_size(0);
    110   const int row_count = input_flat.size() / depth;
    111 
    112   // The range here is chosen so that typical input values fit in without any
    113   // overflow or loss of precision, going from +1m to -1m with 10 bits of fixed
    114   // point precision.
    115   *output_min = -(1 << 20);
    116   *output_max = (1 << 20);
    117 
    118   Tensor scale_tensor(DataTypeToEnum<T2>::v(), {depth});
    119   auto scale_flat = scale_tensor.flat<T2>();
    120   Tensor offset_tensor(DataTypeToEnum<T2>::v(), {depth});
    121   auto offset_flat = offset_tensor.flat<T2>();
    122   for (int channel = 0; channel < depth; ++channel) {
    123     const float mean_value =
    124         QuantizedToFloat(mean_flat(channel), mean_min, mean_max);
    125     const float var_value =
    126         QuantizedToFloat(var_flat(channel), var_min, var_max);
    127     const float beta_value =
    128         QuantizedToFloat(beta_flat(channel), beta_min, beta_max);
    129     const float gamma_value =
    130         QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max);
    131     float scale_value;
    132     if (scale_after_normalization) {
    133       scale_value = (1.0f / sqrtf(var_value + variance_epsilon)) * gamma_value;
    134     } else {
    135       scale_value = (1.0f / sqrtf(var_value + variance_epsilon));
    136     }
    137     const float offset_value = (-mean_value * scale_value) + beta_value;
    138     scale_flat(channel) =
    139         FloatToQuantized<T2>(scale_value, *output_min, *output_max);
    140     offset_flat(channel) =
    141         FloatToQuantized<T2>(offset_value, *output_min, *output_max);
    142   }
    143 
    144   const T2 one_in_output_space =
    145       FloatToQuantized<T2>(1.0f, *output_min, *output_max);
    146   for (int row_index = 0; row_index < row_count; ++row_index) {
    147     for (int channel = 0; channel < depth; ++channel) {
    148       const int input_index = (row_index * depth) + channel;
    149       const T2 input_value =
    150           RequantizeInNewRange<T1, T2>(input_flat(input_index), input_min,
    151                                        input_max, *output_min, *output_max);
    152       const T2 scale_value = scale_flat(channel);
    153       const T2 offset_value = offset_flat(channel);
    154       const T2 output_value =
    155           ((input_value * scale_value) / one_in_output_space) + offset_value;
    156       output_flat(input_index) = output_value;
    157     }
    158   }
    159 }
    160 
    161 }  // namespace
    162 
    163 template <typename T1, typename T2>
    164 class QuantizedBatchNormOp : public OpKernel {
    165  public:
    166   explicit QuantizedBatchNormOp(OpKernelConstruction* context)
    167       : OpKernel(context) {
    168     OP_REQUIRES_OK(context,
    169                    context->GetAttr("variance_epsilon", &variance_epsilon_));
    170     OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
    171                                              &scale_after_normalization_));
    172   }
    173 
    174   void Compute(OpKernelContext* context) override {
    175     const Tensor& input = context->input(0);
    176     const float input_min = context->input(1).flat<float>()(0);
    177     const float input_max = context->input(2).flat<float>()(0);
    178     const Tensor& mean = context->input(3);
    179     const float mean_min = context->input(4).flat<float>()(0);
    180     const float mean_max = context->input(5).flat<float>()(0);
    181     const Tensor& var = context->input(6);
    182     const float var_min = context->input(7).flat<float>()(0);
    183     const float var_max = context->input(8).flat<float>()(0);
    184     const Tensor& beta = context->input(9);
    185     const float beta_min = context->input(10).flat<float>()(0);
    186     const float beta_max = context->input(11).flat<float>()(0);
    187     const Tensor& gamma = context->input(12);
    188     const float gamma_min = context->input(13).flat<float>()(0);
    189     const float gamma_max = context->input(14).flat<float>()(0);
    190 
    191     OP_REQUIRES(context, input.dims() == 4,
    192                 errors::InvalidArgument("input must be 4-dimensional",
    193                                         input.shape().DebugString()));
    194     OP_REQUIRES(context, mean.dims() == 1,
    195                 errors::InvalidArgument("mean must be 1-dimensional",
    196                                         mean.shape().DebugString()));
    197     OP_REQUIRES(context, var.dims() == 1,
    198                 errors::InvalidArgument("var must be 1-dimensional",
    199                                         var.shape().DebugString()));
    200     OP_REQUIRES(context, beta.dims() == 1,
    201                 errors::InvalidArgument("beta must be 1-dimensional",
    202                                         beta.shape().DebugString()));
    203     OP_REQUIRES(context, gamma.dims() == 1,
    204                 errors::InvalidArgument("gamma must be 1-dimensional",
    205                                         gamma.shape().DebugString()));
    206 
    207     Tensor* output = nullptr;
    208     OP_REQUIRES_OK(context,
    209                    context->allocate_output(0, input.shape(), &output));
    210     float output_min;
    211     float output_max;
    212     FixedPointBatchNorm<T1, T2>(input, input_min, input_max, mean, mean_min,
    213                                 mean_max, var, var_min, var_max, beta, beta_min,
    214                                 beta_max, gamma, gamma_min, gamma_max,
    215                                 variance_epsilon_, scale_after_normalization_,
    216                                 output, &output_min, &output_max);
    217 
    218     Tensor* output_min_tensor = nullptr;
    219     OP_REQUIRES_OK(context,
    220                    context->allocate_output(1, {}, &output_min_tensor));
    221     output_min_tensor->flat<float>()(0) = output_min;
    222 
    223     Tensor* output_max_tensor = nullptr;
    224     OP_REQUIRES_OK(context,
    225                    context->allocate_output(2, {}, &output_max_tensor));
    226     output_max_tensor->flat<float>()(0) = output_max;
    227   }
    228 
    229  private:
    230   float variance_epsilon_;
    231   bool scale_after_normalization_;
    232 };
    233 
    234 REGISTER_KERNEL_BUILDER(Name("QuantizedBatchNormWithGlobalNormalization")
    235                             .Device(DEVICE_CPU)
    236                             .TypeConstraint<quint8>("Tinput")
    237                             .TypeConstraint<qint32>("out_type"),
    238                         QuantizedBatchNormOp<quint8, qint32>);
    239 
    240 }  // namespace tensorflow
    241