Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/matmul_op.h"
     21 
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/kernels/fill_functor.h"
     26 #include "tensorflow/core/util/matmul_autotune.h"
     27 #if GOOGLE_CUDA
     28 #include "cuda/include/cuda.h"
     29 #include "tensorflow/core/kernels/gpu_utils.h"
     30 #include "tensorflow/core/platform/stream_executor.h"
     31 #endif  // GOOGLE_CUDA
     32 
     33 namespace tensorflow {
     34 
     35 typedef Eigen::ThreadPoolDevice CPUDevice;
     36 typedef Eigen::GpuDevice GPUDevice;
     37 #ifdef TENSORFLOW_USE_SYCL
     38 typedef Eigen::SyclDevice SYCLDevice;
     39 #endif  // TENSORFLOW_USE_SYCL
     40 
     41 template <typename Device, typename T, bool USE_CUBLAS>
     42 struct LaunchMatMul;
     43 
     44 namespace {
     45 // Converts a TensorFlow Tensor to an Eigen Matrix.
     46 template <typename T>
     47 Eigen::Map<
     48     const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
     49 ToEigenMatrix(const Tensor& tensor) {
     50   auto matrix = tensor.matrix<T>();
     51   return Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>::Map(
     52       matrix.data(), matrix.dimension(0), matrix.dimension(1));
     53 }
     54 
     55 // Converts a TensorFlow Tensor to an Eigen Vector.
     56 template <typename T>
     57 Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(Tensor* tensor) {
     58   auto v = tensor->flat<T>();
     59   return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
     60 }
     61 template <typename T>
     62 Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> ToEigenVector(
     63     const Tensor& tensor) {
     64   auto v = tensor.flat<T>();
     65   return Eigen::Matrix<T, Eigen::Dynamic, 1>::Map(v.data(), v.dimension(0));
     66 }
     67 }  // namespace
     68 
     69 // If either side can be represented as a vector, do an explicit vector
     70 // matrix multiply and return true; else return false.
     71 //
     72 // Note: this uses plain Eigen and not Eigen Tensor because it is more
     73 // efficient.
     74 template <typename T>
     75 bool ExplicitVectorMatrixOptimization(
     76     const Tensor& a, const Tensor& b,
     77     const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
     78     Tensor* out) {
     79   if (out->dim_size(0) == 1) {
     80     if (dim_pair[0].second == 0) {
     81       // Note: this case is optimized in Eigen Tensors.
     82       return false;
     83     } else {
     84       auto out_v = ToEigenVector<T>(out);
     85       auto a_v = ToEigenVector<T>(a);
     86       auto b_m = ToEigenMatrix<T>(b);
     87       out_v.noalias() = b_m * a_v;
     88     }
     89     return true;
     90   } else if (out->dim_size(1) == 1) {
     91     auto out_v = ToEigenVector<T>(out);
     92     auto a_m = ToEigenMatrix<T>(a);
     93     auto b_v = ToEigenVector<T>(b);
     94     if (dim_pair[0].first == 0) {
     95       out_v.noalias() = a_m.transpose() * b_v;
     96     } else {
     97       out_v.noalias() = a_m * b_v;
     98     }
     99     return true;
    100   }
    101   return false;
    102 }
    103 // Half is not supported.
    104 template <>
    105 bool ExplicitVectorMatrixOptimization<Eigen::half>(
    106     const Tensor& a, const Tensor& b,
    107     const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
    108     Tensor* out) {
    109   return false;
    110 }
    111 
    112 template <typename Device, typename T>
    113 struct LaunchMatMulBase {
    114 #if GOOGLE_CUDA
    115   typedef perftools::gputools::blas::AlgorithmType AlgorithmType;
    116 #else
    117   typedef int64 AlgorithmType;
    118 #endif  // GOOGLE_CUDA
    119 
    120   static void launch(
    121       OpKernelContext* ctx, const Tensor& a, const Tensor& b,
    122       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
    123       std::vector<AlgorithmType>* algorithms, bool use_aututone, Tensor* out) {
    124 #ifndef TENSORFLOW_USE_SYCL
    125     // An explicit vector-matrix multiply is much better optimized than an
    126     // implicit one and this is a bottleneck during non-batched inference.
    127     bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
    128     if (!was_vector) {
    129 #endif  // TENSORFLOW_USE_SYCL
    130       functor::MatMulFunctor<Device, T>()(ctx->eigen_device<Device>(),
    131                                           out->matrix<T>(), a.matrix<T>(),
    132                                           b.matrix<T>(), dim_pair);
    133 #ifndef TENSORFLOW_USE_SYCL
    134     }
    135 #endif  // TENSORFLOW_USE_SYCL
    136   }
    137 
    138   static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
    139                                    std::vector<int64>* algorithms,
    140                                    bool* algorithm_set_flag) {}
    141 };
    142 // On CPUs, we ignore USE_CUBLAS
    143 template <typename T>
    144 struct LaunchMatMulCPU : LaunchMatMulBase<CPUDevice, T> {};
    145 
    146 template <typename T, bool USE_CUBLAS>
    147 struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};
    148 
    149 #ifdef TENSORFLOW_USE_SYCL
    150 template <typename T>
    151 struct LaunchMatMulSYCL : LaunchMatMulBase<SYCLDevice, T> {};
    152 
    153 template <typename T, bool USE_CUBLAS>
    154 struct LaunchMatMul<SYCLDevice, T, USE_CUBLAS> : public LaunchMatMulSYCL<T> {};
    155 #endif  // TENSORFLOW_USE_SYCL
    156 
    157 #if GOOGLE_CUDA
    158 
    159 namespace {
    160 
    161 template <typename T>
    162 struct LaunchBlasGemv {
    163   static void Compute(
    164       OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
    165       uint64 m, uint64 n, const perftools::gputools::DeviceMemory<T>& a,
    166       const perftools::gputools::DeviceMemory<T>& b,
    167       perftools::gputools::DeviceMemory<T>* c,
    168       perftools::gputools::blas::ProfileResult* output_profile) {
    169     const auto blas_trans =
    170         trans ? perftools::gputools::blas::Transpose::kTranspose
    171               : perftools::gputools::blas::Transpose::kNoTranspose;
    172     if (output_profile == nullptr) {
    173       bool blas_launch_status =
    174           stream
    175               ->ThenBlasGemv(blas_trans, m, n, static_cast<T>(1.0), a, m, b, 1,
    176                              static_cast<T>(0.0), c, 1)
    177               .ok();
    178       if (!blas_launch_status) {
    179         ctx->SetStatus(
    180             errors::Internal("Blas GEMV launch failed:  m=", m, ", n=", n));
    181       }
    182     } else {
    183       bool blas_launch_status =
    184           stream
    185               ->ThenBlasGemvWithProfiling(blas_trans, m, n, static_cast<T>(1.0),
    186                                           a, m, b, 1, static_cast<T>(0.0), c, 1,
    187                                           output_profile)
    188               .ok();
    189       if (!blas_launch_status) {
    190         ctx->SetStatus(errors::Internal(
    191             "Blas GEMV with profiling launch failed:  m=", m, ", n=", n));
    192       }
    193     }
    194   }
    195 
    196   static bool IsSupported() { return true; }
    197 };
    198 
    199 template <>
    200 void LaunchBlasGemv<Eigen::half>::Compute(
    201     OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
    202     uint64 m, uint64 n, const perftools::gputools::DeviceMemory<Eigen::half>& a,
    203     const perftools::gputools::DeviceMemory<Eigen::half>& b,
    204     perftools::gputools::DeviceMemory<Eigen::half>* c,
    205     perftools::gputools::blas::ProfileResult* output_profile) {
    206   ctx->SetStatus(errors::Internal(
    207       "Blas GEMV launch failed: GEMV is not implemented for float16."));
    208 }
    209 
    210 template <>
    211 bool LaunchBlasGemv<Eigen::half>::IsSupported() {
    212   return false;
    213 }
    214 
    215 template <typename T>
    216 bool ShouldUseGemv(uint64 n) {
    217   return (LaunchBlasGemv<T>::IsSupported() && n == 1);
    218 }
    219 
    220 }  // namespace
    221 
    222 bool GetCublasAutotuneComputationType(
    223     const DataType& dtype,
    224     perftools::gputools::blas::ComputationType* compute_type) {
    225   using perftools::gputools::blas::ComputationType;
    226   bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
    227   switch (dtype) {
    228     case DT_HALF:
    229     case DT_BFLOAT16:
    230       if (use_f32_for_f16_computation) {
    231         *compute_type = ComputationType::kF32;
    232       } else {
    233         *compute_type = ComputationType::kF16;
    234       }
    235       return false;
    236     case DT_FLOAT:
    237       *compute_type = ComputationType::kF32;
    238       return true;
    239     case DT_DOUBLE:
    240       *compute_type = ComputationType::kF64;
    241       return true;
    242     default:
    243       // Unsupported compute_type, return false.
    244       return false;
    245   }
    246 }
    247 
    248 // A dummy type to group matmul autotune results together.
    249 struct MatmulAutoTuneGroup {
    250   static string name() { return "Matmul"; }
    251 };
    252 typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
    253                           perftools::gputools::blas::AlgorithmConfig>
    254     AutoTuneMatmul;
    255 
    256 template <typename T>
    257 struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
    258   static void launch(
    259       OpKernelContext* ctx, const Tensor& a, const Tensor& b,
    260       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
    261       std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
    262     using perftools::gputools::blas::AlgorithmConfig;
    263     using perftools::gputools::blas::ComputationType;
    264     using perftools::gputools::blas::kDefaultAlgorithm;
    265     using perftools::gputools::blas::kDefaultBlasGemm;
    266     using perftools::gputools::blas::kDefaultBlasGemv;
    267     using perftools::gputools::blas::kNoAlgorithm;
    268     using perftools::gputools::blas::ProfileResult;
    269     using perftools::gputools::blas::Transpose;
    270     Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
    271     const uint64 m = a.dim_size(1 - dim_pair[0].first);
    272     const uint64 k = a.dim_size(dim_pair[0].first);
    273     const uint64 n = b.dim_size(1 - dim_pair[0].second);
    274     bool transpose_a = dim_pair[0].first == 0;
    275     bool transpose_b = dim_pair[0].second == 1;
    276     auto blas_transpose_a = trans[transpose_a];
    277     auto blas_transpose_b = trans[transpose_b];
    278 
    279     auto* stream = ctx->op_device_context()->stream();
    280     OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
    281 
    282     auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
    283                                 a.template flat<T>().size());
    284     auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
    285                                 b.template flat<T>().size());
    286     auto c_ptr = AsDeviceMemory(out->template flat<T>().data(),
    287                                 out->template flat<T>().size());
    288     auto alpha = static_cast<T>(1.0);
    289     auto beta = static_cast<T>(0.0);
    290 
    291     int device_id = stream->parent()->device_ordinal();
    292     DataType dtype = a.dtype();
    293     MatmulParameters matmul_parameters = {
    294         transpose_a, transpose_b, m, n, k, dtype, device_id,
    295     };
    296     AlgorithmConfig algorithm_config(kNoAlgorithm);
    297 
    298     ComputationType computation_type;
    299     bool compute_type_supported =
    300         GetCublasAutotuneComputationType(dtype, &computation_type);
    301     if (use_autotune && compute_type_supported && !algorithms->empty()) {
    302       ProfileResult best_result;
    303       // TODO(yangzihao): Unify this code with conv autotuning.
    304       if (!AutoTuneMatmul::GetInstance()->Find(matmul_parameters,
    305                                                &algorithm_config)) {
    306         ProfileResult profile_result;
    307         for (auto profile_algorithm : (*algorithms)) {
    308           // Cublas does
    309           // C = A x B
    310           // where A, B and C are assumed to be in column major.
    311           // We want the output to be in row-major, so we can compute
    312           // C' = B' x A' (' stands for transpose)
    313           bool cublas_launch_status =
    314               stream
    315                   ->ThenBlasGemmWithAlgorithm(
    316                       blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
    317                       transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
    318                       &c_ptr, n, computation_type, profile_algorithm,
    319                       &profile_result)
    320                   .ok();
    321           if (cublas_launch_status) {
    322             if (profile_result.is_valid()) {
    323               if (profile_result.elapsed_time_in_ms() <
    324                   best_result.elapsed_time_in_ms()) {
    325                 best_result = profile_result;
    326               }
    327             }
    328           }
    329         }
    330         // Try BlasGemmWithProfiling
    331         bool cublas_launch_status =
    332             stream
    333                 ->ThenBlasGemmWithProfiling(
    334                     blas_transpose_b, blas_transpose_a, n, m, k, 1.0, b_ptr,
    335                     transpose_b ? k : n, a_ptr, transpose_a ? m : k, 0.0,
    336                     &c_ptr, n, &profile_result)
    337                 .ok();
    338         if (cublas_launch_status) {
    339           if (profile_result.is_valid()) {
    340             if (profile_result.elapsed_time_in_ms() <
    341                 best_result.elapsed_time_in_ms()) {
    342               best_result = profile_result;
    343             }
    344           }
    345         }
    346         // Try BlasGemvWithProfiling
    347         if (ShouldUseGemv<T>(n)) {
    348           LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
    349                                      transpose_a ? m : k, transpose_a ? k : m,
    350                                      a_ptr, b_ptr, &c_ptr, &profile_result);
    351           if (profile_result.is_valid()) {
    352             if (profile_result.elapsed_time_in_ms() <
    353                 best_result.elapsed_time_in_ms()) {
    354               best_result = profile_result;
    355             }
    356           }
    357         }
    358       }
    359       // We make sure that each matmul parameter set only gets one pass of
    360       // autotune. If the best result is found, assign it to algorithm_type
    361       // and insert it to autotune map. If all internal kernels of
    362       // cublasGemmEx() returns invalid results, we add kNoAlgorithm to the
    363       // autotune map.
    364       if (best_result.is_valid()) {
    365         algorithm_config.set_algorithm(best_result.algorithm());
    366       }
    367       AutoTuneMatmul::GetInstance()->Insert(matmul_parameters,
    368                                             algorithm_config);
    369       if (algorithm_config.algorithm() != kNoAlgorithm &&
    370           algorithm_config.algorithm() != kDefaultBlasGemm &&
    371           algorithm_config.algorithm() != kDefaultBlasGemv) {
    372         bool cublas_launch_status =
    373             stream
    374                 ->ThenBlasGemmWithAlgorithm(
    375                     blas_transpose_b, blas_transpose_a, n, m, k, alpha, b_ptr,
    376                     transpose_b ? k : n, a_ptr, transpose_a ? m : k, beta,
    377                     &c_ptr, n, computation_type, algorithm_config.algorithm(),
    378                     nullptr)
    379                 .ok();
    380         if (!cublas_launch_status) {
    381           ctx->SetStatus(errors::Internal(
    382               "Blas GEMM with algorithm launch failed : a.shape=(",
    383               a.dim_size(0), ", ", a.dim_size(1), "), b.shape=(", b.dim_size(0),
    384               ", ", b.dim_size(1), "), m=", m, ", n=", n, ", k=", k));
    385         }
    386       }
    387     }
    388     // For the following case, we use normal BlasGemm():
    389     //  1) We didn't set the use_autotune flag;
    390     //  2) compute type does not support autotune;
    391     //  3) no algorithm is found;
    392     //  4) all internal kernels in autotune return invalid results.
    393     //  For the following case, we use normal BlasGemv():
    394     //  1) We didn't set the use_autotune flag but LaunchBlasGemv is supported
    395     //     and n == 1.
    396     //  2) We set the use_autotune flag and it picked up BlasGemv() and set the
    397     //     algorithm_config.algorithm() to be kDefaultBlasGemv.
    398     if (!use_autotune || !compute_type_supported || algorithms->empty() ||
    399         algorithm_config.algorithm() == kNoAlgorithm ||
    400         algorithm_config.algorithm() == kDefaultBlasGemm ||
    401         algorithm_config.algorithm() == kDefaultBlasGemv) {
    402       if (algorithm_config.algorithm() == kDefaultBlasGemv ||
    403           ShouldUseGemv<T>(n)) {
    404         // This is a matrix*vector multiply so use GEMV to compute A * b.
    405         // Here we are multiplying in the natural order, so we have to flip
    406         // the transposition flag to compensate for the tensor being stored
    407         // row-major.
    408         // TODO(yangzihao): Add Gemv as an autotuning option too.
    409         LaunchBlasGemv<T>::Compute(ctx, stream, !transpose_a,
    410                                    transpose_a ? m : k, transpose_a ? k : m,
    411                                    a_ptr, b_ptr, &c_ptr, nullptr);
    412       } else {
    413         // Use C' = B' x A' (' stands for transpose)
    414         bool blas_launch_status =
    415             stream
    416                 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
    417                                1.0f, b_ptr, transpose_b ? k : n, a_ptr,
    418                                transpose_a ? m : k, 0.0f, &c_ptr, n)
    419                 .ok();
    420         if (!blas_launch_status) {
    421           ctx->SetStatus(errors::Internal(
    422               "Blas GEMM launch failed : a.shape=(", a.dim_size(0), ", ",
    423               a.dim_size(1), "), b.shape=(", b.dim_size(0), ", ", b.dim_size(1),
    424               "), m=", m, ", n=", n, ", k=", k));
    425         }
    426       }
    427     }
    428   }
    429 
    430   static void GetBlasGemmAlgorithm(OpKernelConstruction* ctx,
    431                                    std::vector<int64>* algorithms,
    432                                    bool* algorithm_set_flag) {
    433     if (*algorithm_set_flag == false) {
    434       auto* stream = ctx->device()->tensorflow_gpu_device_info()->stream;
    435       stream->parent()->GetBlasGemmAlgorithms(algorithms);
    436       *algorithm_set_flag = true;
    437     }
    438   }
    439 };
    440 
    441 #endif  // GOOGLE_CUDA
    442 
    443 template <typename Device, typename T, bool USE_CUBLAS>
    444 class MatMulOp : public OpKernel {
    445  public:
    446   explicit MatMulOp(OpKernelConstruction* ctx)
    447       : OpKernel(ctx), algorithms_set_already_(false) {
    448     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
    449     OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
    450 
    451     LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
    452         ctx, &algorithms_, &algorithms_set_already_);
    453     use_autotune_ = MatmulAutotuneEnable();
    454   }
    455 
    456   void Compute(OpKernelContext* ctx) override {
    457     const Tensor& a = ctx->input(0);
    458     const Tensor& b = ctx->input(1);
    459 
    460     // Check that the dimensions of the two matrices are valid.
    461     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
    462                 errors::InvalidArgument("In[0] is not a matrix"));
    463     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
    464                 errors::InvalidArgument("In[1] is not a matrix"));
    465     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
    466     dim_pair[0].first = transpose_a_ ? 0 : 1;
    467     dim_pair[0].second = transpose_b_ ? 1 : 0;
    468 
    469     OP_REQUIRES(
    470         ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
    471         errors::InvalidArgument(
    472             "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
    473             ", In[1]: ", b.shape().DebugString()));
    474     int a_dim_remaining = 1 - dim_pair[0].first;
    475     int b_dim_remaining = 1 - dim_pair[0].second;
    476     TensorShape out_shape(
    477         {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
    478     Tensor* out = nullptr;
    479     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
    480 
    481     if (out->NumElements() == 0) {
    482       // If a has shape [0, x] or b has shape [x, 0], the output shape
    483       // is a 0-element matrix, so there is nothing to do.
    484       return;
    485     }
    486 
    487     if (a.NumElements() == 0 || b.NumElements() == 0) {
    488       // If a has shape [x, 0] and b has shape [0, y], the
    489       // output shape is [x, y] where x and y are non-zero, so we fill
    490       // the output with zeros.
    491       functor::SetZeroFunctor<Device, T> f;
    492       f(ctx->eigen_device<Device>(), out->flat<T>());
    493       return;
    494     }
    495 
    496     LaunchMatMul<Device, T, USE_CUBLAS>::launch(
    497         ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
    498   }
    499 
    500  private:
    501   std::vector<int64> algorithms_;
    502   bool algorithms_set_already_;
    503   bool use_autotune_;
    504   bool transpose_a_;
    505   bool transpose_b_;
    506 };
    507 
    508 namespace functor {
    509 
    510 // Partial specialization MatMulFunctor<Device=CPUDevice, T>.
    511 template <typename T>
    512 struct MatMulFunctor<CPUDevice, T> {
    513   void operator()(
    514       const CPUDevice& d, typename MatMulTypes<T>::out_type out,
    515       typename MatMulTypes<T>::in_type in0,
    516       typename MatMulTypes<T>::in_type in1,
    517       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
    518     MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
    519   }
    520 };
    521 
    522 #ifdef TENSORFLOW_USE_SYCL
    523 // Partial specialization MatMulFunctor<Device=SYCLDevice, T>.
    524 template <typename T>
    525 struct MatMulFunctor<SYCLDevice, T> {
    526   void operator()(
    527       const SYCLDevice& d, typename MatMulTypes<T>::out_type out,
    528       typename MatMulTypes<T>::in_type in0,
    529       typename MatMulTypes<T>::in_type in1,
    530       const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
    531     MatMul<SYCLDevice>(d, out, in0, in1, dim_pair);
    532   }
    533 };
    534 #endif  // TENSORFLOW_USE_SYCL
    535 
    536 }  // end namespace functor
    537 
    538 #define REGISTER_CPU_EIGEN(T)                                                  \
    539   REGISTER_KERNEL_BUILDER(                                                     \
    540       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T").Label("eigen"), \
    541       MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
    542 
    543 #define REGISTER_CPU(T)                                             \
    544   REGISTER_KERNEL_BUILDER(                                          \
    545       Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"),     \
    546       MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \
    547   REGISTER_CPU_EIGEN(T);
    548 
    549 #define REGISTER_GPU(T)                                            \
    550   REGISTER_KERNEL_BUILDER(                                         \
    551       Name("MatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"),    \
    552       MatMulOp<GPUDevice, T, true /* cublas, true by default */>); \
    553   REGISTER_KERNEL_BUILDER(Name("MatMul")                           \
    554                               .Device(DEVICE_GPU)                  \
    555                               .TypeConstraint<T>("T")              \
    556                               .Label("cublas"),                    \
    557                           MatMulOp<GPUDevice, T, true /* cublas */>)
    558 
    559 #if defined(INTEL_MKL)
    560 // MKL does not support half and int32 types for matrix-multiplication, so
    561 // register the kernel to use default Eigen based implementations for these
    562 // types. Registration for NO-LABEL version is in mkl_matmul_op.cc
    563 TF_CALL_float(REGISTER_CPU_EIGEN);
    564 TF_CALL_double(REGISTER_CPU_EIGEN);
    565 TF_CALL_half(REGISTER_CPU);
    566 
    567 TF_CALL_int32(REGISTER_CPU);
    568 TF_CALL_complex64(REGISTER_CPU_EIGEN);
    569 TF_CALL_complex128(REGISTER_CPU_EIGEN);
    570 #else
    571 TF_CALL_float(REGISTER_CPU);
    572 TF_CALL_double(REGISTER_CPU);
    573 TF_CALL_half(REGISTER_CPU);
    574 
    575 TF_CALL_int32(REGISTER_CPU);
    576 TF_CALL_complex64(REGISTER_CPU);
    577 TF_CALL_complex128(REGISTER_CPU);
    578 #endif
    579 
    580 #if GOOGLE_CUDA
    581 TF_CALL_float(REGISTER_GPU);
    582 TF_CALL_double(REGISTER_GPU);
    583 TF_CALL_complex64(REGISTER_GPU);
    584 TF_CALL_complex128(REGISTER_GPU);
    585 #if CUDA_VERSION >= 7050
    586 TF_CALL_half(REGISTER_GPU);
    587 #endif
    588 #endif  // GOOGLE_CUDA
    589 
    590 #ifdef TENSORFLOW_USE_SYCL
    591 #define REGISTER_SYCL(T)                                         \
    592   REGISTER_KERNEL_BUILDER(                                       \
    593       Name("MatMul").Device(DEVICE_SYCL).TypeConstraint<T>("T"), \
    594       MatMulOp<SYCLDevice, T, false /* xxblas */>);              \
    595   REGISTER_KERNEL_BUILDER(Name("MatMul")                         \
    596                               .Device(DEVICE_SYCL)               \
    597                               .TypeConstraint<T>("T")            \
    598                               .Label("eigen"),                   \
    599                           MatMulOp<SYCLDevice, T, false /* xxblas */>)
    600 TF_CALL_float(REGISTER_SYCL);
    601 TF_CALL_double(REGISTER_SYCL);
    602 
    603 #endif  // TENSORFLOW_USE_SYCL
    604 }  // namespace tensorflow
    605