1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // This is a standalone testbed and benchmark for gemmlowp-style GEMM kernels, 16 // either doing integer or float arithmetic. 17 // It verifies that a kernel produces correct results, then benchmarks it. 18 // 19 // Some benchmark results are recorded in this spreadsheet: 20 // 21 // https://docs.google.com/spreadsheets/d/1UPbzbp9rdsD6RXxOr5q6AZ0n1omgEknLYO2ogiw6Kqk/edit?usp=sharing 22 // 23 // This program is entirely self-contained, and can be compiled manually 24 // such as suggested in the command lines below. 25 // It currently supports only Android/ARM but would trivially generalize to 26 // other OSes (it's mostly standard POSIX) or architectures (each kernel 27 // targets a specific architecture, one may simply add more). 28 29 /* 30 Build and run this benchmark on Android/ARM/32bit: 31 ~/android/toolchains/arm-linux-androideabi/bin/arm-linux-androideabi-clang++ \ 32 -fPIE -pie -O3 --std=c++11 standalone/neon-gemm-kernel-benchmark.cc -o \ 33 /tmp/benchmark -mfloat-abi=softfp -mfpu=neon-vfpv4 && adb push /tmp/benchmark \ 34 /data/local/tmp && adb shell /data/local/tmp/benchmark 35 Build and run this benchmark on Android/ARM/64bit: 36 ~/android/toolchains/aarch64-linux-android/bin/aarch64-linux-android-clang++ \ 37 -fPIE -static -O3 --std=c++11 standalone/neon-gemm-kernel-benchmark.cc -o \ 38 /tmp/benchmark && adb push /tmp/benchmark /data/local/tmp && adb shell \ 39 /data/local/tmp/benchmark 40 */ 41 42 // For big.LITTLE devices, use 'taskset' to select which cores to benchmark. 43 // 44 // The syntax is: taskset <mask> <commandline> 45 // where mask is a binary mask where each bit corresponds to a core, 46 // and low bits are little cores. 47 // 48 // Examples: 49 // Nexus 5X big cores: taskset 30 50 // Nexus 5X little cores: taskset 0f 51 // Pixel XL big cores: taskset 0c 52 // Pixel XL little cores: taskset 03 53 // 54 // Full example: 55 // adb shell taskset 0c /data/local/tmp/benchmark 56 57 #include <sched.h> 58 #include <unistd.h> 59 60 #include <algorithm> 61 #include <cassert> 62 #include <cstdint> 63 #include <cstdlib> 64 #include <cstring> 65 #include <iostream> 66 #include <random> 67 #include <type_traits> 68 69 #if !defined(__arm__) && !defined(__aarch64__) && \ 70 !(defined(__mips) && (__mips_isa_rev >= 5) && defined(__mips_msa)) 71 #error This benchmark assumes ARM or MIPS (for intrinsics and inline assembly sections). 72 #endif 73 74 #if defined(__arm__) || defined(__aarch64__) 75 #include <arm_neon.h> 76 #endif 77 78 #if defined(__mips) 79 #include <msa.h> 80 81 // Some convenience macros to hide differences between MIPS32 and MIPS64. 82 #ifdef __LP64__ 83 #define GEMMLOWP_MIPS_XADDIU "daddiu" 84 #else 85 #define GEMMLOWP_MIPS_XADDIU "addiu" 86 #endif 87 #endif 88 89 // Typically one wants to fit in L1 cache, and GEMM implementations 90 // are carefully optimized to tune their access patterns to that effect. 91 // Most devices have at least 16k of L1 cache. The Kraits have exactly 16k. 92 const int kDefaultCacheSizeK = 16; 93 94 const int kCacheLineSize = 64; 95 96 // These definitions are used for labels within assembly code. Required for 97 // iOS toolchain compatibility. 98 #define GEMMLOWP_LABEL_AFTER_LOOP "1" 99 #define GEMMLOWP_LABEL_LOOP "2" 100 #define GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES "3" 101 #define GEMMLOWP_LABEL_STORE "4" 102 103 // BEGIN code copied from gemmlowp/internal/kernel.h 104 105 // Explanation of general gemmlowp terminology 106 // =========================================== 107 // 108 // We use the following abbreviations: 109 // LHS = "left-hand side" 110 // RHS = "right-hand side" 111 // Sometimes when referring to either LHS or RHS, we just say a "Side". 112 // 113 // In a matrix product of a MxK matrix times a KxN matrix, 114 // we call K the 'depth'. Note that M is the number of rows 115 // of the result (and of the LHS), and N is the number of columns 116 // of the result (and of the RHS). 117 // 118 // In each of the LHS and RHS matrices, we call 'width' the 119 // other dimension, besides the depth. So in the LHS, 'width' 120 // is the number of rows, while in the RHS, 'width' is the number 121 // of columns. 122 // 123 // So in the LHS MxK matrix, the depth is K and the width in M. 124 // And in the RHS KxN matrix, the depth is K and the width in N. 125 // 126 // This is illustrated in this picture: 127 // 128 // RHS width 129 // <-----------------> 130 // +-----------------+ ^ 131 // | RHS | | Depth 132 // +-----------------+ v 133 // ^ +--+ +-----------------+ 134 // | |L | | | 135 // LHS width | |H | | Result | 136 // | |S | | | 137 // v +--+ +-----------------+ 138 // <--> 139 // Depth 140 141 // Explanation of gemmlowp kernel formats and "cells" 142 // ================================================== 143 // 144 // Kernels operate on small LHS and RHS blocks that fit in registers. 145 // These blocks are stored contiguously in memory, but not always 146 // in a traditional column-major or row-major order; instead, 147 // they consist of a number of sub-blocks, which we call "cells", 148 // that are stored in column-major or row-major order. However, 149 // what really matters to us is not so much rows vs columns, but 150 // rather width vs depth. So we refer to "width-major" and "depth-major" 151 // storage orders. In the LHS, width-major means row-major, 152 // while in the RHS, width-major means column-major. 153 // There is also a third possibility, "diagonal order", 154 // which is unused at the moment. 155 // 156 // We aim to treat both sides, LHS and RHS, on an equal footing, 157 // so we call them both 'sides'. A KernelFormat thus is just a pair 158 // of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat 159 // contains a CellFormat and a number of cells; cells are only ever 160 // stacked in the width dimension, which means stacked vertically in the 161 // LHS and stacked horizondally in the RHS. 162 // 163 // Example 164 // ======= 165 // 166 // Let's work out the data layout expected by a kernel having the 167 // following format (the struct names here are defined below in this file): 168 // 169 // KernelFormat< 170 // KernelSideFormat<CellFormat<3, 4>, 3>, 171 // KernelSideFormat<CellFormat<5, 4>, 2> 172 // > 173 // 174 // The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means: 175 // 3 cells, each cell having dimensions (width=3, depth=4), laid out in 176 // DepthMajor order (the default value, see CellFormat). In the LHS, 177 // DepthMajor means column-major, so the LHS cells are of size 3x4 in 178 // column-major order, so the LHS layout is: 179 // 180 // 0 3 6 9 181 // 1 4 7 10 182 // 2 5 8 11 183 // 12 15 18 21 184 // 13 16 19 22 185 // 14 17 20 23 186 // 24 27 30 33 187 // 25 28 31 34 188 // 26 29 32 35 189 // 190 // The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means: 191 // 2 cells each having dimensions (width=5, depth=4), laid out in 192 // DepthMajor order (the default value, see CellFormat). In the RHS, 193 // DepthMajor means row-major, so the RHS cells are of size 4x5 in 194 // row-major order, so the RHS layout is: 195 // 196 // 0 1 2 3 4 20 21 22 23 24 197 // 5 6 7 8 9 25 26 27 28 29 198 // 10 11 12 13 14 30 31 32 33 34 199 // 15 16 17 18 19 35 36 37 38 39 200 201 // CellOrder enumerates the possible storage orders (=layouts) for 202 // a cell (see explanation above). 203 enum class CellOrder { DepthMajor, WidthMajor, Diagonal }; 204 205 // CellFormat describes how data is laid 206 // out in a cell. That is, a CellOrder together with actual dimensions. 207 template <int tWidth, int tDepth, CellOrder tOrder> 208 struct CellFormat { 209 static const int kWidth = tWidth; 210 static const int kDepth = tDepth; 211 static const CellOrder kOrder = tOrder; 212 213 static const int kSize = kWidth * kDepth; 214 }; 215 216 // KernelSideFormat describes how data is laid out in a kernel side 217 // (i.e. LHS or RHS). That is, a CellFormat together with a number of 218 // cells. These cells are always stacked in the Width dimension. 219 // For example, in the LHS case, the Width dimension is the rows dimension, 220 // se we're saying that in the LHS, cells are stacked vertically. 221 // We never stack cells in the Depth dimension. 222 template <typename tCellFormat, int tCells> 223 struct KernelSideFormat { 224 typedef tCellFormat Cell; 225 static const int kCells = tCells; 226 static const int kWidth = kCells * Cell::kWidth; 227 static const int kDepth = Cell::kDepth; 228 }; 229 230 // KernelFormat describes fully the input data layout that a kernel expects. 231 // It consists of two KernelSideFormat's, one for LHS and one for RHS. 232 template <typename tLhs, typename tRhs> 233 struct KernelFormat { 234 typedef tLhs Lhs; 235 typedef tRhs Rhs; 236 237 static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, ""); 238 static const int kDepth = Lhs::Cell::kDepth; 239 static const int kRows = Lhs::Cell::kWidth * Lhs::kCells; 240 static const int kCols = Rhs::Cell::kWidth * Rhs::kCells; 241 }; 242 243 inline const char* CellOrderName(CellOrder o) { 244 switch (o) { 245 case CellOrder::DepthMajor: 246 return "DepthMajor"; 247 case CellOrder::WidthMajor: 248 return "WidthMajor"; 249 case CellOrder::Diagonal: 250 return "Diagonal"; 251 default: 252 assert(false); 253 return nullptr; 254 } 255 } 256 257 // Returns the offset into a cell, at which a given coefficient is stored. 258 template <typename CellFormat> 259 inline int OffsetIntoCell(int w, int d) { 260 switch (CellFormat::kOrder) { 261 case CellOrder::DepthMajor: 262 return w + d * CellFormat::kWidth; 263 case CellOrder::WidthMajor: 264 return d + w * CellFormat::kDepth; 265 case CellOrder::Diagonal: 266 assert(CellFormat::kWidth == CellFormat::kDepth); 267 static const int size = CellFormat::kWidth; 268 return ((size + w - d) * size + d) % (size * size); 269 default: 270 assert(false); 271 return 0; 272 } 273 } 274 275 // END code copied from gemmlowp/internal/kernel.h 276 277 #ifdef __arm__ 278 279 // This is the current standard kernel in gemmlowp, see: 280 // https://github.com/google/gemmlowp/blob/b1e2a29ff866680028f3080efc244e10e8dd7f46/internal/kernel_neon.h#L33 281 struct NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators { 282 typedef std::uint8_t OperandType; 283 typedef std::uint32_t AccumulatorType; 284 typedef KernelFormat< 285 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 286 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > 287 Format; 288 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 289 AccumulatorType* accum_ptr, int depth) { 290 asm volatile( 291 // Load 1 Rhs cell of size 2x4 292 "vld1.8 {d0}, [%[rhs_ptr]]!\n" 293 // Load 3 Lhs cells of size 4x2 each 294 "vld1.8 {d2}, [%[lhs_ptr]]!\n" 295 "vld1.8 {d4}, [%[lhs_ptr]]!\n" 296 "vld1.8 {d6}, [%[lhs_ptr]]!\n" 297 // Load accumulators 298 "mov r0, %[accum_ptr]\n" 299 "vld1.32 {d8, d9}, [r0]!\n" 300 "vld1.32 {d16, d17}, [r0]!\n" 301 "vld1.32 {d24, d25}, [r0]!\n" 302 "vld1.32 {d10, d11}, [r0]!\n" 303 "vld1.32 {d18, d19}, [r0]!\n" 304 "vld1.32 {d26, d27}, [r0]!\n" 305 "vld1.32 {d12, d13}, [r0]!\n" 306 "vld1.32 {d20, d21}, [r0]!\n" 307 "vld1.32 {d28, d29}, [r0]!\n" 308 "vld1.32 {d14, d15}, [r0]!\n" 309 "vld1.32 {d22, d23}, [r0]!\n" 310 "vld1.32 {d30, d31}, [r0]!\n" 311 312 "subs %[depth], #2\n" 313 314 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n" 315 316 GEMMLOWP_LABEL_LOOP 317 ":\n" 318 // Overview of register layout: 319 // 320 // A 2x4 cell of Rhs is stored in 16bit in d0--d1 (q0). 321 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in d2--d7 322 // (q1--q3). 323 // A 12x4 block of accumulators is stored in 32bit in q4--q15. 324 // 325 // +-----+-----+-----+-----+ 326 // |d0[0]|d0[1]|d0[2]|d0[3]| 327 // Rhs +-----+-----+-----+-----+ 328 // |d1[0]|d1[1]|d1[2]|d1[3]| 329 // +-----+-----+-----+-----+ 330 // 331 // | | | | | 332 // 333 // Lhs | | | | | 334 // 335 // +--+--+ - - - - +-----+-----+-----+-----+ 336 // |d2|d3| | q4 | q5 | q6 | q7 | 337 // |d2|d3| | q4 | q5 | q6 | q7 | 338 // |d2|d3| | q4 | q5 | q6 | q7 | 339 // |d2|d3| | q4 | q5 | q6 | q7 | 340 // +--+--+ - - - - +-----+-----+-----+-----+ 341 // |d4|d5| | q8 | q9 | q10 | q11 | 342 // |d4|d5| | q8 | q9 | q10 | q11 | 343 // |d4|d5| | q8 | q9 | q10 | q11 | 344 // |d4|d5| | q8 | q9 | q10 | q11 | 345 // +--+--+ - - - - +-----+-----+-----+-----+ 346 // |d6|d7| | q12 | q13 | q14 | q15 | 347 // |d6|d7| | q12 | q13 | q14 | q15 | 348 // |d6|d7| | q12 | q13 | q14 | q15 | 349 // |d6|d7| | q12 | q13 | q14 | q15 | 350 // +--+--+ - - - - +-----+-----+-----+-----+ 351 // 352 // Accumulator 353 354 // Expand Lhs/Rhs cells to 16 bit. 355 // Note: moving theses vmovls further down to allow for 356 // longer data pipelining helps a little on A57 but is 357 // harmful on A53 --- It looks as if A53 doesn't like 358 // interleaving vmovl's into the vmlal's. 359 "vmovl.u8 q0, d0\n" 360 "vmovl.u8 q1, d2\n" 361 "vmovl.u8 q2, d4\n" 362 "vmovl.u8 q3, d6\n" 363 364 // Multiply-accumulate, level of depth 0 365 "vmlal.u16 q4, d2, d0[0]\n" 366 "vmlal.u16 q5, d2, d0[1]\n" 367 "vmlal.u16 q6, d2, d0[2]\n" 368 "vmlal.u16 q7, d2, d0[3]\n" 369 "vldr d2, [%[lhs_ptr]]\n" 370 "vmlal.u16 q8, d4, d0[0]\n" 371 "vmlal.u16 q9, d4, d0[1]\n" 372 "vmlal.u16 q10, d4, d0[2]\n" 373 "vmlal.u16 q11, d4, d0[3]\n" 374 "vldr d4, [%[lhs_ptr], #8]\n" 375 "vmlal.u16 q12, d6, d0[0]\n" 376 "vmlal.u16 q13, d6, d0[1]\n" 377 "vmlal.u16 q14, d6, d0[2]\n" 378 "vmlal.u16 q15, d6, d0[3]\n" 379 "vldr d6, [%[lhs_ptr], #16]\n" 380 "vldr d0, [%[rhs_ptr]]\n" 381 382 // Multiply-accumulate, level of depth 1 383 "vmlal.u16 q4, d3, d1[0]\n" 384 "vmlal.u16 q5, d3, d1[1]\n" 385 "add %[lhs_ptr], #24\n" 386 "vmlal.u16 q6, d3, d1[2]\n" 387 "vmlal.u16 q7, d3, d1[3]\n" 388 "add %[rhs_ptr], #8\n" 389 "vmlal.u16 q8, d5, d1[0]\n" 390 "vmlal.u16 q9, d5, d1[1]\n" 391 "subs %[depth], #2\n" 392 "vmlal.u16 q10, d5, d1[2]\n" 393 "vmlal.u16 q11, d5, d1[3]\n" 394 "vmlal.u16 q12, d7, d1[0]\n" 395 "vmlal.u16 q13, d7, d1[1]\n" 396 "vmlal.u16 q14, d7, d1[2]\n" 397 "vmlal.u16 q15, d7, d1[3]\n" 398 399 "bne " GEMMLOWP_LABEL_LOOP "b\n" 400 401 GEMMLOWP_LABEL_AFTER_LOOP 402 ":\n" 403 404 // Expand Lhs/Rhs cells to 16 bit. 405 "vmovl.u8 q0, d0\n" 406 "vmovl.u8 q1, d2\n" 407 "vmovl.u8 q2, d4\n" 408 "vmovl.u8 q3, d6\n" 409 410 // Multiply-accumulate, level of depth 0 411 "vmlal.u16 q4, d2, d0[0]\n" 412 "vmlal.u16 q5, d2, d0[1]\n" 413 "vmlal.u16 q6, d2, d0[2]\n" 414 "vmlal.u16 q7, d2, d0[3]\n" 415 "vmlal.u16 q8, d4, d0[0]\n" 416 "vmlal.u16 q9, d4, d0[1]\n" 417 "vmlal.u16 q10, d4, d0[2]\n" 418 "vmlal.u16 q11, d4, d0[3]\n" 419 "vmlal.u16 q12, d6, d0[0]\n" 420 "vmlal.u16 q13, d6, d0[1]\n" 421 "vmlal.u16 q14, d6, d0[2]\n" 422 "vmlal.u16 q15, d6, d0[3]\n" 423 424 // Multiply-accumulate, level of depth 1 425 "vmlal.u16 q4, d3, d1[0]\n" 426 "vmlal.u16 q5, d3, d1[1]\n" 427 "vmlal.u16 q6, d3, d1[2]\n" 428 "vmlal.u16 q7, d3, d1[3]\n" 429 "vmlal.u16 q8, d5, d1[0]\n" 430 "vmlal.u16 q9, d5, d1[1]\n" 431 "vmlal.u16 q10, d5, d1[2]\n" 432 "vmlal.u16 q11, d5, d1[3]\n" 433 "vmlal.u16 q12, d7, d1[0]\n" 434 "vmlal.u16 q13, d7, d1[1]\n" 435 "vmlal.u16 q14, d7, d1[2]\n" 436 "vmlal.u16 q15, d7, d1[3]\n" 437 438 // Store accumulators 439 "mov r0, %[accum_ptr]\n" 440 "vst1.32 {d8, d9}, [r0]!\n" 441 "vst1.32 {d16, d17}, [r0]!\n" 442 "vst1.32 {d24, d25}, [r0]!\n" 443 "vst1.32 {d10, d11}, [r0]!\n" 444 "vst1.32 {d18, d19}, [r0]!\n" 445 "vst1.32 {d26, d27}, [r0]!\n" 446 "vst1.32 {d12, d13}, [r0]!\n" 447 "vst1.32 {d20, d21}, [r0]!\n" 448 "vst1.32 {d28, d29}, [r0]!\n" 449 "vst1.32 {d14, d15}, [r0]!\n" 450 "vst1.32 {d22, d23}, [r0]!\n" 451 "vst1.32 {d30, d31}, [r0]!\n" 452 : // outputs 453 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 454 [depth] "+r"(depth) 455 : // inputs 456 [accum_ptr] "r"(accum_ptr) 457 : // clobbers 458 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 459 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 460 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 461 "d28", "d29", "d30", "d31"); 462 } 463 }; 464 465 // This is Maciek Chociej's fast kernel not expanding operands, 466 // from gemmlowp/meta/. Search for 467 // mul_3x8_3x8_int32_lhsadd_rhsadd 468 // in this file: 469 // https://raw.githubusercontent.com/google/gemmlowp/e4b9d858b6637d5d0058bfa3d869d2b95864251b/meta/single_thread_gemm.h 470 struct NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand { 471 typedef std::uint8_t OperandType; 472 typedef std::uint32_t AccumulatorType; 473 typedef KernelFormat< 474 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1>, 475 KernelSideFormat<CellFormat<3, 8, CellOrder::WidthMajor>, 1> > 476 Format; 477 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 478 AccumulatorType* accum_ptr, int depth) { 479 asm volatile( 480 // Clear aggregators. 481 "vmov.i32 q0, #0\n" 482 "vmov.i32 q1, #0\n" 483 "vmov.i32 q2, #0\n" 484 "vmov.i32 q3, q0\n" 485 "vmov.i32 q4, q1\n" 486 "vmov.i32 q5, q2\n" 487 "vmov.i32 q6, q3\n" 488 "vmov.i32 q7, q4\n" 489 "vmov.i32 q8, q5\n" 490 491 // Loop head 492 GEMMLOWP_LABEL_LOOP 493 ":\n" 494 495 // Subtract counter. 496 "subs %[depth], %[depth], #8\n" 497 498 "vld1.8 {d18, d19, d20}, [%[rhs_ptr]]!\n" 499 "vld1.8 {d21, d22, d23}, [%[lhs_ptr]]!\n" 500 "vmull.u8 q12, d18, d21\n" 501 "vmull.u8 q13, d18, d22\n" 502 "vmull.u8 q14, d18, d23\n" 503 "vmull.u8 q15, d19, d21\n" 504 "vpadal.u16 q0, q12\n" 505 "vpadal.u16 q1, q13\n" 506 "vpadal.u16 q2, q14\n" 507 "vpadal.u16 q3, q15\n" 508 "vmull.u8 q12, d19, d22\n" 509 "vmull.u8 q13, d19, d23\n" 510 "vmull.u8 q14, d20, d21\n" 511 "vmull.u8 q15, d20, d22\n" 512 "vmull.u8 q9, d20, d23\n" 513 "vpadal.u16 q4, q12\n" 514 "vpadal.u16 q5, q13\n" 515 "vpadal.u16 q6, q14\n" 516 "vpadal.u16 q7, q15\n" 517 "vpadal.u16 q8, q9\n" 518 519 // Loop branch 520 "bne " GEMMLOWP_LABEL_LOOP 521 "b\n" 522 523 // Horizontal reduce aggregators, step 1 524 "vpadd.u32 d0, d0, d1\n" 525 "vpadd.u32 d2, d2, d3\n" 526 "vpadd.u32 d4, d4, d5\n" 527 "vpadd.u32 d6, d6, d7\n" 528 "vpadd.u32 d8, d8, d9\n" 529 "vpadd.u32 d10, d10, d11\n" 530 "vpadd.u32 d12, d12, d13\n" 531 "vpadd.u32 d14, d14, d15\n" 532 "vpadd.u32 d16, d16, d17\n" 533 534 // Horizontal reduce aggregators, step 2 535 "vpadd.u32 d0, d0, d2\n" 536 "vpadd.u32 d1, d4, d4\n" 537 "vpadd.u32 d6, d6, d8\n" 538 "vpadd.u32 d7, d10, d10\n" 539 "vpadd.u32 d12, d12, d14\n" 540 "vpadd.u32 d13, d16, d16\n" 541 542 // Load accumulators 543 "mov r0, %[accum_ptr]\n" 544 "vld1.32 {d2}, [r0]!\n" 545 "vld1.32 {d3[0]}, [r0]!\n" 546 547 "vld1.32 {d8}, [r0]!\n" 548 "vld1.32 {d9[0]}, [r0]!\n" 549 550 "vld1.32 {d14}, [r0]!\n" 551 "vld1.32 {d15[0]}, [r0]!\n" 552 553 // Accumulate 554 "vadd.s32 q0, q0, q1\n" 555 "vadd.s32 q3, q3, q4\n" 556 "vadd.s32 q6, q6, q7\n" 557 558 // Store accumulators 559 "mov r0, %[accum_ptr]\n" 560 "vst1.32 {d0}, [r0]!\n" 561 "vst1.32 {d1[0]}, [r0]!\n" 562 563 "vst1.32 {d6}, [r0]!\n" 564 "vst1.32 {d7[0]}, [r0]!\n" 565 566 "vst1.32 {d12}, [r0]!\n" 567 "vst1.32 {d13[0]}, [r0]!\n" 568 : // outputs 569 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 570 [depth] "+r"(depth) 571 : // inputs 572 [accum_ptr] "r"(accum_ptr) 573 : // clobbers 574 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 575 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 576 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 577 "d28", "d29", "d30", "d31"); 578 } 579 }; 580 581 // Fast kernel operating on int8 operands. 582 // It is assumed that one of the two int8 operands only takes values 583 // in [-127, 127], while the other may freely range in [-128, 127]. 584 // The issue with both operands taking the value -128 is that: 585 // -128*-128 + -128*-128 == -32768 overflows int16. 586 // Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 587 // range. That is the basic idea of this kernel. 588 struct NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits { 589 typedef std::int8_t OperandType; 590 typedef std::int32_t AccumulatorType; 591 typedef KernelFormat< 592 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, 593 KernelSideFormat<CellFormat<2, 16, CellOrder::WidthMajor>, 1> > 594 Format; 595 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 596 AccumulatorType* accum_ptr, int depth) { 597 std::size_t start_depth = 123; 598 std::size_t run_depth = depth; 599 std::size_t dst_col_stride = 4; 600 AccumulatorType* dst_ptr = accum_ptr; 601 asm volatile( 602 603 // Overview of register layout: 604 // 605 // A 2x16 block of Rhs is stored in 8 bit in d0--d3. 606 // A 4x16 block of Lhs is stored in 8 bit in d4--d7. That is only 607 // half of the register space required, so we loop over these registers 608 // twice. Only half of it, a 2x16 block, is stored in d4--d7 at 609 // any given time. 610 // 611 // A 4x2 block of accumulators is stored in q8--q15 (as 4x32 bit 612 // components which need to be horizontally-added at the end) 613 // 614 // The Lhs vectors are multiplied by the Rhs vectors with a widening 615 // multiply over the 8 first levels of depth, producing int16x8 616 // vectors of products for each position in the accumulator matrix. 617 // Here comes the special trick: since the operands are signed int8, 618 // their range being [ -2^7 , 2^7 ), their products are in range 619 // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values 620 // without any risk of overflowing int16. 621 // We thus proceed with the 8 next levels of depth, multiplying 622 // again Lhs by Rhs, accumulating into this existing int16x8 vector. 623 // 624 // Only then, having processed 16 levels of depth, do we need to 625 // horizontally add these int16x8 accumulators into the final 626 // int32x4 accumulators. 627 // 628 // As we do not have enough registers to store all 16 int16x8 629 // temporary-16bit-accumulators, we have them cycle through q4--q7. 630 // 631 // 632 // Register layout (ignoring the q4--q7 temporary 16bit accumulators): 633 // 634 // +----+----+ 635 // | d0 | d2 | 636 // | . | . | 637 // | . | . | 638 // | . | . | 639 // Rhs +----+----+ 640 // | d1 | d3 | 641 // | . | . | 642 // | . | . | 643 // | . | . | 644 // +----+----+ 645 // 646 // | | | 647 // 648 // Lhs | | | 649 // 650 // +--------+--------+ - - - - +----+----+ 651 // | d4 ... | d5 ... | | q8 | q9 | 652 // | d6 ... | d7 ... | | q10| q11| 653 // | d4 ... | d5 ... | | q12| q13| 654 // | d6 ... | d7 ... | | q14| q15| 655 // +--------+--------+ - - - - +----+----+ 656 // 657 // Accumulator 658 // 659 660 // Clear accumulators, and, interleaved with it, 661 // initial loads of the first loop iteration, 662 // taken out of the loop so that in the loop itself we have 663 // optimal streaming of data from memory. 664 "vldr d0, [%[rhs_ptr], #0]\n" 665 "vmov.i32 q8, #0\n" 666 "vldr d4, [%[lhs_ptr], #0]\n" 667 "vmov.i32 q9, #0\n" 668 "vldr d2, [%[rhs_ptr], #16]\n" 669 "vmov.i32 q10, q8\n" 670 "vldr d6, [%[lhs_ptr], #16]\n" 671 "vmov.i32 q11, q8\n" 672 "vldr d1, [%[rhs_ptr], #8]\n" 673 "vmov.i32 q12, q8\n" 674 "vldr d5, [%[lhs_ptr], #8]\n" 675 "vmov.i32 q13, q8\n" 676 "vldr d3, [%[rhs_ptr], #24]\n" 677 "vmov.i32 q14, q8\n" 678 "vldr d7, [%[lhs_ptr], #24]\n" 679 "vmov.i32 q15, q8\n" 680 681 // General loop. 682 GEMMLOWP_LABEL_LOOP 683 ":\n" 684 685 // Multiply 8 first levels of depth. 686 "vmull.s8 q4, d0, d4\n" 687 "add %[rhs_ptr], %[rhs_ptr], #32\n" 688 "vmull.s8 q5, d2, d4\n" 689 "vldr d4, [%[lhs_ptr], #32]\n" 690 "vmull.s8 q6, d0, d6\n" 691 "vmull.s8 q7, d2, d6\n" 692 "vldr d6, [%[lhs_ptr], #48]\n" 693 694 // Multiply-accumulate second-half, again into the same 695 // 16bit local accumulator registers. This is where we 696 // take advantage of having int8 instead of uint8 and therefore 697 // being able to accumulate two products into int16. 698 "vmlal.s8 q4, d1, d5\n" 699 "vmlal.s8 q5, d3, d5\n" 700 "vldr d5, [%[lhs_ptr], #40]\n" 701 "vmlal.s8 q6, d1, d7\n" 702 "vmlal.s8 q7, d3, d7\n" 703 "vldr d7, [%[lhs_ptr], #56]\n" 704 705 // Add pairwise, accumulate into 32-bit accumulators. 706 "vpadal.s16 q8, q4\n" 707 "add %[lhs_ptr], %[lhs_ptr], #64\n" 708 "vpadal.s16 q9, q5\n" 709 "subs %[run_depth], %[run_depth], #16\n" 710 "vpadal.s16 q10, q6\n" 711 "vpadal.s16 q11, q7\n" 712 713 "beq " GEMMLOWP_LABEL_AFTER_LOOP 714 "f\n" 715 716 // Multiply first half. 717 "vmull.s8 q4, d0, d4\n" 718 "vmull.s8 q5, d2, d4\n" 719 "vldr d4, [%[lhs_ptr], #0]\n" 720 "vmull.s8 q6, d0, d6\n" 721 "vldr d0, [%[rhs_ptr], #0]\n" 722 "vmull.s8 q7, d2, d6\n" 723 "vldr d2, [%[rhs_ptr], #16]\n" 724 725 // Multiply-accumulate second-half, again into the same 726 // 16bit local accumulator registers. This is where we 727 // take advantage of having int8 instead of uint8 and therefore 728 // being able to accumulate two products into int16. 729 "vmlal.s8 q4, d1, d5\n" 730 "vldr d6, [%[lhs_ptr], #16]\n" 731 "vmlal.s8 q5, d3, d5\n" 732 "vldr d5, [%[lhs_ptr], #8]\n" 733 "vmlal.s8 q6, d1, d7\n" 734 "vldr d1, [%[rhs_ptr], #8]\n" 735 "vmlal.s8 q7, d3, d7\n" 736 "vldr d3, [%[rhs_ptr], #24]\n" 737 738 // Add pairwise, accumulate into 32-bit accumulators. 739 "vpadal.s16 q12, q4\n" 740 "vldr d7, [%[lhs_ptr], #24]\n" 741 "vpadal.s16 q13, q5\n" 742 "vpadal.s16 q14, q6\n" 743 "vpadal.s16 q15, q7\n" 744 745 "b " GEMMLOWP_LABEL_LOOP "b\n" 746 747 GEMMLOWP_LABEL_AFTER_LOOP 748 ":\n" 749 750 // Multiply first half. 751 "vmull.s8 q4, d0, d4\n" 752 "vmull.s8 q5, d2, d4\n" 753 "vmull.s8 q6, d0, d6\n" 754 "vmull.s8 q7, d2, d6\n" 755 756 // Multiply-accumulate second-half, again into the same 757 // 16bit local accumulator registers. This is where we 758 // take advantage of having int8 instead of uint8 and therefore 759 // being able to accumulate two products into int16. 760 "vmlal.s8 q4, d1, d5\n" 761 "vmlal.s8 q5, d3, d5\n" 762 "vmlal.s8 q6, d1, d7\n" 763 "vmlal.s8 q7, d3, d7\n" 764 765 // Add pairwise, accumulate into 32-bit accumulators. 766 "vpadal.s16 q12, q4\n" 767 "vpadal.s16 q13, q5\n" 768 "vpadal.s16 q14, q6\n" 769 "vpadal.s16 q15, q7\n" 770 "cmp %[start_depth], #0\n" 771 772 // Reduce 32bit accumulators horizontally. 773 "vpadd.s32 d0, d16, d17\n" 774 "vpadd.s32 d1, d18, d19\n" 775 "vpadd.s32 d2, d20, d21\n" 776 "vpadd.s32 d3, d22, d23\n" 777 "vpadd.s32 d4, d24, d25\n" 778 "vpadd.s32 d5, d26, d27\n" 779 "vpadd.s32 d6, d28, d29\n" 780 "vpadd.s32 d7, d30, d31\n" 781 782 "bne " GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES 783 "f\n" 784 785 // Reduce 32bit accumulators horizontally, second pass 786 // (each pass adds pairwise. we need to add 4-wise). 787 "vpadd.s32 d8, d0, d2\n" 788 "vpadd.s32 d9, d4, d6\n" 789 "vpadd.s32 d10, d1, d3\n" 790 "vpadd.s32 d11, d5, d7\n" 791 792 "b " GEMMLOWP_LABEL_STORE "f\n" 793 794 GEMMLOWP_LABEL_ACCUMULATE_EXISTING_DST_VALUES 795 ":\n" 796 797 // Reduce 32bit accumulators horizontally, second pass 798 // (each pass adds pairwise. we need to add 4-wise), 799 // and load destination values from memory. 800 "mov r0, %[dst_ptr]\n" 801 "vld1.32 {d16, d17}, [r0]!\n" 802 "vpadd.s32 d8, d0, d2\n" 803 "vpadd.s32 d9, d4, d6\n" 804 "vld1.32 {d18, d19}, [r0]\n" 805 "vpadd.s32 d10, d1, d3\n" 806 "vpadd.s32 d11, d5, d7\n" 807 808 // Add horizontally-reduced accumulators into 809 // the values loaded from memory 810 "vadd.s32 q4, q8, q4\n" 811 "vadd.s32 q5, q9, q5\n" 812 813 GEMMLOWP_LABEL_STORE 814 ":\n" 815 // Store back into memory 816 "mov r0, %[dst_ptr]\n" 817 "vst1.32 {d8, d9}, [r0]!\n" 818 "vst1.32 {d10, d11}, [r0]\n" 819 : // outputs 820 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 821 [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth) 822 : // inputs 823 [start_depth] "r"(start_depth) 824 : // clobbers 825 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 826 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 827 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 828 "d28", "d29", "d30", "d31"); 829 } 830 }; 831 832 // We don't actually use int32*int32 in production. This is just an 833 // experiment to help dissociate the effect of integer-vs-float, from the 834 // effect of operands width. 835 struct NEON_32bit_GEMM_Int32_WithScalar { 836 typedef std::int32_t OperandType; 837 typedef std::int32_t AccumulatorType; 838 typedef KernelFormat< 839 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 840 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> > 841 Format; 842 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 843 AccumulatorType* accum_ptr, int depth) { 844 asm volatile( 845 // Load accumulators 846 "mov r0, %[accum_ptr]\n" 847 "vld1.32 {d8, d9}, [r0]!\n" 848 "vld1.32 {d16, d17}, [r0]!\n" 849 "vld1.32 {d24, d25}, [r0]!\n" 850 "vld1.32 {d10, d11}, [r0]!\n" 851 "vld1.32 {d18, d19}, [r0]!\n" 852 "vld1.32 {d26, d27}, [r0]!\n" 853 "vld1.32 {d12, d13}, [r0]!\n" 854 "vld1.32 {d20, d21}, [r0]!\n" 855 "vld1.32 {d28, d29}, [r0]!\n" 856 "vld1.32 {d14, d15}, [r0]!\n" 857 "vld1.32 {d22, d23}, [r0]!\n" 858 "vld1.32 {d30, d31}, [r0]!\n" 859 860 GEMMLOWP_LABEL_LOOP 861 ":\n" 862 863 // Load 1 Rhs cell of size 1x4 864 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n" 865 866 // Load 3 Lhs cells of size 4x1 each 867 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" 868 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n" 869 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n" 870 871 // Multiply-accumulate 872 "vmla.s32 q4, q1, d0[0]\n" 873 "vmla.s32 q5, q1, d0[1]\n" 874 "vmla.s32 q6, q1, d1[0]\n" 875 "vmla.s32 q7, q1, d1[1]\n" 876 "vmla.s32 q8, q2, d0[0]\n" 877 "vmla.s32 q9, q2, d0[1]\n" 878 "vmla.s32 q10, q2, d1[0]\n" 879 "vmla.s32 q11, q2, d1[1]\n" 880 "vmla.s32 q12, q3, d0[0]\n" 881 "vmla.s32 q13, q3, d0[1]\n" 882 "vmla.s32 q14, q3, d1[0]\n" 883 "vmla.s32 q15, q3, d1[1]\n" 884 885 // Loop. Decrement loop index (depth) by 1, since we just handled 1 886 // level of depth. 887 "subs %[depth], #1\n" 888 "bne " GEMMLOWP_LABEL_LOOP 889 "b\n" 890 891 // Store accumulators 892 "mov r0, %[accum_ptr]\n" 893 "vst1.32 {d8, d9}, [r0]!\n" 894 "vst1.32 {d16, d17}, [r0]!\n" 895 "vst1.32 {d24, d25}, [r0]!\n" 896 "vst1.32 {d10, d11}, [r0]!\n" 897 "vst1.32 {d18, d19}, [r0]!\n" 898 "vst1.32 {d26, d27}, [r0]!\n" 899 "vst1.32 {d12, d13}, [r0]!\n" 900 "vst1.32 {d20, d21}, [r0]!\n" 901 "vst1.32 {d28, d29}, [r0]!\n" 902 "vst1.32 {d14, d15}, [r0]!\n" 903 "vst1.32 {d22, d23}, [r0]!\n" 904 "vst1.32 {d30, d31}, [r0]!\n" 905 : // outputs 906 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 907 [depth] "+r"(depth) 908 : // inputs 909 [accum_ptr] "r"(accum_ptr) 910 : // clobbers 911 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 912 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 913 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 914 "d28", "d29", "d30", "d31"); 915 } 916 }; 917 918 // Not very efficient kernel, just an experiment to see what we can do 919 // without using NEON multiply-with-scalar instructions. 920 struct NEON_32bit_GEMM_Float32_MLA_WithVectorDuplicatingScalar { 921 typedef float OperandType; 922 typedef float AccumulatorType; 923 typedef KernelFormat< 924 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 925 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> > 926 Format; 927 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 928 AccumulatorType* accum_ptr, int depth) { 929 asm volatile( 930 // Load accumulators 931 "mov r0, %[accum_ptr]\n" 932 "vld1.32 {d8, d9}, [r0]!\n" 933 "vld1.32 {d16, d17}, [r0]!\n" 934 "vld1.32 {d24, d25}, [r0]!\n" 935 "vld1.32 {d10, d11}, [r0]!\n" 936 "vld1.32 {d18, d19}, [r0]!\n" 937 "vld1.32 {d26, d27}, [r0]!\n" 938 "vld1.32 {d12, d13}, [r0]!\n" 939 "vld1.32 {d20, d21}, [r0]!\n" 940 "vld1.32 {d28, d29}, [r0]!\n" 941 "vld1.32 {d14, d15}, [r0]!\n" 942 "vld1.32 {d22, d23}, [r0]!\n" 943 "vld1.32 {d30, d31}, [r0]!\n" 944 945 GEMMLOWP_LABEL_LOOP 946 ":\n" 947 948 // Load 3 Lhs cells of size 4x1 each 949 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" 950 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n" 951 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n" 952 953 // Multiply-accumulate 954 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n" 955 "vmla.f32 q4, q1, q0\n" 956 "vmla.f32 q8, q2, q0\n" 957 "vmla.f32 q12, q3, q0\n" 958 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n" 959 "vmla.f32 q5, q1, q0\n" 960 "vmla.f32 q9, q2, q0\n" 961 "vmla.f32 q13, q3, q0\n" 962 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n" 963 "vmla.f32 q6, q1, q0\n" 964 "vmla.f32 q10, q2, q0\n" 965 "vmla.f32 q14, q3, q0\n" 966 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n" 967 "vmla.f32 q7, q1, q0\n" 968 "vmla.f32 q11, q2, q0\n" 969 "vmla.f32 q15, q3, q0\n" 970 971 // Loop. Decrement loop index (depth) by 1, since we just handled 1 972 // level of depth. 973 "subs %[depth], #1\n" 974 "bne " GEMMLOWP_LABEL_LOOP 975 "b\n" 976 977 // Store accumulators 978 "mov r0, %[accum_ptr]\n" 979 "vst1.32 {d8, d9}, [r0]!\n" 980 "vst1.32 {d16, d17}, [r0]!\n" 981 "vst1.32 {d24, d25}, [r0]!\n" 982 "vst1.32 {d10, d11}, [r0]!\n" 983 "vst1.32 {d18, d19}, [r0]!\n" 984 "vst1.32 {d26, d27}, [r0]!\n" 985 "vst1.32 {d12, d13}, [r0]!\n" 986 "vst1.32 {d20, d21}, [r0]!\n" 987 "vst1.32 {d28, d29}, [r0]!\n" 988 "vst1.32 {d14, d15}, [r0]!\n" 989 "vst1.32 {d22, d23}, [r0]!\n" 990 "vst1.32 {d30, d31}, [r0]!\n" 991 : // outputs 992 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 993 [depth] "+r"(depth) 994 : // inputs 995 [accum_ptr] "r"(accum_ptr) 996 : // clobbers 997 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 998 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 999 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 1000 "d28", "d29", "d30", "d31"); 1001 } 1002 }; 1003 1004 // Not very efficient kernel, just an experiment to see what we can do 1005 // without using NEON multiply-with-scalar instructions. 1006 // This variant is relevant as on ARMv7 FMA does not have a with-scalar variant. 1007 struct NEON_32bit_GEMM_Float32_FMA_WithVectorDuplicatingScalar { 1008 typedef float OperandType; 1009 typedef float AccumulatorType; 1010 typedef KernelFormat< 1011 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 1012 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> > 1013 Format; 1014 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 1015 AccumulatorType* accum_ptr, int depth) { 1016 asm volatile( 1017 // Load accumulators 1018 "mov r0, %[accum_ptr]\n" 1019 "vld1.32 {d8, d9}, [r0]!\n" 1020 "vld1.32 {d16, d17}, [r0]!\n" 1021 "vld1.32 {d24, d25}, [r0]!\n" 1022 "vld1.32 {d10, d11}, [r0]!\n" 1023 "vld1.32 {d18, d19}, [r0]!\n" 1024 "vld1.32 {d26, d27}, [r0]!\n" 1025 "vld1.32 {d12, d13}, [r0]!\n" 1026 "vld1.32 {d20, d21}, [r0]!\n" 1027 "vld1.32 {d28, d29}, [r0]!\n" 1028 "vld1.32 {d14, d15}, [r0]!\n" 1029 "vld1.32 {d22, d23}, [r0]!\n" 1030 "vld1.32 {d30, d31}, [r0]!\n" 1031 1032 GEMMLOWP_LABEL_LOOP 1033 ":\n" 1034 1035 // Load 3 Lhs cells of size 4x1 each 1036 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" 1037 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n" 1038 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n" 1039 1040 // Multiply-accumulate 1041 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n" 1042 "vfma.f32 q4, q1, q0\n" 1043 "vfma.f32 q8, q2, q0\n" 1044 "vfma.f32 q12, q3, q0\n" 1045 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n" 1046 "vfma.f32 q5, q1, q0\n" 1047 "vfma.f32 q9, q2, q0\n" 1048 "vfma.f32 q13, q3, q0\n" 1049 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n" 1050 "vfma.f32 q6, q1, q0\n" 1051 "vfma.f32 q10, q2, q0\n" 1052 "vfma.f32 q14, q3, q0\n" 1053 "vld1.32 {d0[], d1[]}, [%[rhs_ptr]]!\n" 1054 "vfma.f32 q7, q1, q0\n" 1055 "vfma.f32 q11, q2, q0\n" 1056 "vfma.f32 q15, q3, q0\n" 1057 1058 // Loop. Decrement loop index (depth) by 1, since we just handled 1 1059 // level of depth. 1060 "subs %[depth], #1\n" 1061 "bne " GEMMLOWP_LABEL_LOOP 1062 "b\n" 1063 1064 // Store accumulators 1065 "mov r0, %[accum_ptr]\n" 1066 "vst1.32 {d8, d9}, [r0]!\n" 1067 "vst1.32 {d16, d17}, [r0]!\n" 1068 "vst1.32 {d24, d25}, [r0]!\n" 1069 "vst1.32 {d10, d11}, [r0]!\n" 1070 "vst1.32 {d18, d19}, [r0]!\n" 1071 "vst1.32 {d26, d27}, [r0]!\n" 1072 "vst1.32 {d12, d13}, [r0]!\n" 1073 "vst1.32 {d20, d21}, [r0]!\n" 1074 "vst1.32 {d28, d29}, [r0]!\n" 1075 "vst1.32 {d14, d15}, [r0]!\n" 1076 "vst1.32 {d22, d23}, [r0]!\n" 1077 "vst1.32 {d30, d31}, [r0]!\n" 1078 : // outputs 1079 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 1080 [depth] "+r"(depth) 1081 : // inputs 1082 [accum_ptr] "r"(accum_ptr) 1083 : // clobbers 1084 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 1085 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 1086 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 1087 "d28", "d29", "d30", "d31"); 1088 } 1089 }; 1090 1091 // This is the "most natural" kernel, using NEON multiply-with-scalar 1092 // instructions. 1093 struct NEON_32bit_GEMM_Float32_MLA_WithScalar { 1094 typedef float OperandType; 1095 typedef float AccumulatorType; 1096 typedef KernelFormat< 1097 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 1098 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> > 1099 Format; 1100 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 1101 AccumulatorType* accum_ptr, int depth) { 1102 asm volatile( 1103 // Load accumulators 1104 "mov r0, %[accum_ptr]\n" 1105 "vld1.32 {d8, d9}, [r0]!\n" 1106 "vld1.32 {d16, d17}, [r0]!\n" 1107 "vld1.32 {d24, d25}, [r0]!\n" 1108 "vld1.32 {d10, d11}, [r0]!\n" 1109 "vld1.32 {d18, d19}, [r0]!\n" 1110 "vld1.32 {d26, d27}, [r0]!\n" 1111 "vld1.32 {d12, d13}, [r0]!\n" 1112 "vld1.32 {d20, d21}, [r0]!\n" 1113 "vld1.32 {d28, d29}, [r0]!\n" 1114 "vld1.32 {d14, d15}, [r0]!\n" 1115 "vld1.32 {d22, d23}, [r0]!\n" 1116 "vld1.32 {d30, d31}, [r0]!\n" 1117 1118 GEMMLOWP_LABEL_LOOP 1119 ":\n" 1120 1121 // Load 1 Rhs cell of size 1x4 1122 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n" 1123 1124 // Load 3 Lhs cells of size 4x1 each 1125 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" 1126 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n" 1127 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n" 1128 1129 // Multiply-accumulate 1130 "vmla.f32 q4, q1, d0[0]\n" 1131 "vmla.f32 q5, q1, d0[1]\n" 1132 "vmla.f32 q6, q1, d1[0]\n" 1133 "vmla.f32 q7, q1, d1[1]\n" 1134 "vmla.f32 q8, q2, d0[0]\n" 1135 "vmla.f32 q9, q2, d0[1]\n" 1136 "vmla.f32 q10, q2, d1[0]\n" 1137 "vmla.f32 q11, q2, d1[1]\n" 1138 "vmla.f32 q12, q3, d0[0]\n" 1139 "vmla.f32 q13, q3, d0[1]\n" 1140 "vmla.f32 q14, q3, d1[0]\n" 1141 "vmla.f32 q15, q3, d1[1]\n" 1142 1143 // Loop. Decrement loop index (depth) by 1, since we just handled 1 1144 // level of depth. 1145 "subs %[depth], #1\n" 1146 "bne " GEMMLOWP_LABEL_LOOP 1147 "b\n" 1148 1149 // Store accumulators 1150 "mov r0, %[accum_ptr]\n" 1151 "vst1.32 {d8, d9}, [r0]!\n" 1152 "vst1.32 {d16, d17}, [r0]!\n" 1153 "vst1.32 {d24, d25}, [r0]!\n" 1154 "vst1.32 {d10, d11}, [r0]!\n" 1155 "vst1.32 {d18, d19}, [r0]!\n" 1156 "vst1.32 {d26, d27}, [r0]!\n" 1157 "vst1.32 {d12, d13}, [r0]!\n" 1158 "vst1.32 {d20, d21}, [r0]!\n" 1159 "vst1.32 {d28, d29}, [r0]!\n" 1160 "vst1.32 {d14, d15}, [r0]!\n" 1161 "vst1.32 {d22, d23}, [r0]!\n" 1162 "vst1.32 {d30, d31}, [r0]!\n" 1163 : // outputs 1164 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 1165 [depth] "+r"(depth) 1166 : // inputs 1167 [accum_ptr] "r"(accum_ptr) 1168 : // clobbers 1169 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 1170 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 1171 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 1172 "d28", "d29", "d30", "d31"); 1173 } 1174 }; 1175 1176 // Faster kernel contributed by ARM in 64bit form 1177 // (see NEON_64bit_GEMM_Float32_WithScalar_A53) then ported to 32bit code. 1178 // Tuned for A53. 1179 struct NEON_32bit_GEMM_Float32_WithScalar_A53 { 1180 typedef float OperandType; 1181 typedef float AccumulatorType; 1182 typedef KernelFormat< 1183 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 1184 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> > 1185 Format; 1186 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 1187 AccumulatorType* accum_ptr, int depth) { 1188 asm volatile( 1189 // Load accumulators 1190 "mov r0, %[accum_ptr]\n" 1191 "vld1.32 {d8, d9}, [r0]!\n" 1192 "vld1.32 {d16, d17}, [r0]!\n" 1193 "vld1.32 {d24, d25}, [r0]!\n" 1194 "vld1.32 {d10, d11}, [r0]!\n" 1195 "vld1.32 {d18, d19}, [r0]!\n" 1196 "vld1.32 {d26, d27}, [r0]!\n" 1197 "vld1.32 {d12, d13}, [r0]!\n" 1198 "vld1.32 {d20, d21}, [r0]!\n" 1199 "vld1.32 {d28, d29}, [r0]!\n" 1200 "vld1.32 {d14, d15}, [r0]!\n" 1201 "vld1.32 {d22, d23}, [r0]!\n" 1202 "vld1.32 {d30, d31}, [r0]!\n" 1203 1204 // Overview of register layout: 1205 // 1206 // A 1x4 cell of Rhs is stored in d0--d1 (q0). 1207 // A 12x1 block of 3 4x1 cells Lhs is stored in d2--d7 1208 // (q1--q3). 1209 // A 12x4 block of accumulators is stored in q4--q15. 1210 // 1211 // +-----+-----+-----+-----+ 1212 // Rhs |d0[0]|d0[1]|d1[0]|d1[1]| 1213 // +-----+-----+-----+-----+ 1214 // 1215 // | | | | | 1216 // 1217 // Lhs | | | | | 1218 // 1219 // +--+- - - - - - +-----+-----+-----+-----+ 1220 // |d2| | q4 | q5 | q6 | q7 | 1221 // |d2| | q4 | q5 | q6 | q7 | 1222 // |d3| | q4 | q5 | q6 | q7 | 1223 // |d3| | q4 | q5 | q6 | q7 | 1224 // +--+- - - - - - +-----+-----+-----+-----+ 1225 // |d4| | q8 | q9 | q10 | q11 | 1226 // |d4| | q8 | q9 | q10 | q11 | 1227 // |d5| | q8 | q9 | q10 | q11 | 1228 // |d5| | q8 | q9 | q10 | q11 | 1229 // +--+ - - - - - - +-----+-----+-----+-----+ 1230 // |d6| | q12 | q13 | q14 | q15 | 1231 // |d6| | q12 | q13 | q14 | q15 | 1232 // |d7| | q12 | q13 | q14 | q15 | 1233 // |d7| | q12 | q13 | q14 | q15 | 1234 // +--+- - - - - - +-----+-----+-----+-----+ 1235 // 1236 // Accumulator 1237 1238 // Load Rhs cell 1239 "vldr d0, [%[rhs_ptr]]\n" 1240 "ldr r2, [%[rhs_ptr], #8]\n" 1241 "ldr r3, [%[rhs_ptr], #12]\n" 1242 1243 // Load 1st Lhs Cell 1244 "vld1.32 {d2, d3}, [%[lhs_ptr]]\n" 1245 1246 GEMMLOWP_LABEL_LOOP 1247 ":\n" 1248 1249 "vldr d4, [%[lhs_ptr], #16]\n" // Load 1st half of 2nd Lhs cell 1250 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell 1251 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0 1252 "ldr r2, [%[lhs_ptr], #24]\n" // Load 2nd half of 2nd Lhs cell, part 1 1253 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1 1254 "ldr r3, [%[lhs_ptr], #28]\n" // Load 2nd half of 2nd Lhs cell, part 2 1255 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2 1256 "subs %[depth], #1\n" 1257 1258 "vldr d6, [%[lhs_ptr], #32]\n" // Load 1st half of 3rd Lhs cell 1259 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell 1260 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3 1261 "ldr r2, [%[lhs_ptr], #40]\n" // Load 2nd half of 3rd Lhs cell, part 1 1262 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0 1263 "ldr r3, [%[lhs_ptr], #44]\n" // Load 2nd half of 3rd Lhs cell, part 2 1264 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1 1265 "add %[rhs_ptr], %[rhs_ptr], #16\n" // Move forward by 1 Rhs cell 1266 1267 "vldr d2, [%[lhs_ptr], #48]\n" // Load 1st half of 1st Lhs cell of next 1268 // iteration 1269 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell 1270 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2 1271 "ldr r2, [%[lhs_ptr], #56]\n" // Load 2nd half of 1st Lhs cell of next 1272 // iter, part 1 1273 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0 1274 "ldr r3, [%[lhs_ptr], #60]\n" // Load 2nd half of 1st Lhs cell of next 1275 // iter, part 2 1276 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1 1277 "add %[lhs_ptr], %[lhs_ptr], #48\n" // Move forward by 3 Lhs cells 1278 1279 "vldr d0, [%[rhs_ptr]]\n" // Load 1st half of Rhs cell of next 1280 // iteration 1281 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next 1282 // iteration 1283 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3 1284 "ldr r2, [%[rhs_ptr], #8]\n" // Load 2nd half of Rhs cell of next 1285 // iteration, part 1 1286 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2 1287 "ldr r3, [%[rhs_ptr], #12]\n" // Load 2nd half of Rhs cell of next 1288 // iteration, part 2 1289 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3 1290 1291 // Loop branch. This will dual issue in fmla cycle 3 of the 4th block. 1292 "bne " GEMMLOWP_LABEL_LOOP 1293 "b\n" 1294 1295 // Store accumulators 1296 "mov r0, %[accum_ptr]\n" 1297 "vst1.32 {d8, d9}, [r0]!\n" 1298 "vst1.32 {d16, d17}, [r0]!\n" 1299 "vst1.32 {d24, d25}, [r0]!\n" 1300 "vst1.32 {d10, d11}, [r0]!\n" 1301 "vst1.32 {d18, d19}, [r0]!\n" 1302 "vst1.32 {d26, d27}, [r0]!\n" 1303 "vst1.32 {d12, d13}, [r0]!\n" 1304 "vst1.32 {d20, d21}, [r0]!\n" 1305 "vst1.32 {d28, d29}, [r0]!\n" 1306 "vst1.32 {d14, d15}, [r0]!\n" 1307 "vst1.32 {d22, d23}, [r0]!\n" 1308 "vst1.32 {d30, d31}, [r0]!\n" 1309 : // outputs 1310 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 1311 [depth] "+r"(depth) 1312 : // inputs 1313 [accum_ptr] "r"(accum_ptr) 1314 : // clobbers 1315 "cc", "memory", "r0", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", 1316 "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", 1317 "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", 1318 "d27", "d28", "d29", "d30", "d31"); 1319 } 1320 }; 1321 1322 struct NEON_32bit_GEMM_Float32_WithScalar_A53_depth2 { 1323 typedef float OperandType; 1324 typedef float AccumulatorType; 1325 typedef KernelFormat< 1326 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 1327 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > 1328 Format; 1329 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 1330 AccumulatorType* accum_ptr, int depth) { 1331 asm volatile( 1332 // Load accumulators 1333 "mov r0, %[accum_ptr]\n" 1334 "vld1.32 {d8, d9}, [r0]!\n" 1335 "vld1.32 {d16, d17}, [r0]!\n" 1336 "vld1.32 {d24, d25}, [r0]!\n" 1337 "vld1.32 {d10, d11}, [r0]!\n" 1338 "vld1.32 {d18, d19}, [r0]!\n" 1339 "vld1.32 {d26, d27}, [r0]!\n" 1340 "vld1.32 {d12, d13}, [r0]!\n" 1341 "vld1.32 {d20, d21}, [r0]!\n" 1342 "vld1.32 {d28, d29}, [r0]!\n" 1343 "vld1.32 {d14, d15}, [r0]!\n" 1344 "vld1.32 {d22, d23}, [r0]!\n" 1345 "vld1.32 {d30, d31}, [r0]!\n" 1346 1347 // Overview of register layout: 1348 // 1349 // A 1x4 cell of Rhs is stored in d0--d1 (q0). 1350 // A 12x1 block of 3 4x1 cells Lhs is stored in d2--d7 1351 // (q1--q3). 1352 // A 12x4 block of accumulators is stored in q4--q15. 1353 // 1354 // +-----+-----+-----+-----+ 1355 // Rhs |d0[0]|d0[1]|d1[0]|d1[1]| 1356 // +-----+-----+-----+-----+ 1357 // 1358 // | | | | | 1359 // 1360 // Lhs | | | | | 1361 // 1362 // +--+- - - - - - +-----+-----+-----+-----+ 1363 // |d2| | q4 | q5 | q6 | q7 | 1364 // |d2| | q4 | q5 | q6 | q7 | 1365 // |d3| | q4 | q5 | q6 | q7 | 1366 // |d3| | q4 | q5 | q6 | q7 | 1367 // +--+- - - - - - +-----+-----+-----+-----+ 1368 // |d4| | q8 | q9 | q10 | q11 | 1369 // |d4| | q8 | q9 | q10 | q11 | 1370 // |d5| | q8 | q9 | q10 | q11 | 1371 // |d5| | q8 | q9 | q10 | q11 | 1372 // +--+ - - - - - - +-----+-----+-----+-----+ 1373 // |d6| | q12 | q13 | q14 | q15 | 1374 // |d6| | q12 | q13 | q14 | q15 | 1375 // |d7| | q12 | q13 | q14 | q15 | 1376 // |d7| | q12 | q13 | q14 | q15 | 1377 // +--+- - - - - - +-----+-----+-----+-----+ 1378 // 1379 // Accumulator 1380 1381 // Load Rhs cell 1382 "vldr d0, [%[rhs_ptr]]\n" 1383 "ldr r2, [%[rhs_ptr], #8]\n" 1384 "ldr r3, [%[rhs_ptr], #12]\n" 1385 1386 // Load 1st Lhs Cell 1387 "vld1.32 {d2, d3}, [%[lhs_ptr]]\n" 1388 1389 // Loop head - handling 2 levels of depth at once 1390 GEMMLOWP_LABEL_LOOP 1391 ":\n" 1392 1393 // Level of depth 1 1394 1395 "vldr d4, [%[lhs_ptr], #32]\n" // Load 1st half of 2nd Lhs cell 1396 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell 1397 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0 1398 "ldr r2, [%[lhs_ptr], #40]\n" // Load 2nd half of 2nd Lhs cell, part 1 1399 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1 1400 "ldr r3, [%[lhs_ptr], #44]\n" // Load 2nd half of 2nd Lhs cell, part 2 1401 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2 1402 1403 "vldr d6, [%[lhs_ptr], #64]\n" // Load 1st half of 3rd Lhs cell 1404 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell 1405 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3 1406 "ldr r2, [%[lhs_ptr], #72]\n" // Load 2nd half of 3rd Lhs cell, part 1 1407 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0 1408 "ldr r3, [%[lhs_ptr], #76]\n" // Load 2nd half of 3rd Lhs cell, part 2 1409 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1 1410 1411 "vldr d2, [%[lhs_ptr], #16]\n" // Load 1st half of 1st Lhs cell of next 1412 // iteration 1413 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell 1414 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2 1415 "ldr r2, [%[lhs_ptr], #24]\n" // Load 2nd half of 1st Lhs cell of next 1416 // iter, part 1 1417 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0 1418 "ldr r3, [%[lhs_ptr], #28]\n" // Load 2nd half of 1st Lhs cell of next 1419 // iter, part 2 1420 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1 1421 1422 "vldr d0, [%[rhs_ptr], #16]\n" // Load 1st half of Rhs cell of next 1423 // iteration 1424 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next 1425 // iteration 1426 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3 1427 "ldr r2, [%[rhs_ptr], #24]\n" // Load 2nd half of Rhs cell of next 1428 // iteration, part 1 1429 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2 1430 "ldr r3, [%[rhs_ptr], #28]\n" // Load 2nd half of Rhs cell of next 1431 // iteration, part 2 1432 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3 1433 1434 // Level of depth 2 1435 "vldr d4, [%[lhs_ptr], #48]\n" // Load 1st half of 2nd Lhs cell 1436 "vmov d1, r2, r3\n" // Prepare 2nd half of Rhs cell 1437 "vmla.f32 q4, q1, d0[0]\n" // Multiply 1st Lhs cell with column 0 1438 "ldr r2, [%[lhs_ptr], #56]\n" // Load 2nd half of 2nd Lhs cell, part 1 1439 "vmla.f32 q5, q1, d0[1]\n" // Multiply 1st Lhs cell with column 1 1440 "ldr r3, [%[lhs_ptr], #60]\n" // Load 2nd half of 2nd Lhs cell, part 2 1441 "vmla.f32 q6, q1, d1[0]\n" // Multiply 1st Lhs cell with column 2 1442 "subs %[depth], #2\n" // Decrement depth counter 1443 1444 "vldr d6, [%[lhs_ptr], #80]\n" // Load 1st half of 3rd Lhs cell 1445 "vmov d5, r2, r3\n" // Prepare 2nd half of 2nd Lhs cell 1446 "vmla.f32 q7, q1, d1[1]\n" // Multiply 1st Lhs cell with column 3 1447 "ldr r2, [%[lhs_ptr], #88]\n" // Load 2nd half of 3rd Lhs cell, part 1 1448 "vmla.f32 q8, q2, d0[0]\n" // Multiply 2nd Lhs cell with column 0 1449 "ldr r3, [%[lhs_ptr], #92]\n" // Load 2nd half of 3rd Lhs cell, part 2 1450 "vmla.f32 q9, q2, d0[1]\n" // Multiply 2nd Lhs cell with column 1 1451 "add %[rhs_ptr], %[rhs_ptr], #32\n" // Move forward by 1 Rhs cell 1452 1453 "vldr d2, [%[lhs_ptr], #96]\n" // Load 1st half of 1st Lhs cell of next 1454 // iteration 1455 "vmov d7, r2, r3\n" // Prepare 2nd half of 3rd Lhs cell 1456 "vmla.f32 q10, q2, d1[0]\n" // Multiply 2nd Lhs cell with column 2 1457 "ldr r2, [%[lhs_ptr], #104]\n" // Load 2nd half of 1st Lhs cell of next 1458 // iter, part 1 1459 "vmla.f32 q12, q3, d0[0]\n" // Multiply 3rd Lhs cell with column 0 1460 "ldr r3, [%[lhs_ptr], #108]\n" // Load 2nd half of 1st Lhs cell of next 1461 // iter, part 2 1462 "vmla.f32 q13, q3, d0[1]\n" // Multiply 3rd Lhs cell with column 1 1463 "add %[lhs_ptr], %[lhs_ptr], #96\n" // Move forward by 3 Lhs cells 1464 1465 "vldr d0, [%[rhs_ptr]]\n" // Load 1st half of Rhs cell of next 1466 // iteration 1467 "vmov d3, r2, r3\n" // Prepare 2nd half of 1st Lhs cell of next 1468 // iteration 1469 "vmla.f32 q11, q2, d1[1]\n" // Multiply 2nd Lhs cell with column 3 1470 "ldr r2, [%[rhs_ptr], #8]\n" // Load 2nd half of Rhs cell of next 1471 // iteration, part 1 1472 "vmla.f32 q14, q3, d1[0]\n" // Multiply 3rd Lhs cell with column 2 1473 "ldr r3, [%[rhs_ptr], #12]\n" // Load 2nd half of Rhs cell of next 1474 // iteration, part 2 1475 "vmla.f32 q15, q3, d1[1]\n" // Multiply 3rd Lhs cell with column 3 1476 1477 // Loop branch. This will dual issue in fmla cycle 3 of the 4th block. 1478 //"bne loop_%=\n" 1479 "bne " GEMMLOWP_LABEL_LOOP 1480 "b\n" 1481 1482 // Store accumulators 1483 "mov r0, %[accum_ptr]\n" 1484 "vst1.32 {d8, d9}, [r0]!\n" 1485 "vst1.32 {d16, d17}, [r0]!\n" 1486 "vst1.32 {d24, d25}, [r0]!\n" 1487 "vst1.32 {d10, d11}, [r0]!\n" 1488 "vst1.32 {d18, d19}, [r0]!\n" 1489 "vst1.32 {d26, d27}, [r0]!\n" 1490 "vst1.32 {d12, d13}, [r0]!\n" 1491 "vst1.32 {d20, d21}, [r0]!\n" 1492 "vst1.32 {d28, d29}, [r0]!\n" 1493 "vst1.32 {d14, d15}, [r0]!\n" 1494 "vst1.32 {d22, d23}, [r0]!\n" 1495 "vst1.32 {d30, d31}, [r0]!\n" 1496 : // outputs 1497 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 1498 [depth] "+r"(depth) 1499 : // inputs 1500 [accum_ptr] "r"(accum_ptr) 1501 : // clobbers 1502 "cc", "memory", "r0", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", 1503 "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", 1504 "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", 1505 "d27", "d28", "d29", "d30", "d31"); 1506 } 1507 }; 1508 1509 // This rotating variant performs well when permutations (vext) can be 1510 // dual-issued with arithmetic instructions. 1511 struct NEON_32bit_GEMM_Float32_MLA_Rotating { 1512 typedef float OperandType; 1513 typedef float AccumulatorType; 1514 typedef KernelFormat< 1515 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 1516 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> > 1517 Format; 1518 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 1519 AccumulatorType* accum_ptr, int depth) { 1520 asm volatile( 1521 // Load accumulators 1522 "mov r0, %[accum_ptr]\n" 1523 "vld1.32 {d8, d9}, [r0]!\n" 1524 "vld1.32 {d16, d17}, [r0]!\n" 1525 "vld1.32 {d24, d25}, [r0]!\n" 1526 "vld1.32 {d10, d11}, [r0]!\n" 1527 "vld1.32 {d18, d19}, [r0]!\n" 1528 "vld1.32 {d26, d27}, [r0]!\n" 1529 "vld1.32 {d12, d13}, [r0]!\n" 1530 "vld1.32 {d20, d21}, [r0]!\n" 1531 "vld1.32 {d28, d29}, [r0]!\n" 1532 "vld1.32 {d14, d15}, [r0]!\n" 1533 "vld1.32 {d22, d23}, [r0]!\n" 1534 "vld1.32 {d30, d31}, [r0]!\n" 1535 1536 #define NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS \ 1537 "vtrn.32 q4, q5\n" \ 1538 "vtrn.32 q6, q7\n" \ 1539 "vswp d9, d12\n" \ 1540 "vswp d11, d14\n" \ 1541 "vtrn.32 q8, q9\n" \ 1542 "vtrn.32 q10, q11\n" \ 1543 "vswp d17, d20\n" \ 1544 "vswp d19, d22\n" \ 1545 "vtrn.32 q12, q13\n" \ 1546 "vtrn.32 q14, q15\n" \ 1547 "vswp d25, d28\n" \ 1548 "vswp d27, d30\n" 1549 1550 #define NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(a, b, c) \ 1551 NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS \ 1552 "vext.32 q5, q5, q5, #" #a \ 1553 "\n" \ 1554 "vext.32 q6, q6, q6, #" #b \ 1555 "\n" \ 1556 "vext.32 q7, q7, q7, #" #c \ 1557 "\n" \ 1558 "vext.32 q9, q9, q9, #" #a \ 1559 "\n" \ 1560 "vext.32 q10, q10, q10, #" #b \ 1561 "\n" \ 1562 "vext.32 q11, q11, q11, #" #c \ 1563 "\n" \ 1564 "vext.32 q13, q13, q13, #" #a \ 1565 "\n" \ 1566 "vext.32 q14, q14, q14, #" #b \ 1567 "\n" \ 1568 "vext.32 q15, q15, q15, #" #c \ 1569 "\n" NEON_32BIT_ROTATING_FLOAT_KERNEL_TRANSPOSE_ACCUMULATOR_CELLS 1570 1571 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(1, 2, 3) 1572 1573 //"loop_%=:\n" 1574 GEMMLOWP_LABEL_LOOP 1575 ":\n" 1576 1577 // Load 1 Rhs cell of size 1x4 1578 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n" 1579 1580 // Load 3 Lhs cells of size 4x1 each 1581 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" 1582 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n" 1583 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n" 1584 1585 // Multiply-accumulate 1586 "vmla.f32 q4, q1, q0\n" 1587 "vmla.f32 q8, q2, q0\n" 1588 "vmla.f32 q12, q3, q0\n" 1589 "vext.f32 q0, q0, q0, #1\n" 1590 "vmla.f32 q5, q1, q0\n" 1591 "vmla.f32 q9, q2, q0\n" 1592 "vmla.f32 q13, q3, q0\n" 1593 "vext.f32 q0, q0, q0, #1\n" 1594 "vmla.f32 q6, q1, q0\n" 1595 "vmla.f32 q10, q2, q0\n" 1596 "vmla.f32 q14, q3, q0\n" 1597 "vext.f32 q0, q0, q0, #1\n" 1598 "vmla.f32 q7, q1, q0\n" 1599 "vmla.f32 q11, q2, q0\n" 1600 "vmla.f32 q15, q3, q0\n" 1601 1602 // Loop. Decrement loop index (depth) by 1, since we just handled 1 1603 // level of depth. 1604 "subs %[depth], #1\n" 1605 //"bne loop_%=\n" 1606 "bne " GEMMLOWP_LABEL_LOOP 1607 "b\n" 1608 1609 // Store accumulators 1610 "mov r0, %[accum_ptr]\n" 1611 1612 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(3, 2, 1) 1613 1614 "vst1.32 {d8, d9}, [r0]!\n" 1615 "vst1.32 {d16, d17}, [r0]!\n" 1616 "vst1.32 {d24, d25}, [r0]!\n" 1617 "vst1.32 {d10, d11}, [r0]!\n" 1618 "vst1.32 {d18, d19}, [r0]!\n" 1619 "vst1.32 {d26, d27}, [r0]!\n" 1620 "vst1.32 {d12, d13}, [r0]!\n" 1621 "vst1.32 {d20, d21}, [r0]!\n" 1622 "vst1.32 {d28, d29}, [r0]!\n" 1623 "vst1.32 {d14, d15}, [r0]!\n" 1624 "vst1.32 {d22, d23}, [r0]!\n" 1625 "vst1.32 {d30, d31}, [r0]!\n" 1626 : // outputs 1627 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 1628 [depth] "+r"(depth) 1629 : // inputs 1630 [accum_ptr] "r"(accum_ptr) 1631 : // clobbers 1632 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 1633 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 1634 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 1635 "d28", "d29", "d30", "d31"); 1636 } 1637 }; 1638 1639 // This rotating variant performs well when permutations (vext) can be 1640 // dual-issued with arithmetic instructions. It is relevant as the rotating 1641 // approach removes the need for multiply-with-scalar instructions, and ARMv7 1642 // FMA does not have a with-scalar variant. 1643 struct NEON_32bit_GEMM_Float32_FMA_Rotating { 1644 typedef float OperandType; 1645 typedef float AccumulatorType; 1646 typedef KernelFormat< 1647 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 1648 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 1> > 1649 Format; 1650 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 1651 AccumulatorType* accum_ptr, int depth) { 1652 asm volatile( 1653 // Load accumulators 1654 "mov r0, %[accum_ptr]\n" 1655 "vld1.32 {d8, d9}, [r0]!\n" 1656 "vld1.32 {d16, d17}, [r0]!\n" 1657 "vld1.32 {d24, d25}, [r0]!\n" 1658 "vld1.32 {d10, d11}, [r0]!\n" 1659 "vld1.32 {d18, d19}, [r0]!\n" 1660 "vld1.32 {d26, d27}, [r0]!\n" 1661 "vld1.32 {d12, d13}, [r0]!\n" 1662 "vld1.32 {d20, d21}, [r0]!\n" 1663 "vld1.32 {d28, d29}, [r0]!\n" 1664 "vld1.32 {d14, d15}, [r0]!\n" 1665 "vld1.32 {d22, d23}, [r0]!\n" 1666 "vld1.32 {d30, d31}, [r0]!\n" 1667 1668 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(1, 2, 3) 1669 1670 //"loop_%=:\n" 1671 GEMMLOWP_LABEL_LOOP 1672 ":\n" 1673 1674 // Load 1 Rhs cell of size 1x4 1675 "vld1.32 {d0, d1}, [%[rhs_ptr]]!\n" 1676 1677 // Load 3 Lhs cells of size 4x1 each 1678 "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" 1679 "vld1.32 {d4, d5}, [%[lhs_ptr]]!\n" 1680 "vld1.32 {d6, d7}, [%[lhs_ptr]]!\n" 1681 1682 // Multiply-accumulate 1683 "vfma.f32 q4, q1, q0\n" 1684 "vfma.f32 q8, q2, q0\n" 1685 "vfma.f32 q12, q3, q0\n" 1686 "vext.f32 q0, q0, q0, #1\n" 1687 "vfma.f32 q5, q1, q0\n" 1688 "vfma.f32 q9, q2, q0\n" 1689 "vfma.f32 q13, q3, q0\n" 1690 "vext.f32 q0, q0, q0, #1\n" 1691 "vfma.f32 q6, q1, q0\n" 1692 "vfma.f32 q10, q2, q0\n" 1693 "vfma.f32 q14, q3, q0\n" 1694 "vext.f32 q0, q0, q0, #1\n" 1695 "vfma.f32 q7, q1, q0\n" 1696 "vfma.f32 q11, q2, q0\n" 1697 "vfma.f32 q15, q3, q0\n" 1698 1699 // Loop. Decrement loop index (depth) by 1, since we just handled 1 1700 // level of depth. 1701 "subs %[depth], #1\n" 1702 //"bne loop_%=\n" 1703 "bne " GEMMLOWP_LABEL_LOOP "b\n" 1704 1705 NEON_32BIT_ROTATING_FLOAT_KERNEL_ROTATE_ACCUMULATOR_CELLS(3, 2, 1) 1706 1707 // Store accumulators 1708 "mov r0, %[accum_ptr]\n" 1709 "vst1.32 {d8, d9}, [r0]!\n" 1710 "vst1.32 {d16, d17}, [r0]!\n" 1711 "vst1.32 {d24, d25}, [r0]!\n" 1712 "vst1.32 {d10, d11}, [r0]!\n" 1713 "vst1.32 {d18, d19}, [r0]!\n" 1714 "vst1.32 {d26, d27}, [r0]!\n" 1715 "vst1.32 {d12, d13}, [r0]!\n" 1716 "vst1.32 {d20, d21}, [r0]!\n" 1717 "vst1.32 {d28, d29}, [r0]!\n" 1718 "vst1.32 {d14, d15}, [r0]!\n" 1719 "vst1.32 {d22, d23}, [r0]!\n" 1720 "vst1.32 {d30, d31}, [r0]!\n" 1721 : // outputs 1722 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 1723 [depth] "+r"(depth) 1724 : // inputs 1725 [accum_ptr] "r"(accum_ptr) 1726 : // clobbers 1727 "cc", "memory", "r0", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", 1728 "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", 1729 "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", 1730 "d28", "d29", "d30", "d31"); 1731 } 1732 }; 1733 1734 #endif // __arm__ 1735 1736 #ifdef __aarch64__ 1737 1738 // This is the current standard kernel in gemmlowp, see: 1739 // https://github.com/google/gemmlowp/blob/b1e2a29ff866680028f3080efc244e10e8dd7f46/internal/kernel_neon.h#L646 1740 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators { 1741 typedef std::uint8_t OperandType; 1742 typedef std::uint32_t AccumulatorType; 1743 typedef KernelFormat< 1744 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 1745 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > 1746 Format; 1747 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 1748 AccumulatorType* accum_ptr, int depth) { 1749 asm volatile( 1750 // Load 1 Rhs cell of size 2x8 1751 "ld1 {v5.8b}, [%[rhs_ptr]], #8\n" 1752 "ld1 {v6.8b}, [%[rhs_ptr]], #8\n" 1753 1754 // Load 3 Lhs cells of size 4x2 each 1755 "ld1 {v2.8b}, [%[lhs_ptr]], #8\n" 1756 "ld1 {v3.8b}, [%[lhs_ptr]], #8\n" 1757 "ld1 {v4.8b}, [%[lhs_ptr]], #8\n" 1758 1759 "subs %w[depth], %w[depth], #2\n" 1760 1761 // Load accumulators 1762 "mov x0, %[accum_ptr]\n" 1763 "ld1 {v8.16b}, [x0], #16\n" 1764 "ld1 {v16.16b}, [x0], #16\n" 1765 "ld1 {v24.16b}, [x0], #16\n" 1766 "ld1 {v9.16b}, [x0], #16\n" 1767 "ld1 {v17.16b}, [x0], #16\n" 1768 "ld1 {v25.16b}, [x0], #16\n" 1769 "ld1 {v10.16b}, [x0], #16\n" 1770 "ld1 {v18.16b}, [x0], #16\n" 1771 "ld1 {v26.16b}, [x0], #16\n" 1772 "ld1 {v11.16b}, [x0], #16\n" 1773 "ld1 {v19.16b}, [x0], #16\n" 1774 "ld1 {v27.16b}, [x0], #16\n" 1775 "ld1 {v12.16b}, [x0], #16\n" 1776 "ld1 {v20.16b}, [x0], #16\n" 1777 "ld1 {v28.16b}, [x0], #16\n" 1778 "ld1 {v13.16b}, [x0], #16\n" 1779 "ld1 {v21.16b}, [x0], #16\n" 1780 "ld1 {v29.16b}, [x0], #16\n" 1781 "ld1 {v14.16b}, [x0], #16\n" 1782 "ld1 {v22.16b}, [x0], #16\n" 1783 "ld1 {v30.16b}, [x0], #16\n" 1784 "ld1 {v15.16b}, [x0], #16\n" 1785 "ld1 {v23.16b}, [x0], #16\n" 1786 "ld1 {v31.16b}, [x0], #16\n" 1787 1788 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n" 1789 1790 //"loop_%=:\n" 1791 GEMMLOWP_LABEL_LOOP 1792 ":\n" 1793 1794 // Overview of register layout: 1795 // 1796 // A 2x8 block of 2 2x4 cells of Rhs is stored in 16bit in v0--v1. 1797 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in v2--v4. 1798 // A 12x8 block of accumulators is stored in 32bit in v8--v31. 1799 // 1800 // +--------+--------+-----+--------+--------+ 1801 // |v0.h[0] |v0.h[1] | ... |v1.h[2] |v1.h[3] | 1802 // Rhs +--------+--------+-----+--------+--------+ 1803 // |v0.h[4] |v0.h[5] | ... |v1.h[6] |v1.h[7] | 1804 // +--------+--------+-----+--------+--------+ 1805 // 1806 // | | | | | | 1807 // 1808 // Lhs | | | | | | 1809 // 1810 // +-------+-------+ - - +--------+--------+-----+--------+--------+ 1811 // |v2.h[0]|v2.h[4]| |v8.s[0] |v9.s[0] | ... |v14.s[0]|v15.s[0]| 1812 // |v2.h[1]|v2.h[5]| |v8.s[1] |v9.s[1] | ... |v14.s[1]|v15.s[1]| 1813 // |v2.h[2]|v2.h[6]| |v8.s[2] |v9.s[2] | ... |v14.s[2]|v15.s[2]| 1814 // |v2.h[3]|v2.h[7]| |v8.s[3] |v9.s[3] | ... |v14.s[3]|v15.s[3]| 1815 // +-------+-------+ - - +--------+--------+-----+--------+--------+ 1816 // |v3.h[0]|v3.h[4]| |v16.s[0]|v17.s[0]| ... |v22.s[0]|v23.s[0]| 1817 // |v3.h[1]|v3.h[5]| |v16.s[1]|v17.s[1]| ... |v22.s[1]|v23.s[1]| 1818 // |v3.h[2]|v3.h[6]| |v16.s[2]|v17.s[2]| ... |v22.s[2]|v23.s[2]| 1819 // |v3.h[3]|v3.h[7]| |v16.s[3]|v17.s[3]| ... |v22.s[3]|v23.s[3]| 1820 // +-------+-------+ - - +--------+--------+-----+--------+--------+ 1821 // |v4.h[0]|v4.h[4]| |v24.s[0]|v25.s[0]| ... |v30.s[0]|v31.s[0]| 1822 // |v4.h[1]|v4.h[5]| |v24.s[1]|v25.s[1]| ... |v30.s[1]|v31.s[1]| 1823 // |v4.h[2]|v4.h[6]| |v24.s[2]|v25.s[2]| ... |v30.s[2]|v31.s[2]| 1824 // |v4.h[3]|v4.h[7]| |v24.s[3]|v25.s[3]| ... |v30.s[3]|v31.s[3]| 1825 // +-------+-------+ - - +--------+--------+-----+--------+--------+ 1826 // 1827 // Accumulator 1828 1829 // Expand Lhs/Rhs cells to 16 bit. 1830 "uxtl v0.8h, v5.8b\n" 1831 "ld1 {v5.8b}, [%[rhs_ptr]], #8\n" 1832 "uxtl v1.8h, v6.8b\n" 1833 "ld1 {v6.8b}, [%[rhs_ptr]], #8\n" 1834 "uxtl v2.8h, v2.8b\n" 1835 "uxtl v3.8h, v3.8b\n" 1836 "uxtl v4.8h, v4.8b\n" 1837 1838 // Multiply-accumulate, top third 1839 "umlal v8.4s, v2.4h, v0.h[0]\n" 1840 "umlal v9.4s, v2.4h, v0.h[1]\n" 1841 "umlal v10.4s, v2.4h, v0.h[2]\n" 1842 "umlal v11.4s, v2.4h, v0.h[3]\n" 1843 "umlal v12.4s, v2.4h, v1.h[0]\n" 1844 "umlal v13.4s, v2.4h, v1.h[1]\n" 1845 "umlal v14.4s, v2.4h, v1.h[2]\n" 1846 "umlal v15.4s, v2.4h, v1.h[3]\n" 1847 "umlal2 v8.4s, v2.8h, v0.h[4]\n" 1848 "umlal2 v9.4s, v2.8h, v0.h[5]\n" 1849 "umlal2 v10.4s, v2.8h, v0.h[6]\n" 1850 "umlal2 v11.4s, v2.8h, v0.h[7]\n" 1851 "umlal2 v12.4s, v2.8h, v1.h[4]\n" 1852 "umlal2 v13.4s, v2.8h, v1.h[5]\n" 1853 "umlal2 v14.4s, v2.8h, v1.h[6]\n" 1854 "umlal2 v15.4s, v2.8h, v1.h[7]\n" 1855 "ld1 {v2.8b}, [%[lhs_ptr]], #8\n" 1856 1857 // Multiply-accumulate, middle third 1858 "umlal v16.4s, v3.4h, v0.h[0]\n" 1859 "umlal v17.4s, v3.4h, v0.h[1]\n" 1860 "umlal v18.4s, v3.4h, v0.h[2]\n" 1861 "umlal v19.4s, v3.4h, v0.h[3]\n" 1862 "umlal v20.4s, v3.4h, v1.h[0]\n" 1863 "umlal v21.4s, v3.4h, v1.h[1]\n" 1864 "umlal v22.4s, v3.4h, v1.h[2]\n" 1865 "umlal v23.4s, v3.4h, v1.h[3]\n" 1866 "umlal2 v16.4s, v3.8h, v0.h[4]\n" 1867 "umlal2 v17.4s, v3.8h, v0.h[5]\n" 1868 "umlal2 v18.4s, v3.8h, v0.h[6]\n" 1869 "umlal2 v19.4s, v3.8h, v0.h[7]\n" 1870 "umlal2 v20.4s, v3.8h, v1.h[4]\n" 1871 "umlal2 v21.4s, v3.8h, v1.h[5]\n" 1872 "umlal2 v22.4s, v3.8h, v1.h[6]\n" 1873 "umlal2 v23.4s, v3.8h, v1.h[7]\n" 1874 "ld1 {v3.8b}, [%[lhs_ptr]], #8\n" 1875 1876 "subs %w[depth], %w[depth], #2\n" 1877 1878 // Multiply-accumulate, bottom third 1879 "umlal v24.4s, v4.4h, v0.h[0]\n" 1880 "umlal v25.4s, v4.4h, v0.h[1]\n" 1881 "umlal v26.4s, v4.4h, v0.h[2]\n" 1882 "umlal v27.4s, v4.4h, v0.h[3]\n" 1883 "umlal v28.4s, v4.4h, v1.h[0]\n" 1884 "umlal v29.4s, v4.4h, v1.h[1]\n" 1885 "umlal v30.4s, v4.4h, v1.h[2]\n" 1886 "umlal v31.4s, v4.4h, v1.h[3]\n" 1887 "umlal2 v24.4s, v4.8h, v0.h[4]\n" 1888 "umlal2 v25.4s, v4.8h, v0.h[5]\n" 1889 "umlal2 v26.4s, v4.8h, v0.h[6]\n" 1890 "umlal2 v27.4s, v4.8h, v0.h[7]\n" 1891 "umlal2 v28.4s, v4.8h, v1.h[4]\n" 1892 "umlal2 v29.4s, v4.8h, v1.h[5]\n" 1893 "umlal2 v30.4s, v4.8h, v1.h[6]\n" 1894 "umlal2 v31.4s, v4.8h, v1.h[7]\n" 1895 "ld1 {v4.8b}, [%[lhs_ptr]], #8\n" 1896 1897 "bne " GEMMLOWP_LABEL_LOOP "b\n" 1898 1899 GEMMLOWP_LABEL_AFTER_LOOP 1900 ":\n" 1901 1902 // Expand Lhs/Rhs cells to 16 bit. 1903 "uxtl v0.8h, v5.8b\n" 1904 "uxtl v1.8h, v6.8b\n" 1905 "uxtl v2.8h, v2.8b\n" 1906 "uxtl v3.8h, v3.8b\n" 1907 "uxtl v4.8h, v4.8b\n" 1908 1909 // Multiply-accumulate, level of depth 0 1910 "umlal v8.4s, v2.4h, v0.h[0]\n" 1911 "umlal v9.4s, v2.4h, v0.h[1]\n" 1912 "umlal v10.4s, v2.4h, v0.h[2]\n" 1913 "umlal v11.4s, v2.4h, v0.h[3]\n" 1914 "umlal v12.4s, v2.4h, v1.h[0]\n" 1915 "umlal v13.4s, v2.4h, v1.h[1]\n" 1916 "umlal v14.4s, v2.4h, v1.h[2]\n" 1917 "umlal v15.4s, v2.4h, v1.h[3]\n" 1918 "umlal v16.4s, v3.4h, v0.h[0]\n" 1919 "umlal v17.4s, v3.4h, v0.h[1]\n" 1920 "umlal v18.4s, v3.4h, v0.h[2]\n" 1921 "umlal v19.4s, v3.4h, v0.h[3]\n" 1922 "umlal v20.4s, v3.4h, v1.h[0]\n" 1923 "umlal v21.4s, v3.4h, v1.h[1]\n" 1924 "umlal v22.4s, v3.4h, v1.h[2]\n" 1925 "umlal v23.4s, v3.4h, v1.h[3]\n" 1926 "umlal v24.4s, v4.4h, v0.h[0]\n" 1927 "umlal v25.4s, v4.4h, v0.h[1]\n" 1928 "umlal v26.4s, v4.4h, v0.h[2]\n" 1929 "umlal v27.4s, v4.4h, v0.h[3]\n" 1930 "umlal v28.4s, v4.4h, v1.h[0]\n" 1931 "umlal v29.4s, v4.4h, v1.h[1]\n" 1932 "umlal v30.4s, v4.4h, v1.h[2]\n" 1933 "umlal v31.4s, v4.4h, v1.h[3]\n" 1934 1935 // Multiply-accumulate, level of depth 1 1936 "umlal2 v8.4s, v2.8h, v0.h[4]\n" 1937 "umlal2 v9.4s, v2.8h, v0.h[5]\n" 1938 "umlal2 v10.4s, v2.8h, v0.h[6]\n" 1939 "umlal2 v11.4s, v2.8h, v0.h[7]\n" 1940 "umlal2 v12.4s, v2.8h, v1.h[4]\n" 1941 "umlal2 v13.4s, v2.8h, v1.h[5]\n" 1942 "umlal2 v14.4s, v2.8h, v1.h[6]\n" 1943 "umlal2 v15.4s, v2.8h, v1.h[7]\n" 1944 "umlal2 v16.4s, v3.8h, v0.h[4]\n" 1945 "umlal2 v17.4s, v3.8h, v0.h[5]\n" 1946 "umlal2 v18.4s, v3.8h, v0.h[6]\n" 1947 "umlal2 v19.4s, v3.8h, v0.h[7]\n" 1948 "umlal2 v20.4s, v3.8h, v1.h[4]\n" 1949 "umlal2 v21.4s, v3.8h, v1.h[5]\n" 1950 "umlal2 v22.4s, v3.8h, v1.h[6]\n" 1951 "umlal2 v23.4s, v3.8h, v1.h[7]\n" 1952 "umlal2 v24.4s, v4.8h, v0.h[4]\n" 1953 "umlal2 v25.4s, v4.8h, v0.h[5]\n" 1954 "umlal2 v26.4s, v4.8h, v0.h[6]\n" 1955 "umlal2 v27.4s, v4.8h, v0.h[7]\n" 1956 "umlal2 v28.4s, v4.8h, v1.h[4]\n" 1957 "umlal2 v29.4s, v4.8h, v1.h[5]\n" 1958 "umlal2 v30.4s, v4.8h, v1.h[6]\n" 1959 "umlal2 v31.4s, v4.8h, v1.h[7]\n" 1960 1961 // Store accumulators 1962 "mov x0, %[accum_ptr]\n" 1963 "st1 {v8.16b}, [x0], #16\n" 1964 "st1 {v16.16b}, [x0], #16\n" 1965 "st1 {v24.16b}, [x0], #16\n" 1966 "st1 {v9.16b}, [x0], #16\n" 1967 "st1 {v17.16b}, [x0], #16\n" 1968 "st1 {v25.16b}, [x0], #16\n" 1969 "st1 {v10.16b}, [x0], #16\n" 1970 "st1 {v18.16b}, [x0], #16\n" 1971 "st1 {v26.16b}, [x0], #16\n" 1972 "st1 {v11.16b}, [x0], #16\n" 1973 "st1 {v19.16b}, [x0], #16\n" 1974 "st1 {v27.16b}, [x0], #16\n" 1975 "st1 {v12.16b}, [x0], #16\n" 1976 "st1 {v20.16b}, [x0], #16\n" 1977 "st1 {v28.16b}, [x0], #16\n" 1978 "st1 {v13.16b}, [x0], #16\n" 1979 "st1 {v21.16b}, [x0], #16\n" 1980 "st1 {v29.16b}, [x0], #16\n" 1981 "st1 {v14.16b}, [x0], #16\n" 1982 "st1 {v22.16b}, [x0], #16\n" 1983 "st1 {v30.16b}, [x0], #16\n" 1984 "st1 {v15.16b}, [x0], #16\n" 1985 "st1 {v23.16b}, [x0], #16\n" 1986 "st1 {v31.16b}, [x0], #16\n" 1987 : // outputs 1988 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 1989 [depth] "+r"(depth) 1990 : // inputs 1991 [accum_ptr] "r"(accum_ptr) 1992 : // clobbers 1993 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", 1994 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", 1995 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", 1996 "v28", "v29", "v30", "v31"); 1997 } 1998 }; 1999 2000 // Faster kernel by ARM. Not expanding operands before multiplication. 2001 // Tuned for A57. Compare to 2002 // NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand 2003 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57 { 2004 typedef std::uint8_t OperandType; 2005 typedef std::uint32_t AccumulatorType; 2006 typedef KernelFormat< 2007 KernelSideFormat<CellFormat<5, 16, CellOrder::WidthMajor>, 1>, 2008 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > 2009 Format; 2010 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 2011 AccumulatorType* accum_ptr, int depth) { 2012 static const int kLhsWidth = Format::Lhs::kWidth; 2013 static const int kRhsWidth = Format::Rhs::kWidth; 2014 AccumulatorType rowmajor_accumulator_buffer[kLhsWidth * kRhsWidth]; 2015 asm volatile( 2016 // Clear aggregators 2017 "dup v12.4s, wzr\n" 2018 "dup v13.4s, wzr\n" 2019 "dup v14.4s, wzr\n" 2020 "dup v15.4s, wzr\n" 2021 "dup v16.4s, wzr\n" 2022 "dup v17.4s, wzr\n" 2023 "dup v18.4s, wzr\n" 2024 "dup v19.4s, wzr\n" 2025 "dup v20.4s, wzr\n" 2026 "dup v21.4s, wzr\n" 2027 "dup v22.4s, wzr\n" 2028 "dup v23.4s, wzr\n" 2029 "dup v24.4s, wzr\n" 2030 "dup v25.4s, wzr\n" 2031 "dup v26.4s, wzr\n" 2032 "dup v27.4s, wzr\n" 2033 "dup v28.4s, wzr\n" 2034 "dup v29.4s, wzr\n" 2035 "dup v30.4s, wzr\n" 2036 "dup v31.4s, wzr\n" 2037 2038 GEMMLOWP_LABEL_LOOP 2039 ":\n" 2040 2041 // Overview of register layout: 2042 // 2043 // A 4x16 block of Rhs is stored in 8 bit in v0--v3. 2044 // A 5x16 block of Lhs is cycled through v4 and v5 in 8 bit. 2045 // 2046 // A 4x5 block of aggregators is stored in v12-v31 (as 4x32 bit 2047 // components which would need to be added at the end) 2048 // 2049 // The Lhs vectors are multiplied by the Rhs vectors with a widening 2050 // multiply to produce an intermediate result which is stored in 2051 // v6-v11. Each intermediate result is 8x16 bits so this happens 2052 // twice for each Lhs/Rhs combination (once with UMULL for elements 2053 // 0-7 and once with UMULL2 for elements 8-15). 2054 // 2055 // UADALP is used to accumulate these intermediate results into the 2056 // result aggregators. 2057 // 2058 // 2059 // 2060 // +--------+--------+--------+--------+ 2061 // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] | 2062 // Rhs +--------+--------+--------+--------+ 2063 // | ... | ... | ... | ... | 2064 // +--------+--------+--------+--------| 2065 // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]| 2066 // +--------+--------+--------+--------+ 2067 // 2068 // | | | | | 2069 // 2070 // Lhs | | | | | 2071 // 2072 // +-------+-----+--------+ - - +--------+--------+--------+--------+ 2073 // |v4.b[0]| ... |v4.b[15]| | v12.4s | v13.4s | v14.4s | v15.4s | 2074 // |v5.b[0]| ... |v5.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s | 2075 // |v4.b[0]| ... |v4.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s | 2076 // |v5.b[0]| ... |v5.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s | 2077 // |v4.b[0]| ... |v4.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s | 2078 // +-------+--------------+ - - +--------+--------+--------+--------+ 2079 // 2080 // Accumulator 2081 // 2082 // 2083 // Further possible optimisations (not tried): 2084 // - Move early loads into previous iteration (see Float32_WithScalar 2085 // for example). - Unroll loop 2x to alternate more smoothly between 2086 // v4 and v5. - A different number of temporary registers might work 2087 // better. - Pairing umull with corresponding umull2 might allow 2088 // better 2089 // register loading (e.g. at the start of the loop) 2090 // - Interleaving umull{2} and uadalp even more aggressively might 2091 // help, (not sure about latency vs. dispatch rate). 2092 // 2093 // 2094 // Start loading Rhs - further loads are interleaved amongst the 2095 // multiplies for better dispatch on A57. 2096 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" 2097 2098 // Load first Lhs vector - further loads are interleaved amongst the 2099 // multiplies 2100 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" 2101 2102 "umull v6.8h, v0.8b, v4.8b\n" 2103 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // 2nd RHS element 2104 "umull v7.8h, v1.8b, v4.8b\n" 2105 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" // 3rd RHS element 2106 "umull v8.8h, v2.8b, v4.8b\n" 2107 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" // 4th RHS element 2108 "umull v9.8h, v3.8b, v4.8b\n" 2109 "umull2 v10.8h, v0.16b, v4.16b\n" 2110 "umull2 v11.8h, v1.16b, v4.16b\n" 2111 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" // 2nd LHS element 2112 2113 "uadalp v12.4s, v6.8h\n" 2114 "umull2 v6.8h, v2.16b, v4.16b\n" 2115 "uadalp v13.4s, v7.8h\n" 2116 "umull2 v7.8h, v3.16b, v4.16b\n" 2117 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // 1st LHS element done - Reuse v4 2118 // for 3rd LHS element 2119 "uadalp v14.4s, v8.8h\n" 2120 "umull v8.8h, v0.8b, v5.8b\n" 2121 "uadalp v15.4s, v9.8h\n" 2122 "umull v9.8h, v1.8b, v5.8b\n" 2123 "uadalp v12.4s, v10.8h\n" 2124 "umull v10.8h, v2.8b, v5.8b\n" 2125 "uadalp v13.4s, v11.8h\n" 2126 "umull v11.8h, v3.8b, v5.8b\n" 2127 2128 "uadalp v14.4s, v6.8h\n" 2129 "umull2 v6.8h, v0.16b, v5.16b\n" 2130 "uadalp v15.4s, v7.8h\n" 2131 "umull2 v7.8h, v1.16b, v5.16b\n" 2132 "uadalp v16.4s, v8.8h\n" 2133 "umull2 v8.8h, v2.16b, v5.16b\n" 2134 "uadalp v17.4s, v9.8h\n" 2135 "umull2 v9.8h, v3.16b, v5.16b\n" 2136 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" // 2nd LHS element done - Reuse v5 2137 // for 4th LHS element 2138 "uadalp v18.4s, v10.8h\n" 2139 "umull v10.8h, v0.8b, v4.8b\n" 2140 "uadalp v19.4s, v11.8h\n" 2141 "umull v11.8h, v1.8b, v4.8b\n" 2142 2143 "uadalp v16.4s, v6.8h\n" 2144 "umull v6.8h, v2.8b, v4.8b\n" 2145 "uadalp v17.4s, v7.8h\n" 2146 "umull v7.8h, v3.8b, v4.8b\n" 2147 "uadalp v18.4s, v8.8h\n" 2148 "umull2 v8.8h, v0.16b, v4.16b\n" 2149 "uadalp v19.4s, v9.8h\n" 2150 "umull2 v9.8h, v1.16b, v4.16b\n" 2151 "uadalp v20.4s, v10.8h\n" 2152 "umull2 v10.8h, v2.16b, v4.16b\n" 2153 "uadalp v21.4s, v11.8h\n" 2154 "umull2 v11.8h, v3.16b, v4.16b\n" 2155 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // 3rd LHS element done - Reuse v4 2156 // for 5th LHS element 2157 2158 "uadalp v22.4s, v6.8h\n" 2159 "umull v6.8h, v0.8b, v5.8b\n" 2160 "uadalp v23.4s, v7.8h\n" 2161 "umull v7.8h, v1.8b, v5.8b\n" 2162 "uadalp v20.4s, v8.8h\n" 2163 "umull v8.8h, v2.8b, v5.8b\n" 2164 "uadalp v21.4s, v9.8h\n" 2165 "umull v9.8h, v3.8b, v5.8b\n" 2166 "uadalp v22.4s, v10.8h\n" 2167 "umull2 v10.8h, v0.16b, v5.16b\n" 2168 "uadalp v23.4s, v11.8h\n" 2169 "umull2 v11.8h, v1.16b, v5.16b\n" 2170 2171 "uadalp v24.4s, v6.8h\n" 2172 "umull2 v6.8h, v2.16b, v5.16b\n" 2173 "uadalp v25.4s, v7.8h\n" 2174 "umull2 v7.8h, v3.16b, v5.16b\n" 2175 "uadalp v26.4s, v8.8h\n" 2176 "umull v8.8h, v0.8b, v4.8b\n" 2177 "uadalp v27.4s, v9.8h\n" 2178 "umull v9.8h, v1.8b, v4.8b\n" 2179 "uadalp v24.4s, v10.8h\n" 2180 "umull v10.8h, v2.8b, v4.8b\n" 2181 "uadalp v25.4s, v11.8h\n" 2182 "umull v11.8h, v3.8b, v4.8b\n" 2183 2184 "uadalp v26.4s, v6.8h\n" 2185 "umull2 v6.8h, v0.16b, v4.16b\n" 2186 "uadalp v27.4s, v7.8h\n" 2187 "umull2 v7.8h, v1.16b, v4.16b\n" 2188 "uadalp v28.4s, v8.8h\n" 2189 "umull2 v8.8h, v2.16b, v4.16b\n" 2190 "uadalp v29.4s, v9.8h\n" 2191 "umull2 v9.8h, v3.16b, v4.16b\n" 2192 "uadalp v30.4s, v10.8h\n" 2193 "uadalp v31.4s, v11.8h\n" 2194 2195 "uadalp v28.4s, v6.8h\n" 2196 "uadalp v29.4s, v7.8h\n" 2197 // Loop. Decrement loop index (depth) by 16, since we just handled 2198 // 16 levels of depth. Do this subs a bit before the end of the loop 2199 // for better dispatch on A57. 2200 "subs %w[depth], %w[depth], #16\n" 2201 "uadalp v30.4s, v8.8h\n" 2202 "uadalp v31.4s, v9.8h\n" 2203 2204 "bne " GEMMLOWP_LABEL_LOOP 2205 "b\n" 2206 2207 // Reduce aggregators horizontally 2208 "addp v0.4s, v12.4s, v13.4s\n" 2209 "addp v1.4s, v14.4s, v15.4s\n" 2210 "addp v2.4s, v16.4s, v17.4s\n" 2211 "addp v3.4s, v18.4s, v19.4s\n" 2212 "addp v4.4s, v20.4s, v21.4s\n" 2213 "addp v5.4s, v22.4s, v23.4s\n" 2214 "addp v6.4s, v24.4s, v25.4s\n" 2215 "addp v7.4s, v26.4s, v27.4s\n" 2216 "addp v8.4s, v28.4s, v29.4s\n" 2217 "addp v9.4s, v30.4s, v31.4s\n" 2218 2219 "addp v10.4s, v0.4s, v1.4s\n" 2220 "addp v11.4s, v2.4s, v3.4s\n" 2221 "addp v12.4s, v4.4s, v5.4s\n" 2222 "addp v13.4s, v6.4s, v7.4s\n" 2223 "addp v14.4s, v8.4s, v9.4s\n" 2224 2225 "mov x0, %[rowmajor_accumulator_buffer]\n" 2226 "st1 {v10.16b}, [x0], #16\n" 2227 "st1 {v11.16b}, [x0], #16\n" 2228 "st1 {v12.16b}, [x0], #16\n" 2229 "st1 {v13.16b}, [x0], #16\n" 2230 "st1 {v14.16b}, [x0], #16\n" 2231 : // outputs 2232 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 2233 [depth] "+r"(depth) 2234 : // inputs 2235 [rowmajor_accumulator_buffer] "r"(rowmajor_accumulator_buffer) 2236 : // clobbers 2237 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", 2238 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", 2239 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", 2240 "v28", "v29", "v30", "v31"); 2241 2242 // accumulate row-major accumulators into global (column-major) accumulators 2243 for (int l = 0; l < kLhsWidth; l++) { 2244 for (int r = 0; r < kRhsWidth; r++) { 2245 accum_ptr[l + kLhsWidth * r] += 2246 rowmajor_accumulator_buffer[r + l * kRhsWidth]; 2247 } 2248 } 2249 } 2250 }; 2251 2252 // Fast kernel operating on int8 operands. 2253 // It is assumed that one of the two int8 operands only takes values 2254 // in [-127, 127], while the other may freely range in [-128, 127]. 2255 // The issue with both operands taking the value -128 is that: 2256 // -128*-128 + -128*-128 == -32768 overflows int16. 2257 // Every other expression a*b + c*d, for any int8 a,b,c,d, fits in int16 2258 // range. That is the basic idea of this kernel. 2259 struct NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits { 2260 typedef std::int8_t OperandType; 2261 typedef std::int32_t AccumulatorType; 2262 typedef KernelFormat< 2263 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, 2264 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1> > 2265 Format; 2266 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 2267 AccumulatorType* accum_ptr, int depth) { 2268 std::size_t start_depth = 123; 2269 std::size_t run_depth = depth; 2270 std::size_t dst_col_stride = 4; 2271 AccumulatorType* dst_ptr = accum_ptr; 2272 asm volatile( 2273 // Overview of register layout: 2274 // 2275 // A 4x16 block of Rhs is stored in 8 bit in v0--v3. 2276 // A 4x16 block of Lhs is stored in 8 bit in v4--v7. 2277 // 2278 // A 4x4 block of accumulators is stored in v16-v31 (as 4x32 bit 2279 // components which need to be horizontally-added at the end) 2280 // 2281 // The Lhs vectors are multiplied by the Rhs vectors with a widening 2282 // multiply over the 8 first levels of depth, producing int16x8 2283 // vectors of products for each position in the accumulator matrix. 2284 // Here comes the special trick: since the operands are signed int8, 2285 // their range being [ -2^7 , 2^7 ), their products are in range 2286 // [ -2^14 , 2^14 - 1 ), meaning that we can add two such values 2287 // without any risk of overflowing int16. 2288 // We thus proceed with the 8 next levels of depth, multiplying 2289 // again Lhs by Rhs, accumulating into this existing int16x8 vector. 2290 // 2291 // Only then, having processed 16 levels of depth, do we need to 2292 // horizontally add these int16x8 accumulators into the final 2293 // int32x4 accumulators. 2294 // 2295 // As we do not have enough registers to store all 16 int16x8 2296 // temporary-16bit-accumulators, we have them cycle through v8--v15. 2297 // 2298 // 2299 // Register layout (ignoring the v8--v15 temporary 16bit accumulators): 2300 // 2301 // +--------+--------+--------+--------+ 2302 // |v0.b[0] |v1.b[0] |v2.b[0] |v3.b[0] | 2303 // Rhs +--------+--------+--------+--------+ 2304 // | ... | ... | ... | ... | 2305 // +--------+--------+--------+--------| 2306 // |v0.b[15]|v1.b[15]|v2.b[15]|v3.b[15]| 2307 // +--------+--------+--------+--------+ 2308 // 2309 // | | | | | 2310 // 2311 // Lhs | | | | | 2312 // 2313 // +-------+-----+--------+ - - +--------+--------+--------+--------+ 2314 // |v4.b[0]| ... |v4.b[15]| | v16.4s | v17.4s | v18.4s | v19.4s | 2315 // |v5.b[0]| ... |v5.b[15]| | v20.4s | v21.4s | v22.4s | v23.4s | 2316 // |v6.b[0]| ... |v6.b[15]| | v24.4s | v25.4s | v26.4s | v27.4s | 2317 // |v7.b[0]| ... |v7.b[15]| | v28.4s | v29.4s | v30.4s | v31.4s | 2318 // +-------+--------------+ - - +--------+--------+--------+--------+ 2319 // 2320 // Accumulator 2321 // 2322 2323 // Clear accumulators 2324 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" 2325 "dup v16.4s, wzr\n" 2326 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" 2327 "dup v17.4s, wzr\n" 2328 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" 2329 "dup v18.4s, wzr\n" 2330 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" 2331 "dup v19.4s, wzr\n" 2332 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" 2333 "dup v20.4s, wzr\n" 2334 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" 2335 "dup v21.4s, wzr\n" 2336 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" 2337 "dup v22.4s, wzr\n" 2338 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" 2339 "dup v23.4s, wzr\n" 2340 "subs %[run_depth], %[run_depth], #16\n" 2341 "dup v24.4s, wzr\n" 2342 "mov x0, %[dst_ptr]\n" 2343 "dup v25.4s, wzr\n" 2344 "dup v26.4s, wzr\n" 2345 "dup v27.4s, wzr\n" 2346 "dup v28.4s, wzr\n" 2347 "dup v29.4s, wzr\n" 2348 "dup v30.4s, wzr\n" 2349 "dup v31.4s, wzr\n" 2350 2351 "smull v12.8h, v0.8b, v4.8b\n" 2352 "smull v13.8h, v1.8b, v4.8b\n" 2353 "smull v14.8h, v0.8b, v5.8b\n" 2354 "smull v15.8h, v1.8b, v5.8b\n" 2355 "smlal2 v12.8h, v0.16b, v4.16b\n" 2356 "smlal2 v13.8h, v1.16b, v4.16b\n" 2357 "smlal2 v14.8h, v0.16b, v5.16b\n" 2358 "smlal2 v15.8h, v1.16b, v5.16b\n" 2359 2360 "beq " GEMMLOWP_LABEL_AFTER_LOOP "f\n" 2361 2362 GEMMLOWP_LABEL_LOOP 2363 ":\n" 2364 2365 "subs %[run_depth], %[run_depth], #16\n" 2366 2367 "sadalp v16.4s, v12.8h\n" 2368 "smull v12.8h, v0.8b, v6.8b\n" 2369 "sadalp v17.4s, v13.8h\n" 2370 "smull v13.8h, v0.8b, v7.8b\n" 2371 "sadalp v20.4s, v14.8h\n" 2372 "smull v14.8h, v1.8b, v6.8b\n" 2373 "sadalp v21.4s, v15.8h\n" 2374 "smull v15.8h, v1.8b, v7.8b\n" 2375 "smlal2 v12.8h, v0.16b, v6.16b\n" 2376 "smlal2 v13.8h, v0.16b, v7.16b\n" 2377 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" 2378 "smlal2 v14.8h, v1.16b, v6.16b\n" 2379 "smlal2 v15.8h, v1.16b, v7.16b\n" 2380 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" 2381 "sadalp v24.4s, v12.8h\n" 2382 "smull v12.8h, v2.8b, v4.8b\n" 2383 "sadalp v28.4s, v13.8h\n" 2384 "smull v13.8h, v3.8b, v4.8b\n" 2385 "sadalp v25.4s, v14.8h\n" 2386 "smull v14.8h, v2.8b, v5.8b\n" 2387 "sadalp v29.4s, v15.8h\n" 2388 "smull v15.8h, v3.8b, v5.8b\n" 2389 "smlal2 v12.8h, v2.16b, v4.16b\n" 2390 "smlal2 v13.8h, v3.16b, v4.16b\n" 2391 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" 2392 "smlal2 v14.8h, v2.16b, v5.16b\n" 2393 "smlal2 v15.8h, v3.16b, v5.16b\n" 2394 "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" 2395 "sadalp v18.4s, v12.8h\n" 2396 "smull v12.8h, v2.8b, v6.8b\n" 2397 "sadalp v19.4s, v13.8h\n" 2398 "smull v13.8h, v2.8b, v7.8b\n" 2399 "sadalp v22.4s, v14.8h\n" 2400 "smull v14.8h, v3.8b, v6.8b\n" 2401 "sadalp v23.4s, v15.8h\n" 2402 "smull v15.8h, v3.8b, v7.8b\n" 2403 "smlal2 v12.8h, v2.16b, v6.16b\n" 2404 "smlal2 v13.8h, v2.16b, v7.16b\n" 2405 "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" 2406 "smlal2 v14.8h, v3.16b, v6.16b\n" 2407 "ld1 {v6.16b}, [%[lhs_ptr]], #16\n" 2408 "smlal2 v15.8h, v3.16b, v7.16b\n" 2409 "ld1 {v7.16b}, [%[lhs_ptr]], #16\n" 2410 "sadalp v26.4s, v12.8h\n" 2411 "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" 2412 "smull v12.8h, v0.8b, v4.8b\n" 2413 "sadalp v30.4s, v13.8h\n" 2414 "smull v13.8h, v1.8b, v4.8b\n" 2415 "sadalp v27.4s, v14.8h\n" 2416 "smull v14.8h, v0.8b, v5.8b\n" 2417 "sadalp v31.4s, v15.8h\n" 2418 "smull v15.8h, v1.8b, v5.8b\n" 2419 "smlal2 v12.8h, v0.16b, v4.16b\n" 2420 "smlal2 v13.8h, v1.16b, v4.16b\n" 2421 "smlal2 v14.8h, v0.16b, v5.16b\n" 2422 "smlal2 v15.8h, v1.16b, v5.16b\n" 2423 2424 "bne " GEMMLOWP_LABEL_LOOP "b\n" 2425 2426 GEMMLOWP_LABEL_AFTER_LOOP 2427 ":\n" 2428 2429 // Load accumulators from memory 2430 "ld1 {v8.16b}, [x0], #16\n" 2431 "ld1 {v9.16b}, [x0], #16\n" 2432 "ld1 {v10.16b}, [x0], #16\n" 2433 "ld1 {v11.16b}, [x0], #16\n" 2434 "mov x0, %[dst_ptr]\n" 2435 2436 // Do the remaining arithmetic for the 16 last levels of depths. 2437 // All the operands are already loaded. 2438 "sadalp v16.4s, v12.8h\n" 2439 "smull v12.8h, v0.8b, v6.8b\n" 2440 "sadalp v17.4s, v13.8h\n" 2441 "smull v13.8h, v0.8b, v7.8b\n" 2442 "sadalp v20.4s, v14.8h\n" 2443 "smull v14.8h, v1.8b, v6.8b\n" 2444 "sadalp v21.4s, v15.8h\n" 2445 "smull v15.8h, v1.8b, v7.8b\n" 2446 "smlal2 v12.8h, v0.16b, v6.16b\n" 2447 "smlal2 v13.8h, v0.16b, v7.16b\n" 2448 "smlal2 v14.8h, v1.16b, v6.16b\n" 2449 "smlal2 v15.8h, v1.16b, v7.16b\n" 2450 "sadalp v24.4s, v12.8h\n" 2451 "smull v12.8h, v2.8b, v4.8b\n" 2452 "sadalp v28.4s, v13.8h\n" 2453 "smull v13.8h, v3.8b, v4.8b\n" 2454 "sadalp v25.4s, v14.8h\n" 2455 "smull v14.8h, v2.8b, v5.8b\n" 2456 "sadalp v29.4s, v15.8h\n" 2457 "smull v15.8h, v3.8b, v5.8b\n" 2458 "smlal2 v12.8h, v2.16b, v4.16b\n" 2459 "smlal2 v13.8h, v3.16b, v4.16b\n" 2460 "smlal2 v14.8h, v2.16b, v5.16b\n" 2461 "smlal2 v15.8h, v3.16b, v5.16b\n" 2462 "sadalp v18.4s, v12.8h\n" 2463 "smull v12.8h, v2.8b, v6.8b\n" 2464 "sadalp v19.4s, v13.8h\n" 2465 "smull v13.8h, v2.8b, v7.8b\n" 2466 "sadalp v22.4s, v14.8h\n" 2467 "smull v14.8h, v3.8b, v6.8b\n" 2468 "sadalp v23.4s, v15.8h\n" 2469 "smull v15.8h, v3.8b, v7.8b\n" 2470 "smlal2 v12.8h, v2.16b, v6.16b\n" 2471 "smlal2 v13.8h, v2.16b, v7.16b\n" 2472 "smlal2 v14.8h, v3.16b, v6.16b\n" 2473 "smlal2 v15.8h, v3.16b, v7.16b\n" 2474 "sadalp v26.4s, v12.8h\n" 2475 "sadalp v30.4s, v13.8h\n" 2476 "sadalp v27.4s, v14.8h\n" 2477 "sadalp v31.4s, v15.8h\n" 2478 2479 // Reduce aggregators horizontally 2480 "addp v0.4s, v16.4s, v20.4s\n" 2481 "addp v1.4s, v17.4s, v21.4s\n" 2482 "addp v2.4s, v18.4s, v22.4s\n" 2483 "addp v3.4s, v19.4s, v23.4s\n" 2484 "addp v4.4s, v24.4s, v28.4s\n" 2485 "addp v5.4s, v25.4s, v29.4s\n" 2486 "addp v6.4s, v26.4s, v30.4s\n" 2487 "addp v7.4s, v27.4s, v31.4s\n" 2488 2489 "addp v12.4s, v0.4s, v4.4s\n" 2490 "addp v13.4s, v1.4s, v5.4s\n" 2491 "addp v14.4s, v2.4s, v6.4s\n" 2492 "addp v15.4s, v3.4s, v7.4s\n" 2493 2494 // Add to the accumulators loaded from memory 2495 "add v8.4s, v8.4s, v12.4s\n" 2496 "add v9.4s, v9.4s, v13.4s\n" 2497 "add v10.4s, v10.4s, v14.4s\n" 2498 "add v11.4s, v11.4s, v15.4s\n" 2499 2500 // Store accumulators back to memory 2501 "st1 {v8.16b}, [x0], #16\n" 2502 "st1 {v9.16b}, [x0], #16\n" 2503 "st1 {v10.16b}, [x0], #16\n" 2504 "st1 {v11.16b}, [x0], #16\n" 2505 : // outputs 2506 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 2507 [dst_ptr] "+r"(dst_ptr), [run_depth] "+r"(run_depth), 2508 [dst_col_stride] "+r"(dst_col_stride) 2509 : // inputs 2510 [start_depth] "r"(start_depth) 2511 : // clobbers 2512 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", 2513 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", 2514 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", 2515 "v28", "v29", "v30", "v31"); 2516 } 2517 }; 2518 2519 #ifdef __ARM_FEATURE_DOTPROD 2520 // Kernels utilizing the Armv8.2 Dot Product extension. 2521 // 2522 // The dot product instructions work by taking 4 consecutive 8-bit depth 2523 // values from each operand, multiplying the 4 pairs together and 2524 // accumulating all the results into the corresponding 32-bit accumulator 2525 // lane. As such, the operation is identical to a 32-bit instruction (like 2526 // FMLA used in SGEMM), except that 4 depth values are processed at a time 2527 // instead of 1. 2528 2529 // Thus, this first kernel is a carbon copy of 2530 // "NEON_64bit_GEMM_Float32_WithScalar_A57" (which should provide good 2531 // performance for most processors) below with the opcode (fmla -> udot) and 2532 // types (float32 -> uint8/uint32) changed. 2533 // 2534 // A signed version of this kernel could be produced by replacing "udot" 2535 // with "sdot" - performance should be identical to this udot kernel. 2536 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct { 2537 typedef std::uint8_t OperandType; 2538 typedef std::uint32_t AccumulatorType; 2539 typedef KernelFormat< 2540 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>, 2541 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> > 2542 Format; 2543 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 2544 AccumulatorType* accum_ptr, int depth) { 2545 asm volatile( 2546 // Load accumulators 2547 "mov x0, %[accum_ptr]\n" 2548 "ld1 {v8.4s}, [x0], #16\n" 2549 "ld1 {v16.4s}, [x0], #16\n" 2550 "ld1 {v24.4s}, [x0], #16\n" 2551 "ld1 {v9.4s}, [x0], #16\n" 2552 "ld1 {v17.4s}, [x0], #16\n" 2553 "ld1 {v25.4s}, [x0], #16\n" 2554 "ld1 {v10.4s}, [x0], #16\n" 2555 "ld1 {v18.4s}, [x0], #16\n" 2556 "ld1 {v26.4s}, [x0], #16\n" 2557 "ld1 {v11.4s}, [x0], #16\n" 2558 "ld1 {v19.4s}, [x0], #16\n" 2559 "ld1 {v27.4s}, [x0], #16\n" 2560 "ld1 {v12.4s}, [x0], #16\n" 2561 "ld1 {v20.4s}, [x0], #16\n" 2562 "ld1 {v28.4s}, [x0], #16\n" 2563 "ld1 {v13.4s}, [x0], #16\n" 2564 "ld1 {v21.4s}, [x0], #16\n" 2565 "ld1 {v29.4s}, [x0], #16\n" 2566 "ld1 {v14.4s}, [x0], #16\n" 2567 "ld1 {v22.4s}, [x0], #16\n" 2568 "ld1 {v30.4s}, [x0], #16\n" 2569 "ld1 {v15.4s}, [x0], #16\n" 2570 "ld1 {v23.4s}, [x0], #16\n" 2571 "ld1 {v31.4s}, [x0], #16\n" 2572 2573 // The start of the loop assumes first Rhs cell is already loaded, so 2574 // do it here for first iteration. 2575 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" 2576 2577 // And the same for the first Lhs cell. 2578 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" 2579 2580 GEMMLOWP_LABEL_LOOP 2581 ":\n" 2582 2583 // Start the MACs at the head of the loop - 1st cell from each side 2584 // already loaded. 2585 "udot v8.4s, v2.16b, v0.b[0]\n" 2586 "udot v9.4s, v2.16b, v0.b[1]\n" 2587 "ld1 {v1.16b}, [%[rhs_ptr]], #16\n" // Load second Rhs cell. 2588 "udot v10.4s, v2.16b, v0.b[2]\n" 2589 "udot v11.4s, v2.16b, v0.b[3]\n" 2590 "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" // Load second Lhs cell. 2591 "udot v12.4s, v2.16b, v1.b[0]\n" 2592 "udot v13.4s, v2.16b, v1.b[1]\n" 2593 "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" // Load third Lhs cell. 2594 "udot v14.4s, v2.16b, v1.b[2]\n" 2595 "udot v15.4s, v2.16b, v1.b[3]\n" 2596 "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load 2597 // for the next iteration early. 2598 "udot v16.4s, v3.16b, v0.b[0]\n" 2599 "udot v17.4s, v3.16b, v0.b[1]\n" 2600 "udot v18.4s, v3.16b, v0.b[2]\n" 2601 "udot v19.4s, v3.16b, v0.b[3]\n" 2602 "udot v20.4s, v3.16b, v1.b[0]\n" 2603 "udot v21.4s, v3.16b, v1.b[1]\n" 2604 "udot v22.4s, v3.16b, v1.b[2]\n" 2605 "udot v23.4s, v3.16b, v1.b[3]\n" 2606 "udot v24.4s, v4.16b, v0.b[0]\n" 2607 "udot v25.4s, v4.16b, v0.b[1]\n" 2608 "udot v26.4s, v4.16b, v0.b[2]\n" 2609 "udot v27.4s, v4.16b, v0.b[3]\n" 2610 "ld1 {v0.16b}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell - 2611 // load for the next iteration early. 2612 "udot v28.4s, v4.16b, v1.b[0]\n" 2613 "udot v29.4s, v4.16b, v1.b[1]\n" 2614 2615 // Loop. Decrement loop index (depth) by 4 as udot processes 4 2616 // depth values. 2617 "subs %w[depth], %w[depth], #4\n" 2618 "udot v30.4s, v4.16b, v1.b[2]\n" 2619 "udot v31.4s, v4.16b, v1.b[3]\n" 2620 2621 "bne " GEMMLOWP_LABEL_LOOP 2622 "b\n" 2623 2624 // Store accumulators 2625 "mov x0, %[accum_ptr]\n" 2626 "st1 {v8.16b}, [x0], #16\n" 2627 "st1 {v16.16b}, [x0], #16\n" 2628 "st1 {v24.16b}, [x0], #16\n" 2629 "st1 {v9.16b}, [x0], #16\n" 2630 "st1 {v17.16b}, [x0], #16\n" 2631 "st1 {v25.16b}, [x0], #16\n" 2632 "st1 {v10.16b}, [x0], #16\n" 2633 "st1 {v18.16b}, [x0], #16\n" 2634 "st1 {v26.16b}, [x0], #16\n" 2635 "st1 {v11.16b}, [x0], #16\n" 2636 "st1 {v19.16b}, [x0], #16\n" 2637 "st1 {v27.16b}, [x0], #16\n" 2638 "st1 {v12.16b}, [x0], #16\n" 2639 "st1 {v20.16b}, [x0], #16\n" 2640 "st1 {v28.16b}, [x0], #16\n" 2641 "st1 {v13.16b}, [x0], #16\n" 2642 "st1 {v21.16b}, [x0], #16\n" 2643 "st1 {v29.16b}, [x0], #16\n" 2644 "st1 {v14.16b}, [x0], #16\n" 2645 "st1 {v22.16b}, [x0], #16\n" 2646 "st1 {v30.16b}, [x0], #16\n" 2647 "st1 {v15.16b}, [x0], #16\n" 2648 "st1 {v23.16b}, [x0], #16\n" 2649 "st1 {v31.16b}, [x0], #16\n" 2650 : // outputs 2651 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 2652 [depth] "+r"(depth) 2653 : // inputs 2654 [accum_ptr] "r"(accum_ptr) 2655 : // clobbers 2656 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", 2657 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", 2658 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", 2659 "v28", "v29", "v30", "v31"); 2660 } 2661 }; 2662 2663 // As above, except tuned for Cortex-A55r1. 2664 // 2665 // Similarly, this is a clone of NEON_64bit_GEMM_Float32_WithScalar_A55r1 2666 // with the names changed. 2667 struct NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1 { 2668 typedef std::uint8_t OperandType; 2669 typedef std::uint32_t AccumulatorType; 2670 typedef KernelFormat< 2671 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 3>, 2672 KernelSideFormat<CellFormat<4, 4, CellOrder::WidthMajor>, 2> > 2673 Format; 2674 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 2675 AccumulatorType* accum_ptr, int depth) { 2676 asm volatile( 2677 // Load accumulators 2678 "mov x0, %[accum_ptr]\n" 2679 "ld1 {v8.4s}, [x0], #16\n" 2680 "ld1 {v16.4s}, [x0], #16\n" 2681 "ld1 {v24.4s}, [x0], #16\n" 2682 "ld1 {v9.4s}, [x0], #16\n" 2683 "ld1 {v17.4s}, [x0], #16\n" 2684 "ld1 {v25.4s}, [x0], #16\n" 2685 "ld1 {v10.4s}, [x0], #16\n" 2686 "ld1 {v18.4s}, [x0], #16\n" 2687 "ld1 {v26.4s}, [x0], #16\n" 2688 "ld1 {v11.4s}, [x0], #16\n" 2689 "ld1 {v19.4s}, [x0], #16\n" 2690 "ld1 {v27.4s}, [x0], #16\n" 2691 "ld1 {v12.4s}, [x0], #16\n" 2692 "ld1 {v20.4s}, [x0], #16\n" 2693 "ld1 {v28.4s}, [x0], #16\n" 2694 "ld1 {v13.4s}, [x0], #16\n" 2695 "ld1 {v21.4s}, [x0], #16\n" 2696 "ld1 {v29.4s}, [x0], #16\n" 2697 "ld1 {v14.4s}, [x0], #16\n" 2698 "ld1 {v22.4s}, [x0], #16\n" 2699 "ld1 {v30.4s}, [x0], #16\n" 2700 "ld1 {v15.4s}, [x0], #16\n" 2701 "ld1 {v23.4s}, [x0], #16\n" 2702 "ld1 {v31.4s}, [x0], #16\n" 2703 2704 // For details on how this kernel works, see the Float32 kernel below. 2705 2706 "ldr d0, [%[rhs_ptr]]\n" 2707 "ldr x18, [%[rhs_ptr], #8]\n" 2708 2709 "ldr q2, [%[lhs_ptr]]\n" 2710 "ldr q3, [%[lhs_ptr], #16]\n" 2711 2712 GEMMLOWP_LABEL_LOOP 2713 ":\n" 2714 2715 "udot v8.4s, v2.16b, v0.b[0]\n" 2716 "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1 2717 "udot v9.4s, v2.16b, v0.b[1]\n" 2718 "ins v0.d[1], x18\n" // Finish loading v0 2719 "udot v16.4s, v3.16b, v0.b[0]\n" // out of sequence - used to reduce load/use pressure. 2720 "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register 2721 "udot v17.4s, v3.16b, v0.b[1]\n" // out of sequence - used to reduce load/use pressure. 2722 "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer. 2723 "udot v10.4s, v2.16b, v0.b[2]\n" 2724 "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4 2725 "udot v11.4s, v2.16b, v0.b[3]\n" 2726 "ins v1.d[1], x18\n" // Finish loading v1 2727 "udot v12.4s, v2.16b, v1.b[0]\n" 2728 "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register 2729 "udot v13.4s, v2.16b, v1.b[1]\n" 2730 "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer. 2731 "udot v14.4s, v2.16b, v1.b[2]\n" 2732 2733 "udot v15.4s, v2.16b, v1.b[3]\n" 2734 "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time) 2735 "udot v18.4s, v3.16b, v0.b[2]\n" 2736 "ins v4.d[1], x18\n" // Finish loading v4 2737 "udot v19.4s, v3.16b, v0.b[3]\n" 2738 "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register 2739 "udot v20.4s, v3.16b, v1.b[0]\n" 2740 "subs %w[depth], %w[depth], #4\n" 2741 "udot v21.4s, v3.16b, v1.b[1]\n" 2742 2743 "udot v22.4s, v3.16b, v1.b[2]\n" 2744 2745 "udot v23.4s, v3.16b, v1.b[3]\n" 2746 "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time) 2747 "udot v24.4s, v4.16b, v0.b[0]\n" 2748 "ins v2.d[1], x18\n" // Finish loading next v2 2749 "udot v25.4s, v4.16b, v0.b[1]\n" 2750 "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register 2751 "udot v26.4s, v4.16b, v0.b[2]\n" 2752 2753 "udot v27.4s, v4.16b, v0.b[3]\n" 2754 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time) 2755 "udot v28.4s, v4.16b, v1.b[0]\n" 2756 "ins v3.d[1], x18\n" // Finish loading next v3 2757 "udot v29.4s, v4.16b, v1.b[1]\n" 2758 "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register 2759 "udot v30.4s, v4.16b, v1.b[2]\n" 2760 2761 "udot v31.4s, v4.16b, v1.b[3]\n" 2762 "bne " GEMMLOWP_LABEL_LOOP "b\n" 2763 2764 // Store accumulators 2765 "mov x0, %[accum_ptr]\n" 2766 "st1 {v8.4s}, [x0], #16\n" 2767 "st1 {v16.4s}, [x0], #16\n" 2768 "st1 {v24.4s}, [x0], #16\n" 2769 "st1 {v9.4s}, [x0], #16\n" 2770 "st1 {v17.4s}, [x0], #16\n" 2771 "st1 {v25.4s}, [x0], #16\n" 2772 "st1 {v10.4s}, [x0], #16\n" 2773 "st1 {v18.4s}, [x0], #16\n" 2774 "st1 {v26.4s}, [x0], #16\n" 2775 "st1 {v11.4s}, [x0], #16\n" 2776 "st1 {v19.4s}, [x0], #16\n" 2777 "st1 {v27.4s}, [x0], #16\n" 2778 "st1 {v12.4s}, [x0], #16\n" 2779 "st1 {v20.4s}, [x0], #16\n" 2780 "st1 {v28.4s}, [x0], #16\n" 2781 "st1 {v13.4s}, [x0], #16\n" 2782 "st1 {v21.4s}, [x0], #16\n" 2783 "st1 {v29.4s}, [x0], #16\n" 2784 "st1 {v14.4s}, [x0], #16\n" 2785 "st1 {v22.4s}, [x0], #16\n" 2786 "st1 {v30.4s}, [x0], #16\n" 2787 "st1 {v15.4s}, [x0], #16\n" 2788 "st1 {v23.4s}, [x0], #16\n" 2789 "st1 {v31.4s}, [x0], #16\n" 2790 : // outputs 2791 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 2792 [depth] "+r"(depth) 2793 : // inputs 2794 [accum_ptr] "r"(accum_ptr) 2795 : // clobbers 2796 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6", 2797 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", 2798 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", 2799 "v27", "v28", "v29", "v30", "v31"); 2800 } 2801 }; 2802 #endif // __ARM_FEATURE_DOTPROD 2803 2804 // We don't actually use int32*int32 in production. This is just an 2805 // experiment to help dissociate the effect of integer-vs-float, from the 2806 // effect of operands width. 2807 struct NEON_64bit_GEMM_Int32_WithScalar { 2808 typedef std::int32_t OperandType; 2809 typedef std::int32_t AccumulatorType; 2810 typedef KernelFormat< 2811 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 2812 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> > 2813 Format; 2814 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 2815 AccumulatorType* accum_ptr, int depth) { 2816 asm volatile( 2817 // Load accumulators 2818 "mov x0, %[accum_ptr]\n" 2819 "ld1 {v8.16b}, [x0], #16\n" 2820 "ld1 {v16.16b}, [x0], #16\n" 2821 "ld1 {v24.16b}, [x0], #16\n" 2822 "ld1 {v9.16b}, [x0], #16\n" 2823 "ld1 {v17.16b}, [x0], #16\n" 2824 "ld1 {v25.16b}, [x0], #16\n" 2825 "ld1 {v10.16b}, [x0], #16\n" 2826 "ld1 {v18.16b}, [x0], #16\n" 2827 "ld1 {v26.16b}, [x0], #16\n" 2828 "ld1 {v11.16b}, [x0], #16\n" 2829 "ld1 {v19.16b}, [x0], #16\n" 2830 "ld1 {v27.16b}, [x0], #16\n" 2831 "ld1 {v12.16b}, [x0], #16\n" 2832 "ld1 {v20.16b}, [x0], #16\n" 2833 "ld1 {v28.16b}, [x0], #16\n" 2834 "ld1 {v13.16b}, [x0], #16\n" 2835 "ld1 {v21.16b}, [x0], #16\n" 2836 "ld1 {v29.16b}, [x0], #16\n" 2837 "ld1 {v14.16b}, [x0], #16\n" 2838 "ld1 {v22.16b}, [x0], #16\n" 2839 "ld1 {v30.16b}, [x0], #16\n" 2840 "ld1 {v15.16b}, [x0], #16\n" 2841 "ld1 {v23.16b}, [x0], #16\n" 2842 "ld1 {v31.16b}, [x0], #16\n" 2843 2844 GEMMLOWP_LABEL_LOOP 2845 ":\n" 2846 2847 // Load 2 Rhs cell of size 1x4 each 2848 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n" 2849 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" 2850 2851 // Load 3 Lhs cells of size 4x1 each 2852 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" 2853 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n" 2854 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" 2855 2856 // Multiply-accumulate 2857 "mla v8.4s, v2.4s, v0.s[0]\n" 2858 "mla v9.4s, v2.4s, v0.s[1]\n" 2859 "mla v10.4s, v2.4s, v0.s[2]\n" 2860 "mla v11.4s, v2.4s, v0.s[3]\n" 2861 "mla v12.4s, v2.4s, v1.s[0]\n" 2862 "mla v13.4s, v2.4s, v1.s[1]\n" 2863 "mla v14.4s, v2.4s, v1.s[2]\n" 2864 "mla v15.4s, v2.4s, v1.s[3]\n" 2865 "mla v16.4s, v3.4s, v0.s[0]\n" 2866 "mla v17.4s, v3.4s, v0.s[1]\n" 2867 "mla v18.4s, v3.4s, v0.s[2]\n" 2868 "mla v19.4s, v3.4s, v0.s[3]\n" 2869 "mla v20.4s, v3.4s, v1.s[0]\n" 2870 "mla v21.4s, v3.4s, v1.s[1]\n" 2871 "mla v22.4s, v3.4s, v1.s[2]\n" 2872 "mla v23.4s, v3.4s, v1.s[3]\n" 2873 "mla v24.4s, v4.4s, v0.s[0]\n" 2874 "mla v25.4s, v4.4s, v0.s[1]\n" 2875 "mla v26.4s, v4.4s, v0.s[2]\n" 2876 "mla v27.4s, v4.4s, v0.s[3]\n" 2877 "mla v28.4s, v4.4s, v1.s[0]\n" 2878 "mla v29.4s, v4.4s, v1.s[1]\n" 2879 "mla v30.4s, v4.4s, v1.s[2]\n" 2880 "mla v31.4s, v4.4s, v1.s[3]\n" 2881 2882 // Loop. Decrement loop index (depth) by 1, since we just handled 1 2883 // level of depth. 2884 "subs %w[depth], %w[depth], #1\n" 2885 "bne " GEMMLOWP_LABEL_LOOP 2886 "b\n" 2887 2888 // Store accumulators 2889 "mov x0, %[accum_ptr]\n" 2890 "st1 {v8.16b}, [x0], #16\n" 2891 "st1 {v16.16b}, [x0], #16\n" 2892 "st1 {v24.16b}, [x0], #16\n" 2893 "st1 {v9.16b}, [x0], #16\n" 2894 "st1 {v17.16b}, [x0], #16\n" 2895 "st1 {v25.16b}, [x0], #16\n" 2896 "st1 {v10.16b}, [x0], #16\n" 2897 "st1 {v18.16b}, [x0], #16\n" 2898 "st1 {v26.16b}, [x0], #16\n" 2899 "st1 {v11.16b}, [x0], #16\n" 2900 "st1 {v19.16b}, [x0], #16\n" 2901 "st1 {v27.16b}, [x0], #16\n" 2902 "st1 {v12.16b}, [x0], #16\n" 2903 "st1 {v20.16b}, [x0], #16\n" 2904 "st1 {v28.16b}, [x0], #16\n" 2905 "st1 {v13.16b}, [x0], #16\n" 2906 "st1 {v21.16b}, [x0], #16\n" 2907 "st1 {v29.16b}, [x0], #16\n" 2908 "st1 {v14.16b}, [x0], #16\n" 2909 "st1 {v22.16b}, [x0], #16\n" 2910 "st1 {v30.16b}, [x0], #16\n" 2911 "st1 {v15.16b}, [x0], #16\n" 2912 "st1 {v23.16b}, [x0], #16\n" 2913 "st1 {v31.16b}, [x0], #16\n" 2914 : // outputs 2915 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 2916 [depth] "+r"(depth) 2917 : // inputs 2918 [accum_ptr] "r"(accum_ptr) 2919 : // clobbers 2920 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", 2921 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", 2922 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", 2923 "v28", "v29", "v30", "v31"); 2924 } 2925 }; 2926 2927 // Not very efficient kernel, just an experiment to see what we can do 2928 // without using NEON multiply-with-scalar instructions. 2929 struct NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar { 2930 typedef float OperandType; 2931 typedef float AccumulatorType; 2932 typedef KernelFormat< 2933 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 2934 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> > 2935 Format; 2936 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 2937 AccumulatorType* accum_ptr, int depth) { 2938 asm volatile( 2939 // Load accumulators 2940 "mov x0, %[accum_ptr]\n" 2941 "ld1 {v8.16b}, [x0], #16\n" 2942 "ld1 {v16.16b}, [x0], #16\n" 2943 "ld1 {v24.16b}, [x0], #16\n" 2944 "ld1 {v9.16b}, [x0], #16\n" 2945 "ld1 {v17.16b}, [x0], #16\n" 2946 "ld1 {v25.16b}, [x0], #16\n" 2947 "ld1 {v10.16b}, [x0], #16\n" 2948 "ld1 {v18.16b}, [x0], #16\n" 2949 "ld1 {v26.16b}, [x0], #16\n" 2950 "ld1 {v11.16b}, [x0], #16\n" 2951 "ld1 {v19.16b}, [x0], #16\n" 2952 "ld1 {v27.16b}, [x0], #16\n" 2953 "ld1 {v12.16b}, [x0], #16\n" 2954 "ld1 {v20.16b}, [x0], #16\n" 2955 "ld1 {v28.16b}, [x0], #16\n" 2956 "ld1 {v13.16b}, [x0], #16\n" 2957 "ld1 {v21.16b}, [x0], #16\n" 2958 "ld1 {v29.16b}, [x0], #16\n" 2959 "ld1 {v14.16b}, [x0], #16\n" 2960 "ld1 {v22.16b}, [x0], #16\n" 2961 "ld1 {v30.16b}, [x0], #16\n" 2962 "ld1 {v15.16b}, [x0], #16\n" 2963 "ld1 {v23.16b}, [x0], #16\n" 2964 "ld1 {v31.16b}, [x0], #16\n" 2965 2966 GEMMLOWP_LABEL_LOOP 2967 ":\n" 2968 2969 // Load 2 Rhs cell of size 1x4 each 2970 "ld1 {v5.4s}, [%[rhs_ptr]], #16\n" 2971 "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" 2972 2973 // Load 3 Lhs cells of size 4x1 each 2974 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" 2975 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n" 2976 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" 2977 2978 // Multiply-accumulate 2979 "dup v0.4s, v5.s[0]\n" 2980 "dup v1.4s, v5.s[1]\n" 2981 "fmla v8.4s, v2.4s, v0.4s\n" 2982 "fmla v16.4s, v3.4s, v0.4s\n" 2983 "fmla v24.4s, v4.4s, v0.4s\n" 2984 "fmla v9.4s, v2.4s, v1.4s\n" 2985 "fmla v17.4s, v3.4s, v1.4s\n" 2986 "fmla v25.4s, v4.4s, v1.4s\n" 2987 "dup v0.4s, v5.s[2]\n" 2988 "dup v1.4s, v5.s[3]\n" 2989 "fmla v10.4s, v2.4s, v0.4s\n" 2990 "fmla v18.4s, v3.4s, v0.4s\n" 2991 "fmla v26.4s, v4.4s, v0.4s\n" 2992 "fmla v11.4s, v2.4s, v1.4s\n" 2993 "fmla v19.4s, v3.4s, v1.4s\n" 2994 "fmla v27.4s, v4.4s, v1.4s\n" 2995 "dup v0.4s, v6.s[0]\n" 2996 "dup v1.4s, v6.s[1]\n" 2997 "fmla v12.4s, v2.4s, v0.4s\n" 2998 "fmla v20.4s, v3.4s, v0.4s\n" 2999 "fmla v28.4s, v4.4s, v0.4s\n" 3000 "fmla v13.4s, v2.4s, v1.4s\n" 3001 "fmla v21.4s, v3.4s, v1.4s\n" 3002 "fmla v29.4s, v4.4s, v1.4s\n" 3003 "dup v0.4s, v6.s[2]\n" 3004 "dup v1.4s, v6.s[3]\n" 3005 "fmla v14.4s, v2.4s, v0.4s\n" 3006 "fmla v22.4s, v3.4s, v0.4s\n" 3007 "fmla v30.4s, v4.4s, v0.4s\n" 3008 "fmla v15.4s, v2.4s, v1.4s\n" 3009 "fmla v23.4s, v3.4s, v1.4s\n" 3010 "fmla v31.4s, v4.4s, v1.4s\n" 3011 3012 // Loop. Decrement loop index (depth) by 1, since we just handled 1 3013 // level of depth. 3014 "subs %w[depth], %w[depth], #1\n" 3015 "bne " GEMMLOWP_LABEL_LOOP 3016 "b\n" 3017 3018 // Store accumulators 3019 "mov x0, %[accum_ptr]\n" 3020 "st1 {v8.16b}, [x0], #16\n" 3021 "st1 {v16.16b}, [x0], #16\n" 3022 "st1 {v24.16b}, [x0], #16\n" 3023 "st1 {v9.16b}, [x0], #16\n" 3024 "st1 {v17.16b}, [x0], #16\n" 3025 "st1 {v25.16b}, [x0], #16\n" 3026 "st1 {v10.16b}, [x0], #16\n" 3027 "st1 {v18.16b}, [x0], #16\n" 3028 "st1 {v26.16b}, [x0], #16\n" 3029 "st1 {v11.16b}, [x0], #16\n" 3030 "st1 {v19.16b}, [x0], #16\n" 3031 "st1 {v27.16b}, [x0], #16\n" 3032 "st1 {v12.16b}, [x0], #16\n" 3033 "st1 {v20.16b}, [x0], #16\n" 3034 "st1 {v28.16b}, [x0], #16\n" 3035 "st1 {v13.16b}, [x0], #16\n" 3036 "st1 {v21.16b}, [x0], #16\n" 3037 "st1 {v29.16b}, [x0], #16\n" 3038 "st1 {v14.16b}, [x0], #16\n" 3039 "st1 {v22.16b}, [x0], #16\n" 3040 "st1 {v30.16b}, [x0], #16\n" 3041 "st1 {v15.16b}, [x0], #16\n" 3042 "st1 {v23.16b}, [x0], #16\n" 3043 "st1 {v31.16b}, [x0], #16\n" 3044 : // outputs 3045 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 3046 [depth] "+r"(depth) 3047 : // inputs 3048 [accum_ptr] "r"(accum_ptr) 3049 : // clobbers 3050 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", 3051 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", 3052 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", 3053 "v28", "v29", "v30", "v31"); 3054 } 3055 }; 3056 3057 // This is the "most natural" kernel, using NEON multiply-with-scalar 3058 // instructions. 3059 struct NEON_64bit_GEMM_Float32_WithScalar { 3060 typedef float OperandType; 3061 typedef float AccumulatorType; 3062 typedef KernelFormat< 3063 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 3064 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> > 3065 Format; 3066 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 3067 AccumulatorType* accum_ptr, int depth) { 3068 asm volatile( 3069 // Load accumulators 3070 "mov x0, %[accum_ptr]\n" 3071 "ld1 {v8.16b}, [x0], #16\n" 3072 "ld1 {v16.16b}, [x0], #16\n" 3073 "ld1 {v24.16b}, [x0], #16\n" 3074 "ld1 {v9.16b}, [x0], #16\n" 3075 "ld1 {v17.16b}, [x0], #16\n" 3076 "ld1 {v25.16b}, [x0], #16\n" 3077 "ld1 {v10.16b}, [x0], #16\n" 3078 "ld1 {v18.16b}, [x0], #16\n" 3079 "ld1 {v26.16b}, [x0], #16\n" 3080 "ld1 {v11.16b}, [x0], #16\n" 3081 "ld1 {v19.16b}, [x0], #16\n" 3082 "ld1 {v27.16b}, [x0], #16\n" 3083 "ld1 {v12.16b}, [x0], #16\n" 3084 "ld1 {v20.16b}, [x0], #16\n" 3085 "ld1 {v28.16b}, [x0], #16\n" 3086 "ld1 {v13.16b}, [x0], #16\n" 3087 "ld1 {v21.16b}, [x0], #16\n" 3088 "ld1 {v29.16b}, [x0], #16\n" 3089 "ld1 {v14.16b}, [x0], #16\n" 3090 "ld1 {v22.16b}, [x0], #16\n" 3091 "ld1 {v30.16b}, [x0], #16\n" 3092 "ld1 {v15.16b}, [x0], #16\n" 3093 "ld1 {v23.16b}, [x0], #16\n" 3094 "ld1 {v31.16b}, [x0], #16\n" 3095 3096 GEMMLOWP_LABEL_LOOP 3097 ":\n" 3098 3099 // Load 2 Rhs cell of size 1x4 each 3100 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n" 3101 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" 3102 3103 // Load 3 Lhs cells of size 4x1 each 3104 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" 3105 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n" 3106 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" 3107 3108 // Multiply-accumulate 3109 "fmla v8.4s, v2.4s, v0.s[0]\n" 3110 "fmla v9.4s, v2.4s, v0.s[1]\n" 3111 "fmla v10.4s, v2.4s, v0.s[2]\n" 3112 "fmla v11.4s, v2.4s, v0.s[3]\n" 3113 "fmla v12.4s, v2.4s, v1.s[0]\n" 3114 "fmla v13.4s, v2.4s, v1.s[1]\n" 3115 "fmla v14.4s, v2.4s, v1.s[2]\n" 3116 "fmla v15.4s, v2.4s, v1.s[3]\n" 3117 "fmla v16.4s, v3.4s, v0.s[0]\n" 3118 "fmla v17.4s, v3.4s, v0.s[1]\n" 3119 "fmla v18.4s, v3.4s, v0.s[2]\n" 3120 "fmla v19.4s, v3.4s, v0.s[3]\n" 3121 "fmla v20.4s, v3.4s, v1.s[0]\n" 3122 "fmla v21.4s, v3.4s, v1.s[1]\n" 3123 "fmla v22.4s, v3.4s, v1.s[2]\n" 3124 "fmla v23.4s, v3.4s, v1.s[3]\n" 3125 "fmla v24.4s, v4.4s, v0.s[0]\n" 3126 "fmla v25.4s, v4.4s, v0.s[1]\n" 3127 "fmla v26.4s, v4.4s, v0.s[2]\n" 3128 "fmla v27.4s, v4.4s, v0.s[3]\n" 3129 "fmla v28.4s, v4.4s, v1.s[0]\n" 3130 "fmla v29.4s, v4.4s, v1.s[1]\n" 3131 "fmla v30.4s, v4.4s, v1.s[2]\n" 3132 "fmla v31.4s, v4.4s, v1.s[3]\n" 3133 3134 // Loop. Decrement loop index (depth) by 1, since we just handled 1 3135 // level of depth. 3136 "subs %w[depth], %w[depth], #1\n" 3137 "bne " GEMMLOWP_LABEL_LOOP 3138 "b\n" 3139 3140 // Store accumulators 3141 "mov x0, %[accum_ptr]\n" 3142 "st1 {v8.16b}, [x0], #16\n" 3143 "st1 {v16.16b}, [x0], #16\n" 3144 "st1 {v24.16b}, [x0], #16\n" 3145 "st1 {v9.16b}, [x0], #16\n" 3146 "st1 {v17.16b}, [x0], #16\n" 3147 "st1 {v25.16b}, [x0], #16\n" 3148 "st1 {v10.16b}, [x0], #16\n" 3149 "st1 {v18.16b}, [x0], #16\n" 3150 "st1 {v26.16b}, [x0], #16\n" 3151 "st1 {v11.16b}, [x0], #16\n" 3152 "st1 {v19.16b}, [x0], #16\n" 3153 "st1 {v27.16b}, [x0], #16\n" 3154 "st1 {v12.16b}, [x0], #16\n" 3155 "st1 {v20.16b}, [x0], #16\n" 3156 "st1 {v28.16b}, [x0], #16\n" 3157 "st1 {v13.16b}, [x0], #16\n" 3158 "st1 {v21.16b}, [x0], #16\n" 3159 "st1 {v29.16b}, [x0], #16\n" 3160 "st1 {v14.16b}, [x0], #16\n" 3161 "st1 {v22.16b}, [x0], #16\n" 3162 "st1 {v30.16b}, [x0], #16\n" 3163 "st1 {v15.16b}, [x0], #16\n" 3164 "st1 {v23.16b}, [x0], #16\n" 3165 "st1 {v31.16b}, [x0], #16\n" 3166 : // outputs 3167 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 3168 [depth] "+r"(depth) 3169 : // inputs 3170 [accum_ptr] "r"(accum_ptr) 3171 : // clobbers 3172 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", 3173 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", 3174 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", 3175 "v28", "v29", "v30", "v31"); 3176 } 3177 }; 3178 3179 // Faster kernel contributed by ARM. Tuned for A57. 3180 struct NEON_64bit_GEMM_Float32_WithScalar_A57 { 3181 typedef float OperandType; 3182 typedef float AccumulatorType; 3183 typedef KernelFormat< 3184 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 3185 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> > 3186 Format; 3187 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 3188 AccumulatorType* accum_ptr, int depth) { 3189 asm volatile( 3190 // Load accumulators 3191 "mov x0, %[accum_ptr]\n" 3192 "ld1 {v8.16b}, [x0], #16\n" 3193 "ld1 {v16.16b}, [x0], #16\n" 3194 "ld1 {v24.16b}, [x0], #16\n" 3195 "ld1 {v9.16b}, [x0], #16\n" 3196 "ld1 {v17.16b}, [x0], #16\n" 3197 "ld1 {v25.16b}, [x0], #16\n" 3198 "ld1 {v10.16b}, [x0], #16\n" 3199 "ld1 {v18.16b}, [x0], #16\n" 3200 "ld1 {v26.16b}, [x0], #16\n" 3201 "ld1 {v11.16b}, [x0], #16\n" 3202 "ld1 {v19.16b}, [x0], #16\n" 3203 "ld1 {v27.16b}, [x0], #16\n" 3204 "ld1 {v12.16b}, [x0], #16\n" 3205 "ld1 {v20.16b}, [x0], #16\n" 3206 "ld1 {v28.16b}, [x0], #16\n" 3207 "ld1 {v13.16b}, [x0], #16\n" 3208 "ld1 {v21.16b}, [x0], #16\n" 3209 "ld1 {v29.16b}, [x0], #16\n" 3210 "ld1 {v14.16b}, [x0], #16\n" 3211 "ld1 {v22.16b}, [x0], #16\n" 3212 "ld1 {v30.16b}, [x0], #16\n" 3213 "ld1 {v15.16b}, [x0], #16\n" 3214 "ld1 {v23.16b}, [x0], #16\n" 3215 "ld1 {v31.16b}, [x0], #16\n" 3216 3217 // The start of the loop assumes first Rhs cell is already loaded, so 3218 // do it here for first iteration. 3219 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n" 3220 3221 // And the same for the first Lhs cell. 3222 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" 3223 3224 GEMMLOWP_LABEL_LOOP 3225 ":\n" 3226 3227 // Start the MACs at the head of the loop - 1st cell from each side 3228 // already loaded. 3229 "fmla v8.4s, v2.4s, v0.s[0]\n" 3230 "fmla v9.4s, v2.4s, v0.s[1]\n" 3231 "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" // Load second Rhs cell. 3232 "fmla v10.4s, v2.4s, v0.s[2]\n" 3233 "fmla v11.4s, v2.4s, v0.s[3]\n" 3234 "ld1 {v3.4s}, [%[lhs_ptr]], #16\n" // Load second Lhs cell. 3235 "fmla v12.4s, v2.4s, v1.s[0]\n" 3236 "fmla v13.4s, v2.4s, v1.s[1]\n" 3237 "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" // Load third Lhs cell. 3238 "fmla v14.4s, v2.4s, v1.s[2]\n" 3239 "fmla v15.4s, v2.4s, v1.s[3]\n" 3240 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" // Done with first Lhs cell - load 3241 // for the next iteration early. 3242 "fmla v16.4s, v3.4s, v0.s[0]\n" 3243 "fmla v17.4s, v3.4s, v0.s[1]\n" 3244 "fmla v18.4s, v3.4s, v0.s[2]\n" 3245 "fmla v19.4s, v3.4s, v0.s[3]\n" 3246 "fmla v20.4s, v3.4s, v1.s[0]\n" 3247 "fmla v21.4s, v3.4s, v1.s[1]\n" 3248 "fmla v22.4s, v3.4s, v1.s[2]\n" 3249 "fmla v23.4s, v3.4s, v1.s[3]\n" 3250 "fmla v24.4s, v4.4s, v0.s[0]\n" 3251 "fmla v25.4s, v4.4s, v0.s[1]\n" 3252 "fmla v26.4s, v4.4s, v0.s[2]\n" 3253 "fmla v27.4s, v4.4s, v0.s[3]\n" 3254 "ld1 {v0.4s}, [%[rhs_ptr]], #16\n" // Done with the first Rhs cell - 3255 // load for the next iteration 3256 // early. 3257 "fmla v28.4s, v4.4s, v1.s[0]\n" 3258 "fmla v29.4s, v4.4s, v1.s[1]\n" 3259 // Loop. Decrement loop index (depth) by 1, since we just handled 3260 // 1 level of depth. Do this a bit before the end of the loop for 3261 // better dispatch on A57. 3262 "subs %w[depth], %w[depth], #1\n" 3263 "fmla v30.4s, v4.4s, v1.s[2]\n" 3264 "fmla v31.4s, v4.4s, v1.s[3]\n" 3265 3266 "bne " GEMMLOWP_LABEL_LOOP 3267 "b\n" 3268 3269 // Store accumulators 3270 "mov x0, %[accum_ptr]\n" 3271 "st1 {v8.16b}, [x0], #16\n" 3272 "st1 {v16.16b}, [x0], #16\n" 3273 "st1 {v24.16b}, [x0], #16\n" 3274 "st1 {v9.16b}, [x0], #16\n" 3275 "st1 {v17.16b}, [x0], #16\n" 3276 "st1 {v25.16b}, [x0], #16\n" 3277 "st1 {v10.16b}, [x0], #16\n" 3278 "st1 {v18.16b}, [x0], #16\n" 3279 "st1 {v26.16b}, [x0], #16\n" 3280 "st1 {v11.16b}, [x0], #16\n" 3281 "st1 {v19.16b}, [x0], #16\n" 3282 "st1 {v27.16b}, [x0], #16\n" 3283 "st1 {v12.16b}, [x0], #16\n" 3284 "st1 {v20.16b}, [x0], #16\n" 3285 "st1 {v28.16b}, [x0], #16\n" 3286 "st1 {v13.16b}, [x0], #16\n" 3287 "st1 {v21.16b}, [x0], #16\n" 3288 "st1 {v29.16b}, [x0], #16\n" 3289 "st1 {v14.16b}, [x0], #16\n" 3290 "st1 {v22.16b}, [x0], #16\n" 3291 "st1 {v30.16b}, [x0], #16\n" 3292 "st1 {v15.16b}, [x0], #16\n" 3293 "st1 {v23.16b}, [x0], #16\n" 3294 "st1 {v31.16b}, [x0], #16\n" 3295 : // outputs 3296 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 3297 [depth] "+r"(depth) 3298 : // inputs 3299 [accum_ptr] "r"(accum_ptr) 3300 : // clobbers 3301 "cc", "memory", "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", 3302 "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", 3303 "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", 3304 "v28", "v29", "v30", "v31"); 3305 } 3306 }; 3307 3308 #ifndef __APPLE__ 3309 // Faster kernel contributed by ARM. Tuned for A53. 3310 struct NEON_64bit_GEMM_Float32_WithScalar_A53 { 3311 typedef float OperandType; 3312 typedef float AccumulatorType; 3313 typedef KernelFormat< 3314 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 3315 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> > 3316 Format; 3317 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 3318 AccumulatorType* accum_ptr, int depth) { 3319 asm volatile( 3320 // Load accumulators 3321 "mov x0, %[accum_ptr]\n" 3322 "ld1 {v8.16b}, [x0], #16\n" 3323 "ld1 {v16.16b}, [x0], #16\n" 3324 "ld1 {v24.16b}, [x0], #16\n" 3325 "ld1 {v9.16b}, [x0], #16\n" 3326 "ld1 {v17.16b}, [x0], #16\n" 3327 "ld1 {v25.16b}, [x0], #16\n" 3328 "ld1 {v10.16b}, [x0], #16\n" 3329 "ld1 {v18.16b}, [x0], #16\n" 3330 "ld1 {v26.16b}, [x0], #16\n" 3331 "ld1 {v11.16b}, [x0], #16\n" 3332 "ld1 {v19.16b}, [x0], #16\n" 3333 "ld1 {v27.16b}, [x0], #16\n" 3334 "ld1 {v12.16b}, [x0], #16\n" 3335 "ld1 {v20.16b}, [x0], #16\n" 3336 "ld1 {v28.16b}, [x0], #16\n" 3337 "ld1 {v13.16b}, [x0], #16\n" 3338 "ld1 {v21.16b}, [x0], #16\n" 3339 "ld1 {v29.16b}, [x0], #16\n" 3340 "ld1 {v14.16b}, [x0], #16\n" 3341 "ld1 {v22.16b}, [x0], #16\n" 3342 "ld1 {v30.16b}, [x0], #16\n" 3343 "ld1 {v15.16b}, [x0], #16\n" 3344 "ld1 {v23.16b}, [x0], #16\n" 3345 "ld1 {v31.16b}, [x0], #16\n" 3346 3347 // For A53, a very different-looking loop is needed. 3348 // 3349 // The main reason for this is that on A53 128-bit loads take two 3350 // cycles during which no dual issue can occur. Doing two separate 3351 // 64-bit loads avoids this issue - they each take one cycle and are 3352 // able to dual issue. Since vector register loads don't dual issue 3353 // with FMLA, we load half the register as normal and the other half 3354 // into an integer register. This second half can then be moved into 3355 // place later with an INS instruction - which will dual issue with a 3356 // later FP load. 3357 // 3358 // For this kernel there are approximately 3 times as many multiplies 3359 // as loads, so it makes sense to structure the loop into blocks of 4 3360 // cycles, with 1 dedicated "load cycle" and 3 "multiply cycles" per 3361 // block. Strictly preserving this structure with NOPs where no load 3362 // is needed seems to result in higher performance. 3363 // 3364 // Choice of x18 to store the upper halves on their way into the 3365 // vector registers is arbitrary. Added to the clobber list so that 3366 // the compiler will make it available. 3367 // 3368 // 3369 // At the start of the loop, it is assumed that v0 is "half loaded" - 3370 // bottom half in place in d0 and the upper half in x18 ready to 3371 // insert. So set that up here for the first iteration: 3372 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell 3373 "ldr x18, [%[rhs_ptr], #8]\n" // Upper half 3374 "add %[rhs_ptr], %[rhs_ptr], #16\n" // Separate increment (needed as 3375 // there is no operation to load at 3376 // reg + 8 but then increment reg 3377 // by 16). 3378 3379 // v2 should be fully loaded - as it's outside the loop proper it's fine 3380 // to use a 128-bit load here. 3381 "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" // first Lhs cell 3382 3383 GEMMLOWP_LABEL_LOOP 3384 ":\n" 3385 3386 // First block of four cycles. Multplies all require v2 and v0; v2 is 3387 // loaded earlier and v0 is half loaded and completed in the load 3388 // cycle at the start. 3389 "ldr d1, [%[rhs_ptr]]\n" // "load" cycle - loading bottom half of v1 3390 // (second Rhs cell). 3391 "ins v0.d[1], x18\n" // "load" cycle - moving the upper half of v0 into 3392 // place. 3393 "fmla v8.4s, v2.4s, v0.s[0]\n" // "fmla" cycle 1 - first multiply. 3394 "ldr x18, [%[rhs_ptr], #8]\n" // "fmla" cycle 1 - load upper half of v1 3395 // into x18. 3396 "fmla v9.4s, v2.4s, v0.s[1]\n" // "fmla" cycle 2 - second multiply 3397 "add %[rhs_ptr], %[rhs_ptr], #16\n" // "fmla" cycle 2 - increment Rhs 3398 // pointer (if needed) 3399 "fmla v10.4s, v2.4s, v0.s[2]\n" // "fmla" cycle 3 - third multiply. No 3400 // more work to dual issue. 3401 3402 // Second block. Start loading v3 (second Lhs cell), finish loading v1. 3403 "ldr d3, [%[lhs_ptr]]\n" 3404 "ins v1.d[1], x18\n" // v1 ready here. 3405 "fmla v11.4s, v2.4s, v0.s[3]\n" 3406 "ldr x18, [%[lhs_ptr], #8]\n" 3407 "fmla v12.4s, v2.4s, v1.s[0]\n" // First use of v1. 3408 "add %[lhs_ptr], %[lhs_ptr], #16\n" 3409 "fmla v13.4s, v2.4s, v1.s[1]\n" 3410 3411 // Third block. Start loading v4 (third Lhs cell), finish loading v3. 3412 "ldr d4, [%[lhs_ptr]]\n" 3413 "ins v3.d[1], x18\n" // v3 ready here. 3414 "fmla v14.4s, v2.4s, v1.s[2]\n" 3415 "ldr x18, [%[lhs_ptr], #8]\n" 3416 "fmla v15.4s, v2.4s, v1.s[3]\n" 3417 "add %[lhs_ptr], %[lhs_ptr], #16\n" 3418 "fmla v16.4s, v3.4s, v0.s[0]\n" // First use of v3. 3419 3420 // Fourth block. v2 (first Lhs cell) is now finished with, so start 3421 // loading value for next iteration. Finish loading v4. 3422 "ldr d2, [%[lhs_ptr]]\n" 3423 "ins v4.d[1], x18\n" // v4 ready here. 3424 "fmla v17.4s, v3.4s, v0.s[1]\n" 3425 "ldr x18, [%[lhs_ptr], #8]\n" 3426 "fmla v18.4s, v3.4s, v0.s[2]\n" 3427 "add %[lhs_ptr], %[lhs_ptr], #16\n" 3428 "fmla v19.4s, v3.4s, v0.s[3]\n" 3429 3430 // Fifth block, finish loading v2. No new load to start as the other 3431 // registers are all still live. 3432 "ins v2.d[1], x18\n" 3433 "fmla v20.4s, v3.4s, v1.s[0]\n" 3434 "fmla v21.4s, v3.4s, v1.s[1]\n" 3435 "fmla v22.4s, v3.4s, v1.s[2]\n" 3436 3437 // Sixth block, nothing to load. 2 nops needed as a single nop would 3438 // dual issue with the FMLA and break the timing. 3439 "nop\n" 3440 "nop\n" 3441 "fmla v23.4s, v3.4s, v1.s[3]\n" 3442 "fmla v24.4s, v4.4s, v0.s[0]\n" // First use of v4. 3443 "fmla v25.4s, v4.4s, v0.s[1]\n" 3444 3445 // Seventh block, nothing to load. Decrement the loop counter in this 3446 // block as the last block is very full. 3447 "nop\n" 3448 "nop\n" 3449 "fmla v26.4s, v4.4s, v0.s[2]\n" 3450 "subs %w[depth], %w[depth], #1\n" 3451 "fmla v27.4s, v4.4s, v0.s[3]\n" 3452 "fmla v28.4s, v4.4s, v1.s[0]\n" 3453 3454 // Eighth block - start loading v0 for next iteration. 3455 "ldr d0, [%[rhs_ptr]]\n" 3456 "fmla v29.4s, v4.4s, v1.s[1]\n" 3457 "ldr x18, [%[rhs_ptr], #8]\n" 3458 "fmla v30.4s, v4.4s, v1.s[2]\n" 3459 "add %[rhs_ptr], %[rhs_ptr], #16\n" 3460 "fmla v31.4s, v4.4s, v1.s[3]\n" 3461 3462 // Loop branch. This will dual issue in fmla cycle 3 of the 8th block. 3463 "bne " GEMMLOWP_LABEL_LOOP 3464 "b\n" 3465 3466 // Store accumulators 3467 "mov x0, %[accum_ptr]\n" 3468 "st1 {v8.16b}, [x0], #16\n" 3469 "st1 {v16.16b}, [x0], #16\n" 3470 "st1 {v24.16b}, [x0], #16\n" 3471 "st1 {v9.16b}, [x0], #16\n" 3472 "st1 {v17.16b}, [x0], #16\n" 3473 "st1 {v25.16b}, [x0], #16\n" 3474 "st1 {v10.16b}, [x0], #16\n" 3475 "st1 {v18.16b}, [x0], #16\n" 3476 "st1 {v26.16b}, [x0], #16\n" 3477 "st1 {v11.16b}, [x0], #16\n" 3478 "st1 {v19.16b}, [x0], #16\n" 3479 "st1 {v27.16b}, [x0], #16\n" 3480 "st1 {v12.16b}, [x0], #16\n" 3481 "st1 {v20.16b}, [x0], #16\n" 3482 "st1 {v28.16b}, [x0], #16\n" 3483 "st1 {v13.16b}, [x0], #16\n" 3484 "st1 {v21.16b}, [x0], #16\n" 3485 "st1 {v29.16b}, [x0], #16\n" 3486 "st1 {v14.16b}, [x0], #16\n" 3487 "st1 {v22.16b}, [x0], #16\n" 3488 "st1 {v30.16b}, [x0], #16\n" 3489 "st1 {v15.16b}, [x0], #16\n" 3490 "st1 {v23.16b}, [x0], #16\n" 3491 "st1 {v31.16b}, [x0], #16\n" 3492 : // outputs 3493 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 3494 [depth] "+r"(depth) 3495 : // inputs 3496 [accum_ptr] "r"(accum_ptr) 3497 : // clobbers 3498 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6", 3499 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", 3500 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", 3501 "v27", "v28", "v29", "v30", "v31"); 3502 } 3503 }; 3504 #endif 3505 3506 // Faster kernel contributed by ARM. Tuned for A55r1. 3507 struct NEON_64bit_GEMM_Float32_WithScalar_A55r1 { 3508 typedef float OperandType; 3509 typedef float AccumulatorType; 3510 typedef KernelFormat< 3511 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 3512 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 2> > 3513 Format; 3514 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 3515 AccumulatorType* accum_ptr, int depth) { 3516 asm volatile( 3517 // Load accumulators 3518 "mov x0, %[accum_ptr]\n" 3519 "ld1 {v8.4s}, [x0], #16\n" 3520 "ld1 {v16.4s}, [x0], #16\n" 3521 "ld1 {v24.4s}, [x0], #16\n" 3522 "ld1 {v9.4s}, [x0], #16\n" 3523 "ld1 {v17.4s}, [x0], #16\n" 3524 "ld1 {v25.4s}, [x0], #16\n" 3525 "ld1 {v10.4s}, [x0], #16\n" 3526 "ld1 {v18.4s}, [x0], #16\n" 3527 "ld1 {v26.4s}, [x0], #16\n" 3528 "ld1 {v11.4s}, [x0], #16\n" 3529 "ld1 {v19.4s}, [x0], #16\n" 3530 "ld1 {v27.4s}, [x0], #16\n" 3531 "ld1 {v12.4s}, [x0], #16\n" 3532 "ld1 {v20.4s}, [x0], #16\n" 3533 "ld1 {v28.4s}, [x0], #16\n" 3534 "ld1 {v13.4s}, [x0], #16\n" 3535 "ld1 {v21.4s}, [x0], #16\n" 3536 "ld1 {v29.4s}, [x0], #16\n" 3537 "ld1 {v14.4s}, [x0], #16\n" 3538 "ld1 {v22.4s}, [x0], #16\n" 3539 "ld1 {v30.4s}, [x0], #16\n" 3540 "ld1 {v15.4s}, [x0], #16\n" 3541 "ld1 {v23.4s}, [x0], #16\n" 3542 "ld1 {v31.4s}, [x0], #16\n" 3543 3544 // A55r1 requires a hybrid of the A53 and standard approaches. 3545 // 3546 // Like A53, this processor prefers 64-bit loads. 3547 // 3548 // Unlike A53, it is capable of dual-issuing a 64-bit vector load 3549 // (or INS) with a FMLA instruction. 3550 // 3551 // Therefore we aim to issue an FMLA instruction every cycle. 3552 // Alongside three FMLAs we can dual issue a (vector) 64-bit load, a 3553 // scalar 64-bit load and finally an INS to replicate the effect of 3554 // a single 128-bit load. 3555 // 3556 // The loop contains 24 FMLA instructions, and 5 vector registers 3557 // need to be loaded, consuming 15 dual issue slots. This leaves 9 3558 // dual issue slots. Four of these are used for loop housekeeping 3559 // (2 pointer adds, 1 counter update and 1 branch), leaving 5 left 3560 // over (marked by blank lines). 3561 // 3562 // Choice of x18 to store the upper halves on their way into the 3563 // vector registers is arbitrary. Added to the clobber list so that 3564 // the compiler will make it available. 3565 3566 3567 // At the start of the loop, it is assumed that v0 is "half loaded" - 3568 // bottom half in place in d0 and the upper half in x18 ready to 3569 // insert. So set that up here for the first iteration: 3570 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of first Rhs cell 3571 "ldr x18, [%[rhs_ptr], #8]\n" // Upper half 3572 3573 // v2-v3 should be fully loaded - as it's outside the loop proper it's fine 3574 // to use a 128-bit load here. 3575 "ldr q2, [%[lhs_ptr]]\n" // first Lhs cell 3576 "ldr q3, [%[lhs_ptr], #16]\n" // second Lhs cell 3577 3578 GEMMLOWP_LABEL_LOOP 3579 ":\n" 3580 3581 "fmla v8.4s, v2.4s, v0.s[0]\n" 3582 "ldr d1, [%[rhs_ptr], #16]\n" // Bottom half of v1 3583 "fmla v9.4s, v2.4s, v0.s[1]\n" 3584 "ins v0.d[1], x18\n" // Finish loading v0 3585 "fmla v16.4s, v3.4s, v0.s[0]\n" // out of sequence - used to reduce load/use pressure. 3586 "ldr x18, [%[rhs_ptr], #24]\n" // Top half of v1 to X register 3587 "fmla v17.4s, v3.4s, v0.s[1]\n" // out of sequence - used to reduce load/use pressure. 3588 "add %[rhs_ptr], %[rhs_ptr], #32\n" // RHS loads complete - increment pointer. 3589 "fmla v10.4s, v2.4s, v0.s[2]\n" 3590 "ldr d4, [%[lhs_ptr], #32]\n" // Bottom half of v4 3591 "fmla v11.4s, v2.4s, v0.s[3]\n" 3592 "ins v1.d[1], x18\n" // Finish loading v1 3593 "fmla v12.4s, v2.4s, v1.s[0]\n" 3594 "ldr x18, [%[lhs_ptr], #40]\n" // Top half of v4 to X register 3595 "fmla v13.4s, v2.4s, v1.s[1]\n" 3596 "add %[lhs_ptr], %[lhs_ptr], #48\n" // LHS loads complete - increment pointer. 3597 "fmla v14.4s, v2.4s, v1.s[2]\n" 3598 3599 "fmla v15.4s, v2.4s, v1.s[3]\n" 3600 "ldr d2, [%[lhs_ptr]]\n" // Bottom half of v2 (for next time) 3601 "fmla v18.4s, v3.4s, v0.s[2]\n" 3602 "ins v4.d[1], x18\n" // Finish loading v4 3603 "fmla v19.4s, v3.4s, v0.s[3]\n" 3604 "ldr x18, [%[lhs_ptr], #8]\n" // Top half of next v2 to X register 3605 "fmla v20.4s, v3.4s, v1.s[0]\n" 3606 "subs %w[depth], %w[depth], #1\n" 3607 "fmla v21.4s, v3.4s, v1.s[1]\n" 3608 3609 "fmla v22.4s, v3.4s, v1.s[2]\n" 3610 3611 "fmla v23.4s, v3.4s, v1.s[3]\n" 3612 "ldr d3, [%[lhs_ptr], #16]\n" // Bottom half of v3 (for next time) 3613 "fmla v24.4s, v4.4s, v0.s[0]\n" 3614 "ins v2.d[1], x18\n" // Finish loading next v2 3615 "fmla v25.4s, v4.4s, v0.s[1]\n" 3616 "ldr x18, [%[lhs_ptr], #24]\n" // Top half of next v3 to X register 3617 "fmla v26.4s, v4.4s, v0.s[2]\n" 3618 3619 "fmla v27.4s, v4.4s, v0.s[3]\n" 3620 "ldr d0, [%[rhs_ptr]]\n" // Bottom half of v0 (for next time) 3621 "fmla v28.4s, v4.4s, v1.s[0]\n" 3622 "ins v3.d[1], x18\n" // Finish loading next v3 3623 "fmla v29.4s, v4.4s, v1.s[1]\n" 3624 "ldr x18, [%[rhs_ptr], #8]\n" // Top half of next v0 to X register 3625 "fmla v30.4s, v4.4s, v1.s[2]\n" 3626 3627 "fmla v31.4s, v4.4s, v1.s[3]\n" 3628 "bne " GEMMLOWP_LABEL_LOOP "b\n" 3629 3630 // Store accumulators 3631 "mov x0, %[accum_ptr]\n" 3632 "st1 {v8.4s}, [x0], #16\n" 3633 "st1 {v16.4s}, [x0], #16\n" 3634 "st1 {v24.4s}, [x0], #16\n" 3635 "st1 {v9.4s}, [x0], #16\n" 3636 "st1 {v17.4s}, [x0], #16\n" 3637 "st1 {v25.4s}, [x0], #16\n" 3638 "st1 {v10.4s}, [x0], #16\n" 3639 "st1 {v18.4s}, [x0], #16\n" 3640 "st1 {v26.4s}, [x0], #16\n" 3641 "st1 {v11.4s}, [x0], #16\n" 3642 "st1 {v19.4s}, [x0], #16\n" 3643 "st1 {v27.4s}, [x0], #16\n" 3644 "st1 {v12.4s}, [x0], #16\n" 3645 "st1 {v20.4s}, [x0], #16\n" 3646 "st1 {v28.4s}, [x0], #16\n" 3647 "st1 {v13.4s}, [x0], #16\n" 3648 "st1 {v21.4s}, [x0], #16\n" 3649 "st1 {v29.4s}, [x0], #16\n" 3650 "st1 {v14.4s}, [x0], #16\n" 3651 "st1 {v22.4s}, [x0], #16\n" 3652 "st1 {v30.4s}, [x0], #16\n" 3653 "st1 {v15.4s}, [x0], #16\n" 3654 "st1 {v23.4s}, [x0], #16\n" 3655 "st1 {v31.4s}, [x0], #16\n" 3656 : // outputs 3657 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 3658 [depth] "+r"(depth) 3659 : // inputs 3660 [accum_ptr] "r"(accum_ptr) 3661 : // clobbers 3662 "cc", "memory", "x0", "x18", "v0", "v1", "v2", "v3", "v4", "v5", "v6", 3663 "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", 3664 "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", 3665 "v27", "v28", "v29", "v30", "v31"); 3666 } 3667 }; 3668 3669 #endif // __aarch64__ 3670 3671 #if defined(__arm__) || defined(__aarch64__) 3672 #ifndef __aarch64__ 3673 inline int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { 3674 const int32x2_t c = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); 3675 const int32x2_t d = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); 3676 return vcombine_s32(c, d); 3677 } 3678 #endif 3679 3680 // C++ intrinsics-based variant of the deep, int8, fast kernel 3681 template <int Cols> 3682 struct NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics { 3683 typedef std::int8_t OperandType; 3684 typedef std::int32_t AccumulatorType; 3685 typedef KernelFormat< 3686 KernelSideFormat<CellFormat<4, 16, CellOrder::WidthMajor>, 1>, 3687 KernelSideFormat<CellFormat<Cols, 16, CellOrder::WidthMajor>, 1> > 3688 Format; 3689 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 3690 AccumulatorType* accum_ptr, int depth) { 3691 int32x4_t acc[4][Cols]; 3692 for (int i = 0; i < 4; i++) { 3693 for (int j = 0; j < Cols; j++) { 3694 acc[i][j] = vdupq_n_s32(0); 3695 } 3696 } 3697 for (int d = 0; d < depth; d += 16) { 3698 int8x16_t lhs[4]; 3699 for (int i = 0; i < 4; i++) { 3700 lhs[i] = vld1q_s8(lhs_ptr + 16 * i); 3701 } 3702 int8x16_t rhs[Cols]; 3703 for (int i = 0; i < Cols; i++) { 3704 rhs[i] = vld1q_s8(rhs_ptr + 16 * i); 3705 } 3706 for (int i = 0; i < 4; i++) { 3707 for (int j = 0; j < Cols; j++) { 3708 int16x8_t local_acc = 3709 vmull_s8(vget_low_s8(lhs[i]), vget_low_s8(rhs[j])); 3710 local_acc = 3711 vmlal_s8(local_acc, vget_high_s8(lhs[i]), vget_high_s8(rhs[j])); 3712 acc[i][j] = vpadalq_s16(acc[i][j], local_acc); 3713 } 3714 } 3715 lhs_ptr += 64; 3716 rhs_ptr += 16 * Cols; 3717 } 3718 for (int i = 0; i < Cols; i++) { 3719 int32x4_t acc_2x_0 = vpaddq_s32(acc[0][i], acc[1][i]); 3720 int32x4_t acc_2x_1 = vpaddq_s32(acc[2][i], acc[3][i]); 3721 int32x4_t acc_4x = vpaddq_s32(acc_2x_0, acc_2x_1); 3722 int32x4_t dst_val = vld1q_s32(accum_ptr + 4 * i); 3723 dst_val = vaddq_s32(dst_val, acc_4x); 3724 vst1q_s32(accum_ptr + 4 * i, dst_val); 3725 } 3726 } 3727 }; 3728 3729 using NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics = 3730 NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics<4>; 3731 3732 using NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics = 3733 NEON_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics<2>; 3734 3735 // C++ intrinsics-based variant of the wide, uint8, general kernel 3736 template <int RhsCells> 3737 struct NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics { 3738 typedef std::uint8_t OperandType; 3739 typedef std::int32_t AccumulatorType; 3740 typedef KernelFormat< 3741 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 3742 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, RhsCells> > 3743 Format; 3744 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 3745 AccumulatorType* accum_ptr, int depth) { 3746 int32x4_t acc[3][4 * RhsCells]; 3747 for (int i = 0; i < 3; i++) { 3748 for (int j = 0; j < 4 * RhsCells; j++) { 3749 acc[i][j] = vld1q_s32(accum_ptr + 4 * (i + 3 * j)); 3750 } 3751 } 3752 for (int d = 0; d < depth; d += 2) { 3753 int16x8_t lhs[3]; 3754 for (int i = 0; i < 3; i++) { 3755 lhs[i] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(lhs_ptr + 8 * i))); 3756 } 3757 int16x8_t rhs[RhsCells]; 3758 for (int i = 0; i < RhsCells; i++) { 3759 rhs[i] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(rhs_ptr + 8 * i))); 3760 } 3761 for (int i = 0; i < 3; i++) { 3762 for (int j = 0; j < RhsCells; j++) { 3763 acc[i][4 * j + 0] = vmlal_lane_s16( 3764 acc[i][4 * j + 0], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 0); 3765 acc[i][4 * j + 1] = vmlal_lane_s16( 3766 acc[i][4 * j + 1], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 1); 3767 acc[i][4 * j + 2] = vmlal_lane_s16( 3768 acc[i][4 * j + 2], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 2); 3769 acc[i][4 * j + 3] = vmlal_lane_s16( 3770 acc[i][4 * j + 3], vget_low_s16(lhs[i]), vget_low_s16(rhs[j]), 3); 3771 acc[i][4 * j + 0] = 3772 vmlal_lane_s16(acc[i][4 * j + 0], vget_high_s16(lhs[i]), 3773 vget_high_s16(rhs[j]), 0); 3774 acc[i][4 * j + 1] = 3775 vmlal_lane_s16(acc[i][4 * j + 1], vget_high_s16(lhs[i]), 3776 vget_high_s16(rhs[j]), 1); 3777 acc[i][4 * j + 2] = 3778 vmlal_lane_s16(acc[i][4 * j + 2], vget_high_s16(lhs[i]), 3779 vget_high_s16(rhs[j]), 2); 3780 acc[i][4 * j + 3] = 3781 vmlal_lane_s16(acc[i][4 * j + 3], vget_high_s16(lhs[i]), 3782 vget_high_s16(rhs[j]), 3); 3783 } 3784 } 3785 lhs_ptr += 24; 3786 rhs_ptr += 8 * RhsCells; 3787 } 3788 for (int i = 0; i < 3; i++) { 3789 for (int j = 0; j < 4 * RhsCells; j++) { 3790 vst1q_s32(accum_ptr + 4 * (i + 3 * j), acc[i][j]); 3791 } 3792 } 3793 } 3794 }; 3795 3796 using NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics = 3797 NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<1>; 3798 3799 using NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics = 3800 NEON_GEMM_Uint8Operands_Uint32Accumulators_intrinsics<2>; 3801 3802 template <int RhsCells> 3803 struct NEON_GEMM_Float32_WithScalar_intrinsics { 3804 typedef float OperandType; 3805 typedef float AccumulatorType; 3806 typedef KernelFormat< 3807 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, 3>, 3808 KernelSideFormat<CellFormat<4, 1, CellOrder::DepthMajor>, RhsCells> > 3809 Format; 3810 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 3811 AccumulatorType* accum_ptr, int depth) { 3812 float32x4_t acc[3][4 * RhsCells]; 3813 for (int i = 0; i < 3; i++) { 3814 for (int j = 0; j < 4 * RhsCells; j++) { 3815 acc[i][j] = vld1q_f32(accum_ptr + 4 * (i + 3 * j)); 3816 } 3817 } 3818 for (int d = 0; d < depth; d++) { 3819 float32x4_t lhs[3]; 3820 for (int i = 0; i < 3; i++) { 3821 lhs[i] = vld1q_f32(lhs_ptr + 4 * i); 3822 } 3823 float32x4_t rhs[RhsCells]; 3824 for (int i = 0; i < RhsCells; i++) { 3825 rhs[i] = vld1q_f32(rhs_ptr + 4 * i); 3826 } 3827 for (int i = 0; i < 3; i++) { 3828 for (int j = 0; j < RhsCells; j++) { 3829 acc[i][4 * j + 0] = vmlaq_lane_f32(acc[i][4 * j + 0], lhs[i], 3830 vget_low_f32(rhs[j]), 0); 3831 acc[i][4 * j + 1] = vmlaq_lane_f32(acc[i][4 * j + 1], lhs[i], 3832 vget_low_f32(rhs[j]), 1); 3833 acc[i][4 * j + 2] = vmlaq_lane_f32(acc[i][4 * j + 2], lhs[i], 3834 vget_high_f32(rhs[j]), 0); 3835 acc[i][4 * j + 3] = vmlaq_lane_f32(acc[i][4 * j + 3], lhs[i], 3836 vget_high_f32(rhs[j]), 1); 3837 } 3838 } 3839 lhs_ptr += 12; 3840 rhs_ptr += 4 * RhsCells; 3841 } 3842 for (int i = 0; i < 3; i++) { 3843 for (int j = 0; j < 4 * RhsCells; j++) { 3844 vst1q_f32(accum_ptr + 4 * (i + 3 * j), acc[i][j]); 3845 } 3846 } 3847 } 3848 }; 3849 3850 using NEON_32bit_GEMM_Float32_WithScalar_intrinsics = 3851 NEON_GEMM_Float32_WithScalar_intrinsics<1>; 3852 3853 using NEON_64bit_GEMM_Float32_WithScalar_intrinsics = 3854 NEON_GEMM_Float32_WithScalar_intrinsics<2>; 3855 #endif // __arm__ || __aarch64__ 3856 3857 #ifdef __mips 3858 static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) { 3859 // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c). 3860 #if 0 3861 return __builtin_msa_maddv_w(a, b, c); 3862 #else 3863 asm volatile("maddv.w %w[a], %w[b], %w[c]\n" 3864 // Outputs 3865 : [a] "+f"(a) 3866 // Inputs 3867 : [b] "f"(b), [c] "f"(c)); 3868 return a; 3869 #endif 3870 } 3871 3872 // Using 32x32=32 multiplications. 3873 // 20 MSA regs used: 3874 // - 12 accumulators 3875 // - 6 lhs 3876 // - 1 rhs 3877 // - 1 temps/zeroes 3878 // ~55 instructions in the loop. 3879 struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics { 3880 typedef std::uint8_t OperandType; 3881 typedef std::int32_t AccumulatorType; 3882 typedef KernelFormat< 3883 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 3884 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > 3885 Format; 3886 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 3887 AccumulatorType* accum_ptr, int depth) { 3888 const v16i8 zeroes = __builtin_msa_ldi_b(0); 3889 v4i32 acc[3][4]; 3890 // Load accumulators. 3891 for (int i = 0; i < 3; i++) { 3892 for (int j = 0; j < 4; j++) { 3893 acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0); 3894 } 3895 } 3896 3897 while (depth > 0) { 3898 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. 3899 v8i16 lhs[6]; 3900 lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0)); 3901 lhs[1] = 3902 reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0)); 3903 3904 // Zero-extend 8-bit elements of lhs[] to 16 bits. 3905 lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, 3906 reinterpret_cast<v16i8>(lhs[0]))); 3907 lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes, 3908 reinterpret_cast<v16i8>(lhs[1]))); 3909 lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, 3910 reinterpret_cast<v16i8>(lhs[1]))); 3911 3912 // Zero-extend 16-bit elements of lhs[] to 32 bits. 3913 lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); 3914 lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); 3915 lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); 3916 lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); 3917 lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); 3918 lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); 3919 3920 // Depth 0. 3921 for (int j = 0; j < 4; j++) { 3922 // Load 1 byte of rhs, making 4 32-bit replicas of it. 3923 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j])); 3924 // Multiply-add into accumulators. 3925 for (int i = 0; i < 3; i++) { 3926 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); 3927 } 3928 } 3929 3930 // Depth 1. 3931 for (int j = 0; j < 4; j++) { 3932 // Load 1 byte of rhs, making 4 32-bit replicas of it. 3933 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); 3934 // Multiply-add into accumulators. 3935 for (int i = 0; i < 3; i++) { 3936 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); 3937 } 3938 } 3939 3940 lhs_ptr += 24; 3941 rhs_ptr += 8; 3942 depth -= 2; 3943 } 3944 3945 // Store accumulators. 3946 for (int i = 0; i < 3; i++) { 3947 for (int j = 0; j < 4; j++) { 3948 __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0); 3949 } 3950 } 3951 } 3952 }; 3953 3954 // Assembly implementation of the above 3955 // MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics. 3956 // Using 32x32=32 multiplications. 3957 // 20 MSA regs used: 3958 // - 12 accumulators 3959 // - 6 lhs 3960 // - 1 rhs 3961 // - 1 temps/zeroes 3962 // ~55 instructions in the loop. 3963 struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly { 3964 typedef std::uint8_t OperandType; 3965 typedef std::int32_t AccumulatorType; 3966 typedef KernelFormat< 3967 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 3968 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > 3969 Format; 3970 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, 3971 AccumulatorType* accum_ptr, int depth) { 3972 asm volatile( 3973 // Load accumulators 3974 "ld.w $w0, (0*16)(%[accum_ptr])\n" 3975 "ld.w $w4, (1*16)(%[accum_ptr])\n" 3976 "ld.w $w8, (2*16)(%[accum_ptr])\n" 3977 "ld.w $w1, (3*16)(%[accum_ptr])\n" 3978 "ld.w $w5, (4*16)(%[accum_ptr])\n" 3979 "ld.w $w9, (5*16)(%[accum_ptr])\n" 3980 "ld.w $w2, (6*16)(%[accum_ptr])\n" 3981 "ld.w $w6, (7*16)(%[accum_ptr])\n" 3982 "ld.w $w10, (8*16)(%[accum_ptr])\n" 3983 "ld.w $w3, (9*16)(%[accum_ptr])\n" 3984 "ld.w $w7, (10*16)(%[accum_ptr])\n" 3985 "ld.w $w11, (11*16)(%[accum_ptr])\n" 3986 // Set a temp to all zeroes. 3987 "ldi.b $w19, 0\n" 3988 3989 GEMMLOWP_LABEL_LOOP ":\n" 3990 // Overview of register layout: 3991 // 3992 // A half of the 2x4 cell of Rhs is stored in 32bit in w18. 3993 // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w12-w17. 3994 // A 12x4 block of accumulators is stored in 32bit in w0-w11. 3995 // 3996 // +------+------+------+------+ 3997 // Rhs |w18[0]|w18[1]|w18[2]|w18[3]| 3998 // +------+------+------+------+ 3999 // 4000 // | | | | | 4001 // 4002 // Lhs | | | | | 4003 // 4004 // +---+---+ - - - - +------+------+------+------+ 4005 // |w12|w15| | w0 | w1 | w2 | w3 | 4006 // |w12|w15| | w0 | w1 | w2 | w3 | 4007 // |w12|w15| | w0 | w1 | w2 | w3 | 4008 // |w12|w15| | w0 | w1 | w2 | w3 | 4009 // +---+---+ - - - - +------+------+------+------+ 4010 // |w13|w16| | w4 | w5 | w6 | w7 | 4011 // |w13|w16| | w4 | w5 | w6 | w7 | 4012 // |w13|w16| | w4 | w5 | w6 | w7 | 4013 // |w13|w16| | w4 | w5 | w6 | w7 | 4014 // +---+---+ - - - - +------+------+------+------+ 4015 // |w14|w17| | w8 | w9 | w10 | w11 | 4016 // |w14|w17| | w8 | w9 | w10 | w11 | 4017 // |w14|w17| | w8 | w9 | w10 | w11 | 4018 // |w14|w17| | w8 | w9 | w10 | w11 | 4019 // +---+---+ - - - - +------+------+------+------+ 4020 // 4021 // Accumulator 4022 4023 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. 4024 "ld.b $w12, 0(%[lhs_ptr])\n" 4025 "ld.b $w13, 8(%[lhs_ptr])\n" 4026 4027 // Load 4 bytes of rhs[] for depth 0. 4028 "lbu $a0, 0(%[rhs_ptr])\n" 4029 "lbu $a1, 1(%[rhs_ptr])\n" 4030 "lbu $a2, 2(%[rhs_ptr])\n" 4031 "lbu $a3, 3(%[rhs_ptr])\n" 4032 4033 // Zero-extend 8-bit elements of lhs[] to 16 bits. 4034 "ilvr.b $w12, $w19, $w12\n" 4035 "ilvl.b $w14, $w19, $w13\n" 4036 "ilvr.b $w13, $w19, $w13\n" 4037 // Zero-extend 16-bit elements of lhs[] to 32 bits. 4038 "ilvl.h $w15, $w19, $w12\n" 4039 "ilvl.h $w16, $w19, $w13\n" 4040 "ilvl.h $w17, $w19, $w14\n" 4041 "ilvr.h $w12, $w19, $w12\n" 4042 "ilvr.h $w13, $w19, $w13\n" 4043 "ilvr.h $w14, $w19, $w14\n" 4044 4045 // Depth 0. 4046 "fill.w $w18, $a0\n" 4047 "lbu $a0, 4(%[rhs_ptr])\n" 4048 "maddv.w $w0, $w12, $w18\n" 4049 "maddv.w $w4, $w13, $w18\n" 4050 "maddv.w $w8, $w14, $w18\n" 4051 "fill.w $w18, $a1\n" 4052 "lbu $a1, 5(%[rhs_ptr])\n" 4053 "maddv.w $w1, $w12, $w18\n" 4054 "maddv.w $w5, $w13, $w18\n" 4055 "maddv.w $w9, $w14, $w18\n" 4056 "fill.w $w18, $a2\n" 4057 "lbu $a2, 6(%[rhs_ptr])\n" 4058 "maddv.w $w2, $w12, $w18\n" 4059 "maddv.w $w6, $w13, $w18\n" 4060 "maddv.w $w10, $w14, $w18\n" 4061 "fill.w $w18, $a3\n" 4062 "lbu $a3, 7(%[rhs_ptr])\n" 4063 "maddv.w $w3, $w12, $w18\n" 4064 "maddv.w $w7, $w13, $w18\n" 4065 "maddv.w $w11, $w14, $w18\n" 4066 4067 // Depth 1. 4068 "fill.w $w18, $a0\n" 4069 "maddv.w $w0, $w15, $w18\n" 4070 "maddv.w $w4, $w16, $w18\n" 4071 "maddv.w $w8, $w17, $w18\n" 4072 "fill.w $w18, $a1\n" 4073 "maddv.w $w1, $w15, $w18\n" 4074 "maddv.w $w5, $w16, $w18\n" 4075 "maddv.w $w9, $w17, $w18\n" 4076 "fill.w $w18, $a2\n" 4077 "maddv.w $w2, $w15, $w18\n" 4078 "maddv.w $w6, $w16, $w18\n" 4079 "maddv.w $w10, $w17, $w18\n" 4080 "fill.w $w18, $a3\n" 4081 "maddv.w $w3, $w15, $w18\n" 4082 "maddv.w $w7, $w16, $w18\n" 4083 "maddv.w $w11, $w17, $w18\n" 4084 4085 "addiu %[depth], -2\n" 4086 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" 4087 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n" 4088 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" 4089 4090 // Store accumulators. 4091 "st.w $w0, (0*16)(%[accum_ptr])\n" 4092 "st.w $w4, (1*16)(%[accum_ptr])\n" 4093 "st.w $w8, (2*16)(%[accum_ptr])\n" 4094 "st.w $w1, (3*16)(%[accum_ptr])\n" 4095 "st.w $w5, (4*16)(%[accum_ptr])\n" 4096 "st.w $w9, (5*16)(%[accum_ptr])\n" 4097 "st.w $w2, (6*16)(%[accum_ptr])\n" 4098 "st.w $w6, (7*16)(%[accum_ptr])\n" 4099 "st.w $w10, (8*16)(%[accum_ptr])\n" 4100 "st.w $w3, (9*16)(%[accum_ptr])\n" 4101 "st.w $w7, (10*16)(%[accum_ptr])\n" 4102 "st.w $w11, (11*16)(%[accum_ptr])\n" 4103 : // outputs 4104 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 4105 [depth] "+r"(depth) 4106 : // inputs 4107 [accum_ptr] "r"(accum_ptr) 4108 : // clobbers 4109 "memory", 4110 "a0", "a1", "a2", "a3", 4111 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", 4112 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", 4113 "$f16", "$f17", "$f18", "$f19"); 4114 } 4115 }; 4116 4117 // Assembly implementation of the above 4118 // MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO). 4119 // Using 16x16=32 multiplications. 4120 // 20 MSA regs used: 4121 // - 12 accumulators 4122 // - 3 lhs 4123 // - 4 rhs 4124 // - 1 temps/zeroes 4125 // ~45 instructions in the loop. 4126 struct MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2 { 4127 typedef std::uint8_t OperandType; 4128 typedef std::int32_t AccumulatorType; 4129 typedef KernelFormat< 4130 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 4131 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 1> > 4132 Format; 4133 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, 4134 AccumulatorType* accum_ptr, int depth) { 4135 asm volatile( 4136 // Load accumulators 4137 "ld.w $w0, (0*16)(%[accum_ptr])\n" 4138 "ld.w $w4, (1*16)(%[accum_ptr])\n" 4139 "ld.w $w8, (2*16)(%[accum_ptr])\n" 4140 "ld.w $w1, (3*16)(%[accum_ptr])\n" 4141 "ld.w $w5, (4*16)(%[accum_ptr])\n" 4142 "ld.w $w9, (5*16)(%[accum_ptr])\n" 4143 "ld.w $w2, (6*16)(%[accum_ptr])\n" 4144 "ld.w $w6, (7*16)(%[accum_ptr])\n" 4145 "ld.w $w10, (8*16)(%[accum_ptr])\n" 4146 "ld.w $w3, (9*16)(%[accum_ptr])\n" 4147 "ld.w $w7, (10*16)(%[accum_ptr])\n" 4148 "ld.w $w11, (11*16)(%[accum_ptr])\n" 4149 // Set a temp to all zeroes. 4150 "ldi.b $w19, 0\n" 4151 4152 GEMMLOWP_LABEL_LOOP ":\n" 4153 // Overview of register layout: 4154 // 4155 // A 2x4 cell of Rhs is stored in 16bit in w15-w18 (each register 4156 // contains 4 replicas of a pair of elements). 4157 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w12-w14. 4158 // A 12x4 block of accumulators is stored in 32bit in w0-w11. 4159 // 4160 // +-----+-----+-----+-----+ 4161 // Rhs | w15 | w16 | w17 | w18 | 4162 // +-----+-----+-----+-----+ 4163 // 4164 // | | | | | 4165 // 4166 // Lhs | | | | | 4167 // 4168 // +---+ - - - - +-----+-----+-----+-----+ 4169 // |w12| | w0 | w1 | w2 | w3 | 4170 // |w12| | w0 | w1 | w2 | w3 | 4171 // |w12| | w0 | w1 | w2 | w3 | 4172 // |w12| | w0 | w1 | w2 | w3 | 4173 // +---+ - - - - +-----+-----+-----+-----+ 4174 // |w13| | w4 | w5 | w6 | w7 | 4175 // |w13| | w4 | w5 | w6 | w7 | 4176 // |w13| | w4 | w5 | w6 | w7 | 4177 // |w13| | w4 | w5 | w6 | w7 | 4178 // +---+ - - - - +-----+-----+-----+-----+ 4179 // |w14| | w8 | w9 | w10 | w11 | 4180 // |w14| | w8 | w9 | w10 | w11 | 4181 // |w14| | w8 | w9 | w10 | w11 | 4182 // |w14| | w8 | w9 | w10 | w11 | 4183 // +---+ - - - - +-----+-----+-----+-----+ 4184 // 4185 // Accumulators 4186 4187 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. 4188 "ld.b $w12, 0(%[lhs_ptr])\n" 4189 "ld.b $w13, 8(%[lhs_ptr])\n" 4190 4191 // Load 4 bytes of rhs[] for depth 0. 4192 "lbu $a0, 0(%[rhs_ptr])\n" 4193 "lbu $a1, 1(%[rhs_ptr])\n" 4194 "lbu $a2, 2(%[rhs_ptr])\n" 4195 "lbu $a3, 3(%[rhs_ptr])\n" 4196 // Load 4 bytes of rhs[] for depth 1. 4197 "lbu $v0, 4(%[rhs_ptr])\n" 4198 "lbu $v1, 5(%[rhs_ptr])\n" 4199 "lbu $t8, 6(%[rhs_ptr])\n" 4200 "lbu $t9, 7(%[rhs_ptr])\n" 4201 4202 // Zero-extend 8-bit elements of lhs[] to 16 bits. 4203 "ilvr.b $w12, $w19, $w12\n" 4204 "ilvl.b $w14, $w19, $w13\n" 4205 "ilvr.b $w13, $w19, $w13\n" 4206 // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. 4207 "ilvl.d $w15, $w19, $w12\n" 4208 "ilvl.d $w16, $w19, $w13\n" 4209 "ilvl.d $w17, $w19, $w14\n" 4210 "ilvr.h $w12, $w15, $w12\n" 4211 "ilvr.h $w13, $w16, $w13\n" 4212 "ilvr.h $w14, $w17, $w14\n" 4213 4214 // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w. 4215 "ins $a0, $v0, 16, 8\n" 4216 "ins $a1, $v1, 16, 8\n" 4217 "ins $a2, $t8, 16, 8\n" 4218 "ins $a3, $t9, 16, 8\n" 4219 // Make 4 replicas of every pair of rhs[] elements. 4220 "fill.w $w15, $a0\n" 4221 "fill.w $w16, $a1\n" 4222 "fill.w $w17, $a2\n" 4223 "fill.w $w18, $a3\n" 4224 4225 // Depths 0 and 1. 4226 // Dot-product-(and)-add doubles multiplicand width. 4227 "dpadd_u.w $w0, $w12, $w15\n" 4228 "dpadd_u.w $w4, $w13, $w15\n" 4229 "dpadd_u.w $w8, $w14, $w15\n" 4230 "dpadd_u.w $w1, $w12, $w16\n" 4231 "dpadd_u.w $w5, $w13, $w16\n" 4232 "dpadd_u.w $w9, $w14, $w16\n" 4233 "dpadd_u.w $w2, $w12, $w17\n" 4234 "dpadd_u.w $w6, $w13, $w17\n" 4235 "dpadd_u.w $w10, $w14, $w17\n" 4236 "dpadd_u.w $w3, $w12, $w18\n" 4237 "dpadd_u.w $w7, $w13, $w18\n" 4238 "dpadd_u.w $w11, $w14, $w18\n" 4239 4240 "addiu %[depth], -2\n" 4241 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" 4242 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 8\n" 4243 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" 4244 4245 // Store accumulators. 4246 "st.w $w0, (0*16)(%[accum_ptr])\n" 4247 "st.w $w4, (1*16)(%[accum_ptr])\n" 4248 "st.w $w8, (2*16)(%[accum_ptr])\n" 4249 "st.w $w1, (3*16)(%[accum_ptr])\n" 4250 "st.w $w5, (4*16)(%[accum_ptr])\n" 4251 "st.w $w9, (5*16)(%[accum_ptr])\n" 4252 "st.w $w2, (6*16)(%[accum_ptr])\n" 4253 "st.w $w6, (7*16)(%[accum_ptr])\n" 4254 "st.w $w10, (8*16)(%[accum_ptr])\n" 4255 "st.w $w3, (9*16)(%[accum_ptr])\n" 4256 "st.w $w7, (10*16)(%[accum_ptr])\n" 4257 "st.w $w11, (11*16)(%[accum_ptr])\n" 4258 : // outputs 4259 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 4260 [depth] "+r"(depth) 4261 : // inputs 4262 [accum_ptr] "r"(accum_ptr) 4263 : // clobbers 4264 "memory", 4265 "v0", "v1", 4266 "a0", "a1", "a2", "a3", 4267 "t8", "t9", 4268 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", 4269 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", 4270 "$f16", "$f17", "$f18", "$f19"); 4271 } 4272 }; 4273 4274 // Using 32x32=32 multiplications. 4275 // 32 MSA regs used: 4276 // - 24 accumulators 4277 // - 6 lhs 4278 // - 1 rhs 4279 // - 1 temps/zeroes 4280 // ~95 instructions in the loop. 4281 struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics { 4282 typedef std::uint8_t OperandType; 4283 typedef std::uint32_t AccumulatorType; 4284 typedef KernelFormat< 4285 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 4286 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > 4287 Format; 4288 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 4289 AccumulatorType* accum_ptr, int depth) { 4290 const v16i8 zeroes = __builtin_msa_ldi_b(0); 4291 v4i32 acc[3][8]; 4292 // Load accumulators. 4293 for (int i = 0; i < 3; i++) { 4294 for (int j = 0; j < 8; j++) { 4295 acc[i][j] = __builtin_msa_ld_w(accum_ptr + 4 * (i + 3 * j), 0); 4296 } 4297 } 4298 4299 while (depth > 0) { 4300 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. 4301 v8i16 lhs[6]; 4302 lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr), 0)); 4303 lhs[1] = 4304 reinterpret_cast<v8i16>(__builtin_msa_ld_b(const_cast<OperandType*>(lhs_ptr + 8), 0)); 4305 4306 // Zero-extend 8-bit elements of lhs[] to 16 bits. 4307 lhs[0] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, 4308 reinterpret_cast<v16i8>(lhs[0]))); 4309 lhs[2] = reinterpret_cast<v8i16>(__builtin_msa_ilvl_b(zeroes, 4310 reinterpret_cast<v16i8>(lhs[1]))); 4311 lhs[1] = reinterpret_cast<v8i16>(__builtin_msa_ilvr_b(zeroes, 4312 reinterpret_cast<v16i8>(lhs[1]))); 4313 4314 // Zero-extend 16-bit elements of lhs[] to 32 bits. 4315 lhs[3] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); 4316 lhs[4] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); 4317 lhs[5] = __builtin_msa_ilvl_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); 4318 lhs[0] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[0]); 4319 lhs[1] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[1]); 4320 lhs[2] = __builtin_msa_ilvr_h(reinterpret_cast<v8i16>(zeroes), lhs[2]); 4321 4322 // Depth 0. 4323 for (int j = 0; j < 4; j++) { 4324 // Load 1 byte of rhs, making 4 32-bit replicas of it. 4325 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j])); 4326 // Multiply-add into accumulators. 4327 for (int i = 0; i < 3; i++) { 4328 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); 4329 } 4330 } 4331 for (int j = 4; j < 8; j++) { 4332 // Load 1 byte of rhs, making 4 32-bit replicas of it. 4333 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); 4334 // Multiply-add into accumulators. 4335 for (int i = 0; i < 3; i++) { 4336 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i]), rhs); 4337 } 4338 } 4339 4340 // Depth 1. 4341 for (int j = 0; j < 4; j++) { 4342 // Load 1 byte of rhs, making 4 32-bit replicas of it. 4343 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 4])); 4344 // Multiply-add into accumulators. 4345 for (int i = 0; i < 3; i++) { 4346 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); 4347 } 4348 } 4349 for (int j = 4; j < 8; j++) { 4350 // Load 1 byte of rhs, making 4 32-bit replicas of it. 4351 v4i32 rhs = reinterpret_cast<v4i32>(__builtin_msa_fill_w(rhs_ptr[j + 8])); 4352 // Multiply-add into accumulators. 4353 for (int i = 0; i < 3; i++) { 4354 acc[i][j] = workaround_msa_maddv_w(acc[i][j], reinterpret_cast<v4i32>(lhs[i + 3]), rhs); 4355 } 4356 } 4357 4358 lhs_ptr += 24; 4359 rhs_ptr += 16; 4360 depth -= 2; 4361 } 4362 4363 // Store accumulators. 4364 for (int i = 0; i < 3; i++) { 4365 for (int j = 0; j < 8; j++) { 4366 __builtin_msa_st_w(acc[i][j], accum_ptr + 4 * (i + 3 * j), 0); 4367 } 4368 } 4369 } 4370 }; 4371 4372 // Assembly implementation of the above 4373 // MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics. 4374 // Using 32x32=32 multiplications. 4375 // 32 MSA regs used: 4376 // - 24 accumulators 4377 // - 6 lhs 4378 // - 1 rhs 4379 // - 1 temps/zeroes 4380 // ~95 instructions in the loop. 4381 struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly { 4382 typedef std::uint8_t OperandType; 4383 typedef std::uint32_t AccumulatorType; 4384 typedef KernelFormat< 4385 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 4386 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > 4387 Format; 4388 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, 4389 AccumulatorType* accum_ptr, int depth) { 4390 asm volatile( 4391 // Load accumulators 4392 "ld.w $w0, (0*16)(%[accum_ptr])\n" 4393 "ld.w $w4, (1*16)(%[accum_ptr])\n" 4394 "ld.w $w8, (2*16)(%[accum_ptr])\n" 4395 "ld.w $w1, (3*16)(%[accum_ptr])\n" 4396 "ld.w $w5, (4*16)(%[accum_ptr])\n" 4397 "ld.w $w9, (5*16)(%[accum_ptr])\n" 4398 "ld.w $w2, (6*16)(%[accum_ptr])\n" 4399 "ld.w $w6, (7*16)(%[accum_ptr])\n" 4400 "ld.w $w10, (8*16)(%[accum_ptr])\n" 4401 "ld.w $w3, (9*16)(%[accum_ptr])\n" 4402 "ld.w $w7, (10*16)(%[accum_ptr])\n" 4403 "ld.w $w11, (11*16)(%[accum_ptr])\n" 4404 "ld.w $w12, (12*16)(%[accum_ptr])\n" 4405 "ld.w $w16, (13*16)(%[accum_ptr])\n" 4406 "ld.w $w20, (14*16)(%[accum_ptr])\n" 4407 "ld.w $w13, (15*16)(%[accum_ptr])\n" 4408 "ld.w $w17, (16*16)(%[accum_ptr])\n" 4409 "ld.w $w21, (17*16)(%[accum_ptr])\n" 4410 "ld.w $w14, (18*16)(%[accum_ptr])\n" 4411 "ld.w $w18, (19*16)(%[accum_ptr])\n" 4412 "ld.w $w22, (20*16)(%[accum_ptr])\n" 4413 "ld.w $w15, (21*16)(%[accum_ptr])\n" 4414 "ld.w $w19, (22*16)(%[accum_ptr])\n" 4415 "ld.w $w23, (23*16)(%[accum_ptr])\n" 4416 // Set a temp to all zeroes. 4417 "ldi.b $w31, 0\n" 4418 4419 GEMMLOWP_LABEL_LOOP ":\n" 4420 // Overview of register layout: 4421 // 4422 // A quarter of the 2 2x4 cells of Rhs is stored in 32bit in w30. 4423 // A 12x2 block of 3 4x2 cells Lhs is stored in 32bit in w24-w29. 4424 // A 12x8 block of accumulators is stored in 32bit in w0-w23. 4425 // 4426 // +------+------+------+------+ 4427 // Rhs |w30[0]|w30[1]|w30[2]|w30[3]| 4428 // +------+------+------+------+ 4429 // 4430 // | | | | | 4431 // 4432 // Lhs | | | | | 4433 // 4434 // +---+---+ - - - - +------+------+------+------+ 4435 // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | 4436 // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | 4437 // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | 4438 // |w24|w27| |w0/12 |w1/13 |w2/14 |w3/15 | 4439 // +---+---+ - - - - +------+------+------+------+ 4440 // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | 4441 // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | 4442 // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | 4443 // |w25|w28| |w4/16 |w5/17 |w6/18 |w7/19 | 4444 // +---+---+ - - - - +------+------+------+------+ 4445 // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| 4446 // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| 4447 // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| 4448 // |w26|w29| |w8/20 |w9/21 |w10/22|w11/23| 4449 // +---+---+ - - - - +------+------+------+------+ 4450 // 4451 // Accumulator 4452 4453 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. 4454 "ld.b $w24, 0(%[lhs_ptr])\n" 4455 "ld.b $w25, 8(%[lhs_ptr])\n" 4456 4457 // Load 4 bytes of rhs[] for the first half of depth 0. 4458 "lbu $a0, 0(%[rhs_ptr])\n" 4459 "lbu $a1, 1(%[rhs_ptr])\n" 4460 "lbu $a2, 2(%[rhs_ptr])\n" 4461 "lbu $a3, 3(%[rhs_ptr])\n" 4462 4463 // Zero-extend 8-bit elements of lhs[] to 16 bits. 4464 "ilvr.b $w24, $w31, $w24\n" 4465 "ilvl.b $w26, $w31, $w25\n" 4466 "ilvr.b $w25, $w31, $w25\n" 4467 // Zero-extend 16-bit elements of lhs[] to 32 bits. 4468 "ilvl.h $w27, $w31, $w24\n" 4469 "ilvl.h $w28, $w31, $w25\n" 4470 "ilvl.h $w29, $w31, $w26\n" 4471 "ilvr.h $w24, $w31, $w24\n" 4472 "ilvr.h $w25, $w31, $w25\n" 4473 "ilvr.h $w26, $w31, $w26\n" 4474 4475 // Depth 0. 4476 "fill.w $w30, $a0\n" 4477 "lbu $a0, 8(%[rhs_ptr])\n" 4478 "maddv.w $w0, $w24, $w30\n" 4479 "maddv.w $w4, $w25, $w30\n" 4480 "maddv.w $w8, $w26, $w30\n" 4481 "fill.w $w30, $a1\n" 4482 "lbu $a1, 9(%[rhs_ptr])\n" 4483 "maddv.w $w1, $w24, $w30\n" 4484 "maddv.w $w5, $w25, $w30\n" 4485 "maddv.w $w9, $w26, $w30\n" 4486 "fill.w $w30, $a2\n" 4487 "lbu $a2, 10(%[rhs_ptr])\n" 4488 "maddv.w $w2, $w24, $w30\n" 4489 "maddv.w $w6, $w25, $w30\n" 4490 "maddv.w $w10, $w26, $w30\n" 4491 "fill.w $w30, $a3\n" 4492 "lbu $a3, 11(%[rhs_ptr])\n" 4493 "maddv.w $w3, $w24, $w30\n" 4494 "maddv.w $w7, $w25, $w30\n" 4495 "maddv.w $w11, $w26, $w30\n" 4496 4497 "fill.w $w30, $a0\n" 4498 "lbu $a0, 4(%[rhs_ptr])\n" 4499 "maddv.w $w12, $w24, $w30\n" 4500 "maddv.w $w16, $w25, $w30\n" 4501 "maddv.w $w20, $w26, $w30\n" 4502 "fill.w $w30, $a1\n" 4503 "lbu $a1, 5(%[rhs_ptr])\n" 4504 "maddv.w $w13, $w24, $w30\n" 4505 "maddv.w $w17, $w25, $w30\n" 4506 "maddv.w $w21, $w26, $w30\n" 4507 "fill.w $w30, $a2\n" 4508 "lbu $a2, 6(%[rhs_ptr])\n" 4509 "maddv.w $w14, $w24, $w30\n" 4510 "maddv.w $w18, $w25, $w30\n" 4511 "maddv.w $w22, $w26, $w30\n" 4512 "fill.w $w30, $a3\n" 4513 "lbu $a3, 7(%[rhs_ptr])\n" 4514 "maddv.w $w15, $w24, $w30\n" 4515 "maddv.w $w19, $w25, $w30\n" 4516 "maddv.w $w23, $w26, $w30\n" 4517 4518 // Depth 1. 4519 "fill.w $w30, $a0\n" 4520 "lbu $a0, 12(%[rhs_ptr])\n" 4521 "maddv.w $w0, $w27, $w30\n" 4522 "maddv.w $w4, $w28, $w30\n" 4523 "maddv.w $w8, $w29, $w30\n" 4524 "fill.w $w30, $a1\n" 4525 "lbu $a1, 13(%[rhs_ptr])\n" 4526 "maddv.w $w1, $w27, $w30\n" 4527 "maddv.w $w5, $w28, $w30\n" 4528 "maddv.w $w9, $w29, $w30\n" 4529 "fill.w $w30, $a2\n" 4530 "lbu $a2, 14(%[rhs_ptr])\n" 4531 "maddv.w $w2, $w27, $w30\n" 4532 "maddv.w $w6, $w28, $w30\n" 4533 "maddv.w $w10, $w29, $w30\n" 4534 "fill.w $w30, $a3\n" 4535 "lbu $a3, 15(%[rhs_ptr])\n" 4536 "maddv.w $w3, $w27, $w30\n" 4537 "maddv.w $w7, $w28, $w30\n" 4538 "maddv.w $w11, $w29, $w30\n" 4539 4540 "fill.w $w30, $a0\n" 4541 "maddv.w $w12, $w27, $w30\n" 4542 "maddv.w $w16, $w28, $w30\n" 4543 "maddv.w $w20, $w29, $w30\n" 4544 "fill.w $w30, $a1\n" 4545 "maddv.w $w13, $w27, $w30\n" 4546 "maddv.w $w17, $w28, $w30\n" 4547 "maddv.w $w21, $w29, $w30\n" 4548 "fill.w $w30, $a2\n" 4549 "maddv.w $w14, $w27, $w30\n" 4550 "maddv.w $w18, $w28, $w30\n" 4551 "maddv.w $w22, $w29, $w30\n" 4552 "fill.w $w30, $a3\n" 4553 "maddv.w $w15, $w27, $w30\n" 4554 "maddv.w $w19, $w28, $w30\n" 4555 "maddv.w $w23, $w29, $w30\n" 4556 4557 "addiu %[depth], -2\n" 4558 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" 4559 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n" 4560 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" 4561 4562 // Store accumulators. 4563 "st.w $w0, (0*16)(%[accum_ptr])\n" 4564 "st.w $w4, (1*16)(%[accum_ptr])\n" 4565 "st.w $w8, (2*16)(%[accum_ptr])\n" 4566 "st.w $w1, (3*16)(%[accum_ptr])\n" 4567 "st.w $w5, (4*16)(%[accum_ptr])\n" 4568 "st.w $w9, (5*16)(%[accum_ptr])\n" 4569 "st.w $w2, (6*16)(%[accum_ptr])\n" 4570 "st.w $w6, (7*16)(%[accum_ptr])\n" 4571 "st.w $w10, (8*16)(%[accum_ptr])\n" 4572 "st.w $w3, (9*16)(%[accum_ptr])\n" 4573 "st.w $w7, (10*16)(%[accum_ptr])\n" 4574 "st.w $w11, (11*16)(%[accum_ptr])\n" 4575 "st.w $w12, (12*16)(%[accum_ptr])\n" 4576 "st.w $w16, (13*16)(%[accum_ptr])\n" 4577 "st.w $w20, (14*16)(%[accum_ptr])\n" 4578 "st.w $w13, (15*16)(%[accum_ptr])\n" 4579 "st.w $w17, (16*16)(%[accum_ptr])\n" 4580 "st.w $w21, (17*16)(%[accum_ptr])\n" 4581 "st.w $w14, (18*16)(%[accum_ptr])\n" 4582 "st.w $w18, (19*16)(%[accum_ptr])\n" 4583 "st.w $w22, (20*16)(%[accum_ptr])\n" 4584 "st.w $w15, (21*16)(%[accum_ptr])\n" 4585 "st.w $w19, (22*16)(%[accum_ptr])\n" 4586 "st.w $w23, (23*16)(%[accum_ptr])\n" 4587 : // outputs 4588 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 4589 [depth] "+r"(depth) 4590 : // inputs 4591 [accum_ptr] "r"(accum_ptr) 4592 : // clobbers 4593 "memory", 4594 "a0", "a1", "a2", "a3", 4595 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", 4596 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", 4597 "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", 4598 "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31"); 4599 } 4600 }; 4601 4602 // Assembly implementation of the above 4603 // MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics2 (TODO). 4604 // Using 16x16=32 multiplications. 4605 // 32 MSA regs used: 4606 // - 24 accumulators 4607 // - 3 lhs 4608 // - 4 rhs 4609 // - 1 temps/zeroes 4610 // ~70 instructions in the loop. 4611 struct MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2 { 4612 typedef std::uint8_t OperandType; 4613 typedef std::uint32_t AccumulatorType; 4614 typedef KernelFormat< 4615 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 3>, 4616 KernelSideFormat<CellFormat<4, 2, CellOrder::DepthMajor>, 2> > 4617 Format; 4618 static void Run(OperandType* lhs_ptr, OperandType* rhs_ptr, 4619 AccumulatorType* accum_ptr, int depth) { 4620 asm volatile( 4621 // Load accumulators 4622 "ld.w $w0, (0*16)(%[accum_ptr])\n" 4623 "ld.w $w4, (1*16)(%[accum_ptr])\n" 4624 "ld.w $w8, (2*16)(%[accum_ptr])\n" 4625 "ld.w $w1, (3*16)(%[accum_ptr])\n" 4626 "ld.w $w5, (4*16)(%[accum_ptr])\n" 4627 "ld.w $w9, (5*16)(%[accum_ptr])\n" 4628 "ld.w $w2, (6*16)(%[accum_ptr])\n" 4629 "ld.w $w6, (7*16)(%[accum_ptr])\n" 4630 "ld.w $w10, (8*16)(%[accum_ptr])\n" 4631 "ld.w $w3, (9*16)(%[accum_ptr])\n" 4632 "ld.w $w7, (10*16)(%[accum_ptr])\n" 4633 "ld.w $w11, (11*16)(%[accum_ptr])\n" 4634 "ld.w $w12, (12*16)(%[accum_ptr])\n" 4635 "ld.w $w16, (13*16)(%[accum_ptr])\n" 4636 "ld.w $w20, (14*16)(%[accum_ptr])\n" 4637 "ld.w $w13, (15*16)(%[accum_ptr])\n" 4638 "ld.w $w17, (16*16)(%[accum_ptr])\n" 4639 "ld.w $w21, (17*16)(%[accum_ptr])\n" 4640 "ld.w $w14, (18*16)(%[accum_ptr])\n" 4641 "ld.w $w18, (19*16)(%[accum_ptr])\n" 4642 "ld.w $w22, (20*16)(%[accum_ptr])\n" 4643 "ld.w $w15, (21*16)(%[accum_ptr])\n" 4644 "ld.w $w19, (22*16)(%[accum_ptr])\n" 4645 "ld.w $w23, (23*16)(%[accum_ptr])\n" 4646 // Set a temp to all zeroes. 4647 "ldi.b $w31, 0\n" 4648 4649 GEMMLOWP_LABEL_LOOP ":\n" 4650 // Overview of register layout: 4651 // 4652 // A half of the 2 2x4 cells of Rhs is stored in 16bit in w27-w30 4653 // (each register contains 4 replicas of a pair of elements). 4654 // A 12x2 block of 3 4x2 cells Lhs is stored in 16bit in w24-w26. 4655 // A 12x8 block of accumulators is stored in 32bit in w0-w23. 4656 // 4657 // +------+------+------+------+ 4658 // Rhs |w27 |w28 |w29 |w30 | 4659 // +------+------+------+------+ 4660 // 4661 // | | | | | 4662 // 4663 // Lhs | | | | | 4664 // 4665 // +---+ - - - - +------+------+------+------+ 4666 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | 4667 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | 4668 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | 4669 // |w24| |w0/12 |w1/13 |w2/14 |w3/15 | 4670 // +---+ - - - - +------+------+------+------+ 4671 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | 4672 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | 4673 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | 4674 // |w25| |w4/16 |w5/17 |w6/18 |w7/19 | 4675 // +---+ - - - - +------+------+------+------+ 4676 // |w26| |w8/20 |w9/21 |w10/22|w11/23| 4677 // |w26| |w8/20 |w9/21 |w10/22|w11/23| 4678 // |w26| |w8/20 |w9/21 |w10/22|w11/23| 4679 // |w26| |w8/20 |w9/21 |w10/22|w11/23| 4680 // +---+ - - - - +------+------+------+------+ 4681 // 4682 // Accumulators 4683 4684 // Load 3 x 8 bytes of lhs[] with 2 16-byte overlapped loads. 4685 "ld.b $w24, 0(%[lhs_ptr])\n" 4686 "ld.b $w25, 8(%[lhs_ptr])\n" 4687 4688 // Load 4 bytes of rhs[] for the first half of depth 0. 4689 "lbu $a0, 0(%[rhs_ptr])\n" 4690 "lbu $a1, 1(%[rhs_ptr])\n" 4691 "lbu $a2, 2(%[rhs_ptr])\n" 4692 "lbu $a3, 3(%[rhs_ptr])\n" 4693 // Load 4 bytes of rhs[] for the first half of depth 1. 4694 "lbu $v0, 4(%[rhs_ptr])\n" 4695 "lbu $v1, 5(%[rhs_ptr])\n" 4696 "lbu $t8, 6(%[rhs_ptr])\n" 4697 "lbu $t9, 7(%[rhs_ptr])\n" 4698 4699 // Zero-extend 8-bit elements of lhs[] to 16 bits. 4700 "ilvr.b $w24, $w31, $w24\n" 4701 "ilvl.b $w26, $w31, $w25\n" 4702 "ilvr.b $w25, $w31, $w25\n" 4703 // Interleave depth 0 and depth 1 elements of lhs[] for dpadd_u.w. 4704 "ilvl.d $w27, $w31, $w24\n" 4705 "ilvl.d $w28, $w31, $w25\n" 4706 "ilvl.d $w29, $w31, $w26\n" 4707 "ilvr.h $w24, $w27, $w24\n" 4708 "ilvr.h $w25, $w28, $w25\n" 4709 "ilvr.h $w26, $w29, $w26\n" 4710 4711 // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w 4712 // (for the first half). 4713 "ins $a0, $v0, 16, 8\n" 4714 "ins $a1, $v1, 16, 8\n" 4715 "ins $a2, $t8, 16, 8\n" 4716 "ins $a3, $t9, 16, 8\n" 4717 // Make 4 replicas of every pair of rhs[] elements. 4718 "fill.w $w27, $a0\n" 4719 "fill.w $w28, $a1\n" 4720 "fill.w $w29, $a2\n" 4721 "fill.w $w30, $a3\n" 4722 4723 // Load 4 bytes of rhs[] for the second half of depth 0. 4724 "lbu $a0, 8(%[rhs_ptr])\n" 4725 "lbu $a1, 9(%[rhs_ptr])\n" 4726 "lbu $a2, 10(%[rhs_ptr])\n" 4727 "lbu $a3, 11(%[rhs_ptr])\n" 4728 // Load 4 bytes of rhs[] for the second half of depth 1. 4729 "lbu $v0, 12(%[rhs_ptr])\n" 4730 "lbu $v1, 13(%[rhs_ptr])\n" 4731 "lbu $t8, 14(%[rhs_ptr])\n" 4732 "lbu $t9, 15(%[rhs_ptr])\n" 4733 4734 // First half of depths 0 and 1. 4735 // Dot-product-(and)-add doubles multiplicand width. 4736 "dpadd_u.w $w0, $w24, $w27\n" 4737 "dpadd_u.w $w4, $w25, $w27\n" 4738 "dpadd_u.w $w8, $w26, $w27\n" 4739 "dpadd_u.w $w1, $w24, $w28\n" 4740 "dpadd_u.w $w5, $w25, $w28\n" 4741 "dpadd_u.w $w9, $w26, $w28\n" 4742 "dpadd_u.w $w2, $w24, $w29\n" 4743 "dpadd_u.w $w6, $w25, $w29\n" 4744 "dpadd_u.w $w10, $w26, $w29\n" 4745 "dpadd_u.w $w3, $w24, $w30\n" 4746 "dpadd_u.w $w7, $w25, $w30\n" 4747 "dpadd_u.w $w11, $w26, $w30\n" 4748 4749 // Combine and interleave depth 0 and depth 1 elements of rhs[] for dpadd_u.w 4750 // (for the second half). 4751 "ins $a0, $v0, 16, 8\n" 4752 "ins $a1, $v1, 16, 8\n" 4753 "ins $a2, $t8, 16, 8\n" 4754 "ins $a3, $t9, 16, 8\n" 4755 // Make 4 replicas of every pair of rhs[] elements. 4756 "fill.w $w27, $a0\n" 4757 "fill.w $w28, $a1\n" 4758 "fill.w $w29, $a2\n" 4759 "fill.w $w30, $a3\n" 4760 4761 // Second half of depths 0 and 1. 4762 // Dot-product-(and)-add doubles multiplicand width. 4763 "dpadd_u.w $w12, $w24, $w27\n" 4764 "dpadd_u.w $w16, $w25, $w27\n" 4765 "dpadd_u.w $w20, $w26, $w27\n" 4766 "dpadd_u.w $w13, $w24, $w28\n" 4767 "dpadd_u.w $w17, $w25, $w28\n" 4768 "dpadd_u.w $w21, $w26, $w28\n" 4769 "dpadd_u.w $w14, $w24, $w29\n" 4770 "dpadd_u.w $w18, $w25, $w29\n" 4771 "dpadd_u.w $w22, $w26, $w29\n" 4772 "dpadd_u.w $w15, $w24, $w30\n" 4773 "dpadd_u.w $w19, $w25, $w30\n" 4774 "dpadd_u.w $w23, $w26, $w30\n" 4775 4776 "addiu %[depth], -2\n" 4777 GEMMLOWP_MIPS_XADDIU " %[lhs_ptr], 24\n" 4778 GEMMLOWP_MIPS_XADDIU " %[rhs_ptr], 16\n" 4779 "bnez %[depth]," GEMMLOWP_LABEL_LOOP "b\n" 4780 4781 // Store accumulators. 4782 "st.w $w0, (0*16)(%[accum_ptr])\n" 4783 "st.w $w4, (1*16)(%[accum_ptr])\n" 4784 "st.w $w8, (2*16)(%[accum_ptr])\n" 4785 "st.w $w1, (3*16)(%[accum_ptr])\n" 4786 "st.w $w5, (4*16)(%[accum_ptr])\n" 4787 "st.w $w9, (5*16)(%[accum_ptr])\n" 4788 "st.w $w2, (6*16)(%[accum_ptr])\n" 4789 "st.w $w6, (7*16)(%[accum_ptr])\n" 4790 "st.w $w10, (8*16)(%[accum_ptr])\n" 4791 "st.w $w3, (9*16)(%[accum_ptr])\n" 4792 "st.w $w7, (10*16)(%[accum_ptr])\n" 4793 "st.w $w11, (11*16)(%[accum_ptr])\n" 4794 "st.w $w12, (12*16)(%[accum_ptr])\n" 4795 "st.w $w16, (13*16)(%[accum_ptr])\n" 4796 "st.w $w20, (14*16)(%[accum_ptr])\n" 4797 "st.w $w13, (15*16)(%[accum_ptr])\n" 4798 "st.w $w17, (16*16)(%[accum_ptr])\n" 4799 "st.w $w21, (17*16)(%[accum_ptr])\n" 4800 "st.w $w14, (18*16)(%[accum_ptr])\n" 4801 "st.w $w18, (19*16)(%[accum_ptr])\n" 4802 "st.w $w22, (20*16)(%[accum_ptr])\n" 4803 "st.w $w15, (21*16)(%[accum_ptr])\n" 4804 "st.w $w19, (22*16)(%[accum_ptr])\n" 4805 "st.w $w23, (23*16)(%[accum_ptr])\n" 4806 : // outputs 4807 [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), 4808 [depth] "+r"(depth) 4809 : // inputs 4810 [accum_ptr] "r"(accum_ptr) 4811 : // clobbers 4812 "memory", 4813 "v0", "v1", 4814 "a0", "a1", "a2", "a3", 4815 "t8", "t9", 4816 "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f6", "$f7", 4817 "$f8", "$f9", "$f10", "$f11", "$f12", "$f13", "$f14", "$f15", 4818 "$f16", "$f17", "$f18", "$f19", "$f20", "$f21", "$f22", "$f23", 4819 "$f24", "$f25", "$f26", "$f27", "$f28", "$f29", "$f30", "$f31"); 4820 } 4821 }; 4822 #endif // __mips 4823 4824 // BEGIN code copied from gemmlowp/internal/kernel_reference.h 4825 4826 // This kernel is templatized in an arbitrary Format template parameter, 4827 // allowing it to have any arbitrary format. 4828 template <typename tOperandType, typename tAccumulatorType, typename tFormat> 4829 struct ReferenceKernel { 4830 typedef tOperandType OperandType; 4831 typedef tAccumulatorType AccumulatorType; 4832 typedef tFormat Format; 4833 4834 static void Run(const OperandType* lhs_ptr, const OperandType* rhs_ptr, 4835 AccumulatorType* accum_ptr, int depth) { 4836 const int depth_cells = static_cast<int>(depth / Format::kDepth); 4837 4838 // The outer loop is over the depth dimension. 4839 for (int dc = 0; dc < depth_cells; dc++) { 4840 // The next two loops are over cells of the Lhs (stacked vertically), 4841 // and over cells of the Rhs (stacked horizontally). 4842 for (int rc = 0; rc < Format::Lhs::kCells; rc++) { 4843 const OperandType* lhs_cell_ptr = 4844 lhs_ptr + (dc * Format::Lhs::kCells + rc) * 4845 Format::Lhs::Cell::kWidth * Format::kDepth; 4846 for (int cc = 0; cc < Format::Rhs::kCells; cc++) { 4847 const OperandType* rhs_cell_ptr = 4848 rhs_ptr + (dc * Format::Rhs::kCells + cc) * 4849 Format::Rhs::Cell::kWidth * Format::kDepth; 4850 4851 // Now we are inside one cell of the Lhs and inside one cell 4852 // of the Rhs, so the remaining inner loops are just 4853 // traditional three loops of matrix multiplication. 4854 for (int di = 0; di < Format::kDepth; di++) { 4855 for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) { 4856 for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) { 4857 const OperandType* lhs_coeff_ptr = 4858 lhs_cell_ptr + 4859 OffsetIntoCell<typename Format::Lhs::Cell>(ri, di); 4860 const OperandType* rhs_coeff_ptr = 4861 rhs_cell_ptr + 4862 OffsetIntoCell<typename Format::Rhs::Cell>(ci, di); 4863 AccumulatorType* accumulator_coeff_ptr = 4864 accum_ptr + (ri + rc * Format::Lhs::Cell::kWidth) + 4865 (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows; 4866 *accumulator_coeff_ptr += AccumulatorType(*lhs_coeff_ptr) * 4867 AccumulatorType(*rhs_coeff_ptr); 4868 } 4869 } 4870 } 4871 } 4872 } 4873 } 4874 } 4875 }; 4876 4877 // END code copied from gemmlowp/internal/kernel_reference.h 4878 4879 template <typename DataType> 4880 class CacheLineAlignedBuffer { 4881 public: 4882 CacheLineAlignedBuffer(std::size_t size) : size_(size) { 4883 data_ = nullptr; 4884 // Adds a few bytes of padding here, because the 64-bit 'A57' kernel 4885 // reads one iteration past the end the buffer, causing a crash on iOS. 4886 int res = posix_memalign(reinterpret_cast<void**>(&data_), kCacheLineSize, 4887 size_ * sizeof(DataType) + 16); 4888 (void)res; 4889 } 4890 4891 ~CacheLineAlignedBuffer() { free(data_); } 4892 4893 const DataType* data() const { return data_; } 4894 DataType* data() { return data_; } 4895 4896 std::size_t size() const { return size_; } 4897 4898 private: 4899 const std::size_t size_; 4900 DataType* data_; 4901 }; 4902 4903 template <typename DataType> 4904 void FillRandom(CacheLineAlignedBuffer<DataType>* buffer) { 4905 static std::mt19937 generator(0); 4906 // 100 is smaller than any nonzero bound of the range of any data type. 4907 const DataType kMaxVal = DataType(100); 4908 const DataType kMinVal = 4909 std::is_signed<DataType>::value ? -kMaxVal : DataType(0); 4910 std::uniform_real_distribution<float> dist(kMinVal, kMaxVal); 4911 for (std::size_t i = 0; i < buffer->size(); i++) { 4912 buffer->data()[i] = DataType(dist(generator)); 4913 } 4914 } 4915 4916 template <typename DataType> 4917 void FillZero(CacheLineAlignedBuffer<DataType>* buffer) { 4918 for (std::size_t i = 0; i < buffer->size(); i++) { 4919 buffer->data()[i] = DataType(0); 4920 } 4921 } 4922 4923 template <typename DataType> 4924 void Copy(CacheLineAlignedBuffer<DataType>* dst, 4925 const CacheLineAlignedBuffer<DataType>& src) { 4926 assert(dst->size() == src.size()); 4927 memcpy(dst->data(), src.data(), src.size() * sizeof(DataType)); 4928 } 4929 4930 template <typename DataType> 4931 void PrintMatrix(int rows, int cols, int rowstride, int colstride, 4932 const DataType* data) { 4933 for (int r = 0; r < rows; r++) { 4934 for (int c = 0; c < cols; c++) { 4935 std::cerr << double(data[r * rowstride + c * colstride]) << " "; 4936 } 4937 std::cerr << std::endl; 4938 } 4939 std::cerr << std::endl; 4940 } 4941 4942 template <typename DataType> 4943 bool approx_equals(DataType a, DataType b) { 4944 return a == b; 4945 } 4946 4947 template <> 4948 bool approx_equals(float a, float b) { 4949 if (!a && !b) { 4950 return true; 4951 } 4952 // 1e-1 is very coarse accuracy, we should switch to an overall L2 metric 4953 // and tighten the tolerance on that metric. 4954 return std::abs(a - b) < 1e-1f * std::min(std::abs(a), std::abs(b)); 4955 } 4956 4957 template <typename Kernel> 4958 void test_kernel(int depth, const char* kernel_name) { 4959 typedef typename Kernel::OperandType OperandType; 4960 typedef typename Kernel::AccumulatorType AccumulatorType; 4961 typedef typename Kernel::Format Format; 4962 static const int kLhsWidth = Format::Lhs::kWidth; 4963 static const int kRhsWidth = Format::Rhs::kWidth; 4964 4965 typedef ReferenceKernel<OperandType, AccumulatorType, Format> ReferenceKernel; 4966 4967 CacheLineAlignedBuffer<OperandType> lhs(kLhsWidth * depth); 4968 CacheLineAlignedBuffer<OperandType> rhs(kRhsWidth * depth); 4969 CacheLineAlignedBuffer<AccumulatorType> accum_initial(kLhsWidth * kRhsWidth); 4970 CacheLineAlignedBuffer<AccumulatorType> accum(kLhsWidth * kRhsWidth); 4971 CacheLineAlignedBuffer<AccumulatorType> accum_reference(kLhsWidth * 4972 kRhsWidth); 4973 4974 FillRandom(&lhs); 4975 FillRandom(&rhs); 4976 FillRandom(&accum_initial); 4977 Copy(&accum, accum_initial); 4978 Copy(&accum_reference, accum_initial); 4979 4980 ReferenceKernel::Run(lhs.data(), rhs.data(), accum_reference.data(), depth); 4981 Kernel::Run(lhs.data(), rhs.data(), accum.data(), depth); 4982 4983 for (int l = 0; l < kLhsWidth; l++) { 4984 for (int r = 0; r < kRhsWidth; r++) { 4985 const int index = l + kLhsWidth * r; 4986 if (!approx_equals(accum.data()[index], accum_reference.data()[index])) { 4987 std::cerr << "Arithmetic error in kernel:" << std::endl 4988 << " " << kernel_name << std::endl 4989 << "Wrong accumulator for depth=" << depth << ", " 4990 << "at l = " << l << ", r = " << r << std::endl; 4991 std::cerr << "reference value: " << accum_reference.data()[index] 4992 << std::endl; 4993 std::cerr << "actual value: " << accum.data()[index] << std::endl; 4994 if (depth <= 16) { 4995 std::cerr << "LHS matrix:" << std::endl; 4996 PrintMatrix(kLhsWidth, depth, 1, kLhsWidth, lhs.data()); 4997 std::cerr << "RHS matrix:" << std::endl; 4998 PrintMatrix(depth, kRhsWidth, kRhsWidth, 1, rhs.data()); 4999 std::cerr << "Initial Accumulator matrix:" << std::endl; 5000 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth, accum_initial.data()); 5001 std::cerr << "Reference Accumulator matrix:" << std::endl; 5002 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth, 5003 accum_reference.data()); 5004 std::cerr << "Actual Accumulator matrix:" << std::endl; 5005 PrintMatrix(kLhsWidth, kRhsWidth, 1, kLhsWidth, accum.data()); 5006 } 5007 abort(); 5008 } 5009 } 5010 } 5011 } 5012 5013 template <typename Kernel> 5014 int ops(int depth) { 5015 // 2x the number of multiply-accumulate scalar ops. 5016 return 2 * Kernel::Format::Lhs::kWidth * Kernel::Format::Rhs::kWidth * depth; 5017 } 5018 5019 template <unsigned Modulus, typename Integer> 5020 Integer RoundDown(Integer i) { 5021 return i - (i % Modulus); 5022 } 5023 5024 int CacheSizeInKB() { 5025 static const char* cache_size_k_env = getenv("CACHE_SIZE_KB"); 5026 static const int cache_size_k = 5027 cache_size_k_env ? atoi(cache_size_k_env) : kDefaultCacheSizeK; 5028 return cache_size_k; 5029 } 5030 5031 template <typename Kernel> 5032 int BenchmarkDepthToFitInCache() { 5033 const int cache_size_bytes = 1024 * CacheSizeInKB(); 5034 5035 // Subtract the typical size of a few cache lines, so 5036 // we don't need to worry too hard about e.g. some stack data. 5037 const int conservative_cache_size_bytes = 5038 cache_size_bytes - 2 * kCacheLineSize; 5039 5040 // We will subtract the memory occupied by accumulators. 5041 typedef typename Kernel::AccumulatorType AccumulatorType; 5042 const int kAccumulatorBytes = sizeof(AccumulatorType) * 5043 Kernel::Format::Lhs::kWidth * 5044 Kernel::Format::Rhs::kWidth; 5045 5046 // Compute the depth. 5047 typedef typename Kernel::OperandType OperandType; 5048 const int kBytesPerUnitOfDepth = 5049 sizeof(OperandType) * 5050 (Kernel::Format::Lhs::kWidth + Kernel::Format::Rhs::kWidth); 5051 const int unrounded_depth = 5052 (conservative_cache_size_bytes - kAccumulatorBytes) / 5053 kBytesPerUnitOfDepth; 5054 5055 // Cap depth, to avoid unfairly favoring narrower kernels 5056 const int kMaxDepth = 1024; 5057 const int clamped_unrounded_depth = std::min(kMaxDepth, unrounded_depth); 5058 5059 // Round depth down to a multiple of cache line size, which helps because 5060 // our kernels may crash if depth is not a multiple of the number of 5061 // depth level that they want to 5062 // handle at each loop iteration, and we don't want to require kernels 5063 // to be more complex. Currently all kernels process 1, 2 or 8 levels of 5064 // depth at a time. The main reason why that might increase in the future 5065 // is if registers get wider, but I don't suppose that register could 5066 // ever get wider than cache lines. 5067 return RoundDown<kCacheLineSize>(clamped_unrounded_depth); 5068 } 5069 5070 double current_time_in_seconds() { 5071 timespec t; 5072 clock_gettime(CLOCK_REALTIME, &t); 5073 return t.tv_sec + 1e-9 * t.tv_nsec; 5074 } 5075 5076 template <typename Kernel> 5077 double benchmark(int depth) { 5078 // Minimum duration for this benchmark to run. If the workload finishes 5079 // sooner, we retry with double the number of iterations. 5080 static const double min_benchmark_time_in_seconds = 1.0; 5081 5082 typedef typename Kernel::OperandType OperandType; 5083 typedef typename Kernel::AccumulatorType AccumulatorType; 5084 5085 CacheLineAlignedBuffer<OperandType> lhs(Kernel::Format::Lhs::kWidth * depth); 5086 CacheLineAlignedBuffer<OperandType> rhs(Kernel::Format::Rhs::kWidth * depth); 5087 CacheLineAlignedBuffer<AccumulatorType> accum(Kernel::Format::Lhs::kWidth * 5088 Kernel::Format::Rhs::kWidth); 5089 5090 for (std::uint64_t iters_at_a_time = 1;; iters_at_a_time *= 2) { 5091 const double t_start = current_time_in_seconds(); 5092 for (std::uint64_t i = 0; i < iters_at_a_time; i++) { 5093 Kernel::Run(lhs.data(), rhs.data(), accum.data(), depth); 5094 } 5095 const double t_end = current_time_in_seconds(); 5096 const double elapsed = t_end - t_start; 5097 if (elapsed > min_benchmark_time_in_seconds) { 5098 return iters_at_a_time * ops<Kernel>(depth) / elapsed; 5099 } 5100 } 5101 } 5102 5103 template <typename Kernel> 5104 void benchmark_and_print_results(const char* kernel_name) { 5105 if (getenv("BENCHMARK_KERNEL")) { 5106 if (strcmp(getenv("BENCHMARK_KERNEL"), kernel_name)) { 5107 return; 5108 } 5109 } 5110 const int kKernelDepth = Kernel::Format::kDepth; 5111 for (int depth = kKernelDepth; depth <= 1024; depth += kKernelDepth) { 5112 test_kernel<Kernel>(depth, kernel_name); 5113 } 5114 5115 if (getenv("BENCHMARK_ALL_DEPTHS")) { 5116 for (int depth = kKernelDepth; 5117 depth <= BenchmarkDepthToFitInCache<Kernel>(); depth *= 2) { 5118 std::cout << kernel_name << "," << depth << "," 5119 << benchmark<Kernel>(depth) * 1e-9f << std::endl; 5120 } 5121 } else { 5122 const int depth = BenchmarkDepthToFitInCache<Kernel>(); 5123 std::cout << kernel_name << "," << benchmark<Kernel>(depth) * 1e-9f 5124 << std::endl; 5125 } 5126 } 5127 5128 #define BENCHMARK(Kernel) \ 5129 do { \ 5130 benchmark_and_print_results<Kernel>(#Kernel); \ 5131 } while (false) 5132 5133 int main() { 5134 if (getenv("BENCHMARK_ALL_DEPTHS")) { 5135 std::cout << "kernel,depth,Gop/s" << std::endl; 5136 } else { 5137 std::cout << "kernel,Gop/s" << std::endl; 5138 } 5139 5140 #ifdef __arm__ 5141 BENCHMARK(NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits); 5142 BENCHMARK(NEON_32bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics); 5143 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators); 5144 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics); 5145 BENCHMARK(NEON_32bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand); 5146 BENCHMARK(NEON_32bit_GEMM_Int32_WithScalar); 5147 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_WithVectorDuplicatingScalar); 5148 #ifdef __ARM_FEATURE_FMA 5149 BENCHMARK(NEON_32bit_GEMM_Float32_FMA_WithVectorDuplicatingScalar); 5150 #endif 5151 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_WithScalar); 5152 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_intrinsics); 5153 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_A53); 5154 BENCHMARK(NEON_32bit_GEMM_Float32_WithScalar_A53_depth2); 5155 BENCHMARK(NEON_32bit_GEMM_Float32_MLA_Rotating); 5156 #ifdef __ARM_FEATURE_FMA 5157 BENCHMARK(NEON_32bit_GEMM_Float32_FMA_Rotating); 5158 #endif 5159 #endif 5160 5161 #ifdef __aarch64__ 5162 BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits); 5163 BENCHMARK(NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits_intrinsics); 5164 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators); 5165 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_intrinsics); 5166 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_noexpand_A57); 5167 #ifdef __ARM_FEATURE_DOTPROD 5168 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct); 5169 BENCHMARK(NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct_A55r1); 5170 #endif 5171 BENCHMARK(NEON_64bit_GEMM_Int32_WithScalar); 5172 BENCHMARK(NEON_64bit_GEMM_Float32_WithVectorDuplicatingScalar); 5173 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar); 5174 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_intrinsics); 5175 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A57); 5176 #ifndef __APPLE__ 5177 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A53); 5178 #endif 5179 BENCHMARK(NEON_64bit_GEMM_Float32_WithScalar_A55r1); 5180 #endif 5181 5182 #ifdef __mips 5183 BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_intrinsics); 5184 BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly); 5185 BENCHMARK(MSA_GEMM_12x4_Uint8Operands_Uint32Accumulators_assembly2); 5186 BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_intrinsics); 5187 BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly); 5188 BENCHMARK(MSA_GEMM_12x8_Uint8Operands_Uint32Accumulators_assembly2); 5189 #endif 5190 5191 return 0; 5192 } 5193