1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // kernel_reference.h: a reference kernel for CPU architectures where we don't 16 // have optimized kernels yet. Also useful for testing, as it's templatized 17 // to have any arbitrary format, allowing tests to cover all sorts of corner 18 // cases. 19 20 #ifndef GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 21 #define GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 22 23 #include "kernel.h" 24 25 #include <cstdio> 26 #include <cstring> 27 28 namespace gemmlowp { 29 30 // This kernel is templatized in an arbitrary Format template parameter, 31 // allowing it to have any arbitrary format. 32 template <typename tFormat> 33 struct ReferenceKernel : KernelBase { 34 typedef tFormat Format; 35 36 const char* Name() const override { 37 static char buf[256]; 38 snprintf(buf, sizeof(buf), 39 "reference(Lhs: %d cells %dx%d %s, Rhs: %d cells %dx%d %s)", 40 Format::Lhs::kCells, Format::Lhs::Cell::kWidth, 41 Format::Lhs::Cell::kDepth, 42 CellOrderName(Format::Lhs::Cell::kOrder), Format::Rhs::kCells, 43 Format::Rhs::Cell::kDepth, Format::Rhs::Cell::kWidth, 44 CellOrderName(Format::Rhs::Cell::kOrder)); 45 return buf; 46 } 47 48 void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, 49 std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, 50 const std::uint8_t* rhs_ptr, std::size_t start_depth, 51 std::size_t run_depth) const override { 52 std::int32_t accumulator[Format::kRows * Format::kCols]; 53 memset(accumulator, 0, sizeof(accumulator)); 54 55 const int run_depth_cells = static_cast<int>(run_depth / Format::kDepth); 56 57 // The outer loop is over the depth dimension. 58 for (int dc = 0; dc < run_depth_cells; dc++) { 59 // The next two loops are over cells of the Lhs (stacked vertically), 60 // and over cells of the Rhs (stacked horizontally). 61 for (int rc = 0; rc < Format::Lhs::kCells; rc++) { 62 const std::uint8_t* lhs_cell_ptr = lhs_ptr + 63 (dc * Format::Lhs::kCells + rc) * 64 Format::Lhs::Cell::kWidth * 65 Format::kDepth; 66 for (int cc = 0; cc < Format::Rhs::kCells; cc++) { 67 const std::uint8_t* rhs_cell_ptr = rhs_ptr + 68 (dc * Format::Rhs::kCells + cc) * 69 Format::Rhs::Cell::kWidth * 70 Format::kDepth; 71 72 // Now we are inside one cell of the Lhs and inside one cell 73 // of the Rhs, so the remaining inner loops are just 74 // traditional three loops of matrix multiplication. 75 for (int di = 0; di < Format::kDepth; di++) { 76 for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) { 77 for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) { 78 const std::uint8_t* lhs_coeff_ptr = 79 lhs_cell_ptr + 80 OffsetIntoCell<typename Format::Lhs::Cell>(ri, di); 81 const std::uint8_t* rhs_coeff_ptr = 82 rhs_cell_ptr + 83 OffsetIntoCell<typename Format::Rhs::Cell>(ci, di); 84 std::int32_t* accumulator_coeff_ptr = 85 accumulator + (ri + rc * Format::Lhs::Cell::kWidth) + 86 (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows; 87 *accumulator_coeff_ptr += 88 std::int32_t(*lhs_coeff_ptr) * std::int32_t(*rhs_coeff_ptr); 89 } 90 } 91 } 92 } 93 } 94 } 95 96 if (start_depth == 0) { 97 // start_depth == 0 means we haven't accumulated anything yet, so we need 98 // to overwrite the accumulator, as it hasn't been initialized to zero. 99 for (int r = 0; r < Format::kRows; r++) { 100 for (int c = 0; c < Format::kCols; c++) { 101 dst_ptr[r * dst_row_stride + c * dst_col_stride] = 102 accumulator[r + c * Format::kRows]; 103 } 104 } 105 } else { 106 // We have already accumulated stuff, so we need to continue accumulating 107 // instead of just overwriting. 108 for (int r = 0; r < Format::kRows; r++) { 109 for (int c = 0; c < Format::kCols; c++) { 110 dst_ptr[r * dst_row_stride + c * dst_col_stride] += 111 accumulator[r + c * Format::kRows]; 112 } 113 } 114 } 115 } 116 }; 117 118 } // namespace gemmlowp 119 120 #endif // GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 121