Home | History | Annotate | Download | only in public
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // gemmlowp.h: the main public interface header of gemmlowp.
     16 
     17 #ifndef GEMMLOWP_PUBLIC_GEMMLOWP_H_
     18 #define GEMMLOWP_PUBLIC_GEMMLOWP_H_
     19 #include "../internal/dispatch_gemm_shape.h"
     20 #include "bit_depth.h"
     21 #include "map.h"
     22 #include "output_stages.h"
     23 
     24 namespace gemmlowp {
     25 
     26 class GemmContext : public MultiThreadGemmContext {};
     27 
     28 // Computes a general matrix product ("GEMM").
     29 // This is a version that supports per channel quantization.
     30 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
     31           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
     32           typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
     33           typename GemmContextType>
     34 void GemmWithOutputPipelinePC(GemmContextType* context,
     35                               const MatrixMap<const InputScalar, LhsOrder>& lhs,
     36                               const MatrixMap<const InputScalar, RhsOrder>& rhs,
     37                               MatrixMap<OutputScalar, ResultOrder>* result,
     38                               const LhsOffset& lhs_offset,
     39                               const RhsOffset& rhs_offset,
     40                               const OutputPipelineType& output_pipeline) {
     41   DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
     42       context, lhs, rhs, result, lhs_offset, rhs_offset, output_pipeline);
     43 }
     44 
     45 // Computes a general matrix product ("GEMM").
     46 // This is the legacy version that does not support per channel quantization.
     47 // The meaning of the offsets, result_mult_int and result_shift
     48 // parameters is the same as in the standard EightBitIntGemm interface
     49 // (which is also implemented in the eight_bit_int_gemm directory).
     50 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
     51           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
     52           typename OutputPipelineType, typename GemmContextType>
     53 void GemmWithOutputPipeline(GemmContextType* context,
     54                             const MatrixMap<const InputScalar, LhsOrder>& lhs,
     55                             const MatrixMap<const InputScalar, RhsOrder>& rhs,
     56                             MatrixMap<OutputScalar, ResultOrder>* result,
     57                             int lhs_offset, int rhs_offset,
     58                             const OutputPipelineType& output_pipeline) {
     59   typedef VectorDup<const std::int32_t, VectorShape::Col> OffsetColDup;
     60   typedef VectorDup<const std::int32_t, VectorShape::Row> OffsetRowDup;
     61   const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
     62   const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
     63   DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
     64       context, lhs, rhs, result, lhs_offset_vector, rhs_offset_vector,
     65       output_pipeline);
     66 }
     67 
     68 // Computes a general matrix product ("GEMM").
     69 // The meaning of the offsets, result_mult_int and result_shift
     70 // parameters is the same as in the standard EightBitIntGemm interface
     71 // (which is also implemented in the eight_bit_int_gemm directory).
     72 template <typename Scalar, typename BitDepthParams, MapOrder LhsOrder,
     73           MapOrder RhsOrder, MapOrder ResultOrder, typename GemmContextType>
     74 void Gemm(GemmContextType* context,
     75           const MatrixMap<const Scalar, LhsOrder>& lhs,
     76           const MatrixMap<const Scalar, RhsOrder>& rhs,
     77           MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
     78           int rhs_offset, int result_offset, int result_mult_int,
     79           int result_shift) {
     80   GemmWithOutputPipeline<Scalar, Scalar, BitDepthParams>(
     81       context, lhs, rhs, result, lhs_offset, rhs_offset,
     82       MakeStandardOutputPipeline(result_offset, result_mult_int, result_shift));
     83 }
     84 
     85 }  // namespace gemmlowp
     86 
     87 #endif  // GEMMLOWP_PUBLIC_GEMMLOWP_H_
     88