Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // unpack.h: unpacking the result blocks computed by compute.h,
     16 // storing them into the destination matrix.
     17 
     18 #ifndef GEMMLOWP_INTERNAL_UNPACK_H_
     19 #define GEMMLOWP_INTERNAL_UNPACK_H_
     20 
     21 #include "allocator.h"
     22 #include "block_params.h"
     23 #include "output.h"
     24 #include "pack.h"
     25 
     26 #include <cmath>
     27 
     28 namespace gemmlowp {
     29 
     30 class PackedResult {
     31  public:
     32   PackedResult(Allocator* _allocator, const BlockParams& _block_params)
     33       : allocator_(_allocator), block_params_(_block_params) {
     34     matrix_handle_ = allocator_->Reserve<std::int32_t>(block_params_.l2_rows *
     35                                                        block_params_.l2_cols);
     36   }
     37 
     38   ~PackedResult() {}
     39 
     40   MatrixMap<std::int32_t, MapOrder::ColMajor> Map() {
     41     return MatrixMap<std::int32_t, MapOrder::ColMajor>(
     42         allocator_->GetPointer<std::int32_t>(matrix_handle_),
     43         block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows);
     44   }
     45 
     46   MatrixMap<const std::int32_t, MapOrder::ColMajor> Map() const {
     47     return MatrixMap<const std::int32_t, MapOrder::ColMajor>(
     48         allocator_->GetPointer<const std::int32_t>(matrix_handle_),
     49         block_params_.l2_rows, block_params_.l2_cols, block_params_.l2_rows);
     50   }
     51 
     52  private:
     53   Allocator* allocator_;
     54   Allocator::Handle matrix_handle_;
     55   const BlockParams& block_params_;
     56 };
     57 
     58 struct MatrixBlockBounds {
     59   int start_row;
     60   int start_col;
     61   int rows;
     62   int cols;
     63 
     64   MatrixBlockBounds(int start_row_, int start_col_, int rows_, int cols_)
     65       : start_row(start_row_),
     66         start_col(start_col_),
     67         rows(rows_),
     68         cols(cols_) {}
     69 };
     70 
     71 template <int Rows, int Cols, typename SrcMapType>
     72 void PrefetchResultBlock(const SrcMapType& src,
     73                          const VectorMap<const std::int32_t, VectorShape::Col>&
     74                              lhs_sums_of_each_slice,
     75                          int src_row, int src_col) {
     76   const std::int32_t* src_data = src.data(src_row, src_col);
     77   const int src_stride = src.stride();
     78   const std::int32_t* lhs_sums_data = lhs_sums_of_each_slice.data(src_row);
     79   for (int r = 0; r < Rows; r += 4) {
     80     Prefetch(lhs_sums_data + r);
     81   }
     82   for (int c = 0; c < Cols; c++) {
     83     for (int r = 0; r < Rows; r += 4) {
     84       Prefetch(src_data + r + c * src_stride);
     85     }
     86   }
     87 }
     88 
     89 template <typename KernelFormat, typename RegisterBlockType,
     90           typename SrcMapType, typename LhsOffset, typename RhsOffset,
     91           typename OutputPipelineExecutorType, typename DstType>
     92 void UnpackResultBlock(const SrcMapType& src,
     93                        const OutputPipelineExecutorType& executor, DstType* dst,
     94                        const VectorMap<const std::int32_t, VectorShape::Col>&
     95                            lhs_sums_of_each_slice,
     96                        const VectorMap<const std::int32_t, VectorShape::Row>&
     97                            rhs_sums_of_each_slice,
     98                        const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
     99                        int depth, int src_row, int src_col, int src_global_row,
    100                        int src_global_col, int dst_row, int dst_col) {
    101   using KernelLhsScalar = typename KernelFormat::Lhs::Scalar;
    102   using KernelRhsScalar = typename KernelFormat::Rhs::Scalar;
    103   static constexpr int KernelLhsZeroPointInput =
    104       ZeroPointInputValue<KernelLhsScalar>::kValue;
    105   static constexpr int KernelRhsZeroPointInput =
    106       ZeroPointInputValue<KernelRhsScalar>::kValue;
    107   auto acc = Load<RegisterBlockType>(src, src_row, src_col);
    108   const auto& lhs_sums_of_each_slice_block =
    109       LoadForBroadcasting<RegisterBlockType>(lhs_sums_of_each_slice, src_row);
    110   const auto& rhs_sums_of_each_slice_block =
    111       LoadForBroadcasting<RegisterBlockType>(rhs_sums_of_each_slice, src_col);
    112   auto lhs_offset_block =
    113       LoadForBroadcasting<RegisterBlockType>(lhs_offset, src_row);
    114   auto rhs_offset_block =
    115       LoadForBroadcasting<RegisterBlockType>(rhs_offset, src_col);
    116   AddConstant<KernelLhsZeroPointInput>(&lhs_offset_block);
    117   AddConstant<KernelRhsZeroPointInput>(&rhs_offset_block);
    118   BroadcastMulAdd(lhs_sums_of_each_slice_block, rhs_offset_block, &acc);
    119   for (int i = 0; i < decltype(rhs_offset_block)::kRegisterCount; i++) {
    120     rhs_offset_block.buf.reg[i] = Mul(rhs_offset_block.buf.reg[i], depth);
    121   }
    122   BroadcastMulAdd(BroadcastAdd(rhs_sums_of_each_slice_block, rhs_offset_block),
    123                   lhs_offset_block, &acc);
    124   executor.Execute(acc, dst, src_global_row, src_global_col, dst_row, dst_col);
    125 }
    126 
    127 template <typename KernelFormat, typename ResultBlockType,
    128           typename PackedResultType, typename LhsOffset, typename RhsOffset,
    129           typename OutputPipelineType>
    130 void UnpackResult(ResultBlockType* dst, const MatrixBlockBounds& dst_block,
    131                   const PackedResultType& src, int depth,
    132                   const std::int32_t* lhs_sums_of_each_slice_ptr,
    133                   const std::int32_t* rhs_sums_of_each_slice_ptr,
    134                   const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
    135                   const OutputPipelineType& output_pipeline) {
    136   ScopedProfilingLabel label(ResultBlockType::kOrder == MapOrder::ColMajor
    137                                  ? "unpack to column-major"
    138                                  : "unpack to row-major");
    139   assert(dst_block.start_row >= 0);
    140   assert(dst_block.start_row + dst_block.rows <= dst->rows());
    141   assert(dst_block.start_col >= 0);
    142   assert(dst_block.start_col + dst_block.cols <= dst->cols());
    143   const auto src_map = src.Map();
    144   const VectorMap<const std::int32_t, VectorShape::Col> lhs_sums_of_each_slice(
    145       lhs_sums_of_each_slice_ptr, dst_block.rows);
    146   const VectorMap<const std::int32_t, VectorShape::Row> rhs_sums_of_each_slice(
    147       rhs_sums_of_each_slice_ptr, dst_block.cols);
    148   using Int32x1x1 = RegisterBlock<std::int32_t, 1, 1>;
    149   using Int32x4x1 = RegisterBlock<std::int32_t, 4, 1>;
    150   using Int32x8x1 = RegisterBlock<std::int32_t, 8, 1>;
    151   using Int32x1x4 = RegisterBlock<std::int32_t, 1, 4>;
    152   using Int32x4x4 = RegisterBlock<std::int32_t, 4, 4>;
    153   using Int32x8x4 = RegisterBlock<std::int32_t, 8, 4>;
    154 
    155   using DstScalarType = typename ResultBlockType::Scalar;
    156   using DstScalarx8x8 = RegisterBlock<DstScalarType, 8, 8>;
    157 
    158   OutputPipelineExecutor<OutputPipelineType, Int32x1x1>
    159       output_pipeline_executor_1x1(output_pipeline);
    160   OutputPipelineExecutor<OutputPipelineType, Int32x4x1>
    161       output_pipeline_executor_4x1(output_pipeline);
    162   OutputPipelineExecutor<OutputPipelineType, Int32x8x1>
    163       output_pipeline_executor_8x1(output_pipeline);
    164   OutputPipelineExecutor<OutputPipelineType, Int32x1x4>
    165       output_pipeline_executor_1x4(output_pipeline);
    166   OutputPipelineExecutor<OutputPipelineType, Int32x4x4>
    167       output_pipeline_executor_4x4(output_pipeline);
    168   OutputPipelineExecutor<OutputPipelineType, Int32x8x4>
    169       output_pipeline_executor_8x4(output_pipeline);
    170 
    171   int c8 = 0;
    172   if (ResultBlockType::kOrder == MapOrder::RowMajor) {
    173     for (; c8 <= dst_block.cols - 8; c8 += 8) {
    174       PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, 0, c8);
    175       int r = 0;
    176       for (; r <= dst_block.rows - 8; r += 8) {
    177         const int global_row = r + dst_block.start_row;
    178         PrefetchResultBlock<8, 8>(src_map, lhs_sums_of_each_slice, r + 8, c8);
    179         DstScalarType dst_colmajor_buf[64];
    180         MatrixMap<DstScalarType, MapOrder::ColMajor> dst_colmajor_map(
    181             dst_colmajor_buf, 8, 8);
    182         for (int cx = 0; cx < 8; cx += 4) {
    183           const int c = c8 + cx;
    184           const int global_col = c + dst_block.start_col;
    185           UnpackResultBlock<KernelFormat, Int32x8x4>(
    186               src_map, output_pipeline_executor_8x4, &dst_colmajor_map,
    187               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
    188               rhs_offset, depth, r, c, global_row, global_col, 0, cx);
    189         }
    190         StoreFinalOutput(LoadContiguous<DstScalarx8x8>(dst_colmajor_buf), dst,
    191                          r + dst_block.start_row, c8 + dst_block.start_col);
    192       }
    193       for (; r <= dst_block.rows - 4; r += 4) {
    194         const int global_row = r + dst_block.start_row;
    195         for (int cx = 0; cx < 8; cx += 4) {
    196           const int c = c8 + cx;
    197           const int global_col = c + dst_block.start_col;
    198           UnpackResultBlock<KernelFormat, Int32x4x4>(
    199               src_map, output_pipeline_executor_4x4, dst,
    200               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
    201               rhs_offset, depth, r, c, global_row, global_col, global_row,
    202               global_col);
    203         }
    204       }
    205       for (; r < dst_block.rows; r++) {
    206         const int global_row = r + dst_block.start_row;
    207         for (int cx = 0; cx < 8; cx += 4) {
    208           const int c = c8 + cx;
    209           const int global_col = c + dst_block.start_col;
    210           UnpackResultBlock<KernelFormat, Int32x1x4>(
    211               src_map, output_pipeline_executor_1x4, dst,
    212               lhs_sums_of_each_slice, rhs_sums_of_each_slice, lhs_offset,
    213               rhs_offset, depth, r, c, global_row, global_col, global_row,
    214               global_col);
    215         }
    216       }
    217     }
    218   }
    219   int c = c8;
    220   for (; c <= dst_block.cols - 4; c += 4) {
    221     const int global_col = c + dst_block.start_col;
    222     PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, 0, c);
    223     int r = 0;
    224     for (; r <= dst_block.rows - 8; r += 8) {
    225       const int global_row = r + dst_block.start_row;
    226       PrefetchResultBlock<8, 4>(src_map, lhs_sums_of_each_slice, r + 8, c);
    227       UnpackResultBlock<KernelFormat, Int32x8x4>(
    228           src_map, output_pipeline_executor_8x4, dst, lhs_sums_of_each_slice,
    229           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
    230           global_row, global_col, global_row, global_col);
    231     }
    232     for (; r <= dst_block.rows - 4; r += 4) {
    233       const int global_row = r + dst_block.start_row;
    234       UnpackResultBlock<KernelFormat, Int32x4x4>(
    235           src_map, output_pipeline_executor_4x4, dst, lhs_sums_of_each_slice,
    236           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
    237           global_row, global_col, global_row, global_col);
    238     }
    239     for (; r < dst_block.rows; r++) {
    240       const int global_row = r + dst_block.start_row;
    241       UnpackResultBlock<KernelFormat, Int32x1x4>(
    242           src_map, output_pipeline_executor_1x4, dst, lhs_sums_of_each_slice,
    243           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
    244           global_row, global_col, global_row, global_col);
    245     }
    246   }
    247   for (; c < dst_block.cols; c++) {
    248     const int global_col = c + dst_block.start_col;
    249     PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, 0, c);
    250     int r = 0;
    251     for (; r <= dst_block.rows - 8; r += 8) {
    252       const int global_row = r + dst_block.start_row;
    253       PrefetchResultBlock<8, 1>(src_map, lhs_sums_of_each_slice, r + 8, c);
    254       UnpackResultBlock<KernelFormat, Int32x8x1>(
    255           src_map, output_pipeline_executor_8x1, dst, lhs_sums_of_each_slice,
    256           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
    257           global_row, global_col, global_row, global_col);
    258     }
    259     for (; r <= dst_block.rows - 4; r += 4) {
    260       const int global_row = r + dst_block.start_row;
    261       UnpackResultBlock<KernelFormat, Int32x4x1>(
    262           src_map, output_pipeline_executor_4x1, dst, lhs_sums_of_each_slice,
    263           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
    264           global_row, global_col, global_row, global_col);
    265     }
    266     for (; r < dst_block.rows; r++) {
    267       const int global_row = r + dst_block.start_row;
    268       UnpackResultBlock<KernelFormat, Int32x1x1>(
    269           src_map, output_pipeline_executor_1x1, dst, lhs_sums_of_each_slice,
    270           rhs_sums_of_each_slice, lhs_offset, rhs_offset, depth, r, c,
    271           global_row, global_col, global_row, global_col);
    272     }
    273   }
    274 }
    275 
    276 }  // end namespace gemmlowp
    277 
    278 #endif  // GEMMLOWP_INTERNAL_UNPACK_H_
    279