1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The Stream is used in conjunction with the StreamExecutor "parent" to 17 // perform actions with a linear stream of dependencies. Dependencies can also 18 // be created between Streams to do task management (i.e. limit which tasks 19 // can be performed concurrently and specify what task dependencies exist). 20 21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_ 22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_ 23 24 #include <complex> 25 #include <functional> 26 #include <memory> 27 28 #include "tensorflow/stream_executor/blas.h" 29 #include "tensorflow/stream_executor/device_memory.h" 30 #include "tensorflow/stream_executor/dnn.h" 31 #include "tensorflow/stream_executor/event.h" 32 #include "tensorflow/stream_executor/fft.h" 33 #include "tensorflow/stream_executor/kernel.h" 34 #include "tensorflow/stream_executor/launch_dim.h" 35 #include "tensorflow/stream_executor/lib/array_slice.h" 36 #include "tensorflow/stream_executor/platform/mutex.h" 37 #include "tensorflow/stream_executor/platform/port.h" 38 #include "tensorflow/stream_executor/platform/thread_annotations.h" 39 #include "tensorflow/stream_executor/temporary_memory_manager.h" 40 41 namespace perftools { 42 namespace gputools { 43 44 namespace host { 45 class HostBlas; 46 class HostFft; 47 class HostRng; 48 class HostTimer; 49 } // namespace host 50 51 namespace ocl { 52 class CLBlas; 53 } // namespace ocl 54 55 namespace internal { 56 class StreamInterface; 57 } // namespace internal 58 59 class DeviceMemoryBase; 60 template <typename ElemT> 61 class DeviceMemory; 62 63 class Timer; 64 65 namespace dnn { 66 class BatchDescriptor; 67 class FilterDescriptor; 68 class ConvolutionDescriptor; 69 class BatchDescriptor; 70 class FilterDescriptor; 71 class ConvolutionDescriptor; 72 class ProfileResult; 73 class AlgorithmDesc; 74 } // namespace dnn 75 76 class StreamExecutor; 77 class ScratchAllocator; 78 79 // Convert a type to the corresponding QuantizedActivationMode. 80 template <typename ElementType> 81 struct Quantization; 82 83 // Represents a stream of dependent computations on a GPU device. 84 // 85 // The operations within a stream execute linearly and asynchronously until 86 // BlockHostUntilDone() is invoked, which synchronously joins host code with 87 // the execution of the stream. 88 // 89 // If any given operation fails when entraining work for the stream, ok() will 90 // indicate that an error has occurred. After initialization, once a stream is 91 // !ok(), it will never be ok(). 92 // 93 // Thread-safe post-initialization. 94 class Stream { 95 public: 96 // Instantiate a stream tied to parent as a platform executor. Work 97 // entrained onto this stream will be launched/managed on that 98 // StreamExecutor's platform. 99 explicit Stream(StreamExecutor *parent); 100 101 // Test only. Use an externally-populated value (like a mock) for the 102 // platform-specific stream implementation. 103 Stream(StreamExecutor *parent, internal::StreamInterface *implementation); 104 105 // Deallocates any stream resources that the parent StreamExecutor has 106 // bestowed 107 // upon this object. 108 ~Stream(); 109 110 // Returns whether any errors have occurred while entraining work for this 111 // stream. 112 bool ok() const { return !InErrorState(); } 113 114 // Initialize the stream. This must be performed before entraining any other 115 // operations. 116 Stream &Init() LOCKS_EXCLUDED(mu_); 117 118 // Initializes timer t via the StreamExecutor. 119 Stream &InitTimer(Timer *t); 120 121 // Convenience wrapper around Init() and InitTimer(). 122 Stream &InitWithTimer(Timer *t); 123 124 // Get or create a sub-stream from this stream. If there is any sub-stream in 125 // the pool that can be reused then just return this sub-stream. Otherwise 126 // create a new sub-stream. 127 Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_); 128 129 // Return the sub-stream back to the host stream so that it can be reused 130 // later. 131 void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_); 132 133 // Allocate temporary memories. The stream will deallocate them when blocked 134 // or destroyed. 135 template <typename T> 136 port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> 137 AllocateTemporaryArray(uint64 element_count); 138 139 // Entrains onto the stream of operations: a kernel launch with the given 140 // (variadic) parameters for the invocation. These arguments can be things 141 // like DeviceMemory or primitive types such as int. What arguments you may 142 // pass to a given kernel are noted as the template parameters to the 143 // TypedKernel type that the machocc compiler generates. 144 // 145 // Template parameters: 146 // Params... The type list of formal parameters that the typed kernel 147 // expects, which is matched against Args... 148 // Args... The deduced type list for passed actual arguments 149 // 150 // Implementation: A compile-time compatibility check is performed that has 151 // some leniency versus an exact parameter pack match -- for example, 152 // `const DeviceMemory<T>` is considered "pack compatible" with a 153 // `const DeviceMemory<T>&` formal parameter; in part, because we don't have 154 // perfect forwarding support without rvalue references. It also attempts to 155 // spit out helpful static_assert error traces with information as to the 156 // argument number and types that were mismatched. 157 template <typename... Params, typename... Args> 158 Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, 159 const TypedKernel<Params...> &kernel, Args... args); 160 161 // Record a "start" event for the interval timer at this point in the 162 // stream's 163 // execution (relative to the previously and subsequently enqueued items in 164 // the stream's execution). Streams may be started/stopped multiple times. 165 Stream &ThenStartTimer(Timer *t); 166 167 // Record a "stop" event for the interval timer at this point in the 168 // stream's 169 // execution. See also Stream::ThenStartTimer. 170 Stream &ThenStopTimer(Timer *t); 171 172 // TODO(leary) If work is added to the stream that is being depended upon, 173 // then what? Have to describe what happens. 174 template <typename... Params> 175 Stream &ThenWaitFor(Stream *other, Params... more_streams) { 176 return ThenWaitFor(more_streams...).ThenWaitFor(other); 177 } 178 179 // Create a dependency for this stream's next work on the other stream 180 // completing. Does not take ownership of other, and other must not be 181 // null. 182 // 183 // Checks that a stream does not wait for itself, and it is up to the 184 // user to guarantee that a stream does not come to wait on itself in a 185 // cyclic 186 // manner; in that case, behavior is undefined. 187 // 188 // N.B. Base recursion case for the variadic ThenWaitFor. 189 Stream &ThenWaitFor(Stream *other); 190 191 // Waits for all streams values in others. 192 // Checks that there is no shallow circular wait (i.e. that "this" is not in 193 // others) 194 template <typename P> 195 Stream &ThenWaitFor(P others) { 196 for (auto &stream : *others) { 197 CHECK_NE(stream.get(), this); 198 ThenWaitFor(stream.get()); 199 } 200 return *this; 201 } 202 203 // Waits for an event object to be set. 204 // Note that ThenRecordEvent must have been called on the event before 205 // you call this function; otherwise the event will be considered complete 206 // and this wait will do nothing. 207 Stream &ThenWaitFor(Event *event); 208 209 // Inserts the specified event into the end of this stream. Once the stream 210 // has processed all events prior to the insertion point, the event will be 211 // marked as completed. 212 // The stream does not take ownership of event - meaning that event's lifetime 213 // must extend past the point at which it is marked complete! 214 Stream &ThenRecordEvent(Event *event); 215 216 //////////////// 217 // DNN support 218 // 219 // See DnnSupport::* for comments on the following methods. 220 221 Stream &ThenBatchNormalizationForward( 222 const DeviceMemory<float> &x, const DeviceMemory<float> &scale, 223 const DeviceMemory<float> &offset, 224 const DeviceMemory<float> &estimated_mean, 225 const DeviceMemory<float> &estimated_variance, 226 const dnn::BatchDescriptor &x_desc, 227 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, 228 DeviceMemory<float> *y, DeviceMemory<float> *batch_mean, 229 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean, 230 DeviceMemory<float> *saved_inv_var, bool is_training, 231 std::function<const DeviceMemory<float> &()> var_to_inv_var, 232 std::function<void()> inv_var_to_var); 233 234 Stream &ThenBatchNormalizationBackward( 235 const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x, 236 const DeviceMemory<float> &scale, const DeviceMemory<float> &mean, 237 const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc, 238 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, 239 DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop, 240 DeviceMemory<float> *offset_backprop); 241 242 Stream &ThenBatchNormalizationForward( 243 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, 244 const DeviceMemory<float> &offset, 245 const DeviceMemory<float> &estimated_mean, 246 const DeviceMemory<float> &estimated_variance, 247 const dnn::BatchDescriptor &x_desc, 248 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, 249 DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean, 250 DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean, 251 DeviceMemory<float> *saved_inv_var, bool is_training, 252 std::function<const DeviceMemory<float> &()> var_to_inv_var, 253 std::function<void()> inv_var_to_var); 254 255 Stream &ThenBatchNormalizationBackward( 256 const DeviceMemory<Eigen::half> &y_backprop, 257 const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale, 258 const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var, 259 const dnn::BatchDescriptor &x_desc, 260 const dnn::BatchDescriptor &scale_offset_desc, const double epsilon, 261 DeviceMemory<Eigen::half> *x_backprop, 262 DeviceMemory<float> *scale_backprop, 263 DeviceMemory<float> *offset_backprop); 264 265 // TODO(leary) add double-precision version of this interface. 266 Stream &ThenFusedConvolve( 267 const dnn::BatchDescriptor &conv_input_descriptor, 268 const DeviceMemory<int8> &conv_input_data, float conv_input_scale, 269 const dnn::FilterDescriptor &filter_descriptor, 270 const DeviceMemory<int8> &filter_data, 271 const dnn::ConvolutionDescriptor &convolution_descriptor, 272 const DeviceMemory<int8> &side_input_data, float side_input_scale, 273 const dnn::BatchDescriptor &bias_descriptor, 274 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 275 const dnn::BatchDescriptor &output_descriptor, 276 DeviceMemory<int8> *output); 277 278 Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor, 279 const DeviceMemory<float> &input_data, 280 const dnn::FilterDescriptor &filter_descriptor, 281 const DeviceMemory<float> &filter_data, 282 const dnn::ConvolutionDescriptor &convolution_descriptor, 283 const dnn::BatchDescriptor &output_descriptor, 284 DeviceMemory<float> *output); 285 286 Stream &ThenConvolveQuantized( 287 const dnn::BatchDescriptor &input_descriptor, 288 const DeviceMemory<float> &input_data, 289 const dnn::FilterDescriptor &filter_descriptor, 290 const DeviceMemory<int8> &filter_coefficients, 291 const DeviceMemory<float> &coefficient_scales, 292 const dnn::ConvolutionDescriptor &convolution_descriptor, 293 const dnn::BatchDescriptor &output_descriptor, 294 DeviceMemory<float> *output_data); 295 296 Stream &ThenConvolveQuantized( 297 const dnn::BatchDescriptor &input_descriptor, 298 const DeviceMemory<float> &input_data, 299 const dnn::FilterDescriptor &filter_descriptor, 300 const DeviceMemory<int16> &filter_coefficients, 301 const DeviceMemory<float> &coefficient_scales, 302 const dnn::ConvolutionDescriptor &convolution_descriptor, 303 const dnn::BatchDescriptor &output_descriptor, 304 DeviceMemory<float> *output_data); 305 306 Stream &ThenFusedConvolveWithScratch( 307 const dnn::BatchDescriptor &conv_input_descriptor, 308 const DeviceMemory<int8> &conv_input_data, float conv_input_scale, 309 const dnn::FilterDescriptor &filter_descriptor, 310 const DeviceMemory<int8> &filter_data, 311 const dnn::ConvolutionDescriptor &convolution_descriptor, 312 const DeviceMemory<int8> &side_input_data, float side_input_scale, 313 const dnn::BatchDescriptor &bias_descriptor, 314 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 315 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output, 316 ScratchAllocator *scratch_allocator); 317 318 Stream &ThenFusedConvolveWithScratch( 319 const dnn::BatchDescriptor &conv_input_descriptor, 320 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale, 321 const dnn::FilterDescriptor &filter_descriptor, 322 const DeviceMemory<Eigen::half> &filter_data, 323 const dnn::ConvolutionDescriptor &convolution_descriptor, 324 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale, 325 const dnn::BatchDescriptor &bias_descriptor, 326 const DeviceMemory<Eigen::half> &biases, 327 dnn::ActivationMode activation_mode, 328 const dnn::BatchDescriptor &output_descriptor, 329 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator); 330 331 Stream &ThenFusedConvolveWithScratch( 332 const dnn::BatchDescriptor &conv_input_descriptor, 333 const DeviceMemory<float> &conv_input_data, float conv_input_scale, 334 const dnn::FilterDescriptor &filter_descriptor, 335 const DeviceMemory<float> &filter_data, 336 const dnn::ConvolutionDescriptor &convolution_descriptor, 337 const DeviceMemory<float> &side_input_data, float side_input_scale, 338 const dnn::BatchDescriptor &bias_descriptor, 339 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 340 const dnn::BatchDescriptor &output_descriptor, 341 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator); 342 343 Stream &ThenConvolveWithScratch( 344 const dnn::BatchDescriptor &input_descriptor, 345 const DeviceMemory<Eigen::half> &input_data, 346 const dnn::FilterDescriptor &filter_descriptor, 347 const DeviceMemory<Eigen::half> &filter_data, 348 const dnn::ConvolutionDescriptor &convolution_descriptor, 349 const dnn::BatchDescriptor &output_descriptor, 350 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator); 351 352 Stream &ThenConvolveWithScratch( 353 const dnn::BatchDescriptor &input_descriptor, 354 const DeviceMemory<float> &input_data, 355 const dnn::FilterDescriptor &filter_descriptor, 356 const DeviceMemory<float> &filter_data, 357 const dnn::ConvolutionDescriptor &convolution_descriptor, 358 const dnn::BatchDescriptor &output_descriptor, 359 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator); 360 361 Stream &ThenConvolveWithAlgorithm( 362 const dnn::BatchDescriptor &input_descriptor, 363 const DeviceMemory<float> &input_data, 364 const dnn::FilterDescriptor &filter_descriptor, 365 const DeviceMemory<float> &filter_data, 366 const dnn::ConvolutionDescriptor &convolution_descriptor, 367 const dnn::BatchDescriptor &output_descriptor, 368 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator, 369 const dnn::AlgorithmConfig &algorithm_config, 370 dnn::ProfileResult *output_profile_result); 371 372 Stream &ThenConvolveWithAlgorithm( 373 const dnn::BatchDescriptor &input_descriptor, 374 const DeviceMemory<Eigen::half> &input_data, 375 const dnn::FilterDescriptor &filter_descriptor, 376 const DeviceMemory<Eigen::half> &filter_data, 377 const dnn::ConvolutionDescriptor &convolution_descriptor, 378 const dnn::BatchDescriptor &output_descriptor, 379 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator, 380 const dnn::AlgorithmConfig &algorithm_config, 381 dnn::ProfileResult *output_profile_result); 382 383 Stream &ThenFusedConvolveWithAlgorithm( 384 const dnn::BatchDescriptor &conv_input_descriptor, 385 const DeviceMemory<double> &conv_input_data, double conv_input_scale, 386 const dnn::FilterDescriptor &filter_descriptor, 387 const DeviceMemory<double> &filter_data, 388 const dnn::ConvolutionDescriptor &convolution_descriptor, 389 const DeviceMemory<double> &side_input_data, double side_input_scale, 390 const dnn::BatchDescriptor &bias_descriptor, 391 const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode, 392 const dnn::BatchDescriptor &output_descriptor, 393 DeviceMemory<double> *output, ScratchAllocator *scratch_allocator, 394 const dnn::AlgorithmConfig &algorithm_config, 395 dnn::ProfileResult *output_profile_result); 396 397 Stream &ThenFusedConvolveWithAlgorithm( 398 const dnn::BatchDescriptor &conv_input_descriptor, 399 const DeviceMemory<float> &conv_input_data, float conv_input_scale, 400 const dnn::FilterDescriptor &filter_descriptor, 401 const DeviceMemory<float> &filter_data, 402 const dnn::ConvolutionDescriptor &convolution_descriptor, 403 const DeviceMemory<float> &side_input_data, float side_input_scale, 404 const dnn::BatchDescriptor &bias_descriptor, 405 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 406 const dnn::BatchDescriptor &output_descriptor, 407 DeviceMemory<float> *output, ScratchAllocator *scratch_allocator, 408 const dnn::AlgorithmConfig &algorithm_config, 409 dnn::ProfileResult *output_profile_result); 410 411 Stream &ThenFusedConvolveWithAlgorithm( 412 const dnn::BatchDescriptor &conv_input_descriptor, 413 const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale, 414 const dnn::FilterDescriptor &filter_descriptor, 415 const DeviceMemory<Eigen::half> &filter_data, 416 const dnn::ConvolutionDescriptor &convolution_descriptor, 417 const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale, 418 const dnn::BatchDescriptor &bias_descriptor, 419 const DeviceMemory<Eigen::half> &biases, 420 dnn::ActivationMode activation_mode, 421 const dnn::BatchDescriptor &output_descriptor, 422 DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator, 423 const dnn::AlgorithmConfig &algorithm_config, 424 dnn::ProfileResult *output_profile_result); 425 426 Stream &ThenFusedConvolveWithAlgorithm( 427 const dnn::BatchDescriptor &conv_input_descriptor, 428 const DeviceMemory<int8> &conv_input_data, float conv_input_scale, 429 const dnn::FilterDescriptor &filter_descriptor, 430 const DeviceMemory<int8> &filter_data, 431 const dnn::ConvolutionDescriptor &convolution_descriptor, 432 const DeviceMemory<int8> &side_input_data, float side_input_scale, 433 const dnn::BatchDescriptor &bias_descriptor, 434 const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode, 435 const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output, 436 ScratchAllocator *scratch_allocator, 437 const dnn::AlgorithmConfig &algorithm_config, 438 dnn::ProfileResult *output_profile_result); 439 440 Stream &ThenSeparableConvolve( 441 const dnn::BatchDescriptor &input_descriptor, 442 const DeviceMemory<float> &input_data, 443 const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier, 444 const DeviceMemory<float> &first_weights, 445 const DeviceMemory<float> &second_weights, 446 const dnn::ConvolutionDescriptor &convolution_descriptor, 447 const dnn::BatchDescriptor &output_descriptor, 448 DeviceMemory<float> *output); 449 450 Stream &ThenConvolveBackwardData( 451 const dnn::FilterDescriptor &filter_descriptor, 452 const DeviceMemory<float> &filter_data, 453 const dnn::BatchDescriptor &output_descriptor, 454 DeviceMemory<float> backward_output_data, 455 const dnn::ConvolutionDescriptor &convolution_descriptor, 456 const dnn::BatchDescriptor &input_descriptor, 457 DeviceMemory<float> *backward_input_data); 458 459 Stream &ThenConvolveBackwardDataWithScratch( 460 const dnn::FilterDescriptor &filter_descriptor, 461 const DeviceMemory<float> &filter_data, 462 const dnn::BatchDescriptor &output_descriptor, 463 DeviceMemory<float> backward_output_data, 464 const dnn::ConvolutionDescriptor &convolution_descriptor, 465 const dnn::BatchDescriptor &input_descriptor, 466 DeviceMemory<float> *backward_input_data, 467 ScratchAllocator *scratch_allocator); 468 469 Stream &ThenConvolveBackwardDataWithScratch( 470 const dnn::FilterDescriptor &filter_descriptor, 471 const DeviceMemory<Eigen::half> &filter_data, 472 const dnn::BatchDescriptor &output_descriptor, 473 DeviceMemory<Eigen::half> backward_output_data, 474 const dnn::ConvolutionDescriptor &convolution_descriptor, 475 const dnn::BatchDescriptor &input_descriptor, 476 DeviceMemory<Eigen::half> *backward_input_data, 477 ScratchAllocator *scratch_allocator); 478 479 Stream &ThenConvolveBackwardDataWithAlgorithm( 480 const dnn::FilterDescriptor &filter_descriptor, 481 const DeviceMemory<float> &filter_data, 482 const dnn::BatchDescriptor &output_descriptor, 483 DeviceMemory<float> backward_output_data, 484 const dnn::ConvolutionDescriptor &convolution_descriptor, 485 const dnn::BatchDescriptor &input_descriptor, 486 DeviceMemory<float> *backward_input_data, 487 ScratchAllocator *scratch_allocator, 488 const dnn::AlgorithmConfig &algorithm_config, 489 dnn::ProfileResult *output_profile_result); 490 491 Stream &ThenConvolveBackwardDataWithAlgorithm( 492 const dnn::FilterDescriptor &filter_descriptor, 493 const DeviceMemory<Eigen::half> &filter_data, 494 const dnn::BatchDescriptor &output_descriptor, 495 DeviceMemory<Eigen::half> backward_output_data, 496 const dnn::ConvolutionDescriptor &convolution_descriptor, 497 const dnn::BatchDescriptor &input_descriptor, 498 DeviceMemory<Eigen::half> *backward_input_data, 499 ScratchAllocator *scratch_allocator, 500 const dnn::AlgorithmConfig &algorithm_config, 501 dnn::ProfileResult *output_profile_result); 502 503 Stream &ThenConvolveBackwardFilter( 504 const dnn::BatchDescriptor &input_descriptor, 505 const DeviceMemory<float> &input_data, 506 const dnn::BatchDescriptor &output_descriptor, 507 DeviceMemory<float> backward_output_data, 508 const dnn::ConvolutionDescriptor &convolution_descriptor, 509 const dnn::FilterDescriptor &filter_descriptor, 510 DeviceMemory<float> *backward_filter_data); 511 512 Stream &ThenConvolveBackwardFilterWithScratch( 513 const dnn::BatchDescriptor &input_descriptor, 514 const DeviceMemory<float> &input_data, 515 const dnn::BatchDescriptor &output_descriptor, 516 DeviceMemory<float> backward_output_data, 517 const dnn::ConvolutionDescriptor &convolution_descriptor, 518 const dnn::FilterDescriptor &filter_descriptor, 519 DeviceMemory<float> *backward_filter_data, 520 ScratchAllocator *scratch_allocator); 521 522 Stream &ThenConvolveBackwardFilterWithScratch( 523 const dnn::BatchDescriptor &input_descriptor, 524 const DeviceMemory<Eigen::half> &input_data, 525 const dnn::BatchDescriptor &output_descriptor, 526 DeviceMemory<Eigen::half> backward_output_data, 527 const dnn::ConvolutionDescriptor &convolution_descriptor, 528 const dnn::FilterDescriptor &filter_descriptor, 529 DeviceMemory<Eigen::half> *backward_filter_data, 530 ScratchAllocator *scratch_allocator); 531 532 Stream &ThenConvolveBackwardFilterWithAlgorithm( 533 const dnn::BatchDescriptor &input_descriptor, 534 const DeviceMemory<float> &input_data, 535 const dnn::BatchDescriptor &output_descriptor, 536 DeviceMemory<float> backward_output_data, 537 const dnn::ConvolutionDescriptor &convolution_descriptor, 538 const dnn::FilterDescriptor &filter_descriptor, 539 DeviceMemory<float> *backward_filter_data, 540 ScratchAllocator *scratch_allocator, 541 const dnn::AlgorithmConfig &algorithm_config, 542 dnn::ProfileResult *output_profile_result); 543 544 Stream &ThenConvolveBackwardFilterWithAlgorithm( 545 const dnn::BatchDescriptor &input_descriptor, 546 const DeviceMemory<Eigen::half> &input_data, 547 const dnn::BatchDescriptor &output_descriptor, 548 DeviceMemory<Eigen::half> backward_output_data, 549 const dnn::ConvolutionDescriptor &convolution_descriptor, 550 const dnn::FilterDescriptor &filter_descriptor, 551 DeviceMemory<Eigen::half> *backward_filter_data, 552 ScratchAllocator *scratch_allocator, 553 const dnn::AlgorithmConfig &algorithm_config, 554 dnn::ProfileResult *output_profile_result); 555 556 Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor, 557 const DeviceMemory<double> &input_data, 558 const dnn::BatchDescriptor &bias_descriptor, 559 DeviceMemory<double> *backward_bias_data); 560 561 Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor, 562 const DeviceMemory<float> &input_data, 563 const dnn::BatchDescriptor &bias_descriptor, 564 DeviceMemory<float> *backward_bias_data); 565 566 Stream &ThenConvolveBackwardBias( 567 const dnn::BatchDescriptor &input_descriptor, 568 const DeviceMemory<Eigen::half> &input_data, 569 const dnn::BatchDescriptor &bias_descriptor, 570 DeviceMemory<Eigen::half> *backward_bias_data); 571 572 Stream &ThenMatMul(const DeviceMemory<float> &input_data, 573 const DeviceMemory<float> &weights, 574 const dnn::BatchDescriptor &input_dimensions, 575 const dnn::BatchDescriptor &output_dimensions, 576 DeviceMemory<float> *output_data); 577 578 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data, 579 const DeviceMemory<int8> &weights, 580 const DeviceMemory<float> &weight_scales, 581 const dnn::BatchDescriptor &input_dimensions, 582 const dnn::BatchDescriptor &output_dimensions, 583 DeviceMemory<float> *output_data); 584 585 Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data, 586 const DeviceMemory<int16> &weights, 587 const DeviceMemory<float> &weight_scales, 588 const dnn::BatchDescriptor &input_dimensions, 589 const dnn::BatchDescriptor &output_dimensions, 590 DeviceMemory<float> *output_data); 591 592 Stream &ThenBiasAdd(const DeviceMemory<float> &input_data, 593 const DeviceMemory<float> &biases, 594 const dnn::BatchDescriptor &dimensions, 595 DeviceMemory<float> *output_data); 596 597 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, 598 const dnn::BatchDescriptor &input_dimensions, 599 const DeviceMemory<double> &input_data, 600 const dnn::BatchDescriptor &output_dimensions, 601 DeviceMemory<double> *output_data); 602 603 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, 604 const dnn::BatchDescriptor &input_dimensions, 605 const DeviceMemory<float> &input_data, 606 const dnn::BatchDescriptor &output_dimensions, 607 DeviceMemory<float> *output_data); 608 609 Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, 610 const dnn::BatchDescriptor &input_dimensions, 611 const DeviceMemory<Eigen::half> &input_data, 612 const dnn::BatchDescriptor &output_dimensions, 613 DeviceMemory<Eigen::half> *output_data); 614 615 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, 616 const dnn::BatchDescriptor &input_dimensions, 617 const DeviceMemory<double> &input_data, 618 const dnn::BatchDescriptor &output_dimensions, 619 const DeviceMemory<double> &output_data, 620 const DeviceMemory<double> &input_diff_data, 621 DeviceMemory<double> *output_diff_data); 622 623 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, 624 const dnn::BatchDescriptor &input_dimensions, 625 const DeviceMemory<float> &input_data, 626 const dnn::BatchDescriptor &output_dimensions, 627 const DeviceMemory<float> &output_data, 628 const DeviceMemory<float> &input_diff_data, 629 DeviceMemory<float> *output_diff_data); 630 631 Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, 632 const dnn::BatchDescriptor &input_dimensions, 633 const DeviceMemory<Eigen::half> &input_data, 634 const dnn::BatchDescriptor &output_dimensions, 635 const DeviceMemory<Eigen::half> &output_data, 636 const DeviceMemory<Eigen::half> &input_diff_data, 637 DeviceMemory<Eigen::half> *output_diff_data); 638 639 Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor, 640 const DeviceMemory<float> &input_data, 641 DeviceMemory<float> *output_data); 642 643 // Similar to ThenNormalize, but normalizes across feature maps and allows for 644 // specifying the dimensions of the tensor. 645 Stream &ThenNormalizeWithDimensions( 646 const dnn::NormalizeDescriptor &normalize_descriptor, 647 const dnn::BatchDescriptor &dimensions, 648 const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data); 649 650 Stream &ThenNormalizeBackwardWithDimensions( 651 const dnn::NormalizeDescriptor &normalize_descriptor, 652 const dnn::BatchDescriptor &dimensions, 653 const DeviceMemory<float> &raw_data, 654 const DeviceMemory<float> &normalized_data, 655 const DeviceMemory<float> &normalized_variable_gradient, 656 DeviceMemory<float> *raw_variable_gradient); 657 658 Stream &ThenActivate(dnn::ActivationMode activation_mode, 659 const dnn::BatchDescriptor &dimensions, 660 const DeviceMemory<float> &input_data, 661 DeviceMemory<float> *output_data); 662 663 // Same as ThenActivate, but also takes an options argument that can be used 664 // for platform-specific option flags. 665 Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode, 666 const dnn::BatchDescriptor &dimensions, 667 const DeviceMemory<float> &input_data, 668 DeviceMemory<float> *output_data, 669 uint64 options); 670 671 Stream &ThenDepthConcatenate( 672 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, 673 port::ArraySlice<const DeviceMemory<float> *> input_data, 674 DeviceMemory<float> *output_data); 675 676 Stream &ThenSpaceConcatenate( 677 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, 678 port::ArraySlice<const DeviceMemory<float> *> input_data, 679 DeviceMemory<float> *output_data, 680 dnn::SpaceConcatenateMode concat_direction); 681 682 // Change the layout of the data by shrinking one dimension (or set of 683 // dimensions) and growing another dimension (or set of dimensions), while 684 // keeping the total number of data elements constant, and maintaining the 685 // current data ordering. 686 Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions, 687 const DeviceMemory<float> &input_data, 688 const dnn::BatchDescriptor &output_dimensions, 689 DeviceMemory<float> *output_data); 690 691 // Depth to space takes an X by Y image with depth D*M and changes it to an 692 // MX x MY image with depth D. Each input location (x,y) with depth D*M in 693 // the input image is changed to an MxM contiguous area in the output image, 694 // with the values being laid out in raster order specified by 695 // DepthToSpaceLayout, and will have a new depth of D. 696 // See the DoDepthToSpace comment for more information. 697 Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions, 698 const DeviceMemory<float> &input_data, 699 const dnn::DepthToSpaceLayout &depth_to_space_layout, 700 const int sqrt_depth_reduction, 701 DeviceMemory<float> *output_data); 702 703 // Space to depth is the inverse of depth to space. Space to depth takes each 704 // non-overlapping M by M patch (in the X and Y dimensions) with depth D of 705 // the input, and transforms it to a 1 by 1 patch with depth D*M. If the 706 // input has size (MX, MY, D), the output has size (X, Y, D*M). The number of 707 // data elements is not changed. 708 Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions, 709 const DeviceMemory<float> &input_data, 710 const dnn::DepthToSpaceLayout &space_to_depth_layout, 711 const int sqrt_depth_increase, 712 DeviceMemory<float> *output_data); 713 714 Stream &ThenElementwiseOperate( 715 dnn::ElementwiseOperation operation, 716 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, 717 port::ArraySlice<const DeviceMemory<float> *> input_data, 718 const dnn::BatchDescriptor &output_dimensions, 719 DeviceMemory<float> *output_data); 720 721 Stream &ThenElementwiseOperateScaledQuantized( 722 dnn::ElementwiseOperation operation, 723 port::ArraySlice<int> input_multiplicands, int output_divisor, 724 port::ArraySlice<dnn::BatchDescriptor> input_dimensions, 725 port::ArraySlice<const DeviceMemory<float> *> input_data, 726 const dnn::BatchDescriptor &output_dimensions, 727 DeviceMemory<float> *output_data); 728 729 Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions, 730 const DeviceMemory<float> &input_data, int64 left_pad, 731 int64 right_pad, int64 top_pad, int64 bottom_pad, 732 DeviceMemory<float> *output_data); 733 734 Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions, 735 const DeviceMemory<float> &input_data, int64 left_trim, 736 int64 right_trim, int64 top_trim, int64 bottom_trim, 737 DeviceMemory<float> *output_data); 738 739 // Grows the input tensor by replicating the X and Y dimensions. The batch and 740 // depth/feature_map dimensions are unchanged. Currently, the input tensor is 741 // limited to X=1 and Y=1. 742 Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions, 743 const DeviceMemory<float> &input_data, 744 int64 replicate_x, int64 replicate_y, 745 DeviceMemory<float> *output_data); 746 747 // See DnnSupport::DoMemcpyD2HQuantized. 748 Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src, 749 dnn::QuantizedActivationMode mode, 750 void *host_dst, uint64 size); 751 752 // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice 753 // and uses the Quantization trait to call the generic version of 754 // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode. 755 template <typename ElementType> 756 Stream &ThenMemcpyD2HQuantized( 757 const DeviceMemory<float> &gpu_unquantized_src, 758 port::MutableArraySlice<ElementType> host_dst) { 759 return ThenMemcpyD2HQuantized( 760 gpu_unquantized_src, Quantization<ElementType>::kModeId, 761 host_dst.data(), host_dst.size() * sizeof(ElementType)); 762 } 763 764 // See DnnSupport::DoMemcpyH2DQuantized. 765 Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size, 766 dnn::QuantizedActivationMode mode, 767 DeviceMemory<float> *gpu_unquantized_dst); 768 769 // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice 770 // and uses the Quantization trait to call the generic version of 771 // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode. 772 template <typename ElementType> 773 Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src, 774 DeviceMemory<float> *gpu_unquantized_dst) { 775 return ThenMemcpyH2DQuantized( 776 host_src.data(), host_src.size() * sizeof(ElementType), 777 Quantization<ElementType>::kModeId, gpu_unquantized_dst); 778 } 779 780 // See DnnSupport::DoCopyHostBuffer2Device. 781 Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src, 782 DeviceMemory<float> *gpu_unquantized_dst); 783 784 // See DnnSupport::DoCopyDevice2HostBuffer. 785 Stream &ThenCopyDevice2HostBuffer( 786 const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst); 787 788 ///////////////// 789 // BLAS support 790 791 // See BlasSupport::DoBlasAsum. 792 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x, 793 int incx, DeviceMemory<float> *result); 794 Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x, 795 int incx, DeviceMemory<double> *result); 796 Stream &ThenBlasAsum(uint64 elem_count, 797 const DeviceMemory<std::complex<float>> &x, int incx, 798 DeviceMemory<float> *result); 799 Stream &ThenBlasAsum(uint64 elem_count, 800 const DeviceMemory<std::complex<double>> &x, int incx, 801 DeviceMemory<double> *result); 802 803 // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is 804 // present in DeviceMemory, it must be an execution-time constant (i.e. a 805 // value 806 // that the stream does not change or populate during the course of 807 // execution). The value is effectively captured at stream-enqueue time. 808 Stream &ThenBlasAxpy(uint64 elem_count, float alpha, 809 const DeviceMemory<float> &x, int incx, 810 DeviceMemory<float> *y, int incy); 811 Stream &ThenBlasAxpy(uint64 elem_count, double alpha, 812 const DeviceMemory<double> &x, int incx, 813 DeviceMemory<double> *y, int incy); 814 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha, 815 const DeviceMemory<std::complex<float>> &x, int incx, 816 DeviceMemory<std::complex<float>> *y, int incy); 817 Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha, 818 const DeviceMemory<std::complex<double>> &x, int incx, 819 DeviceMemory<std::complex<double>> *y, int incy); 820 821 // See BlasSupport::DoBlasCopy. 822 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x, 823 int incx, DeviceMemory<float> *y, int incy); 824 Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x, 825 int incx, DeviceMemory<double> *y, int incy); 826 Stream &ThenBlasCopy(uint64 elem_count, 827 const DeviceMemory<std::complex<float>> &x, int incx, 828 DeviceMemory<std::complex<float>> *y, int incy); 829 Stream &ThenBlasCopy(uint64 elem_count, 830 const DeviceMemory<std::complex<double>> &x, int incx, 831 DeviceMemory<std::complex<double>> *y, int incy); 832 833 // See BlasSupport::DoBlasDot. 834 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx, 835 const DeviceMemory<float> &y, int incy, 836 DeviceMemory<float> *result); 837 Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x, 838 int incx, const DeviceMemory<double> &y, int incy, 839 DeviceMemory<double> *result); 840 841 // See BlasSupport::DoBlasDotc. 842 Stream &ThenBlasDotc(uint64 elem_count, 843 const DeviceMemory<std::complex<float>> &x, int incx, 844 const DeviceMemory<std::complex<float>> &y, int incy, 845 DeviceMemory<std::complex<float>> *result); 846 Stream &ThenBlasDotc(uint64 elem_count, 847 const DeviceMemory<std::complex<double>> &x, int incx, 848 const DeviceMemory<std::complex<double>> &y, int incy, 849 DeviceMemory<std::complex<double>> *result); 850 851 // See BlasSupport::DoBlasDotu. 852 Stream &ThenBlasDotu(uint64 elem_count, 853 const DeviceMemory<std::complex<float>> &x, int incx, 854 const DeviceMemory<std::complex<float>> &y, int incy, 855 DeviceMemory<std::complex<float>> *result); 856 Stream &ThenBlasDotu(uint64 elem_count, 857 const DeviceMemory<std::complex<double>> &x, int incx, 858 const DeviceMemory<std::complex<double>> &y, int incy, 859 DeviceMemory<std::complex<double>> *result); 860 861 // See BlasSupport::DoBlasNrm2. 862 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x, 863 int incx, DeviceMemory<float> *result); 864 Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x, 865 int incx, DeviceMemory<double> *result); 866 Stream &ThenBlasNrm2(uint64 elem_count, 867 const DeviceMemory<std::complex<float>> &x, int incx, 868 DeviceMemory<float> *result); 869 Stream &ThenBlasNrm2(uint64 elem_count, 870 const DeviceMemory<std::complex<double>> &x, int incx, 871 DeviceMemory<double> *result); 872 873 // See BlasSupport::DoBlasRot. 874 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx, 875 DeviceMemory<float> *y, int incy, float c, float s); 876 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx, 877 DeviceMemory<double> *y, int incy, double c, double s); 878 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x, 879 int incx, DeviceMemory<std::complex<float>> *y, int incy, 880 float c, float s); 881 Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x, 882 int incx, DeviceMemory<std::complex<double>> *y, int incy, 883 double c, double s); 884 885 // See BlasSupport::DoBlasRotg. 886 Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b, 887 DeviceMemory<float> *c, DeviceMemory<float> *s); 888 Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b, 889 DeviceMemory<double> *c, DeviceMemory<double> *s); 890 Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a, 891 DeviceMemory<std::complex<float>> *b, 892 DeviceMemory<float> *c, 893 DeviceMemory<std::complex<float>> *s); 894 Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a, 895 DeviceMemory<std::complex<double>> *b, 896 DeviceMemory<double> *c, 897 DeviceMemory<std::complex<double>> *s); 898 899 // See BlasSupport::DoBlasRotm. 900 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx, 901 DeviceMemory<float> *y, int incy, 902 const DeviceMemory<float> ¶m); 903 Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx, 904 DeviceMemory<double> *y, int incy, 905 const DeviceMemory<double> ¶m); 906 907 // See BlasSupport::DoBlasRotmg. 908 Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2, 909 DeviceMemory<float> *x1, const DeviceMemory<float> &y1, 910 DeviceMemory<float> *param); 911 Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2, 912 DeviceMemory<double> *x1, 913 const DeviceMemory<double> &y1, 914 DeviceMemory<double> *param); 915 916 // See BlasSupport::DoBlasScal. 917 Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x, 918 int incx); 919 Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x, 920 int incx); 921 Stream &ThenBlasScal(uint64 elem_count, float alpha, 922 DeviceMemory<std::complex<float>> *x, int incx); 923 Stream &ThenBlasScal(uint64 elem_count, double alpha, 924 DeviceMemory<std::complex<double>> *x, int incx); 925 Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha, 926 DeviceMemory<std::complex<float>> *x, int incx); 927 Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha, 928 DeviceMemory<std::complex<double>> *x, int incx); 929 930 // See BlasSupport::DoBlasSwap. 931 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx, 932 DeviceMemory<float> *y, int incy); 933 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx, 934 DeviceMemory<double> *y, int incy); 935 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x, 936 int incx, DeviceMemory<std::complex<float>> *y, 937 int incy); 938 Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x, 939 int incx, DeviceMemory<std::complex<double>> *y, 940 int incy); 941 942 // See BlasSupport::DoBlasIamax. 943 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x, 944 int incx, DeviceMemory<int> *result); 945 Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x, 946 int incx, DeviceMemory<int> *result); 947 Stream &ThenBlasIamax(uint64 elem_count, 948 const DeviceMemory<std::complex<float>> &x, int incx, 949 DeviceMemory<int> *result); 950 Stream &ThenBlasIamax(uint64 elem_count, 951 const DeviceMemory<std::complex<double>> &x, int incx, 952 DeviceMemory<int> *result); 953 954 // See BlasSupport::DoBlasIamin. 955 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x, 956 int incx, DeviceMemory<int> *result); 957 Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x, 958 int incx, DeviceMemory<int> *result); 959 Stream &ThenBlasIamin(uint64 elem_count, 960 const DeviceMemory<std::complex<float>> &x, int incx, 961 DeviceMemory<int> *result); 962 Stream &ThenBlasIamin(uint64 elem_count, 963 const DeviceMemory<std::complex<double>> &x, int incx, 964 DeviceMemory<int> *result); 965 966 // See BlasSupport::DoBlasGbmv. 967 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl, 968 uint64 ku, float alpha, const DeviceMemory<float> &a, 969 int lda, const DeviceMemory<float> &x, int incx, 970 float beta, DeviceMemory<float> *y, int incy); 971 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl, 972 uint64 ku, double alpha, const DeviceMemory<double> &a, 973 int lda, const DeviceMemory<double> &x, int incx, 974 double beta, DeviceMemory<double> *y, int incy); 975 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl, 976 uint64 ku, std::complex<float> alpha, 977 const DeviceMemory<std::complex<float>> &a, int lda, 978 const DeviceMemory<std::complex<float>> &x, int incx, 979 std::complex<float> beta, 980 DeviceMemory<std::complex<float>> *y, int incy); 981 Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl, 982 uint64 ku, std::complex<double> alpha, 983 const DeviceMemory<std::complex<double>> &a, int lda, 984 const DeviceMemory<std::complex<double>> &x, int incx, 985 std::complex<double> beta, 986 DeviceMemory<std::complex<double>> *y, int incy); 987 988 // See BlasSupport::DoBlasGemv. 989 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha, 990 const DeviceMemory<float> &a, int lda, 991 const DeviceMemory<float> &x, int incx, float beta, 992 DeviceMemory<float> *y, int incy); 993 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha, 994 const DeviceMemory<double> &a, int lda, 995 const DeviceMemory<double> &x, int incx, double beta, 996 DeviceMemory<double> *y, int incy); 997 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, 998 std::complex<float> alpha, 999 const DeviceMemory<std::complex<float>> &a, int lda, 1000 const DeviceMemory<std::complex<float>> &x, int incx, 1001 std::complex<float> beta, 1002 DeviceMemory<std::complex<float>> *y, int incy); 1003 Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, 1004 std::complex<double> alpha, 1005 const DeviceMemory<std::complex<double>> &a, int lda, 1006 const DeviceMemory<std::complex<double>> &x, int incx, 1007 std::complex<double> beta, 1008 DeviceMemory<std::complex<double>> *y, int incy); 1009 1010 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n, 1011 float alpha, const DeviceMemory<float> &a, 1012 int lda, const DeviceMemory<float> &x, 1013 int incx, float beta, 1014 DeviceMemory<float> *y, int incy, 1015 blas::ProfileResult *output_profile_result); 1016 Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n, 1017 double alpha, const DeviceMemory<double> &a, 1018 int lda, const DeviceMemory<double> &x, 1019 int incx, double beta, 1020 DeviceMemory<double> *y, int incy, 1021 blas::ProfileResult *output_profile_result); 1022 Stream &ThenBlasGemvWithProfiling( 1023 blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha, 1024 const DeviceMemory<std::complex<float>> &a, int lda, 1025 const DeviceMemory<std::complex<float>> &x, int incx, 1026 std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy, 1027 blas::ProfileResult *output_profile_result); 1028 Stream &ThenBlasGemvWithProfiling( 1029 blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha, 1030 const DeviceMemory<std::complex<double>> &a, int lda, 1031 const DeviceMemory<std::complex<double>> &x, int incx, 1032 std::complex<double> beta, DeviceMemory<std::complex<double>> *y, 1033 int incy, blas::ProfileResult *output_profile_result); 1034 1035 // See BlasSupport::DoBlasGer. 1036 Stream &ThenBlasGer(uint64 m, uint64 n, float alpha, 1037 const DeviceMemory<float> &x, int incx, 1038 const DeviceMemory<float> &y, int incy, 1039 DeviceMemory<float> *a, int lda); 1040 Stream &ThenBlasGer(uint64 m, uint64 n, double alpha, 1041 const DeviceMemory<double> &x, int incx, 1042 const DeviceMemory<double> &y, int incy, 1043 DeviceMemory<double> *a, int lda); 1044 1045 // See BlasSupport::DoBlasGerc. 1046 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha, 1047 const DeviceMemory<std::complex<float>> &x, int incx, 1048 const DeviceMemory<std::complex<float>> &y, int incy, 1049 DeviceMemory<std::complex<float>> *a, int lda); 1050 Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha, 1051 const DeviceMemory<std::complex<double>> &x, int incx, 1052 const DeviceMemory<std::complex<double>> &y, int incy, 1053 DeviceMemory<std::complex<double>> *a, int lda); 1054 1055 // See BlasSupport::DoBlasGeru. 1056 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha, 1057 const DeviceMemory<std::complex<float>> &x, int incx, 1058 const DeviceMemory<std::complex<float>> &y, int incy, 1059 DeviceMemory<std::complex<float>> *a, int lda); 1060 Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha, 1061 const DeviceMemory<std::complex<double>> &x, int incx, 1062 const DeviceMemory<std::complex<double>> &y, int incy, 1063 DeviceMemory<std::complex<double>> *a, int lda); 1064 1065 // See BlasSupport::DoBlasHbmv. 1066 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k, 1067 std::complex<float> alpha, 1068 const DeviceMemory<std::complex<float>> &a, int lda, 1069 const DeviceMemory<std::complex<float>> &x, int incx, 1070 std::complex<float> beta, 1071 DeviceMemory<std::complex<float>> *y, int incy); 1072 Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k, 1073 std::complex<double> alpha, 1074 const DeviceMemory<std::complex<double>> &a, int lda, 1075 const DeviceMemory<std::complex<double>> &x, int incx, 1076 std::complex<double> beta, 1077 DeviceMemory<std::complex<double>> *y, int incy); 1078 1079 // See BlasSupport::DoBlasHemv. 1080 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n, 1081 std::complex<float> alpha, 1082 const DeviceMemory<std::complex<float>> &a, int lda, 1083 const DeviceMemory<std::complex<float>> &x, int incx, 1084 std::complex<float> beta, 1085 DeviceMemory<std::complex<float>> *y, int incy); 1086 Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n, 1087 std::complex<double> alpha, 1088 const DeviceMemory<std::complex<double>> &a, int lda, 1089 const DeviceMemory<std::complex<double>> &x, int incx, 1090 std::complex<double> beta, 1091 DeviceMemory<std::complex<double>> *y, int incy); 1092 1093 // See BlasSupport::DoBlasHer. 1094 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha, 1095 const DeviceMemory<std::complex<float>> &x, int incx, 1096 DeviceMemory<std::complex<float>> *a, int lda); 1097 Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha, 1098 const DeviceMemory<std::complex<double>> &x, int incx, 1099 DeviceMemory<std::complex<double>> *a, int lda); 1100 1101 // See BlasSupport::DoBlasHer2. 1102 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n, 1103 std::complex<float> alpha, 1104 const DeviceMemory<std::complex<float>> &x, int incx, 1105 const DeviceMemory<std::complex<float>> &y, int incy, 1106 DeviceMemory<std::complex<float>> *a, int lda); 1107 Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n, 1108 std::complex<double> alpha, 1109 const DeviceMemory<std::complex<double>> &x, int incx, 1110 const DeviceMemory<std::complex<double>> &y, int incy, 1111 DeviceMemory<std::complex<double>> *a, int lda); 1112 1113 // See BlasSupport::DoBlasHpmv. 1114 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n, 1115 std::complex<float> alpha, 1116 const DeviceMemory<std::complex<float>> &ap, 1117 const DeviceMemory<std::complex<float>> &x, int incx, 1118 std::complex<float> beta, 1119 DeviceMemory<std::complex<float>> *y, int incy); 1120 Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n, 1121 std::complex<double> alpha, 1122 const DeviceMemory<std::complex<double>> &ap, 1123 const DeviceMemory<std::complex<double>> &x, int incx, 1124 std::complex<double> beta, 1125 DeviceMemory<std::complex<double>> *y, int incy); 1126 1127 // See BlasSupport::DoBlasHpr. 1128 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha, 1129 const DeviceMemory<std::complex<float>> &x, int incx, 1130 DeviceMemory<std::complex<float>> *ap); 1131 Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha, 1132 const DeviceMemory<std::complex<double>> &x, int incx, 1133 DeviceMemory<std::complex<double>> *ap); 1134 1135 // See BlasSupport::DoBlasHpr2. 1136 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n, 1137 std::complex<float> alpha, 1138 const DeviceMemory<std::complex<float>> &x, int incx, 1139 const DeviceMemory<std::complex<float>> &y, int incy, 1140 DeviceMemory<std::complex<float>> *ap); 1141 Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n, 1142 std::complex<double> alpha, 1143 const DeviceMemory<std::complex<double>> &x, int incx, 1144 const DeviceMemory<std::complex<double>> &y, int incy, 1145 DeviceMemory<std::complex<double>> *ap); 1146 1147 // See BlasSupport::DoBlasSbmv. 1148 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha, 1149 const DeviceMemory<float> &a, int lda, 1150 const DeviceMemory<float> &x, int incx, float beta, 1151 DeviceMemory<float> *y, int incy); 1152 Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha, 1153 const DeviceMemory<double> &a, int lda, 1154 const DeviceMemory<double> &x, int incx, double beta, 1155 DeviceMemory<double> *y, int incy); 1156 1157 // See BlasSupport::DoBlasSpmv. 1158 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha, 1159 const DeviceMemory<float> &ap, 1160 const DeviceMemory<float> &x, int incx, float beta, 1161 DeviceMemory<float> *y, int incy); 1162 Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha, 1163 const DeviceMemory<double> &ap, 1164 const DeviceMemory<double> &x, int incx, double beta, 1165 DeviceMemory<double> *y, int incy); 1166 1167 // See BlasSupport::DoBlasSpr. 1168 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha, 1169 const DeviceMemory<float> &x, int incx, 1170 DeviceMemory<float> *ap); 1171 Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha, 1172 const DeviceMemory<double> &x, int incx, 1173 DeviceMemory<double> *ap); 1174 1175 // See BlasSupport::DoBlasSpr2. 1176 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha, 1177 const DeviceMemory<float> &x, int incx, 1178 const DeviceMemory<float> &y, int incy, 1179 DeviceMemory<float> *ap); 1180 Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha, 1181 const DeviceMemory<double> &x, int incx, 1182 const DeviceMemory<double> &y, int incy, 1183 DeviceMemory<double> *ap); 1184 1185 // See BlasSupport::DoBlasSymv. 1186 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha, 1187 const DeviceMemory<float> &a, int lda, 1188 const DeviceMemory<float> &x, int incx, float beta, 1189 DeviceMemory<float> *y, int incy); 1190 Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha, 1191 const DeviceMemory<double> &a, int lda, 1192 const DeviceMemory<double> &x, int incx, double beta, 1193 DeviceMemory<double> *y, int incy); 1194 1195 // See BlasSupport::DoBlasSyr. 1196 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha, 1197 const DeviceMemory<float> &x, int incx, 1198 DeviceMemory<float> *a, int lda); 1199 Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha, 1200 const DeviceMemory<double> &x, int incx, 1201 DeviceMemory<double> *a, int lda); 1202 1203 // See BlasSupport::DoBlasSyr2. 1204 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha, 1205 const DeviceMemory<float> &x, int incx, 1206 const DeviceMemory<float> &y, int incy, 1207 DeviceMemory<float> *a, int lda); 1208 Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha, 1209 const DeviceMemory<double> &x, int incx, 1210 const DeviceMemory<double> &y, int incy, 1211 DeviceMemory<double> *a, int lda); 1212 1213 // See BlasSupport::DoBlasTbmv. 1214 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, 1215 blas::Diagonal diag, uint64 n, uint64 k, 1216 const DeviceMemory<float> &a, int lda, 1217 DeviceMemory<float> *x, int incx); 1218 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, 1219 blas::Diagonal diag, uint64 n, uint64 k, 1220 const DeviceMemory<double> &a, int lda, 1221 DeviceMemory<double> *x, int incx); 1222 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, 1223 blas::Diagonal diag, uint64 n, uint64 k, 1224 const DeviceMemory<std::complex<float>> &a, int lda, 1225 DeviceMemory<std::complex<float>> *x, int incx); 1226 Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans, 1227 blas::Diagonal diag, uint64 n, uint64 k, 1228 const DeviceMemory<std::complex<double>> &a, int lda, 1229 DeviceMemory<std::complex<double>> *x, int incx); 1230 1231 // See BlasSupport::DoBlasTbsv. 1232 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, 1233 blas::Diagonal diag, uint64 n, uint64 k, 1234 const DeviceMemory<float> &a, int lda, 1235 DeviceMemory<float> *x, int incx); 1236 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, 1237 blas::Diagonal diag, uint64 n, uint64 k, 1238 const DeviceMemory<double> &a, int lda, 1239 DeviceMemory<double> *x, int incx); 1240 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, 1241 blas::Diagonal diag, uint64 n, uint64 k, 1242 const DeviceMemory<std::complex<float>> &a, int lda, 1243 DeviceMemory<std::complex<float>> *x, int incx); 1244 Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans, 1245 blas::Diagonal diag, uint64 n, uint64 k, 1246 const DeviceMemory<std::complex<double>> &a, int lda, 1247 DeviceMemory<std::complex<double>> *x, int incx); 1248 1249 // See BlasSupport::DoBlasTpmv. 1250 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, 1251 blas::Diagonal diag, uint64 n, 1252 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 1253 int incx); 1254 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, 1255 blas::Diagonal diag, uint64 n, 1256 const DeviceMemory<double> &ap, DeviceMemory<double> *x, 1257 int incx); 1258 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, 1259 blas::Diagonal diag, uint64 n, 1260 const DeviceMemory<std::complex<float>> &ap, 1261 DeviceMemory<std::complex<float>> *x, int incx); 1262 Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans, 1263 blas::Diagonal diag, uint64 n, 1264 const DeviceMemory<std::complex<double>> &ap, 1265 DeviceMemory<std::complex<double>> *x, int incx); 1266 1267 // See BlasSupport::DoBlasTpsv. 1268 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, 1269 blas::Diagonal diag, uint64 n, 1270 const DeviceMemory<float> &ap, DeviceMemory<float> *x, 1271 int incx); 1272 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, 1273 blas::Diagonal diag, uint64 n, 1274 const DeviceMemory<double> &ap, DeviceMemory<double> *x, 1275 int incx); 1276 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, 1277 blas::Diagonal diag, uint64 n, 1278 const DeviceMemory<std::complex<float>> &ap, 1279 DeviceMemory<std::complex<float>> *x, int incx); 1280 Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans, 1281 blas::Diagonal diag, uint64 n, 1282 const DeviceMemory<std::complex<double>> &ap, 1283 DeviceMemory<std::complex<double>> *x, int incx); 1284 1285 // See BlasSupport::DoBlasTrmv. 1286 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, 1287 blas::Diagonal diag, uint64 n, 1288 const DeviceMemory<float> &a, int lda, 1289 DeviceMemory<float> *x, int incx); 1290 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, 1291 blas::Diagonal diag, uint64 n, 1292 const DeviceMemory<double> &a, int lda, 1293 DeviceMemory<double> *x, int incx); 1294 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, 1295 blas::Diagonal diag, uint64 n, 1296 const DeviceMemory<std::complex<float>> &a, int lda, 1297 DeviceMemory<std::complex<float>> *x, int incx); 1298 Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans, 1299 blas::Diagonal diag, uint64 n, 1300 const DeviceMemory<std::complex<double>> &a, int lda, 1301 DeviceMemory<std::complex<double>> *x, int incx); 1302 1303 // See BlasSupport::DoBlasTrsv. 1304 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, 1305 blas::Diagonal diag, uint64 n, 1306 const DeviceMemory<float> &a, int lda, 1307 DeviceMemory<float> *x, int incx); 1308 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, 1309 blas::Diagonal diag, uint64 n, 1310 const DeviceMemory<double> &a, int lda, 1311 DeviceMemory<double> *x, int incx); 1312 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, 1313 blas::Diagonal diag, uint64 n, 1314 const DeviceMemory<std::complex<float>> &a, int lda, 1315 DeviceMemory<std::complex<float>> *x, int incx); 1316 Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans, 1317 blas::Diagonal diag, uint64 n, 1318 const DeviceMemory<std::complex<double>> &a, int lda, 1319 DeviceMemory<std::complex<double>> *x, int incx); 1320 1321 // See BlasSupport::DoBlasGemm. 1322 Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, 1323 uint64 n, uint64 k, float alpha, 1324 const DeviceMemory<Eigen::half> &a, int lda, 1325 const DeviceMemory<Eigen::half> &b, int ldb, float beta, 1326 DeviceMemory<Eigen::half> *c, int ldc); 1327 Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, 1328 uint64 n, uint64 k, float alpha, 1329 const DeviceMemory<float> &a, int lda, 1330 const DeviceMemory<float> &b, int ldb, float beta, 1331 DeviceMemory<float> *c, int ldc); 1332 Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, 1333 uint64 n, uint64 k, double alpha, 1334 const DeviceMemory<double> &a, int lda, 1335 const DeviceMemory<double> &b, int ldb, double beta, 1336 DeviceMemory<double> *c, int ldc); 1337 Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, 1338 uint64 n, uint64 k, std::complex<float> alpha, 1339 const DeviceMemory<std::complex<float>> &a, int lda, 1340 const DeviceMemory<std::complex<float>> &b, int ldb, 1341 std::complex<float> beta, 1342 DeviceMemory<std::complex<float>> *c, int ldc); 1343 Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m, 1344 uint64 n, uint64 k, std::complex<double> alpha, 1345 const DeviceMemory<std::complex<double>> &a, int lda, 1346 const DeviceMemory<std::complex<double>> &b, int ldb, 1347 std::complex<double> beta, 1348 DeviceMemory<std::complex<double>> *c, int ldc); 1349 1350 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, 1351 blas::Transpose transb, uint64 m, uint64 n, 1352 uint64 k, float alpha, 1353 const DeviceMemory<Eigen::half> &a, int lda, 1354 const DeviceMemory<Eigen::half> &b, int ldb, 1355 float beta, DeviceMemory<Eigen::half> *c, 1356 int ldc, 1357 blas::ProfileResult *output_profile_result); 1358 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, 1359 blas::Transpose transb, uint64 m, uint64 n, 1360 uint64 k, float alpha, 1361 const DeviceMemory<float> &a, int lda, 1362 const DeviceMemory<float> &b, int ldb, 1363 float beta, DeviceMemory<float> *c, int ldc, 1364 blas::ProfileResult *output_profile_result); 1365 Stream &ThenBlasGemmWithProfiling(blas::Transpose transa, 1366 blas::Transpose transb, uint64 m, uint64 n, 1367 uint64 k, double alpha, 1368 const DeviceMemory<double> &a, int lda, 1369 const DeviceMemory<double> &b, int ldb, 1370 double beta, DeviceMemory<double> *c, 1371 int ldc, 1372 blas::ProfileResult *output_profile_result); 1373 Stream &ThenBlasGemmWithProfiling( 1374 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1375 uint64 k, std::complex<float> alpha, 1376 const DeviceMemory<std::complex<float>> &a, int lda, 1377 const DeviceMemory<std::complex<float>> &b, int ldb, 1378 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1379 blas::ProfileResult *output_profile_result); 1380 Stream &ThenBlasGemmWithProfiling( 1381 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1382 uint64 k, std::complex<double> alpha, 1383 const DeviceMemory<std::complex<double>> &a, int lda, 1384 const DeviceMemory<std::complex<double>> &b, int ldb, 1385 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1386 blas::ProfileResult *output_profile_result); 1387 1388 // See BlasSupport::DoBlasGemmWithAlgorithm. 1389 Stream &ThenBlasGemmWithAlgorithm( 1390 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1391 uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a, 1392 int lda, const DeviceMemory<Eigen::half> &b, int ldb, 1393 const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc, 1394 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 1395 blas::ProfileResult *output_profile_result); 1396 Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa, 1397 blas::Transpose transb, uint64 m, uint64 n, 1398 uint64 k, int alpha, 1399 const DeviceMemory<int8> &a, int lda, 1400 const DeviceMemory<int8> &b, int ldb, 1401 int beta, DeviceMemory<int> *c, int ldc, 1402 blas::ComputationType computation_type, 1403 blas::AlgorithmType algorithm, 1404 blas::ProfileResult *output_profile_result); 1405 Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa, 1406 blas::Transpose transb, uint64 m, uint64 n, 1407 uint64 k, float alpha, 1408 const DeviceMemory<float> &a, int lda, 1409 const DeviceMemory<float> &b, int ldb, 1410 float beta, DeviceMemory<float> *c, int ldc, 1411 blas::ComputationType computation_type, 1412 blas::AlgorithmType algorithm, 1413 blas::ProfileResult *output_profile_result); 1414 Stream &ThenBlasGemmWithAlgorithm( 1415 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1416 uint64 k, double alpha, const DeviceMemory<double> &a, int lda, 1417 const DeviceMemory<double> &b, int ldb, double beta, 1418 DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type, 1419 blas::AlgorithmType algorithm, 1420 blas::ProfileResult *output_profile_result); 1421 Stream &ThenBlasGemmWithAlgorithm( 1422 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1423 uint64 k, std::complex<float> alpha, 1424 const DeviceMemory<std::complex<float>> &a, int lda, 1425 const DeviceMemory<std::complex<float>> &b, int ldb, 1426 std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, 1427 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 1428 blas::ProfileResult *output_profile_result); 1429 Stream &ThenBlasGemmWithAlgorithm( 1430 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1431 uint64 k, std::complex<double> alpha, 1432 const DeviceMemory<std::complex<double>> &a, int lda, 1433 const DeviceMemory<std::complex<double>> &b, int ldb, 1434 std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc, 1435 blas::ComputationType computation_type, blas::AlgorithmType algorithm, 1436 blas::ProfileResult *output_profile_result); 1437 1438 // See BlasSupport::DoBlasGemmBatched. 1439 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, 1440 uint64 m, uint64 n, uint64 k, float alpha, 1441 const port::ArraySlice<DeviceMemory<float> *> &a, 1442 int lda, 1443 const port::ArraySlice<DeviceMemory<float> *> &b, 1444 int ldb, float beta, 1445 const port::ArraySlice<DeviceMemory<float> *> &c, 1446 int ldc, int batch_count); 1447 Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb, 1448 uint64 m, uint64 n, uint64 k, double alpha, 1449 const port::ArraySlice<DeviceMemory<double> *> &a, 1450 int lda, 1451 const port::ArraySlice<DeviceMemory<double> *> &b, 1452 int ldb, double beta, 1453 const port::ArraySlice<DeviceMemory<double> *> &c, 1454 int ldc, int batch_count); 1455 Stream &ThenBlasGemmBatched( 1456 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1457 uint64 k, std::complex<float> alpha, 1458 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, 1459 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, 1460 std::complex<float> beta, 1461 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, 1462 int batch_count); 1463 Stream &ThenBlasGemmBatched( 1464 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1465 uint64 k, std::complex<double> alpha, 1466 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, 1467 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, 1468 std::complex<double> beta, 1469 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, 1470 int batch_count); 1471 Stream &ThenBlasGemmBatchedWithScratch( 1472 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1473 uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a, 1474 int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb, 1475 float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc, 1476 int batch_count, ScratchAllocator *scratch_allocator); 1477 Stream &ThenBlasGemmBatchedWithScratch( 1478 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1479 uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a, 1480 int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb, 1481 double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc, 1482 int batch_count, ScratchAllocator *scratch_allocator); 1483 Stream &ThenBlasGemmBatchedWithScratch( 1484 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1485 uint64 k, std::complex<float> alpha, 1486 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda, 1487 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb, 1488 std::complex<float> beta, 1489 const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc, 1490 int batch_count, ScratchAllocator *scratch_allocator); 1491 Stream &ThenBlasGemmBatchedWithScratch( 1492 blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n, 1493 uint64 k, std::complex<double> alpha, 1494 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda, 1495 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb, 1496 std::complex<double> beta, 1497 const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc, 1498 int batch_count, ScratchAllocator *scratch_allocator); 1499 1500 // See BlasSupport::DoBlasHemm. 1501 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, 1502 uint64 n, std::complex<float> alpha, 1503 const DeviceMemory<std::complex<float>> &a, int lda, 1504 const DeviceMemory<std::complex<float>> &b, int ldb, 1505 std::complex<float> beta, 1506 DeviceMemory<std::complex<float>> *c, int ldc); 1507 Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m, 1508 uint64 n, std::complex<double> alpha, 1509 const DeviceMemory<std::complex<double>> &a, int lda, 1510 const DeviceMemory<std::complex<double>> &b, int ldb, 1511 std::complex<double> beta, 1512 DeviceMemory<std::complex<double>> *c, int ldc); 1513 1514 // See BlasSupport::DoBlasHerk. 1515 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1516 uint64 k, float alpha, 1517 const DeviceMemory<std::complex<float>> &a, int lda, 1518 float beta, DeviceMemory<std::complex<float>> *c, 1519 int ldc); 1520 Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1521 uint64 k, double alpha, 1522 const DeviceMemory<std::complex<double>> &a, int lda, 1523 double beta, DeviceMemory<std::complex<double>> *c, 1524 int ldc); 1525 1526 // See BlasSupport::DoBlasHer2k. 1527 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1528 uint64 k, std::complex<float> alpha, 1529 const DeviceMemory<std::complex<float>> &a, int lda, 1530 const DeviceMemory<std::complex<float>> &b, int ldb, 1531 float beta, DeviceMemory<std::complex<float>> *c, 1532 int ldc); 1533 Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1534 uint64 k, std::complex<double> alpha, 1535 const DeviceMemory<std::complex<double>> &a, int lda, 1536 const DeviceMemory<std::complex<double>> &b, int ldb, 1537 double beta, DeviceMemory<std::complex<double>> *c, 1538 int ldc); 1539 1540 // See BlasSupport::DoBlasSymm. 1541 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, 1542 uint64 n, float alpha, const DeviceMemory<float> &a, 1543 int lda, const DeviceMemory<float> &b, int ldb, 1544 float beta, DeviceMemory<float> *c, int ldc); 1545 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, 1546 uint64 n, double alpha, const DeviceMemory<double> &a, 1547 int lda, const DeviceMemory<double> &b, int ldb, 1548 double beta, DeviceMemory<double> *c, int ldc); 1549 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, 1550 uint64 n, std::complex<float> alpha, 1551 const DeviceMemory<std::complex<float>> &a, int lda, 1552 const DeviceMemory<std::complex<float>> &b, int ldb, 1553 std::complex<float> beta, 1554 DeviceMemory<std::complex<float>> *c, int ldc); 1555 Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m, 1556 uint64 n, std::complex<double> alpha, 1557 const DeviceMemory<std::complex<double>> &a, int lda, 1558 const DeviceMemory<std::complex<double>> &b, int ldb, 1559 std::complex<double> beta, 1560 DeviceMemory<std::complex<double>> *c, int ldc); 1561 1562 // See BlasSupport::DoBlasSyrk. 1563 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1564 uint64 k, float alpha, const DeviceMemory<float> &a, 1565 int lda, float beta, DeviceMemory<float> *c, int ldc); 1566 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1567 uint64 k, double alpha, const DeviceMemory<double> &a, 1568 int lda, double beta, DeviceMemory<double> *c, int ldc); 1569 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1570 uint64 k, std::complex<float> alpha, 1571 const DeviceMemory<std::complex<float>> &a, int lda, 1572 std::complex<float> beta, 1573 DeviceMemory<std::complex<float>> *c, int ldc); 1574 Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1575 uint64 k, std::complex<double> alpha, 1576 const DeviceMemory<std::complex<double>> &a, int lda, 1577 std::complex<double> beta, 1578 DeviceMemory<std::complex<double>> *c, int ldc); 1579 1580 // See BlasSupport::DoBlasSyr2k. 1581 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1582 uint64 k, float alpha, const DeviceMemory<float> &a, 1583 int lda, const DeviceMemory<float> &b, int ldb, 1584 float beta, DeviceMemory<float> *c, int ldc); 1585 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1586 uint64 k, double alpha, const DeviceMemory<double> &a, 1587 int lda, const DeviceMemory<double> &b, int ldb, 1588 double beta, DeviceMemory<double> *c, int ldc); 1589 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1590 uint64 k, std::complex<float> alpha, 1591 const DeviceMemory<std::complex<float>> &a, int lda, 1592 const DeviceMemory<std::complex<float>> &b, int ldb, 1593 std::complex<float> beta, 1594 DeviceMemory<std::complex<float>> *c, int ldc); 1595 Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n, 1596 uint64 k, std::complex<double> alpha, 1597 const DeviceMemory<std::complex<double>> &a, int lda, 1598 const DeviceMemory<std::complex<double>> &b, int ldb, 1599 std::complex<double> beta, 1600 DeviceMemory<std::complex<double>> *c, int ldc); 1601 1602 // See BlasSupport::DoBlasTrmm. 1603 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, 1604 blas::Transpose transa, blas::Diagonal diag, uint64 m, 1605 uint64 n, float alpha, const DeviceMemory<float> &a, 1606 int lda, DeviceMemory<float> *b, int ldb); 1607 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, 1608 blas::Transpose transa, blas::Diagonal diag, uint64 m, 1609 uint64 n, double alpha, const DeviceMemory<double> &a, 1610 int lda, DeviceMemory<double> *b, int ldb); 1611 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, 1612 blas::Transpose transa, blas::Diagonal diag, uint64 m, 1613 uint64 n, std::complex<float> alpha, 1614 const DeviceMemory<std::complex<float>> &a, int lda, 1615 DeviceMemory<std::complex<float>> *b, int ldb); 1616 Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo, 1617 blas::Transpose transa, blas::Diagonal diag, uint64 m, 1618 uint64 n, std::complex<double> alpha, 1619 const DeviceMemory<std::complex<double>> &a, int lda, 1620 DeviceMemory<std::complex<double>> *b, int ldb); 1621 1622 // See BlasSupport::DoBlasTrsm. 1623 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, 1624 blas::Transpose transa, blas::Diagonal diag, uint64 m, 1625 uint64 n, float alpha, const DeviceMemory<float> &a, 1626 int lda, DeviceMemory<float> *b, int ldb); 1627 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, 1628 blas::Transpose transa, blas::Diagonal diag, uint64 m, 1629 uint64 n, double alpha, const DeviceMemory<double> &a, 1630 int lda, DeviceMemory<double> *b, int ldb); 1631 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, 1632 blas::Transpose transa, blas::Diagonal diag, uint64 m, 1633 uint64 n, std::complex<float> alpha, 1634 const DeviceMemory<std::complex<float>> &a, int lda, 1635 DeviceMemory<std::complex<float>> *b, int ldb); 1636 Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo, 1637 blas::Transpose transa, blas::Diagonal diag, uint64 m, 1638 uint64 n, std::complex<double> alpha, 1639 const DeviceMemory<std::complex<double>> &a, int lda, 1640 DeviceMemory<std::complex<double>> *b, int ldb); 1641 1642 // See FftSupport::DoFft. 1643 Stream &ThenFft(fft::Plan *plan, 1644 const DeviceMemory<std::complex<float>> &input, 1645 DeviceMemory<std::complex<float>> *output); 1646 Stream &ThenFft(fft::Plan *plan, 1647 const DeviceMemory<std::complex<double>> &input, 1648 DeviceMemory<std::complex<double>> *output); 1649 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input, 1650 DeviceMemory<std::complex<float>> *output); 1651 Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input, 1652 DeviceMemory<std::complex<double>> *output); 1653 Stream &ThenFft(fft::Plan *plan, 1654 const DeviceMemory<std::complex<float>> &input, 1655 DeviceMemory<float> *output); 1656 Stream &ThenFft(fft::Plan *plan, 1657 const DeviceMemory<std::complex<double>> &input, 1658 DeviceMemory<double> *output); 1659 1660 // Makes the RNG use the provided value as the basis for further generation. 1661 // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good 1662 // sources of seed data if the default (high quality) sources are not 1663 // desired. 1664 // For most use cases, this function will not be necessary; each provided 1665 // back-end implementation will be appropriately seeded by default. 1666 // At a minimum 16 bytes of data are required in the seed buffer. 1667 // 1668 // To seed with good (non-reproducible) data: 1669 // File* f = File::Open("/dev/random", "r"); 1670 // int64 bytes_read = f->Read(seed_data, bytes_to_read); 1671 // < error checking > 1672 // stream.ThenSetRngSeed(seed_data, bytes_read); 1673 // 1674 // To seed with reproducible data: 1675 // uint64_t seed_data[2] = { <data> }; 1676 // stream.ThenSetRngSeed(seed_data, 16); 1677 Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes); 1678 1679 // Populates the memory indicated by values with uniform-random-distribution 1680 // values. TODO(leary) seeding API/description 1681 // 1682 // Uses the type and size of the DeviceMemory to infer what data should be 1683 // populated. 1684 Stream &ThenPopulateRandUniform(DeviceMemory<float> *values); 1685 Stream &ThenPopulateRandUniform(DeviceMemory<double> *values); 1686 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values); 1687 Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values); 1688 Stream &ThenPopulateRandGaussian(float mean, float stddev, 1689 DeviceMemory<float> *values); 1690 Stream &ThenPopulateRandGaussian(double mean, double stddev, 1691 DeviceMemory<double> *values); 1692 1693 // Entrain onto the stream: a memcpy to a host destination from a GPU source 1694 // of the given target size. host_dst must be a pointer to host memory 1695 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and 1696 // then registered with StreamExecutor::HostMemoryRegister. 1697 Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, 1698 uint64 size); 1699 1700 // Entrain onto the stream: a memcpy to a GPU destination from a host source 1701 // of the given target size. host_src must be a pointer to host memory 1702 // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and 1703 // then registered with StreamExecutor::HostMemoryRegister. 1704 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, 1705 uint64 size); 1706 1707 // Alternative interface for memcpying from device to host that takes an 1708 // array slice. Checks that the destination size can accommodate the host 1709 // slice size. 1710 template <typename T> 1711 Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src, 1712 port::MutableArraySlice<T> host_dst) { 1713 auto host_size = host_dst.size() * sizeof(T); 1714 CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size()); 1715 return ThenMemcpy(host_dst.begin(), gpu_src, host_size); 1716 } 1717 1718 // Alternative interface for memcpying from host to device that takes an 1719 // array slice. Checks that the destination size can accommodate the host 1720 // slice size. 1721 template <typename T> 1722 Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src, 1723 DeviceMemory<T> *gpu_dst) { 1724 auto host_size = host_src.size() * sizeof(T); 1725 CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size); 1726 return ThenMemcpy(gpu_dst, host_src.begin(), host_size); 1727 } 1728 1729 // Entrain onto the stream: a memcpy to a GPU destination from a GPU source 1730 // of the given target size. gpu_src/dst must be pointers to GPU memory and 1731 // peer access must be enabled between their owning StreamExecutors. 1732 Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, 1733 uint64 size); 1734 1735 // Calls to the device-to-device copy overload of ThenMemcpy -- useful for 1736 // ensuring that the host pointer isn't getting confused accidentally with a 1737 // device pointer if you're not doing metaprogramming against the API. 1738 Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst, 1739 const DeviceMemoryBase &gpu_src, uint64 size) { 1740 return ThenMemcpy(gpu_dst, gpu_src, size); 1741 } 1742 1743 // Entrain onto the stream: a memset of zero at a GPU location of size bytes. 1744 // The location must not be null. 1745 Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size); 1746 1747 // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of 1748 // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible 1749 // by 4). The location must not be null. 1750 Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size); 1751 1752 // Enqueue a forward operation of the RNN model onto the stream. 1753 // See DnnSupport::DoRnnForward for more details. 1754 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, 1755 const dnn::RnnSequenceTensorDescriptor &input_desc, 1756 const DeviceMemory<Eigen::half> &input_data, 1757 const dnn::RnnStateTensorDescriptor &input_h_desc, 1758 const DeviceMemory<Eigen::half> &input_h_data, 1759 const dnn::RnnStateTensorDescriptor &input_c_desc, 1760 const DeviceMemory<Eigen::half> &input_c_data, 1761 const DeviceMemory<Eigen::half> ¶ms, 1762 const dnn::RnnSequenceTensorDescriptor &output_desc, 1763 DeviceMemory<Eigen::half> *output_data, 1764 const dnn::RnnStateTensorDescriptor &output_h_desc, 1765 DeviceMemory<Eigen::half> *output_h_data, 1766 const dnn::RnnStateTensorDescriptor &output_c_desc, 1767 DeviceMemory<Eigen::half> *output_c_data, 1768 bool is_training, 1769 ScratchAllocator *reserve_space_allocator, 1770 ScratchAllocator *workspace_allocator); 1771 1772 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, 1773 const dnn::RnnSequenceTensorDescriptor &input_desc, 1774 const DeviceMemory<float> &input_data, 1775 const dnn::RnnStateTensorDescriptor &input_h_desc, 1776 const DeviceMemory<float> &input_h_data, 1777 const dnn::RnnStateTensorDescriptor &input_c_desc, 1778 const DeviceMemory<float> &input_c_data, 1779 const DeviceMemory<float> ¶ms, 1780 const dnn::RnnSequenceTensorDescriptor &output_desc, 1781 DeviceMemory<float> *output_data, 1782 const dnn::RnnStateTensorDescriptor &output_h_desc, 1783 DeviceMemory<float> *output_h_data, 1784 const dnn::RnnStateTensorDescriptor &output_c_desc, 1785 DeviceMemory<float> *output_c_data, bool is_training, 1786 ScratchAllocator *reserve_space_allocator, 1787 ScratchAllocator *workspace_allocator); 1788 1789 Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc, 1790 const dnn::RnnSequenceTensorDescriptor &input_desc, 1791 const DeviceMemory<double> &input_data, 1792 const dnn::RnnStateTensorDescriptor &input_h_desc, 1793 const DeviceMemory<double> &input_h_data, 1794 const dnn::RnnStateTensorDescriptor &input_c_desc, 1795 const DeviceMemory<double> &input_c_data, 1796 const DeviceMemory<double> ¶ms, 1797 const dnn::RnnSequenceTensorDescriptor &output_desc, 1798 DeviceMemory<double> *output_data, 1799 const dnn::RnnStateTensorDescriptor &output_h_desc, 1800 DeviceMemory<double> *output_h_data, 1801 const dnn::RnnStateTensorDescriptor &output_c_desc, 1802 DeviceMemory<double> *output_c_data, bool is_training, 1803 ScratchAllocator *reserve_space_allocator, 1804 ScratchAllocator *workspace_allocator); 1805 1806 // Enqueue a backward operation of the RNN model onto the stream. 1807 // See DnnSupport::DoRnnBackward for more details. 1808 Stream &ThenRnnBackward( 1809 const dnn::RnnDescriptor &rnn_desc, 1810 const dnn::RnnSequenceTensorDescriptor &input_desc, 1811 const DeviceMemory<Eigen::half> &input_data, 1812 const dnn::RnnStateTensorDescriptor &input_h_desc, 1813 const DeviceMemory<Eigen::half> &input_h_data, 1814 const dnn::RnnStateTensorDescriptor &input_c_desc, 1815 const DeviceMemory<Eigen::half> &input_c_data, 1816 const DeviceMemory<Eigen::half> ¶ms, 1817 const dnn::RnnSequenceTensorDescriptor &output_desc, 1818 const DeviceMemory<Eigen::half> &output_data, 1819 const dnn::RnnStateTensorDescriptor &output_h_desc, 1820 const DeviceMemory<Eigen::half> &output_h_data, 1821 const dnn::RnnStateTensorDescriptor &output_c_desc, 1822 const DeviceMemory<Eigen::half> &output_c_data, 1823 const DeviceMemory<Eigen::half> &output_backprop_data, 1824 const DeviceMemory<Eigen::half> &output_h_backprop_data, 1825 const DeviceMemory<Eigen::half> &output_c_backprop_data, 1826 DeviceMemory<Eigen::half> *input_backprop_data, 1827 DeviceMemory<Eigen::half> *input_h_backprop_data, 1828 DeviceMemory<Eigen::half> *input_c_backprop_data, 1829 DeviceMemory<Eigen::half> *params_backprop_data, 1830 DeviceMemory<uint8> *reserve_space_data, 1831 ScratchAllocator *workspace_allocator); 1832 1833 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, 1834 const dnn::RnnSequenceTensorDescriptor &input_desc, 1835 const DeviceMemory<float> &input_data, 1836 const dnn::RnnStateTensorDescriptor &input_h_desc, 1837 const DeviceMemory<float> &input_h_data, 1838 const dnn::RnnStateTensorDescriptor &input_c_desc, 1839 const DeviceMemory<float> &input_c_data, 1840 const DeviceMemory<float> ¶ms, 1841 const dnn::RnnSequenceTensorDescriptor &output_desc, 1842 const DeviceMemory<float> &output_data, 1843 const dnn::RnnStateTensorDescriptor &output_h_desc, 1844 const DeviceMemory<float> &output_h_data, 1845 const dnn::RnnStateTensorDescriptor &output_c_desc, 1846 const DeviceMemory<float> &output_c_data, 1847 const DeviceMemory<float> &output_backprop_data, 1848 const DeviceMemory<float> &output_h_backprop_data, 1849 const DeviceMemory<float> &output_c_backprop_data, 1850 DeviceMemory<float> *input_backprop_data, 1851 DeviceMemory<float> *input_h_backprop_data, 1852 DeviceMemory<float> *input_c_backprop_data, 1853 DeviceMemory<float> *params_backprop_data, 1854 DeviceMemory<uint8> *reserve_space_data, 1855 ScratchAllocator *workspace_allocator); 1856 1857 Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc, 1858 const dnn::RnnSequenceTensorDescriptor &input_desc, 1859 const DeviceMemory<double> &input_data, 1860 const dnn::RnnStateTensorDescriptor &input_h_desc, 1861 const DeviceMemory<double> &input_h_data, 1862 const dnn::RnnStateTensorDescriptor &input_c_desc, 1863 const DeviceMemory<double> &input_c_data, 1864 const DeviceMemory<double> ¶ms, 1865 const dnn::RnnSequenceTensorDescriptor &output_desc, 1866 const DeviceMemory<double> &output_data, 1867 const dnn::RnnStateTensorDescriptor &output_h_desc, 1868 const DeviceMemory<double> &output_h_data, 1869 const dnn::RnnStateTensorDescriptor &output_c_desc, 1870 const DeviceMemory<double> &output_c_data, 1871 const DeviceMemory<double> &output_backprop_data, 1872 const DeviceMemory<double> &output_h_backprop_data, 1873 const DeviceMemory<double> &output_c_backprop_data, 1874 DeviceMemory<double> *input_backprop_data, 1875 DeviceMemory<double> *input_h_backprop_data, 1876 DeviceMemory<double> *input_c_backprop_data, 1877 DeviceMemory<double> *params_backprop_data, 1878 DeviceMemory<uint8> *reserve_space_data, 1879 ScratchAllocator *workspace_allocator); 1880 1881 // Enqueue onto the stream a operation that transforms a tensor. 1882 // See DnnSupport::DoTransformTensor for more details. 1883 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, 1884 dnn::DataType input_type, 1885 const DeviceMemoryBase &input_data, 1886 const dnn::BatchDescriptor &output_desc, 1887 dnn::DataType output_type, float scale, 1888 DeviceMemoryBase *output_data); 1889 1890 // The templated version of the above ThenTransformTensor. Useful when the 1891 // input and output types are statically known. 1892 template <typename InElemT, typename OutElemT> 1893 Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc, 1894 const DeviceMemory<InElemT> &input_data, 1895 const dnn::BatchDescriptor &output_desc, 1896 DeviceMemory<OutElemT> *output_data) { 1897 return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(), 1898 input_data, output_desc, 1899 dnn::ToDataType<OutElemT>(), output_data); 1900 } 1901 1902 // (Synchronously) block the host code waiting for the operations 1903 // entrained on the stream (enqueued to this point in program 1904 // execution) to complete. 1905 // 1906 // Returns an OK status if the blocking was successful and the stream is ok(). 1907 // Otherwise returns an error describing why the blocking failed. 1908 port::Status BlockHostUntilDone() LOCKS_EXCLUDED(mu_); 1909 1910 // Warning! This method interacts with internal threads in 1911 // sometimes-unpredictable ways and is intended for GPU-Executor-internal 1912 // use 1913 // only. Please check with a member of the FASTR team before making use of 1914 // this method. 1915 // 1916 // Entrains onto the stream a function to be executed on the host at some 1917 // point in the future. 1918 // Async host callbacks DO NOT block the stream as device functions (or as 1919 // synchronous host callbacks). No synchronization is possible with 1920 // asynchronous callbacks; they are strictly fire-and-forget. 1921 // This method is private due to the potential for undefined behavior with 1922 // synchronization using OpenCL user events. 1923 // The ONLY lifetime guarantee in these calls is that the StreamExecutor 1924 // parameter will still be valid - this Stream may not be! 1925 // Any callbacks requiring device API calls must use this method. 1926 Stream &ThenEnqueueOnBackgroundThread( 1927 std::function<void(StreamExecutor *)> task); 1928 1929 // Returns the (opaque) platform-specific backing object. Ownership is not 1930 // transferred to the caller. 1931 internal::StreamInterface *implementation() { return implementation_.get(); } 1932 1933 // Entrains onto the stream a callback to the host (from the device). 1934 // Host callbacks block/occupy the stream just as device functions 1935 // (execute one at a time, block later stream operations). 1936 // Behavior is undefined when synchronizing using OpenCL user events. 1937 // Behavior is undefined if host callbacks call device routines or insert 1938 // them into any stream. 1939 // On certain platforms, ThenDoHostCallback is expected to have significant 1940 // negative effects on performance. 1941 Stream &ThenDoHostCallback(std::function<void()> callback); 1942 1943 // Identical to ThenDoHostCallback; only exposed for testing purposes. 1944 Stream &ThenDoHostCallbackForTest(std::function<void()> callback); 1945 1946 // Returns the StreamExecutor (parent object) associated with this stream. 1947 StreamExecutor *parent() const { 1948 CHECK(parent_ != nullptr); 1949 return parent_; 1950 } 1951 1952 // Returns the (internal usage) temporary-memory-allocation manager associated 1953 // with this stream. 1954 internal::TemporaryMemoryManager *temporary_memory_manager(); 1955 1956 private: 1957 friend class host::HostBlas; // for parent_. 1958 friend class host::HostFft; // for parent_. 1959 friend class host::HostRng; // for parent_. 1960 template <typename... Args> 1961 friend struct ThenBlasImpl; // for implementing ThenBlasXXX. 1962 friend class ocl::CLBlas; // for parent_. 1963 1964 bool InErrorState() const LOCKS_EXCLUDED(mu_) { 1965 tf_shared_lock lock{mu_}; 1966 return !ok_; 1967 } 1968 1969 // Sets the error state if operation_retcode is false. 1970 // This is a useful shorthand for many stream routines. 1971 void CheckError(bool operation_retcode) LOCKS_EXCLUDED(mu_) { 1972 if (operation_retcode) { 1973 return; 1974 } 1975 mutex_lock lock{mu_}; 1976 ok_ = false; 1977 } 1978 1979 void SetError() { CheckError(false /* = operation_retcode */); } 1980 1981 void SetErrorAndLogNoDnnSupport() { 1982 SetError(); 1983 LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor " 1984 "without DNN support"; 1985 } 1986 1987 // The StreamExecutor that supports the operation of this stream. 1988 StreamExecutor *parent_; 1989 1990 // The platform-dependent implementation that the StreamExecutor interface 1991 // delegates to. 1992 std::unique_ptr<internal::StreamInterface> implementation_; 1993 1994 // mutex that guards the allocation / error state flags. 1995 // Mutable so that it can be obtained via const reader lock. 1996 mutable mutex mu_; 1997 1998 // Whether Init() was successfully called to allocate this stream on the 1999 // underlying platform. It simply flips from 0 to 1 with a sanity check. 2000 // See StreamExecutor::AllocateStream. 2001 bool allocated_ GUARDED_BY(mu_); 2002 2003 // Whether all operations have entrained successfully to the current program 2004 // point. 2005 bool ok_ GUARDED_BY(mu_); 2006 2007 // Sub-streams that are generated from this stream. Each element has a pointer 2008 // to sub-stream and a boolean value indicating if this substream is ready to 2009 // be reused. 2010 std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_ 2011 GUARDED_BY(mu_); 2012 2013 // Streams can allocate temporary memories to help with work they enqueue 2014 // (e.g. for scratch memory spaces). This member tracks those allocations and 2015 // notes when they can be reclaimed -- reclamation is attempted when 2016 // BlockHostUntilDone() is called. 2017 internal::TemporaryMemoryManager temporary_memory_manager_; 2018 2019 // Implementation of ThenConvolveBackwardBias that is shared by all types. 2020 template <typename T> 2021 Stream &ThenConvolveBackwardBiasImpl( 2022 const dnn::BatchDescriptor &input_descriptor, 2023 const DeviceMemory<T> &input_data, 2024 const dnn::BatchDescriptor &bias_descriptor, 2025 DeviceMemory<T> *backward_bias_data); 2026 2027 SE_DISALLOW_COPY_AND_ASSIGN(Stream); 2028 }; 2029 2030 //////////// 2031 // Inlines 2032 2033 template <typename T> 2034 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>> 2035 Stream::AllocateTemporaryArray(uint64 element_count) { 2036 return temporary_memory_manager_.AllocateArray<T>(element_count); 2037 } 2038 2039 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() { 2040 return &temporary_memory_manager_; 2041 } 2042 2043 template <> 2044 struct Quantization<uint8> { 2045 static constexpr dnn::QuantizedActivationMode kModeId = 2046 dnn::QuantizedActivationMode::k8Bit; 2047 }; 2048 2049 template <> 2050 struct Quantization<uint16> { 2051 static constexpr dnn::QuantizedActivationMode kModeId = 2052 dnn::QuantizedActivationMode::k16Bit; 2053 }; 2054 2055 template <> 2056 struct Quantization<int32> { 2057 static constexpr dnn::QuantizedActivationMode kModeId = 2058 dnn::QuantizedActivationMode::k32Bit; 2059 }; 2060 2061 } // namespace gputools 2062 } // namespace perftools 2063 2064 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_ 2065