Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // The Stream is used in conjunction with the StreamExecutor "parent" to
     17 // perform actions with a linear stream of dependencies. Dependencies can also
     18 // be created between Streams to do task management (i.e. limit which tasks
     19 // can be performed concurrently and specify what task dependencies exist).
     20 
     21 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
     22 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
     23 
     24 #include <complex>
     25 #include <functional>
     26 #include <memory>
     27 
     28 #include "tensorflow/stream_executor/blas.h"
     29 #include "tensorflow/stream_executor/device_memory.h"
     30 #include "tensorflow/stream_executor/dnn.h"
     31 #include "tensorflow/stream_executor/event.h"
     32 #include "tensorflow/stream_executor/fft.h"
     33 #include "tensorflow/stream_executor/kernel.h"
     34 #include "tensorflow/stream_executor/launch_dim.h"
     35 #include "tensorflow/stream_executor/lib/array_slice.h"
     36 #include "tensorflow/stream_executor/platform/mutex.h"
     37 #include "tensorflow/stream_executor/platform/port.h"
     38 #include "tensorflow/stream_executor/platform/thread_annotations.h"
     39 #include "tensorflow/stream_executor/temporary_memory_manager.h"
     40 
     41 namespace perftools {
     42 namespace gputools {
     43 
     44 namespace host {
     45 class HostBlas;
     46 class HostFft;
     47 class HostRng;
     48 class HostTimer;
     49 }  // namespace host
     50 
     51 namespace ocl {
     52 class CLBlas;
     53 }  // namespace ocl
     54 
     55 namespace internal {
     56 class StreamInterface;
     57 }  // namespace internal
     58 
     59 class DeviceMemoryBase;
     60 template <typename ElemT>
     61 class DeviceMemory;
     62 
     63 class Timer;
     64 
     65 namespace dnn {
     66 class BatchDescriptor;
     67 class FilterDescriptor;
     68 class ConvolutionDescriptor;
     69 class BatchDescriptor;
     70 class FilterDescriptor;
     71 class ConvolutionDescriptor;
     72 class ProfileResult;
     73 class AlgorithmDesc;
     74 }  // namespace dnn
     75 
     76 class StreamExecutor;
     77 class ScratchAllocator;
     78 
     79 // Convert a type to the corresponding QuantizedActivationMode.
     80 template <typename ElementType>
     81 struct Quantization;
     82 
     83 // Represents a stream of dependent computations on a GPU device.
     84 //
     85 // The operations within a stream execute linearly and asynchronously until
     86 // BlockHostUntilDone() is invoked, which synchronously joins host code with
     87 // the execution of the stream.
     88 //
     89 // If any given operation fails when entraining work for the stream, ok() will
     90 // indicate that an error has occurred. After initialization, once a stream is
     91 // !ok(), it will never be ok().
     92 //
     93 // Thread-safe post-initialization.
     94 class Stream {
     95  public:
     96   // Instantiate a stream tied to parent as a platform executor. Work
     97   // entrained onto this stream will be launched/managed on that
     98   // StreamExecutor's platform.
     99   explicit Stream(StreamExecutor *parent);
    100 
    101   // Test only. Use an externally-populated value (like a mock) for the
    102   // platform-specific stream implementation.
    103   Stream(StreamExecutor *parent, internal::StreamInterface *implementation);
    104 
    105   // Deallocates any stream resources that the parent StreamExecutor has
    106   // bestowed
    107   // upon this object.
    108   ~Stream();
    109 
    110   // Returns whether any errors have occurred while entraining work for this
    111   // stream.
    112   bool ok() const { return !InErrorState(); }
    113 
    114   // Initialize the stream. This must be performed before entraining any other
    115   // operations.
    116   Stream &Init() LOCKS_EXCLUDED(mu_);
    117 
    118   // Initializes timer t via the StreamExecutor.
    119   Stream &InitTimer(Timer *t);
    120 
    121   // Convenience wrapper around Init() and InitTimer().
    122   Stream &InitWithTimer(Timer *t);
    123 
    124   // Get or create a sub-stream from this stream. If there is any sub-stream in
    125   // the pool that can be reused then just return this sub-stream.  Otherwise
    126   // create a new sub-stream.
    127   Stream *GetOrCreateSubStream() LOCKS_EXCLUDED(mu_);
    128 
    129   // Return the sub-stream back to the host stream so that it can be reused
    130   // later.
    131   void ReturnSubStream(Stream *sub_stream) LOCKS_EXCLUDED(mu_);
    132 
    133   // Allocate temporary memories. The stream will deallocate them when blocked
    134   // or destroyed.
    135   template <typename T>
    136   port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
    137   AllocateTemporaryArray(uint64 element_count);
    138 
    139   // Entrains onto the stream of operations: a kernel launch with the given
    140   // (variadic) parameters for the invocation. These arguments can be things
    141   // like DeviceMemory or primitive types such as int. What arguments you may
    142   // pass to a given kernel are noted as the template parameters to the
    143   // TypedKernel type that the machocc compiler generates.
    144   //
    145   // Template parameters:
    146   //  Params...   The type list of formal parameters that the typed kernel
    147   //              expects, which is matched against Args...
    148   //  Args...     The deduced type list for passed actual arguments
    149   //
    150   // Implementation: A compile-time compatibility check is performed that has
    151   // some leniency versus an exact parameter pack match -- for example,
    152   // `const DeviceMemory<T>` is considered "pack compatible" with a
    153   // `const DeviceMemory<T>&` formal parameter; in part, because we don't have
    154   // perfect forwarding support without rvalue references. It also attempts to
    155   // spit out helpful static_assert error traces with information as to the
    156   // argument number and types that were mismatched.
    157   template <typename... Params, typename... Args>
    158   Stream &ThenLaunch(ThreadDim thread_dims, BlockDim block_dims,
    159                      const TypedKernel<Params...> &kernel, Args... args);
    160 
    161   // Record a "start" event for the interval timer at this point in the
    162   // stream's
    163   // execution (relative to the previously and subsequently enqueued items in
    164   // the stream's execution). Streams may be started/stopped multiple times.
    165   Stream &ThenStartTimer(Timer *t);
    166 
    167   // Record a "stop" event for the interval timer at this point in the
    168   // stream's
    169   // execution. See also Stream::ThenStartTimer.
    170   Stream &ThenStopTimer(Timer *t);
    171 
    172   // TODO(leary) If work is added to the stream that is being depended upon,
    173   //              then what? Have to describe what happens.
    174   template <typename... Params>
    175   Stream &ThenWaitFor(Stream *other, Params... more_streams) {
    176     return ThenWaitFor(more_streams...).ThenWaitFor(other);
    177   }
    178 
    179   // Create a dependency for this stream's next work on the other stream
    180   // completing. Does not take ownership of other, and other must not be
    181   // null.
    182   //
    183   // Checks that a stream does not wait for itself, and it is up to the
    184   // user to guarantee that a stream does not come to wait on itself in a
    185   // cyclic
    186   // manner; in that case, behavior is undefined.
    187   //
    188   // N.B. Base recursion case for the variadic ThenWaitFor.
    189   Stream &ThenWaitFor(Stream *other);
    190 
    191   // Waits for all streams values in others.
    192   // Checks that there is no shallow circular wait (i.e. that "this" is not in
    193   // others)
    194   template <typename P>
    195   Stream &ThenWaitFor(P others) {
    196     for (auto &stream : *others) {
    197       CHECK_NE(stream.get(), this);
    198       ThenWaitFor(stream.get());
    199     }
    200     return *this;
    201   }
    202 
    203   // Waits for an event object to be set.
    204   // Note that ThenRecordEvent must have been called on the event before
    205   // you call this function; otherwise the event will be considered complete
    206   // and this wait will do nothing.
    207   Stream &ThenWaitFor(Event *event);
    208 
    209   // Inserts the specified event into the end of this stream. Once the stream
    210   // has processed all events prior to the insertion point, the event will be
    211   // marked as completed.
    212   // The stream does not take ownership of event - meaning that event's lifetime
    213   // must extend past the point at which it is marked complete!
    214   Stream &ThenRecordEvent(Event *event);
    215 
    216   ////////////////
    217   // DNN support
    218   //
    219   // See DnnSupport::* for comments on the following methods.
    220 
    221   Stream &ThenBatchNormalizationForward(
    222       const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
    223       const DeviceMemory<float> &offset,
    224       const DeviceMemory<float> &estimated_mean,
    225       const DeviceMemory<float> &estimated_variance,
    226       const dnn::BatchDescriptor &x_desc,
    227       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    228       DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
    229       DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
    230       DeviceMemory<float> *saved_inv_var, bool is_training,
    231       std::function<const DeviceMemory<float> &()> var_to_inv_var,
    232       std::function<void()> inv_var_to_var);
    233 
    234   Stream &ThenBatchNormalizationBackward(
    235       const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
    236       const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
    237       const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
    238       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    239       DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
    240       DeviceMemory<float> *offset_backprop);
    241 
    242   Stream &ThenBatchNormalizationForward(
    243       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
    244       const DeviceMemory<float> &offset,
    245       const DeviceMemory<float> &estimated_mean,
    246       const DeviceMemory<float> &estimated_variance,
    247       const dnn::BatchDescriptor &x_desc,
    248       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    249       DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
    250       DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
    251       DeviceMemory<float> *saved_inv_var, bool is_training,
    252       std::function<const DeviceMemory<float> &()> var_to_inv_var,
    253       std::function<void()> inv_var_to_var);
    254 
    255   Stream &ThenBatchNormalizationBackward(
    256       const DeviceMemory<Eigen::half> &y_backprop,
    257       const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
    258       const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
    259       const dnn::BatchDescriptor &x_desc,
    260       const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    261       DeviceMemory<Eigen::half> *x_backprop,
    262       DeviceMemory<float> *scale_backprop,
    263       DeviceMemory<float> *offset_backprop);
    264 
    265   // TODO(leary) add double-precision version of this interface.
    266   Stream &ThenFusedConvolve(
    267       const dnn::BatchDescriptor &conv_input_descriptor,
    268       const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
    269       const dnn::FilterDescriptor &filter_descriptor,
    270       const DeviceMemory<int8> &filter_data,
    271       const dnn::ConvolutionDescriptor &convolution_descriptor,
    272       const DeviceMemory<int8> &side_input_data, float side_input_scale,
    273       const dnn::BatchDescriptor &bias_descriptor,
    274       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    275       const dnn::BatchDescriptor &output_descriptor,
    276       DeviceMemory<int8> *output);
    277 
    278   Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
    279                        const DeviceMemory<float> &input_data,
    280                        const dnn::FilterDescriptor &filter_descriptor,
    281                        const DeviceMemory<float> &filter_data,
    282                        const dnn::ConvolutionDescriptor &convolution_descriptor,
    283                        const dnn::BatchDescriptor &output_descriptor,
    284                        DeviceMemory<float> *output);
    285 
    286   Stream &ThenConvolveQuantized(
    287       const dnn::BatchDescriptor &input_descriptor,
    288       const DeviceMemory<float> &input_data,
    289       const dnn::FilterDescriptor &filter_descriptor,
    290       const DeviceMemory<int8> &filter_coefficients,
    291       const DeviceMemory<float> &coefficient_scales,
    292       const dnn::ConvolutionDescriptor &convolution_descriptor,
    293       const dnn::BatchDescriptor &output_descriptor,
    294       DeviceMemory<float> *output_data);
    295 
    296   Stream &ThenConvolveQuantized(
    297       const dnn::BatchDescriptor &input_descriptor,
    298       const DeviceMemory<float> &input_data,
    299       const dnn::FilterDescriptor &filter_descriptor,
    300       const DeviceMemory<int16> &filter_coefficients,
    301       const DeviceMemory<float> &coefficient_scales,
    302       const dnn::ConvolutionDescriptor &convolution_descriptor,
    303       const dnn::BatchDescriptor &output_descriptor,
    304       DeviceMemory<float> *output_data);
    305 
    306   Stream &ThenFusedConvolveWithScratch(
    307       const dnn::BatchDescriptor &conv_input_descriptor,
    308       const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
    309       const dnn::FilterDescriptor &filter_descriptor,
    310       const DeviceMemory<int8> &filter_data,
    311       const dnn::ConvolutionDescriptor &convolution_descriptor,
    312       const DeviceMemory<int8> &side_input_data, float side_input_scale,
    313       const dnn::BatchDescriptor &bias_descriptor,
    314       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    315       const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
    316       ScratchAllocator *scratch_allocator);
    317 
    318   Stream &ThenFusedConvolveWithScratch(
    319       const dnn::BatchDescriptor &conv_input_descriptor,
    320       const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
    321       const dnn::FilterDescriptor &filter_descriptor,
    322       const DeviceMemory<Eigen::half> &filter_data,
    323       const dnn::ConvolutionDescriptor &convolution_descriptor,
    324       const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
    325       const dnn::BatchDescriptor &bias_descriptor,
    326       const DeviceMemory<Eigen::half> &biases,
    327       dnn::ActivationMode activation_mode,
    328       const dnn::BatchDescriptor &output_descriptor,
    329       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator);
    330 
    331   Stream &ThenFusedConvolveWithScratch(
    332       const dnn::BatchDescriptor &conv_input_descriptor,
    333       const DeviceMemory<float> &conv_input_data, float conv_input_scale,
    334       const dnn::FilterDescriptor &filter_descriptor,
    335       const DeviceMemory<float> &filter_data,
    336       const dnn::ConvolutionDescriptor &convolution_descriptor,
    337       const DeviceMemory<float> &side_input_data, float side_input_scale,
    338       const dnn::BatchDescriptor &bias_descriptor,
    339       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    340       const dnn::BatchDescriptor &output_descriptor,
    341       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator);
    342 
    343   Stream &ThenConvolveWithScratch(
    344       const dnn::BatchDescriptor &input_descriptor,
    345       const DeviceMemory<Eigen::half> &input_data,
    346       const dnn::FilterDescriptor &filter_descriptor,
    347       const DeviceMemory<Eigen::half> &filter_data,
    348       const dnn::ConvolutionDescriptor &convolution_descriptor,
    349       const dnn::BatchDescriptor &output_descriptor,
    350       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator);
    351 
    352   Stream &ThenConvolveWithScratch(
    353       const dnn::BatchDescriptor &input_descriptor,
    354       const DeviceMemory<float> &input_data,
    355       const dnn::FilterDescriptor &filter_descriptor,
    356       const DeviceMemory<float> &filter_data,
    357       const dnn::ConvolutionDescriptor &convolution_descriptor,
    358       const dnn::BatchDescriptor &output_descriptor,
    359       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator);
    360 
    361   Stream &ThenConvolveWithAlgorithm(
    362       const dnn::BatchDescriptor &input_descriptor,
    363       const DeviceMemory<float> &input_data,
    364       const dnn::FilterDescriptor &filter_descriptor,
    365       const DeviceMemory<float> &filter_data,
    366       const dnn::ConvolutionDescriptor &convolution_descriptor,
    367       const dnn::BatchDescriptor &output_descriptor,
    368       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
    369       const dnn::AlgorithmConfig &algorithm_config,
    370       dnn::ProfileResult *output_profile_result);
    371 
    372   Stream &ThenConvolveWithAlgorithm(
    373       const dnn::BatchDescriptor &input_descriptor,
    374       const DeviceMemory<Eigen::half> &input_data,
    375       const dnn::FilterDescriptor &filter_descriptor,
    376       const DeviceMemory<Eigen::half> &filter_data,
    377       const dnn::ConvolutionDescriptor &convolution_descriptor,
    378       const dnn::BatchDescriptor &output_descriptor,
    379       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
    380       const dnn::AlgorithmConfig &algorithm_config,
    381       dnn::ProfileResult *output_profile_result);
    382 
    383   Stream &ThenFusedConvolveWithAlgorithm(
    384       const dnn::BatchDescriptor &conv_input_descriptor,
    385       const DeviceMemory<double> &conv_input_data, double conv_input_scale,
    386       const dnn::FilterDescriptor &filter_descriptor,
    387       const DeviceMemory<double> &filter_data,
    388       const dnn::ConvolutionDescriptor &convolution_descriptor,
    389       const DeviceMemory<double> &side_input_data, double side_input_scale,
    390       const dnn::BatchDescriptor &bias_descriptor,
    391       const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
    392       const dnn::BatchDescriptor &output_descriptor,
    393       DeviceMemory<double> *output, ScratchAllocator *scratch_allocator,
    394       const dnn::AlgorithmConfig &algorithm_config,
    395       dnn::ProfileResult *output_profile_result);
    396 
    397   Stream &ThenFusedConvolveWithAlgorithm(
    398       const dnn::BatchDescriptor &conv_input_descriptor,
    399       const DeviceMemory<float> &conv_input_data, float conv_input_scale,
    400       const dnn::FilterDescriptor &filter_descriptor,
    401       const DeviceMemory<float> &filter_data,
    402       const dnn::ConvolutionDescriptor &convolution_descriptor,
    403       const DeviceMemory<float> &side_input_data, float side_input_scale,
    404       const dnn::BatchDescriptor &bias_descriptor,
    405       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    406       const dnn::BatchDescriptor &output_descriptor,
    407       DeviceMemory<float> *output, ScratchAllocator *scratch_allocator,
    408       const dnn::AlgorithmConfig &algorithm_config,
    409       dnn::ProfileResult *output_profile_result);
    410 
    411   Stream &ThenFusedConvolveWithAlgorithm(
    412       const dnn::BatchDescriptor &conv_input_descriptor,
    413       const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
    414       const dnn::FilterDescriptor &filter_descriptor,
    415       const DeviceMemory<Eigen::half> &filter_data,
    416       const dnn::ConvolutionDescriptor &convolution_descriptor,
    417       const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
    418       const dnn::BatchDescriptor &bias_descriptor,
    419       const DeviceMemory<Eigen::half> &biases,
    420       dnn::ActivationMode activation_mode,
    421       const dnn::BatchDescriptor &output_descriptor,
    422       DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
    423       const dnn::AlgorithmConfig &algorithm_config,
    424       dnn::ProfileResult *output_profile_result);
    425 
    426   Stream &ThenFusedConvolveWithAlgorithm(
    427       const dnn::BatchDescriptor &conv_input_descriptor,
    428       const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
    429       const dnn::FilterDescriptor &filter_descriptor,
    430       const DeviceMemory<int8> &filter_data,
    431       const dnn::ConvolutionDescriptor &convolution_descriptor,
    432       const DeviceMemory<int8> &side_input_data, float side_input_scale,
    433       const dnn::BatchDescriptor &bias_descriptor,
    434       const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    435       const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
    436       ScratchAllocator *scratch_allocator,
    437       const dnn::AlgorithmConfig &algorithm_config,
    438       dnn::ProfileResult *output_profile_result);
    439 
    440   Stream &ThenSeparableConvolve(
    441       const dnn::BatchDescriptor &input_descriptor,
    442       const DeviceMemory<float> &input_data,
    443       const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
    444       const DeviceMemory<float> &first_weights,
    445       const DeviceMemory<float> &second_weights,
    446       const dnn::ConvolutionDescriptor &convolution_descriptor,
    447       const dnn::BatchDescriptor &output_descriptor,
    448       DeviceMemory<float> *output);
    449 
    450   Stream &ThenConvolveBackwardData(
    451       const dnn::FilterDescriptor &filter_descriptor,
    452       const DeviceMemory<float> &filter_data,
    453       const dnn::BatchDescriptor &output_descriptor,
    454       DeviceMemory<float> backward_output_data,
    455       const dnn::ConvolutionDescriptor &convolution_descriptor,
    456       const dnn::BatchDescriptor &input_descriptor,
    457       DeviceMemory<float> *backward_input_data);
    458 
    459   Stream &ThenConvolveBackwardDataWithScratch(
    460       const dnn::FilterDescriptor &filter_descriptor,
    461       const DeviceMemory<float> &filter_data,
    462       const dnn::BatchDescriptor &output_descriptor,
    463       DeviceMemory<float> backward_output_data,
    464       const dnn::ConvolutionDescriptor &convolution_descriptor,
    465       const dnn::BatchDescriptor &input_descriptor,
    466       DeviceMemory<float> *backward_input_data,
    467       ScratchAllocator *scratch_allocator);
    468 
    469   Stream &ThenConvolveBackwardDataWithScratch(
    470       const dnn::FilterDescriptor &filter_descriptor,
    471       const DeviceMemory<Eigen::half> &filter_data,
    472       const dnn::BatchDescriptor &output_descriptor,
    473       DeviceMemory<Eigen::half> backward_output_data,
    474       const dnn::ConvolutionDescriptor &convolution_descriptor,
    475       const dnn::BatchDescriptor &input_descriptor,
    476       DeviceMemory<Eigen::half> *backward_input_data,
    477       ScratchAllocator *scratch_allocator);
    478 
    479   Stream &ThenConvolveBackwardDataWithAlgorithm(
    480       const dnn::FilterDescriptor &filter_descriptor,
    481       const DeviceMemory<float> &filter_data,
    482       const dnn::BatchDescriptor &output_descriptor,
    483       DeviceMemory<float> backward_output_data,
    484       const dnn::ConvolutionDescriptor &convolution_descriptor,
    485       const dnn::BatchDescriptor &input_descriptor,
    486       DeviceMemory<float> *backward_input_data,
    487       ScratchAllocator *scratch_allocator,
    488       const dnn::AlgorithmConfig &algorithm_config,
    489       dnn::ProfileResult *output_profile_result);
    490 
    491   Stream &ThenConvolveBackwardDataWithAlgorithm(
    492       const dnn::FilterDescriptor &filter_descriptor,
    493       const DeviceMemory<Eigen::half> &filter_data,
    494       const dnn::BatchDescriptor &output_descriptor,
    495       DeviceMemory<Eigen::half> backward_output_data,
    496       const dnn::ConvolutionDescriptor &convolution_descriptor,
    497       const dnn::BatchDescriptor &input_descriptor,
    498       DeviceMemory<Eigen::half> *backward_input_data,
    499       ScratchAllocator *scratch_allocator,
    500       const dnn::AlgorithmConfig &algorithm_config,
    501       dnn::ProfileResult *output_profile_result);
    502 
    503   Stream &ThenConvolveBackwardFilter(
    504       const dnn::BatchDescriptor &input_descriptor,
    505       const DeviceMemory<float> &input_data,
    506       const dnn::BatchDescriptor &output_descriptor,
    507       DeviceMemory<float> backward_output_data,
    508       const dnn::ConvolutionDescriptor &convolution_descriptor,
    509       const dnn::FilterDescriptor &filter_descriptor,
    510       DeviceMemory<float> *backward_filter_data);
    511 
    512   Stream &ThenConvolveBackwardFilterWithScratch(
    513       const dnn::BatchDescriptor &input_descriptor,
    514       const DeviceMemory<float> &input_data,
    515       const dnn::BatchDescriptor &output_descriptor,
    516       DeviceMemory<float> backward_output_data,
    517       const dnn::ConvolutionDescriptor &convolution_descriptor,
    518       const dnn::FilterDescriptor &filter_descriptor,
    519       DeviceMemory<float> *backward_filter_data,
    520       ScratchAllocator *scratch_allocator);
    521 
    522   Stream &ThenConvolveBackwardFilterWithScratch(
    523       const dnn::BatchDescriptor &input_descriptor,
    524       const DeviceMemory<Eigen::half> &input_data,
    525       const dnn::BatchDescriptor &output_descriptor,
    526       DeviceMemory<Eigen::half> backward_output_data,
    527       const dnn::ConvolutionDescriptor &convolution_descriptor,
    528       const dnn::FilterDescriptor &filter_descriptor,
    529       DeviceMemory<Eigen::half> *backward_filter_data,
    530       ScratchAllocator *scratch_allocator);
    531 
    532   Stream &ThenConvolveBackwardFilterWithAlgorithm(
    533       const dnn::BatchDescriptor &input_descriptor,
    534       const DeviceMemory<float> &input_data,
    535       const dnn::BatchDescriptor &output_descriptor,
    536       DeviceMemory<float> backward_output_data,
    537       const dnn::ConvolutionDescriptor &convolution_descriptor,
    538       const dnn::FilterDescriptor &filter_descriptor,
    539       DeviceMemory<float> *backward_filter_data,
    540       ScratchAllocator *scratch_allocator,
    541       const dnn::AlgorithmConfig &algorithm_config,
    542       dnn::ProfileResult *output_profile_result);
    543 
    544   Stream &ThenConvolveBackwardFilterWithAlgorithm(
    545       const dnn::BatchDescriptor &input_descriptor,
    546       const DeviceMemory<Eigen::half> &input_data,
    547       const dnn::BatchDescriptor &output_descriptor,
    548       DeviceMemory<Eigen::half> backward_output_data,
    549       const dnn::ConvolutionDescriptor &convolution_descriptor,
    550       const dnn::FilterDescriptor &filter_descriptor,
    551       DeviceMemory<Eigen::half> *backward_filter_data,
    552       ScratchAllocator *scratch_allocator,
    553       const dnn::AlgorithmConfig &algorithm_config,
    554       dnn::ProfileResult *output_profile_result);
    555 
    556   Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
    557                                    const DeviceMemory<double> &input_data,
    558                                    const dnn::BatchDescriptor &bias_descriptor,
    559                                    DeviceMemory<double> *backward_bias_data);
    560 
    561   Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
    562                                    const DeviceMemory<float> &input_data,
    563                                    const dnn::BatchDescriptor &bias_descriptor,
    564                                    DeviceMemory<float> *backward_bias_data);
    565 
    566   Stream &ThenConvolveBackwardBias(
    567       const dnn::BatchDescriptor &input_descriptor,
    568       const DeviceMemory<Eigen::half> &input_data,
    569       const dnn::BatchDescriptor &bias_descriptor,
    570       DeviceMemory<Eigen::half> *backward_bias_data);
    571 
    572   Stream &ThenMatMul(const DeviceMemory<float> &input_data,
    573                      const DeviceMemory<float> &weights,
    574                      const dnn::BatchDescriptor &input_dimensions,
    575                      const dnn::BatchDescriptor &output_dimensions,
    576                      DeviceMemory<float> *output_data);
    577 
    578   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
    579                               const DeviceMemory<int8> &weights,
    580                               const DeviceMemory<float> &weight_scales,
    581                               const dnn::BatchDescriptor &input_dimensions,
    582                               const dnn::BatchDescriptor &output_dimensions,
    583                               DeviceMemory<float> *output_data);
    584 
    585   Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
    586                               const DeviceMemory<int16> &weights,
    587                               const DeviceMemory<float> &weight_scales,
    588                               const dnn::BatchDescriptor &input_dimensions,
    589                               const dnn::BatchDescriptor &output_dimensions,
    590                               DeviceMemory<float> *output_data);
    591 
    592   Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
    593                       const DeviceMemory<float> &biases,
    594                       const dnn::BatchDescriptor &dimensions,
    595                       DeviceMemory<float> *output_data);
    596 
    597   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
    598                           const dnn::BatchDescriptor &input_dimensions,
    599                           const DeviceMemory<double> &input_data,
    600                           const dnn::BatchDescriptor &output_dimensions,
    601                           DeviceMemory<double> *output_data);
    602 
    603   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
    604                           const dnn::BatchDescriptor &input_dimensions,
    605                           const DeviceMemory<float> &input_data,
    606                           const dnn::BatchDescriptor &output_dimensions,
    607                           DeviceMemory<float> *output_data);
    608 
    609   Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
    610                           const dnn::BatchDescriptor &input_dimensions,
    611                           const DeviceMemory<Eigen::half> &input_data,
    612                           const dnn::BatchDescriptor &output_dimensions,
    613                           DeviceMemory<Eigen::half> *output_data);
    614 
    615   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
    616                            const dnn::BatchDescriptor &input_dimensions,
    617                            const DeviceMemory<double> &input_data,
    618                            const dnn::BatchDescriptor &output_dimensions,
    619                            const DeviceMemory<double> &output_data,
    620                            const DeviceMemory<double> &input_diff_data,
    621                            DeviceMemory<double> *output_diff_data);
    622 
    623   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
    624                            const dnn::BatchDescriptor &input_dimensions,
    625                            const DeviceMemory<float> &input_data,
    626                            const dnn::BatchDescriptor &output_dimensions,
    627                            const DeviceMemory<float> &output_data,
    628                            const DeviceMemory<float> &input_diff_data,
    629                            DeviceMemory<float> *output_diff_data);
    630 
    631   Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
    632                            const dnn::BatchDescriptor &input_dimensions,
    633                            const DeviceMemory<Eigen::half> &input_data,
    634                            const dnn::BatchDescriptor &output_dimensions,
    635                            const DeviceMemory<Eigen::half> &output_data,
    636                            const DeviceMemory<Eigen::half> &input_diff_data,
    637                            DeviceMemory<Eigen::half> *output_diff_data);
    638 
    639   Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor,
    640                         const DeviceMemory<float> &input_data,
    641                         DeviceMemory<float> *output_data);
    642 
    643   // Similar to ThenNormalize, but normalizes across feature maps and allows for
    644   // specifying the dimensions of the tensor.
    645   Stream &ThenNormalizeWithDimensions(
    646       const dnn::NormalizeDescriptor &normalize_descriptor,
    647       const dnn::BatchDescriptor &dimensions,
    648       const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data);
    649 
    650   Stream &ThenNormalizeBackwardWithDimensions(
    651       const dnn::NormalizeDescriptor &normalize_descriptor,
    652       const dnn::BatchDescriptor &dimensions,
    653       const DeviceMemory<float> &raw_data,
    654       const DeviceMemory<float> &normalized_data,
    655       const DeviceMemory<float> &normalized_variable_gradient,
    656       DeviceMemory<float> *raw_variable_gradient);
    657 
    658   Stream &ThenActivate(dnn::ActivationMode activation_mode,
    659                        const dnn::BatchDescriptor &dimensions,
    660                        const DeviceMemory<float> &input_data,
    661                        DeviceMemory<float> *output_data);
    662 
    663   // Same as ThenActivate, but also takes an options argument that can be used
    664   // for platform-specific option flags.
    665   Stream &ThenActivateWithOptions(dnn::ActivationMode activation_mode,
    666                                   const dnn::BatchDescriptor &dimensions,
    667                                   const DeviceMemory<float> &input_data,
    668                                   DeviceMemory<float> *output_data,
    669                                   uint64 options);
    670 
    671   Stream &ThenDepthConcatenate(
    672       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
    673       port::ArraySlice<const DeviceMemory<float> *> input_data,
    674       DeviceMemory<float> *output_data);
    675 
    676   Stream &ThenSpaceConcatenate(
    677       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
    678       port::ArraySlice<const DeviceMemory<float> *> input_data,
    679       DeviceMemory<float> *output_data,
    680       dnn::SpaceConcatenateMode concat_direction);
    681 
    682   // Change the layout of the data by shrinking one dimension (or set of
    683   // dimensions) and growing another dimension (or set of dimensions), while
    684   // keeping the total number of data elements constant, and maintaining the
    685   // current data ordering.
    686   Stream &ThenReshape(const dnn::BatchDescriptor &input_dimensions,
    687                       const DeviceMemory<float> &input_data,
    688                       const dnn::BatchDescriptor &output_dimensions,
    689                       DeviceMemory<float> *output_data);
    690 
    691   // Depth to space takes an X by Y image with depth D*M and changes it to an
    692   // MX x MY image with depth D. Each input location (x,y) with depth D*M in
    693   // the input image is changed to an MxM contiguous area in the output image,
    694   // with the values being laid out in raster order specified by
    695   // DepthToSpaceLayout, and will have a new depth of D.
    696   // See the DoDepthToSpace comment for more information.
    697   Stream &ThenDepthToSpace(const dnn::BatchDescriptor &input_dimensions,
    698                            const DeviceMemory<float> &input_data,
    699                            const dnn::DepthToSpaceLayout &depth_to_space_layout,
    700                            const int sqrt_depth_reduction,
    701                            DeviceMemory<float> *output_data);
    702 
    703   // Space to depth is the inverse of depth to space. Space to depth takes each
    704   // non-overlapping M by M patch (in the X and Y dimensions) with depth D of
    705   // the input, and transforms it to a 1 by 1 patch with depth D*M. If the
    706   // input has size (MX, MY, D), the output has size (X, Y, D*M). The number of
    707   // data elements is not changed.
    708   Stream &ThenSpaceToDepth(const dnn::BatchDescriptor &input_dimensions,
    709                            const DeviceMemory<float> &input_data,
    710                            const dnn::DepthToSpaceLayout &space_to_depth_layout,
    711                            const int sqrt_depth_increase,
    712                            DeviceMemory<float> *output_data);
    713 
    714   Stream &ThenElementwiseOperate(
    715       dnn::ElementwiseOperation operation,
    716       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
    717       port::ArraySlice<const DeviceMemory<float> *> input_data,
    718       const dnn::BatchDescriptor &output_dimensions,
    719       DeviceMemory<float> *output_data);
    720 
    721   Stream &ThenElementwiseOperateScaledQuantized(
    722       dnn::ElementwiseOperation operation,
    723       port::ArraySlice<int> input_multiplicands, int output_divisor,
    724       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
    725       port::ArraySlice<const DeviceMemory<float> *> input_data,
    726       const dnn::BatchDescriptor &output_dimensions,
    727       DeviceMemory<float> *output_data);
    728 
    729   Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
    730                     const DeviceMemory<float> &input_data, int64 left_pad,
    731                     int64 right_pad, int64 top_pad, int64 bottom_pad,
    732                     DeviceMemory<float> *output_data);
    733 
    734   Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
    735                       const DeviceMemory<float> &input_data, int64 left_trim,
    736                       int64 right_trim, int64 top_trim, int64 bottom_trim,
    737                       DeviceMemory<float> *output_data);
    738 
    739   // Grows the input tensor by replicating the X and Y dimensions. The batch and
    740   // depth/feature_map dimensions are unchanged. Currently, the input tensor is
    741   // limited to X=1 and Y=1.
    742   Stream &ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
    743                           const DeviceMemory<float> &input_data,
    744                           int64 replicate_x, int64 replicate_y,
    745                           DeviceMemory<float> *output_data);
    746 
    747   // See DnnSupport::DoMemcpyD2HQuantized.
    748   Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
    749                                  dnn::QuantizedActivationMode mode,
    750                                  void *host_dst, uint64 size);
    751 
    752   // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
    753   // and uses the Quantization trait to call the generic version of
    754   // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
    755   template <typename ElementType>
    756   Stream &ThenMemcpyD2HQuantized(
    757       const DeviceMemory<float> &gpu_unquantized_src,
    758       port::MutableArraySlice<ElementType> host_dst) {
    759     return ThenMemcpyD2HQuantized(
    760         gpu_unquantized_src, Quantization<ElementType>::kModeId,
    761         host_dst.data(), host_dst.size() * sizeof(ElementType));
    762   }
    763 
    764   // See DnnSupport::DoMemcpyH2DQuantized.
    765   Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
    766                                  dnn::QuantizedActivationMode mode,
    767                                  DeviceMemory<float> *gpu_unquantized_dst);
    768 
    769   // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
    770   // and uses the Quantization trait to call the generic version of
    771   // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
    772   template <typename ElementType>
    773   Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
    774                                  DeviceMemory<float> *gpu_unquantized_dst) {
    775     return ThenMemcpyH2DQuantized(
    776         host_src.data(), host_src.size() * sizeof(ElementType),
    777         Quantization<ElementType>::kModeId, gpu_unquantized_dst);
    778   }
    779 
    780   // See DnnSupport::DoCopyHostBuffer2Device.
    781   Stream &ThenCopyHostBuffer2Device(HostBuffer *buffer_src,
    782                                     DeviceMemory<float> *gpu_unquantized_dst);
    783 
    784   // See DnnSupport::DoCopyDevice2HostBuffer.
    785   Stream &ThenCopyDevice2HostBuffer(
    786       const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst);
    787 
    788   /////////////////
    789   // BLAS support
    790 
    791   // See BlasSupport::DoBlasAsum.
    792   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
    793                        int incx, DeviceMemory<float> *result);
    794   Stream &ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
    795                        int incx, DeviceMemory<double> *result);
    796   Stream &ThenBlasAsum(uint64 elem_count,
    797                        const DeviceMemory<std::complex<float>> &x, int incx,
    798                        DeviceMemory<float> *result);
    799   Stream &ThenBlasAsum(uint64 elem_count,
    800                        const DeviceMemory<std::complex<double>> &x, int incx,
    801                        DeviceMemory<double> *result);
    802 
    803   // See BlasSupport::DoBlasAxpy. Note that, even for the case where alpha is
    804   // present in DeviceMemory, it must be an execution-time constant (i.e. a
    805   // value
    806   // that the stream does not change or populate during the course of
    807   // execution). The value is effectively captured at stream-enqueue time.
    808   Stream &ThenBlasAxpy(uint64 elem_count, float alpha,
    809                        const DeviceMemory<float> &x, int incx,
    810                        DeviceMemory<float> *y, int incy);
    811   Stream &ThenBlasAxpy(uint64 elem_count, double alpha,
    812                        const DeviceMemory<double> &x, int incx,
    813                        DeviceMemory<double> *y, int incy);
    814   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
    815                        const DeviceMemory<std::complex<float>> &x, int incx,
    816                        DeviceMemory<std::complex<float>> *y, int incy);
    817   Stream &ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
    818                        const DeviceMemory<std::complex<double>> &x, int incx,
    819                        DeviceMemory<std::complex<double>> *y, int incy);
    820 
    821   // See BlasSupport::DoBlasCopy.
    822   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
    823                        int incx, DeviceMemory<float> *y, int incy);
    824   Stream &ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
    825                        int incx, DeviceMemory<double> *y, int incy);
    826   Stream &ThenBlasCopy(uint64 elem_count,
    827                        const DeviceMemory<std::complex<float>> &x, int incx,
    828                        DeviceMemory<std::complex<float>> *y, int incy);
    829   Stream &ThenBlasCopy(uint64 elem_count,
    830                        const DeviceMemory<std::complex<double>> &x, int incx,
    831                        DeviceMemory<std::complex<double>> *y, int incy);
    832 
    833   // See BlasSupport::DoBlasDot.
    834   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x, int incx,
    835                       const DeviceMemory<float> &y, int incy,
    836                       DeviceMemory<float> *result);
    837   Stream &ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
    838                       int incx, const DeviceMemory<double> &y, int incy,
    839                       DeviceMemory<double> *result);
    840 
    841   // See BlasSupport::DoBlasDotc.
    842   Stream &ThenBlasDotc(uint64 elem_count,
    843                        const DeviceMemory<std::complex<float>> &x, int incx,
    844                        const DeviceMemory<std::complex<float>> &y, int incy,
    845                        DeviceMemory<std::complex<float>> *result);
    846   Stream &ThenBlasDotc(uint64 elem_count,
    847                        const DeviceMemory<std::complex<double>> &x, int incx,
    848                        const DeviceMemory<std::complex<double>> &y, int incy,
    849                        DeviceMemory<std::complex<double>> *result);
    850 
    851   // See BlasSupport::DoBlasDotu.
    852   Stream &ThenBlasDotu(uint64 elem_count,
    853                        const DeviceMemory<std::complex<float>> &x, int incx,
    854                        const DeviceMemory<std::complex<float>> &y, int incy,
    855                        DeviceMemory<std::complex<float>> *result);
    856   Stream &ThenBlasDotu(uint64 elem_count,
    857                        const DeviceMemory<std::complex<double>> &x, int incx,
    858                        const DeviceMemory<std::complex<double>> &y, int incy,
    859                        DeviceMemory<std::complex<double>> *result);
    860 
    861   // See BlasSupport::DoBlasNrm2.
    862   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
    863                        int incx, DeviceMemory<float> *result);
    864   Stream &ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
    865                        int incx, DeviceMemory<double> *result);
    866   Stream &ThenBlasNrm2(uint64 elem_count,
    867                        const DeviceMemory<std::complex<float>> &x, int incx,
    868                        DeviceMemory<float> *result);
    869   Stream &ThenBlasNrm2(uint64 elem_count,
    870                        const DeviceMemory<std::complex<double>> &x, int incx,
    871                        DeviceMemory<double> *result);
    872 
    873   // See BlasSupport::DoBlasRot.
    874   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
    875                       DeviceMemory<float> *y, int incy, float c, float s);
    876   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x, int incx,
    877                       DeviceMemory<double> *y, int incy, double c, double s);
    878   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
    879                       int incx, DeviceMemory<std::complex<float>> *y, int incy,
    880                       float c, float s);
    881   Stream &ThenBlasRot(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
    882                       int incx, DeviceMemory<std::complex<double>> *y, int incy,
    883                       double c, double s);
    884 
    885   // See BlasSupport::DoBlasRotg.
    886   Stream &ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
    887                        DeviceMemory<float> *c, DeviceMemory<float> *s);
    888   Stream &ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
    889                        DeviceMemory<double> *c, DeviceMemory<double> *s);
    890   Stream &ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
    891                        DeviceMemory<std::complex<float>> *b,
    892                        DeviceMemory<float> *c,
    893                        DeviceMemory<std::complex<float>> *s);
    894   Stream &ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
    895                        DeviceMemory<std::complex<double>> *b,
    896                        DeviceMemory<double> *c,
    897                        DeviceMemory<std::complex<double>> *s);
    898 
    899   // See BlasSupport::DoBlasRotm.
    900   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x, int incx,
    901                        DeviceMemory<float> *y, int incy,
    902                        const DeviceMemory<float> &param);
    903   Stream &ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x, int incx,
    904                        DeviceMemory<double> *y, int incy,
    905                        const DeviceMemory<double> &param);
    906 
    907   // See BlasSupport::DoBlasRotmg.
    908   Stream &ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
    909                         DeviceMemory<float> *x1, const DeviceMemory<float> &y1,
    910                         DeviceMemory<float> *param);
    911   Stream &ThenBlasRotmg(DeviceMemory<double> *d1, DeviceMemory<double> *d2,
    912                         DeviceMemory<double> *x1,
    913                         const DeviceMemory<double> &y1,
    914                         DeviceMemory<double> *param);
    915 
    916   // See BlasSupport::DoBlasScal.
    917   Stream &ThenBlasScal(uint64 elem_count, float alpha, DeviceMemory<float> *x,
    918                        int incx);
    919   Stream &ThenBlasScal(uint64 elem_count, double alpha, DeviceMemory<double> *x,
    920                        int incx);
    921   Stream &ThenBlasScal(uint64 elem_count, float alpha,
    922                        DeviceMemory<std::complex<float>> *x, int incx);
    923   Stream &ThenBlasScal(uint64 elem_count, double alpha,
    924                        DeviceMemory<std::complex<double>> *x, int incx);
    925   Stream &ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
    926                        DeviceMemory<std::complex<float>> *x, int incx);
    927   Stream &ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
    928                        DeviceMemory<std::complex<double>> *x, int incx);
    929 
    930   // See BlasSupport::DoBlasSwap.
    931   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x, int incx,
    932                        DeviceMemory<float> *y, int incy);
    933   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x, int incx,
    934                        DeviceMemory<double> *y, int incy);
    935   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<float>> *x,
    936                        int incx, DeviceMemory<std::complex<float>> *y,
    937                        int incy);
    938   Stream &ThenBlasSwap(uint64 elem_count, DeviceMemory<std::complex<double>> *x,
    939                        int incx, DeviceMemory<std::complex<double>> *y,
    940                        int incy);
    941 
    942   // See BlasSupport::DoBlasIamax.
    943   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
    944                         int incx, DeviceMemory<int> *result);
    945   Stream &ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
    946                         int incx, DeviceMemory<int> *result);
    947   Stream &ThenBlasIamax(uint64 elem_count,
    948                         const DeviceMemory<std::complex<float>> &x, int incx,
    949                         DeviceMemory<int> *result);
    950   Stream &ThenBlasIamax(uint64 elem_count,
    951                         const DeviceMemory<std::complex<double>> &x, int incx,
    952                         DeviceMemory<int> *result);
    953 
    954   // See BlasSupport::DoBlasIamin.
    955   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
    956                         int incx, DeviceMemory<int> *result);
    957   Stream &ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
    958                         int incx, DeviceMemory<int> *result);
    959   Stream &ThenBlasIamin(uint64 elem_count,
    960                         const DeviceMemory<std::complex<float>> &x, int incx,
    961                         DeviceMemory<int> *result);
    962   Stream &ThenBlasIamin(uint64 elem_count,
    963                         const DeviceMemory<std::complex<double>> &x, int incx,
    964                         DeviceMemory<int> *result);
    965 
    966   // See BlasSupport::DoBlasGbmv.
    967   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
    968                        uint64 ku, float alpha, const DeviceMemory<float> &a,
    969                        int lda, const DeviceMemory<float> &x, int incx,
    970                        float beta, DeviceMemory<float> *y, int incy);
    971   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
    972                        uint64 ku, double alpha, const DeviceMemory<double> &a,
    973                        int lda, const DeviceMemory<double> &x, int incx,
    974                        double beta, DeviceMemory<double> *y, int incy);
    975   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
    976                        uint64 ku, std::complex<float> alpha,
    977                        const DeviceMemory<std::complex<float>> &a, int lda,
    978                        const DeviceMemory<std::complex<float>> &x, int incx,
    979                        std::complex<float> beta,
    980                        DeviceMemory<std::complex<float>> *y, int incy);
    981   Stream &ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n, uint64 kl,
    982                        uint64 ku, std::complex<double> alpha,
    983                        const DeviceMemory<std::complex<double>> &a, int lda,
    984                        const DeviceMemory<std::complex<double>> &x, int incx,
    985                        std::complex<double> beta,
    986                        DeviceMemory<std::complex<double>> *y, int incy);
    987 
    988   // See BlasSupport::DoBlasGemv.
    989   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, float alpha,
    990                        const DeviceMemory<float> &a, int lda,
    991                        const DeviceMemory<float> &x, int incx, float beta,
    992                        DeviceMemory<float> *y, int incy);
    993   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n, double alpha,
    994                        const DeviceMemory<double> &a, int lda,
    995                        const DeviceMemory<double> &x, int incx, double beta,
    996                        DeviceMemory<double> *y, int incy);
    997   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
    998                        std::complex<float> alpha,
    999                        const DeviceMemory<std::complex<float>> &a, int lda,
   1000                        const DeviceMemory<std::complex<float>> &x, int incx,
   1001                        std::complex<float> beta,
   1002                        DeviceMemory<std::complex<float>> *y, int incy);
   1003   Stream &ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   1004                        std::complex<double> alpha,
   1005                        const DeviceMemory<std::complex<double>> &a, int lda,
   1006                        const DeviceMemory<std::complex<double>> &x, int incx,
   1007                        std::complex<double> beta,
   1008                        DeviceMemory<std::complex<double>> *y, int incy);
   1009 
   1010   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
   1011                                     float alpha, const DeviceMemory<float> &a,
   1012                                     int lda, const DeviceMemory<float> &x,
   1013                                     int incx, float beta,
   1014                                     DeviceMemory<float> *y, int incy,
   1015                                     blas::ProfileResult *output_profile_result);
   1016   Stream &ThenBlasGemvWithProfiling(blas::Transpose trans, uint64 m, uint64 n,
   1017                                     double alpha, const DeviceMemory<double> &a,
   1018                                     int lda, const DeviceMemory<double> &x,
   1019                                     int incx, double beta,
   1020                                     DeviceMemory<double> *y, int incy,
   1021                                     blas::ProfileResult *output_profile_result);
   1022   Stream &ThenBlasGemvWithProfiling(
   1023       blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
   1024       const DeviceMemory<std::complex<float>> &a, int lda,
   1025       const DeviceMemory<std::complex<float>> &x, int incx,
   1026       std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
   1027       blas::ProfileResult *output_profile_result);
   1028   Stream &ThenBlasGemvWithProfiling(
   1029       blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
   1030       const DeviceMemory<std::complex<double>> &a, int lda,
   1031       const DeviceMemory<std::complex<double>> &x, int incx,
   1032       std::complex<double> beta, DeviceMemory<std::complex<double>> *y,
   1033       int incy, blas::ProfileResult *output_profile_result);
   1034 
   1035   // See BlasSupport::DoBlasGer.
   1036   Stream &ThenBlasGer(uint64 m, uint64 n, float alpha,
   1037                       const DeviceMemory<float> &x, int incx,
   1038                       const DeviceMemory<float> &y, int incy,
   1039                       DeviceMemory<float> *a, int lda);
   1040   Stream &ThenBlasGer(uint64 m, uint64 n, double alpha,
   1041                       const DeviceMemory<double> &x, int incx,
   1042                       const DeviceMemory<double> &y, int incy,
   1043                       DeviceMemory<double> *a, int lda);
   1044 
   1045   // See BlasSupport::DoBlasGerc.
   1046   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
   1047                        const DeviceMemory<std::complex<float>> &x, int incx,
   1048                        const DeviceMemory<std::complex<float>> &y, int incy,
   1049                        DeviceMemory<std::complex<float>> *a, int lda);
   1050   Stream &ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
   1051                        const DeviceMemory<std::complex<double>> &x, int incx,
   1052                        const DeviceMemory<std::complex<double>> &y, int incy,
   1053                        DeviceMemory<std::complex<double>> *a, int lda);
   1054 
   1055   // See BlasSupport::DoBlasGeru.
   1056   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
   1057                        const DeviceMemory<std::complex<float>> &x, int incx,
   1058                        const DeviceMemory<std::complex<float>> &y, int incy,
   1059                        DeviceMemory<std::complex<float>> *a, int lda);
   1060   Stream &ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
   1061                        const DeviceMemory<std::complex<double>> &x, int incx,
   1062                        const DeviceMemory<std::complex<double>> &y, int incy,
   1063                        DeviceMemory<std::complex<double>> *a, int lda);
   1064 
   1065   // See BlasSupport::DoBlasHbmv.
   1066   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   1067                        std::complex<float> alpha,
   1068                        const DeviceMemory<std::complex<float>> &a, int lda,
   1069                        const DeviceMemory<std::complex<float>> &x, int incx,
   1070                        std::complex<float> beta,
   1071                        DeviceMemory<std::complex<float>> *y, int incy);
   1072   Stream &ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   1073                        std::complex<double> alpha,
   1074                        const DeviceMemory<std::complex<double>> &a, int lda,
   1075                        const DeviceMemory<std::complex<double>> &x, int incx,
   1076                        std::complex<double> beta,
   1077                        DeviceMemory<std::complex<double>> *y, int incy);
   1078 
   1079   // See BlasSupport::DoBlasHemv.
   1080   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
   1081                        std::complex<float> alpha,
   1082                        const DeviceMemory<std::complex<float>> &a, int lda,
   1083                        const DeviceMemory<std::complex<float>> &x, int incx,
   1084                        std::complex<float> beta,
   1085                        DeviceMemory<std::complex<float>> *y, int incy);
   1086   Stream &ThenBlasHemv(blas::UpperLower uplo, uint64 n,
   1087                        std::complex<double> alpha,
   1088                        const DeviceMemory<std::complex<double>> &a, int lda,
   1089                        const DeviceMemory<std::complex<double>> &x, int incx,
   1090                        std::complex<double> beta,
   1091                        DeviceMemory<std::complex<double>> *y, int incy);
   1092 
   1093   // See BlasSupport::DoBlasHer.
   1094   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
   1095                       const DeviceMemory<std::complex<float>> &x, int incx,
   1096                       DeviceMemory<std::complex<float>> *a, int lda);
   1097   Stream &ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
   1098                       const DeviceMemory<std::complex<double>> &x, int incx,
   1099                       DeviceMemory<std::complex<double>> *a, int lda);
   1100 
   1101   // See BlasSupport::DoBlasHer2.
   1102   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
   1103                        std::complex<float> alpha,
   1104                        const DeviceMemory<std::complex<float>> &x, int incx,
   1105                        const DeviceMemory<std::complex<float>> &y, int incy,
   1106                        DeviceMemory<std::complex<float>> *a, int lda);
   1107   Stream &ThenBlasHer2(blas::UpperLower uplo, uint64 n,
   1108                        std::complex<double> alpha,
   1109                        const DeviceMemory<std::complex<double>> &x, int incx,
   1110                        const DeviceMemory<std::complex<double>> &y, int incy,
   1111                        DeviceMemory<std::complex<double>> *a, int lda);
   1112 
   1113   // See BlasSupport::DoBlasHpmv.
   1114   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
   1115                        std::complex<float> alpha,
   1116                        const DeviceMemory<std::complex<float>> &ap,
   1117                        const DeviceMemory<std::complex<float>> &x, int incx,
   1118                        std::complex<float> beta,
   1119                        DeviceMemory<std::complex<float>> *y, int incy);
   1120   Stream &ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
   1121                        std::complex<double> alpha,
   1122                        const DeviceMemory<std::complex<double>> &ap,
   1123                        const DeviceMemory<std::complex<double>> &x, int incx,
   1124                        std::complex<double> beta,
   1125                        DeviceMemory<std::complex<double>> *y, int incy);
   1126 
   1127   // See BlasSupport::DoBlasHpr.
   1128   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
   1129                       const DeviceMemory<std::complex<float>> &x, int incx,
   1130                       DeviceMemory<std::complex<float>> *ap);
   1131   Stream &ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
   1132                       const DeviceMemory<std::complex<double>> &x, int incx,
   1133                       DeviceMemory<std::complex<double>> *ap);
   1134 
   1135   // See BlasSupport::DoBlasHpr2.
   1136   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
   1137                        std::complex<float> alpha,
   1138                        const DeviceMemory<std::complex<float>> &x, int incx,
   1139                        const DeviceMemory<std::complex<float>> &y, int incy,
   1140                        DeviceMemory<std::complex<float>> *ap);
   1141   Stream &ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
   1142                        std::complex<double> alpha,
   1143                        const DeviceMemory<std::complex<double>> &x, int incx,
   1144                        const DeviceMemory<std::complex<double>> &y, int incy,
   1145                        DeviceMemory<std::complex<double>> *ap);
   1146 
   1147   // See BlasSupport::DoBlasSbmv.
   1148   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, float alpha,
   1149                        const DeviceMemory<float> &a, int lda,
   1150                        const DeviceMemory<float> &x, int incx, float beta,
   1151                        DeviceMemory<float> *y, int incy);
   1152   Stream &ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k, double alpha,
   1153                        const DeviceMemory<double> &a, int lda,
   1154                        const DeviceMemory<double> &x, int incx, double beta,
   1155                        DeviceMemory<double> *y, int incy);
   1156 
   1157   // See BlasSupport::DoBlasSpmv.
   1158   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
   1159                        const DeviceMemory<float> &ap,
   1160                        const DeviceMemory<float> &x, int incx, float beta,
   1161                        DeviceMemory<float> *y, int incy);
   1162   Stream &ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
   1163                        const DeviceMemory<double> &ap,
   1164                        const DeviceMemory<double> &x, int incx, double beta,
   1165                        DeviceMemory<double> *y, int incy);
   1166 
   1167   // See BlasSupport::DoBlasSpr.
   1168   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
   1169                       const DeviceMemory<float> &x, int incx,
   1170                       DeviceMemory<float> *ap);
   1171   Stream &ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
   1172                       const DeviceMemory<double> &x, int incx,
   1173                       DeviceMemory<double> *ap);
   1174 
   1175   // See BlasSupport::DoBlasSpr2.
   1176   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
   1177                        const DeviceMemory<float> &x, int incx,
   1178                        const DeviceMemory<float> &y, int incy,
   1179                        DeviceMemory<float> *ap);
   1180   Stream &ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
   1181                        const DeviceMemory<double> &x, int incx,
   1182                        const DeviceMemory<double> &y, int incy,
   1183                        DeviceMemory<double> *ap);
   1184 
   1185   // See BlasSupport::DoBlasSymv.
   1186   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
   1187                        const DeviceMemory<float> &a, int lda,
   1188                        const DeviceMemory<float> &x, int incx, float beta,
   1189                        DeviceMemory<float> *y, int incy);
   1190   Stream &ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
   1191                        const DeviceMemory<double> &a, int lda,
   1192                        const DeviceMemory<double> &x, int incx, double beta,
   1193                        DeviceMemory<double> *y, int incy);
   1194 
   1195   // See BlasSupport::DoBlasSyr.
   1196   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
   1197                       const DeviceMemory<float> &x, int incx,
   1198                       DeviceMemory<float> *a, int lda);
   1199   Stream &ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
   1200                       const DeviceMemory<double> &x, int incx,
   1201                       DeviceMemory<double> *a, int lda);
   1202 
   1203   // See BlasSupport::DoBlasSyr2.
   1204   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
   1205                        const DeviceMemory<float> &x, int incx,
   1206                        const DeviceMemory<float> &y, int incy,
   1207                        DeviceMemory<float> *a, int lda);
   1208   Stream &ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
   1209                        const DeviceMemory<double> &x, int incx,
   1210                        const DeviceMemory<double> &y, int incy,
   1211                        DeviceMemory<double> *a, int lda);
   1212 
   1213   // See BlasSupport::DoBlasTbmv.
   1214   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   1215                        blas::Diagonal diag, uint64 n, uint64 k,
   1216                        const DeviceMemory<float> &a, int lda,
   1217                        DeviceMemory<float> *x, int incx);
   1218   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   1219                        blas::Diagonal diag, uint64 n, uint64 k,
   1220                        const DeviceMemory<double> &a, int lda,
   1221                        DeviceMemory<double> *x, int incx);
   1222   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   1223                        blas::Diagonal diag, uint64 n, uint64 k,
   1224                        const DeviceMemory<std::complex<float>> &a, int lda,
   1225                        DeviceMemory<std::complex<float>> *x, int incx);
   1226   Stream &ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   1227                        blas::Diagonal diag, uint64 n, uint64 k,
   1228                        const DeviceMemory<std::complex<double>> &a, int lda,
   1229                        DeviceMemory<std::complex<double>> *x, int incx);
   1230 
   1231   // See BlasSupport::DoBlasTbsv.
   1232   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   1233                        blas::Diagonal diag, uint64 n, uint64 k,
   1234                        const DeviceMemory<float> &a, int lda,
   1235                        DeviceMemory<float> *x, int incx);
   1236   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   1237                        blas::Diagonal diag, uint64 n, uint64 k,
   1238                        const DeviceMemory<double> &a, int lda,
   1239                        DeviceMemory<double> *x, int incx);
   1240   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   1241                        blas::Diagonal diag, uint64 n, uint64 k,
   1242                        const DeviceMemory<std::complex<float>> &a, int lda,
   1243                        DeviceMemory<std::complex<float>> *x, int incx);
   1244   Stream &ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   1245                        blas::Diagonal diag, uint64 n, uint64 k,
   1246                        const DeviceMemory<std::complex<double>> &a, int lda,
   1247                        DeviceMemory<std::complex<double>> *x, int incx);
   1248 
   1249   // See BlasSupport::DoBlasTpmv.
   1250   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   1251                        blas::Diagonal diag, uint64 n,
   1252                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
   1253                        int incx);
   1254   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   1255                        blas::Diagonal diag, uint64 n,
   1256                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
   1257                        int incx);
   1258   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   1259                        blas::Diagonal diag, uint64 n,
   1260                        const DeviceMemory<std::complex<float>> &ap,
   1261                        DeviceMemory<std::complex<float>> *x, int incx);
   1262   Stream &ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   1263                        blas::Diagonal diag, uint64 n,
   1264                        const DeviceMemory<std::complex<double>> &ap,
   1265                        DeviceMemory<std::complex<double>> *x, int incx);
   1266 
   1267   // See BlasSupport::DoBlasTpsv.
   1268   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   1269                        blas::Diagonal diag, uint64 n,
   1270                        const DeviceMemory<float> &ap, DeviceMemory<float> *x,
   1271                        int incx);
   1272   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   1273                        blas::Diagonal diag, uint64 n,
   1274                        const DeviceMemory<double> &ap, DeviceMemory<double> *x,
   1275                        int incx);
   1276   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   1277                        blas::Diagonal diag, uint64 n,
   1278                        const DeviceMemory<std::complex<float>> &ap,
   1279                        DeviceMemory<std::complex<float>> *x, int incx);
   1280   Stream &ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   1281                        blas::Diagonal diag, uint64 n,
   1282                        const DeviceMemory<std::complex<double>> &ap,
   1283                        DeviceMemory<std::complex<double>> *x, int incx);
   1284 
   1285   // See BlasSupport::DoBlasTrmv.
   1286   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   1287                        blas::Diagonal diag, uint64 n,
   1288                        const DeviceMemory<float> &a, int lda,
   1289                        DeviceMemory<float> *x, int incx);
   1290   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   1291                        blas::Diagonal diag, uint64 n,
   1292                        const DeviceMemory<double> &a, int lda,
   1293                        DeviceMemory<double> *x, int incx);
   1294   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   1295                        blas::Diagonal diag, uint64 n,
   1296                        const DeviceMemory<std::complex<float>> &a, int lda,
   1297                        DeviceMemory<std::complex<float>> *x, int incx);
   1298   Stream &ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   1299                        blas::Diagonal diag, uint64 n,
   1300                        const DeviceMemory<std::complex<double>> &a, int lda,
   1301                        DeviceMemory<std::complex<double>> *x, int incx);
   1302 
   1303   // See BlasSupport::DoBlasTrsv.
   1304   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   1305                        blas::Diagonal diag, uint64 n,
   1306                        const DeviceMemory<float> &a, int lda,
   1307                        DeviceMemory<float> *x, int incx);
   1308   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   1309                        blas::Diagonal diag, uint64 n,
   1310                        const DeviceMemory<double> &a, int lda,
   1311                        DeviceMemory<double> *x, int incx);
   1312   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   1313                        blas::Diagonal diag, uint64 n,
   1314                        const DeviceMemory<std::complex<float>> &a, int lda,
   1315                        DeviceMemory<std::complex<float>> *x, int incx);
   1316   Stream &ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   1317                        blas::Diagonal diag, uint64 n,
   1318                        const DeviceMemory<std::complex<double>> &a, int lda,
   1319                        DeviceMemory<std::complex<double>> *x, int incx);
   1320 
   1321   // See BlasSupport::DoBlasGemm.
   1322   Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
   1323                        uint64 n, uint64 k, float alpha,
   1324                        const DeviceMemory<Eigen::half> &a, int lda,
   1325                        const DeviceMemory<Eigen::half> &b, int ldb, float beta,
   1326                        DeviceMemory<Eigen::half> *c, int ldc);
   1327   Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
   1328                        uint64 n, uint64 k, float alpha,
   1329                        const DeviceMemory<float> &a, int lda,
   1330                        const DeviceMemory<float> &b, int ldb, float beta,
   1331                        DeviceMemory<float> *c, int ldc);
   1332   Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
   1333                        uint64 n, uint64 k, double alpha,
   1334                        const DeviceMemory<double> &a, int lda,
   1335                        const DeviceMemory<double> &b, int ldb, double beta,
   1336                        DeviceMemory<double> *c, int ldc);
   1337   Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
   1338                        uint64 n, uint64 k, std::complex<float> alpha,
   1339                        const DeviceMemory<std::complex<float>> &a, int lda,
   1340                        const DeviceMemory<std::complex<float>> &b, int ldb,
   1341                        std::complex<float> beta,
   1342                        DeviceMemory<std::complex<float>> *c, int ldc);
   1343   Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
   1344                        uint64 n, uint64 k, std::complex<double> alpha,
   1345                        const DeviceMemory<std::complex<double>> &a, int lda,
   1346                        const DeviceMemory<std::complex<double>> &b, int ldb,
   1347                        std::complex<double> beta,
   1348                        DeviceMemory<std::complex<double>> *c, int ldc);
   1349 
   1350   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
   1351                                     blas::Transpose transb, uint64 m, uint64 n,
   1352                                     uint64 k, float alpha,
   1353                                     const DeviceMemory<Eigen::half> &a, int lda,
   1354                                     const DeviceMemory<Eigen::half> &b, int ldb,
   1355                                     float beta, DeviceMemory<Eigen::half> *c,
   1356                                     int ldc,
   1357                                     blas::ProfileResult *output_profile_result);
   1358   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
   1359                                     blas::Transpose transb, uint64 m, uint64 n,
   1360                                     uint64 k, float alpha,
   1361                                     const DeviceMemory<float> &a, int lda,
   1362                                     const DeviceMemory<float> &b, int ldb,
   1363                                     float beta, DeviceMemory<float> *c, int ldc,
   1364                                     blas::ProfileResult *output_profile_result);
   1365   Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
   1366                                     blas::Transpose transb, uint64 m, uint64 n,
   1367                                     uint64 k, double alpha,
   1368                                     const DeviceMemory<double> &a, int lda,
   1369                                     const DeviceMemory<double> &b, int ldb,
   1370                                     double beta, DeviceMemory<double> *c,
   1371                                     int ldc,
   1372                                     blas::ProfileResult *output_profile_result);
   1373   Stream &ThenBlasGemmWithProfiling(
   1374       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1375       uint64 k, std::complex<float> alpha,
   1376       const DeviceMemory<std::complex<float>> &a, int lda,
   1377       const DeviceMemory<std::complex<float>> &b, int ldb,
   1378       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
   1379       blas::ProfileResult *output_profile_result);
   1380   Stream &ThenBlasGemmWithProfiling(
   1381       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1382       uint64 k, std::complex<double> alpha,
   1383       const DeviceMemory<std::complex<double>> &a, int lda,
   1384       const DeviceMemory<std::complex<double>> &b, int ldb,
   1385       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
   1386       blas::ProfileResult *output_profile_result);
   1387 
   1388   // See BlasSupport::DoBlasGemmWithAlgorithm.
   1389   Stream &ThenBlasGemmWithAlgorithm(
   1390       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1391       uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
   1392       int lda, const DeviceMemory<Eigen::half> &b, int ldb,
   1393       const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
   1394       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   1395       blas::ProfileResult *output_profile_result);
   1396   Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
   1397                                     blas::Transpose transb, uint64 m, uint64 n,
   1398                                     uint64 k, int alpha,
   1399                                     const DeviceMemory<int8> &a, int lda,
   1400                                     const DeviceMemory<int8> &b, int ldb,
   1401                                     int beta, DeviceMemory<int> *c, int ldc,
   1402                                     blas::ComputationType computation_type,
   1403                                     blas::AlgorithmType algorithm,
   1404                                     blas::ProfileResult *output_profile_result);
   1405   Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
   1406                                     blas::Transpose transb, uint64 m, uint64 n,
   1407                                     uint64 k, float alpha,
   1408                                     const DeviceMemory<float> &a, int lda,
   1409                                     const DeviceMemory<float> &b, int ldb,
   1410                                     float beta, DeviceMemory<float> *c, int ldc,
   1411                                     blas::ComputationType computation_type,
   1412                                     blas::AlgorithmType algorithm,
   1413                                     blas::ProfileResult *output_profile_result);
   1414   Stream &ThenBlasGemmWithAlgorithm(
   1415       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1416       uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
   1417       const DeviceMemory<double> &b, int ldb, double beta,
   1418       DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
   1419       blas::AlgorithmType algorithm,
   1420       blas::ProfileResult *output_profile_result);
   1421   Stream &ThenBlasGemmWithAlgorithm(
   1422       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1423       uint64 k, std::complex<float> alpha,
   1424       const DeviceMemory<std::complex<float>> &a, int lda,
   1425       const DeviceMemory<std::complex<float>> &b, int ldb,
   1426       std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
   1427       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   1428       blas::ProfileResult *output_profile_result);
   1429   Stream &ThenBlasGemmWithAlgorithm(
   1430       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1431       uint64 k, std::complex<double> alpha,
   1432       const DeviceMemory<std::complex<double>> &a, int lda,
   1433       const DeviceMemory<std::complex<double>> &b, int ldb,
   1434       std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
   1435       blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   1436       blas::ProfileResult *output_profile_result);
   1437 
   1438   // See BlasSupport::DoBlasGemmBatched.
   1439   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
   1440                               uint64 m, uint64 n, uint64 k, float alpha,
   1441                               const port::ArraySlice<DeviceMemory<float> *> &a,
   1442                               int lda,
   1443                               const port::ArraySlice<DeviceMemory<float> *> &b,
   1444                               int ldb, float beta,
   1445                               const port::ArraySlice<DeviceMemory<float> *> &c,
   1446                               int ldc, int batch_count);
   1447   Stream &ThenBlasGemmBatched(blas::Transpose transa, blas::Transpose transb,
   1448                               uint64 m, uint64 n, uint64 k, double alpha,
   1449                               const port::ArraySlice<DeviceMemory<double> *> &a,
   1450                               int lda,
   1451                               const port::ArraySlice<DeviceMemory<double> *> &b,
   1452                               int ldb, double beta,
   1453                               const port::ArraySlice<DeviceMemory<double> *> &c,
   1454                               int ldc, int batch_count);
   1455   Stream &ThenBlasGemmBatched(
   1456       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1457       uint64 k, std::complex<float> alpha,
   1458       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
   1459       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
   1460       std::complex<float> beta,
   1461       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
   1462       int batch_count);
   1463   Stream &ThenBlasGemmBatched(
   1464       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1465       uint64 k, std::complex<double> alpha,
   1466       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
   1467       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
   1468       std::complex<double> beta,
   1469       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
   1470       int batch_count);
   1471   Stream &ThenBlasGemmBatchedWithScratch(
   1472       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1473       uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
   1474       int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
   1475       float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
   1476       int batch_count, ScratchAllocator *scratch_allocator);
   1477   Stream &ThenBlasGemmBatchedWithScratch(
   1478       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1479       uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
   1480       int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
   1481       double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
   1482       int batch_count, ScratchAllocator *scratch_allocator);
   1483   Stream &ThenBlasGemmBatchedWithScratch(
   1484       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1485       uint64 k, std::complex<float> alpha,
   1486       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
   1487       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
   1488       std::complex<float> beta,
   1489       const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
   1490       int batch_count, ScratchAllocator *scratch_allocator);
   1491   Stream &ThenBlasGemmBatchedWithScratch(
   1492       blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   1493       uint64 k, std::complex<double> alpha,
   1494       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
   1495       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
   1496       std::complex<double> beta,
   1497       const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
   1498       int batch_count, ScratchAllocator *scratch_allocator);
   1499 
   1500   // See BlasSupport::DoBlasHemm.
   1501   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
   1502                        uint64 n, std::complex<float> alpha,
   1503                        const DeviceMemory<std::complex<float>> &a, int lda,
   1504                        const DeviceMemory<std::complex<float>> &b, int ldb,
   1505                        std::complex<float> beta,
   1506                        DeviceMemory<std::complex<float>> *c, int ldc);
   1507   Stream &ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
   1508                        uint64 n, std::complex<double> alpha,
   1509                        const DeviceMemory<std::complex<double>> &a, int lda,
   1510                        const DeviceMemory<std::complex<double>> &b, int ldb,
   1511                        std::complex<double> beta,
   1512                        DeviceMemory<std::complex<double>> *c, int ldc);
   1513 
   1514   // See BlasSupport::DoBlasHerk.
   1515   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1516                        uint64 k, float alpha,
   1517                        const DeviceMemory<std::complex<float>> &a, int lda,
   1518                        float beta, DeviceMemory<std::complex<float>> *c,
   1519                        int ldc);
   1520   Stream &ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1521                        uint64 k, double alpha,
   1522                        const DeviceMemory<std::complex<double>> &a, int lda,
   1523                        double beta, DeviceMemory<std::complex<double>> *c,
   1524                        int ldc);
   1525 
   1526   // See BlasSupport::DoBlasHer2k.
   1527   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1528                         uint64 k, std::complex<float> alpha,
   1529                         const DeviceMemory<std::complex<float>> &a, int lda,
   1530                         const DeviceMemory<std::complex<float>> &b, int ldb,
   1531                         float beta, DeviceMemory<std::complex<float>> *c,
   1532                         int ldc);
   1533   Stream &ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1534                         uint64 k, std::complex<double> alpha,
   1535                         const DeviceMemory<std::complex<double>> &a, int lda,
   1536                         const DeviceMemory<std::complex<double>> &b, int ldb,
   1537                         double beta, DeviceMemory<std::complex<double>> *c,
   1538                         int ldc);
   1539 
   1540   // See BlasSupport::DoBlasSymm.
   1541   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   1542                        uint64 n, float alpha, const DeviceMemory<float> &a,
   1543                        int lda, const DeviceMemory<float> &b, int ldb,
   1544                        float beta, DeviceMemory<float> *c, int ldc);
   1545   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   1546                        uint64 n, double alpha, const DeviceMemory<double> &a,
   1547                        int lda, const DeviceMemory<double> &b, int ldb,
   1548                        double beta, DeviceMemory<double> *c, int ldc);
   1549   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   1550                        uint64 n, std::complex<float> alpha,
   1551                        const DeviceMemory<std::complex<float>> &a, int lda,
   1552                        const DeviceMemory<std::complex<float>> &b, int ldb,
   1553                        std::complex<float> beta,
   1554                        DeviceMemory<std::complex<float>> *c, int ldc);
   1555   Stream &ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   1556                        uint64 n, std::complex<double> alpha,
   1557                        const DeviceMemory<std::complex<double>> &a, int lda,
   1558                        const DeviceMemory<std::complex<double>> &b, int ldb,
   1559                        std::complex<double> beta,
   1560                        DeviceMemory<std::complex<double>> *c, int ldc);
   1561 
   1562   // See BlasSupport::DoBlasSyrk.
   1563   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1564                        uint64 k, float alpha, const DeviceMemory<float> &a,
   1565                        int lda, float beta, DeviceMemory<float> *c, int ldc);
   1566   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1567                        uint64 k, double alpha, const DeviceMemory<double> &a,
   1568                        int lda, double beta, DeviceMemory<double> *c, int ldc);
   1569   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1570                        uint64 k, std::complex<float> alpha,
   1571                        const DeviceMemory<std::complex<float>> &a, int lda,
   1572                        std::complex<float> beta,
   1573                        DeviceMemory<std::complex<float>> *c, int ldc);
   1574   Stream &ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1575                        uint64 k, std::complex<double> alpha,
   1576                        const DeviceMemory<std::complex<double>> &a, int lda,
   1577                        std::complex<double> beta,
   1578                        DeviceMemory<std::complex<double>> *c, int ldc);
   1579 
   1580   // See BlasSupport::DoBlasSyr2k.
   1581   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1582                         uint64 k, float alpha, const DeviceMemory<float> &a,
   1583                         int lda, const DeviceMemory<float> &b, int ldb,
   1584                         float beta, DeviceMemory<float> *c, int ldc);
   1585   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1586                         uint64 k, double alpha, const DeviceMemory<double> &a,
   1587                         int lda, const DeviceMemory<double> &b, int ldb,
   1588                         double beta, DeviceMemory<double> *c, int ldc);
   1589   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1590                         uint64 k, std::complex<float> alpha,
   1591                         const DeviceMemory<std::complex<float>> &a, int lda,
   1592                         const DeviceMemory<std::complex<float>> &b, int ldb,
   1593                         std::complex<float> beta,
   1594                         DeviceMemory<std::complex<float>> *c, int ldc);
   1595   Stream &ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans, uint64 n,
   1596                         uint64 k, std::complex<double> alpha,
   1597                         const DeviceMemory<std::complex<double>> &a, int lda,
   1598                         const DeviceMemory<std::complex<double>> &b, int ldb,
   1599                         std::complex<double> beta,
   1600                         DeviceMemory<std::complex<double>> *c, int ldc);
   1601 
   1602   // See BlasSupport::DoBlasTrmm.
   1603   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   1604                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
   1605                        uint64 n, float alpha, const DeviceMemory<float> &a,
   1606                        int lda, DeviceMemory<float> *b, int ldb);
   1607   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   1608                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
   1609                        uint64 n, double alpha, const DeviceMemory<double> &a,
   1610                        int lda, DeviceMemory<double> *b, int ldb);
   1611   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   1612                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
   1613                        uint64 n, std::complex<float> alpha,
   1614                        const DeviceMemory<std::complex<float>> &a, int lda,
   1615                        DeviceMemory<std::complex<float>> *b, int ldb);
   1616   Stream &ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   1617                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
   1618                        uint64 n, std::complex<double> alpha,
   1619                        const DeviceMemory<std::complex<double>> &a, int lda,
   1620                        DeviceMemory<std::complex<double>> *b, int ldb);
   1621 
   1622   // See BlasSupport::DoBlasTrsm.
   1623   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   1624                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
   1625                        uint64 n, float alpha, const DeviceMemory<float> &a,
   1626                        int lda, DeviceMemory<float> *b, int ldb);
   1627   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   1628                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
   1629                        uint64 n, double alpha, const DeviceMemory<double> &a,
   1630                        int lda, DeviceMemory<double> *b, int ldb);
   1631   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   1632                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
   1633                        uint64 n, std::complex<float> alpha,
   1634                        const DeviceMemory<std::complex<float>> &a, int lda,
   1635                        DeviceMemory<std::complex<float>> *b, int ldb);
   1636   Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   1637                        blas::Transpose transa, blas::Diagonal diag, uint64 m,
   1638                        uint64 n, std::complex<double> alpha,
   1639                        const DeviceMemory<std::complex<double>> &a, int lda,
   1640                        DeviceMemory<std::complex<double>> *b, int ldb);
   1641 
   1642   // See FftSupport::DoFft.
   1643   Stream &ThenFft(fft::Plan *plan,
   1644                   const DeviceMemory<std::complex<float>> &input,
   1645                   DeviceMemory<std::complex<float>> *output);
   1646   Stream &ThenFft(fft::Plan *plan,
   1647                   const DeviceMemory<std::complex<double>> &input,
   1648                   DeviceMemory<std::complex<double>> *output);
   1649   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
   1650                   DeviceMemory<std::complex<float>> *output);
   1651   Stream &ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
   1652                   DeviceMemory<std::complex<double>> *output);
   1653   Stream &ThenFft(fft::Plan *plan,
   1654                   const DeviceMemory<std::complex<float>> &input,
   1655                   DeviceMemory<float> *output);
   1656   Stream &ThenFft(fft::Plan *plan,
   1657                   const DeviceMemory<std::complex<double>> &input,
   1658                   DeviceMemory<double> *output);
   1659 
   1660   // Makes the RNG use the provided value as the basis for further generation.
   1661   // /dev/urandom (good) and /dev/random (better, but sometimes slow) are good
   1662   // sources of seed data if the default (high quality) sources are not
   1663   // desired.
   1664   // For most use cases, this function will not be necessary; each provided
   1665   // back-end implementation will be appropriately seeded by default.
   1666   // At a minimum 16 bytes of data are required in the seed buffer.
   1667   //
   1668   // To seed with good (non-reproducible) data:
   1669   //   File* f = File::Open("/dev/random", "r");
   1670   //   int64 bytes_read = f->Read(seed_data, bytes_to_read);
   1671   //   < error checking >
   1672   //   stream.ThenSetRngSeed(seed_data, bytes_read);
   1673   //
   1674   // To seed with reproducible data:
   1675   //   uint64_t seed_data[2] = { <data> };
   1676   //   stream.ThenSetRngSeed(seed_data, 16);
   1677   Stream &ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes);
   1678 
   1679   // Populates the memory indicated by values with uniform-random-distribution
   1680   // values. TODO(leary) seeding API/description
   1681   //
   1682   // Uses the type and size of the DeviceMemory to infer what data should be
   1683   // populated.
   1684   Stream &ThenPopulateRandUniform(DeviceMemory<float> *values);
   1685   Stream &ThenPopulateRandUniform(DeviceMemory<double> *values);
   1686   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<float>> *values);
   1687   Stream &ThenPopulateRandUniform(DeviceMemory<std::complex<double>> *values);
   1688   Stream &ThenPopulateRandGaussian(float mean, float stddev,
   1689                                    DeviceMemory<float> *values);
   1690   Stream &ThenPopulateRandGaussian(double mean, double stddev,
   1691                                    DeviceMemory<double> *values);
   1692 
   1693   // Entrain onto the stream: a memcpy to a host destination from a GPU source
   1694   // of the given target size. host_dst must be a pointer to host memory
   1695   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
   1696   // then registered with StreamExecutor::HostMemoryRegister.
   1697   Stream &ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
   1698                      uint64 size);
   1699 
   1700   // Entrain onto the stream: a memcpy to a GPU destination from a host source
   1701   // of the given target size. host_src must be a pointer to host memory
   1702   // allocated by StreamExecutor::HostMemoryAllocate or otherwise allocated and
   1703   // then registered with StreamExecutor::HostMemoryRegister.
   1704   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
   1705                      uint64 size);
   1706 
   1707   // Alternative interface for memcpying from device to host that takes an
   1708   // array slice. Checks that the destination size can accommodate the host
   1709   // slice size.
   1710   template <typename T>
   1711   Stream &ThenMemcpyD2H(const DeviceMemory<T> &gpu_src,
   1712                         port::MutableArraySlice<T> host_dst) {
   1713     auto host_size = host_dst.size() * sizeof(T);
   1714     CHECK(gpu_src.size() == 0 || host_size >= gpu_src.size());
   1715     return ThenMemcpy(host_dst.begin(), gpu_src, host_size);
   1716   }
   1717 
   1718   // Alternative interface for memcpying from host to device that takes an
   1719   // array slice. Checks that the destination size can accommodate the host
   1720   // slice size.
   1721   template <typename T>
   1722   Stream &ThenMemcpyH2D(port::ArraySlice<T> host_src,
   1723                         DeviceMemory<T> *gpu_dst) {
   1724     auto host_size = host_src.size() * sizeof(T);
   1725     CHECK(gpu_dst->size() == 0 || gpu_dst->size() >= host_size);
   1726     return ThenMemcpy(gpu_dst, host_src.begin(), host_size);
   1727   }
   1728 
   1729   // Entrain onto the stream: a memcpy to a GPU destination from a GPU source
   1730   // of the given target size. gpu_src/dst must be pointers to GPU memory and
   1731   // peer access must be enabled between their owning StreamExecutors.
   1732   Stream &ThenMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src,
   1733                      uint64 size);
   1734 
   1735   // Calls to the device-to-device copy overload of ThenMemcpy -- useful for
   1736   // ensuring that the host pointer isn't getting confused accidentally with a
   1737   // device pointer if you're not doing metaprogramming against the API.
   1738   Stream &ThenMemcpyD2D(DeviceMemoryBase *gpu_dst,
   1739                         const DeviceMemoryBase &gpu_src, uint64 size) {
   1740     return ThenMemcpy(gpu_dst, gpu_src, size);
   1741   }
   1742 
   1743   // Entrain onto the stream: a memset of zero at a GPU location of size bytes.
   1744   // The location must not be null.
   1745   Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
   1746 
   1747   // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location of
   1748   // size bytes, where bytes must be evenly 32-bit sized (i.e. evenly divisible
   1749   // by 4). The location must not be null.
   1750   Stream &ThenMemset32(DeviceMemoryBase *location, uint32 pattern, uint64 size);
   1751 
   1752   // Enqueue a forward operation of the RNN model onto the stream.
   1753   // See DnnSupport::DoRnnForward for more details.
   1754   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
   1755                          const dnn::RnnSequenceTensorDescriptor &input_desc,
   1756                          const DeviceMemory<Eigen::half> &input_data,
   1757                          const dnn::RnnStateTensorDescriptor &input_h_desc,
   1758                          const DeviceMemory<Eigen::half> &input_h_data,
   1759                          const dnn::RnnStateTensorDescriptor &input_c_desc,
   1760                          const DeviceMemory<Eigen::half> &input_c_data,
   1761                          const DeviceMemory<Eigen::half> &params,
   1762                          const dnn::RnnSequenceTensorDescriptor &output_desc,
   1763                          DeviceMemory<Eigen::half> *output_data,
   1764                          const dnn::RnnStateTensorDescriptor &output_h_desc,
   1765                          DeviceMemory<Eigen::half> *output_h_data,
   1766                          const dnn::RnnStateTensorDescriptor &output_c_desc,
   1767                          DeviceMemory<Eigen::half> *output_c_data,
   1768                          bool is_training,
   1769                          ScratchAllocator *reserve_space_allocator,
   1770                          ScratchAllocator *workspace_allocator);
   1771 
   1772   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
   1773                          const dnn::RnnSequenceTensorDescriptor &input_desc,
   1774                          const DeviceMemory<float> &input_data,
   1775                          const dnn::RnnStateTensorDescriptor &input_h_desc,
   1776                          const DeviceMemory<float> &input_h_data,
   1777                          const dnn::RnnStateTensorDescriptor &input_c_desc,
   1778                          const DeviceMemory<float> &input_c_data,
   1779                          const DeviceMemory<float> &params,
   1780                          const dnn::RnnSequenceTensorDescriptor &output_desc,
   1781                          DeviceMemory<float> *output_data,
   1782                          const dnn::RnnStateTensorDescriptor &output_h_desc,
   1783                          DeviceMemory<float> *output_h_data,
   1784                          const dnn::RnnStateTensorDescriptor &output_c_desc,
   1785                          DeviceMemory<float> *output_c_data, bool is_training,
   1786                          ScratchAllocator *reserve_space_allocator,
   1787                          ScratchAllocator *workspace_allocator);
   1788 
   1789   Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
   1790                          const dnn::RnnSequenceTensorDescriptor &input_desc,
   1791                          const DeviceMemory<double> &input_data,
   1792                          const dnn::RnnStateTensorDescriptor &input_h_desc,
   1793                          const DeviceMemory<double> &input_h_data,
   1794                          const dnn::RnnStateTensorDescriptor &input_c_desc,
   1795                          const DeviceMemory<double> &input_c_data,
   1796                          const DeviceMemory<double> &params,
   1797                          const dnn::RnnSequenceTensorDescriptor &output_desc,
   1798                          DeviceMemory<double> *output_data,
   1799                          const dnn::RnnStateTensorDescriptor &output_h_desc,
   1800                          DeviceMemory<double> *output_h_data,
   1801                          const dnn::RnnStateTensorDescriptor &output_c_desc,
   1802                          DeviceMemory<double> *output_c_data, bool is_training,
   1803                          ScratchAllocator *reserve_space_allocator,
   1804                          ScratchAllocator *workspace_allocator);
   1805 
   1806   // Enqueue a backward operation of the RNN model onto the stream.
   1807   // See DnnSupport::DoRnnBackward for more details.
   1808   Stream &ThenRnnBackward(
   1809       const dnn::RnnDescriptor &rnn_desc,
   1810       const dnn::RnnSequenceTensorDescriptor &input_desc,
   1811       const DeviceMemory<Eigen::half> &input_data,
   1812       const dnn::RnnStateTensorDescriptor &input_h_desc,
   1813       const DeviceMemory<Eigen::half> &input_h_data,
   1814       const dnn::RnnStateTensorDescriptor &input_c_desc,
   1815       const DeviceMemory<Eigen::half> &input_c_data,
   1816       const DeviceMemory<Eigen::half> &params,
   1817       const dnn::RnnSequenceTensorDescriptor &output_desc,
   1818       const DeviceMemory<Eigen::half> &output_data,
   1819       const dnn::RnnStateTensorDescriptor &output_h_desc,
   1820       const DeviceMemory<Eigen::half> &output_h_data,
   1821       const dnn::RnnStateTensorDescriptor &output_c_desc,
   1822       const DeviceMemory<Eigen::half> &output_c_data,
   1823       const DeviceMemory<Eigen::half> &output_backprop_data,
   1824       const DeviceMemory<Eigen::half> &output_h_backprop_data,
   1825       const DeviceMemory<Eigen::half> &output_c_backprop_data,
   1826       DeviceMemory<Eigen::half> *input_backprop_data,
   1827       DeviceMemory<Eigen::half> *input_h_backprop_data,
   1828       DeviceMemory<Eigen::half> *input_c_backprop_data,
   1829       DeviceMemory<Eigen::half> *params_backprop_data,
   1830       DeviceMemory<uint8> *reserve_space_data,
   1831       ScratchAllocator *workspace_allocator);
   1832 
   1833   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
   1834                           const dnn::RnnSequenceTensorDescriptor &input_desc,
   1835                           const DeviceMemory<float> &input_data,
   1836                           const dnn::RnnStateTensorDescriptor &input_h_desc,
   1837                           const DeviceMemory<float> &input_h_data,
   1838                           const dnn::RnnStateTensorDescriptor &input_c_desc,
   1839                           const DeviceMemory<float> &input_c_data,
   1840                           const DeviceMemory<float> &params,
   1841                           const dnn::RnnSequenceTensorDescriptor &output_desc,
   1842                           const DeviceMemory<float> &output_data,
   1843                           const dnn::RnnStateTensorDescriptor &output_h_desc,
   1844                           const DeviceMemory<float> &output_h_data,
   1845                           const dnn::RnnStateTensorDescriptor &output_c_desc,
   1846                           const DeviceMemory<float> &output_c_data,
   1847                           const DeviceMemory<float> &output_backprop_data,
   1848                           const DeviceMemory<float> &output_h_backprop_data,
   1849                           const DeviceMemory<float> &output_c_backprop_data,
   1850                           DeviceMemory<float> *input_backprop_data,
   1851                           DeviceMemory<float> *input_h_backprop_data,
   1852                           DeviceMemory<float> *input_c_backprop_data,
   1853                           DeviceMemory<float> *params_backprop_data,
   1854                           DeviceMemory<uint8> *reserve_space_data,
   1855                           ScratchAllocator *workspace_allocator);
   1856 
   1857   Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
   1858                           const dnn::RnnSequenceTensorDescriptor &input_desc,
   1859                           const DeviceMemory<double> &input_data,
   1860                           const dnn::RnnStateTensorDescriptor &input_h_desc,
   1861                           const DeviceMemory<double> &input_h_data,
   1862                           const dnn::RnnStateTensorDescriptor &input_c_desc,
   1863                           const DeviceMemory<double> &input_c_data,
   1864                           const DeviceMemory<double> &params,
   1865                           const dnn::RnnSequenceTensorDescriptor &output_desc,
   1866                           const DeviceMemory<double> &output_data,
   1867                           const dnn::RnnStateTensorDescriptor &output_h_desc,
   1868                           const DeviceMemory<double> &output_h_data,
   1869                           const dnn::RnnStateTensorDescriptor &output_c_desc,
   1870                           const DeviceMemory<double> &output_c_data,
   1871                           const DeviceMemory<double> &output_backprop_data,
   1872                           const DeviceMemory<double> &output_h_backprop_data,
   1873                           const DeviceMemory<double> &output_c_backprop_data,
   1874                           DeviceMemory<double> *input_backprop_data,
   1875                           DeviceMemory<double> *input_h_backprop_data,
   1876                           DeviceMemory<double> *input_c_backprop_data,
   1877                           DeviceMemory<double> *params_backprop_data,
   1878                           DeviceMemory<uint8> *reserve_space_data,
   1879                           ScratchAllocator *workspace_allocator);
   1880 
   1881   // Enqueue onto the stream a operation that transforms a tensor.
   1882   // See DnnSupport::DoTransformTensor for more details.
   1883   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
   1884                               dnn::DataType input_type,
   1885                               const DeviceMemoryBase &input_data,
   1886                               const dnn::BatchDescriptor &output_desc,
   1887                               dnn::DataType output_type, float scale,
   1888                               DeviceMemoryBase *output_data);
   1889 
   1890   // The templated version of the above ThenTransformTensor. Useful when the
   1891   // input and output types are statically known.
   1892   template <typename InElemT, typename OutElemT>
   1893   Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
   1894                               const DeviceMemory<InElemT> &input_data,
   1895                               const dnn::BatchDescriptor &output_desc,
   1896                               DeviceMemory<OutElemT> *output_data) {
   1897     return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
   1898                                input_data, output_desc,
   1899                                dnn::ToDataType<OutElemT>(), output_data);
   1900   }
   1901 
   1902   // (Synchronously) block the host code waiting for the operations
   1903   // entrained on the stream (enqueued to this point in program
   1904   // execution) to complete.
   1905   //
   1906   // Returns an OK status if the blocking was successful and the stream is ok().
   1907   // Otherwise returns an error describing why the blocking failed.
   1908   port::Status BlockHostUntilDone() LOCKS_EXCLUDED(mu_);
   1909 
   1910   // Warning! This method interacts with internal threads in
   1911   // sometimes-unpredictable ways and is intended for GPU-Executor-internal
   1912   // use
   1913   // only. Please check with a member of the FASTR team before making use of
   1914   // this method.
   1915   //
   1916   // Entrains onto the stream a function to be executed on the host at some
   1917   // point in the future.
   1918   // Async host callbacks DO NOT block the stream as device functions (or as
   1919   // synchronous host callbacks). No synchronization is possible with
   1920   // asynchronous callbacks; they are strictly fire-and-forget.
   1921   // This method is private due to the potential for undefined behavior with
   1922   // synchronization using OpenCL user events.
   1923   // The ONLY lifetime guarantee in these calls is that the StreamExecutor
   1924   // parameter will still be valid - this Stream may not be!
   1925   // Any callbacks requiring device API calls must use this method.
   1926   Stream &ThenEnqueueOnBackgroundThread(
   1927       std::function<void(StreamExecutor *)> task);
   1928 
   1929   // Returns the (opaque) platform-specific backing object. Ownership is not
   1930   // transferred to the caller.
   1931   internal::StreamInterface *implementation() { return implementation_.get(); }
   1932 
   1933   // Entrains onto the stream a callback to the host (from the device).
   1934   // Host callbacks block/occupy the stream just as device functions
   1935   // (execute one at a time, block later stream operations).
   1936   // Behavior is undefined when synchronizing using OpenCL user events.
   1937   // Behavior is undefined if host callbacks call device routines or insert
   1938   // them into any stream.
   1939   // On certain platforms, ThenDoHostCallback is expected to have significant
   1940   // negative effects on performance.
   1941   Stream &ThenDoHostCallback(std::function<void()> callback);
   1942 
   1943   // Identical to ThenDoHostCallback; only exposed for testing purposes.
   1944   Stream &ThenDoHostCallbackForTest(std::function<void()> callback);
   1945 
   1946   // Returns the StreamExecutor (parent object) associated with this stream.
   1947   StreamExecutor *parent() const {
   1948     CHECK(parent_ != nullptr);
   1949     return parent_;
   1950   }
   1951 
   1952   // Returns the (internal usage) temporary-memory-allocation manager associated
   1953   // with this stream.
   1954   internal::TemporaryMemoryManager *temporary_memory_manager();
   1955 
   1956  private:
   1957   friend class host::HostBlas;  // for parent_.
   1958   friend class host::HostFft;   // for parent_.
   1959   friend class host::HostRng;   // for parent_.
   1960   template <typename... Args>
   1961   friend struct ThenBlasImpl;  // for implementing ThenBlasXXX.
   1962   friend class ocl::CLBlas;    // for parent_.
   1963 
   1964   bool InErrorState() const LOCKS_EXCLUDED(mu_) {
   1965     tf_shared_lock lock{mu_};
   1966     return !ok_;
   1967   }
   1968 
   1969   // Sets the error state if operation_retcode is false.
   1970   // This is a useful shorthand for many stream routines.
   1971   void CheckError(bool operation_retcode) LOCKS_EXCLUDED(mu_) {
   1972     if (operation_retcode) {
   1973       return;
   1974     }
   1975     mutex_lock lock{mu_};
   1976     ok_ = false;
   1977   }
   1978 
   1979   void SetError() { CheckError(false /* = operation_retcode */); }
   1980 
   1981   void SetErrorAndLogNoDnnSupport() {
   1982     SetError();
   1983     LOG(WARNING) << "attempting to perform DNN operation using StreamExecutor "
   1984                     "without DNN support";
   1985   }
   1986 
   1987   // The StreamExecutor that supports the operation of this stream.
   1988   StreamExecutor *parent_;
   1989 
   1990   // The platform-dependent implementation that the StreamExecutor interface
   1991   // delegates to.
   1992   std::unique_ptr<internal::StreamInterface> implementation_;
   1993 
   1994   // mutex that guards the allocation / error state flags.
   1995   // Mutable so that it can be obtained via const reader lock.
   1996   mutable mutex mu_;
   1997 
   1998   // Whether Init() was successfully called to allocate this stream on the
   1999   // underlying platform. It simply flips from 0 to 1 with a sanity check.
   2000   // See StreamExecutor::AllocateStream.
   2001   bool allocated_ GUARDED_BY(mu_);
   2002 
   2003   // Whether all operations have entrained successfully to the current program
   2004   // point.
   2005   bool ok_ GUARDED_BY(mu_);
   2006 
   2007   // Sub-streams that are generated from this stream. Each element has a pointer
   2008   // to sub-stream and a boolean value indicating if this substream is ready to
   2009   // be reused.
   2010   std::vector<std::pair<std::unique_ptr<Stream>, bool>> sub_streams_
   2011       GUARDED_BY(mu_);
   2012 
   2013   // Streams can allocate temporary memories to help with work they enqueue
   2014   // (e.g. for scratch memory spaces). This member tracks those allocations and
   2015   // notes when they can be reclaimed -- reclamation is attempted when
   2016   // BlockHostUntilDone() is called.
   2017   internal::TemporaryMemoryManager temporary_memory_manager_;
   2018 
   2019   // Implementation of ThenConvolveBackwardBias that is shared by all types.
   2020   template <typename T>
   2021   Stream &ThenConvolveBackwardBiasImpl(
   2022       const dnn::BatchDescriptor &input_descriptor,
   2023       const DeviceMemory<T> &input_data,
   2024       const dnn::BatchDescriptor &bias_descriptor,
   2025       DeviceMemory<T> *backward_bias_data);
   2026 
   2027   SE_DISALLOW_COPY_AND_ASSIGN(Stream);
   2028 };
   2029 
   2030 ////////////
   2031 // Inlines
   2032 
   2033 template <typename T>
   2034 inline port::StatusOr<std::unique_ptr<TemporaryDeviceMemory<T>>>
   2035 Stream::AllocateTemporaryArray(uint64 element_count) {
   2036   return temporary_memory_manager_.AllocateArray<T>(element_count);
   2037 }
   2038 
   2039 inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
   2040   return &temporary_memory_manager_;
   2041 }
   2042 
   2043 template <>
   2044 struct Quantization<uint8> {
   2045   static constexpr dnn::QuantizedActivationMode kModeId =
   2046       dnn::QuantizedActivationMode::k8Bit;
   2047 };
   2048 
   2049 template <>
   2050 struct Quantization<uint16> {
   2051   static constexpr dnn::QuantizedActivationMode kModeId =
   2052       dnn::QuantizedActivationMode::k16Bit;
   2053 };
   2054 
   2055 template <>
   2056 struct Quantization<int32> {
   2057   static constexpr dnn::QuantizedActivationMode kModeId =
   2058       dnn::QuantizedActivationMode::k32Bit;
   2059 };
   2060 
   2061 }  // namespace gputools
   2062 }  // namespace perftools
   2063 
   2064 #endif  // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_
   2065