Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // pack_SSE.h: optimized SSE specializations of the templates in pack.h.
     16 
     17 #ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_
     18 #define GEMMLOWP_INTERNAL_PACK_SSE_H_
     19 
     20 #include <smmintrin.h>
     21 #include "pack.h"
     22 
     23 namespace gemmlowp {
     24 
     25 // TODO: Add DepthMajorUint8SideMap
     26 
     27 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
     28     WidthMajorUint8SideMap;
     29 
     30 template <int Cells>
     31 using WidthMajorSideFormatNCells4x2 =
     32     KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
     33 
     34 template <int Cells>
     35 class PackingRegisterBlock<
     36     WidthMajorUint8SideMap,
     37     PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > >
     38     : public PackingRegisterBlockBase<
     39           WidthMajorUint8SideMap,
     40           PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > {
     41  public:
     42   typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
     43   typedef typename KernelSideFormat::Cell CellFormat;
     44   static const int kCells = KernelSideFormat::kCells;
     45   static const int kCellWidth = CellFormat::kWidth;
     46   static const int kKernelWidth = CellFormat::kWidth * kCells;
     47   static const int kCellDepth = CellFormat::kDepth;
     48   static const int kCellSize = CellFormat::kSize;
     49 
     50   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
     51     std::uint8_t* dst_ptr = dst->current_data();
     52     const int width_stride = this->complete_src_.width_stride();
     53     int depth_step = 8;
     54 
     55     __m128i one = _mm_set1_epi16(1);
     56     for (int cell_start_depth = 0; cell_start_depth < kRegisterSize;
     57          cell_start_depth += depth_step) {
     58       for (int cell_start_width = 0; cell_start_width < kKernelWidth;
     59            cell_start_width += kCellWidth) {
     60         std::int32_t* cell_sums_of_each_slice_ptr =
     61             dst->sums_of_each_slice() + start_width + cell_start_width;
     62         const std::uint8_t* src_data =
     63             this->complete_src_.data(cell_start_width, cell_start_depth);
     64 
     65         __m128i xmm1 =
     66             _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0]));
     67         __m128i xmm2 = _mm_loadl_epi64(
     68             reinterpret_cast<const __m128i*>(&src_data[1 * width_stride]));
     69         __m128i xmm3 = _mm_loadl_epi64(
     70             reinterpret_cast<const __m128i*>(&src_data[2 * width_stride]));
     71         __m128i xmm4 = _mm_loadl_epi64(
     72             reinterpret_cast<const __m128i*>(&src_data[3 * width_stride]));
     73 
     74         __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2);
     75         __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31);
     76 
     77         __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4);
     78         __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80);
     79 
     80         __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc);
     81         __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc);
     82 
     83         _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9);
     84         _mm_storel_epi64(
     85             reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10);
     86 
     87         __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee);
     88         __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee);
     89 
     90         _mm_storel_epi64(
     91             reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]),
     92             xmm11);
     93         _mm_storel_epi64(
     94             reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]),
     95             xmm12);
     96 
     97         xmm1 = _mm_cvtepu8_epi16(xmm9);
     98         xmm2 = _mm_madd_epi16(xmm1, one);
     99         __m128i sums_of_each_slice_xmm = _mm_loadu_si128(
    100             reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0]));
    101         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
    102 
    103         xmm1 = _mm_cvtepu8_epi16(xmm10);
    104         xmm2 = _mm_madd_epi16(xmm1, one);
    105         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
    106 
    107         xmm1 = _mm_cvtepu8_epi16(xmm11);
    108         xmm2 = _mm_madd_epi16(xmm1, one);
    109         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
    110 
    111         xmm1 = _mm_cvtepu8_epi16(xmm12);
    112         xmm2 = _mm_madd_epi16(xmm1, one);
    113         sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2);
    114 
    115         _mm_storeu_si128(
    116             reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]),
    117             sums_of_each_slice_xmm);
    118         dst_ptr += kCellSize;
    119       }
    120       dst_ptr += 3 * kCellSize * kCells;
    121     }
    122     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
    123   }
    124 };
    125 
    126 }  // namespace gemmlowp
    127 
    128 #endif  // GEMMLOWP_INTERNAL_PACK_SSE_H_
    129