1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GEMMLOWP_META_BASE_H_ 16 #define GEMMLOWP_META_BASE_H_ 17 18 #include <cassert> 19 #include <cstdint> 20 21 #include "../internal/common.h" 22 23 namespace gemmlowp { 24 namespace meta { 25 26 template <int align> 27 inline int AlignTo(int value) { 28 return ((value + align - 1) / align) * align; 29 } 30 31 inline int AlignTo(int align, int value) { 32 return ((value + align - 1) / align) * align; 33 } 34 35 template <typename Kernel_, typename OutputStream_> 36 struct FusedKernelParams { 37 public: 38 typedef Kernel_ Kernel; 39 typedef OutputStream_ OutputStream; 40 41 Kernel kernel; 42 OutputStream output_stream; 43 }; 44 45 template <typename InType_, typename OutType_, typename LeftStream_, 46 typename RightStream_, typename Kernel_, typename OutputStream_> 47 struct GemmParams { 48 public: 49 typedef InType_ InType; 50 typedef OutType_ OutType; 51 typedef LeftStream_ LeftStream; 52 typedef RightStream_ RightStream; 53 typedef Kernel_ Kernel; 54 typedef OutputStream_ OutputStream; 55 56 typedef FusedKernelParams<Kernel, OutputStream> FusedKernel; 57 58 // Common parameters. 59 60 int m; 61 int n; 62 int k; 63 64 const InType* lhs; 65 const InType* rhs; 66 OutType* result; 67 std::uint8_t* scratch; 68 69 // Specialized parameters. 70 71 LeftStream left_stream; 72 RightStream right_stream; 73 FusedKernel fused_kernel; 74 }; 75 76 template <typename InType, int lanes_count, int pack_size, int leftovers, 77 typename StreamParams> 78 class Stream { 79 public: 80 static void Pack(const InType* in, const StreamParams& params, InType* out); 81 82 static int UnpackedAdvance(const StreamParams& params); 83 84 static int PackedAdvance(const StreamParams& params); 85 86 static int UnpackedStride(const StreamParams& params); 87 88 static int PackedStride(const StreamParams& params); 89 }; 90 91 template <typename InType, typename StreamType> 92 class StreamUtil { 93 public: 94 static const InType* Offset(const StreamType& params, const InType* source, 95 int offset_stride, int offset_advance); 96 97 static int Scratch(const StreamType& params, int lanes); 98 }; 99 100 template <typename InType, typename OutType, typename Kernel, 101 typename OutputStream, int kernel_m, int kernel_n, int pack_size> 102 class MulKernel { 103 public: 104 static void Multiply(const InType* lhs, const InType* rhs, 105 const FusedKernelParams<Kernel, OutputStream>& params, 106 OutType* result); 107 }; 108 109 template <typename InType_, typename OutType_, typename Kernel_> 110 struct Transform1DParams { 111 typedef InType_ InType; 112 typedef OutType_ OutType; 113 typedef Kernel_ Kernel; 114 115 const InType* input; 116 OutType* output; 117 std::uint8_t* scratch; 118 119 Kernel kernel; 120 }; 121 122 template <typename InType, typename OutType, typename Kernel, int kernel_size, 123 int leftovers> 124 class Transform1DKernel { 125 public: 126 static void Transform(const InType* input, const Kernel& params, 127 OutType* output); 128 }; 129 130 template <typename InType, typename OutType, typename Transform> 131 class Transform1DUtil { 132 public: 133 static int EstimateComputeCost(const Transform& params); 134 135 static const InType* OffsetInput(const Transform& params, const InType* input, 136 int offset); 137 138 static OutType* OffsetOutput(const Transform& params, OutType* output, 139 int offset); 140 }; 141 142 } // namespace meta 143 } // namespace gemmlowp 144 145 #endif // GEMMLOWP_META_BASE_H_ 146