Home | History | Annotate | Download | only in internal
      1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // simd_wrappers_neon.h: NEON specialization of simd_wrappers.h
     16 
     17 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
     18 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
     19 
     20 #include <arm_neon.h>
     21 
     22 namespace gemmlowp {
     23 
     24 using Int32x4 = int32x4_t;
     25 using Int16x4 = int16x4_t;
     26 using Int16x8 = int16x8_t;
     27 using Uint8x8 = uint8x8_t;
     28 
     29 template <int ScalarCount>
     30 struct RegisterType<std::int32_t, ScalarCount> {
     31   using Type =
     32       typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
     33 };
     34 
     35 template <int ScalarCount>
     36 struct RegisterType<std::int16_t, ScalarCount> {
     37   using Type = typename std::conditional<
     38       ScalarCount >= 8, Int16x8,
     39       typename std::conditional<ScalarCount >= 4, Int16x4,
     40                                 std::int16_t>::type>::type;
     41 };
     42 
     43 template <int ScalarCount>
     44 struct RegisterType<std::uint8_t, ScalarCount> {
     45   using Type = typename std::conditional<
     46       ScalarCount >= 8, Uint8x8,
     47       typename std::conditional<ScalarCount >= 4, std::uint32_t,
     48                                 std::uint8_t>::type>::type;
     49 };
     50 
     51 inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
     52 inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
     53 inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
     54 
     55 inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
     56   vst1q_s32(dst, value);
     57 }
     58 
     59 inline void StoreInt16x4(std::int16_t* dst, Int16x4 value) {
     60   vst1_s16(dst, value);
     61 }
     62 
     63 inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
     64   vst1q_s16(dst, value);
     65 }
     66 
     67 template <int Lane>
     68 std::int32_t GetLane(Int32x4 value) {
     69   return vgetq_lane_s32(value, Lane);
     70 }
     71 
     72 template <int Lane>
     73 Int32x4 DupLane(Int32x4 value) {
     74   switch (Lane) {
     75     case 0:
     76       return vdupq_lane_s32(vget_low_s32(value), 0);
     77     case 1:
     78       return vdupq_lane_s32(vget_low_s32(value), 1);
     79     case 2:
     80       return vdupq_lane_s32(vget_high_s32(value), 0);
     81     case 3:
     82       return vdupq_lane_s32(vget_high_s32(value), 1);
     83     default:
     84       static_assert(Lane >= 0 && Lane <= 3, "");
     85       return vdupq_n_s32(0);
     86   }
     87 }
     88 
     89 inline Int32x4 Mul(Int32x4 a, std::int32_t b) { return vmulq_n_s32(a, b); }
     90 
     91 inline Int32x4 Min(Int32x4 a, Int32x4 b) { return vminq_s32(a, b); }
     92 
     93 inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); }
     94 
     95 inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
     96   return vqrdmulhq_n_s32(a, b);
     97 }
     98 
     99 template <int Lane>
    100 Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
    101   switch (Lane) {
    102     case 0:
    103       return vmulq_lane_s32(a, vget_low_s32(b), 0);
    104     case 1:
    105       return vmulq_lane_s32(a, vget_low_s32(b), 1);
    106     case 2:
    107       return vmulq_lane_s32(a, vget_high_s32(b), 0);
    108     case 3:
    109       return vmulq_lane_s32(a, vget_high_s32(b), 1);
    110     default:
    111       static_assert(Lane >= 0 && Lane <= 3, "");
    112       return vdupq_n_s32(0);
    113   }
    114 }
    115 
    116 inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
    117   *acc = vmlaq_s32(*acc, lhs, rhs);
    118 }
    119 
    120 inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
    121   *acc = vmlaq_n_s32(*acc, lhs, rhs);
    122 }
    123 
    124 template <int Lane>
    125 inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
    126   switch (Lane) {
    127     case 0:
    128       *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 0);
    129       break;
    130     case 1:
    131       *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 1);
    132       break;
    133     case 2:
    134       *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 0);
    135       break;
    136     case 3:
    137       *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 1);
    138       break;
    139     default:
    140       static_assert(Lane >= 0 && Lane <= 3, "");
    141   }
    142 }
    143 
    144 template <>
    145 struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
    146   static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
    147     RegBlockInt16<8, 8> result;
    148     for (int i = 0; i < 8; i++) {
    149       result.buf.reg[i] = vld1q_s16(src + 8 * i);
    150     }
    151     return result;
    152   }
    153 };
    154 
    155 template <>
    156 struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
    157   static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
    158     RegBlockUint8<8, 8> result;
    159     for (int i = 0; i < 8; i++) {
    160       result.buf.reg[i] = vld1_u8(src + 8 * i);
    161     }
    162     return result;
    163   }
    164 };
    165 
    166 template <>
    167 struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
    168   static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
    169     RegBlockInt32<8, 8> result;
    170     for (int i = 0; i < 16; i++) {
    171       result.buf.reg[i] = vld1q_s32(src + 4 * i);
    172     }
    173     return result;
    174   }
    175 };
    176 
    177 }  // end namespace gemmlowp
    178 
    179 #include "simd_wrappers_common_neon_sse.h"
    180 
    181 #endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
    182