Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // output.h: processing the 32-bit accumulators output by the unpack
     16 // stage, obtaining the final result matrix entries and storing them into
     17 // the destination matrix.
     18 
     19 #ifndef GEMMLOWP_INTERNAL_OUTPUT_H_
     20 #define GEMMLOWP_INTERNAL_OUTPUT_H_
     21 
     22 #include <cmath>
     23 #include <tuple>
     24 #include <type_traits>
     25 
     26 #include "../fixedpoint/fixedpoint.h"
     27 #include "../public/output_stages.h"
     28 #include "simd_wrappers.h"
     29 
     30 namespace gemmlowp {
     31 
     32 template <typename OutputStage, typename InputBufferType>
     33 struct OutputStageEvalBufferImpl {
     34   // This generic template body should never be hit.
     35   static_assert(
     36       std::is_same<InputBufferType, void>::value,
     37       "Unimplemented: missing implementation of this output pipeline stage "
     38       "for this data type. This would happen if some architecture-specific "
     39       "SIMD back-end (output_$arch.h) were incomplete.");
     40 };
     41 
     42 template <typename OutputStage, typename InputType>
     43 struct OutputStageEvalImpl {
     44   static constexpr int kRows = InputType::kRows;
     45   static constexpr int kCols = InputType::kCols;
     46   using InputBufferType = typename InputType::BufferType;
     47   using BufferEvalImplType =
     48       OutputStageEvalBufferImpl<OutputStage, InputBufferType>;
     49   using OutputBufferType = typename BufferEvalImplType::OutputType;
     50   using OutputScalarType = typename OutputBufferType::ScalarType;
     51   using OutputType = RegisterBlock<OutputScalarType, kRows, kCols>;
     52 
     53   OutputStageEvalImpl(const OutputStage& s) : buffer_eval_impl(s) {}
     54 
     55   OutputType Eval(InputType input, int, int) const {
     56     OutputType output;
     57     output.buf = buffer_eval_impl.Eval(input.buf);
     58     return output;
     59   }
     60 
     61   const BufferEvalImplType buffer_eval_impl;
     62 };
     63 
     64 template <int Size>
     65 struct OutputStageEvalBufferImpl<OutputStageQuantizeDownInt32ToUint8Scale,
     66                                  RegisterBuffer<std::int32_t, Size>> {
     67   using InputType = RegisterBuffer<std::int32_t, Size>;
     68   using OutputType = RegisterBuffer<std::int32_t, Size>;
     69 
     70   typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage;
     71 
     72   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
     73 
     74   OutputType Eval(InputType input) const {
     75     const int result_shift = output_stage.result_shift;
     76     const std::int32_t result_mult_int = output_stage.result_mult_int;
     77     using RegisterType = typename InputType::RegisterType;
     78     const RegisterType result_offset =
     79         Dup<RegisterType>(output_stage.result_offset);
     80     OutputType output;
     81     for (int i = 0; i < InputType::kRegisterCount; i++) {
     82       output.reg[i] = RoundingDivideByPOT(
     83           Mul(Add(input.reg[i], result_offset), result_mult_int), result_shift);
     84     }
     85     return output;
     86   }
     87 
     88   const OutputStage& output_stage;
     89 };
     90 
     91 template <int Rows, int Cols, VectorShape Shape>
     92 struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>,
     93                            RegisterBlock<std::int32_t, Rows, Cols>> {
     94   typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
     95   typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
     96   typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> OutputStage;
     97 
     98   OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
     99 
    100   OutputType Eval(InputType input, int row, int col) const {
    101     OutputType output;
    102     const int result_shift = output_stage.result_shift;
    103     const int pos = Shape == VectorShape::Col ? row : col;
    104     const auto result_mult_int =
    105         LoadForBroadcasting<InputType>(output_stage.result_mult_int, pos);
    106     const auto result_offset =
    107         LoadForBroadcasting<InputType>(output_stage.result_offset, pos);
    108     const auto dividend = BroadcastMul<InputType>(
    109         BroadcastAdd<InputType>(input, result_offset), result_mult_int);
    110     for (int i = 0; i < InputType::kRegisterCount; i++) {
    111       output.buf.reg[i] =
    112           RoundingDivideByPOT(dividend.buf.reg[i], result_shift);
    113     }
    114     return output;
    115   }
    116 
    117   const OutputStage& output_stage;
    118 };
    119 
    120 template <int Size>
    121 struct OutputStageEvalBufferImpl<
    122     OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
    123     RegisterBuffer<std::int32_t, Size>> {
    124   typedef RegisterBuffer<std::int32_t, Size> InputType;
    125   typedef RegisterBuffer<std::int32_t, Size> OutputType;
    126 
    127   typedef OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint OutputStage;
    128 
    129   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
    130 
    131   OutputType Eval(InputType input) const {
    132     OutputType output;
    133     using RegisterType = typename InputType::RegisterType;
    134     const RegisterType result_offset_after_shift =
    135         Dup<RegisterType>(output_stage.result_offset_after_shift);
    136     for (int i = 0; i < InputType::kRegisterCount; i++) {
    137       const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul(
    138           input.reg[i], output_stage.result_fixedpoint_multiplier);
    139       output.reg[i] =
    140           Add(RoundingDivideByPOT(mulhigh_val, output_stage.result_shift),
    141               result_offset_after_shift);
    142     }
    143     return output;
    144   }
    145 
    146   const OutputStage& output_stage;
    147 };
    148 
    149 // Implementation of OutputStageSaturatingCastToUint8 for scalar data
    150 template <int Size>
    151 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
    152                                  RegisterBuffer<std::int32_t, Size>> {
    153   typedef RegisterBuffer<std::int32_t, Size> InputType;
    154   typedef RegisterBuffer<std::uint8_t, Size> OutputType;
    155   static_assert(InputType::kRegisterLanes == 1,
    156                 "This path is only for scalar values");
    157 
    158   typedef OutputStageSaturatingCastToUint8 OutputStage;
    159 
    160   OutputStageEvalBufferImpl(const OutputStage&) {}
    161 
    162   OutputType Eval(InputType input) const {
    163     OutputType output;
    164     for (int i = 0; i < InputType::kRegisterCount; i++) {
    165       std::int32_t data = input.reg[i];
    166       output.reg[i] = data > 255 ? 255 : data < 0 ? 0 : data;
    167     }
    168     return output;
    169   }
    170 };
    171 
    172 template <int Rows, int Cols, typename VectorType>
    173 struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>,
    174                            RegisterBlock<std::int32_t, Rows, Cols>> {
    175   typedef RegisterBlock<std::int32_t, Rows, Cols> InputType;
    176   typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType;
    177   typedef OutputStageBiasAddition<VectorType> OutputStage;
    178 
    179   OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {}
    180 
    181   OutputType Eval(InputType input, int row, int col) const {
    182     const int pos = VectorType::kShape == VectorShape::Row ? col : row;
    183     return BroadcastAdd<InputType>(
    184         input, LoadForBroadcasting<InputType>(output_stage.bias_vector, pos));
    185   }
    186 
    187   const OutputStage& output_stage;
    188 };
    189 
    190 template <int Size>
    191 struct OutputStageEvalBufferImpl<OutputStageClamp,
    192                                  RegisterBuffer<std::int32_t, Size>> {
    193   typedef RegisterBuffer<std::int32_t, Size> InputType;
    194   typedef RegisterBuffer<std::int32_t, Size> OutputType;
    195 
    196   typedef OutputStageClamp OutputStage;
    197 
    198   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {}
    199 
    200   OutputType Eval(InputType input) const {
    201     using RegisterType = typename InputType::RegisterType;
    202     const RegisterType min = Dup<RegisterType>(output_stage.min);
    203     const RegisterType max = Dup<RegisterType>(output_stage.max);
    204     OutputType output;
    205     for (int i = 0; i < InputType::kRegisterCount; i++) {
    206       output.reg[i] = Min(Max(input.reg[i], min), max);
    207     }
    208     return output;
    209   }
    210 
    211   const OutputStage& output_stage;
    212 };
    213 
    214 template <int Size>
    215 struct OutputStageEvalBufferImpl<OutputStageTanh,
    216                                  RegisterBuffer<std::int32_t, Size>> {
    217   typedef RegisterBuffer<std::int32_t, Size> InputType;
    218   typedef RegisterBuffer<std::int32_t, Size> OutputType;
    219   using RegisterType = typename InputType::RegisterType;
    220   typedef RegisterType DataType;
    221   typedef OutputStageTanh OutputStage;
    222 
    223   OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {
    224     const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32;
    225     const std::int32_t real_amplitude_as_int32 =
    226         output_stage.real_amplitude_as_int32;
    227 
    228     input_cutoff_min = real_zero_as_int32 - 8 * real_amplitude_as_int32;
    229     input_cutoff_max = real_zero_as_int32 + 8 * real_amplitude_as_int32;
    230     output_min = real_zero_as_int32 - real_amplitude_as_int32;
    231     output_max = real_zero_as_int32 + real_amplitude_as_int32;
    232 
    233     double inverse_amplitude_normalized_double = 1.0 / real_amplitude_as_int32;
    234     inverse_amplitude_neg_exponent = 0;
    235     while (inverse_amplitude_normalized_double < 0.5) {
    236       inverse_amplitude_normalized_double *= 2;
    237       inverse_amplitude_neg_exponent++;
    238     }
    239     inverse_amplitude_normalized = FixedPoint<DataType, 0>::FromDouble(
    240         inverse_amplitude_normalized_double);
    241 
    242     double amplitude_normalized_double = real_amplitude_as_int32;
    243     amplitude_exponent = 0;
    244     while (amplitude_normalized_double >= 1.0) {
    245       amplitude_normalized_double *= 0.5;
    246       amplitude_exponent++;
    247     }
    248     amplitude_normalized =
    249         FixedPoint<DataType, 0>::FromDouble(amplitude_normalized_double);
    250   }
    251 
    252   OutputType Eval(InputType input) const {
    253     const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32;
    254 
    255     typedef FixedPoint<DataType, 3> F3;
    256     typedef FixedPoint<DataType, 0> F0;
    257 
    258     OutputType output;
    259 
    260     for (int i = 0; i < OutputType::kRegisterCount; i++) {
    261       // fixed-point affine transformation
    262       DataType input_centered =
    263           Sub(input.reg[i], Dup<DataType>(real_zero_as_int32));
    264       F3 fixedpoint_input =
    265           F3::FromRaw(input_centered) * inverse_amplitude_normalized;
    266       // left shift
    267       fixedpoint_input.raw() = ShiftLeft(fixedpoint_input.raw(),
    268                                          28 - inverse_amplitude_neg_exponent);
    269       // fixed-point tanh and multiplication
    270       F0 fixedpoint_output = tanh(fixedpoint_input) * amplitude_normalized;
    271       // right shift
    272       DataType int32_output =
    273           Add(Dup<DataType>(real_zero_as_int32),
    274               ShiftRight(fixedpoint_output.raw(), 31 - amplitude_exponent));
    275 
    276       DataType mask_if_below_cutoff_min =
    277           MaskIfLessThanOrEqual(input.reg[i], Dup<DataType>(input_cutoff_min));
    278       DataType mask_if_above_cutoff_max = MaskIfGreaterThanOrEqual(
    279           input.reg[i], Dup<DataType>(input_cutoff_max));
    280 
    281       output.reg[i] = SelectUsingMask(
    282           mask_if_below_cutoff_min, Dup<DataType>(output_min),
    283           SelectUsingMask(mask_if_above_cutoff_max, Dup<DataType>(output_max),
    284                           int32_output));
    285     }
    286     return output;
    287   }
    288 
    289   const OutputStage& output_stage;
    290   std::int32_t input_cutoff_min, input_cutoff_max;
    291   std::int32_t output_min, output_max;
    292   FixedPoint<DataType, 0> inverse_amplitude_normalized;
    293   int inverse_amplitude_neg_exponent;
    294   FixedPoint<DataType, 0> amplitude_normalized;
    295   int amplitude_exponent;
    296 };
    297 
    298 // OutputPipelineOutputType is a helper to determine the output data type of a
    299 // pipeline, for a
    300 // given input data type. It is a recursive template; see the explanation on
    301 // OutputPipelineEvalImpl below.
    302 template <typename OutputPipelineType, int FirstStage, typename InputType,
    303           bool StopRecursion =
    304               FirstStage == std::tuple_size<OutputPipelineType>::value>
    305 struct OutputPipelineOutputType {
    306   typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type
    307       FirstStageType;
    308   typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType
    309       FirstStageOutputType;
    310   typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage + 1,
    311                                             FirstStageOutputType>::Type Type;
    312 };
    313 
    314 template <typename OutputPipelineType, int FirstStage, typename InputType>
    315 struct OutputPipelineOutputType<OutputPipelineType, FirstStage, InputType,
    316                                 true> {
    317   typedef InputType Type;
    318 };
    319 
    320 // OutputPipelineEvalImpl is a helper to implement the evaluation of
    321 // the whole pipeline. It is a recursive template to implement compile-time
    322 // unrolling of the loop over all pipeline stages. The 'FirstStage' parameter
    323 // is how we implement recursion: each specialization implements only
    324 // evaluation starting at 'FirstStage'. The StopRecursion parameter is just a
    325 // helper to implement the termination of the recursion as a partial
    326 // specialization below.
    327 template <typename OutputPipelineType, int FirstStage, typename InputType,
    328           bool StopRecursion =
    329               FirstStage == std::tuple_size<OutputPipelineType>::value>
    330 struct OutputPipelineEvalImpl {
    331   typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type
    332       FirstStageType;
    333   typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType
    334       FirstStageOutputType;
    335   typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage,
    336                                             InputType>::Type OutputType;
    337 
    338   OutputPipelineEvalImpl(const OutputPipelineType& output_pipeline)
    339       : head_impl(std::get<FirstStage>(output_pipeline)),
    340         tail_impl(output_pipeline) {}
    341 
    342   OutputType Eval(InputType input, int row, int col) const {
    343     // Evaluate the first stage.
    344     FirstStageOutputType first_stage_output = head_impl.Eval(input, row, col);
    345     // Recurse into the remaining stages.
    346     return tail_impl.Eval(first_stage_output, row, col);
    347   }
    348 
    349   const OutputStageEvalImpl<FirstStageType, InputType> head_impl;
    350   const OutputPipelineEvalImpl<OutputPipelineType, FirstStage + 1,
    351                                FirstStageOutputType>
    352       tail_impl;
    353 };
    354 
    355 // Specialization on 'StopRecursion' for terminating the recursion.
    356 template <typename OutputPipelineType, int FirstStage, typename InputType>
    357 struct OutputPipelineEvalImpl<OutputPipelineType, FirstStage, InputType, true> {
    358   OutputPipelineEvalImpl(const OutputPipelineType&) {}
    359 
    360   InputType Eval(InputType input, int, int) const {
    361     // Terminating the recursion.
    362     return input;
    363   }
    364 };
    365 
    366 template <typename RegisterBlockType, typename DstType>
    367 struct StoreFinalOutputImpl {
    368   static_assert(std::is_same<RegisterBlockType, void>::value,
    369                 "This generic impl should never be hit");
    370 };
    371 
    372 template <typename ScalarType, int Rows, int Cols, typename DstType>
    373 struct StoreFinalOutputImpl<RegisterBlock<ScalarType, Rows, Cols>, DstType> {
    374   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
    375   static void Run(const RegisterBlockType& src, DstType* dst, int row,
    376                   int col) {
    377     for (int r = 0; r < Rows; r++) {
    378       for (int c = 0; c < Cols; c++) {
    379         *dst->data(row + r, col + c) = src.buf.reg[r + c * Rows];
    380       }
    381     }
    382   }
    383 };
    384 
    385 // StoreFinalOutput takes the final value at the end of the output pipeline and
    386 // stores it into the destination matrix. It can be specialized for different
    387 // data types; the generic implementation here is typically used only for plain
    388 // old scalar (not SIMD) types.
    389 template <typename RegisterBlockType, typename DstType>
    390 void StoreFinalOutput(RegisterBlockType src, DstType* dst, int row, int col) {
    391   StoreFinalOutputImpl<RegisterBlockType, DstType>::Run(src, dst, row, col);
    392 }
    393 
    394 template <typename OutputPipelineType, typename InputType>
    395 struct OutputPipelineExecutor {
    396   OutputPipelineExecutor(const OutputPipelineType& output_pipeline)
    397       : output_pipeline_eval_impl_(output_pipeline) {}
    398 
    399   // RunOutputPipeline is the entry point into the output pipeline evaluation
    400   // code. It should be the only thing that unpack code calls. It takes the
    401   // result
    402   // of the unpack stage and stores it into the destination matrix.
    403   template <typename DstType>
    404   void Execute(InputType input, DstType* dst, int src_global_row,
    405                int src_global_col, int dst_row, int dst_col) const {
    406     // Statically assert that the output pipeline matches the given destination
    407     // matrix's scalar type.
    408     typedef typename OutputPipelineOutputType<
    409         OutputPipelineType, 0, InputType>::Type::BufferType::ScalarType
    410 
    411         ScalarOutputType;
    412     typedef typename DstType::Scalar ScalarDstType;
    413     static_assert(std::is_same<ScalarOutputType, ScalarDstType>::value,
    414                   "mismatched destination scalar type and output pipeline");
    415 
    416     // Evaluate the output pipeline.
    417     auto output =
    418         output_pipeline_eval_impl_.Eval(input, src_global_row, src_global_col);
    419     // Store the result into the destination matrix.
    420     StoreFinalOutput(output, dst, dst_row, dst_col);
    421   }
    422 
    423   const OutputPipelineEvalImpl<OutputPipelineType, 0, InputType>
    424       output_pipeline_eval_impl_;
    425 };
    426 
    427 }  // namespace gemmlowp
    428 
    429 #ifdef GEMMLOWP_NEON
    430 #include "output_neon.h"
    431 #elif defined(GEMMLOWP_SSE4)
    432 #include "output_sse.h"
    433 #endif
    434 
    435 #endif  // GEMMLOWP_INTERNAL_OUTPUT_H_
    436