Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 // This file uses MKL CBLAS batched xGEMM for acceleration of TF Batch
     19 // Matrix-Matrix Multiplication (MatMul) operations.
     20 // We currently register this kernel only for MKL supported data
     21 // types (float, double, complex64, complex128). The macro INTEL_MKL is defined
     22 // by the build system only when MKL is chosen as an option at configure stage
     23 // and when it is undefined at build time, this file becomes an empty
     24 // compilation unit
     25 
     26 #define EIGEN_USE_THREADS
     27 
     28 #if defined(INTEL_MKL)
     29 #include <vector>
     30 #include "mkl_cblas.h"
     31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     32 #include "tensorflow/core/framework/op.h"
     33 #include "tensorflow/core/framework/op_kernel.h"
     34 #include "tensorflow/core/framework/register_types.h"
     35 #include "tensorflow/core/framework/tensor.h"
     36 #include "tensorflow/core/framework/tensor_shape.h"
     37 #include "tensorflow/core/framework/type_traits.h"
     38 #include "tensorflow/core/framework/types.h"
     39 #include "tensorflow/core/kernels/fill_functor.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 #include "tensorflow/core/platform/types.h"
     42 
     43 namespace tensorflow {
     44 
     45 typedef Eigen::ThreadPoolDevice CPUDevice;
     46 
     47 template <typename Device, typename Scalar>
     48 class BatchMatMulMkl : public OpKernel {
     49  public:
     50   explicit BatchMatMulMkl(OpKernelConstruction *context) : OpKernel(context) {
     51     OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
     52     OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
     53   }
     54 
     55   virtual ~BatchMatMulMkl() {}
     56 
     57   void Compute(OpKernelContext *ctx) override {
     58     const Tensor &lhs = ctx->input(0);
     59     const Tensor &rhs = ctx->input(1);
     60     OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
     61                 errors::InvalidArgument("lhs and rhs has different ndims: ",
     62                                         lhs.shape().DebugString(), " vs. ",
     63                                         rhs.shape().DebugString()));
     64     const int ndims = lhs.dims();
     65     OP_REQUIRES(
     66         ctx, ndims >= 2,
     67         errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
     68     TensorShape out_shape;
     69     for (int i = 0; i < ndims - 2; ++i) {
     70       OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
     71                   errors::InvalidArgument(
     72                       "lhs.dim(", i, ") and rhs.dim(", i,
     73                       ") must be the same: ", lhs.shape().DebugString(), " vs ",
     74                       rhs.shape().DebugString()));
     75       out_shape.AddDim(lhs.dim_size(i));
     76     }
     77     auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements();
     78     auto lhs_rows = lhs.dim_size(ndims - 2);
     79     auto lhs_cols = lhs.dim_size(ndims - 1);
     80     auto rhs_rows = rhs.dim_size(ndims - 2);
     81     auto rhs_cols = rhs.dim_size(ndims - 1);
     82     if (adj_x_) std::swap(lhs_rows, lhs_cols);
     83     if (adj_y_) std::swap(rhs_rows, rhs_cols);
     84     OP_REQUIRES(ctx, lhs_cols == rhs_rows,
     85                 errors::InvalidArgument(
     86                     "lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows,
     87                     ": ", lhs.shape().DebugString(), " ",
     88                     rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_));
     89     out_shape.AddDim(lhs_rows);
     90     out_shape.AddDim(rhs_cols);
     91     Tensor *out = nullptr;
     92     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
     93     if (out->NumElements() == 0) {
     94       return;
     95     }
     96     if (lhs.NumElements() == 0 || rhs.NumElements() == 0) {
     97       functor::SetZeroFunctor<Device, Scalar> f;
     98       f(ctx->eigen_device<Device>(), out->flat<Scalar>());
     99       return;
    100     }
    101 
    102     auto rhs_reshaped = rhs.template flat_inner_dims<Scalar, 3>();
    103     auto lhs_reshaped = lhs.template flat_inner_dims<Scalar, 3>();
    104     auto out_reshaped = out->template flat_inner_dims<Scalar, 3>();
    105     const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1);
    106     const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2);
    107     const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2);
    108 
    109     std::vector<MKL_INT> m_array(batch_size, M);
    110     std::vector<MKL_INT> n_array(batch_size, N);
    111     std::vector<MKL_INT> k_array(batch_size, K);
    112     std::vector<MKL_INT> lda_array(batch_size, adj_x_ ? M : K);
    113     std::vector<MKL_INT> ldb_array(batch_size, adj_y_ ? K : N);
    114     std::vector<MKL_INT> ldc_array(batch_size, N);
    115     std::vector<MKL_INT> group_size(1, batch_size);
    116     std::vector<const Scalar *> a_array;
    117     std::vector<const Scalar *> b_array;
    118     std::vector<Scalar *> c_array;
    119     a_array.reserve(batch_size);
    120     b_array.reserve(batch_size);
    121     c_array.reserve(batch_size);
    122     for (int64 i = 0; i < batch_size; i++) {
    123       a_array.push_back(&lhs_reshaped(i, 0, 0));
    124       b_array.push_back(&rhs_reshaped(i, 0, 0));
    125       c_array.push_back(&out_reshaped(i, 0, 0));
    126     }
    127 
    128     MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0],
    129                       &k_array[0], &a_array[0], &lda_array[0], &b_array[0],
    130                       &ldb_array[0], &c_array[0], &ldc_array[0], 1,
    131                       &group_size[0]);
    132   }
    133 
    134  private:
    135   bool adj_x_;
    136   bool adj_y_;
    137 
    138   void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
    139                          const bool TransB, const MKL_INT *M_Array,
    140                          const MKL_INT *N_Array, const MKL_INT *K_Array,
    141                          const float **A_Array, const MKL_INT *lda_Array,
    142                          const float **B_Array, const MKL_INT *ldb_Array,
    143                          float **C_Array, const MKL_INT *ldc_Array,
    144                          const MKL_INT group_count, const MKL_INT *group_size) {
    145     std::vector<CBLAS_TRANSPOSE> TransA_Array(
    146         group_size[0], TransA ? CblasTrans : CblasNoTrans);
    147     std::vector<CBLAS_TRANSPOSE> TransB_Array(
    148         group_size[0], TransB ? CblasTrans : CblasNoTrans);
    149     std::vector<float> alpha_Array(group_size[0], 1.0);
    150     std::vector<float> beta_Array(group_size[0], 0.0);
    151     cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], M_Array,
    152                       N_Array, K_Array, &alpha_Array[0], A_Array, lda_Array,
    153                       B_Array, ldb_Array, &beta_Array[0], C_Array, ldc_Array,
    154                       group_count, group_size);
    155   }
    156 
    157   void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
    158                          const bool TransB, const MKL_INT *M_Array,
    159                          const MKL_INT *N_Array, const MKL_INT *K_Array,
    160                          const double **A_Array, const MKL_INT *lda_Array,
    161                          const double **B_Array, const MKL_INT *ldb_Array,
    162                          double **C_Array, const MKL_INT *ldc_Array,
    163                          const MKL_INT group_count, const MKL_INT *group_size) {
    164     std::vector<CBLAS_TRANSPOSE> TransA_array(
    165         group_size[0], TransA ? CblasTrans : CblasNoTrans);
    166     std::vector<CBLAS_TRANSPOSE> TransB_array(
    167         group_size[0], TransB ? CblasTrans : CblasNoTrans);
    168     std::vector<double> alpha_Array(group_size[0], 1.0);
    169     std::vector<double> beta_Array(group_size[0], 0.0);
    170     cblas_dgemm_batch(Layout, &TransA_array[0], &TransB_array[0], M_Array,
    171                       N_Array, K_Array, &alpha_Array[0], A_Array, lda_Array,
    172                       B_Array, ldb_Array, &beta_Array[0], C_Array, ldc_Array,
    173                       group_count, group_size);
    174   }
    175 
    176   void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
    177                          const bool TransB, const MKL_INT *M_Array,
    178                          const MKL_INT *N_Array, const MKL_INT *K_Array,
    179                          const complex64 **A_Array, const MKL_INT *lda_Array,
    180                          const complex64 **B_Array, const MKL_INT *ldb_Array,
    181                          complex64 **C_Array, const MKL_INT *ldc_Array,
    182                          const MKL_INT group_count, const MKL_INT *group_size) {
    183     std::vector<CBLAS_TRANSPOSE> TransA_array(
    184         group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
    185     std::vector<CBLAS_TRANSPOSE> TransB_array(
    186         group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
    187     std::vector<complex64> alpha_Array(group_size[0], {1.0f, 0.0f});
    188     std::vector<complex64> beta_Array(group_size[0], {0.0f, 0.0f});
    189     cblas_cgemm_batch(
    190         Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
    191         static_cast<const void *>(&alpha_Array[0]),
    192         reinterpret_cast<const void **>(A_Array), lda_Array,
    193         reinterpret_cast<const void **>(B_Array), ldb_Array,
    194         static_cast<const void *>(&beta_Array[0]),
    195         reinterpret_cast<void **>(C_Array), ldc_Array, group_count, group_size);
    196   }
    197 
    198   void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
    199                          const bool TransB, const MKL_INT *M_Array,
    200                          const MKL_INT *N_Array, const MKL_INT *K_Array,
    201                          const complex128 **A_Array,
    202                          const MKL_INT *lda_Array,
    203                          const complex128 **B_Array,
    204                          const MKL_INT *ldb_Array, complex128 **C_Array,
    205                          const MKL_INT *ldc_Array, const MKL_INT group_count,
    206                          const MKL_INT *group_size) {
    207     std::vector<CBLAS_TRANSPOSE> TransA_array(
    208         group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
    209     std::vector<CBLAS_TRANSPOSE> TransB_array(
    210         group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
    211     std::vector<complex128> alpha_Array(group_size[0], {1.0f, 0.0f});
    212     std::vector<complex128> beta_Array(group_size[0], {0.0f, 0.0f});
    213     cblas_zgemm_batch(
    214         Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
    215         static_cast<const void *>(&alpha_Array[0]),
    216         reinterpret_cast<const void **>(A_Array), lda_Array,
    217         reinterpret_cast<const void **>(B_Array), ldb_Array,
    218         static_cast<const void *>(&beta_Array[0]),
    219         reinterpret_cast<void **>(C_Array), ldc_Array, group_count, group_size);
    220   }
    221 };
    222 
    223 #define REGISTER_BATCH_MATMUL_MKL(TYPE)                                 \
    224   REGISTER_KERNEL_BUILDER(                                              \
    225       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
    226       BatchMatMulMkl<CPUDevice, TYPE>)
    227 
    228 TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
    229 TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
    230 TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
    231 TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
    232 
    233 }  // end namespace tensorflow
    234 #endif
    235