Home | History | Annotate | Download | only in fixedpoint
      1 // Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // fixedpoint_msa.h: optimized MSA specializations of the templates
     16 // in fixedpoint.h.
     17 
     18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
     19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
     20 
     21 #include <msa.h>
     22 
     23 namespace gemmlowp {
     24 
     25 template <>
     26 struct FixedPointRawTypeTraits<v4i32> {
     27   typedef std::int32_t ScalarRawType;
     28   static const int kLanes = 4;
     29 };
     30 
     31 template <>
     32 struct FixedPointRawTypeTraits<v8i16> {
     33   typedef std::int16_t ScalarRawType;
     34   static const int kLanes = 8;
     35 };
     36 
     37 template <>
     38 inline v4i32 BitAnd(v4i32 a, v4i32 b) {
     39   return reinterpret_cast<v4i32>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a),
     40                                                      reinterpret_cast<v16u8>(b)));
     41 }
     42 
     43 template <>
     44 inline v8i16 BitAnd(v8i16 a, v8i16 b) {
     45   return reinterpret_cast<v8i16>(__builtin_msa_and_v(reinterpret_cast<v16u8>(a),
     46                                                      reinterpret_cast<v16u8>(b)));
     47 }
     48 
     49 template <>
     50 inline v4i32 BitOr(v4i32 a, v4i32 b) {
     51   return reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a),
     52                                                     reinterpret_cast<v16u8>(b)));
     53 }
     54 
     55 template <>
     56 inline v8i16 BitOr(v8i16 a, v8i16 b) {
     57   return reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(a),
     58                                                     reinterpret_cast<v16u8>(b)));
     59 }
     60 
     61 template <>
     62 inline v4i32 BitXor(v4i32 a, v4i32 b) {
     63   return reinterpret_cast<v4i32>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a),
     64                                                      reinterpret_cast<v16u8>(b)));
     65 }
     66 
     67 template <>
     68 inline v8i16 BitXor(v8i16 a, v8i16 b) {
     69   return reinterpret_cast<v8i16>(__builtin_msa_xor_v(reinterpret_cast<v16u8>(a),
     70                                                      reinterpret_cast<v16u8>(b)));
     71 }
     72 
     73 template <>
     74 inline v4i32 BitNot(v4i32 a) {
     75   return reinterpret_cast<v4i32>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a),
     76                                                      reinterpret_cast<v16u8>(a)));
     77 }
     78 
     79 template <>
     80 inline v8i16 BitNot(v8i16 a) {
     81   return reinterpret_cast<v8i16>(__builtin_msa_nor_v(reinterpret_cast<v16u8>(a),
     82                                                      reinterpret_cast<v16u8>(a)));
     83 }
     84 
     85 template <>
     86 inline v4i32 Add(v4i32 a, v4i32 b) {
     87   return __builtin_msa_addv_w(a, b);
     88 }
     89 
     90 template <>
     91 inline v8i16 Add(v8i16 a, v8i16 b) {
     92   return __builtin_msa_addv_h(a, b);
     93 }
     94 
     95 template <>
     96 inline v4i32 Sub(v4i32 a, v4i32 b) {
     97   return __builtin_msa_subv_w(a, b);
     98 }
     99 
    100 template <>
    101 inline v8i16 Sub(v8i16 a, v8i16 b) {
    102   return __builtin_msa_subv_h(a, b);
    103 }
    104 
    105 template <>
    106 inline v4i32 Neg(v4i32 a) {
    107   v4i32 zeroes = __builtin_msa_ldi_w(0);
    108   return __builtin_msa_subv_w(zeroes, a);
    109 }
    110 
    111 template <>
    112 inline v8i16 Neg(v8i16 a) {
    113   v8i16 zeroes = __builtin_msa_ldi_h(0);
    114   return __builtin_msa_subv_h(zeroes, a);
    115 }
    116 
    117 template <>
    118 inline v4i32 ShiftLeft(v4i32 a, int offset) {
    119   return __builtin_msa_sll_w(a, __builtin_msa_fill_w(offset));
    120 }
    121 
    122 template <>
    123 inline v8i16 ShiftLeft(v8i16 a, int offset) {
    124   return __builtin_msa_sll_h(a, __builtin_msa_fill_h(offset));
    125 }
    126 
    127 template <>
    128 inline v4i32 ShiftRight(v4i32 a, int offset) {
    129   return __builtin_msa_sra_w(a, __builtin_msa_fill_w(offset));
    130 }
    131 
    132 template <>
    133 inline v8i16 ShiftRight(v8i16 a, int offset) {
    134   return __builtin_msa_sra_h(a, __builtin_msa_fill_h(offset));
    135 }
    136 
    137 template <>
    138 inline v4i32 SelectUsingMask(v4i32 if_mask, v4i32 then_val, v4i32 else_val) {
    139   if_mask = reinterpret_cast<v4i32>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask),
    140                                                          reinterpret_cast<v16u8>(else_val),
    141                                                          reinterpret_cast<v16u8>(then_val)));
    142   return if_mask;
    143 }
    144 
    145 template <>
    146 inline v8i16 SelectUsingMask(v8i16 if_mask, v8i16 then_val, v8i16 else_val) {
    147   if_mask = reinterpret_cast<v8i16>(__builtin_msa_bsel_v(reinterpret_cast<v16u8>(if_mask),
    148                                                          reinterpret_cast<v16u8>(else_val),
    149                                                          reinterpret_cast<v16u8>(then_val)));
    150   return if_mask;
    151 }
    152 
    153 template <>
    154 inline v4i32 MaskIfEqual(v4i32 a, v4i32 b) {
    155   return __builtin_msa_ceq_w(a, b);
    156 }
    157 
    158 template <>
    159 inline v8i16 MaskIfEqual(v8i16 a, v8i16 b) {
    160   return __builtin_msa_ceq_h(a, b);
    161 }
    162 
    163 template <>
    164 inline v4i32 MaskIfNotEqual(v4i32 a, v4i32 b) {
    165   return BitNot(MaskIfEqual(a, b));
    166 }
    167 
    168 template <>
    169 inline v8i16 MaskIfNotEqual(v8i16 a, v8i16 b) {
    170   return BitNot(MaskIfEqual(a, b));
    171 }
    172 
    173 template <>
    174 inline v4i32 MaskIfZero(v4i32 a) {
    175   return __builtin_msa_ceqi_w(a, 0);
    176 }
    177 
    178 template <>
    179 inline v8i16 MaskIfZero(v8i16 a) {
    180   return __builtin_msa_ceqi_h(a, 0);
    181 }
    182 
    183 template <>
    184 inline v4i32 MaskIfNonZero(v4i32 a) {
    185   return BitNot(MaskIfZero(a));
    186 }
    187 
    188 template <>
    189 inline v8i16 MaskIfNonZero(v8i16 a) {
    190   return BitNot(MaskIfZero(a));
    191 }
    192 
    193 template <>
    194 inline v4i32 MaskIfGreaterThan(v4i32 a, v4i32 b) {
    195   return __builtin_msa_clt_s_w(b, a);
    196 }
    197 
    198 template <>
    199 inline v8i16 MaskIfGreaterThan(v8i16 a, v8i16 b) {
    200   return __builtin_msa_clt_s_h(b, a);
    201 }
    202 
    203 template <>
    204 inline v4i32 MaskIfGreaterThanOrEqual(v4i32 a, v4i32 b) {
    205   return __builtin_msa_cle_s_w(b, a);
    206 }
    207 
    208 template <>
    209 inline v8i16 MaskIfGreaterThanOrEqual(v8i16 a, v8i16 b) {
    210   return __builtin_msa_cle_s_h(b, a);
    211 }
    212 
    213 template <>
    214 inline v4i32 MaskIfLessThan(v4i32 a, v4i32 b) {
    215   return __builtin_msa_clt_s_w(a, b);
    216 }
    217 
    218 template <>
    219 inline v8i16 MaskIfLessThan(v8i16 a, v8i16 b) {
    220   return __builtin_msa_clt_s_h(a, b);
    221 }
    222 
    223 template <>
    224 inline v4i32 MaskIfLessThanOrEqual(v4i32 a, v4i32 b) {
    225   return __builtin_msa_cle_s_w(a, b);
    226 }
    227 
    228 template <>
    229 inline v8i16 MaskIfLessThanOrEqual(v8i16 a, v8i16 b) {
    230   return __builtin_msa_cle_s_h(a, b);
    231 }
    232 
    233 template <>
    234 inline bool All(v4i32 a) {
    235   return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a)));
    236 }
    237 
    238 template <>
    239 inline bool All(v8i16 a) {
    240   return __builtin_msa_bz_v(reinterpret_cast<v16u8>(BitNot(a)));
    241 }
    242 
    243 template <>
    244 inline bool Any(v4i32 a) {
    245   return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a));
    246 }
    247 
    248 template <>
    249 inline bool Any(v8i16 a) {
    250   return __builtin_msa_bnz_v(reinterpret_cast<v16u8>(a));
    251 }
    252 
    253 template <>
    254 inline v4i32 RoundingHalfSum(v4i32 a, v4i32 b) {
    255   return __builtin_msa_aver_s_w(a, b);
    256 }
    257 
    258 template <>
    259 inline v8i16 RoundingHalfSum(v8i16 a, v8i16 b) {
    260   return __builtin_msa_aver_s_h(a, b);
    261 }
    262 
    263 template <>
    264 inline v4i32 SaturatingRoundingDoublingHighMul(v4i32 a, v4i32 b) {
    265   return __builtin_msa_mulr_q_w(a, b);
    266 }
    267 
    268 template <>
    269 inline v8i16 SaturatingRoundingDoublingHighMul(v8i16 a, v8i16 b) {
    270   return __builtin_msa_mulr_q_h(a, b);
    271 }
    272 
    273 template <int Exponent>
    274 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, 1> {
    275   static v4i32 eval(v4i32 x) {
    276     static_assert(Exponent >= 0 && Exponent < 32, "");
    277     if (Exponent < 5) {
    278       for (int i = 0; i < Exponent; i++) {
    279         x = __builtin_msa_adds_s_w(x, x);
    280       }
    281       return x;
    282     } else {
    283       // Saturate each signed 32-bit element to (32 - Exponent)
    284       // bits (this takes full care of negative elements).
    285       v4i32 res = __builtin_msa_sat_s_w(x, 31 - Exponent);
    286       // Set tmp to 0x7FFFFFFF for those elements which staturated
    287       // to smaller (positive) values and 0 for all others.
    288       v4i32 tmp = __builtin_msa_srli_w(__builtin_msa_clt_s_w(res, x), 1);
    289       // Shift the saturated elements. The positive saturated elements
    290       // will have Exponent trailing zero bits after the shift. Those
    291       // need to be ones, not zeroes.
    292       res = __builtin_msa_slli_w(res, Exponent);
    293       // Finally, set those trailing zero bits to ones.
    294       res = reinterpret_cast<v4i32>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res),
    295                                                        reinterpret_cast<v16u8>(tmp)));
    296       return res;
    297     }
    298   }
    299 };
    300 
    301 template <int Exponent>
    302 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, 1> {
    303   static v8i16 eval(v8i16 x) {
    304     static_assert(Exponent >= 0 && Exponent < 16, "");
    305     if (Exponent < 5) {
    306       for (int i = 0; i < Exponent; i++) {
    307         x = __builtin_msa_adds_s_h(x, x);
    308       }
    309       return x;
    310     } else {
    311       // Saturate each signed 16-bit element to (16 - Exponent)
    312       // bits (this takes full care of negative elements).
    313       v8i16 res = __builtin_msa_sat_s_h(x, 15 - Exponent);
    314       // Set tmp to 0x7FFF for those elements which staturated
    315       // to smaller (positive) values and 0 for all others.
    316       v8i16 tmp = __builtin_msa_srli_h(__builtin_msa_clt_s_h(res, x), 1);
    317       // Shift the saturated elements. The positive saturated elements
    318       // will have Exponent trailing zero bits after the shift. Those
    319       // need to be ones, not zeroes.
    320       res = __builtin_msa_slli_h(res, Exponent);
    321       // Finally, set those trailing zero bits to ones.
    322       res = reinterpret_cast<v8i16>(__builtin_msa_or_v(reinterpret_cast<v16u8>(res),
    323                                                        reinterpret_cast<v16u8>(tmp)));
    324       return res;
    325     }
    326   }
    327 };
    328 
    329 // TODO: possibly implement:
    330 // template <> v4i32 RoundingDivideByPOT(v4i32, int)
    331 // template <> v8i16 RoundingDivideByPOT(v8i16, int)
    332 // template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v4i32, -1>
    333 // template <int Exponent> struct ImplSaturatingRoundingMultiplyByPOT<Exponent, v8i16, -1>
    334 
    335 template <>
    336 inline v4i32 Dup<v4i32>(std::int32_t x) {
    337   return __builtin_msa_fill_w(x);
    338 }
    339 
    340 template <>
    341 inline v8i16 Dup<v8i16>(std::int16_t x) {
    342   return __builtin_msa_fill_h(x);
    343 }
    344 
    345 // So far this is only needed for int16.
    346 template <>
    347 inline v8i16 SaturatingAdd(v8i16 a, v8i16 b) {
    348   return __builtin_msa_adds_s_h(a, b);
    349   return a;
    350 }
    351 
    352 }  // end namespace gemmlowp
    353 
    354 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_MSA_H_
    355