1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // output_neon.h: optimized NEON specializations of the templates in output.h. 16 17 #ifndef GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 18 #define GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 19 20 #include "output.h" 21 22 #include <arm_neon.h> 23 24 namespace gemmlowp { 25 26 template <> 27 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 28 RegBufferInt32<4>> { 29 typedef RegBufferInt32<4> InputType; 30 typedef RegBufferUint8<4> OutputType; 31 32 typedef OutputStageSaturatingCastToUint8 OutputStage; 33 34 OutputStageEvalBufferImpl(const OutputStage&) {} 35 36 OutputType Eval(InputType input) const { 37 OutputType output; 38 int16x4_t res_16 = vqmovn_s32(input.reg[0]); 39 uint8x8_t res_8 = vqmovun_s16(vcombine_s16(res_16, res_16)); 40 output.reg[0] = vget_lane_u32(vreinterpret_u32_u8(res_8), 0); 41 return output; 42 } 43 }; 44 45 template <> 46 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 47 RegBufferInt32<8>> { 48 typedef RegBufferInt32<8> InputType; 49 typedef RegBufferUint8<8> OutputType; 50 51 typedef OutputStageSaturatingCastToUint8 OutputStage; 52 53 OutputStageEvalBufferImpl(const OutputStage&) {} 54 55 OutputType Eval(InputType input) const { 56 OutputType output; 57 int16x8_t res_16 = 58 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 59 output.reg[0] = vqmovun_s16(res_16); 60 return output; 61 } 62 }; 63 64 template <> 65 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 66 RegBufferInt32<16>> { 67 typedef RegBufferInt32<16> InputType; 68 typedef RegBufferUint8<16> OutputType; 69 70 typedef OutputStageSaturatingCastToUint8 OutputStage; 71 72 OutputStageEvalBufferImpl(const OutputStage&) {} 73 74 OutputType Eval(InputType input) const { 75 OutputType output; 76 int16x8_t res_16_0 = 77 vcombine_s16(vqmovn_s32(input.reg[0]), vqmovn_s32(input.reg[1])); 78 int16x8_t res_16_1 = 79 vcombine_s16(vqmovn_s32(input.reg[2]), vqmovn_s32(input.reg[3])); 80 output.reg[0] = vqmovun_s16(res_16_0); 81 output.reg[1] = vqmovun_s16(res_16_1); 82 return output; 83 } 84 }; 85 86 template <> 87 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 88 RegBufferInt32<32>> { 89 typedef RegBufferInt32<32> InputType; 90 typedef RegBufferUint8<32> OutputType; 91 92 typedef OutputStageSaturatingCastToUint8 OutputStage; 93 94 OutputStageEvalBufferImpl(const OutputStage&) {} 95 96 OutputType Eval(InputType input) const { 97 OutputType output; 98 int16x8_t res_16[4]; 99 for (int i = 0; i < 4; i++) { 100 res_16[i] = vcombine_s16(vqmovn_s32(input.reg[2 * i]), 101 vqmovn_s32(input.reg[2 * i + 1])); 102 } 103 for (int i = 0; i < 4; i++) { 104 output.reg[i] = vqmovun_s16(res_16[i]); 105 } 106 return output; 107 } 108 }; 109 110 template <typename DstType> 111 struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { 112 static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, 113 int col) { 114 if (DstType::kOrder == MapOrder::ColMajor) { 115 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 116 StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); 117 } else { 118 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 119 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 120 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 121 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 122 *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); 123 *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); 124 *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); 125 *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); 126 } 127 } 128 }; 129 130 inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { 131 const int32x4x2_t t0 = vtrnq_s32(src.buf.reg[0], src.buf.reg[1]); 132 const int32x4x2_t t1 = vtrnq_s32(src.buf.reg[2], src.buf.reg[3]); 133 RegBlockInt32<4, 4> result; 134 result.buf.reg[0] = 135 vcombine_s32(vget_low_s32(t0.val[0]), vget_low_s32(t1.val[0])); 136 result.buf.reg[1] = 137 vcombine_s32(vget_low_s32(t0.val[1]), vget_low_s32(t1.val[1])); 138 result.buf.reg[2] = 139 vcombine_s32(vget_high_s32(t0.val[0]), vget_high_s32(t1.val[0])); 140 result.buf.reg[3] = 141 vcombine_s32(vget_high_s32(t0.val[1]), vget_high_s32(t1.val[1])); 142 return result; 143 } 144 145 template <typename DstType> 146 struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { 147 static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, 148 int col) { 149 const auto& block = 150 DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); 151 std::int32_t* dst_ptr = dst->data(row, col); 152 int stride = dst->stride(); 153 for (int i = 0; i < 4; i++) { 154 vst1q_s32(dst_ptr + i * stride, block.buf.reg[i]); 155 } 156 } 157 }; 158 159 template <typename DstType> 160 struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { 161 static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, 162 int col) { 163 std::int32_t* dst_ptr = dst->data(row, col); 164 if (DstType::kOrder == MapOrder::ColMajor) { 165 int col_stride = dst->cols_stride(); 166 for (int i = 0; i < 4; i++) { 167 vst1q_s32(dst_ptr + i * col_stride + 0, src.buf.reg[2 * i + 0]); 168 vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); 169 } 170 } else { 171 int row_stride = dst->rows_stride(); 172 RegBlockInt32<4, 4> top; 173 top.buf.reg[0] = src.buf.reg[0]; 174 top.buf.reg[1] = src.buf.reg[2]; 175 top.buf.reg[2] = src.buf.reg[4]; 176 top.buf.reg[3] = src.buf.reg[6]; 177 const auto transpose_top = Transpose(top); 178 for (int i = 0; i < 4; i++) { 179 vst1q_s32(dst_ptr + i * row_stride, transpose_top.buf.reg[i]); 180 } 181 RegBlockInt32<4, 4> bottom; 182 bottom.buf.reg[0] = src.buf.reg[1]; 183 bottom.buf.reg[1] = src.buf.reg[3]; 184 bottom.buf.reg[2] = src.buf.reg[5]; 185 bottom.buf.reg[3] = src.buf.reg[7]; 186 const auto transpose_bottom = Transpose(bottom); 187 for (int i = 0; i < 4; i++) { 188 vst1q_s32(dst_ptr + (i + 4) * row_stride, transpose_bottom.buf.reg[i]); 189 } 190 } 191 } 192 }; 193 194 template <typename DstType> 195 struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { 196 static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, 197 int col) { 198 std::int32_t* dst_ptr = dst->data(row, col); 199 if (DstType::kOrder == MapOrder::ColMajor) { 200 int col_stride = dst->cols_stride(); 201 for (int i = 0; i < 8; i++) { 202 vst1q_s32(dst_ptr + i * col_stride, src.buf.reg[2 * i]); 203 vst1q_s32(dst_ptr + i * col_stride + 4, src.buf.reg[2 * i + 1]); 204 } 205 } else { 206 int row_stride = dst->rows_stride(); 207 RegBlockInt32<4, 4> top_left; 208 top_left.buf.reg[0] = src.buf.reg[0]; 209 top_left.buf.reg[1] = src.buf.reg[2]; 210 top_left.buf.reg[2] = src.buf.reg[4]; 211 top_left.buf.reg[3] = src.buf.reg[6]; 212 const auto transpose_top_left = Transpose(top_left); 213 for (int i = 0; i < 4; i++) { 214 vst1q_s32(dst_ptr + i * row_stride, transpose_top_left.buf.reg[i]); 215 } 216 RegBlockInt32<4, 4> bottom_left; 217 bottom_left.buf.reg[0] = src.buf.reg[1]; 218 bottom_left.buf.reg[1] = src.buf.reg[3]; 219 bottom_left.buf.reg[2] = src.buf.reg[5]; 220 bottom_left.buf.reg[3] = src.buf.reg[7]; 221 const auto transpose_bottom_left = Transpose(bottom_left); 222 for (int i = 0; i < 4; i++) { 223 vst1q_s32(dst_ptr + (i + 4) * row_stride, 224 transpose_bottom_left.buf.reg[i]); 225 } 226 RegBlockInt32<4, 4> top_right; 227 top_right.buf.reg[0] = src.buf.reg[8]; 228 top_right.buf.reg[1] = src.buf.reg[10]; 229 top_right.buf.reg[2] = src.buf.reg[12]; 230 top_right.buf.reg[3] = src.buf.reg[14]; 231 const auto transpose_top_right = Transpose(top_right); 232 for (int i = 0; i < 4; i++) { 233 vst1q_s32(dst_ptr + i * row_stride + 4, transpose_top_right.buf.reg[i]); 234 } 235 RegBlockInt32<4, 4> bottom_right; 236 bottom_right.buf.reg[0] = src.buf.reg[9]; 237 bottom_right.buf.reg[1] = src.buf.reg[11]; 238 bottom_right.buf.reg[2] = src.buf.reg[13]; 239 bottom_right.buf.reg[3] = src.buf.reg[15]; 240 const auto transpose_bottom_right = Transpose(bottom_right); 241 for (int i = 0; i < 4; i++) { 242 vst1q_s32(dst_ptr + (i + 4) * row_stride + 4, 243 transpose_bottom_right.buf.reg[i]); 244 } 245 } 246 } 247 }; 248 249 template <typename DstType> 250 struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { 251 static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, 252 int col) { 253 std::int32_t* dst_ptr = dst->data(row, col); 254 if (DstType::kOrder == MapOrder::ColMajor) { 255 vst1q_s32(dst_ptr, src.buf.reg[0]); 256 } else { 257 int row_stride = dst->rows_stride(); 258 vst1q_lane_s32(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); 259 vst1q_lane_s32(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); 260 vst1q_lane_s32(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); 261 vst1q_lane_s32(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); 262 } 263 } 264 }; 265 266 template <typename DstType> 267 struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { 268 static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, 269 int col) { 270 std::int32_t* dst_ptr = dst->data(row, col); 271 if (DstType::kOrder == MapOrder::RowMajor) { 272 vst1q_s32(dst_ptr, src.buf.reg[0]); 273 } else { 274 int col_stride = dst->cols_stride(); 275 vst1q_lane_s32(dst_ptr + 0 * col_stride, src.buf.reg[0], 0); 276 vst1q_lane_s32(dst_ptr + 1 * col_stride, src.buf.reg[0], 1); 277 vst1q_lane_s32(dst_ptr + 2 * col_stride, src.buf.reg[0], 2); 278 vst1q_lane_s32(dst_ptr + 3 * col_stride, src.buf.reg[0], 3); 279 } 280 } 281 }; 282 283 template <typename DstType> 284 struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { 285 static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, 286 int col) { 287 const std::uint32_t src_reg = src.buf.reg[0]; 288 for (int i = 0; i < 4; i++) { 289 *dst->data(row + i, col) = (src_reg >> (8 * i)); 290 } 291 } 292 }; 293 294 template <typename DstType> 295 struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { 296 static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, 297 int col) { 298 for (int i = 0; i < 4; i++) { 299 *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); 300 } 301 } 302 }; 303 304 template <typename DstType> 305 struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { 306 static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, 307 int col) { 308 std::uint8_t* dst_ptr = dst->data(row, col); 309 if (DstType::kOrder == MapOrder::ColMajor) { 310 vst1_u8(dst_ptr, src.buf.reg[0]); 311 } else { 312 const int row_stride = dst->rows_stride(); 313 vst1_lane_u8(dst_ptr + 0 * row_stride, src.buf.reg[0], 0); 314 vst1_lane_u8(dst_ptr + 1 * row_stride, src.buf.reg[0], 1); 315 vst1_lane_u8(dst_ptr + 2 * row_stride, src.buf.reg[0], 2); 316 vst1_lane_u8(dst_ptr + 3 * row_stride, src.buf.reg[0], 3); 317 vst1_lane_u8(dst_ptr + 4 * row_stride, src.buf.reg[0], 4); 318 vst1_lane_u8(dst_ptr + 5 * row_stride, src.buf.reg[0], 5); 319 vst1_lane_u8(dst_ptr + 6 * row_stride, src.buf.reg[0], 6); 320 vst1_lane_u8(dst_ptr + 7 * row_stride, src.buf.reg[0], 7); 321 } 322 } 323 }; 324 325 template <typename DstType> 326 struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { 327 static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, 328 int col) { 329 std::uint8_t* dst_ptr = dst->data(row, col); 330 const int row_stride = dst->rows_stride(); 331 const int col_stride = dst->cols_stride(); 332 for (int i = 0; i < 2; i++) { 333 vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 0) * col_stride, 334 src.buf.reg[i], 0); 335 vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 0) * col_stride, 336 src.buf.reg[i], 1); 337 vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 0) * col_stride, 338 src.buf.reg[i], 2); 339 vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 0) * col_stride, 340 src.buf.reg[i], 3); 341 vst1_lane_u8(dst_ptr + 0 * row_stride + (2 * i + 1) * col_stride, 342 src.buf.reg[i], 4); 343 vst1_lane_u8(dst_ptr + 1 * row_stride + (2 * i + 1) * col_stride, 344 src.buf.reg[i], 5); 345 vst1_lane_u8(dst_ptr + 2 * row_stride + (2 * i + 1) * col_stride, 346 src.buf.reg[i], 6); 347 vst1_lane_u8(dst_ptr + 3 * row_stride + (2 * i + 1) * col_stride, 348 src.buf.reg[i], 7); 349 } 350 } 351 }; 352 353 template <typename DstType> 354 struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { 355 static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, 356 int col) { 357 std::uint8_t* dst_ptr = dst->data(row, col); 358 if (DstType::kOrder == MapOrder::ColMajor) { 359 int col_stride = dst->cols_stride(); 360 for (int i = 0; i < 4; i++) { 361 vst1_u8(dst_ptr + i * col_stride, src.buf.reg[i]); 362 } 363 } else { 364 for (int i = 0; i < 4; i++) { 365 int row_stride = dst->rows_stride(); 366 std::uint8_t* col_ptr = dst_ptr + i; 367 vst1_lane_u8(col_ptr + 0 * row_stride, src.buf.reg[i], 0); 368 vst1_lane_u8(col_ptr + 1 * row_stride, src.buf.reg[i], 1); 369 vst1_lane_u8(col_ptr + 2 * row_stride, src.buf.reg[i], 2); 370 vst1_lane_u8(col_ptr + 3 * row_stride, src.buf.reg[i], 3); 371 vst1_lane_u8(col_ptr + 4 * row_stride, src.buf.reg[i], 4); 372 vst1_lane_u8(col_ptr + 5 * row_stride, src.buf.reg[i], 5); 373 vst1_lane_u8(col_ptr + 6 * row_stride, src.buf.reg[i], 6); 374 vst1_lane_u8(col_ptr + 7 * row_stride, src.buf.reg[i], 7); 375 } 376 } 377 } 378 }; 379 380 inline RegBlockUint8<8, 8> Transpose(const RegBlockUint8<8, 8>& src) { 381 uint8x8x2_t a[4]; 382 a[0] = vtrn_u8(src.buf.reg[0], src.buf.reg[1]); 383 a[1] = vtrn_u8(src.buf.reg[2], src.buf.reg[3]); 384 a[2] = vtrn_u8(src.buf.reg[4], src.buf.reg[5]); 385 a[3] = vtrn_u8(src.buf.reg[6], src.buf.reg[7]); 386 uint16x4x2_t b[4]; 387 b[0] = vtrn_u16(vreinterpret_u16_u8(a[0].val[0]), 388 vreinterpret_u16_u8(a[1].val[0])); 389 b[1] = vtrn_u16(vreinterpret_u16_u8(a[0].val[1]), 390 vreinterpret_u16_u8(a[1].val[1])); 391 b[2] = vtrn_u16(vreinterpret_u16_u8(a[2].val[0]), 392 vreinterpret_u16_u8(a[3].val[0])); 393 b[3] = vtrn_u16(vreinterpret_u16_u8(a[2].val[1]), 394 vreinterpret_u16_u8(a[3].val[1])); 395 uint32x2x2_t c[4]; 396 c[0] = vtrn_u32(vreinterpret_u32_u16(b[0].val[0]), 397 vreinterpret_u32_u16(b[2].val[0])); 398 c[1] = vtrn_u32(vreinterpret_u32_u16(b[1].val[0]), 399 vreinterpret_u32_u16(b[3].val[0])); 400 c[2] = vtrn_u32(vreinterpret_u32_u16(b[0].val[1]), 401 vreinterpret_u32_u16(b[2].val[1])); 402 c[3] = vtrn_u32(vreinterpret_u32_u16(b[1].val[1]), 403 vreinterpret_u32_u16(b[3].val[1])); 404 RegBlockUint8<8, 8> result; 405 result.buf.reg[0] = vreinterpret_u8_u32(c[0].val[0]); 406 result.buf.reg[1] = vreinterpret_u8_u32(c[1].val[0]); 407 result.buf.reg[2] = vreinterpret_u8_u32(c[2].val[0]); 408 result.buf.reg[3] = vreinterpret_u8_u32(c[3].val[0]); 409 result.buf.reg[4] = vreinterpret_u8_u32(c[0].val[1]); 410 result.buf.reg[5] = vreinterpret_u8_u32(c[1].val[1]); 411 result.buf.reg[6] = vreinterpret_u8_u32(c[2].val[1]); 412 result.buf.reg[7] = vreinterpret_u8_u32(c[3].val[1]); 413 return result; 414 } 415 416 template <typename DstType> 417 struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { 418 static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, 419 int col) { 420 const auto& block = 421 DstType::kOrder == MapOrder::ColMajor ? src : Transpose(src); 422 std::uint8_t* dst_ptr = dst->data(row, col); 423 int stride = dst->stride(); 424 for (int i = 0; i < 8; i++) { 425 vst1_u8(dst_ptr + i * stride, block.buf.reg[i]); 426 } 427 } 428 }; 429 430 } // namespace gemmlowp 431 432 #endif // GEMMLOWP_INTERNAL_OUTPUT_NEON_H_ 433