Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // This is a set of different implementations for the basic matrix by matrix
     17 // multiply function, commonly known as GEMM after the BLAS library's naming.
     18 // Having a standard interface enables us to swap out implementations on
     19 // different platforms, to make sure we're using the optimal version. They are
     20 // implemented as C++ template functors, so they're easy to swap into all of the
     21 // different kernels that use them.
     22 
     23 #if !defined(EIGEN_USE_THREADS)
     24 #error "EIGEN_USE_THREADS must be enabled by all .cc files including this."
     25 #endif  // EIGEN_USE_THREADS
     26 
     27 #include <string.h>
     28 #include <map>
     29 #include <vector>
     30 
     31 #include "tensorflow/core/common_runtime/threadpool_device.h"
     32 #include "tensorflow/core/framework/op_kernel.h"
     33 #include "tensorflow/core/framework/tensor.h"
     34 #include "tensorflow/core/framework/tensor_types.h"
     35 
     36 // Apple provides an optimized BLAS library that is better than Eigen for their
     37 // devices, so use that if possible.
     38 #if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV)
     39 #include <Accelerate/Accelerate.h>
     40 #define USE_CBLAS_GEMM
     41 #endif  // __APPLE__
     42 
     43 // Older Raspberry Pi systems don't have NEON SIMD acceleration, so Eigen falls
     44 // back to scalar code, but OpenBLAS has much faster support so prefer that.
     45 #if defined(RASPBERRY_PI) && defined(USE_GEMM_FOR_CONV) && defined(USE_OPENBLAS)
     46 #include <cblas.h>
     47 #define USE_CBLAS_GEMM
     48 #endif
     49 
     50 // A readable but slow implementation of matrix multiplication, useful for
     51 // debugging and understanding the algorithm. Use instead of FastGemmFunctor in
     52 // the Im2ColConvFunctor template definition inside the op registration to
     53 // enable. Assumes row-major ordering of the values in memory.
     54 template <class T1, class T2, class T3>
     55 class ReferenceGemmFunctor {
     56  public:
     57   void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n,
     58                   size_t k, const T1* a, size_t lda, const T2* b, size_t ldb,
     59                   T3* c, size_t ldc) {
     60     const size_t a_i_stride = lda;
     61     const size_t a_l_stride = 1;
     62     const size_t b_j_stride = 1;
     63     const size_t b_l_stride = ldb;
     64     const size_t c_i_stride = ldc;
     65     const size_t c_j_stride = 1;
     66     size_t i, j, l;
     67     for (j = 0; j < n; j++) {
     68       for (i = 0; i < m; i++) {
     69         T3 total(0);
     70         for (l = 0; l < k; l++) {
     71           const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
     72           const T1 a_value = a[a_index];
     73           const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
     74           const T2 b_value = b[b_index];
     75           total += (a_value * b_value);
     76         }
     77         const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
     78         c[c_index] = total;
     79       }
     80     }
     81   }
     82 };
     83 
     84 // Uses the optimized EigenTensor library to implement the matrix multiplication
     85 // required by the Im2ColConvFunctor class. We supply the two input and one
     86 // output types so that the accumulator can potentially be higher-precision than
     87 // the inputs, even though we don't currently take advantage of this.
     88 template <class T1, class T2, class T3>
     89 class FastGemmFunctor {
     90  public:
     91   void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n,
     92                   size_t k, const T1* a, size_t lda, const T2* b, size_t ldb,
     93                   T3* c, size_t ldc) {
     94     typename tensorflow::TTypes<const T1>::Matrix a_matrix(a, m, k);
     95     typename tensorflow::TTypes<const T2>::Matrix b_matrix(b, k, n);
     96     typename tensorflow::TTypes<T3>::Matrix c_matrix(c, m, n);
     97 
     98     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
     99     dim_pair[0].first = 1;
    100     dim_pair[0].second = 0;
    101     c_matrix.device(ctx->eigen_device<Eigen::ThreadPoolDevice>()) =
    102         a_matrix.contract(b_matrix, dim_pair);
    103   }
    104 };
    105 
    106 // If we have a fast CBLAS library, use its implementation through a wrapper.
    107 #if defined(USE_CBLAS_GEMM)
    108 template <>
    109 class FastGemmFunctor<float, float, float> {
    110  public:
    111   void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n,
    112                   size_t k, const float* a, size_t lda, const float* b,
    113                   size_t ldb, float* c, size_t ldc) {
    114     cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
    115                 lda, b, ldb, 0.0f, c, ldc);
    116   }
    117 };
    118 #endif  // USE_CBLAS_GEMM
    119