1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // dispatch_gemm_shape.h: dispatch GEMM calls according to their shape 16 17 #ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ 18 #define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ 19 20 #include "../internal/kernel_default.h" 21 #include "../public/map.h" 22 #include "../public/output_stages.h" 23 #include "multi_thread_gemm.h" 24 25 namespace gemmlowp { 26 27 template <typename T> 28 struct TransposeImpl { 29 typedef T DstType; 30 static T Run(const T& t) { return t; } 31 }; 32 33 template <typename T> 34 using TransposeType = typename TransposeImpl<T>::DstType; 35 36 template <typename T> 37 TransposeType<T> Transpose(const T& t) { 38 return TransposeImpl<T>::Run(t); 39 } 40 41 template <MapOrder Order> 42 struct TransposeMapOrder { 43 static constexpr MapOrder Value = 44 Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor; 45 }; 46 47 template <VectorShape Shape> 48 struct TransposeVectorShape { 49 static constexpr VectorShape Value = 50 Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row; 51 }; 52 53 template <typename Scalar, VectorShape Shape> 54 struct TransposeImpl<VectorMap<Scalar, Shape>> { 55 typedef VectorMap<Scalar, Shape> SrcType; 56 static constexpr VectorShape TransposedShape = 57 TransposeVectorShape<Shape>::Value; 58 typedef VectorMap<Scalar, TransposedShape> DstType; 59 static DstType Run(const SrcType& src) { 60 return DstType(src.data(), src.size()); 61 } 62 }; 63 64 template <typename Scalar, MapOrder Order> 65 struct TransposeImpl<MatrixMap<Scalar, Order>> { 66 typedef MatrixMap<Scalar, Order> SrcType; 67 static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value; 68 typedef MatrixMap<Scalar, TransposedOrder> DstType; 69 static DstType Run(const SrcType& src) { 70 return DstType(src.data(), src.cols(), src.rows(), src.stride()); 71 } 72 }; 73 74 template <VectorShape Shape> 75 struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> { 76 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType; 77 static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value; 78 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType; 79 static DstType Run(const SrcType& src) { 80 DstType dst; 81 dst.result_shift = src.result_shift; 82 dst.result_offset = Transpose(src.result_offset); 83 dst.result_mult_int = Transpose(src.result_mult_int); 84 return dst; 85 } 86 }; 87 88 template <typename VectorMapType> 89 struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> { 90 typedef OutputStageBiasAddition<VectorMapType> SrcType; 91 typedef TransposeType<VectorMapType> TransposedVectorMapType; 92 typedef OutputStageBiasAddition<TransposedVectorMapType> DstType; 93 static DstType Run(const SrcType& src) { 94 DstType dst; 95 dst.bias_vector = Transpose(src.bias_vector); 96 return dst; 97 } 98 }; 99 100 // TODO(benoitjacob) - does anyone understand C++ variadic templates? 101 // How to use them to implement TransposeTuple? Note: there are lots 102 // of answers on StackOverflow but they seem to all involve either 103 // C++14/C++17 (we can only use C++11) or lots of abstract nonsense. 104 inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; } 105 106 template <typename T0> 107 std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) { 108 return std::make_tuple(Transpose(std::get<0>(t))); 109 } 110 111 template <typename T0, typename T1> 112 std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple( 113 const std::tuple<T0, T1>& t) { 114 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t))); 115 } 116 117 template <typename T0, typename T1, typename T2> 118 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>> 119 TransposeTuple(const std::tuple<T0, T1, T2>& t) { 120 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), 121 Transpose(std::get<2>(t))); 122 } 123 124 template <typename T0, typename T1, typename T2, typename T3> 125 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, 126 TransposeType<T3>> 127 TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) { 128 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), 129 Transpose(std::get<2>(t)), Transpose(std::get<3>(t))); 130 } 131 132 template <typename T0, typename T1, typename T2, typename T3, typename T4> 133 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, 134 TransposeType<T3>, TransposeType<T4>> 135 TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) { 136 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), 137 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), 138 Transpose(std::get<4>(t))); 139 } 140 141 template <typename T0, typename T1, typename T2, typename T3, typename T4, 142 typename T5> 143 std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, 144 TransposeType<T3>, TransposeType<T4>, TransposeType<T5>> 145 TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) { 146 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), 147 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), 148 Transpose(std::get<4>(t)), Transpose(std::get<5>(t))); 149 } 150 151 template <typename InputScalar, typename OutputScalar, typename BitDepthParams, 152 MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder, 153 typename LhsOffset, typename RhsOffset, typename OutputPipelineType, 154 typename GemmContextType> 155 void DispatchGemmShape(GemmContextType* context, 156 const MatrixMap<const InputScalar, LhsOrder>& lhs, 157 const MatrixMap<const InputScalar, RhsOrder>& rhs, 158 MatrixMap<OutputScalar, ResultOrder>* result, 159 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, 160 const OutputPipelineType& output_pipeline) { 161 assert(lhs.cols() == rhs.rows()); 162 163 int rows = result->rows(); 164 int cols = result->cols(); 165 int depth = lhs.cols(); 166 167 if (rows == 0 || cols == 0 || depth == 0) { 168 // Vacuous GEMM, return early to avoid having to deal with 169 // zero sizes below. 170 return; 171 } 172 173 if (rows < cols) { 174 auto transposed_result_map = Transpose(*result); 175 return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>( 176 context, Transpose(rhs), Transpose(lhs), &transposed_result_map, 177 Transpose(rhs_offset), Transpose(lhs_offset), 178 TransposeTuple(output_pipeline)); 179 } 180 181 typedef DefaultKernel<BitDepthParams> Kernel; 182 MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar, 183 BitDepthParams>(context, Kernel(), lhs, rhs, result, 184 lhs_offset, rhs_offset, output_pipeline); 185 } 186 187 } // end namespace gemmlowp 188 189 #endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ 190