Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // output_sse.h: optimized SSE4.2 specializations of the templates in output.h.
     16 
     17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
     18 #define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
     19 
     20 #include "output.h"
     21 
     22 #include <smmintrin.h>
     23 
     24 namespace gemmlowp {
     25 
     26 template <>
     27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     28                                  RegBufferInt32<4>> {
     29   typedef RegBufferInt32<4> InputType;
     30   typedef RegBufferUint8<4> OutputType;
     31 
     32   typedef OutputStageSaturatingCastToUint8 OutputStage;
     33 
     34   OutputStageEvalBufferImpl(const OutputStage&) {}
     35 
     36   OutputType Eval(InputType input) const {
     37     OutputType output;
     38     __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
     39     __m128i res_8 = _mm_packus_epi16(res_16, res_16);
     40     output.reg[0] = _mm_cvtsi128_si32(res_8);
     41     return output;
     42   }
     43 };
     44 
     45 template <>
     46 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     47                                  RegBufferInt32<8>> {
     48   typedef RegBufferInt32<8> InputType;
     49   typedef RegBufferUint8<8> OutputType;
     50 
     51   typedef OutputStageSaturatingCastToUint8 OutputStage;
     52 
     53   OutputStageEvalBufferImpl(const OutputStage&) {}
     54 
     55   OutputType Eval(InputType input) const {
     56     OutputType output;
     57     __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]);
     58     __m128i res_8 = _mm_packus_epi16(res_16, res_16);
     59     output.reg[0] = _mm_extract_epi32(res_8, 0);
     60     output.reg[1] = _mm_extract_epi32(res_8, 1);
     61     return output;
     62   }
     63 };
     64 
     65 template <>
     66 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     67                                  RegBufferInt32<16>> {
     68   typedef RegBufferInt32<16> InputType;
     69   typedef RegBufferUint8<16> OutputType;
     70 
     71   typedef OutputStageSaturatingCastToUint8 OutputStage;
     72 
     73   OutputStageEvalBufferImpl(const OutputStage&) {}
     74 
     75   OutputType Eval(InputType input) const {
     76     OutputType output;
     77     __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
     78     __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
     79     output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
     80     return output;
     81   }
     82 };
     83 
     84 template <>
     85 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     86                                  RegBufferInt32<32>> {
     87   typedef RegBufferInt32<32> InputType;
     88   typedef RegBufferUint8<32> OutputType;
     89 
     90   typedef OutputStageSaturatingCastToUint8 OutputStage;
     91 
     92   OutputStageEvalBufferImpl(const OutputStage&) {}
     93 
     94   OutputType Eval(InputType input) const {
     95     OutputType output;
     96     __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
     97     __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
     98     output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
     99     __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]);
    100     __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]);
    101     output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3);
    102     return output;
    103   }
    104 };
    105 
    106 template <typename DstType>
    107 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
    108   static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
    109                   int col) {
    110     if (DstType::kOrder == MapOrder::ColMajor) {
    111       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
    112     } else {
    113       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
    114       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
    115       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
    116       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
    117     }
    118   }
    119 };
    120 
    121 template <typename DstType>
    122 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
    123   static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
    124                   int col) {
    125     if (DstType::kOrder == MapOrder::ColMajor) {
    126       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
    127       StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
    128     } else {
    129       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
    130       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
    131       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
    132       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
    133       *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
    134       *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
    135       *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
    136       *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
    137     }
    138   }
    139 };
    140 
    141 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
    142   __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]);
    143   __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]);
    144   __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]);
    145   __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]);
    146 
    147   RegBlockInt32<4, 4> result;
    148   result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1);
    149   result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1);
    150   result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3);
    151   result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3);
    152   return result;
    153 }
    154 
    155 template <typename DstType>
    156 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
    157   static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
    158                   int col) {
    159     if (DstType::kOrder == MapOrder::ColMajor) {
    160       for (int i = 0; i < 4; i++) {
    161         StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
    162       }
    163     } else {
    164       const auto transpose = Transpose(src);
    165       for (int i = 0; i < 4; i++) {
    166         StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
    167       }
    168     }
    169   }
    170 };
    171 
    172 template <typename DstType>
    173 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
    174   static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
    175                   int col) {
    176     if (DstType::kOrder == MapOrder::ColMajor) {
    177       for (int i = 0; i < 4; i++) {
    178         StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
    179         StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
    180       }
    181     } else {
    182       RegBlockInt32<4, 4> top;
    183       top.buf.reg[0] = src.buf.reg[0];
    184       top.buf.reg[1] = src.buf.reg[2];
    185       top.buf.reg[2] = src.buf.reg[4];
    186       top.buf.reg[3] = src.buf.reg[6];
    187       const auto transpose_top = Transpose(top);
    188       for (int i = 0; i < 4; i++) {
    189         StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
    190       }
    191       RegBlockInt32<4, 4> bottom;
    192       bottom.buf.reg[0] = src.buf.reg[1];
    193       bottom.buf.reg[1] = src.buf.reg[3];
    194       bottom.buf.reg[2] = src.buf.reg[5];
    195       bottom.buf.reg[3] = src.buf.reg[7];
    196       const auto transpose_bottom = Transpose(bottom);
    197       for (int i = 0; i < 4; i++) {
    198         StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
    199       }
    200     }
    201   }
    202 };
    203 
    204 template <typename DstType>
    205 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
    206   static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
    207                   int col) {
    208     if (DstType::kOrder == MapOrder::ColMajor) {
    209       for (int i = 0; i < 8; i++) {
    210         StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
    211         StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
    212       }
    213     } else {
    214       RegBlockInt32<4, 4> top_left;
    215       top_left.buf.reg[0] = src.buf.reg[0];
    216       top_left.buf.reg[1] = src.buf.reg[2];
    217       top_left.buf.reg[2] = src.buf.reg[4];
    218       top_left.buf.reg[3] = src.buf.reg[6];
    219       const auto transpose_top_left = Transpose(top_left);
    220       for (int i = 0; i < 4; i++) {
    221         StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
    222       }
    223       RegBlockInt32<4, 4> bottom_left;
    224       bottom_left.buf.reg[0] = src.buf.reg[1];
    225       bottom_left.buf.reg[1] = src.buf.reg[3];
    226       bottom_left.buf.reg[2] = src.buf.reg[5];
    227       bottom_left.buf.reg[3] = src.buf.reg[7];
    228       const auto transpose_bottom_left = Transpose(bottom_left);
    229       for (int i = 0; i < 4; i++) {
    230         StoreInt32x4(dst->data(row + 4 + i, col),
    231                      transpose_bottom_left.buf.reg[i]);
    232       }
    233       RegBlockInt32<4, 4> top_right;
    234       top_right.buf.reg[0] = src.buf.reg[8];
    235       top_right.buf.reg[1] = src.buf.reg[10];
    236       top_right.buf.reg[2] = src.buf.reg[12];
    237       top_right.buf.reg[3] = src.buf.reg[14];
    238       const auto transpose_top_right = Transpose(top_right);
    239       for (int i = 0; i < 4; i++) {
    240         StoreInt32x4(dst->data(row + i, col + 4),
    241                      transpose_top_right.buf.reg[i]);
    242       }
    243       RegBlockInt32<4, 4> bottom_right;
    244       bottom_right.buf.reg[0] = src.buf.reg[9];
    245       bottom_right.buf.reg[1] = src.buf.reg[11];
    246       bottom_right.buf.reg[2] = src.buf.reg[13];
    247       bottom_right.buf.reg[3] = src.buf.reg[15];
    248       const auto transpose_bottom_right = Transpose(bottom_right);
    249       for (int i = 0; i < 4; i++) {
    250         StoreInt32x4(dst->data(row + 4 + i, col + 4),
    251                      transpose_bottom_right.buf.reg[i]);
    252       }
    253     }
    254   }
    255 };
    256 
    257 template <typename DstType>
    258 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
    259   static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
    260                   int col) {
    261     if (DstType::kOrder == MapOrder::ColMajor) {
    262       *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
    263       *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
    264       *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
    265       *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
    266     } else {
    267       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
    268     }
    269   }
    270 };
    271 
    272 template <typename DstType>
    273 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
    274   static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
    275                   int col) {
    276     const std::uint32_t src_reg = src.buf.reg[0];
    277     for (int i = 0; i < 4; i++) {
    278       *dst->data(row + i, col) = (src_reg >> (8 * i));
    279     }
    280   }
    281 };
    282 
    283 template <typename DstType>
    284 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
    285   static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
    286                   int col) {
    287     for (int i = 0; i < 4; i++) {
    288       *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
    289     }
    290     for (int i = 0; i < 4; i++) {
    291       *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
    292     }
    293   }
    294 };
    295 
    296 template <typename DstType>
    297 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
    298   static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
    299                   int col) {
    300     for (int i = 0; i < 4; i++) {
    301       *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
    302     }
    303   }
    304 };
    305 
    306 template <typename DstType>
    307 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
    308   static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
    309                   int col) {
    310     std::uint8_t buf[16];
    311     StoreUint8x16(buf, src.buf.reg[0]);
    312     for (int c = 0; c < 4; c++) {
    313       for (int r = 0; r < 4; r++) {
    314         *dst->data(row + r, col + c) = buf[r + 4 * c];
    315       }
    316     }
    317   }
    318 };
    319 
    320 template <typename DstType>
    321 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
    322   static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
    323                   int col) {
    324     std::uint8_t buf[32];
    325     StoreUint8x16(buf, src.buf.reg[0]);
    326     StoreUint8x16(buf + 16, src.buf.reg[1]);
    327     for (int c = 0; c < 4; c++) {
    328       for (int r = 0; r < 8; r++) {
    329         *dst->data(row + r, col + c) = buf[r + 8 * c];
    330       }
    331     }
    332   }
    333 };
    334 
    335 template <typename DstType>
    336 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
    337   static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
    338                   int col) {
    339     std::uint8_t buf[64];
    340     StoreUint8x16(buf, src.buf.reg[0]);
    341     StoreUint8x16(buf + 16, src.buf.reg[1]);
    342     StoreUint8x16(buf + 32, src.buf.reg[2]);
    343     StoreUint8x16(buf + 48, src.buf.reg[3]);
    344     for (int c = 0; c < 8; c++) {
    345       for (int r = 0; r < 8; r++) {
    346         *dst->data(row + r, col + c) = buf[r + 8 * c];
    347       }
    348     }
    349   }
    350 };
    351 
    352 }  // namespace gemmlowp
    353 
    354 #endif  // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
    355