Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 // This file uses MKL CBLAS xGEMM for acceleration of TF Matrix-Matrix
     19 // Multiplication (MatMul) operations.
     20 // We currently register this kernel only for MKL supported data
     21 // types (float, double, complex64, complex128). The macro INTEL_MKL is defined
     22 // by the build system only when MKL is chosen as an option at configure stage
     23 // and when it is undefined at build time, this file becomes an empty
     24 // compilation unit
     25 
     26 #if defined(INTEL_MKL)
     27 
     28 #include "mkl_cblas.h"
     29 #include "tensorflow/core/framework/op.h"
     30 #include "tensorflow/core/framework/op_kernel.h"
     31 #include "tensorflow/core/framework/register_types.h"
     32 #include "tensorflow/core/kernels/fill_functor.h"
     33 
     34 namespace tensorflow {
     35 
     36 typedef Eigen::ThreadPoolDevice CPUDevice;
     37 
     38 template <typename Device, typename T, bool USE_CUBLAS>
     39 class MklMatMulOp : public OpKernel {
     40  public:
     41   explicit MklMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     42     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
     43     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
     44   }
     45 
     46   void Compute(OpKernelContext* ctx) override {
     47     const Tensor& a = ctx->input(0);
     48     const Tensor& b = ctx->input(1);
     49 
     50     // Check that the dimensions of the two matrices are valid.
     51     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
     52                 errors::InvalidArgument("In[0] is not a matrix"));
     53     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
     54                 errors::InvalidArgument("In[1] is not a matrix"));
     55     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
     56     dim_pair[0].first = transpose_a_ ? 0 : 1;
     57     dim_pair[0].second = transpose_b_ ? 1 : 0;
     58 
     59     OP_REQUIRES(
     60         ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
     61         errors::InvalidArgument(
     62             "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
     63             ", In[1]: ", b.shape().DebugString()));
     64     int a_dim_remaining = 1 - dim_pair[0].first;
     65     int b_dim_remaining = 1 - dim_pair[0].second;
     66     TensorShape out_shape(
     67         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
     68     Tensor* out = nullptr;
     69     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
     70 
     71     if (out->NumElements() == 0) {
     72       // If a has shape [0, x] or b has shape [x, 0], the output shape
     73       // is a 0-element matrix, so there is nothing to do.
     74       return;
     75     }
     76 
     77     if (a.NumElements() == 0 || b.NumElements() == 0) {
     78       // If a has shape [x, 0] and b has shape [0, y], the
     79       // output shape is [x, y] where x and y are non-zero, so we fill
     80       // the output with zeros.
     81       functor::SetZeroFunctor<Device, T> f;
     82       f(ctx->eigen_device<Device>(), out->flat<T>());
     83       return;
     84     }
     85 
     86     const int m = a.dim_size(1 - dim_pair[0].first);
     87     const int k = a.dim_size(dim_pair[0].first);
     88     const int n = b.dim_size(1 - dim_pair[0].second);
     89     bool transpose_a = dim_pair[0].first == 0;
     90     bool transpose_b = dim_pair[0].second == 1;
     91 
     92     auto a_ptr = (a.template flat<T>().data());
     93     auto b_ptr = (b.template flat<T>().data());
     94     auto c_ptr = (out->template flat<T>().data());
     95 
     96     MklBlasGemm(transpose_a, transpose_b, m, n, k, a_ptr, transpose_a ? m : k,
     97                 b_ptr, transpose_b ? k : n, c_ptr, n);
     98   }
     99 
    100  private:
    101   bool transpose_a_;
    102   bool transpose_b_;
    103 
    104   // --------------------------------------------------------------------------
    105   //
    106   // @brief Matrix-Matrix Multiplication with FP32 tensors, a, b, c using CBLAS
    107   // interface. c = op(a) * op(b)
    108   //
    109   // @param transa  Specifies the form of op(a) used in MatMul. If transa is
    110   // true, then op(a) = a^T, otherwise op(a) = a
    111   //
    112   // @param transb  Specifies the form of op(b) used in MatMul. If transb is
    113   // true, then op(b) = b^T, otherwise op(b) = b
    114   //
    115   // @param m       Specifies the number of rows of the matrix op(a) and of the
    116   // matrix c. The value of m must be at least zero.
    117   //
    118   // @param n       Specifies the number of columns of the matrix op(b) and the
    119   // number of columns of the matrix c. The value of n must be at least zero.
    120   //
    121   // @param k       Specifies the number of columns of the matrix op(a) and the
    122   // number of rows of the matrix op(b)
    123   //
    124   // @param a       Address of matrix a
    125   //
    126   // @param lda     Leading dimension of 'a' matrix. This is set at calling site
    127   // depending on transa parameter. Since TF uses row-major
    128   // layout, leading dimension is the stride between consecutive rows
    129   // lda = max(1,k) when transa is false, otherwise lda = max(1,m)
    130   //
    131   // @param b       Address of matrix b
    132   //
    133   // @param ldb     Leading dimension of 'b' matrix. This is set at calling site
    134   // depending on transb parameter. Since TF uses row-major
    135   // layout, leading dimension is the stride between consecutive rows
    136   // ldb = max(1,n) when transb is false, otherwise ldb = max(1,k)
    137   //
    138   // @param c       Address of matrix c
    139   //
    140   // @param ldc     Leading dimension of 'c' matrix. Since TF uses row-major
    141   // layout, leading dimension is the stride between consecutive rows, max(1,n)
    142   //
    143   // --------------------------------------------------------------------------
    144   void MklBlasGemm(bool transa, bool transb, const int m, const int n,
    145                    const int k, const float* a, const int lda, const float* b,
    146                    const int ldb, float* c, const int ldc) {
    147     // BLAS GEMM API defines Matrix Multiplication as c = alpha * op(a) * op(b)
    148     // + beta * c.
    149     // Since TF MatMul does not have parameters for alpha, beta, we set them to
    150     // 1.0 and 0.0 respectively.
    151     const float alpha = 1.0f;
    152     const float beta = 0.0f;
    153     cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
    154                 transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
    155                 ldb, beta, c, ldc);
    156   }
    157 
    158   // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
    159   // parameters, look at FP32 function description.
    160   void MklBlasGemm(bool transa, bool transb, const int m, const int n,
    161                    const int k, const double* a, const int lda, const double* b,
    162                    const int ldb, double* c, const int ldc) {
    163     const double alpha = 1.0;
    164     const double beta = 0.0;
    165     cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
    166                 transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
    167                 ldb, beta, c, ldc);
    168   }
    169 
    170   // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
    171   // For detailed info about parameters, look at FP32 function description.
    172   void MklBlasGemm(bool transa, bool transb, const int m, const int n,
    173                    const int k, const complex64* a, const int lda,
    174                    const complex64* b, const int ldb,
    175                    complex64* c, int const ldc) {
    176     const MKL_Complex8 alpha = {1.0f, 0.0f};
    177     const MKL_Complex8 beta = {0.0f, 0.0f};
    178     cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
    179                 transb ? CblasTrans : CblasNoTrans,
    180                 m, n, k, &alpha, reinterpret_cast<const MKL_Complex8*>(a), lda,
    181                 reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
    182                 reinterpret_cast<MKL_Complex8*>(c), ldc);
    183   }
    184 
    185   // Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
    186   // tensors. For detailed info about parameters, look at FP32 function
    187   // description.
    188   void MklBlasGemm(bool transa, bool transb, const int m, const int n,
    189                    const int k, const complex128* a, const int lda,
    190                    const complex128* b, const int ldb,
    191                    complex128* c, const int ldc) {
    192     const MKL_Complex16 alpha = {1.0, 0.0};
    193     const MKL_Complex16 beta = {0.0, 0.0};
    194     cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
    195                 transb ? CblasTrans : CblasNoTrans,
    196                 m, n, k, &alpha, reinterpret_cast<const MKL_Complex16*>(a), lda,
    197                 reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
    198                 reinterpret_cast<MKL_Complex16*>(c), ldc);
    199   }
    200 };
    201 
    202 #define REGISTER_CPU(T)                                         \
    203   REGISTER_KERNEL_BUILDER(                                      \
    204       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    205       MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
    206 
    207 // TODO(inteltf) Consider template specialization when adding/removing
    208 // additional types
    209 TF_CALL_float(REGISTER_CPU);
    210 TF_CALL_double(REGISTER_CPU);
    211 TF_CALL_complex64(REGISTER_CPU);
    212 TF_CALL_complex128(REGISTER_CPU);
    213 
    214 }  // namespace tensorflow
    215 #endif  // INTEL_MKL
    216