1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // simd_wrappers.h: some inline functions wrapping SIMD intrinsics, 16 // extending the set of such functions from fixedpoint.h. 17 18 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ 19 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ 20 21 #include <algorithm> 22 #include <type_traits> 23 #include "../fixedpoint/fixedpoint.h" 24 25 namespace gemmlowp { 26 27 template <typename ScalarType, int ScalarCount> 28 struct RegisterType { 29 using Type = ScalarType; 30 }; 31 32 inline std::int32_t Min(std::int32_t a, std::int32_t b) { 33 return std::min(a, b); 34 } 35 36 inline std::int32_t Max(std::int32_t a, std::int32_t b) { 37 return std::max(a, b); 38 } 39 40 inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) { 41 *acc += lhs * rhs; 42 } 43 44 template <typename tScalarType, int tScalarCount> 45 struct RegisterBuffer { 46 using ScalarType = tScalarType; 47 static constexpr int kScalarCount = tScalarCount; 48 using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type; 49 static_assert((kScalarCount & (kScalarCount - 1)) == 0, 50 "kScalarCount must be a power of two"); 51 static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, ""); 52 static constexpr int kRegisterLanes = 53 sizeof(RegisterType) / sizeof(ScalarType); 54 static constexpr int kRegisterCount = 55 (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) / 56 sizeof(RegisterType); 57 58 RegisterType reg[kRegisterCount]; 59 }; 60 61 template <typename tScalarType, int tRows, int tCols> 62 struct RegisterBlock { 63 using ScalarType = tScalarType; 64 static constexpr int kRows = tRows; 65 static constexpr int kCols = tCols; 66 static constexpr int kScalarCount = kRows * kCols; 67 using BufferType = RegisterBuffer<ScalarType, kScalarCount>; 68 using RegisterType = typename BufferType::RegisterType; 69 static constexpr int kRegisterCount = BufferType::kRegisterCount; 70 static constexpr int kRegisterLanes = BufferType::kRegisterLanes; 71 72 BufferType buf; 73 }; 74 75 template <typename RegisterBlockType> 76 struct RegisterBlockAddImpl { 77 static RegisterBlockType Run(const RegisterBlockType& lhs, 78 const RegisterBlockType& rhs) { 79 RegisterBlockType result; 80 for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { 81 result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); 82 } 83 return result; 84 } 85 }; 86 87 template <typename RegisterBlockType> 88 RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs, 89 const RegisterBlockType& rhs) { 90 return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs); 91 } 92 93 template <typename LhsType, typename RhsType> 94 struct ShouldFlipLhsRhs { 95 static constexpr bool kValue = 96 (LhsType::kScalarCount < RhsType::kScalarCount) || 97 (LhsType::kScalarCount == RhsType::kScalarCount && 98 (LhsType::kRows < RhsType::kRows)); 99 }; 100 101 template <typename LhsType, typename RhsType, 102 bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue> 103 struct FlipLhsRhs { 104 using FlippedLhsType = LhsType; 105 using FlippedRhsType = RhsType; 106 static const FlippedLhsType& FlippedLhs(const LhsType& lhs, 107 const RhsType& rhs) { 108 return lhs; 109 } 110 static const FlippedRhsType& FlippedRhs(const LhsType& lhs, 111 const RhsType& rhs) { 112 return rhs; 113 } 114 }; 115 116 template <typename LhsType, typename RhsType> 117 struct FlipLhsRhs<LhsType, RhsType, true> { 118 using FlippedLhsType = RhsType; 119 using FlippedRhsType = LhsType; 120 static const FlippedLhsType& FlippedLhs(const LhsType& lhs, 121 const RhsType& rhs) { 122 return rhs; 123 } 124 static const FlippedRhsType& FlippedRhs(const LhsType& lhs, 125 const RhsType& rhs) { 126 return lhs; 127 } 128 }; 129 130 template <typename Lhs, typename Rhs> 131 struct BroadcastBinaryOpShape { 132 static constexpr int kRows = 133 Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows; 134 static constexpr int kCols = 135 Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols; 136 }; 137 138 template <typename Lhs, typename Rhs> 139 struct BroadcastBinaryOpRegisterBlock { 140 using Shape = BroadcastBinaryOpShape<Lhs, Rhs>; 141 using ScalarType = typename Lhs::ScalarType; 142 using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; 143 }; 144 145 template <typename Lhs, typename Rhs> 146 struct BroadcastAddImpl { 147 using ResultBlockType = 148 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; 149 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { 150 ResultBlockType result; 151 static constexpr int Rows = ResultBlockType::kRows; 152 static constexpr int Cols = ResultBlockType::kCols; 153 static constexpr int LhsRows = Lhs::kRows; 154 static constexpr int LhsCols = Lhs::kCols; 155 static constexpr int RhsRows = Rhs::kRows; 156 static constexpr int RhsCols = Rhs::kCols; 157 158 static_assert(LhsRows == Rows || LhsRows == 1, ""); 159 static_assert(RhsRows == Rows || RhsRows == 1, ""); 160 static_assert(LhsCols == Cols || LhsCols == 1, ""); 161 static_assert(RhsCols == Cols || RhsCols == 1, ""); 162 static_assert(ResultBlockType::kRegisterLanes == 1, 163 "This path is only for scalar values"); 164 static_assert(Lhs::kRegisterLanes == 1, 165 "This path is only for scalar values"); 166 static_assert(Rhs::kRegisterLanes == 1, 167 "This path is only for scalar values"); 168 169 for (int c = 0; c < Cols; c++) { 170 const int lhs_c = LhsCols == Cols ? c : 0; 171 const int rhs_c = RhsCols == Cols ? c : 0; 172 for (int r = 0; r < Rows; r++) { 173 const int lhs_r = LhsRows == Rows ? r : 0; 174 const int rhs_r = RhsRows == Rows ? r : 0; 175 result.buf.reg[r + c * Rows] = 176 Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows], 177 rhs.buf.reg[rhs_r + rhs_c * RhsRows]); 178 } 179 } 180 return result; 181 } 182 }; 183 184 template <typename Lhs, typename Rhs> 185 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd( 186 const Lhs& lhs, const Rhs& rhs) { 187 using Flip = FlipLhsRhs<Lhs, Rhs>; 188 return BroadcastAddImpl< 189 typename Flip::FlippedLhsType, 190 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), 191 Flip::FlippedRhs(lhs, rhs)); 192 } 193 194 template <typename Lhs, typename Rhs> 195 struct BroadcastMulImpl { 196 using ResultBlockType = 197 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; 198 static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { 199 ResultBlockType result; 200 static constexpr int Rows = ResultBlockType::kRows; 201 static constexpr int Cols = ResultBlockType::kCols; 202 static constexpr int LhsRows = Lhs::kRows; 203 static constexpr int LhsCols = Lhs::kCols; 204 static constexpr int RhsRows = Rhs::kRows; 205 static constexpr int RhsCols = Rhs::kCols; 206 static_assert(ResultBlockType::kRegisterLanes == 1, 207 "This path is only for scalar values"); 208 static_assert(Lhs::kRegisterLanes == 1, 209 "This path is only for scalar values"); 210 static_assert(Rhs::kRegisterLanes == 1, 211 "This path is only for scalar values"); 212 213 static_assert(LhsRows == Rows || LhsRows == 1, ""); 214 static_assert(RhsRows == Rows || RhsRows == 1, ""); 215 static_assert(LhsCols == Cols || LhsCols == 1, ""); 216 static_assert(RhsCols == Cols || RhsCols == 1, ""); 217 for (int c = 0; c < Cols; c++) { 218 const int lhs_c = LhsCols == Cols ? c : 0; 219 const int rhs_c = RhsCols == Cols ? c : 0; 220 for (int r = 0; r < Rows; r++) { 221 const int lhs_r = LhsRows == Rows ? r : 0; 222 const int rhs_r = RhsRows == Rows ? r : 0; 223 result.buf.reg[r + c * Rows] = 224 Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows], 225 rhs.buf.reg[rhs_r + rhs_c * RhsRows]); 226 } 227 } 228 return result; 229 } 230 }; 231 232 template <typename Lhs, typename Rhs> 233 typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul( 234 const Lhs& lhs, const Rhs& rhs) { 235 using Flip = FlipLhsRhs<Lhs, Rhs>; 236 return BroadcastMulImpl< 237 typename Flip::FlippedLhsType, 238 typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), 239 Flip::FlippedRhs(lhs, rhs)); 240 } 241 242 template <typename Lhs, typename Rhs, typename Acc> 243 struct BroadcastMulAddImpl { 244 static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) { 245 static constexpr int Rows = Acc::kRows; 246 static constexpr int Cols = Acc::kCols; 247 static constexpr int LhsRows = Lhs::kRows; 248 static constexpr int LhsCols = Lhs::kCols; 249 static constexpr int RhsRows = Rhs::kRows; 250 static constexpr int RhsCols = Rhs::kCols; 251 static_assert(Acc::kRegisterLanes == 1, 252 "This path is only for scalar values"); 253 static_assert(Lhs::kRegisterLanes == 1, 254 "This path is only for scalar values"); 255 static_assert(Rhs::kRegisterLanes == 1, 256 "This path is only for scalar values"); 257 258 static_assert(LhsRows == Rows || LhsRows == 1, ""); 259 static_assert(RhsRows == Rows || RhsRows == 1, ""); 260 static_assert(LhsCols == Cols || LhsCols == 1, ""); 261 static_assert(RhsCols == Cols || RhsCols == 1, ""); 262 for (int c = 0; c < Cols; c++) { 263 const int lhs_c = LhsCols == Cols ? c : 0; 264 const int rhs_c = RhsCols == Cols ? c : 0; 265 for (int r = 0; r < Rows; r++) { 266 const int lhs_r = LhsRows == Rows ? r : 0; 267 const int rhs_r = RhsRows == Rows ? r : 0; 268 MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows], 269 rhs.buf.reg[rhs_r + rhs_c * RhsRows], 270 &acc->buf.reg[r + c * Rows]); 271 } 272 } 273 } 274 }; 275 276 template <typename Lhs, typename Rhs, typename Acc> 277 void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) { 278 using Flip = FlipLhsRhs<Lhs, Rhs>; 279 BroadcastMulAddImpl<typename Flip::FlippedLhsType, 280 typename Flip::FlippedRhsType, 281 Acc>::Run(Flip::FlippedLhs(lhs, rhs), 282 Flip::FlippedRhs(lhs, rhs), acc); 283 } 284 285 template <typename RegisterBlockType, typename SrcObjectType> 286 struct LoadImpl { 287 static_assert(std::is_same<SrcObjectType, void>::value, 288 "This generic impl should never be hit"); 289 }; 290 291 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType> 292 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, 293 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 294 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; 295 using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>; 296 static RegisterBlockType Run(const SrcObjectType& src, int row, int col) { 297 RegisterBlockType result; 298 int i = 0; 299 for (int c = 0; c < Cols; c++) { 300 const ScalarType* src_ptr = src.data(row, col + c); 301 for (int r = 0; r < Rows; r++) { 302 result.buf.reg[i++] = *src_ptr++; 303 } 304 } 305 return result; 306 } 307 }; 308 309 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, 310 VectorShape Shape> 311 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, 312 VectorMap<SrcScalarType, Shape>> { 313 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; 314 using SrcObjectType = VectorMap<SrcScalarType, Shape>; 315 static RegisterBlockType Run(const SrcObjectType& src, int pos) { 316 static_assert(Shape == VectorShape::Col || Rows == 1, ""); 317 static_assert(Shape == VectorShape::Row || Cols == 1, ""); 318 RegisterBlockType result; 319 for (int i = 0; i < Rows * Cols; i++) { 320 result.buf.reg[i] = src(pos + i); 321 } 322 return result; 323 } 324 }; 325 326 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, 327 VectorShape Shape> 328 struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, 329 VectorDup<SrcScalarType, Shape>> { 330 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; 331 using SrcObjectType = VectorDup<SrcScalarType, Shape>; 332 static RegisterBlockType Run(const SrcObjectType& src, int) { 333 static_assert(Shape == VectorShape::Col || Rows == 1, ""); 334 static_assert(Shape == VectorShape::Row || Cols == 1, ""); 335 RegisterBlockType result; 336 for (int i = 0; i < Rows * Cols; i++) { 337 result.buf.reg[i] = src(0); 338 } 339 return result; 340 } 341 }; 342 343 template <typename RegisterBlockType, typename SrcObjectType> 344 RegisterBlockType Load(const SrcObjectType& src, int row, int col) { 345 return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col); 346 } 347 348 template <typename RegisterBlockType, typename SrcObjectType> 349 RegisterBlockType Load(const SrcObjectType& src, int pos) { 350 return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos); 351 } 352 353 template <typename RegisterBlockType> 354 struct LoadContiguousImpl { 355 using ScalarType = typename RegisterBlockType::ScalarType; 356 static_assert(RegisterBlockType::kRegisterLanes == 1, 357 "This path is only for scalar values"); 358 static RegisterBlockType Run(const ScalarType* src) { 359 RegisterBlockType result; 360 for (int i = 0; i < RegisterBlockType::kScalarCount; i++) { 361 result.buf.reg[i] = src[i]; 362 } 363 return result; 364 } 365 }; 366 367 template <typename RegisterBlockType> 368 RegisterBlockType LoadContiguous( 369 const typename RegisterBlockType::ScalarType* src) { 370 return LoadContiguousImpl<RegisterBlockType>::Run(src); 371 } 372 373 template <int BroadcastRows, int BroadcastCols, typename SrcObjectType> 374 struct LoadForBroadcastingShape {}; 375 376 template <int BroadcastRows, int BroadcastCols, typename ScalarType, 377 VectorShape Shape> 378 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, 379 VectorMap<ScalarType, Shape>> { 380 static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1; 381 static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1; 382 }; 383 384 template <int BroadcastRows, int BroadcastCols, typename ScalarType, 385 VectorShape Shape> 386 struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, 387 VectorDup<ScalarType, Shape>> { 388 static constexpr int kRows = 1; 389 static constexpr int kCols = 1; 390 }; 391 392 template <typename RegisterBlockType, typename SrcObjectType> 393 struct LoadForBroadcastingRegisterBlock { 394 using Shape = 395 LoadForBroadcastingShape<RegisterBlockType::kRows, 396 RegisterBlockType::kCols, SrcObjectType>; 397 using ScalarType = typename RegisterBlockType::ScalarType; 398 using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; 399 }; 400 401 template <typename RegisterBlockType, typename SrcObjectType> 402 struct LoadForBroadcastingImpl { 403 static_assert(std::is_same<SrcObjectType, void>::value, 404 "This generic impl should never be hit"); 405 }; 406 407 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, 408 VectorShape Shape> 409 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, 410 VectorMap<SrcScalarType, Shape>> { 411 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; 412 using SrcObjectType = VectorMap<SrcScalarType, Shape>; 413 using ResultBlockType = 414 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 415 SrcObjectType>::Type; 416 static_assert(ResultBlockType::kRegisterLanes == 1, 417 "This path is only for scalar values"); 418 static ResultBlockType Run(const SrcObjectType& src, int pos) { 419 ResultBlockType result; 420 for (int c = 0; c < ResultBlockType::kCols; c++) { 421 for (int r = 0; r < ResultBlockType::kRows; r++) { 422 const int i = Shape == VectorShape::Col ? r : c; 423 result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i); 424 } 425 } 426 return result; 427 } 428 }; 429 430 template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, 431 VectorShape Shape> 432 struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, 433 VectorDup<SrcScalarType, Shape>> { 434 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; 435 using SrcObjectType = VectorDup<SrcScalarType, Shape>; 436 using ResultBlockType = 437 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 438 SrcObjectType>::Type; 439 static_assert(ResultBlockType::kRegisterLanes == 1, 440 "This path is only for scalar values"); 441 static ResultBlockType Run(const SrcObjectType& src, int) { 442 ResultBlockType result; 443 for (int c = 0; c < ResultBlockType::kCols; c++) { 444 for (int r = 0; r < ResultBlockType::kRows; r++) { 445 result.buf.reg[r + c * ResultBlockType::kRows] = src(0); 446 } 447 } 448 return result; 449 } 450 }; 451 452 template <typename RegisterBlockType, typename SrcObjectType> 453 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 454 SrcObjectType>::Type 455 LoadForBroadcasting(const SrcObjectType& src, int row, int col) { 456 return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run( 457 src, row, col); 458 } 459 460 template <typename RegisterBlockType, typename SrcObjectType> 461 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 462 SrcObjectType>::Type 463 LoadForBroadcasting(const SrcObjectType& src, int pos) { 464 return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src, 465 pos); 466 } 467 468 template <int ConstantValue, typename RegisterBlockType> 469 struct AddConstantImpl { 470 static void Run(RegisterBlockType* block) { 471 using RegisterType = typename RegisterBlockType::RegisterType; 472 const RegisterType dup = Dup<RegisterType>(ConstantValue); 473 for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { 474 block->buf.reg[i] = Add(block->buf.reg[i], dup); 475 } 476 } 477 }; 478 479 template <typename RegisterBlockType> 480 struct AddConstantImpl<0, RegisterBlockType> { 481 static void Run(RegisterBlockType*) { 482 // This is a no-op. 483 } 484 }; 485 486 template <int ConstantValue, typename RegisterBlockType> 487 void AddConstant(RegisterBlockType* block) { 488 AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block); 489 } 490 491 template <int N> 492 using RegBufferInt32 = RegisterBuffer<std::int32_t, N>; 493 template <int N> 494 using RegBufferInt16 = RegisterBuffer<std::int16_t, N>; 495 template <int N> 496 using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>; 497 template <int R, int C> 498 using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>; 499 template <int R, int C> 500 using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>; 501 template <int R, int C> 502 using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>; 503 504 } // end namespace gemmlowp 505 506 #if defined GEMMLOWP_NEON 507 #include "simd_wrappers_neon.h" 508 #elif defined GEMMLOWP_SSE4 509 #include "simd_wrappers_sse.h" 510 #elif defined GEMMLOWP_MSA 511 #include "simd_wrappers_msa.h" 512 #endif 513 514 #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ 515