Home | History | Annotate | Download | only in fixedpoint
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // fixedpoint_neon.h: optimized NEON specializations of the templates
     16 // in fixedpoint.h.
     17 
     18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
     19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
     20 
     21 #include <arm_neon.h>
     22 
     23 namespace gemmlowp {
     24 
     25 template <>
     26 struct FixedPointRawTypeTraits<int32x4_t> {
     27   typedef std::int32_t ScalarRawType;
     28   static const int kLanes = 4;
     29 };
     30 
     31 template <>
     32 struct FixedPointRawTypeTraits<int16x8_t> {
     33   typedef std::int16_t ScalarRawType;
     34   static const int kLanes = 8;
     35 };
     36 
     37 template <>
     38 inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) {
     39   return vandq_s32(a, b);
     40 }
     41 
     42 template <>
     43 inline int16x8_t BitAnd(int16x8_t a, int16x8_t b) {
     44   return vandq_s16(a, b);
     45 }
     46 
     47 template <>
     48 inline int32x4_t BitOr(int32x4_t a, int32x4_t b) {
     49   return vorrq_s32(a, b);
     50 }
     51 
     52 template <>
     53 inline int16x8_t BitOr(int16x8_t a, int16x8_t b) {
     54   return vorrq_s16(a, b);
     55 }
     56 
     57 template <>
     58 inline int32x4_t BitXor(int32x4_t a, int32x4_t b) {
     59   return veorq_s32(a, b);
     60 }
     61 
     62 template <>
     63 inline int16x8_t BitXor(int16x8_t a, int16x8_t b) {
     64   return veorq_s16(a, b);
     65 }
     66 
     67 template <>
     68 inline int32x4_t BitNot(int32x4_t a) {
     69   return veorq_s32(a, vdupq_n_s32(-1));
     70 }
     71 
     72 template <>
     73 inline int16x8_t BitNot(int16x8_t a) {
     74   return veorq_s16(a, vdupq_n_s16(-1));
     75 }
     76 
     77 template <>
     78 inline int32x4_t Add(int32x4_t a, int32x4_t b) {
     79   return vaddq_s32(a, b);
     80 }
     81 
     82 template <>
     83 inline int16x8_t Add(int16x8_t a, int16x8_t b) {
     84   return vaddq_s16(a, b);
     85 }
     86 
     87 template <>
     88 inline int32x4_t Sub(int32x4_t a, int32x4_t b) {
     89   return vsubq_s32(a, b);
     90 }
     91 
     92 template <>
     93 inline int16x8_t Sub(int16x8_t a, int16x8_t b) {
     94   return vsubq_s16(a, b);
     95 }
     96 
     97 template <>
     98 inline int32x4_t Neg(int32x4_t a) {
     99   return vnegq_s32(a);
    100 }
    101 
    102 template <>
    103 inline int16x8_t Neg(int16x8_t a) {
    104   return vnegq_s16(a);
    105 }
    106 
    107 template <>
    108 inline int32x4_t ShiftLeft(int32x4_t a, int offset) {
    109   return vshlq_s32(a, vdupq_n_s32(offset));
    110 }
    111 
    112 template <>
    113 inline int16x8_t ShiftLeft(int16x8_t a, int offset) {
    114   return vshlq_s16(a, vdupq_n_s16(offset));
    115 }
    116 
    117 template <>
    118 inline int32x4_t ShiftRight(int32x4_t a, int offset) {
    119   return vshlq_s32(a, vdupq_n_s32(-offset));
    120 }
    121 
    122 template <>
    123 inline int16x8_t ShiftRight(int16x8_t a, int offset) {
    124   return vshlq_s16(a, vdupq_n_s16(-offset));
    125 }
    126 
    127 template <>
    128 inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val,
    129                                  int32x4_t else_val) {
    130   return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val);
    131 }
    132 
    133 template <>
    134 inline int16x8_t SelectUsingMask(int16x8_t if_mask, int16x8_t then_val,
    135                                  int16x8_t else_val) {
    136   return vbslq_s16(vreinterpretq_u16_s16(if_mask), then_val, else_val);
    137 }
    138 
    139 template <>
    140 inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) {
    141   return vreinterpretq_s32_u32(vceqq_s32(a, b));
    142 }
    143 
    144 template <>
    145 inline int16x8_t MaskIfEqual(int16x8_t a, int16x8_t b) {
    146   return vreinterpretq_s16_u16(vceqq_s16(a, b));
    147 }
    148 
    149 template <>
    150 inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) {
    151   return BitNot(MaskIfEqual(a, b));
    152 }
    153 
    154 template <>
    155 inline int16x8_t MaskIfNotEqual(int16x8_t a, int16x8_t b) {
    156   return BitNot(MaskIfEqual(a, b));
    157 }
    158 
    159 template <>
    160 inline int32x4_t MaskIfZero(int32x4_t a) {
    161   return MaskIfEqual(a, vdupq_n_s32(0));
    162 }
    163 
    164 template <>
    165 inline int16x8_t MaskIfZero(int16x8_t a) {
    166   return MaskIfEqual(a, vdupq_n_s16(0));
    167 }
    168 
    169 template <>
    170 inline int32x4_t MaskIfNonZero(int32x4_t a) {
    171   return vreinterpretq_s32_u32(vtstq_s32(a, a));
    172 }
    173 
    174 template <>
    175 inline int16x8_t MaskIfNonZero(int16x8_t a) {
    176   return vreinterpretq_s16_u16(vtstq_s16(a, a));
    177 }
    178 
    179 template <>
    180 inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) {
    181   return vreinterpretq_s32_u32(vcgtq_s32(a, b));
    182 }
    183 
    184 template <>
    185 inline int16x8_t MaskIfGreaterThan(int16x8_t a, int16x8_t b) {
    186   return vreinterpretq_s16_u16(vcgtq_s16(a, b));
    187 }
    188 
    189 template <>
    190 inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) {
    191   return vreinterpretq_s32_u32(vcgeq_s32(a, b));
    192 }
    193 
    194 template <>
    195 inline int16x8_t MaskIfGreaterThanOrEqual(int16x8_t a, int16x8_t b) {
    196   return vreinterpretq_s16_u16(vcgeq_s16(a, b));
    197 }
    198 
    199 template <>
    200 inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) {
    201   return vreinterpretq_s32_u32(vcltq_s32(a, b));
    202 }
    203 
    204 template <>
    205 inline int16x8_t MaskIfLessThan(int16x8_t a, int16x8_t b) {
    206   return vreinterpretq_s16_u16(vcltq_s16(a, b));
    207 }
    208 
    209 template <>
    210 inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) {
    211   return vreinterpretq_s32_u32(vcleq_s32(a, b));
    212 }
    213 
    214 template <>
    215 inline int16x8_t MaskIfLessThanOrEqual(int16x8_t a, int16x8_t b) {
    216   return vreinterpretq_s16_u16(vcleq_s16(a, b));
    217 }
    218 
    219 template <>
    220 inline bool All(int32x4_t a) {
    221   a = vandq_s32(a, vextq_s32(a, a, 1));
    222   a = vandq_s32(a, vextq_s32(a, a, 2));
    223   return vgetq_lane_s32(a, 0);
    224 }
    225 
    226 template <>
    227 inline bool All(int16x8_t a) {
    228   a = vandq_s16(a, vextq_s16(a, a, 1));
    229   a = vandq_s16(a, vextq_s16(a, a, 2));
    230   a = vandq_s16(a, vextq_s16(a, a, 4));
    231   return vgetq_lane_s16(a, 0);
    232 }
    233 
    234 template <>
    235 inline bool Any(int32x4_t a) {
    236   a = vorrq_s32(a, vextq_s32(a, a, 1));
    237   a = vorrq_s32(a, vextq_s32(a, a, 2));
    238   return vgetq_lane_s32(a, 0);
    239 }
    240 
    241 template <>
    242 inline bool Any(int16x8_t a) {
    243   a = vorrq_s16(a, vextq_s16(a, a, 1));
    244   a = vorrq_s16(a, vextq_s16(a, a, 2));
    245   a = vorrq_s16(a, vextq_s16(a, a, 4));
    246   return vgetq_lane_s16(a, 0);
    247 }
    248 
    249 template <>
    250 inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) {
    251   return vrhaddq_s32(a, b);
    252 }
    253 
    254 template <>
    255 inline int16x8_t RoundingHalfSum(int16x8_t a, int16x8_t b) {
    256   return vrhaddq_s16(a, b);
    257 }
    258 
    259 template <>
    260 inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) {
    261   return vqrdmulhq_s32(a, b);
    262 }
    263 
    264 template <>
    265 inline int16x8_t SaturatingRoundingDoublingHighMul(int16x8_t a, int16x8_t b) {
    266   return vqrdmulhq_s16(a, b);
    267 }
    268 
    269 template <>
    270 inline int32x4_t RoundingDivideByPOT(int32x4_t x, int exponent) {
    271   const int32x4_t shift_vec = vdupq_n_s32(-exponent);
    272   const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31);
    273   const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
    274   return vrshlq_s32(fixed_up_x, shift_vec);
    275 }
    276 
    277 template <>
    278 inline int16x8_t RoundingDivideByPOT(int16x8_t x, int exponent) {
    279   const int16x8_t shift_vec = vdupq_n_s16(-exponent);
    280   const int16x8_t fixup = vshrq_n_s16(vandq_s16(x, shift_vec), 15);
    281   const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
    282   return vrshlq_s16(fixed_up_x, shift_vec);
    283 }
    284 
    285 template <int Exponent>
    286 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
    287   static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
    288 };
    289 
    290 template <int Exponent>
    291 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> {
    292   static int32x4_t eval(int32x4_t x) {
    293     const int32x4_t fixup = vshrq_n_s32(x, 31);
    294     const int32x4_t fixed_up_x = vqaddq_s32(x, fixup);
    295     return vrshrq_n_s32(fixed_up_x, -Exponent);
    296   }
    297 };
    298 
    299 template <int Exponent>
    300 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, 1> {
    301   static int16x8_t eval(int16x8_t x) { return vqshlq_n_s16(x, Exponent); }
    302 };
    303 
    304 template <int Exponent>
    305 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int16x8_t, -1> {
    306   static int16x8_t eval(int16x8_t x) {
    307     const int16x8_t fixup = vshrq_n_s16(x, 15);
    308     const int16x8_t fixed_up_x = vqaddq_s16(x, fixup);
    309     return vrshrq_n_s16(fixed_up_x, -Exponent);
    310   }
    311 };
    312 
    313 template <>
    314 inline int32x4_t Dup<int32x4_t>(std::int32_t x) {
    315   return vdupq_n_s32(x);
    316 }
    317 
    318 template <>
    319 inline int16x8_t Dup<int16x8_t>(std::int16_t x) {
    320   return vdupq_n_s16(x);
    321 }
    322 
    323 // So far this is only needed for int16.
    324 template <>
    325 inline int16x8_t SaturatingAdd(int16x8_t a, int16x8_t b) {
    326   return vqaddq_s16(a, b);
    327 }
    328 
    329 }  // end namespace gemmlowp
    330 
    331 #endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
    332