Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // kernel.h: general definitions for kernels.
     16 
     17 #ifndef GEMMLOWP_INTERNAL_KERNEL_H_
     18 #define GEMMLOWP_INTERNAL_KERNEL_H_
     19 
     20 #include "../public/bit_depth.h"
     21 #include "common.h"
     22 
     23 namespace gemmlowp {
     24 
     25 // Explanation of general gemmlowp terminology
     26 // ===========================================
     27 //
     28 // We use the following abbreviations:
     29 // LHS = "left-hand side"
     30 // RHS = "right-hand side"
     31 // Sometimes when referring to either LHS or RHS, we just say a "Side".
     32 //
     33 // In a matrix product of a MxK matrix times a KxN matrix,
     34 // we call K the 'depth'. Note that M is the number of rows
     35 // of the result (and of the LHS), and N is the number of columns
     36 // of the result (and of the RHS).
     37 //
     38 // In each of the LHS and RHS matrices, we call 'width' the
     39 // other dimension, besides the depth. So in the LHS, 'width'
     40 // is the number of rows, while in the RHS, 'width' is the number
     41 // of columns.
     42 //
     43 //  So in the LHS MxK matrix, the depth is K and the width in M.
     44 // And in the RHS KxN matrix, the depth is K and the width in N.
     45 //
     46 // This is illustrated in this picture:
     47 //
     48 //                             RHS width
     49 //                        <----------------->
     50 //                        +-----------------+ ^
     51 //                        |       RHS       | | Depth
     52 //                        +-----------------+ v
     53 //                 ^ +--+ +-----------------+
     54 //                 | |L | |                 |
     55 //       LHS width | |H | |      Result     |
     56 //                 | |S | |                 |
     57 //                 v +--+ +-----------------+
     58 //                   <-->
     59 //                   Depth
     60 
     61 // Explanation of gemmlowp kernel formats and "cells"
     62 // ==================================================
     63 //
     64 // Kernels operate on small LHS and RHS blocks that fit in registers.
     65 // These blocks are stored contiguously in memory, but not always
     66 // in a traditional column-major or row-major order; instead,
     67 // they consist of a number of sub-blocks, which we call "cells",
     68 // that are stored in column-major or row-major order. However,
     69 // what really matters to us is not so much rows vs columns, but
     70 // rather width vs depth. So we refer to "width-major" and "depth-major"
     71 // storage orders. In the LHS, width-major means row-major,
     72 // while in the RHS, width-major means column-major.
     73 // There is also a third possibility, "diagonal order",
     74 // which is unused at the moment.
     75 //
     76 // We aim to treat both sides, LHS and RHS, on an equal footing,
     77 // so we call them both 'sides'. A KernelFormat thus is just a pair
     78 // of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat
     79 // contains a CellFormat and a number of cells; cells are only ever
     80 // stacked in the width dimension, which means stacked vertically in the
     81 // LHS and stacked horizondally in the RHS.
     82 //
     83 // Example
     84 // =======
     85 //
     86 // Let's work out the data layout expected by a kernel having the
     87 // following format (the struct names here are defined below in this file):
     88 //
     89 // KernelFormat<
     90 //   KernelSideFormat<CellFormat<3, 4>, 3>,
     91 //   KernelSideFormat<CellFormat<5, 4>, 2>
     92 // >
     93 //
     94 // The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means:
     95 // 3 cells, each cell having dimensions (width=3, depth=4), laid out in
     96 // DepthMajor order (the default value, see CellFormat). In the LHS,
     97 // DepthMajor means column-major, so the LHS cells are of size 3x4 in
     98 // column-major order, so the LHS layout is:
     99 //
    100 // 0  3  6  9
    101 // 1  4  7  10
    102 // 2  5  8  11
    103 // 12 15 18 21
    104 // 13 16 19 22
    105 // 14 17 20 23
    106 // 24 27 30 33
    107 // 25 28 31 34
    108 // 26 29 32 35
    109 //
    110 // The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means:
    111 // 2 cells each having dimensions (width=5, depth=4), laid out in
    112 // DepthMajor order (the default value, see CellFormat). In the RHS,
    113 // DepthMajor means row-major, so the RHS cells are of size 4x5 in
    114 // row-major order, so the RHS layout is:
    115 //
    116 // 0  1  2  3  4  20 21 22 23 24
    117 // 5  6  7  8  9  25 26 27 28 29
    118 // 10 11 12 13 14 30 31 32 33 34
    119 // 15 16 17 18 19 35 36 37 38 39
    120 
    121 // CellOrder enumerates the possible storage orders (=layouts) for
    122 // a cell (see explanation above).
    123 enum class CellOrder { DepthMajor, WidthMajor, Diagonal };
    124 
    125 // CellFormat describes how data is laid
    126 // out in a cell. That is, a CellOrder together with actual dimensions.
    127 template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor>
    128 struct CellFormat {
    129   static const int kWidth = tWidth;
    130   static const int kDepth = tDepth;
    131   static const CellOrder kOrder = tOrder;
    132 
    133   static const int kSize = kWidth * kDepth;
    134 };
    135 
    136 // KernelSideFormat describes how data is laid out in a kernel side
    137 // (i.e. LHS or RHS). That is, a CellFormat together with a number of
    138 // cells. These cells are always stacked in the Width dimension.
    139 // For example, in the LHS case, the Width dimension is the rows dimension,
    140 // se we're saying that in the LHS, cells are stacked vertically.
    141 // We never stack cells in the Depth dimension.
    142 template <typename tCellFormat, int tCells>
    143 struct KernelSideFormat {
    144   typedef tCellFormat Cell;
    145   static const int kCells = tCells;
    146   static const int kWidth = kCells * Cell::kWidth;
    147   static const int kDepth = Cell::kDepth;
    148   typedef std::uint8_t Scalar;
    149 };
    150 
    151 template <typename tCellFormat, int tCells>
    152 struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> {
    153   typedef std::int8_t Scalar;
    154 };
    155 
    156 // KernelFormat describes fully the input data layout that a kernel expects.
    157 // It consists of two KernelSideFormat's, one for LHS and one for RHS.
    158 template <typename tLhs, typename tRhs>
    159 struct KernelFormat {
    160   typedef tLhs Lhs;
    161   typedef tRhs Rhs;
    162 
    163   static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, "");
    164   static const int kDepth = Lhs::Cell::kDepth;
    165   static const int kRows = Lhs::Cell::kWidth * Lhs::kCells;
    166   static const int kCols = Rhs::Cell::kWidth * Rhs::kCells;
    167 };
    168 
    169 inline const char* CellOrderName(CellOrder o) {
    170   switch (o) {
    171     case CellOrder::DepthMajor:
    172       return "DepthMajor";
    173     case CellOrder::WidthMajor:
    174       return "WidthMajor";
    175     case CellOrder::Diagonal:
    176       return "Diagonal";
    177     default:
    178       assert(false);
    179       return nullptr;
    180   }
    181 }
    182 
    183 // Returns the offset into a cell, at which a given coefficient is stored.
    184 template <typename CellFormat>
    185 inline int OffsetIntoCell(int w, int d) {
    186   const int size = CellFormat::kWidth;
    187   switch (CellFormat::kOrder) {
    188     case CellOrder::DepthMajor:
    189       return w + d * CellFormat::kWidth;
    190     case CellOrder::WidthMajor:
    191       return d + w * CellFormat::kDepth;
    192     case CellOrder::Diagonal:
    193       assert(CellFormat::kWidth == CellFormat::kDepth);
    194       return ((size + w - d) * size + d) % (size * size);
    195     default:
    196       assert(false);
    197       return 0;
    198   }
    199 }
    200 
    201 // KernelBase is the virtual base class below all kernels.
    202 // The idea is that we don't need to templatize all our code on the exact
    203 // kernel type; we only need to templatize on kernel format. Kernels
    204 // sharing the same format can thus share the same packing/unpacking code.
    205 struct KernelBase {
    206   virtual const char* Name() const = 0;
    207 
    208   // This is the kernel implementation. We use the word 'run' consistently
    209   // throughout gemmlowp to mean an inner loop, the implementation of which
    210   // is to be provided by a separate optimized function.
    211   virtual void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
    212                    std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
    213                    const std::uint8_t* rhs_ptr, std::size_t start_depth,
    214                    std::size_t run_depth) const = 0;
    215 
    216   virtual ~KernelBase() {}
    217 };
    218 
    219 template <typename KernelScalarType>
    220 struct ZeroPointInputValue {};
    221 
    222 template <>
    223 struct ZeroPointInputValue<std::uint8_t> {
    224   static constexpr std::uint8_t kValue = 0;
    225 };
    226 
    227 template <>
    228 struct ZeroPointInputValue<std::int8_t> {
    229   static constexpr std::uint8_t kValue = 128;
    230 };
    231 
    232 }  // namespace gemmlowp
    233 
    234 #endif  // GEMMLOWP_INTERNAL_KERNEL_H_
    235