1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 // This file uses MKL CBLAS xGEMM for acceleration of TF Matrix-Matrix 19 // Multiplication (MatMul) operations. 20 // We currently register this kernel only for MKL supported data 21 // types (float, double, complex64, complex128). The macro INTEL_MKL is defined 22 // by the build system only when MKL is chosen as an option at configure stage 23 // and when it is undefined at build time, this file becomes an empty 24 // compilation unit 25 26 #if defined(INTEL_MKL) 27 28 #include "mkl_cblas.h" 29 #include "tensorflow/core/framework/op.h" 30 #include "tensorflow/core/framework/op_kernel.h" 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/kernels/fill_functor.h" 33 34 namespace tensorflow { 35 36 typedef Eigen::ThreadPoolDevice CPUDevice; 37 38 template <typename Device, typename T, bool USE_CUBLAS> 39 class MklMatMulOp : public OpKernel { 40 public: 41 explicit MklMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 42 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); 43 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); 44 } 45 46 void Compute(OpKernelContext* ctx) override { 47 const Tensor& a = ctx->input(0); 48 const Tensor& b = ctx->input(1); 49 50 // Check that the dimensions of the two matrices are valid. 51 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), 52 errors::InvalidArgument("In[0] is not a matrix")); 53 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), 54 errors::InvalidArgument("In[1] is not a matrix")); 55 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; 56 dim_pair[0].first = transpose_a_ ? 0 : 1; 57 dim_pair[0].second = transpose_b_ ? 1 : 0; 58 59 OP_REQUIRES( 60 ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), 61 errors::InvalidArgument( 62 "Matrix size-incompatible: In[0]: ", a.shape().DebugString(), 63 ", In[1]: ", b.shape().DebugString())); 64 int a_dim_remaining = 1 - dim_pair[0].first; 65 int b_dim_remaining = 1 - dim_pair[0].second; 66 TensorShape out_shape( 67 {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); 68 Tensor* out = nullptr; 69 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); 70 71 if (out->NumElements() == 0) { 72 // If a has shape [0, x] or b has shape [x, 0], the output shape 73 // is a 0-element matrix, so there is nothing to do. 74 return; 75 } 76 77 if (a.NumElements() == 0 || b.NumElements() == 0) { 78 // If a has shape [x, 0] and b has shape [0, y], the 79 // output shape is [x, y] where x and y are non-zero, so we fill 80 // the output with zeros. 81 functor::SetZeroFunctor<Device, T> f; 82 f(ctx->eigen_device<Device>(), out->flat<T>()); 83 return; 84 } 85 86 const int m = a.dim_size(1 - dim_pair[0].first); 87 const int k = a.dim_size(dim_pair[0].first); 88 const int n = b.dim_size(1 - dim_pair[0].second); 89 bool transpose_a = dim_pair[0].first == 0; 90 bool transpose_b = dim_pair[0].second == 1; 91 92 auto a_ptr = (a.template flat<T>().data()); 93 auto b_ptr = (b.template flat<T>().data()); 94 auto c_ptr = (out->template flat<T>().data()); 95 96 MklBlasGemm(transpose_a, transpose_b, m, n, k, a_ptr, transpose_a ? m : k, 97 b_ptr, transpose_b ? k : n, c_ptr, n); 98 } 99 100 private: 101 bool transpose_a_; 102 bool transpose_b_; 103 104 // -------------------------------------------------------------------------- 105 // 106 // @brief Matrix-Matrix Multiplication with FP32 tensors, a, b, c using CBLAS 107 // interface. c = op(a) * op(b) 108 // 109 // @param transa Specifies the form of op(a) used in MatMul. If transa is 110 // true, then op(a) = a^T, otherwise op(a) = a 111 // 112 // @param transb Specifies the form of op(b) used in MatMul. If transb is 113 // true, then op(b) = b^T, otherwise op(b) = b 114 // 115 // @param m Specifies the number of rows of the matrix op(a) and of the 116 // matrix c. The value of m must be at least zero. 117 // 118 // @param n Specifies the number of columns of the matrix op(b) and the 119 // number of columns of the matrix c. The value of n must be at least zero. 120 // 121 // @param k Specifies the number of columns of the matrix op(a) and the 122 // number of rows of the matrix op(b) 123 // 124 // @param a Address of matrix a 125 // 126 // @param lda Leading dimension of 'a' matrix. This is set at calling site 127 // depending on transa parameter. Since TF uses row-major 128 // layout, leading dimension is the stride between consecutive rows 129 // lda = max(1,k) when transa is false, otherwise lda = max(1,m) 130 // 131 // @param b Address of matrix b 132 // 133 // @param ldb Leading dimension of 'b' matrix. This is set at calling site 134 // depending on transb parameter. Since TF uses row-major 135 // layout, leading dimension is the stride between consecutive rows 136 // ldb = max(1,n) when transb is false, otherwise ldb = max(1,k) 137 // 138 // @param c Address of matrix c 139 // 140 // @param ldc Leading dimension of 'c' matrix. Since TF uses row-major 141 // layout, leading dimension is the stride between consecutive rows, max(1,n) 142 // 143 // -------------------------------------------------------------------------- 144 void MklBlasGemm(bool transa, bool transb, const int m, const int n, 145 const int k, const float* a, const int lda, const float* b, 146 const int ldb, float* c, const int ldc) { 147 // BLAS GEMM API defines Matrix Multiplication as c = alpha * op(a) * op(b) 148 // + beta * c. 149 // Since TF MatMul does not have parameters for alpha, beta, we set them to 150 // 1.0 and 0.0 respectively. 151 const float alpha = 1.0f; 152 const float beta = 0.0f; 153 cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, 154 transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, 155 ldb, beta, c, ldc); 156 } 157 158 // Matrix-Matrix Multiplication with FP64 tensors. For detailed info about 159 // parameters, look at FP32 function description. 160 void MklBlasGemm(bool transa, bool transb, const int m, const int n, 161 const int k, const double* a, const int lda, const double* b, 162 const int ldb, double* c, const int ldc) { 163 const double alpha = 1.0; 164 const double beta = 0.0; 165 cblas_dgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, 166 transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b, 167 ldb, beta, c, ldc); 168 } 169 170 // Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors. 171 // For detailed info about parameters, look at FP32 function description. 172 void MklBlasGemm(bool transa, bool transb, const int m, const int n, 173 const int k, const complex64* a, const int lda, 174 const complex64* b, const int ldb, 175 complex64* c, int const ldc) { 176 const MKL_Complex8 alpha = {1.0f, 0.0f}; 177 const MKL_Complex8 beta = {0.0f, 0.0f}; 178 cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, 179 transb ? CblasTrans : CblasNoTrans, 180 m, n, k, &alpha, reinterpret_cast<const MKL_Complex8*>(a), lda, 181 reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta, 182 reinterpret_cast<MKL_Complex8*>(c), ldc); 183 } 184 185 // Matrix-Matrix Multiplication with Complex128 (std::complex<double>) 186 // tensors. For detailed info about parameters, look at FP32 function 187 // description. 188 void MklBlasGemm(bool transa, bool transb, const int m, const int n, 189 const int k, const complex128* a, const int lda, 190 const complex128* b, const int ldb, 191 complex128* c, const int ldc) { 192 const MKL_Complex16 alpha = {1.0, 0.0}; 193 const MKL_Complex16 beta = {0.0, 0.0}; 194 cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, 195 transb ? CblasTrans : CblasNoTrans, 196 m, n, k, &alpha, reinterpret_cast<const MKL_Complex16*>(a), lda, 197 reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta, 198 reinterpret_cast<MKL_Complex16*>(c), ldc); 199 } 200 }; 201 202 #define REGISTER_CPU(T) \ 203 REGISTER_KERNEL_BUILDER( \ 204 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 205 MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); 206 207 // TODO(inteltf) Consider template specialization when adding/removing 208 // additional types 209 TF_CALL_float(REGISTER_CPU); 210 TF_CALL_double(REGISTER_CPU); 211 TF_CALL_complex64(REGISTER_CPU); 212 TF_CALL_complex128(REGISTER_CPU); 213 214 } // namespace tensorflow 215 #endif // INTEL_MKL 216