Home | History | Annotate | Download | only in internal
      1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // output_msa.h: optimized MSA specializations of the templates in output.h.
     16 
     17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
     18 #define GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
     19 
     20 #include "output.h"
     21 
     22 #include <msa.h>
     23 
     24 namespace gemmlowp {
     25 
     26 template <>
     27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     28                                  RegBufferInt32<4>> {
     29   typedef RegBufferInt32<4> InputType;
     30   typedef RegBufferUint8<4> OutputType;
     31 
     32   typedef OutputStageSaturatingCastToUint8 OutputStage;
     33 
     34   OutputStageEvalBufferImpl(const OutputStage&) {}
     35 
     36   OutputType Eval(InputType input) const {
     37     OutputType output;
     38     // Signed saturate each 32-bit element to 9 bits
     39     // (this takes full care of non-negative elements).
     40     v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8);
     41     // Pack every 32-bit element into 16 bits.
     42     tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
     43         reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp)));
     44     // Detect negative elements with arithmetic shift right (we
     45     // get a 16-bit mask of all zeroes or all ones for every element).
     46     v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15);
     47     // Zero out negative elements.
     48     signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
     49         reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0));
     50     // Pack every element into 8 bits.
     51     tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
     52         reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
     53     // Return 4 uint8_t elements as uint32_t.
     54     output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
     55     return output;
     56   }
     57 };
     58 
     59 template <>
     60 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     61                                  RegBufferInt32<8>> {
     62   typedef RegBufferInt32<8> InputType;
     63   typedef RegBufferUint8<8> OutputType;
     64 
     65   typedef OutputStageSaturatingCastToUint8 OutputStage;
     66 
     67   OutputStageEvalBufferImpl(const OutputStage&) {}
     68 
     69   OutputType Eval(InputType input) const {
     70     OutputType output;
     71     // Signed saturate each 32-bit element to 9 bits
     72     // (this takes full care of non-negative elements).
     73     v4i32 tmp_lo = __builtin_msa_sat_s_w(input.reg[0], 8);
     74     v4i32 tmp_hi = __builtin_msa_sat_s_w(input.reg[1], 8);
     75     // Pack every 32-bit element into 16 bits,
     76     // combining all 8 elements into one vector.
     77     tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
     78         reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo)));
     79     // Detect negative elements with arithmetic shift right (we
     80     // get a 16-bit mask of all zeroes or all ones for every element).
     81     v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15);
     82     // Zero out negative elements.
     83     signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
     84         reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0));
     85     // Pack every element into 8 bits.
     86     tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
     87         reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
     88     // Return 8 uint8_t elements as 2 uint32_t's.
     89     output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0);
     90     output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1);
     91     return output;
     92   }
     93 };
     94 
     95 #define GEMMLOWP_MIPS_SAT_U8_16(out, in0, in1, in2, in3)                     \
     96   {                                                                          \
     97     v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 8);                              \
     98     v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 8);                              \
     99     v4i32 tmp2 = __builtin_msa_sat_s_w(in2, 8);                              \
    100     v4i32 tmp3 = __builtin_msa_sat_s_w(in3, 8);                              \
    101     tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(                    \
    102         reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)));      \
    103     tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(                    \
    104         reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2)));      \
    105     v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15);  \
    106     v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15);  \
    107     signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(                  \
    108         reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \
    109     signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(                  \
    110         reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \
    111     signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b(                  \
    112         reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0)));  \
    113     out = reinterpret_cast<v16i8>(signs0);                                   \
    114   }
    115 
    116 template <>
    117 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
    118                                  RegBufferInt32<16>> {
    119   typedef RegBufferInt32<16> InputType;
    120   typedef RegBufferUint8<16> OutputType;
    121 
    122   typedef OutputStageSaturatingCastToUint8 OutputStage;
    123 
    124   OutputStageEvalBufferImpl(const OutputStage&) {}
    125 
    126   OutputType Eval(InputType input) const {
    127     OutputType output;
    128     GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
    129                             input.reg[2], input.reg[3]);
    130     return output;
    131   }
    132 };
    133 
    134 template <>
    135 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
    136                                  RegBufferInt32<32>> {
    137   typedef RegBufferInt32<32> InputType;
    138   typedef RegBufferUint8<32> OutputType;
    139 
    140   typedef OutputStageSaturatingCastToUint8 OutputStage;
    141 
    142   OutputStageEvalBufferImpl(const OutputStage&) {}
    143 
    144   OutputType Eval(InputType input) const {
    145     OutputType output;
    146     GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
    147                             input.reg[2], input.reg[3]);
    148     GEMMLOWP_MIPS_SAT_U8_16(output.reg[1], input.reg[4], input.reg[5],
    149                             input.reg[6], input.reg[7]);
    150     return output;
    151   }
    152 };
    153 
    154 #undef GEMMLOWP_MIPS_SAT_U8_16
    155 
    156 template <>
    157 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
    158                                  RegBufferInt32<4>> {
    159   typedef RegBufferInt32<4> InputType;
    160   typedef RegBufferInt16<4> OutputType;
    161 
    162   typedef OutputStageSaturatingCastToInt16 OutputStage;
    163 
    164   OutputStageEvalBufferImpl(const OutputStage&) {}
    165 
    166   OutputType Eval(InputType input) const {
    167     OutputType output;
    168     // Signed saturate each 32-bit element to 16 bits.
    169     v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(
    170         input.reg[0], 15));
    171     output.reg[0] = __builtin_msa_copy_s_h(tmp, 0);
    172     output.reg[1] = __builtin_msa_copy_s_h(tmp, 2);
    173     output.reg[2] = __builtin_msa_copy_s_h(tmp, 4);
    174     output.reg[3] = __builtin_msa_copy_s_h(tmp, 6);
    175     return output;
    176   }
    177 };
    178 
    179 #define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1)                         \
    180   {                                                                    \
    181     v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15);                       \
    182     v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15);                       \
    183     out = __builtin_msa_pckev_h(                                       \
    184         reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \
    185   }
    186 
    187 template <>
    188 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
    189                                  RegBufferInt32<8>> {
    190   typedef RegBufferInt32<8> InputType;
    191   typedef RegBufferInt16<8> OutputType;
    192 
    193   typedef OutputStageSaturatingCastToInt16 OutputStage;
    194 
    195   OutputStageEvalBufferImpl(const OutputStage&) {}
    196 
    197   OutputType Eval(InputType input) const {
    198     OutputType output;
    199     GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
    200     return output;
    201   }
    202 };
    203 
    204 template <>
    205 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
    206                                  RegBufferInt32<16>> {
    207   typedef RegBufferInt32<16> InputType;
    208   typedef RegBufferInt16<16> OutputType;
    209 
    210   typedef OutputStageSaturatingCastToInt16 OutputStage;
    211 
    212   OutputStageEvalBufferImpl(const OutputStage&) {}
    213 
    214   OutputType Eval(InputType input) const {
    215     OutputType output;
    216     GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
    217     GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
    218     return output;
    219   }
    220 };
    221 
    222 template <>
    223 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
    224                                  RegBufferInt32<32>> {
    225   typedef RegBufferInt32<32> InputType;
    226   typedef RegBufferInt16<32> OutputType;
    227 
    228   typedef OutputStageSaturatingCastToInt16 OutputStage;
    229 
    230   OutputStageEvalBufferImpl(const OutputStage&) {}
    231 
    232   OutputType Eval(InputType input) const {
    233     OutputType output;
    234     GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
    235     GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
    236     GEMMLOWP_MIPS_SAT_I16_8(output.reg[2], input.reg[4], input.reg[5]);
    237     GEMMLOWP_MIPS_SAT_I16_8(output.reg[3], input.reg[6], input.reg[7]);
    238     return output;
    239   }
    240 };
    241 
    242 #undef GEMMLOWP_MIPS_SAT_I16_8
    243 
    244 template <typename DstType>
    245 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
    246   static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
    247                   int col) {
    248     if (DstType::kOrder == MapOrder::ColMajor) {
    249       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
    250     } else {
    251       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
    252       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
    253       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
    254       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
    255     }
    256   }
    257 };
    258 
    259 template <typename DstType>
    260 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
    261   static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
    262                   int col) {
    263     if (DstType::kOrder == MapOrder::ColMajor) {
    264       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
    265       StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
    266     } else {
    267       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
    268       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
    269       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
    270       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
    271       *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
    272       *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
    273       *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
    274       *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
    275     }
    276   }
    277 };
    278 
    279 template <typename DstType>
    280 struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
    281   static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
    282                   int col) {
    283     *dst->data(row + 0, col) = src.buf.reg[0];
    284     *dst->data(row + 1, col) = src.buf.reg[1];
    285     *dst->data(row + 2, col) = src.buf.reg[2];
    286     *dst->data(row + 3, col) = src.buf.reg[3];
    287   }
    288 };
    289 
    290 template <typename DstType>
    291 struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
    292   static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
    293                   int col) {
    294     if (DstType::kOrder == MapOrder::ColMajor) {
    295       StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
    296     } else {
    297       *dst->data(row + 0, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 0);
    298       *dst->data(row + 1, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 1);
    299       *dst->data(row + 2, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 2);
    300       *dst->data(row + 3, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 3);
    301       *dst->data(row + 4, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 4);
    302       *dst->data(row + 5, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 5);
    303       *dst->data(row + 6, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 6);
    304       *dst->data(row + 7, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 7);
    305     }
    306   }
    307 };
    308 
    309 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
    310   RegBlockInt32<4, 4> result;
    311   v4i32 tmp0, tmp1;
    312   tmp0 = __builtin_msa_ilvr_w(src.buf.reg[1], src.buf.reg[0]);
    313   tmp1 = __builtin_msa_ilvr_w(src.buf.reg[3], src.buf.reg[2]);
    314   result.buf.reg[0] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
    315       reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
    316   result.buf.reg[1] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
    317       reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
    318   tmp0 = __builtin_msa_ilvl_w(src.buf.reg[1], src.buf.reg[0]);
    319   tmp1 = __builtin_msa_ilvl_w(src.buf.reg[3], src.buf.reg[2]);
    320   result.buf.reg[2] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
    321       reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
    322   result.buf.reg[3] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
    323       reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
    324   return result;
    325 }
    326 
    327 template <typename DstType>
    328 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
    329   static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
    330                   int col) {
    331     if (DstType::kOrder == MapOrder::ColMajor) {
    332       for (int i = 0; i < 4; i++) {
    333         StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
    334       }
    335     } else {
    336       const auto transpose = Transpose(src);
    337       for (int i = 0; i < 4; i++) {
    338         StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
    339       }
    340     }
    341   }
    342 };
    343 
    344 template <typename DstType>
    345 struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
    346   static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
    347                   int col) {
    348     std::int16_t buf[16];
    349     StoreInt16x8(buf + 0, src.buf.reg[0]);
    350     StoreInt16x8(buf + 8, src.buf.reg[1]);
    351     for (int i = 0; i < 4; i++) {
    352       for (int j = 0; j < 4; j++) {
    353         *dst->data(row + i, col + j) = buf[i + 4 * j];
    354       }
    355     }
    356   }
    357 };
    358 
    359 template <typename DstType>
    360 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
    361   static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
    362                   int col) {
    363     if (DstType::kOrder == MapOrder::ColMajor) {
    364       for (int i = 0; i < 4; i++) {
    365         StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
    366         StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
    367       }
    368     } else {
    369       RegBlockInt32<4, 4> top;
    370       top.buf.reg[0] = src.buf.reg[0];
    371       top.buf.reg[1] = src.buf.reg[2];
    372       top.buf.reg[2] = src.buf.reg[4];
    373       top.buf.reg[3] = src.buf.reg[6];
    374       const auto transpose_top = Transpose(top);
    375       for (int i = 0; i < 4; i++) {
    376         StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
    377       }
    378       RegBlockInt32<4, 4> bottom;
    379       bottom.buf.reg[0] = src.buf.reg[1];
    380       bottom.buf.reg[1] = src.buf.reg[3];
    381       bottom.buf.reg[2] = src.buf.reg[5];
    382       bottom.buf.reg[3] = src.buf.reg[7];
    383       const auto transpose_bottom = Transpose(bottom);
    384       for (int i = 0; i < 4; i++) {
    385         StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
    386       }
    387     }
    388   }
    389 };
    390 
    391 template <typename DstType>
    392 struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
    393   static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
    394                   int col) {
    395     if (DstType::kOrder == MapOrder::ColMajor) {
    396       for (int i = 0; i < 4; i++) {
    397         StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
    398       }
    399     } else {
    400       std::int16_t buf[32];
    401       StoreInt16x8(buf + 0, src.buf.reg[0]);
    402       StoreInt16x8(buf + 8, src.buf.reg[1]);
    403       StoreInt16x8(buf + 16, src.buf.reg[2]);
    404       StoreInt16x8(buf + 24, src.buf.reg[3]);
    405       for (int i = 0; i < 8; i++) {
    406         for (int j = 0; j < 4; j++) {
    407           *dst->data(row + i, col + j) = buf[i + 8 * j];
    408         }
    409       }
    410     }
    411   }
    412 };
    413 
    414 template <typename DstType>
    415 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
    416   static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
    417                   int col) {
    418     if (DstType::kOrder == MapOrder::ColMajor) {
    419       for (int i = 0; i < 8; i++) {
    420         StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
    421         StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
    422       }
    423     } else {
    424       RegBlockInt32<4, 4> top_left;
    425       top_left.buf.reg[0] = src.buf.reg[0];
    426       top_left.buf.reg[1] = src.buf.reg[2];
    427       top_left.buf.reg[2] = src.buf.reg[4];
    428       top_left.buf.reg[3] = src.buf.reg[6];
    429       const auto transpose_top_left = Transpose(top_left);
    430       for (int i = 0; i < 4; i++) {
    431         StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
    432       }
    433       RegBlockInt32<4, 4> bottom_left;
    434       bottom_left.buf.reg[0] = src.buf.reg[1];
    435       bottom_left.buf.reg[1] = src.buf.reg[3];
    436       bottom_left.buf.reg[2] = src.buf.reg[5];
    437       bottom_left.buf.reg[3] = src.buf.reg[7];
    438       const auto transpose_bottom_left = Transpose(bottom_left);
    439       for (int i = 0; i < 4; i++) {
    440         StoreInt32x4(dst->data(row + 4 + i, col),
    441                      transpose_bottom_left.buf.reg[i]);
    442       }
    443       RegBlockInt32<4, 4> top_right;
    444       top_right.buf.reg[0] = src.buf.reg[8];
    445       top_right.buf.reg[1] = src.buf.reg[10];
    446       top_right.buf.reg[2] = src.buf.reg[12];
    447       top_right.buf.reg[3] = src.buf.reg[14];
    448       const auto transpose_top_right = Transpose(top_right);
    449       for (int i = 0; i < 4; i++) {
    450         StoreInt32x4(dst->data(row + i, col + 4),
    451                      transpose_top_right.buf.reg[i]);
    452       }
    453       RegBlockInt32<4, 4> bottom_right;
    454       bottom_right.buf.reg[0] = src.buf.reg[9];
    455       bottom_right.buf.reg[1] = src.buf.reg[11];
    456       bottom_right.buf.reg[2] = src.buf.reg[13];
    457       bottom_right.buf.reg[3] = src.buf.reg[15];
    458       const auto transpose_bottom_right = Transpose(bottom_right);
    459       for (int i = 0; i < 4; i++) {
    460         StoreInt32x4(dst->data(row + 4 + i, col + 4),
    461                      transpose_bottom_right.buf.reg[i]);
    462       }
    463     }
    464   }
    465 };
    466 
    467 template <typename DstType>
    468 struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
    469   static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
    470                   int col) {
    471     if (DstType::kOrder == MapOrder::ColMajor) {
    472       for (int i = 0; i < 8; i++) {
    473         StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
    474       }
    475     } else {
    476       // top-left 4x4
    477       v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1],
    478           src.buf.reg[0]));
    479       v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3],
    480           src.buf.reg[2]));
    481       v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0));
    482       v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0));
    483       // top-right 4x4
    484       v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5],
    485           src.buf.reg[4]));
    486       v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7],
    487           src.buf.reg[6]));
    488       v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2));
    489       v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2));
    490       // bottom-left 4x4
    491       v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1],
    492           src.buf.reg[0]));
    493       v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3],
    494           src.buf.reg[2]));
    495       v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4));
    496       v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4));
    497       // bottom-right 4x4
    498       v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5],
    499           src.buf.reg[4]));
    500       v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7],
    501           src.buf.reg[6]));
    502       v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6));
    503       v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6));
    504 
    505       StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>(
    506           __builtin_msa_ilvr_d(u2, u0)));
    507       StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>(
    508           __builtin_msa_ilvl_d(u2, u0)));
    509       StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>(
    510           __builtin_msa_ilvr_d(u3, u1)));
    511       StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>(
    512           __builtin_msa_ilvl_d(u3, u1)));
    513       StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>(
    514           __builtin_msa_ilvr_d(u6, u4)));
    515       StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>(
    516           __builtin_msa_ilvl_d(u6, u4)));
    517       StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>(
    518           __builtin_msa_ilvr_d(u7, u5)));
    519       StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>(
    520           __builtin_msa_ilvl_d(u7, u5)));
    521     }
    522   }
    523 };
    524 
    525 template <typename DstType>
    526 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
    527   static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
    528                   int col) {
    529     if (DstType::kOrder == MapOrder::ColMajor) {
    530       *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
    531       *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
    532       *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
    533       *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
    534     } else {
    535       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
    536     }
    537   }
    538 };
    539 
    540 template <typename DstType>
    541 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
    542   static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
    543                   int col) {
    544     const std::uint32_t src_reg = src.buf.reg[0];
    545     for (int i = 0; i < 4; i++) {
    546       *dst->data(row + i, col) = (src_reg >> (8 * i));
    547     }
    548   }
    549 };
    550 
    551 template <typename DstType>
    552 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
    553   static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
    554                   int col) {
    555     for (int i = 0; i < 4; i++) {
    556       *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
    557     }
    558     for (int i = 0; i < 4; i++) {
    559       *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
    560     }
    561   }
    562 };
    563 
    564 template <typename DstType>
    565 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
    566   static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
    567                   int col) {
    568     for (int i = 0; i < 4; i++) {
    569       *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
    570     }
    571   }
    572 };
    573 
    574 template <typename DstType>
    575 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
    576   static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
    577                   int col) {
    578     std::uint8_t buf[16];
    579     StoreUint8x16(buf, src.buf.reg[0]);
    580     for (int c = 0; c < 4; c++) {
    581       for (int r = 0; r < 4; r++) {
    582         *dst->data(row + r, col + c) = buf[r + 4 * c];
    583       }
    584     }
    585   }
    586 };
    587 
    588 template <typename DstType>
    589 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
    590   static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
    591                   int col) {
    592     std::uint8_t buf[32];
    593     StoreUint8x16(buf, src.buf.reg[0]);
    594     StoreUint8x16(buf + 16, src.buf.reg[1]);
    595     for (int c = 0; c < 4; c++) {
    596       for (int r = 0; r < 8; r++) {
    597         *dst->data(row + r, col + c) = buf[r + 8 * c];
    598       }
    599     }
    600   }
    601 };
    602 
    603 template <typename DstType>
    604 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
    605   static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
    606                   int col) {
    607     std::uint8_t buf[64];
    608     StoreUint8x16(buf, src.buf.reg[0]);
    609     StoreUint8x16(buf + 16, src.buf.reg[1]);
    610     StoreUint8x16(buf + 32, src.buf.reg[2]);
    611     StoreUint8x16(buf + 48, src.buf.reg[3]);
    612     for (int c = 0; c < 8; c++) {
    613       for (int r = 0; r < 8; r++) {
    614         *dst->data(row + r, col + c) = buf[r + 8 * c];
    615       }
    616     }
    617   }
    618 };
    619 
    620 }  // namespace gemmlowp
    621 
    622 #endif  // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
    623