Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // output_neon.h: optimized NEON specializations of the templates in output.h.
     16 
     17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
     18 #define GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
     19 
     20 #include "output.h"
     21 
     22 #include <arm_neon.h>
     23 
     24 namespace gemmlowp {
     25 
     26 template <>
     27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     28                                  RegBufferInt32<4>> {
     29   typedef RegBufferInt32<4> InputType;
     30   typedef RegBufferUint8<4> OutputType;
     31 
     32   typedef OutputStageSaturatingCastToUint8 OutputStage;
     33 
     34   OutputStageEvalBufferImpl(const OutputStage&) {}
     35 
     36   OutputType Eval(InputType input) const {
     37     OutputType output;
     38     int16x4_t res_16 = vqmovn_s32(input.reg[0]);
     39     uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16));
     40     output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0);
     41     return output;
     42   }
     43 };
     44 
     45 template <>
     46 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     47                                  RegBufferInt32<8>> {
     48   typedef RegBufferInt32<8> InputType;
     49   typedef RegBufferUint8<8> OutputType;
     50 
     51   typedef OutputStageSaturatingCastToUint8 OutputStage;
     52 
     53   OutputStageEvalBufferImpl(const OutputStage&) {}
     54 
     55   OutputType Eval(InputType input) const {
     56     OutputType output;
     57     int16x8_t res_16 =
     58         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
     59     output.reg[0] = vqmovun_s16(res_16);
     60     return output;
     61   }
     62 };
     63 
     64 template <>
     65 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     66                                  RegBufferInt32<16>> {
     67   typedef RegBufferInt32<16> InputType;
     68   typedef RegBufferUint8<16> OutputType;
     69 
     70   typedef OutputStageSaturatingCastToUint8 OutputStage;
     71 
     72   OutputStageEvalBufferImpl(const OutputStage&) {}
     73 
     74   OutputType Eval(InputType input) const {
     75     OutputType output;
     76     int16x8_t res_16_0 =
     77         vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1]));
     78     int16x8_t res_16_1 =
     79         vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3]));
     80     output.reg[0] = vqmovun_s16(res_16_0);
     81     output.reg[1] = vqmovun_s16(res_16_1);
     82     return output;
     83   }
     84 };
     85 
     86 template <>
     87 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
     88                                  RegBufferInt32<32>> {
     89   typedef RegBufferInt32<32> InputType;
     90   typedef RegBufferUint8<32> OutputType;
     91 
     92   typedef OutputStageSaturatingCastToUint8 OutputStage;
     93 
     94   OutputStageEvalBufferImpl(const OutputStage&) {}
     95 
     96   OutputType Eval(InputType input) const {
     97     OutputType output;
     98     int16x8_t res_16[4];
     99     for (int i = 0; i < 4; i++) {
    100       res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]),
    101                                vqmovn_s32(input.reg[2 * i + 1]));
    102     }
    103     for (int i = 0; i < 4; i++) {
    104       output.reg[i] = vqmovun_s16(res_16[i]);
    105     }
    106     return output;
    107   }
    108 };
    109 
    110 template <typename DstType>
    111 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
    112   static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
    113                   int col) {
    114     if (DstType::kOrder == MapOrder::ColMajor) {
    115       StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
    116       StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
    117     } else {
    118       *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
    119       *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
    120       *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
    121       *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
    122       *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
    123       *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
    124       *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
    125       *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
    126     }
    127   }
    128 };
    129 
    130 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
    131   const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]);
    132   const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]);
    133   RegBlockInt32<4, 4> result;
    134   result.buf.reg[0] =
    135       vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0]));
    136   result.buf.reg[1] =
    137       vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1]));
    138   result.buf.reg[2] =
    139       vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0]));
    140   result.buf.reg[3] =
    141       vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1]));
    142   return result;
    143 }
    144 
    145 template <typename DstType>
    146 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
    147   static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
    148                   int col) {
    149     const auto& block =
    150         DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
    151     std::int32_t* dst_ptr = dst->data(row, col);
    152     int stride = dst->stride();
    153     for (int i = 0; i < 4; i++) {
    154       vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]);
    155     }
    156   }
    157 };
    158 
    159 template <typename DstType>
    160 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
    161   static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
    162                   int col) {
    163     std::int32_t* dst_ptr = dst->data(row, col);
    164     if (DstType::kOrder == MapOrder::ColMajor) {
    165       int col_stride = dst->cols_stride();
    166       for (int i = 0; i < 4; i++) {
    167         vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]);
    168         vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]);
    169       }
    170     } else {
    171       int row_stride = dst->rows_stride();
    172       RegBlockInt32<4, 4> top;
    173       top.buf.reg[0] = src.buf.reg[0];
    174       top.buf.reg[1] = src.buf.reg[2];
    175       top.buf.reg[2] = src.buf.reg[4];
    176       top.buf.reg[3] = src.buf.reg[6];
    177       const auto transpose_top = Transpose(top);
    178       for (int i = 0; i < 4; i++) {
    179         vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]);
    180       }
    181       RegBlockInt32<4, 4> bottom;
    182       bottom.buf.reg[0] = src.buf.reg[1];
    183       bottom.buf.reg[1] = src.buf.reg[3];
    184       bottom.buf.reg[2] = src.buf.reg[5];
    185       bottom.buf.reg[3] = src.buf.reg[7];
    186       const auto transpose_bottom = Transpose(bottom);
    187       for (int i = 0; i < 4; i++) {
    188         vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]);
    189       }
    190     }
    191   }
    192 };
    193 
    194 template <typename DstType>
    195 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
    196   static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
    197                   int col) {
    198     std::int32_t* dst_ptr = dst->data(row, col);
    199     if (DstType::kOrder == MapOrder::ColMajor) {
    200       int col_stride = dst->cols_stride();
    201       for (int i = 0; i < 8; i++) {
    202         vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]);
    203         vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]);
    204       }
    205     } else {
    206       int row_stride = dst->rows_stride();
    207       RegBlockInt32<4, 4> top_left;
    208       top_left.buf.reg[0] = src.buf.reg[0];
    209       top_left.buf.reg[1] = src.buf.reg[2];
    210       top_left.buf.reg[2] = src.buf.reg[4];
    211       top_left.buf.reg[3] = src.buf.reg[6];
    212       const auto transpose_top_left = Transpose(top_left);
    213       for (int i = 0; i < 4; i++) {
    214         vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]);
    215       }
    216       RegBlockInt32<4, 4> bottom_left;
    217       bottom_left.buf.reg[0] = src.buf.reg[1];
    218       bottom_left.buf.reg[1] = src.buf.reg[3];
    219       bottom_left.buf.reg[2] = src.buf.reg[5];
    220       bottom_left.buf.reg[3] = src.buf.reg[7];
    221       const auto transpose_bottom_left = Transpose(bottom_left);
    222       for (int i = 0; i < 4; i++) {
    223         vst1q_s32(dst_ptr + (i + 4) * row_stride,
    224                   transpose_bottom_left.buf.reg[i]);
    225       }
    226       RegBlockInt32<4, 4> top_right;
    227       top_right.buf.reg[0] = src.buf.reg[8];
    228       top_right.buf.reg[1] = src.buf.reg[10];
    229       top_right.buf.reg[2] = src.buf.reg[12];
    230       top_right.buf.reg[3] = src.buf.reg[14];
    231       const auto transpose_top_right = Transpose(top_right);
    232       for (int i = 0; i < 4; i++) {
    233         vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]);
    234       }
    235       RegBlockInt32<4, 4> bottom_right;
    236       bottom_right.buf.reg[0] = src.buf.reg[9];
    237       bottom_right.buf.reg[1] = src.buf.reg[11];
    238       bottom_right.buf.reg[2] = src.buf.reg[13];
    239       bottom_right.buf.reg[3] = src.buf.reg[15];
    240       const auto transpose_bottom_right = Transpose(bottom_right);
    241       for (int i = 0; i < 4; i++) {
    242         vst1q_s32(dst_ptr + (i + 4) * row_stride + 4,
    243                   transpose_bottom_right.buf.reg[i]);
    244       }
    245     }
    246   }
    247 };
    248 
    249 template <typename DstType>
    250 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
    251   static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
    252                   int col) {
    253     std::int32_t* dst_ptr = dst->data(row, col);
    254     if (DstType::kOrder == MapOrder::ColMajor) {
    255       vst1q_s32(dst_ptr, src.buf.reg[0]);
    256     } else {
    257       int row_stride = dst->rows_stride();
    258       vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
    259       vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
    260       vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
    261       vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
    262     }
    263   }
    264 };
    265 
    266 template <typename DstType>
    267 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
    268   static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
    269                   int col) {
    270     std::int32_t* dst_ptr = dst->data(row, col);
    271     if (DstType::kOrder == MapOrder::RowMajor) {
    272       vst1q_s32(dst_ptr, src.buf.reg[0]);
    273     } else {
    274       int col_stride = dst->cols_stride();
    275       vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0);
    276       vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1);
    277       vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2);
    278       vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3);
    279     }
    280   }
    281 };
    282 
    283 template <typename DstType>
    284 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
    285   static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
    286                   int col) {
    287     const std::uint32_t src_reg = src.buf.reg[0];
    288     for (int i = 0; i < 4; i++) {
    289       *dst->data(row + i, col) = (src_reg >> (8 * i));
    290     }
    291   }
    292 };
    293 
    294 template <typename DstType>
    295 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
    296   static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
    297                   int col) {
    298     for (int i = 0; i < 4; i++) {
    299       *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
    300     }
    301   }
    302 };
    303 
    304 template <typename DstType>
    305 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
    306   static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
    307                   int col) {
    308     std::uint8_t* dst_ptr = dst->data(row, col);
    309     if (DstType::kOrder == MapOrder::ColMajor) {
    310       vst1_u8(dst_ptr, src.buf.reg[0]);
    311     } else {
    312       const int row_stride = dst->rows_stride();
    313       vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0);
    314       vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1);
    315       vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2);
    316       vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3);
    317       vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4);
    318       vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5);
    319       vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6);
    320       vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7);
    321     }
    322   }
    323 };
    324 
    325 template <typename DstType>
    326 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
    327   static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
    328                   int col) {
    329     std::uint8_t* dst_ptr = dst->data(row, col);
    330     const int row_stride = dst->rows_stride();
    331     const int col_stride = dst->cols_stride();
    332     for (int i = 0; i < 2; i++) {
    333       vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride,
    334                    src.buf.reg[i], 0);
    335       vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride,
    336                    src.buf.reg[i], 1);
    337       vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride,
    338                    src.buf.reg[i], 2);
    339       vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride,
    340                    src.buf.reg[i], 3);
    341       vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride,
    342                    src.buf.reg[i], 4);
    343       vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride,
    344                    src.buf.reg[i], 5);
    345       vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride,
    346                    src.buf.reg[i], 6);
    347       vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride,
    348                    src.buf.reg[i], 7);
    349     }
    350   }
    351 };
    352 
    353 template <typename DstType>
    354 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
    355   static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
    356                   int col) {
    357     std::uint8_t* dst_ptr = dst->data(row, col);
    358     if (DstType::kOrder == MapOrder::ColMajor) {
    359       int col_stride = dst->cols_stride();
    360       for (int i = 0; i < 4; i++) {
    361         vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]);
    362       }
    363     } else {
    364       for (int i = 0; i < 4; i++) {
    365         int row_stride = dst->rows_stride();
    366         std::uint8_t* col_ptr = dst_ptr + i;
    367         vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0);
    368         vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1);
    369         vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2);
    370         vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3);
    371         vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4);
    372         vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5);
    373         vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6);
    374         vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7);
    375       }
    376     }
    377   }
    378 };
    379 
    380 inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) {
    381   uint8x8x2_t a[4];
    382   a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]);
    383   a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]);
    384   a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]);
    385   a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]);
    386   uint16x4x2_t b[4];
    387   b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]),
    388                   vreinterpret_u16_u8(a[1].val[0]));
    389   b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]),
    390                   vreinterpret_u16_u8(a[1].val[1]));
    391   b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]),
    392                   vreinterpret_u16_u8(a[3].val[0]));
    393   b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]),
    394                   vreinterpret_u16_u8(a[3].val[1]));
    395   uint32x2x2_t c[4];
    396   c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]),
    397                   vreinterpret_u32_u16(b[2].val[0]));
    398   c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]),
    399                   vreinterpret_u32_u16(b[3].val[0]));
    400   c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]),
    401                   vreinterpret_u32_u16(b[2].val[1]));
    402   c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]),
    403                   vreinterpret_u32_u16(b[3].val[1]));
    404   RegBlockUint8<8, 8> result;
    405   result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]);
    406   result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]);
    407   result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]);
    408   result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]);
    409   result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]);
    410   result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]);
    411   result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]);
    412   result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]);
    413   return result;
    414 }
    415 
    416 template <typename DstType>
    417 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
    418   static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
    419                   int col) {
    420     const auto& block =
    421         DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src);
    422     std::uint8_t* dst_ptr = dst->data(row, col);
    423     int stride = dst->stride();
    424     for (int i = 0; i < 8; i++) {
    425       vst1_u8(dst_ptr + i * stride, block.buf.reg[i]);
    426     }
    427   }
    428 };
    429 
    430 }  // namespace gemmlowp
    431 
    432 #endif  // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_
    433