1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/math_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include <vector> 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/register_types.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/type_traits.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/kernels/fill_functor.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/util/work_sharder.h" 33 34 #if GOOGLE_CUDA 35 #include "tensorflow/core/platform/stream_executor.h" 36 #endif // GOOGLE_CUDA 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 typedef Eigen::GpuDevice GPUDevice; 42 #ifdef TENSORFLOW_USE_SYCL 43 typedef Eigen::SyclDevice SYCLDevice; 44 #endif // TENSORFLOW_USE_SYCL 45 46 namespace { 47 48 Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) { 49 if (!adj_x) { 50 if (!adj_y) { 51 return Eigen::IndexPair<Eigen::DenseIndex>(1, 0); 52 } else { 53 return Eigen::IndexPair<Eigen::DenseIndex>(1, 1); 54 } 55 } else { 56 if (!adj_y) { 57 return Eigen::IndexPair<Eigen::DenseIndex>(0, 0); 58 } else { 59 return Eigen::IndexPair<Eigen::DenseIndex>(0, 1); 60 } 61 } 62 } 63 64 // Parallel batch matmul kernel based on the multi-threaded tensor contraction 65 // in Eigen. 66 template <typename Scalar, bool IsComplex = true> 67 struct ParallelMatMulKernel { 68 static void Conjugate(const OpKernelContext* context, Tensor* out) { 69 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); 70 auto z = out->tensor<Scalar, 3>(); 71 z.device(d) = z.conjugate(); 72 } 73 74 static void Run(const OpKernelContext* context, const Tensor& in_x, 75 const Tensor in_y, bool adj_x, bool adj_y, Tensor* out, 76 int start, int limit) { 77 static_assert(IsComplex, "Complex type expected."); 78 auto Tx = in_x.tensor<Scalar, 3>(); 79 auto Ty = in_y.tensor<Scalar, 3>(); 80 auto Tz = out->tensor<Scalar, 3>(); 81 // We use the identities 82 // conj(a) * conj(b) = conj(a * b) 83 // conj(a) * b = conj(a * conj(b)) 84 // to halve the number of cases. The final conjugation of the result is 85 // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch(). 86 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; 87 contract_pairs[0] = ContractionDims(adj_x, adj_y); 88 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); 89 for (int i = start; i < limit; ++i) { 90 auto x = Tx.template chip<0>(i); 91 auto z = Tz.template chip<0>(i); 92 if (adj_x != adj_y) { 93 auto y = Ty.template chip<0>(i).conjugate(); 94 z.device(d) = x.contract(y, contract_pairs); 95 } else { 96 auto y = Ty.template chip<0>(i); 97 z.device(d) = x.contract(y, contract_pairs); 98 } 99 } 100 } 101 }; 102 103 // The Eigen contraction kernel used here is very large and slow to compile, 104 // so we partially specialize ParallelMatMulKernel for real types to avoid all 105 // but one of the instantiations. 106 template <typename Scalar> 107 struct ParallelMatMulKernel<Scalar, false> { 108 static void Conjugate(const OpKernelContext* context, Tensor* out) {} 109 110 static void Run(const OpKernelContext* context, const Tensor& in_x, 111 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out, 112 int start, int limit) { 113 auto Tx = in_x.tensor<Scalar, 3>(); 114 auto Ty = in_y.tensor<Scalar, 3>(); 115 auto Tz = out->tensor<Scalar, 3>(); 116 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; 117 contract_pairs[0] = ContractionDims(adj_x, adj_y); 118 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); 119 for (int i = start; i < limit; ++i) { 120 auto x = Tx.template chip<0>(i); 121 auto y = Ty.template chip<0>(i); 122 auto z = Tz.template chip<0>(i); 123 z.device(d) = x.contract(y, contract_pairs); 124 } 125 } 126 }; 127 128 // TODO(rmlarsen): Get rid of this when we have upstreamed improvements 129 // for matrix*vector and vector*matrix to Eigen's general matrix product. 130 template <typename Tx, typename Ty, typename Tz> 131 static void Multiply(bool adj_x, bool adj_y, Tx x, Ty y, Tz z) { 132 if (!adj_x) { 133 if (!adj_y) { 134 z.noalias() = x * y; 135 } else { 136 z.noalias() = x * y.adjoint(); 137 } 138 } else { 139 if (!adj_y) { 140 z.noalias() = x.adjoint() * y; 141 } else { 142 z.noalias() = x.adjoint() * y.adjoint(); 143 } 144 } 145 } 146 147 // Sequential batch matmul kernel that calls the regular Eigen matmul. 148 // We prefer this over the tensor contraction because it performs 149 // better on vector-matrix and matrix-vector products. 150 template <typename Scalar> 151 struct SequentialMatMulKernel { 152 using Matrix = 153 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; 154 using ConstMatrixMap = Eigen::Map<const Matrix>; 155 using MatrixMap = Eigen::Map<Matrix>; 156 157 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t, 158 int slice) { 159 return ConstMatrixMap( 160 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2), 161 t.dim_size(1), t.dim_size(2)); 162 } 163 164 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) { 165 return MatrixMap( 166 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2), 167 t->dim_size(1), t->dim_size(2)); 168 } 169 170 static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x, 171 bool adj_y, Tensor* out, int start, int limit) { 172 for (int i = start; i < limit; ++i) { 173 auto x = ConstTensorSliceToEigenMatrix(in_x, i); 174 auto y = ConstTensorSliceToEigenMatrix(in_y, i); 175 auto z = TensorSliceToEigenMatrix(out, i); 176 // TODO(rmlarsen): Get rid of the special casing here when we have 177 // upstreamed improvements for matrix*vector and vector*matrix to 178 // Eigen's general matrix product. 179 if (!adj_x && x.rows() == 1) { 180 Multiply(adj_x, adj_y, x.row(0), y, z); 181 } else if (adj_x && x.cols() == 1) { 182 Multiply(adj_x, adj_y, x.col(0), y, z); 183 } else if (!adj_y && y.cols() == 1) { 184 Multiply(adj_x, adj_y, x, y.col(0), z); 185 } else if (adj_y && y.rows() == 1) { 186 Multiply(adj_x, adj_y, x, y.row(0), z); 187 } else { 188 Multiply(adj_x, adj_y, x, y, z); 189 } 190 } 191 } 192 }; 193 194 } // namespace 195 196 template <typename Device, typename Scalar> 197 struct LaunchBatchMatMul; 198 199 template <typename Scalar> 200 struct LaunchBatchMatMul<CPUDevice, Scalar> { 201 static void Launch(OpKernelContext* context, const Tensor& in_x, 202 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { 203 typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex> 204 ParallelMatMulKernel; 205 bool conjugate_result = false; 206 207 // Number of matrix multiplies i.e. size of the batch. 208 const int64 batch_size = in_x.dim_size(0); 209 const int64 cost_per_unit = 210 in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2); 211 const int64 small_dim = std::min( 212 std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2)); 213 const int64 kMaxCostOuterParallelism = 128 * 128 * 256; // heuristic. 214 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 215 if (small_dim > 1 && 216 (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) { 217 // Parallelize over inner dims. 218 // For large matrix products it is counter-productive to parallelize 219 // over the batch dimension. 220 ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, out, 0, 221 batch_size); 222 conjugate_result = adj_x; 223 } else { 224 // Parallelize over outer dims. For small matrices and large batches, it 225 // is counter-productive to parallelize the inner matrix multiplies. 226 Shard(worker_threads.num_threads, worker_threads.workers, batch_size, 227 cost_per_unit, 228 [&in_x, &in_y, adj_x, adj_y, out](int start, int limit) { 229 SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y, out, 230 start, limit); 231 }); 232 } 233 if (conjugate_result) { 234 // We used one of the identities 235 // conj(a) * conj(b) = conj(a * b) 236 // conj(a) * b = conj(a * conj(b)) 237 // above, we need to conjugate the final output. This is a 238 // no-op for non-complex types. 239 ParallelMatMulKernel::Conjugate(context, out); 240 } 241 } 242 }; 243 244 #if GOOGLE_CUDA 245 246 namespace { 247 template <typename T> 248 perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { 249 perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); 250 perftools::gputools::DeviceMemory<T> typed(wrapped); 251 return typed; 252 } 253 254 class CublasScratchAllocator : public perftools::gputools::ScratchAllocator { 255 public: 256 using Stream = ::perftools::gputools::Stream; 257 using DeviceMemoryBytes = ::perftools::gputools::DeviceMemory<uint8>; 258 259 CublasScratchAllocator(OpKernelContext* context) : context_(context) {} 260 261 int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; } 262 263 perftools::gputools::port::StatusOr<DeviceMemoryBytes> AllocateBytes( 264 Stream* stream, int64 byte_size) override { 265 Tensor temporary_memory; 266 267 Status allocation_status(context_->allocate_temp( 268 DT_UINT8, TensorShape({byte_size}), &temporary_memory)); 269 if (!allocation_status.ok()) { 270 return perftools::gputools::port::StatusOr<DeviceMemoryBytes>( 271 DeviceMemoryBytes::MakeFromByteSize(nullptr, 0)); 272 } 273 // Hold the reference of the allocated tensors until the end of the 274 // allocator. 275 allocated_tensors_.push_back(temporary_memory); 276 return perftools::gputools::port::StatusOr<DeviceMemoryBytes>( 277 DeviceMemoryBytes::MakeFromByteSize( 278 temporary_memory.flat<uint8>().data(), 279 temporary_memory.flat<uint8>().size())); 280 } 281 282 private: 283 OpKernelContext* context_; 284 std::vector<Tensor> allocated_tensors_; 285 }; 286 } // namespace 287 288 template <typename Scalar> 289 struct LaunchBatchMatMul<GPUDevice, Scalar> { 290 static void Launch(OpKernelContext* context, const Tensor& in_x, 291 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { 292 constexpr perftools::gputools::blas::Transpose kTranspose = 293 is_complex<Scalar>::value 294 ? perftools::gputools::blas::Transpose::kConjugateTranspose 295 : perftools::gputools::blas::Transpose::kTranspose; 296 perftools::gputools::blas::Transpose trans[] = { 297 perftools::gputools::blas::Transpose::kNoTranspose, kTranspose}; 298 const uint64 m = in_x.dim_size(adj_x ? 2 : 1); 299 const uint64 k = in_x.dim_size(adj_x ? 1 : 2); 300 const uint64 n = in_y.dim_size(adj_y ? 1 : 2); 301 const uint64 batch_size = in_x.dim_size(0); 302 auto blas_transpose_a = trans[adj_x]; 303 auto blas_transpose_b = trans[adj_y]; 304 305 auto* stream = context->op_device_context()->stream(); 306 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); 307 308 typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType; 309 std::vector<DeviceMemoryType> a_device_memory; 310 std::vector<DeviceMemoryType> b_device_memory; 311 std::vector<DeviceMemoryType> c_device_memory; 312 std::vector<DeviceMemoryType*> a_ptrs; 313 std::vector<DeviceMemoryType*> b_ptrs; 314 std::vector<DeviceMemoryType*> c_ptrs; 315 a_device_memory.reserve(batch_size); 316 b_device_memory.reserve(batch_size); 317 c_device_memory.reserve(batch_size); 318 a_ptrs.reserve(batch_size); 319 b_ptrs.reserve(batch_size); 320 c_ptrs.reserve(batch_size); 321 auto* a_base_ptr = in_x.template flat<Scalar>().data(); 322 auto* b_base_ptr = in_y.template flat<Scalar>().data(); 323 auto* c_base_ptr = out->template flat<Scalar>().data(); 324 for (int64 i = 0; i < batch_size; ++i) { 325 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); 326 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); 327 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); 328 a_ptrs.push_back(&a_device_memory.back()); 329 b_ptrs.push_back(&b_device_memory.back()); 330 c_ptrs.push_back(&c_device_memory.back()); 331 } 332 333 // Cublas does 334 // C = A x B 335 // where A, B and C are assumed to be in column major. 336 // We want the output to be in row-major, so we can compute 337 // C' = B' x A', where ' stands for transpose (not adjoint). 338 // TODO(yangzihao): Choose the best of the three strategies using autotune. 339 if (batch_size == 1) { 340 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the 341 // overhead of the scratch allocator and the batch interface. 342 if (n == 1 && 343 blas_transpose_b != 344 perftools::gputools::blas::Transpose::kConjugateTranspose && 345 blas_transpose_a != 346 perftools::gputools::blas::Transpose::kConjugateTranspose) { 347 // This is a matrix*vector multiply so use GEMV to compute A * b. 348 // Here we are multiplying in the natural order, so we have to flip 349 // the transposition flag to compensate for the tensor being stored 350 // row-major. Since GEMV doesn't provide a way to just conjugate an 351 // argument, we have to defer those cases to GEMM below. 352 auto gemv_trans_a = 353 blas_transpose_a == perftools::gputools::blas::Transpose::kTranspose 354 ? perftools::gputools::blas::Transpose::kNoTranspose 355 : perftools::gputools::blas::Transpose::kTranspose; 356 bool blas_launch_status = 357 stream 358 ->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m, 359 static_cast<Scalar>(1.0), *(a_ptrs[0]), 360 adj_x ? m : k, *(b_ptrs[0]), 1, 361 static_cast<Scalar>(0.0), c_ptrs[0], 1) 362 .ok(); 363 if (!blas_launch_status) { 364 context->SetStatus(errors::Internal( 365 "Blas xGEMV launch failed : a.shape=", in_x.shape().DebugString(), 366 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, 367 ", k=", k)); 368 } 369 } else { 370 bool blas_launch_status = 371 stream 372 ->ThenBlasGemm(blas_transpose_b, blas_transpose_a, n, m, k, 373 static_cast<Scalar>(1.0), *(b_ptrs[0]), 374 adj_y ? k : n, *(a_ptrs[0]), adj_x ? m : k, 375 static_cast<Scalar>(0.0), c_ptrs[0], n) 376 .ok(); 377 if (!blas_launch_status) { 378 context->SetStatus(errors::Internal( 379 "Blas xGEMM launch failed : a.shape=", in_x.shape().DebugString(), 380 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, 381 ", k=", k)); 382 } 383 } 384 } else { 385 CublasScratchAllocator scratch_allocator(context); 386 bool blas_launch_status = 387 stream 388 ->ThenBlasGemmBatchedWithScratch( 389 blas_transpose_b, blas_transpose_a, n, m, k, 390 static_cast<Scalar>(1.0), b_ptrs, adj_y ? k : n, a_ptrs, 391 adj_x ? m : k, static_cast<Scalar>(0.0), c_ptrs, n, 392 batch_size, &scratch_allocator) 393 .ok(); 394 if (!blas_launch_status) { 395 context->SetStatus(errors::Internal( 396 "Blas xGEMMBatched launch failed : a.shape=", 397 in_x.shape().DebugString(), 398 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, 399 ", k=", k, ", batch_size=", batch_size)); 400 } 401 } 402 } 403 }; 404 405 #endif // GOOGLE_CUDA 406 407 #ifdef TENSORFLOW_USE_SYCL 408 template <typename Scalar> 409 struct ParallelMatMulKernelSYCL { 410 static void Run(const OpKernelContext* context, const Tensor& in_x, 411 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out, 412 int start, int limit) { 413 auto Tx = in_x.tensor<Scalar, 3>(); 414 auto Ty = in_y.tensor<Scalar, 3>(); 415 auto Tz = out->tensor<Scalar, 3>(); 416 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; 417 contract_pairs[0] = ContractionDims(adj_x, adj_y); 418 auto d = context->eigen_sycl_device(); 419 for (int i = start; i < limit; ++i) { 420 auto x = Tx.template chip<0>(i); 421 auto y = Ty.template chip<0>(i); 422 auto z = Tz.template chip<0>(i); 423 z.device(d) = x.contract(y, contract_pairs); 424 } 425 } 426 }; 427 428 template <typename Scalar> 429 struct LaunchBatchMatMul<SYCLDevice, Scalar> { 430 static void Launch(OpKernelContext* context, const Tensor& in_x, 431 const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) { 432 // Number of matrix multiplies i.e. size of the batch. 433 const int64 batch_size = in_x.dim_size(0); 434 ParallelMatMulKernelSYCL<Scalar>::Run(context, in_x, in_y, adj_x, adj_y, 435 out, 0, batch_size); 436 } 437 }; 438 #endif // TENSORFLOW_USE_SYCL 439 440 template <typename Device, typename Scalar> 441 class BatchMatMul : public OpKernel { 442 public: 443 explicit BatchMatMul(OpKernelConstruction* context) : OpKernel(context) { 444 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); 445 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); 446 } 447 448 virtual ~BatchMatMul() {} 449 450 void Compute(OpKernelContext* ctx) override { 451 const Tensor& in0 = ctx->input(0); 452 const Tensor& in1 = ctx->input(1); 453 OP_REQUIRES(ctx, in0.dims() == in1.dims(), 454 errors::InvalidArgument("In[0] and In[1] has different ndims: ", 455 in0.shape().DebugString(), " vs. ", 456 in1.shape().DebugString())); 457 const int ndims = in0.dims(); 458 OP_REQUIRES( 459 ctx, ndims >= 2, 460 errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims)); 461 TensorShape out_shape; 462 for (int i = 0; i < ndims - 2; ++i) { 463 OP_REQUIRES(ctx, in0.dim_size(i) == in1.dim_size(i), 464 errors::InvalidArgument( 465 "In[0].dim(", i, ") and In[1].dim(", i, 466 ") must be the same: ", in0.shape().DebugString(), " vs ", 467 in1.shape().DebugString())); 468 out_shape.AddDim(in0.dim_size(i)); 469 } 470 auto n = (ndims == 2) ? 1 : out_shape.num_elements(); 471 auto d0 = in0.dim_size(ndims - 2); 472 auto d1 = in0.dim_size(ndims - 1); 473 Tensor in0_reshaped; 474 CHECK(in0_reshaped.CopyFrom(in0, TensorShape({n, d0, d1}))); 475 auto d2 = in1.dim_size(ndims - 2); 476 auto d3 = in1.dim_size(ndims - 1); 477 Tensor in1_reshaped; 478 CHECK(in1_reshaped.CopyFrom(in1, TensorShape({n, d2, d3}))); 479 if (adj_x_) std::swap(d0, d1); 480 if (adj_y_) std::swap(d2, d3); 481 OP_REQUIRES(ctx, d1 == d2, 482 errors::InvalidArgument( 483 "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ", 484 in0.shape().DebugString(), " ", in1.shape().DebugString(), 485 " ", adj_x_, " ", adj_y_)); 486 out_shape.AddDim(d0); 487 out_shape.AddDim(d3); 488 Tensor* out = nullptr; 489 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); 490 if (out->NumElements() == 0) { 491 return; 492 } 493 if (in0.NumElements() == 0 || in1.NumElements() == 0) { 494 functor::SetZeroFunctor<Device, Scalar> f; 495 f(ctx->eigen_device<Device>(), out->flat<Scalar>()); 496 return; 497 } 498 Tensor out_reshaped; 499 CHECK(out_reshaped.CopyFrom(*out, TensorShape({n, d0, d3}))); 500 LaunchBatchMatMul<Device, Scalar>::Launch(ctx, in0_reshaped, in1_reshaped, 501 adj_x_, adj_y_, &out_reshaped); 502 } 503 504 private: 505 bool adj_x_; 506 bool adj_y_; 507 }; 508 509 #define REGISTER_BATCH_MATMUL_CPU(TYPE) \ 510 REGISTER_KERNEL_BUILDER( \ 511 Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ 512 BatchMatMul<CPUDevice, TYPE>) 513 514 #define REGISTER_BATCH_MATMUL_GPU(TYPE) \ 515 REGISTER_KERNEL_BUILDER( \ 516 Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ 517 BatchMatMul<GPUDevice, TYPE>) 518 519 #ifdef TENSORFLOW_USE_SYCL 520 #define REGISTER_BATCH_MATMUL_SYCL(TYPE) \ 521 REGISTER_KERNEL_BUILDER( \ 522 Name("BatchMatMul").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \ 523 BatchMatMul<SYCLDevice, TYPE>) 524 #endif // TENSORFLOW_USE_SYCL 525 } // end namespace tensorflow 526