Home | History | Annotate | Download | only in internal
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // kernel_reference.h: a reference kernel for CPU architectures where we don't
     16 // have optimized kernels yet. Also useful for testing, as it's templatized
     17 // to have any arbitrary format, allowing tests to cover all sorts of corner
     18 // cases.
     19 
     20 #ifndef GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
     21 #define GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
     22 
     23 #include "kernel.h"
     24 
     25 #include <cstdio>
     26 #include <cstring>
     27 
     28 namespace gemmlowp {
     29 
     30 // This kernel is templatized in an arbitrary Format template parameter,
     31 // allowing it to have any arbitrary format.
     32 template <typename tFormat>
     33 struct ReferenceKernel : KernelBase {
     34   typedef tFormat Format;
     35 
     36   const char* Name() const override {
     37     static char buf[256];
     38     snprintf(buf, sizeof(buf),
     39              "reference(Lhs: %d cells %dx%d %s, Rhs: %d cells %dx%d %s)",
     40              Format::Lhs::kCells, Format::Lhs::Cell::kWidth,
     41              Format::Lhs::Cell::kDepth,
     42              CellOrderName(Format::Lhs::Cell::kOrder), Format::Rhs::kCells,
     43              Format::Rhs::Cell::kDepth, Format::Rhs::Cell::kWidth,
     44              CellOrderName(Format::Rhs::Cell::kOrder));
     45     return buf;
     46   }
     47 
     48   void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
     49            std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
     50            const std::uint8_t* rhs_ptr, std::size_t start_depth,
     51            std::size_t run_depth) const override {
     52     std::int32_t accumulator[Format::kRows * Format::kCols];
     53     memset(accumulator, 0, sizeof(accumulator));
     54 
     55     const int run_depth_cells = static_cast<int>(run_depth / Format::kDepth);
     56 
     57     // The outer loop is over the depth dimension.
     58     for (int dc = 0; dc < run_depth_cells; dc++) {
     59       // The next two loops are over cells of the Lhs (stacked vertically),
     60       // and over cells of the Rhs (stacked horizontally).
     61       for (int rc = 0; rc < Format::Lhs::kCells; rc++) {
     62         const std::uint8_t* lhs_cell_ptr = lhs_ptr +
     63                                            (dc * Format::Lhs::kCells + rc) *
     64                                                Format::Lhs::Cell::kWidth *
     65                                                Format::kDepth;
     66         for (int cc = 0; cc < Format::Rhs::kCells; cc++) {
     67           const std::uint8_t* rhs_cell_ptr = rhs_ptr +
     68                                              (dc * Format::Rhs::kCells + cc) *
     69                                                  Format::Rhs::Cell::kWidth *
     70                                                  Format::kDepth;
     71 
     72           // Now we are inside one cell of the Lhs and inside one cell
     73           // of the Rhs, so the remaining inner loops are just
     74           // traditional three loops of matrix multiplication.
     75           for (int di = 0; di < Format::kDepth; di++) {
     76             for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) {
     77               for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) {
     78                 const std::uint8_t* lhs_coeff_ptr =
     79                     lhs_cell_ptr +
     80                     OffsetIntoCell<typename Format::Lhs::Cell>(ri, di);
     81                 const std::uint8_t* rhs_coeff_ptr =
     82                     rhs_cell_ptr +
     83                     OffsetIntoCell<typename Format::Rhs::Cell>(ci, di);
     84                 std::int32_t* accumulator_coeff_ptr =
     85                     accumulator + (ri + rc * Format::Lhs::Cell::kWidth) +
     86                     (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows;
     87                 *accumulator_coeff_ptr +=
     88                     std::int32_t(*lhs_coeff_ptr) * std::int32_t(*rhs_coeff_ptr);
     89               }
     90             }
     91           }
     92         }
     93       }
     94     }
     95 
     96     if (start_depth == 0) {
     97       // start_depth == 0 means we haven't accumulated anything yet, so we need
     98       // to overwrite the accumulator, as it hasn't been initialized to zero.
     99       for (int r = 0; r < Format::kRows; r++) {
    100         for (int c = 0; c < Format::kCols; c++) {
    101           dst_ptr[r * dst_row_stride + c * dst_col_stride] =
    102               accumulator[r + c * Format::kRows];
    103         }
    104       }
    105     } else {
    106       // We have already accumulated stuff, so we need to continue accumulating
    107       // instead of just overwriting.
    108       for (int r = 0; r < Format::kRows; r++) {
    109         for (int c = 0; c < Format::kCols; c++) {
    110           dst_ptr[r * dst_row_stride + c * dst_col_stride] +=
    111               accumulator[r + c * Format::kRows];
    112         }
    113       }
    114     }
    115   }
    116 };
    117 
    118 }  // namespace gemmlowp
    119 
    120 #endif  // GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
    121