Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/math_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <vector>
     21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/type_traits.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/kernels/fill_functor.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/util/work_sharder.h"
     33 
     34 #if GOOGLE_CUDA
     35 #include "tensorflow/core/platform/stream_executor.h"
     36 #endif  // GOOGLE_CUDA
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 typedef Eigen::GpuDevice GPUDevice;
     42 #ifdef TENSORFLOW_USE_SYCL
     43 typedef Eigen::SyclDevice SYCLDevice;
     44 #endif  // TENSORFLOW_USE_SYCL
     45 
     46 namespace {
     47 
     48 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
     49   if (!adj_x) {
     50     if (!adj_y) {
     51       return Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
     52     } else {
     53       return Eigen::IndexPair<Eigen::DenseIndex>(1, 1);
     54     }
     55   } else {
     56     if (!adj_y) {
     57       return Eigen::IndexPair<Eigen::DenseIndex>(0, 0);
     58     } else {
     59       return Eigen::IndexPair<Eigen::DenseIndex>(0, 1);
     60     }
     61   }
     62 }
     63 
     64 // Parallel batch matmul kernel based on the multi-threaded tensor contraction
     65 // in Eigen.
     66 template <typename Scalar, bool IsComplex = true>
     67 struct ParallelMatMulKernel {
     68   static void Conjugate(const OpKernelContext* context, Tensor* out) {
     69     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
     70     auto z = out->tensor<Scalar, 3>();
     71     z.device(d) = z.conjugate();
     72   }
     73 
     74   static void Run(const OpKernelContext* context, const Tensor& in_x,
     75                   const Tensor in_y, bool adj_x, bool adj_y, Tensor* out,
     76                   int start, int limit) {
     77     static_assert(IsComplex, "Complex type expected.");
     78     auto Tx = in_x.tensor<Scalar, 3>();
     79     auto Ty = in_y.tensor<Scalar, 3>();
     80     auto Tz = out->tensor<Scalar, 3>();
     81     // We use the identities
     82     //   conj(a) * conj(b) = conj(a * b)
     83     //   conj(a) * b = conj(a * conj(b))
     84     // to halve the number of cases. The final conjugation of the result is
     85     // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
     86     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
     87     contract_pairs[0] = ContractionDims(adj_x, adj_y);
     88     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
     89     for (int i = start; i < limit; ++i) {
     90       auto x = Tx.template chip<0>(i);
     91       auto z = Tz.template chip<0>(i);
     92       if (adj_x != adj_y) {
     93         auto y = Ty.template chip<0>(i).conjugate();
     94         z.device(d) = x.contract(y, contract_pairs);
     95       } else {
     96         auto y = Ty.template chip<0>(i);
     97         z.device(d) = x.contract(y, contract_pairs);
     98       }
     99     }
    100   }
    101 };
    102 
    103 // The Eigen contraction kernel used here is very large and slow to compile,
    104 // so we partially specialize ParallelMatMulKernel for real types to avoid all
    105 // but one of the instantiations.
    106 template <typename Scalar>
    107 struct ParallelMatMulKernel<Scalar, false> {
    108   static void Conjugate(const OpKernelContext* context, Tensor* out) {}
    109 
    110   static void Run(const OpKernelContext* context, const Tensor& in_x,
    111                   const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
    112                   int start, int limit) {
    113     auto Tx = in_x.tensor<Scalar, 3>();
    114     auto Ty = in_y.tensor<Scalar, 3>();
    115     auto Tz = out->tensor<Scalar, 3>();
    116     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
    117     contract_pairs[0] = ContractionDims(adj_x, adj_y);
    118     const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
    119     for (int i = start; i < limit; ++i) {
    120       auto x = Tx.template chip<0>(i);
    121       auto y = Ty.template chip<0>(i);
    122       auto z = Tz.template chip<0>(i);
    123       z.device(d) = x.contract(y, contract_pairs);
    124     }
    125   }
    126 };
    127 
    128 // TODO(rmlarsen): Get rid of this when we have upstreamed improvements
    129 // for matrix*vector and vector*matrix to Eigen's general matrix product.
    130 template <typename Tx, typename Ty, typename Tz>
    131 static void Multiply(bool adj_x, bool adj_y, Tx x, Ty y, Tz z) {
    132   if (!adj_x) {
    133     if (!adj_y) {
    134       z.noalias() = x * y;
    135     } else {
    136       z.noalias() = x * y.adjoint();
    137     }
    138   } else {
    139     if (!adj_y) {
    140       z.noalias() = x.adjoint() * y;
    141     } else {
    142       z.noalias() = x.adjoint() * y.adjoint();
    143     }
    144   }
    145 }
    146 
    147 // Sequential batch matmul kernel that calls the regular Eigen matmul.
    148 // We prefer this over the tensor contraction because it performs
    149 // better on vector-matrix and matrix-vector products.
    150 template <typename Scalar>
    151 struct SequentialMatMulKernel {
    152   using Matrix =
    153       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
    154   using ConstMatrixMap = Eigen::Map<const Matrix>;
    155   using MatrixMap = Eigen::Map<Matrix>;
    156 
    157   static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
    158                                                       int slice) {
    159     return ConstMatrixMap(
    160         t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
    161         t.dim_size(1), t.dim_size(2));
    162   }
    163 
    164   static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
    165     return MatrixMap(
    166         t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
    167         t->dim_size(1), t->dim_size(2));
    168   }
    169 
    170   static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
    171                   bool adj_y, Tensor* out, int start, int limit) {
    172     for (int i = start; i < limit; ++i) {
    173       auto x = ConstTensorSliceToEigenMatrix(in_x, i);
    174       auto y = ConstTensorSliceToEigenMatrix(in_y, i);
    175       auto z = TensorSliceToEigenMatrix(out, i);
    176       // TODO(rmlarsen): Get rid of the special casing here when we have
    177       // upstreamed improvements for matrix*vector and vector*matrix to
    178       // Eigen's general matrix product.
    179       if (!adj_x && x.rows() == 1) {
    180         Multiply(adj_x, adj_y, x.row(0), y, z);
    181       } else if (adj_x && x.cols() == 1) {
    182         Multiply(adj_x, adj_y, x.col(0), y, z);
    183       } else if (!adj_y && y.cols() == 1) {
    184         Multiply(adj_x, adj_y, x, y.col(0), z);
    185       } else if (adj_y && y.rows() == 1) {
    186         Multiply(adj_x, adj_y, x, y.row(0), z);
    187       } else {
    188         Multiply(adj_x, adj_y, x, y, z);
    189       }
    190     }
    191   }
    192 };
    193 
    194 }  // namespace
    195 
    196 template <typename Device, typename Scalar>
    197 struct LaunchBatchMatMul;
    198 
    199 template <typename Scalar>
    200 struct LaunchBatchMatMul<CPUDevice, Scalar> {
    201   static void Launch(OpKernelContext* context, const Tensor& in_x,
    202                      const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
    203     typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
    204         ParallelMatMulKernel;
    205     bool conjugate_result = false;
    206 
    207     // Number of matrix multiplies i.e. size of the batch.
    208     const int64 batch_size = in_x.dim_size(0);
    209     const int64 cost_per_unit =
    210         in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
    211     const int64 small_dim = std::min(
    212         std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
    213     const int64 kMaxCostOuterParallelism = 128 * 128 * 256;  // heuristic.
    214     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    215     if (small_dim > 1 &&
    216         (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
    217       // Parallelize over inner dims.
    218       // For large matrix products it is counter-productive to parallelize
    219       // over the batch dimension.
    220       ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, out, 0,
    221                                 batch_size);
    222       conjugate_result = adj_x;
    223     } else {
    224       // Parallelize over outer dims. For small matrices and large batches, it
    225       // is counter-productive to parallelize the inner matrix multiplies.
    226       Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
    227             cost_per_unit,
    228             [&in_x, &in_y, adj_x, adj_y, out](int start, int limit) {
    229               SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y, out,
    230                                                   start, limit);
    231             });
    232     }
    233     if (conjugate_result) {
    234       // We used one of the identities
    235       //   conj(a) * conj(b) = conj(a * b)
    236       //   conj(a) * b = conj(a * conj(b))
    237       // above, we need to conjugate the final output. This is a
    238       // no-op for non-complex types.
    239       ParallelMatMulKernel::Conjugate(context, out);
    240     }
    241   }
    242 };
    243 
    244 #if GOOGLE_CUDA
    245 
    246 namespace {
    247 template <typename T>
    248 perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
    249   perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
    250   perftools::gputools::DeviceMemory<T> typed(wrapped);
    251   return typed;
    252 }
    253 
    254 class CublasScratchAllocator : public perftools::gputools::ScratchAllocator {
    255  public:
    256   using Stream = ::perftools::gputools::Stream;
    257   using DeviceMemoryBytes = ::perftools::gputools::DeviceMemory<uint8>;
    258 
    259   CublasScratchAllocator(OpKernelContext* context) : context_(context) {}
    260 
    261   int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
    262 
    263   perftools::gputools::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
    264       Stream* stream, int64 byte_size) override {
    265     Tensor temporary_memory;
    266 
    267     Status allocation_status(context_->allocate_temp(
    268         DT_UINT8, TensorShape({byte_size}), &temporary_memory));
    269     if (!allocation_status.ok()) {
    270       return perftools::gputools::port::StatusOr<DeviceMemoryBytes>(
    271           DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
    272     }
    273     // Hold the reference of the allocated tensors until the end of the
    274     // allocator.
    275     allocated_tensors_.push_back(temporary_memory);
    276     return perftools::gputools::port::StatusOr<DeviceMemoryBytes>(
    277         DeviceMemoryBytes::MakeFromByteSize(
    278             temporary_memory.flat<uint8>().data(),
    279             temporary_memory.flat<uint8>().size()));
    280   }
    281 
    282  private:
    283   OpKernelContext* context_;
    284   std::vector<Tensor> allocated_tensors_;
    285 };
    286 }  // namespace
    287 
    288 template <typename Scalar>
    289 struct LaunchBatchMatMul<GPUDevice, Scalar> {
    290   static void Launch(OpKernelContext* context, const Tensor& in_x,
    291                      const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
    292     constexpr perftools::gputools::blas::Transpose kTranspose =
    293         is_complex<Scalar>::value
    294             ? perftools::gputools::blas::Transpose::kConjugateTranspose
    295             : perftools::gputools::blas::Transpose::kTranspose;
    296     perftools::gputools::blas::Transpose trans[] = {
    297         perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
    298     const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
    299     const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
    300     const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
    301     const uint64 batch_size = in_x.dim_size(0);
    302     auto blas_transpose_a = trans[adj_x];
    303     auto blas_transpose_b = trans[adj_y];
    304 
    305     auto* stream = context->op_device_context()->stream();
    306     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
    307 
    308     typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
    309     std::vector<DeviceMemoryType> a_device_memory;
    310     std::vector<DeviceMemoryType> b_device_memory;
    311     std::vector<DeviceMemoryType> c_device_memory;
    312     std::vector<DeviceMemoryType*> a_ptrs;
    313     std::vector<DeviceMemoryType*> b_ptrs;
    314     std::vector<DeviceMemoryType*> c_ptrs;
    315     a_device_memory.reserve(batch_size);
    316     b_device_memory.reserve(batch_size);
    317     c_device_memory.reserve(batch_size);
    318     a_ptrs.reserve(batch_size);
    319     b_ptrs.reserve(batch_size);
    320     c_ptrs.reserve(batch_size);
    321     auto* a_base_ptr = in_x.template flat<Scalar>().data();
    322     auto* b_base_ptr = in_y.template flat<Scalar>().data();
    323     auto* c_base_ptr = out->template flat<Scalar>().data();
    324     for (int64 i = 0; i < batch_size; ++i) {
    325       a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
    326       b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
    327       c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
    328       a_ptrs.push_back(&a_device_memory.back());
    329       b_ptrs.push_back(&b_device_memory.back());
    330       c_ptrs.push_back(&c_device_memory.back());
    331     }
    332 
    333     // Cublas does
    334     // C = A x B
    335     // where A, B and C are assumed to be in column major.
    336     // We want the output to be in row-major, so we can compute
    337     // C' = B' x A', where ' stands for transpose (not adjoint).
    338     // TODO(yangzihao): Choose the best of the three strategies using autotune.
    339     if (batch_size == 1) {
    340       // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
    341       // overhead of the scratch allocator and the batch interface.
    342       if (n == 1 &&
    343           blas_transpose_b !=
    344               perftools::gputools::blas::Transpose::kConjugateTranspose &&
    345           blas_transpose_a !=
    346               perftools::gputools::blas::Transpose::kConjugateTranspose) {
    347         // This is a matrix*vector multiply so use GEMV to compute A * b.
    348         // Here we are multiplying in the natural order, so we have to flip
    349         // the transposition flag to compensate for the tensor being stored
    350         // row-major. Since GEMV doesn't provide a way to just conjugate an
    351         // argument, we have to defer those cases to GEMM below.
    352         auto gemv_trans_a =
    353             blas_transpose_a == perftools::gputools::blas::Transpose::kTranspose
    354                 ? perftools::gputools::blas::Transpose::kNoTranspose
    355                 : perftools::gputools::blas::Transpose::kTranspose;
    356         bool blas_launch_status =
    357             stream
    358                 ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
    359                                static_cast<Scalar>(1.0), *(a_ptrs[0]),
    360                                adj_x ? m : k, *(b_ptrs[0]), 1,
    361                                static_cast<Scalar>(0.0), c_ptrs[0], 1)
    362                 .ok();
    363         if (!blas_launch_status) {
    364           context->SetStatus(errors::Internal(
    365               "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(),
    366               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
    367               ", k=", k));
    368         }
    369       } else {
    370         bool blas_launch_status =
    371             stream
    372                 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k,
    373                                static_cast<Scalar>(1.0), *(b_ptrs[0]),
    374                                adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k,
    375                                static_cast<Scalar>(0.0), c_ptrs[0], n)
    376                 .ok();
    377         if (!blas_launch_status) {
    378           context->SetStatus(errors::Internal(
    379               "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(),
    380               ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
    381               ", k=", k));
    382         }
    383       }
    384     } else {
    385       CublasScratchAllocator scratch_allocator(context);
    386       bool blas_launch_status =
    387           stream
    388               ->ThenBlasGemmBatchedWithScratch(
    389                   blas_transpose_b, blas_transpose_a, n, m, k,
    390                   static_cast<Scalar>(1.0), b_ptrs, adj_y ? k : n, a_ptrs,
    391                   adj_x ? m : k, static_cast<Scalar>(0.0), c_ptrs, n,
    392                   batch_size, &scratch_allocator)
    393               .ok();
    394       if (!blas_launch_status) {
    395         context->SetStatus(errors::Internal(
    396             "Blas xGEMMBatched launch failed : a.shape=",
    397             in_x.shape().DebugString(),
    398             ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
    399             ", k=", k, ", batch_size=", batch_size));
    400       }
    401     }
    402   }
    403 };
    404 
    405 #endif  // GOOGLE_CUDA
    406 
    407 #ifdef TENSORFLOW_USE_SYCL
    408 template <typename Scalar>
    409 struct ParallelMatMulKernelSYCL {
    410   static void Run(const OpKernelContext* context, const Tensor& in_x,
    411                   const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out,
    412                   int start, int limit) {
    413     auto Tx = in_x.tensor<Scalar, 3>();
    414     auto Ty = in_y.tensor<Scalar, 3>();
    415     auto Tz = out->tensor<Scalar, 3>();
    416     Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
    417     contract_pairs[0] = ContractionDims(adj_x, adj_y);
    418     auto d = context->eigen_sycl_device();
    419     for (int i = start; i < limit; ++i) {
    420       auto x = Tx.template chip<0>(i);
    421       auto y = Ty.template chip<0>(i);
    422       auto z = Tz.template chip<0>(i);
    423       z.device(d) = x.contract(y, contract_pairs);
    424     }
    425   }
    426 };
    427 
    428 template <typename Scalar>
    429 struct LaunchBatchMatMul<SYCLDevice, Scalar> {
    430   static void Launch(OpKernelContext* context, const Tensor& in_x,
    431                      const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
    432     // Number of matrix multiplies i.e. size of the batch.
    433     const int64 batch_size = in_x.dim_size(0);
    434     ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y,
    435                                           out, 0, batch_size);
    436   }
    437 };
    438 #endif  // TENSORFLOW_USE_SYCL
    439 
    440 template <typename Device, typename Scalar>
    441 class BatchMatMul : public OpKernel {
    442  public:
    443   explicit BatchMatMul(OpKernelConstruction* context) : OpKernel(context) {
    444     OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
    445     OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
    446   }
    447 
    448   virtual ~BatchMatMul() {}
    449 
    450   void Compute(OpKernelContext* ctx) override {
    451     const Tensor& in0 = ctx->input(0);
    452     const Tensor& in1 = ctx->input(1);
    453     OP_REQUIRES(ctx, in0.dims() == in1.dims(),
    454                 errors::InvalidArgument("In[0] and In[1] has different ndims: ",
    455                                         in0.shape().DebugString(), " vs. ",
    456                                         in1.shape().DebugString()));
    457     const int ndims = in0.dims();
    458     OP_REQUIRES(
    459         ctx, ndims >= 2,
    460         errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
    461     TensorShape out_shape;
    462     for (int i = 0; i < ndims - 2; ++i) {
    463       OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i),
    464                   errors::InvalidArgument(
    465                       "In[0].dim(", i, ") and In[1].dim(", i,
    466                       ") must be the same: ", in0.shape().DebugString(), " vs ",
    467                       in1.shape().DebugString()));
    468       out_shape.AddDim(in0.dim_size(i));
    469     }
    470     auto n = (ndims == 2) ? 1 : out_shape.num_elements();
    471     auto d0 = in0.dim_size(ndims - 2);
    472     auto d1 = in0.dim_size(ndims - 1);
    473     Tensor in0_reshaped;
    474     CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1})));
    475     auto d2 = in1.dim_size(ndims - 2);
    476     auto d3 = in1.dim_size(ndims - 1);
    477     Tensor in1_reshaped;
    478     CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3})));
    479     if (adj_x_) std::swap(d0, d1);
    480     if (adj_y_) std::swap(d2, d3);
    481     OP_REQUIRES(ctx, d1 == d2,
    482                 errors::InvalidArgument(
    483                     "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ",
    484                     in0.shape().DebugString(), " ", in1.shape().DebugString(),
    485                     " ", adj_x_, " ", adj_y_));
    486     out_shape.AddDim(d0);
    487     out_shape.AddDim(d3);
    488     Tensor* out = nullptr;
    489     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
    490     if (out->NumElements() == 0) {
    491       return;
    492     }
    493     if (in0.NumElements() == 0 || in1.NumElements() == 0) {
    494       functor::SetZeroFunctor<Device, Scalar> f;
    495       f(ctx->eigen_device<Device>(), out->flat<Scalar>());
    496       return;
    497     }
    498     Tensor out_reshaped;
    499     CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3})));
    500     LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped,
    501                                               adj_x_, adj_y_, &out_reshaped);
    502   }
    503 
    504  private:
    505   bool adj_x_;
    506   bool adj_y_;
    507 };
    508 
    509 #define REGISTER_BATCH_MATMUL_CPU(TYPE)                                 \
    510   REGISTER_KERNEL_BUILDER(                                              \
    511       Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
    512       BatchMatMul<CPUDevice, TYPE>)
    513 
    514 #define REGISTER_BATCH_MATMUL_GPU(TYPE)                                 \
    515   REGISTER_KERNEL_BUILDER(                                              \
    516       Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
    517       BatchMatMul<GPUDevice, TYPE>)
    518 
    519 #ifdef TENSORFLOW_USE_SYCL
    520 #define REGISTER_BATCH_MATMUL_SYCL(TYPE)                                 \
    521   REGISTER_KERNEL_BUILDER(                                               \
    522       Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
    523       BatchMatMul<SYCLDevice, TYPE>)
    524 #endif  // TENSORFLOW_USE_SYCL
    525 }  // end namespace tensorflow
    526