1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Exposes the family of BLAS routines as pre-canned high performance calls for 17 // use in conjunction with the StreamExecutor abstraction. 18 // 19 // Note that this interface is optionally supported by platforms; see 20 // StreamExecutor::SupportsBlas() for details. 21 // 22 // This abstraction makes it simple to entrain BLAS operations on GPU data into 23 // a Stream -- users typically will not use this API directly, but will use the 24 // Stream builder methods to entrain these operations "under the hood". For 25 // example: 26 // 27 // DeviceMemory<float> x = stream_exec->AllocateArray<float>(1024); 28 // DeviceMemory<float> y = stream_exec->AllocateArray<float>(1024); 29 // // ... populate x and y ... 30 // Stream stream{stream_exec}; 31 // stream 32 // .Init() 33 // .ThenBlasAxpy(1024, 5.5, x, 1, &y, 1); 34 // SE_CHECK_OK(stream.BlockHostUntilDone()); 35 // 36 // By using stream operations in this manner the user can easily intermix custom 37 // kernel launches (via StreamExecutor::ThenLaunch()) with these pre-canned BLAS 38 // routines. 39 40 #ifndef TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 41 #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 42 43 #include <complex> 44 #include <vector> 45 46 #include "tensorflow/stream_executor/host_or_device_scalar.h" 47 #include "tensorflow/stream_executor/lib/array_slice.h" 48 #include "tensorflow/stream_executor/platform/port.h" 49 50 namespace Eigen { 51 struct half; 52 } // namespace Eigen 53 54 namespace stream_executor { 55 56 class Stream; 57 class ScratchAllocator; 58 59 template <typename ElemT> 60 class DeviceMemory; 61 62 namespace blas { 63 64 // Specifies whether the input matrix will be transposed or 65 // transposed+conjugated before any BLAS operations. 66 enum class Transpose { kNoTranspose, kTranspose, kConjugateTranspose }; 67 68 // Returns a name for t. 69 string TransposeString(Transpose t); 70 71 // Specifies whether the upper or lower triangular part of a 72 // symmetric/Hermitian matrix is used. 73 enum class UpperLower { kUpper, kLower }; 74 75 // Returns a name for ul. 76 string UpperLowerString(UpperLower ul); 77 78 // Specifies whether a matrix is unit triangular. 79 enum class Diagonal { kUnit, kNonUnit }; 80 81 // Returns a name for d. 82 string DiagonalString(Diagonal d); 83 84 // Specifies whether a Hermitian matrix appears on the left or right in 85 // operation. 86 enum class Side { kLeft, kRight }; 87 88 // Returns a name for s. 89 string SideString(Side s); 90 91 // Type with which intermediate computations of a blas routine are performed. 92 // 93 // Some blas calls can perform computations with a type that's different than 94 // the type of their inputs/outputs. This lets you e.g. multiply two matricies 95 // of int8s using float32s to store the matmul's intermediate values. 96 enum class ComputationType { 97 kF16, // 16-bit floating-point 98 kF32, // 32-bit floating-point 99 kF64, // 64-bit floating-point 100 kI32, // 32-bit integer 101 kComplexF32, // Complex number comprised of two f32s. 102 kComplexF64, // Complex number comprised of two f64s. 103 }; 104 105 // Converts a ComputationType to a string. 106 string ComputationTypeString(ComputationType ty); 107 108 std::ostream &operator<<(std::ostream &os, ComputationType ty); 109 110 // Opaque identifier for an "algorithm" used by a blas routine. This functions 111 // as a hint to the blas library. 112 typedef int64 AlgorithmType; 113 constexpr AlgorithmType kDefaultAlgorithm = -1; 114 constexpr AlgorithmType kDefaultBlasGemm = -2; 115 constexpr AlgorithmType kDefaultBlasGemv = -3; 116 constexpr AlgorithmType kNoAlgorithm = -4; 117 118 // blas uses -1 to represent the default algorithm. This happens to match up 119 // with the CUBLAS_GEMM_DFALT constant, so cuda_blas.cc is using static_cast 120 // to convert from AlgorithmType to cublasGemmAlgo_t, and uses a static_assert 121 // to ensure that this assumption does not break. 122 // If another blas implementation uses a different value for the default 123 // algorithm, then it needs to convert kDefaultGemmAlgo to that value 124 // (e.g. via a function called ToWhateverGemmAlgo). 125 constexpr AlgorithmType kDefaultGemmAlgo = -1; 126 127 // Describes the result of a performance experiment, usually timing the speed of 128 // a particular AlgorithmType. 129 // 130 // If the call we were benchmarking failed (a common occurrence; not all 131 // algorithms are valid for all calls), is_valid() will be false. 132 class ProfileResult { 133 public: 134 bool is_valid() const { return is_valid_; } 135 void set_is_valid(bool val) { is_valid_ = val; } 136 AlgorithmType algorithm() const { return algorithm_; } 137 void set_algorithm(AlgorithmType val) { algorithm_ = val; } 138 float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } 139 void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } 140 141 private: 142 bool is_valid_ = false; 143 AlgorithmType algorithm_ = kDefaultAlgorithm; 144 float elapsed_time_in_ms_ = std::numeric_limits<float>::max(); 145 }; 146 147 class AlgorithmConfig { 148 public: 149 AlgorithmConfig() : algorithm_(kDefaultAlgorithm) {} 150 explicit AlgorithmConfig(AlgorithmType algorithm) : algorithm_(algorithm) {} 151 AlgorithmType algorithm() const { return algorithm_; } 152 void set_algorithm(AlgorithmType val) { algorithm_ = val; } 153 bool operator==(const AlgorithmConfig &other) const { 154 return this->algorithm_ == other.algorithm_; 155 } 156 bool operator!=(const AlgorithmConfig &other) const { 157 return !(*this == other); 158 } 159 string ToString() const; 160 161 private: 162 AlgorithmType algorithm_; 163 }; 164 165 // BLAS support interface -- this can be derived from a GPU executor when the 166 // underlying platform has an BLAS library implementation available. See 167 // StreamExecutor::AsBlas(). 168 // 169 // Thread-hostile: CUDA associates a CUDA-context with a particular thread in 170 // the system. Any operation that a user attempts to perform by enqueueing BLAS 171 // operations on a thread not-associated with the CUDA-context has unknown 172 // behavior at the current time; see b/13176597 173 class BlasSupport { 174 public: 175 virtual ~BlasSupport() {} 176 177 // Computes the sum of magnitudes of the vector elements. 178 // result <- |Re x(1)| + |Im x(1)| + |Re x(2)| + |Im x(2)|+ ... + |Re x(n)| 179 // + |Im x(n)|. 180 // Note that Im x(i) = 0 for real types float/double. 181 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 182 const DeviceMemory<float> &x, int incx, 183 DeviceMemory<float> *result) = 0; 184 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 185 const DeviceMemory<double> &x, int incx, 186 DeviceMemory<double> *result) = 0; 187 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 188 const DeviceMemory<std::complex<float>> &x, int incx, 189 DeviceMemory<float> *result) = 0; 190 virtual bool DoBlasAsum(Stream *stream, uint64 elem_count, 191 const DeviceMemory<std::complex<double>> &x, int incx, 192 DeviceMemory<double> *result) = 0; 193 194 // Performs a BLAS y <- ax+y operation. 195 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, 196 const DeviceMemory<float> &x, int incx, 197 DeviceMemory<float> *y, int incy) = 0; 198 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, 199 const DeviceMemory<double> &x, int incx, 200 DeviceMemory<double> *y, int incy) = 0; 201 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, 202 std::complex<float> alpha, 203 const DeviceMemory<std::complex<float>> &x, int incx, 204 DeviceMemory<std::complex<float>> *y, int incy) = 0; 205 virtual bool DoBlasAxpy(Stream *stream, uint64 elem_count, 206 std::complex<double> alpha, 207 const DeviceMemory<std::complex<double>> &x, int incx, 208 DeviceMemory<std::complex<double>> *y, int incy) = 0; 209 210 // Copies vector to another vector: y <- x. 211 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 212 const DeviceMemory<float> &x, int incx, 213 DeviceMemory<float> *y, int incy) = 0; 214 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 215 const DeviceMemory<double> &x, int incx, 216 DeviceMemory<double> *y, int incy) = 0; 217 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 218 const DeviceMemory<std::complex<float>> &x, int incx, 219 DeviceMemory<std::complex<float>> *y, int incy) = 0; 220 virtual bool DoBlasCopy(Stream *stream, uint64 elem_count, 221 const DeviceMemory<std::complex<double>> &x, int incx, 222 DeviceMemory<std::complex<double>> *y, int incy) = 0; 223 224 // Performs a BLAS dot product result <- x . y. 225 virtual bool DoBlasDot(Stream *stream, uint64 elem_count, 226 const DeviceMemory<float> &x, int incx, 227 const DeviceMemory<float> &y, int incy, 228 DeviceMemory<float> *result) = 0; 229 virtual bool DoBlasDot(Stream *stream, uint64 elem_count, 230 const DeviceMemory<double> &x, int incx, 231 const DeviceMemory<double> &y, int incy, 232 DeviceMemory<double> *result) = 0; 233 234 // Performs a BLAS dot product result <- conj(x) . y for complex types. 235 virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, 236 const DeviceMemory<std::complex<float>> &x, int incx, 237 const DeviceMemory<std::complex<float>> &y, int incy, 238 DeviceMemory<std::complex<float>> *result) = 0; 239 virtual bool DoBlasDotc(Stream *stream, uint64 elem_count, 240 const DeviceMemory<std::complex<double>> &x, int incx, 241 const DeviceMemory<std::complex<double>> &y, int incy, 242 DeviceMemory<std::complex<double>> *result) = 0; 243 244 // Performs a BLAS dot product result <- x . y for complex types. Note that 245 // x is unconjugated in this routine. 246 virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, 247 const DeviceMemory<std::complex<float>> &x, int incx, 248 const DeviceMemory<std::complex<float>> &y, int incy, 249 DeviceMemory<std::complex<float>> *result) = 0; 250 virtual bool DoBlasDotu(Stream *stream, uint64 elem_count, 251 const DeviceMemory<std::complex<double>> &x, int incx, 252 const DeviceMemory<std::complex<double>> &y, int incy, 253 DeviceMemory<std::complex<double>> *result) = 0; 254 255 // Computes the Euclidean norm of a vector: result <- ||x||. 256 // See the following link for more information of Euclidean norm: 257 // http://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm 258 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 259 const DeviceMemory<float> &x, int incx, 260 DeviceMemory<float> *result) = 0; 261 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 262 const DeviceMemory<double> &x, int incx, 263 DeviceMemory<double> *result) = 0; 264 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 265 const DeviceMemory<std::complex<float>> &x, int incx, 266 DeviceMemory<float> *result) = 0; 267 virtual bool DoBlasNrm2(Stream *stream, uint64 elem_count, 268 const DeviceMemory<std::complex<double>> &x, int incx, 269 DeviceMemory<double> *result) = 0; 270 271 // Performs rotation of points in the plane: 272 // x(i) = c*x(i) + s*y(i) 273 // y(i) = c*y(i) - s*x(i). 274 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 275 DeviceMemory<float> *x, int incx, 276 DeviceMemory<float> *y, int incy, float c, 277 float s) = 0; 278 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 279 DeviceMemory<double> *x, int incx, 280 DeviceMemory<double> *y, int incy, double c, 281 double s) = 0; 282 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 283 DeviceMemory<std::complex<float>> *x, int incx, 284 DeviceMemory<std::complex<float>> *y, int incy, 285 float c, float s) = 0; 286 virtual bool DoBlasRot(Stream *stream, uint64 elem_count, 287 DeviceMemory<std::complex<double>> *x, int incx, 288 DeviceMemory<std::complex<double>> *y, int incy, 289 double c, double s) = 0; 290 291 // Computes the parameters for a Givens rotation. 292 // Given the Cartesian coordinates (a, b) of a point, these routines return 293 // the parameters c, s, r, and z associated with the Givens rotation. The 294 // parameters c and s define a unitary matrix such that: 295 // 296 // | c s |.| a | = | r | 297 // | -s c | | b | | 0 | 298 // 299 // The parameter z is defined such that if |a| > |b|, z is s; otherwise if 300 // c is not 0 z is 1/c; otherwise z is 1. 301 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, 302 DeviceMemory<float> *b, DeviceMemory<float> *c, 303 DeviceMemory<float> *s) = 0; 304 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, 305 DeviceMemory<double> *b, DeviceMemory<double> *c, 306 DeviceMemory<double> *s) = 0; 307 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, 308 DeviceMemory<std::complex<float>> *b, 309 DeviceMemory<float> *c, 310 DeviceMemory<std::complex<float>> *s) = 0; 311 virtual bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, 312 DeviceMemory<std::complex<double>> *b, 313 DeviceMemory<double> *c, 314 DeviceMemory<std::complex<double>> *s) = 0; 315 316 // Performs modified Givens rotation of points in the plane. 317 // Given two vectors x and y, each vector element of these vectors is replaced 318 // as follows: 319 // 320 // | x(i) | = H | x(i) | 321 // | y(i) | | y(i) | 322 // 323 // for i=1 to n, where H is a modified Givens transformation matrix whose 324 // values are stored in the param[1] through param[4] array. 325 // For more information please Google this routine. 326 virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, 327 DeviceMemory<float> *x, int incx, 328 DeviceMemory<float> *y, int incy, 329 const DeviceMemory<float> ¶m) = 0; 330 virtual bool DoBlasRotm(Stream *stream, uint64 elem_count, 331 DeviceMemory<double> *x, int incx, 332 DeviceMemory<double> *y, int incy, 333 const DeviceMemory<double> ¶m) = 0; 334 335 // Computes the parameters for a modified Givens rotation. 336 // Given Cartesian coordinates (x1, y1) of an input vector, these routines 337 // compute the components of a modified Givens transformation matrix H that 338 // zeros the y-component of the resulting vector: 339 // 340 // | x1 | = H | x1 * sqrt(d1) | 341 // | 0 | | y1 * sqrt(d1) | 342 // 343 // For more information please Google this routine. 344 virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, 345 DeviceMemory<float> *d2, DeviceMemory<float> *x1, 346 const DeviceMemory<float> &y1, 347 DeviceMemory<float> *param) = 0; 348 virtual bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, 349 DeviceMemory<double> *d2, DeviceMemory<double> *x1, 350 const DeviceMemory<double> &y1, 351 DeviceMemory<double> *param) = 0; 352 353 // Computes the product of a vector by a scalar: x <- a*x. 354 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 355 DeviceMemory<float> *x, int incx) = 0; 356 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 357 DeviceMemory<double> *x, int incx) = 0; 358 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, 359 DeviceMemory<std::complex<float>> *x, int incx) = 0; 360 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, 361 DeviceMemory<std::complex<double>> *x, int incx) = 0; 362 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, 363 std::complex<float> alpha, 364 DeviceMemory<std::complex<float>> *x, int incx) = 0; 365 virtual bool DoBlasScal(Stream *stream, uint64 elem_count, 366 std::complex<double> alpha, 367 DeviceMemory<std::complex<double>> *x, int incx) = 0; 368 369 // Swaps a vector with another vector. 370 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 371 DeviceMemory<float> *x, int incx, 372 DeviceMemory<float> *y, int incy) = 0; 373 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 374 DeviceMemory<double> *x, int incx, 375 DeviceMemory<double> *y, int incy) = 0; 376 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 377 DeviceMemory<std::complex<float>> *x, int incx, 378 DeviceMemory<std::complex<float>> *y, int incy) = 0; 379 virtual bool DoBlasSwap(Stream *stream, uint64 elem_count, 380 DeviceMemory<std::complex<double>> *x, int incx, 381 DeviceMemory<std::complex<double>> *y, int incy) = 0; 382 383 // Finds the index of the element with maximum absolute value. 384 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 385 const DeviceMemory<float> &x, int incx, 386 DeviceMemory<int> *result) = 0; 387 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 388 const DeviceMemory<double> &x, int incx, 389 DeviceMemory<int> *result) = 0; 390 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 391 const DeviceMemory<std::complex<float>> &x, int incx, 392 DeviceMemory<int> *result) = 0; 393 virtual bool DoBlasIamax(Stream *stream, uint64 elem_count, 394 const DeviceMemory<std::complex<double>> &x, 395 int incx, DeviceMemory<int> *result) = 0; 396 397 // Finds the index of the element with minimum absolute value. 398 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 399 const DeviceMemory<float> &x, int incx, 400 DeviceMemory<int> *result) = 0; 401 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 402 const DeviceMemory<double> &x, int incx, 403 DeviceMemory<int> *result) = 0; 404 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 405 const DeviceMemory<std::complex<float>> &x, int incx, 406 DeviceMemory<int> *result) = 0; 407 virtual bool DoBlasIamin(Stream *stream, uint64 elem_count, 408 const DeviceMemory<std::complex<double>> &x, 409 int incx, DeviceMemory<int> *result) = 0; 410 411 // Computes a matrix-vector product using a general band matrix: 412 // 413 // y <- alpha * a * x + beta * y, 414 // or 415 // y <- alpha * a' * x + beta * y, 416 // or 417 // y <- alpha * conj(a') * x + beta * y, 418 // 419 // alpha and beta are scalars; a is an m-by-n general band matrix, with kl 420 // sub-diagonals and ku super-diagonals; x is a vector with 421 // n(trans==kNoTranspose)/m(otherwise) elements; 422 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 423 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 424 uint64 n, uint64 kl, uint64 ku, float alpha, 425 const DeviceMemory<float> &a, int lda, 426 const DeviceMemory<float> &x, int incx, float beta, 427 DeviceMemory<float> *y, int incy) = 0; 428 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 429 uint64 n, uint64 kl, uint64 ku, double alpha, 430 const DeviceMemory<double> &a, int lda, 431 const DeviceMemory<double> &x, int incx, double beta, 432 DeviceMemory<double> *y, int incy) = 0; 433 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 434 uint64 n, uint64 kl, uint64 ku, 435 std::complex<float> alpha, 436 const DeviceMemory<std::complex<float>> &a, int lda, 437 const DeviceMemory<std::complex<float>> &x, int incx, 438 std::complex<float> beta, 439 DeviceMemory<std::complex<float>> *y, int incy) = 0; 440 virtual bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, 441 uint64 n, uint64 kl, uint64 ku, 442 std::complex<double> alpha, 443 const DeviceMemory<std::complex<double>> &a, int lda, 444 const DeviceMemory<std::complex<double>> &x, int incx, 445 std::complex<double> beta, 446 DeviceMemory<std::complex<double>> *y, int incy) = 0; 447 448 // Computes a matrix-vector product using a general matrix. 449 // 450 // y <- alpha * a * x + beta * y, 451 // or 452 // y <- alpha * a' * x + beta * y, 453 // or 454 // y <- alpha * conj(a') * x + beta * y, 455 // 456 // alpha and beta are scalars; a is an m-by-n general matrix; x is a vector 457 // with n(trans==kNoTranspose)/m(otherwise) elements; 458 // y is a vector with m(trans==kNoTranspose)/n(otherwise) elements. 459 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 460 uint64 n, float alpha, const DeviceMemory<float> &a, 461 int lda, const DeviceMemory<float> &x, int incx, 462 float beta, DeviceMemory<float> *y, int incy) = 0; 463 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 464 uint64 n, double alpha, const DeviceMemory<double> &a, 465 int lda, const DeviceMemory<double> &x, int incx, 466 double beta, DeviceMemory<double> *y, int incy) = 0; 467 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 468 uint64 n, std::complex<float> alpha, 469 const DeviceMemory<std::complex<float>> &a, int lda, 470 const DeviceMemory<std::complex<float>> &x, int incx, 471 std::complex<float> beta, 472 DeviceMemory<std::complex<float>> *y, int incy) = 0; 473 virtual bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, 474 uint64 n, std::complex<double> alpha, 475 const DeviceMemory<std::complex<double>> &a, int lda, 476 const DeviceMemory<std::complex<double>> &x, int incx, 477 std::complex<double> beta, 478 DeviceMemory<std::complex<double>> *y, int incy) = 0; 479 480 virtual bool DoBlasGemvWithProfiling( 481 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, 482 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, 483 int incx, float beta, DeviceMemory<float> *y, int incy, 484 ProfileResult *output_profile_result) = 0; 485 virtual bool DoBlasGemvWithProfiling( 486 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, 487 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, 488 int incx, double beta, DeviceMemory<double> *y, int incy, 489 ProfileResult *output_profile_result) = 0; 490 virtual bool DoBlasGemvWithProfiling( 491 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 492 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, 493 int lda, const DeviceMemory<std::complex<float>> &x, int incx, 494 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, 495 ProfileResult *output_profile_result) = 0; 496 virtual bool DoBlasGemvWithProfiling( 497 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, 498 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, 499 int lda, const DeviceMemory<std::complex<double>> &x, int incx, 500 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, 501 int incy, ProfileResult *output_profile_result) = 0; 502 503 // Performs a rank-1 update of a general matrix. 504 // 505 // a <- alpha * x * y' + a, 506 // 507 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 508 // an m-by-n general matrix. 509 virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, 510 const DeviceMemory<float> &x, int incx, 511 const DeviceMemory<float> &y, int incy, 512 DeviceMemory<float> *a, int lda) = 0; 513 virtual bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, 514 const DeviceMemory<double> &x, int incx, 515 const DeviceMemory<double> &y, int incy, 516 DeviceMemory<double> *a, int lda) = 0; 517 518 // Performs a rank-1 update (conjugated) of a general matrix. 519 // 520 // a <- alpha * x * conj(y') + a, 521 // 522 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 523 // an m-by-n general matrix. 524 virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, 525 std::complex<float> alpha, 526 const DeviceMemory<std::complex<float>> &x, int incx, 527 const DeviceMemory<std::complex<float>> &y, int incy, 528 DeviceMemory<std::complex<float>> *a, int lda) = 0; 529 virtual bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, 530 std::complex<double> alpha, 531 const DeviceMemory<std::complex<double>> &x, int incx, 532 const DeviceMemory<std::complex<double>> &y, int incy, 533 DeviceMemory<std::complex<double>> *a, int lda) = 0; 534 535 // Performs a rank-1 update (unconjugated) of a general matrix. 536 // 537 // a <- alpha * x * y' + a, 538 // 539 // alpha is a scalar; x is an m-element vector; y is an n-element vector; a is 540 // an m-by-n general matrix. 541 virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, 542 std::complex<float> alpha, 543 const DeviceMemory<std::complex<float>> &x, int incx, 544 const DeviceMemory<std::complex<float>> &y, int incy, 545 DeviceMemory<std::complex<float>> *a, int lda) = 0; 546 virtual bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, 547 std::complex<double> alpha, 548 const DeviceMemory<std::complex<double>> &x, int incx, 549 const DeviceMemory<std::complex<double>> &y, int incy, 550 DeviceMemory<std::complex<double>> *a, int lda) = 0; 551 552 // Computes a matrix-vector product using a Hermitian band matrix. 553 // 554 // y <- alpha * a * x + beta * y, 555 // 556 // alpha and beta are scalars; a is an n-by-n Hermitian band matrix, with k 557 // super-diagonals; x and y are n-element vectors. 558 virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 559 uint64 k, std::complex<float> alpha, 560 const DeviceMemory<std::complex<float>> &a, int lda, 561 const DeviceMemory<std::complex<float>> &x, int incx, 562 std::complex<float> beta, 563 DeviceMemory<std::complex<float>> *y, int incy) = 0; 564 virtual bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 565 uint64 k, std::complex<double> alpha, 566 const DeviceMemory<std::complex<double>> &a, int lda, 567 const DeviceMemory<std::complex<double>> &x, int incx, 568 std::complex<double> beta, 569 DeviceMemory<std::complex<double>> *y, int incy) = 0; 570 571 // Computes a matrix-vector product using a Hermitian matrix. 572 // 573 // y <- alpha * a * x + beta * y, 574 // 575 // alpha and beta are scalars; a is an n-by-n Hermitian matrix; x and y are 576 // n-element vectors. 577 virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 578 std::complex<float> alpha, 579 const DeviceMemory<std::complex<float>> &a, int lda, 580 const DeviceMemory<std::complex<float>> &x, int incx, 581 std::complex<float> beta, 582 DeviceMemory<std::complex<float>> *y, int incy) = 0; 583 virtual bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, 584 std::complex<double> alpha, 585 const DeviceMemory<std::complex<double>> &a, int lda, 586 const DeviceMemory<std::complex<double>> &x, int incx, 587 std::complex<double> beta, 588 DeviceMemory<std::complex<double>> *y, int incy) = 0; 589 590 // Performs a rank-1 update of a Hermitian matrix. 591 // 592 // a <- alpha * x * conj(x') + a, 593 // 594 // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian 595 // matrix. 596 virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 597 float alpha, 598 const DeviceMemory<std::complex<float>> &x, int incx, 599 DeviceMemory<std::complex<float>> *a, int lda) = 0; 600 virtual bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, 601 double alpha, 602 const DeviceMemory<std::complex<double>> &x, int incx, 603 DeviceMemory<std::complex<double>> *a, int lda) = 0; 604 605 // Performs a rank-2 update of a Hermitian matrix. 606 // 607 // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, 608 // 609 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian 610 // matrix. 611 virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 612 std::complex<float> alpha, 613 const DeviceMemory<std::complex<float>> &x, int incx, 614 const DeviceMemory<std::complex<float>> &y, int incy, 615 DeviceMemory<std::complex<float>> *a, int lda) = 0; 616 virtual bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, 617 std::complex<double> alpha, 618 const DeviceMemory<std::complex<double>> &x, int incx, 619 const DeviceMemory<std::complex<double>> &y, int incy, 620 DeviceMemory<std::complex<double>> *a, int lda) = 0; 621 622 // Computes a matrix-vector product using a Hermitian packed matrix. 623 // 624 // y <- alpha * a * x + beta * y, 625 // 626 // alpha and beta are scalars; a is an n-by-n Hermitian matrix, supplied in 627 // packed form; x and y are n-element vectors. 628 virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 629 std::complex<float> alpha, 630 const DeviceMemory<std::complex<float>> &ap, 631 const DeviceMemory<std::complex<float>> &x, int incx, 632 std::complex<float> beta, 633 DeviceMemory<std::complex<float>> *y, int incy) = 0; 634 virtual bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 635 std::complex<double> alpha, 636 const DeviceMemory<std::complex<double>> &ap, 637 const DeviceMemory<std::complex<double>> &x, int incx, 638 std::complex<double> beta, 639 DeviceMemory<std::complex<double>> *y, int incy) = 0; 640 641 // Performs a rank-1 update of a Hermitian packed matrix. 642 // 643 // a <- alpha * x * conj(x') + a, 644 // 645 // alpha is a scalar; x is an n-element vector; a is an n-by-n Hermitian 646 // matrix, supplied in packed form. 647 virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 648 float alpha, 649 const DeviceMemory<std::complex<float>> &x, int incx, 650 DeviceMemory<std::complex<float>> *ap) = 0; 651 virtual bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, 652 double alpha, 653 const DeviceMemory<std::complex<double>> &x, int incx, 654 DeviceMemory<std::complex<double>> *ap) = 0; 655 656 // Performs a rank-2 update of a Hermitian packed matrix. 657 // 658 // a <- alpha * x * conj(x') + conj(alpha) * y * conj(x') + a, 659 // 660 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n Hermitian 661 // matrix, supplied in packed form. 662 virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 663 std::complex<float> alpha, 664 const DeviceMemory<std::complex<float>> &x, int incx, 665 const DeviceMemory<std::complex<float>> &y, int incy, 666 DeviceMemory<std::complex<float>> *ap) = 0; 667 virtual bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 668 std::complex<double> alpha, 669 const DeviceMemory<std::complex<double>> &x, int incx, 670 const DeviceMemory<std::complex<double>> &y, int incy, 671 DeviceMemory<std::complex<double>> *ap) = 0; 672 673 // Computes a matrix-vector product using a symmetric band matrix. 674 // 675 // y <- alpha * a * x + beta * y, 676 // 677 // alpha and beta are scalars; a is an n-by-n symmetric band matrix, with k 678 // super-diagonals; x and y are n-element vectors. 679 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 680 uint64 k, float alpha, const DeviceMemory<float> &a, 681 int lda, const DeviceMemory<float> &x, int incx, 682 float beta, DeviceMemory<float> *y, int incy) = 0; 683 virtual bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, 684 uint64 k, double alpha, const DeviceMemory<double> &a, 685 int lda, const DeviceMemory<double> &x, int incx, 686 double beta, DeviceMemory<double> *y, int incy) = 0; 687 688 // Computes a matrix-vector product using a symmetric packed matrix. 689 // 690 // y <- alpha * a * x + beta * y, 691 // 692 // alpha and beta are scalars; a is an n-by-n symmetric matrix, supplied in 693 // packed form; x and y are n-element vectors. 694 virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 695 float alpha, const DeviceMemory<float> &ap, 696 const DeviceMemory<float> &x, int incx, float beta, 697 DeviceMemory<float> *y, int incy) = 0; 698 virtual bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, 699 double alpha, const DeviceMemory<double> &ap, 700 const DeviceMemory<double> &x, int incx, double beta, 701 DeviceMemory<double> *y, int incy) = 0; 702 703 // Performs a rank-1 update of a symmetric packed matrix. 704 // 705 // a <- alpha * x * x' + a, 706 // 707 // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric 708 // matrix, supplied in packed form. 709 virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 710 float alpha, const DeviceMemory<float> &x, int incx, 711 DeviceMemory<float> *ap) = 0; 712 virtual bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, 713 double alpha, const DeviceMemory<double> &x, int incx, 714 DeviceMemory<double> *ap) = 0; 715 716 // Performs a rank-2 update of a symmetric packed matrix. 717 // 718 // a <- alpha * x * x' + alpha * y * x' + a, 719 // 720 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric 721 // matrix, supplied in packed form. 722 virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 723 float alpha, const DeviceMemory<float> &x, int incx, 724 const DeviceMemory<float> &y, int incy, 725 DeviceMemory<float> *ap) = 0; 726 virtual bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, 727 double alpha, const DeviceMemory<double> &x, int incx, 728 const DeviceMemory<double> &y, int incy, 729 DeviceMemory<double> *ap) = 0; 730 731 // Computes a matrix-vector product for a symmetric matrix. 732 // 733 // y <- alpha * a * x + beta * y, 734 // 735 // alpha and beta are scalars; a is an n-by-n symmetric matrix; x and y are 736 // n-element vectors. 737 virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 738 float alpha, const DeviceMemory<float> &a, int lda, 739 const DeviceMemory<float> &x, int incx, float beta, 740 DeviceMemory<float> *y, int incy) = 0; 741 virtual bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, 742 double alpha, const DeviceMemory<double> &a, int lda, 743 const DeviceMemory<double> &x, int incx, double beta, 744 DeviceMemory<double> *y, int incy) = 0; 745 746 // Performs a rank-1 update of a symmetric matrix. 747 // 748 // a <- alpha * x * x' + a, 749 // 750 // alpha is a scalar; x is an n-element vector; a is an n-by-n symmetric 751 // matrix. 752 virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 753 float alpha, const DeviceMemory<float> &x, int incx, 754 DeviceMemory<float> *a, int lda) = 0; 755 virtual bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, 756 double alpha, const DeviceMemory<double> &x, int incx, 757 DeviceMemory<double> *a, int lda) = 0; 758 759 // Performs a rank-2 update of symmetric matrix. 760 // 761 // a <- alpha * x * x' + alpha * y * x' + a, 762 // 763 // alpha is a scalar; x and y are n-element vectors; a is an n-by-n symmetric 764 // matrix. 765 virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 766 float alpha, const DeviceMemory<float> &x, int incx, 767 const DeviceMemory<float> &y, int incy, 768 DeviceMemory<float> *a, int lda) = 0; 769 virtual bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, 770 double alpha, const DeviceMemory<double> &x, int incx, 771 const DeviceMemory<double> &y, int incy, 772 DeviceMemory<double> *a, int lda) = 0; 773 774 // Computes a matrix-vector product using a triangular band matrix. 775 // 776 // x <- a * x, 777 // or 778 // x <- a' * x, 779 // or 780 // x <- conj(a') * x, 781 // 782 // a is an n-by-n unit, or non-unit, upper or lower triangular band matrix, 783 // with k+1 diagonals; x is a n-element vector. 784 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 785 blas::Transpose trans, blas::Diagonal diag, uint64 n, 786 uint64 k, const DeviceMemory<float> &a, int lda, 787 DeviceMemory<float> *x, int incx) = 0; 788 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 789 blas::Transpose trans, blas::Diagonal diag, uint64 n, 790 uint64 k, const DeviceMemory<double> &a, int lda, 791 DeviceMemory<double> *x, int incx) = 0; 792 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 793 blas::Transpose trans, blas::Diagonal diag, uint64 n, 794 uint64 k, const DeviceMemory<std::complex<float>> &a, 795 int lda, DeviceMemory<std::complex<float>> *x, 796 int incx) = 0; 797 virtual bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, 798 blas::Transpose trans, blas::Diagonal diag, uint64 n, 799 uint64 k, const DeviceMemory<std::complex<double>> &a, 800 int lda, DeviceMemory<std::complex<double>> *x, 801 int incx) = 0; 802 803 // Solves a system of linear equations whose coefficients are in a triangular 804 // band matrix as below: 805 // 806 // a * x = b, 807 // or 808 // a' * x = b, 809 // or 810 // conj(a') * x = b, 811 // 812 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 813 // lower triangular band matrix, with k+1 diagonals. 814 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 815 blas::Transpose trans, blas::Diagonal diag, uint64 n, 816 uint64 k, const DeviceMemory<float> &a, int lda, 817 DeviceMemory<float> *x, int incx) = 0; 818 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 819 blas::Transpose trans, blas::Diagonal diag, uint64 n, 820 uint64 k, const DeviceMemory<double> &a, int lda, 821 DeviceMemory<double> *x, int incx) = 0; 822 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 823 blas::Transpose trans, blas::Diagonal diag, uint64 n, 824 uint64 k, const DeviceMemory<std::complex<float>> &a, 825 int lda, DeviceMemory<std::complex<float>> *x, 826 int incx) = 0; 827 virtual bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, 828 blas::Transpose trans, blas::Diagonal diag, uint64 n, 829 uint64 k, const DeviceMemory<std::complex<double>> &a, 830 int lda, DeviceMemory<std::complex<double>> *x, 831 int incx) = 0; 832 833 // Computes a matrix-vector product using a triangular packed matrix. 834 // 835 // x <- a * x, 836 // or 837 // x <- a' * x, 838 // or 839 // x <- conj(a') * x, 840 // 841 // a is an n-by-n unit, or non-unit, upper or lower triangular matrix, 842 // supplied in packed form; x is a n-element vector. 843 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 844 blas::Transpose trans, blas::Diagonal diag, uint64 n, 845 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 846 int incx) = 0; 847 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 848 blas::Transpose trans, blas::Diagonal diag, uint64 n, 849 const DeviceMemory<double> &ap, 850 DeviceMemory<double> *x, int incx) = 0; 851 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 852 blas::Transpose trans, blas::Diagonal diag, uint64 n, 853 const DeviceMemory<std::complex<float>> &ap, 854 DeviceMemory<std::complex<float>> *x, int incx) = 0; 855 virtual bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, 856 blas::Transpose trans, blas::Diagonal diag, uint64 n, 857 const DeviceMemory<std::complex<double>> &ap, 858 DeviceMemory<std::complex<double>> *x, int incx) = 0; 859 860 // Solves a system of linear equations whose coefficients are in a triangular 861 // packed matrix as below: 862 // 863 // a * x = b, 864 // or 865 // a' * x = b, 866 // or 867 // conj(a') * x = b, 868 // 869 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 870 // lower triangular matrix, supplied in packed form. 871 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 872 blas::Transpose trans, blas::Diagonal diag, uint64 n, 873 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 874 int incx) = 0; 875 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 876 blas::Transpose trans, blas::Diagonal diag, uint64 n, 877 const DeviceMemory<double> &ap, 878 DeviceMemory<double> *x, int incx) = 0; 879 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 880 blas::Transpose trans, blas::Diagonal diag, uint64 n, 881 const DeviceMemory<std::complex<float>> &ap, 882 DeviceMemory<std::complex<float>> *x, int incx) = 0; 883 virtual bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, 884 blas::Transpose trans, blas::Diagonal diag, uint64 n, 885 const DeviceMemory<std::complex<double>> &ap, 886 DeviceMemory<std::complex<double>> *x, int incx) = 0; 887 888 // Computes a matrix-vector product using a triangular matrix. 889 // 890 // x <- a * x, 891 // or 892 // x <- a' * x, 893 // or 894 // x <- conj(a') * x, 895 // 896 // a is an n-by-n unit, or non-unit, upper or lower triangular matrix; x is a 897 // n-element vector. 898 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 899 blas::Transpose trans, blas::Diagonal diag, uint64 n, 900 const DeviceMemory<float> &a, int lda, 901 DeviceMemory<float> *x, int incx) = 0; 902 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 903 blas::Transpose trans, blas::Diagonal diag, uint64 n, 904 const DeviceMemory<double> &a, int lda, 905 DeviceMemory<double> *x, int incx) = 0; 906 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 907 blas::Transpose trans, blas::Diagonal diag, uint64 n, 908 const DeviceMemory<std::complex<float>> &a, int lda, 909 DeviceMemory<std::complex<float>> *x, int incx) = 0; 910 virtual bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, 911 blas::Transpose trans, blas::Diagonal diag, uint64 n, 912 const DeviceMemory<std::complex<double>> &a, int lda, 913 DeviceMemory<std::complex<double>> *x, int incx) = 0; 914 915 // Solves a system of linear equations whose coefficients are in a triangular 916 // matrix as below: 917 // 918 // a * x = b, 919 // or 920 // a' * x = b, 921 // or 922 // conj(a') * x = b, 923 // 924 // b and x are n-element vectors; a is an n-by-n unit, or non-unit, upper or 925 // lower triangular matrix. 926 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 927 blas::Transpose trans, blas::Diagonal diag, uint64 n, 928 const DeviceMemory<float> &a, int lda, 929 DeviceMemory<float> *x, int incx) = 0; 930 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 931 blas::Transpose trans, blas::Diagonal diag, uint64 n, 932 const DeviceMemory<double> &a, int lda, 933 DeviceMemory<double> *x, int incx) = 0; 934 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 935 blas::Transpose trans, blas::Diagonal diag, uint64 n, 936 const DeviceMemory<std::complex<float>> &a, int lda, 937 DeviceMemory<std::complex<float>> *x, int incx) = 0; 938 virtual bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, 939 blas::Transpose trans, blas::Diagonal diag, uint64 n, 940 const DeviceMemory<std::complex<double>> &a, int lda, 941 DeviceMemory<std::complex<double>> *x, int incx) = 0; 942 943 // Computes a matrix-matrix product with general matrices: 944 // 945 // c <- alpha * op(a) * op(b) + beta * c, 946 // 947 // op(X) is one of op(X) = X, or op(X) = X', or op(X) = conj(X'); alpha and 948 // beta are scalars; a, b, and c are matrices; op(a) is an m-by-k matrix; 949 // op(b) is a k-by-n matrix; c is an m-by-n matrix. 950 // 951 // Note: The half interface uses float precision internally; the version 952 // that uses half precision internally is not yet supported. There is no 953 // batched version of the half-precision interface. 954 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 955 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 956 float alpha, const DeviceMemory<Eigen::half> &a, 957 int lda, const DeviceMemory<Eigen::half> &b, int ldb, 958 float beta, DeviceMemory<Eigen::half> *c, 959 int ldc) = 0; 960 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 961 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 962 float alpha, const DeviceMemory<float> &a, int lda, 963 const DeviceMemory<float> &b, int ldb, float beta, 964 DeviceMemory<float> *c, int ldc) = 0; 965 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 966 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 967 double alpha, const DeviceMemory<double> &a, int lda, 968 const DeviceMemory<double> &b, int ldb, double beta, 969 DeviceMemory<double> *c, int ldc) = 0; 970 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 971 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 972 std::complex<float> alpha, 973 const DeviceMemory<std::complex<float>> &a, int lda, 974 const DeviceMemory<std::complex<float>> &b, int ldb, 975 std::complex<float> beta, 976 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 977 virtual bool DoBlasGemm(Stream *stream, blas::Transpose transa, 978 blas::Transpose transb, uint64 m, uint64 n, uint64 k, 979 std::complex<double> alpha, 980 const DeviceMemory<std::complex<double>> &a, int lda, 981 const DeviceMemory<std::complex<double>> &b, int ldb, 982 std::complex<double> beta, 983 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 984 985 virtual bool DoBlasGemmWithProfiling( 986 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 987 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, 988 int lda, const DeviceMemory<Eigen::half> &b, int ldb, float beta, 989 DeviceMemory<Eigen::half> *c, int ldc, 990 ProfileResult *output_profile_result) = 0; 991 virtual bool DoBlasGemmWithProfiling( 992 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 993 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 994 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, 995 int ldc, ProfileResult *output_profile_result) = 0; 996 virtual bool DoBlasGemmWithProfiling( 997 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 998 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 999 const DeviceMemory<double> &b, int ldb, double beta, 1000 DeviceMemory<double> *c, int ldc, 1001 ProfileResult *output_profile_result) = 0; 1002 virtual bool DoBlasGemmWithProfiling( 1003 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1004 uint64 n, uint64 k, std::complex<float> alpha, 1005 const DeviceMemory<std::complex<float>> &a, int lda, 1006 const DeviceMemory<std::complex<float>> &b, int ldb, 1007 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1008 ProfileResult *output_profile_result) = 0; 1009 virtual bool DoBlasGemmWithProfiling( 1010 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1011 uint64 n, uint64 k, std::complex<double> alpha, 1012 const DeviceMemory<std::complex<double>> &a, int lda, 1013 const DeviceMemory<std::complex<double>> &b, int ldb, 1014 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1015 ProfileResult *output_profile_result) = 0; 1016 1017 // Gets a list of supported algorithms for DoBlasGemmWithAlgorithm. 1018 virtual bool GetBlasGemmAlgorithms( 1019 std::vector<AlgorithmType> *out_algorithms) = 0; 1020 1021 // Like DoBlasGemm, but accepts an algorithm and an compute type. 1022 // 1023 // The compute type lets you say (e.g.) that the inputs and outputs are 1024 // Eigen::halfs, but you want the internal computations to be done with 1025 // float32 precision. 1026 // 1027 // Note the subtle difference in the version that accepts Eigen:::half -- 1028 // alpha and beta have type const Eigen::half&, not float. 1029 // 1030 // If output_profile_result is not null, a failure here does not put the 1031 // stream in a failure state. Instead, success/failure is indicated by 1032 // output_profile_result->is_valid(). This lets you use this function for 1033 // choosing the best algorithm among many (some of which may fail) without 1034 // creating a new Stream for each attempt. 1035 virtual bool DoBlasGemmWithAlgorithm( 1036 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1037 uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha, 1038 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, 1039 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c, 1040 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1041 ProfileResult *output_profile_result) = 0; 1042 virtual bool DoBlasGemmWithAlgorithm( 1043 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1044 uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha, 1045 const DeviceMemory<Eigen::half> &a, int lda, 1046 const DeviceMemory<Eigen::half> &b, int ldb, 1047 const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c, 1048 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1049 ProfileResult *output_profile_result) = 0; 1050 virtual bool DoBlasGemmWithAlgorithm( 1051 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1052 uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha, 1053 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, 1054 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, 1055 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1056 ProfileResult *output_profile_result) = 0; 1057 virtual bool DoBlasGemmWithAlgorithm( 1058 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1059 uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha, 1060 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, 1061 int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c, 1062 int ldc, ComputationType computation_type, AlgorithmType algorithm, 1063 ProfileResult *output_profile_result) = 0; 1064 virtual bool DoBlasGemmWithAlgorithm( 1065 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1066 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha, 1067 const DeviceMemory<std::complex<float>> &a, int lda, 1068 const DeviceMemory<std::complex<float>> &b, int ldb, 1069 const HostOrDeviceScalar<std::complex<float>> &beta, 1070 DeviceMemory<std::complex<float>> *c, int ldc, 1071 ComputationType computation_type, AlgorithmType algorithm, 1072 ProfileResult *output_profile_result) = 0; 1073 virtual bool DoBlasGemmWithAlgorithm( 1074 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1075 uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha, 1076 const DeviceMemory<std::complex<double>> &a, int lda, 1077 const DeviceMemory<std::complex<double>> &b, int ldb, 1078 const HostOrDeviceScalar<std::complex<double>> &beta, 1079 DeviceMemory<std::complex<double>> *c, int ldc, 1080 ComputationType computation_type, AlgorithmType algorithm, 1081 ProfileResult *output_profile_result) = 0; 1082 1083 // Computes a batch of matrix-matrix product with general matrices. 1084 // This is a batched version of DoBlasGemm. 1085 // The batched GEMM computes matrix product for each input/output in a, b, 1086 // and c, which contain batch_count DeviceMemory objects. 1087 virtual bool DoBlasGemmBatched( 1088 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1089 uint64 n, uint64 k, float alpha, 1090 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, 1091 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, 1092 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, 1093 int ldc, int batch_count, ScratchAllocator *scratch_allocator) = 0; 1094 virtual bool DoBlasGemmBatched( 1095 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1096 uint64 n, uint64 k, float alpha, 1097 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, 1098 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, 1099 const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, 1100 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1101 virtual bool DoBlasGemmBatched( 1102 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1103 uint64 n, uint64 k, double alpha, 1104 const port::ArraySlice<DeviceMemory<double> *> &a, int lda, 1105 const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, 1106 const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, 1107 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1108 virtual bool DoBlasGemmBatched( 1109 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1110 uint64 n, uint64 k, std::complex<float> alpha, 1111 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, 1112 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, 1113 std::complex<float> beta, 1114 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, 1115 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1116 virtual bool DoBlasGemmBatched( 1117 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1118 uint64 n, uint64 k, std::complex<double> alpha, 1119 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, 1120 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, 1121 std::complex<double> beta, 1122 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, 1123 int batch_count, ScratchAllocator *scratch_allocator) = 0; 1124 1125 // Batched gemm with strides instead of pointer arrays. 1126 virtual bool DoBlasGemmStridedBatched( 1127 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1128 uint64 n, uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, 1129 int lda, int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, 1130 int64 stride_b, float beta, DeviceMemory<Eigen::half> *c, int ldc, 1131 int64 stride_c, int batch_count) = 0; 1132 virtual bool DoBlasGemmStridedBatched( 1133 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1134 uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 1135 int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b, 1136 float beta, DeviceMemory<float> *c, int ldc, int64 stride_c, 1137 int batch_count) = 0; 1138 virtual bool DoBlasGemmStridedBatched( 1139 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1140 uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 1141 int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b, 1142 double beta, DeviceMemory<double> *c, int ldc, int64 stride_c, 1143 int batch_count) = 0; 1144 virtual bool DoBlasGemmStridedBatched( 1145 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1146 uint64 n, uint64 k, std::complex<float> alpha, 1147 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, 1148 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, 1149 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1150 int64 stride_c, int batch_count) = 0; 1151 virtual bool DoBlasGemmStridedBatched( 1152 Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m, 1153 uint64 n, uint64 k, std::complex<double> alpha, 1154 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, 1155 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, 1156 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1157 int64 stride_c, int batch_count) = 0; 1158 1159 // Computes a matrix-matrix product where one input matrix is Hermitian: 1160 // 1161 // c <- alpha * a * b + beta * c, 1162 // or 1163 // c <- alpha * b * a + beta * c, 1164 // 1165 // alpha and beta are scalars; a is a Hermitian matrix; b and c are m-by-n 1166 // matrices. 1167 virtual bool DoBlasHemm(Stream *stream, blas::Side side, 1168 blas::UpperLower uplo, uint64 m, uint64 n, 1169 std::complex<float> alpha, 1170 const DeviceMemory<std::complex<float>> &a, int lda, 1171 const DeviceMemory<std::complex<float>> &b, int ldb, 1172 std::complex<float> beta, 1173 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1174 virtual bool DoBlasHemm(Stream *stream, blas::Side side, 1175 blas::UpperLower uplo, uint64 m, uint64 n, 1176 std::complex<double> alpha, 1177 const DeviceMemory<std::complex<double>> &a, int lda, 1178 const DeviceMemory<std::complex<double>> &b, int ldb, 1179 std::complex<double> beta, 1180 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1181 1182 // Performs a Hermitian rank-k update. 1183 // 1184 // c <- alpha * a * conj(a') + beta * c, 1185 // or 1186 // c <- alpha * conj(a') * a + beta * c, 1187 // 1188 // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a is an n-by-k 1189 // matrix in the first case and a k-by-n matrix in the second case. 1190 virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, 1191 blas::Transpose trans, uint64 n, uint64 k, 1192 float alpha, 1193 const DeviceMemory<std::complex<float>> &a, int lda, 1194 float beta, DeviceMemory<std::complex<float>> *c, 1195 int ldc) = 0; 1196 virtual bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, 1197 blas::Transpose trans, uint64 n, uint64 k, 1198 double alpha, 1199 const DeviceMemory<std::complex<double>> &a, int lda, 1200 double beta, DeviceMemory<std::complex<double>> *c, 1201 int ldc) = 0; 1202 1203 // Performs a Hermitian rank-2k update. 1204 // 1205 // c <- alpha * a * conj(b') + conj(alpha) * b * conj(a') + beta * c, 1206 // or 1207 // c <- alpha * conj(b') * a + conj(alpha) * conj(a') * b + beta * c, 1208 // 1209 // alpha and beta are scalars; c is a n-by-n Hermitian matrix; a and b are 1210 // n-by-k matrices in the first case and k-by-n matrices in the second case. 1211 virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 1212 blas::Transpose trans, uint64 n, uint64 k, 1213 std::complex<float> alpha, 1214 const DeviceMemory<std::complex<float>> &a, int lda, 1215 const DeviceMemory<std::complex<float>> &b, int ldb, 1216 float beta, DeviceMemory<std::complex<float>> *c, 1217 int ldc) = 0; 1218 virtual bool DoBlasHer2k(Stream *stream, blas::UpperLower uplo, 1219 blas::Transpose trans, uint64 n, uint64 k, 1220 std::complex<double> alpha, 1221 const DeviceMemory<std::complex<double>> &a, int lda, 1222 const DeviceMemory<std::complex<double>> &b, int ldb, 1223 double beta, DeviceMemory<std::complex<double>> *c, 1224 int ldc) = 0; 1225 1226 // Computes a matrix-matrix product where one input matrix is symmetric. 1227 // 1228 // c <- alpha * a * b + beta * c, 1229 // or 1230 // c <- alpha * b * a + beta * c, 1231 // 1232 // alpha and beta are scalars; a is a symmetric matrix; b and c are m-by-n 1233 // matrices. 1234 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1235 blas::UpperLower uplo, uint64 m, uint64 n, 1236 float alpha, const DeviceMemory<float> &a, int lda, 1237 const DeviceMemory<float> &b, int ldb, float beta, 1238 DeviceMemory<float> *c, int ldc) = 0; 1239 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1240 blas::UpperLower uplo, uint64 m, uint64 n, 1241 double alpha, const DeviceMemory<double> &a, int lda, 1242 const DeviceMemory<double> &b, int ldb, double beta, 1243 DeviceMemory<double> *c, int ldc) = 0; 1244 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1245 blas::UpperLower uplo, uint64 m, uint64 n, 1246 std::complex<float> alpha, 1247 const DeviceMemory<std::complex<float>> &a, int lda, 1248 const DeviceMemory<std::complex<float>> &b, int ldb, 1249 std::complex<float> beta, 1250 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1251 virtual bool DoBlasSymm(Stream *stream, blas::Side side, 1252 blas::UpperLower uplo, uint64 m, uint64 n, 1253 std::complex<double> alpha, 1254 const DeviceMemory<std::complex<double>> &a, int lda, 1255 const DeviceMemory<std::complex<double>> &b, int ldb, 1256 std::complex<double> beta, 1257 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1258 1259 // Performs a symmetric rank-k update. 1260 // 1261 // c <- alpha * a * a' + beta * c, 1262 // or 1263 // c <- alpha * a' * a + beta * c, 1264 // 1265 // alpha and beta are scalars; c is a n-by-n symmetric matrix; a is an n-by-k 1266 // matrix in the first case and a k-by-n matrix in the second case. 1267 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1268 blas::Transpose trans, uint64 n, uint64 k, 1269 float alpha, const DeviceMemory<float> &a, int lda, 1270 float beta, DeviceMemory<float> *c, int ldc) = 0; 1271 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1272 blas::Transpose trans, uint64 n, uint64 k, 1273 double alpha, const DeviceMemory<double> &a, int lda, 1274 double beta, DeviceMemory<double> *c, int ldc) = 0; 1275 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1276 blas::Transpose trans, uint64 n, uint64 k, 1277 std::complex<float> alpha, 1278 const DeviceMemory<std::complex<float>> &a, int lda, 1279 std::complex<float> beta, 1280 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1281 virtual bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, 1282 blas::Transpose trans, uint64 n, uint64 k, 1283 std::complex<double> alpha, 1284 const DeviceMemory<std::complex<double>> &a, int lda, 1285 std::complex<double> beta, 1286 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1287 1288 // Performs a symmetric rank-2k update. 1289 // 1290 // c <- alpha * a * b' + alpha * b * a' + beta * c, 1291 // or 1292 // c <- alpha * b' * a + alpha * a' * b + beta * c, 1293 // 1294 // alpha and beta are scalars; c is a n-by-n symmetric matrix; a and b are 1295 // n-by-k matrices in the first case and k-by-n matrices in the second case. 1296 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1297 blas::Transpose trans, uint64 n, uint64 k, 1298 float alpha, const DeviceMemory<float> &a, int lda, 1299 const DeviceMemory<float> &b, int ldb, float beta, 1300 DeviceMemory<float> *c, int ldc) = 0; 1301 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1302 blas::Transpose trans, uint64 n, uint64 k, 1303 double alpha, const DeviceMemory<double> &a, int lda, 1304 const DeviceMemory<double> &b, int ldb, double beta, 1305 DeviceMemory<double> *c, int ldc) = 0; 1306 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1307 blas::Transpose trans, uint64 n, uint64 k, 1308 std::complex<float> alpha, 1309 const DeviceMemory<std::complex<float>> &a, int lda, 1310 const DeviceMemory<std::complex<float>> &b, int ldb, 1311 std::complex<float> beta, 1312 DeviceMemory<std::complex<float>> *c, int ldc) = 0; 1313 virtual bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, 1314 blas::Transpose trans, uint64 n, uint64 k, 1315 std::complex<double> alpha, 1316 const DeviceMemory<std::complex<double>> &a, int lda, 1317 const DeviceMemory<std::complex<double>> &b, int ldb, 1318 std::complex<double> beta, 1319 DeviceMemory<std::complex<double>> *c, int ldc) = 0; 1320 1321 // Computes a matrix-matrix product where one input matrix is triangular. 1322 // 1323 // b <- alpha * op(a) * b, 1324 // or 1325 // b <- alpha * b * op(a) 1326 // 1327 // alpha is a scalar; b is an m-by-n matrix; a is a unit, or non-unit, upper 1328 // or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', or 1329 // op(a) = conj(a'). 1330 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1331 blas::UpperLower uplo, blas::Transpose transa, 1332 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 1333 const DeviceMemory<float> &a, int lda, 1334 DeviceMemory<float> *b, int ldb) = 0; 1335 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1336 blas::UpperLower uplo, blas::Transpose transa, 1337 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 1338 const DeviceMemory<double> &a, int lda, 1339 DeviceMemory<double> *b, int ldb) = 0; 1340 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1341 blas::UpperLower uplo, blas::Transpose transa, 1342 blas::Diagonal diag, uint64 m, uint64 n, 1343 std::complex<float> alpha, 1344 const DeviceMemory<std::complex<float>> &a, int lda, 1345 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 1346 virtual bool DoBlasTrmm(Stream *stream, blas::Side side, 1347 blas::UpperLower uplo, blas::Transpose transa, 1348 blas::Diagonal diag, uint64 m, uint64 n, 1349 std::complex<double> alpha, 1350 const DeviceMemory<std::complex<double>> &a, int lda, 1351 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 1352 1353 // Solves a triangular matrix equation. 1354 // 1355 // op(a) * x = alpha * b, 1356 // or 1357 // x * op(a) = alpha * b 1358 // 1359 // alpha is a scalar; x and b are m-by-n matrices; a is a unit, or non-unit, 1360 // upper or lower triangular matrix; op(a) is one of op(a) = a, or op(a) = a', 1361 // or op(a) = conj(a'). 1362 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1363 blas::UpperLower uplo, blas::Transpose transa, 1364 blas::Diagonal diag, uint64 m, uint64 n, float alpha, 1365 const DeviceMemory<float> &a, int lda, 1366 DeviceMemory<float> *b, int ldb) = 0; 1367 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1368 blas::UpperLower uplo, blas::Transpose transa, 1369 blas::Diagonal diag, uint64 m, uint64 n, double alpha, 1370 const DeviceMemory<double> &a, int lda, 1371 DeviceMemory<double> *b, int ldb) = 0; 1372 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1373 blas::UpperLower uplo, blas::Transpose transa, 1374 blas::Diagonal diag, uint64 m, uint64 n, 1375 std::complex<float> alpha, 1376 const DeviceMemory<std::complex<float>> &a, int lda, 1377 DeviceMemory<std::complex<float>> *b, int ldb) = 0; 1378 virtual bool DoBlasTrsm(Stream *stream, blas::Side side, 1379 blas::UpperLower uplo, blas::Transpose transa, 1380 blas::Diagonal diag, uint64 m, uint64 n, 1381 std::complex<double> alpha, 1382 const DeviceMemory<std::complex<double>> &a, int lda, 1383 DeviceMemory<std::complex<double>> *b, int ldb) = 0; 1384 1385 protected: 1386 BlasSupport() {} 1387 1388 private: 1389 SE_DISALLOW_COPY_AND_ASSIGN(BlasSupport); 1390 }; 1391 1392 // Macro used to quickly declare overrides for abstract virtuals in the 1393 // BlasSupport base class. 1394 #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ 1395 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1396 const DeviceMemory<float> &x, int incx, \ 1397 DeviceMemory<float> *result) override; \ 1398 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1399 const DeviceMemory<double> &x, int incx, \ 1400 DeviceMemory<double> *result) override; \ 1401 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1402 const DeviceMemory<std::complex<float>> &x, int incx, \ 1403 DeviceMemory<float> *result) override; \ 1404 bool DoBlasAsum(Stream *stream, uint64 elem_count, \ 1405 const DeviceMemory<std::complex<double>> &x, int incx, \ 1406 DeviceMemory<double> *result) override; \ 1407 bool DoBlasAxpy(Stream *stream, uint64 elem_count, float alpha, \ 1408 const DeviceMemory<float> &x, int incx, \ 1409 DeviceMemory<float> *y, int incy) override; \ 1410 bool DoBlasAxpy(Stream *stream, uint64 elem_count, double alpha, \ 1411 const DeviceMemory<double> &x, int incx, \ 1412 DeviceMemory<double> *y, int incy) override; \ 1413 bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ 1414 std::complex<float> alpha, \ 1415 const DeviceMemory<std::complex<float>> &x, int incx, \ 1416 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1417 bool DoBlasAxpy(Stream *stream, uint64 elem_count, \ 1418 std::complex<double> alpha, \ 1419 const DeviceMemory<std::complex<double>> &x, int incx, \ 1420 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1421 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1422 const DeviceMemory<float> &x, int incx, \ 1423 DeviceMemory<float> *y, int incy) override; \ 1424 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1425 const DeviceMemory<double> &x, int incx, \ 1426 DeviceMemory<double> *y, int incy) override; \ 1427 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1428 const DeviceMemory<std::complex<float>> &x, int incx, \ 1429 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1430 bool DoBlasCopy(Stream *stream, uint64 elem_count, \ 1431 const DeviceMemory<std::complex<double>> &x, int incx, \ 1432 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1433 bool DoBlasDot(Stream *stream, uint64 elem_count, \ 1434 const DeviceMemory<float> &x, int incx, \ 1435 const DeviceMemory<float> &y, int incy, \ 1436 DeviceMemory<float> *result) override; \ 1437 bool DoBlasDot(Stream *stream, uint64 elem_count, \ 1438 const DeviceMemory<double> &x, int incx, \ 1439 const DeviceMemory<double> &y, int incy, \ 1440 DeviceMemory<double> *result) override; \ 1441 bool DoBlasDotc(Stream *stream, uint64 elem_count, \ 1442 const DeviceMemory<std::complex<float>> &x, int incx, \ 1443 const DeviceMemory<std::complex<float>> &y, int incy, \ 1444 DeviceMemory<std::complex<float>> *result) override; \ 1445 bool DoBlasDotc(Stream *stream, uint64 elem_count, \ 1446 const DeviceMemory<std::complex<double>> &x, int incx, \ 1447 const DeviceMemory<std::complex<double>> &y, int incy, \ 1448 DeviceMemory<std::complex<double>> *result) override; \ 1449 bool DoBlasDotu(Stream *stream, uint64 elem_count, \ 1450 const DeviceMemory<std::complex<float>> &x, int incx, \ 1451 const DeviceMemory<std::complex<float>> &y, int incy, \ 1452 DeviceMemory<std::complex<float>> *result) override; \ 1453 bool DoBlasDotu(Stream *stream, uint64 elem_count, \ 1454 const DeviceMemory<std::complex<double>> &x, int incx, \ 1455 const DeviceMemory<std::complex<double>> &y, int incy, \ 1456 DeviceMemory<std::complex<double>> *result) override; \ 1457 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1458 const DeviceMemory<float> &x, int incx, \ 1459 DeviceMemory<float> *result) override; \ 1460 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1461 const DeviceMemory<double> &x, int incx, \ 1462 DeviceMemory<double> *result) override; \ 1463 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1464 const DeviceMemory<std::complex<float>> &x, int incx, \ 1465 DeviceMemory<float> *result) override; \ 1466 bool DoBlasNrm2(Stream *stream, uint64 elem_count, \ 1467 const DeviceMemory<std::complex<double>> &x, int incx, \ 1468 DeviceMemory<double> *result) override; \ 1469 bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1470 int incx, DeviceMemory<float> *y, int incy, float c, float s) \ 1471 override; \ 1472 bool DoBlasRot(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1473 int incx, DeviceMemory<double> *y, int incy, double c, \ 1474 double s) override; \ 1475 bool DoBlasRot(Stream *stream, uint64 elem_count, \ 1476 DeviceMemory<std::complex<float>> *x, int incx, \ 1477 DeviceMemory<std::complex<float>> *y, int incy, float c, \ 1478 float s) override; \ 1479 bool DoBlasRot(Stream *stream, uint64 elem_count, \ 1480 DeviceMemory<std::complex<double>> *x, int incx, \ 1481 DeviceMemory<std::complex<double>> *y, int incy, double c, \ 1482 double s) override; \ 1483 bool DoBlasRotg(Stream *stream, DeviceMemory<float> *a, \ 1484 DeviceMemory<float> *b, DeviceMemory<float> *c, \ 1485 DeviceMemory<float> *s) override; \ 1486 bool DoBlasRotg(Stream *stream, DeviceMemory<double> *a, \ 1487 DeviceMemory<double> *b, DeviceMemory<double> *c, \ 1488 DeviceMemory<double> *s) override; \ 1489 bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<float>> *a, \ 1490 DeviceMemory<std::complex<float>> *b, \ 1491 DeviceMemory<float> *c, \ 1492 DeviceMemory<std::complex<float>> *s) override; \ 1493 bool DoBlasRotg(Stream *stream, DeviceMemory<std::complex<double>> *a, \ 1494 DeviceMemory<std::complex<double>> *b, \ 1495 DeviceMemory<double> *c, \ 1496 DeviceMemory<std::complex<double>> *s) override; \ 1497 bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1498 int incx, DeviceMemory<float> *y, int incy, \ 1499 const DeviceMemory<float> ¶m) override; \ 1500 bool DoBlasRotm(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1501 int incx, DeviceMemory<double> *y, int incy, \ 1502 const DeviceMemory<double> ¶m) override; \ 1503 bool DoBlasRotmg(Stream *stream, DeviceMemory<float> *d1, \ 1504 DeviceMemory<float> *d2, DeviceMemory<float> *x1, \ 1505 const DeviceMemory<float> &y1, DeviceMemory<float> *param) \ 1506 override; \ 1507 bool DoBlasRotmg(Stream *stream, DeviceMemory<double> *d1, \ 1508 DeviceMemory<double> *d2, DeviceMemory<double> *x1, \ 1509 const DeviceMemory<double> &y1, \ 1510 DeviceMemory<double> *param) override; \ 1511 bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ 1512 DeviceMemory<float> *x, int incx) override; \ 1513 bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ 1514 DeviceMemory<double> *x, int incx) override; \ 1515 bool DoBlasScal(Stream *stream, uint64 elem_count, float alpha, \ 1516 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1517 bool DoBlasScal(Stream *stream, uint64 elem_count, double alpha, \ 1518 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1519 bool DoBlasScal(Stream *stream, uint64 elem_count, \ 1520 std::complex<float> alpha, \ 1521 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1522 bool DoBlasScal(Stream *stream, uint64 elem_count, \ 1523 std::complex<double> alpha, \ 1524 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1525 bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<float> *x, \ 1526 int incx, DeviceMemory<float> *y, int incy) override; \ 1527 bool DoBlasSwap(Stream *stream, uint64 elem_count, DeviceMemory<double> *x, \ 1528 int incx, DeviceMemory<double> *y, int incy) override; \ 1529 bool DoBlasSwap(Stream *stream, uint64 elem_count, \ 1530 DeviceMemory<std::complex<float>> *x, int incx, \ 1531 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1532 bool DoBlasSwap(Stream *stream, uint64 elem_count, \ 1533 DeviceMemory<std::complex<double>> *x, int incx, \ 1534 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1535 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1536 const DeviceMemory<float> &x, int incx, \ 1537 DeviceMemory<int> *result) override; \ 1538 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1539 const DeviceMemory<double> &x, int incx, \ 1540 DeviceMemory<int> *result) override; \ 1541 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1542 const DeviceMemory<std::complex<float>> &x, int incx, \ 1543 DeviceMemory<int> *result) override; \ 1544 bool DoBlasIamax(Stream *stream, uint64 elem_count, \ 1545 const DeviceMemory<std::complex<double>> &x, int incx, \ 1546 DeviceMemory<int> *result) override; \ 1547 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1548 const DeviceMemory<float> &x, int incx, \ 1549 DeviceMemory<int> *result) override; \ 1550 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1551 const DeviceMemory<double> &x, int incx, \ 1552 DeviceMemory<int> *result) override; \ 1553 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1554 const DeviceMemory<std::complex<float>> &x, int incx, \ 1555 DeviceMemory<int> *result) override; \ 1556 bool DoBlasIamin(Stream *stream, uint64 elem_count, \ 1557 const DeviceMemory<std::complex<double>> &x, int incx, \ 1558 DeviceMemory<int> *result) override; \ 1559 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1560 uint64 kl, uint64 ku, float alpha, \ 1561 const DeviceMemory<float> &a, int lda, \ 1562 const DeviceMemory<float> &x, int incx, float beta, \ 1563 DeviceMemory<float> *y, int incy) override; \ 1564 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1565 uint64 kl, uint64 ku, double alpha, \ 1566 const DeviceMemory<double> &a, int lda, \ 1567 const DeviceMemory<double> &x, int incx, double beta, \ 1568 DeviceMemory<double> *y, int incy) override; \ 1569 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1570 uint64 kl, uint64 ku, std::complex<float> alpha, \ 1571 const DeviceMemory<std::complex<float>> &a, int lda, \ 1572 const DeviceMemory<std::complex<float>> &x, int incx, \ 1573 std::complex<float> beta, \ 1574 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1575 bool DoBlasGbmv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1576 uint64 kl, uint64 ku, std::complex<double> alpha, \ 1577 const DeviceMemory<std::complex<double>> &a, int lda, \ 1578 const DeviceMemory<std::complex<double>> &x, int incx, \ 1579 std::complex<double> beta, \ 1580 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1581 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1582 float alpha, const DeviceMemory<float> &a, int lda, \ 1583 const DeviceMemory<float> &x, int incx, float beta, \ 1584 DeviceMemory<float> *y, int incy) override; \ 1585 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1586 double alpha, const DeviceMemory<double> &a, int lda, \ 1587 const DeviceMemory<double> &x, int incx, double beta, \ 1588 DeviceMemory<double> *y, int incy) override; \ 1589 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1590 std::complex<float> alpha, \ 1591 const DeviceMemory<std::complex<float>> &a, int lda, \ 1592 const DeviceMemory<std::complex<float>> &x, int incx, \ 1593 std::complex<float> beta, \ 1594 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1595 bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1596 std::complex<double> alpha, \ 1597 const DeviceMemory<std::complex<double>> &a, int lda, \ 1598 const DeviceMemory<std::complex<double>> &x, int incx, \ 1599 std::complex<double> beta, \ 1600 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1601 bool DoBlasGemvWithProfiling( \ 1602 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, float alpha, \ 1603 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, \ 1604 int incx, float beta, DeviceMemory<float> *y, int incy, \ 1605 blas::ProfileResult *output_profile_result) override; \ 1606 bool DoBlasGemvWithProfiling( \ 1607 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, double alpha, \ 1608 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, \ 1609 int incx, double beta, DeviceMemory<double> *y, int incy, \ 1610 blas::ProfileResult *output_profile_result) override; \ 1611 bool DoBlasGemvWithProfiling( \ 1612 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1613 std::complex<float> alpha, const DeviceMemory<std::complex<float>> &a, \ 1614 int lda, const DeviceMemory<std::complex<float>> &x, int incx, \ 1615 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, \ 1616 int incy, blas::ProfileResult *output_profile_result) override; \ 1617 bool DoBlasGemvWithProfiling( \ 1618 Stream *stream, blas::Transpose trans, uint64 m, uint64 n, \ 1619 std::complex<double> alpha, const DeviceMemory<std::complex<double>> &a, \ 1620 int lda, const DeviceMemory<std::complex<double>> &x, int incx, \ 1621 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, \ 1622 int incy, blas::ProfileResult *output_profile_result) override; \ 1623 bool DoBlasGer(Stream *stream, uint64 m, uint64 n, float alpha, \ 1624 const DeviceMemory<float> &x, int incx, \ 1625 const DeviceMemory<float> &y, int incy, \ 1626 DeviceMemory<float> *a, int lda) override; \ 1627 bool DoBlasGer(Stream *stream, uint64 m, uint64 n, double alpha, \ 1628 const DeviceMemory<double> &x, int incx, \ 1629 const DeviceMemory<double> &y, int incy, \ 1630 DeviceMemory<double> *a, int lda) override; \ 1631 bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ 1632 std::complex<float> alpha, \ 1633 const DeviceMemory<std::complex<float>> &x, int incx, \ 1634 const DeviceMemory<std::complex<float>> &y, int incy, \ 1635 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1636 bool DoBlasGerc(Stream *stream, uint64 m, uint64 n, \ 1637 std::complex<double> alpha, \ 1638 const DeviceMemory<std::complex<double>> &x, int incx, \ 1639 const DeviceMemory<std::complex<double>> &y, int incy, \ 1640 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1641 bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ 1642 std::complex<float> alpha, \ 1643 const DeviceMemory<std::complex<float>> &x, int incx, \ 1644 const DeviceMemory<std::complex<float>> &y, int incy, \ 1645 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1646 bool DoBlasGeru(Stream *stream, uint64 m, uint64 n, \ 1647 std::complex<double> alpha, \ 1648 const DeviceMemory<std::complex<double>> &x, int incx, \ 1649 const DeviceMemory<std::complex<double>> &y, int incy, \ 1650 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1651 bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1652 std::complex<float> alpha, \ 1653 const DeviceMemory<std::complex<float>> &a, int lda, \ 1654 const DeviceMemory<std::complex<float>> &x, int incx, \ 1655 std::complex<float> beta, \ 1656 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1657 bool DoBlasHbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1658 std::complex<double> alpha, \ 1659 const DeviceMemory<std::complex<double>> &a, int lda, \ 1660 const DeviceMemory<std::complex<double>> &x, int incx, \ 1661 std::complex<double> beta, \ 1662 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1663 bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1664 std::complex<float> alpha, \ 1665 const DeviceMemory<std::complex<float>> &a, int lda, \ 1666 const DeviceMemory<std::complex<float>> &x, int incx, \ 1667 std::complex<float> beta, \ 1668 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1669 bool DoBlasHemv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1670 std::complex<double> alpha, \ 1671 const DeviceMemory<std::complex<double>> &a, int lda, \ 1672 const DeviceMemory<std::complex<double>> &x, int incx, \ 1673 std::complex<double> beta, \ 1674 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1675 bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1676 const DeviceMemory<std::complex<float>> &x, int incx, \ 1677 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1678 bool DoBlasHer(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1679 double alpha, const DeviceMemory<std::complex<double>> &x, \ 1680 int incx, DeviceMemory<std::complex<double>> *a, int lda) \ 1681 override; \ 1682 bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1683 std::complex<float> alpha, \ 1684 const DeviceMemory<std::complex<float>> &x, int incx, \ 1685 const DeviceMemory<std::complex<float>> &y, int incy, \ 1686 DeviceMemory<std::complex<float>> *a, int lda) override; \ 1687 bool DoBlasHer2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1688 std::complex<double> alpha, \ 1689 const DeviceMemory<std::complex<double>> &x, int incx, \ 1690 const DeviceMemory<std::complex<double>> &y, int incy, \ 1691 DeviceMemory<std::complex<double>> *a, int lda) override; \ 1692 bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1693 std::complex<float> alpha, \ 1694 const DeviceMemory<std::complex<float>> &ap, \ 1695 const DeviceMemory<std::complex<float>> &x, int incx, \ 1696 std::complex<float> beta, \ 1697 DeviceMemory<std::complex<float>> *y, int incy) override; \ 1698 bool DoBlasHpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1699 std::complex<double> alpha, \ 1700 const DeviceMemory<std::complex<double>> &ap, \ 1701 const DeviceMemory<std::complex<double>> &x, int incx, \ 1702 std::complex<double> beta, \ 1703 DeviceMemory<std::complex<double>> *y, int incy) override; \ 1704 bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1705 const DeviceMemory<std::complex<float>> &x, int incx, \ 1706 DeviceMemory<std::complex<float>> *ap) override; \ 1707 bool DoBlasHpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1708 double alpha, const DeviceMemory<std::complex<double>> &x, \ 1709 int incx, DeviceMemory<std::complex<double>> *ap) override; \ 1710 bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1711 std::complex<float> alpha, \ 1712 const DeviceMemory<std::complex<float>> &x, int incx, \ 1713 const DeviceMemory<std::complex<float>> &y, int incy, \ 1714 DeviceMemory<std::complex<float>> *ap) override; \ 1715 bool DoBlasHpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1716 std::complex<double> alpha, \ 1717 const DeviceMemory<std::complex<double>> &x, int incx, \ 1718 const DeviceMemory<std::complex<double>> &y, int incy, \ 1719 DeviceMemory<std::complex<double>> *ap) override; \ 1720 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1721 float alpha, const DeviceMemory<float> &a, int lda, \ 1722 const DeviceMemory<float> &x, int incx, float beta, \ 1723 DeviceMemory<float> *y, int incy) override; \ 1724 bool DoBlasSbmv(Stream *stream, blas::UpperLower uplo, uint64 n, uint64 k, \ 1725 double alpha, const DeviceMemory<double> &a, int lda, \ 1726 const DeviceMemory<double> &x, int incx, double beta, \ 1727 DeviceMemory<double> *y, int incy) override; \ 1728 bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1729 float alpha, const DeviceMemory<float> &ap, \ 1730 const DeviceMemory<float> &x, int incx, float beta, \ 1731 DeviceMemory<float> *y, int incy) override; \ 1732 bool DoBlasSpmv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1733 double alpha, const DeviceMemory<double> &ap, \ 1734 const DeviceMemory<double> &x, int incx, double beta, \ 1735 DeviceMemory<double> *y, int incy) override; \ 1736 bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1737 const DeviceMemory<float> &x, int incx, \ 1738 DeviceMemory<float> *ap) override; \ 1739 bool DoBlasSpr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1740 double alpha, const DeviceMemory<double> &x, int incx, \ 1741 DeviceMemory<double> *ap) override; \ 1742 bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1743 float alpha, const DeviceMemory<float> &x, int incx, \ 1744 const DeviceMemory<float> &y, int incy, \ 1745 DeviceMemory<float> *ap) override; \ 1746 bool DoBlasSpr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1747 double alpha, const DeviceMemory<double> &x, int incx, \ 1748 const DeviceMemory<double> &y, int incy, \ 1749 DeviceMemory<double> *ap) override; \ 1750 bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1751 float alpha, const DeviceMemory<float> &a, int lda, \ 1752 const DeviceMemory<float> &x, int incx, float beta, \ 1753 DeviceMemory<float> *y, int incy) override; \ 1754 bool DoBlasSymv(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1755 double alpha, const DeviceMemory<double> &a, int lda, \ 1756 const DeviceMemory<double> &x, int incx, double beta, \ 1757 DeviceMemory<double> *y, int incy) override; \ 1758 bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, float alpha, \ 1759 const DeviceMemory<float> &x, int incx, \ 1760 DeviceMemory<float> *a, int lda) override; \ 1761 bool DoBlasSyr(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1762 double alpha, const DeviceMemory<double> &x, int incx, \ 1763 DeviceMemory<double> *a, int lda) override; \ 1764 bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1765 float alpha, const DeviceMemory<float> &x, int incx, \ 1766 const DeviceMemory<float> &y, int incy, \ 1767 DeviceMemory<float> *a, int lda) override; \ 1768 bool DoBlasSyr2(Stream *stream, blas::UpperLower uplo, uint64 n, \ 1769 double alpha, const DeviceMemory<double> &x, int incx, \ 1770 const DeviceMemory<double> &y, int incy, \ 1771 DeviceMemory<double> *a, int lda) override; \ 1772 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1773 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1774 uint64 k, const DeviceMemory<float> &a, int lda, \ 1775 DeviceMemory<float> *x, int incx) override; \ 1776 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1777 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1778 uint64 k, const DeviceMemory<double> &a, int lda, \ 1779 DeviceMemory<double> *x, int incx) override; \ 1780 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1781 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1782 uint64 k, const DeviceMemory<std::complex<float>> &a, \ 1783 int lda, DeviceMemory<std::complex<float>> *x, int incx) \ 1784 override; \ 1785 bool DoBlasTbmv(Stream *stream, blas::UpperLower uplo, \ 1786 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1787 uint64 k, const DeviceMemory<std::complex<double>> &a, \ 1788 int lda, DeviceMemory<std::complex<double>> *x, int incx) \ 1789 override; \ 1790 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1791 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1792 uint64 k, const DeviceMemory<float> &a, int lda, \ 1793 DeviceMemory<float> *x, int incx) override; \ 1794 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1795 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1796 uint64 k, const DeviceMemory<double> &a, int lda, \ 1797 DeviceMemory<double> *x, int incx) override; \ 1798 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1799 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1800 uint64 k, const DeviceMemory<std::complex<float>> &a, \ 1801 int lda, DeviceMemory<std::complex<float>> *x, int incx) \ 1802 override; \ 1803 bool DoBlasTbsv(Stream *stream, blas::UpperLower uplo, \ 1804 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1805 uint64 k, const DeviceMemory<std::complex<double>> &a, \ 1806 int lda, DeviceMemory<std::complex<double>> *x, int incx) \ 1807 override; \ 1808 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1809 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1810 const DeviceMemory<float> &ap, DeviceMemory<float> *x, \ 1811 int incx) override; \ 1812 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1813 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1814 const DeviceMemory<double> &ap, DeviceMemory<double> *x, \ 1815 int incx) override; \ 1816 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1817 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1818 const DeviceMemory<std::complex<float>> &ap, \ 1819 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1820 bool DoBlasTpmv(Stream *stream, blas::UpperLower uplo, \ 1821 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1822 const DeviceMemory<std::complex<double>> &ap, \ 1823 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1824 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1825 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1826 const DeviceMemory<float> &ap, DeviceMemory<float> *x, \ 1827 int incx) override; \ 1828 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1829 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1830 const DeviceMemory<double> &ap, DeviceMemory<double> *x, \ 1831 int incx) override; \ 1832 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1833 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1834 const DeviceMemory<std::complex<float>> &ap, \ 1835 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1836 bool DoBlasTpsv(Stream *stream, blas::UpperLower uplo, \ 1837 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1838 const DeviceMemory<std::complex<double>> &ap, \ 1839 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1840 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1841 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1842 const DeviceMemory<float> &a, int lda, \ 1843 DeviceMemory<float> *x, int incx) override; \ 1844 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1845 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1846 const DeviceMemory<double> &a, int lda, \ 1847 DeviceMemory<double> *x, int incx) override; \ 1848 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1849 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1850 const DeviceMemory<std::complex<float>> &a, int lda, \ 1851 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1852 bool DoBlasTrmv(Stream *stream, blas::UpperLower uplo, \ 1853 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1854 const DeviceMemory<std::complex<double>> &a, int lda, \ 1855 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1856 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1857 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1858 const DeviceMemory<float> &a, int lda, \ 1859 DeviceMemory<float> *x, int incx) override; \ 1860 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1861 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1862 const DeviceMemory<double> &a, int lda, \ 1863 DeviceMemory<double> *x, int incx) override; \ 1864 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1865 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1866 const DeviceMemory<std::complex<float>> &a, int lda, \ 1867 DeviceMemory<std::complex<float>> *x, int incx) override; \ 1868 bool DoBlasTrsv(Stream *stream, blas::UpperLower uplo, \ 1869 blas::Transpose trans, blas::Diagonal diag, uint64 n, \ 1870 const DeviceMemory<std::complex<double>> &a, int lda, \ 1871 DeviceMemory<std::complex<double>> *x, int incx) override; \ 1872 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1873 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1874 float alpha, const DeviceMemory<Eigen::half> &a, int lda, \ 1875 const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ 1876 DeviceMemory<Eigen::half> *c, int ldc) override; \ 1877 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1878 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1879 float alpha, const DeviceMemory<float> &a, int lda, \ 1880 const DeviceMemory<float> &b, int ldb, float beta, \ 1881 DeviceMemory<float> *c, int ldc) override; \ 1882 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1883 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1884 double alpha, const DeviceMemory<double> &a, int lda, \ 1885 const DeviceMemory<double> &b, int ldb, double beta, \ 1886 DeviceMemory<double> *c, int ldc) override; \ 1887 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1888 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1889 std::complex<float> alpha, \ 1890 const DeviceMemory<std::complex<float>> &a, int lda, \ 1891 const DeviceMemory<std::complex<float>> &b, int ldb, \ 1892 std::complex<float> beta, \ 1893 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 1894 bool DoBlasGemm(Stream *stream, blas::Transpose transa, \ 1895 blas::Transpose transb, uint64 m, uint64 n, uint64 k, \ 1896 std::complex<double> alpha, \ 1897 const DeviceMemory<std::complex<double>> &a, int lda, \ 1898 const DeviceMemory<std::complex<double>> &b, int ldb, \ 1899 std::complex<double> beta, \ 1900 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 1901 bool DoBlasGemmWithProfiling( \ 1902 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1903 uint64 m, uint64 n, uint64 k, float alpha, \ 1904 const DeviceMemory<Eigen::half> &a, int lda, \ 1905 const DeviceMemory<Eigen::half> &b, int ldb, float beta, \ 1906 DeviceMemory<Eigen::half> *c, int ldc, \ 1907 blas::ProfileResult *output_profile_result) override; \ 1908 bool DoBlasGemmWithProfiling( \ 1909 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1910 uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ 1911 int lda, const DeviceMemory<float> &b, int ldb, float beta, \ 1912 DeviceMemory<float> *c, int ldc, \ 1913 blas::ProfileResult *output_profile_result) override; \ 1914 bool DoBlasGemmWithProfiling( \ 1915 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1916 uint64 m, uint64 n, uint64 k, double alpha, \ 1917 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ 1918 int ldb, double beta, DeviceMemory<double> *c, int ldc, \ 1919 blas::ProfileResult *output_profile_result) override; \ 1920 bool DoBlasGemmWithProfiling( \ 1921 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1922 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 1923 const DeviceMemory<std::complex<float>> &a, int lda, \ 1924 const DeviceMemory<std::complex<float>> &b, int ldb, \ 1925 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ 1926 blas::ProfileResult *output_profile_result) override; \ 1927 bool DoBlasGemmWithProfiling( \ 1928 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1929 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 1930 const DeviceMemory<std::complex<double>> &a, int lda, \ 1931 const DeviceMemory<std::complex<double>> &b, int ldb, \ 1932 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ 1933 int ldc, blas::ProfileResult *output_profile_result) override; \ 1934 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms) \ 1935 override; \ 1936 bool DoBlasGemmWithAlgorithm( \ 1937 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1938 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha, \ 1939 const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, \ 1940 int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, \ 1941 int ldc, blas::ComputationType computation_type, \ 1942 blas::AlgorithmType algorithm, \ 1943 blas::ProfileResult *output_profile_result) override; \ 1944 bool DoBlasGemmWithAlgorithm( \ 1945 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1946 uint64 m, uint64 n, uint64 k, \ 1947 const HostOrDeviceScalar<Eigen::half> &alpha, \ 1948 const DeviceMemory<Eigen::half> &a, int lda, \ 1949 const DeviceMemory<Eigen::half> &b, int ldb, \ 1950 const HostOrDeviceScalar<Eigen::half> &beta, \ 1951 DeviceMemory<Eigen::half> *c, int ldc, \ 1952 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 1953 blas::ProfileResult *output_profile_result) override; \ 1954 bool DoBlasGemmWithAlgorithm( \ 1955 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1956 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha, \ 1957 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, \ 1958 int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, \ 1959 int ldc, blas::ComputationType computation_type, \ 1960 blas::AlgorithmType algorithm, \ 1961 blas::ProfileResult *output_profile_result) override; \ 1962 bool DoBlasGemmWithAlgorithm( \ 1963 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1964 uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha, \ 1965 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \ 1966 int ldb, const HostOrDeviceScalar<double> &beta, \ 1967 DeviceMemory<double> *c, int ldc, \ 1968 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 1969 blas::ProfileResult *output_profile_result) override; \ 1970 bool DoBlasGemmWithAlgorithm( \ 1971 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1972 uint64 m, uint64 n, uint64 k, \ 1973 const HostOrDeviceScalar<std::complex<float>> &alpha, \ 1974 const DeviceMemory<std::complex<float>> &a, int lda, \ 1975 const DeviceMemory<std::complex<float>> &b, int ldb, \ 1976 const HostOrDeviceScalar<std::complex<float>> &beta, \ 1977 DeviceMemory<std::complex<float>> *c, int ldc, \ 1978 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 1979 blas::ProfileResult *output_profile_result) override; \ 1980 bool DoBlasGemmWithAlgorithm( \ 1981 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1982 uint64 m, uint64 n, uint64 k, \ 1983 const HostOrDeviceScalar<std::complex<double>> &alpha, \ 1984 const DeviceMemory<std::complex<double>> &a, int lda, \ 1985 const DeviceMemory<std::complex<double>> &b, int ldb, \ 1986 const HostOrDeviceScalar<std::complex<double>> &beta, \ 1987 DeviceMemory<std::complex<double>> *c, int ldc, \ 1988 blas::ComputationType computation_type, blas::AlgorithmType algorithm, \ 1989 blas::ProfileResult *output_profile_result) override; \ 1990 bool DoBlasGemmBatched( \ 1991 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1992 uint64 m, uint64 n, uint64 k, float alpha, \ 1993 const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda, \ 1994 const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, \ 1995 float beta, const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, \ 1996 int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ 1997 bool DoBlasGemmBatched( \ 1998 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 1999 uint64 m, uint64 n, uint64 k, float alpha, \ 2000 const port::ArraySlice<DeviceMemory<float> *> &a, int lda, \ 2001 const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, float beta, \ 2002 const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, \ 2003 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2004 bool DoBlasGemmBatched( \ 2005 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2006 uint64 m, uint64 n, uint64 k, double alpha, \ 2007 const port::ArraySlice<DeviceMemory<double> *> &a, int lda, \ 2008 const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, double beta, \ 2009 const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, \ 2010 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2011 bool DoBlasGemmBatched( \ 2012 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2013 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 2014 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, \ 2015 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, \ 2016 std::complex<float> beta, \ 2017 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, \ 2018 int batch_count, ScratchAllocator *scratch_allocator) override; \ 2019 bool DoBlasGemmBatched( \ 2020 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2021 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 2022 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, \ 2023 int lda, \ 2024 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, \ 2025 int ldb, std::complex<double> beta, \ 2026 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, \ 2027 int ldc, int batch_count, ScratchAllocator *scratch_allocator) override; \ 2028 bool DoBlasGemmStridedBatched( \ 2029 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2030 uint64 m, uint64 n, uint64 k, float alpha, \ 2031 const DeviceMemory<Eigen::half> &a, int lda, int64 stride_a, \ 2032 const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b, float beta, \ 2033 DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c, int batch_count); \ 2034 bool DoBlasGemmStridedBatched( \ 2035 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2036 uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \ 2037 int lda, int64 stride_a, const DeviceMemory<float> &b, int ldb, \ 2038 int64 stride_b, float beta, DeviceMemory<float> *c, int ldc, \ 2039 int64 stride_c, int batch_count); \ 2040 bool DoBlasGemmStridedBatched( \ 2041 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2042 uint64 m, uint64 n, uint64 k, double alpha, \ 2043 const DeviceMemory<double> &a, int lda, int64 stride_a, \ 2044 const DeviceMemory<double> &b, int ldb, int64 stride_b, double beta, \ 2045 DeviceMemory<double> *c, int ldc, int64 stride_c, int batch_count); \ 2046 bool DoBlasGemmStridedBatched( \ 2047 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2048 uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \ 2049 const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a, \ 2050 const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b, \ 2051 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \ 2052 int64 stride_c, int batch_count); \ 2053 bool DoBlasGemmStridedBatched( \ 2054 Stream *stream, blas::Transpose transa, blas::Transpose transb, \ 2055 uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \ 2056 const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a, \ 2057 const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b, \ 2058 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \ 2059 int ldc, int64 stride_c, int batch_count); \ 2060 bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2061 uint64 m, uint64 n, std::complex<float> alpha, \ 2062 const DeviceMemory<std::complex<float>> &a, int lda, \ 2063 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2064 std::complex<float> beta, \ 2065 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2066 bool DoBlasHemm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2067 uint64 m, uint64 n, std::complex<double> alpha, \ 2068 const DeviceMemory<std::complex<double>> &a, int lda, \ 2069 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2070 std::complex<double> beta, \ 2071 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2072 bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ 2073 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2074 const DeviceMemory<std::complex<float>> &a, int lda, \ 2075 float beta, DeviceMemory<std::complex<float>> *c, int ldc) \ 2076 override; \ 2077 bool DoBlasHerk(Stream *stream, blas::UpperLower uplo, \ 2078 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2079 const DeviceMemory<std::complex<double>> &a, int lda, \ 2080 double beta, DeviceMemory<std::complex<double>> *c, int ldc) \ 2081 override; \ 2082 bool DoBlasHer2k( \ 2083 Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ 2084 uint64 k, std::complex<float> alpha, \ 2085 const DeviceMemory<std::complex<float>> &a, int lda, \ 2086 const DeviceMemory<std::complex<float>> &b, int ldb, float beta, \ 2087 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2088 bool DoBlasHer2k( \ 2089 Stream *stream, blas::UpperLower uplo, blas::Transpose trans, uint64 n, \ 2090 uint64 k, std::complex<double> alpha, \ 2091 const DeviceMemory<std::complex<double>> &a, int lda, \ 2092 const DeviceMemory<std::complex<double>> &b, int ldb, double beta, \ 2093 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2094 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2095 uint64 m, uint64 n, float alpha, \ 2096 const DeviceMemory<float> &a, int lda, \ 2097 const DeviceMemory<float> &b, int ldb, float beta, \ 2098 DeviceMemory<float> *c, int ldc) override; \ 2099 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2100 uint64 m, uint64 n, double alpha, \ 2101 const DeviceMemory<double> &a, int lda, \ 2102 const DeviceMemory<double> &b, int ldb, double beta, \ 2103 DeviceMemory<double> *c, int ldc) override; \ 2104 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2105 uint64 m, uint64 n, std::complex<float> alpha, \ 2106 const DeviceMemory<std::complex<float>> &a, int lda, \ 2107 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2108 std::complex<float> beta, \ 2109 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2110 bool DoBlasSymm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2111 uint64 m, uint64 n, std::complex<double> alpha, \ 2112 const DeviceMemory<std::complex<double>> &a, int lda, \ 2113 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2114 std::complex<double> beta, \ 2115 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2116 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2117 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2118 const DeviceMemory<float> &a, int lda, float beta, \ 2119 DeviceMemory<float> *c, int ldc) override; \ 2120 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2121 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2122 const DeviceMemory<double> &a, int lda, double beta, \ 2123 DeviceMemory<double> *c, int ldc) override; \ 2124 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2125 blas::Transpose trans, uint64 n, uint64 k, \ 2126 std::complex<float> alpha, \ 2127 const DeviceMemory<std::complex<float>> &a, int lda, \ 2128 std::complex<float> beta, \ 2129 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2130 bool DoBlasSyrk(Stream *stream, blas::UpperLower uplo, \ 2131 blas::Transpose trans, uint64 n, uint64 k, \ 2132 std::complex<double> alpha, \ 2133 const DeviceMemory<std::complex<double>> &a, int lda, \ 2134 std::complex<double> beta, \ 2135 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2136 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2137 blas::Transpose trans, uint64 n, uint64 k, float alpha, \ 2138 const DeviceMemory<float> &a, int lda, \ 2139 const DeviceMemory<float> &b, int ldb, float beta, \ 2140 DeviceMemory<float> *c, int ldc) override; \ 2141 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2142 blas::Transpose trans, uint64 n, uint64 k, double alpha, \ 2143 const DeviceMemory<double> &a, int lda, \ 2144 const DeviceMemory<double> &b, int ldb, double beta, \ 2145 DeviceMemory<double> *c, int ldc) override; \ 2146 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2147 blas::Transpose trans, uint64 n, uint64 k, \ 2148 std::complex<float> alpha, \ 2149 const DeviceMemory<std::complex<float>> &a, int lda, \ 2150 const DeviceMemory<std::complex<float>> &b, int ldb, \ 2151 std::complex<float> beta, \ 2152 DeviceMemory<std::complex<float>> *c, int ldc) override; \ 2153 bool DoBlasSyr2k(Stream *stream, blas::UpperLower uplo, \ 2154 blas::Transpose trans, uint64 n, uint64 k, \ 2155 std::complex<double> alpha, \ 2156 const DeviceMemory<std::complex<double>> &a, int lda, \ 2157 const DeviceMemory<std::complex<double>> &b, int ldb, \ 2158 std::complex<double> beta, \ 2159 DeviceMemory<std::complex<double>> *c, int ldc) override; \ 2160 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2161 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2162 uint64 n, float alpha, const DeviceMemory<float> &a, \ 2163 int lda, DeviceMemory<float> *b, int ldb) override; \ 2164 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2165 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2166 uint64 n, double alpha, const DeviceMemory<double> &a, \ 2167 int lda, DeviceMemory<double> *b, int ldb) override; \ 2168 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2169 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2170 uint64 n, std::complex<float> alpha, \ 2171 const DeviceMemory<std::complex<float>> &a, int lda, \ 2172 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 2173 bool DoBlasTrmm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2174 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2175 uint64 n, std::complex<double> alpha, \ 2176 const DeviceMemory<std::complex<double>> &a, int lda, \ 2177 DeviceMemory<std::complex<double>> *b, int ldb) override; \ 2178 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2179 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2180 uint64 n, float alpha, const DeviceMemory<float> &a, \ 2181 int lda, DeviceMemory<float> *b, int ldb) override; \ 2182 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2183 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2184 uint64 n, double alpha, const DeviceMemory<double> &a, \ 2185 int lda, DeviceMemory<double> *b, int ldb) override; \ 2186 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2187 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2188 uint64 n, std::complex<float> alpha, \ 2189 const DeviceMemory<std::complex<float>> &a, int lda, \ 2190 DeviceMemory<std::complex<float>> *b, int ldb) override; \ 2191 bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, \ 2192 blas::Transpose transa, blas::Diagonal diag, uint64 m, \ 2193 uint64 n, std::complex<double> alpha, \ 2194 const DeviceMemory<std::complex<double>> &a, int lda, \ 2195 DeviceMemory<std::complex<double>> *b, int ldb) override; 2196 2197 } // namespace blas 2198 } // namespace stream_executor 2199 2200 #endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ 2201