Home | History | Annotate | Download | only in internal
      1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
     16 
     17 #ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
     18 #define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
     19 
     20 #include "../internal/kernel_default.h"
     21 #include "../public/map.h"
     22 #include "../public/output_stages.h"
     23 #include "multi_thread_gemm.h"
     24 
     25 namespace gemmlowp {
     26 
     27 template <typename T>
     28 struct TransposeImpl {
     29   typedef T DstType;
     30   static T Run(const T& t) { return t; }
     31 };
     32 
     33 template <typename T>
     34 using TransposeType = typename TransposeImpl<T>::DstType;
     35 
     36 template <typename T>
     37 TransposeType<T> Transpose(const T& t) {
     38   return TransposeImpl<T>::Run(t);
     39 }
     40 
     41 template <MapOrder Order>
     42 struct TransposeMapOrder {
     43   static constexpr MapOrder Value =
     44       Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
     45 };
     46 
     47 template <VectorShape Shape>
     48 struct TransposeVectorShape {
     49   static constexpr VectorShape Value =
     50       Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
     51 };
     52 
     53 template <typename Scalar, VectorShape Shape>
     54 struct TransposeImpl<VectorMap<Scalar, Shape>> {
     55   typedef VectorMap<Scalar, Shape> SrcType;
     56   static constexpr VectorShape TransposedShape =
     57       TransposeVectorShape<Shape>::Value;
     58   typedef VectorMap<Scalar, TransposedShape> DstType;
     59   static DstType Run(const SrcType& src) {
     60     return DstType(src.data(), src.size());
     61   }
     62 };
     63 
     64 template <typename Scalar, MapOrder Order>
     65 struct TransposeImpl<MatrixMap<Scalar, Order>> {
     66   typedef MatrixMap<Scalar, Order> SrcType;
     67   static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
     68   typedef MatrixMap<Scalar, TransposedOrder> DstType;
     69   static DstType Run(const SrcType& src) {
     70     return DstType(src.data(), src.cols(), src.rows(), src.stride());
     71   }
     72 };
     73 
     74 template <VectorShape Shape>
     75 struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
     76   typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
     77   static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
     78   typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
     79   static DstType Run(const SrcType& src) {
     80     DstType dst;
     81     dst.result_shift = src.result_shift;
     82     dst.result_offset = Transpose(src.result_offset);
     83     dst.result_mult_int = Transpose(src.result_mult_int);
     84     return dst;
     85   }
     86 };
     87 
     88 template <typename VectorMapType>
     89 struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
     90   typedef OutputStageBiasAddition<VectorMapType> SrcType;
     91   typedef TransposeType<VectorMapType> TransposedVectorMapType;
     92   typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
     93   static DstType Run(const SrcType& src) {
     94     DstType dst;
     95     dst.bias_vector = Transpose(src.bias_vector);
     96     return dst;
     97   }
     98 };
     99 
    100 // TODO(benoitjacob) - does anyone understand C++ variadic templates?
    101 // How to use them to implement TransposeTuple? Note: there are lots
    102 // of answers on StackOverflow but they seem to all involve either
    103 // C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
    104 inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
    105 
    106 template <typename T0>
    107 std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
    108   return std::make_tuple(Transpose(std::get<0>(t)));
    109 }
    110 
    111 template <typename T0, typename T1>
    112 std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
    113     const std::tuple<T0, T1>& t) {
    114   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
    115 }
    116 
    117 template <typename T0, typename T1, typename T2>
    118 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
    119 TransposeTuple(const std::tuple<T0, T1, T2>& t) {
    120   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
    121                          Transpose(std::get<2>(t)));
    122 }
    123 
    124 template <typename T0, typename T1, typename T2, typename T3>
    125 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
    126            TransposeType<T3>>
    127 TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
    128   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
    129                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
    130 }
    131 
    132 template <typename T0, typename T1, typename T2, typename T3, typename T4>
    133 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
    134            TransposeType<T3>, TransposeType<T4>>
    135 TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
    136   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
    137                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
    138                          Transpose(std::get<4>(t)));
    139 }
    140 
    141 template <typename T0, typename T1, typename T2, typename T3, typename T4,
    142           typename T5>
    143 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
    144            TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
    145 TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
    146   return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
    147                          Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
    148                          Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
    149 }
    150 
    151 template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
    152           MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
    153           typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
    154           typename GemmContextType>
    155 void DispatchGemmShape(GemmContextType* context,
    156                        const MatrixMap<const InputScalar, LhsOrder>& lhs,
    157                        const MatrixMap<const InputScalar, RhsOrder>& rhs,
    158                        MatrixMap<OutputScalar, ResultOrder>* result,
    159                        const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
    160                        const OutputPipelineType& output_pipeline) {
    161   assert(lhs.cols() == rhs.rows());
    162 
    163   int rows = result->rows();
    164   int cols = result->cols();
    165   int depth = lhs.cols();
    166 
    167   if (rows == 0 || cols == 0 || depth == 0) {
    168     // Vacuous GEMM, return early to avoid having to deal with
    169     // zero sizes below.
    170     return;
    171   }
    172 
    173   if (rows < cols) {
    174     auto transposed_result_map = Transpose(*result);
    175     return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
    176         context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
    177         Transpose(rhs_offset), Transpose(lhs_offset),
    178         TransposeTuple(output_pipeline));
    179   }
    180 
    181   typedef DefaultKernel<BitDepthParams> Kernel;
    182   MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
    183                   BitDepthParams>(context, Kernel(), lhs, rhs, result,
    184                                   lhs_offset, rhs_offset, output_pipeline);
    185 }
    186 
    187 }  // end namespace gemmlowp
    188 
    189 #endif  // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
    190