1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/stream_executor/stream.h" 17 18 #include "tensorflow/stream_executor/platform/port.h" 19 20 #include "tensorflow/stream_executor/blas.h" 21 #include "tensorflow/stream_executor/host_buffer.h" 22 #include "tensorflow/stream_executor/lib/stacktrace.h" 23 #include "tensorflow/stream_executor/lib/strcat.h" 24 #include "tensorflow/stream_executor/platform.h" 25 #include "tensorflow/stream_executor/platform/logging.h" 26 #include "tensorflow/stream_executor/rng.h" 27 #include "tensorflow/stream_executor/stream_executor_internal.h" 28 #include "tensorflow/stream_executor/stream_executor_pimpl.h" 29 30 namespace perftools { 31 namespace gputools { 32 33 namespace { 34 // Code to turn parameters to functions on stream into strings that 35 // will be VLOG'ed. We need overloads, instead of 36 // e.g. BatchDescriptorToVlogString(), as the code that calls these 37 // functions does not know what the type of the parameter is. 38 string ToVlogString(const dnn::BatchDescriptor &descriptor) { 39 return descriptor.ToShortString(); 40 } 41 42 string ToVlogString(const dnn::FilterDescriptor &descriptor) { 43 return descriptor.ToShortString(); 44 } 45 46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) { 47 return descriptor.ToShortString(); 48 } 49 50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) { 51 return descriptor.ToShortString(); 52 } 53 54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) { 55 return descriptor.ToShortString(); 56 } 57 58 string ToVlogString(dnn::ActivationMode mode) { 59 return dnn::ActivationModeString(mode); 60 } 61 62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) { 63 return algo_config.ToString(); 64 } 65 66 string ToVlogString(dnn::ElementwiseOperation op) { 67 return dnn::ElementwiseOperationString(op); 68 } 69 70 string ToVlogString(dnn::QuantizedActivationMode mode) { 71 return dnn::QuantizedActivationModeString(mode); 72 } 73 74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); } 75 76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); } 77 78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); } 79 80 string ToVlogString(blas::Side s) { return blas::SideString(s); } 81 82 string ToVlogString(blas::ComputationType ty) { 83 return blas::ComputationTypeString(ty); 84 } 85 86 string ToVlogString(const void *ptr) { 87 if (ptr == nullptr) { 88 return "null"; 89 } 90 91 // StrCat does not convert pointers to text. 92 std::ostringstream out; 93 out << ptr; 94 return out.str(); 95 } 96 97 string ToVlogString(const HostBuffer &buffer) { return buffer.AsString(); } 98 99 template <class T> 100 string ToVlogString(const std::complex<T> &c) { 101 // StrCat does not convert std::complex to text. 102 std::ostringstream out; 103 out << c; 104 return out.str(); 105 } 106 107 template <class T> 108 string ToVlogString(const std::function<T> &f) { 109 return f == nullptr ? "null" : "<non-null function>"; 110 } 111 112 string ToVlogString(const DeviceMemoryBase &memory) { 113 return ToVlogString(memory.opaque()); 114 } 115 116 string ToVlogString(const DeviceMemoryBase *memory) { 117 return ToVlogString(*memory); 118 } 119 120 string ToVlogString(const Eigen::half &h) { return port::StrCat(h); } 121 122 string ToVlogString(int i) { return port::StrCat(i); } 123 124 string ToVlogString(uint32 i) { return port::StrCat(i); } 125 126 string ToVlogString(uint64 i) { return port::StrCat(i); } 127 128 string ToVlogString(int64 i) { return port::StrCat(i); } 129 130 string ToVlogString(float f) { return port::StrCat(f); } 131 132 string ToVlogString(double d) { return port::StrCat(d); } 133 134 template <class T> 135 string ToVlogString(port::ArraySlice<T> elements) { 136 string str = port::StrCat( 137 ToVlogString(reinterpret_cast<const void *>(elements.data())), "[", 138 elements.size(), "]{"); 139 const char *separator = ""; 140 size_t max_to_show = std::numeric_limits<size_t>::max(); 141 if (!VLOG_IS_ON(2)) { 142 max_to_show = 5; 143 } else if (!VLOG_IS_ON(3)) { 144 max_to_show = 20; 145 } else if (!VLOG_IS_ON(11)) { 146 max_to_show = 1000; 147 } 148 for (size_t i = 0; i < elements.size(); ++i) { 149 if (i == max_to_show) { 150 str += ", ..."; 151 break; 152 } 153 port::StrAppend(&str, separator, ToVlogString(elements[i])); 154 separator = ", "; 155 } 156 str += "}"; 157 return str; 158 } 159 160 template <class T> 161 string ToVlogString(port::MutableArraySlice<T> elements) { 162 return ToVlogString(port::ArraySlice<T>(elements)); 163 } 164 165 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) { 166 switch (depth_to_space_layout) { 167 case dnn::DepthToSpaceLayout::DepthHeightWidth: 168 return "DepthToSpaceLayout::DepthHeightWidth"; 169 } 170 return "unknown DepthToSpaceLayout"; 171 } 172 173 string ToVlogString(dnn::DataType data_type) { 174 switch (data_type) { 175 case dnn::DataType::kFloat: 176 return "dnn::DataType::kFloat"; 177 case dnn::DataType::kDouble: 178 return "dnn::DataType::kDouble"; 179 case dnn::DataType::kHalf: 180 return "dnn::DataType::kHalf"; 181 case dnn::DataType::kInt8: 182 return "dnn::DataType::kInt8"; 183 } 184 } 185 186 // Used together with PARAM to VLOG calls made to the stream. Intended 187 // to be used like this: 188 // 189 // VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)}); 190 // 191 // where a and b are the parameters to MyFunction. 192 // 193 // See VLOG_CALL for a short-hand for this. This way of doing it saves 194 // a tremendous amount of boilerplate code given how many functions 195 // there are on Stream and how many parameters they each have. 196 string CallStr(const char *function_name, Stream *stream, 197 std::vector<std::pair<const char *, string>> params) { 198 // Do not call this function unless VLOG is on since just 199 // constructing all the strings in params is expensive. 200 CHECK(VLOG_IS_ON(1)); 201 202 string str = port::StrCat("Called Stream::", function_name, "("); 203 const char *separator = ""; 204 for (const auto ¶m : params) { 205 port::StrAppend(&str, separator, param.first, "=", param.second); 206 separator = ", "; 207 } 208 port::StrAppend(&str, ") stream=", ToVlogString(stream)); 209 if (VLOG_IS_ON(10)) { 210 port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); 211 } 212 return str; 213 } 214 215 // Use this macro to avoid having to type every parameter twice to log 216 // it with VLOG and CallStr. 217 #define PARAM(parameter) \ 218 { #parameter, ToVlogString(parameter) } 219 220 // Use this macro to avoid having to type out the name of each 221 // function and to save some boilerplate. Intended to be used like this: 222 // 223 // VLOG_CALL(PARAM(a), PARAM(b)) 224 // 225 // This saves a tremendous amount of boilerplate compared to the alternative: 226 // 227 // VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a) 228 // << ", b=" << ToVlogString(b); 229 // 230 // Note here that most of the parameter names are not short and that 231 // most of the functions take many more than 2 parameters. 232 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__}) 233 234 } // namespace 235 236 Stream::Stream(StreamExecutor *parent) 237 : parent_(parent), 238 implementation_(parent->implementation()->GetStreamImplementation()), 239 allocated_(false), 240 ok_(false), 241 temporary_memory_manager_(this) { 242 VLOG_CALL(PARAM(parent)); 243 } 244 245 Stream::Stream(StreamExecutor *parent, 246 internal::StreamInterface *implementation) 247 : parent_(parent), 248 implementation_(implementation), 249 allocated_(false), 250 ok_(false), 251 temporary_memory_manager_(this) { 252 VLOG_CALL(PARAM(parent), PARAM(implementation)); 253 } 254 255 Stream::~Stream() { 256 VLOG_CALL(); 257 258 temporary_memory_manager_.ForceDeallocateAll(); 259 260 if (allocated_) { 261 parent_->DeallocateStream(this); 262 } 263 } 264 265 Stream &Stream::Init() { 266 VLOG_CALL(); 267 268 mutex_lock lock{mu_}; 269 CHECK_EQ(false, allocated_) 270 << "stream appears to already have been initialized"; 271 CHECK(!ok_) << "stream should be in !ok() state pre-initialization"; 272 273 if (parent_->AllocateStream(this)) { 274 // Successful initialization! 275 allocated_ = true; 276 ok_ = true; 277 } else { 278 LOG(ERROR) << "failed to allocate stream during initialization"; 279 } 280 281 return *this; 282 } 283 284 Stream &Stream::InitTimer(Timer *timer) { 285 VLOG_CALL(PARAM(timer)); 286 287 if (ok()) { 288 CheckError(parent_->AllocateTimer(timer)); 289 } else { 290 LOG(INFO) << "did not allocate timer: " << timer; 291 } 292 return *this; 293 } 294 295 Stream &Stream::InitWithTimer(Timer *timer) { 296 VLOG_CALL(PARAM(timer)); 297 298 return Init().InitTimer(timer); 299 } 300 301 Stream &Stream::ThenRecordEvent(Event *event) { 302 VLOG_CALL(PARAM(event)); 303 304 port::Status status = parent_->RecordEvent(this, event); 305 if (!status.ok()) { 306 LOG(ERROR) << "Error recording event in stream: " << status.error_message() 307 << "; not marking stream as bad, as the Event object may be " 308 << "at fault. Monitor for further errors."; 309 } 310 311 return *this; 312 } 313 314 Stream &Stream::ThenBatchNormalizationForward( 315 const DeviceMemory<float> &x, const DeviceMemory<float> &scale, 316 const DeviceMemory<float> &offset, 317 const DeviceMemory<float> &estimated_mean, 318 const DeviceMemory<float> &estimated_variance, 319 const dnn::BatchDescriptor &x_desc, 320 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, 321 DeviceMemory<float> *y, DeviceMemory<float> *batch_mean, 322 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean, 323 DeviceMemory<float> *saved_inv_var, bool is_training, 324 std::function<const DeviceMemory<float> &()> var_to_inv_var, 325 std::function<void()> inv_var_to_var) { 326 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), 327 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); 328 if (ok()) { 329 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 330 CheckError(dnn->DoBatchNormalizationForward( 331 this, x, scale, offset, estimated_mean, estimated_variance, x_desc, 332 scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean, 333 saved_inv_var, is_training, std::move(var_to_inv_var), 334 std::move(inv_var_to_var))); 335 } else { 336 SetErrorAndLogNoDnnSupport(); 337 } 338 } 339 return *this; 340 } 341 342 Stream &Stream::ThenBatchNormalizationBackward( 343 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x, 344 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean, 345 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc, 346 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, 347 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop, 348 DeviceMemory<float> *offset_backprop) { 349 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), 350 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), 351 PARAM(scale_backprop), PARAM(offset_backprop)); 352 if (ok()) { 353 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 354 CheckError(dnn->DoBatchNormalizationBackward( 355 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc, 356 epsilon, x_backprop, scale_backprop, offset_backprop)); 357 } else { 358 SetErrorAndLogNoDnnSupport(); 359 } 360 } 361 return *this; 362 } 363 364 Stream &Stream::ThenBatchNormalizationForward( 365 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, 366 const DeviceMemory<float> &offset, 367 const DeviceMemory<float> &estimated_mean, 368 const DeviceMemory<float> &estimated_variance, 369 const dnn::BatchDescriptor &x_desc, 370 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, 371 DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean, 372 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean, 373 DeviceMemory<float> *saved_inv_var, bool is_training, 374 std::function<const DeviceMemory<float> &()> var_to_inv_var, 375 std::function<void()> inv_var_to_var) { 376 VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc), 377 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y)); 378 if (ok()) { 379 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 380 CheckError(dnn->DoBatchNormalizationForward( 381 this, x, scale, offset, estimated_mean, estimated_variance, x_desc, 382 scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean, 383 saved_inv_var, is_training, std::move(var_to_inv_var), 384 std::move(inv_var_to_var))); 385 } else { 386 SetErrorAndLogNoDnnSupport(); 387 } 388 } 389 return *this; 390 } 391 392 Stream &Stream::ThenBatchNormalizationBackward( 393 const DeviceMemory<Eigen::half> &y_backprop, 394 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, 395 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var, 396 const dnn::BatchDescriptor &x_desc, 397 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, 398 DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop, 399 DeviceMemory<float> *offset_backprop) { 400 VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc), 401 PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop), 402 PARAM(scale_backprop), PARAM(offset_backprop)); 403 if (ok()) { 404 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 405 CheckError(dnn->DoBatchNormalizationBackward( 406 this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc, 407 epsilon, x_backprop, scale_backprop, offset_backprop)); 408 } else { 409 SetErrorAndLogNoDnnSupport(); 410 } 411 } 412 return *this; 413 } 414 415 Stream &Stream::ThenFusedConvolveWithScratch( 416 const dnn::BatchDescriptor &conv_input_descriptor, 417 const DeviceMemory<int8> &conv_input_data, float conv_input_scale, 418 const dnn::FilterDescriptor &filter_descriptor, 419 const DeviceMemory<int8> &filter_data, 420 const dnn::ConvolutionDescriptor &convolution_descriptor, 421 const DeviceMemory<int8> &side_input_data, float side_input_scale, 422 const dnn::BatchDescriptor &bias_descriptor, 423 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 424 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output, 425 ScratchAllocator *scratch_allocator) { 426 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), 427 PARAM(conv_input_scale), PARAM(filter_descriptor), 428 PARAM(filter_data), PARAM(convolution_descriptor), 429 PARAM(side_input_data), PARAM(side_input_scale), 430 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode), 431 PARAM(output_descriptor), PARAM(output)); 432 433 if (ok()) { 434 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 435 CheckError(dnn->DoFusedConvolve( 436 this, conv_input_descriptor, conv_input_data, conv_input_scale, 437 filter_descriptor, filter_data, convolution_descriptor, 438 side_input_data, side_input_scale, bias_descriptor, biases, 439 activation_mode, output_descriptor, output, scratch_allocator, 440 dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr)); 441 } else { 442 SetErrorAndLogNoDnnSupport(); 443 } 444 } 445 return *this; 446 } 447 448 Stream &Stream::ThenFusedConvolveWithScratch( 449 const dnn::BatchDescriptor &conv_input_descriptor, 450 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale, 451 const dnn::FilterDescriptor &filter_descriptor, 452 const DeviceMemory<Eigen::half> &filter_data, 453 const dnn::ConvolutionDescriptor &convolution_descriptor, 454 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale, 455 const dnn::BatchDescriptor &bias_descriptor, 456 const DeviceMemory<Eigen::half> &biases, 457 dnn::ActivationMode activation_mode, 458 const dnn::BatchDescriptor &output_descriptor, 459 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) { 460 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), 461 PARAM(conv_input_scale), PARAM(filter_descriptor), 462 PARAM(filter_data), PARAM(convolution_descriptor), 463 PARAM(side_input_data), PARAM(side_input_scale), 464 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode), 465 PARAM(output_descriptor), PARAM(output)); 466 467 if (ok()) { 468 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 469 CheckError(dnn->DoFusedConvolve( 470 this, conv_input_descriptor, conv_input_data, conv_input_scale, 471 filter_descriptor, filter_data, convolution_descriptor, 472 side_input_data, side_input_scale, bias_descriptor, biases, 473 activation_mode, output_descriptor, output, scratch_allocator, 474 dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr)); 475 } else { 476 SetErrorAndLogNoDnnSupport(); 477 } 478 } 479 return *this; 480 } 481 482 Stream &Stream::ThenFusedConvolveWithScratch( 483 const dnn::BatchDescriptor &conv_input_descriptor, 484 const DeviceMemory<float> &conv_input_data, float conv_input_scale, 485 const dnn::FilterDescriptor &filter_descriptor, 486 const DeviceMemory<float> &filter_data, 487 const dnn::ConvolutionDescriptor &convolution_descriptor, 488 const DeviceMemory<float> &side_input_data, float side_input_scale, 489 const dnn::BatchDescriptor &bias_descriptor, 490 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 491 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, 492 ScratchAllocator *scratch_allocator) { 493 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), 494 PARAM(conv_input_scale), PARAM(filter_descriptor), 495 PARAM(filter_data), PARAM(convolution_descriptor), 496 PARAM(side_input_data), PARAM(side_input_scale), 497 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode), 498 PARAM(output_descriptor), PARAM(output)); 499 500 if (ok()) { 501 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 502 CheckError(dnn->DoFusedConvolve( 503 this, conv_input_descriptor, conv_input_data, conv_input_scale, 504 filter_descriptor, filter_data, convolution_descriptor, 505 side_input_data, side_input_scale, bias_descriptor, biases, 506 activation_mode, output_descriptor, output, scratch_allocator, 507 dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr)); 508 } else { 509 SetErrorAndLogNoDnnSupport(); 510 } 511 } 512 return *this; 513 } 514 515 Stream &Stream::ThenConvolveWithScratch( 516 const dnn::BatchDescriptor &input_descriptor, 517 const DeviceMemory<Eigen::half> &input_data, 518 const dnn::FilterDescriptor &filter_descriptor, 519 const DeviceMemory<Eigen::half> &filter_data, 520 const dnn::ConvolutionDescriptor &convolution_descriptor, 521 const dnn::BatchDescriptor &output_descriptor, 522 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) { 523 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 524 PARAM(filter_descriptor), PARAM(filter_data), 525 PARAM(convolution_descriptor), PARAM(output_descriptor), 526 PARAM(output)); 527 528 if (ok()) { 529 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 530 CheckError(dnn->DoConvolve( 531 this, input_descriptor, input_data, filter_descriptor, filter_data, 532 convolution_descriptor, output_descriptor, output, scratch_allocator, 533 dnn::AlgorithmConfig(), 534 /*output_profile_result=*/nullptr)); 535 } else { 536 SetErrorAndLogNoDnnSupport(); 537 } 538 } 539 return *this; 540 } 541 542 Stream &Stream::ThenConvolveWithScratch( 543 const dnn::BatchDescriptor &input_descriptor, 544 const DeviceMemory<float> &input_data, 545 const dnn::FilterDescriptor &filter_descriptor, 546 const DeviceMemory<float> &filter_data, 547 const dnn::ConvolutionDescriptor &convolution_descriptor, 548 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, 549 ScratchAllocator *scratch_allocator) { 550 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 551 PARAM(filter_descriptor), PARAM(filter_data), 552 PARAM(convolution_descriptor), PARAM(output_descriptor), 553 PARAM(output)); 554 555 if (ok()) { 556 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 557 CheckError(dnn->DoConvolve( 558 this, input_descriptor, input_data, filter_descriptor, filter_data, 559 convolution_descriptor, output_descriptor, output, scratch_allocator, 560 dnn::AlgorithmConfig(), 561 /*output_profile_result=*/nullptr)); 562 } else { 563 SetErrorAndLogNoDnnSupport(); 564 } 565 } 566 return *this; 567 } 568 569 Stream &Stream::ThenFusedConvolveWithAlgorithm( 570 const dnn::BatchDescriptor &conv_input_descriptor, 571 const DeviceMemory<float> &conv_input_data, float conv_input_scale, 572 const dnn::FilterDescriptor &filter_descriptor, 573 const DeviceMemory<float> &filter_data, 574 const dnn::ConvolutionDescriptor &convolution_descriptor, 575 const DeviceMemory<float> &side_input_data, float side_input_scale, 576 const dnn::BatchDescriptor &bias_descriptor, 577 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 578 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, 579 ScratchAllocator *scratch_allocator, 580 const dnn::AlgorithmConfig &algorithm_config, 581 dnn::ProfileResult *output_profile_result) { 582 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), 583 PARAM(conv_input_scale), PARAM(filter_descriptor), 584 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), 585 PARAM(side_input_data), PARAM(side_input_scale), 586 PARAM(activation_mode), PARAM(output_descriptor), PARAM(output), 587 PARAM(algorithm_config)); 588 589 if (ok()) { 590 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 591 auto status = dnn->DoFusedConvolve( 592 this, conv_input_descriptor, conv_input_data, conv_input_scale, 593 filter_descriptor, filter_data, convolution_descriptor, 594 side_input_data, side_input_scale, bias_descriptor, biases, 595 activation_mode, output_descriptor, output, scratch_allocator, 596 algorithm_config, output_profile_result); 597 if (!status && !output_profile_result) { 598 SetError(); 599 } 600 } else { 601 SetErrorAndLogNoDnnSupport(); 602 } 603 } 604 return *this; 605 } 606 607 Stream &Stream::ThenFusedConvolveWithAlgorithm( 608 const dnn::BatchDescriptor &conv_input_descriptor, 609 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale, 610 const dnn::FilterDescriptor &filter_descriptor, 611 const DeviceMemory<Eigen::half> &filter_data, 612 const dnn::ConvolutionDescriptor &convolution_descriptor, 613 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale, 614 const dnn::BatchDescriptor &bias_descriptor, 615 const DeviceMemory<Eigen::half> &biases, 616 dnn::ActivationMode activation_mode, 617 const dnn::BatchDescriptor &output_descriptor, 618 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator, 619 const dnn::AlgorithmConfig &algorithm_config, 620 dnn::ProfileResult *output_profile_result) { 621 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), 622 PARAM(conv_input_scale), PARAM(filter_descriptor), 623 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), 624 PARAM(side_input_data), PARAM(side_input_scale), 625 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode), 626 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config)); 627 628 if (ok()) { 629 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 630 auto status = dnn->DoFusedConvolve( 631 this, conv_input_descriptor, conv_input_data, conv_input_scale, 632 filter_descriptor, filter_data, convolution_descriptor, 633 side_input_data, side_input_scale, bias_descriptor, biases, 634 activation_mode, output_descriptor, output, scratch_allocator, 635 algorithm_config, output_profile_result); 636 if (!status && !output_profile_result) { 637 SetError(); 638 } 639 } else { 640 SetErrorAndLogNoDnnSupport(); 641 } 642 } 643 return *this; 644 } 645 646 Stream &Stream::ThenFusedConvolveWithAlgorithm( 647 const dnn::BatchDescriptor &conv_input_descriptor, 648 const DeviceMemory<int8> &conv_input_data, float conv_input_scale, 649 const dnn::FilterDescriptor &filter_descriptor, 650 const DeviceMemory<int8> &filter_data, 651 const dnn::ConvolutionDescriptor &convolution_descriptor, 652 const DeviceMemory<int8> &side_input_data, float side_input_scale, 653 const dnn::BatchDescriptor &bias_descriptor, 654 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 655 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output, 656 ScratchAllocator *scratch_allocator, 657 const dnn::AlgorithmConfig &algorithm_config, 658 dnn::ProfileResult *output_profile_result) { 659 VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data), 660 PARAM(conv_input_scale), PARAM(filter_descriptor), 661 PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases), 662 PARAM(side_input_data), PARAM(side_input_scale), 663 PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode), 664 PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config)); 665 666 if (ok()) { 667 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 668 auto status = dnn->DoFusedConvolve( 669 this, conv_input_descriptor, conv_input_data, conv_input_scale, 670 filter_descriptor, filter_data, convolution_descriptor, 671 side_input_data, side_input_scale, bias_descriptor, biases, 672 activation_mode, output_descriptor, output, scratch_allocator, 673 algorithm_config, output_profile_result); 674 if (!status && !output_profile_result) { 675 SetError(); 676 } 677 } else { 678 SetErrorAndLogNoDnnSupport(); 679 } 680 } 681 return *this; 682 } 683 684 Stream &Stream::ThenConvolveWithAlgorithm( 685 const dnn::BatchDescriptor &input_descriptor, 686 const DeviceMemory<float> &input_data, 687 const dnn::FilterDescriptor &filter_descriptor, 688 const DeviceMemory<float> &filter_data, 689 const dnn::ConvolutionDescriptor &convolution_descriptor, 690 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, 691 ScratchAllocator *scratch_allocator, 692 const dnn::AlgorithmConfig &algorithm_config, 693 dnn::ProfileResult *output_profile_result) { 694 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 695 PARAM(filter_descriptor), PARAM(filter_data), 696 PARAM(convolution_descriptor), PARAM(output_descriptor), 697 PARAM(output), PARAM(algorithm_config)); 698 699 if (ok()) { 700 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 701 auto status = dnn->DoConvolve( 702 this, input_descriptor, input_data, filter_descriptor, filter_data, 703 convolution_descriptor, output_descriptor, output, scratch_allocator, 704 algorithm_config, output_profile_result); 705 if (!status && !output_profile_result) { 706 SetError(); 707 } 708 } else { 709 SetErrorAndLogNoDnnSupport(); 710 } 711 } 712 return *this; 713 } 714 715 Stream &Stream::ThenConvolveWithAlgorithm( 716 const dnn::BatchDescriptor &input_descriptor, 717 const DeviceMemory<Eigen::half> &input_data, 718 const dnn::FilterDescriptor &filter_descriptor, 719 const DeviceMemory<Eigen::half> &filter_data, 720 const dnn::ConvolutionDescriptor &convolution_descriptor, 721 const dnn::BatchDescriptor &output_descriptor, 722 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator, 723 const dnn::AlgorithmConfig &algorithm_config, 724 dnn::ProfileResult *output_profile_result) { 725 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 726 PARAM(filter_descriptor), PARAM(filter_data), 727 PARAM(convolution_descriptor), PARAM(output_descriptor), 728 PARAM(output), PARAM(algorithm_config)); 729 730 if (ok()) { 731 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 732 auto status = dnn->DoConvolve( 733 this, input_descriptor, input_data, filter_descriptor, filter_data, 734 convolution_descriptor, output_descriptor, output, scratch_allocator, 735 algorithm_config, output_profile_result); 736 if (!status && !output_profile_result) { 737 SetError(); 738 } 739 } else { 740 SetErrorAndLogNoDnnSupport(); 741 } 742 } 743 return *this; 744 } 745 746 Stream &Stream::ThenFusedConvolve( 747 const dnn::BatchDescriptor &conv_input_descriptor, 748 const DeviceMemory<int8> &conv_input_data, float conv_input_scale, 749 const dnn::FilterDescriptor &filter_descriptor, 750 const DeviceMemory<int8> &filter_data, 751 const dnn::ConvolutionDescriptor &convolution_descriptor, 752 const DeviceMemory<int8> &side_input_data, float side_input_scale, 753 const dnn::BatchDescriptor &bias_descriptor, 754 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 755 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output) { 756 return ThenFusedConvolveWithScratch( 757 conv_input_descriptor, conv_input_data, conv_input_scale, 758 filter_descriptor, filter_data, convolution_descriptor, side_input_data, 759 side_input_scale, bias_descriptor, biases, activation_mode, 760 output_descriptor, output, 761 /*scratch_allocator=*/nullptr); 762 } 763 764 Stream &Stream::ThenConvolve( 765 const dnn::BatchDescriptor &input_descriptor, 766 const DeviceMemory<float> &input_data, 767 const dnn::FilterDescriptor &filter_descriptor, 768 const DeviceMemory<float> &filter_data, 769 const dnn::ConvolutionDescriptor &convolution_descriptor, 770 const dnn::BatchDescriptor &output_descriptor, 771 DeviceMemory<float> *output) { 772 return ThenConvolveWithScratch(input_descriptor, input_data, 773 filter_descriptor, filter_data, 774 convolution_descriptor, output_descriptor, 775 output, /*scratch_allocator=*/nullptr); 776 } 777 778 Stream &Stream::ThenConvolveQuantized( 779 const dnn::BatchDescriptor &input_descriptor, 780 const DeviceMemory<float> &input_data, 781 const dnn::FilterDescriptor &filter_descriptor, 782 const DeviceMemory<int8> &filter_coefficients, 783 const DeviceMemory<float> &coefficient_scales, 784 const dnn::ConvolutionDescriptor &convolution_descriptor, 785 const dnn::BatchDescriptor &output_descriptor, 786 DeviceMemory<float> *output) { 787 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 788 PARAM(filter_descriptor), PARAM(filter_coefficients), 789 PARAM(coefficient_scales), PARAM(convolution_descriptor), 790 PARAM(output_descriptor), PARAM(output)); 791 792 if (ok()) { 793 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 794 CheckError(dnn->DoConvolveQuantized( 795 this, input_descriptor, input_data, filter_descriptor, 796 filter_coefficients, coefficient_scales, convolution_descriptor, 797 output_descriptor, output)); 798 } else { 799 SetError(); 800 LOG(WARNING) 801 << "attempting to perform DNN operation using StreamExecutor " 802 "without DNN support"; 803 } 804 } 805 return *this; 806 } 807 808 Stream &Stream::ThenConvolveQuantized( 809 const dnn::BatchDescriptor &input_descriptor, 810 const DeviceMemory<float> &input_data, 811 const dnn::FilterDescriptor &filter_descriptor, 812 const DeviceMemory<int16> &filter_coefficients, 813 const DeviceMemory<float> &coefficient_scales, 814 const dnn::ConvolutionDescriptor &convolution_descriptor, 815 const dnn::BatchDescriptor &output_descriptor, 816 DeviceMemory<float> *output) { 817 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 818 PARAM(filter_descriptor), PARAM(filter_coefficients), 819 PARAM(coefficient_scales), PARAM(convolution_descriptor), 820 PARAM(output_descriptor), PARAM(output)); 821 822 if (ok()) { 823 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 824 CheckError(dnn->DoConvolveQuantized( 825 this, input_descriptor, input_data, filter_descriptor, 826 filter_coefficients, coefficient_scales, convolution_descriptor, 827 output_descriptor, output)); 828 } else { 829 SetError(); 830 LOG(WARNING) 831 << "attempting to perform DNN operation using StreamExecutor " 832 "without DNN support"; 833 } 834 } 835 return *this; 836 } 837 838 Stream &Stream::ThenSeparableConvolve( 839 const dnn::BatchDescriptor &batch_descriptor, 840 const DeviceMemory<float> &input_data, 841 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier, 842 const DeviceMemory<float> &first_weights, 843 const DeviceMemory<float> &second_weights, 844 const dnn::ConvolutionDescriptor &convolution_descriptor, 845 const dnn::BatchDescriptor &output_descriptor, 846 DeviceMemory<float> *output) { 847 VLOG_CALL( 848 PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor), 849 PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights), 850 PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output)); 851 852 if (ok()) { 853 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 854 CheckError(dnn->DoSeparableConvolve( 855 this, batch_descriptor, input_data, filter_descriptor, 856 depth_multiplier, first_weights, second_weights, 857 convolution_descriptor, output_descriptor, output)); 858 } else { 859 SetErrorAndLogNoDnnSupport(); 860 } 861 } 862 return *this; 863 } 864 865 Stream &Stream::ThenConvolveBackwardDataWithScratch( 866 const dnn::FilterDescriptor &filter_descriptor, 867 const DeviceMemory<float> &filter_data, 868 const dnn::BatchDescriptor &output_descriptor, 869 DeviceMemory<float> backward_output_data, 870 const dnn::ConvolutionDescriptor &convolution_descriptor, 871 const dnn::BatchDescriptor &input_descriptor, 872 DeviceMemory<float> *backward_input_data, 873 ScratchAllocator *scratch_allocator) { 874 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data), 875 PARAM(output_descriptor), PARAM(backward_output_data), 876 PARAM(convolution_descriptor), PARAM(input_descriptor), 877 PARAM(backward_input_data)); 878 879 if (ok()) { 880 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 881 CheckError(dnn->DoConvolveBackwardData( 882 this, filter_descriptor, filter_data, output_descriptor, 883 backward_output_data, convolution_descriptor, input_descriptor, 884 backward_input_data, scratch_allocator, dnn::AlgorithmConfig(), 885 /*output_profile_result=*/nullptr)); 886 } else { 887 SetErrorAndLogNoDnnSupport(); 888 } 889 } 890 return *this; 891 } 892 893 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm( 894 const dnn::FilterDescriptor &filter_descriptor, 895 const DeviceMemory<float> &filter_data, 896 const dnn::BatchDescriptor &output_descriptor, 897 DeviceMemory<float> backward_output_data, 898 const dnn::ConvolutionDescriptor &convolution_descriptor, 899 const dnn::BatchDescriptor &input_descriptor, 900 DeviceMemory<float> *backward_input_data, 901 ScratchAllocator *scratch_allocator, 902 const dnn::AlgorithmConfig &algorithm_config, 903 dnn::ProfileResult *output_profile_result) { 904 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data), 905 PARAM(output_descriptor), PARAM(backward_output_data), 906 PARAM(convolution_descriptor), PARAM(input_descriptor), 907 PARAM(backward_input_data)); 908 909 if (ok()) { 910 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 911 auto status = dnn->DoConvolveBackwardData( 912 this, filter_descriptor, filter_data, output_descriptor, 913 backward_output_data, convolution_descriptor, input_descriptor, 914 backward_input_data, scratch_allocator, algorithm_config, 915 output_profile_result); 916 if (!status && !output_profile_result) { 917 SetError(); 918 } 919 } else { 920 SetErrorAndLogNoDnnSupport(); 921 } 922 } 923 return *this; 924 } 925 926 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm( 927 const dnn::FilterDescriptor &filter_descriptor, 928 const DeviceMemory<Eigen::half> &filter_data, 929 const dnn::BatchDescriptor &output_descriptor, 930 DeviceMemory<Eigen::half> backward_output_data, 931 const dnn::ConvolutionDescriptor &convolution_descriptor, 932 const dnn::BatchDescriptor &input_descriptor, 933 DeviceMemory<Eigen::half> *backward_input_data, 934 ScratchAllocator *scratch_allocator, 935 const dnn::AlgorithmConfig &algorithm_config, 936 dnn::ProfileResult *output_profile_result) { 937 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data), 938 PARAM(output_descriptor), PARAM(backward_output_data), 939 PARAM(convolution_descriptor), PARAM(input_descriptor), 940 PARAM(backward_input_data)); 941 942 if (ok()) { 943 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 944 auto status = dnn->DoConvolveBackwardData( 945 this, filter_descriptor, filter_data, output_descriptor, 946 backward_output_data, convolution_descriptor, input_descriptor, 947 backward_input_data, scratch_allocator, algorithm_config, 948 output_profile_result); 949 if (!status && !output_profile_result) { 950 SetError(); 951 } 952 } else { 953 SetErrorAndLogNoDnnSupport(); 954 } 955 } 956 return *this; 957 } 958 959 Stream &Stream::ThenConvolveBackwardDataWithScratch( 960 const dnn::FilterDescriptor &filter_descriptor, 961 const DeviceMemory<Eigen::half> &filter_data, 962 const dnn::BatchDescriptor &output_descriptor, 963 DeviceMemory<Eigen::half> backward_output_data, 964 const dnn::ConvolutionDescriptor &convolution_descriptor, 965 const dnn::BatchDescriptor &input_descriptor, 966 DeviceMemory<Eigen::half> *backward_input_data, 967 ScratchAllocator *scratch_allocator) { 968 VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data), 969 PARAM(output_descriptor), PARAM(backward_output_data), 970 PARAM(convolution_descriptor), PARAM(input_descriptor), 971 PARAM(backward_input_data)); 972 973 if (ok()) { 974 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 975 CheckError(dnn->DoConvolveBackwardData( 976 this, filter_descriptor, filter_data, output_descriptor, 977 backward_output_data, convolution_descriptor, input_descriptor, 978 backward_input_data, scratch_allocator, dnn::AlgorithmConfig(), 979 /*output_profile_result=*/nullptr)); 980 } else { 981 SetErrorAndLogNoDnnSupport(); 982 } 983 } 984 return *this; 985 } 986 987 Stream &Stream::ThenConvolveBackwardData( 988 const dnn::FilterDescriptor &filter_descriptor, 989 const DeviceMemory<float> &filter_data, 990 const dnn::BatchDescriptor &output_descriptor, 991 DeviceMemory<float> backward_output_data, 992 const dnn::ConvolutionDescriptor &convolution_descriptor, 993 const dnn::BatchDescriptor &input_descriptor, 994 DeviceMemory<float> *backward_input_data) { 995 return ThenConvolveBackwardDataWithScratch( 996 filter_descriptor, filter_data, output_descriptor, backward_output_data, 997 convolution_descriptor, input_descriptor, backward_input_data, 998 /*scratch_allocator=*/nullptr); 999 } 1000 1001 Stream &Stream::ThenConvolveBackwardFilterWithScratch( 1002 const dnn::BatchDescriptor &input_descriptor, 1003 const DeviceMemory<float> &input_data, 1004 const dnn::BatchDescriptor &output_descriptor, 1005 DeviceMemory<float> backward_output_data, 1006 const dnn::ConvolutionDescriptor &convolution_descriptor, 1007 const dnn::FilterDescriptor &filter_descriptor, 1008 DeviceMemory<float> *backward_filter_data, 1009 ScratchAllocator *scratch_allocator) { 1010 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 1011 PARAM(output_descriptor), PARAM(backward_output_data), 1012 PARAM(convolution_descriptor), PARAM(filter_descriptor), 1013 PARAM(backward_filter_data)); 1014 1015 if (ok()) { 1016 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1017 CheckError(dnn->DoConvolveBackwardFilter( 1018 this, input_descriptor, input_data, output_descriptor, 1019 backward_output_data, convolution_descriptor, filter_descriptor, 1020 backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(), 1021 /*output_profile_result=*/nullptr)); 1022 } else { 1023 SetErrorAndLogNoDnnSupport(); 1024 } 1025 } 1026 return *this; 1027 } 1028 1029 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm( 1030 const dnn::BatchDescriptor &input_descriptor, 1031 const DeviceMemory<float> &input_data, 1032 const dnn::BatchDescriptor &output_descriptor, 1033 DeviceMemory<float> backward_output_data, 1034 const dnn::ConvolutionDescriptor &convolution_descriptor, 1035 const dnn::FilterDescriptor &filter_descriptor, 1036 DeviceMemory<float> *backward_filter_data, 1037 ScratchAllocator *scratch_allocator, 1038 const dnn::AlgorithmConfig &algorithm_config, 1039 dnn::ProfileResult *output_profile_result) { 1040 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 1041 PARAM(output_descriptor), PARAM(backward_output_data), 1042 PARAM(convolution_descriptor), PARAM(filter_descriptor), 1043 PARAM(backward_filter_data)); 1044 1045 if (ok()) { 1046 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1047 auto status = dnn->DoConvolveBackwardFilter( 1048 this, input_descriptor, input_data, output_descriptor, 1049 backward_output_data, convolution_descriptor, filter_descriptor, 1050 backward_filter_data, scratch_allocator, algorithm_config, 1051 output_profile_result); 1052 if (!status && !output_profile_result) { 1053 SetError(); 1054 } 1055 } else { 1056 SetErrorAndLogNoDnnSupport(); 1057 } 1058 } 1059 return *this; 1060 } 1061 1062 Stream &Stream::ThenConvolveBackwardFilterWithScratch( 1063 const dnn::BatchDescriptor &input_descriptor, 1064 const DeviceMemory<Eigen::half> &input_data, 1065 const dnn::BatchDescriptor &output_descriptor, 1066 DeviceMemory<Eigen::half> backward_output_data, 1067 const dnn::ConvolutionDescriptor &convolution_descriptor, 1068 const dnn::FilterDescriptor &filter_descriptor, 1069 DeviceMemory<Eigen::half> *backward_filter_data, 1070 ScratchAllocator *scratch_allocator) { 1071 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 1072 PARAM(output_descriptor), PARAM(backward_output_data), 1073 PARAM(convolution_descriptor), PARAM(filter_descriptor), 1074 PARAM(backward_filter_data)); 1075 1076 if (ok()) { 1077 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1078 CheckError(dnn->DoConvolveBackwardFilter( 1079 this, input_descriptor, input_data, output_descriptor, 1080 backward_output_data, convolution_descriptor, filter_descriptor, 1081 backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(), 1082 /*output_profile_result=*/nullptr)); 1083 } else { 1084 SetErrorAndLogNoDnnSupport(); 1085 } 1086 } 1087 return *this; 1088 } 1089 1090 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm( 1091 const dnn::BatchDescriptor &input_descriptor, 1092 const DeviceMemory<Eigen::half> &input_data, 1093 const dnn::BatchDescriptor &output_descriptor, 1094 DeviceMemory<Eigen::half> backward_output_data, 1095 const dnn::ConvolutionDescriptor &convolution_descriptor, 1096 const dnn::FilterDescriptor &filter_descriptor, 1097 DeviceMemory<Eigen::half> *backward_filter_data, 1098 ScratchAllocator *scratch_allocator, 1099 const dnn::AlgorithmConfig &algorithm_config, 1100 dnn::ProfileResult *output_profile_result) { 1101 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), 1102 PARAM(output_descriptor), PARAM(backward_output_data), 1103 PARAM(convolution_descriptor), PARAM(filter_descriptor), 1104 PARAM(backward_filter_data)); 1105 1106 if (ok()) { 1107 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1108 auto status = dnn->DoConvolveBackwardFilter( 1109 this, input_descriptor, input_data, output_descriptor, 1110 backward_output_data, convolution_descriptor, filter_descriptor, 1111 backward_filter_data, scratch_allocator, algorithm_config, 1112 output_profile_result); 1113 if (!status && !output_profile_result) { 1114 SetError(); 1115 } 1116 } else { 1117 SetErrorAndLogNoDnnSupport(); 1118 } 1119 } 1120 return *this; 1121 } 1122 1123 Stream &Stream::ThenConvolveBackwardFilter( 1124 const dnn::BatchDescriptor &input_descriptor, 1125 const DeviceMemory<float> &input_data, 1126 const dnn::BatchDescriptor &output_descriptor, 1127 DeviceMemory<float> backward_output_data, 1128 const dnn::ConvolutionDescriptor &convolution_descriptor, 1129 const dnn::FilterDescriptor &filter_descriptor, 1130 DeviceMemory<float> *backward_filter_data) { 1131 return ThenConvolveBackwardFilterWithScratch( 1132 input_descriptor, input_data, output_descriptor, backward_output_data, 1133 convolution_descriptor, filter_descriptor, backward_filter_data, 1134 /*scratch_allocator=*/nullptr); 1135 } 1136 1137 template <typename T> 1138 Stream &Stream::ThenConvolveBackwardBiasImpl( 1139 const dnn::BatchDescriptor &input_descriptor, 1140 const DeviceMemory<T> &input_data, 1141 const dnn::BatchDescriptor &bias_descriptor, 1142 DeviceMemory<T> *backward_bias_data) { 1143 VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor), 1144 PARAM(backward_bias_data)); 1145 1146 if (ok()) { 1147 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1148 CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data, 1149 bias_descriptor, 1150 backward_bias_data)); 1151 } else { 1152 SetErrorAndLogNoDnnSupport(); 1153 } 1154 } 1155 return *this; 1156 } 1157 1158 Stream &Stream::ThenConvolveBackwardBias( 1159 const dnn::BatchDescriptor &input_descriptor, 1160 const DeviceMemory<double> &input_data, 1161 const dnn::BatchDescriptor &bias_descriptor, 1162 DeviceMemory<double> *backward_bias_data) { 1163 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data, 1164 bias_descriptor, backward_bias_data); 1165 } 1166 1167 Stream &Stream::ThenConvolveBackwardBias( 1168 const dnn::BatchDescriptor &input_descriptor, 1169 const DeviceMemory<float> &input_data, 1170 const dnn::BatchDescriptor &bias_descriptor, 1171 DeviceMemory<float> *backward_bias_data) { 1172 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data, 1173 bias_descriptor, backward_bias_data); 1174 } 1175 1176 Stream &Stream::ThenConvolveBackwardBias( 1177 const dnn::BatchDescriptor &input_descriptor, 1178 const DeviceMemory<Eigen::half> &input_data, 1179 const dnn::BatchDescriptor &bias_descriptor, 1180 DeviceMemory<Eigen::half> *backward_bias_data) { 1181 return ThenConvolveBackwardBiasImpl(input_descriptor, input_data, 1182 bias_descriptor, backward_bias_data); 1183 } 1184 1185 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data, 1186 const DeviceMemory<float> &weights, 1187 const dnn::BatchDescriptor &input_dimensions, 1188 const dnn::BatchDescriptor &output_dimensions, 1189 DeviceMemory<float> *output_data) { 1190 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions), 1191 PARAM(output_dimensions), PARAM(output_data)); 1192 1193 if (ok()) { 1194 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1195 CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions, 1196 output_dimensions, output_data)); 1197 } else { 1198 SetErrorAndLogNoDnnSupport(); 1199 } 1200 } 1201 return *this; 1202 } 1203 1204 Stream &Stream::ThenMatMulQuantized( 1205 const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights, 1206 const DeviceMemory<float> &weight_scales, 1207 const dnn::BatchDescriptor &input_dimensions, 1208 const dnn::BatchDescriptor &output_dimensions, 1209 DeviceMemory<float> *output_data) { 1210 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales), 1211 PARAM(input_dimensions), PARAM(output_dimensions), 1212 PARAM(output_data)); 1213 1214 if (ok()) { 1215 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1216 CheckError(dnn->DoMatMulQuantized(this, input_data, weights, 1217 weight_scales, input_dimensions, 1218 output_dimensions, output_data)); 1219 } else { 1220 SetErrorAndLogNoDnnSupport(); 1221 } 1222 } 1223 return *this; 1224 } 1225 1226 Stream &Stream::ThenMatMulQuantized( 1227 const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights, 1228 const DeviceMemory<float> &weight_scales, 1229 const dnn::BatchDescriptor &input_dimensions, 1230 const dnn::BatchDescriptor &output_dimensions, 1231 DeviceMemory<float> *output_data) { 1232 VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales), 1233 PARAM(input_dimensions), PARAM(output_dimensions), 1234 PARAM(output_data)); 1235 1236 if (ok()) { 1237 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1238 CheckError(dnn->DoMatMulQuantized(this, input_data, weights, 1239 weight_scales, input_dimensions, 1240 output_dimensions, output_data)); 1241 } else { 1242 SetErrorAndLogNoDnnSupport(); 1243 } 1244 } 1245 return *this; 1246 } 1247 1248 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data, 1249 const DeviceMemory<float> &biases, 1250 const dnn::BatchDescriptor &dimensions, 1251 DeviceMemory<float> *output_data) { 1252 VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions), 1253 PARAM(output_data)); 1254 1255 if (ok()) { 1256 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1257 CheckError( 1258 dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data)); 1259 } else { 1260 SetErrorAndLogNoDnnSupport(); 1261 } 1262 } 1263 return *this; 1264 } 1265 1266 Stream &Stream::ThenPoolForward( 1267 const dnn::PoolingDescriptor &pooling_dimensions, 1268 const dnn::BatchDescriptor &input_dimensions, 1269 const DeviceMemory<double> &input_data, 1270 const dnn::BatchDescriptor &output_dimensions, 1271 DeviceMemory<double> *output_data) { 1272 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), 1273 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); 1274 1275 if (ok()) { 1276 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1277 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, 1278 input_data, output_dimensions, 1279 output_data)); 1280 } else { 1281 SetError(); 1282 LOG(WARNING) 1283 << "attempting to perform DNN operation using StreamExecutor " 1284 "without DNN support"; 1285 } 1286 } 1287 return *this; 1288 } 1289 1290 Stream &Stream::ThenPoolForward( 1291 const dnn::PoolingDescriptor &pooling_dimensions, 1292 const dnn::BatchDescriptor &input_dimensions, 1293 const DeviceMemory<float> &input_data, 1294 const dnn::BatchDescriptor &output_dimensions, 1295 DeviceMemory<float> *output_data) { 1296 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), 1297 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); 1298 1299 if (ok()) { 1300 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1301 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, 1302 input_data, output_dimensions, 1303 output_data)); 1304 } else { 1305 SetErrorAndLogNoDnnSupport(); 1306 } 1307 } 1308 return *this; 1309 } 1310 1311 Stream &Stream::ThenPoolForward( 1312 const dnn::PoolingDescriptor &pooling_dimensions, 1313 const dnn::BatchDescriptor &input_dimensions, 1314 const DeviceMemory<Eigen::half> &input_data, 1315 const dnn::BatchDescriptor &output_dimensions, 1316 DeviceMemory<Eigen::half> *output_data) { 1317 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), 1318 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); 1319 1320 if (ok()) { 1321 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1322 CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, 1323 input_data, output_dimensions, 1324 output_data)); 1325 } else { 1326 SetErrorAndLogNoDnnSupport(); 1327 } 1328 } 1329 return *this; 1330 } 1331 1332 Stream &Stream::ThenPoolBackward( 1333 const dnn::PoolingDescriptor &pooling_dimensions, 1334 const dnn::BatchDescriptor &input_dimensions, 1335 const DeviceMemory<double> &input_data, 1336 const dnn::BatchDescriptor &output_dimensions, 1337 const DeviceMemory<double> &output_data, 1338 const DeviceMemory<double> &input_diff_data, 1339 DeviceMemory<double> *output_diff_data) { 1340 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), 1341 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), 1342 PARAM(input_diff_data), PARAM(output_diff_data)); 1343 1344 if (ok()) { 1345 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1346 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, 1347 input_data, output_dimensions, output_data, 1348 input_diff_data, output_diff_data)); 1349 } else { 1350 SetError(); 1351 LOG(WARNING) 1352 << "attempting to perform DNN operation using StreamExecutor " 1353 "without DNN support"; 1354 } 1355 } 1356 return *this; 1357 } 1358 1359 Stream &Stream::ThenPoolBackward( 1360 const dnn::PoolingDescriptor &pooling_dimensions, 1361 const dnn::BatchDescriptor &input_dimensions, 1362 const DeviceMemory<float> &input_data, 1363 const dnn::BatchDescriptor &output_dimensions, 1364 const DeviceMemory<float> &output_data, 1365 const DeviceMemory<float> &input_diff_data, 1366 DeviceMemory<float> *output_diff_data) { 1367 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), 1368 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), 1369 PARAM(input_diff_data), PARAM(output_diff_data)); 1370 1371 if (ok()) { 1372 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1373 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, 1374 input_data, output_dimensions, output_data, 1375 input_diff_data, output_diff_data)); 1376 } else { 1377 SetErrorAndLogNoDnnSupport(); 1378 } 1379 } 1380 return *this; 1381 } 1382 1383 Stream &Stream::ThenPoolBackward( 1384 const dnn::PoolingDescriptor &pooling_dimensions, 1385 const dnn::BatchDescriptor &input_dimensions, 1386 const DeviceMemory<Eigen::half> &input_data, 1387 const dnn::BatchDescriptor &output_dimensions, 1388 const DeviceMemory<Eigen::half> &output_data, 1389 const DeviceMemory<Eigen::half> &input_diff_data, 1390 DeviceMemory<Eigen::half> *output_diff_data) { 1391 VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), 1392 PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), 1393 PARAM(input_diff_data), PARAM(output_diff_data)); 1394 1395 if (ok()) { 1396 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1397 CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, 1398 input_data, output_dimensions, output_data, 1399 input_diff_data, output_diff_data)); 1400 } else { 1401 SetErrorAndLogNoDnnSupport(); 1402 } 1403 } 1404 return *this; 1405 } 1406 1407 Stream &Stream::ThenNormalize( 1408 const dnn::NormalizeDescriptor &normalize_descriptor, 1409 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) { 1410 VLOG_CALL(PARAM(normalize_descriptor), PARAM(input_data), PARAM(output_data)); 1411 1412 if (ok()) { 1413 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1414 CheckError(dnn->DoNormalize(this, normalize_descriptor, input_data, 1415 output_data)); 1416 } else { 1417 SetErrorAndLogNoDnnSupport(); 1418 } 1419 } 1420 return *this; 1421 } 1422 1423 Stream &Stream::ThenNormalizeWithDimensions( 1424 const dnn::NormalizeDescriptor &normalize_descriptor, 1425 const dnn::BatchDescriptor &dimensions, 1426 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) { 1427 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data), 1428 PARAM(output_data)); 1429 1430 if (ok()) { 1431 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1432 CheckError(dnn->DoNormalizeWithDimensions( 1433 this, normalize_descriptor, dimensions, input_data, output_data)); 1434 } else { 1435 SetErrorAndLogNoDnnSupport(); 1436 } 1437 } 1438 return *this; 1439 } 1440 1441 Stream &Stream::ThenNormalizeBackwardWithDimensions( 1442 const dnn::NormalizeDescriptor &normalize_descriptor, 1443 const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data, 1444 const DeviceMemory<float> &normalized_data, 1445 const DeviceMemory<float> &normalized_variable_gradient, 1446 DeviceMemory<float> *raw_variable_gradient) { 1447 VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data), 1448 PARAM(normalized_data), PARAM(normalized_variable_gradient), 1449 PARAM(raw_variable_gradient)); 1450 1451 if (ok()) { 1452 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1453 CheckError(dnn->DoNormalizeBackwardWithDimensions( 1454 this, normalize_descriptor, dimensions, raw_data, normalized_data, 1455 normalized_variable_gradient, raw_variable_gradient)); 1456 } else { 1457 SetErrorAndLogNoDnnSupport(); 1458 } 1459 } 1460 return *this; 1461 } 1462 1463 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode, 1464 const dnn::BatchDescriptor &dimensions, 1465 const DeviceMemory<float> &input_data, 1466 DeviceMemory<float> *output_data) { 1467 return ThenActivateWithOptions(activation_mode, dimensions, input_data, 1468 output_data, /*options=*/0); 1469 } 1470 1471 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode, 1472 const dnn::BatchDescriptor &dimensions, 1473 const DeviceMemory<float> &input_data, 1474 DeviceMemory<float> *output_data, 1475 uint64 options) { 1476 VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data), 1477 PARAM(output_data), PARAM(options)); 1478 1479 if (ok()) { 1480 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1481 CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data, 1482 output_data, options)); 1483 } else { 1484 SetErrorAndLogNoDnnSupport(); 1485 } 1486 } 1487 return *this; 1488 } 1489 1490 Stream &Stream::ThenDepthConcatenate( 1491 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, 1492 port::ArraySlice<const DeviceMemory<float> *> input_data, 1493 DeviceMemory<float> *output_data) { 1494 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data)); 1495 1496 for (size_t i = 1; i < input_dimensions.size(); ++i) { 1497 if (input_dimensions[i].count() != input_dimensions[0].count() || 1498 input_dimensions[i].height() != input_dimensions[0].height() || 1499 input_dimensions[i].width() != input_dimensions[0].width()) { 1500 SetError(); 1501 LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n" 1502 << "input_dimensions[0]: " << input_dimensions[0].ToString() 1503 << "input_dimensions[" << i 1504 << "]: " << input_dimensions[i].ToString(); 1505 return *this; 1506 } 1507 } 1508 1509 if (ok()) { 1510 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1511 CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data, 1512 output_data)); 1513 } else { 1514 SetErrorAndLogNoDnnSupport(); 1515 } 1516 } 1517 return *this; 1518 } 1519 1520 Stream &Stream::ThenSpaceConcatenate( 1521 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, 1522 port::ArraySlice<const DeviceMemory<float> *> input_data, 1523 DeviceMemory<float> *output_data, 1524 dnn::SpaceConcatenateMode concat_direction) { 1525 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data)); 1526 1527 // Check that the input dimensions of all the other batches match those of the 1528 // first batch. 1529 for (size_t i = 1; i < input_dimensions.size(); ++i) { 1530 if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) && 1531 (input_dimensions[i].count() != input_dimensions[0].count() || 1532 input_dimensions[i].height() != input_dimensions[0].height() || 1533 input_dimensions[i].feature_map_count() != 1534 input_dimensions[0].feature_map_count())) { 1535 SetError(); 1536 LOG(ERROR) << "Incompatible dimensions for X concatenation.\n" 1537 << "input_dimensions[0]: " << input_dimensions[0].ToString() 1538 << "input_dimensions[" << i 1539 << "]: " << input_dimensions[i].ToString(); 1540 return *this; 1541 } 1542 1543 if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) && 1544 (input_dimensions[i].count() != input_dimensions[0].count() || 1545 input_dimensions[i].width() != input_dimensions[0].width() || 1546 input_dimensions[i].feature_map_count() != 1547 input_dimensions[0].feature_map_count())) { 1548 SetError(); 1549 LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n" 1550 << "input_dimensions[0]: " << input_dimensions[0].ToString() 1551 << "input_dimensions[" << i 1552 << "]: " << input_dimensions[i].ToString(); 1553 return *this; 1554 } 1555 } 1556 if (ok()) { 1557 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1558 CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data, 1559 output_data, concat_direction)); 1560 } else { 1561 SetErrorAndLogNoDnnSupport(); 1562 } 1563 } 1564 return *this; 1565 } 1566 1567 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions, 1568 const DeviceMemory<float> &input_data, 1569 const dnn::BatchDescriptor &output_dimensions, 1570 DeviceMemory<float> *output_data) { 1571 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), 1572 PARAM(output_dimensions), PARAM(output_data)); 1573 1574 if (ok()) { 1575 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1576 CheckError(dnn->DoReshape(this, input_dimensions, input_data, 1577 output_dimensions, output_data)); 1578 } else { 1579 SetErrorAndLogNoDnnSupport(); 1580 } 1581 } 1582 return *this; 1583 } 1584 1585 Stream &Stream::ThenDepthToSpace( 1586 const dnn::BatchDescriptor &input_dimensions, 1587 const DeviceMemory<float> &input_data, 1588 const dnn::DepthToSpaceLayout &depth_to_space_layout, 1589 const int sqrt_depth_reduction, DeviceMemory<float> *output_data) { 1590 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), 1591 PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction), 1592 PARAM(output_data)); 1593 1594 if (ok()) { 1595 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1596 CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data, 1597 depth_to_space_layout, 1598 sqrt_depth_reduction, output_data)); 1599 } else { 1600 SetErrorAndLogNoDnnSupport(); 1601 } 1602 } 1603 return *this; 1604 } 1605 1606 Stream &Stream::ThenSpaceToDepth( 1607 const dnn::BatchDescriptor &input_dimensions, 1608 const DeviceMemory<float> &input_data, 1609 const dnn::DepthToSpaceLayout &space_to_depth_layout, 1610 const int sqrt_depth_increase, DeviceMemory<float> *output_data) { 1611 VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), 1612 PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase), 1613 PARAM(output_data)); 1614 1615 if (ok()) { 1616 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1617 CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data, 1618 space_to_depth_layout, sqrt_depth_increase, 1619 output_data)); 1620 } else { 1621 SetErrorAndLogNoDnnSupport(); 1622 } 1623 } 1624 return *this; 1625 } 1626 1627 Stream &Stream::ThenElementwiseOperate( 1628 dnn::ElementwiseOperation operation, 1629 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, 1630 port::ArraySlice<const DeviceMemory<float> *> input_data, 1631 const dnn::BatchDescriptor &output_dimensions, 1632 DeviceMemory<float> *output_data) { 1633 VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data), 1634 PARAM(output_dimensions), PARAM(output_data)); 1635 1636 if (ok()) { 1637 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1638 CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions, 1639 input_data, output_dimensions, 1640 output_data)); 1641 } else { 1642 SetErrorAndLogNoDnnSupport(); 1643 } 1644 } 1645 return *this; 1646 } 1647 1648 Stream &Stream::ThenElementwiseOperateScaledQuantized( 1649 dnn::ElementwiseOperation operation, 1650 port::ArraySlice<int> input_multiplicands, int output_divisor, 1651 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, 1652 port::ArraySlice<const DeviceMemory<float> *> input_data, 1653 const dnn::BatchDescriptor &output_dimensions, 1654 DeviceMemory<float> *output_data) { 1655 VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor), 1656 PARAM(input_dimensions), PARAM(input_data), 1657 PARAM(output_dimensions), PARAM(output_data)); 1658 1659 if (ok()) { 1660 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1661 CheckError(dnn->DoElementwiseOperateScaledQuantized( 1662 this, operation, input_multiplicands, output_divisor, 1663 input_dimensions, input_data, output_dimensions, output_data)); 1664 } else { 1665 SetErrorAndLogNoDnnSupport(); 1666 } 1667 } 1668 return *this; 1669 } 1670 1671 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions, 1672 const DeviceMemory<float> &input_data, int64 left_pad, 1673 int64 right_pad, int64 top_pad, int64 bottom_pad, 1674 DeviceMemory<float> *output_data) { 1675 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad), 1676 PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad), 1677 PARAM(output_data)); 1678 1679 if (ok()) { 1680 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1681 CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad, 1682 top_pad, bottom_pad, output_data)); 1683 } else { 1684 SetErrorAndLogNoDnnSupport(); 1685 } 1686 } 1687 return *this; 1688 } 1689 1690 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions, 1691 const DeviceMemory<float> &input_data, 1692 int64 left_trim, int64 right_trim, int64 top_trim, 1693 int64 bottom_trim, 1694 DeviceMemory<float> *output_data) { 1695 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim), 1696 PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim), 1697 PARAM(output_data)); 1698 1699 if (ok()) { 1700 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1701 CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim, 1702 right_trim, top_trim, bottom_trim, 1703 output_data)); 1704 } else { 1705 SetErrorAndLogNoDnnSupport(); 1706 } 1707 } 1708 return *this; 1709 } 1710 1711 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions, 1712 const DeviceMemory<float> &input_data, 1713 int64 replicate_x, int64 replicate_y, 1714 DeviceMemory<float> *output_data) { 1715 VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x), 1716 PARAM(replicate_y), PARAM(output_data)); 1717 1718 if (ok()) { 1719 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1720 CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x, 1721 replicate_y, output_data)); 1722 } else { 1723 SetErrorAndLogNoDnnSupport(); 1724 } 1725 } 1726 return *this; 1727 } 1728 1729 Stream &Stream::ThenMemcpyD2HQuantized( 1730 const DeviceMemory<float> &gpu_unquantized_src, 1731 dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) { 1732 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst), 1733 PARAM(size)); 1734 1735 if (ok()) { 1736 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1737 CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode, 1738 host_dst, size)); 1739 } else { 1740 SetErrorAndLogNoDnnSupport(); 1741 } 1742 } 1743 return *this; 1744 } 1745 1746 Stream &Stream::ThenMemcpyH2DQuantized( 1747 const void *host_src, uint64 size, dnn::QuantizedActivationMode mode, 1748 DeviceMemory<float> *gpu_unquantized_dst) { 1749 VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode), 1750 PARAM(gpu_unquantized_dst)); 1751 1752 if (ok()) { 1753 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1754 CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode, 1755 gpu_unquantized_dst)); 1756 } else { 1757 SetErrorAndLogNoDnnSupport(); 1758 } 1759 } 1760 return *this; 1761 } 1762 1763 Stream &Stream::ThenCopyHostBuffer2Device( 1764 HostBuffer *buffer_src, DeviceMemory<float> *gpu_unquantized_dst) { 1765 VLOG_CALL(PARAM(*buffer_src), PARAM(gpu_unquantized_dst)); 1766 1767 if (ok()) { 1768 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1769 CheckError( 1770 dnn->DoCopyHostBuffer2Device(this, buffer_src, gpu_unquantized_dst)); 1771 } else { 1772 SetErrorAndLogNoDnnSupport(); 1773 } 1774 } 1775 return *this; 1776 } 1777 1778 Stream &Stream::ThenCopyDevice2HostBuffer( 1779 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst) { 1780 VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(*buffer_dst)); 1781 1782 if (ok()) { 1783 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 1784 CheckError( 1785 dnn->DoCopyDevice2HostBuffer(this, gpu_unquantized_src, buffer_dst)); 1786 } else { 1787 SetErrorAndLogNoDnnSupport(); 1788 } 1789 } 1790 return *this; 1791 } 1792 1793 Stream *Stream::GetOrCreateSubStream() { 1794 mutex_lock lock{mu_}; 1795 for (auto &stream : sub_streams_) { 1796 if (stream.second) { 1797 stream.second = false; 1798 return stream.first.get(); 1799 } 1800 } 1801 sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}}, 1802 false); 1803 Stream *sub_stream = sub_streams_.back().first.get(); 1804 sub_stream->Init(); 1805 CHECK(ok_) << "sub-stream failed to be initialized"; 1806 1807 return sub_stream; 1808 } 1809 1810 void Stream::ReturnSubStream(Stream *sub_stream) { 1811 mutex_lock lock{mu_}; 1812 for (auto &stream : sub_streams_) { 1813 if (stream.first.get() == sub_stream) { 1814 stream.second = true; 1815 return; 1816 } 1817 } 1818 LOG(FATAL) << "the sub-stream to be returned is not created by this stream"; 1819 } 1820 1821 Stream &Stream::ThenStartTimer(Timer *t) { 1822 VLOG_CALL(PARAM(t)); 1823 1824 if (ok()) { 1825 CheckError(parent_->StartTimer(this, t)); 1826 } else { 1827 LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t; 1828 } 1829 return *this; 1830 } 1831 1832 Stream &Stream::ThenStopTimer(Timer *t) { 1833 VLOG_CALL(PARAM(t)); 1834 1835 if (ok()) { 1836 CheckError(parent_->StopTimer(this, t)); 1837 } else { 1838 LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t; 1839 } 1840 return *this; 1841 } 1842 1843 Stream &Stream::ThenWaitFor(Stream *other) { 1844 VLOG_CALL(PARAM(other)); 1845 1846 CHECK(this != other) << "stream cannot wait for itself"; 1847 if (ok() && other->ok()) { 1848 CheckError(parent_->CreateStreamDependency(this, other)); 1849 } else { 1850 SetError(); 1851 LOG(INFO) << "stream " << this << " did not wait for stream: " << other; 1852 } 1853 return *this; 1854 } 1855 1856 Stream &Stream::ThenWaitFor(Event *event) { 1857 VLOG_CALL(PARAM(event)); 1858 1859 if (ok()) { 1860 port::Status status = parent_->WaitForEvent(this, event); 1861 if (!status.ok()) { 1862 LOG(ERROR) << "Error waiting for event in stream: " 1863 << status.error_message() 1864 << "; not marking stream as bad, as the Event object may be " 1865 << "at fault. Monitor for further errors."; 1866 } 1867 } else { 1868 LOG(INFO) << "stream " << this << " did not wait for an event."; 1869 } 1870 return *this; 1871 } 1872 1873 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX 1874 // functions and logs for errors. 1875 template <typename... Args> 1876 struct ThenBlasImpl { 1877 // blas_func is the DoBlasXXX member function pointer, and args are its 1878 // arguments except the first one of Stream* type. 1879 Stream &operator()(Stream *stream, 1880 bool (blas::BlasSupport::*blas_func)(Stream *, Args...), 1881 Args... args) { 1882 return Run(stream, blas_func, /*record_error=*/true, args...); 1883 } 1884 1885 // Like operator(), but only calls stream->CheckError() if record_error is 1886 // true. 1887 Stream &Run(Stream *stream, 1888 bool (blas::BlasSupport::*blas_func)(Stream *, Args...), 1889 bool record_error, Args... args); 1890 }; 1891 1892 template <typename... Args> 1893 Stream &ThenBlasImpl<Args...>::Run( 1894 Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...), 1895 bool record_error, Args... args) { 1896 if (stream->ok()) { 1897 bool ok; 1898 if (blas::BlasSupport *blas = stream->parent_->AsBlas()) { 1899 ok = (blas->*blas_func)(stream, args...); 1900 } else { 1901 LOG(WARNING) 1902 << "attempting to perform BLAS operation using StreamExecutor " 1903 "without BLAS support"; 1904 ok = false; 1905 } 1906 if (record_error) { 1907 stream->CheckError(ok); 1908 } 1909 } 1910 return *stream; 1911 } 1912 1913 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x, 1914 int incx, DeviceMemory<float> *result) { 1915 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 1916 1917 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *> 1918 impl; 1919 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx, 1920 result); 1921 } 1922 1923 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x, 1924 int incx, DeviceMemory<double> *result) { 1925 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 1926 1927 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, 1928 DeviceMemory<double> *> impl; 1929 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx, 1930 result); 1931 } 1932 1933 Stream &Stream::ThenBlasAsum(uint64 elem_count, 1934 const DeviceMemory<std::complex<float>> &x, 1935 int incx, DeviceMemory<float> *result) { 1936 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 1937 1938 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, 1939 DeviceMemory<float> *> impl; 1940 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx, 1941 result); 1942 } 1943 1944 Stream &Stream::ThenBlasAsum(uint64 elem_count, 1945 const DeviceMemory<std::complex<double>> &x, 1946 int incx, DeviceMemory<double> *result) { 1947 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 1948 1949 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, 1950 DeviceMemory<double> *> impl; 1951 return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx, 1952 result); 1953 } 1954 1955 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha, 1956 const DeviceMemory<float> &x, int incx, 1957 DeviceMemory<float> *y, int incy) { 1958 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 1959 PARAM(incy)); 1960 1961 ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int, 1962 DeviceMemory<float> *, int> impl; 1963 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, 1964 y, incy); 1965 } 1966 1967 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha, 1968 const DeviceMemory<double> &x, int incx, 1969 DeviceMemory<double> *y, int incy) { 1970 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 1971 PARAM(incy)); 1972 1973 ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int, 1974 DeviceMemory<double> *, int> impl; 1975 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, 1976 y, incy); 1977 } 1978 1979 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha, 1980 const DeviceMemory<std::complex<float>> &x, 1981 int incx, DeviceMemory<std::complex<float>> *y, 1982 int incy) { 1983 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 1984 PARAM(incy)); 1985 1986 ThenBlasImpl<uint64, std::complex<float>, 1987 const DeviceMemory<std::complex<float>> &, int, 1988 DeviceMemory<std::complex<float>> *, int> impl; 1989 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, 1990 y, incy); 1991 } 1992 1993 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha, 1994 const DeviceMemory<std::complex<double>> &x, 1995 int incx, DeviceMemory<std::complex<double>> *y, 1996 int incy) { 1997 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 1998 PARAM(incy)); 1999 2000 ThenBlasImpl<uint64, std::complex<double>, 2001 const DeviceMemory<std::complex<double>> &, int, 2002 DeviceMemory<std::complex<double>> *, int> impl; 2003 return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx, 2004 y, incy); 2005 } 2006 2007 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x, 2008 int incx, DeviceMemory<float> *y, int incy) { 2009 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); 2010 2011 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *, 2012 int> impl; 2013 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, 2014 incy); 2015 } 2016 2017 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x, 2018 int incx, DeviceMemory<double> *y, int incy) { 2019 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); 2020 2021 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, 2022 DeviceMemory<double> *, int> impl; 2023 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, 2024 incy); 2025 } 2026 2027 Stream &Stream::ThenBlasCopy(uint64 elem_count, 2028 const DeviceMemory<std::complex<float>> &x, 2029 int incx, DeviceMemory<std::complex<float>> *y, 2030 int incy) { 2031 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); 2032 2033 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, 2034 DeviceMemory<std::complex<float>> *, int> impl; 2035 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, 2036 incy); 2037 } 2038 2039 Stream &Stream::ThenBlasCopy(uint64 elem_count, 2040 const DeviceMemory<std::complex<double>> &x, 2041 int incx, DeviceMemory<std::complex<double>> *y, 2042 int incy) { 2043 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); 2044 2045 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, 2046 DeviceMemory<std::complex<double>> *, int> impl; 2047 return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y, 2048 incy); 2049 } 2050 2051 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, 2052 int incx, const DeviceMemory<float> &y, int incy, 2053 DeviceMemory<float> *result) { 2054 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2055 PARAM(result)); 2056 2057 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, 2058 const DeviceMemory<float> &, int, DeviceMemory<float> *> impl; 2059 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy, 2060 result); 2061 } 2062 2063 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x, 2064 int incx, const DeviceMemory<double> &y, int incy, 2065 DeviceMemory<double> *result) { 2066 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2067 PARAM(result)); 2068 2069 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, 2070 const DeviceMemory<double> &, int, DeviceMemory<double> *> impl; 2071 return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy, 2072 result); 2073 } 2074 2075 Stream &Stream::ThenBlasDotc(uint64 elem_count, 2076 const DeviceMemory<std::complex<float>> &x, 2077 int incx, 2078 const DeviceMemory<std::complex<float>> &y, 2079 int incy, 2080 DeviceMemory<std::complex<float>> *result) { 2081 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2082 PARAM(result)); 2083 2084 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, 2085 const DeviceMemory<std::complex<float>> &, int, 2086 DeviceMemory<std::complex<float>> *> impl; 2087 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y, 2088 incy, result); 2089 } 2090 2091 Stream &Stream::ThenBlasDotc(uint64 elem_count, 2092 const DeviceMemory<std::complex<double>> &x, 2093 int incx, 2094 const DeviceMemory<std::complex<double>> &y, 2095 int incy, 2096 DeviceMemory<std::complex<double>> *result) { 2097 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2098 PARAM(result)); 2099 2100 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, 2101 const DeviceMemory<std::complex<double>> &, int, 2102 DeviceMemory<std::complex<double>> *> impl; 2103 return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y, 2104 incy, result); 2105 } 2106 2107 Stream &Stream::ThenBlasDotu(uint64 elem_count, 2108 const DeviceMemory<std::complex<float>> &x, 2109 int incx, 2110 const DeviceMemory<std::complex<float>> &y, 2111 int incy, 2112 DeviceMemory<std::complex<float>> *result) { 2113 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2114 PARAM(result)); 2115 2116 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, 2117 const DeviceMemory<std::complex<float>> &, int, 2118 DeviceMemory<std::complex<float>> *> impl; 2119 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y, 2120 incy, result); 2121 } 2122 2123 Stream &Stream::ThenBlasDotu(uint64 elem_count, 2124 const DeviceMemory<std::complex<double>> &x, 2125 int incx, 2126 const DeviceMemory<std::complex<double>> &y, 2127 int incy, 2128 DeviceMemory<std::complex<double>> *result) { 2129 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2130 PARAM(result)); 2131 2132 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, 2133 const DeviceMemory<std::complex<double>> &, int, 2134 DeviceMemory<std::complex<double>> *> impl; 2135 return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y, 2136 incy, result); 2137 } 2138 2139 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x, 2140 int incx, DeviceMemory<float> *result) { 2141 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2142 2143 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *> 2144 impl; 2145 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx, 2146 result); 2147 } 2148 2149 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x, 2150 int incx, DeviceMemory<double> *result) { 2151 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2152 2153 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, 2154 DeviceMemory<double> *> impl; 2155 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx, 2156 result); 2157 } 2158 2159 Stream &Stream::ThenBlasNrm2(uint64 elem_count, 2160 const DeviceMemory<std::complex<float>> &x, 2161 int incx, DeviceMemory<float> *result) { 2162 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2163 2164 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, 2165 DeviceMemory<float> *> impl; 2166 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx, 2167 result); 2168 } 2169 2170 Stream &Stream::ThenBlasNrm2(uint64 elem_count, 2171 const DeviceMemory<std::complex<double>> &x, 2172 int incx, DeviceMemory<double> *result) { 2173 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2174 2175 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, 2176 DeviceMemory<double> *> impl; 2177 return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx, 2178 result); 2179 } 2180 2181 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx, 2182 DeviceMemory<float> *y, int incy, float c, 2183 float s) { 2184 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2185 PARAM(c), PARAM(s)); 2186 2187 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int, 2188 float, float> impl; 2189 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy, 2190 c, s); 2191 } 2192 2193 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, 2194 int incx, DeviceMemory<double> *y, int incy, 2195 double c, double s) { 2196 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2197 PARAM(c), PARAM(s)); 2198 2199 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int, 2200 double, double> impl; 2201 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy, 2202 c, s); 2203 } 2204 2205 Stream &Stream::ThenBlasRot(uint64 elem_count, 2206 DeviceMemory<std::complex<float>> *x, int incx, 2207 DeviceMemory<std::complex<float>> *y, int incy, 2208 float c, float s) { 2209 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2210 PARAM(c), PARAM(s)); 2211 2212 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int, 2213 DeviceMemory<std::complex<float>> *, int, float, float> impl; 2214 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy, 2215 c, s); 2216 } 2217 2218 Stream &Stream::ThenBlasRot(uint64 elem_count, 2219 DeviceMemory<std::complex<double>> *x, int incx, 2220 DeviceMemory<std::complex<double>> *y, int incy, 2221 double c, double s) { 2222 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2223 PARAM(c), PARAM(s)); 2224 2225 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int, 2226 DeviceMemory<std::complex<double>> *, int, double, double> impl; 2227 return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy, 2228 c, s); 2229 } 2230 2231 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b, 2232 DeviceMemory<float> *c, DeviceMemory<float> *s) { 2233 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s)); 2234 2235 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *, 2236 DeviceMemory<float> *, DeviceMemory<float> *> impl; 2237 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s); 2238 } 2239 2240 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b, 2241 DeviceMemory<double> *c, DeviceMemory<double> *s) { 2242 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s)); 2243 2244 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *, 2245 DeviceMemory<double> *, DeviceMemory<double> *> impl; 2246 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s); 2247 } 2248 2249 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a, 2250 DeviceMemory<std::complex<float>> *b, 2251 DeviceMemory<float> *c, 2252 DeviceMemory<std::complex<float>> *s) { 2253 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s)); 2254 2255 ThenBlasImpl<DeviceMemory<std::complex<float>> *, 2256 DeviceMemory<std::complex<float>> *, DeviceMemory<float> *, 2257 DeviceMemory<std::complex<float>> *> impl; 2258 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s); 2259 } 2260 2261 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a, 2262 DeviceMemory<std::complex<double>> *b, 2263 DeviceMemory<double> *c, 2264 DeviceMemory<std::complex<double>> *s) { 2265 VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s)); 2266 2267 ThenBlasImpl<DeviceMemory<std::complex<double>> *, 2268 DeviceMemory<std::complex<double>> *, DeviceMemory<double> *, 2269 DeviceMemory<std::complex<double>> *> impl; 2270 return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s); 2271 } 2272 2273 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, 2274 int incx, DeviceMemory<float> *y, int incy, 2275 const DeviceMemory<float> ¶m) { 2276 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2277 PARAM(param)); 2278 2279 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int, 2280 const DeviceMemory<float> &> impl; 2281 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y, 2282 incy, param); 2283 } 2284 2285 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, 2286 int incx, DeviceMemory<double> *y, int incy, 2287 const DeviceMemory<double> ¶m) { 2288 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy), 2289 PARAM(param)); 2290 2291 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int, 2292 const DeviceMemory<double> &> impl; 2293 return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y, 2294 incy, param); 2295 } 2296 2297 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2, 2298 DeviceMemory<float> *x1, 2299 const DeviceMemory<float> &y1, 2300 DeviceMemory<float> *param) { 2301 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param)); 2302 2303 ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *, 2304 DeviceMemory<float> *, const DeviceMemory<float> &, 2305 DeviceMemory<float> *> impl; 2306 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param); 2307 } 2308 2309 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1, 2310 DeviceMemory<double> *d2, 2311 DeviceMemory<double> *x1, 2312 const DeviceMemory<double> &y1, 2313 DeviceMemory<double> *param) { 2314 VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param)); 2315 2316 ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *, 2317 DeviceMemory<double> *, const DeviceMemory<double> &, 2318 DeviceMemory<double> *> impl; 2319 return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param); 2320 } 2321 2322 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha, 2323 DeviceMemory<float> *x, int incx) { 2324 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); 2325 2326 ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl; 2327 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); 2328 } 2329 2330 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha, 2331 DeviceMemory<double> *x, int incx) { 2332 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); 2333 2334 ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl; 2335 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); 2336 } 2337 2338 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha, 2339 DeviceMemory<std::complex<float>> *x, int incx) { 2340 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); 2341 2342 ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl; 2343 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); 2344 } 2345 2346 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha, 2347 DeviceMemory<std::complex<double>> *x, int incx) { 2348 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); 2349 2350 ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl; 2351 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); 2352 } 2353 2354 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha, 2355 DeviceMemory<std::complex<float>> *x, int incx) { 2356 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); 2357 2358 ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *, 2359 int> impl; 2360 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); 2361 } 2362 2363 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha, 2364 DeviceMemory<std::complex<double>> *x, int incx) { 2365 VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx)); 2366 2367 ThenBlasImpl<uint64, std::complex<double>, 2368 DeviceMemory<std::complex<double>> *, int> impl; 2369 return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx); 2370 } 2371 2372 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, 2373 int incx, DeviceMemory<float> *y, int incy) { 2374 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); 2375 2376 ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int> 2377 impl; 2378 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y, 2379 incy); 2380 } 2381 2382 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, 2383 int incx, DeviceMemory<double> *y, int incy) { 2384 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); 2385 2386 ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int> 2387 impl; 2388 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y, 2389 incy); 2390 } 2391 2392 Stream &Stream::ThenBlasSwap(uint64 elem_count, 2393 DeviceMemory<std::complex<float>> *x, int incx, 2394 DeviceMemory<std::complex<float>> *y, int incy) { 2395 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); 2396 2397 ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int, 2398 DeviceMemory<std::complex<float>> *, int> impl; 2399 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y, 2400 incy); 2401 } 2402 2403 Stream &Stream::ThenBlasSwap(uint64 elem_count, 2404 DeviceMemory<std::complex<double>> *x, int incx, 2405 DeviceMemory<std::complex<double>> *y, int incy) { 2406 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy)); 2407 2408 ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int, 2409 DeviceMemory<std::complex<double>> *, int> impl; 2410 return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y, 2411 incy); 2412 } 2413 2414 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x, 2415 int incx, DeviceMemory<int> *result) { 2416 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2417 2418 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *> 2419 impl; 2420 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx, 2421 result); 2422 } 2423 2424 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x, 2425 int incx, DeviceMemory<int> *result) { 2426 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2427 2428 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *> 2429 impl; 2430 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx, 2431 result); 2432 } 2433 2434 Stream &Stream::ThenBlasIamax(uint64 elem_count, 2435 const DeviceMemory<std::complex<float>> &x, 2436 int incx, DeviceMemory<int> *result) { 2437 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2438 2439 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, 2440 DeviceMemory<int> *> impl; 2441 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx, 2442 result); 2443 } 2444 2445 Stream &Stream::ThenBlasIamax(uint64 elem_count, 2446 const DeviceMemory<std::complex<double>> &x, 2447 int incx, DeviceMemory<int> *result) { 2448 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2449 2450 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, 2451 DeviceMemory<int> *> impl; 2452 return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx, 2453 result); 2454 } 2455 2456 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x, 2457 int incx, DeviceMemory<int> *result) { 2458 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2459 2460 ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *> 2461 impl; 2462 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx, 2463 result); 2464 } 2465 2466 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x, 2467 int incx, DeviceMemory<int> *result) { 2468 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2469 2470 ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *> 2471 impl; 2472 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx, 2473 result); 2474 } 2475 2476 Stream &Stream::ThenBlasIamin(uint64 elem_count, 2477 const DeviceMemory<std::complex<float>> &x, 2478 int incx, DeviceMemory<int> *result) { 2479 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2480 2481 ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int, 2482 DeviceMemory<int> *> impl; 2483 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx, 2484 result); 2485 } 2486 2487 Stream &Stream::ThenBlasIamin(uint64 elem_count, 2488 const DeviceMemory<std::complex<double>> &x, 2489 int incx, DeviceMemory<int> *result) { 2490 VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result)); 2491 2492 ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int, 2493 DeviceMemory<int> *> impl; 2494 return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx, 2495 result); 2496 } 2497 2498 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, 2499 uint64 kl, uint64 ku, float alpha, 2500 const DeviceMemory<float> &a, int lda, 2501 const DeviceMemory<float> &x, int incx, float beta, 2502 DeviceMemory<float> *y, int incy) { 2503 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku), 2504 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx), 2505 PARAM(beta), PARAM(y), PARAM(incy)); 2506 2507 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float, 2508 const DeviceMemory<float> &, int, const DeviceMemory<float> &, 2509 int, float, DeviceMemory<float> *, int> impl; 2510 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha, 2511 a, lda, x, incx, beta, y, incy); 2512 } 2513 2514 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, 2515 uint64 kl, uint64 ku, double alpha, 2516 const DeviceMemory<double> &a, int lda, 2517 const DeviceMemory<double> &x, int incx, 2518 double beta, DeviceMemory<double> *y, int incy) { 2519 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku), 2520 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx), 2521 PARAM(beta), PARAM(y), PARAM(incy)); 2522 2523 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double, 2524 const DeviceMemory<double> &, int, const DeviceMemory<double> &, 2525 int, double, DeviceMemory<double> *, int> impl; 2526 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha, 2527 a, lda, x, incx, beta, y, incy); 2528 } 2529 2530 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, 2531 uint64 kl, uint64 ku, std::complex<float> alpha, 2532 const DeviceMemory<std::complex<float>> &a, 2533 int lda, 2534 const DeviceMemory<std::complex<float>> &x, 2535 int incx, std::complex<float> beta, 2536 DeviceMemory<std::complex<float>> *y, int incy) { 2537 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku), 2538 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx), 2539 PARAM(beta), PARAM(y), PARAM(incy)); 2540 2541 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, 2542 std::complex<float>, const DeviceMemory<std::complex<float>> &, 2543 int, const DeviceMemory<std::complex<float>> &, int, 2544 std::complex<float>, DeviceMemory<std::complex<float>> *, 2545 int> impl; 2546 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha, 2547 a, lda, x, incx, beta, y, incy); 2548 } 2549 2550 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, 2551 uint64 kl, uint64 ku, std::complex<double> alpha, 2552 const DeviceMemory<std::complex<double>> &a, 2553 int lda, 2554 const DeviceMemory<std::complex<double>> &x, 2555 int incx, std::complex<double> beta, 2556 DeviceMemory<std::complex<double>> *y, int incy) { 2557 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku), 2558 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx), 2559 PARAM(beta), PARAM(y), PARAM(incy)); 2560 2561 ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, 2562 std::complex<double>, const DeviceMemory<std::complex<double>> &, 2563 int, const DeviceMemory<std::complex<double>> &, int, 2564 std::complex<double>, DeviceMemory<std::complex<double>> *, 2565 int> impl; 2566 return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha, 2567 a, lda, x, incx, beta, y, incy); 2568 } 2569 2570 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, 2571 float alpha, const DeviceMemory<float> &a, int lda, 2572 const DeviceMemory<float> &x, int incx, float beta, 2573 DeviceMemory<float> *y, int incy) { 2574 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), 2575 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), 2576 PARAM(incy)); 2577 2578 ThenBlasImpl<blas::Transpose, uint64, uint64, float, 2579 const DeviceMemory<float> &, int, const DeviceMemory<float> &, 2580 int, float, DeviceMemory<float> *, int> impl; 2581 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, 2582 x, incx, beta, y, incy); 2583 } 2584 2585 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, 2586 double alpha, const DeviceMemory<double> &a, 2587 int lda, const DeviceMemory<double> &x, int incx, 2588 double beta, DeviceMemory<double> *y, int incy) { 2589 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), 2590 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), 2591 PARAM(incy)); 2592 2593 ThenBlasImpl<blas::Transpose, uint64, uint64, double, 2594 const DeviceMemory<double> &, int, const DeviceMemory<double> &, 2595 int, double, DeviceMemory<double> *, int> impl; 2596 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, 2597 x, incx, beta, y, incy); 2598 } 2599 2600 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, 2601 std::complex<float> alpha, 2602 const DeviceMemory<std::complex<float>> &a, 2603 int lda, 2604 const DeviceMemory<std::complex<float>> &x, 2605 int incx, std::complex<float> beta, 2606 DeviceMemory<std::complex<float>> *y, int incy) { 2607 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), 2608 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), 2609 PARAM(incy)); 2610 2611 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>, 2612 const DeviceMemory<std::complex<float>> &, int, 2613 const DeviceMemory<std::complex<float>> &, int, 2614 std::complex<float>, DeviceMemory<std::complex<float>> *, 2615 int> impl; 2616 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, 2617 x, incx, beta, y, incy); 2618 } 2619 2620 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, 2621 std::complex<double> alpha, 2622 const DeviceMemory<std::complex<double>> &a, 2623 int lda, 2624 const DeviceMemory<std::complex<double>> &x, 2625 int incx, std::complex<double> beta, 2626 DeviceMemory<std::complex<double>> *y, int incy) { 2627 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), 2628 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), 2629 PARAM(incy)); 2630 2631 ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>, 2632 const DeviceMemory<std::complex<double>> &, int, 2633 const DeviceMemory<std::complex<double>> &, int, 2634 std::complex<double>, DeviceMemory<std::complex<double>> *, 2635 int> impl; 2636 return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda, 2637 x, incx, beta, y, incy); 2638 } 2639 2640 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha, 2641 const DeviceMemory<float> &x, int incx, 2642 const DeviceMemory<float> &y, int incy, 2643 DeviceMemory<float> *a, int lda) { 2644 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 2645 PARAM(incy), PARAM(a), PARAM(lda)); 2646 2647 ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int, 2648 const DeviceMemory<float> &, int, DeviceMemory<float> *, 2649 int> impl; 2650 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y, 2651 incy, a, lda); 2652 } 2653 2654 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha, 2655 const DeviceMemory<double> &x, int incx, 2656 const DeviceMemory<double> &y, int incy, 2657 DeviceMemory<double> *a, int lda) { 2658 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 2659 PARAM(incy), PARAM(a), PARAM(lda)); 2660 2661 ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int, 2662 const DeviceMemory<double> &, int, DeviceMemory<double> *, 2663 int> impl; 2664 return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y, 2665 incy, a, lda); 2666 } 2667 2668 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha, 2669 const DeviceMemory<std::complex<float>> &x, 2670 int incx, 2671 const DeviceMemory<std::complex<float>> &y, 2672 int incy, DeviceMemory<std::complex<float>> *a, 2673 int lda) { 2674 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 2675 PARAM(incy), PARAM(a), PARAM(lda)); 2676 2677 ThenBlasImpl<uint64, uint64, std::complex<float>, 2678 const DeviceMemory<std::complex<float>> &, int, 2679 const DeviceMemory<std::complex<float>> &, int, 2680 DeviceMemory<std::complex<float>> *, int> impl; 2681 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y, 2682 incy, a, lda); 2683 } 2684 2685 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha, 2686 const DeviceMemory<std::complex<double>> &x, 2687 int incx, 2688 const DeviceMemory<std::complex<double>> &y, 2689 int incy, DeviceMemory<std::complex<double>> *a, 2690 int lda) { 2691 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 2692 PARAM(incy), PARAM(a), PARAM(lda)); 2693 2694 ThenBlasImpl<uint64, uint64, std::complex<double>, 2695 const DeviceMemory<std::complex<double>> &, int, 2696 const DeviceMemory<std::complex<double>> &, int, 2697 DeviceMemory<std::complex<double>> *, int> impl; 2698 return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y, 2699 incy, a, lda); 2700 } 2701 2702 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha, 2703 const DeviceMemory<std::complex<float>> &x, 2704 int incx, 2705 const DeviceMemory<std::complex<float>> &y, 2706 int incy, DeviceMemory<std::complex<float>> *a, 2707 int lda) { 2708 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 2709 PARAM(incy), PARAM(a), PARAM(lda)); 2710 2711 ThenBlasImpl<uint64, uint64, std::complex<float>, 2712 const DeviceMemory<std::complex<float>> &, int, 2713 const DeviceMemory<std::complex<float>> &, int, 2714 DeviceMemory<std::complex<float>> *, int> impl; 2715 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y, 2716 incy, a, lda); 2717 } 2718 2719 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha, 2720 const DeviceMemory<std::complex<double>> &x, 2721 int incx, 2722 const DeviceMemory<std::complex<double>> &y, 2723 int incy, DeviceMemory<std::complex<double>> *a, 2724 int lda) { 2725 VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y), 2726 PARAM(incy), PARAM(a), PARAM(lda)); 2727 2728 ThenBlasImpl<uint64, uint64, std::complex<double>, 2729 const DeviceMemory<std::complex<double>> &, int, 2730 const DeviceMemory<std::complex<double>> &, int, 2731 DeviceMemory<std::complex<double>> *, int> impl; 2732 return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y, 2733 incy, a, lda); 2734 } 2735 2736 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k, 2737 std::complex<float> alpha, 2738 const DeviceMemory<std::complex<float>> &a, 2739 int lda, 2740 const DeviceMemory<std::complex<float>> &x, 2741 int incx, std::complex<float> beta, 2742 DeviceMemory<std::complex<float>> *y, int incy) { 2743 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), 2744 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 2745 2746 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>, 2747 const DeviceMemory<std::complex<float>> &, int, 2748 const DeviceMemory<std::complex<float>> &, int, 2749 std::complex<float>, DeviceMemory<std::complex<float>> *, 2750 int> impl; 2751 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda, 2752 x, incx, beta, y, incy); 2753 } 2754 2755 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k, 2756 std::complex<double> alpha, 2757 const DeviceMemory<std::complex<double>> &a, 2758 int lda, 2759 const DeviceMemory<std::complex<double>> &x, 2760 int incx, std::complex<double> beta, 2761 DeviceMemory<std::complex<double>> *y, int incy) { 2762 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), 2763 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 2764 2765 ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>, 2766 const DeviceMemory<std::complex<double>> &, int, 2767 const DeviceMemory<std::complex<double>> &, int, 2768 std::complex<double>, DeviceMemory<std::complex<double>> *, 2769 int> impl; 2770 return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda, 2771 x, incx, beta, y, incy); 2772 } 2773 2774 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n, 2775 std::complex<float> alpha, 2776 const DeviceMemory<std::complex<float>> &a, 2777 int lda, 2778 const DeviceMemory<std::complex<float>> &x, 2779 int incx, std::complex<float> beta, 2780 DeviceMemory<std::complex<float>> *y, int incy) { 2781 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), 2782 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 2783 2784 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>, 2785 const DeviceMemory<std::complex<float>> &, int, 2786 const DeviceMemory<std::complex<float>> &, int, 2787 std::complex<float>, DeviceMemory<std::complex<float>> *, 2788 int> impl; 2789 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x, 2790 incx, beta, y, incy); 2791 } 2792 2793 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n, 2794 std::complex<double> alpha, 2795 const DeviceMemory<std::complex<double>> &a, 2796 int lda, 2797 const DeviceMemory<std::complex<double>> &x, 2798 int incx, std::complex<double> beta, 2799 DeviceMemory<std::complex<double>> *y, int incy) { 2800 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), 2801 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 2802 2803 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>, 2804 const DeviceMemory<std::complex<double>> &, int, 2805 const DeviceMemory<std::complex<double>> &, int, 2806 std::complex<double>, DeviceMemory<std::complex<double>> *, 2807 int> impl; 2808 return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x, 2809 incx, beta, y, incy); 2810 } 2811 2812 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha, 2813 const DeviceMemory<std::complex<float>> &x, 2814 int incx, DeviceMemory<std::complex<float>> *a, 2815 int lda) { 2816 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 2817 PARAM(a), PARAM(lda)); 2818 2819 ThenBlasImpl<blas::UpperLower, uint64, float, 2820 const DeviceMemory<std::complex<float>> &, int, 2821 DeviceMemory<std::complex<float>> *, int> impl; 2822 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a, 2823 lda); 2824 } 2825 2826 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha, 2827 const DeviceMemory<std::complex<double>> &x, 2828 int incx, DeviceMemory<std::complex<double>> *a, 2829 int lda) { 2830 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 2831 PARAM(a), PARAM(lda)); 2832 2833 ThenBlasImpl<blas::UpperLower, uint64, double, 2834 const DeviceMemory<std::complex<double>> &, int, 2835 DeviceMemory<std::complex<double>> *, int> impl; 2836 return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a, 2837 lda); 2838 } 2839 2840 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n, 2841 std::complex<float> alpha, 2842 const DeviceMemory<std::complex<float>> &x, 2843 int incx, 2844 const DeviceMemory<std::complex<float>> &y, 2845 int incy, DeviceMemory<std::complex<float>> *a, 2846 int lda) { 2847 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 2848 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda)); 2849 2850 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>, 2851 const DeviceMemory<std::complex<float>> &, int, 2852 const DeviceMemory<std::complex<float>> &, int, 2853 DeviceMemory<std::complex<float>> *, int> impl; 2854 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y, 2855 incy, a, lda); 2856 } 2857 2858 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n, 2859 std::complex<double> alpha, 2860 const DeviceMemory<std::complex<double>> &x, 2861 int incx, 2862 const DeviceMemory<std::complex<double>> &y, 2863 int incy, DeviceMemory<std::complex<double>> *a, 2864 int lda) { 2865 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 2866 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda)); 2867 2868 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>, 2869 const DeviceMemory<std::complex<double>> &, int, 2870 const DeviceMemory<std::complex<double>> &, int, 2871 DeviceMemory<std::complex<double>> *, int> impl; 2872 return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y, 2873 incy, a, lda); 2874 } 2875 2876 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n, 2877 std::complex<float> alpha, 2878 const DeviceMemory<std::complex<float>> &ap, 2879 const DeviceMemory<std::complex<float>> &x, 2880 int incx, std::complex<float> beta, 2881 DeviceMemory<std::complex<float>> *y, int incy) { 2882 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x), 2883 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 2884 2885 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>, 2886 const DeviceMemory<std::complex<float>> &, 2887 const DeviceMemory<std::complex<float>> &, int, 2888 std::complex<float>, DeviceMemory<std::complex<float>> *, 2889 int> impl; 2890 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx, 2891 beta, y, incy); 2892 } 2893 2894 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n, 2895 std::complex<double> alpha, 2896 const DeviceMemory<std::complex<double>> &ap, 2897 const DeviceMemory<std::complex<double>> &x, 2898 int incx, std::complex<double> beta, 2899 DeviceMemory<std::complex<double>> *y, int incy) { 2900 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x), 2901 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 2902 2903 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>, 2904 const DeviceMemory<std::complex<double>> &, 2905 const DeviceMemory<std::complex<double>> &, int, 2906 std::complex<double>, DeviceMemory<std::complex<double>> *, 2907 int> impl; 2908 return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx, 2909 beta, y, incy); 2910 } 2911 2912 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha, 2913 const DeviceMemory<std::complex<float>> &x, 2914 int incx, DeviceMemory<std::complex<float>> *ap) { 2915 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 2916 PARAM(ap)); 2917 2918 ThenBlasImpl<blas::UpperLower, uint64, float, 2919 const DeviceMemory<std::complex<float>> &, int, 2920 DeviceMemory<std::complex<float>> *> impl; 2921 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap); 2922 } 2923 2924 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha, 2925 const DeviceMemory<std::complex<double>> &x, 2926 int incx, DeviceMemory<std::complex<double>> *ap) { 2927 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 2928 PARAM(ap)); 2929 2930 ThenBlasImpl<blas::UpperLower, uint64, double, 2931 const DeviceMemory<std::complex<double>> &, int, 2932 DeviceMemory<std::complex<double>> *> impl; 2933 return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap); 2934 } 2935 2936 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n, 2937 std::complex<float> alpha, 2938 const DeviceMemory<std::complex<float>> &x, 2939 int incx, 2940 const DeviceMemory<std::complex<float>> &y, 2941 int incy, DeviceMemory<std::complex<float>> *ap) { 2942 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 2943 PARAM(y), PARAM(incy), PARAM(ap)); 2944 2945 ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>, 2946 const DeviceMemory<std::complex<float>> &, int, 2947 const DeviceMemory<std::complex<float>> &, int, 2948 DeviceMemory<std::complex<float>> *> impl; 2949 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y, 2950 incy, ap); 2951 } 2952 2953 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n, 2954 std::complex<double> alpha, 2955 const DeviceMemory<std::complex<double>> &x, 2956 int incx, 2957 const DeviceMemory<std::complex<double>> &y, 2958 int incy, DeviceMemory<std::complex<double>> *ap) { 2959 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 2960 PARAM(y), PARAM(incy), PARAM(ap)); 2961 2962 ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>, 2963 const DeviceMemory<std::complex<double>> &, int, 2964 const DeviceMemory<std::complex<double>> &, int, 2965 DeviceMemory<std::complex<double>> *> impl; 2966 return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y, 2967 incy, ap); 2968 } 2969 2970 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, 2971 float alpha, const DeviceMemory<float> &a, int lda, 2972 const DeviceMemory<float> &x, int incx, float beta, 2973 DeviceMemory<float> *y, int incy) { 2974 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), 2975 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 2976 2977 ThenBlasImpl<blas::UpperLower, uint64, uint64, float, 2978 const DeviceMemory<float> &, int, const DeviceMemory<float> &, 2979 int, float, DeviceMemory<float> *, int> impl; 2980 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda, 2981 x, incx, beta, y, incy); 2982 } 2983 2984 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, 2985 double alpha, const DeviceMemory<double> &a, 2986 int lda, const DeviceMemory<double> &x, int incx, 2987 double beta, DeviceMemory<double> *y, int incy) { 2988 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), 2989 PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 2990 2991 ThenBlasImpl<blas::UpperLower, uint64, uint64, double, 2992 const DeviceMemory<double> &, int, const DeviceMemory<double> &, 2993 int, double, DeviceMemory<double> *, int> impl; 2994 return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda, 2995 x, incx, beta, y, incy); 2996 } 2997 2998 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha, 2999 const DeviceMemory<float> &ap, 3000 const DeviceMemory<float> &x, int incx, float beta, 3001 DeviceMemory<float> *y, int incy) { 3002 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x), 3003 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 3004 3005 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, 3006 const DeviceMemory<float> &, int, float, DeviceMemory<float> *, 3007 int> impl; 3008 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx, 3009 beta, y, incy); 3010 } 3011 3012 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha, 3013 const DeviceMemory<double> &ap, 3014 const DeviceMemory<double> &x, int incx, 3015 double beta, DeviceMemory<double> *y, int incy) { 3016 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x), 3017 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 3018 3019 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, 3020 const DeviceMemory<double> &, int, double, 3021 DeviceMemory<double> *, int> impl; 3022 return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx, 3023 beta, y, incy); 3024 } 3025 3026 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha, 3027 const DeviceMemory<float> &x, int incx, 3028 DeviceMemory<float> *ap) { 3029 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 3030 PARAM(ap)); 3031 3032 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, 3033 int, DeviceMemory<float> *> impl; 3034 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap); 3035 } 3036 3037 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha, 3038 const DeviceMemory<double> &x, int incx, 3039 DeviceMemory<double> *ap) { 3040 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 3041 PARAM(ap)); 3042 3043 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, 3044 int, DeviceMemory<double> *> impl; 3045 return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap); 3046 } 3047 3048 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha, 3049 const DeviceMemory<float> &x, int incx, 3050 const DeviceMemory<float> &y, int incy, 3051 DeviceMemory<float> *ap) { 3052 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 3053 PARAM(y), PARAM(incy), PARAM(ap)); 3054 3055 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, 3056 int, const DeviceMemory<float> &, int, 3057 DeviceMemory<float> *> impl; 3058 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y, 3059 incy, ap); 3060 } 3061 3062 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha, 3063 const DeviceMemory<double> &x, int incx, 3064 const DeviceMemory<double> &y, int incy, 3065 DeviceMemory<double> *ap) { 3066 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 3067 PARAM(y), PARAM(incy), PARAM(ap)); 3068 3069 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, 3070 int, const DeviceMemory<double> &, int, 3071 DeviceMemory<double> *> impl; 3072 return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y, 3073 incy, ap); 3074 } 3075 3076 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha, 3077 const DeviceMemory<float> &a, int lda, 3078 const DeviceMemory<float> &x, int incx, float beta, 3079 DeviceMemory<float> *y, int incy) { 3080 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), 3081 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 3082 3083 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, 3084 int, const DeviceMemory<float> &, int, float, 3085 DeviceMemory<float> *, int> impl; 3086 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x, 3087 incx, beta, y, incy); 3088 } 3089 3090 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha, 3091 const DeviceMemory<double> &a, int lda, 3092 const DeviceMemory<double> &x, int incx, 3093 double beta, DeviceMemory<double> *y, int incy) { 3094 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), 3095 PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy)); 3096 3097 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, 3098 int, const DeviceMemory<double> &, int, double, 3099 DeviceMemory<double> *, int> impl; 3100 return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x, 3101 incx, beta, y, incy); 3102 } 3103 3104 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha, 3105 const DeviceMemory<float> &x, int incx, 3106 DeviceMemory<float> *a, int lda) { 3107 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 3108 PARAM(a), PARAM(lda)); 3109 3110 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, 3111 int, DeviceMemory<float> *, int> impl; 3112 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a, 3113 lda); 3114 } 3115 3116 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha, 3117 const DeviceMemory<double> &x, int incx, 3118 DeviceMemory<double> *a, int lda) { 3119 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 3120 PARAM(a), PARAM(lda)); 3121 3122 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, 3123 int, DeviceMemory<double> *, int> impl; 3124 return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a, 3125 lda); 3126 } 3127 3128 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha, 3129 const DeviceMemory<float> &x, int incx, 3130 const DeviceMemory<float> &y, int incy, 3131 DeviceMemory<float> *a, int lda) { 3132 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 3133 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda)); 3134 3135 ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &, 3136 int, const DeviceMemory<float> &, int, DeviceMemory<float> *, 3137 int> impl; 3138 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y, 3139 incy, a, lda); 3140 } 3141 3142 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha, 3143 const DeviceMemory<double> &x, int incx, 3144 const DeviceMemory<double> &y, int incy, 3145 DeviceMemory<double> *a, int lda) { 3146 VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), 3147 PARAM(y), PARAM(incy), PARAM(a), PARAM(lda)); 3148 3149 ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &, 3150 int, const DeviceMemory<double> &, int, DeviceMemory<double> *, 3151 int> impl; 3152 return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y, 3153 incy, a, lda); 3154 } 3155 3156 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, 3157 blas::Diagonal diag, uint64 n, uint64 k, 3158 const DeviceMemory<float> &a, int lda, 3159 DeviceMemory<float> *x, int incx) { 3160 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), 3161 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); 3162 3163 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3164 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *, 3165 int> impl; 3166 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a, 3167 lda, x, incx); 3168 } 3169 3170 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, 3171 blas::Diagonal diag, uint64 n, uint64 k, 3172 const DeviceMemory<double> &a, int lda, 3173 DeviceMemory<double> *x, int incx) { 3174 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), 3175 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); 3176 3177 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3178 uint64, const DeviceMemory<double> &, int, 3179 DeviceMemory<double> *, int> impl; 3180 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a, 3181 lda, x, incx); 3182 } 3183 3184 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, 3185 blas::Diagonal diag, uint64 n, uint64 k, 3186 const DeviceMemory<std::complex<float>> &a, 3187 int lda, DeviceMemory<std::complex<float>> *x, 3188 int incx) { 3189 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), 3190 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); 3191 3192 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3193 uint64, const DeviceMemory<std::complex<float>> &, int, 3194 DeviceMemory<std::complex<float>> *, int> impl; 3195 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a, 3196 lda, x, incx); 3197 } 3198 3199 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, 3200 blas::Diagonal diag, uint64 n, uint64 k, 3201 const DeviceMemory<std::complex<double>> &a, 3202 int lda, DeviceMemory<std::complex<double>> *x, 3203 int incx) { 3204 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), 3205 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); 3206 3207 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3208 uint64, const DeviceMemory<std::complex<double>> &, int, 3209 DeviceMemory<std::complex<double>> *, int> impl; 3210 return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a, 3211 lda, x, incx); 3212 } 3213 3214 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, 3215 blas::Diagonal diag, uint64 n, uint64 k, 3216 const DeviceMemory<float> &a, int lda, 3217 DeviceMemory<float> *x, int incx) { 3218 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), 3219 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); 3220 3221 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3222 uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *, 3223 int> impl; 3224 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a, 3225 lda, x, incx); 3226 } 3227 3228 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, 3229 blas::Diagonal diag, uint64 n, uint64 k, 3230 const DeviceMemory<double> &a, int lda, 3231 DeviceMemory<double> *x, int incx) { 3232 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), 3233 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); 3234 3235 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3236 uint64, const DeviceMemory<double> &, int, 3237 DeviceMemory<double> *, int> impl; 3238 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a, 3239 lda, x, incx); 3240 } 3241 3242 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, 3243 blas::Diagonal diag, uint64 n, uint64 k, 3244 const DeviceMemory<std::complex<float>> &a, 3245 int lda, DeviceMemory<std::complex<float>> *x, 3246 int incx) { 3247 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), 3248 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); 3249 3250 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3251 uint64, const DeviceMemory<std::complex<float>> &, int, 3252 DeviceMemory<std::complex<float>> *, int> impl; 3253 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a, 3254 lda, x, incx); 3255 } 3256 3257 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, 3258 blas::Diagonal diag, uint64 n, uint64 k, 3259 const DeviceMemory<std::complex<double>> &a, 3260 int lda, DeviceMemory<std::complex<double>> *x, 3261 int incx) { 3262 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k), 3263 PARAM(a), PARAM(lda), PARAM(x), PARAM(incx)); 3264 3265 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3266 uint64, const DeviceMemory<std::complex<double>> &, int, 3267 DeviceMemory<std::complex<double>> *, int> impl; 3268 return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a, 3269 lda, x, incx); 3270 } 3271 3272 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, 3273 blas::Diagonal diag, uint64 n, 3274 const DeviceMemory<float> &ap, 3275 DeviceMemory<float> *x, int incx) { 3276 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), 3277 PARAM(x), PARAM(incx)); 3278 3279 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3280 const DeviceMemory<float> &, DeviceMemory<float> *, int> impl; 3281 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x, 3282 incx); 3283 } 3284 3285 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, 3286 blas::Diagonal diag, uint64 n, 3287 const DeviceMemory<double> &ap, 3288 DeviceMemory<double> *x, int incx) { 3289 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), 3290 PARAM(x), PARAM(incx)); 3291 3292 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3293 const DeviceMemory<double> &, DeviceMemory<double> *, int> impl; 3294 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x, 3295 incx); 3296 } 3297 3298 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, 3299 blas::Diagonal diag, uint64 n, 3300 const DeviceMemory<std::complex<float>> &ap, 3301 DeviceMemory<std::complex<float>> *x, int incx) { 3302 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), 3303 PARAM(x), PARAM(incx)); 3304 3305 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3306 const DeviceMemory<std::complex<float>> &, 3307 DeviceMemory<std::complex<float>> *, int> impl; 3308 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x, 3309 incx); 3310 } 3311 3312 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, 3313 blas::Diagonal diag, uint64 n, 3314 const DeviceMemory<std::complex<double>> &ap, 3315 DeviceMemory<std::complex<double>> *x, int incx) { 3316 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), 3317 PARAM(x), PARAM(incx)); 3318 3319 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3320 const DeviceMemory<std::complex<double>> &, 3321 DeviceMemory<std::complex<double>> *, int> impl; 3322 return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x, 3323 incx); 3324 } 3325 3326 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, 3327 blas::Diagonal diag, uint64 n, 3328 const DeviceMemory<float> &ap, 3329 DeviceMemory<float> *x, int incx) { 3330 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), 3331 PARAM(x), PARAM(incx)); 3332 3333 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3334 const DeviceMemory<float> &, DeviceMemory<float> *, int> impl; 3335 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x, 3336 incx); 3337 } 3338 3339 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, 3340 blas::Diagonal diag, uint64 n, 3341 const DeviceMemory<double> &ap, 3342 DeviceMemory<double> *x, int incx) { 3343 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), 3344 PARAM(x), PARAM(incx)); 3345 3346 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3347 const DeviceMemory<double> &, DeviceMemory<double> *, int> impl; 3348 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x, 3349 incx); 3350 } 3351 3352 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, 3353 blas::Diagonal diag, uint64 n, 3354 const DeviceMemory<std::complex<float>> &ap, 3355 DeviceMemory<std::complex<float>> *x, int incx) { 3356 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), 3357 PARAM(x), PARAM(incx)); 3358 3359 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3360 const DeviceMemory<std::complex<float>> &, 3361 DeviceMemory<std::complex<float>> *, int> impl; 3362 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x, 3363 incx); 3364 } 3365 3366 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, 3367 blas::Diagonal diag, uint64 n, 3368 const DeviceMemory<std::complex<double>> &ap, 3369 DeviceMemory<std::complex<double>> *x, int incx) { 3370 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap), 3371 PARAM(x), PARAM(incx)); 3372 3373 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3374 const DeviceMemory<std::complex<double>> &, 3375 DeviceMemory<std::complex<double>> *, int> impl; 3376 return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x, 3377 incx); 3378 } 3379 3380 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, 3381 blas::Diagonal diag, uint64 n, 3382 const DeviceMemory<float> &a, int lda, 3383 DeviceMemory<float> *x, int incx) { 3384 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), 3385 PARAM(lda), PARAM(x), PARAM(incx)); 3386 3387 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3388 const DeviceMemory<float> &, int, DeviceMemory<float> *, 3389 int> impl; 3390 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a, 3391 lda, x, incx); 3392 } 3393 3394 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, 3395 blas::Diagonal diag, uint64 n, 3396 const DeviceMemory<double> &a, int lda, 3397 DeviceMemory<double> *x, int incx) { 3398 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), 3399 PARAM(lda), PARAM(x), PARAM(incx)); 3400 3401 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3402 const DeviceMemory<double> &, int, DeviceMemory<double> *, 3403 int> impl; 3404 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a, 3405 lda, x, incx); 3406 } 3407 3408 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, 3409 blas::Diagonal diag, uint64 n, 3410 const DeviceMemory<std::complex<float>> &a, 3411 int lda, DeviceMemory<std::complex<float>> *x, 3412 int incx) { 3413 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), 3414 PARAM(lda), PARAM(x), PARAM(incx)); 3415 3416 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3417 const DeviceMemory<std::complex<float>> &, int, 3418 DeviceMemory<std::complex<float>> *, int> impl; 3419 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a, 3420 lda, x, incx); 3421 } 3422 3423 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, 3424 blas::Diagonal diag, uint64 n, 3425 const DeviceMemory<std::complex<double>> &a, 3426 int lda, DeviceMemory<std::complex<double>> *x, 3427 int incx) { 3428 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), 3429 PARAM(lda), PARAM(x), PARAM(incx)); 3430 3431 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3432 const DeviceMemory<std::complex<double>> &, int, 3433 DeviceMemory<std::complex<double>> *, int> impl; 3434 return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a, 3435 lda, x, incx); 3436 } 3437 3438 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, 3439 blas::Diagonal diag, uint64 n, 3440 const DeviceMemory<float> &a, int lda, 3441 DeviceMemory<float> *x, int incx) { 3442 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), 3443 PARAM(lda), PARAM(x), PARAM(incx)); 3444 3445 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3446 const DeviceMemory<float> &, int, DeviceMemory<float> *, 3447 int> impl; 3448 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a, 3449 lda, x, incx); 3450 } 3451 3452 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, 3453 blas::Diagonal diag, uint64 n, 3454 const DeviceMemory<double> &a, int lda, 3455 DeviceMemory<double> *x, int incx) { 3456 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), 3457 PARAM(lda), PARAM(x), PARAM(incx)); 3458 3459 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3460 const DeviceMemory<double> &, int, DeviceMemory<double> *, 3461 int> impl; 3462 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a, 3463 lda, x, incx); 3464 } 3465 3466 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, 3467 blas::Diagonal diag, uint64 n, 3468 const DeviceMemory<std::complex<float>> &a, 3469 int lda, DeviceMemory<std::complex<float>> *x, 3470 int incx) { 3471 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), 3472 PARAM(lda), PARAM(x), PARAM(incx)); 3473 3474 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3475 const DeviceMemory<std::complex<float>> &, int, 3476 DeviceMemory<std::complex<float>> *, int> impl; 3477 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a, 3478 lda, x, incx); 3479 } 3480 3481 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, 3482 blas::Diagonal diag, uint64 n, 3483 const DeviceMemory<std::complex<double>> &a, 3484 int lda, DeviceMemory<std::complex<double>> *x, 3485 int incx) { 3486 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a), 3487 PARAM(lda), PARAM(x), PARAM(incx)); 3488 3489 ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64, 3490 const DeviceMemory<std::complex<double>> &, int, 3491 DeviceMemory<std::complex<double>> *, int> impl; 3492 return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a, 3493 lda, x, incx); 3494 } 3495 3496 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, 3497 uint64 m, uint64 n, uint64 k, float alpha, 3498 const DeviceMemory<Eigen::half> &a, int lda, 3499 const DeviceMemory<Eigen::half> &b, int ldb, 3500 float beta, 3501 DeviceMemory<Eigen::half> *c, int ldc) { 3502 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3503 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3504 PARAM(beta), PARAM(c), PARAM(ldc)); 3505 3506 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, 3507 const DeviceMemory<Eigen::half> &, int, 3508 const DeviceMemory<Eigen::half> &, int, 3509 float, DeviceMemory<Eigen::half> *, int> impl; 3510 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, 3511 alpha, a, lda, b, ldb, beta, c, ldc); 3512 } 3513 3514 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, 3515 uint64 m, uint64 n, uint64 k, float alpha, 3516 const DeviceMemory<float> &a, int lda, 3517 const DeviceMemory<float> &b, int ldb, float beta, 3518 DeviceMemory<float> *c, int ldc) { 3519 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3520 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3521 PARAM(beta), PARAM(c), PARAM(ldc)); 3522 3523 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, 3524 const DeviceMemory<float> &, int, const DeviceMemory<float> &, 3525 int, float, DeviceMemory<float> *, int> impl; 3526 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, 3527 alpha, a, lda, b, ldb, beta, c, ldc); 3528 } 3529 3530 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, 3531 uint64 m, uint64 n, uint64 k, double alpha, 3532 const DeviceMemory<double> &a, int lda, 3533 const DeviceMemory<double> &b, int ldb, 3534 double beta, DeviceMemory<double> *c, int ldc) { 3535 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3536 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3537 PARAM(beta), PARAM(c), PARAM(ldc)); 3538 3539 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, 3540 const DeviceMemory<double> &, int, const DeviceMemory<double> &, 3541 int, double, DeviceMemory<double> *, int> impl; 3542 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, 3543 alpha, a, lda, b, ldb, beta, c, ldc); 3544 } 3545 3546 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, 3547 uint64 m, uint64 n, uint64 k, 3548 std::complex<float> alpha, 3549 const DeviceMemory<std::complex<float>> &a, 3550 int lda, 3551 const DeviceMemory<std::complex<float>> &b, 3552 int ldb, std::complex<float> beta, 3553 DeviceMemory<std::complex<float>> *c, int ldc) { 3554 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3555 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3556 PARAM(beta), PARAM(c), PARAM(ldc)); 3557 3558 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, 3559 std::complex<float>, const DeviceMemory<std::complex<float>> &, 3560 int, const DeviceMemory<std::complex<float>> &, int, 3561 std::complex<float>, DeviceMemory<std::complex<float>> *, 3562 int> impl; 3563 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, 3564 alpha, a, lda, b, ldb, beta, c, ldc); 3565 } 3566 3567 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, 3568 uint64 m, uint64 n, uint64 k, 3569 std::complex<double> alpha, 3570 const DeviceMemory<std::complex<double>> &a, 3571 int lda, 3572 const DeviceMemory<std::complex<double>> &b, 3573 int ldb, std::complex<double> beta, 3574 DeviceMemory<std::complex<double>> *c, int ldc) { 3575 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3576 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3577 PARAM(beta), PARAM(c), PARAM(ldc)); 3578 3579 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, 3580 std::complex<double>, const DeviceMemory<std::complex<double>> &, 3581 int, const DeviceMemory<std::complex<double>> &, int, 3582 std::complex<double>, DeviceMemory<std::complex<double>> *, 3583 int> impl; 3584 return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k, 3585 alpha, a, lda, b, ldb, beta, c, ldc); 3586 } 3587 3588 namespace { 3589 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a 3590 // blas::ProfileResult*. This functor doesn't put the stream into an error 3591 // state if the op fails and the profile result is non-null. Instead, the 3592 // error-ness is returned in the profile result itself. 3593 template <typename... Args> 3594 struct ThenBlasWithProfileImpl { 3595 Stream &operator()(Stream *stream, 3596 bool (blas::BlasSupport::*blas_func)( 3597 Stream *, Args..., blas::ProfileResult *), 3598 Args... args, blas::ProfileResult *profile_result) { 3599 ThenBlasImpl<Args..., blas::ProfileResult *> Runner; 3600 bool record_error = profile_result == nullptr; 3601 return Runner.Run(stream, blas_func, record_error, args..., profile_result); 3602 } 3603 }; 3604 } // anonymous namespace 3605 3606 Stream &Stream::ThenBlasGemvWithProfiling( 3607 blas::Transpose trans, uint64 m, uint64 n, float alpha, 3608 const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x, 3609 int incx, float beta, DeviceMemory<float> *y, int incy, 3610 blas::ProfileResult *output_profile_result) { 3611 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), 3612 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), 3613 PARAM(incy)); 3614 3615 ThenBlasWithProfileImpl< 3616 blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int, 3617 const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int> 3618 impl; 3619 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, 3620 alpha, a, lda, x, incx, beta, y, incy, output_profile_result); 3621 } 3622 3623 Stream &Stream::ThenBlasGemvWithProfiling( 3624 blas::Transpose trans, uint64 m, uint64 n, double alpha, 3625 const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x, 3626 int incx, double beta, DeviceMemory<double> *y, int incy, 3627 blas::ProfileResult *output_profile_result) { 3628 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), 3629 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), 3630 PARAM(incy)); 3631 3632 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double, 3633 const DeviceMemory<double> &, int, 3634 const DeviceMemory<double> &, int, double, 3635 DeviceMemory<double> *, int> 3636 impl; 3637 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, 3638 alpha, a, lda, x, incx, beta, y, incy, output_profile_result); 3639 } 3640 3641 Stream &Stream::ThenBlasGemvWithProfiling( 3642 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha, 3643 const DeviceMemory<std::complex<float>> &a, int lda, 3644 const DeviceMemory<std::complex<float>> &x, int incx, 3645 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, 3646 blas::ProfileResult *output_profile_result) { 3647 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), 3648 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), 3649 PARAM(incy)); 3650 3651 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>, 3652 const DeviceMemory<std::complex<float>> &, int, 3653 const DeviceMemory<std::complex<float>> &, int, 3654 std::complex<float>, 3655 DeviceMemory<std::complex<float>> *, int> 3656 impl; 3657 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, 3658 alpha, a, lda, x, incx, beta, y, incy, output_profile_result); 3659 } 3660 3661 Stream &Stream::ThenBlasGemvWithProfiling( 3662 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha, 3663 const DeviceMemory<std::complex<double>> &a, int lda, 3664 const DeviceMemory<std::complex<double>> &x, int incx, 3665 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy, 3666 blas::ProfileResult *output_profile_result) { 3667 VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a), 3668 PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), 3669 PARAM(incy)); 3670 3671 ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>, 3672 const DeviceMemory<std::complex<double>> &, int, 3673 const DeviceMemory<std::complex<double>> &, int, 3674 std::complex<double>, 3675 DeviceMemory<std::complex<double>> *, int> 3676 impl; 3677 return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n, 3678 alpha, a, lda, x, incx, beta, y, incy, output_profile_result); 3679 } 3680 3681 Stream &Stream::ThenBlasGemmWithProfiling( 3682 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3683 uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda, 3684 const DeviceMemory<Eigen::half> &b, int ldb, float beta, 3685 DeviceMemory<Eigen::half> *c, int ldc, 3686 blas::ProfileResult *output_profile_result) { 3687 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3688 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3689 PARAM(beta), PARAM(c), PARAM(ldc)); 3690 3691 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, 3692 uint64, float, const DeviceMemory<Eigen::half> &, int, 3693 const DeviceMemory<Eigen::half> &, int, float, 3694 DeviceMemory<Eigen::half> *, int> 3695 impl; 3696 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, 3697 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 3698 output_profile_result); 3699 } 3700 3701 Stream &Stream::ThenBlasGemmWithProfiling( 3702 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3703 uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 3704 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, 3705 int ldc, blas::ProfileResult *output_profile_result) { 3706 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3707 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3708 PARAM(beta), PARAM(c), PARAM(ldc)); 3709 3710 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, 3711 uint64, float, const DeviceMemory<float> &, int, 3712 const DeviceMemory<float> &, int, float, 3713 DeviceMemory<float> *, int> 3714 impl; 3715 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, 3716 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 3717 output_profile_result); 3718 } 3719 3720 Stream &Stream::ThenBlasGemmWithProfiling( 3721 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3722 uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 3723 const DeviceMemory<double> &b, int ldb, double beta, 3724 DeviceMemory<double> *c, int ldc, 3725 blas::ProfileResult *output_profile_result) { 3726 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3727 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3728 PARAM(beta), PARAM(c), PARAM(ldc)); 3729 3730 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, 3731 uint64, double, const DeviceMemory<double> &, int, 3732 const DeviceMemory<double> &, int, double, 3733 DeviceMemory<double> *, int> 3734 impl; 3735 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, 3736 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 3737 output_profile_result); 3738 } 3739 3740 Stream &Stream::ThenBlasGemmWithProfiling( 3741 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3742 uint64 k, std::complex<float> alpha, 3743 const DeviceMemory<std::complex<float>> &a, int lda, 3744 const DeviceMemory<std::complex<float>> &b, int ldb, 3745 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 3746 blas::ProfileResult *output_profile_result) { 3747 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3748 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3749 PARAM(beta), PARAM(c), PARAM(ldc)); 3750 3751 ThenBlasWithProfileImpl< 3752 blas::Transpose, blas::Transpose, uint64, uint64, uint64, 3753 std::complex<float>, const DeviceMemory<std::complex<float>> &, int, 3754 const DeviceMemory<std::complex<float>> &, int, std::complex<float>, 3755 DeviceMemory<std::complex<float>> *, int> 3756 impl; 3757 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, 3758 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 3759 output_profile_result); 3760 } 3761 3762 Stream &Stream::ThenBlasGemmWithProfiling( 3763 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3764 uint64 k, std::complex<double> alpha, 3765 const DeviceMemory<std::complex<double>> &a, int lda, 3766 const DeviceMemory<std::complex<double>> &b, int ldb, 3767 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 3768 blas::ProfileResult *output_profile_result) { 3769 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3770 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3771 PARAM(beta), PARAM(c), PARAM(ldc)); 3772 3773 ThenBlasWithProfileImpl< 3774 blas::Transpose, blas::Transpose, uint64, uint64, uint64, 3775 std::complex<double>, const DeviceMemory<std::complex<double>> &, int, 3776 const DeviceMemory<std::complex<double>> &, int, std::complex<double>, 3777 DeviceMemory<std::complex<double>> *, int> 3778 impl; 3779 return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb, 3780 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, 3781 output_profile_result); 3782 } 3783 3784 Stream &Stream::ThenBlasGemmWithAlgorithm( 3785 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3786 uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a, 3787 int lda, const DeviceMemory<Eigen::half> &b, int ldb, 3788 const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc, 3789 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 3790 blas::ProfileResult *output_profile_result) { 3791 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3792 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3793 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), 3794 PARAM(algorithm)); 3795 3796 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, 3797 uint64, const Eigen::half &, 3798 const DeviceMemory<Eigen::half> &, int, 3799 const DeviceMemory<Eigen::half> &, int, 3800 const Eigen::half &, DeviceMemory<Eigen::half> *, int, 3801 blas::ComputationType, blas::AlgorithmType> 3802 impl; 3803 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, 3804 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, 3805 algorithm, output_profile_result); 3806 } 3807 3808 Stream &Stream::ThenBlasGemmWithAlgorithm( 3809 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3810 uint64 k, int alpha, const DeviceMemory<int8> &a, int lda, 3811 const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c, 3812 int ldc, blas::ComputationType computation_type, 3813 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { 3814 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3815 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3816 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), 3817 PARAM(algorithm)); 3818 3819 ThenBlasWithProfileImpl< 3820 blas::Transpose, blas::Transpose, uint64, uint64, uint64, int, 3821 const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int, 3822 DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType> 3823 impl; 3824 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, 3825 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, 3826 algorithm, output_profile_result); 3827 } 3828 3829 Stream &Stream::ThenBlasGemmWithAlgorithm( 3830 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3831 uint64 k, float alpha, const DeviceMemory<float> &a, int lda, 3832 const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c, 3833 int ldc, blas::ComputationType computation_type, 3834 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { 3835 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3836 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3837 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), 3838 PARAM(algorithm)); 3839 3840 ThenBlasWithProfileImpl< 3841 blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, 3842 const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float, 3843 DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType> 3844 impl; 3845 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, 3846 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, 3847 algorithm, output_profile_result); 3848 } 3849 3850 Stream &Stream::ThenBlasGemmWithAlgorithm( 3851 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3852 uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 3853 const DeviceMemory<double> &b, int ldb, double beta, 3854 DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type, 3855 blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) { 3856 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3857 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3858 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), 3859 PARAM(algorithm)); 3860 3861 ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64, 3862 uint64, double, const DeviceMemory<double> &, int, 3863 const DeviceMemory<double> &, int, double, 3864 DeviceMemory<double> *, int, blas::ComputationType, 3865 blas::AlgorithmType> 3866 impl; 3867 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, 3868 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, 3869 algorithm, output_profile_result); 3870 } 3871 3872 Stream &Stream::ThenBlasGemmWithAlgorithm( 3873 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3874 uint64 k, std::complex<float> alpha, 3875 const DeviceMemory<std::complex<float>> &a, int lda, 3876 const DeviceMemory<std::complex<float>> &b, int ldb, 3877 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 3878 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 3879 blas::ProfileResult *output_profile_result) { 3880 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3881 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3882 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), 3883 PARAM(algorithm)); 3884 3885 ThenBlasWithProfileImpl< 3886 blas::Transpose, blas::Transpose, uint64, uint64, uint64, 3887 std::complex<float>, const DeviceMemory<std::complex<float>> &, int, 3888 const DeviceMemory<std::complex<float>> &, int, std::complex<float>, 3889 DeviceMemory<std::complex<float>> *, int, blas::ComputationType, 3890 blas::AlgorithmType> 3891 impl; 3892 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, 3893 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, 3894 algorithm, output_profile_result); 3895 } 3896 3897 Stream &Stream::ThenBlasGemmWithAlgorithm( 3898 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 3899 uint64 k, std::complex<double> alpha, 3900 const DeviceMemory<std::complex<double>> &a, int lda, 3901 const DeviceMemory<std::complex<double>> &b, int ldb, 3902 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 3903 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 3904 blas::ProfileResult *output_profile_result) { 3905 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 3906 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 3907 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type), 3908 PARAM(algorithm)); 3909 3910 ThenBlasWithProfileImpl< 3911 blas::Transpose, blas::Transpose, uint64, uint64, uint64, 3912 std::complex<double>, const DeviceMemory<std::complex<double>> &, int, 3913 const DeviceMemory<std::complex<double>> &, int, std::complex<double>, 3914 DeviceMemory<std::complex<double>> *, int, blas::ComputationType, 3915 blas::AlgorithmType> 3916 impl; 3917 return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb, 3918 m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type, 3919 algorithm, output_profile_result); 3920 } 3921 3922 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, 3923 uint64 n, std::complex<float> alpha, 3924 const DeviceMemory<std::complex<float>> &a, 3925 int lda, 3926 const DeviceMemory<std::complex<float>> &b, 3927 int ldb, std::complex<float> beta, 3928 DeviceMemory<std::complex<float>> *c, int ldc) { 3929 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), 3930 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 3931 PARAM(ldc)); 3932 3933 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, 3934 std::complex<float>, const DeviceMemory<std::complex<float>> &, 3935 int, const DeviceMemory<std::complex<float>> &, int, 3936 std::complex<float>, DeviceMemory<std::complex<float>> *, 3937 int> impl; 3938 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a, 3939 lda, b, ldb, beta, c, ldc); 3940 } 3941 3942 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, 3943 uint64 n, std::complex<double> alpha, 3944 const DeviceMemory<std::complex<double>> &a, 3945 int lda, 3946 const DeviceMemory<std::complex<double>> &b, 3947 int ldb, std::complex<double> beta, 3948 DeviceMemory<std::complex<double>> *c, int ldc) { 3949 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), 3950 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 3951 PARAM(ldc)); 3952 3953 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, 3954 std::complex<double>, const DeviceMemory<std::complex<double>> &, 3955 int, const DeviceMemory<std::complex<double>> &, int, 3956 std::complex<double>, DeviceMemory<std::complex<double>> *, 3957 int> impl; 3958 return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a, 3959 lda, b, ldb, beta, c, ldc); 3960 } 3961 3962 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, 3963 uint64 n, uint64 k, float alpha, 3964 const DeviceMemory<std::complex<float>> &a, 3965 int lda, float beta, 3966 DeviceMemory<std::complex<float>> *c, int ldc) { 3967 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 3968 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); 3969 3970 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float, 3971 const DeviceMemory<std::complex<float>> &, int, float, 3972 DeviceMemory<std::complex<float>> *, int> impl; 3973 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a, 3974 lda, beta, c, ldc); 3975 } 3976 3977 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, 3978 uint64 n, uint64 k, double alpha, 3979 const DeviceMemory<std::complex<double>> &a, 3980 int lda, double beta, 3981 DeviceMemory<std::complex<double>> *c, int ldc) { 3982 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 3983 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); 3984 3985 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double, 3986 const DeviceMemory<std::complex<double>> &, int, double, 3987 DeviceMemory<std::complex<double>> *, int> impl; 3988 return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a, 3989 lda, beta, c, ldc); 3990 } 3991 3992 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, 3993 uint64 n, uint64 k, std::complex<float> alpha, 3994 const DeviceMemory<std::complex<float>> &a, 3995 int lda, 3996 const DeviceMemory<std::complex<float>> &b, 3997 int ldb, float beta, 3998 DeviceMemory<std::complex<float>> *c, int ldc) { 3999 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4000 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4001 PARAM(ldc)); 4002 4003 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, 4004 std::complex<float>, const DeviceMemory<std::complex<float>> &, 4005 int, const DeviceMemory<std::complex<float>> &, int, float, 4006 DeviceMemory<std::complex<float>> *, int> impl; 4007 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha, 4008 a, lda, b, ldb, beta, c, ldc); 4009 } 4010 4011 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, 4012 uint64 n, uint64 k, std::complex<double> alpha, 4013 const DeviceMemory<std::complex<double>> &a, 4014 int lda, 4015 const DeviceMemory<std::complex<double>> &b, 4016 int ldb, double beta, 4017 DeviceMemory<std::complex<double>> *c, int ldc) { 4018 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4019 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4020 PARAM(ldc)); 4021 4022 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, 4023 std::complex<double>, const DeviceMemory<std::complex<double>> &, 4024 int, const DeviceMemory<std::complex<double>> &, int, double, 4025 DeviceMemory<std::complex<double>> *, int> impl; 4026 return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha, 4027 a, lda, b, ldb, beta, c, ldc); 4028 } 4029 4030 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, 4031 uint64 n, float alpha, 4032 const DeviceMemory<float> &a, int lda, 4033 const DeviceMemory<float> &b, int ldb, float beta, 4034 DeviceMemory<float> *c, int ldc) { 4035 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), 4036 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4037 PARAM(ldc)); 4038 4039 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float, 4040 const DeviceMemory<float> &, int, const DeviceMemory<float> &, 4041 int, float, DeviceMemory<float> *, int> impl; 4042 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a, 4043 lda, b, ldb, beta, c, ldc); 4044 } 4045 4046 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, 4047 uint64 n, double alpha, 4048 const DeviceMemory<double> &a, int lda, 4049 const DeviceMemory<double> &b, int ldb, 4050 double beta, DeviceMemory<double> *c, int ldc) { 4051 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), 4052 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4053 PARAM(ldc)); 4054 4055 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double, 4056 const DeviceMemory<double> &, int, const DeviceMemory<double> &, 4057 int, double, DeviceMemory<double> *, int> impl; 4058 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a, 4059 lda, b, ldb, beta, c, ldc); 4060 } 4061 4062 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, 4063 uint64 n, std::complex<float> alpha, 4064 const DeviceMemory<std::complex<float>> &a, 4065 int lda, 4066 const DeviceMemory<std::complex<float>> &b, 4067 int ldb, std::complex<float> beta, 4068 DeviceMemory<std::complex<float>> *c, int ldc) { 4069 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), 4070 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4071 PARAM(ldc)); 4072 4073 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, 4074 std::complex<float>, const DeviceMemory<std::complex<float>> &, 4075 int, const DeviceMemory<std::complex<float>> &, int, 4076 std::complex<float>, DeviceMemory<std::complex<float>> *, 4077 int> impl; 4078 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a, 4079 lda, b, ldb, beta, c, ldc); 4080 } 4081 4082 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, 4083 uint64 n, std::complex<double> alpha, 4084 const DeviceMemory<std::complex<double>> &a, 4085 int lda, 4086 const DeviceMemory<std::complex<double>> &b, 4087 int ldb, std::complex<double> beta, 4088 DeviceMemory<std::complex<double>> *c, int ldc) { 4089 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha), 4090 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4091 PARAM(ldc)); 4092 4093 ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, 4094 std::complex<double>, const DeviceMemory<std::complex<double>> &, 4095 int, const DeviceMemory<std::complex<double>> &, int, 4096 std::complex<double>, DeviceMemory<std::complex<double>> *, 4097 int> impl; 4098 return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a, 4099 lda, b, ldb, beta, c, ldc); 4100 } 4101 4102 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, 4103 uint64 n, uint64 k, float alpha, 4104 const DeviceMemory<float> &a, int lda, float beta, 4105 DeviceMemory<float> *c, int ldc) { 4106 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4107 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); 4108 4109 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float, 4110 const DeviceMemory<float> &, int, float, DeviceMemory<float> *, 4111 int> impl; 4112 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a, 4113 lda, beta, c, ldc); 4114 } 4115 4116 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, 4117 uint64 n, uint64 k, double alpha, 4118 const DeviceMemory<double> &a, int lda, 4119 double beta, DeviceMemory<double> *c, int ldc) { 4120 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4121 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); 4122 4123 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double, 4124 const DeviceMemory<double> &, int, double, 4125 DeviceMemory<double> *, int> impl; 4126 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a, 4127 lda, beta, c, ldc); 4128 } 4129 4130 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, 4131 uint64 n, uint64 k, std::complex<float> alpha, 4132 const DeviceMemory<std::complex<float>> &a, 4133 int lda, std::complex<float> beta, 4134 DeviceMemory<std::complex<float>> *c, int ldc) { 4135 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4136 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); 4137 4138 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, 4139 std::complex<float>, const DeviceMemory<std::complex<float>> &, 4140 int, std::complex<float>, DeviceMemory<std::complex<float>> *, 4141 int> impl; 4142 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a, 4143 lda, beta, c, ldc); 4144 } 4145 4146 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, 4147 uint64 n, uint64 k, std::complex<double> alpha, 4148 const DeviceMemory<std::complex<double>> &a, 4149 int lda, std::complex<double> beta, 4150 DeviceMemory<std::complex<double>> *c, int ldc) { 4151 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4152 PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc)); 4153 4154 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, 4155 std::complex<double>, const DeviceMemory<std::complex<double>> &, 4156 int, std::complex<double>, DeviceMemory<std::complex<double>> *, 4157 int> impl; 4158 return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a, 4159 lda, beta, c, ldc); 4160 } 4161 4162 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, 4163 uint64 n, uint64 k, float alpha, 4164 const DeviceMemory<float> &a, int lda, 4165 const DeviceMemory<float> &b, int ldb, float beta, 4166 DeviceMemory<float> *c, int ldc) { 4167 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4168 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4169 PARAM(ldc)); 4170 4171 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float, 4172 const DeviceMemory<float> &, int, const DeviceMemory<float> &, 4173 int, float, DeviceMemory<float> *, int> impl; 4174 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha, 4175 a, lda, b, ldb, beta, c, ldc); 4176 } 4177 4178 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, 4179 uint64 n, uint64 k, double alpha, 4180 const DeviceMemory<double> &a, int lda, 4181 const DeviceMemory<double> &b, int ldb, 4182 double beta, DeviceMemory<double> *c, int ldc) { 4183 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4184 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4185 PARAM(ldc)); 4186 4187 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double, 4188 const DeviceMemory<double> &, int, const DeviceMemory<double> &, 4189 int, double, DeviceMemory<double> *, int> impl; 4190 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha, 4191 a, lda, b, ldb, beta, c, ldc); 4192 } 4193 4194 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, 4195 uint64 n, uint64 k, std::complex<float> alpha, 4196 const DeviceMemory<std::complex<float>> &a, 4197 int lda, 4198 const DeviceMemory<std::complex<float>> &b, 4199 int ldb, std::complex<float> beta, 4200 DeviceMemory<std::complex<float>> *c, int ldc) { 4201 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4202 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4203 PARAM(ldc)); 4204 4205 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, 4206 std::complex<float>, const DeviceMemory<std::complex<float>> &, 4207 int, const DeviceMemory<std::complex<float>> &, int, 4208 std::complex<float>, DeviceMemory<std::complex<float>> *, 4209 int> impl; 4210 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha, 4211 a, lda, b, ldb, beta, c, ldc); 4212 } 4213 4214 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, 4215 uint64 n, uint64 k, std::complex<double> alpha, 4216 const DeviceMemory<std::complex<double>> &a, 4217 int lda, 4218 const DeviceMemory<std::complex<double>> &b, 4219 int ldb, std::complex<double> beta, 4220 DeviceMemory<std::complex<double>> *c, int ldc) { 4221 VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha), 4222 PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), 4223 PARAM(ldc)); 4224 4225 ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, 4226 std::complex<double>, const DeviceMemory<std::complex<double>> &, 4227 int, const DeviceMemory<std::complex<double>> &, int, 4228 std::complex<double>, DeviceMemory<std::complex<double>> *, 4229 int> impl; 4230 return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha, 4231 a, lda, b, ldb, beta, c, ldc); 4232 } 4233 4234 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, 4235 blas::Transpose transa, blas::Diagonal diag, 4236 uint64 m, uint64 n, float alpha, 4237 const DeviceMemory<float> &a, int lda, 4238 DeviceMemory<float> *b, int ldb) { 4239 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), 4240 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); 4241 4242 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, 4243 uint64, uint64, float, const DeviceMemory<float> &, int, 4244 DeviceMemory<float> *, int> impl; 4245 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m, 4246 n, alpha, a, lda, b, ldb); 4247 } 4248 4249 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, 4250 blas::Transpose transa, blas::Diagonal diag, 4251 uint64 m, uint64 n, double alpha, 4252 const DeviceMemory<double> &a, int lda, 4253 DeviceMemory<double> *b, int ldb) { 4254 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), 4255 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); 4256 4257 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, 4258 uint64, uint64, double, const DeviceMemory<double> &, int, 4259 DeviceMemory<double> *, int> impl; 4260 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m, 4261 n, alpha, a, lda, b, ldb); 4262 } 4263 4264 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, 4265 blas::Transpose transa, blas::Diagonal diag, 4266 uint64 m, uint64 n, std::complex<float> alpha, 4267 const DeviceMemory<std::complex<float>> &a, 4268 int lda, DeviceMemory<std::complex<float>> *b, 4269 int ldb) { 4270 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), 4271 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); 4272 4273 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, 4274 uint64, uint64, std::complex<float>, 4275 const DeviceMemory<std::complex<float>> &, int, 4276 DeviceMemory<std::complex<float>> *, int> impl; 4277 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m, 4278 n, alpha, a, lda, b, ldb); 4279 } 4280 4281 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, 4282 blas::Transpose transa, blas::Diagonal diag, 4283 uint64 m, uint64 n, std::complex<double> alpha, 4284 const DeviceMemory<std::complex<double>> &a, 4285 int lda, DeviceMemory<std::complex<double>> *b, 4286 int ldb) { 4287 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), 4288 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); 4289 4290 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, 4291 uint64, uint64, std::complex<double>, 4292 const DeviceMemory<std::complex<double>> &, int, 4293 DeviceMemory<std::complex<double>> *, int> impl; 4294 return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m, 4295 n, alpha, a, lda, b, ldb); 4296 } 4297 4298 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, 4299 blas::Transpose transa, blas::Diagonal diag, 4300 uint64 m, uint64 n, float alpha, 4301 const DeviceMemory<float> &a, int lda, 4302 DeviceMemory<float> *b, int ldb) { 4303 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), 4304 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); 4305 4306 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, 4307 uint64, uint64, float, const DeviceMemory<float> &, int, 4308 DeviceMemory<float> *, int> impl; 4309 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, 4310 n, alpha, a, lda, b, ldb); 4311 } 4312 4313 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, 4314 blas::Transpose transa, blas::Diagonal diag, 4315 uint64 m, uint64 n, double alpha, 4316 const DeviceMemory<double> &a, int lda, 4317 DeviceMemory<double> *b, int ldb) { 4318 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), 4319 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); 4320 4321 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, 4322 uint64, uint64, double, const DeviceMemory<double> &, int, 4323 DeviceMemory<double> *, int> impl; 4324 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, 4325 n, alpha, a, lda, b, ldb); 4326 } 4327 4328 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, 4329 blas::Transpose transa, blas::Diagonal diag, 4330 uint64 m, uint64 n, std::complex<float> alpha, 4331 const DeviceMemory<std::complex<float>> &a, 4332 int lda, DeviceMemory<std::complex<float>> *b, 4333 int ldb) { 4334 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), 4335 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); 4336 4337 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, 4338 uint64, uint64, std::complex<float>, 4339 const DeviceMemory<std::complex<float>> &, int, 4340 DeviceMemory<std::complex<float>> *, int> impl; 4341 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, 4342 n, alpha, a, lda, b, ldb); 4343 } 4344 4345 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, 4346 blas::Transpose transa, blas::Diagonal diag, 4347 uint64 m, uint64 n, std::complex<double> alpha, 4348 const DeviceMemory<std::complex<double>> &a, 4349 int lda, DeviceMemory<std::complex<double>> *b, 4350 int ldb) { 4351 VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m), 4352 PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb)); 4353 4354 ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal, 4355 uint64, uint64, std::complex<double>, 4356 const DeviceMemory<std::complex<double>> &, int, 4357 DeviceMemory<std::complex<double>> *, int> impl; 4358 return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m, 4359 n, alpha, a, lda, b, ldb); 4360 } 4361 4362 Stream &Stream::ThenBlasGemmBatched( 4363 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 4364 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a, 4365 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, 4366 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, 4367 int batch_count) { 4368 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, 4369 b, ldb, beta, c, ldc, batch_count, 4370 /*scratch_allocator=*/nullptr); 4371 } 4372 4373 Stream &Stream::ThenBlasGemmBatchedWithScratch( 4374 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 4375 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a, 4376 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, 4377 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, 4378 int batch_count, ScratchAllocator *scratch_allocator) { 4379 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 4380 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 4381 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); 4382 4383 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float, 4384 const port::ArraySlice<DeviceMemory<float> *> &, int, 4385 const port::ArraySlice<DeviceMemory<float> *> &, int, float, 4386 const port::ArraySlice<DeviceMemory<float> *> &, int, int, 4387 ScratchAllocator *> 4388 impl; 4389 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, 4390 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, 4391 scratch_allocator); 4392 } 4393 4394 Stream &Stream::ThenBlasGemmBatched( 4395 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 4396 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a, 4397 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, 4398 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, 4399 int batch_count) { 4400 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, 4401 b, ldb, beta, c, ldc, batch_count, 4402 /*scratch_allocator=*/nullptr); 4403 } 4404 4405 Stream &Stream::ThenBlasGemmBatchedWithScratch( 4406 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 4407 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a, 4408 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, 4409 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, 4410 int batch_count, ScratchAllocator *scratch_allocator) { 4411 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 4412 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 4413 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); 4414 4415 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double, 4416 const port::ArraySlice<DeviceMemory<double> *> &, int, 4417 const port::ArraySlice<DeviceMemory<double> *> &, int, double, 4418 const port::ArraySlice<DeviceMemory<double> *> &, int, int, 4419 ScratchAllocator *> 4420 impl; 4421 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, 4422 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, 4423 scratch_allocator); 4424 } 4425 4426 Stream &Stream::ThenBlasGemmBatched( 4427 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 4428 uint64 k, std::complex<float> alpha, 4429 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, 4430 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, 4431 std::complex<float> beta, 4432 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, 4433 int batch_count) { 4434 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, 4435 b, ldb, beta, c, ldc, batch_count, 4436 /*scratch_allocator=*/nullptr); 4437 } 4438 4439 Stream &Stream::ThenBlasGemmBatchedWithScratch( 4440 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 4441 uint64 k, std::complex<float> alpha, 4442 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, 4443 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, 4444 std::complex<float> beta, 4445 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, 4446 int batch_count, ScratchAllocator *scratch_allocator) { 4447 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 4448 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 4449 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); 4450 4451 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, 4452 std::complex<float>, 4453 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &, 4454 int, 4455 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &, 4456 int, std::complex<float>, 4457 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &, 4458 int, int, ScratchAllocator *> 4459 impl; 4460 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, 4461 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, 4462 scratch_allocator); 4463 } 4464 4465 Stream &Stream::ThenBlasGemmBatched( 4466 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 4467 uint64 k, std::complex<double> alpha, 4468 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, 4469 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, 4470 std::complex<double> beta, 4471 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, 4472 int batch_count) { 4473 return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, 4474 b, ldb, beta, c, ldc, batch_count, 4475 /*scratch_allocator=*/nullptr); 4476 } 4477 4478 Stream &Stream::ThenBlasGemmBatchedWithScratch( 4479 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 4480 uint64 k, std::complex<double> alpha, 4481 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, 4482 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, 4483 std::complex<double> beta, 4484 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, 4485 int batch_count, ScratchAllocator *scratch_allocator) { 4486 VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), 4487 PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), 4488 PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); 4489 4490 ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, 4491 std::complex<double>, 4492 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &, 4493 int, 4494 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &, 4495 int, std::complex<double>, 4496 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &, 4497 int, int, ScratchAllocator *> 4498 impl; 4499 return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, 4500 k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, 4501 scratch_allocator); 4502 } 4503 4504 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) { 4505 VLOG_CALL(PARAM(seed), PARAM(seed_bytes)); 4506 4507 if (ok()) { 4508 if (rng::RngSupport *rng = parent_->AsRng()) { 4509 CheckError(rng->SetSeed(this, seed, seed_bytes)); 4510 } else { 4511 SetError(); 4512 LOG(INFO) << "stream " << this << " unable to initialize RNG"; 4513 } 4514 } else { 4515 LOG(INFO) << "stream " << this 4516 << " did not set RNG seed: " << static_cast<const void *>(seed) 4517 << "; bytes: " << seed_bytes; 4518 } 4519 return *this; 4520 } 4521 4522 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) { 4523 VLOG_CALL(PARAM(values)); 4524 4525 if (ok()) { 4526 if (rng::RngSupport *rng = parent_->AsRng()) { 4527 CheckError(rng->DoPopulateRandUniform(this, values)); 4528 } else { 4529 SetError(); 4530 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " 4531 "without RNG support."; 4532 } 4533 } 4534 return *this; 4535 } 4536 4537 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd, 4538 DeviceMemory<float> *values) { 4539 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values)); 4540 4541 if (ok()) { 4542 if (rng::RngSupport *rng = parent_->AsRng()) { 4543 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); 4544 } else { 4545 SetError(); 4546 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " 4547 "without RNG support."; 4548 } 4549 } 4550 return *this; 4551 } 4552 4553 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd, 4554 DeviceMemory<double> *values) { 4555 VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values)); 4556 4557 if (ok()) { 4558 if (rng::RngSupport *rng = parent_->AsRng()) { 4559 CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values)); 4560 } else { 4561 SetError(); 4562 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " 4563 "without RNG support."; 4564 } 4565 } 4566 return *this; 4567 } 4568 4569 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) { 4570 VLOG_CALL(PARAM(values)); 4571 4572 if (ok()) { 4573 if (rng::RngSupport *rng = parent_->AsRng()) { 4574 CheckError(rng->DoPopulateRandUniform(this, values)); 4575 } else { 4576 SetError(); 4577 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " 4578 "without RNG support."; 4579 } 4580 } 4581 return *this; 4582 } 4583 4584 Stream &Stream::ThenPopulateRandUniform( 4585 DeviceMemory<std::complex<float>> *values) { 4586 VLOG_CALL(PARAM(values)); 4587 4588 if (ok()) { 4589 if (rng::RngSupport *rng = parent_->AsRng()) { 4590 CheckError(rng->DoPopulateRandUniform(this, values)); 4591 } else { 4592 SetError(); 4593 LOG(INFO) << "attempting to perform RNG operation using StreamExecutor " 4594 "without RNG support."; 4595 } 4596 } 4597 return *this; 4598 } 4599 4600 Stream &Stream::ThenPopulateRandUniform( 4601 DeviceMemory<std::complex<double>> *values) { 4602 VLOG_CALL(PARAM(values)); 4603 4604 if (ok()) { 4605 if (rng::RngSupport *rng = parent_->AsRng()) { 4606 CheckError(rng->DoPopulateRandUniform(this, values)); 4607 } else { 4608 SetError(); 4609 LOG(INFO) << "stream " << this 4610 << " attempting to perform RNG operation using StreamExecutor " 4611 "without RNG support."; 4612 } 4613 } 4614 return *this; 4615 } 4616 4617 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, 4618 uint64 size) { 4619 VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size)); 4620 4621 if (ok()) { 4622 CheckError(parent_->Memcpy(this, host_dst, gpu_src, size)); 4623 } else { 4624 LOG(INFO) << "stream " << this 4625 << " did not memcpy device-to-host; source: " << gpu_src.opaque(); 4626 } 4627 return *this; 4628 } 4629 4630 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, 4631 uint64 size) { 4632 VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size)); 4633 4634 if (ok()) { 4635 CheckError(parent_->Memcpy(this, gpu_dst, host_src, size)); 4636 } else { 4637 LOG(INFO) << "stream " << this 4638 << " did not memcpy host-to-device; source: " << host_src; 4639 } 4640 return *this; 4641 } 4642 4643 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, 4644 const DeviceMemoryBase &gpu_src, uint64 size) { 4645 VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size)); 4646 4647 if (ok()) { 4648 CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size)); 4649 } else { 4650 LOG(INFO) << "stream " << this 4651 << " did not memcpy gpu-to-gpu; source: " << &gpu_src; 4652 } 4653 return *this; 4654 } 4655 4656 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) { 4657 VLOG_CALL(PARAM(location), PARAM(size)); 4658 4659 if (ok()) { 4660 CheckError(parent_->MemZero(this, location, size)); 4661 } else { 4662 LOG(INFO) << "stream " << this 4663 << " did not memzero GPU location; source: " << location; 4664 } 4665 return *this; 4666 } 4667 4668 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern, 4669 uint64 size) { 4670 VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size)); 4671 4672 if (ok()) { 4673 CheckError(parent_->Memset32(this, location, pattern, size)); 4674 } else { 4675 LOG(INFO) << "stream " << this 4676 << " did not memset GPU location; source: " << location 4677 << "; size: " << size << "; pattern: " << std::hex << pattern; 4678 } 4679 return *this; 4680 } 4681 4682 Stream &Stream::ThenRnnForward( 4683 const dnn::RnnDescriptor &rnn_desc, 4684 const dnn::RnnSequenceTensorDescriptor &input_desc, 4685 const DeviceMemory<Eigen::half> &input_data, 4686 const dnn::RnnStateTensorDescriptor &input_h_desc, 4687 const DeviceMemory<Eigen::half> &input_h_data, 4688 const dnn::RnnStateTensorDescriptor &input_c_desc, 4689 const DeviceMemory<Eigen::half> &input_c_data, 4690 const DeviceMemory<Eigen::half> ¶ms, 4691 const dnn::RnnSequenceTensorDescriptor &output_desc, 4692 DeviceMemory<Eigen::half> *output_data, 4693 const dnn::RnnStateTensorDescriptor &output_h_desc, 4694 DeviceMemory<Eigen::half> *output_h_data, 4695 const dnn::RnnStateTensorDescriptor &output_c_desc, 4696 DeviceMemory<Eigen::half> *output_c_data, bool is_training, 4697 ScratchAllocator *reserve_space_allocator, 4698 ScratchAllocator *workspace_allocator) { 4699 // TODO(zhengxq): add VLOG PARAM calls. 4700 if (ok()) { 4701 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 4702 CheckError(dnn->DoRnnForward( 4703 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, 4704 input_c_desc, input_c_data, params, output_desc, output_data, 4705 output_h_desc, output_h_data, output_c_desc, output_c_data, 4706 is_training, reserve_space_allocator, workspace_allocator)); 4707 } else { 4708 SetError(); 4709 LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support"; 4710 } 4711 } 4712 return *this; 4713 } 4714 4715 Stream &Stream::ThenRnnForward( 4716 const dnn::RnnDescriptor &rnn_desc, 4717 const dnn::RnnSequenceTensorDescriptor &input_desc, 4718 const DeviceMemory<float> &input_data, 4719 const dnn::RnnStateTensorDescriptor &input_h_desc, 4720 const DeviceMemory<float> &input_h_data, 4721 const dnn::RnnStateTensorDescriptor &input_c_desc, 4722 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms, 4723 const dnn::RnnSequenceTensorDescriptor &output_desc, 4724 DeviceMemory<float> *output_data, 4725 const dnn::RnnStateTensorDescriptor &output_h_desc, 4726 DeviceMemory<float> *output_h_data, 4727 const dnn::RnnStateTensorDescriptor &output_c_desc, 4728 DeviceMemory<float> *output_c_data, bool is_training, 4729 ScratchAllocator *reserve_space_allocator, 4730 ScratchAllocator *workspace_allocator) { 4731 // TODO(zhengxq): add VLOG PARAM calls. 4732 if (ok()) { 4733 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 4734 CheckError(dnn->DoRnnForward( 4735 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, 4736 input_c_desc, input_c_data, params, output_desc, output_data, 4737 output_h_desc, output_h_data, output_c_desc, output_c_data, 4738 is_training, reserve_space_allocator, workspace_allocator)); 4739 } else { 4740 SetError(); 4741 LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support"; 4742 } 4743 } 4744 return *this; 4745 } 4746 4747 Stream &Stream::ThenRnnForward( 4748 const dnn::RnnDescriptor &rnn_desc, 4749 const dnn::RnnSequenceTensorDescriptor &input_desc, 4750 const DeviceMemory<double> &input_data, 4751 const dnn::RnnStateTensorDescriptor &input_h_desc, 4752 const DeviceMemory<double> &input_h_data, 4753 const dnn::RnnStateTensorDescriptor &input_c_desc, 4754 const DeviceMemory<double> &input_c_data, 4755 const DeviceMemory<double> ¶ms, 4756 const dnn::RnnSequenceTensorDescriptor &output_desc, 4757 DeviceMemory<double> *output_data, 4758 const dnn::RnnStateTensorDescriptor &output_h_desc, 4759 DeviceMemory<double> *output_h_data, 4760 const dnn::RnnStateTensorDescriptor &output_c_desc, 4761 DeviceMemory<double> *output_c_data, bool is_training, 4762 ScratchAllocator *reserve_space_allocator, 4763 ScratchAllocator *workspace_allocator) { 4764 // TODO(zhengxq): add VLOG PARAM calls. 4765 if (ok()) { 4766 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 4767 CheckError(dnn->DoRnnForward( 4768 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, 4769 input_c_desc, input_c_data, params, output_desc, output_data, 4770 output_h_desc, output_h_data, output_c_desc, output_c_data, 4771 is_training, reserve_space_allocator, workspace_allocator)); 4772 } else { 4773 SetError(); 4774 LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support"; 4775 } 4776 } 4777 return *this; 4778 } 4779 4780 Stream &Stream::ThenRnnBackward( 4781 const dnn::RnnDescriptor &rnn_desc, 4782 const dnn::RnnSequenceTensorDescriptor &input_desc, 4783 const DeviceMemory<Eigen::half> &input_data, 4784 const dnn::RnnStateTensorDescriptor &input_h_desc, 4785 const DeviceMemory<Eigen::half> &input_h_data, 4786 const dnn::RnnStateTensorDescriptor &input_c_desc, 4787 const DeviceMemory<Eigen::half> &input_c_data, 4788 const DeviceMemory<Eigen::half> ¶ms, 4789 const dnn::RnnSequenceTensorDescriptor &output_desc, 4790 const DeviceMemory<Eigen::half> &output_data, 4791 const dnn::RnnStateTensorDescriptor &output_h_desc, 4792 const DeviceMemory<Eigen::half> &output_h_data, 4793 const dnn::RnnStateTensorDescriptor &output_c_desc, 4794 const DeviceMemory<Eigen::half> &output_c_data, 4795 const DeviceMemory<Eigen::half> &output_backprop_data, 4796 const DeviceMemory<Eigen::half> &output_h_backprop_data, 4797 const DeviceMemory<Eigen::half> &output_c_backprop_data, 4798 DeviceMemory<Eigen::half> *input_backprop_data, 4799 DeviceMemory<Eigen::half> *input_h_backprop_data, 4800 DeviceMemory<Eigen::half> *input_c_backprop_data, 4801 DeviceMemory<Eigen::half> *params_backprop_data, 4802 DeviceMemory<uint8> *reserve_space_data, 4803 ScratchAllocator *workspace_allocator) { 4804 // TODO(zhengxq): add VLOG PARAM calls. 4805 if (ok()) { 4806 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 4807 CheckError(dnn->DoRnnBackward( 4808 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, 4809 input_c_desc, input_c_data, params, output_desc, output_data, 4810 output_h_desc, output_h_data, output_c_desc, output_c_data, 4811 output_backprop_data, output_h_backprop_data, output_c_backprop_data, 4812 input_backprop_data, input_h_backprop_data, input_c_backprop_data, 4813 params_backprop_data, reserve_space_data, workspace_allocator)); 4814 } else { 4815 SetError(); 4816 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; 4817 } 4818 } 4819 return *this; 4820 } 4821 4822 Stream &Stream::ThenRnnBackward( 4823 const dnn::RnnDescriptor &rnn_desc, 4824 const dnn::RnnSequenceTensorDescriptor &input_desc, 4825 const DeviceMemory<float> &input_data, 4826 const dnn::RnnStateTensorDescriptor &input_h_desc, 4827 const DeviceMemory<float> &input_h_data, 4828 const dnn::RnnStateTensorDescriptor &input_c_desc, 4829 const DeviceMemory<float> &input_c_data, const DeviceMemory<float> ¶ms, 4830 const dnn::RnnSequenceTensorDescriptor &output_desc, 4831 const DeviceMemory<float> &output_data, 4832 const dnn::RnnStateTensorDescriptor &output_h_desc, 4833 const DeviceMemory<float> &output_h_data, 4834 const dnn::RnnStateTensorDescriptor &output_c_desc, 4835 const DeviceMemory<float> &output_c_data, 4836 const DeviceMemory<float> &output_backprop_data, 4837 const DeviceMemory<float> &output_h_backprop_data, 4838 const DeviceMemory<float> &output_c_backprop_data, 4839 DeviceMemory<float> *input_backprop_data, 4840 DeviceMemory<float> *input_h_backprop_data, 4841 DeviceMemory<float> *input_c_backprop_data, 4842 DeviceMemory<float> *params_backprop_data, 4843 DeviceMemory<uint8> *reserve_space_data, 4844 ScratchAllocator *workspace_allocator) { 4845 // TODO(zhengxq): add VLOG PARAM calls. 4846 if (ok()) { 4847 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 4848 CheckError(dnn->DoRnnBackward( 4849 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, 4850 input_c_desc, input_c_data, params, output_desc, output_data, 4851 output_h_desc, output_h_data, output_c_desc, output_c_data, 4852 output_backprop_data, output_h_backprop_data, output_c_backprop_data, 4853 input_backprop_data, input_h_backprop_data, input_c_backprop_data, 4854 params_backprop_data, reserve_space_data, workspace_allocator)); 4855 } else { 4856 SetError(); 4857 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; 4858 } 4859 } 4860 return *this; 4861 } 4862 4863 Stream &Stream::ThenRnnBackward( 4864 const dnn::RnnDescriptor &rnn_desc, 4865 const dnn::RnnSequenceTensorDescriptor &input_desc, 4866 const DeviceMemory<double> &input_data, 4867 const dnn::RnnStateTensorDescriptor &input_h_desc, 4868 const DeviceMemory<double> &input_h_data, 4869 const dnn::RnnStateTensorDescriptor &input_c_desc, 4870 const DeviceMemory<double> &input_c_data, 4871 const DeviceMemory<double> ¶ms, 4872 const dnn::RnnSequenceTensorDescriptor &output_desc, 4873 const DeviceMemory<double> &output_data, 4874 const dnn::RnnStateTensorDescriptor &output_h_desc, 4875 const DeviceMemory<double> &output_h_data, 4876 const dnn::RnnStateTensorDescriptor &output_c_desc, 4877 const DeviceMemory<double> &output_c_data, 4878 const DeviceMemory<double> &output_backprop_data, 4879 const DeviceMemory<double> &output_h_backprop_data, 4880 const DeviceMemory<double> &output_c_backprop_data, 4881 DeviceMemory<double> *input_backprop_data, 4882 DeviceMemory<double> *input_h_backprop_data, 4883 DeviceMemory<double> *input_c_backprop_data, 4884 DeviceMemory<double> *params_backprop_data, 4885 DeviceMemory<uint8> *reserve_space_data, 4886 ScratchAllocator *workspace_allocator) { 4887 // TODO(zhengxq): add VLOG PARAM calls. 4888 if (ok()) { 4889 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 4890 CheckError(dnn->DoRnnBackward( 4891 this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data, 4892 input_c_desc, input_c_data, params, output_desc, output_data, 4893 output_h_desc, output_h_data, output_c_desc, output_c_data, 4894 output_backprop_data, output_h_backprop_data, output_c_backprop_data, 4895 input_backprop_data, input_h_backprop_data, input_c_backprop_data, 4896 params_backprop_data, reserve_space_data, workspace_allocator)); 4897 } else { 4898 SetError(); 4899 LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support"; 4900 } 4901 } 4902 return *this; 4903 } 4904 4905 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, 4906 dnn::DataType input_type, 4907 const DeviceMemoryBase &input_data, 4908 const dnn::BatchDescriptor &output_desc, 4909 dnn::DataType output_type, float scale, 4910 DeviceMemoryBase *output_data) { 4911 VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data), 4912 PARAM(output_desc), PARAM(output_type), PARAM(scale), 4913 PARAM(output_data)); 4914 if (ok()) { 4915 if (dnn::DnnSupport *dnn = parent_->AsDnn()) { 4916 CheckError(dnn->DoTransformTensor(this, input_desc, input_type, 4917 input_data, output_desc, output_type, 4918 scale, output_data)); 4919 } else { 4920 SetErrorAndLogNoDnnSupport(); 4921 } 4922 } 4923 return *this; 4924 } 4925 4926 Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) { 4927 VLOG_CALL(PARAM(callback)); 4928 4929 return ThenDoHostCallback(callback); 4930 } 4931 4932 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) { 4933 VLOG_CALL(PARAM(callback)); 4934 4935 if (ok()) { 4936 CheckError(parent_->HostCallback(this, callback)); 4937 } else { 4938 LOG(INFO) << "stream " << this 4939 << " was in error state before adding host callback"; 4940 } 4941 return *this; 4942 } 4943 4944 Stream &Stream::ThenFft(fft::Plan *plan, 4945 const DeviceMemory<std::complex<float>> &input, 4946 DeviceMemory<std::complex<float>> *output) { 4947 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); 4948 4949 if (ok()) { 4950 if (fft::FftSupport *fft = parent_->AsFft()) { 4951 CheckError(fft->DoFft(this, plan, input, output)); 4952 } else { 4953 SetError(); 4954 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " 4955 "without FFT support"; 4956 } 4957 } 4958 return *this; 4959 } 4960 4961 Stream &Stream::ThenFft(fft::Plan *plan, 4962 const DeviceMemory<std::complex<double>> &input, 4963 DeviceMemory<std::complex<double>> *output) { 4964 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); 4965 4966 if (ok()) { 4967 if (fft::FftSupport *fft = parent_->AsFft()) { 4968 CheckError(fft->DoFft(this, plan, input, output)); 4969 } else { 4970 SetError(); 4971 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " 4972 "without FFT support"; 4973 } 4974 } 4975 return *this; 4976 } 4977 4978 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, 4979 DeviceMemory<std::complex<float>> *output) { 4980 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); 4981 4982 if (ok()) { 4983 if (fft::FftSupport *fft = parent_->AsFft()) { 4984 CheckError(fft->DoFft(this, plan, input, output)); 4985 } else { 4986 SetError(); 4987 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " 4988 "without FFT support"; 4989 } 4990 } 4991 return *this; 4992 } 4993 4994 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, 4995 DeviceMemory<std::complex<double>> *output) { 4996 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); 4997 4998 if (ok()) { 4999 if (fft::FftSupport *fft = parent_->AsFft()) { 5000 CheckError(fft->DoFft(this, plan, input, output)); 5001 } else { 5002 SetError(); 5003 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " 5004 "without FFT support"; 5005 } 5006 } 5007 return *this; 5008 } 5009 5010 Stream &Stream::ThenFft(fft::Plan *plan, 5011 const DeviceMemory<std::complex<float>> &input, 5012 DeviceMemory<float> *output) { 5013 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); 5014 5015 if (ok()) { 5016 if (fft::FftSupport *fft = parent_->AsFft()) { 5017 CheckError(fft->DoFft(this, plan, input, output)); 5018 } else { 5019 SetError(); 5020 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " 5021 "without FFT support"; 5022 } 5023 } 5024 return *this; 5025 } 5026 5027 Stream &Stream::ThenFft(fft::Plan *plan, 5028 const DeviceMemory<std::complex<double>> &input, 5029 DeviceMemory<double> *output) { 5030 VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output)); 5031 5032 if (ok()) { 5033 if (fft::FftSupport *fft = parent_->AsFft()) { 5034 CheckError(fft->DoFft(this, plan, input, output)); 5035 } else { 5036 SetError(); 5037 LOG(INFO) << "attempting to perform FFT operation using StreamExecutor " 5038 "without FFT support"; 5039 } 5040 } 5041 return *this; 5042 } 5043 5044 // It looks confusing, but all this is doing is inserting a callback at the 5045 // present point in the stream to then enqueue a task on the host executor. 5046 Stream &Stream::ThenEnqueueOnBackgroundThread( 5047 std::function<void(StreamExecutor *)> task) { 5048 VLOG_CALL(PARAM(task)); 5049 5050 StreamExecutor *stream_executor = this->parent_; 5051 std::function<void()> bound_task = std::bind(task, stream_executor); 5052 5053 return ThenDoHostCallback([stream_executor, bound_task]() { 5054 stream_executor->EnqueueOnBackgroundThread(bound_task); 5055 }); 5056 } 5057 5058 port::Status Stream::BlockHostUntilDone() { 5059 VLOG_CALL(); 5060 5061 if (!ok()) { 5062 port::Status status = port::Status( 5063 port::error::INTERNAL, 5064 "stream did not block host until done; was already in an error state"); 5065 LOG(INFO) << status << " " << this; 5066 return status; 5067 } 5068 5069 port::Status first_error; 5070 { 5071 // Wait until all active sub-streams have done their tasks. 5072 mutex_lock lock{mu_}; 5073 for (auto &stream : sub_streams_) { 5074 if (!stream.second) { 5075 first_error.Update(stream.first->BlockHostUntilDone()); 5076 // Set this sub-stream as available. 5077 stream.second = true; 5078 } 5079 } 5080 } 5081 5082 temporary_memory_manager_.DeallocateFinalizedTemporaries(); 5083 5084 first_error.Update(parent_->BlockHostUntilDone(this)); 5085 CheckError(first_error.ok()); 5086 return first_error; 5087 } 5088 5089 } // namespace gputools 5090 } // namespace perftools 5091