Home | History | Annotate | Download | only in meta
      1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #ifndef GEMMLOWP_META_BASE_H_
     16 #define GEMMLOWP_META_BASE_H_
     17 
     18 #include <cassert>
     19 #include <cstdint>
     20 
     21 #include "../internal/common.h"
     22 
     23 namespace gemmlowp {
     24 namespace meta {
     25 
     26 template <int align>
     27 inline int AlignTo(int value) {
     28   return ((value + align - 1) / align) * align;
     29 }
     30 
     31 inline int AlignTo(int align, int value) {
     32   return ((value + align - 1) / align) * align;
     33 }
     34 
     35 template <typename Kernel_, typename OutputStream_>
     36 struct FusedKernelParams {
     37  public:
     38   typedef Kernel_ Kernel;
     39   typedef OutputStream_ OutputStream;
     40 
     41   Kernel kernel;
     42   OutputStream output_stream;
     43 };
     44 
     45 template <typename InType_, typename OutType_, typename LeftStream_,
     46           typename RightStream_, typename Kernel_, typename OutputStream_>
     47 struct GemmParams {
     48  public:
     49   typedef InType_ InType;
     50   typedef OutType_ OutType;
     51   typedef LeftStream_ LeftStream;
     52   typedef RightStream_ RightStream;
     53   typedef Kernel_ Kernel;
     54   typedef OutputStream_ OutputStream;
     55 
     56   typedef FusedKernelParams<Kernel, OutputStream> FusedKernel;
     57 
     58   // Common parameters.
     59 
     60   int m;
     61   int n;
     62   int k;
     63 
     64   const InType* lhs;
     65   const InType* rhs;
     66   OutType* result;
     67   std::uint8_t* scratch;
     68 
     69   // Specialized parameters.
     70 
     71   LeftStream left_stream;
     72   RightStream right_stream;
     73   FusedKernel fused_kernel;
     74 };
     75 
     76 template <typename InType, int lanes_count, int pack_size, int leftovers,
     77           typename StreamParams>
     78 class Stream {
     79  public:
     80   static void Pack(const InType* in, const StreamParams& params, InType* out);
     81 
     82   static int UnpackedAdvance(const StreamParams& params);
     83 
     84   static int PackedAdvance(const StreamParams& params);
     85 
     86   static int UnpackedStride(const StreamParams& params);
     87 
     88   static int PackedStride(const StreamParams& params);
     89 };
     90 
     91 template <typename InType, typename StreamType>
     92 class StreamUtil {
     93  public:
     94   static const InType* Offset(const StreamType& params, const InType* source,
     95                               int offset_stride, int offset_advance);
     96 
     97   static int Scratch(const StreamType& params, int lanes);
     98 };
     99 
    100 template <typename InType, typename OutType, typename Kernel,
    101           typename OutputStream, int kernel_m, int kernel_n, int pack_size>
    102 class MulKernel {
    103  public:
    104   static void Multiply(const InType* lhs, const InType* rhs,
    105                        const FusedKernelParams<Kernel, OutputStream>& params,
    106                        OutType* result);
    107 };
    108 
    109 template <typename InType_, typename OutType_, typename Kernel_>
    110 struct Transform1DParams {
    111   typedef InType_ InType;
    112   typedef OutType_ OutType;
    113   typedef Kernel_ Kernel;
    114 
    115   const InType* input;
    116   OutType* output;
    117   std::uint8_t* scratch;
    118 
    119   Kernel kernel;
    120 };
    121 
    122 template <typename InType, typename OutType, typename Kernel, int kernel_size,
    123           int leftovers>
    124 class Transform1DKernel {
    125  public:
    126   static void Transform(const InType* input, const Kernel& params,
    127                         OutType* output);
    128 };
    129 
    130 template <typename InType, typename OutType, typename Transform>
    131 class Transform1DUtil {
    132  public:
    133   static int EstimateComputeCost(const Transform& params);
    134 
    135   static const InType* OffsetInput(const Transform& params, const InType* input,
    136                                    int offset);
    137 
    138   static OutType* OffsetOutput(const Transform& params, OutType* output,
    139                                int offset);
    140 };
    141 
    142 }  // namespace meta
    143 }  // namespace gemmlowp
    144 
    145 #endif  // GEMMLOWP_META_BASE_H_
    146