Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
     17 #define TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
     18 
     19 #include <stdlib.h>
     20 
     21 #include "third_party/eigen3/Eigen/Core"
     22 #include "tensorflow/core/platform/types.h"
     23 
     24 // This is an unoptimized but debuggable implementation of the GEMM matrix
     25 // multiply function, used to compare to faster but more opaque versions, or
     26 // for bit depths or argument combinations that aren't supported by optimized
     27 // code.
     28 // It assumes the row-major convention used by TensorFlow, and implements
     29 // C = A * B, like the standard BLAS GEMM interface. If the transpose flags are
     30 // true, then the relevant matrix is treated as stored in column-major order.
     31 
     32 namespace tensorflow {
     33 template <class T1, class T2, class T3>
     34 void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c,
     35                    size_t m, size_t n, size_t k, const T1* a, int32 offset_a,
     36                    size_t lda, const T2* b, int32 offset_b, size_t ldb, T3* c,
     37                    int32 shift_c, int32 offset_c, int32 mult_c, size_t ldc) {
     38   int a_i_stride;
     39   int a_l_stride;
     40   if (transpose_a) {
     41     a_i_stride = 1;
     42     a_l_stride = lda;
     43   } else {
     44     a_i_stride = lda;
     45     a_l_stride = 1;
     46   }
     47   int b_j_stride;
     48   int b_l_stride;
     49   if (transpose_b) {
     50     b_j_stride = ldb;
     51     b_l_stride = 1;
     52   } else {
     53     b_j_stride = 1;
     54     b_l_stride = ldb;
     55   }
     56   int c_i_stride;
     57   int c_j_stride;
     58   if (transpose_c) {
     59     c_i_stride = 1;
     60     c_j_stride = ldc;
     61   } else {
     62     c_i_stride = ldc;
     63     c_j_stride = 1;
     64   }
     65 
     66   const int32 highest = static_cast<int32>(Eigen::NumTraits<T3>::highest());
     67   const int32 lowest = static_cast<int32>(Eigen::NumTraits<T3>::lowest());
     68   const int32 rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1));
     69 
     70   int i, j, l;
     71   for (j = 0; j < n; j++) {
     72     for (i = 0; i < m; i++) {
     73       int32 total = 0;
     74       for (l = 0; l < k; l++) {
     75         const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
     76         const int32 a_value = static_cast<int32>(a[a_index]) - offset_a;
     77         const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
     78         const int32 b_value = static_cast<int32>(b[b_index]) - offset_b;
     79         total += (a_value * b_value);
     80       }
     81       const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
     82       int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c);
     83       if (output > highest) {
     84         output = highest;
     85       }
     86       if (output < lowest) {
     87         output = lowest;
     88       }
     89       c[c_index] = static_cast<T3>(output);
     90     }
     91   }
     92 }
     93 }  // namespace tensorflow
     94 
     95 #endif  // TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_
     96