1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "tensorflow/core/kernels/matmul_op.h" 21 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/kernels/fill_functor.h" 26 #include "tensorflow/core/util/matmul_autotune.h" 27 #if GOOGLE_CUDA 28 #include "cuda/include/cuda.h" 29 #include "tensorflow/core/kernels/gpu_utils.h" 30 #include "tensorflow/core/platform/stream_executor.h" 31 #endif // GOOGLE_CUDA 32 33 namespace tensorflow { 34 35 typedef Eigen::ThreadPoolDevice CPUDevice; 36 typedef Eigen::GpuDevice GPUDevice; 37 #ifdef TENSORFLOW_USE_SYCL 38 typedef Eigen::SyclDevice SYCLDevice; 39 #endif // TENSORFLOW_USE_SYCL 40 41 template <typename Device, typename T, bool USE_CUBLAS> 42 struct LaunchMatMul; 43 44 namespace { 45 // Converts a TensorFlow Tensor to an Eigen Matrix. 46 template <typename T> 47 Eigen::Map< 48 const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> 49 ToEigenMatrix(const Tensor& tensor) { 50 auto matrix = tensor.matrix<T>(); 51 return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map( 52 matrix.data(), matrix.dimension(0), matrix.dimension(1)); 53 } 54 55 // Converts a TensorFlow Tensor to an Eigen Vector. 56 template <typename T> 57 Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) { 58 auto v = tensor->flat<T>(); 59 return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0)); 60 } 61 template <typename T> 62 Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector( 63 const Tensor& tensor) { 64 auto v = tensor.flat<T>(); 65 return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0)); 66 } 67 } // namespace 68 69 // If either side can be represented as a vector, do an explicit vector 70 // matrix multiply and return true; else return false. 71 // 72 // Note: this uses plain Eigen and not Eigen Tensor because it is more 73 // efficient. 74 template <typename T> 75 bool ExplicitVectorMatrixOptimization( 76 const Tensor& a, const Tensor& b, 77 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, 78 Tensor* out) { 79 if (out->dim_size(0) == 1) { 80 if (dim_pair[0].second == 0) { 81 // Note: this case is optimized in Eigen Tensors. 82 return false; 83 } else { 84 auto out_v = ToEigenVector<T>(out); 85 auto a_v = ToEigenVector<T>(a); 86 auto b_m = ToEigenMatrix<T>(b); 87 out_v.noalias() = b_m * a_v; 88 } 89 return true; 90 } else if (out->dim_size(1) == 1) { 91 auto out_v = ToEigenVector<T>(out); 92 auto a_m = ToEigenMatrix<T>(a); 93 auto b_v = ToEigenVector<T>(b); 94 if (dim_pair[0].first == 0) { 95 out_v.noalias() = a_m.transpose() * b_v; 96 } else { 97 out_v.noalias() = a_m * b_v; 98 } 99 return true; 100 } 101 return false; 102 } 103 // Half is not supported. 104 template <> 105 bool ExplicitVectorMatrixOptimization<Eigen::half>( 106 const Tensor& a, const Tensor& b, 107 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, 108 Tensor* out) { 109 return false; 110 } 111 112 template <typename Device, typename T> 113 struct LaunchMatMulBase { 114 #if GOOGLE_CUDA 115 typedef perftools::gputools::blas::AlgorithmType AlgorithmType; 116 #else 117 typedef int64 AlgorithmType; 118 #endif // GOOGLE_CUDA 119 120 static void launch( 121 OpKernelContext* ctx, const Tensor& a, const Tensor& b, 122 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, 123 std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) { 124 #ifndef TENSORFLOW_USE_SYCL 125 // An explicit vector-matrix multiply is much better optimized than an 126 // implicit one and this is a bottleneck during non-batched inference. 127 bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out); 128 if (!was_vector) { 129 #endif // TENSORFLOW_USE_SYCL 130 functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(), 131 out->matrix<T>(), a.matrix<T>(), 132 b.matrix<T>(), dim_pair); 133 #ifndef TENSORFLOW_USE_SYCL 134 } 135 #endif // TENSORFLOW_USE_SYCL 136 } 137 138 static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx, 139 std::vector<int64>* algorithms, 140 bool* algorithm_set_flag) {} 141 }; 142 // On CPUs, we ignore USE_CUBLAS 143 template <typename T> 144 struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {}; 145 146 template <typename T, bool USE_CUBLAS> 147 struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {}; 148 149 #ifdef TENSORFLOW_USE_SYCL 150 template <typename T> 151 struct LaunchMatMulSYCL : LaunchMatMulBase<SYCLDevice, T> {}; 152 153 template <typename T, bool USE_CUBLAS> 154 struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {}; 155 #endif // TENSORFLOW_USE_SYCL 156 157 #if GOOGLE_CUDA 158 159 namespace { 160 161 template <typename T> 162 struct LaunchBlasGemv { 163 static void Compute( 164 OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans, 165 uint64 m, uint64 n, const perftools::gputools::DeviceMemory<T>& a, 166 const perftools::gputools::DeviceMemory<T>& b, 167 perftools::gputools::DeviceMemory<T>* c, 168 perftools::gputools::blas::ProfileResult* output_profile) { 169 const auto blas_trans = 170 trans ? perftools::gputools::blas::Transpose::kTranspose 171 : perftools::gputools::blas::Transpose::kNoTranspose; 172 if (output_profile == nullptr) { 173 bool blas_launch_status = 174 stream 175 ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1, 176 static_cast<T>(0.0), c, 1) 177 .ok(); 178 if (!blas_launch_status) { 179 ctx->SetStatus( 180 errors::Internal("Blas GEMV launch failed: m=", m, ", n=", n)); 181 } 182 } else { 183 bool blas_launch_status = 184 stream 185 ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0), 186 a, m, b, 1, static_cast<T>(0.0), c, 1, 187 output_profile) 188 .ok(); 189 if (!blas_launch_status) { 190 ctx->SetStatus(errors::Internal( 191 "Blas GEMV with profiling launch failed: m=", m, ", n=", n)); 192 } 193 } 194 } 195 196 static bool IsSupported() { return true; } 197 }; 198 199 template <> 200 void LaunchBlasGemv<Eigen::half>::Compute( 201 OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans, 202 uint64 m, uint64 n, const perftools::gputools::DeviceMemory<Eigen::half>& a, 203 const perftools::gputools::DeviceMemory<Eigen::half>& b, 204 perftools::gputools::DeviceMemory<Eigen::half>* c, 205 perftools::gputools::blas::ProfileResult* output_profile) { 206 ctx->SetStatus(errors::Internal( 207 "Blas GEMV launch failed: GEMV is not implemented for float16.")); 208 } 209 210 template <> 211 bool LaunchBlasGemv<Eigen::half>::IsSupported() { 212 return false; 213 } 214 215 template <typename T> 216 bool ShouldUseGemv(uint64 n) { 217 return (LaunchBlasGemv<T>::IsSupported() && n == 1); 218 } 219 220 } // namespace 221 222 bool GetCublasAutotuneComputationType( 223 const DataType& dtype, 224 perftools::gputools::blas::ComputationType* compute_type) { 225 using perftools::gputools::blas::ComputationType; 226 bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input(); 227 switch (dtype) { 228 case DT_HALF: 229 case DT_BFLOAT16: 230 if (use_f32_for_f16_computation) { 231 *compute_type = ComputationType::kF32; 232 } else { 233 *compute_type = ComputationType::kF16; 234 } 235 return false; 236 case DT_FLOAT: 237 *compute_type = ComputationType::kF32; 238 return true; 239 case DT_DOUBLE: 240 *compute_type = ComputationType::kF64; 241 return true; 242 default: 243 // Unsupported compute_type, return false. 244 return false; 245 } 246 } 247 248 // A dummy type to group matmul autotune results together. 249 struct MatmulAutoTuneGroup { 250 static string name() { return "Matmul"; } 251 }; 252 typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters, 253 perftools::gputools::blas::AlgorithmConfig> 254 AutoTuneMatmul; 255 256 template <typename T> 257 struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> { 258 static void launch( 259 OpKernelContext* ctx, const Tensor& a, const Tensor& b, 260 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair, 261 std::vector<int64>* algorithms, bool use_autotune, Tensor* out) { 262 using perftools::gputools::blas::AlgorithmConfig; 263 using perftools::gputools::blas::ComputationType; 264 using perftools::gputools::blas::kDefaultAlgorithm; 265 using perftools::gputools::blas::kDefaultBlasGemm; 266 using perftools::gputools::blas::kDefaultBlasGemv; 267 using perftools::gputools::blas::kNoAlgorithm; 268 using perftools::gputools::blas::ProfileResult; 269 using perftools::gputools::blas::Transpose; 270 Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose}; 271 const uint64 m = a.dim_size(1 - dim_pair[0].first); 272 const uint64 k = a.dim_size(dim_pair[0].first); 273 const uint64 n = b.dim_size(1 - dim_pair[0].second); 274 bool transpose_a = dim_pair[0].first == 0; 275 bool transpose_b = dim_pair[0].second == 1; 276 auto blas_transpose_a = trans[transpose_a]; 277 auto blas_transpose_b = trans[transpose_b]; 278 279 auto* stream = ctx->op_device_context()->stream(); 280 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); 281 282 auto a_ptr = AsDeviceMemory(a.template flat<T>().data(), 283 a.template flat<T>().size()); 284 auto b_ptr = AsDeviceMemory(b.template flat<T>().data(), 285 b.template flat<T>().size()); 286 auto c_ptr = AsDeviceMemory(out->template flat<T>().data(), 287 out->template flat<T>().size()); 288 auto alpha = static_cast<T>(1.0); 289 auto beta = static_cast<T>(0.0); 290 291 int device_id = stream->parent()->device_ordinal(); 292 DataType dtype = a.dtype(); 293 MatmulParameters matmul_parameters = { 294 transpose_a, transpose_b, m, n, k, dtype, device_id, 295 }; 296 AlgorithmConfig algorithm_config(kNoAlgorithm); 297 298 ComputationType computation_type; 299 bool compute_type_supported = 300 GetCublasAutotuneComputationType(dtype, &computation_type); 301 if (use_autotune && compute_type_supported && !algorithms->empty()) { 302 ProfileResult best_result; 303 // TODO(yangzihao): Unify this code with conv autotuning. 304 if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters, 305 &algorithm_config)) { 306 ProfileResult profile_result; 307 for (auto profile_algorithm : (*algorithms)) { 308 // Cublas does 309 // C = A x B 310 // where A, B and C are assumed to be in column major. 311 // We want the output to be in row-major, so we can compute 312 // C' = B' x A' (' stands for transpose) 313 bool cublas_launch_status = 314 stream 315 ->ThenBlasGemmWithAlgorithm( 316 blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr, 317 transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta, 318 &c_ptr, n, computation_type, profile_algorithm, 319 &profile_result) 320 .ok(); 321 if (cublas_launch_status) { 322 if (profile_result.is_valid()) { 323 if (profile_result.elapsed_time_in_ms() < 324 best_result.elapsed_time_in_ms()) { 325 best_result = profile_result; 326 } 327 } 328 } 329 } 330 // Try BlasGemmWithProfiling 331 bool cublas_launch_status = 332 stream 333 ->ThenBlasGemmWithProfiling( 334 blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr, 335 transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0, 336 &c_ptr, n, &profile_result) 337 .ok(); 338 if (cublas_launch_status) { 339 if (profile_result.is_valid()) { 340 if (profile_result.elapsed_time_in_ms() < 341 best_result.elapsed_time_in_ms()) { 342 best_result = profile_result; 343 } 344 } 345 } 346 // Try BlasGemvWithProfiling 347 if (ShouldUseGemv<T>(n)) { 348 LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a, 349 transpose_a ? m : k, transpose_a ? k : m, 350 a_ptr, b_ptr, &c_ptr, &profile_result); 351 if (profile_result.is_valid()) { 352 if (profile_result.elapsed_time_in_ms() < 353 best_result.elapsed_time_in_ms()) { 354 best_result = profile_result; 355 } 356 } 357 } 358 } 359 // We make sure that each matmul parameter set only gets one pass of 360 // autotune. If the best result is found, assign it to algorithm_type 361 // and insert it to autotune map. If all internal kernels of 362 // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the 363 // autotune map. 364 if (best_result.is_valid()) { 365 algorithm_config.set_algorithm(best_result.algorithm()); 366 } 367 AutoTuneMatmul::GetInstance()->Insert(matmul_parameters, 368 algorithm_config); 369 if (algorithm_config.algorithm() != kNoAlgorithm && 370 algorithm_config.algorithm() != kDefaultBlasGemm && 371 algorithm_config.algorithm() != kDefaultBlasGemv) { 372 bool cublas_launch_status = 373 stream 374 ->ThenBlasGemmWithAlgorithm( 375 blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr, 376 transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta, 377 &c_ptr, n, computation_type, algorithm_config.algorithm(), 378 nullptr) 379 .ok(); 380 if (!cublas_launch_status) { 381 ctx->SetStatus(errors::Internal( 382 "Blas GEMM with algorithm launch failed : a.shape=(", 383 a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0), 384 ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k)); 385 } 386 } 387 } 388 // For the following case, we use normal BlasGemm(): 389 // 1) We didn't set the use_autotune flag; 390 // 2) compute type does not support autotune; 391 // 3) no algorithm is found; 392 // 4) all internal kernels in autotune return invalid results. 393 // For the following case, we use normal BlasGemv(): 394 // 1) We didn't set the use_autotune flag but LaunchBlasGemv is supported 395 // and n == 1. 396 // 2) We set the use_autotune flag and it picked up BlasGemv() and set the 397 // algorithm_config.algorithm() to be kDefaultBlasGemv. 398 if (!use_autotune || !compute_type_supported || algorithms->empty() || 399 algorithm_config.algorithm() == kNoAlgorithm || 400 algorithm_config.algorithm() == kDefaultBlasGemm || 401 algorithm_config.algorithm() == kDefaultBlasGemv) { 402 if (algorithm_config.algorithm() == kDefaultBlasGemv || 403 ShouldUseGemv<T>(n)) { 404 // This is a matrix*vector multiply so use GEMV to compute A * b. 405 // Here we are multiplying in the natural order, so we have to flip 406 // the transposition flag to compensate for the tensor being stored 407 // row-major. 408 // TODO(yangzihao): Add Gemv as an autotuning option too. 409 LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a, 410 transpose_a ? m : k, transpose_a ? k : m, 411 a_ptr, b_ptr, &c_ptr, nullptr); 412 } else { 413 // Use C' = B' x A' (' stands for transpose) 414 bool blas_launch_status = 415 stream 416 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 417 1.0f, b_ptr, transpose_b ? k : n, a_ptr, 418 transpose_a ? m : k, 0.0f, &c_ptr, n) 419 .ok(); 420 if (!blas_launch_status) { 421 ctx->SetStatus(errors::Internal( 422 "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ", 423 a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1), 424 "), m=", m, ", n=", n, ", k=", k)); 425 } 426 } 427 } 428 } 429 430 static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx, 431 std::vector<int64>* algorithms, 432 bool* algorithm_set_flag) { 433 if (*algorithm_set_flag == false) { 434 auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream; 435 stream->parent()->GetBlasGemmAlgorithms(algorithms); 436 *algorithm_set_flag = true; 437 } 438 } 439 }; 440 441 #endif // GOOGLE_CUDA 442 443 template <typename Device, typename T, bool USE_CUBLAS> 444 class MatMulOp : public OpKernel { 445 public: 446 explicit MatMulOp(OpKernelConstruction* ctx) 447 : OpKernel(ctx), algorithms_set_already_(false) { 448 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); 449 OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); 450 451 LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm( 452 ctx, &algorithms_, &algorithms_set_already_); 453 use_autotune_ = MatmulAutotuneEnable(); 454 } 455 456 void Compute(OpKernelContext* ctx) override { 457 const Tensor& a = ctx->input(0); 458 const Tensor& b = ctx->input(1); 459 460 // Check that the dimensions of the two matrices are valid. 461 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()), 462 errors::InvalidArgument("In[0] is not a matrix")); 463 OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()), 464 errors::InvalidArgument("In[1] is not a matrix")); 465 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair; 466 dim_pair[0].first = transpose_a_ ? 0 : 1; 467 dim_pair[0].second = transpose_b_ ? 1 : 0; 468 469 OP_REQUIRES( 470 ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second), 471 errors::InvalidArgument( 472 "Matrix size-incompatible: In[0]: ", a.shape().DebugString(), 473 ", In[1]: ", b.shape().DebugString())); 474 int a_dim_remaining = 1 - dim_pair[0].first; 475 int b_dim_remaining = 1 - dim_pair[0].second; 476 TensorShape out_shape( 477 {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)}); 478 Tensor* out = nullptr; 479 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); 480 481 if (out->NumElements() == 0) { 482 // If a has shape [0, x] or b has shape [x, 0], the output shape 483 // is a 0-element matrix, so there is nothing to do. 484 return; 485 } 486 487 if (a.NumElements() == 0 || b.NumElements() == 0) { 488 // If a has shape [x, 0] and b has shape [0, y], the 489 // output shape is [x, y] where x and y are non-zero, so we fill 490 // the output with zeros. 491 functor::SetZeroFunctor<Device, T> f; 492 f(ctx->eigen_device<Device>(), out->flat<T>()); 493 return; 494 } 495 496 LaunchMatMul<Device, T, USE_CUBLAS>::launch( 497 ctx, a, b, dim_pair, &algorithms_, use_autotune_, out); 498 } 499 500 private: 501 std::vector<int64> algorithms_; 502 bool algorithms_set_already_; 503 bool use_autotune_; 504 bool transpose_a_; 505 bool transpose_b_; 506 }; 507 508 namespace functor { 509 510 // Partial specialization MatMulFunctor<Device=CPUDevice, T>. 511 template <typename T> 512 struct MatMulFunctor<CPUDevice, T> { 513 void operator()( 514 const CPUDevice& d, typename MatMulTypes<T>::out_type out, 515 typename MatMulTypes<T>::in_type in0, 516 typename MatMulTypes<T>::in_type in1, 517 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) { 518 MatMul<CPUDevice>(d, out, in0, in1, dim_pair); 519 } 520 }; 521 522 #ifdef TENSORFLOW_USE_SYCL 523 // Partial specialization MatMulFunctor<Device=SYCLDevice, T>. 524 template <typename T> 525 struct MatMulFunctor<SYCLDevice, T> { 526 void operator()( 527 const SYCLDevice& d, typename MatMulTypes<T>::out_type out, 528 typename MatMulTypes<T>::in_type in0, 529 typename MatMulTypes<T>::in_type in1, 530 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) { 531 MatMul<SYCLDevice>(d, out, in0, in1, dim_pair); 532 } 533 }; 534 #endif // TENSORFLOW_USE_SYCL 535 536 } // end namespace functor 537 538 #define REGISTER_CPU_EIGEN(T) \ 539 REGISTER_KERNEL_BUILDER( \ 540 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \ 541 MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); 542 543 #define REGISTER_CPU(T) \ 544 REGISTER_KERNEL_BUILDER( \ 545 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 546 MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \ 547 REGISTER_CPU_EIGEN(T); 548 549 #define REGISTER_GPU(T) \ 550 REGISTER_KERNEL_BUILDER( \ 551 Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ 552 MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \ 553 REGISTER_KERNEL_BUILDER(Name("MatMul") \ 554 .Device(DEVICE_GPU) \ 555 .TypeConstraint<T>("T") \ 556 .Label("cublas"), \ 557 MatMulOp<GPUDevice, T, true /* cublas */>) 558 559 #if defined(INTEL_MKL) 560 // MKL does not support half and int32 types for matrix-multiplication, so 561 // register the kernel to use default Eigen based implementations for these 562 // types. Registration for NO-LABEL version is in mkl_matmul_op.cc 563 TF_CALL_float(REGISTER_CPU_EIGEN); 564 TF_CALL_double(REGISTER_CPU_EIGEN); 565 TF_CALL_half(REGISTER_CPU); 566 567 TF_CALL_int32(REGISTER_CPU); 568 TF_CALL_complex64(REGISTER_CPU_EIGEN); 569 TF_CALL_complex128(REGISTER_CPU_EIGEN); 570 #else 571 TF_CALL_float(REGISTER_CPU); 572 TF_CALL_double(REGISTER_CPU); 573 TF_CALL_half(REGISTER_CPU); 574 575 TF_CALL_int32(REGISTER_CPU); 576 TF_CALL_complex64(REGISTER_CPU); 577 TF_CALL_complex128(REGISTER_CPU); 578 #endif 579 580 #if GOOGLE_CUDA 581 TF_CALL_float(REGISTER_GPU); 582 TF_CALL_double(REGISTER_GPU); 583 TF_CALL_complex64(REGISTER_GPU); 584 TF_CALL_complex128(REGISTER_GPU); 585 #if CUDA_VERSION >= 7050 586 TF_CALL_half(REGISTER_GPU); 587 #endif 588 #endif // GOOGLE_CUDA 589 590 #ifdef TENSORFLOW_USE_SYCL 591 #define REGISTER_SYCL(T) \ 592 REGISTER_KERNEL_BUILDER( \ 593 Name("MatMul").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \ 594 MatMulOp<SYCLDevice, T, false /* xxblas */>); \ 595 REGISTER_KERNEL_BUILDER(Name("MatMul") \ 596 .Device(DEVICE_SYCL) \ 597 .TypeConstraint<T>("T") \ 598 .Label("eigen"), \ 599 MatMulOp<SYCLDevice, T, false /* xxblas */>) 600 TF_CALL_float(REGISTER_SYCL); 601 TF_CALL_double(REGISTER_SYCL); 602 603 #endif // TENSORFLOW_USE_SYCL 604 } // namespace tensorflow 605