1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // pack_neon.h: optimized NEON specializations of the templates in pack.h. 16 17 #ifndef GEMMLOWP_INTERNAL_PACK_NEON_H_ 18 #define GEMMLOWP_INTERNAL_PACK_NEON_H_ 19 20 #include "pack.h" 21 22 #include <arm_neon.h> 23 24 namespace gemmlowp { 25 26 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> 27 WidthMajorUint8SideMap; 28 29 template <int Cells> 30 using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>; 31 32 template <int Cells> 33 class PackingRegisterBlock< 34 WidthMajorUint8SideMap, 35 PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> 36 : public PackingRegisterBlockBase< 37 WidthMajorUint8SideMap, 38 PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> { 39 public: 40 typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat; 41 typedef typename KernelSideFormat::Cell CellFormat; 42 static const int kCells = KernelSideFormat::kCells; 43 static const int kCellWidth = CellFormat::kWidth; 44 static const int kKernelWidth = CellFormat::kWidth * kCells; 45 static const int kCellDepth = CellFormat::kDepth; 46 static const int kCellSize = CellFormat::kSize; 47 48 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { 49 std::uint8_t* dst_ptr = dst->current_data(); 50 const std::uint8_t* const src_ptr = this->complete_src_.data(); 51 const int stride = this->complete_src_.stride(); 52 // Load source WidthMajor data 53 uint8x16_t src_lines[4 * kCells]; 54 for (int i = 0; i < 4 * kCells; i++) { 55 src_lines[i] = vld1q_u8(src_ptr + i * stride); 56 } 57 // Reorder the data within registers to make DepthMajor 4x2 cells 58 uint8x16x2_t src_lines_intertwined_2x[2 * kCells]; 59 for (int i = 0; i < kCells; i++) { 60 src_lines_intertwined_2x[2 * i] = 61 vzipq_u8(src_lines[4 * i], src_lines[4 * i + 2]); 62 src_lines_intertwined_2x[2 * i + 1] = 63 vzipq_u8(src_lines[4 * i + 1], src_lines[4 * i + 3]); 64 } 65 uint8x16x2_t src_lines_intertwined_4x[2 * kCells]; 66 for (int i = 0; i < kCells; i++) { 67 src_lines_intertwined_4x[2 * i] = 68 vzipq_u8(src_lines_intertwined_2x[2 * i].val[0], 69 src_lines_intertwined_2x[2 * i + 1].val[0]); 70 src_lines_intertwined_4x[2 * i + 1] = 71 vzipq_u8(src_lines_intertwined_2x[2 * i].val[1], 72 src_lines_intertwined_2x[2 * i + 1].val[1]); 73 } 74 // Store the resulting DepthMajor 4x2 cells in the destination packed block 75 for (int outer = 0; outer < 2; outer++) { 76 for (int inner = 0; inner < 2; inner++) { 77 for (int cell = 0; cell < kCells; cell++) { 78 uint8x8_t value = vget_low_u8( 79 src_lines_intertwined_4x[2 * cell + outer].val[inner]); 80 vst1_u8(dst_ptr, value); 81 dst_ptr += 8; 82 } 83 for (int cell = 0; cell < kCells; cell++) { 84 uint8x8_t value = vget_high_u8( 85 src_lines_intertwined_4x[2 * cell + outer].val[inner]); 86 vst1_u8(dst_ptr, value); 87 dst_ptr += 8; 88 } 89 } 90 } 91 // Compute sums across the depth dimension 92 uint16x8_t sums_of_2_cells[kCells][4]; 93 for (int outer = 0; outer < 2; outer++) { 94 for (int inner = 0; inner < 2; inner++) { 95 int i = 2 * outer + inner; 96 for (int cell = 0; cell < kCells; cell++) { 97 sums_of_2_cells[cell][i] = vaddl_u8( 98 vget_low_u8( 99 src_lines_intertwined_4x[2 * cell + outer].val[inner]), 100 vget_high_u8( 101 src_lines_intertwined_4x[2 * cell + outer].val[inner])); 102 } 103 } 104 } 105 int32x4_t sums_of_4_cells[kCells][4]; 106 for (int i = 0; i < 4; i++) { 107 for (int cell = 0; cell < kCells; cell++) { 108 sums_of_4_cells[cell][i] = vreinterpretq_s32_u32( 109 vaddl_u16(vget_low_u16(sums_of_2_cells[cell][i]), 110 vget_high_u16(sums_of_2_cells[cell][i]))); 111 } 112 } 113 // Update the sums_of_each_slice vector 114 for (int cell = 0; cell < kCells; cell++) { 115 int32x4_t s01 = 116 vaddq_s32(sums_of_4_cells[cell][0], sums_of_4_cells[cell][1]); 117 int32x4_t s23 = 118 vaddq_s32(sums_of_4_cells[cell][2], sums_of_4_cells[cell][3]); 119 int32x4_t s = vaddq_s32(s01, s23); 120 std::int32_t* sums_of_each_slice_ptr = 121 dst->sums_of_each_slice() + start_width + 4 * cell; 122 vst1q_s32(sums_of_each_slice_ptr, 123 vaddq_s32(s, vld1q_s32(sums_of_each_slice_ptr))); 124 } 125 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); 126 } 127 }; 128 129 template <int Cells> 130 using WidthMajorSideFormatNCells4x2 = 131 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; 132 133 template <int Cells> 134 class PackingRegisterBlock< 135 WidthMajorUint8SideMap, 136 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> 137 : public PackingRegisterBlockBase< 138 WidthMajorUint8SideMap, 139 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> { 140 public: 141 typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; 142 typedef typename KernelSideFormat::Cell CellFormat; 143 static const int kCells = KernelSideFormat::kCells; 144 static const int kCellWidth = CellFormat::kWidth; 145 static const int kKernelWidth = CellFormat::kWidth * kCells; 146 static const int kCellDepth = CellFormat::kDepth; 147 static const int kCellSize = CellFormat::kSize; 148 149 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { 150 std::uint8_t* dst_ptr = dst->current_data(); 151 const std::uint8_t* src_ptr = this->complete_src_.data(); 152 const int stride = this->complete_src_.stride(); 153 // Load source WidthMajor data 154 uint16x8_t src_lines[kCells * 4]; 155 for (int i = 0; i < kCells; i++) { 156 // This packing path is used with our current 157 // less-than-8-bit kernel, and the partial unrolling of this loop 158 // results in substantially faster code (thanks to better 159 // register allocation) on Nexus 5. 160 161 #define GEMMLOWP_UNROLLED_LOOP_ITER(k) \ 162 src_lines[4 * i + k] = vreinterpretq_u16_u8(vld1q_u8(src_ptr)); \ 163 src_ptr += stride; 164 165 GEMMLOWP_UNROLLED_LOOP_ITER(0) 166 GEMMLOWP_UNROLLED_LOOP_ITER(1) 167 GEMMLOWP_UNROLLED_LOOP_ITER(2) 168 GEMMLOWP_UNROLLED_LOOP_ITER(3) 169 170 #undef GEMMLOWP_UNROLLED_LOOP_ITER 171 } 172 // Reorder the data within registers to make WidthMajor 4x2 cells 173 uint16x8x2_t src_lines_intertwined_2x[2 * kCells]; 174 for (int i = 0; i < kCells; i++) { 175 src_lines_intertwined_2x[2 * i] = 176 vzipq_u16(src_lines[4 * i], src_lines[4 * i + 2]); 177 src_lines_intertwined_2x[2 * i + 1] = 178 vzipq_u16(src_lines[4 * i + 1], src_lines[4 * i + 3]); 179 } 180 uint16x8x2_t src_lines_intertwined_4x[2 * kCells]; 181 for (int i = 0; i < kCells; i++) { 182 src_lines_intertwined_4x[2 * i] = 183 vzipq_u16(src_lines_intertwined_2x[2 * i].val[0], 184 src_lines_intertwined_2x[2 * i + 1].val[0]); 185 src_lines_intertwined_4x[2 * i + 1] = 186 vzipq_u16(src_lines_intertwined_2x[2 * i].val[1], 187 src_lines_intertwined_2x[2 * i + 1].val[1]); 188 } 189 // Store the resulting WidthMajor 4x2 cells in the destination packed block 190 for (int outer = 0; outer < 2; outer++) { 191 for (int inner = 0; inner < 2; inner++) { 192 for (int cell = 0; cell < kCells; cell++) { 193 uint8x8_t value = vreinterpret_u8_u16(vget_low_u16( 194 src_lines_intertwined_4x[2 * cell + outer].val[inner])); 195 vst1_u8(dst_ptr, value); 196 dst_ptr += 8; 197 } 198 for (int cell = 0; cell < kCells; cell++) { 199 uint8x8_t value = vreinterpret_u8_u16(vget_high_u16( 200 src_lines_intertwined_4x[2 * cell + outer].val[inner])); 201 vst1_u8(dst_ptr, value); 202 dst_ptr += 8; 203 } 204 } 205 } 206 // Compute sums across the depth dimension 207 uint16x8_t sums_of_2[kCells][4]; 208 for (int outer = 0; outer < 2; outer++) { 209 for (int inner = 0; inner < 2; inner++) { 210 int i = 2 * outer + inner; 211 for (int cell = 0; cell < kCells; cell++) { 212 sums_of_2[cell][i] = vpaddlq_u8(vreinterpretq_u8_u16( 213 src_lines_intertwined_4x[2 * cell + outer].val[inner])); 214 } 215 } 216 } 217 uint16x8_t sums_of_4[kCells][2]; 218 for (int i = 0; i < 2; i++) { 219 for (int cell = 0; cell < kCells; cell++) { 220 sums_of_4[cell][i] = 221 vaddq_u16(sums_of_2[cell][2 * i], sums_of_2[cell][2 * i + 1]); 222 } 223 } 224 uint16x8_t sums_of_8[kCells]; 225 for (int cell = 0; cell < kCells; cell++) { 226 sums_of_8[cell] = vaddq_u16(sums_of_4[cell][0], sums_of_4[cell][1]); 227 } 228 229 uint16x4_t sums_of_16[kCells]; 230 for (int cell = 0; cell < kCells; cell++) { 231 sums_of_16[cell] = vadd_u16(vget_low_u16(sums_of_8[cell]), 232 vget_high_u16(sums_of_8[cell])); 233 } 234 // Update the sums_of_each_slice vector 235 for (int cell = 0; cell < kCells; cell++) { 236 int32x4_t s = vreinterpretq_s32_u32(vmovl_u16(sums_of_16[cell])); 237 std::int32_t* sums_of_each_slice_ptr = 238 dst->sums_of_each_slice() + start_width + 4 * cell; 239 vst1q_s32(sums_of_each_slice_ptr, 240 vaddq_s32(s, vld1q_s32(sums_of_each_slice_ptr))); 241 } 242 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); 243 } 244 }; 245 246 #ifdef GEMMLOWP_NEON_32 247 inline int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { 248 const int16x4_t c = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); 249 const int16x4_t d = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); 250 return vcombine_s16(c, d); 251 } 252 #endif 253 254 template <int Width> 255 using Int8FastKernelFormat = 256 KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>; 257 258 template <int Width> 259 class PackingRegisterBlock<WidthMajorUint8SideMap, 260 PackedSideBlock<Int8FastKernelFormat<Width>>> 261 : public PackingRegisterBlockBase< 262 WidthMajorUint8SideMap, 263 PackedSideBlock<Int8FastKernelFormat<Width>>> { 264 public: 265 static_assert(Width == 2 || Width == 4, ""); 266 typedef Int8FastKernelFormat<Width> KernelSideFormat; 267 typedef typename KernelSideFormat::Cell CellFormat; 268 static const int kCells = KernelSideFormat::kCells; 269 static const int kCellWidth = CellFormat::kWidth; 270 static const int kKernelWidth = CellFormat::kWidth * kCells; 271 static const int kCellDepth = CellFormat::kDepth; 272 static const int kCellSize = CellFormat::kSize; 273 274 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) { 275 std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width; 276 std::uint8_t* dst_ptr = dst->current_data(); 277 const std::uint8_t* const src_ptr = this->complete_src_.data(); 278 const int stride = this->complete_src_.stride(); 279 // Load source WidthMajor data 280 uint8x16_t src_lines[Width]; 281 for (int i = 0; i < Width; i++) { 282 src_lines[i] = vld1q_u8(src_ptr + i * stride); 283 } 284 const uint8x16_t sign_bit_dup = vdupq_n_u8(0x80); 285 for (int i = 0; i < Width; i++) { 286 src_lines[i] = veorq_u8(src_lines[i], sign_bit_dup); 287 } 288 for (int i = 0; i < Width; i++) { 289 vst1q_u8(dst_ptr + 16 * i, src_lines[i]); 290 } 291 int16x8_t sums2[Width]; 292 for (int i = 0; i < Width; i++) { 293 const int8x8_t lo = vreinterpret_s8_u8(vget_low_u8(src_lines[i])); 294 const int8x8_t hi = vreinterpret_s8_u8(vget_high_u8(src_lines[i])); 295 sums2[i] = vaddl_s8(lo, hi); 296 } 297 int16x8_t sums4[Width / 2]; 298 for (int i = 0; i < Width / 2; i++) { 299 sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]); 300 } 301 if (Width == 4) { 302 int32x4_t sum = vld1q_s32(sums_ptr); 303 int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]); 304 sum = vpadalq_s16(sum, sums8); 305 vst1q_s32(sums_ptr, sum); 306 } else { 307 assert(Width == 2); 308 int32x2_t sum = vld1_s32(sums_ptr); 309 int16x4_t sums8 = 310 vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0])); 311 sum = vpadal_s16(sum, sums8); 312 vst1_s32(sums_ptr, sum); 313 } 314 dst->seek_forward_n_cells(1); 315 } 316 }; 317 318 } // namespace gemmlowp 319 320 #endif // GEMMLOWP_INTERNAL_PACK_NEON_H_ 321