Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // pack_neon.h: optimized NEON specializations of the templates in pack.h.
     16 
     17 #ifndef GEMMLOWP_INTERNAL_PACK_NEON_H_
     18 #define GEMMLOWP_INTERNAL_PACK_NEON_H_
     19 
     20 #include "pack.h"
     21 
     22 #include <arm_neon.h>
     23 
     24 namespace gemmlowp {
     25 
     26 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor>
     27     WidthMajorUint8SideMap;
     28 
     29 template <int Cells>
     30 using DepthMajorSideFormatNCells4x2 = KernelSideFormat<CellFormat<4, 2>, Cells>;
     31 
     32 template <int Cells>
     33 class PackingRegisterBlock<
     34     WidthMajorUint8SideMap,
     35     PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>>
     36     : public PackingRegisterBlockBase<
     37           WidthMajorUint8SideMap,
     38           PackedSideBlock<DepthMajorSideFormatNCells4x2<Cells>>> {
     39  public:
     40   typedef DepthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
     41   typedef typename KernelSideFormat::Cell CellFormat;
     42   static const int kCells = KernelSideFormat::kCells;
     43   static const int kCellWidth = CellFormat::kWidth;
     44   static const int kKernelWidth = CellFormat::kWidth * kCells;
     45   static const int kCellDepth = CellFormat::kDepth;
     46   static const int kCellSize = CellFormat::kSize;
     47 
     48   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
     49     std::uint8_t* dst_ptr = dst->current_data();
     50     const std::uint8_t* const src_ptr = this->complete_src_.data();
     51     const int stride = this->complete_src_.stride();
     52     // Load source WidthMajor data
     53     uint8x16_t src_lines[4 * kCells];
     54     for (int i = 0; i < 4 * kCells; i++) {
     55       src_lines[i] = vld1q_u8(src_ptr + i * stride);
     56     }
     57     // Reorder the data within registers to make DepthMajor 4x2 cells
     58     uint8x16x2_t src_lines_intertwined_2x[2 * kCells];
     59     for (int i = 0; i < kCells; i++) {
     60       src_lines_intertwined_2x[2 * i] =
     61           vzipq_u8(src_lines[4 * i], src_lines[4 * i + 2]);
     62       src_lines_intertwined_2x[2 * i + 1] =
     63           vzipq_u8(src_lines[4 * i + 1], src_lines[4 * i + 3]);
     64     }
     65     uint8x16x2_t src_lines_intertwined_4x[2 * kCells];
     66     for (int i = 0; i < kCells; i++) {
     67       src_lines_intertwined_4x[2 * i] =
     68           vzipq_u8(src_lines_intertwined_2x[2 * i].val[0],
     69                    src_lines_intertwined_2x[2 * i + 1].val[0]);
     70       src_lines_intertwined_4x[2 * i + 1] =
     71           vzipq_u8(src_lines_intertwined_2x[2 * i].val[1],
     72                    src_lines_intertwined_2x[2 * i + 1].val[1]);
     73     }
     74     // Store the resulting DepthMajor 4x2 cells in the destination packed block
     75     for (int outer = 0; outer < 2; outer++) {
     76       for (int inner = 0; inner < 2; inner++) {
     77         for (int cell = 0; cell < kCells; cell++) {
     78           uint8x8_t value = vget_low_u8(
     79               src_lines_intertwined_4x[2 * cell + outer].val[inner]);
     80           vst1_u8(dst_ptr, value);
     81           dst_ptr += 8;
     82         }
     83         for (int cell = 0; cell < kCells; cell++) {
     84           uint8x8_t value = vget_high_u8(
     85               src_lines_intertwined_4x[2 * cell + outer].val[inner]);
     86           vst1_u8(dst_ptr, value);
     87           dst_ptr += 8;
     88         }
     89       }
     90     }
     91     // Compute sums across the depth dimension
     92     uint16x8_t sums_of_2_cells[kCells][4];
     93     for (int outer = 0; outer < 2; outer++) {
     94       for (int inner = 0; inner < 2; inner++) {
     95         int i = 2 * outer + inner;
     96         for (int cell = 0; cell < kCells; cell++) {
     97           sums_of_2_cells[cell][i] = vaddl_u8(
     98               vget_low_u8(
     99                   src_lines_intertwined_4x[2 * cell + outer].val[inner]),
    100               vget_high_u8(
    101                   src_lines_intertwined_4x[2 * cell + outer].val[inner]));
    102         }
    103       }
    104     }
    105     int32x4_t sums_of_4_cells[kCells][4];
    106     for (int i = 0; i < 4; i++) {
    107       for (int cell = 0; cell < kCells; cell++) {
    108         sums_of_4_cells[cell][i] = vreinterpretq_s32_u32(
    109             vaddl_u16(vget_low_u16(sums_of_2_cells[cell][i]),
    110                       vget_high_u16(sums_of_2_cells[cell][i])));
    111       }
    112     }
    113     // Update the sums_of_each_slice vector
    114     for (int cell = 0; cell < kCells; cell++) {
    115       int32x4_t s01 =
    116           vaddq_s32(sums_of_4_cells[cell][0], sums_of_4_cells[cell][1]);
    117       int32x4_t s23 =
    118           vaddq_s32(sums_of_4_cells[cell][2], sums_of_4_cells[cell][3]);
    119       int32x4_t s = vaddq_s32(s01, s23);
    120       std::int32_t* sums_of_each_slice_ptr =
    121           dst->sums_of_each_slice() + start_width + 4 * cell;
    122       vst1q_s32(sums_of_each_slice_ptr,
    123                 vaddq_s32(s, vld1q_s32(sums_of_each_slice_ptr)));
    124     }
    125     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
    126   }
    127 };
    128 
    129 template <int Cells>
    130 using WidthMajorSideFormatNCells4x2 =
    131     KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>;
    132 
    133 template <int Cells>
    134 class PackingRegisterBlock<
    135     WidthMajorUint8SideMap,
    136     PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>>
    137     : public PackingRegisterBlockBase<
    138           WidthMajorUint8SideMap,
    139           PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells>>> {
    140  public:
    141   typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat;
    142   typedef typename KernelSideFormat::Cell CellFormat;
    143   static const int kCells = KernelSideFormat::kCells;
    144   static const int kCellWidth = CellFormat::kWidth;
    145   static const int kKernelWidth = CellFormat::kWidth * kCells;
    146   static const int kCellDepth = CellFormat::kDepth;
    147   static const int kCellSize = CellFormat::kSize;
    148 
    149   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
    150     std::uint8_t* dst_ptr = dst->current_data();
    151     const std::uint8_t* src_ptr = this->complete_src_.data();
    152     const int stride = this->complete_src_.stride();
    153     // Load source WidthMajor data
    154     uint16x8_t src_lines[kCells * 4];
    155     for (int i = 0; i < kCells; i++) {
    156       // This packing path is used with our current
    157       // less-than-8-bit kernel, and the partial unrolling of this loop
    158       // results in substantially faster code (thanks to better
    159       // register allocation) on Nexus 5.
    160 
    161 #define GEMMLOWP_UNROLLED_LOOP_ITER(k)                            \
    162   src_lines[4 * i + k] = vreinterpretq_u16_u8(vld1q_u8(src_ptr)); \
    163   src_ptr += stride;
    164 
    165       GEMMLOWP_UNROLLED_LOOP_ITER(0)
    166       GEMMLOWP_UNROLLED_LOOP_ITER(1)
    167       GEMMLOWP_UNROLLED_LOOP_ITER(2)
    168       GEMMLOWP_UNROLLED_LOOP_ITER(3)
    169 
    170 #undef GEMMLOWP_UNROLLED_LOOP_ITER
    171     }
    172     // Reorder the data within registers to make WidthMajor 4x2 cells
    173     uint16x8x2_t src_lines_intertwined_2x[2 * kCells];
    174     for (int i = 0; i < kCells; i++) {
    175       src_lines_intertwined_2x[2 * i] =
    176           vzipq_u16(src_lines[4 * i], src_lines[4 * i + 2]);
    177       src_lines_intertwined_2x[2 * i + 1] =
    178           vzipq_u16(src_lines[4 * i + 1], src_lines[4 * i + 3]);
    179     }
    180     uint16x8x2_t src_lines_intertwined_4x[2 * kCells];
    181     for (int i = 0; i < kCells; i++) {
    182       src_lines_intertwined_4x[2 * i] =
    183           vzipq_u16(src_lines_intertwined_2x[2 * i].val[0],
    184                     src_lines_intertwined_2x[2 * i + 1].val[0]);
    185       src_lines_intertwined_4x[2 * i + 1] =
    186           vzipq_u16(src_lines_intertwined_2x[2 * i].val[1],
    187                     src_lines_intertwined_2x[2 * i + 1].val[1]);
    188     }
    189     // Store the resulting WidthMajor 4x2 cells in the destination packed block
    190     for (int outer = 0; outer < 2; outer++) {
    191       for (int inner = 0; inner < 2; inner++) {
    192         for (int cell = 0; cell < kCells; cell++) {
    193           uint8x8_t value = vreinterpret_u8_u16(vget_low_u16(
    194               src_lines_intertwined_4x[2 * cell + outer].val[inner]));
    195           vst1_u8(dst_ptr, value);
    196           dst_ptr += 8;
    197         }
    198         for (int cell = 0; cell < kCells; cell++) {
    199           uint8x8_t value = vreinterpret_u8_u16(vget_high_u16(
    200               src_lines_intertwined_4x[2 * cell + outer].val[inner]));
    201           vst1_u8(dst_ptr, value);
    202           dst_ptr += 8;
    203         }
    204       }
    205     }
    206     // Compute sums across the depth dimension
    207     uint16x8_t sums_of_2[kCells][4];
    208     for (int outer = 0; outer < 2; outer++) {
    209       for (int inner = 0; inner < 2; inner++) {
    210         int i = 2 * outer + inner;
    211         for (int cell = 0; cell < kCells; cell++) {
    212           sums_of_2[cell][i] = vpaddlq_u8(vreinterpretq_u8_u16(
    213               src_lines_intertwined_4x[2 * cell + outer].val[inner]));
    214         }
    215       }
    216     }
    217     uint16x8_t sums_of_4[kCells][2];
    218     for (int i = 0; i < 2; i++) {
    219       for (int cell = 0; cell < kCells; cell++) {
    220         sums_of_4[cell][i] =
    221             vaddq_u16(sums_of_2[cell][2 * i], sums_of_2[cell][2 * i + 1]);
    222       }
    223     }
    224     uint16x8_t sums_of_8[kCells];
    225     for (int cell = 0; cell < kCells; cell++) {
    226       sums_of_8[cell] = vaddq_u16(sums_of_4[cell][0], sums_of_4[cell][1]);
    227     }
    228 
    229     uint16x4_t sums_of_16[kCells];
    230     for (int cell = 0; cell < kCells; cell++) {
    231       sums_of_16[cell] = vadd_u16(vget_low_u16(sums_of_8[cell]),
    232                                   vget_high_u16(sums_of_8[cell]));
    233     }
    234     // Update the sums_of_each_slice vector
    235     for (int cell = 0; cell < kCells; cell++) {
    236       int32x4_t s = vreinterpretq_s32_u32(vmovl_u16(sums_of_16[cell]));
    237       std::int32_t* sums_of_each_slice_ptr =
    238           dst->sums_of_each_slice() + start_width + 4 * cell;
    239       vst1q_s32(sums_of_each_slice_ptr,
    240                 vaddq_s32(s, vld1q_s32(sums_of_each_slice_ptr)));
    241     }
    242     dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth);
    243   }
    244 };
    245 
    246 #ifdef GEMMLOWP_NEON_32
    247 inline int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
    248   const int16x4_t c = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
    249   const int16x4_t d = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
    250   return vcombine_s16(c, d);
    251 }
    252 #endif
    253 
    254 template <int Width>
    255 using Int8FastKernelFormat =
    256     KernelSideFormatInt8<CellFormat<Width, 16, CellOrder::WidthMajor>, 1>;
    257 
    258 template <int Width>
    259 class PackingRegisterBlock<WidthMajorUint8SideMap,
    260                            PackedSideBlock<Int8FastKernelFormat<Width>>>
    261     : public PackingRegisterBlockBase<
    262           WidthMajorUint8SideMap,
    263           PackedSideBlock<Int8FastKernelFormat<Width>>> {
    264  public:
    265   static_assert(Width == 2 || Width == 4, "");
    266   typedef Int8FastKernelFormat<Width> KernelSideFormat;
    267   typedef typename KernelSideFormat::Cell CellFormat;
    268   static const int kCells = KernelSideFormat::kCells;
    269   static const int kCellWidth = CellFormat::kWidth;
    270   static const int kKernelWidth = CellFormat::kWidth * kCells;
    271   static const int kCellDepth = CellFormat::kDepth;
    272   static const int kCellSize = CellFormat::kSize;
    273 
    274   void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width) {
    275     std::int32_t* sums_ptr = dst->sums_of_each_slice() + start_width;
    276     std::uint8_t* dst_ptr = dst->current_data();
    277     const std::uint8_t* const src_ptr = this->complete_src_.data();
    278     const int stride = this->complete_src_.stride();
    279     // Load source WidthMajor data
    280     uint8x16_t src_lines[Width];
    281     for (int i = 0; i < Width; i++) {
    282       src_lines[i] = vld1q_u8(src_ptr + i * stride);
    283     }
    284     const uint8x16_t sign_bit_dup = vdupq_n_u8(0x80);
    285     for (int i = 0; i < Width; i++) {
    286       src_lines[i] = veorq_u8(src_lines[i], sign_bit_dup);
    287     }
    288     for (int i = 0; i < Width; i++) {
    289       vst1q_u8(dst_ptr + 16 * i, src_lines[i]);
    290     }
    291     int16x8_t sums2[Width];
    292     for (int i = 0; i < Width; i++) {
    293       const int8x8_t lo = vreinterpret_s8_u8(vget_low_u8(src_lines[i]));
    294       const int8x8_t hi = vreinterpret_s8_u8(vget_high_u8(src_lines[i]));
    295       sums2[i] = vaddl_s8(lo, hi);
    296     }
    297     int16x8_t sums4[Width / 2];
    298     for (int i = 0; i < Width / 2; i++) {
    299       sums4[i] = vpaddq_s16(sums2[2 * i], sums2[2 * i + 1]);
    300     }
    301     if (Width == 4) {
    302       int32x4_t sum = vld1q_s32(sums_ptr);
    303       int16x8_t sums8 = vpaddq_s16(sums4[0], sums4[1]);
    304       sum = vpadalq_s16(sum, sums8);
    305       vst1q_s32(sums_ptr, sum);
    306     } else {
    307       assert(Width == 2);
    308       int32x2_t sum = vld1_s32(sums_ptr);
    309       int16x4_t sums8 =
    310           vpadd_s16(vget_low_s16(sums4[0]), vget_high_s16(sums4[0]));
    311       sum = vpadal_s16(sum, sums8);
    312       vst1_s32(sums_ptr, sum);
    313     }
    314     dst->seek_forward_n_cells(1);
    315   }
    316 };
    317 
    318 }  // namespace gemmlowp
    319 
    320 #endif  // GEMMLOWP_INTERNAL_PACK_NEON_H_
    321