1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // kernel.h: general definitions for kernels. 16 17 #ifndef GEMMLOWP_INTERNAL_KERNEL_H_ 18 #define GEMMLOWP_INTERNAL_KERNEL_H_ 19 20 #include "../public/bit_depth.h" 21 #include "common.h" 22 23 namespace gemmlowp { 24 25 // Explanation of general gemmlowp terminology 26 // =========================================== 27 // 28 // We use the following abbreviations: 29 // LHS = "left-hand side" 30 // RHS = "right-hand side" 31 // Sometimes when referring to either LHS or RHS, we just say a "Side". 32 // 33 // In a matrix product of a MxK matrix times a KxN matrix, 34 // we call K the 'depth'. Note that M is the number of rows 35 // of the result (and of the LHS), and N is the number of columns 36 // of the result (and of the RHS). 37 // 38 // In each of the LHS and RHS matrices, we call 'width' the 39 // other dimension, besides the depth. So in the LHS, 'width' 40 // is the number of rows, while in the RHS, 'width' is the number 41 // of columns. 42 // 43 // So in the LHS MxK matrix, the depth is K and the width in M. 44 // And in the RHS KxN matrix, the depth is K and the width in N. 45 // 46 // This is illustrated in this picture: 47 // 48 // RHS width 49 // <-----------------> 50 // +-----------------+ ^ 51 // | RHS | | Depth 52 // +-----------------+ v 53 // ^ +--+ +-----------------+ 54 // | |L | | | 55 // LHS width | |H | | Result | 56 // | |S | | | 57 // v +--+ +-----------------+ 58 // <--> 59 // Depth 60 61 // Explanation of gemmlowp kernel formats and "cells" 62 // ================================================== 63 // 64 // Kernels operate on small LHS and RHS blocks that fit in registers. 65 // These blocks are stored contiguously in memory, but not always 66 // in a traditional column-major or row-major order; instead, 67 // they consist of a number of sub-blocks, which we call "cells", 68 // that are stored in column-major or row-major order. However, 69 // what really matters to us is not so much rows vs columns, but 70 // rather width vs depth. So we refer to "width-major" and "depth-major" 71 // storage orders. In the LHS, width-major means row-major, 72 // while in the RHS, width-major means column-major. 73 // There is also a third possibility, "diagonal order", 74 // which is unused at the moment. 75 // 76 // We aim to treat both sides, LHS and RHS, on an equal footing, 77 // so we call them both 'sides'. A KernelFormat thus is just a pair 78 // of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat 79 // contains a CellFormat and a number of cells; cells are only ever 80 // stacked in the width dimension, which means stacked vertically in the 81 // LHS and stacked horizondally in the RHS. 82 // 83 // Example 84 // ======= 85 // 86 // Let's work out the data layout expected by a kernel having the 87 // following format (the struct names here are defined below in this file): 88 // 89 // KernelFormat< 90 // KernelSideFormat<CellFormat<3, 4>, 3>, 91 // KernelSideFormat<CellFormat<5, 4>, 2> 92 // > 93 // 94 // The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means: 95 // 3 cells, each cell having dimensions (width=3, depth=4), laid out in 96 // DepthMajor order (the default value, see CellFormat). In the LHS, 97 // DepthMajor means column-major, so the LHS cells are of size 3x4 in 98 // column-major order, so the LHS layout is: 99 // 100 // 0 3 6 9 101 // 1 4 7 10 102 // 2 5 8 11 103 // 12 15 18 21 104 // 13 16 19 22 105 // 14 17 20 23 106 // 24 27 30 33 107 // 25 28 31 34 108 // 26 29 32 35 109 // 110 // The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means: 111 // 2 cells each having dimensions (width=5, depth=4), laid out in 112 // DepthMajor order (the default value, see CellFormat). In the RHS, 113 // DepthMajor means row-major, so the RHS cells are of size 4x5 in 114 // row-major order, so the RHS layout is: 115 // 116 // 0 1 2 3 4 20 21 22 23 24 117 // 5 6 7 8 9 25 26 27 28 29 118 // 10 11 12 13 14 30 31 32 33 34 119 // 15 16 17 18 19 35 36 37 38 39 120 121 // CellOrder enumerates the possible storage orders (=layouts) for 122 // a cell (see explanation above). 123 enum class CellOrder { DepthMajor, WidthMajor, Diagonal }; 124 125 // CellFormat describes how data is laid 126 // out in a cell. That is, a CellOrder together with actual dimensions. 127 template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor> 128 struct CellFormat { 129 static const int kWidth = tWidth; 130 static const int kDepth = tDepth; 131 static const CellOrder kOrder = tOrder; 132 133 static const int kSize = kWidth * kDepth; 134 }; 135 136 // KernelSideFormat describes how data is laid out in a kernel side 137 // (i.e. LHS or RHS). That is, a CellFormat together with a number of 138 // cells. These cells are always stacked in the Width dimension. 139 // For example, in the LHS case, the Width dimension is the rows dimension, 140 // se we're saying that in the LHS, cells are stacked vertically. 141 // We never stack cells in the Depth dimension. 142 template <typename tCellFormat, int tCells> 143 struct KernelSideFormat { 144 typedef tCellFormat Cell; 145 static const int kCells = tCells; 146 static const int kWidth = kCells * Cell::kWidth; 147 static const int kDepth = Cell::kDepth; 148 typedef std::uint8_t Scalar; 149 }; 150 151 template <typename tCellFormat, int tCells> 152 struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> { 153 typedef std::int8_t Scalar; 154 }; 155 156 // KernelFormat describes fully the input data layout that a kernel expects. 157 // It consists of two KernelSideFormat's, one for LHS and one for RHS. 158 template <typename tLhs, typename tRhs> 159 struct KernelFormat { 160 typedef tLhs Lhs; 161 typedef tRhs Rhs; 162 163 static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, ""); 164 static const int kDepth = Lhs::Cell::kDepth; 165 static const int kRows = Lhs::Cell::kWidth * Lhs::kCells; 166 static const int kCols = Rhs::Cell::kWidth * Rhs::kCells; 167 }; 168 169 inline const char* CellOrderName(CellOrder o) { 170 switch (o) { 171 case CellOrder::DepthMajor: 172 return "DepthMajor"; 173 case CellOrder::WidthMajor: 174 return "WidthMajor"; 175 case CellOrder::Diagonal: 176 return "Diagonal"; 177 default: 178 assert(false); 179 return nullptr; 180 } 181 } 182 183 // Returns the offset into a cell, at which a given coefficient is stored. 184 template <typename CellFormat> 185 inline int OffsetIntoCell(int w, int d) { 186 const int size = CellFormat::kWidth; 187 switch (CellFormat::kOrder) { 188 case CellOrder::DepthMajor: 189 return w + d * CellFormat::kWidth; 190 case CellOrder::WidthMajor: 191 return d + w * CellFormat::kDepth; 192 case CellOrder::Diagonal: 193 assert(CellFormat::kWidth == CellFormat::kDepth); 194 return ((size + w - d) * size + d) % (size * size); 195 default: 196 assert(false); 197 return 0; 198 } 199 } 200 201 // KernelBase is the virtual base class below all kernels. 202 // The idea is that we don't need to templatize all our code on the exact 203 // kernel type; we only need to templatize on kernel format. Kernels 204 // sharing the same format can thus share the same packing/unpacking code. 205 struct KernelBase { 206 virtual const char* Name() const = 0; 207 208 // This is the kernel implementation. We use the word 'run' consistently 209 // throughout gemmlowp to mean an inner loop, the implementation of which 210 // is to be provided by a separate optimized function. 211 virtual void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, 212 std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, 213 const std::uint8_t* rhs_ptr, std::size_t start_depth, 214 std::size_t run_depth) const = 0; 215 216 virtual ~KernelBase() {} 217 }; 218 219 template <typename KernelScalarType> 220 struct ZeroPointInputValue {}; 221 222 template <> 223 struct ZeroPointInputValue<std::uint8_t> { 224 static constexpr std::uint8_t kValue = 0; 225 }; 226 227 template <> 228 struct ZeroPointInputValue<std::int8_t> { 229 static constexpr std::uint8_t kValue = 128; 230 }; 231 232 } // namespace gemmlowp 233 234 #endif // GEMMLOWP_INTERNAL_KERNEL_H_ 235