Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // single_thread_gemm.h: Single-threaded GEMM implementation.
     16 // This is a good place to start reading code, as it shows the overall
     17 // structure of a GEMM and is much simpler than multi_thread_gemm.h.
     18 
     19 #ifndef GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
     20 #define GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
     21 
     22 #include <cassert>
     23 
     24 #include "../public/map.h"
     25 #include "allocator.h"
     26 #include "compute.h"
     27 #include "kernel.h"
     28 #include "pack.h"
     29 #include "unpack.h"
     30 
     31 namespace gemmlowp {
     32 
     33 class SingleThreadGemmContext {
     34  public:
     35   Allocator* allocator() { return &allocator_; }
     36 
     37  protected:
     38   Allocator allocator_;
     39 };
     40 
     41 typedef VectorMap<const int32_t, VectorShape::Col> OffsetColMap;
     42 typedef VectorMap<const int32_t, VectorShape::Row> OffsetRowMap;
     43 typedef VectorDup<const int32_t, VectorShape::Col> OffsetColDup;
     44 typedef VectorDup<const int32_t, VectorShape::Row> OffsetRowDup;
     45 
     46 template <typename KernelFormat, typename InputScalar, typename OutputScalar,
     47           typename BitDepthParams, MapOrder LhsOrder, MapOrder RhsOrder,
     48           MapOrder ResultOrder, typename LhsOffset, typename RhsOffset,
     49           typename OutputPipelineType>
     50 void SingleThreadGemm(SingleThreadGemmContext* context,
     51                       const KernelBase& kernel,
     52                       const MatrixMap<const InputScalar, LhsOrder>& lhs,
     53                       const MatrixMap<const InputScalar, RhsOrder>& rhs,
     54                       MatrixMap<OutputScalar, ResultOrder>* result,
     55                       const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
     56                       const OutputPipelineType& output_pipeline) {
     57   ScopedProfilingLabel label("gemmlowp::SingleThreadGemm");
     58 
     59   assert(lhs.cols() == rhs.rows());
     60 
     61   int rows = result->rows();
     62   int cols = result->cols();
     63   int depth = lhs.cols();
     64 
     65   assert(rows > 0);
     66   assert(cols > 0);
     67   assert(depth > 0);
     68 
     69   Allocator* allocator = context->allocator();
     70 
     71   BlockParams block_params;
     72   block_params.Init<KernelFormat>(rows, cols, depth, 1);
     73 
     74   PackedSideBlock<typename KernelFormat::Lhs> packed_lhs(
     75       Side::Lhs, allocator, block_params);
     76   PackedSideBlock<typename KernelFormat::Rhs> packed_rhs(
     77       Side::Rhs, allocator, block_params);
     78 
     79   PackedResult packed_result(allocator, block_params);
     80 
     81   allocator->Commit();
     82 
     83   const bool pack_rhs_once = block_params.l2_cols == cols;
     84 
     85   if (pack_rhs_once) {
     86     PackRhs<BitDepthParams>(&packed_rhs, rhs);
     87   }
     88 
     89   for (int r = 0; r < rows; r += block_params.l2_rows) {
     90     int rs = std::min(block_params.l2_rows, rows - r);
     91 
     92     PackLhs<BitDepthParams>(&packed_lhs, lhs.block(r, 0, rs, depth));
     93 
     94     for (int c = 0; c < cols; c += block_params.l2_cols) {
     95       int cs = std::min(block_params.l2_cols, cols - c);
     96 
     97       if (!pack_rhs_once) {
     98         PackRhs<BitDepthParams>(&packed_rhs, rhs.block(0, c, depth, cs));
     99       }
    100 
    101       Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs);
    102 
    103       auto result_block = result->block(r, c, rs, cs);
    104       UnpackResult<BitDepthParams>(&result_block, packed_result, depth,
    105                                    packed_lhs.sums_of_each_slice(),
    106                                    packed_rhs.sums_of_each_slice(),
    107                                    lhs_offset, rhs_offset, output_pipeline);
    108     }
    109   }
    110 
    111   allocator->Decommit();
    112 }
    113 
    114 }  // namespace gemmlowp
    115 
    116 #endif  // GEMMLOWP_INTERNAL_SINGLE_THREAD_GEMM_H_
    117