Home | History | Annotate | Download | only in public
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // gemmlowp.h: the main public interface header of gemmlowp.
     16 
     17 #ifndef GEMMLOWP_PUBLIC_GEMMLOWP_H_
     18 #define GEMMLOWP_PUBLIC_GEMMLOWP_H_
     19 #include "../internal/kernel_default.h"
     20 #include "../internal/multi_thread_gemm.h"
     21 #include "../internal/unpack.h"
     22 #include "bit_depth.h"
     23 #include "map.h"
     24 #include "output_stages.h"
     25 
     26 namespace gemmlowp {
     27 
     28 inline bool IsRequantizationWorthIt(int rows, int cols) {
     29   // We pack depth*(rows+cols) and compute depth*rows*cols.
     30   // Thus the ratio of compute/packing cost is rows*cols/(rows+cols)
     31   // In the square case rows==cols==N, it becomes N/2.
     32   return 2 * rows * cols >= (rows + cols) * kMinimumWidthForRequantization;
     33 }
     34 
     35 class GemmContext : public MultiThreadGemmContext {};
     36 
     37 // Computes a general matrix product ("GEMM").
     38 // This is a version that supports per channel quantization.
     39 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
     40           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
     41           typename LhsOffset, typename RhsOffset, typename OutputPipelineType>
     42 void GemmWithOutputPipelinePC(GemmContext* context,
     43                               const MatrixMap<const InputScalar, LhsOrder>& lhs,
     44                               const MatrixMap<const InputScalar, RhsOrder>& rhs,
     45                               MatrixMap<OutputScalar, ResultOrder>* result,
     46                               const LhsOffset& lhs_offset,
     47                               const RhsOffset& rhs_offset,
     48                               const OutputPipelineType& output_pipeline) {
     49   assert(lhs.cols() == rhs.rows());
     50 
     51   int rows = result->rows();
     52   int cols = result->cols();
     53   int depth = lhs.cols();
     54 
     55   if (rows == 0 || cols == 0 || depth == 0) {
     56     // Vacuous GEMM, return early to avoid having to deal with
     57     // zero sizes below.
     58     return;
     59   }
     60 
     61   if (cols == 1) {
     62     if (IsRequantizationWorthIt(rows, cols)) {
     63       typedef DefaultKernel<KernelFamily::Gemv, BitDepthParams> Kernel;
     64       MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
     65                       BitDepthParams>(context, Kernel(), lhs, rhs, result,
     66                                       lhs_offset, rhs_offset, output_pipeline);
     67     } else {
     68       typedef DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>
     69           Kernel;
     70       MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
     71                       DefaultL8R8BitDepthParams>(context, Kernel(), lhs, rhs,
     72                                                  result, lhs_offset, rhs_offset,
     73                                                  output_pipeline);
     74     }
     75   } else {
     76     if (IsRequantizationWorthIt(rows, cols)) {
     77       typedef DefaultKernel<KernelFamily::Gemm, BitDepthParams> Kernel;
     78       MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
     79                       BitDepthParams>(context, Kernel(), lhs, rhs, result,
     80                                       lhs_offset, rhs_offset, output_pipeline);
     81     } else {
     82       typedef DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>
     83           Kernel;
     84       MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
     85                       DefaultL8R8BitDepthParams>(context, Kernel(), lhs, rhs,
     86                                                  result, lhs_offset, rhs_offset,
     87                                                  output_pipeline);
     88     }
     89   }
     90 }
     91 
     92 // Computes a general matrix product ("GEMM").
     93 // This is the legacy version that does not support per channel quantization.
     94 // The meaning of the offsets, result_mult_int and result_shift
     95 // parameters is the same as in the standard EightBitIntGemm interface
     96 // (which is also implemented in the eight_bit_int_gemm directory).
     97 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
     98           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
     99           typename OutputPipelineType>
    100 void GemmWithOutputPipeline(GemmContext* context,
    101                             const MatrixMap<const InputScalar, LhsOrder>& lhs,
    102                             const MatrixMap<const InputScalar, RhsOrder>& rhs,
    103                             MatrixMap<OutputScalar, ResultOrder>* result,
    104                             int lhs_offset, int rhs_offset,
    105                             const OutputPipelineType& output_pipeline) {
    106   const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
    107   const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
    108   GemmWithOutputPipelinePC<InputScalar, OutputScalar, BitDepthParams>(
    109       context, lhs, rhs, result, lhs_offset_vector, rhs_offset_vector,
    110       output_pipeline);
    111 }
    112 
    113 // Computes a general matrix product ("GEMM").
    114 // The meaning of the offsets, result_mult_int and result_shift
    115 // parameters is the same as in the standard EightBitIntGemm interface
    116 // (which is also implemented in the eight_bit_int_gemm directory).
    117 template <typename Scalar, typename BitDepthParams, MapOrder LhsOrder,
    118           MapOrder RhsOrder, MapOrder ResultOrder>
    119 void Gemm(GemmContext* context, const MatrixMap<const Scalar, LhsOrder>& lhs,
    120           const MatrixMap<const Scalar, RhsOrder>& rhs,
    121           MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
    122           int rhs_offset, int result_offset, int result_mult_int,
    123           int result_shift) {
    124   GemmWithOutputPipeline<Scalar, Scalar, BitDepthParams>(
    125       context, lhs, rhs, result, lhs_offset, rhs_offset,
    126       MakeStandardOutputPipeline(result_offset, result_mult_int, result_shift));
    127 }
    128 
    129 }  // namespace gemmlowp
    130 
    131 #endif  // GEMMLOWP_PUBLIC_GEMMLOWP_H_
    132