Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // pack_SSE.h: optimized SSE specializations of the templates in pack.h.
     16 
     17 #ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_
     18 #define GEMMLOWP_INTERNAL_PACK_SSE_H_
     19 
     20 #include <smmintrin.h>
     21 #include "pack.h"
     22 
     23 namespace gemmlowp {
     24 
     25 // Requantizes source values pointed by raw_src_ptr in [0..255] range
     26 // to the range specified by BitDepth, [0..((2^bits)-1)].
     27 // This is in-place requantization, where the input is
     28 // not modified if 8bit integers are used. SSE does not
     29 // have less than 8bit kernels currently. Altought SSE registers
     30 // hold 16 uint8_t elements, only first 8 uint8_t elements are
     31 // requantized. The packing only use first 8 uint8_t elements
     32 // of the SSE registers. Therefore, requantizing all 16 uint8_t
     33 // elements will be wasteful computation.
     34 template <typename QuantizationParams>
     35 void SSERequantize(
     36     __m128i* raw_src_ptr,
     37     ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode>*
     38         rounding_offset_generator) {
     39   static const int kBits = QuantizationParams::BitDepth::kBits;
     40   static const std::uint8_t kMaxVal = (1 << kBits) - 1;
     41   if (kBits == 8) {
     42     return;
     43   }
     44 
     45   std::uint8_t* raw_src_ui8_ptr = (std::uint8_t*)&raw_src_ptr[0];
     46 
     47   // modify only first 8 elements in the register (see note above)
     48   for (int i = 0; i < 8; ++i) {
     49     std::uint16_t scaled =
     50         static_cast<std::uint16_t>(raw_src_ui8_ptr[i]) * kMaxVal;
     51     std::uint8_t rounding_offset = rounding_offset_generator->get();
     52     raw_src_ui8_ptr[i] = (scaled + rounding_offset) / 255;
     53   }
     54 }
     55 
     56 // TODO: Add DepthMajorUint8SideMap
     57 
     58 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
     59     WidthMajorUint8SideMap;
     60 
     61 template <int Cells>
     62 using WidthMajorSideFormatNCells4x2 =
     63     KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
     64 
     65 template <typename QuantizationParams, int Cells>
     66 class PackingRegisterBlock<
     67     QuantizationParams, WidthMajorUint8SideMap,
     68     PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > >
     69     : public PackingRegisterBlockBase<
     70           QuantizationParams, WidthMajorUint8SideMap,
     71           PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > {
     72  public:
     73   typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
     74   typedef typename KernelSideFormat::Cell CellFormat;
     75   static const int kCells = KernelSideFormat::kCells;
     76   static const int kCellWidth = CellFormat::kWidth;
     77   static const int kKernelWidth = CellFormat::kWidth * kCells;
     78   static const int kCellDepth = CellFormat::kDepth;
     79   static const int kCellSize = CellFormat::kSize;
     80 
     81   typedef ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode>
     82       RoundingOffsetGenerator;
     83 
     84   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width,
     85             RoundingOffsetGenerator* rounding_offset_generator) {
     86     std::uint8_t* dst_ptr = dst->current_data();
     87     const int width_stride = this->complete_src_.width_stride();
     88     int depth_step = 8;
     89 
     90     __m128i one = _mm_set1_epi16(1);
     91     for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
     92          cell_start_depth += depth_step) {
     93       for (int cell_start_width = 0; cell_start_width < kKernelWidth;
     94            cell_start_width += kCellWidth) {
     95         std::int32_t* cell_sums_of_each_slice_ptr =
     96             dst->sums_of_each_slice() + start_width + cell_start_width;
     97         const std::uint8_t* src_data =
     98             this->complete_src_.data(cell_start_width, cell_start_depth);
     99 
    100         __m128i xmm1 =
    101             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0]));
    102         __m128i xmm2 = _mm_loadl_epi64(
    103             reinterpret_cast<const __m128i*>(&src_data[1 * width_stride]));
    104         __m128i xmm3 = _mm_loadl_epi64(
    105             reinterpret_cast<const __m128i*>(&src_data[2 * width_stride]));
    106         __m128i xmm4 = _mm_loadl_epi64(
    107             reinterpret_cast<const __m128i*>(&src_data[3 * width_stride]));
    108 
    109         __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
    110         __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
    111 
    112         __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
    113         __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
    114 
    115         __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
    116         SSERequantize<QuantizationParams>(&xmm9, rounding_offset_generator);
    117 
    118         __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
    119         SSERequantize<QuantizationParams>(&xmm10, rounding_offset_generator);
    120 
    121         _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9);
    122         _mm_storel_epi64(
    123             reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10);
    124 
    125         __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
    126         SSERequantize<QuantizationParams>(&xmm11, rounding_offset_generator);
    127 
    128         __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
    129         SSERequantize<QuantizationParams>(&xmm12, rounding_offset_generator);
    130 
    131         _mm_storel_epi64(
    132             reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]),
    133             xmm11);
    134         _mm_storel_epi64(
    135             reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]),
    136             xmm12);
    137 
    138         xmm1 = _mm_cvtepu8_epi16(xmm9);
    139         xmm2 = _mm_madd_epi16(xmm1, one);
    140         __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
    141             reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0]));
    142         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
    143 
    144         xmm1 = _mm_cvtepu8_epi16(xmm10);
    145         xmm2 = _mm_madd_epi16(xmm1, one);
    146         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
    147 
    148         xmm1 = _mm_cvtepu8_epi16(xmm11);
    149         xmm2 = _mm_madd_epi16(xmm1, one);
    150         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
    151 
    152         xmm1 = _mm_cvtepu8_epi16(xmm12);
    153         xmm2 = _mm_madd_epi16(xmm1, one);
    154         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
    155 
    156         _mm_storeu_si128(
    157             reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]),
    158             sums_of_each_slice_xmm);
    159         dst_ptr += kCellSize;
    160       }
    161       dst_ptr += 3 * kCellSize * kCells;
    162     }
    163     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
    164   }
    165 };
    166 
    167 }  // namespace gemmlowp
    168 
    169 #endif  // GEMMLOWP_INTERNAL_PACK_SSE_H_
    170