Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code
     16 
     17 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
     18 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
     19 
     20 #include "simd_wrappers.h"
     21 
     22 namespace gemmlowp {
     23 
     24 template <typename SrcScalarType, int N>
     25 struct LoadImpl<RegBlockInt32<4, N>,
     26                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
     27   static RegBlockInt32<4, N> Run(
     28       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
     29       int col) {
     30     RegBlockInt32<4, N> result;
     31     for (int i = 0; i < N; i++) {
     32       result.buf.reg[i] = LoadInt32x4(src.data(row, col + i));
     33     }
     34     return result;
     35   }
     36 };
     37 
     38 template <typename SrcScalarType, int N>
     39 struct LoadImpl<RegBlockInt32<8, N>,
     40                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
     41   static RegBlockInt32<8, N> Run(
     42       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
     43       int col) {
     44     RegBlockInt32<8, N> result;
     45     for (int i = 0; i < N; i++) {
     46       result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i));
     47       result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i));
     48     }
     49     return result;
     50   }
     51 };
     52 
     53 template <typename SrcScalarType>
     54 struct LoadImpl<RegBlockInt32<1, 4>,
     55                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
     56   static RegBlockInt32<1, 4> Run(
     57       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
     58       int col) {
     59     RegBlockInt32<1, 4> result;
     60     std::int32_t buf[4];
     61     for (int i = 0; i < 4; i++) {
     62       buf[i] = src(row, col + i);
     63     }
     64     result.buf.reg[0] = LoadInt32x4(buf);
     65     return result;
     66   }
     67 };
     68 
     69 template <typename SrcScalarType>
     70 struct LoadImpl<RegBlockInt32<1, 8>,
     71                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
     72   static RegBlockInt32<1, 8> Run(
     73       const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
     74       int col) {
     75     RegBlockInt32<1, 8> result;
     76     std::int32_t buf[8];
     77     for (int i = 0; i < 8; i++) {
     78       buf[i] = src(row, col + i);
     79     }
     80     result.buf.reg[0] = LoadInt32x4(buf);
     81     result.buf.reg[1] = LoadInt32x4(buf + 4);
     82     return result;
     83   }
     84 };
     85 
     86 template <typename SrcScalarType>
     87 struct LoadImpl<RegBlockInt32<4, 1>,
     88                 VectorMap<SrcScalarType, VectorShape::Col>> {
     89   static RegBlockInt32<4, 1> Run(
     90       const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) {
     91     RegBlockInt32<4, 1> result;
     92     result.buf.reg[0] = LoadInt32x4(src.data(pos));
     93     return result;
     94   }
     95 };
     96 
     97 template <typename SrcScalarType>
     98 struct LoadImpl<RegBlockInt32<4, 1>,
     99                 VectorDup<SrcScalarType, VectorShape::Col>> {
    100   static RegBlockInt32<4, 1> Run(
    101       const VectorDup<SrcScalarType, VectorShape::Col>& src, int) {
    102     RegBlockInt32<4, 1> result;
    103     result.buf.reg[0] = LoadInt32x4(src(0));
    104     return result;
    105   }
    106 };
    107 
    108 template <typename SrcScalarType, int N>
    109 struct LoadForBroadcastingImpl<RegBlockInt32<4, N>,
    110                                VectorMap<SrcScalarType, VectorShape::Col>> {
    111   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
    112   using RegisterBlockType = RegBlockInt32<4, N>;
    113   using ResultBlockType =
    114       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
    115                                                 SrcObjectType>::Type;
    116 
    117   static ResultBlockType Run(const SrcObjectType& src, int pos) {
    118     ResultBlockType result;
    119     static_assert(ResultBlockType::kRegisterCount == 1, "");
    120     result.buf.reg[0] = LoadInt32x4(src.data(pos));
    121     return result;
    122   }
    123 };
    124 
    125 template <typename SrcScalarType, int N>
    126 struct LoadForBroadcastingImpl<RegBlockInt32<8, N>,
    127                                VectorMap<SrcScalarType, VectorShape::Col>> {
    128   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
    129   using RegisterBlockType = RegBlockInt32<8, N>;
    130   using ResultBlockType =
    131       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
    132                                                 SrcObjectType>::Type;
    133 
    134   static ResultBlockType Run(const SrcObjectType& src, int pos) {
    135     ResultBlockType result;
    136     static_assert(ResultBlockType::kRegisterCount == 2, "");
    137     result.buf.reg[0] = LoadInt32x4(src.data(pos));
    138     result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
    139     return result;
    140   }
    141 };
    142 
    143 template <typename SrcScalarType>
    144 struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>,
    145                                VectorMap<SrcScalarType, VectorShape::Row>> {
    146   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
    147   using RegisterBlockType = RegBlockInt32<4, 1>;
    148   using ResultBlockType =
    149       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
    150                                                 SrcObjectType>::Type;
    151 
    152   static ResultBlockType Run(const SrcObjectType& src, int pos) {
    153     ResultBlockType result;
    154     result.buf.reg[0] = src(pos);
    155     return result;
    156   }
    157 };
    158 
    159 template <typename SrcScalarType, int N>
    160 struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>,
    161                                VectorMap<SrcScalarType, VectorShape::Row>> {
    162   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
    163   using RegisterBlockType = RegBlockInt32<N, 4>;
    164   using ResultBlockType =
    165       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
    166                                                 SrcObjectType>::Type;
    167 
    168   static ResultBlockType Run(const SrcObjectType& src, int pos) {
    169     ResultBlockType result;
    170     static_assert(ResultBlockType::kRegisterCount == 1, "");
    171     result.buf.reg[0] = LoadInt32x4(src.data(pos));
    172     return result;
    173   }
    174 };
    175 
    176 template <typename SrcScalarType, int N>
    177 struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>,
    178                                VectorMap<SrcScalarType, VectorShape::Row>> {
    179   using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
    180   using RegisterBlockType = RegBlockInt32<N, 8>;
    181   using ResultBlockType =
    182       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
    183                                                 SrcObjectType>::Type;
    184 
    185   static ResultBlockType Run(const SrcObjectType& src, int pos) {
    186     ResultBlockType result;
    187     static_assert(ResultBlockType::kRegisterCount == 2, "");
    188     result.buf.reg[0] = LoadInt32x4(src.data(pos));
    189     result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
    190     return result;
    191   }
    192 };
    193 
    194 // 4x1 := 4x1 + 1x1
    195 template <>
    196 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
    197   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
    198                                  const RegBlockInt32<1, 1>& rhs) {
    199     RegBlockInt32<4, 1> result;
    200     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
    201     return result;
    202   }
    203 };
    204 
    205 // 1x4 := 1x4 + 1x1
    206 template <>
    207 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
    208   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
    209                                  const RegBlockInt32<1, 1>& rhs) {
    210     RegBlockInt32<1, 4> result;
    211     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
    212     return result;
    213   }
    214 };
    215 
    216 // 4x1 := 4x1 + 4x1
    217 template <>
    218 struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
    219   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
    220                                  const RegBlockInt32<4, 1>& rhs) {
    221     RegBlockInt32<4, 1> result;
    222     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
    223     return result;
    224   }
    225 };
    226 
    227 // 1x4 := 1x4 + 1x4
    228 template <>
    229 struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
    230   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
    231                                  const RegBlockInt32<1, 4>& rhs) {
    232     RegBlockInt32<1, 4> result;
    233     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
    234     return result;
    235   }
    236 };
    237 
    238 // 4x4 := 4x4 + 1x4
    239 template <>
    240 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
    241   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
    242                                  const RegBlockInt32<1, 4>& rhs) {
    243     RegBlockInt32<4, 4> result;
    244     result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
    245     result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
    246     result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
    247     result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
    248     return result;
    249   }
    250 };
    251 
    252 // 4x4 := 4x4 + 4x1
    253 template <>
    254 struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
    255   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
    256                                  const RegBlockInt32<4, 1>& rhs) {
    257     RegBlockInt32<4, 4> result;
    258     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
    259     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]);
    260     result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
    261     result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]);
    262     return result;
    263   }
    264 };
    265 
    266 // 8x1 := 8x1 + 1x1
    267 template <>
    268 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
    269   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
    270                                  const RegBlockInt32<1, 1>& rhs) {
    271     RegBlockInt32<8, 1> result;
    272     const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
    273     for (int i = 0; i < 2; i++) {
    274       result.buf.reg[i] = Add(lhs.buf.reg[i], p);
    275     }
    276     return result;
    277   }
    278 };
    279 
    280 // 8x1 := 8x1 + 8x1
    281 template <>
    282 struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
    283   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
    284                                  const RegBlockInt32<8, 1>& rhs) {
    285     RegBlockInt32<8, 1> result;
    286     for (int i = 0; i < 2; i++) {
    287       result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
    288     }
    289     return result;
    290   }
    291 };
    292 
    293 // 8x4 := 8x4 + 1x4
    294 template <>
    295 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
    296   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
    297                                  const RegBlockInt32<1, 4>& rhs) {
    298     RegBlockInt32<8, 4> result;
    299     result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
    300     result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
    301     result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
    302     result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
    303     result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
    304     result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
    305     result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
    306     result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
    307     return result;
    308   }
    309 };
    310 
    311 // 8x4 := 8x4 + 8x1
    312 template <>
    313 struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
    314   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
    315                                  const RegBlockInt32<8, 1>& rhs) {
    316     RegBlockInt32<8, 4> result;
    317     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
    318     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
    319     result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
    320     result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]);
    321     result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]);
    322     result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]);
    323     result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]);
    324     result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]);
    325     return result;
    326   }
    327 };
    328 
    329 // 1x8 := 1x8 + 1x8
    330 template <>
    331 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
    332   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
    333                                  const RegBlockInt32<1, 8>& rhs) {
    334     RegBlockInt32<1, 8> result;
    335     result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
    336     result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
    337     return result;
    338   }
    339 };
    340 
    341 // 1x8 := 1x8 + 1x1
    342 template <>
    343 struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
    344   static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
    345                                  const RegBlockInt32<1, 1>& rhs) {
    346     RegBlockInt32<1, 8> result;
    347     result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
    348     result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
    349     return result;
    350   }
    351 };
    352 
    353 // 4x1 := 4x1 * 1x1
    354 template <>
    355 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
    356   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
    357                                  const RegBlockInt32<1, 1>& rhs) {
    358     RegBlockInt32<4, 1> result;
    359     result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
    360     return result;
    361   }
    362 };
    363 
    364 // 4x1 := 4x1 * 4x1
    365 template <>
    366 struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
    367   static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
    368                                  const RegBlockInt32<4, 1>& rhs) {
    369     RegBlockInt32<4, 1> result;
    370     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
    371     return result;
    372   }
    373 };
    374 
    375 // 1x4 := 1x4 * 1x4
    376 template <>
    377 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
    378   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
    379                                  const RegBlockInt32<1, 4>& rhs) {
    380     RegBlockInt32<1, 4> result;
    381     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
    382     return result;
    383   }
    384 };
    385 
    386 // 1x4 := 1x4 * 1x1
    387 template <>
    388 struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
    389   static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
    390                                  const RegBlockInt32<1, 1>& rhs) {
    391     RegBlockInt32<1, 4> result;
    392     result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
    393     return result;
    394   }
    395 };
    396 
    397 // 4x4 := 4x4 * 1x4
    398 template <>
    399 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
    400   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
    401                                  const RegBlockInt32<1, 4>& rhs) {
    402     RegBlockInt32<4, 4> result;
    403     const Int32x4 p = rhs.buf.reg[0];
    404     result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p);
    405     result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p);
    406     result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p);
    407     result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p);
    408     return result;
    409   }
    410 };
    411 
    412 // 4x4 := 4x4 * 4x1
    413 template <>
    414 struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
    415   static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
    416                                  const RegBlockInt32<4, 1>& rhs) {
    417     RegBlockInt32<4, 4> result;
    418     const Int32x4 p = rhs.buf.reg[0];
    419     result.buf.reg[0] = Mul(lhs.buf.reg[0], p);
    420     result.buf.reg[1] = Mul(lhs.buf.reg[1], p);
    421     result.buf.reg[2] = Mul(lhs.buf.reg[2], p);
    422     result.buf.reg[3] = Mul(lhs.buf.reg[3], p);
    423     return result;
    424   }
    425 };
    426 
    427 // 8x1 := 8x1 * 1x1
    428 template <>
    429 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
    430   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
    431                                  const RegBlockInt32<1, 1>& rhs) {
    432     RegBlockInt32<8, 1> result;
    433     const std::int32_t p = rhs.buf.reg[0];
    434     for (int i = 0; i < 2; i++) {
    435       result.buf.reg[i] = Mul(lhs.buf.reg[i], p);
    436     }
    437     return result;
    438   }
    439 };
    440 
    441 // 8x1 := 8x1 * 8x1
    442 template <>
    443 struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
    444   static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
    445                                  const RegBlockInt32<8, 1>& rhs) {
    446     RegBlockInt32<8, 1> result;
    447     for (int i = 0; i < 2; i++) {
    448       result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]);
    449     }
    450     return result;
    451   }
    452 };
    453 
    454 // 8x4 := 8x4 * 1x4
    455 template <>
    456 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
    457   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
    458                                  const RegBlockInt32<1, 4>& rhs) {
    459     RegBlockInt32<8, 4> result;
    460     const Int32x4 p = rhs.buf.reg[0];
    461     for (int i = 0; i < 2; i++) {
    462       result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p);
    463       result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p);
    464       result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p);
    465       result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p);
    466     }
    467     return result;
    468   }
    469 };
    470 
    471 // 8x4 := 8x4 * 8x1
    472 template <>
    473 struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
    474   static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
    475                                  const RegBlockInt32<8, 1>& rhs) {
    476     RegBlockInt32<8, 4> result;
    477     const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]};
    478     for (int i = 0; i < 4; i++) {
    479       for (int j = 0; j < 2; j++) {
    480         const int k = j + 2 * i;
    481         result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]);
    482       }
    483     }
    484     return result;
    485   }
    486 };
    487 
    488 // Rx1 += Rx1 * 1x1
    489 template <int Rows>
    490 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
    491                            RegBlockInt32<Rows, 1>> {
    492   static void Run(const RegBlockInt32<Rows, 1>& lhs,
    493                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) {
    494     const std::int32_t p = rhs.buf.reg[0];
    495     for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) {
    496       MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
    497     }
    498   }
    499 };
    500 
    501 // RxC += Rx1 * 1x1
    502 template <int Rows, int Cols>
    503 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
    504                            RegBlockInt32<Rows, Cols>> {
    505   static void Run(const RegBlockInt32<Rows, 1>& lhs,
    506                   const RegBlockInt32<1, 1>& rhs,
    507                   RegBlockInt32<Rows, Cols>* acc) {
    508     const std::int32_t p = rhs.buf.reg[0];
    509     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
    510     for (int i = 0; i < kRegsPerCol; i++) {
    511       const Int32x4 q = Mul(lhs.buf.reg[i], p);
    512       for (int j = 0; j < Cols; j++) {
    513         acc->buf.reg[i + j * kRegsPerCol] =
    514             Add(acc->buf.reg[i + j * kRegsPerCol], q);
    515       }
    516     }
    517   }
    518 };
    519 
    520 // 1xC += 1xC * 1x1
    521 template <int Cols>
    522 struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>,
    523                            RegBlockInt32<1, Cols>> {
    524   static void Run(const RegBlockInt32<1, Cols>& lhs,
    525                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
    526     const std::int32_t p = rhs.buf.reg[0];
    527     for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
    528       MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
    529     }
    530   }
    531 };
    532 
    533 // RxC += 1x1 * 1x1
    534 template <int Rows, int Cols>
    535 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
    536                            RegBlockInt32<Rows, Cols>> {
    537   static void Run(const RegBlockInt32<1, 1>& lhs,
    538                   const RegBlockInt32<1, 1>& rhs,
    539                   RegBlockInt32<Rows, Cols>* acc) {
    540     const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
    541     for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) {
    542       acc->buf.reg[i] = Add(acc->buf.reg[i], p);
    543     }
    544   }
    545 };
    546 
    547 // 1x1 += 1x1 * 1x1
    548 template <>
    549 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
    550                            RegBlockInt32<1, 1>> {
    551   static void Run(const RegBlockInt32<1, 1>& lhs,
    552                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) {
    553     MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]);
    554   }
    555 };
    556 
    557 // Rx4 += Rx1 * 1x4
    558 template <int Rows>
    559 struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>,
    560                            RegBlockInt32<Rows, 4>> {
    561   static void Run(const RegBlockInt32<Rows, 1>& lhs,
    562                   const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) {
    563     const Int32x4 p = rhs.buf.reg[0];
    564     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
    565     for (int i = 0; i < kRegsPerCol; i++) {
    566       MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]);
    567       MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]);
    568       MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]);
    569       MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]);
    570     }
    571   }
    572 };
    573 
    574 // Rx4 += 1x4 * 1x1
    575 template <int Rows>
    576 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
    577                            RegBlockInt32<Rows, 4>> {
    578   static void Run(const RegBlockInt32<1, 4>& lhs,
    579                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) {
    580     const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
    581     Int32x4 q[4];
    582     q[0] = DupLane<0>(p);
    583     q[1] = DupLane<1>(p);
    584     q[2] = DupLane<2>(p);
    585     q[3] = DupLane<3>(p);
    586     static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
    587     for (int i = 0; i < kRegsPerCol; i++) {
    588       for (int j = 0; j < 4; j++) {
    589         acc->buf.reg[i + j * kRegsPerCol] =
    590             Add(q[j], acc->buf.reg[i + j * kRegsPerCol]);
    591       }
    592     }
    593   }
    594 };
    595 
    596 // 1xC += 1x1 * 1x1
    597 template <int Cols>
    598 struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
    599                            RegBlockInt32<1, Cols>> {
    600   static void Run(const RegBlockInt32<1, 1>& lhs,
    601                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
    602     const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
    603     for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
    604       acc->buf.reg[i] = Add(acc->buf.reg[i], p);
    605     }
    606   }
    607 };
    608 
    609 // 1x4 += 1x4 * 1x1
    610 template <>
    611 struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
    612                            RegBlockInt32<1, 4>> {
    613   static void Run(const RegBlockInt32<1, 4>& lhs,
    614                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) {
    615     const std::int32_t p = rhs.buf.reg[0];
    616     MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
    617   }
    618 };
    619 
    620 // 4xC += 4x1 * 1x1
    621 template <int Cols>
    622 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
    623                            RegBlockInt32<4, Cols>> {
    624   static void Run(const RegBlockInt32<4, 1>& lhs,
    625                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) {
    626     const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
    627     for (int i = 0; i < Cols; i++) {
    628       acc->buf.reg[i] = Add(p, acc->buf.reg[i]);
    629     }
    630   }
    631 };
    632 
    633 // 4x1 += 4x1 * 1x1
    634 template <>
    635 struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
    636                            RegBlockInt32<4, 1>> {
    637   static void Run(const RegBlockInt32<4, 1>& lhs,
    638                   const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) {
    639     const std::int32_t p = rhs.buf.reg[0];
    640     MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
    641   }
    642 };
    643 
    644 }  // namespace gemmlowp
    645 
    646 #endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
    647