1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // pack_SSE.h: optimized SSE specializations of the templates in pack.h. 16 17 #ifndef GEMMLOWP_INTERNAL_PACK_SSE_H_ 18 #define GEMMLOWP_INTERNAL_PACK_SSE_H_ 19 20 #include <smmintrin.h> 21 #include "pack.h" 22 23 namespace gemmlowp { 24 25 // Requantizes source values pointed by raw_src_ptr in [0..255] range 26 // to the range specified by BitDepth, [0..((2^bits)-1)]. 27 // This is in-place requantization, where the input is 28 // not modified if 8bit integers are used. SSE does not 29 // have less than 8bit kernels currently. Altought SSE registers 30 // hold 16 uint8_t elements, only first 8 uint8_t elements are 31 // requantized. The packing only use first 8 uint8_t elements 32 // of the SSE registers. Therefore, requantizing all 16 uint8_t 33 // elements will be wasteful computation. 34 template <typename QuantizationParams> 35 void SSERequantize( 36 __m128i* raw_src_ptr, 37 ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode>* 38 rounding_offset_generator) { 39 static const int kBits = QuantizationParams::BitDepth::kBits; 40 static const std::uint8_t kMaxVal = (1 << kBits) - 1; 41 if (kBits == 8) { 42 return; 43 } 44 45 std::uint8_t* raw_src_ui8_ptr = (std::uint8_t*)&raw_src_ptr[0]; 46 47 // modify only first 8 elements in the register (see note above) 48 for (int i = 0; i < 8; ++i) { 49 std::uint16_t scaled = 50 static_cast<std::uint16_t>(raw_src_ui8_ptr[i]) * kMaxVal; 51 std::uint8_t rounding_offset = rounding_offset_generator->get(); 52 raw_src_ui8_ptr[i] = (scaled + rounding_offset) / 255; 53 } 54 } 55 56 // TODO: Add DepthMajorUint8SideMap 57 58 typedef SideMap<const std::uint8_t, SideMapOrder::WidthMajor> 59 WidthMajorUint8SideMap; 60 61 template <int Cells> 62 using WidthMajorSideFormatNCells4x2 = 63 KernelSideFormat<CellFormat<4, 2, CellOrder::WidthMajor>, Cells>; 64 65 template <typename QuantizationParams, int Cells> 66 class PackingRegisterBlock< 67 QuantizationParams, WidthMajorUint8SideMap, 68 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > 69 : public PackingRegisterBlockBase< 70 QuantizationParams, WidthMajorUint8SideMap, 71 PackedSideBlock<WidthMajorSideFormatNCells4x2<Cells> > > { 72 public: 73 typedef WidthMajorSideFormatNCells4x2<Cells> KernelSideFormat; 74 typedef typename KernelSideFormat::Cell CellFormat; 75 static const int kCells = KernelSideFormat::kCells; 76 static const int kCellWidth = CellFormat::kWidth; 77 static const int kKernelWidth = CellFormat::kWidth * kCells; 78 static const int kCellDepth = CellFormat::kDepth; 79 static const int kCellSize = CellFormat::kSize; 80 81 typedef ScalarRoundingOffsetGenerator<QuantizationParams::kRoundingMode> 82 RoundingOffsetGenerator; 83 84 void Pack(PackedSideBlock<KernelSideFormat>* dst, int start_width, 85 RoundingOffsetGenerator* rounding_offset_generator) { 86 std::uint8_t* dst_ptr = dst->current_data(); 87 const int width_stride = this->complete_src_.width_stride(); 88 int depth_step = 8; 89 90 __m128i one = _mm_set1_epi16(1); 91 for (int cell_start_depth = 0; cell_start_depth < kRegisterSize; 92 cell_start_depth += depth_step) { 93 for (int cell_start_width = 0; cell_start_width < kKernelWidth; 94 cell_start_width += kCellWidth) { 95 std::int32_t* cell_sums_of_each_slice_ptr = 96 dst->sums_of_each_slice() + start_width + cell_start_width; 97 const std::uint8_t* src_data = 98 this->complete_src_.data(cell_start_width, cell_start_depth); 99 100 __m128i xmm1 = 101 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&src_data[0])); 102 __m128i xmm2 = _mm_loadl_epi64( 103 reinterpret_cast<const __m128i*>(&src_data[1 * width_stride])); 104 __m128i xmm3 = _mm_loadl_epi64( 105 reinterpret_cast<const __m128i*>(&src_data[2 * width_stride])); 106 __m128i xmm4 = _mm_loadl_epi64( 107 reinterpret_cast<const __m128i*>(&src_data[3 * width_stride])); 108 109 __m128i xmm5 = _mm_unpacklo_epi16(xmm1, xmm2); 110 __m128i xmm8 = _mm_shuffle_epi32(xmm5, 0x31); 111 112 __m128i xmm6 = _mm_unpacklo_epi16(xmm3, xmm4); 113 __m128i xmm7 = _mm_shuffle_epi32(xmm6, 0x80); 114 115 __m128i xmm9 = _mm_blend_epi16(xmm5, xmm7, 0xcc); 116 SSERequantize<QuantizationParams>(&xmm9, rounding_offset_generator); 117 118 __m128i xmm10 = _mm_blend_epi16(xmm8, xmm6, 0xcc); 119 SSERequantize<QuantizationParams>(&xmm10, rounding_offset_generator); 120 121 _mm_storel_epi64(reinterpret_cast<__m128i*>(&dst_ptr[0]), xmm9); 122 _mm_storel_epi64( 123 reinterpret_cast<__m128i*>(&dst_ptr[kCellSize * kCells]), xmm10); 124 125 __m128i xmm11 = _mm_shuffle_epi32(xmm9, 0xee); 126 SSERequantize<QuantizationParams>(&xmm11, rounding_offset_generator); 127 128 __m128i xmm12 = _mm_shuffle_epi32(xmm10, 0xee); 129 SSERequantize<QuantizationParams>(&xmm12, rounding_offset_generator); 130 131 _mm_storel_epi64( 132 reinterpret_cast<__m128i*>(&dst_ptr[2 * kCellSize * kCells]), 133 xmm11); 134 _mm_storel_epi64( 135 reinterpret_cast<__m128i*>(&dst_ptr[3 * kCellSize * kCells]), 136 xmm12); 137 138 xmm1 = _mm_cvtepu8_epi16(xmm9); 139 xmm2 = _mm_madd_epi16(xmm1, one); 140 __m128i sums_of_each_slice_xmm = _mm_loadu_si128( 141 reinterpret_cast<const __m128i*>(&cell_sums_of_each_slice_ptr[0])); 142 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 143 144 xmm1 = _mm_cvtepu8_epi16(xmm10); 145 xmm2 = _mm_madd_epi16(xmm1, one); 146 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 147 148 xmm1 = _mm_cvtepu8_epi16(xmm11); 149 xmm2 = _mm_madd_epi16(xmm1, one); 150 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 151 152 xmm1 = _mm_cvtepu8_epi16(xmm12); 153 xmm2 = _mm_madd_epi16(xmm1, one); 154 sums_of_each_slice_xmm = _mm_add_epi32(sums_of_each_slice_xmm, xmm2); 155 156 _mm_storeu_si128( 157 reinterpret_cast<__m128i*>(&cell_sums_of_each_slice_ptr[0]), 158 sums_of_each_slice_xmm); 159 dst_ptr += kCellSize; 160 } 161 dst_ptr += 3 * kCellSize * kCells; 162 } 163 dst->seek_forward_n_cells(kCells * kRegisterSize / kCellDepth); 164 } 165 }; 166 167 } // namespace gemmlowp 168 169 #endif // GEMMLOWP_INTERNAL_PACK_SSE_H_ 170