Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/stream_executor/stream.h"
     17 
     18 #include "tensorflow/stream_executor/platform/port.h"
     19 
     20 #include "tensorflow/stream_executor/blas.h"
     21 #include "tensorflow/stream_executor/host_buffer.h"
     22 #include "tensorflow/stream_executor/lib/stacktrace.h"
     23 #include "tensorflow/stream_executor/lib/strcat.h"
     24 #include "tensorflow/stream_executor/platform.h"
     25 #include "tensorflow/stream_executor/platform/logging.h"
     26 #include "tensorflow/stream_executor/rng.h"
     27 #include "tensorflow/stream_executor/stream_executor_internal.h"
     28 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
     29 
     30 namespace perftools {
     31 namespace gputools {
     32 
     33 namespace {
     34 // Code to turn parameters to functions on stream into strings that
     35 // will be VLOG'ed. We need overloads, instead of
     36 // e.g. BatchDescriptorToVlogString(), as the code that calls these
     37 // functions does not know what the type of the parameter is.
     38 string ToVlogString(const dnn::BatchDescriptor &descriptor) {
     39   return descriptor.ToShortString();
     40 }
     41 
     42 string ToVlogString(const dnn::FilterDescriptor &descriptor) {
     43   return descriptor.ToShortString();
     44 }
     45 
     46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
     47   return descriptor.ToShortString();
     48 }
     49 
     50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
     51   return descriptor.ToShortString();
     52 }
     53 
     54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
     55   return descriptor.ToShortString();
     56 }
     57 
     58 string ToVlogString(dnn::ActivationMode mode) {
     59   return dnn::ActivationModeString(mode);
     60 }
     61 
     62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
     63   return algo_config.ToString();
     64 }
     65 
     66 string ToVlogString(dnn::ElementwiseOperation op) {
     67   return dnn::ElementwiseOperationString(op);
     68 }
     69 
     70 string ToVlogString(dnn::QuantizedActivationMode mode) {
     71   return dnn::QuantizedActivationModeString(mode);
     72 }
     73 
     74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
     75 
     76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
     77 
     78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
     79 
     80 string ToVlogString(blas::Side s) { return blas::SideString(s); }
     81 
     82 string ToVlogString(blas::ComputationType ty) {
     83   return blas::ComputationTypeString(ty);
     84 }
     85 
     86 string ToVlogString(const void *ptr) {
     87   if (ptr == nullptr) {
     88     return "null";
     89   }
     90 
     91   // StrCat does not convert pointers to text.
     92   std::ostringstream out;
     93   out << ptr;
     94   return out.str();
     95 }
     96 
     97 string ToVlogString(const HostBuffer &buffer) { return buffer.AsString(); }
     98 
     99 template <class T>
    100 string ToVlogString(const std::complex<T> &c) {
    101   // StrCat does not convert std::complex to text.
    102   std::ostringstream out;
    103   out << c;
    104   return out.str();
    105 }
    106 
    107 template <class T>
    108 string ToVlogString(const std::function<T> &f) {
    109   return f == nullptr ? "null" : "<non-null function>";
    110 }
    111 
    112 string ToVlogString(const DeviceMemoryBase &memory) {
    113   return ToVlogString(memory.opaque());
    114 }
    115 
    116 string ToVlogString(const DeviceMemoryBase *memory) {
    117   return ToVlogString(*memory);
    118 }
    119 
    120 string ToVlogString(const Eigen::half &h) { return port::StrCat(h); }
    121 
    122 string ToVlogString(int i) { return port::StrCat(i); }
    123 
    124 string ToVlogString(uint32 i) { return port::StrCat(i); }
    125 
    126 string ToVlogString(uint64 i) { return port::StrCat(i); }
    127 
    128 string ToVlogString(int64 i) { return port::StrCat(i); }
    129 
    130 string ToVlogString(float f) { return port::StrCat(f); }
    131 
    132 string ToVlogString(double d) { return port::StrCat(d); }
    133 
    134 template <class T>
    135 string ToVlogString(port::ArraySlice<T> elements) {
    136   string str = port::StrCat(
    137       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
    138       elements.size(), "]{");
    139   const char *separator = "";
    140   size_t max_to_show = std::numeric_limits<size_t>::max();
    141   if (!VLOG_IS_ON(2)) {
    142     max_to_show = 5;
    143   } else if (!VLOG_IS_ON(3)) {
    144     max_to_show = 20;
    145   } else if (!VLOG_IS_ON(11)) {
    146     max_to_show = 1000;
    147   }
    148   for (size_t i = 0; i < elements.size(); ++i) {
    149     if (i == max_to_show) {
    150       str += ", ...";
    151       break;
    152     }
    153     port::StrAppend(&str, separator, ToVlogString(elements[i]));
    154     separator = ", ";
    155   }
    156   str += "}";
    157   return str;
    158 }
    159 
    160 template <class T>
    161 string ToVlogString(port::MutableArraySlice<T> elements) {
    162   return ToVlogString(port::ArraySlice<T>(elements));
    163 }
    164 
    165 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
    166   switch (depth_to_space_layout) {
    167     case dnn::DepthToSpaceLayout::DepthHeightWidth:
    168       return "DepthToSpaceLayout::DepthHeightWidth";
    169   }
    170   return "unknown DepthToSpaceLayout";
    171 }
    172 
    173 string ToVlogString(dnn::DataType data_type) {
    174   switch (data_type) {
    175     case dnn::DataType::kFloat:
    176       return "dnn::DataType::kFloat";
    177     case dnn::DataType::kDouble:
    178       return "dnn::DataType::kDouble";
    179     case dnn::DataType::kHalf:
    180       return "dnn::DataType::kHalf";
    181     case dnn::DataType::kInt8:
    182       return "dnn::DataType::kInt8";
    183   }
    184 }
    185 
    186 // Used together with PARAM to VLOG calls made to the stream. Intended
    187 // to be used like this:
    188 //
    189 //   VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
    190 //
    191 // where a and b are the parameters to MyFunction.
    192 //
    193 // See VLOG_CALL for a short-hand for this. This way of doing it saves
    194 // a tremendous amount of boilerplate code given how many functions
    195 // there are on Stream and how many parameters they each have.
    196 string CallStr(const char *function_name, Stream *stream,
    197                std::vector<std::pair<const char *, string>> params) {
    198   // Do not call this function unless VLOG is on since just
    199   // constructing all the strings in params is expensive.
    200   CHECK(VLOG_IS_ON(1));
    201 
    202   string str = port::StrCat("Called Stream::", function_name, "(");
    203   const char *separator = "";
    204   for (const auto &param : params) {
    205     port::StrAppend(&str, separator, param.first, "=", param.second);
    206     separator = ", ";
    207   }
    208   port::StrAppend(&str, ") stream=", ToVlogString(stream));
    209   if (VLOG_IS_ON(10)) {
    210     port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
    211   }
    212   return str;
    213 }
    214 
    215 // Use this macro to avoid having to type every parameter twice to log
    216 // it with VLOG and CallStr.
    217 #define PARAM(parameter) \
    218   { #parameter, ToVlogString(parameter) }
    219 
    220 // Use this macro to avoid having to type out the name of each
    221 // function and to save some boilerplate. Intended to be used like this:
    222 //
    223 //   VLOG_CALL(PARAM(a), PARAM(b))
    224 //
    225 // This saves a tremendous amount of boilerplate compared to the alternative:
    226 //
    227 //   VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
    228 //           << ", b=" << ToVlogString(b);
    229 //
    230 // Note here that most of the parameter names are not short and that
    231 // most of the functions take many more than 2 parameters.
    232 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
    233 
    234 }  // namespace
    235 
    236 Stream::Stream(StreamExecutor *parent)
    237     : parent_(parent),
    238       implementation_(parent->implementation()->GetStreamImplementation()),
    239       allocated_(false),
    240       ok_(false),
    241       temporary_memory_manager_(this) {
    242   VLOG_CALL(PARAM(parent));
    243 }
    244 
    245 Stream::Stream(StreamExecutor *parent,
    246                internal::StreamInterface *implementation)
    247     : parent_(parent),
    248       implementation_(implementation),
    249       allocated_(false),
    250       ok_(false),
    251       temporary_memory_manager_(this) {
    252   VLOG_CALL(PARAM(parent), PARAM(implementation));
    253 }
    254 
    255 Stream::~Stream() {
    256   VLOG_CALL();
    257 
    258   temporary_memory_manager_.ForceDeallocateAll();
    259 
    260   if (allocated_) {
    261     parent_->DeallocateStream(this);
    262   }
    263 }
    264 
    265 Stream &Stream::Init() {
    266   VLOG_CALL();
    267 
    268   mutex_lock lock{mu_};
    269   CHECK_EQ(false, allocated_)
    270       << "stream appears to already have been initialized";
    271   CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
    272 
    273   if (parent_->AllocateStream(this)) {
    274     // Successful initialization!
    275     allocated_ = true;
    276     ok_ = true;
    277   } else {
    278     LOG(ERROR) << "failed to allocate stream during initialization";
    279   }
    280 
    281   return *this;
    282 }
    283 
    284 Stream &Stream::InitTimer(Timer *timer) {
    285   VLOG_CALL(PARAM(timer));
    286 
    287   if (ok()) {
    288     CheckError(parent_->AllocateTimer(timer));
    289   } else {
    290     LOG(INFO) << "did not allocate timer: " << timer;
    291   }
    292   return *this;
    293 }
    294 
    295 Stream &Stream::InitWithTimer(Timer *timer) {
    296   VLOG_CALL(PARAM(timer));
    297 
    298   return Init().InitTimer(timer);
    299 }
    300 
    301 Stream &Stream::ThenRecordEvent(Event *event) {
    302   VLOG_CALL(PARAM(event));
    303 
    304   port::Status status = parent_->RecordEvent(this, event);
    305   if (!status.ok()) {
    306     LOG(ERROR) << "Error recording event in stream: " << status.error_message()
    307                << "; not marking stream as bad, as the Event object may be "
    308                << "at fault. Monitor for further errors.";
    309   }
    310 
    311   return *this;
    312 }
    313 
    314 Stream &Stream::ThenBatchNormalizationForward(
    315     const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
    316     const DeviceMemory<float> &offset,
    317     const DeviceMemory<float> &estimated_mean,
    318     const DeviceMemory<float> &estimated_variance,
    319     const dnn::BatchDescriptor &x_desc,
    320     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    321     DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
    322     DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
    323     DeviceMemory<float> *saved_inv_var, bool is_training,
    324     std::function<const DeviceMemory<float> &()> var_to_inv_var,
    325     std::function<void()> inv_var_to_var) {
    326   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
    327             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
    328   if (ok()) {
    329     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    330       CheckError(dnn->DoBatchNormalizationForward(
    331           this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
    332           scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
    333           saved_inv_var, is_training, std::move(var_to_inv_var),
    334           std::move(inv_var_to_var)));
    335     } else {
    336       SetErrorAndLogNoDnnSupport();
    337     }
    338   }
    339   return *this;
    340 }
    341 
    342 Stream &Stream::ThenBatchNormalizationBackward(
    343     const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
    344     const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
    345     const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
    346     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    347     DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
    348     DeviceMemory<float> *offset_backprop) {
    349   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
    350             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
    351             PARAM(scale_backprop), PARAM(offset_backprop));
    352   if (ok()) {
    353     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    354       CheckError(dnn->DoBatchNormalizationBackward(
    355           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
    356           epsilon, x_backprop, scale_backprop, offset_backprop));
    357     } else {
    358       SetErrorAndLogNoDnnSupport();
    359     }
    360   }
    361   return *this;
    362 }
    363 
    364 Stream &Stream::ThenBatchNormalizationForward(
    365     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
    366     const DeviceMemory<float> &offset,
    367     const DeviceMemory<float> &estimated_mean,
    368     const DeviceMemory<float> &estimated_variance,
    369     const dnn::BatchDescriptor &x_desc,
    370     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    371     DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
    372     DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
    373     DeviceMemory<float> *saved_inv_var, bool is_training,
    374     std::function<const DeviceMemory<float> &()> var_to_inv_var,
    375     std::function<void()> inv_var_to_var) {
    376   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
    377             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
    378   if (ok()) {
    379     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    380       CheckError(dnn->DoBatchNormalizationForward(
    381           this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
    382           scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
    383           saved_inv_var, is_training, std::move(var_to_inv_var),
    384           std::move(inv_var_to_var)));
    385     } else {
    386       SetErrorAndLogNoDnnSupport();
    387     }
    388   }
    389   return *this;
    390 }
    391 
    392 Stream &Stream::ThenBatchNormalizationBackward(
    393     const DeviceMemory<Eigen::half> &y_backprop,
    394     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
    395     const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
    396     const dnn::BatchDescriptor &x_desc,
    397     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    398     DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
    399     DeviceMemory<float> *offset_backprop) {
    400   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
    401             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
    402             PARAM(scale_backprop), PARAM(offset_backprop));
    403   if (ok()) {
    404     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    405       CheckError(dnn->DoBatchNormalizationBackward(
    406           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
    407           epsilon, x_backprop, scale_backprop, offset_backprop));
    408     } else {
    409       SetErrorAndLogNoDnnSupport();
    410     }
    411   }
    412   return *this;
    413 }
    414 
    415 Stream &Stream::ThenFusedConvolveWithScratch(
    416     const dnn::BatchDescriptor &conv_input_descriptor,
    417     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
    418     const dnn::FilterDescriptor &filter_descriptor,
    419     const DeviceMemory<int8> &filter_data,
    420     const dnn::ConvolutionDescriptor &convolution_descriptor,
    421     const DeviceMemory<int8> &side_input_data, float side_input_scale,
    422     const dnn::BatchDescriptor &bias_descriptor,
    423     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    424     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
    425     ScratchAllocator *scratch_allocator) {
    426   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    427             PARAM(conv_input_scale), PARAM(filter_descriptor),
    428             PARAM(filter_data), PARAM(convolution_descriptor),
    429             PARAM(side_input_data), PARAM(side_input_scale),
    430             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
    431             PARAM(output_descriptor), PARAM(output));
    432 
    433   if (ok()) {
    434     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    435       CheckError(dnn->DoFusedConvolve(
    436           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    437           filter_descriptor, filter_data, convolution_descriptor,
    438           side_input_data, side_input_scale, bias_descriptor, biases,
    439           activation_mode, output_descriptor, output, scratch_allocator,
    440           dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
    441     } else {
    442       SetErrorAndLogNoDnnSupport();
    443     }
    444   }
    445   return *this;
    446 }
    447 
    448 Stream &Stream::ThenFusedConvolveWithScratch(
    449     const dnn::BatchDescriptor &conv_input_descriptor,
    450     const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
    451     const dnn::FilterDescriptor &filter_descriptor,
    452     const DeviceMemory<Eigen::half> &filter_data,
    453     const dnn::ConvolutionDescriptor &convolution_descriptor,
    454     const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
    455     const dnn::BatchDescriptor &bias_descriptor,
    456     const DeviceMemory<Eigen::half> &biases,
    457     dnn::ActivationMode activation_mode,
    458     const dnn::BatchDescriptor &output_descriptor,
    459     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
    460   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    461             PARAM(conv_input_scale), PARAM(filter_descriptor),
    462             PARAM(filter_data), PARAM(convolution_descriptor),
    463             PARAM(side_input_data), PARAM(side_input_scale),
    464             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
    465             PARAM(output_descriptor), PARAM(output));
    466 
    467   if (ok()) {
    468     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    469       CheckError(dnn->DoFusedConvolve(
    470           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    471           filter_descriptor, filter_data, convolution_descriptor,
    472           side_input_data, side_input_scale, bias_descriptor, biases,
    473           activation_mode, output_descriptor, output, scratch_allocator,
    474           dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
    475     } else {
    476       SetErrorAndLogNoDnnSupport();
    477     }
    478   }
    479   return *this;
    480 }
    481 
    482 Stream &Stream::ThenFusedConvolveWithScratch(
    483     const dnn::BatchDescriptor &conv_input_descriptor,
    484     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
    485     const dnn::FilterDescriptor &filter_descriptor,
    486     const DeviceMemory<float> &filter_data,
    487     const dnn::ConvolutionDescriptor &convolution_descriptor,
    488     const DeviceMemory<float> &side_input_data, float side_input_scale,
    489     const dnn::BatchDescriptor &bias_descriptor,
    490     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    491     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
    492     ScratchAllocator *scratch_allocator) {
    493   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    494             PARAM(conv_input_scale), PARAM(filter_descriptor),
    495             PARAM(filter_data), PARAM(convolution_descriptor),
    496             PARAM(side_input_data), PARAM(side_input_scale),
    497             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
    498             PARAM(output_descriptor), PARAM(output));
    499 
    500   if (ok()) {
    501     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    502       CheckError(dnn->DoFusedConvolve(
    503           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    504           filter_descriptor, filter_data, convolution_descriptor,
    505           side_input_data, side_input_scale, bias_descriptor, biases,
    506           activation_mode, output_descriptor, output, scratch_allocator,
    507           dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
    508     } else {
    509       SetErrorAndLogNoDnnSupport();
    510     }
    511   }
    512   return *this;
    513 }
    514 
    515 Stream &Stream::ThenConvolveWithScratch(
    516     const dnn::BatchDescriptor &input_descriptor,
    517     const DeviceMemory<Eigen::half> &input_data,
    518     const dnn::FilterDescriptor &filter_descriptor,
    519     const DeviceMemory<Eigen::half> &filter_data,
    520     const dnn::ConvolutionDescriptor &convolution_descriptor,
    521     const dnn::BatchDescriptor &output_descriptor,
    522     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
    523   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    524             PARAM(filter_descriptor), PARAM(filter_data),
    525             PARAM(convolution_descriptor), PARAM(output_descriptor),
    526             PARAM(output));
    527 
    528   if (ok()) {
    529     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    530       CheckError(dnn->DoConvolve(
    531           this, input_descriptor, input_data, filter_descriptor, filter_data,
    532           convolution_descriptor, output_descriptor, output, scratch_allocator,
    533           dnn::AlgorithmConfig(),
    534           /*output_profile_result=*/nullptr));
    535     } else {
    536       SetErrorAndLogNoDnnSupport();
    537     }
    538   }
    539   return *this;
    540 }
    541 
    542 Stream &Stream::ThenConvolveWithScratch(
    543     const dnn::BatchDescriptor &input_descriptor,
    544     const DeviceMemory<float> &input_data,
    545     const dnn::FilterDescriptor &filter_descriptor,
    546     const DeviceMemory<float> &filter_data,
    547     const dnn::ConvolutionDescriptor &convolution_descriptor,
    548     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
    549     ScratchAllocator *scratch_allocator) {
    550   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    551             PARAM(filter_descriptor), PARAM(filter_data),
    552             PARAM(convolution_descriptor), PARAM(output_descriptor),
    553             PARAM(output));
    554 
    555   if (ok()) {
    556     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    557       CheckError(dnn->DoConvolve(
    558           this, input_descriptor, input_data, filter_descriptor, filter_data,
    559           convolution_descriptor, output_descriptor, output, scratch_allocator,
    560           dnn::AlgorithmConfig(),
    561           /*output_profile_result=*/nullptr));
    562     } else {
    563       SetErrorAndLogNoDnnSupport();
    564     }
    565   }
    566   return *this;
    567 }
    568 
    569 Stream &Stream::ThenFusedConvolveWithAlgorithm(
    570     const dnn::BatchDescriptor &conv_input_descriptor,
    571     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
    572     const dnn::FilterDescriptor &filter_descriptor,
    573     const DeviceMemory<float> &filter_data,
    574     const dnn::ConvolutionDescriptor &convolution_descriptor,
    575     const DeviceMemory<float> &side_input_data, float side_input_scale,
    576     const dnn::BatchDescriptor &bias_descriptor,
    577     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    578     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
    579     ScratchAllocator *scratch_allocator,
    580     const dnn::AlgorithmConfig &algorithm_config,
    581     dnn::ProfileResult *output_profile_result) {
    582   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    583             PARAM(conv_input_scale), PARAM(filter_descriptor),
    584             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
    585             PARAM(side_input_data), PARAM(side_input_scale),
    586             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
    587             PARAM(algorithm_config));
    588 
    589   if (ok()) {
    590     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    591       auto status = dnn->DoFusedConvolve(
    592           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    593           filter_descriptor, filter_data, convolution_descriptor,
    594           side_input_data, side_input_scale, bias_descriptor, biases,
    595           activation_mode, output_descriptor, output, scratch_allocator,
    596           algorithm_config, output_profile_result);
    597       if (!status && !output_profile_result) {
    598         SetError();
    599       }
    600     } else {
    601       SetErrorAndLogNoDnnSupport();
    602     }
    603   }
    604   return *this;
    605 }
    606 
    607 Stream &Stream::ThenFusedConvolveWithAlgorithm(
    608     const dnn::BatchDescriptor &conv_input_descriptor,
    609     const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
    610     const dnn::FilterDescriptor &filter_descriptor,
    611     const DeviceMemory<Eigen::half> &filter_data,
    612     const dnn::ConvolutionDescriptor &convolution_descriptor,
    613     const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
    614     const dnn::BatchDescriptor &bias_descriptor,
    615     const DeviceMemory<Eigen::half> &biases,
    616     dnn::ActivationMode activation_mode,
    617     const dnn::BatchDescriptor &output_descriptor,
    618     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
    619     const dnn::AlgorithmConfig &algorithm_config,
    620     dnn::ProfileResult *output_profile_result) {
    621   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    622             PARAM(conv_input_scale), PARAM(filter_descriptor),
    623             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
    624             PARAM(side_input_data), PARAM(side_input_scale),
    625             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
    626             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
    627 
    628   if (ok()) {
    629     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    630       auto status = dnn->DoFusedConvolve(
    631           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    632           filter_descriptor, filter_data, convolution_descriptor,
    633           side_input_data, side_input_scale, bias_descriptor, biases,
    634           activation_mode, output_descriptor, output, scratch_allocator,
    635           algorithm_config, output_profile_result);
    636       if (!status && !output_profile_result) {
    637         SetError();
    638       }
    639     } else {
    640       SetErrorAndLogNoDnnSupport();
    641     }
    642   }
    643   return *this;
    644 }
    645 
    646 Stream &Stream::ThenFusedConvolveWithAlgorithm(
    647     const dnn::BatchDescriptor &conv_input_descriptor,
    648     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
    649     const dnn::FilterDescriptor &filter_descriptor,
    650     const DeviceMemory<int8> &filter_data,
    651     const dnn::ConvolutionDescriptor &convolution_descriptor,
    652     const DeviceMemory<int8> &side_input_data, float side_input_scale,
    653     const dnn::BatchDescriptor &bias_descriptor,
    654     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    655     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
    656     ScratchAllocator *scratch_allocator,
    657     const dnn::AlgorithmConfig &algorithm_config,
    658     dnn::ProfileResult *output_profile_result) {
    659   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    660             PARAM(conv_input_scale), PARAM(filter_descriptor),
    661             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
    662             PARAM(side_input_data), PARAM(side_input_scale),
    663             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
    664             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
    665 
    666   if (ok()) {
    667     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    668       auto status = dnn->DoFusedConvolve(
    669           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    670           filter_descriptor, filter_data, convolution_descriptor,
    671           side_input_data, side_input_scale, bias_descriptor, biases,
    672           activation_mode, output_descriptor, output, scratch_allocator,
    673           algorithm_config, output_profile_result);
    674       if (!status && !output_profile_result) {
    675         SetError();
    676       }
    677     } else {
    678       SetErrorAndLogNoDnnSupport();
    679     }
    680   }
    681   return *this;
    682 }
    683 
    684 Stream &Stream::ThenConvolveWithAlgorithm(
    685     const dnn::BatchDescriptor &input_descriptor,
    686     const DeviceMemory<float> &input_data,
    687     const dnn::FilterDescriptor &filter_descriptor,
    688     const DeviceMemory<float> &filter_data,
    689     const dnn::ConvolutionDescriptor &convolution_descriptor,
    690     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
    691     ScratchAllocator *scratch_allocator,
    692     const dnn::AlgorithmConfig &algorithm_config,
    693     dnn::ProfileResult *output_profile_result) {
    694   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    695             PARAM(filter_descriptor), PARAM(filter_data),
    696             PARAM(convolution_descriptor), PARAM(output_descriptor),
    697             PARAM(output), PARAM(algorithm_config));
    698 
    699   if (ok()) {
    700     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    701       auto status = dnn->DoConvolve(
    702           this, input_descriptor, input_data, filter_descriptor, filter_data,
    703           convolution_descriptor, output_descriptor, output, scratch_allocator,
    704           algorithm_config, output_profile_result);
    705       if (!status && !output_profile_result) {
    706         SetError();
    707       }
    708     } else {
    709       SetErrorAndLogNoDnnSupport();
    710     }
    711   }
    712   return *this;
    713 }
    714 
    715 Stream &Stream::ThenConvolveWithAlgorithm(
    716     const dnn::BatchDescriptor &input_descriptor,
    717     const DeviceMemory<Eigen::half> &input_data,
    718     const dnn::FilterDescriptor &filter_descriptor,
    719     const DeviceMemory<Eigen::half> &filter_data,
    720     const dnn::ConvolutionDescriptor &convolution_descriptor,
    721     const dnn::BatchDescriptor &output_descriptor,
    722     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
    723     const dnn::AlgorithmConfig &algorithm_config,
    724     dnn::ProfileResult *output_profile_result) {
    725   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    726             PARAM(filter_descriptor), PARAM(filter_data),
    727             PARAM(convolution_descriptor), PARAM(output_descriptor),
    728             PARAM(output), PARAM(algorithm_config));
    729 
    730   if (ok()) {
    731     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    732       auto status = dnn->DoConvolve(
    733           this, input_descriptor, input_data, filter_descriptor, filter_data,
    734           convolution_descriptor, output_descriptor, output, scratch_allocator,
    735           algorithm_config, output_profile_result);
    736       if (!status && !output_profile_result) {
    737         SetError();
    738       }
    739     } else {
    740       SetErrorAndLogNoDnnSupport();
    741     }
    742   }
    743   return *this;
    744 }
    745 
    746 Stream &Stream::ThenFusedConvolve(
    747     const dnn::BatchDescriptor &conv_input_descriptor,
    748     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
    749     const dnn::FilterDescriptor &filter_descriptor,
    750     const DeviceMemory<int8> &filter_data,
    751     const dnn::ConvolutionDescriptor &convolution_descriptor,
    752     const DeviceMemory<int8> &side_input_data, float side_input_scale,
    753     const dnn::BatchDescriptor &bias_descriptor,
    754     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    755     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output) {
    756   return ThenFusedConvolveWithScratch(
    757       conv_input_descriptor, conv_input_data, conv_input_scale,
    758       filter_descriptor, filter_data, convolution_descriptor, side_input_data,
    759       side_input_scale, bias_descriptor, biases, activation_mode,
    760       output_descriptor, output,
    761       /*scratch_allocator=*/nullptr);
    762 }
    763 
    764 Stream &Stream::ThenConvolve(
    765     const dnn::BatchDescriptor &input_descriptor,
    766     const DeviceMemory<float> &input_data,
    767     const dnn::FilterDescriptor &filter_descriptor,
    768     const DeviceMemory<float> &filter_data,
    769     const dnn::ConvolutionDescriptor &convolution_descriptor,
    770     const dnn::BatchDescriptor &output_descriptor,
    771     DeviceMemory<float> *output) {
    772   return ThenConvolveWithScratch(input_descriptor, input_data,
    773                                  filter_descriptor, filter_data,
    774                                  convolution_descriptor, output_descriptor,
    775                                  output, /*scratch_allocator=*/nullptr);
    776 }
    777 
    778 Stream &Stream::ThenConvolveQuantized(
    779     const dnn::BatchDescriptor &input_descriptor,
    780     const DeviceMemory<float> &input_data,
    781     const dnn::FilterDescriptor &filter_descriptor,
    782     const DeviceMemory<int8> &filter_coefficients,
    783     const DeviceMemory<float> &coefficient_scales,
    784     const dnn::ConvolutionDescriptor &convolution_descriptor,
    785     const dnn::BatchDescriptor &output_descriptor,
    786     DeviceMemory<float> *output) {
    787   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    788             PARAM(filter_descriptor), PARAM(filter_coefficients),
    789             PARAM(coefficient_scales), PARAM(convolution_descriptor),
    790             PARAM(output_descriptor), PARAM(output));
    791 
    792   if (ok()) {
    793     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    794       CheckError(dnn->DoConvolveQuantized(
    795           this, input_descriptor, input_data, filter_descriptor,
    796           filter_coefficients, coefficient_scales, convolution_descriptor,
    797           output_descriptor, output));
    798     } else {
    799       SetError();
    800       LOG(WARNING)
    801           << "attempting to perform DNN operation using StreamExecutor "
    802              "without DNN support";
    803     }
    804   }
    805   return *this;
    806 }
    807 
    808 Stream &Stream::ThenConvolveQuantized(
    809     const dnn::BatchDescriptor &input_descriptor,
    810     const DeviceMemory<float> &input_data,
    811     const dnn::FilterDescriptor &filter_descriptor,
    812     const DeviceMemory<int16> &filter_coefficients,
    813     const DeviceMemory<float> &coefficient_scales,
    814     const dnn::ConvolutionDescriptor &convolution_descriptor,
    815     const dnn::BatchDescriptor &output_descriptor,
    816     DeviceMemory<float> *output) {
    817   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    818             PARAM(filter_descriptor), PARAM(filter_coefficients),
    819             PARAM(coefficient_scales), PARAM(convolution_descriptor),
    820             PARAM(output_descriptor), PARAM(output));
    821 
    822   if (ok()) {
    823     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    824       CheckError(dnn->DoConvolveQuantized(
    825           this, input_descriptor, input_data, filter_descriptor,
    826           filter_coefficients, coefficient_scales, convolution_descriptor,
    827           output_descriptor, output));
    828     } else {
    829       SetError();
    830       LOG(WARNING)
    831           << "attempting to perform DNN operation using StreamExecutor "
    832              "without DNN support";
    833     }
    834   }
    835   return *this;
    836 }
    837 
    838 Stream &Stream::ThenSeparableConvolve(
    839     const dnn::BatchDescriptor &batch_descriptor,
    840     const DeviceMemory<float> &input_data,
    841     const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
    842     const DeviceMemory<float> &first_weights,
    843     const DeviceMemory<float> &second_weights,
    844     const dnn::ConvolutionDescriptor &convolution_descriptor,
    845     const dnn::BatchDescriptor &output_descriptor,
    846     DeviceMemory<float> *output) {
    847   VLOG_CALL(
    848       PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
    849       PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
    850       PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
    851 
    852   if (ok()) {
    853     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    854       CheckError(dnn->DoSeparableConvolve(
    855           this, batch_descriptor, input_data, filter_descriptor,
    856           depth_multiplier, first_weights, second_weights,
    857           convolution_descriptor, output_descriptor, output));
    858     } else {
    859       SetErrorAndLogNoDnnSupport();
    860     }
    861   }
    862   return *this;
    863 }
    864 
    865 Stream &Stream::ThenConvolveBackwardDataWithScratch(
    866     const dnn::FilterDescriptor &filter_descriptor,
    867     const DeviceMemory<float> &filter_data,
    868     const dnn::BatchDescriptor &output_descriptor,
    869     DeviceMemory<float> backward_output_data,
    870     const dnn::ConvolutionDescriptor &convolution_descriptor,
    871     const dnn::BatchDescriptor &input_descriptor,
    872     DeviceMemory<float> *backward_input_data,
    873     ScratchAllocator *scratch_allocator) {
    874   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
    875             PARAM(output_descriptor), PARAM(backward_output_data),
    876             PARAM(convolution_descriptor), PARAM(input_descriptor),
    877             PARAM(backward_input_data));
    878 
    879   if (ok()) {
    880     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    881       CheckError(dnn->DoConvolveBackwardData(
    882           this, filter_descriptor, filter_data, output_descriptor,
    883           backward_output_data, convolution_descriptor, input_descriptor,
    884           backward_input_data, scratch_allocator, dnn::AlgorithmConfig(),
    885           /*output_profile_result=*/nullptr));
    886     } else {
    887       SetErrorAndLogNoDnnSupport();
    888     }
    889   }
    890   return *this;
    891 }
    892 
    893 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
    894     const dnn::FilterDescriptor &filter_descriptor,
    895     const DeviceMemory<float> &filter_data,
    896     const dnn::BatchDescriptor &output_descriptor,
    897     DeviceMemory<float> backward_output_data,
    898     const dnn::ConvolutionDescriptor &convolution_descriptor,
    899     const dnn::BatchDescriptor &input_descriptor,
    900     DeviceMemory<float> *backward_input_data,
    901     ScratchAllocator *scratch_allocator,
    902     const dnn::AlgorithmConfig &algorithm_config,
    903     dnn::ProfileResult *output_profile_result) {
    904   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
    905             PARAM(output_descriptor), PARAM(backward_output_data),
    906             PARAM(convolution_descriptor), PARAM(input_descriptor),
    907             PARAM(backward_input_data));
    908 
    909   if (ok()) {
    910     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    911       auto status = dnn->DoConvolveBackwardData(
    912           this, filter_descriptor, filter_data, output_descriptor,
    913           backward_output_data, convolution_descriptor, input_descriptor,
    914           backward_input_data, scratch_allocator, algorithm_config,
    915           output_profile_result);
    916       if (!status && !output_profile_result) {
    917         SetError();
    918       }
    919     } else {
    920       SetErrorAndLogNoDnnSupport();
    921     }
    922   }
    923   return *this;
    924 }
    925 
    926 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
    927     const dnn::FilterDescriptor &filter_descriptor,
    928     const DeviceMemory<Eigen::half> &filter_data,
    929     const dnn::BatchDescriptor &output_descriptor,
    930     DeviceMemory<Eigen::half> backward_output_data,
    931     const dnn::ConvolutionDescriptor &convolution_descriptor,
    932     const dnn::BatchDescriptor &input_descriptor,
    933     DeviceMemory<Eigen::half> *backward_input_data,
    934     ScratchAllocator *scratch_allocator,
    935     const dnn::AlgorithmConfig &algorithm_config,
    936     dnn::ProfileResult *output_profile_result) {
    937   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
    938             PARAM(output_descriptor), PARAM(backward_output_data),
    939             PARAM(convolution_descriptor), PARAM(input_descriptor),
    940             PARAM(backward_input_data));
    941 
    942   if (ok()) {
    943     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    944       auto status = dnn->DoConvolveBackwardData(
    945           this, filter_descriptor, filter_data, output_descriptor,
    946           backward_output_data, convolution_descriptor, input_descriptor,
    947           backward_input_data, scratch_allocator, algorithm_config,
    948           output_profile_result);
    949       if (!status && !output_profile_result) {
    950         SetError();
    951       }
    952     } else {
    953       SetErrorAndLogNoDnnSupport();
    954     }
    955   }
    956   return *this;
    957 }
    958 
    959 Stream &Stream::ThenConvolveBackwardDataWithScratch(
    960     const dnn::FilterDescriptor &filter_descriptor,
    961     const DeviceMemory<Eigen::half> &filter_data,
    962     const dnn::BatchDescriptor &output_descriptor,
    963     DeviceMemory<Eigen::half> backward_output_data,
    964     const dnn::ConvolutionDescriptor &convolution_descriptor,
    965     const dnn::BatchDescriptor &input_descriptor,
    966     DeviceMemory<Eigen::half> *backward_input_data,
    967     ScratchAllocator *scratch_allocator) {
    968   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
    969             PARAM(output_descriptor), PARAM(backward_output_data),
    970             PARAM(convolution_descriptor), PARAM(input_descriptor),
    971             PARAM(backward_input_data));
    972 
    973   if (ok()) {
    974     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    975       CheckError(dnn->DoConvolveBackwardData(
    976           this, filter_descriptor, filter_data, output_descriptor,
    977           backward_output_data, convolution_descriptor, input_descriptor,
    978           backward_input_data, scratch_allocator, dnn::AlgorithmConfig(),
    979           /*output_profile_result=*/nullptr));
    980     } else {
    981       SetErrorAndLogNoDnnSupport();
    982     }
    983   }
    984   return *this;
    985 }
    986 
    987 Stream &Stream::ThenConvolveBackwardData(
    988     const dnn::FilterDescriptor &filter_descriptor,
    989     const DeviceMemory<float> &filter_data,
    990     const dnn::BatchDescriptor &output_descriptor,
    991     DeviceMemory<float> backward_output_data,
    992     const dnn::ConvolutionDescriptor &convolution_descriptor,
    993     const dnn::BatchDescriptor &input_descriptor,
    994     DeviceMemory<float> *backward_input_data) {
    995   return ThenConvolveBackwardDataWithScratch(
    996       filter_descriptor, filter_data, output_descriptor, backward_output_data,
    997       convolution_descriptor, input_descriptor, backward_input_data,
    998       /*scratch_allocator=*/nullptr);
    999 }
   1000 
   1001 Stream &Stream::ThenConvolveBackwardFilterWithScratch(
   1002     const dnn::BatchDescriptor &input_descriptor,
   1003     const DeviceMemory<float> &input_data,
   1004     const dnn::BatchDescriptor &output_descriptor,
   1005     DeviceMemory<float> backward_output_data,
   1006     const dnn::ConvolutionDescriptor &convolution_descriptor,
   1007     const dnn::FilterDescriptor &filter_descriptor,
   1008     DeviceMemory<float> *backward_filter_data,
   1009     ScratchAllocator *scratch_allocator) {
   1010   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
   1011             PARAM(output_descriptor), PARAM(backward_output_data),
   1012             PARAM(convolution_descriptor), PARAM(filter_descriptor),
   1013             PARAM(backward_filter_data));
   1014 
   1015   if (ok()) {
   1016     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1017       CheckError(dnn->DoConvolveBackwardFilter(
   1018           this, input_descriptor, input_data, output_descriptor,
   1019           backward_output_data, convolution_descriptor, filter_descriptor,
   1020           backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(),
   1021           /*output_profile_result=*/nullptr));
   1022     } else {
   1023       SetErrorAndLogNoDnnSupport();
   1024     }
   1025   }
   1026   return *this;
   1027 }
   1028 
   1029 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
   1030     const dnn::BatchDescriptor &input_descriptor,
   1031     const DeviceMemory<float> &input_data,
   1032     const dnn::BatchDescriptor &output_descriptor,
   1033     DeviceMemory<float> backward_output_data,
   1034     const dnn::ConvolutionDescriptor &convolution_descriptor,
   1035     const dnn::FilterDescriptor &filter_descriptor,
   1036     DeviceMemory<float> *backward_filter_data,
   1037     ScratchAllocator *scratch_allocator,
   1038     const dnn::AlgorithmConfig &algorithm_config,
   1039     dnn::ProfileResult *output_profile_result) {
   1040   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
   1041             PARAM(output_descriptor), PARAM(backward_output_data),
   1042             PARAM(convolution_descriptor), PARAM(filter_descriptor),
   1043             PARAM(backward_filter_data));
   1044 
   1045   if (ok()) {
   1046     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1047       auto status = dnn->DoConvolveBackwardFilter(
   1048           this, input_descriptor, input_data, output_descriptor,
   1049           backward_output_data, convolution_descriptor, filter_descriptor,
   1050           backward_filter_data, scratch_allocator, algorithm_config,
   1051           output_profile_result);
   1052       if (!status && !output_profile_result) {
   1053         SetError();
   1054       }
   1055     } else {
   1056       SetErrorAndLogNoDnnSupport();
   1057     }
   1058   }
   1059   return *this;
   1060 }
   1061 
   1062 Stream &Stream::ThenConvolveBackwardFilterWithScratch(
   1063     const dnn::BatchDescriptor &input_descriptor,
   1064     const DeviceMemory<Eigen::half> &input_data,
   1065     const dnn::BatchDescriptor &output_descriptor,
   1066     DeviceMemory<Eigen::half> backward_output_data,
   1067     const dnn::ConvolutionDescriptor &convolution_descriptor,
   1068     const dnn::FilterDescriptor &filter_descriptor,
   1069     DeviceMemory<Eigen::half> *backward_filter_data,
   1070     ScratchAllocator *scratch_allocator) {
   1071   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
   1072             PARAM(output_descriptor), PARAM(backward_output_data),
   1073             PARAM(convolution_descriptor), PARAM(filter_descriptor),
   1074             PARAM(backward_filter_data));
   1075 
   1076   if (ok()) {
   1077     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1078       CheckError(dnn->DoConvolveBackwardFilter(
   1079           this, input_descriptor, input_data, output_descriptor,
   1080           backward_output_data, convolution_descriptor, filter_descriptor,
   1081           backward_filter_data, scratch_allocator, dnn::AlgorithmConfig(),
   1082           /*output_profile_result=*/nullptr));
   1083     } else {
   1084       SetErrorAndLogNoDnnSupport();
   1085     }
   1086   }
   1087   return *this;
   1088 }
   1089 
   1090 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
   1091     const dnn::BatchDescriptor &input_descriptor,
   1092     const DeviceMemory<Eigen::half> &input_data,
   1093     const dnn::BatchDescriptor &output_descriptor,
   1094     DeviceMemory<Eigen::half> backward_output_data,
   1095     const dnn::ConvolutionDescriptor &convolution_descriptor,
   1096     const dnn::FilterDescriptor &filter_descriptor,
   1097     DeviceMemory<Eigen::half> *backward_filter_data,
   1098     ScratchAllocator *scratch_allocator,
   1099     const dnn::AlgorithmConfig &algorithm_config,
   1100     dnn::ProfileResult *output_profile_result) {
   1101   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
   1102             PARAM(output_descriptor), PARAM(backward_output_data),
   1103             PARAM(convolution_descriptor), PARAM(filter_descriptor),
   1104             PARAM(backward_filter_data));
   1105 
   1106   if (ok()) {
   1107     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1108       auto status = dnn->DoConvolveBackwardFilter(
   1109           this, input_descriptor, input_data, output_descriptor,
   1110           backward_output_data, convolution_descriptor, filter_descriptor,
   1111           backward_filter_data, scratch_allocator, algorithm_config,
   1112           output_profile_result);
   1113       if (!status && !output_profile_result) {
   1114         SetError();
   1115       }
   1116     } else {
   1117       SetErrorAndLogNoDnnSupport();
   1118     }
   1119   }
   1120   return *this;
   1121 }
   1122 
   1123 Stream &Stream::ThenConvolveBackwardFilter(
   1124     const dnn::BatchDescriptor &input_descriptor,
   1125     const DeviceMemory<float> &input_data,
   1126     const dnn::BatchDescriptor &output_descriptor,
   1127     DeviceMemory<float> backward_output_data,
   1128     const dnn::ConvolutionDescriptor &convolution_descriptor,
   1129     const dnn::FilterDescriptor &filter_descriptor,
   1130     DeviceMemory<float> *backward_filter_data) {
   1131   return ThenConvolveBackwardFilterWithScratch(
   1132       input_descriptor, input_data, output_descriptor, backward_output_data,
   1133       convolution_descriptor, filter_descriptor, backward_filter_data,
   1134       /*scratch_allocator=*/nullptr);
   1135 }
   1136 
   1137 template <typename T>
   1138 Stream &Stream::ThenConvolveBackwardBiasImpl(
   1139     const dnn::BatchDescriptor &input_descriptor,
   1140     const DeviceMemory<T> &input_data,
   1141     const dnn::BatchDescriptor &bias_descriptor,
   1142     DeviceMemory<T> *backward_bias_data) {
   1143   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
   1144             PARAM(backward_bias_data));
   1145 
   1146   if (ok()) {
   1147     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1148       CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
   1149                                              bias_descriptor,
   1150                                              backward_bias_data));
   1151     } else {
   1152       SetErrorAndLogNoDnnSupport();
   1153     }
   1154   }
   1155   return *this;
   1156 }
   1157 
   1158 Stream &Stream::ThenConvolveBackwardBias(
   1159     const dnn::BatchDescriptor &input_descriptor,
   1160     const DeviceMemory<double> &input_data,
   1161     const dnn::BatchDescriptor &bias_descriptor,
   1162     DeviceMemory<double> *backward_bias_data) {
   1163   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
   1164                                       bias_descriptor, backward_bias_data);
   1165 }
   1166 
   1167 Stream &Stream::ThenConvolveBackwardBias(
   1168     const dnn::BatchDescriptor &input_descriptor,
   1169     const DeviceMemory<float> &input_data,
   1170     const dnn::BatchDescriptor &bias_descriptor,
   1171     DeviceMemory<float> *backward_bias_data) {
   1172   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
   1173                                       bias_descriptor, backward_bias_data);
   1174 }
   1175 
   1176 Stream &Stream::ThenConvolveBackwardBias(
   1177     const dnn::BatchDescriptor &input_descriptor,
   1178     const DeviceMemory<Eigen::half> &input_data,
   1179     const dnn::BatchDescriptor &bias_descriptor,
   1180     DeviceMemory<Eigen::half> *backward_bias_data) {
   1181   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
   1182                                       bias_descriptor, backward_bias_data);
   1183 }
   1184 
   1185 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
   1186                            const DeviceMemory<float> &weights,
   1187                            const dnn::BatchDescriptor &input_dimensions,
   1188                            const dnn::BatchDescriptor &output_dimensions,
   1189                            DeviceMemory<float> *output_data) {
   1190   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
   1191             PARAM(output_dimensions), PARAM(output_data));
   1192 
   1193   if (ok()) {
   1194     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1195       CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
   1196                                output_dimensions, output_data));
   1197     } else {
   1198       SetErrorAndLogNoDnnSupport();
   1199     }
   1200   }
   1201   return *this;
   1202 }
   1203 
   1204 Stream &Stream::ThenMatMulQuantized(
   1205     const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
   1206     const DeviceMemory<float> &weight_scales,
   1207     const dnn::BatchDescriptor &input_dimensions,
   1208     const dnn::BatchDescriptor &output_dimensions,
   1209     DeviceMemory<float> *output_data) {
   1210   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
   1211             PARAM(input_dimensions), PARAM(output_dimensions),
   1212             PARAM(output_data));
   1213 
   1214   if (ok()) {
   1215     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1216       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
   1217                                         weight_scales, input_dimensions,
   1218                                         output_dimensions, output_data));
   1219     } else {
   1220       SetErrorAndLogNoDnnSupport();
   1221     }
   1222   }
   1223   return *this;
   1224 }
   1225 
   1226 Stream &Stream::ThenMatMulQuantized(
   1227     const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
   1228     const DeviceMemory<float> &weight_scales,
   1229     const dnn::BatchDescriptor &input_dimensions,
   1230     const dnn::BatchDescriptor &output_dimensions,
   1231     DeviceMemory<float> *output_data) {
   1232   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
   1233             PARAM(input_dimensions), PARAM(output_dimensions),
   1234             PARAM(output_data));
   1235 
   1236   if (ok()) {
   1237     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1238       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
   1239                                         weight_scales, input_dimensions,
   1240                                         output_dimensions, output_data));
   1241     } else {
   1242       SetErrorAndLogNoDnnSupport();
   1243     }
   1244   }
   1245   return *this;
   1246 }
   1247 
   1248 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
   1249                             const DeviceMemory<float> &biases,
   1250                             const dnn::BatchDescriptor &dimensions,
   1251                             DeviceMemory<float> *output_data) {
   1252   VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
   1253             PARAM(output_data));
   1254 
   1255   if (ok()) {
   1256     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1257       CheckError(
   1258           dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
   1259     } else {
   1260       SetErrorAndLogNoDnnSupport();
   1261     }
   1262   }
   1263   return *this;
   1264 }
   1265 
   1266 Stream &Stream::ThenPoolForward(
   1267     const dnn::PoolingDescriptor &pooling_dimensions,
   1268     const dnn::BatchDescriptor &input_dimensions,
   1269     const DeviceMemory<double> &input_data,
   1270     const dnn::BatchDescriptor &output_dimensions,
   1271     DeviceMemory<double> *output_data) {
   1272   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1273             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
   1274 
   1275   if (ok()) {
   1276     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1277       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
   1278                                     input_data, output_dimensions,
   1279                                     output_data));
   1280     } else {
   1281       SetError();
   1282       LOG(WARNING)
   1283           << "attempting to perform DNN operation using StreamExecutor "
   1284              "without DNN support";
   1285     }
   1286   }
   1287   return *this;
   1288 }
   1289 
   1290 Stream &Stream::ThenPoolForward(
   1291     const dnn::PoolingDescriptor &pooling_dimensions,
   1292     const dnn::BatchDescriptor &input_dimensions,
   1293     const DeviceMemory<float> &input_data,
   1294     const dnn::BatchDescriptor &output_dimensions,
   1295     DeviceMemory<float> *output_data) {
   1296   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1297             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
   1298 
   1299   if (ok()) {
   1300     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1301       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
   1302                                     input_data, output_dimensions,
   1303                                     output_data));
   1304     } else {
   1305       SetErrorAndLogNoDnnSupport();
   1306     }
   1307   }
   1308   return *this;
   1309 }
   1310 
   1311 Stream &Stream::ThenPoolForward(
   1312     const dnn::PoolingDescriptor &pooling_dimensions,
   1313     const dnn::BatchDescriptor &input_dimensions,
   1314     const DeviceMemory<Eigen::half> &input_data,
   1315     const dnn::BatchDescriptor &output_dimensions,
   1316     DeviceMemory<Eigen::half> *output_data) {
   1317   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1318             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
   1319 
   1320   if (ok()) {
   1321     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1322       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
   1323                                     input_data, output_dimensions,
   1324                                     output_data));
   1325     } else {
   1326       SetErrorAndLogNoDnnSupport();
   1327     }
   1328   }
   1329   return *this;
   1330 }
   1331 
   1332 Stream &Stream::ThenPoolBackward(
   1333     const dnn::PoolingDescriptor &pooling_dimensions,
   1334     const dnn::BatchDescriptor &input_dimensions,
   1335     const DeviceMemory<double> &input_data,
   1336     const dnn::BatchDescriptor &output_dimensions,
   1337     const DeviceMemory<double> &output_data,
   1338     const DeviceMemory<double> &input_diff_data,
   1339     DeviceMemory<double> *output_diff_data) {
   1340   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1341             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1342             PARAM(input_diff_data), PARAM(output_diff_data));
   1343 
   1344   if (ok()) {
   1345     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1346       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
   1347                                      input_data, output_dimensions, output_data,
   1348                                      input_diff_data, output_diff_data));
   1349     } else {
   1350       SetError();
   1351       LOG(WARNING)
   1352           << "attempting to perform DNN operation using StreamExecutor "
   1353              "without DNN support";
   1354     }
   1355   }
   1356   return *this;
   1357 }
   1358 
   1359 Stream &Stream::ThenPoolBackward(
   1360     const dnn::PoolingDescriptor &pooling_dimensions,
   1361     const dnn::BatchDescriptor &input_dimensions,
   1362     const DeviceMemory<float> &input_data,
   1363     const dnn::BatchDescriptor &output_dimensions,
   1364     const DeviceMemory<float> &output_data,
   1365     const DeviceMemory<float> &input_diff_data,
   1366     DeviceMemory<float> *output_diff_data) {
   1367   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1368             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1369             PARAM(input_diff_data), PARAM(output_diff_data));
   1370 
   1371   if (ok()) {
   1372     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1373       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
   1374                                      input_data, output_dimensions, output_data,
   1375                                      input_diff_data, output_diff_data));
   1376     } else {
   1377       SetErrorAndLogNoDnnSupport();
   1378     }
   1379   }
   1380   return *this;
   1381 }
   1382 
   1383 Stream &Stream::ThenPoolBackward(
   1384     const dnn::PoolingDescriptor &pooling_dimensions,
   1385     const dnn::BatchDescriptor &input_dimensions,
   1386     const DeviceMemory<Eigen::half> &input_data,
   1387     const dnn::BatchDescriptor &output_dimensions,
   1388     const DeviceMemory<Eigen::half> &output_data,
   1389     const DeviceMemory<Eigen::half> &input_diff_data,
   1390     DeviceMemory<Eigen::half> *output_diff_data) {
   1391   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1392             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1393             PARAM(input_diff_data), PARAM(output_diff_data));
   1394 
   1395   if (ok()) {
   1396     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1397       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
   1398                                      input_data, output_dimensions, output_data,
   1399                                      input_diff_data, output_diff_data));
   1400     } else {
   1401       SetErrorAndLogNoDnnSupport();
   1402     }
   1403   }
   1404   return *this;
   1405 }
   1406 
   1407 Stream &Stream::ThenNormalize(
   1408     const dnn::NormalizeDescriptor &normalize_descriptor,
   1409     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
   1410   VLOG_CALL(PARAM(normalize_descriptor), PARAM(input_data), PARAM(output_data));
   1411 
   1412   if (ok()) {
   1413     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1414       CheckError(dnn->DoNormalize(this, normalize_descriptor, input_data,
   1415                                   output_data));
   1416     } else {
   1417       SetErrorAndLogNoDnnSupport();
   1418     }
   1419   }
   1420   return *this;
   1421 }
   1422 
   1423 Stream &Stream::ThenNormalizeWithDimensions(
   1424     const dnn::NormalizeDescriptor &normalize_descriptor,
   1425     const dnn::BatchDescriptor &dimensions,
   1426     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
   1427   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
   1428             PARAM(output_data));
   1429 
   1430   if (ok()) {
   1431     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1432       CheckError(dnn->DoNormalizeWithDimensions(
   1433           this, normalize_descriptor, dimensions, input_data, output_data));
   1434     } else {
   1435       SetErrorAndLogNoDnnSupport();
   1436     }
   1437   }
   1438   return *this;
   1439 }
   1440 
   1441 Stream &Stream::ThenNormalizeBackwardWithDimensions(
   1442     const dnn::NormalizeDescriptor &normalize_descriptor,
   1443     const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
   1444     const DeviceMemory<float> &normalized_data,
   1445     const DeviceMemory<float> &normalized_variable_gradient,
   1446     DeviceMemory<float> *raw_variable_gradient) {
   1447   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
   1448             PARAM(normalized_data), PARAM(normalized_variable_gradient),
   1449             PARAM(raw_variable_gradient));
   1450 
   1451   if (ok()) {
   1452     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1453       CheckError(dnn->DoNormalizeBackwardWithDimensions(
   1454           this, normalize_descriptor, dimensions, raw_data, normalized_data,
   1455           normalized_variable_gradient, raw_variable_gradient));
   1456     } else {
   1457       SetErrorAndLogNoDnnSupport();
   1458     }
   1459   }
   1460   return *this;
   1461 }
   1462 
   1463 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
   1464                              const dnn::BatchDescriptor &dimensions,
   1465                              const DeviceMemory<float> &input_data,
   1466                              DeviceMemory<float> *output_data) {
   1467   return ThenActivateWithOptions(activation_mode, dimensions, input_data,
   1468                                  output_data, /*options=*/0);
   1469 }
   1470 
   1471 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
   1472                                         const dnn::BatchDescriptor &dimensions,
   1473                                         const DeviceMemory<float> &input_data,
   1474                                         DeviceMemory<float> *output_data,
   1475                                         uint64 options) {
   1476   VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
   1477             PARAM(output_data), PARAM(options));
   1478 
   1479   if (ok()) {
   1480     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1481       CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
   1482                                  output_data, options));
   1483     } else {
   1484       SetErrorAndLogNoDnnSupport();
   1485     }
   1486   }
   1487   return *this;
   1488 }
   1489 
   1490 Stream &Stream::ThenDepthConcatenate(
   1491     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   1492     port::ArraySlice<const DeviceMemory<float> *> input_data,
   1493     DeviceMemory<float> *output_data) {
   1494   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
   1495 
   1496   for (size_t i = 1; i < input_dimensions.size(); ++i) {
   1497     if (input_dimensions[i].count() != input_dimensions[0].count() ||
   1498         input_dimensions[i].height() != input_dimensions[0].height() ||
   1499         input_dimensions[i].width() != input_dimensions[0].width()) {
   1500       SetError();
   1501       LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
   1502                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
   1503                  << "input_dimensions[" << i
   1504                  << "]: " << input_dimensions[i].ToString();
   1505       return *this;
   1506     }
   1507   }
   1508 
   1509   if (ok()) {
   1510     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1511       CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
   1512                                          output_data));
   1513     } else {
   1514       SetErrorAndLogNoDnnSupport();
   1515     }
   1516   }
   1517   return *this;
   1518 }
   1519 
   1520 Stream &Stream::ThenSpaceConcatenate(
   1521     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   1522     port::ArraySlice<const DeviceMemory<float> *> input_data,
   1523     DeviceMemory<float> *output_data,
   1524     dnn::SpaceConcatenateMode concat_direction) {
   1525   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
   1526 
   1527   // Check that the input dimensions of all the other batches match those of the
   1528   // first batch.
   1529   for (size_t i = 1; i < input_dimensions.size(); ++i) {
   1530     if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
   1531         (input_dimensions[i].count() != input_dimensions[0].count() ||
   1532          input_dimensions[i].height() != input_dimensions[0].height() ||
   1533          input_dimensions[i].feature_map_count() !=
   1534              input_dimensions[0].feature_map_count())) {
   1535       SetError();
   1536       LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
   1537                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
   1538                  << "input_dimensions[" << i
   1539                  << "]: " << input_dimensions[i].ToString();
   1540       return *this;
   1541     }
   1542 
   1543     if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
   1544         (input_dimensions[i].count() != input_dimensions[0].count() ||
   1545          input_dimensions[i].width() != input_dimensions[0].width() ||
   1546          input_dimensions[i].feature_map_count() !=
   1547              input_dimensions[0].feature_map_count())) {
   1548       SetError();
   1549       LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
   1550                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
   1551                  << "input_dimensions[" << i
   1552                  << "]: " << input_dimensions[i].ToString();
   1553       return *this;
   1554     }
   1555   }
   1556   if (ok()) {
   1557     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1558       CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
   1559                                          output_data, concat_direction));
   1560     } else {
   1561       SetErrorAndLogNoDnnSupport();
   1562     }
   1563   }
   1564   return *this;
   1565 }
   1566 
   1567 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
   1568                             const DeviceMemory<float> &input_data,
   1569                             const dnn::BatchDescriptor &output_dimensions,
   1570                             DeviceMemory<float> *output_data) {
   1571   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
   1572             PARAM(output_dimensions), PARAM(output_data));
   1573 
   1574   if (ok()) {
   1575     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1576       CheckError(dnn->DoReshape(this, input_dimensions, input_data,
   1577                                 output_dimensions, output_data));
   1578     } else {
   1579       SetErrorAndLogNoDnnSupport();
   1580     }
   1581   }
   1582   return *this;
   1583 }
   1584 
   1585 Stream &Stream::ThenDepthToSpace(
   1586     const dnn::BatchDescriptor &input_dimensions,
   1587     const DeviceMemory<float> &input_data,
   1588     const dnn::DepthToSpaceLayout &depth_to_space_layout,
   1589     const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
   1590   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
   1591             PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
   1592             PARAM(output_data));
   1593 
   1594   if (ok()) {
   1595     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1596       CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
   1597                                      depth_to_space_layout,
   1598                                      sqrt_depth_reduction, output_data));
   1599     } else {
   1600       SetErrorAndLogNoDnnSupport();
   1601     }
   1602   }
   1603   return *this;
   1604 }
   1605 
   1606 Stream &Stream::ThenSpaceToDepth(
   1607     const dnn::BatchDescriptor &input_dimensions,
   1608     const DeviceMemory<float> &input_data,
   1609     const dnn::DepthToSpaceLayout &space_to_depth_layout,
   1610     const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
   1611   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
   1612             PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
   1613             PARAM(output_data));
   1614 
   1615   if (ok()) {
   1616     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1617       CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
   1618                                      space_to_depth_layout, sqrt_depth_increase,
   1619                                      output_data));
   1620     } else {
   1621       SetErrorAndLogNoDnnSupport();
   1622     }
   1623   }
   1624   return *this;
   1625 }
   1626 
   1627 Stream &Stream::ThenElementwiseOperate(
   1628     dnn::ElementwiseOperation operation,
   1629     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   1630     port::ArraySlice<const DeviceMemory<float> *> input_data,
   1631     const dnn::BatchDescriptor &output_dimensions,
   1632     DeviceMemory<float> *output_data) {
   1633   VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
   1634             PARAM(output_dimensions), PARAM(output_data));
   1635 
   1636   if (ok()) {
   1637     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1638       CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
   1639                                            input_data, output_dimensions,
   1640                                            output_data));
   1641     } else {
   1642       SetErrorAndLogNoDnnSupport();
   1643     }
   1644   }
   1645   return *this;
   1646 }
   1647 
   1648 Stream &Stream::ThenElementwiseOperateScaledQuantized(
   1649     dnn::ElementwiseOperation operation,
   1650     port::ArraySlice<int> input_multiplicands, int output_divisor,
   1651     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   1652     port::ArraySlice<const DeviceMemory<float> *> input_data,
   1653     const dnn::BatchDescriptor &output_dimensions,
   1654     DeviceMemory<float> *output_data) {
   1655   VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
   1656             PARAM(input_dimensions), PARAM(input_data),
   1657             PARAM(output_dimensions), PARAM(output_data));
   1658 
   1659   if (ok()) {
   1660     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1661       CheckError(dnn->DoElementwiseOperateScaledQuantized(
   1662           this, operation, input_multiplicands, output_divisor,
   1663           input_dimensions, input_data, output_dimensions, output_data));
   1664     } else {
   1665       SetErrorAndLogNoDnnSupport();
   1666     }
   1667   }
   1668   return *this;
   1669 }
   1670 
   1671 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
   1672                           const DeviceMemory<float> &input_data, int64 left_pad,
   1673                           int64 right_pad, int64 top_pad, int64 bottom_pad,
   1674                           DeviceMemory<float> *output_data) {
   1675   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
   1676             PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
   1677             PARAM(output_data));
   1678 
   1679   if (ok()) {
   1680     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1681       CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
   1682                               top_pad, bottom_pad, output_data));
   1683     } else {
   1684       SetErrorAndLogNoDnnSupport();
   1685     }
   1686   }
   1687   return *this;
   1688 }
   1689 
   1690 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
   1691                             const DeviceMemory<float> &input_data,
   1692                             int64 left_trim, int64 right_trim, int64 top_trim,
   1693                             int64 bottom_trim,
   1694                             DeviceMemory<float> *output_data) {
   1695   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
   1696             PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
   1697             PARAM(output_data));
   1698 
   1699   if (ok()) {
   1700     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1701       CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
   1702                                 right_trim, top_trim, bottom_trim,
   1703                                 output_data));
   1704     } else {
   1705       SetErrorAndLogNoDnnSupport();
   1706     }
   1707   }
   1708   return *this;
   1709 }
   1710 
   1711 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
   1712                                 const DeviceMemory<float> &input_data,
   1713                                 int64 replicate_x, int64 replicate_y,
   1714                                 DeviceMemory<float> *output_data) {
   1715   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
   1716             PARAM(replicate_y), PARAM(output_data));
   1717 
   1718   if (ok()) {
   1719     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1720       CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
   1721                                     replicate_y, output_data));
   1722     } else {
   1723       SetErrorAndLogNoDnnSupport();
   1724     }
   1725   }
   1726   return *this;
   1727 }
   1728 
   1729 Stream &Stream::ThenMemcpyD2HQuantized(
   1730     const DeviceMemory<float> &gpu_unquantized_src,
   1731     dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
   1732   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
   1733             PARAM(size));
   1734 
   1735   if (ok()) {
   1736     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1737       CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
   1738                                            host_dst, size));
   1739     } else {
   1740       SetErrorAndLogNoDnnSupport();
   1741     }
   1742   }
   1743   return *this;
   1744 }
   1745 
   1746 Stream &Stream::ThenMemcpyH2DQuantized(
   1747     const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
   1748     DeviceMemory<float> *gpu_unquantized_dst) {
   1749   VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
   1750             PARAM(gpu_unquantized_dst));
   1751 
   1752   if (ok()) {
   1753     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1754       CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
   1755                                            gpu_unquantized_dst));
   1756     } else {
   1757       SetErrorAndLogNoDnnSupport();
   1758     }
   1759   }
   1760   return *this;
   1761 }
   1762 
   1763 Stream &Stream::ThenCopyHostBuffer2Device(
   1764     HostBuffer *buffer_src, DeviceMemory<float> *gpu_unquantized_dst) {
   1765   VLOG_CALL(PARAM(*buffer_src), PARAM(gpu_unquantized_dst));
   1766 
   1767   if (ok()) {
   1768     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1769       CheckError(
   1770           dnn->DoCopyHostBuffer2Device(this, buffer_src, gpu_unquantized_dst));
   1771     } else {
   1772       SetErrorAndLogNoDnnSupport();
   1773     }
   1774   }
   1775   return *this;
   1776 }
   1777 
   1778 Stream &Stream::ThenCopyDevice2HostBuffer(
   1779     const DeviceMemory<float> &gpu_unquantized_src, HostBuffer *buffer_dst) {
   1780   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(*buffer_dst));
   1781 
   1782   if (ok()) {
   1783     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1784       CheckError(
   1785           dnn->DoCopyDevice2HostBuffer(this, gpu_unquantized_src, buffer_dst));
   1786     } else {
   1787       SetErrorAndLogNoDnnSupport();
   1788     }
   1789   }
   1790   return *this;
   1791 }
   1792 
   1793 Stream *Stream::GetOrCreateSubStream() {
   1794   mutex_lock lock{mu_};
   1795   for (auto &stream : sub_streams_) {
   1796     if (stream.second) {
   1797       stream.second = false;
   1798       return stream.first.get();
   1799     }
   1800   }
   1801   sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
   1802                             false);
   1803   Stream *sub_stream = sub_streams_.back().first.get();
   1804   sub_stream->Init();
   1805   CHECK(ok_) << "sub-stream failed to be initialized";
   1806 
   1807   return sub_stream;
   1808 }
   1809 
   1810 void Stream::ReturnSubStream(Stream *sub_stream) {
   1811   mutex_lock lock{mu_};
   1812   for (auto &stream : sub_streams_) {
   1813     if (stream.first.get() == sub_stream) {
   1814       stream.second = true;
   1815       return;
   1816     }
   1817   }
   1818   LOG(FATAL) << "the sub-stream to be returned is not created by this stream";
   1819 }
   1820 
   1821 Stream &Stream::ThenStartTimer(Timer *t) {
   1822   VLOG_CALL(PARAM(t));
   1823 
   1824   if (ok()) {
   1825     CheckError(parent_->StartTimer(this, t));
   1826   } else {
   1827     LOG(INFO) << "stream " << this << " did not enqueue 'start timer': " << t;
   1828   }
   1829   return *this;
   1830 }
   1831 
   1832 Stream &Stream::ThenStopTimer(Timer *t) {
   1833   VLOG_CALL(PARAM(t));
   1834 
   1835   if (ok()) {
   1836     CheckError(parent_->StopTimer(this, t));
   1837   } else {
   1838     LOG(INFO) << "stream " << this << " did not enqueue 'stop timer': " << t;
   1839   }
   1840   return *this;
   1841 }
   1842 
   1843 Stream &Stream::ThenWaitFor(Stream *other) {
   1844   VLOG_CALL(PARAM(other));
   1845 
   1846   CHECK(this != other) << "stream cannot wait for itself";
   1847   if (ok() && other->ok()) {
   1848     CheckError(parent_->CreateStreamDependency(this, other));
   1849   } else {
   1850     SetError();
   1851     LOG(INFO) << "stream " << this << " did not wait for stream: " << other;
   1852   }
   1853   return *this;
   1854 }
   1855 
   1856 Stream &Stream::ThenWaitFor(Event *event) {
   1857   VLOG_CALL(PARAM(event));
   1858 
   1859   if (ok()) {
   1860     port::Status status = parent_->WaitForEvent(this, event);
   1861     if (!status.ok()) {
   1862       LOG(ERROR) << "Error waiting for event in stream: "
   1863                  << status.error_message()
   1864                  << "; not marking stream as bad, as the Event object may be "
   1865                  << "at fault. Monitor for further errors.";
   1866     }
   1867   } else {
   1868     LOG(INFO) << "stream " << this << " did not wait for an event.";
   1869   }
   1870   return *this;
   1871 }
   1872 
   1873 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
   1874 // functions and logs for errors.
   1875 template <typename... Args>
   1876 struct ThenBlasImpl {
   1877   // blas_func is the DoBlasXXX member function pointer, and args are its
   1878   // arguments except the first one of Stream* type.
   1879   Stream &operator()(Stream *stream,
   1880                      bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
   1881                      Args... args) {
   1882     return Run(stream, blas_func, /*record_error=*/true, args...);
   1883   }
   1884 
   1885   // Like operator(), but only calls stream->CheckError() if record_error is
   1886   // true.
   1887   Stream &Run(Stream *stream,
   1888               bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
   1889               bool record_error, Args... args);
   1890 };
   1891 
   1892 template <typename... Args>
   1893 Stream &ThenBlasImpl<Args...>::Run(
   1894     Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
   1895     bool record_error, Args... args) {
   1896   if (stream->ok()) {
   1897     bool ok;
   1898     if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
   1899       ok = (blas->*blas_func)(stream, args...);
   1900     } else {
   1901       LOG(WARNING)
   1902           << "attempting to perform BLAS operation using StreamExecutor "
   1903              "without BLAS support";
   1904       ok = false;
   1905     }
   1906     if (record_error) {
   1907       stream->CheckError(ok);
   1908     }
   1909   }
   1910   return *stream;
   1911 }
   1912 
   1913 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
   1914                              int incx, DeviceMemory<float> *result) {
   1915   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   1916 
   1917   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
   1918       impl;
   1919   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
   1920               result);
   1921 }
   1922 
   1923 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
   1924                              int incx, DeviceMemory<double> *result) {
   1925   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   1926 
   1927   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
   1928                DeviceMemory<double> *> impl;
   1929   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
   1930               result);
   1931 }
   1932 
   1933 Stream &Stream::ThenBlasAsum(uint64 elem_count,
   1934                              const DeviceMemory<std::complex<float>> &x,
   1935                              int incx, DeviceMemory<float> *result) {
   1936   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   1937 
   1938   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   1939                DeviceMemory<float> *> impl;
   1940   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
   1941               result);
   1942 }
   1943 
   1944 Stream &Stream::ThenBlasAsum(uint64 elem_count,
   1945                              const DeviceMemory<std::complex<double>> &x,
   1946                              int incx, DeviceMemory<double> *result) {
   1947   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   1948 
   1949   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   1950                DeviceMemory<double> *> impl;
   1951   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
   1952               result);
   1953 }
   1954 
   1955 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
   1956                              const DeviceMemory<float> &x, int incx,
   1957                              DeviceMemory<float> *y, int incy) {
   1958   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   1959             PARAM(incy));
   1960 
   1961   ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
   1962                DeviceMemory<float> *, int> impl;
   1963   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
   1964               y, incy);
   1965 }
   1966 
   1967 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
   1968                              const DeviceMemory<double> &x, int incx,
   1969                              DeviceMemory<double> *y, int incy) {
   1970   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   1971             PARAM(incy));
   1972 
   1973   ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
   1974                DeviceMemory<double> *, int> impl;
   1975   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
   1976               y, incy);
   1977 }
   1978 
   1979 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
   1980                              const DeviceMemory<std::complex<float>> &x,
   1981                              int incx, DeviceMemory<std::complex<float>> *y,
   1982                              int incy) {
   1983   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   1984             PARAM(incy));
   1985 
   1986   ThenBlasImpl<uint64, std::complex<float>,
   1987                const DeviceMemory<std::complex<float>> &, int,
   1988                DeviceMemory<std::complex<float>> *, int> impl;
   1989   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
   1990               y, incy);
   1991 }
   1992 
   1993 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
   1994                              const DeviceMemory<std::complex<double>> &x,
   1995                              int incx, DeviceMemory<std::complex<double>> *y,
   1996                              int incy) {
   1997   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   1998             PARAM(incy));
   1999 
   2000   ThenBlasImpl<uint64, std::complex<double>,
   2001                const DeviceMemory<std::complex<double>> &, int,
   2002                DeviceMemory<std::complex<double>> *, int> impl;
   2003   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
   2004               y, incy);
   2005 }
   2006 
   2007 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
   2008                              int incx, DeviceMemory<float> *y, int incy) {
   2009   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2010 
   2011   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
   2012                int> impl;
   2013   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
   2014               incy);
   2015 }
   2016 
   2017 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
   2018                              int incx, DeviceMemory<double> *y, int incy) {
   2019   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2020 
   2021   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
   2022                DeviceMemory<double> *, int> impl;
   2023   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
   2024               incy);
   2025 }
   2026 
   2027 Stream &Stream::ThenBlasCopy(uint64 elem_count,
   2028                              const DeviceMemory<std::complex<float>> &x,
   2029                              int incx, DeviceMemory<std::complex<float>> *y,
   2030                              int incy) {
   2031   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2032 
   2033   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2034                DeviceMemory<std::complex<float>> *, int> impl;
   2035   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
   2036               incy);
   2037 }
   2038 
   2039 Stream &Stream::ThenBlasCopy(uint64 elem_count,
   2040                              const DeviceMemory<std::complex<double>> &x,
   2041                              int incx, DeviceMemory<std::complex<double>> *y,
   2042                              int incy) {
   2043   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2044 
   2045   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2046                DeviceMemory<std::complex<double>> *, int> impl;
   2047   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
   2048               incy);
   2049 }
   2050 
   2051 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
   2052                             int incx, const DeviceMemory<float> &y, int incy,
   2053                             DeviceMemory<float> *result) {
   2054   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2055             PARAM(result));
   2056 
   2057   ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
   2058                const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
   2059   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
   2060               result);
   2061 }
   2062 
   2063 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
   2064                             int incx, const DeviceMemory<double> &y, int incy,
   2065                             DeviceMemory<double> *result) {
   2066   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2067             PARAM(result));
   2068 
   2069   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
   2070                const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
   2071   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
   2072               result);
   2073 }
   2074 
   2075 Stream &Stream::ThenBlasDotc(uint64 elem_count,
   2076                              const DeviceMemory<std::complex<float>> &x,
   2077                              int incx,
   2078                              const DeviceMemory<std::complex<float>> &y,
   2079                              int incy,
   2080                              DeviceMemory<std::complex<float>> *result) {
   2081   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2082             PARAM(result));
   2083 
   2084   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2085                const DeviceMemory<std::complex<float>> &, int,
   2086                DeviceMemory<std::complex<float>> *> impl;
   2087   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
   2088               incy, result);
   2089 }
   2090 
   2091 Stream &Stream::ThenBlasDotc(uint64 elem_count,
   2092                              const DeviceMemory<std::complex<double>> &x,
   2093                              int incx,
   2094                              const DeviceMemory<std::complex<double>> &y,
   2095                              int incy,
   2096                              DeviceMemory<std::complex<double>> *result) {
   2097   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2098             PARAM(result));
   2099 
   2100   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2101                const DeviceMemory<std::complex<double>> &, int,
   2102                DeviceMemory<std::complex<double>> *> impl;
   2103   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
   2104               incy, result);
   2105 }
   2106 
   2107 Stream &Stream::ThenBlasDotu(uint64 elem_count,
   2108                              const DeviceMemory<std::complex<float>> &x,
   2109                              int incx,
   2110                              const DeviceMemory<std::complex<float>> &y,
   2111                              int incy,
   2112                              DeviceMemory<std::complex<float>> *result) {
   2113   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2114             PARAM(result));
   2115 
   2116   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2117                const DeviceMemory<std::complex<float>> &, int,
   2118                DeviceMemory<std::complex<float>> *> impl;
   2119   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
   2120               incy, result);
   2121 }
   2122 
   2123 Stream &Stream::ThenBlasDotu(uint64 elem_count,
   2124                              const DeviceMemory<std::complex<double>> &x,
   2125                              int incx,
   2126                              const DeviceMemory<std::complex<double>> &y,
   2127                              int incy,
   2128                              DeviceMemory<std::complex<double>> *result) {
   2129   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2130             PARAM(result));
   2131 
   2132   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2133                const DeviceMemory<std::complex<double>> &, int,
   2134                DeviceMemory<std::complex<double>> *> impl;
   2135   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
   2136               incy, result);
   2137 }
   2138 
   2139 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
   2140                              int incx, DeviceMemory<float> *result) {
   2141   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2142 
   2143   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
   2144       impl;
   2145   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
   2146               result);
   2147 }
   2148 
   2149 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
   2150                              int incx, DeviceMemory<double> *result) {
   2151   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2152 
   2153   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
   2154                DeviceMemory<double> *> impl;
   2155   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
   2156               result);
   2157 }
   2158 
   2159 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
   2160                              const DeviceMemory<std::complex<float>> &x,
   2161                              int incx, DeviceMemory<float> *result) {
   2162   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2163 
   2164   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2165                DeviceMemory<float> *> impl;
   2166   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
   2167               result);
   2168 }
   2169 
   2170 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
   2171                              const DeviceMemory<std::complex<double>> &x,
   2172                              int incx, DeviceMemory<double> *result) {
   2173   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2174 
   2175   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2176                DeviceMemory<double> *> impl;
   2177   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
   2178               result);
   2179 }
   2180 
   2181 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
   2182                             DeviceMemory<float> *y, int incy, float c,
   2183                             float s) {
   2184   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2185             PARAM(c), PARAM(s));
   2186 
   2187   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
   2188                float, float> impl;
   2189   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
   2190               c, s);
   2191 }
   2192 
   2193 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
   2194                             int incx, DeviceMemory<double> *y, int incy,
   2195                             double c, double s) {
   2196   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2197             PARAM(c), PARAM(s));
   2198 
   2199   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
   2200                double, double> impl;
   2201   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
   2202               c, s);
   2203 }
   2204 
   2205 Stream &Stream::ThenBlasRot(uint64 elem_count,
   2206                             DeviceMemory<std::complex<float>> *x, int incx,
   2207                             DeviceMemory<std::complex<float>> *y, int incy,
   2208                             float c, float s) {
   2209   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2210             PARAM(c), PARAM(s));
   2211 
   2212   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
   2213                DeviceMemory<std::complex<float>> *, int, float, float> impl;
   2214   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
   2215               c, s);
   2216 }
   2217 
   2218 Stream &Stream::ThenBlasRot(uint64 elem_count,
   2219                             DeviceMemory<std::complex<double>> *x, int incx,
   2220                             DeviceMemory<std::complex<double>> *y, int incy,
   2221                             double c, double s) {
   2222   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2223             PARAM(c), PARAM(s));
   2224 
   2225   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
   2226                DeviceMemory<std::complex<double>> *, int, double, double> impl;
   2227   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
   2228               c, s);
   2229 }
   2230 
   2231 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
   2232                              DeviceMemory<float> *c, DeviceMemory<float> *s) {
   2233   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
   2234 
   2235   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
   2236                DeviceMemory<float> *, DeviceMemory<float> *> impl;
   2237   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
   2238 }
   2239 
   2240 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
   2241                              DeviceMemory<double> *c, DeviceMemory<double> *s) {
   2242   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
   2243 
   2244   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
   2245                DeviceMemory<double> *, DeviceMemory<double> *> impl;
   2246   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
   2247 }
   2248 
   2249 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
   2250                              DeviceMemory<std::complex<float>> *b,
   2251                              DeviceMemory<float> *c,
   2252                              DeviceMemory<std::complex<float>> *s) {
   2253   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
   2254 
   2255   ThenBlasImpl<DeviceMemory<std::complex<float>> *,
   2256                DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
   2257                DeviceMemory<std::complex<float>> *> impl;
   2258   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
   2259 }
   2260 
   2261 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
   2262                              DeviceMemory<std::complex<double>> *b,
   2263                              DeviceMemory<double> *c,
   2264                              DeviceMemory<std::complex<double>> *s) {
   2265   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
   2266 
   2267   ThenBlasImpl<DeviceMemory<std::complex<double>> *,
   2268                DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
   2269                DeviceMemory<std::complex<double>> *> impl;
   2270   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
   2271 }
   2272 
   2273 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
   2274                              int incx, DeviceMemory<float> *y, int incy,
   2275                              const DeviceMemory<float> &param) {
   2276   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2277             PARAM(param));
   2278 
   2279   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
   2280                const DeviceMemory<float> &> impl;
   2281   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
   2282               incy, param);
   2283 }
   2284 
   2285 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
   2286                              int incx, DeviceMemory<double> *y, int incy,
   2287                              const DeviceMemory<double> &param) {
   2288   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2289             PARAM(param));
   2290 
   2291   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
   2292                const DeviceMemory<double> &> impl;
   2293   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
   2294               incy, param);
   2295 }
   2296 
   2297 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
   2298                               DeviceMemory<float> *x1,
   2299                               const DeviceMemory<float> &y1,
   2300                               DeviceMemory<float> *param) {
   2301   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
   2302 
   2303   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
   2304                DeviceMemory<float> *, const DeviceMemory<float> &,
   2305                DeviceMemory<float> *> impl;
   2306   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
   2307 }
   2308 
   2309 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
   2310                               DeviceMemory<double> *d2,
   2311                               DeviceMemory<double> *x1,
   2312                               const DeviceMemory<double> &y1,
   2313                               DeviceMemory<double> *param) {
   2314   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
   2315 
   2316   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
   2317                DeviceMemory<double> *, const DeviceMemory<double> &,
   2318                DeviceMemory<double> *> impl;
   2319   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
   2320 }
   2321 
   2322 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
   2323                              DeviceMemory<float> *x, int incx) {
   2324   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2325 
   2326   ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
   2327   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2328 }
   2329 
   2330 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
   2331                              DeviceMemory<double> *x, int incx) {
   2332   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2333 
   2334   ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
   2335   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2336 }
   2337 
   2338 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
   2339                              DeviceMemory<std::complex<float>> *x, int incx) {
   2340   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2341 
   2342   ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
   2343   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2344 }
   2345 
   2346 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
   2347                              DeviceMemory<std::complex<double>> *x, int incx) {
   2348   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2349 
   2350   ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
   2351   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2352 }
   2353 
   2354 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
   2355                              DeviceMemory<std::complex<float>> *x, int incx) {
   2356   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2357 
   2358   ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
   2359                int> impl;
   2360   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2361 }
   2362 
   2363 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
   2364                              DeviceMemory<std::complex<double>> *x, int incx) {
   2365   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2366 
   2367   ThenBlasImpl<uint64, std::complex<double>,
   2368                DeviceMemory<std::complex<double>> *, int> impl;
   2369   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2370 }
   2371 
   2372 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
   2373                              int incx, DeviceMemory<float> *y, int incy) {
   2374   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2375 
   2376   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
   2377       impl;
   2378   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
   2379               incy);
   2380 }
   2381 
   2382 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
   2383                              int incx, DeviceMemory<double> *y, int incy) {
   2384   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2385 
   2386   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
   2387       impl;
   2388   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
   2389               incy);
   2390 }
   2391 
   2392 Stream &Stream::ThenBlasSwap(uint64 elem_count,
   2393                              DeviceMemory<std::complex<float>> *x, int incx,
   2394                              DeviceMemory<std::complex<float>> *y, int incy) {
   2395   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2396 
   2397   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
   2398                DeviceMemory<std::complex<float>> *, int> impl;
   2399   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
   2400               incy);
   2401 }
   2402 
   2403 Stream &Stream::ThenBlasSwap(uint64 elem_count,
   2404                              DeviceMemory<std::complex<double>> *x, int incx,
   2405                              DeviceMemory<std::complex<double>> *y, int incy) {
   2406   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2407 
   2408   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
   2409                DeviceMemory<std::complex<double>> *, int> impl;
   2410   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
   2411               incy);
   2412 }
   2413 
   2414 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
   2415                               int incx, DeviceMemory<int> *result) {
   2416   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2417 
   2418   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
   2419       impl;
   2420   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
   2421               result);
   2422 }
   2423 
   2424 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
   2425                               int incx, DeviceMemory<int> *result) {
   2426   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2427 
   2428   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
   2429       impl;
   2430   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
   2431               result);
   2432 }
   2433 
   2434 Stream &Stream::ThenBlasIamax(uint64 elem_count,
   2435                               const DeviceMemory<std::complex<float>> &x,
   2436                               int incx, DeviceMemory<int> *result) {
   2437   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2438 
   2439   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2440                DeviceMemory<int> *> impl;
   2441   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
   2442               result);
   2443 }
   2444 
   2445 Stream &Stream::ThenBlasIamax(uint64 elem_count,
   2446                               const DeviceMemory<std::complex<double>> &x,
   2447                               int incx, DeviceMemory<int> *result) {
   2448   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2449 
   2450   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2451                DeviceMemory<int> *> impl;
   2452   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
   2453               result);
   2454 }
   2455 
   2456 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
   2457                               int incx, DeviceMemory<int> *result) {
   2458   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2459 
   2460   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
   2461       impl;
   2462   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
   2463               result);
   2464 }
   2465 
   2466 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
   2467                               int incx, DeviceMemory<int> *result) {
   2468   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2469 
   2470   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
   2471       impl;
   2472   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
   2473               result);
   2474 }
   2475 
   2476 Stream &Stream::ThenBlasIamin(uint64 elem_count,
   2477                               const DeviceMemory<std::complex<float>> &x,
   2478                               int incx, DeviceMemory<int> *result) {
   2479   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2480 
   2481   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2482                DeviceMemory<int> *> impl;
   2483   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
   2484               result);
   2485 }
   2486 
   2487 Stream &Stream::ThenBlasIamin(uint64 elem_count,
   2488                               const DeviceMemory<std::complex<double>> &x,
   2489                               int incx, DeviceMemory<int> *result) {
   2490   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2491 
   2492   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2493                DeviceMemory<int> *> impl;
   2494   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
   2495               result);
   2496 }
   2497 
   2498 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
   2499                              uint64 kl, uint64 ku, float alpha,
   2500                              const DeviceMemory<float> &a, int lda,
   2501                              const DeviceMemory<float> &x, int incx, float beta,
   2502                              DeviceMemory<float> *y, int incy) {
   2503   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
   2504             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
   2505             PARAM(beta), PARAM(y), PARAM(incy));
   2506 
   2507   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
   2508                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   2509                int, float, DeviceMemory<float> *, int> impl;
   2510   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
   2511               a, lda, x, incx, beta, y, incy);
   2512 }
   2513 
   2514 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
   2515                              uint64 kl, uint64 ku, double alpha,
   2516                              const DeviceMemory<double> &a, int lda,
   2517                              const DeviceMemory<double> &x, int incx,
   2518                              double beta, DeviceMemory<double> *y, int incy) {
   2519   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
   2520             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
   2521             PARAM(beta), PARAM(y), PARAM(incy));
   2522 
   2523   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
   2524                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   2525                int, double, DeviceMemory<double> *, int> impl;
   2526   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
   2527               a, lda, x, incx, beta, y, incy);
   2528 }
   2529 
   2530 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
   2531                              uint64 kl, uint64 ku, std::complex<float> alpha,
   2532                              const DeviceMemory<std::complex<float>> &a,
   2533                              int lda,
   2534                              const DeviceMemory<std::complex<float>> &x,
   2535                              int incx, std::complex<float> beta,
   2536                              DeviceMemory<std::complex<float>> *y, int incy) {
   2537   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
   2538             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
   2539             PARAM(beta), PARAM(y), PARAM(incy));
   2540 
   2541   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
   2542                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   2543                int, const DeviceMemory<std::complex<float>> &, int,
   2544                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2545                int> impl;
   2546   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
   2547               a, lda, x, incx, beta, y, incy);
   2548 }
   2549 
   2550 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
   2551                              uint64 kl, uint64 ku, std::complex<double> alpha,
   2552                              const DeviceMemory<std::complex<double>> &a,
   2553                              int lda,
   2554                              const DeviceMemory<std::complex<double>> &x,
   2555                              int incx, std::complex<double> beta,
   2556                              DeviceMemory<std::complex<double>> *y, int incy) {
   2557   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
   2558             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
   2559             PARAM(beta), PARAM(y), PARAM(incy));
   2560 
   2561   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
   2562                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   2563                int, const DeviceMemory<std::complex<double>> &, int,
   2564                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2565                int> impl;
   2566   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
   2567               a, lda, x, incx, beta, y, incy);
   2568 }
   2569 
   2570 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   2571                              float alpha, const DeviceMemory<float> &a, int lda,
   2572                              const DeviceMemory<float> &x, int incx, float beta,
   2573                              DeviceMemory<float> *y, int incy) {
   2574   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   2575             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   2576             PARAM(incy));
   2577 
   2578   ThenBlasImpl<blas::Transpose, uint64, uint64, float,
   2579                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   2580                int, float, DeviceMemory<float> *, int> impl;
   2581   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
   2582               x, incx, beta, y, incy);
   2583 }
   2584 
   2585 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   2586                              double alpha, const DeviceMemory<double> &a,
   2587                              int lda, const DeviceMemory<double> &x, int incx,
   2588                              double beta, DeviceMemory<double> *y, int incy) {
   2589   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   2590             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   2591             PARAM(incy));
   2592 
   2593   ThenBlasImpl<blas::Transpose, uint64, uint64, double,
   2594                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   2595                int, double, DeviceMemory<double> *, int> impl;
   2596   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
   2597               x, incx, beta, y, incy);
   2598 }
   2599 
   2600 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   2601                              std::complex<float> alpha,
   2602                              const DeviceMemory<std::complex<float>> &a,
   2603                              int lda,
   2604                              const DeviceMemory<std::complex<float>> &x,
   2605                              int incx, std::complex<float> beta,
   2606                              DeviceMemory<std::complex<float>> *y, int incy) {
   2607   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   2608             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   2609             PARAM(incy));
   2610 
   2611   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
   2612                const DeviceMemory<std::complex<float>> &, int,
   2613                const DeviceMemory<std::complex<float>> &, int,
   2614                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2615                int> impl;
   2616   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
   2617               x, incx, beta, y, incy);
   2618 }
   2619 
   2620 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   2621                              std::complex<double> alpha,
   2622                              const DeviceMemory<std::complex<double>> &a,
   2623                              int lda,
   2624                              const DeviceMemory<std::complex<double>> &x,
   2625                              int incx, std::complex<double> beta,
   2626                              DeviceMemory<std::complex<double>> *y, int incy) {
   2627   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   2628             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   2629             PARAM(incy));
   2630 
   2631   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
   2632                const DeviceMemory<std::complex<double>> &, int,
   2633                const DeviceMemory<std::complex<double>> &, int,
   2634                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2635                int> impl;
   2636   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
   2637               x, incx, beta, y, incy);
   2638 }
   2639 
   2640 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
   2641                             const DeviceMemory<float> &x, int incx,
   2642                             const DeviceMemory<float> &y, int incy,
   2643                             DeviceMemory<float> *a, int lda) {
   2644   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2645             PARAM(incy), PARAM(a), PARAM(lda));
   2646 
   2647   ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
   2648                const DeviceMemory<float> &, int, DeviceMemory<float> *,
   2649                int> impl;
   2650   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
   2651               incy, a, lda);
   2652 }
   2653 
   2654 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
   2655                             const DeviceMemory<double> &x, int incx,
   2656                             const DeviceMemory<double> &y, int incy,
   2657                             DeviceMemory<double> *a, int lda) {
   2658   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2659             PARAM(incy), PARAM(a), PARAM(lda));
   2660 
   2661   ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
   2662                const DeviceMemory<double> &, int, DeviceMemory<double> *,
   2663                int> impl;
   2664   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
   2665               incy, a, lda);
   2666 }
   2667 
   2668 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
   2669                              const DeviceMemory<std::complex<float>> &x,
   2670                              int incx,
   2671                              const DeviceMemory<std::complex<float>> &y,
   2672                              int incy, DeviceMemory<std::complex<float>> *a,
   2673                              int lda) {
   2674   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2675             PARAM(incy), PARAM(a), PARAM(lda));
   2676 
   2677   ThenBlasImpl<uint64, uint64, std::complex<float>,
   2678                const DeviceMemory<std::complex<float>> &, int,
   2679                const DeviceMemory<std::complex<float>> &, int,
   2680                DeviceMemory<std::complex<float>> *, int> impl;
   2681   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
   2682               incy, a, lda);
   2683 }
   2684 
   2685 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
   2686                              const DeviceMemory<std::complex<double>> &x,
   2687                              int incx,
   2688                              const DeviceMemory<std::complex<double>> &y,
   2689                              int incy, DeviceMemory<std::complex<double>> *a,
   2690                              int lda) {
   2691   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2692             PARAM(incy), PARAM(a), PARAM(lda));
   2693 
   2694   ThenBlasImpl<uint64, uint64, std::complex<double>,
   2695                const DeviceMemory<std::complex<double>> &, int,
   2696                const DeviceMemory<std::complex<double>> &, int,
   2697                DeviceMemory<std::complex<double>> *, int> impl;
   2698   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
   2699               incy, a, lda);
   2700 }
   2701 
   2702 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
   2703                              const DeviceMemory<std::complex<float>> &x,
   2704                              int incx,
   2705                              const DeviceMemory<std::complex<float>> &y,
   2706                              int incy, DeviceMemory<std::complex<float>> *a,
   2707                              int lda) {
   2708   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2709             PARAM(incy), PARAM(a), PARAM(lda));
   2710 
   2711   ThenBlasImpl<uint64, uint64, std::complex<float>,
   2712                const DeviceMemory<std::complex<float>> &, int,
   2713                const DeviceMemory<std::complex<float>> &, int,
   2714                DeviceMemory<std::complex<float>> *, int> impl;
   2715   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
   2716               incy, a, lda);
   2717 }
   2718 
   2719 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
   2720                              const DeviceMemory<std::complex<double>> &x,
   2721                              int incx,
   2722                              const DeviceMemory<std::complex<double>> &y,
   2723                              int incy, DeviceMemory<std::complex<double>> *a,
   2724                              int lda) {
   2725   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2726             PARAM(incy), PARAM(a), PARAM(lda));
   2727 
   2728   ThenBlasImpl<uint64, uint64, std::complex<double>,
   2729                const DeviceMemory<std::complex<double>> &, int,
   2730                const DeviceMemory<std::complex<double>> &, int,
   2731                DeviceMemory<std::complex<double>> *, int> impl;
   2732   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
   2733               incy, a, lda);
   2734 }
   2735 
   2736 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   2737                              std::complex<float> alpha,
   2738                              const DeviceMemory<std::complex<float>> &a,
   2739                              int lda,
   2740                              const DeviceMemory<std::complex<float>> &x,
   2741                              int incx, std::complex<float> beta,
   2742                              DeviceMemory<std::complex<float>> *y, int incy) {
   2743   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
   2744             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2745 
   2746   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
   2747                const DeviceMemory<std::complex<float>> &, int,
   2748                const DeviceMemory<std::complex<float>> &, int,
   2749                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2750                int> impl;
   2751   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
   2752               x, incx, beta, y, incy);
   2753 }
   2754 
   2755 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   2756                              std::complex<double> alpha,
   2757                              const DeviceMemory<std::complex<double>> &a,
   2758                              int lda,
   2759                              const DeviceMemory<std::complex<double>> &x,
   2760                              int incx, std::complex<double> beta,
   2761                              DeviceMemory<std::complex<double>> *y, int incy) {
   2762   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
   2763             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2764 
   2765   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
   2766                const DeviceMemory<std::complex<double>> &, int,
   2767                const DeviceMemory<std::complex<double>> &, int,
   2768                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2769                int> impl;
   2770   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
   2771               x, incx, beta, y, incy);
   2772 }
   2773 
   2774 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
   2775                              std::complex<float> alpha,
   2776                              const DeviceMemory<std::complex<float>> &a,
   2777                              int lda,
   2778                              const DeviceMemory<std::complex<float>> &x,
   2779                              int incx, std::complex<float> beta,
   2780                              DeviceMemory<std::complex<float>> *y, int incy) {
   2781   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
   2782             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2783 
   2784   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
   2785                const DeviceMemory<std::complex<float>> &, int,
   2786                const DeviceMemory<std::complex<float>> &, int,
   2787                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2788                int> impl;
   2789   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
   2790               incx, beta, y, incy);
   2791 }
   2792 
   2793 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
   2794                              std::complex<double> alpha,
   2795                              const DeviceMemory<std::complex<double>> &a,
   2796                              int lda,
   2797                              const DeviceMemory<std::complex<double>> &x,
   2798                              int incx, std::complex<double> beta,
   2799                              DeviceMemory<std::complex<double>> *y, int incy) {
   2800   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
   2801             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2802 
   2803   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
   2804                const DeviceMemory<std::complex<double>> &, int,
   2805                const DeviceMemory<std::complex<double>> &, int,
   2806                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2807                int> impl;
   2808   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
   2809               incx, beta, y, incy);
   2810 }
   2811 
   2812 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
   2813                             const DeviceMemory<std::complex<float>> &x,
   2814                             int incx, DeviceMemory<std::complex<float>> *a,
   2815                             int lda) {
   2816   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2817             PARAM(a), PARAM(lda));
   2818 
   2819   ThenBlasImpl<blas::UpperLower, uint64, float,
   2820                const DeviceMemory<std::complex<float>> &, int,
   2821                DeviceMemory<std::complex<float>> *, int> impl;
   2822   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
   2823               lda);
   2824 }
   2825 
   2826 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
   2827                             const DeviceMemory<std::complex<double>> &x,
   2828                             int incx, DeviceMemory<std::complex<double>> *a,
   2829                             int lda) {
   2830   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2831             PARAM(a), PARAM(lda));
   2832 
   2833   ThenBlasImpl<blas::UpperLower, uint64, double,
   2834                const DeviceMemory<std::complex<double>> &, int,
   2835                DeviceMemory<std::complex<double>> *, int> impl;
   2836   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
   2837               lda);
   2838 }
   2839 
   2840 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
   2841                              std::complex<float> alpha,
   2842                              const DeviceMemory<std::complex<float>> &x,
   2843                              int incx,
   2844                              const DeviceMemory<std::complex<float>> &y,
   2845                              int incy, DeviceMemory<std::complex<float>> *a,
   2846                              int lda) {
   2847   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2848             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
   2849 
   2850   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
   2851                const DeviceMemory<std::complex<float>> &, int,
   2852                const DeviceMemory<std::complex<float>> &, int,
   2853                DeviceMemory<std::complex<float>> *, int> impl;
   2854   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
   2855               incy, a, lda);
   2856 }
   2857 
   2858 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
   2859                              std::complex<double> alpha,
   2860                              const DeviceMemory<std::complex<double>> &x,
   2861                              int incx,
   2862                              const DeviceMemory<std::complex<double>> &y,
   2863                              int incy, DeviceMemory<std::complex<double>> *a,
   2864                              int lda) {
   2865   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2866             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
   2867 
   2868   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
   2869                const DeviceMemory<std::complex<double>> &, int,
   2870                const DeviceMemory<std::complex<double>> &, int,
   2871                DeviceMemory<std::complex<double>> *, int> impl;
   2872   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
   2873               incy, a, lda);
   2874 }
   2875 
   2876 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
   2877                              std::complex<float> alpha,
   2878                              const DeviceMemory<std::complex<float>> &ap,
   2879                              const DeviceMemory<std::complex<float>> &x,
   2880                              int incx, std::complex<float> beta,
   2881                              DeviceMemory<std::complex<float>> *y, int incy) {
   2882   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
   2883             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2884 
   2885   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
   2886                const DeviceMemory<std::complex<float>> &,
   2887                const DeviceMemory<std::complex<float>> &, int,
   2888                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2889                int> impl;
   2890   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
   2891               beta, y, incy);
   2892 }
   2893 
   2894 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
   2895                              std::complex<double> alpha,
   2896                              const DeviceMemory<std::complex<double>> &ap,
   2897                              const DeviceMemory<std::complex<double>> &x,
   2898                              int incx, std::complex<double> beta,
   2899                              DeviceMemory<std::complex<double>> *y, int incy) {
   2900   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
   2901             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2902 
   2903   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
   2904                const DeviceMemory<std::complex<double>> &,
   2905                const DeviceMemory<std::complex<double>> &, int,
   2906                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2907                int> impl;
   2908   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
   2909               beta, y, incy);
   2910 }
   2911 
   2912 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
   2913                             const DeviceMemory<std::complex<float>> &x,
   2914                             int incx, DeviceMemory<std::complex<float>> *ap) {
   2915   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2916             PARAM(ap));
   2917 
   2918   ThenBlasImpl<blas::UpperLower, uint64, float,
   2919                const DeviceMemory<std::complex<float>> &, int,
   2920                DeviceMemory<std::complex<float>> *> impl;
   2921   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
   2922 }
   2923 
   2924 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
   2925                             const DeviceMemory<std::complex<double>> &x,
   2926                             int incx, DeviceMemory<std::complex<double>> *ap) {
   2927   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2928             PARAM(ap));
   2929 
   2930   ThenBlasImpl<blas::UpperLower, uint64, double,
   2931                const DeviceMemory<std::complex<double>> &, int,
   2932                DeviceMemory<std::complex<double>> *> impl;
   2933   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
   2934 }
   2935 
   2936 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
   2937                              std::complex<float> alpha,
   2938                              const DeviceMemory<std::complex<float>> &x,
   2939                              int incx,
   2940                              const DeviceMemory<std::complex<float>> &y,
   2941                              int incy, DeviceMemory<std::complex<float>> *ap) {
   2942   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2943             PARAM(y), PARAM(incy), PARAM(ap));
   2944 
   2945   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
   2946                const DeviceMemory<std::complex<float>> &, int,
   2947                const DeviceMemory<std::complex<float>> &, int,
   2948                DeviceMemory<std::complex<float>> *> impl;
   2949   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
   2950               incy, ap);
   2951 }
   2952 
   2953 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
   2954                              std::complex<double> alpha,
   2955                              const DeviceMemory<std::complex<double>> &x,
   2956                              int incx,
   2957                              const DeviceMemory<std::complex<double>> &y,
   2958                              int incy, DeviceMemory<std::complex<double>> *ap) {
   2959   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2960             PARAM(y), PARAM(incy), PARAM(ap));
   2961 
   2962   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
   2963                const DeviceMemory<std::complex<double>> &, int,
   2964                const DeviceMemory<std::complex<double>> &, int,
   2965                DeviceMemory<std::complex<double>> *> impl;
   2966   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
   2967               incy, ap);
   2968 }
   2969 
   2970 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   2971                              float alpha, const DeviceMemory<float> &a, int lda,
   2972                              const DeviceMemory<float> &x, int incx, float beta,
   2973                              DeviceMemory<float> *y, int incy) {
   2974   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
   2975             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2976 
   2977   ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
   2978                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   2979                int, float, DeviceMemory<float> *, int> impl;
   2980   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
   2981               x, incx, beta, y, incy);
   2982 }
   2983 
   2984 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   2985                              double alpha, const DeviceMemory<double> &a,
   2986                              int lda, const DeviceMemory<double> &x, int incx,
   2987                              double beta, DeviceMemory<double> *y, int incy) {
   2988   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
   2989             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2990 
   2991   ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
   2992                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   2993                int, double, DeviceMemory<double> *, int> impl;
   2994   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
   2995               x, incx, beta, y, incy);
   2996 }
   2997 
   2998 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
   2999                              const DeviceMemory<float> &ap,
   3000                              const DeviceMemory<float> &x, int incx, float beta,
   3001                              DeviceMemory<float> *y, int incy) {
   3002   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
   3003             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   3004 
   3005   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3006                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
   3007                int> impl;
   3008   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
   3009               beta, y, incy);
   3010 }
   3011 
   3012 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
   3013                              const DeviceMemory<double> &ap,
   3014                              const DeviceMemory<double> &x, int incx,
   3015                              double beta, DeviceMemory<double> *y, int incy) {
   3016   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
   3017             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   3018 
   3019   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3020                const DeviceMemory<double> &, int, double,
   3021                DeviceMemory<double> *, int> impl;
   3022   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
   3023               beta, y, incy);
   3024 }
   3025 
   3026 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
   3027                             const DeviceMemory<float> &x, int incx,
   3028                             DeviceMemory<float> *ap) {
   3029   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3030             PARAM(ap));
   3031 
   3032   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3033                int, DeviceMemory<float> *> impl;
   3034   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
   3035 }
   3036 
   3037 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
   3038                             const DeviceMemory<double> &x, int incx,
   3039                             DeviceMemory<double> *ap) {
   3040   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3041             PARAM(ap));
   3042 
   3043   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3044                int, DeviceMemory<double> *> impl;
   3045   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
   3046 }
   3047 
   3048 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
   3049                              const DeviceMemory<float> &x, int incx,
   3050                              const DeviceMemory<float> &y, int incy,
   3051                              DeviceMemory<float> *ap) {
   3052   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3053             PARAM(y), PARAM(incy), PARAM(ap));
   3054 
   3055   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3056                int, const DeviceMemory<float> &, int,
   3057                DeviceMemory<float> *> impl;
   3058   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
   3059               incy, ap);
   3060 }
   3061 
   3062 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
   3063                              const DeviceMemory<double> &x, int incx,
   3064                              const DeviceMemory<double> &y, int incy,
   3065                              DeviceMemory<double> *ap) {
   3066   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3067             PARAM(y), PARAM(incy), PARAM(ap));
   3068 
   3069   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3070                int, const DeviceMemory<double> &, int,
   3071                DeviceMemory<double> *> impl;
   3072   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
   3073               incy, ap);
   3074 }
   3075 
   3076 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
   3077                              const DeviceMemory<float> &a, int lda,
   3078                              const DeviceMemory<float> &x, int incx, float beta,
   3079                              DeviceMemory<float> *y, int incy) {
   3080   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
   3081             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   3082 
   3083   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3084                int, const DeviceMemory<float> &, int, float,
   3085                DeviceMemory<float> *, int> impl;
   3086   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
   3087               incx, beta, y, incy);
   3088 }
   3089 
   3090 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
   3091                              const DeviceMemory<double> &a, int lda,
   3092                              const DeviceMemory<double> &x, int incx,
   3093                              double beta, DeviceMemory<double> *y, int incy) {
   3094   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
   3095             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   3096 
   3097   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3098                int, const DeviceMemory<double> &, int, double,
   3099                DeviceMemory<double> *, int> impl;
   3100   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
   3101               incx, beta, y, incy);
   3102 }
   3103 
   3104 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
   3105                             const DeviceMemory<float> &x, int incx,
   3106                             DeviceMemory<float> *a, int lda) {
   3107   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3108             PARAM(a), PARAM(lda));
   3109 
   3110   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3111                int, DeviceMemory<float> *, int> impl;
   3112   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
   3113               lda);
   3114 }
   3115 
   3116 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
   3117                             const DeviceMemory<double> &x, int incx,
   3118                             DeviceMemory<double> *a, int lda) {
   3119   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3120             PARAM(a), PARAM(lda));
   3121 
   3122   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3123                int, DeviceMemory<double> *, int> impl;
   3124   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
   3125               lda);
   3126 }
   3127 
   3128 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
   3129                              const DeviceMemory<float> &x, int incx,
   3130                              const DeviceMemory<float> &y, int incy,
   3131                              DeviceMemory<float> *a, int lda) {
   3132   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3133             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
   3134 
   3135   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3136                int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3137                int> impl;
   3138   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
   3139               incy, a, lda);
   3140 }
   3141 
   3142 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
   3143                              const DeviceMemory<double> &x, int incx,
   3144                              const DeviceMemory<double> &y, int incy,
   3145                              DeviceMemory<double> *a, int lda) {
   3146   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3147             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
   3148 
   3149   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3150                int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
   3151                int> impl;
   3152   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
   3153               incy, a, lda);
   3154 }
   3155 
   3156 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   3157                              blas::Diagonal diag, uint64 n, uint64 k,
   3158                              const DeviceMemory<float> &a, int lda,
   3159                              DeviceMemory<float> *x, int incx) {
   3160   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3161             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3162 
   3163   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3164                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3165                int> impl;
   3166   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
   3167               lda, x, incx);
   3168 }
   3169 
   3170 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   3171                              blas::Diagonal diag, uint64 n, uint64 k,
   3172                              const DeviceMemory<double> &a, int lda,
   3173                              DeviceMemory<double> *x, int incx) {
   3174   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3175             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3176 
   3177   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3178                uint64, const DeviceMemory<double> &, int,
   3179                DeviceMemory<double> *, int> impl;
   3180   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
   3181               lda, x, incx);
   3182 }
   3183 
   3184 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   3185                              blas::Diagonal diag, uint64 n, uint64 k,
   3186                              const DeviceMemory<std::complex<float>> &a,
   3187                              int lda, DeviceMemory<std::complex<float>> *x,
   3188                              int incx) {
   3189   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3190             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3191 
   3192   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3193                uint64, const DeviceMemory<std::complex<float>> &, int,
   3194                DeviceMemory<std::complex<float>> *, int> impl;
   3195   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
   3196               lda, x, incx);
   3197 }
   3198 
   3199 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   3200                              blas::Diagonal diag, uint64 n, uint64 k,
   3201                              const DeviceMemory<std::complex<double>> &a,
   3202                              int lda, DeviceMemory<std::complex<double>> *x,
   3203                              int incx) {
   3204   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3205             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3206 
   3207   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3208                uint64, const DeviceMemory<std::complex<double>> &, int,
   3209                DeviceMemory<std::complex<double>> *, int> impl;
   3210   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
   3211               lda, x, incx);
   3212 }
   3213 
   3214 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   3215                              blas::Diagonal diag, uint64 n, uint64 k,
   3216                              const DeviceMemory<float> &a, int lda,
   3217                              DeviceMemory<float> *x, int incx) {
   3218   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3219             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3220 
   3221   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3222                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3223                int> impl;
   3224   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
   3225               lda, x, incx);
   3226 }
   3227 
   3228 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   3229                              blas::Diagonal diag, uint64 n, uint64 k,
   3230                              const DeviceMemory<double> &a, int lda,
   3231                              DeviceMemory<double> *x, int incx) {
   3232   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3233             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3234 
   3235   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3236                uint64, const DeviceMemory<double> &, int,
   3237                DeviceMemory<double> *, int> impl;
   3238   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
   3239               lda, x, incx);
   3240 }
   3241 
   3242 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   3243                              blas::Diagonal diag, uint64 n, uint64 k,
   3244                              const DeviceMemory<std::complex<float>> &a,
   3245                              int lda, DeviceMemory<std::complex<float>> *x,
   3246                              int incx) {
   3247   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3248             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3249 
   3250   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3251                uint64, const DeviceMemory<std::complex<float>> &, int,
   3252                DeviceMemory<std::complex<float>> *, int> impl;
   3253   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
   3254               lda, x, incx);
   3255 }
   3256 
   3257 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   3258                              blas::Diagonal diag, uint64 n, uint64 k,
   3259                              const DeviceMemory<std::complex<double>> &a,
   3260                              int lda, DeviceMemory<std::complex<double>> *x,
   3261                              int incx) {
   3262   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3263             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3264 
   3265   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3266                uint64, const DeviceMemory<std::complex<double>> &, int,
   3267                DeviceMemory<std::complex<double>> *, int> impl;
   3268   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
   3269               lda, x, incx);
   3270 }
   3271 
   3272 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   3273                              blas::Diagonal diag, uint64 n,
   3274                              const DeviceMemory<float> &ap,
   3275                              DeviceMemory<float> *x, int incx) {
   3276   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3277             PARAM(x), PARAM(incx));
   3278 
   3279   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3280                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
   3281   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
   3282               incx);
   3283 }
   3284 
   3285 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   3286                              blas::Diagonal diag, uint64 n,
   3287                              const DeviceMemory<double> &ap,
   3288                              DeviceMemory<double> *x, int incx) {
   3289   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3290             PARAM(x), PARAM(incx));
   3291 
   3292   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3293                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
   3294   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
   3295               incx);
   3296 }
   3297 
   3298 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   3299                              blas::Diagonal diag, uint64 n,
   3300                              const DeviceMemory<std::complex<float>> &ap,
   3301                              DeviceMemory<std::complex<float>> *x, int incx) {
   3302   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3303             PARAM(x), PARAM(incx));
   3304 
   3305   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3306                const DeviceMemory<std::complex<float>> &,
   3307                DeviceMemory<std::complex<float>> *, int> impl;
   3308   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
   3309               incx);
   3310 }
   3311 
   3312 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   3313                              blas::Diagonal diag, uint64 n,
   3314                              const DeviceMemory<std::complex<double>> &ap,
   3315                              DeviceMemory<std::complex<double>> *x, int incx) {
   3316   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3317             PARAM(x), PARAM(incx));
   3318 
   3319   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3320                const DeviceMemory<std::complex<double>> &,
   3321                DeviceMemory<std::complex<double>> *, int> impl;
   3322   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
   3323               incx);
   3324 }
   3325 
   3326 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   3327                              blas::Diagonal diag, uint64 n,
   3328                              const DeviceMemory<float> &ap,
   3329                              DeviceMemory<float> *x, int incx) {
   3330   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3331             PARAM(x), PARAM(incx));
   3332 
   3333   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3334                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
   3335   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
   3336               incx);
   3337 }
   3338 
   3339 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   3340                              blas::Diagonal diag, uint64 n,
   3341                              const DeviceMemory<double> &ap,
   3342                              DeviceMemory<double> *x, int incx) {
   3343   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3344             PARAM(x), PARAM(incx));
   3345 
   3346   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3347                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
   3348   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
   3349               incx);
   3350 }
   3351 
   3352 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   3353                              blas::Diagonal diag, uint64 n,
   3354                              const DeviceMemory<std::complex<float>> &ap,
   3355                              DeviceMemory<std::complex<float>> *x, int incx) {
   3356   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3357             PARAM(x), PARAM(incx));
   3358 
   3359   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3360                const DeviceMemory<std::complex<float>> &,
   3361                DeviceMemory<std::complex<float>> *, int> impl;
   3362   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
   3363               incx);
   3364 }
   3365 
   3366 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   3367                              blas::Diagonal diag, uint64 n,
   3368                              const DeviceMemory<std::complex<double>> &ap,
   3369                              DeviceMemory<std::complex<double>> *x, int incx) {
   3370   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3371             PARAM(x), PARAM(incx));
   3372 
   3373   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3374                const DeviceMemory<std::complex<double>> &,
   3375                DeviceMemory<std::complex<double>> *, int> impl;
   3376   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
   3377               incx);
   3378 }
   3379 
   3380 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   3381                              blas::Diagonal diag, uint64 n,
   3382                              const DeviceMemory<float> &a, int lda,
   3383                              DeviceMemory<float> *x, int incx) {
   3384   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3385             PARAM(lda), PARAM(x), PARAM(incx));
   3386 
   3387   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3388                const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3389                int> impl;
   3390   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
   3391               lda, x, incx);
   3392 }
   3393 
   3394 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   3395                              blas::Diagonal diag, uint64 n,
   3396                              const DeviceMemory<double> &a, int lda,
   3397                              DeviceMemory<double> *x, int incx) {
   3398   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3399             PARAM(lda), PARAM(x), PARAM(incx));
   3400 
   3401   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3402                const DeviceMemory<double> &, int, DeviceMemory<double> *,
   3403                int> impl;
   3404   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
   3405               lda, x, incx);
   3406 }
   3407 
   3408 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   3409                              blas::Diagonal diag, uint64 n,
   3410                              const DeviceMemory<std::complex<float>> &a,
   3411                              int lda, DeviceMemory<std::complex<float>> *x,
   3412                              int incx) {
   3413   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3414             PARAM(lda), PARAM(x), PARAM(incx));
   3415 
   3416   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3417                const DeviceMemory<std::complex<float>> &, int,
   3418                DeviceMemory<std::complex<float>> *, int> impl;
   3419   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
   3420               lda, x, incx);
   3421 }
   3422 
   3423 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   3424                              blas::Diagonal diag, uint64 n,
   3425                              const DeviceMemory<std::complex<double>> &a,
   3426                              int lda, DeviceMemory<std::complex<double>> *x,
   3427                              int incx) {
   3428   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3429             PARAM(lda), PARAM(x), PARAM(incx));
   3430 
   3431   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3432                const DeviceMemory<std::complex<double>> &, int,
   3433                DeviceMemory<std::complex<double>> *, int> impl;
   3434   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
   3435               lda, x, incx);
   3436 }
   3437 
   3438 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   3439                              blas::Diagonal diag, uint64 n,
   3440                              const DeviceMemory<float> &a, int lda,
   3441                              DeviceMemory<float> *x, int incx) {
   3442   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3443             PARAM(lda), PARAM(x), PARAM(incx));
   3444 
   3445   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3446                const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3447                int> impl;
   3448   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
   3449               lda, x, incx);
   3450 }
   3451 
   3452 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   3453                              blas::Diagonal diag, uint64 n,
   3454                              const DeviceMemory<double> &a, int lda,
   3455                              DeviceMemory<double> *x, int incx) {
   3456   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3457             PARAM(lda), PARAM(x), PARAM(incx));
   3458 
   3459   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3460                const DeviceMemory<double> &, int, DeviceMemory<double> *,
   3461                int> impl;
   3462   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
   3463               lda, x, incx);
   3464 }
   3465 
   3466 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   3467                              blas::Diagonal diag, uint64 n,
   3468                              const DeviceMemory<std::complex<float>> &a,
   3469                              int lda, DeviceMemory<std::complex<float>> *x,
   3470                              int incx) {
   3471   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3472             PARAM(lda), PARAM(x), PARAM(incx));
   3473 
   3474   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3475                const DeviceMemory<std::complex<float>> &, int,
   3476                DeviceMemory<std::complex<float>> *, int> impl;
   3477   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
   3478               lda, x, incx);
   3479 }
   3480 
   3481 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   3482                              blas::Diagonal diag, uint64 n,
   3483                              const DeviceMemory<std::complex<double>> &a,
   3484                              int lda, DeviceMemory<std::complex<double>> *x,
   3485                              int incx) {
   3486   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3487             PARAM(lda), PARAM(x), PARAM(incx));
   3488 
   3489   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3490                const DeviceMemory<std::complex<double>> &, int,
   3491                DeviceMemory<std::complex<double>> *, int> impl;
   3492   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
   3493               lda, x, incx);
   3494 }
   3495 
   3496 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3497                              uint64 m, uint64 n, uint64 k, float alpha,
   3498                              const DeviceMemory<Eigen::half> &a, int lda,
   3499                              const DeviceMemory<Eigen::half> &b, int ldb,
   3500                              float beta,
   3501                              DeviceMemory<Eigen::half> *c, int ldc) {
   3502   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3503             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3504             PARAM(beta), PARAM(c), PARAM(ldc));
   3505 
   3506   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   3507                const DeviceMemory<Eigen::half> &, int,
   3508                const DeviceMemory<Eigen::half> &, int,
   3509                float, DeviceMemory<Eigen::half> *, int> impl;
   3510   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3511               alpha, a, lda, b, ldb, beta, c, ldc);
   3512 }
   3513 
   3514 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3515                              uint64 m, uint64 n, uint64 k, float alpha,
   3516                              const DeviceMemory<float> &a, int lda,
   3517                              const DeviceMemory<float> &b, int ldb, float beta,
   3518                              DeviceMemory<float> *c, int ldc) {
   3519   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3520             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3521             PARAM(beta), PARAM(c), PARAM(ldc));
   3522 
   3523   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   3524                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   3525                int, float, DeviceMemory<float> *, int> impl;
   3526   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3527               alpha, a, lda, b, ldb, beta, c, ldc);
   3528 }
   3529 
   3530 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3531                              uint64 m, uint64 n, uint64 k, double alpha,
   3532                              const DeviceMemory<double> &a, int lda,
   3533                              const DeviceMemory<double> &b, int ldb,
   3534                              double beta, DeviceMemory<double> *c, int ldc) {
   3535   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3536             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3537             PARAM(beta), PARAM(c), PARAM(ldc));
   3538 
   3539   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
   3540                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   3541                int, double, DeviceMemory<double> *, int> impl;
   3542   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3543               alpha, a, lda, b, ldb, beta, c, ldc);
   3544 }
   3545 
   3546 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3547                              uint64 m, uint64 n, uint64 k,
   3548                              std::complex<float> alpha,
   3549                              const DeviceMemory<std::complex<float>> &a,
   3550                              int lda,
   3551                              const DeviceMemory<std::complex<float>> &b,
   3552                              int ldb, std::complex<float> beta,
   3553                              DeviceMemory<std::complex<float>> *c, int ldc) {
   3554   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3555             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3556             PARAM(beta), PARAM(c), PARAM(ldc));
   3557 
   3558   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3559                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   3560                int, const DeviceMemory<std::complex<float>> &, int,
   3561                std::complex<float>, DeviceMemory<std::complex<float>> *,
   3562                int> impl;
   3563   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3564               alpha, a, lda, b, ldb, beta, c, ldc);
   3565 }
   3566 
   3567 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3568                              uint64 m, uint64 n, uint64 k,
   3569                              std::complex<double> alpha,
   3570                              const DeviceMemory<std::complex<double>> &a,
   3571                              int lda,
   3572                              const DeviceMemory<std::complex<double>> &b,
   3573                              int ldb, std::complex<double> beta,
   3574                              DeviceMemory<std::complex<double>> *c, int ldc) {
   3575   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3576             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3577             PARAM(beta), PARAM(c), PARAM(ldc));
   3578 
   3579   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3580                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   3581                int, const DeviceMemory<std::complex<double>> &, int,
   3582                std::complex<double>, DeviceMemory<std::complex<double>> *,
   3583                int> impl;
   3584   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3585               alpha, a, lda, b, ldb, beta, c, ldc);
   3586 }
   3587 
   3588 namespace {
   3589 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
   3590 // blas::ProfileResult*.  This functor doesn't put the stream into an error
   3591 // state if the op fails and the profile result is non-null.  Instead, the
   3592 // error-ness is returned in the profile result itself.
   3593 template <typename... Args>
   3594 struct ThenBlasWithProfileImpl {
   3595   Stream &operator()(Stream *stream,
   3596                      bool (blas::BlasSupport::*blas_func)(
   3597                          Stream *, Args..., blas::ProfileResult *),
   3598                      Args... args, blas::ProfileResult *profile_result) {
   3599     ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
   3600     bool record_error = profile_result == nullptr;
   3601     return Runner.Run(stream, blas_func, record_error, args..., profile_result);
   3602   }
   3603 };
   3604 }  // anonymous namespace
   3605 
   3606 Stream &Stream::ThenBlasGemvWithProfiling(
   3607     blas::Transpose trans, uint64 m, uint64 n, float alpha,
   3608     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
   3609     int incx, float beta, DeviceMemory<float> *y, int incy,
   3610     blas::ProfileResult *output_profile_result) {
   3611   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   3612             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   3613             PARAM(incy));
   3614 
   3615   ThenBlasWithProfileImpl<
   3616       blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
   3617       const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
   3618       impl;
   3619   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
   3620               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
   3621 }
   3622 
   3623 Stream &Stream::ThenBlasGemvWithProfiling(
   3624     blas::Transpose trans, uint64 m, uint64 n, double alpha,
   3625     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
   3626     int incx, double beta, DeviceMemory<double> *y, int incy,
   3627     blas::ProfileResult *output_profile_result) {
   3628   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   3629             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   3630             PARAM(incy));
   3631 
   3632   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
   3633                           const DeviceMemory<double> &, int,
   3634                           const DeviceMemory<double> &, int, double,
   3635                           DeviceMemory<double> *, int>
   3636       impl;
   3637   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
   3638               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
   3639 }
   3640 
   3641 Stream &Stream::ThenBlasGemvWithProfiling(
   3642     blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
   3643     const DeviceMemory<std::complex<float>> &a, int lda,
   3644     const DeviceMemory<std::complex<float>> &x, int incx,
   3645     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
   3646     blas::ProfileResult *output_profile_result) {
   3647   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   3648             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   3649             PARAM(incy));
   3650 
   3651   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
   3652                           const DeviceMemory<std::complex<float>> &, int,
   3653                           const DeviceMemory<std::complex<float>> &, int,
   3654                           std::complex<float>,
   3655                           DeviceMemory<std::complex<float>> *, int>
   3656       impl;
   3657   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
   3658               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
   3659 }
   3660 
   3661 Stream &Stream::ThenBlasGemvWithProfiling(
   3662     blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
   3663     const DeviceMemory<std::complex<double>> &a, int lda,
   3664     const DeviceMemory<std::complex<double>> &x, int incx,
   3665     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
   3666     blas::ProfileResult *output_profile_result) {
   3667   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   3668             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   3669             PARAM(incy));
   3670 
   3671   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
   3672                           const DeviceMemory<std::complex<double>> &, int,
   3673                           const DeviceMemory<std::complex<double>> &, int,
   3674                           std::complex<double>,
   3675                           DeviceMemory<std::complex<double>> *, int>
   3676       impl;
   3677   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
   3678               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
   3679 }
   3680 
   3681 Stream &Stream::ThenBlasGemmWithProfiling(
   3682     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3683     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
   3684     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
   3685     DeviceMemory<Eigen::half> *c, int ldc,
   3686     blas::ProfileResult *output_profile_result) {
   3687   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3688             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3689             PARAM(beta), PARAM(c), PARAM(ldc));
   3690 
   3691   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3692                           uint64, float, const DeviceMemory<Eigen::half> &, int,
   3693                           const DeviceMemory<Eigen::half> &, int, float,
   3694                           DeviceMemory<Eigen::half> *, int>
   3695       impl;
   3696   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3697               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3698               output_profile_result);
   3699 }
   3700 
   3701 Stream &Stream::ThenBlasGemmWithProfiling(
   3702     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3703     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
   3704     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
   3705     int ldc, blas::ProfileResult *output_profile_result) {
   3706   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3707             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3708             PARAM(beta), PARAM(c), PARAM(ldc));
   3709 
   3710   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3711                           uint64, float, const DeviceMemory<float> &, int,
   3712                           const DeviceMemory<float> &, int, float,
   3713                           DeviceMemory<float> *, int>
   3714       impl;
   3715   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3716               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3717               output_profile_result);
   3718 }
   3719 
   3720 Stream &Stream::ThenBlasGemmWithProfiling(
   3721     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3722     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
   3723     const DeviceMemory<double> &b, int ldb, double beta,
   3724     DeviceMemory<double> *c, int ldc,
   3725     blas::ProfileResult *output_profile_result) {
   3726   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3727             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3728             PARAM(beta), PARAM(c), PARAM(ldc));
   3729 
   3730   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3731                           uint64, double, const DeviceMemory<double> &, int,
   3732                           const DeviceMemory<double> &, int, double,
   3733                           DeviceMemory<double> *, int>
   3734       impl;
   3735   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3736               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3737               output_profile_result);
   3738 }
   3739 
   3740 Stream &Stream::ThenBlasGemmWithProfiling(
   3741     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3742     uint64 k, std::complex<float> alpha,
   3743     const DeviceMemory<std::complex<float>> &a, int lda,
   3744     const DeviceMemory<std::complex<float>> &b, int ldb,
   3745     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
   3746     blas::ProfileResult *output_profile_result) {
   3747   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3748             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3749             PARAM(beta), PARAM(c), PARAM(ldc));
   3750 
   3751   ThenBlasWithProfileImpl<
   3752       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3753       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
   3754       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
   3755       DeviceMemory<std::complex<float>> *, int>
   3756       impl;
   3757   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3758               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3759               output_profile_result);
   3760 }
   3761 
   3762 Stream &Stream::ThenBlasGemmWithProfiling(
   3763     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3764     uint64 k, std::complex<double> alpha,
   3765     const DeviceMemory<std::complex<double>> &a, int lda,
   3766     const DeviceMemory<std::complex<double>> &b, int ldb,
   3767     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
   3768     blas::ProfileResult *output_profile_result) {
   3769   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3770             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3771             PARAM(beta), PARAM(c), PARAM(ldc));
   3772 
   3773   ThenBlasWithProfileImpl<
   3774       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3775       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
   3776       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
   3777       DeviceMemory<std::complex<double>> *, int>
   3778       impl;
   3779   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3780               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3781               output_profile_result);
   3782 }
   3783 
   3784 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3785     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3786     uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
   3787     int lda, const DeviceMemory<Eigen::half> &b, int ldb,
   3788     const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
   3789     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   3790     blas::ProfileResult *output_profile_result) {
   3791   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3792             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3793             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3794             PARAM(algorithm));
   3795 
   3796   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3797                           uint64, const Eigen::half &,
   3798                           const DeviceMemory<Eigen::half> &, int,
   3799                           const DeviceMemory<Eigen::half> &, int,
   3800                           const Eigen::half &, DeviceMemory<Eigen::half> *, int,
   3801                           blas::ComputationType, blas::AlgorithmType>
   3802       impl;
   3803   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3804               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3805               algorithm, output_profile_result);
   3806 }
   3807 
   3808 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3809     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3810     uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
   3811     const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
   3812     int ldc, blas::ComputationType computation_type,
   3813     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   3814   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3815             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3816             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3817             PARAM(algorithm));
   3818 
   3819   ThenBlasWithProfileImpl<
   3820       blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
   3821       const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
   3822       DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
   3823       impl;
   3824   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3825               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3826               algorithm, output_profile_result);
   3827 }
   3828 
   3829 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3830     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3831     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
   3832     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
   3833     int ldc, blas::ComputationType computation_type,
   3834     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   3835   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3836             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3837             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3838             PARAM(algorithm));
   3839 
   3840   ThenBlasWithProfileImpl<
   3841       blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   3842       const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float,
   3843       DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
   3844       impl;
   3845   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3846               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3847               algorithm, output_profile_result);
   3848 }
   3849 
   3850 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3851     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3852     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
   3853     const DeviceMemory<double> &b, int ldb, double beta,
   3854     DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
   3855     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   3856   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3857             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3858             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3859             PARAM(algorithm));
   3860 
   3861   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3862                           uint64, double, const DeviceMemory<double> &, int,
   3863                           const DeviceMemory<double> &, int, double,
   3864                           DeviceMemory<double> *, int, blas::ComputationType,
   3865                           blas::AlgorithmType>
   3866       impl;
   3867   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3868               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3869               algorithm, output_profile_result);
   3870 }
   3871 
   3872 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3873     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3874     uint64 k, std::complex<float> alpha,
   3875     const DeviceMemory<std::complex<float>> &a, int lda,
   3876     const DeviceMemory<std::complex<float>> &b, int ldb,
   3877     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
   3878     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   3879     blas::ProfileResult *output_profile_result) {
   3880   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3881             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3882             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3883             PARAM(algorithm));
   3884 
   3885   ThenBlasWithProfileImpl<
   3886       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3887       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
   3888       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
   3889       DeviceMemory<std::complex<float>> *, int, blas::ComputationType,
   3890       blas::AlgorithmType>
   3891       impl;
   3892   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3893               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3894               algorithm, output_profile_result);
   3895 }
   3896 
   3897 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3898     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3899     uint64 k, std::complex<double> alpha,
   3900     const DeviceMemory<std::complex<double>> &a, int lda,
   3901     const DeviceMemory<std::complex<double>> &b, int ldb,
   3902     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
   3903     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   3904     blas::ProfileResult *output_profile_result) {
   3905   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3906             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3907             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3908             PARAM(algorithm));
   3909 
   3910   ThenBlasWithProfileImpl<
   3911       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3912       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
   3913       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
   3914       DeviceMemory<std::complex<double>> *, int, blas::ComputationType,
   3915       blas::AlgorithmType>
   3916       impl;
   3917   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3918               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3919               algorithm, output_profile_result);
   3920 }
   3921 
   3922 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
   3923                              uint64 n, std::complex<float> alpha,
   3924                              const DeviceMemory<std::complex<float>> &a,
   3925                              int lda,
   3926                              const DeviceMemory<std::complex<float>> &b,
   3927                              int ldb, std::complex<float> beta,
   3928                              DeviceMemory<std::complex<float>> *c, int ldc) {
   3929   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   3930             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   3931             PARAM(ldc));
   3932 
   3933   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
   3934                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   3935                int, const DeviceMemory<std::complex<float>> &, int,
   3936                std::complex<float>, DeviceMemory<std::complex<float>> *,
   3937                int> impl;
   3938   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
   3939               lda, b, ldb, beta, c, ldc);
   3940 }
   3941 
   3942 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
   3943                              uint64 n, std::complex<double> alpha,
   3944                              const DeviceMemory<std::complex<double>> &a,
   3945                              int lda,
   3946                              const DeviceMemory<std::complex<double>> &b,
   3947                              int ldb, std::complex<double> beta,
   3948                              DeviceMemory<std::complex<double>> *c, int ldc) {
   3949   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   3950             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   3951             PARAM(ldc));
   3952 
   3953   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
   3954                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   3955                int, const DeviceMemory<std::complex<double>> &, int,
   3956                std::complex<double>, DeviceMemory<std::complex<double>> *,
   3957                int> impl;
   3958   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
   3959               lda, b, ldb, beta, c, ldc);
   3960 }
   3961 
   3962 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
   3963                              uint64 n, uint64 k, float alpha,
   3964                              const DeviceMemory<std::complex<float>> &a,
   3965                              int lda, float beta,
   3966                              DeviceMemory<std::complex<float>> *c, int ldc) {
   3967   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   3968             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   3969 
   3970   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
   3971                const DeviceMemory<std::complex<float>> &, int, float,
   3972                DeviceMemory<std::complex<float>> *, int> impl;
   3973   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
   3974               lda, beta, c, ldc);
   3975 }
   3976 
   3977 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
   3978                              uint64 n, uint64 k, double alpha,
   3979                              const DeviceMemory<std::complex<double>> &a,
   3980                              int lda, double beta,
   3981                              DeviceMemory<std::complex<double>> *c, int ldc) {
   3982   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   3983             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   3984 
   3985   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
   3986                const DeviceMemory<std::complex<double>> &, int, double,
   3987                DeviceMemory<std::complex<double>> *, int> impl;
   3988   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
   3989               lda, beta, c, ldc);
   3990 }
   3991 
   3992 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
   3993                               uint64 n, uint64 k, std::complex<float> alpha,
   3994                               const DeviceMemory<std::complex<float>> &a,
   3995                               int lda,
   3996                               const DeviceMemory<std::complex<float>> &b,
   3997                               int ldb, float beta,
   3998                               DeviceMemory<std::complex<float>> *c, int ldc) {
   3999   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4000             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4001             PARAM(ldc));
   4002 
   4003   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4004                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4005                int, const DeviceMemory<std::complex<float>> &, int, float,
   4006                DeviceMemory<std::complex<float>> *, int> impl;
   4007   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
   4008               a, lda, b, ldb, beta, c, ldc);
   4009 }
   4010 
   4011 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
   4012                               uint64 n, uint64 k, std::complex<double> alpha,
   4013                               const DeviceMemory<std::complex<double>> &a,
   4014                               int lda,
   4015                               const DeviceMemory<std::complex<double>> &b,
   4016                               int ldb, double beta,
   4017                               DeviceMemory<std::complex<double>> *c, int ldc) {
   4018   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4019             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4020             PARAM(ldc));
   4021 
   4022   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4023                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4024                int, const DeviceMemory<std::complex<double>> &, int, double,
   4025                DeviceMemory<std::complex<double>> *, int> impl;
   4026   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
   4027               a, lda, b, ldb, beta, c, ldc);
   4028 }
   4029 
   4030 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   4031                              uint64 n, float alpha,
   4032                              const DeviceMemory<float> &a, int lda,
   4033                              const DeviceMemory<float> &b, int ldb, float beta,
   4034                              DeviceMemory<float> *c, int ldc) {
   4035   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   4036             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4037             PARAM(ldc));
   4038 
   4039   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
   4040                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   4041                int, float, DeviceMemory<float> *, int> impl;
   4042   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
   4043               lda, b, ldb, beta, c, ldc);
   4044 }
   4045 
   4046 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   4047                              uint64 n, double alpha,
   4048                              const DeviceMemory<double> &a, int lda,
   4049                              const DeviceMemory<double> &b, int ldb,
   4050                              double beta, DeviceMemory<double> *c, int ldc) {
   4051   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   4052             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4053             PARAM(ldc));
   4054 
   4055   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
   4056                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   4057                int, double, DeviceMemory<double> *, int> impl;
   4058   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
   4059               lda, b, ldb, beta, c, ldc);
   4060 }
   4061 
   4062 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   4063                              uint64 n, std::complex<float> alpha,
   4064                              const DeviceMemory<std::complex<float>> &a,
   4065                              int lda,
   4066                              const DeviceMemory<std::complex<float>> &b,
   4067                              int ldb, std::complex<float> beta,
   4068                              DeviceMemory<std::complex<float>> *c, int ldc) {
   4069   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   4070             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4071             PARAM(ldc));
   4072 
   4073   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
   4074                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4075                int, const DeviceMemory<std::complex<float>> &, int,
   4076                std::complex<float>, DeviceMemory<std::complex<float>> *,
   4077                int> impl;
   4078   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
   4079               lda, b, ldb, beta, c, ldc);
   4080 }
   4081 
   4082 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   4083                              uint64 n, std::complex<double> alpha,
   4084                              const DeviceMemory<std::complex<double>> &a,
   4085                              int lda,
   4086                              const DeviceMemory<std::complex<double>> &b,
   4087                              int ldb, std::complex<double> beta,
   4088                              DeviceMemory<std::complex<double>> *c, int ldc) {
   4089   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   4090             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4091             PARAM(ldc));
   4092 
   4093   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
   4094                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4095                int, const DeviceMemory<std::complex<double>> &, int,
   4096                std::complex<double>, DeviceMemory<std::complex<double>> *,
   4097                int> impl;
   4098   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
   4099               lda, b, ldb, beta, c, ldc);
   4100 }
   4101 
   4102 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
   4103                              uint64 n, uint64 k, float alpha,
   4104                              const DeviceMemory<float> &a, int lda, float beta,
   4105                              DeviceMemory<float> *c, int ldc) {
   4106   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4107             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   4108 
   4109   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
   4110                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
   4111                int> impl;
   4112   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
   4113               lda, beta, c, ldc);
   4114 }
   4115 
   4116 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
   4117                              uint64 n, uint64 k, double alpha,
   4118                              const DeviceMemory<double> &a, int lda,
   4119                              double beta, DeviceMemory<double> *c, int ldc) {
   4120   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4121             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   4122 
   4123   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
   4124                const DeviceMemory<double> &, int, double,
   4125                DeviceMemory<double> *, int> impl;
   4126   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
   4127               lda, beta, c, ldc);
   4128 }
   4129 
   4130 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
   4131                              uint64 n, uint64 k, std::complex<float> alpha,
   4132                              const DeviceMemory<std::complex<float>> &a,
   4133                              int lda, std::complex<float> beta,
   4134                              DeviceMemory<std::complex<float>> *c, int ldc) {
   4135   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4136             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   4137 
   4138   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4139                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4140                int, std::complex<float>, DeviceMemory<std::complex<float>> *,
   4141                int> impl;
   4142   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
   4143               lda, beta, c, ldc);
   4144 }
   4145 
   4146 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
   4147                              uint64 n, uint64 k, std::complex<double> alpha,
   4148                              const DeviceMemory<std::complex<double>> &a,
   4149                              int lda, std::complex<double> beta,
   4150                              DeviceMemory<std::complex<double>> *c, int ldc) {
   4151   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4152             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   4153 
   4154   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4155                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4156                int, std::complex<double>, DeviceMemory<std::complex<double>> *,
   4157                int> impl;
   4158   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
   4159               lda, beta, c, ldc);
   4160 }
   4161 
   4162 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
   4163                               uint64 n, uint64 k, float alpha,
   4164                               const DeviceMemory<float> &a, int lda,
   4165                               const DeviceMemory<float> &b, int ldb, float beta,
   4166                               DeviceMemory<float> *c, int ldc) {
   4167   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4168             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4169             PARAM(ldc));
   4170 
   4171   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
   4172                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   4173                int, float, DeviceMemory<float> *, int> impl;
   4174   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
   4175               a, lda, b, ldb, beta, c, ldc);
   4176 }
   4177 
   4178 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
   4179                               uint64 n, uint64 k, double alpha,
   4180                               const DeviceMemory<double> &a, int lda,
   4181                               const DeviceMemory<double> &b, int ldb,
   4182                               double beta, DeviceMemory<double> *c, int ldc) {
   4183   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4184             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4185             PARAM(ldc));
   4186 
   4187   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
   4188                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   4189                int, double, DeviceMemory<double> *, int> impl;
   4190   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
   4191               a, lda, b, ldb, beta, c, ldc);
   4192 }
   4193 
   4194 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
   4195                               uint64 n, uint64 k, std::complex<float> alpha,
   4196                               const DeviceMemory<std::complex<float>> &a,
   4197                               int lda,
   4198                               const DeviceMemory<std::complex<float>> &b,
   4199                               int ldb, std::complex<float> beta,
   4200                               DeviceMemory<std::complex<float>> *c, int ldc) {
   4201   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4202             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4203             PARAM(ldc));
   4204 
   4205   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4206                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4207                int, const DeviceMemory<std::complex<float>> &, int,
   4208                std::complex<float>, DeviceMemory<std::complex<float>> *,
   4209                int> impl;
   4210   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
   4211               a, lda, b, ldb, beta, c, ldc);
   4212 }
   4213 
   4214 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
   4215                               uint64 n, uint64 k, std::complex<double> alpha,
   4216                               const DeviceMemory<std::complex<double>> &a,
   4217                               int lda,
   4218                               const DeviceMemory<std::complex<double>> &b,
   4219                               int ldb, std::complex<double> beta,
   4220                               DeviceMemory<std::complex<double>> *c, int ldc) {
   4221   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4222             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4223             PARAM(ldc));
   4224 
   4225   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4226                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4227                int, const DeviceMemory<std::complex<double>> &, int,
   4228                std::complex<double>, DeviceMemory<std::complex<double>> *,
   4229                int> impl;
   4230   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
   4231               a, lda, b, ldb, beta, c, ldc);
   4232 }
   4233 
   4234 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   4235                              blas::Transpose transa, blas::Diagonal diag,
   4236                              uint64 m, uint64 n, float alpha,
   4237                              const DeviceMemory<float> &a, int lda,
   4238                              DeviceMemory<float> *b, int ldb) {
   4239   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4240             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4241 
   4242   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4243                uint64, uint64, float, const DeviceMemory<float> &, int,
   4244                DeviceMemory<float> *, int> impl;
   4245   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
   4246               n, alpha, a, lda, b, ldb);
   4247 }
   4248 
   4249 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   4250                              blas::Transpose transa, blas::Diagonal diag,
   4251                              uint64 m, uint64 n, double alpha,
   4252                              const DeviceMemory<double> &a, int lda,
   4253                              DeviceMemory<double> *b, int ldb) {
   4254   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4255             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4256 
   4257   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4258                uint64, uint64, double, const DeviceMemory<double> &, int,
   4259                DeviceMemory<double> *, int> impl;
   4260   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
   4261               n, alpha, a, lda, b, ldb);
   4262 }
   4263 
   4264 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   4265                              blas::Transpose transa, blas::Diagonal diag,
   4266                              uint64 m, uint64 n, std::complex<float> alpha,
   4267                              const DeviceMemory<std::complex<float>> &a,
   4268                              int lda, DeviceMemory<std::complex<float>> *b,
   4269                              int ldb) {
   4270   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4271             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4272 
   4273   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4274                uint64, uint64, std::complex<float>,
   4275                const DeviceMemory<std::complex<float>> &, int,
   4276                DeviceMemory<std::complex<float>> *, int> impl;
   4277   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
   4278               n, alpha, a, lda, b, ldb);
   4279 }
   4280 
   4281 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   4282                              blas::Transpose transa, blas::Diagonal diag,
   4283                              uint64 m, uint64 n, std::complex<double> alpha,
   4284                              const DeviceMemory<std::complex<double>> &a,
   4285                              int lda, DeviceMemory<std::complex<double>> *b,
   4286                              int ldb) {
   4287   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4288             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4289 
   4290   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4291                uint64, uint64, std::complex<double>,
   4292                const DeviceMemory<std::complex<double>> &, int,
   4293                DeviceMemory<std::complex<double>> *, int> impl;
   4294   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
   4295               n, alpha, a, lda, b, ldb);
   4296 }
   4297 
   4298 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   4299                              blas::Transpose transa, blas::Diagonal diag,
   4300                              uint64 m, uint64 n, float alpha,
   4301                              const DeviceMemory<float> &a, int lda,
   4302                              DeviceMemory<float> *b, int ldb) {
   4303   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4304             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4305 
   4306   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4307                uint64, uint64, float, const DeviceMemory<float> &, int,
   4308                DeviceMemory<float> *, int> impl;
   4309   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
   4310               n, alpha, a, lda, b, ldb);
   4311 }
   4312 
   4313 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   4314                              blas::Transpose transa, blas::Diagonal diag,
   4315                              uint64 m, uint64 n, double alpha,
   4316                              const DeviceMemory<double> &a, int lda,
   4317                              DeviceMemory<double> *b, int ldb) {
   4318   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4319             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4320 
   4321   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4322                uint64, uint64, double, const DeviceMemory<double> &, int,
   4323                DeviceMemory<double> *, int> impl;
   4324   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
   4325               n, alpha, a, lda, b, ldb);
   4326 }
   4327 
   4328 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   4329                              blas::Transpose transa, blas::Diagonal diag,
   4330                              uint64 m, uint64 n, std::complex<float> alpha,
   4331                              const DeviceMemory<std::complex<float>> &a,
   4332                              int lda, DeviceMemory<std::complex<float>> *b,
   4333                              int ldb) {
   4334   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4335             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4336 
   4337   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4338                uint64, uint64, std::complex<float>,
   4339                const DeviceMemory<std::complex<float>> &, int,
   4340                DeviceMemory<std::complex<float>> *, int> impl;
   4341   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
   4342               n, alpha, a, lda, b, ldb);
   4343 }
   4344 
   4345 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   4346                              blas::Transpose transa, blas::Diagonal diag,
   4347                              uint64 m, uint64 n, std::complex<double> alpha,
   4348                              const DeviceMemory<std::complex<double>> &a,
   4349                              int lda, DeviceMemory<std::complex<double>> *b,
   4350                              int ldb) {
   4351   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4352             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4353 
   4354   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4355                uint64, uint64, std::complex<double>,
   4356                const DeviceMemory<std::complex<double>> &, int,
   4357                DeviceMemory<std::complex<double>> *, int> impl;
   4358   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
   4359               n, alpha, a, lda, b, ldb);
   4360 }
   4361 
   4362 Stream &Stream::ThenBlasGemmBatched(
   4363     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4364     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
   4365     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
   4366     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
   4367     int batch_count) {
   4368   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4369                                         b, ldb, beta, c, ldc, batch_count,
   4370                                         /*scratch_allocator=*/nullptr);
   4371 }
   4372 
   4373 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4374     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4375     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
   4376     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
   4377     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
   4378     int batch_count, ScratchAllocator *scratch_allocator) {
   4379   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4380             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4381             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4382 
   4383   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   4384                const port::ArraySlice<DeviceMemory<float> *> &, int,
   4385                const port::ArraySlice<DeviceMemory<float> *> &, int, float,
   4386                const port::ArraySlice<DeviceMemory<float> *> &, int, int,
   4387                ScratchAllocator *>
   4388       impl;
   4389   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4390               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4391               scratch_allocator);
   4392 }
   4393 
   4394 Stream &Stream::ThenBlasGemmBatched(
   4395     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4396     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
   4397     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
   4398     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
   4399     int batch_count) {
   4400   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4401                                         b, ldb, beta, c, ldc, batch_count,
   4402                                         /*scratch_allocator=*/nullptr);
   4403 }
   4404 
   4405 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4406     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4407     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
   4408     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
   4409     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
   4410     int batch_count, ScratchAllocator *scratch_allocator) {
   4411   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4412             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4413             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4414 
   4415   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
   4416                const port::ArraySlice<DeviceMemory<double> *> &, int,
   4417                const port::ArraySlice<DeviceMemory<double> *> &, int, double,
   4418                const port::ArraySlice<DeviceMemory<double> *> &, int, int,
   4419                ScratchAllocator *>
   4420       impl;
   4421   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4422               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4423               scratch_allocator);
   4424 }
   4425 
   4426 Stream &Stream::ThenBlasGemmBatched(
   4427     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4428     uint64 k, std::complex<float> alpha,
   4429     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
   4430     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
   4431     std::complex<float> beta,
   4432     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
   4433     int batch_count) {
   4434   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4435                                         b, ldb, beta, c, ldc, batch_count,
   4436                                         /*scratch_allocator=*/nullptr);
   4437 }
   4438 
   4439 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4440     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4441     uint64 k, std::complex<float> alpha,
   4442     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
   4443     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
   4444     std::complex<float> beta,
   4445     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
   4446     int batch_count, ScratchAllocator *scratch_allocator) {
   4447   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4448             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4449             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4450 
   4451   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   4452                std::complex<float>,
   4453                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
   4454                int,
   4455                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
   4456                int, std::complex<float>,
   4457                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
   4458                int, int, ScratchAllocator *>
   4459       impl;
   4460   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4461               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4462               scratch_allocator);
   4463 }
   4464 
   4465 Stream &Stream::ThenBlasGemmBatched(
   4466     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4467     uint64 k, std::complex<double> alpha,
   4468     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
   4469     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
   4470     std::complex<double> beta,
   4471     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
   4472     int batch_count) {
   4473   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4474                                         b, ldb, beta, c, ldc, batch_count,
   4475                                         /*scratch_allocator=*/nullptr);
   4476 }
   4477 
   4478 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4479     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4480     uint64 k, std::complex<double> alpha,
   4481     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
   4482     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
   4483     std::complex<double> beta,
   4484     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
   4485     int batch_count, ScratchAllocator *scratch_allocator) {
   4486   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4487             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4488             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4489 
   4490   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   4491                std::complex<double>,
   4492                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
   4493                int,
   4494                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
   4495                int, std::complex<double>,
   4496                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
   4497                int, int, ScratchAllocator *>
   4498       impl;
   4499   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4500               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4501               scratch_allocator);
   4502 }
   4503 
   4504 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
   4505   VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
   4506 
   4507   if (ok()) {
   4508     if (rng::RngSupport *rng = parent_->AsRng()) {
   4509       CheckError(rng->SetSeed(this, seed, seed_bytes));
   4510     } else {
   4511       SetError();
   4512       LOG(INFO) << "stream " << this << " unable to initialize RNG";
   4513     }
   4514   } else {
   4515     LOG(INFO) << "stream " << this
   4516               << " did not set RNG seed: " << static_cast<const void *>(seed)
   4517               << "; bytes: " << seed_bytes;
   4518   }
   4519   return *this;
   4520 }
   4521 
   4522 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
   4523   VLOG_CALL(PARAM(values));
   4524 
   4525   if (ok()) {
   4526     if (rng::RngSupport *rng = parent_->AsRng()) {
   4527       CheckError(rng->DoPopulateRandUniform(this, values));
   4528     } else {
   4529       SetError();
   4530       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
   4531                    "without RNG support.";
   4532     }
   4533   }
   4534   return *this;
   4535 }
   4536 
   4537 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
   4538                                          DeviceMemory<float> *values) {
   4539   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
   4540 
   4541   if (ok()) {
   4542     if (rng::RngSupport *rng = parent_->AsRng()) {
   4543       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
   4544     } else {
   4545       SetError();
   4546       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
   4547                    "without RNG support.";
   4548     }
   4549   }
   4550   return *this;
   4551 }
   4552 
   4553 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
   4554                                          DeviceMemory<double> *values) {
   4555   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
   4556 
   4557   if (ok()) {
   4558     if (rng::RngSupport *rng = parent_->AsRng()) {
   4559       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
   4560     } else {
   4561       SetError();
   4562       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
   4563                    "without RNG support.";
   4564     }
   4565   }
   4566   return *this;
   4567 }
   4568 
   4569 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
   4570   VLOG_CALL(PARAM(values));
   4571 
   4572   if (ok()) {
   4573     if (rng::RngSupport *rng = parent_->AsRng()) {
   4574       CheckError(rng->DoPopulateRandUniform(this, values));
   4575     } else {
   4576       SetError();
   4577       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
   4578                    "without RNG support.";
   4579     }
   4580   }
   4581   return *this;
   4582 }
   4583 
   4584 Stream &Stream::ThenPopulateRandUniform(
   4585     DeviceMemory<std::complex<float>> *values) {
   4586   VLOG_CALL(PARAM(values));
   4587 
   4588   if (ok()) {
   4589     if (rng::RngSupport *rng = parent_->AsRng()) {
   4590       CheckError(rng->DoPopulateRandUniform(this, values));
   4591     } else {
   4592       SetError();
   4593       LOG(INFO) << "attempting to perform RNG operation using StreamExecutor "
   4594                    "without RNG support.";
   4595     }
   4596   }
   4597   return *this;
   4598 }
   4599 
   4600 Stream &Stream::ThenPopulateRandUniform(
   4601     DeviceMemory<std::complex<double>> *values) {
   4602   VLOG_CALL(PARAM(values));
   4603 
   4604   if (ok()) {
   4605     if (rng::RngSupport *rng = parent_->AsRng()) {
   4606       CheckError(rng->DoPopulateRandUniform(this, values));
   4607     } else {
   4608       SetError();
   4609       LOG(INFO) << "stream " << this
   4610                 << " attempting to perform RNG operation using StreamExecutor "
   4611                    "without RNG support.";
   4612     }
   4613   }
   4614   return *this;
   4615 }
   4616 
   4617 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
   4618                            uint64 size) {
   4619   VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
   4620 
   4621   if (ok()) {
   4622     CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
   4623   } else {
   4624     LOG(INFO) << "stream " << this
   4625               << " did not memcpy device-to-host; source: " << gpu_src.opaque();
   4626   }
   4627   return *this;
   4628 }
   4629 
   4630 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
   4631                            uint64 size) {
   4632   VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
   4633 
   4634   if (ok()) {
   4635     CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
   4636   } else {
   4637     LOG(INFO) << "stream " << this
   4638               << " did not memcpy host-to-device; source: " << host_src;
   4639   }
   4640   return *this;
   4641 }
   4642 
   4643 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
   4644                            const DeviceMemoryBase &gpu_src, uint64 size) {
   4645   VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
   4646 
   4647   if (ok()) {
   4648     CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
   4649   } else {
   4650     LOG(INFO) << "stream " << this
   4651               << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
   4652   }
   4653   return *this;
   4654 }
   4655 
   4656 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
   4657   VLOG_CALL(PARAM(location), PARAM(size));
   4658 
   4659   if (ok()) {
   4660     CheckError(parent_->MemZero(this, location, size));
   4661   } else {
   4662     LOG(INFO) << "stream " << this
   4663               << " did not memzero GPU location; source: " << location;
   4664   }
   4665   return *this;
   4666 }
   4667 
   4668 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
   4669                              uint64 size) {
   4670   VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
   4671 
   4672   if (ok()) {
   4673     CheckError(parent_->Memset32(this, location, pattern, size));
   4674   } else {
   4675     LOG(INFO) << "stream " << this
   4676               << " did not memset GPU location; source: " << location
   4677               << "; size: " << size << "; pattern: " << std::hex << pattern;
   4678   }
   4679   return *this;
   4680 }
   4681 
   4682 Stream &Stream::ThenRnnForward(
   4683     const dnn::RnnDescriptor &rnn_desc,
   4684     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4685     const DeviceMemory<Eigen::half> &input_data,
   4686     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4687     const DeviceMemory<Eigen::half> &input_h_data,
   4688     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4689     const DeviceMemory<Eigen::half> &input_c_data,
   4690     const DeviceMemory<Eigen::half> &params,
   4691     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4692     DeviceMemory<Eigen::half> *output_data,
   4693     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4694     DeviceMemory<Eigen::half> *output_h_data,
   4695     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4696     DeviceMemory<Eigen::half> *output_c_data, bool is_training,
   4697     ScratchAllocator *reserve_space_allocator,
   4698     ScratchAllocator *workspace_allocator) {
   4699   // TODO(zhengxq): add VLOG PARAM calls.
   4700   if (ok()) {
   4701     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4702       CheckError(dnn->DoRnnForward(
   4703           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4704           input_c_desc, input_c_data, params, output_desc, output_data,
   4705           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4706           is_training, reserve_space_allocator, workspace_allocator));
   4707     } else {
   4708       SetError();
   4709       LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
   4710     }
   4711   }
   4712   return *this;
   4713 }
   4714 
   4715 Stream &Stream::ThenRnnForward(
   4716     const dnn::RnnDescriptor &rnn_desc,
   4717     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4718     const DeviceMemory<float> &input_data,
   4719     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4720     const DeviceMemory<float> &input_h_data,
   4721     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4722     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
   4723     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4724     DeviceMemory<float> *output_data,
   4725     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4726     DeviceMemory<float> *output_h_data,
   4727     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4728     DeviceMemory<float> *output_c_data, bool is_training,
   4729     ScratchAllocator *reserve_space_allocator,
   4730     ScratchAllocator *workspace_allocator) {
   4731   // TODO(zhengxq): add VLOG PARAM calls.
   4732   if (ok()) {
   4733     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4734       CheckError(dnn->DoRnnForward(
   4735           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4736           input_c_desc, input_c_data, params, output_desc, output_data,
   4737           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4738           is_training, reserve_space_allocator, workspace_allocator));
   4739     } else {
   4740       SetError();
   4741       LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
   4742     }
   4743   }
   4744   return *this;
   4745 }
   4746 
   4747 Stream &Stream::ThenRnnForward(
   4748     const dnn::RnnDescriptor &rnn_desc,
   4749     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4750     const DeviceMemory<double> &input_data,
   4751     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4752     const DeviceMemory<double> &input_h_data,
   4753     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4754     const DeviceMemory<double> &input_c_data,
   4755     const DeviceMemory<double> &params,
   4756     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4757     DeviceMemory<double> *output_data,
   4758     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4759     DeviceMemory<double> *output_h_data,
   4760     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4761     DeviceMemory<double> *output_c_data, bool is_training,
   4762     ScratchAllocator *reserve_space_allocator,
   4763     ScratchAllocator *workspace_allocator) {
   4764   // TODO(zhengxq): add VLOG PARAM calls.
   4765   if (ok()) {
   4766     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4767       CheckError(dnn->DoRnnForward(
   4768           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4769           input_c_desc, input_c_data, params, output_desc, output_data,
   4770           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4771           is_training, reserve_space_allocator, workspace_allocator));
   4772     } else {
   4773       SetError();
   4774       LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
   4775     }
   4776   }
   4777   return *this;
   4778 }
   4779 
   4780 Stream &Stream::ThenRnnBackward(
   4781     const dnn::RnnDescriptor &rnn_desc,
   4782     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4783     const DeviceMemory<Eigen::half> &input_data,
   4784     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4785     const DeviceMemory<Eigen::half> &input_h_data,
   4786     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4787     const DeviceMemory<Eigen::half> &input_c_data,
   4788     const DeviceMemory<Eigen::half> &params,
   4789     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4790     const DeviceMemory<Eigen::half> &output_data,
   4791     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4792     const DeviceMemory<Eigen::half> &output_h_data,
   4793     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4794     const DeviceMemory<Eigen::half> &output_c_data,
   4795     const DeviceMemory<Eigen::half> &output_backprop_data,
   4796     const DeviceMemory<Eigen::half> &output_h_backprop_data,
   4797     const DeviceMemory<Eigen::half> &output_c_backprop_data,
   4798     DeviceMemory<Eigen::half> *input_backprop_data,
   4799     DeviceMemory<Eigen::half> *input_h_backprop_data,
   4800     DeviceMemory<Eigen::half> *input_c_backprop_data,
   4801     DeviceMemory<Eigen::half> *params_backprop_data,
   4802     DeviceMemory<uint8> *reserve_space_data,
   4803     ScratchAllocator *workspace_allocator) {
   4804   // TODO(zhengxq): add VLOG PARAM calls.
   4805   if (ok()) {
   4806     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4807       CheckError(dnn->DoRnnBackward(
   4808           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4809           input_c_desc, input_c_data, params, output_desc, output_data,
   4810           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4811           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
   4812           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
   4813           params_backprop_data, reserve_space_data, workspace_allocator));
   4814     } else {
   4815       SetError();
   4816       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
   4817     }
   4818   }
   4819   return *this;
   4820 }
   4821 
   4822 Stream &Stream::ThenRnnBackward(
   4823     const dnn::RnnDescriptor &rnn_desc,
   4824     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4825     const DeviceMemory<float> &input_data,
   4826     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4827     const DeviceMemory<float> &input_h_data,
   4828     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4829     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
   4830     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4831     const DeviceMemory<float> &output_data,
   4832     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4833     const DeviceMemory<float> &output_h_data,
   4834     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4835     const DeviceMemory<float> &output_c_data,
   4836     const DeviceMemory<float> &output_backprop_data,
   4837     const DeviceMemory<float> &output_h_backprop_data,
   4838     const DeviceMemory<float> &output_c_backprop_data,
   4839     DeviceMemory<float> *input_backprop_data,
   4840     DeviceMemory<float> *input_h_backprop_data,
   4841     DeviceMemory<float> *input_c_backprop_data,
   4842     DeviceMemory<float> *params_backprop_data,
   4843     DeviceMemory<uint8> *reserve_space_data,
   4844     ScratchAllocator *workspace_allocator) {
   4845   // TODO(zhengxq): add VLOG PARAM calls.
   4846   if (ok()) {
   4847     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4848       CheckError(dnn->DoRnnBackward(
   4849           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4850           input_c_desc, input_c_data, params, output_desc, output_data,
   4851           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4852           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
   4853           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
   4854           params_backprop_data, reserve_space_data, workspace_allocator));
   4855     } else {
   4856       SetError();
   4857       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
   4858     }
   4859   }
   4860   return *this;
   4861 }
   4862 
   4863 Stream &Stream::ThenRnnBackward(
   4864     const dnn::RnnDescriptor &rnn_desc,
   4865     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4866     const DeviceMemory<double> &input_data,
   4867     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4868     const DeviceMemory<double> &input_h_data,
   4869     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4870     const DeviceMemory<double> &input_c_data,
   4871     const DeviceMemory<double> &params,
   4872     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4873     const DeviceMemory<double> &output_data,
   4874     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4875     const DeviceMemory<double> &output_h_data,
   4876     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4877     const DeviceMemory<double> &output_c_data,
   4878     const DeviceMemory<double> &output_backprop_data,
   4879     const DeviceMemory<double> &output_h_backprop_data,
   4880     const DeviceMemory<double> &output_c_backprop_data,
   4881     DeviceMemory<double> *input_backprop_data,
   4882     DeviceMemory<double> *input_h_backprop_data,
   4883     DeviceMemory<double> *input_c_backprop_data,
   4884     DeviceMemory<double> *params_backprop_data,
   4885     DeviceMemory<uint8> *reserve_space_data,
   4886     ScratchAllocator *workspace_allocator) {
   4887   // TODO(zhengxq): add VLOG PARAM calls.
   4888   if (ok()) {
   4889     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4890       CheckError(dnn->DoRnnBackward(
   4891           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4892           input_c_desc, input_c_data, params, output_desc, output_data,
   4893           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4894           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
   4895           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
   4896           params_backprop_data, reserve_space_data, workspace_allocator));
   4897     } else {
   4898       SetError();
   4899       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
   4900     }
   4901   }
   4902   return *this;
   4903 }
   4904 
   4905 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
   4906                                     dnn::DataType input_type,
   4907                                     const DeviceMemoryBase &input_data,
   4908                                     const dnn::BatchDescriptor &output_desc,
   4909                                     dnn::DataType output_type, float scale,
   4910                                     DeviceMemoryBase *output_data) {
   4911   VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
   4912             PARAM(output_desc), PARAM(output_type), PARAM(scale),
   4913             PARAM(output_data));
   4914   if (ok()) {
   4915     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4916       CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
   4917                                         input_data, output_desc, output_type,
   4918                                         scale, output_data));
   4919     } else {
   4920       SetErrorAndLogNoDnnSupport();
   4921     }
   4922   }
   4923   return *this;
   4924 }
   4925 
   4926 Stream &Stream::ThenDoHostCallbackForTest(std::function<void()> callback) {
   4927   VLOG_CALL(PARAM(callback));
   4928 
   4929   return ThenDoHostCallback(callback);
   4930 }
   4931 
   4932 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
   4933   VLOG_CALL(PARAM(callback));
   4934 
   4935   if (ok()) {
   4936     CheckError(parent_->HostCallback(this, callback));
   4937   } else {
   4938     LOG(INFO) << "stream " << this
   4939               << " was in error state before adding host callback";
   4940   }
   4941   return *this;
   4942 }
   4943 
   4944 Stream &Stream::ThenFft(fft::Plan *plan,
   4945                         const DeviceMemory<std::complex<float>> &input,
   4946                         DeviceMemory<std::complex<float>> *output) {
   4947   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   4948 
   4949   if (ok()) {
   4950     if (fft::FftSupport *fft = parent_->AsFft()) {
   4951       CheckError(fft->DoFft(this, plan, input, output));
   4952     } else {
   4953       SetError();
   4954       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
   4955                    "without FFT support";
   4956     }
   4957   }
   4958   return *this;
   4959 }
   4960 
   4961 Stream &Stream::ThenFft(fft::Plan *plan,
   4962                         const DeviceMemory<std::complex<double>> &input,
   4963                         DeviceMemory<std::complex<double>> *output) {
   4964   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   4965 
   4966   if (ok()) {
   4967     if (fft::FftSupport *fft = parent_->AsFft()) {
   4968       CheckError(fft->DoFft(this, plan, input, output));
   4969     } else {
   4970       SetError();
   4971       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
   4972                    "without FFT support";
   4973     }
   4974   }
   4975   return *this;
   4976 }
   4977 
   4978 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
   4979                         DeviceMemory<std::complex<float>> *output) {
   4980   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   4981 
   4982   if (ok()) {
   4983     if (fft::FftSupport *fft = parent_->AsFft()) {
   4984       CheckError(fft->DoFft(this, plan, input, output));
   4985     } else {
   4986       SetError();
   4987       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
   4988                    "without FFT support";
   4989     }
   4990   }
   4991   return *this;
   4992 }
   4993 
   4994 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
   4995                         DeviceMemory<std::complex<double>> *output) {
   4996   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   4997 
   4998   if (ok()) {
   4999     if (fft::FftSupport *fft = parent_->AsFft()) {
   5000       CheckError(fft->DoFft(this, plan, input, output));
   5001     } else {
   5002       SetError();
   5003       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
   5004                    "without FFT support";
   5005     }
   5006   }
   5007   return *this;
   5008 }
   5009 
   5010 Stream &Stream::ThenFft(fft::Plan *plan,
   5011                         const DeviceMemory<std::complex<float>> &input,
   5012                         DeviceMemory<float> *output) {
   5013   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   5014 
   5015   if (ok()) {
   5016     if (fft::FftSupport *fft = parent_->AsFft()) {
   5017       CheckError(fft->DoFft(this, plan, input, output));
   5018     } else {
   5019       SetError();
   5020       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
   5021                    "without FFT support";
   5022     }
   5023   }
   5024   return *this;
   5025 }
   5026 
   5027 Stream &Stream::ThenFft(fft::Plan *plan,
   5028                         const DeviceMemory<std::complex<double>> &input,
   5029                         DeviceMemory<double> *output) {
   5030   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   5031 
   5032   if (ok()) {
   5033     if (fft::FftSupport *fft = parent_->AsFft()) {
   5034       CheckError(fft->DoFft(this, plan, input, output));
   5035     } else {
   5036       SetError();
   5037       LOG(INFO) << "attempting to perform FFT operation using StreamExecutor "
   5038                    "without FFT support";
   5039     }
   5040   }
   5041   return *this;
   5042 }
   5043 
   5044 // It looks confusing, but all this is doing is inserting a callback at the
   5045 // present point in the stream to then enqueue a task on the host executor.
   5046 Stream &Stream::ThenEnqueueOnBackgroundThread(
   5047     std::function<void(StreamExecutor *)> task) {
   5048   VLOG_CALL(PARAM(task));
   5049 
   5050   StreamExecutor *stream_executor = this->parent_;
   5051   std::function<void()> bound_task = std::bind(task, stream_executor);
   5052 
   5053   return ThenDoHostCallback([stream_executor, bound_task]() {
   5054     stream_executor->EnqueueOnBackgroundThread(bound_task);
   5055   });
   5056 }
   5057 
   5058 port::Status Stream::BlockHostUntilDone() {
   5059   VLOG_CALL();
   5060 
   5061   if (!ok()) {
   5062     port::Status status = port::Status(
   5063         port::error::INTERNAL,
   5064         "stream did not block host until done; was already in an error state");
   5065     LOG(INFO) << status << " " << this;
   5066     return status;
   5067   }
   5068 
   5069   port::Status first_error;
   5070   {
   5071     // Wait until all active sub-streams have done their tasks.
   5072     mutex_lock lock{mu_};
   5073     for (auto &stream : sub_streams_) {
   5074       if (!stream.second) {
   5075         first_error.Update(stream.first->BlockHostUntilDone());
   5076         // Set this sub-stream as available.
   5077         stream.second = true;
   5078       }
   5079     }
   5080   }
   5081 
   5082   temporary_memory_manager_.DeallocateFinalizedTemporaries();
   5083 
   5084   first_error.Update(parent_->BlockHostUntilDone(this));
   5085   CheckError(first_error.ok());
   5086   return first_error;
   5087 }
   5088 
   5089 }  // namespace gputools
   5090 }  // namespace perftools
   5091