Home | History | Annotate | Download | only in stream_executor
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/stream_executor/stream.h"
     17 
     18 #include "tensorflow/stream_executor/platform/port.h"
     19 
     20 #include "absl/strings/str_cat.h"
     21 #include "third_party/eigen3/Eigen/Core"
     22 #include "tensorflow/stream_executor/blas.h"
     23 #include "tensorflow/stream_executor/host_or_device_scalar.h"
     24 #include "tensorflow/stream_executor/lib/stacktrace.h"
     25 #include "tensorflow/stream_executor/platform.h"
     26 #include "tensorflow/stream_executor/platform/logging.h"
     27 #include "tensorflow/stream_executor/rng.h"
     28 #include "tensorflow/stream_executor/stream_executor_internal.h"
     29 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
     30 
     31 namespace stream_executor {
     32 
     33 namespace {
     34 // Code to turn parameters to functions on stream into strings that
     35 // will be VLOG'ed. We need overloads, instead of
     36 // e.g. BatchDescriptorToVlogString(), as the code that calls these
     37 // functions does not know what the type of the parameter is.
     38 string ToVlogString(const dnn::BatchDescriptor &descriptor) {
     39   return descriptor.ToShortString();
     40 }
     41 
     42 string ToVlogString(const dnn::FilterDescriptor &descriptor) {
     43   return descriptor.ToShortString();
     44 }
     45 
     46 string ToVlogString(const dnn::ConvolutionDescriptor &descriptor) {
     47   return descriptor.ToShortString();
     48 }
     49 
     50 string ToVlogString(const dnn::PoolingDescriptor &descriptor) {
     51   return descriptor.ToShortString();
     52 }
     53 
     54 string ToVlogString(const dnn::NormalizeDescriptor &descriptor) {
     55   return descriptor.ToShortString();
     56 }
     57 
     58 string ToVlogString(dnn::ActivationMode mode) {
     59   return dnn::ActivationModeString(mode);
     60 }
     61 
     62 string ToVlogString(const dnn::AlgorithmConfig &algo_config) {
     63   return algo_config.ToString();
     64 }
     65 
     66 string ToVlogString(dnn::ElementwiseOperation op) {
     67   return dnn::ElementwiseOperationString(op);
     68 }
     69 
     70 string ToVlogString(dnn::QuantizedActivationMode mode) {
     71   return dnn::QuantizedActivationModeString(mode);
     72 }
     73 
     74 string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
     75 
     76 string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
     77 
     78 string ToVlogString(blas::Diagonal d) { return blas::DiagonalString(d); }
     79 
     80 string ToVlogString(blas::Side s) { return blas::SideString(s); }
     81 
     82 string ToVlogString(blas::ComputationType ty) {
     83   return blas::ComputationTypeString(ty);
     84 }
     85 
     86 string ToVlogString(const void *ptr) {
     87   if (ptr == nullptr) {
     88     return "null";
     89   }
     90 
     91   // StrCat does not convert pointers to text.
     92   std::ostringstream out;
     93   out << ptr;
     94   return out.str();
     95 }
     96 
     97 template <class T>
     98 string ToVlogString(const std::complex<T> &c) {
     99   // StrCat does not convert std::complex to text.
    100   std::ostringstream out;
    101   out << c;
    102   return out.str();
    103 }
    104 
    105 template <class T>
    106 string ToVlogString(const std::function<T> &f) {
    107   return f == nullptr ? "null" : "<non-null function>";
    108 }
    109 
    110 string ToVlogString(const DeviceMemoryBase &memory) {
    111   return ToVlogString(memory.opaque());
    112 }
    113 
    114 string ToVlogString(const DeviceMemoryBase *memory) {
    115   return memory == nullptr ? "null" : ToVlogString(*memory);
    116 }
    117 
    118 string ToVlogString(const Eigen::half &h) {
    119   return absl::StrCat(static_cast<float>(h));
    120 }
    121 
    122 string ToVlogString(int i) { return absl::StrCat(i); }
    123 
    124 string ToVlogString(uint32 i) { return absl::StrCat(i); }
    125 
    126 string ToVlogString(uint64 i) { return absl::StrCat(i); }
    127 
    128 string ToVlogString(int64 i) { return absl::StrCat(i); }
    129 
    130 string ToVlogString(float f) { return absl::StrCat(f); }
    131 
    132 string ToVlogString(double d) { return absl::StrCat(d); }
    133 
    134 template <typename T>
    135 string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
    136   if (memory_or_constant.is_pointer()) {
    137     return ToVlogString(memory_or_constant.pointer());
    138   }
    139   return ToVlogString(memory_or_constant.value());
    140 }
    141 
    142 template <class T>
    143 string ToVlogString(port::ArraySlice<T> elements) {
    144   string str = absl::StrCat(
    145       ToVlogString(reinterpret_cast<const void *>(elements.data())), "[",
    146       elements.size(), "]{");
    147   const char *separator = "";
    148   size_t max_to_show = std::numeric_limits<size_t>::max();
    149   if (!VLOG_IS_ON(2)) {
    150     max_to_show = 5;
    151   } else if (!VLOG_IS_ON(3)) {
    152     max_to_show = 20;
    153   } else if (!VLOG_IS_ON(11)) {
    154     max_to_show = 1000;
    155   }
    156   for (size_t i = 0; i < elements.size(); ++i) {
    157     if (i == max_to_show) {
    158       str += ", ...";
    159       break;
    160     }
    161     absl::StrAppend(&str, separator, ToVlogString(elements[i]));
    162     separator = ", ";
    163   }
    164   str += "}";
    165   return str;
    166 }
    167 
    168 template <class T>
    169 string ToVlogString(port::MutableArraySlice<T> elements) {
    170   return ToVlogString(port::ArraySlice<T>(elements));
    171 }
    172 
    173 string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
    174   switch (depth_to_space_layout) {
    175     case dnn::DepthToSpaceLayout::DepthHeightWidth:
    176       return "DepthToSpaceLayout::DepthHeightWidth";
    177   }
    178   return "unknown DepthToSpaceLayout";
    179 }
    180 
    181 string ToVlogString(dnn::DataType data_type) {
    182   switch (data_type) {
    183     case dnn::DataType::kFloat:
    184       return "dnn::DataType::kFloat";
    185     case dnn::DataType::kDouble:
    186       return "dnn::DataType::kDouble";
    187     case dnn::DataType::kHalf:
    188       return "dnn::DataType::kHalf";
    189     case dnn::DataType::kInt8:
    190       return "dnn::DataType::kInt8";
    191     case dnn::DataType::kInt32:
    192       return "dnn::DataType::kInt32";
    193     default:
    194       return "unknown DataType";
    195   }
    196 }
    197 
    198 // Used together with PARAM to VLOG calls made to the stream. Intended
    199 // to be used like this:
    200 //
    201 //   VLOG(1) << CallStr("MyFunction", this, {PARAM(a), PARAM(b)});
    202 //
    203 // where a and b are the parameters to MyFunction.
    204 //
    205 // See VLOG_CALL for a short-hand for this. This way of doing it saves
    206 // a tremendous amount of boilerplate code given how many functions
    207 // there are on Stream and how many parameters they each have.
    208 string CallStr(const char *function_name, Stream *stream,
    209                std::vector<std::pair<const char *, string>> params) {
    210   // Do not call this function unless VLOG is on since just
    211   // constructing all the strings in params is expensive.
    212   CHECK(VLOG_IS_ON(1));
    213 
    214   string str = absl::StrCat(stream->DebugStreamPointers(),
    215                             " Called Stream::", function_name, "(");
    216   const char *separator = "";
    217   for (const auto &param : params) {
    218     absl::StrAppend(&str, separator, param.first, "=", param.second);
    219     separator = ", ";
    220   }
    221   absl::StrAppend(&str, ")");
    222   if (VLOG_IS_ON(10)) {
    223     absl::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
    224   }
    225   return str;
    226 }
    227 
    228 // Use this macro to avoid having to type every parameter twice to log
    229 // it with VLOG and CallStr.
    230 #define PARAM(parameter) \
    231   { #parameter, ToVlogString(parameter) }
    232 
    233 // Use this macro to avoid having to type out the name of each
    234 // function and to save some boilerplate. Intended to be used like this:
    235 //
    236 //   VLOG_CALL(PARAM(a), PARAM(b))
    237 //
    238 // This saves a tremendous amount of boilerplate compared to the alternative:
    239 //
    240 //   VLOG(1) << "Calling MyFunction(a=" << ToVlogString(a)
    241 //           << ", b=" << ToVlogString(b);
    242 //
    243 // Note here that most of the parameter names are not short and that
    244 // most of the functions take many more than 2 parameters.
    245 #define VLOG_CALL(...) VLOG(1) << CallStr(__func__, this, {__VA_ARGS__})
    246 
    247 }  // namespace
    248 
    249 Stream::Stream(StreamExecutor *parent)
    250     : parent_(parent),
    251       implementation_(parent->implementation()->GetStreamImplementation()),
    252       allocated_(false),
    253       ok_(false),
    254       temporary_memory_manager_(this) {
    255   VLOG_CALL(PARAM(parent));
    256 }
    257 
    258 Stream::Stream(StreamExecutor *parent,
    259                internal::StreamInterface *implementation)
    260     : parent_(parent),
    261       implementation_(implementation),
    262       allocated_(false),
    263       ok_(false),
    264       temporary_memory_manager_(this) {
    265   VLOG_CALL(PARAM(parent), PARAM(implementation));
    266 }
    267 
    268 Stream::~Stream() {
    269   VLOG_CALL();
    270 
    271   // Ensure the stream is completed.
    272   auto status = BlockHostUntilDone();
    273   if (!status.ok()) {
    274     LOG(WARNING) << "Error blocking host until done in stream destructor: "
    275                  << status;
    276   }
    277   temporary_memory_manager_.ForceDeallocateAll();
    278 
    279   if (allocated_) {
    280     parent_->DeallocateStream(this);
    281   }
    282 }
    283 
    284 port::Status Stream::RefreshStatus() {
    285   port::Status status = parent_->GetStatus(this);
    286   CheckStatus(status);
    287   return status;
    288 }
    289 
    290 Stream &Stream::Init() {
    291   VLOG_CALL();
    292 
    293   mutex_lock lock(mu_);
    294   CHECK_EQ(false, allocated_)
    295       << "stream appears to already have been initialized";
    296   CHECK(!ok_) << "stream should be in !ok() state pre-initialization";
    297 
    298   if (parent_->AllocateStream(this)) {
    299     // Successful initialization!
    300     allocated_ = true;
    301     ok_ = true;
    302   } else {
    303     LOG(ERROR) << "failed to allocate stream during initialization";
    304   }
    305 
    306   return *this;
    307 }
    308 
    309 Stream &Stream::InitTimer(Timer *timer) {
    310   VLOG_CALL(PARAM(timer));
    311 
    312   if (ok()) {
    313     CheckError(parent_->AllocateTimer(timer));
    314   } else {
    315     LOG(INFO) << "did not allocate timer: " << timer;
    316   }
    317   return *this;
    318 }
    319 
    320 Stream &Stream::InitWithTimer(Timer *timer) {
    321   VLOG_CALL(PARAM(timer));
    322 
    323   return Init().InitTimer(timer);
    324 }
    325 
    326 Stream &Stream::ThenRecordEvent(Event *event) {
    327   VLOG_CALL(PARAM(event));
    328 
    329   port::Status status = parent_->RecordEvent(this, event);
    330   if (!status.ok()) {
    331     LOG(ERROR) << "Error recording event in stream: " << status.error_message()
    332                << "; not marking stream as bad, as the Event object may be "
    333                << "at fault. Monitor for further errors.";
    334   }
    335 
    336   return *this;
    337 }
    338 
    339 Stream &Stream::ThenBatchNormalizationForward(
    340     const DeviceMemory<float> &x, const DeviceMemory<float> &scale,
    341     const DeviceMemory<float> &offset,
    342     const DeviceMemory<float> &estimated_mean,
    343     const DeviceMemory<float> &estimated_variance,
    344     const dnn::BatchDescriptor &x_desc,
    345     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    346     DeviceMemory<float> *y, DeviceMemory<float> *batch_mean,
    347     DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
    348     DeviceMemory<float> *saved_inv_var, bool is_training,
    349     std::function<const DeviceMemory<float> &()> var_to_inv_var,
    350     std::function<void()> inv_var_to_var) {
    351   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
    352             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
    353   if (ok()) {
    354     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    355       CheckError(dnn->DoBatchNormalizationForward(
    356           this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
    357           scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
    358           saved_inv_var, is_training, std::move(var_to_inv_var),
    359           std::move(inv_var_to_var)));
    360     } else {
    361       SetErrorAndLogNoDnnSupport();
    362     }
    363   }
    364   return *this;
    365 }
    366 
    367 Stream &Stream::ThenBatchNormalizationBackward(
    368     const DeviceMemory<float> &y_backprop, const DeviceMemory<float> &x,
    369     const DeviceMemory<float> &scale, const DeviceMemory<float> &mean,
    370     const DeviceMemory<float> &inv_var, const dnn::BatchDescriptor &x_desc,
    371     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    372     DeviceMemory<float> *x_backprop, DeviceMemory<float> *scale_backprop,
    373     DeviceMemory<float> *offset_backprop) {
    374   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
    375             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
    376             PARAM(scale_backprop), PARAM(offset_backprop));
    377   if (ok()) {
    378     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    379       CheckError(dnn->DoBatchNormalizationBackward(
    380           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
    381           epsilon, x_backprop, scale_backprop, offset_backprop));
    382     } else {
    383       SetErrorAndLogNoDnnSupport();
    384     }
    385   }
    386   return *this;
    387 }
    388 
    389 Stream &Stream::ThenBatchNormalizationForward(
    390     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
    391     const DeviceMemory<float> &offset,
    392     const DeviceMemory<float> &estimated_mean,
    393     const DeviceMemory<float> &estimated_variance,
    394     const dnn::BatchDescriptor &x_desc,
    395     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    396     DeviceMemory<Eigen::half> *y, DeviceMemory<float> *batch_mean,
    397     DeviceMemory<float> *batch_var, DeviceMemory<float> *saved_mean,
    398     DeviceMemory<float> *saved_inv_var, bool is_training,
    399     std::function<const DeviceMemory<float> &()> var_to_inv_var,
    400     std::function<void()> inv_var_to_var) {
    401   VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
    402             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(y));
    403   if (ok()) {
    404     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    405       CheckError(dnn->DoBatchNormalizationForward(
    406           this, x, scale, offset, estimated_mean, estimated_variance, x_desc,
    407           scale_offset_desc, epsilon, y, batch_mean, batch_var, saved_mean,
    408           saved_inv_var, is_training, std::move(var_to_inv_var),
    409           std::move(inv_var_to_var)));
    410     } else {
    411       SetErrorAndLogNoDnnSupport();
    412     }
    413   }
    414   return *this;
    415 }
    416 
    417 Stream &Stream::ThenBatchNormalizationBackward(
    418     const DeviceMemory<Eigen::half> &y_backprop,
    419     const DeviceMemory<Eigen::half> &x, const DeviceMemory<float> &scale,
    420     const DeviceMemory<float> &mean, const DeviceMemory<float> &inv_var,
    421     const dnn::BatchDescriptor &x_desc,
    422     const dnn::BatchDescriptor &scale_offset_desc, const double epsilon,
    423     DeviceMemory<Eigen::half> *x_backprop, DeviceMemory<float> *scale_backprop,
    424     DeviceMemory<float> *offset_backprop) {
    425   VLOG_CALL(PARAM(y_backprop), PARAM(x), PARAM(scale), PARAM(x_desc),
    426             PARAM(scale_offset_desc), PARAM(epsilon), PARAM(x_backprop),
    427             PARAM(scale_backprop), PARAM(offset_backprop));
    428   if (ok()) {
    429     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    430       CheckError(dnn->DoBatchNormalizationBackward(
    431           this, y_backprop, x, scale, mean, inv_var, x_desc, scale_offset_desc,
    432           epsilon, x_backprop, scale_backprop, offset_backprop));
    433     } else {
    434       SetErrorAndLogNoDnnSupport();
    435     }
    436   }
    437   return *this;
    438 }
    439 
    440 Stream &Stream::ThenFusedConvolveWithAlgorithm(
    441     const dnn::BatchDescriptor &conv_input_descriptor,
    442     const DeviceMemory<double> &conv_input_data, double conv_input_scale,
    443     const dnn::FilterDescriptor &filter_descriptor,
    444     const DeviceMemory<double> &filter_data,
    445     const dnn::ConvolutionDescriptor &convolution_descriptor,
    446     const DeviceMemory<double> &side_input_data, double side_input_scale,
    447     const dnn::BatchDescriptor &bias_descriptor,
    448     const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
    449     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
    450     ScratchAllocator *scratch_allocator,
    451     const dnn::AlgorithmConfig &algorithm_config,
    452     dnn::ProfileResult *output_profile_result) {
    453   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    454             PARAM(conv_input_scale), PARAM(filter_descriptor),
    455             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
    456             PARAM(side_input_data), PARAM(side_input_scale),
    457             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
    458             PARAM(algorithm_config));
    459 
    460   if (ok()) {
    461     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    462       auto status = dnn->DoFusedConvolve(
    463           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    464           filter_descriptor, filter_data, convolution_descriptor,
    465           side_input_data, side_input_scale, bias_descriptor, biases,
    466           activation_mode, output_descriptor, output, scratch_allocator,
    467           algorithm_config, output_profile_result);
    468       if (!status && !output_profile_result) {
    469         SetError();
    470       }
    471     } else {
    472       SetErrorAndLogNoDnnSupport();
    473     }
    474   }
    475   return *this;
    476 }
    477 
    478 Stream &Stream::ThenFusedConvolveWithAlgorithm(
    479     const dnn::BatchDescriptor &conv_input_descriptor,
    480     const DeviceMemory<float> &conv_input_data, float conv_input_scale,
    481     const dnn::FilterDescriptor &filter_descriptor,
    482     const DeviceMemory<float> &filter_data,
    483     const dnn::ConvolutionDescriptor &convolution_descriptor,
    484     const DeviceMemory<float> &side_input_data, float side_input_scale,
    485     const dnn::BatchDescriptor &bias_descriptor,
    486     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    487     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
    488     ScratchAllocator *scratch_allocator,
    489     const dnn::AlgorithmConfig &algorithm_config,
    490     dnn::ProfileResult *output_profile_result) {
    491   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    492             PARAM(conv_input_scale), PARAM(filter_descriptor),
    493             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
    494             PARAM(side_input_data), PARAM(side_input_scale),
    495             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
    496             PARAM(algorithm_config));
    497 
    498   if (ok()) {
    499     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    500       auto status = dnn->DoFusedConvolve(
    501           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    502           filter_descriptor, filter_data, convolution_descriptor,
    503           side_input_data, side_input_scale, bias_descriptor, biases,
    504           activation_mode, output_descriptor, output, scratch_allocator,
    505           algorithm_config, output_profile_result);
    506       if (!status && !output_profile_result) {
    507         SetError();
    508       }
    509     } else {
    510       SetErrorAndLogNoDnnSupport();
    511     }
    512   }
    513   return *this;
    514 }
    515 
    516 Stream &Stream::ThenFusedConvolveWithAlgorithm(
    517     const dnn::BatchDescriptor &conv_input_descriptor,
    518     const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
    519     const dnn::FilterDescriptor &filter_descriptor,
    520     const DeviceMemory<Eigen::half> &filter_data,
    521     const dnn::ConvolutionDescriptor &convolution_descriptor,
    522     const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
    523     const dnn::BatchDescriptor &bias_descriptor,
    524     const DeviceMemory<Eigen::half> &biases,
    525     dnn::ActivationMode activation_mode,
    526     const dnn::BatchDescriptor &output_descriptor,
    527     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
    528     const dnn::AlgorithmConfig &algorithm_config,
    529     dnn::ProfileResult *output_profile_result) {
    530   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    531             PARAM(conv_input_scale), PARAM(filter_descriptor),
    532             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
    533             PARAM(side_input_data), PARAM(side_input_scale),
    534             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
    535             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
    536 
    537   if (ok()) {
    538     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    539       auto status = dnn->DoFusedConvolve(
    540           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    541           filter_descriptor, filter_data, convolution_descriptor,
    542           side_input_data, side_input_scale, bias_descriptor, biases,
    543           activation_mode, output_descriptor, output, scratch_allocator,
    544           algorithm_config, output_profile_result);
    545       if (!status && !output_profile_result) {
    546         SetError();
    547       }
    548     } else {
    549       SetErrorAndLogNoDnnSupport();
    550     }
    551   }
    552   return *this;
    553 }
    554 
    555 Stream &Stream::ThenFusedConvolveWithAlgorithm(
    556     const dnn::BatchDescriptor &conv_input_descriptor,
    557     const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
    558     const dnn::FilterDescriptor &filter_descriptor,
    559     const DeviceMemory<int8> &filter_data,
    560     const dnn::ConvolutionDescriptor &convolution_descriptor,
    561     const DeviceMemory<int8> &side_input_data, float side_input_scale,
    562     const dnn::BatchDescriptor &bias_descriptor,
    563     const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
    564     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
    565     ScratchAllocator *scratch_allocator,
    566     const dnn::AlgorithmConfig &algorithm_config,
    567     dnn::ProfileResult *output_profile_result) {
    568   VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
    569             PARAM(conv_input_scale), PARAM(filter_descriptor),
    570             PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
    571             PARAM(side_input_data), PARAM(side_input_scale),
    572             PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
    573             PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
    574 
    575   if (ok()) {
    576     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    577       auto status = dnn->DoFusedConvolve(
    578           this, conv_input_descriptor, conv_input_data, conv_input_scale,
    579           filter_descriptor, filter_data, convolution_descriptor,
    580           side_input_data, side_input_scale, bias_descriptor, biases,
    581           activation_mode, output_descriptor, output, scratch_allocator,
    582           algorithm_config, output_profile_result);
    583       if (!status && !output_profile_result) {
    584         SetError();
    585       }
    586     } else {
    587       SetErrorAndLogNoDnnSupport();
    588     }
    589   }
    590   return *this;
    591 }
    592 
    593 Stream &Stream::ThenConvolveWithAlgorithm(
    594     const dnn::BatchDescriptor &input_descriptor,
    595     const DeviceMemory<double> &input_data,
    596     const dnn::FilterDescriptor &filter_descriptor,
    597     const DeviceMemory<double> &filter_data,
    598     const dnn::ConvolutionDescriptor &convolution_descriptor,
    599     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
    600     ScratchAllocator *scratch_allocator,
    601     const dnn::AlgorithmConfig &algorithm_config,
    602     dnn::ProfileResult *output_profile_result) {
    603   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    604             PARAM(filter_descriptor), PARAM(filter_data),
    605             PARAM(convolution_descriptor), PARAM(output_descriptor),
    606             PARAM(output), PARAM(algorithm_config));
    607 
    608   if (ok()) {
    609     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    610       DeviceMemory<uint8> scratch_memory;
    611       dnn::AlgorithmDesc algorithm_desc;
    612       auto status =
    613           dnn->PrepareForConvolution(
    614                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
    615                  input_data, filter_descriptor, filter_data, output_descriptor,
    616                  *output, convolution_descriptor, algorithm_config,
    617                  scratch_allocator, &algorithm_desc, &scratch_memory)
    618               .ok();
    619       if (status) {
    620         status = dnn->DoConvolve(
    621             this, input_descriptor, input_data, filter_descriptor, filter_data,
    622             convolution_descriptor, output_descriptor, output, algorithm_desc,
    623             &scratch_memory, output_profile_result);
    624       }
    625       if (!status && !output_profile_result) {
    626         SetError();
    627       }
    628     } else {
    629       SetErrorAndLogNoDnnSupport();
    630     }
    631   }
    632   return *this;
    633 }
    634 
    635 Stream &Stream::ThenConvolveWithAlgorithm(
    636     const dnn::BatchDescriptor &input_descriptor,
    637     const DeviceMemory<float> &input_data,
    638     const dnn::FilterDescriptor &filter_descriptor,
    639     const DeviceMemory<float> &filter_data,
    640     const dnn::ConvolutionDescriptor &convolution_descriptor,
    641     const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
    642     ScratchAllocator *scratch_allocator,
    643     const dnn::AlgorithmConfig &algorithm_config,
    644     dnn::ProfileResult *output_profile_result) {
    645   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    646             PARAM(filter_descriptor), PARAM(filter_data),
    647             PARAM(convolution_descriptor), PARAM(output_descriptor),
    648             PARAM(output), PARAM(algorithm_config));
    649 
    650   if (ok()) {
    651     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    652       DeviceMemory<uint8> scratch_memory;
    653       dnn::AlgorithmDesc algorithm_desc;
    654       auto status =
    655           dnn->PrepareForConvolution(
    656                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
    657                  input_data, filter_descriptor, filter_data, output_descriptor,
    658                  *output, convolution_descriptor, algorithm_config,
    659                  scratch_allocator, &algorithm_desc, &scratch_memory)
    660               .ok();
    661       if (status) {
    662         status = dnn->DoConvolve(
    663             this, input_descriptor, input_data, filter_descriptor, filter_data,
    664             convolution_descriptor, output_descriptor, output, algorithm_desc,
    665             &scratch_memory, output_profile_result);
    666       }
    667       if (!status && !output_profile_result) {
    668         SetError();
    669       }
    670     } else {
    671       SetErrorAndLogNoDnnSupport();
    672     }
    673   }
    674   return *this;
    675 }
    676 
    677 Stream &Stream::ThenConvolveWithAlgorithm(
    678     const dnn::BatchDescriptor &input_descriptor,
    679     const DeviceMemory<Eigen::half> &input_data,
    680     const dnn::FilterDescriptor &filter_descriptor,
    681     const DeviceMemory<Eigen::half> &filter_data,
    682     const dnn::ConvolutionDescriptor &convolution_descriptor,
    683     const dnn::BatchDescriptor &output_descriptor,
    684     DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator,
    685     const dnn::AlgorithmConfig &algorithm_config,
    686     dnn::ProfileResult *output_profile_result) {
    687   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    688             PARAM(filter_descriptor), PARAM(filter_data),
    689             PARAM(convolution_descriptor), PARAM(output_descriptor),
    690             PARAM(output), PARAM(algorithm_config));
    691 
    692   if (ok()) {
    693     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    694       DeviceMemory<uint8> scratch_memory;
    695       dnn::AlgorithmDesc algorithm_desc;
    696       auto status =
    697           dnn->PrepareForConvolution(
    698                  dnn::ConvolutionKind::FORWARD, this, input_descriptor,
    699                  input_data, filter_descriptor, filter_data, output_descriptor,
    700                  *output, convolution_descriptor, algorithm_config,
    701                  scratch_allocator, &algorithm_desc, &scratch_memory)
    702               .ok();
    703       if (status) {
    704         status = dnn->DoConvolve(
    705             this, input_descriptor, input_data, filter_descriptor, filter_data,
    706             convolution_descriptor, output_descriptor, output, algorithm_desc,
    707             &scratch_memory, output_profile_result);
    708       }
    709       if (!status && !output_profile_result) {
    710         SetError();
    711       }
    712     } else {
    713       SetErrorAndLogNoDnnSupport();
    714     }
    715   }
    716   return *this;
    717 }
    718 
    719 Stream &Stream::ThenConvolve(
    720     const dnn::BatchDescriptor &input_descriptor,
    721     const DeviceMemory<float> &input_data,
    722     const dnn::FilterDescriptor &filter_descriptor,
    723     const DeviceMemory<float> &filter_data,
    724     const dnn::ConvolutionDescriptor &convolution_descriptor,
    725     const dnn::BatchDescriptor &output_descriptor,
    726     DeviceMemory<float> *output) {
    727   return ThenConvolveWithAlgorithm(
    728       input_descriptor, input_data, filter_descriptor, filter_data,
    729       convolution_descriptor, output_descriptor, output,
    730       /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
    731       /*output_profile_result=*/nullptr);
    732 }
    733 
    734 Stream &Stream::ThenConvolveQuantized(
    735     const dnn::BatchDescriptor &input_descriptor,
    736     const DeviceMemory<float> &input_data,
    737     const dnn::FilterDescriptor &filter_descriptor,
    738     const DeviceMemory<int8> &filter_coefficients,
    739     const DeviceMemory<float> &coefficient_scales,
    740     const dnn::ConvolutionDescriptor &convolution_descriptor,
    741     const dnn::BatchDescriptor &output_descriptor,
    742     DeviceMemory<float> *output) {
    743   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    744             PARAM(filter_descriptor), PARAM(filter_coefficients),
    745             PARAM(coefficient_scales), PARAM(convolution_descriptor),
    746             PARAM(output_descriptor), PARAM(output));
    747 
    748   if (ok()) {
    749     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    750       CheckError(dnn->DoConvolveQuantized(
    751           this, input_descriptor, input_data, filter_descriptor,
    752           filter_coefficients, coefficient_scales, convolution_descriptor,
    753           output_descriptor, output));
    754     } else {
    755       SetError();
    756       LOG(WARNING)
    757           << "attempting to perform DNN operation using StreamExecutor "
    758              "without DNN support";
    759     }
    760   }
    761   return *this;
    762 }
    763 
    764 Stream &Stream::ThenConvolveQuantized(
    765     const dnn::BatchDescriptor &input_descriptor,
    766     const DeviceMemory<float> &input_data,
    767     const dnn::FilterDescriptor &filter_descriptor,
    768     const DeviceMemory<int16> &filter_coefficients,
    769     const DeviceMemory<float> &coefficient_scales,
    770     const dnn::ConvolutionDescriptor &convolution_descriptor,
    771     const dnn::BatchDescriptor &output_descriptor,
    772     DeviceMemory<float> *output) {
    773   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    774             PARAM(filter_descriptor), PARAM(filter_coefficients),
    775             PARAM(coefficient_scales), PARAM(convolution_descriptor),
    776             PARAM(output_descriptor), PARAM(output));
    777 
    778   if (ok()) {
    779     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    780       CheckError(dnn->DoConvolveQuantized(
    781           this, input_descriptor, input_data, filter_descriptor,
    782           filter_coefficients, coefficient_scales, convolution_descriptor,
    783           output_descriptor, output));
    784     } else {
    785       SetError();
    786       LOG(WARNING)
    787           << "attempting to perform DNN operation using StreamExecutor "
    788              "without DNN support";
    789     }
    790   }
    791   return *this;
    792 }
    793 
    794 Stream &Stream::ThenSeparableConvolve(
    795     const dnn::BatchDescriptor &batch_descriptor,
    796     const DeviceMemory<float> &input_data,
    797     const dnn::FilterDescriptor &filter_descriptor, int depth_multiplier,
    798     const DeviceMemory<float> &first_weights,
    799     const DeviceMemory<float> &second_weights,
    800     const dnn::ConvolutionDescriptor &convolution_descriptor,
    801     const dnn::BatchDescriptor &output_descriptor,
    802     DeviceMemory<float> *output) {
    803   VLOG_CALL(
    804       PARAM(batch_descriptor), PARAM(input_data), PARAM(filter_descriptor),
    805       PARAM(depth_multiplier), PARAM(first_weights), PARAM(second_weights),
    806       PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output));
    807 
    808   if (ok()) {
    809     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    810       CheckError(dnn->DoSeparableConvolve(
    811           this, batch_descriptor, input_data, filter_descriptor,
    812           depth_multiplier, first_weights, second_weights,
    813           convolution_descriptor, output_descriptor, output));
    814     } else {
    815       SetErrorAndLogNoDnnSupport();
    816     }
    817   }
    818   return *this;
    819 }
    820 
    821 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
    822     const dnn::FilterDescriptor &filter_descriptor,
    823     const DeviceMemory<double> &filter_data,
    824     const dnn::BatchDescriptor &output_descriptor,
    825     DeviceMemory<double> backward_output_data,
    826     const dnn::ConvolutionDescriptor &convolution_descriptor,
    827     const dnn::BatchDescriptor &input_descriptor,
    828     DeviceMemory<double> *backward_input_data,
    829     ScratchAllocator *scratch_allocator,
    830     const dnn::AlgorithmConfig &algorithm_config,
    831     dnn::ProfileResult *output_profile_result) {
    832   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
    833             PARAM(output_descriptor), PARAM(backward_output_data),
    834             PARAM(convolution_descriptor), PARAM(input_descriptor),
    835             PARAM(backward_input_data));
    836 
    837   if (ok()) {
    838     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    839       DeviceMemory<uint8> scratch_memory;
    840       dnn::AlgorithmDesc algorithm_desc;
    841       auto status =
    842           dnn->PrepareForConvolution(
    843                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
    844                  *backward_input_data, filter_descriptor, filter_data,
    845                  output_descriptor, backward_output_data,
    846                  convolution_descriptor, algorithm_config, scratch_allocator,
    847                  &algorithm_desc, &scratch_memory)
    848               .ok();
    849       if (status) {
    850         status = dnn->DoConvolveBackwardData(
    851             this, filter_descriptor, filter_data, output_descriptor,
    852             backward_output_data, convolution_descriptor, input_descriptor,
    853             backward_input_data, algorithm_desc, &scratch_memory,
    854             output_profile_result);
    855       }
    856       if (!status && !output_profile_result) {
    857         SetError();
    858       }
    859     } else {
    860       SetErrorAndLogNoDnnSupport();
    861     }
    862   }
    863   return *this;
    864 }
    865 
    866 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
    867     const dnn::FilterDescriptor &filter_descriptor,
    868     const DeviceMemory<float> &filter_data,
    869     const dnn::BatchDescriptor &output_descriptor,
    870     DeviceMemory<float> backward_output_data,
    871     const dnn::ConvolutionDescriptor &convolution_descriptor,
    872     const dnn::BatchDescriptor &input_descriptor,
    873     DeviceMemory<float> *backward_input_data,
    874     ScratchAllocator *scratch_allocator,
    875     const dnn::AlgorithmConfig &algorithm_config,
    876     dnn::ProfileResult *output_profile_result) {
    877   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
    878             PARAM(output_descriptor), PARAM(backward_output_data),
    879             PARAM(convolution_descriptor), PARAM(input_descriptor),
    880             PARAM(backward_input_data));
    881 
    882   if (ok()) {
    883     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    884       DeviceMemory<uint8> scratch_memory;
    885       dnn::AlgorithmDesc algorithm_desc;
    886       auto status =
    887           dnn->PrepareForConvolution(
    888                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
    889                  *backward_input_data, filter_descriptor, filter_data,
    890                  output_descriptor, backward_output_data,
    891                  convolution_descriptor, algorithm_config, scratch_allocator,
    892                  &algorithm_desc, &scratch_memory)
    893               .ok();
    894       if (status) {
    895         status = dnn->DoConvolveBackwardData(
    896             this, filter_descriptor, filter_data, output_descriptor,
    897             backward_output_data, convolution_descriptor, input_descriptor,
    898             backward_input_data, algorithm_desc, &scratch_memory,
    899             output_profile_result);
    900       }
    901       if (!status && !output_profile_result) {
    902         SetError();
    903       }
    904     } else {
    905       SetErrorAndLogNoDnnSupport();
    906     }
    907   }
    908   return *this;
    909 }
    910 
    911 Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
    912     const dnn::FilterDescriptor &filter_descriptor,
    913     const DeviceMemory<Eigen::half> &filter_data,
    914     const dnn::BatchDescriptor &output_descriptor,
    915     DeviceMemory<Eigen::half> backward_output_data,
    916     const dnn::ConvolutionDescriptor &convolution_descriptor,
    917     const dnn::BatchDescriptor &input_descriptor,
    918     DeviceMemory<Eigen::half> *backward_input_data,
    919     ScratchAllocator *scratch_allocator,
    920     const dnn::AlgorithmConfig &algorithm_config,
    921     dnn::ProfileResult *output_profile_result) {
    922   VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
    923             PARAM(output_descriptor), PARAM(backward_output_data),
    924             PARAM(convolution_descriptor), PARAM(input_descriptor),
    925             PARAM(backward_input_data));
    926 
    927   if (ok()) {
    928     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    929       DeviceMemory<uint8> scratch_memory;
    930       dnn::AlgorithmDesc algorithm_desc;
    931       auto status =
    932           dnn->PrepareForConvolution(
    933                  dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
    934                  *backward_input_data, filter_descriptor, filter_data,
    935                  output_descriptor, backward_output_data,
    936                  convolution_descriptor, algorithm_config, scratch_allocator,
    937                  &algorithm_desc, &scratch_memory)
    938               .ok();
    939       if (status) {
    940         status = dnn->DoConvolveBackwardData(
    941             this, filter_descriptor, filter_data, output_descriptor,
    942             backward_output_data, convolution_descriptor, input_descriptor,
    943             backward_input_data, algorithm_desc, &scratch_memory,
    944             output_profile_result);
    945       }
    946       if (!status && !output_profile_result) {
    947         SetError();
    948       }
    949     } else {
    950       SetErrorAndLogNoDnnSupport();
    951     }
    952   }
    953   return *this;
    954 }
    955 
    956 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
    957     const dnn::BatchDescriptor &input_descriptor,
    958     const DeviceMemory<double> &input_data,
    959     const dnn::BatchDescriptor &output_descriptor,
    960     DeviceMemory<double> backward_output_data,
    961     const dnn::ConvolutionDescriptor &convolution_descriptor,
    962     const dnn::FilterDescriptor &filter_descriptor,
    963     DeviceMemory<double> *backward_filter_data,
    964     ScratchAllocator *scratch_allocator,
    965     const dnn::AlgorithmConfig &algorithm_config,
    966     dnn::ProfileResult *output_profile_result) {
    967   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
    968             PARAM(output_descriptor), PARAM(backward_output_data),
    969             PARAM(convolution_descriptor), PARAM(filter_descriptor),
    970             PARAM(backward_filter_data));
    971 
    972   if (ok()) {
    973     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
    974       DeviceMemory<uint8> scratch_memory;
    975       dnn::AlgorithmDesc algorithm_desc;
    976       auto status =
    977           dnn->PrepareForConvolution(
    978                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
    979                  input_data, filter_descriptor, *backward_filter_data,
    980                  output_descriptor, backward_output_data,
    981                  convolution_descriptor, algorithm_config, scratch_allocator,
    982                  &algorithm_desc, &scratch_memory)
    983               .ok();
    984       if (status) {
    985         status = dnn->DoConvolveBackwardFilter(
    986             this, input_descriptor, input_data, output_descriptor,
    987             backward_output_data, convolution_descriptor, filter_descriptor,
    988             backward_filter_data, algorithm_desc, &scratch_memory,
    989             output_profile_result);
    990       }
    991       if (!status && !output_profile_result) {
    992         SetError();
    993       }
    994     } else {
    995       SetErrorAndLogNoDnnSupport();
    996     }
    997   }
    998   return *this;
    999 }
   1000 
   1001 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
   1002     const dnn::BatchDescriptor &input_descriptor,
   1003     const DeviceMemory<float> &input_data,
   1004     const dnn::BatchDescriptor &output_descriptor,
   1005     DeviceMemory<float> backward_output_data,
   1006     const dnn::ConvolutionDescriptor &convolution_descriptor,
   1007     const dnn::FilterDescriptor &filter_descriptor,
   1008     DeviceMemory<float> *backward_filter_data,
   1009     ScratchAllocator *scratch_allocator,
   1010     const dnn::AlgorithmConfig &algorithm_config,
   1011     dnn::ProfileResult *output_profile_result) {
   1012   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
   1013             PARAM(output_descriptor), PARAM(backward_output_data),
   1014             PARAM(convolution_descriptor), PARAM(filter_descriptor),
   1015             PARAM(backward_filter_data));
   1016 
   1017   if (ok()) {
   1018     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1019       DeviceMemory<uint8> scratch_memory;
   1020       dnn::AlgorithmDesc algorithm_desc;
   1021       auto status =
   1022           dnn->PrepareForConvolution(
   1023                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
   1024                  input_data, filter_descriptor, *backward_filter_data,
   1025                  output_descriptor, backward_output_data,
   1026                  convolution_descriptor, algorithm_config, scratch_allocator,
   1027                  &algorithm_desc, &scratch_memory)
   1028               .ok();
   1029       if (status) {
   1030         status = dnn->DoConvolveBackwardFilter(
   1031             this, input_descriptor, input_data, output_descriptor,
   1032             backward_output_data, convolution_descriptor, filter_descriptor,
   1033             backward_filter_data, algorithm_desc, &scratch_memory,
   1034             output_profile_result);
   1035       }
   1036       if (!status && !output_profile_result) {
   1037         SetError();
   1038       }
   1039     } else {
   1040       SetErrorAndLogNoDnnSupport();
   1041     }
   1042   }
   1043   return *this;
   1044 }
   1045 
   1046 Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
   1047     const dnn::BatchDescriptor &input_descriptor,
   1048     const DeviceMemory<Eigen::half> &input_data,
   1049     const dnn::BatchDescriptor &output_descriptor,
   1050     DeviceMemory<Eigen::half> backward_output_data,
   1051     const dnn::ConvolutionDescriptor &convolution_descriptor,
   1052     const dnn::FilterDescriptor &filter_descriptor,
   1053     DeviceMemory<Eigen::half> *backward_filter_data,
   1054     ScratchAllocator *scratch_allocator,
   1055     const dnn::AlgorithmConfig &algorithm_config,
   1056     dnn::ProfileResult *output_profile_result) {
   1057   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
   1058             PARAM(output_descriptor), PARAM(backward_output_data),
   1059             PARAM(convolution_descriptor), PARAM(filter_descriptor),
   1060             PARAM(backward_filter_data));
   1061 
   1062   if (ok()) {
   1063     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1064       DeviceMemory<uint8> scratch_memory;
   1065       dnn::AlgorithmDesc algorithm_desc;
   1066       auto status =
   1067           dnn->PrepareForConvolution(
   1068                  dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
   1069                  input_data, filter_descriptor, *backward_filter_data,
   1070                  output_descriptor, backward_output_data,
   1071                  convolution_descriptor, algorithm_config, scratch_allocator,
   1072                  &algorithm_desc, &scratch_memory)
   1073               .ok();
   1074       if (status) {
   1075         status = dnn->DoConvolveBackwardFilter(
   1076             this, input_descriptor, input_data, output_descriptor,
   1077             backward_output_data, convolution_descriptor, filter_descriptor,
   1078             backward_filter_data, algorithm_desc, &scratch_memory,
   1079             output_profile_result);
   1080       }
   1081       if (!status && !output_profile_result) {
   1082         SetError();
   1083       }
   1084     } else {
   1085       SetErrorAndLogNoDnnSupport();
   1086     }
   1087   }
   1088   return *this;
   1089 }
   1090 
   1091 template <typename T>
   1092 Stream &Stream::ThenConvolveBackwardBiasImpl(
   1093     const dnn::BatchDescriptor &input_descriptor,
   1094     const DeviceMemory<T> &input_data,
   1095     const dnn::BatchDescriptor &bias_descriptor,
   1096     DeviceMemory<T> *backward_bias_data) {
   1097   VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
   1098             PARAM(backward_bias_data));
   1099 
   1100   if (ok()) {
   1101     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1102       CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
   1103                                              bias_descriptor,
   1104                                              backward_bias_data));
   1105     } else {
   1106       SetErrorAndLogNoDnnSupport();
   1107     }
   1108   }
   1109   return *this;
   1110 }
   1111 
   1112 Stream &Stream::ThenConvolveBackwardBias(
   1113     const dnn::BatchDescriptor &input_descriptor,
   1114     const DeviceMemory<double> &input_data,
   1115     const dnn::BatchDescriptor &bias_descriptor,
   1116     DeviceMemory<double> *backward_bias_data) {
   1117   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
   1118                                       bias_descriptor, backward_bias_data);
   1119 }
   1120 
   1121 Stream &Stream::ThenConvolveBackwardBias(
   1122     const dnn::BatchDescriptor &input_descriptor,
   1123     const DeviceMemory<float> &input_data,
   1124     const dnn::BatchDescriptor &bias_descriptor,
   1125     DeviceMemory<float> *backward_bias_data) {
   1126   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
   1127                                       bias_descriptor, backward_bias_data);
   1128 }
   1129 
   1130 Stream &Stream::ThenConvolveBackwardBias(
   1131     const dnn::BatchDescriptor &input_descriptor,
   1132     const DeviceMemory<Eigen::half> &input_data,
   1133     const dnn::BatchDescriptor &bias_descriptor,
   1134     DeviceMemory<Eigen::half> *backward_bias_data) {
   1135   return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
   1136                                       bias_descriptor, backward_bias_data);
   1137 }
   1138 
   1139 Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
   1140                            const DeviceMemory<float> &weights,
   1141                            const dnn::BatchDescriptor &input_dimensions,
   1142                            const dnn::BatchDescriptor &output_dimensions,
   1143                            DeviceMemory<float> *output_data) {
   1144   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(input_dimensions),
   1145             PARAM(output_dimensions), PARAM(output_data));
   1146 
   1147   if (ok()) {
   1148     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1149       CheckError(dnn->DoMatMul(this, input_data, weights, input_dimensions,
   1150                                output_dimensions, output_data));
   1151     } else {
   1152       SetErrorAndLogNoDnnSupport();
   1153     }
   1154   }
   1155   return *this;
   1156 }
   1157 
   1158 Stream &Stream::ThenMatMulQuantized(
   1159     const DeviceMemory<float> &input_data, const DeviceMemory<int8> &weights,
   1160     const DeviceMemory<float> &weight_scales,
   1161     const dnn::BatchDescriptor &input_dimensions,
   1162     const dnn::BatchDescriptor &output_dimensions,
   1163     DeviceMemory<float> *output_data) {
   1164   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
   1165             PARAM(input_dimensions), PARAM(output_dimensions),
   1166             PARAM(output_data));
   1167 
   1168   if (ok()) {
   1169     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1170       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
   1171                                         weight_scales, input_dimensions,
   1172                                         output_dimensions, output_data));
   1173     } else {
   1174       SetErrorAndLogNoDnnSupport();
   1175     }
   1176   }
   1177   return *this;
   1178 }
   1179 
   1180 Stream &Stream::ThenMatMulQuantized(
   1181     const DeviceMemory<float> &input_data, const DeviceMemory<int16> &weights,
   1182     const DeviceMemory<float> &weight_scales,
   1183     const dnn::BatchDescriptor &input_dimensions,
   1184     const dnn::BatchDescriptor &output_dimensions,
   1185     DeviceMemory<float> *output_data) {
   1186   VLOG_CALL(PARAM(input_data), PARAM(weights), PARAM(weight_scales),
   1187             PARAM(input_dimensions), PARAM(output_dimensions),
   1188             PARAM(output_data));
   1189 
   1190   if (ok()) {
   1191     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1192       CheckError(dnn->DoMatMulQuantized(this, input_data, weights,
   1193                                         weight_scales, input_dimensions,
   1194                                         output_dimensions, output_data));
   1195     } else {
   1196       SetErrorAndLogNoDnnSupport();
   1197     }
   1198   }
   1199   return *this;
   1200 }
   1201 
   1202 Stream &Stream::ThenBiasAdd(const DeviceMemory<float> &input_data,
   1203                             const DeviceMemory<float> &biases,
   1204                             const dnn::BatchDescriptor &dimensions,
   1205                             DeviceMemory<float> *output_data) {
   1206   VLOG_CALL(PARAM(input_data), PARAM(biases), PARAM(dimensions),
   1207             PARAM(output_data));
   1208 
   1209   if (ok()) {
   1210     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1211       CheckError(
   1212           dnn->DoBiasAdd(this, input_data, biases, dimensions, output_data));
   1213     } else {
   1214       SetErrorAndLogNoDnnSupport();
   1215     }
   1216   }
   1217   return *this;
   1218 }
   1219 
   1220 Stream &Stream::ThenPoolForward(
   1221     const dnn::PoolingDescriptor &pooling_dimensions,
   1222     const dnn::BatchDescriptor &input_dimensions,
   1223     const DeviceMemory<double> &input_data,
   1224     const dnn::BatchDescriptor &output_dimensions,
   1225     DeviceMemory<double> *output_data, ScratchAllocator *workspace_allocator) {
   1226   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1227             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1228             PARAM(workspace_allocator));
   1229 
   1230   if (ok()) {
   1231     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1232       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
   1233                                     input_data, output_dimensions, output_data,
   1234                                     workspace_allocator));
   1235     } else {
   1236       SetError();
   1237       LOG(WARNING)
   1238           << "attempting to perform DNN operation using StreamExecutor "
   1239              "without DNN support";
   1240     }
   1241   }
   1242   return *this;
   1243 }
   1244 
   1245 Stream &Stream::ThenPoolForward(
   1246     const dnn::PoolingDescriptor &pooling_dimensions,
   1247     const dnn::BatchDescriptor &input_dimensions,
   1248     const DeviceMemory<float> &input_data,
   1249     const dnn::BatchDescriptor &output_dimensions,
   1250     DeviceMemory<float> *output_data, ScratchAllocator *workspace_allocator) {
   1251   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1252             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1253             PARAM(workspace_allocator));
   1254 
   1255   if (ok()) {
   1256     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1257       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
   1258                                     input_data, output_dimensions, output_data,
   1259                                     workspace_allocator));
   1260     } else {
   1261       SetErrorAndLogNoDnnSupport();
   1262     }
   1263   }
   1264   return *this;
   1265 }
   1266 
   1267 Stream &Stream::ThenPoolForward(
   1268     const dnn::PoolingDescriptor &pooling_dimensions,
   1269     const dnn::BatchDescriptor &input_dimensions,
   1270     const DeviceMemory<Eigen::half> &input_data,
   1271     const dnn::BatchDescriptor &output_dimensions,
   1272     DeviceMemory<Eigen::half> *output_data,
   1273     ScratchAllocator *workspace_allocator) {
   1274   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1275             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1276             PARAM(workspace_allocator));
   1277 
   1278   if (ok()) {
   1279     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1280       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
   1281                                     input_data, output_dimensions, output_data,
   1282                                     workspace_allocator));
   1283     } else {
   1284       SetErrorAndLogNoDnnSupport();
   1285     }
   1286   }
   1287   return *this;
   1288 }
   1289 
   1290 Stream &Stream::ThenPoolForward(
   1291     const dnn::PoolingDescriptor &pooling_dimensions,
   1292     const dnn::BatchDescriptor &input_dimensions,
   1293     const DeviceMemory<int8> &input_data,
   1294     const dnn::BatchDescriptor &output_dimensions,
   1295     DeviceMemory<int8> *output_data, ScratchAllocator *workspace_allocator) {
   1296   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1297             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1298             PARAM(workspace_allocator));
   1299 
   1300   if (ok()) {
   1301     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1302       CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
   1303                                     input_data, output_dimensions, output_data,
   1304                                     workspace_allocator));
   1305     } else {
   1306       SetErrorAndLogNoDnnSupport();
   1307     }
   1308   }
   1309   return *this;
   1310 }
   1311 
   1312 Stream &Stream::ThenPoolBackward(
   1313     const dnn::PoolingDescriptor &pooling_dimensions,
   1314     const dnn::BatchDescriptor &input_dimensions,
   1315     const DeviceMemory<double> &input_data,
   1316     const dnn::BatchDescriptor &output_dimensions,
   1317     const DeviceMemory<double> &output_data,
   1318     const DeviceMemory<double> &input_diff_data,
   1319     DeviceMemory<double> *output_diff_data,
   1320     ScratchAllocator *workspace_allocator) {
   1321   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1322             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1323             PARAM(input_diff_data), PARAM(output_diff_data),
   1324             PARAM(workspace_allocator));
   1325 
   1326   if (ok()) {
   1327     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1328       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
   1329                                      input_data, output_dimensions, output_data,
   1330                                      input_diff_data, output_diff_data,
   1331                                      workspace_allocator));
   1332     } else {
   1333       SetError();
   1334       LOG(WARNING)
   1335           << "attempting to perform DNN operation using StreamExecutor "
   1336              "without DNN support";
   1337     }
   1338   }
   1339   return *this;
   1340 }
   1341 
   1342 Stream &Stream::ThenPoolBackward(
   1343     const dnn::PoolingDescriptor &pooling_dimensions,
   1344     const dnn::BatchDescriptor &input_dimensions,
   1345     const DeviceMemory<float> &input_data,
   1346     const dnn::BatchDescriptor &output_dimensions,
   1347     const DeviceMemory<float> &output_data,
   1348     const DeviceMemory<float> &input_diff_data,
   1349     DeviceMemory<float> *output_diff_data,
   1350     ScratchAllocator *workspace_allocator) {
   1351   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1352             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1353             PARAM(input_diff_data), PARAM(output_diff_data),
   1354             PARAM(workspace_allocator));
   1355 
   1356   if (ok()) {
   1357     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1358       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
   1359                                      input_data, output_dimensions, output_data,
   1360                                      input_diff_data, output_diff_data,
   1361                                      workspace_allocator));
   1362     } else {
   1363       SetErrorAndLogNoDnnSupport();
   1364     }
   1365   }
   1366   return *this;
   1367 }
   1368 
   1369 Stream &Stream::ThenPoolBackward(
   1370     const dnn::PoolingDescriptor &pooling_dimensions,
   1371     const dnn::BatchDescriptor &input_dimensions,
   1372     const DeviceMemory<Eigen::half> &input_data,
   1373     const dnn::BatchDescriptor &output_dimensions,
   1374     const DeviceMemory<Eigen::half> &output_data,
   1375     const DeviceMemory<Eigen::half> &input_diff_data,
   1376     DeviceMemory<Eigen::half> *output_diff_data,
   1377     ScratchAllocator *workspace_allocator) {
   1378   VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
   1379             PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
   1380             PARAM(input_diff_data), PARAM(output_diff_data),
   1381             PARAM(workspace_allocator));
   1382 
   1383   if (ok()) {
   1384     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1385       CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
   1386                                      input_data, output_dimensions, output_data,
   1387                                      input_diff_data, output_diff_data,
   1388                                      workspace_allocator));
   1389     } else {
   1390       SetErrorAndLogNoDnnSupport();
   1391     }
   1392   }
   1393   return *this;
   1394 }
   1395 
   1396 Stream &Stream::ThenNormalizeWithDimensions(
   1397     const dnn::NormalizeDescriptor &normalize_descriptor,
   1398     const dnn::BatchDescriptor &dimensions,
   1399     const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
   1400   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(input_data),
   1401             PARAM(output_data));
   1402 
   1403   if (ok()) {
   1404     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1405       CheckError(dnn->DoNormalizeWithDimensions(
   1406           this, normalize_descriptor, dimensions, input_data, output_data));
   1407     } else {
   1408       SetErrorAndLogNoDnnSupport();
   1409     }
   1410   }
   1411   return *this;
   1412 }
   1413 
   1414 Stream &Stream::ThenNormalizeBackwardWithDimensions(
   1415     const dnn::NormalizeDescriptor &normalize_descriptor,
   1416     const dnn::BatchDescriptor &dimensions, const DeviceMemory<float> &raw_data,
   1417     const DeviceMemory<float> &normalized_data,
   1418     const DeviceMemory<float> &normalized_variable_gradient,
   1419     DeviceMemory<float> *raw_variable_gradient,
   1420     ScratchAllocator *workspace_allocator) {
   1421   VLOG_CALL(PARAM(normalize_descriptor), PARAM(dimensions), PARAM(raw_data),
   1422             PARAM(normalized_data), PARAM(normalized_variable_gradient),
   1423             PARAM(raw_variable_gradient), PARAM(workspace_allocator));
   1424 
   1425   if (ok()) {
   1426     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1427       CheckError(dnn->DoNormalizeBackwardWithDimensions(
   1428           this, normalize_descriptor, dimensions, raw_data, normalized_data,
   1429           normalized_variable_gradient, raw_variable_gradient,
   1430           workspace_allocator));
   1431     } else {
   1432       SetErrorAndLogNoDnnSupport();
   1433     }
   1434   }
   1435   return *this;
   1436 }
   1437 
   1438 Stream &Stream::ThenActivate(dnn::ActivationMode activation_mode,
   1439                              const dnn::BatchDescriptor &dimensions,
   1440                              const DeviceMemory<float> &input_data,
   1441                              DeviceMemory<float> *output_data) {
   1442   return ThenActivateWithOptions(activation_mode, dimensions, input_data,
   1443                                  output_data, /*options=*/0);
   1444 }
   1445 
   1446 Stream &Stream::ThenActivateWithOptions(dnn::ActivationMode activation_mode,
   1447                                         const dnn::BatchDescriptor &dimensions,
   1448                                         const DeviceMemory<float> &input_data,
   1449                                         DeviceMemory<float> *output_data,
   1450                                         uint64 options) {
   1451   VLOG_CALL(PARAM(activation_mode), PARAM(dimensions), PARAM(input_data),
   1452             PARAM(output_data), PARAM(options));
   1453 
   1454   if (ok()) {
   1455     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1456       CheckError(dnn->DoActivate(this, activation_mode, dimensions, input_data,
   1457                                  output_data, options));
   1458     } else {
   1459       SetErrorAndLogNoDnnSupport();
   1460     }
   1461   }
   1462   return *this;
   1463 }
   1464 
   1465 Stream &Stream::ThenDepthConcatenate(
   1466     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   1467     port::ArraySlice<const DeviceMemory<float> *> input_data,
   1468     DeviceMemory<float> *output_data) {
   1469   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
   1470 
   1471   for (size_t i = 1; i < input_dimensions.size(); ++i) {
   1472     if (input_dimensions[i].count() != input_dimensions[0].count() ||
   1473         input_dimensions[i].height() != input_dimensions[0].height() ||
   1474         input_dimensions[i].width() != input_dimensions[0].width()) {
   1475       SetError();
   1476       LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
   1477                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
   1478                  << "input_dimensions[" << i
   1479                  << "]: " << input_dimensions[i].ToString();
   1480       return *this;
   1481     }
   1482   }
   1483 
   1484   if (ok()) {
   1485     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1486       CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
   1487                                          output_data));
   1488     } else {
   1489       SetErrorAndLogNoDnnSupport();
   1490     }
   1491   }
   1492   return *this;
   1493 }
   1494 
   1495 Stream &Stream::ThenSpaceConcatenate(
   1496     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   1497     port::ArraySlice<const DeviceMemory<float> *> input_data,
   1498     DeviceMemory<float> *output_data,
   1499     dnn::SpaceConcatenateMode concat_direction) {
   1500   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
   1501 
   1502   // Check that the input dimensions of all the other batches match those of the
   1503   // first batch.
   1504   for (size_t i = 1; i < input_dimensions.size(); ++i) {
   1505     if ((concat_direction == dnn::SpaceConcatenateMode::XDirection) &&
   1506         (input_dimensions[i].count() != input_dimensions[0].count() ||
   1507          input_dimensions[i].height() != input_dimensions[0].height() ||
   1508          input_dimensions[i].feature_map_count() !=
   1509              input_dimensions[0].feature_map_count())) {
   1510       SetError();
   1511       LOG(ERROR) << "Incompatible dimensions for X concatenation.\n"
   1512                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
   1513                  << "input_dimensions[" << i
   1514                  << "]: " << input_dimensions[i].ToString();
   1515       return *this;
   1516     }
   1517 
   1518     if ((concat_direction == dnn::SpaceConcatenateMode::YDirection) &&
   1519         (input_dimensions[i].count() != input_dimensions[0].count() ||
   1520          input_dimensions[i].width() != input_dimensions[0].width() ||
   1521          input_dimensions[i].feature_map_count() !=
   1522              input_dimensions[0].feature_map_count())) {
   1523       SetError();
   1524       LOG(ERROR) << "Incompatible dimensions for Y concatenation.\n"
   1525                  << "input_dimensions[0]: " << input_dimensions[0].ToString()
   1526                  << "input_dimensions[" << i
   1527                  << "]: " << input_dimensions[i].ToString();
   1528       return *this;
   1529     }
   1530   }
   1531   if (ok()) {
   1532     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1533       CheckError(dnn->DoSpaceConcatenate(this, input_dimensions, input_data,
   1534                                          output_data, concat_direction));
   1535     } else {
   1536       SetErrorAndLogNoDnnSupport();
   1537     }
   1538   }
   1539   return *this;
   1540 }
   1541 
   1542 Stream &Stream::ThenReshape(const dnn::BatchDescriptor &input_dimensions,
   1543                             const DeviceMemory<float> &input_data,
   1544                             const dnn::BatchDescriptor &output_dimensions,
   1545                             DeviceMemory<float> *output_data) {
   1546   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
   1547             PARAM(output_dimensions), PARAM(output_data));
   1548 
   1549   if (ok()) {
   1550     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1551       CheckError(dnn->DoReshape(this, input_dimensions, input_data,
   1552                                 output_dimensions, output_data));
   1553     } else {
   1554       SetErrorAndLogNoDnnSupport();
   1555     }
   1556   }
   1557   return *this;
   1558 }
   1559 
   1560 Stream &Stream::ThenDepthToSpace(
   1561     const dnn::BatchDescriptor &input_dimensions,
   1562     const DeviceMemory<float> &input_data,
   1563     const dnn::DepthToSpaceLayout &depth_to_space_layout,
   1564     const int sqrt_depth_reduction, DeviceMemory<float> *output_data) {
   1565   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
   1566             PARAM(depth_to_space_layout), PARAM(sqrt_depth_reduction),
   1567             PARAM(output_data));
   1568 
   1569   if (ok()) {
   1570     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1571       CheckError(dnn->DoDepthToSpace(this, input_dimensions, input_data,
   1572                                      depth_to_space_layout,
   1573                                      sqrt_depth_reduction, output_data));
   1574     } else {
   1575       SetErrorAndLogNoDnnSupport();
   1576     }
   1577   }
   1578   return *this;
   1579 }
   1580 
   1581 Stream &Stream::ThenSpaceToDepth(
   1582     const dnn::BatchDescriptor &input_dimensions,
   1583     const DeviceMemory<float> &input_data,
   1584     const dnn::DepthToSpaceLayout &space_to_depth_layout,
   1585     const int sqrt_depth_increase, DeviceMemory<float> *output_data) {
   1586   VLOG_CALL(PARAM(input_dimensions), PARAM(input_data),
   1587             PARAM(space_to_depth_layout), PARAM(sqrt_depth_increase),
   1588             PARAM(output_data));
   1589 
   1590   if (ok()) {
   1591     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1592       CheckError(dnn->DoSpaceToDepth(this, input_dimensions, input_data,
   1593                                      space_to_depth_layout, sqrt_depth_increase,
   1594                                      output_data));
   1595     } else {
   1596       SetErrorAndLogNoDnnSupport();
   1597     }
   1598   }
   1599   return *this;
   1600 }
   1601 
   1602 Stream &Stream::ThenElementwiseOperate(
   1603     dnn::ElementwiseOperation operation,
   1604     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   1605     port::ArraySlice<const DeviceMemory<float> *> input_data,
   1606     const dnn::BatchDescriptor &output_dimensions,
   1607     DeviceMemory<float> *output_data) {
   1608   VLOG_CALL(PARAM(operation), PARAM(input_dimensions), PARAM(input_data),
   1609             PARAM(output_dimensions), PARAM(output_data));
   1610 
   1611   if (ok()) {
   1612     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1613       CheckError(dnn->DoElementwiseOperate(this, operation, input_dimensions,
   1614                                            input_data, output_dimensions,
   1615                                            output_data));
   1616     } else {
   1617       SetErrorAndLogNoDnnSupport();
   1618     }
   1619   }
   1620   return *this;
   1621 }
   1622 
   1623 Stream &Stream::ThenElementwiseOperateScaledQuantized(
   1624     dnn::ElementwiseOperation operation,
   1625     port::ArraySlice<int> input_multiplicands, int output_divisor,
   1626     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
   1627     port::ArraySlice<const DeviceMemory<float> *> input_data,
   1628     const dnn::BatchDescriptor &output_dimensions,
   1629     DeviceMemory<float> *output_data) {
   1630   VLOG_CALL(PARAM(operation), PARAM(input_multiplicands), PARAM(output_divisor),
   1631             PARAM(input_dimensions), PARAM(input_data),
   1632             PARAM(output_dimensions), PARAM(output_data));
   1633 
   1634   if (ok()) {
   1635     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1636       CheckError(dnn->DoElementwiseOperateScaledQuantized(
   1637           this, operation, input_multiplicands, output_divisor,
   1638           input_dimensions, input_data, output_dimensions, output_data));
   1639     } else {
   1640       SetErrorAndLogNoDnnSupport();
   1641     }
   1642   }
   1643   return *this;
   1644 }
   1645 
   1646 Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
   1647                           const DeviceMemory<float> &input_data, int64 left_pad,
   1648                           int64 right_pad, int64 top_pad, int64 bottom_pad,
   1649                           DeviceMemory<float> *output_data) {
   1650   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
   1651             PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
   1652             PARAM(output_data));
   1653 
   1654   if (ok()) {
   1655     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1656       CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
   1657                               top_pad, bottom_pad, output_data));
   1658     } else {
   1659       SetErrorAndLogNoDnnSupport();
   1660     }
   1661   }
   1662   return *this;
   1663 }
   1664 
   1665 Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
   1666                             const DeviceMemory<float> &input_data,
   1667                             int64 left_trim, int64 right_trim, int64 top_trim,
   1668                             int64 bottom_trim,
   1669                             DeviceMemory<float> *output_data) {
   1670   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
   1671             PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
   1672             PARAM(output_data));
   1673 
   1674   if (ok()) {
   1675     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1676       CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
   1677                                 right_trim, top_trim, bottom_trim,
   1678                                 output_data));
   1679     } else {
   1680       SetErrorAndLogNoDnnSupport();
   1681     }
   1682   }
   1683   return *this;
   1684 }
   1685 
   1686 Stream &Stream::ThenXYBroadcast(const dnn::BatchDescriptor &dimensions,
   1687                                 const DeviceMemory<float> &input_data,
   1688                                 int64 replicate_x, int64 replicate_y,
   1689                                 DeviceMemory<float> *output_data) {
   1690   VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(replicate_x),
   1691             PARAM(replicate_y), PARAM(output_data));
   1692 
   1693   if (ok()) {
   1694     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1695       CheckError(dnn->DoXYBroadcast(this, dimensions, input_data, replicate_x,
   1696                                     replicate_y, output_data));
   1697     } else {
   1698       SetErrorAndLogNoDnnSupport();
   1699     }
   1700   }
   1701   return *this;
   1702 }
   1703 
   1704 Stream &Stream::ThenMemcpyD2HQuantized(
   1705     const DeviceMemory<float> &gpu_unquantized_src,
   1706     dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
   1707   VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
   1708             PARAM(size));
   1709 
   1710   if (ok()) {
   1711     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1712       CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
   1713                                            host_dst, size));
   1714     } else {
   1715       SetErrorAndLogNoDnnSupport();
   1716     }
   1717   }
   1718   return *this;
   1719 }
   1720 
   1721 Stream &Stream::ThenMemcpyH2DQuantized(
   1722     const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
   1723     DeviceMemory<float> *gpu_unquantized_dst) {
   1724   VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
   1725             PARAM(gpu_unquantized_dst));
   1726 
   1727   if (ok()) {
   1728     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   1729       CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
   1730                                            gpu_unquantized_dst));
   1731     } else {
   1732       SetErrorAndLogNoDnnSupport();
   1733     }
   1734   }
   1735   return *this;
   1736 }
   1737 
   1738 Stream *Stream::GetOrCreateSubStream() {
   1739   mutex_lock lock(mu_);
   1740 
   1741   // Look for the first reusable sub_stream that is ok, dropping !ok sub_streams
   1742   // we encounter along the way.
   1743   for (int64 index = 0; index < sub_streams_.size();) {
   1744     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
   1745     if (pair.second) {
   1746       // The sub_stream is reusable.
   1747       Stream *sub_stream = pair.first.get();
   1748       if (sub_stream->ok()) {
   1749         VLOG(1) << DebugStreamPointers() << " reusing sub_stream "
   1750                 << sub_stream->DebugStreamPointers();
   1751         pair.second = false;
   1752         return sub_stream;
   1753       }
   1754 
   1755       // The stream is reusable and not ok. Streams have a monotonic state
   1756       // machine; the stream will remain in !ok forever. Swap it with the last
   1757       // stream and pop it off.
   1758       const int64 last = sub_streams_.size() - 1;
   1759       if (index != last) {
   1760         std::swap(pair, sub_streams_[last]);
   1761       }
   1762       sub_streams_.pop_back();
   1763       VLOG(1) << DebugStreamPointers() << " dropped !ok sub_stream "
   1764               << sub_stream->DebugStreamPointers();
   1765     } else {
   1766       // The sub_stream is not reusable, move on to the next one.
   1767       ++index;
   1768     }
   1769   }
   1770 
   1771   // No streams are reusable; create a new stream.
   1772   sub_streams_.emplace_back(std::unique_ptr<Stream>{new Stream{parent_}},
   1773                             false);
   1774   Stream *sub_stream = sub_streams_.back().first.get();
   1775   sub_stream->Init();
   1776   if (!sub_stream->ok_) {
   1777     LOG(ERROR) << "sub-stream failed to be initialized";
   1778   }
   1779   VLOG(1) << DebugStreamPointers() << " created new sub_stream "
   1780           << sub_stream->DebugStreamPointers();
   1781 
   1782   return sub_stream;
   1783 }
   1784 
   1785 void Stream::ReturnSubStream(Stream *sub_stream) {
   1786   mutex_lock lock(mu_);
   1787 
   1788   // Look for the sub-stream.
   1789   for (int64 index = 0; index < sub_streams_.size(); ++index) {
   1790     std::pair<std::unique_ptr<Stream>, bool> &pair = sub_streams_[index];
   1791     if (pair.first.get() != sub_stream) {
   1792       continue;
   1793     }
   1794 
   1795     // Found the sub_stream.
   1796     if (sub_stream->ok()) {
   1797       VLOG(1) << DebugStreamPointers() << " returned ok sub_stream "
   1798               << sub_stream->DebugStreamPointers();
   1799       pair.second = true;
   1800     } else {
   1801       // The returned stream is not ok. Streams have a monotonic state
   1802       // machine; the stream will remain in !ok forever. Swap it with the last
   1803       // stream and pop it off.
   1804       VLOG(1) << DebugStreamPointers() << " returned !ok sub_stream "
   1805               << sub_stream->DebugStreamPointers();
   1806       const int64 last = sub_streams_.size() - 1;
   1807       if (index != last) {
   1808         std::swap(pair, sub_streams_[last]);
   1809       }
   1810       sub_streams_.pop_back();
   1811     }
   1812     return;
   1813   }
   1814 
   1815   LOG(FATAL) << DebugStreamPointers()
   1816              << " did not create the returned sub-stream "
   1817              << sub_stream->DebugStreamPointers();
   1818 }
   1819 
   1820 Stream &Stream::ThenStartTimer(Timer *t) {
   1821   VLOG_CALL(PARAM(t));
   1822 
   1823   if (ok()) {
   1824     CheckError(parent_->StartTimer(this, t));
   1825   } else {
   1826     LOG(INFO) << DebugStreamPointers()
   1827               << " did not enqueue 'start timer': " << t;
   1828   }
   1829   return *this;
   1830 }
   1831 
   1832 Stream &Stream::ThenStopTimer(Timer *t) {
   1833   VLOG_CALL(PARAM(t));
   1834 
   1835   if (ok()) {
   1836     CheckError(parent_->StopTimer(this, t));
   1837   } else {
   1838     LOG(INFO) << DebugStreamPointers()
   1839               << " did not enqueue 'stop timer': " << t;
   1840   }
   1841   return *this;
   1842 }
   1843 
   1844 Stream &Stream::ThenWaitFor(Stream *other) {
   1845   VLOG_CALL(PARAM(other));
   1846 
   1847   CHECK(this != other) << "stream cannot wait for itself";
   1848   if (ok() && other->ok()) {
   1849     CheckError(parent_->CreateStreamDependency(this, other));
   1850   } else {
   1851     SetError();
   1852     LOG(INFO) << DebugStreamPointers() << " did not wait for "
   1853               << other->DebugStreamPointers();
   1854   }
   1855   return *this;
   1856 }
   1857 
   1858 Stream &Stream::ThenWaitFor(Event *event) {
   1859   VLOG_CALL(PARAM(event));
   1860 
   1861   if (ok()) {
   1862     port::Status status = parent_->WaitForEvent(this, event);
   1863     if (!status.ok()) {
   1864       LOG(ERROR) << "Error waiting for event in stream: "
   1865                  << status.error_message()
   1866                  << "; not marking stream as bad, as the Event object may be "
   1867                  << "at fault. Monitor for further errors.";
   1868     }
   1869   } else {
   1870     LOG(INFO) << DebugStreamPointers() << " did not wait for an event.";
   1871   }
   1872   return *this;
   1873 }
   1874 
   1875 // A functor that implements ThenBlasXXX interfaces, which calls DoBlasXXX
   1876 // functions and logs for errors.
   1877 template <typename... Args>
   1878 struct ThenBlasImpl {
   1879   // blas_func is the DoBlasXXX member function pointer, and args are its
   1880   // arguments except the first one of Stream* type.
   1881   Stream &operator()(Stream *stream,
   1882                      bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
   1883                      Args... args) {
   1884     return Run(stream, blas_func, /*record_error=*/true, args...);
   1885   }
   1886 
   1887   // Like operator(), but only calls stream->CheckError() if record_error is
   1888   // true.
   1889   Stream &Run(Stream *stream,
   1890               bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
   1891               bool record_error, Args... args);
   1892 };
   1893 
   1894 template <typename... Args>
   1895 Stream &ThenBlasImpl<Args...>::Run(
   1896     Stream *stream, bool (blas::BlasSupport::*blas_func)(Stream *, Args...),
   1897     bool record_error, Args... args) {
   1898   if (stream->ok()) {
   1899     bool ok;
   1900     if (blas::BlasSupport *blas = stream->parent_->AsBlas()) {
   1901       ok = (blas->*blas_func)(stream, args...);
   1902     } else {
   1903       LOG(WARNING)
   1904           << "attempting to perform BLAS operation using StreamExecutor "
   1905              "without BLAS support";
   1906       ok = false;
   1907     }
   1908     if (record_error) {
   1909       stream->CheckError(ok);
   1910     }
   1911   }
   1912   return *stream;
   1913 }
   1914 
   1915 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<float> &x,
   1916                              int incx, DeviceMemory<float> *result) {
   1917   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   1918 
   1919   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
   1920       impl;
   1921   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
   1922               result);
   1923 }
   1924 
   1925 Stream &Stream::ThenBlasAsum(uint64 elem_count, const DeviceMemory<double> &x,
   1926                              int incx, DeviceMemory<double> *result) {
   1927   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   1928 
   1929   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
   1930                DeviceMemory<double> *> impl;
   1931   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
   1932               result);
   1933 }
   1934 
   1935 Stream &Stream::ThenBlasAsum(uint64 elem_count,
   1936                              const DeviceMemory<std::complex<float>> &x,
   1937                              int incx, DeviceMemory<float> *result) {
   1938   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   1939 
   1940   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   1941                DeviceMemory<float> *> impl;
   1942   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
   1943               result);
   1944 }
   1945 
   1946 Stream &Stream::ThenBlasAsum(uint64 elem_count,
   1947                              const DeviceMemory<std::complex<double>> &x,
   1948                              int incx, DeviceMemory<double> *result) {
   1949   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   1950 
   1951   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   1952                DeviceMemory<double> *> impl;
   1953   return impl(this, &blas::BlasSupport::DoBlasAsum, elem_count, x, incx,
   1954               result);
   1955 }
   1956 
   1957 Stream &Stream::ThenBlasAxpy(uint64 elem_count, float alpha,
   1958                              const DeviceMemory<float> &x, int incx,
   1959                              DeviceMemory<float> *y, int incy) {
   1960   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   1961             PARAM(incy));
   1962 
   1963   ThenBlasImpl<uint64, float, const DeviceMemory<float> &, int,
   1964                DeviceMemory<float> *, int> impl;
   1965   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
   1966               y, incy);
   1967 }
   1968 
   1969 Stream &Stream::ThenBlasAxpy(uint64 elem_count, double alpha,
   1970                              const DeviceMemory<double> &x, int incx,
   1971                              DeviceMemory<double> *y, int incy) {
   1972   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   1973             PARAM(incy));
   1974 
   1975   ThenBlasImpl<uint64, double, const DeviceMemory<double> &, int,
   1976                DeviceMemory<double> *, int> impl;
   1977   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
   1978               y, incy);
   1979 }
   1980 
   1981 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<float> alpha,
   1982                              const DeviceMemory<std::complex<float>> &x,
   1983                              int incx, DeviceMemory<std::complex<float>> *y,
   1984                              int incy) {
   1985   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   1986             PARAM(incy));
   1987 
   1988   ThenBlasImpl<uint64, std::complex<float>,
   1989                const DeviceMemory<std::complex<float>> &, int,
   1990                DeviceMemory<std::complex<float>> *, int> impl;
   1991   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
   1992               y, incy);
   1993 }
   1994 
   1995 Stream &Stream::ThenBlasAxpy(uint64 elem_count, std::complex<double> alpha,
   1996                              const DeviceMemory<std::complex<double>> &x,
   1997                              int incx, DeviceMemory<std::complex<double>> *y,
   1998                              int incy) {
   1999   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2000             PARAM(incy));
   2001 
   2002   ThenBlasImpl<uint64, std::complex<double>,
   2003                const DeviceMemory<std::complex<double>> &, int,
   2004                DeviceMemory<std::complex<double>> *, int> impl;
   2005   return impl(this, &blas::BlasSupport::DoBlasAxpy, elem_count, alpha, x, incx,
   2006               y, incy);
   2007 }
   2008 
   2009 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<float> &x,
   2010                              int incx, DeviceMemory<float> *y, int incy) {
   2011   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2012 
   2013   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
   2014                int> impl;
   2015   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
   2016               incy);
   2017 }
   2018 
   2019 Stream &Stream::ThenBlasCopy(uint64 elem_count, const DeviceMemory<double> &x,
   2020                              int incx, DeviceMemory<double> *y, int incy) {
   2021   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2022 
   2023   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
   2024                DeviceMemory<double> *, int> impl;
   2025   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
   2026               incy);
   2027 }
   2028 
   2029 Stream &Stream::ThenBlasCopy(uint64 elem_count,
   2030                              const DeviceMemory<std::complex<float>> &x,
   2031                              int incx, DeviceMemory<std::complex<float>> *y,
   2032                              int incy) {
   2033   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2034 
   2035   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2036                DeviceMemory<std::complex<float>> *, int> impl;
   2037   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
   2038               incy);
   2039 }
   2040 
   2041 Stream &Stream::ThenBlasCopy(uint64 elem_count,
   2042                              const DeviceMemory<std::complex<double>> &x,
   2043                              int incx, DeviceMemory<std::complex<double>> *y,
   2044                              int incy) {
   2045   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2046 
   2047   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2048                DeviceMemory<std::complex<double>> *, int> impl;
   2049   return impl(this, &blas::BlasSupport::DoBlasCopy, elem_count, x, incx, y,
   2050               incy);
   2051 }
   2052 
   2053 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<float> &x,
   2054                             int incx, const DeviceMemory<float> &y, int incy,
   2055                             DeviceMemory<float> *result) {
   2056   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2057             PARAM(result));
   2058 
   2059   ThenBlasImpl<uint64, const DeviceMemory<float> &, int,
   2060                const DeviceMemory<float> &, int, DeviceMemory<float> *> impl;
   2061   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
   2062               result);
   2063 }
   2064 
   2065 Stream &Stream::ThenBlasDot(uint64 elem_count, const DeviceMemory<double> &x,
   2066                             int incx, const DeviceMemory<double> &y, int incy,
   2067                             DeviceMemory<double> *result) {
   2068   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2069             PARAM(result));
   2070 
   2071   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
   2072                const DeviceMemory<double> &, int, DeviceMemory<double> *> impl;
   2073   return impl(this, &blas::BlasSupport::DoBlasDot, elem_count, x, incx, y, incy,
   2074               result);
   2075 }
   2076 
   2077 Stream &Stream::ThenBlasDotc(uint64 elem_count,
   2078                              const DeviceMemory<std::complex<float>> &x,
   2079                              int incx,
   2080                              const DeviceMemory<std::complex<float>> &y,
   2081                              int incy,
   2082                              DeviceMemory<std::complex<float>> *result) {
   2083   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2084             PARAM(result));
   2085 
   2086   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2087                const DeviceMemory<std::complex<float>> &, int,
   2088                DeviceMemory<std::complex<float>> *> impl;
   2089   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
   2090               incy, result);
   2091 }
   2092 
   2093 Stream &Stream::ThenBlasDotc(uint64 elem_count,
   2094                              const DeviceMemory<std::complex<double>> &x,
   2095                              int incx,
   2096                              const DeviceMemory<std::complex<double>> &y,
   2097                              int incy,
   2098                              DeviceMemory<std::complex<double>> *result) {
   2099   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2100             PARAM(result));
   2101 
   2102   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2103                const DeviceMemory<std::complex<double>> &, int,
   2104                DeviceMemory<std::complex<double>> *> impl;
   2105   return impl(this, &blas::BlasSupport::DoBlasDotc, elem_count, x, incx, y,
   2106               incy, result);
   2107 }
   2108 
   2109 Stream &Stream::ThenBlasDotu(uint64 elem_count,
   2110                              const DeviceMemory<std::complex<float>> &x,
   2111                              int incx,
   2112                              const DeviceMemory<std::complex<float>> &y,
   2113                              int incy,
   2114                              DeviceMemory<std::complex<float>> *result) {
   2115   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2116             PARAM(result));
   2117 
   2118   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2119                const DeviceMemory<std::complex<float>> &, int,
   2120                DeviceMemory<std::complex<float>> *> impl;
   2121   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
   2122               incy, result);
   2123 }
   2124 
   2125 Stream &Stream::ThenBlasDotu(uint64 elem_count,
   2126                              const DeviceMemory<std::complex<double>> &x,
   2127                              int incx,
   2128                              const DeviceMemory<std::complex<double>> &y,
   2129                              int incy,
   2130                              DeviceMemory<std::complex<double>> *result) {
   2131   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2132             PARAM(result));
   2133 
   2134   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2135                const DeviceMemory<std::complex<double>> &, int,
   2136                DeviceMemory<std::complex<double>> *> impl;
   2137   return impl(this, &blas::BlasSupport::DoBlasDotu, elem_count, x, incx, y,
   2138               incy, result);
   2139 }
   2140 
   2141 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<float> &x,
   2142                              int incx, DeviceMemory<float> *result) {
   2143   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2144 
   2145   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *>
   2146       impl;
   2147   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
   2148               result);
   2149 }
   2150 
   2151 Stream &Stream::ThenBlasNrm2(uint64 elem_count, const DeviceMemory<double> &x,
   2152                              int incx, DeviceMemory<double> *result) {
   2153   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2154 
   2155   ThenBlasImpl<uint64, const DeviceMemory<double> &, int,
   2156                DeviceMemory<double> *> impl;
   2157   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
   2158               result);
   2159 }
   2160 
   2161 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
   2162                              const DeviceMemory<std::complex<float>> &x,
   2163                              int incx, DeviceMemory<float> *result) {
   2164   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2165 
   2166   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2167                DeviceMemory<float> *> impl;
   2168   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
   2169               result);
   2170 }
   2171 
   2172 Stream &Stream::ThenBlasNrm2(uint64 elem_count,
   2173                              const DeviceMemory<std::complex<double>> &x,
   2174                              int incx, DeviceMemory<double> *result) {
   2175   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2176 
   2177   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2178                DeviceMemory<double> *> impl;
   2179   return impl(this, &blas::BlasSupport::DoBlasNrm2, elem_count, x, incx,
   2180               result);
   2181 }
   2182 
   2183 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<float> *x, int incx,
   2184                             DeviceMemory<float> *y, int incy, float c,
   2185                             float s) {
   2186   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2187             PARAM(c), PARAM(s));
   2188 
   2189   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
   2190                float, float> impl;
   2191   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
   2192               c, s);
   2193 }
   2194 
   2195 Stream &Stream::ThenBlasRot(uint64 elem_count, DeviceMemory<double> *x,
   2196                             int incx, DeviceMemory<double> *y, int incy,
   2197                             double c, double s) {
   2198   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2199             PARAM(c), PARAM(s));
   2200 
   2201   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
   2202                double, double> impl;
   2203   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
   2204               c, s);
   2205 }
   2206 
   2207 Stream &Stream::ThenBlasRot(uint64 elem_count,
   2208                             DeviceMemory<std::complex<float>> *x, int incx,
   2209                             DeviceMemory<std::complex<float>> *y, int incy,
   2210                             float c, float s) {
   2211   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2212             PARAM(c), PARAM(s));
   2213 
   2214   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
   2215                DeviceMemory<std::complex<float>> *, int, float, float> impl;
   2216   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
   2217               c, s);
   2218 }
   2219 
   2220 Stream &Stream::ThenBlasRot(uint64 elem_count,
   2221                             DeviceMemory<std::complex<double>> *x, int incx,
   2222                             DeviceMemory<std::complex<double>> *y, int incy,
   2223                             double c, double s) {
   2224   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2225             PARAM(c), PARAM(s));
   2226 
   2227   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
   2228                DeviceMemory<std::complex<double>> *, int, double, double> impl;
   2229   return impl(this, &blas::BlasSupport::DoBlasRot, elem_count, x, incx, y, incy,
   2230               c, s);
   2231 }
   2232 
   2233 Stream &Stream::ThenBlasRotg(DeviceMemory<float> *a, DeviceMemory<float> *b,
   2234                              DeviceMemory<float> *c, DeviceMemory<float> *s) {
   2235   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
   2236 
   2237   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
   2238                DeviceMemory<float> *, DeviceMemory<float> *> impl;
   2239   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
   2240 }
   2241 
   2242 Stream &Stream::ThenBlasRotg(DeviceMemory<double> *a, DeviceMemory<double> *b,
   2243                              DeviceMemory<double> *c, DeviceMemory<double> *s) {
   2244   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
   2245 
   2246   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
   2247                DeviceMemory<double> *, DeviceMemory<double> *> impl;
   2248   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
   2249 }
   2250 
   2251 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<float>> *a,
   2252                              DeviceMemory<std::complex<float>> *b,
   2253                              DeviceMemory<float> *c,
   2254                              DeviceMemory<std::complex<float>> *s) {
   2255   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
   2256 
   2257   ThenBlasImpl<DeviceMemory<std::complex<float>> *,
   2258                DeviceMemory<std::complex<float>> *, DeviceMemory<float> *,
   2259                DeviceMemory<std::complex<float>> *> impl;
   2260   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
   2261 }
   2262 
   2263 Stream &Stream::ThenBlasRotg(DeviceMemory<std::complex<double>> *a,
   2264                              DeviceMemory<std::complex<double>> *b,
   2265                              DeviceMemory<double> *c,
   2266                              DeviceMemory<std::complex<double>> *s) {
   2267   VLOG_CALL(PARAM(a), PARAM(b), PARAM(c), PARAM(s));
   2268 
   2269   ThenBlasImpl<DeviceMemory<std::complex<double>> *,
   2270                DeviceMemory<std::complex<double>> *, DeviceMemory<double> *,
   2271                DeviceMemory<std::complex<double>> *> impl;
   2272   return impl(this, &blas::BlasSupport::DoBlasRotg, a, b, c, s);
   2273 }
   2274 
   2275 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<float> *x,
   2276                              int incx, DeviceMemory<float> *y, int incy,
   2277                              const DeviceMemory<float> &param) {
   2278   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2279             PARAM(param));
   2280 
   2281   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int,
   2282                const DeviceMemory<float> &> impl;
   2283   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
   2284               incy, param);
   2285 }
   2286 
   2287 Stream &Stream::ThenBlasRotm(uint64 elem_count, DeviceMemory<double> *x,
   2288                              int incx, DeviceMemory<double> *y, int incy,
   2289                              const DeviceMemory<double> &param) {
   2290   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy),
   2291             PARAM(param));
   2292 
   2293   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int,
   2294                const DeviceMemory<double> &> impl;
   2295   return impl(this, &blas::BlasSupport::DoBlasRotm, elem_count, x, incx, y,
   2296               incy, param);
   2297 }
   2298 
   2299 Stream &Stream::ThenBlasRotmg(DeviceMemory<float> *d1, DeviceMemory<float> *d2,
   2300                               DeviceMemory<float> *x1,
   2301                               const DeviceMemory<float> &y1,
   2302                               DeviceMemory<float> *param) {
   2303   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
   2304 
   2305   ThenBlasImpl<DeviceMemory<float> *, DeviceMemory<float> *,
   2306                DeviceMemory<float> *, const DeviceMemory<float> &,
   2307                DeviceMemory<float> *> impl;
   2308   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
   2309 }
   2310 
   2311 Stream &Stream::ThenBlasRotmg(DeviceMemory<double> *d1,
   2312                               DeviceMemory<double> *d2,
   2313                               DeviceMemory<double> *x1,
   2314                               const DeviceMemory<double> &y1,
   2315                               DeviceMemory<double> *param) {
   2316   VLOG_CALL(PARAM(d1), PARAM(d2), PARAM(x1), PARAM(y1), PARAM(param));
   2317 
   2318   ThenBlasImpl<DeviceMemory<double> *, DeviceMemory<double> *,
   2319                DeviceMemory<double> *, const DeviceMemory<double> &,
   2320                DeviceMemory<double> *> impl;
   2321   return impl(this, &blas::BlasSupport::DoBlasRotmg, d1, d2, x1, y1, param);
   2322 }
   2323 
   2324 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
   2325                              DeviceMemory<float> *x, int incx) {
   2326   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2327 
   2328   ThenBlasImpl<uint64, float, DeviceMemory<float> *, int> impl;
   2329   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2330 }
   2331 
   2332 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
   2333                              DeviceMemory<double> *x, int incx) {
   2334   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2335 
   2336   ThenBlasImpl<uint64, double, DeviceMemory<double> *, int> impl;
   2337   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2338 }
   2339 
   2340 Stream &Stream::ThenBlasScal(uint64 elem_count, float alpha,
   2341                              DeviceMemory<std::complex<float>> *x, int incx) {
   2342   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2343 
   2344   ThenBlasImpl<uint64, float, DeviceMemory<std::complex<float>> *, int> impl;
   2345   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2346 }
   2347 
   2348 Stream &Stream::ThenBlasScal(uint64 elem_count, double alpha,
   2349                              DeviceMemory<std::complex<double>> *x, int incx) {
   2350   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2351 
   2352   ThenBlasImpl<uint64, double, DeviceMemory<std::complex<double>> *, int> impl;
   2353   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2354 }
   2355 
   2356 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<float> alpha,
   2357                              DeviceMemory<std::complex<float>> *x, int incx) {
   2358   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2359 
   2360   ThenBlasImpl<uint64, std::complex<float>, DeviceMemory<std::complex<float>> *,
   2361                int> impl;
   2362   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2363 }
   2364 
   2365 Stream &Stream::ThenBlasScal(uint64 elem_count, std::complex<double> alpha,
   2366                              DeviceMemory<std::complex<double>> *x, int incx) {
   2367   VLOG_CALL(PARAM(elem_count), PARAM(alpha), PARAM(x), PARAM(incx));
   2368 
   2369   ThenBlasImpl<uint64, std::complex<double>,
   2370                DeviceMemory<std::complex<double>> *, int> impl;
   2371   return impl(this, &blas::BlasSupport::DoBlasScal, elem_count, alpha, x, incx);
   2372 }
   2373 
   2374 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<float> *x,
   2375                              int incx, DeviceMemory<float> *y, int incy) {
   2376   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2377 
   2378   ThenBlasImpl<uint64, DeviceMemory<float> *, int, DeviceMemory<float> *, int>
   2379       impl;
   2380   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
   2381               incy);
   2382 }
   2383 
   2384 Stream &Stream::ThenBlasSwap(uint64 elem_count, DeviceMemory<double> *x,
   2385                              int incx, DeviceMemory<double> *y, int incy) {
   2386   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2387 
   2388   ThenBlasImpl<uint64, DeviceMemory<double> *, int, DeviceMemory<double> *, int>
   2389       impl;
   2390   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
   2391               incy);
   2392 }
   2393 
   2394 Stream &Stream::ThenBlasSwap(uint64 elem_count,
   2395                              DeviceMemory<std::complex<float>> *x, int incx,
   2396                              DeviceMemory<std::complex<float>> *y, int incy) {
   2397   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2398 
   2399   ThenBlasImpl<uint64, DeviceMemory<std::complex<float>> *, int,
   2400                DeviceMemory<std::complex<float>> *, int> impl;
   2401   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
   2402               incy);
   2403 }
   2404 
   2405 Stream &Stream::ThenBlasSwap(uint64 elem_count,
   2406                              DeviceMemory<std::complex<double>> *x, int incx,
   2407                              DeviceMemory<std::complex<double>> *y, int incy) {
   2408   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(y), PARAM(incy));
   2409 
   2410   ThenBlasImpl<uint64, DeviceMemory<std::complex<double>> *, int,
   2411                DeviceMemory<std::complex<double>> *, int> impl;
   2412   return impl(this, &blas::BlasSupport::DoBlasSwap, elem_count, x, incx, y,
   2413               incy);
   2414 }
   2415 
   2416 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<float> &x,
   2417                               int incx, DeviceMemory<int> *result) {
   2418   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2419 
   2420   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
   2421       impl;
   2422   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
   2423               result);
   2424 }
   2425 
   2426 Stream &Stream::ThenBlasIamax(uint64 elem_count, const DeviceMemory<double> &x,
   2427                               int incx, DeviceMemory<int> *result) {
   2428   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2429 
   2430   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
   2431       impl;
   2432   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
   2433               result);
   2434 }
   2435 
   2436 Stream &Stream::ThenBlasIamax(uint64 elem_count,
   2437                               const DeviceMemory<std::complex<float>> &x,
   2438                               int incx, DeviceMemory<int> *result) {
   2439   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2440 
   2441   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2442                DeviceMemory<int> *> impl;
   2443   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
   2444               result);
   2445 }
   2446 
   2447 Stream &Stream::ThenBlasIamax(uint64 elem_count,
   2448                               const DeviceMemory<std::complex<double>> &x,
   2449                               int incx, DeviceMemory<int> *result) {
   2450   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2451 
   2452   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2453                DeviceMemory<int> *> impl;
   2454   return impl(this, &blas::BlasSupport::DoBlasIamax, elem_count, x, incx,
   2455               result);
   2456 }
   2457 
   2458 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<float> &x,
   2459                               int incx, DeviceMemory<int> *result) {
   2460   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2461 
   2462   ThenBlasImpl<uint64, const DeviceMemory<float> &, int, DeviceMemory<int> *>
   2463       impl;
   2464   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
   2465               result);
   2466 }
   2467 
   2468 Stream &Stream::ThenBlasIamin(uint64 elem_count, const DeviceMemory<double> &x,
   2469                               int incx, DeviceMemory<int> *result) {
   2470   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2471 
   2472   ThenBlasImpl<uint64, const DeviceMemory<double> &, int, DeviceMemory<int> *>
   2473       impl;
   2474   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
   2475               result);
   2476 }
   2477 
   2478 Stream &Stream::ThenBlasIamin(uint64 elem_count,
   2479                               const DeviceMemory<std::complex<float>> &x,
   2480                               int incx, DeviceMemory<int> *result) {
   2481   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2482 
   2483   ThenBlasImpl<uint64, const DeviceMemory<std::complex<float>> &, int,
   2484                DeviceMemory<int> *> impl;
   2485   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
   2486               result);
   2487 }
   2488 
   2489 Stream &Stream::ThenBlasIamin(uint64 elem_count,
   2490                               const DeviceMemory<std::complex<double>> &x,
   2491                               int incx, DeviceMemory<int> *result) {
   2492   VLOG_CALL(PARAM(elem_count), PARAM(x), PARAM(incx), PARAM(result));
   2493 
   2494   ThenBlasImpl<uint64, const DeviceMemory<std::complex<double>> &, int,
   2495                DeviceMemory<int> *> impl;
   2496   return impl(this, &blas::BlasSupport::DoBlasIamin, elem_count, x, incx,
   2497               result);
   2498 }
   2499 
   2500 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
   2501                              uint64 kl, uint64 ku, float alpha,
   2502                              const DeviceMemory<float> &a, int lda,
   2503                              const DeviceMemory<float> &x, int incx, float beta,
   2504                              DeviceMemory<float> *y, int incy) {
   2505   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
   2506             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
   2507             PARAM(beta), PARAM(y), PARAM(incy));
   2508 
   2509   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, float,
   2510                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   2511                int, float, DeviceMemory<float> *, int> impl;
   2512   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
   2513               a, lda, x, incx, beta, y, incy);
   2514 }
   2515 
   2516 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
   2517                              uint64 kl, uint64 ku, double alpha,
   2518                              const DeviceMemory<double> &a, int lda,
   2519                              const DeviceMemory<double> &x, int incx,
   2520                              double beta, DeviceMemory<double> *y, int incy) {
   2521   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
   2522             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
   2523             PARAM(beta), PARAM(y), PARAM(incy));
   2524 
   2525   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64, double,
   2526                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   2527                int, double, DeviceMemory<double> *, int> impl;
   2528   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
   2529               a, lda, x, incx, beta, y, incy);
   2530 }
   2531 
   2532 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
   2533                              uint64 kl, uint64 ku, std::complex<float> alpha,
   2534                              const DeviceMemory<std::complex<float>> &a,
   2535                              int lda,
   2536                              const DeviceMemory<std::complex<float>> &x,
   2537                              int incx, std::complex<float> beta,
   2538                              DeviceMemory<std::complex<float>> *y, int incy) {
   2539   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
   2540             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
   2541             PARAM(beta), PARAM(y), PARAM(incy));
   2542 
   2543   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
   2544                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   2545                int, const DeviceMemory<std::complex<float>> &, int,
   2546                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2547                int> impl;
   2548   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
   2549               a, lda, x, incx, beta, y, incy);
   2550 }
   2551 
   2552 Stream &Stream::ThenBlasGbmv(blas::Transpose trans, uint64 m, uint64 n,
   2553                              uint64 kl, uint64 ku, std::complex<double> alpha,
   2554                              const DeviceMemory<std::complex<double>> &a,
   2555                              int lda,
   2556                              const DeviceMemory<std::complex<double>> &x,
   2557                              int incx, std::complex<double> beta,
   2558                              DeviceMemory<std::complex<double>> *y, int incy) {
   2559   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(kl), PARAM(ku),
   2560             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x), PARAM(incx),
   2561             PARAM(beta), PARAM(y), PARAM(incy));
   2562 
   2563   ThenBlasImpl<blas::Transpose, uint64, uint64, uint64, uint64,
   2564                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   2565                int, const DeviceMemory<std::complex<double>> &, int,
   2566                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2567                int> impl;
   2568   return impl(this, &blas::BlasSupport::DoBlasGbmv, trans, m, n, kl, ku, alpha,
   2569               a, lda, x, incx, beta, y, incy);
   2570 }
   2571 
   2572 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   2573                              float alpha, const DeviceMemory<float> &a, int lda,
   2574                              const DeviceMemory<float> &x, int incx, float beta,
   2575                              DeviceMemory<float> *y, int incy) {
   2576   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   2577             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   2578             PARAM(incy));
   2579 
   2580   ThenBlasImpl<blas::Transpose, uint64, uint64, float,
   2581                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   2582                int, float, DeviceMemory<float> *, int> impl;
   2583   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
   2584               x, incx, beta, y, incy);
   2585 }
   2586 
   2587 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   2588                              double alpha, const DeviceMemory<double> &a,
   2589                              int lda, const DeviceMemory<double> &x, int incx,
   2590                              double beta, DeviceMemory<double> *y, int incy) {
   2591   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   2592             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   2593             PARAM(incy));
   2594 
   2595   ThenBlasImpl<blas::Transpose, uint64, uint64, double,
   2596                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   2597                int, double, DeviceMemory<double> *, int> impl;
   2598   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
   2599               x, incx, beta, y, incy);
   2600 }
   2601 
   2602 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   2603                              std::complex<float> alpha,
   2604                              const DeviceMemory<std::complex<float>> &a,
   2605                              int lda,
   2606                              const DeviceMemory<std::complex<float>> &x,
   2607                              int incx, std::complex<float> beta,
   2608                              DeviceMemory<std::complex<float>> *y, int incy) {
   2609   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   2610             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   2611             PARAM(incy));
   2612 
   2613   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<float>,
   2614                const DeviceMemory<std::complex<float>> &, int,
   2615                const DeviceMemory<std::complex<float>> &, int,
   2616                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2617                int> impl;
   2618   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
   2619               x, incx, beta, y, incy);
   2620 }
   2621 
   2622 Stream &Stream::ThenBlasGemv(blas::Transpose trans, uint64 m, uint64 n,
   2623                              std::complex<double> alpha,
   2624                              const DeviceMemory<std::complex<double>> &a,
   2625                              int lda,
   2626                              const DeviceMemory<std::complex<double>> &x,
   2627                              int incx, std::complex<double> beta,
   2628                              DeviceMemory<std::complex<double>> *y, int incy) {
   2629   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   2630             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   2631             PARAM(incy));
   2632 
   2633   ThenBlasImpl<blas::Transpose, uint64, uint64, std::complex<double>,
   2634                const DeviceMemory<std::complex<double>> &, int,
   2635                const DeviceMemory<std::complex<double>> &, int,
   2636                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2637                int> impl;
   2638   return impl(this, &blas::BlasSupport::DoBlasGemv, trans, m, n, alpha, a, lda,
   2639               x, incx, beta, y, incy);
   2640 }
   2641 
   2642 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, float alpha,
   2643                             const DeviceMemory<float> &x, int incx,
   2644                             const DeviceMemory<float> &y, int incy,
   2645                             DeviceMemory<float> *a, int lda) {
   2646   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2647             PARAM(incy), PARAM(a), PARAM(lda));
   2648 
   2649   ThenBlasImpl<uint64, uint64, float, const DeviceMemory<float> &, int,
   2650                const DeviceMemory<float> &, int, DeviceMemory<float> *,
   2651                int> impl;
   2652   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
   2653               incy, a, lda);
   2654 }
   2655 
   2656 Stream &Stream::ThenBlasGer(uint64 m, uint64 n, double alpha,
   2657                             const DeviceMemory<double> &x, int incx,
   2658                             const DeviceMemory<double> &y, int incy,
   2659                             DeviceMemory<double> *a, int lda) {
   2660   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2661             PARAM(incy), PARAM(a), PARAM(lda));
   2662 
   2663   ThenBlasImpl<uint64, uint64, double, const DeviceMemory<double> &, int,
   2664                const DeviceMemory<double> &, int, DeviceMemory<double> *,
   2665                int> impl;
   2666   return impl(this, &blas::BlasSupport::DoBlasGer, m, n, alpha, x, incx, y,
   2667               incy, a, lda);
   2668 }
   2669 
   2670 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<float> alpha,
   2671                              const DeviceMemory<std::complex<float>> &x,
   2672                              int incx,
   2673                              const DeviceMemory<std::complex<float>> &y,
   2674                              int incy, DeviceMemory<std::complex<float>> *a,
   2675                              int lda) {
   2676   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2677             PARAM(incy), PARAM(a), PARAM(lda));
   2678 
   2679   ThenBlasImpl<uint64, uint64, std::complex<float>,
   2680                const DeviceMemory<std::complex<float>> &, int,
   2681                const DeviceMemory<std::complex<float>> &, int,
   2682                DeviceMemory<std::complex<float>> *, int> impl;
   2683   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
   2684               incy, a, lda);
   2685 }
   2686 
   2687 Stream &Stream::ThenBlasGerc(uint64 m, uint64 n, std::complex<double> alpha,
   2688                              const DeviceMemory<std::complex<double>> &x,
   2689                              int incx,
   2690                              const DeviceMemory<std::complex<double>> &y,
   2691                              int incy, DeviceMemory<std::complex<double>> *a,
   2692                              int lda) {
   2693   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2694             PARAM(incy), PARAM(a), PARAM(lda));
   2695 
   2696   ThenBlasImpl<uint64, uint64, std::complex<double>,
   2697                const DeviceMemory<std::complex<double>> &, int,
   2698                const DeviceMemory<std::complex<double>> &, int,
   2699                DeviceMemory<std::complex<double>> *, int> impl;
   2700   return impl(this, &blas::BlasSupport::DoBlasGerc, m, n, alpha, x, incx, y,
   2701               incy, a, lda);
   2702 }
   2703 
   2704 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<float> alpha,
   2705                              const DeviceMemory<std::complex<float>> &x,
   2706                              int incx,
   2707                              const DeviceMemory<std::complex<float>> &y,
   2708                              int incy, DeviceMemory<std::complex<float>> *a,
   2709                              int lda) {
   2710   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2711             PARAM(incy), PARAM(a), PARAM(lda));
   2712 
   2713   ThenBlasImpl<uint64, uint64, std::complex<float>,
   2714                const DeviceMemory<std::complex<float>> &, int,
   2715                const DeviceMemory<std::complex<float>> &, int,
   2716                DeviceMemory<std::complex<float>> *, int> impl;
   2717   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
   2718               incy, a, lda);
   2719 }
   2720 
   2721 Stream &Stream::ThenBlasGeru(uint64 m, uint64 n, std::complex<double> alpha,
   2722                              const DeviceMemory<std::complex<double>> &x,
   2723                              int incx,
   2724                              const DeviceMemory<std::complex<double>> &y,
   2725                              int incy, DeviceMemory<std::complex<double>> *a,
   2726                              int lda) {
   2727   VLOG_CALL(PARAM(m), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx), PARAM(y),
   2728             PARAM(incy), PARAM(a), PARAM(lda));
   2729 
   2730   ThenBlasImpl<uint64, uint64, std::complex<double>,
   2731                const DeviceMemory<std::complex<double>> &, int,
   2732                const DeviceMemory<std::complex<double>> &, int,
   2733                DeviceMemory<std::complex<double>> *, int> impl;
   2734   return impl(this, &blas::BlasSupport::DoBlasGeru, m, n, alpha, x, incx, y,
   2735               incy, a, lda);
   2736 }
   2737 
   2738 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   2739                              std::complex<float> alpha,
   2740                              const DeviceMemory<std::complex<float>> &a,
   2741                              int lda,
   2742                              const DeviceMemory<std::complex<float>> &x,
   2743                              int incx, std::complex<float> beta,
   2744                              DeviceMemory<std::complex<float>> *y, int incy) {
   2745   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
   2746             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2747 
   2748   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<float>,
   2749                const DeviceMemory<std::complex<float>> &, int,
   2750                const DeviceMemory<std::complex<float>> &, int,
   2751                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2752                int> impl;
   2753   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
   2754               x, incx, beta, y, incy);
   2755 }
   2756 
   2757 Stream &Stream::ThenBlasHbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   2758                              std::complex<double> alpha,
   2759                              const DeviceMemory<std::complex<double>> &a,
   2760                              int lda,
   2761                              const DeviceMemory<std::complex<double>> &x,
   2762                              int incx, std::complex<double> beta,
   2763                              DeviceMemory<std::complex<double>> *y, int incy) {
   2764   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
   2765             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2766 
   2767   ThenBlasImpl<blas::UpperLower, uint64, uint64, std::complex<double>,
   2768                const DeviceMemory<std::complex<double>> &, int,
   2769                const DeviceMemory<std::complex<double>> &, int,
   2770                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2771                int> impl;
   2772   return impl(this, &blas::BlasSupport::DoBlasHbmv, uplo, n, k, alpha, a, lda,
   2773               x, incx, beta, y, incy);
   2774 }
   2775 
   2776 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
   2777                              std::complex<float> alpha,
   2778                              const DeviceMemory<std::complex<float>> &a,
   2779                              int lda,
   2780                              const DeviceMemory<std::complex<float>> &x,
   2781                              int incx, std::complex<float> beta,
   2782                              DeviceMemory<std::complex<float>> *y, int incy) {
   2783   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
   2784             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2785 
   2786   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
   2787                const DeviceMemory<std::complex<float>> &, int,
   2788                const DeviceMemory<std::complex<float>> &, int,
   2789                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2790                int> impl;
   2791   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
   2792               incx, beta, y, incy);
   2793 }
   2794 
   2795 Stream &Stream::ThenBlasHemv(blas::UpperLower uplo, uint64 n,
   2796                              std::complex<double> alpha,
   2797                              const DeviceMemory<std::complex<double>> &a,
   2798                              int lda,
   2799                              const DeviceMemory<std::complex<double>> &x,
   2800                              int incx, std::complex<double> beta,
   2801                              DeviceMemory<std::complex<double>> *y, int incy) {
   2802   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
   2803             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2804 
   2805   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
   2806                const DeviceMemory<std::complex<double>> &, int,
   2807                const DeviceMemory<std::complex<double>> &, int,
   2808                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2809                int> impl;
   2810   return impl(this, &blas::BlasSupport::DoBlasHemv, uplo, n, alpha, a, lda, x,
   2811               incx, beta, y, incy);
   2812 }
   2813 
   2814 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, float alpha,
   2815                             const DeviceMemory<std::complex<float>> &x,
   2816                             int incx, DeviceMemory<std::complex<float>> *a,
   2817                             int lda) {
   2818   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2819             PARAM(a), PARAM(lda));
   2820 
   2821   ThenBlasImpl<blas::UpperLower, uint64, float,
   2822                const DeviceMemory<std::complex<float>> &, int,
   2823                DeviceMemory<std::complex<float>> *, int> impl;
   2824   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
   2825               lda);
   2826 }
   2827 
   2828 Stream &Stream::ThenBlasHer(blas::UpperLower uplo, uint64 n, double alpha,
   2829                             const DeviceMemory<std::complex<double>> &x,
   2830                             int incx, DeviceMemory<std::complex<double>> *a,
   2831                             int lda) {
   2832   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2833             PARAM(a), PARAM(lda));
   2834 
   2835   ThenBlasImpl<blas::UpperLower, uint64, double,
   2836                const DeviceMemory<std::complex<double>> &, int,
   2837                DeviceMemory<std::complex<double>> *, int> impl;
   2838   return impl(this, &blas::BlasSupport::DoBlasHer, uplo, n, alpha, x, incx, a,
   2839               lda);
   2840 }
   2841 
   2842 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
   2843                              std::complex<float> alpha,
   2844                              const DeviceMemory<std::complex<float>> &x,
   2845                              int incx,
   2846                              const DeviceMemory<std::complex<float>> &y,
   2847                              int incy, DeviceMemory<std::complex<float>> *a,
   2848                              int lda) {
   2849   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2850             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
   2851 
   2852   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
   2853                const DeviceMemory<std::complex<float>> &, int,
   2854                const DeviceMemory<std::complex<float>> &, int,
   2855                DeviceMemory<std::complex<float>> *, int> impl;
   2856   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
   2857               incy, a, lda);
   2858 }
   2859 
   2860 Stream &Stream::ThenBlasHer2(blas::UpperLower uplo, uint64 n,
   2861                              std::complex<double> alpha,
   2862                              const DeviceMemory<std::complex<double>> &x,
   2863                              int incx,
   2864                              const DeviceMemory<std::complex<double>> &y,
   2865                              int incy, DeviceMemory<std::complex<double>> *a,
   2866                              int lda) {
   2867   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2868             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
   2869 
   2870   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
   2871                const DeviceMemory<std::complex<double>> &, int,
   2872                const DeviceMemory<std::complex<double>> &, int,
   2873                DeviceMemory<std::complex<double>> *, int> impl;
   2874   return impl(this, &blas::BlasSupport::DoBlasHer2, uplo, n, alpha, x, incx, y,
   2875               incy, a, lda);
   2876 }
   2877 
   2878 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
   2879                              std::complex<float> alpha,
   2880                              const DeviceMemory<std::complex<float>> &ap,
   2881                              const DeviceMemory<std::complex<float>> &x,
   2882                              int incx, std::complex<float> beta,
   2883                              DeviceMemory<std::complex<float>> *y, int incy) {
   2884   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
   2885             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2886 
   2887   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
   2888                const DeviceMemory<std::complex<float>> &,
   2889                const DeviceMemory<std::complex<float>> &, int,
   2890                std::complex<float>, DeviceMemory<std::complex<float>> *,
   2891                int> impl;
   2892   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
   2893               beta, y, incy);
   2894 }
   2895 
   2896 Stream &Stream::ThenBlasHpmv(blas::UpperLower uplo, uint64 n,
   2897                              std::complex<double> alpha,
   2898                              const DeviceMemory<std::complex<double>> &ap,
   2899                              const DeviceMemory<std::complex<double>> &x,
   2900                              int incx, std::complex<double> beta,
   2901                              DeviceMemory<std::complex<double>> *y, int incy) {
   2902   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
   2903             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2904 
   2905   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
   2906                const DeviceMemory<std::complex<double>> &,
   2907                const DeviceMemory<std::complex<double>> &, int,
   2908                std::complex<double>, DeviceMemory<std::complex<double>> *,
   2909                int> impl;
   2910   return impl(this, &blas::BlasSupport::DoBlasHpmv, uplo, n, alpha, ap, x, incx,
   2911               beta, y, incy);
   2912 }
   2913 
   2914 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, float alpha,
   2915                             const DeviceMemory<std::complex<float>> &x,
   2916                             int incx, DeviceMemory<std::complex<float>> *ap) {
   2917   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2918             PARAM(ap));
   2919 
   2920   ThenBlasImpl<blas::UpperLower, uint64, float,
   2921                const DeviceMemory<std::complex<float>> &, int,
   2922                DeviceMemory<std::complex<float>> *> impl;
   2923   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
   2924 }
   2925 
   2926 Stream &Stream::ThenBlasHpr(blas::UpperLower uplo, uint64 n, double alpha,
   2927                             const DeviceMemory<std::complex<double>> &x,
   2928                             int incx, DeviceMemory<std::complex<double>> *ap) {
   2929   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2930             PARAM(ap));
   2931 
   2932   ThenBlasImpl<blas::UpperLower, uint64, double,
   2933                const DeviceMemory<std::complex<double>> &, int,
   2934                DeviceMemory<std::complex<double>> *> impl;
   2935   return impl(this, &blas::BlasSupport::DoBlasHpr, uplo, n, alpha, x, incx, ap);
   2936 }
   2937 
   2938 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
   2939                              std::complex<float> alpha,
   2940                              const DeviceMemory<std::complex<float>> &x,
   2941                              int incx,
   2942                              const DeviceMemory<std::complex<float>> &y,
   2943                              int incy, DeviceMemory<std::complex<float>> *ap) {
   2944   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2945             PARAM(y), PARAM(incy), PARAM(ap));
   2946 
   2947   ThenBlasImpl<blas::UpperLower, uint64, std::complex<float>,
   2948                const DeviceMemory<std::complex<float>> &, int,
   2949                const DeviceMemory<std::complex<float>> &, int,
   2950                DeviceMemory<std::complex<float>> *> impl;
   2951   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
   2952               incy, ap);
   2953 }
   2954 
   2955 Stream &Stream::ThenBlasHpr2(blas::UpperLower uplo, uint64 n,
   2956                              std::complex<double> alpha,
   2957                              const DeviceMemory<std::complex<double>> &x,
   2958                              int incx,
   2959                              const DeviceMemory<std::complex<double>> &y,
   2960                              int incy, DeviceMemory<std::complex<double>> *ap) {
   2961   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   2962             PARAM(y), PARAM(incy), PARAM(ap));
   2963 
   2964   ThenBlasImpl<blas::UpperLower, uint64, std::complex<double>,
   2965                const DeviceMemory<std::complex<double>> &, int,
   2966                const DeviceMemory<std::complex<double>> &, int,
   2967                DeviceMemory<std::complex<double>> *> impl;
   2968   return impl(this, &blas::BlasSupport::DoBlasHpr2, uplo, n, alpha, x, incx, y,
   2969               incy, ap);
   2970 }
   2971 
   2972 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   2973                              float alpha, const DeviceMemory<float> &a, int lda,
   2974                              const DeviceMemory<float> &x, int incx, float beta,
   2975                              DeviceMemory<float> *y, int incy) {
   2976   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
   2977             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2978 
   2979   ThenBlasImpl<blas::UpperLower, uint64, uint64, float,
   2980                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   2981                int, float, DeviceMemory<float> *, int> impl;
   2982   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
   2983               x, incx, beta, y, incy);
   2984 }
   2985 
   2986 Stream &Stream::ThenBlasSbmv(blas::UpperLower uplo, uint64 n, uint64 k,
   2987                              double alpha, const DeviceMemory<double> &a,
   2988                              int lda, const DeviceMemory<double> &x, int incx,
   2989                              double beta, DeviceMemory<double> *y, int incy) {
   2990   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda),
   2991             PARAM(x), PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   2992 
   2993   ThenBlasImpl<blas::UpperLower, uint64, uint64, double,
   2994                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   2995                int, double, DeviceMemory<double> *, int> impl;
   2996   return impl(this, &blas::BlasSupport::DoBlasSbmv, uplo, n, k, alpha, a, lda,
   2997               x, incx, beta, y, incy);
   2998 }
   2999 
   3000 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, float alpha,
   3001                              const DeviceMemory<float> &ap,
   3002                              const DeviceMemory<float> &x, int incx, float beta,
   3003                              DeviceMemory<float> *y, int incy) {
   3004   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
   3005             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   3006 
   3007   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3008                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
   3009                int> impl;
   3010   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
   3011               beta, y, incy);
   3012 }
   3013 
   3014 Stream &Stream::ThenBlasSpmv(blas::UpperLower uplo, uint64 n, double alpha,
   3015                              const DeviceMemory<double> &ap,
   3016                              const DeviceMemory<double> &x, int incx,
   3017                              double beta, DeviceMemory<double> *y, int incy) {
   3018   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(ap), PARAM(x),
   3019             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   3020 
   3021   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3022                const DeviceMemory<double> &, int, double,
   3023                DeviceMemory<double> *, int> impl;
   3024   return impl(this, &blas::BlasSupport::DoBlasSpmv, uplo, n, alpha, ap, x, incx,
   3025               beta, y, incy);
   3026 }
   3027 
   3028 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, float alpha,
   3029                             const DeviceMemory<float> &x, int incx,
   3030                             DeviceMemory<float> *ap) {
   3031   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3032             PARAM(ap));
   3033 
   3034   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3035                int, DeviceMemory<float> *> impl;
   3036   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
   3037 }
   3038 
   3039 Stream &Stream::ThenBlasSpr(blas::UpperLower uplo, uint64 n, double alpha,
   3040                             const DeviceMemory<double> &x, int incx,
   3041                             DeviceMemory<double> *ap) {
   3042   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3043             PARAM(ap));
   3044 
   3045   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3046                int, DeviceMemory<double> *> impl;
   3047   return impl(this, &blas::BlasSupport::DoBlasSpr, uplo, n, alpha, x, incx, ap);
   3048 }
   3049 
   3050 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, float alpha,
   3051                              const DeviceMemory<float> &x, int incx,
   3052                              const DeviceMemory<float> &y, int incy,
   3053                              DeviceMemory<float> *ap) {
   3054   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3055             PARAM(y), PARAM(incy), PARAM(ap));
   3056 
   3057   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3058                int, const DeviceMemory<float> &, int,
   3059                DeviceMemory<float> *> impl;
   3060   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
   3061               incy, ap);
   3062 }
   3063 
   3064 Stream &Stream::ThenBlasSpr2(blas::UpperLower uplo, uint64 n, double alpha,
   3065                              const DeviceMemory<double> &x, int incx,
   3066                              const DeviceMemory<double> &y, int incy,
   3067                              DeviceMemory<double> *ap) {
   3068   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3069             PARAM(y), PARAM(incy), PARAM(ap));
   3070 
   3071   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3072                int, const DeviceMemory<double> &, int,
   3073                DeviceMemory<double> *> impl;
   3074   return impl(this, &blas::BlasSupport::DoBlasSpr2, uplo, n, alpha, x, incx, y,
   3075               incy, ap);
   3076 }
   3077 
   3078 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, float alpha,
   3079                              const DeviceMemory<float> &a, int lda,
   3080                              const DeviceMemory<float> &x, int incx, float beta,
   3081                              DeviceMemory<float> *y, int incy) {
   3082   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
   3083             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   3084 
   3085   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3086                int, const DeviceMemory<float> &, int, float,
   3087                DeviceMemory<float> *, int> impl;
   3088   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
   3089               incx, beta, y, incy);
   3090 }
   3091 
   3092 Stream &Stream::ThenBlasSymv(blas::UpperLower uplo, uint64 n, double alpha,
   3093                              const DeviceMemory<double> &a, int lda,
   3094                              const DeviceMemory<double> &x, int incx,
   3095                              double beta, DeviceMemory<double> *y, int incy) {
   3096   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(x),
   3097             PARAM(incx), PARAM(beta), PARAM(y), PARAM(incy));
   3098 
   3099   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3100                int, const DeviceMemory<double> &, int, double,
   3101                DeviceMemory<double> *, int> impl;
   3102   return impl(this, &blas::BlasSupport::DoBlasSymv, uplo, n, alpha, a, lda, x,
   3103               incx, beta, y, incy);
   3104 }
   3105 
   3106 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, float alpha,
   3107                             const DeviceMemory<float> &x, int incx,
   3108                             DeviceMemory<float> *a, int lda) {
   3109   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3110             PARAM(a), PARAM(lda));
   3111 
   3112   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3113                int, DeviceMemory<float> *, int> impl;
   3114   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
   3115               lda);
   3116 }
   3117 
   3118 Stream &Stream::ThenBlasSyr(blas::UpperLower uplo, uint64 n, double alpha,
   3119                             const DeviceMemory<double> &x, int incx,
   3120                             DeviceMemory<double> *a, int lda) {
   3121   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3122             PARAM(a), PARAM(lda));
   3123 
   3124   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3125                int, DeviceMemory<double> *, int> impl;
   3126   return impl(this, &blas::BlasSupport::DoBlasSyr, uplo, n, alpha, x, incx, a,
   3127               lda);
   3128 }
   3129 
   3130 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, float alpha,
   3131                              const DeviceMemory<float> &x, int incx,
   3132                              const DeviceMemory<float> &y, int incy,
   3133                              DeviceMemory<float> *a, int lda) {
   3134   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3135             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
   3136 
   3137   ThenBlasImpl<blas::UpperLower, uint64, float, const DeviceMemory<float> &,
   3138                int, const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3139                int> impl;
   3140   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
   3141               incy, a, lda);
   3142 }
   3143 
   3144 Stream &Stream::ThenBlasSyr2(blas::UpperLower uplo, uint64 n, double alpha,
   3145                              const DeviceMemory<double> &x, int incx,
   3146                              const DeviceMemory<double> &y, int incy,
   3147                              DeviceMemory<double> *a, int lda) {
   3148   VLOG_CALL(PARAM(uplo), PARAM(n), PARAM(alpha), PARAM(x), PARAM(incx),
   3149             PARAM(y), PARAM(incy), PARAM(a), PARAM(lda));
   3150 
   3151   ThenBlasImpl<blas::UpperLower, uint64, double, const DeviceMemory<double> &,
   3152                int, const DeviceMemory<double> &, int, DeviceMemory<double> *,
   3153                int> impl;
   3154   return impl(this, &blas::BlasSupport::DoBlasSyr2, uplo, n, alpha, x, incx, y,
   3155               incy, a, lda);
   3156 }
   3157 
   3158 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   3159                              blas::Diagonal diag, uint64 n, uint64 k,
   3160                              const DeviceMemory<float> &a, int lda,
   3161                              DeviceMemory<float> *x, int incx) {
   3162   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3163             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3164 
   3165   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3166                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3167                int> impl;
   3168   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
   3169               lda, x, incx);
   3170 }
   3171 
   3172 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   3173                              blas::Diagonal diag, uint64 n, uint64 k,
   3174                              const DeviceMemory<double> &a, int lda,
   3175                              DeviceMemory<double> *x, int incx) {
   3176   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3177             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3178 
   3179   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3180                uint64, const DeviceMemory<double> &, int,
   3181                DeviceMemory<double> *, int> impl;
   3182   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
   3183               lda, x, incx);
   3184 }
   3185 
   3186 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   3187                              blas::Diagonal diag, uint64 n, uint64 k,
   3188                              const DeviceMemory<std::complex<float>> &a,
   3189                              int lda, DeviceMemory<std::complex<float>> *x,
   3190                              int incx) {
   3191   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3192             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3193 
   3194   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3195                uint64, const DeviceMemory<std::complex<float>> &, int,
   3196                DeviceMemory<std::complex<float>> *, int> impl;
   3197   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
   3198               lda, x, incx);
   3199 }
   3200 
   3201 Stream &Stream::ThenBlasTbmv(blas::UpperLower uplo, blas::Transpose trans,
   3202                              blas::Diagonal diag, uint64 n, uint64 k,
   3203                              const DeviceMemory<std::complex<double>> &a,
   3204                              int lda, DeviceMemory<std::complex<double>> *x,
   3205                              int incx) {
   3206   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3207             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3208 
   3209   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3210                uint64, const DeviceMemory<std::complex<double>> &, int,
   3211                DeviceMemory<std::complex<double>> *, int> impl;
   3212   return impl(this, &blas::BlasSupport::DoBlasTbmv, uplo, trans, diag, n, k, a,
   3213               lda, x, incx);
   3214 }
   3215 
   3216 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   3217                              blas::Diagonal diag, uint64 n, uint64 k,
   3218                              const DeviceMemory<float> &a, int lda,
   3219                              DeviceMemory<float> *x, int incx) {
   3220   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3221             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3222 
   3223   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3224                uint64, const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3225                int> impl;
   3226   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
   3227               lda, x, incx);
   3228 }
   3229 
   3230 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   3231                              blas::Diagonal diag, uint64 n, uint64 k,
   3232                              const DeviceMemory<double> &a, int lda,
   3233                              DeviceMemory<double> *x, int incx) {
   3234   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3235             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3236 
   3237   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3238                uint64, const DeviceMemory<double> &, int,
   3239                DeviceMemory<double> *, int> impl;
   3240   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
   3241               lda, x, incx);
   3242 }
   3243 
   3244 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   3245                              blas::Diagonal diag, uint64 n, uint64 k,
   3246                              const DeviceMemory<std::complex<float>> &a,
   3247                              int lda, DeviceMemory<std::complex<float>> *x,
   3248                              int incx) {
   3249   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3250             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3251 
   3252   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3253                uint64, const DeviceMemory<std::complex<float>> &, int,
   3254                DeviceMemory<std::complex<float>> *, int> impl;
   3255   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
   3256               lda, x, incx);
   3257 }
   3258 
   3259 Stream &Stream::ThenBlasTbsv(blas::UpperLower uplo, blas::Transpose trans,
   3260                              blas::Diagonal diag, uint64 n, uint64 k,
   3261                              const DeviceMemory<std::complex<double>> &a,
   3262                              int lda, DeviceMemory<std::complex<double>> *x,
   3263                              int incx) {
   3264   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(k),
   3265             PARAM(a), PARAM(lda), PARAM(x), PARAM(incx));
   3266 
   3267   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3268                uint64, const DeviceMemory<std::complex<double>> &, int,
   3269                DeviceMemory<std::complex<double>> *, int> impl;
   3270   return impl(this, &blas::BlasSupport::DoBlasTbsv, uplo, trans, diag, n, k, a,
   3271               lda, x, incx);
   3272 }
   3273 
   3274 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   3275                              blas::Diagonal diag, uint64 n,
   3276                              const DeviceMemory<float> &ap,
   3277                              DeviceMemory<float> *x, int incx) {
   3278   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3279             PARAM(x), PARAM(incx));
   3280 
   3281   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3282                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
   3283   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
   3284               incx);
   3285 }
   3286 
   3287 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   3288                              blas::Diagonal diag, uint64 n,
   3289                              const DeviceMemory<double> &ap,
   3290                              DeviceMemory<double> *x, int incx) {
   3291   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3292             PARAM(x), PARAM(incx));
   3293 
   3294   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3295                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
   3296   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
   3297               incx);
   3298 }
   3299 
   3300 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   3301                              blas::Diagonal diag, uint64 n,
   3302                              const DeviceMemory<std::complex<float>> &ap,
   3303                              DeviceMemory<std::complex<float>> *x, int incx) {
   3304   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3305             PARAM(x), PARAM(incx));
   3306 
   3307   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3308                const DeviceMemory<std::complex<float>> &,
   3309                DeviceMemory<std::complex<float>> *, int> impl;
   3310   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
   3311               incx);
   3312 }
   3313 
   3314 Stream &Stream::ThenBlasTpmv(blas::UpperLower uplo, blas::Transpose trans,
   3315                              blas::Diagonal diag, uint64 n,
   3316                              const DeviceMemory<std::complex<double>> &ap,
   3317                              DeviceMemory<std::complex<double>> *x, int incx) {
   3318   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3319             PARAM(x), PARAM(incx));
   3320 
   3321   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3322                const DeviceMemory<std::complex<double>> &,
   3323                DeviceMemory<std::complex<double>> *, int> impl;
   3324   return impl(this, &blas::BlasSupport::DoBlasTpmv, uplo, trans, diag, n, ap, x,
   3325               incx);
   3326 }
   3327 
   3328 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   3329                              blas::Diagonal diag, uint64 n,
   3330                              const DeviceMemory<float> &ap,
   3331                              DeviceMemory<float> *x, int incx) {
   3332   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3333             PARAM(x), PARAM(incx));
   3334 
   3335   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3336                const DeviceMemory<float> &, DeviceMemory<float> *, int> impl;
   3337   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
   3338               incx);
   3339 }
   3340 
   3341 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   3342                              blas::Diagonal diag, uint64 n,
   3343                              const DeviceMemory<double> &ap,
   3344                              DeviceMemory<double> *x, int incx) {
   3345   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3346             PARAM(x), PARAM(incx));
   3347 
   3348   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3349                const DeviceMemory<double> &, DeviceMemory<double> *, int> impl;
   3350   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
   3351               incx);
   3352 }
   3353 
   3354 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   3355                              blas::Diagonal diag, uint64 n,
   3356                              const DeviceMemory<std::complex<float>> &ap,
   3357                              DeviceMemory<std::complex<float>> *x, int incx) {
   3358   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3359             PARAM(x), PARAM(incx));
   3360 
   3361   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3362                const DeviceMemory<std::complex<float>> &,
   3363                DeviceMemory<std::complex<float>> *, int> impl;
   3364   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
   3365               incx);
   3366 }
   3367 
   3368 Stream &Stream::ThenBlasTpsv(blas::UpperLower uplo, blas::Transpose trans,
   3369                              blas::Diagonal diag, uint64 n,
   3370                              const DeviceMemory<std::complex<double>> &ap,
   3371                              DeviceMemory<std::complex<double>> *x, int incx) {
   3372   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(ap),
   3373             PARAM(x), PARAM(incx));
   3374 
   3375   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3376                const DeviceMemory<std::complex<double>> &,
   3377                DeviceMemory<std::complex<double>> *, int> impl;
   3378   return impl(this, &blas::BlasSupport::DoBlasTpsv, uplo, trans, diag, n, ap, x,
   3379               incx);
   3380 }
   3381 
   3382 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   3383                              blas::Diagonal diag, uint64 n,
   3384                              const DeviceMemory<float> &a, int lda,
   3385                              DeviceMemory<float> *x, int incx) {
   3386   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3387             PARAM(lda), PARAM(x), PARAM(incx));
   3388 
   3389   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3390                const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3391                int> impl;
   3392   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
   3393               lda, x, incx);
   3394 }
   3395 
   3396 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   3397                              blas::Diagonal diag, uint64 n,
   3398                              const DeviceMemory<double> &a, int lda,
   3399                              DeviceMemory<double> *x, int incx) {
   3400   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3401             PARAM(lda), PARAM(x), PARAM(incx));
   3402 
   3403   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3404                const DeviceMemory<double> &, int, DeviceMemory<double> *,
   3405                int> impl;
   3406   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
   3407               lda, x, incx);
   3408 }
   3409 
   3410 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   3411                              blas::Diagonal diag, uint64 n,
   3412                              const DeviceMemory<std::complex<float>> &a,
   3413                              int lda, DeviceMemory<std::complex<float>> *x,
   3414                              int incx) {
   3415   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3416             PARAM(lda), PARAM(x), PARAM(incx));
   3417 
   3418   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3419                const DeviceMemory<std::complex<float>> &, int,
   3420                DeviceMemory<std::complex<float>> *, int> impl;
   3421   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
   3422               lda, x, incx);
   3423 }
   3424 
   3425 Stream &Stream::ThenBlasTrmv(blas::UpperLower uplo, blas::Transpose trans,
   3426                              blas::Diagonal diag, uint64 n,
   3427                              const DeviceMemory<std::complex<double>> &a,
   3428                              int lda, DeviceMemory<std::complex<double>> *x,
   3429                              int incx) {
   3430   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3431             PARAM(lda), PARAM(x), PARAM(incx));
   3432 
   3433   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3434                const DeviceMemory<std::complex<double>> &, int,
   3435                DeviceMemory<std::complex<double>> *, int> impl;
   3436   return impl(this, &blas::BlasSupport::DoBlasTrmv, uplo, trans, diag, n, a,
   3437               lda, x, incx);
   3438 }
   3439 
   3440 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   3441                              blas::Diagonal diag, uint64 n,
   3442                              const DeviceMemory<float> &a, int lda,
   3443                              DeviceMemory<float> *x, int incx) {
   3444   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3445             PARAM(lda), PARAM(x), PARAM(incx));
   3446 
   3447   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3448                const DeviceMemory<float> &, int, DeviceMemory<float> *,
   3449                int> impl;
   3450   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
   3451               lda, x, incx);
   3452 }
   3453 
   3454 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   3455                              blas::Diagonal diag, uint64 n,
   3456                              const DeviceMemory<double> &a, int lda,
   3457                              DeviceMemory<double> *x, int incx) {
   3458   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3459             PARAM(lda), PARAM(x), PARAM(incx));
   3460 
   3461   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3462                const DeviceMemory<double> &, int, DeviceMemory<double> *,
   3463                int> impl;
   3464   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
   3465               lda, x, incx);
   3466 }
   3467 
   3468 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   3469                              blas::Diagonal diag, uint64 n,
   3470                              const DeviceMemory<std::complex<float>> &a,
   3471                              int lda, DeviceMemory<std::complex<float>> *x,
   3472                              int incx) {
   3473   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3474             PARAM(lda), PARAM(x), PARAM(incx));
   3475 
   3476   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3477                const DeviceMemory<std::complex<float>> &, int,
   3478                DeviceMemory<std::complex<float>> *, int> impl;
   3479   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
   3480               lda, x, incx);
   3481 }
   3482 
   3483 Stream &Stream::ThenBlasTrsv(blas::UpperLower uplo, blas::Transpose trans,
   3484                              blas::Diagonal diag, uint64 n,
   3485                              const DeviceMemory<std::complex<double>> &a,
   3486                              int lda, DeviceMemory<std::complex<double>> *x,
   3487                              int incx) {
   3488   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(diag), PARAM(n), PARAM(a),
   3489             PARAM(lda), PARAM(x), PARAM(incx));
   3490 
   3491   ThenBlasImpl<blas::UpperLower, blas::Transpose, blas::Diagonal, uint64,
   3492                const DeviceMemory<std::complex<double>> &, int,
   3493                DeviceMemory<std::complex<double>> *, int> impl;
   3494   return impl(this, &blas::BlasSupport::DoBlasTrsv, uplo, trans, diag, n, a,
   3495               lda, x, incx);
   3496 }
   3497 
   3498 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3499                              uint64 m, uint64 n, uint64 k, float alpha,
   3500                              const DeviceMemory<Eigen::half> &a, int lda,
   3501                              const DeviceMemory<Eigen::half> &b, int ldb,
   3502                              float beta,
   3503                              DeviceMemory<Eigen::half> *c, int ldc) {
   3504   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3505             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3506             PARAM(beta), PARAM(c), PARAM(ldc));
   3507 
   3508   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   3509                const DeviceMemory<Eigen::half> &, int,
   3510                const DeviceMemory<Eigen::half> &, int,
   3511                float, DeviceMemory<Eigen::half> *, int> impl;
   3512   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3513               alpha, a, lda, b, ldb, beta, c, ldc);
   3514 }
   3515 
   3516 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3517                              uint64 m, uint64 n, uint64 k, float alpha,
   3518                              const DeviceMemory<float> &a, int lda,
   3519                              const DeviceMemory<float> &b, int ldb, float beta,
   3520                              DeviceMemory<float> *c, int ldc) {
   3521   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3522             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3523             PARAM(beta), PARAM(c), PARAM(ldc));
   3524 
   3525   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   3526                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   3527                int, float, DeviceMemory<float> *, int> impl;
   3528   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3529               alpha, a, lda, b, ldb, beta, c, ldc);
   3530 }
   3531 
   3532 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3533                              uint64 m, uint64 n, uint64 k, double alpha,
   3534                              const DeviceMemory<double> &a, int lda,
   3535                              const DeviceMemory<double> &b, int ldb,
   3536                              double beta, DeviceMemory<double> *c, int ldc) {
   3537   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3538             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3539             PARAM(beta), PARAM(c), PARAM(ldc));
   3540 
   3541   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
   3542                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   3543                int, double, DeviceMemory<double> *, int> impl;
   3544   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3545               alpha, a, lda, b, ldb, beta, c, ldc);
   3546 }
   3547 
   3548 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3549                              uint64 m, uint64 n, uint64 k,
   3550                              std::complex<float> alpha,
   3551                              const DeviceMemory<std::complex<float>> &a,
   3552                              int lda,
   3553                              const DeviceMemory<std::complex<float>> &b,
   3554                              int ldb, std::complex<float> beta,
   3555                              DeviceMemory<std::complex<float>> *c, int ldc) {
   3556   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3557             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3558             PARAM(beta), PARAM(c), PARAM(ldc));
   3559 
   3560   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3561                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   3562                int, const DeviceMemory<std::complex<float>> &, int,
   3563                std::complex<float>, DeviceMemory<std::complex<float>> *,
   3564                int> impl;
   3565   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3566               alpha, a, lda, b, ldb, beta, c, ldc);
   3567 }
   3568 
   3569 Stream &Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
   3570                              uint64 m, uint64 n, uint64 k,
   3571                              std::complex<double> alpha,
   3572                              const DeviceMemory<std::complex<double>> &a,
   3573                              int lda,
   3574                              const DeviceMemory<std::complex<double>> &b,
   3575                              int ldb, std::complex<double> beta,
   3576                              DeviceMemory<std::complex<double>> *c, int ldc) {
   3577   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3578             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3579             PARAM(beta), PARAM(c), PARAM(ldc));
   3580 
   3581   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3582                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   3583                int, const DeviceMemory<std::complex<double>> &, int,
   3584                std::complex<double>, DeviceMemory<std::complex<double>> *,
   3585                int> impl;
   3586   return impl(this, &blas::BlasSupport::DoBlasGemm, transa, transb, m, n, k,
   3587               alpha, a, lda, b, ldb, beta, c, ldc);
   3588 }
   3589 
   3590 namespace {
   3591 // Like ThenBlasImpl, except this expects the last argument of blas_func to be a
   3592 // blas::ProfileResult*.  This functor doesn't put the stream into an error
   3593 // state if the op fails and the profile result is non-null.  Instead, the
   3594 // error-ness is returned in the profile result itself.
   3595 template <typename... Args>
   3596 struct ThenBlasWithProfileImpl {
   3597   Stream &operator()(Stream *stream,
   3598                      bool (blas::BlasSupport::*blas_func)(
   3599                          Stream *, Args..., blas::ProfileResult *),
   3600                      Args... args, blas::ProfileResult *profile_result) {
   3601     ThenBlasImpl<Args..., blas::ProfileResult *> Runner;
   3602     bool record_error = profile_result == nullptr;
   3603     return Runner.Run(stream, blas_func, record_error, args..., profile_result);
   3604   }
   3605 };
   3606 }  // anonymous namespace
   3607 
   3608 Stream &Stream::ThenBlasGemvWithProfiling(
   3609     blas::Transpose trans, uint64 m, uint64 n, float alpha,
   3610     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &x,
   3611     int incx, float beta, DeviceMemory<float> *y, int incy,
   3612     blas::ProfileResult *output_profile_result) {
   3613   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   3614             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   3615             PARAM(incy));
   3616 
   3617   ThenBlasWithProfileImpl<
   3618       blas::Transpose, uint64, uint64, float, const DeviceMemory<float> &, int,
   3619       const DeviceMemory<float> &, int, float, DeviceMemory<float> *, int>
   3620       impl;
   3621   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
   3622               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
   3623 }
   3624 
   3625 Stream &Stream::ThenBlasGemvWithProfiling(
   3626     blas::Transpose trans, uint64 m, uint64 n, double alpha,
   3627     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &x,
   3628     int incx, double beta, DeviceMemory<double> *y, int incy,
   3629     blas::ProfileResult *output_profile_result) {
   3630   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   3631             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   3632             PARAM(incy));
   3633 
   3634   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, double,
   3635                           const DeviceMemory<double> &, int,
   3636                           const DeviceMemory<double> &, int, double,
   3637                           DeviceMemory<double> *, int>
   3638       impl;
   3639   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
   3640               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
   3641 }
   3642 
   3643 Stream &Stream::ThenBlasGemvWithProfiling(
   3644     blas::Transpose trans, uint64 m, uint64 n, std::complex<float> alpha,
   3645     const DeviceMemory<std::complex<float>> &a, int lda,
   3646     const DeviceMemory<std::complex<float>> &x, int incx,
   3647     std::complex<float> beta, DeviceMemory<std::complex<float>> *y, int incy,
   3648     blas::ProfileResult *output_profile_result) {
   3649   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   3650             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   3651             PARAM(incy));
   3652 
   3653   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<float>,
   3654                           const DeviceMemory<std::complex<float>> &, int,
   3655                           const DeviceMemory<std::complex<float>> &, int,
   3656                           std::complex<float>,
   3657                           DeviceMemory<std::complex<float>> *, int>
   3658       impl;
   3659   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
   3660               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
   3661 }
   3662 
   3663 Stream &Stream::ThenBlasGemvWithProfiling(
   3664     blas::Transpose trans, uint64 m, uint64 n, std::complex<double> alpha,
   3665     const DeviceMemory<std::complex<double>> &a, int lda,
   3666     const DeviceMemory<std::complex<double>> &x, int incx,
   3667     std::complex<double> beta, DeviceMemory<std::complex<double>> *y, int incy,
   3668     blas::ProfileResult *output_profile_result) {
   3669   VLOG_CALL(PARAM(trans), PARAM(m), PARAM(n), PARAM(alpha), PARAM(a),
   3670             PARAM(lda), PARAM(x), PARAM(incx), PARAM(beta), PARAM(y),
   3671             PARAM(incy));
   3672 
   3673   ThenBlasWithProfileImpl<blas::Transpose, uint64, uint64, std::complex<double>,
   3674                           const DeviceMemory<std::complex<double>> &, int,
   3675                           const DeviceMemory<std::complex<double>> &, int,
   3676                           std::complex<double>,
   3677                           DeviceMemory<std::complex<double>> *, int>
   3678       impl;
   3679   return impl(this, &blas::BlasSupport::DoBlasGemvWithProfiling, trans, m, n,
   3680               alpha, a, lda, x, incx, beta, y, incy, output_profile_result);
   3681 }
   3682 
   3683 Stream &Stream::ThenBlasGemmWithProfiling(
   3684     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3685     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
   3686     const DeviceMemory<Eigen::half> &b, int ldb, float beta,
   3687     DeviceMemory<Eigen::half> *c, int ldc,
   3688     blas::ProfileResult *output_profile_result) {
   3689   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3690             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3691             PARAM(beta), PARAM(c), PARAM(ldc));
   3692 
   3693   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3694                           uint64, float, const DeviceMemory<Eigen::half> &, int,
   3695                           const DeviceMemory<Eigen::half> &, int, float,
   3696                           DeviceMemory<Eigen::half> *, int>
   3697       impl;
   3698   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3699               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3700               output_profile_result);
   3701 }
   3702 
   3703 Stream &Stream::ThenBlasGemmWithProfiling(
   3704     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3705     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
   3706     const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
   3707     int ldc, blas::ProfileResult *output_profile_result) {
   3708   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3709             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3710             PARAM(beta), PARAM(c), PARAM(ldc));
   3711 
   3712   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3713                           uint64, float, const DeviceMemory<float> &, int,
   3714                           const DeviceMemory<float> &, int, float,
   3715                           DeviceMemory<float> *, int>
   3716       impl;
   3717   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3718               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3719               output_profile_result);
   3720 }
   3721 
   3722 Stream &Stream::ThenBlasGemmWithProfiling(
   3723     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3724     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
   3725     const DeviceMemory<double> &b, int ldb, double beta,
   3726     DeviceMemory<double> *c, int ldc,
   3727     blas::ProfileResult *output_profile_result) {
   3728   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3729             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3730             PARAM(beta), PARAM(c), PARAM(ldc));
   3731 
   3732   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3733                           uint64, double, const DeviceMemory<double> &, int,
   3734                           const DeviceMemory<double> &, int, double,
   3735                           DeviceMemory<double> *, int>
   3736       impl;
   3737   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3738               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3739               output_profile_result);
   3740 }
   3741 
   3742 Stream &Stream::ThenBlasGemmWithProfiling(
   3743     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3744     uint64 k, std::complex<float> alpha,
   3745     const DeviceMemory<std::complex<float>> &a, int lda,
   3746     const DeviceMemory<std::complex<float>> &b, int ldb,
   3747     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
   3748     blas::ProfileResult *output_profile_result) {
   3749   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3750             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3751             PARAM(beta), PARAM(c), PARAM(ldc));
   3752 
   3753   ThenBlasWithProfileImpl<
   3754       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3755       std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
   3756       const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
   3757       DeviceMemory<std::complex<float>> *, int>
   3758       impl;
   3759   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3760               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3761               output_profile_result);
   3762 }
   3763 
   3764 Stream &Stream::ThenBlasGemmWithProfiling(
   3765     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3766     uint64 k, std::complex<double> alpha,
   3767     const DeviceMemory<std::complex<double>> &a, int lda,
   3768     const DeviceMemory<std::complex<double>> &b, int ldb,
   3769     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
   3770     blas::ProfileResult *output_profile_result) {
   3771   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3772             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3773             PARAM(beta), PARAM(c), PARAM(ldc));
   3774 
   3775   ThenBlasWithProfileImpl<
   3776       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3777       std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
   3778       const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
   3779       DeviceMemory<std::complex<double>> *, int>
   3780       impl;
   3781   return impl(this, &blas::BlasSupport::DoBlasGemmWithProfiling, transa, transb,
   3782               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
   3783               output_profile_result);
   3784 }
   3785 
   3786 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3787     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3788     uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
   3789     const DeviceMemory<Eigen::half> &a, int lda,
   3790     const DeviceMemory<Eigen::half> &b, int ldb,
   3791     const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
   3792     int ldc, blas::ComputationType computation_type,
   3793     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   3794   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3795             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3796             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3797             PARAM(algorithm));
   3798 
   3799   ThenBlasWithProfileImpl<
   3800       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3801       const HostOrDeviceScalar<Eigen::half> &,
   3802       const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
   3803       int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
   3804       int, blas::ComputationType, blas::AlgorithmType>
   3805       impl;
   3806   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3807               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3808               algorithm, output_profile_result);
   3809 }
   3810 
   3811 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3812     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3813     uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
   3814     int lda, const DeviceMemory<int8> &b, int ldb,
   3815     const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
   3816     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   3817     blas::ProfileResult *output_profile_result) {
   3818   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3819             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3820             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3821             PARAM(algorithm));
   3822 
   3823   ThenBlasWithProfileImpl<
   3824       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3825       const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
   3826       const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
   3827       DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
   3828       impl;
   3829   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3830               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3831               algorithm, output_profile_result);
   3832 }
   3833 
   3834 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3835     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3836     uint64 k, const HostOrDeviceScalar<float> &alpha,
   3837     const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
   3838     int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
   3839     int ldc, blas::ComputationType computation_type,
   3840     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   3841   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3842             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3843             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3844             PARAM(algorithm));
   3845 
   3846   ThenBlasWithProfileImpl<
   3847       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3848       const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
   3849       const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
   3850       DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
   3851       impl;
   3852   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3853               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3854               algorithm, output_profile_result);
   3855 }
   3856 
   3857 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3858     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3859     uint64 k, const HostOrDeviceScalar<double> &alpha,
   3860     const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
   3861     int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
   3862     int ldc, blas::ComputationType computation_type,
   3863     blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
   3864   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3865             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3866             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3867             PARAM(algorithm));
   3868 
   3869   ThenBlasWithProfileImpl<
   3870       blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   3871       const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
   3872       const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
   3873       DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
   3874       impl;
   3875   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3876               m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
   3877               HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
   3878               algorithm, output_profile_result);
   3879 }
   3880 
   3881 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3882     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3883     uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
   3884     const DeviceMemory<std::complex<float>> &a, int lda,
   3885     const DeviceMemory<std::complex<float>> &b, int ldb,
   3886     const HostOrDeviceScalar<std::complex<float>> &beta,
   3887     DeviceMemory<std::complex<float>> *c, int ldc,
   3888     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   3889     blas::ProfileResult *output_profile_result) {
   3890   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3891             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3892             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3893             PARAM(algorithm));
   3894 
   3895   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3896                           uint64,
   3897                           const HostOrDeviceScalar<std::complex<float>> &,
   3898                           const DeviceMemory<std::complex<float>> &, int,
   3899                           const DeviceMemory<std::complex<float>> &, int,
   3900                           const HostOrDeviceScalar<std::complex<float>> &,
   3901                           DeviceMemory<std::complex<float>> *, int,
   3902                           blas::ComputationType, blas::AlgorithmType>
   3903       impl;
   3904   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3905               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3906               algorithm, output_profile_result);
   3907 }
   3908 
   3909 Stream &Stream::ThenBlasGemmWithAlgorithm(
   3910     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   3911     uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
   3912     const DeviceMemory<std::complex<double>> &a, int lda,
   3913     const DeviceMemory<std::complex<double>> &b, int ldb,
   3914     const HostOrDeviceScalar<std::complex<double>> &beta,
   3915     DeviceMemory<std::complex<double>> *c, int ldc,
   3916     blas::ComputationType computation_type, blas::AlgorithmType algorithm,
   3917     blas::ProfileResult *output_profile_result) {
   3918   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   3919             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   3920             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
   3921             PARAM(algorithm));
   3922 
   3923   ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
   3924                           uint64,
   3925                           const HostOrDeviceScalar<std::complex<double>> &,
   3926                           const DeviceMemory<std::complex<double>> &, int,
   3927                           const DeviceMemory<std::complex<double>> &, int,
   3928                           const HostOrDeviceScalar<std::complex<double>> &,
   3929                           DeviceMemory<std::complex<double>> *, int,
   3930                           blas::ComputationType, blas::AlgorithmType>
   3931       impl;
   3932   return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
   3933               m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
   3934               algorithm, output_profile_result);
   3935 }
   3936 
   3937 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
   3938                              uint64 n, std::complex<float> alpha,
   3939                              const DeviceMemory<std::complex<float>> &a,
   3940                              int lda,
   3941                              const DeviceMemory<std::complex<float>> &b,
   3942                              int ldb, std::complex<float> beta,
   3943                              DeviceMemory<std::complex<float>> *c, int ldc) {
   3944   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   3945             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   3946             PARAM(ldc));
   3947 
   3948   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
   3949                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   3950                int, const DeviceMemory<std::complex<float>> &, int,
   3951                std::complex<float>, DeviceMemory<std::complex<float>> *,
   3952                int> impl;
   3953   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
   3954               lda, b, ldb, beta, c, ldc);
   3955 }
   3956 
   3957 Stream &Stream::ThenBlasHemm(blas::Side side, blas::UpperLower uplo, uint64 m,
   3958                              uint64 n, std::complex<double> alpha,
   3959                              const DeviceMemory<std::complex<double>> &a,
   3960                              int lda,
   3961                              const DeviceMemory<std::complex<double>> &b,
   3962                              int ldb, std::complex<double> beta,
   3963                              DeviceMemory<std::complex<double>> *c, int ldc) {
   3964   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   3965             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   3966             PARAM(ldc));
   3967 
   3968   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
   3969                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   3970                int, const DeviceMemory<std::complex<double>> &, int,
   3971                std::complex<double>, DeviceMemory<std::complex<double>> *,
   3972                int> impl;
   3973   return impl(this, &blas::BlasSupport::DoBlasHemm, side, uplo, m, n, alpha, a,
   3974               lda, b, ldb, beta, c, ldc);
   3975 }
   3976 
   3977 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
   3978                              uint64 n, uint64 k, float alpha,
   3979                              const DeviceMemory<std::complex<float>> &a,
   3980                              int lda, float beta,
   3981                              DeviceMemory<std::complex<float>> *c, int ldc) {
   3982   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   3983             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   3984 
   3985   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
   3986                const DeviceMemory<std::complex<float>> &, int, float,
   3987                DeviceMemory<std::complex<float>> *, int> impl;
   3988   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
   3989               lda, beta, c, ldc);
   3990 }
   3991 
   3992 Stream &Stream::ThenBlasHerk(blas::UpperLower uplo, blas::Transpose trans,
   3993                              uint64 n, uint64 k, double alpha,
   3994                              const DeviceMemory<std::complex<double>> &a,
   3995                              int lda, double beta,
   3996                              DeviceMemory<std::complex<double>> *c, int ldc) {
   3997   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   3998             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   3999 
   4000   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
   4001                const DeviceMemory<std::complex<double>> &, int, double,
   4002                DeviceMemory<std::complex<double>> *, int> impl;
   4003   return impl(this, &blas::BlasSupport::DoBlasHerk, uplo, trans, n, k, alpha, a,
   4004               lda, beta, c, ldc);
   4005 }
   4006 
   4007 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
   4008                               uint64 n, uint64 k, std::complex<float> alpha,
   4009                               const DeviceMemory<std::complex<float>> &a,
   4010                               int lda,
   4011                               const DeviceMemory<std::complex<float>> &b,
   4012                               int ldb, float beta,
   4013                               DeviceMemory<std::complex<float>> *c, int ldc) {
   4014   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4015             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4016             PARAM(ldc));
   4017 
   4018   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4019                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4020                int, const DeviceMemory<std::complex<float>> &, int, float,
   4021                DeviceMemory<std::complex<float>> *, int> impl;
   4022   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
   4023               a, lda, b, ldb, beta, c, ldc);
   4024 }
   4025 
   4026 Stream &Stream::ThenBlasHer2k(blas::UpperLower uplo, blas::Transpose trans,
   4027                               uint64 n, uint64 k, std::complex<double> alpha,
   4028                               const DeviceMemory<std::complex<double>> &a,
   4029                               int lda,
   4030                               const DeviceMemory<std::complex<double>> &b,
   4031                               int ldb, double beta,
   4032                               DeviceMemory<std::complex<double>> *c, int ldc) {
   4033   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4034             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4035             PARAM(ldc));
   4036 
   4037   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4038                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4039                int, const DeviceMemory<std::complex<double>> &, int, double,
   4040                DeviceMemory<std::complex<double>> *, int> impl;
   4041   return impl(this, &blas::BlasSupport::DoBlasHer2k, uplo, trans, n, k, alpha,
   4042               a, lda, b, ldb, beta, c, ldc);
   4043 }
   4044 
   4045 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   4046                              uint64 n, float alpha,
   4047                              const DeviceMemory<float> &a, int lda,
   4048                              const DeviceMemory<float> &b, int ldb, float beta,
   4049                              DeviceMemory<float> *c, int ldc) {
   4050   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   4051             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4052             PARAM(ldc));
   4053 
   4054   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, float,
   4055                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   4056                int, float, DeviceMemory<float> *, int> impl;
   4057   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
   4058               lda, b, ldb, beta, c, ldc);
   4059 }
   4060 
   4061 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   4062                              uint64 n, double alpha,
   4063                              const DeviceMemory<double> &a, int lda,
   4064                              const DeviceMemory<double> &b, int ldb,
   4065                              double beta, DeviceMemory<double> *c, int ldc) {
   4066   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   4067             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4068             PARAM(ldc));
   4069 
   4070   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64, double,
   4071                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   4072                int, double, DeviceMemory<double> *, int> impl;
   4073   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
   4074               lda, b, ldb, beta, c, ldc);
   4075 }
   4076 
   4077 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   4078                              uint64 n, std::complex<float> alpha,
   4079                              const DeviceMemory<std::complex<float>> &a,
   4080                              int lda,
   4081                              const DeviceMemory<std::complex<float>> &b,
   4082                              int ldb, std::complex<float> beta,
   4083                              DeviceMemory<std::complex<float>> *c, int ldc) {
   4084   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   4085             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4086             PARAM(ldc));
   4087 
   4088   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
   4089                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4090                int, const DeviceMemory<std::complex<float>> &, int,
   4091                std::complex<float>, DeviceMemory<std::complex<float>> *,
   4092                int> impl;
   4093   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
   4094               lda, b, ldb, beta, c, ldc);
   4095 }
   4096 
   4097 Stream &Stream::ThenBlasSymm(blas::Side side, blas::UpperLower uplo, uint64 m,
   4098                              uint64 n, std::complex<double> alpha,
   4099                              const DeviceMemory<std::complex<double>> &a,
   4100                              int lda,
   4101                              const DeviceMemory<std::complex<double>> &b,
   4102                              int ldb, std::complex<double> beta,
   4103                              DeviceMemory<std::complex<double>> *c, int ldc) {
   4104   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(m), PARAM(n), PARAM(alpha),
   4105             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4106             PARAM(ldc));
   4107 
   4108   ThenBlasImpl<blas::Side, blas::UpperLower, uint64, uint64,
   4109                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4110                int, const DeviceMemory<std::complex<double>> &, int,
   4111                std::complex<double>, DeviceMemory<std::complex<double>> *,
   4112                int> impl;
   4113   return impl(this, &blas::BlasSupport::DoBlasSymm, side, uplo, m, n, alpha, a,
   4114               lda, b, ldb, beta, c, ldc);
   4115 }
   4116 
   4117 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
   4118                              uint64 n, uint64 k, float alpha,
   4119                              const DeviceMemory<float> &a, int lda, float beta,
   4120                              DeviceMemory<float> *c, int ldc) {
   4121   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4122             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   4123 
   4124   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
   4125                const DeviceMemory<float> &, int, float, DeviceMemory<float> *,
   4126                int> impl;
   4127   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
   4128               lda, beta, c, ldc);
   4129 }
   4130 
   4131 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
   4132                              uint64 n, uint64 k, double alpha,
   4133                              const DeviceMemory<double> &a, int lda,
   4134                              double beta, DeviceMemory<double> *c, int ldc) {
   4135   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4136             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   4137 
   4138   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
   4139                const DeviceMemory<double> &, int, double,
   4140                DeviceMemory<double> *, int> impl;
   4141   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
   4142               lda, beta, c, ldc);
   4143 }
   4144 
   4145 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
   4146                              uint64 n, uint64 k, std::complex<float> alpha,
   4147                              const DeviceMemory<std::complex<float>> &a,
   4148                              int lda, std::complex<float> beta,
   4149                              DeviceMemory<std::complex<float>> *c, int ldc) {
   4150   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4151             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   4152 
   4153   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4154                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4155                int, std::complex<float>, DeviceMemory<std::complex<float>> *,
   4156                int> impl;
   4157   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
   4158               lda, beta, c, ldc);
   4159 }
   4160 
   4161 Stream &Stream::ThenBlasSyrk(blas::UpperLower uplo, blas::Transpose trans,
   4162                              uint64 n, uint64 k, std::complex<double> alpha,
   4163                              const DeviceMemory<std::complex<double>> &a,
   4164                              int lda, std::complex<double> beta,
   4165                              DeviceMemory<std::complex<double>> *c, int ldc) {
   4166   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4167             PARAM(a), PARAM(lda), PARAM(beta), PARAM(c), PARAM(ldc));
   4168 
   4169   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4170                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4171                int, std::complex<double>, DeviceMemory<std::complex<double>> *,
   4172                int> impl;
   4173   return impl(this, &blas::BlasSupport::DoBlasSyrk, uplo, trans, n, k, alpha, a,
   4174               lda, beta, c, ldc);
   4175 }
   4176 
   4177 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
   4178                               uint64 n, uint64 k, float alpha,
   4179                               const DeviceMemory<float> &a, int lda,
   4180                               const DeviceMemory<float> &b, int ldb, float beta,
   4181                               DeviceMemory<float> *c, int ldc) {
   4182   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4183             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4184             PARAM(ldc));
   4185 
   4186   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, float,
   4187                const DeviceMemory<float> &, int, const DeviceMemory<float> &,
   4188                int, float, DeviceMemory<float> *, int> impl;
   4189   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
   4190               a, lda, b, ldb, beta, c, ldc);
   4191 }
   4192 
   4193 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
   4194                               uint64 n, uint64 k, double alpha,
   4195                               const DeviceMemory<double> &a, int lda,
   4196                               const DeviceMemory<double> &b, int ldb,
   4197                               double beta, DeviceMemory<double> *c, int ldc) {
   4198   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4199             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4200             PARAM(ldc));
   4201 
   4202   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64, double,
   4203                const DeviceMemory<double> &, int, const DeviceMemory<double> &,
   4204                int, double, DeviceMemory<double> *, int> impl;
   4205   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
   4206               a, lda, b, ldb, beta, c, ldc);
   4207 }
   4208 
   4209 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
   4210                               uint64 n, uint64 k, std::complex<float> alpha,
   4211                               const DeviceMemory<std::complex<float>> &a,
   4212                               int lda,
   4213                               const DeviceMemory<std::complex<float>> &b,
   4214                               int ldb, std::complex<float> beta,
   4215                               DeviceMemory<std::complex<float>> *c, int ldc) {
   4216   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4217             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4218             PARAM(ldc));
   4219 
   4220   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4221                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4222                int, const DeviceMemory<std::complex<float>> &, int,
   4223                std::complex<float>, DeviceMemory<std::complex<float>> *,
   4224                int> impl;
   4225   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
   4226               a, lda, b, ldb, beta, c, ldc);
   4227 }
   4228 
   4229 Stream &Stream::ThenBlasSyr2k(blas::UpperLower uplo, blas::Transpose trans,
   4230                               uint64 n, uint64 k, std::complex<double> alpha,
   4231                               const DeviceMemory<std::complex<double>> &a,
   4232                               int lda,
   4233                               const DeviceMemory<std::complex<double>> &b,
   4234                               int ldb, std::complex<double> beta,
   4235                               DeviceMemory<std::complex<double>> *c, int ldc) {
   4236   VLOG_CALL(PARAM(uplo), PARAM(trans), PARAM(n), PARAM(k), PARAM(alpha),
   4237             PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c),
   4238             PARAM(ldc));
   4239 
   4240   ThenBlasImpl<blas::UpperLower, blas::Transpose, uint64, uint64,
   4241                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4242                int, const DeviceMemory<std::complex<double>> &, int,
   4243                std::complex<double>, DeviceMemory<std::complex<double>> *,
   4244                int> impl;
   4245   return impl(this, &blas::BlasSupport::DoBlasSyr2k, uplo, trans, n, k, alpha,
   4246               a, lda, b, ldb, beta, c, ldc);
   4247 }
   4248 
   4249 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   4250                              blas::Transpose transa, blas::Diagonal diag,
   4251                              uint64 m, uint64 n, float alpha,
   4252                              const DeviceMemory<float> &a, int lda,
   4253                              DeviceMemory<float> *b, int ldb) {
   4254   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4255             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4256 
   4257   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4258                uint64, uint64, float, const DeviceMemory<float> &, int,
   4259                DeviceMemory<float> *, int> impl;
   4260   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
   4261               n, alpha, a, lda, b, ldb);
   4262 }
   4263 
   4264 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   4265                              blas::Transpose transa, blas::Diagonal diag,
   4266                              uint64 m, uint64 n, double alpha,
   4267                              const DeviceMemory<double> &a, int lda,
   4268                              DeviceMemory<double> *b, int ldb) {
   4269   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4270             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4271 
   4272   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4273                uint64, uint64, double, const DeviceMemory<double> &, int,
   4274                DeviceMemory<double> *, int> impl;
   4275   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
   4276               n, alpha, a, lda, b, ldb);
   4277 }
   4278 
   4279 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   4280                              blas::Transpose transa, blas::Diagonal diag,
   4281                              uint64 m, uint64 n, std::complex<float> alpha,
   4282                              const DeviceMemory<std::complex<float>> &a,
   4283                              int lda, DeviceMemory<std::complex<float>> *b,
   4284                              int ldb) {
   4285   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4286             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4287 
   4288   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4289                uint64, uint64, std::complex<float>,
   4290                const DeviceMemory<std::complex<float>> &, int,
   4291                DeviceMemory<std::complex<float>> *, int> impl;
   4292   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
   4293               n, alpha, a, lda, b, ldb);
   4294 }
   4295 
   4296 Stream &Stream::ThenBlasTrmm(blas::Side side, blas::UpperLower uplo,
   4297                              blas::Transpose transa, blas::Diagonal diag,
   4298                              uint64 m, uint64 n, std::complex<double> alpha,
   4299                              const DeviceMemory<std::complex<double>> &a,
   4300                              int lda, DeviceMemory<std::complex<double>> *b,
   4301                              int ldb) {
   4302   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4303             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4304 
   4305   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4306                uint64, uint64, std::complex<double>,
   4307                const DeviceMemory<std::complex<double>> &, int,
   4308                DeviceMemory<std::complex<double>> *, int> impl;
   4309   return impl(this, &blas::BlasSupport::DoBlasTrmm, side, uplo, transa, diag, m,
   4310               n, alpha, a, lda, b, ldb);
   4311 }
   4312 
   4313 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   4314                              blas::Transpose transa, blas::Diagonal diag,
   4315                              uint64 m, uint64 n, float alpha,
   4316                              const DeviceMemory<float> &a, int lda,
   4317                              DeviceMemory<float> *b, int ldb) {
   4318   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4319             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4320 
   4321   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4322                uint64, uint64, float, const DeviceMemory<float> &, int,
   4323                DeviceMemory<float> *, int> impl;
   4324   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
   4325               n, alpha, a, lda, b, ldb);
   4326 }
   4327 
   4328 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   4329                              blas::Transpose transa, blas::Diagonal diag,
   4330                              uint64 m, uint64 n, double alpha,
   4331                              const DeviceMemory<double> &a, int lda,
   4332                              DeviceMemory<double> *b, int ldb) {
   4333   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4334             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4335 
   4336   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4337                uint64, uint64, double, const DeviceMemory<double> &, int,
   4338                DeviceMemory<double> *, int> impl;
   4339   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
   4340               n, alpha, a, lda, b, ldb);
   4341 }
   4342 
   4343 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   4344                              blas::Transpose transa, blas::Diagonal diag,
   4345                              uint64 m, uint64 n, std::complex<float> alpha,
   4346                              const DeviceMemory<std::complex<float>> &a,
   4347                              int lda, DeviceMemory<std::complex<float>> *b,
   4348                              int ldb) {
   4349   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4350             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4351 
   4352   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4353                uint64, uint64, std::complex<float>,
   4354                const DeviceMemory<std::complex<float>> &, int,
   4355                DeviceMemory<std::complex<float>> *, int> impl;
   4356   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
   4357               n, alpha, a, lda, b, ldb);
   4358 }
   4359 
   4360 Stream &Stream::ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
   4361                              blas::Transpose transa, blas::Diagonal diag,
   4362                              uint64 m, uint64 n, std::complex<double> alpha,
   4363                              const DeviceMemory<std::complex<double>> &a,
   4364                              int lda, DeviceMemory<std::complex<double>> *b,
   4365                              int ldb) {
   4366   VLOG_CALL(PARAM(side), PARAM(uplo), PARAM(transa), PARAM(diag), PARAM(m),
   4367             PARAM(n), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb));
   4368 
   4369   ThenBlasImpl<blas::Side, blas::UpperLower, blas::Transpose, blas::Diagonal,
   4370                uint64, uint64, std::complex<double>,
   4371                const DeviceMemory<std::complex<double>> &, int,
   4372                DeviceMemory<std::complex<double>> *, int> impl;
   4373   return impl(this, &blas::BlasSupport::DoBlasTrsm, side, uplo, transa, diag, m,
   4374               n, alpha, a, lda, b, ldb);
   4375 }
   4376 
   4377 Stream &Stream::ThenBlasGemmBatched(
   4378     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4379     uint64 k, float alpha,
   4380     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
   4381     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
   4382     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
   4383     int batch_count) {
   4384   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4385                                         b, ldb, beta, c, ldc, batch_count,
   4386                                         /*scratch_allocator=*/nullptr);
   4387 }
   4388 
   4389 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4390     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4391     uint64 k, float alpha,
   4392     const port::ArraySlice<DeviceMemory<Eigen::half> *> &a, int lda,
   4393     const port::ArraySlice<DeviceMemory<Eigen::half> *> &b, int ldb, float beta,
   4394     const port::ArraySlice<DeviceMemory<Eigen::half> *> &c, int ldc,
   4395     int batch_count, ScratchAllocator *scratch_allocator) {
   4396   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4397             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4398             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4399 
   4400   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   4401                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
   4402                const port::ArraySlice<DeviceMemory<Eigen::half> *> &, int,
   4403                float, const port::ArraySlice<DeviceMemory<Eigen::half> *> &,
   4404                int, int, ScratchAllocator *>
   4405       impl;
   4406   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4407               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4408               scratch_allocator);
   4409 }
   4410 
   4411 Stream &Stream::ThenBlasGemmBatched(
   4412     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4413     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
   4414     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
   4415     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
   4416     int batch_count) {
   4417   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4418                                         b, ldb, beta, c, ldc, batch_count,
   4419                                         /*scratch_allocator=*/nullptr);
   4420 }
   4421 
   4422 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4423     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4424     uint64 k, float alpha, const port::ArraySlice<DeviceMemory<float> *> &a,
   4425     int lda, const port::ArraySlice<DeviceMemory<float> *> &b, int ldb,
   4426     float beta, const port::ArraySlice<DeviceMemory<float> *> &c, int ldc,
   4427     int batch_count, ScratchAllocator *scratch_allocator) {
   4428   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4429             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4430             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4431 
   4432   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   4433                const port::ArraySlice<DeviceMemory<float> *> &, int,
   4434                const port::ArraySlice<DeviceMemory<float> *> &, int, float,
   4435                const port::ArraySlice<DeviceMemory<float> *> &, int, int,
   4436                ScratchAllocator *>
   4437       impl;
   4438   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4439               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4440               scratch_allocator);
   4441 }
   4442 
   4443 Stream &Stream::ThenBlasGemmBatched(
   4444     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4445     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
   4446     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
   4447     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
   4448     int batch_count) {
   4449   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4450                                         b, ldb, beta, c, ldc, batch_count,
   4451                                         /*scratch_allocator=*/nullptr);
   4452 }
   4453 
   4454 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4455     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4456     uint64 k, double alpha, const port::ArraySlice<DeviceMemory<double> *> &a,
   4457     int lda, const port::ArraySlice<DeviceMemory<double> *> &b, int ldb,
   4458     double beta, const port::ArraySlice<DeviceMemory<double> *> &c, int ldc,
   4459     int batch_count, ScratchAllocator *scratch_allocator) {
   4460   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4461             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4462             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4463 
   4464   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
   4465                const port::ArraySlice<DeviceMemory<double> *> &, int,
   4466                const port::ArraySlice<DeviceMemory<double> *> &, int, double,
   4467                const port::ArraySlice<DeviceMemory<double> *> &, int, int,
   4468                ScratchAllocator *>
   4469       impl;
   4470   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4471               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4472               scratch_allocator);
   4473 }
   4474 
   4475 Stream &Stream::ThenBlasGemmBatched(
   4476     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4477     uint64 k, std::complex<float> alpha,
   4478     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
   4479     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
   4480     std::complex<float> beta,
   4481     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
   4482     int batch_count) {
   4483   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4484                                         b, ldb, beta, c, ldc, batch_count,
   4485                                         /*scratch_allocator=*/nullptr);
   4486 }
   4487 
   4488 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4489     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4490     uint64 k, std::complex<float> alpha,
   4491     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &a, int lda,
   4492     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &b, int ldb,
   4493     std::complex<float> beta,
   4494     const port::ArraySlice<DeviceMemory<std::complex<float>> *> &c, int ldc,
   4495     int batch_count, ScratchAllocator *scratch_allocator) {
   4496   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4497             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4498             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4499 
   4500   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   4501                std::complex<float>,
   4502                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
   4503                int,
   4504                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
   4505                int, std::complex<float>,
   4506                const port::ArraySlice<DeviceMemory<std::complex<float>> *> &,
   4507                int, int, ScratchAllocator *>
   4508       impl;
   4509   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4510               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4511               scratch_allocator);
   4512 }
   4513 
   4514 Stream &Stream::ThenBlasGemmBatched(
   4515     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4516     uint64 k, std::complex<double> alpha,
   4517     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
   4518     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
   4519     std::complex<double> beta,
   4520     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
   4521     int batch_count) {
   4522   return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
   4523                                         b, ldb, beta, c, ldc, batch_count,
   4524                                         /*scratch_allocator=*/nullptr);
   4525 }
   4526 
   4527 Stream &Stream::ThenBlasGemmBatchedWithScratch(
   4528     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4529     uint64 k, std::complex<double> alpha,
   4530     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &a, int lda,
   4531     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &b, int ldb,
   4532     std::complex<double> beta,
   4533     const port::ArraySlice<DeviceMemory<std::complex<double>> *> &c, int ldc,
   4534     int batch_count, ScratchAllocator *scratch_allocator) {
   4535   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4536             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
   4537             PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));
   4538 
   4539   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   4540                std::complex<double>,
   4541                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
   4542                int,
   4543                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
   4544                int, std::complex<double>,
   4545                const port::ArraySlice<DeviceMemory<std::complex<double>> *> &,
   4546                int, int, ScratchAllocator *>
   4547       impl;
   4548   return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
   4549               k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
   4550               scratch_allocator);
   4551 }
   4552 
   4553 Stream &Stream::ThenBlasGemmStridedBatched(
   4554     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4555     uint64 k, float alpha, const DeviceMemory<Eigen::half> &a, int lda,
   4556     int64 stride_a, const DeviceMemory<Eigen::half> &b, int ldb, int64 stride_b,
   4557     float beta, DeviceMemory<Eigen::half> *c, int ldc, int64 stride_c,
   4558     int batch_count) {
   4559   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4560             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
   4561             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
   4562             PARAM(stride_c), PARAM(batch_count));
   4563 
   4564   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   4565                const DeviceMemory<Eigen::half> &, int, int64,
   4566                const DeviceMemory<Eigen::half> &, int, int64, float,
   4567                DeviceMemory<Eigen::half> *, int, int64, int>
   4568       impl;
   4569   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
   4570               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
   4571               c, ldc, stride_c, batch_count);
   4572 }
   4573 
   4574 Stream &Stream::ThenBlasGemmStridedBatched(
   4575     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4576     uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
   4577     int64 stride_a, const DeviceMemory<float> &b, int ldb, int64 stride_b,
   4578     float beta, DeviceMemory<float> *c, int ldc, int64 stride_c,
   4579     int batch_count) {
   4580   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4581             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
   4582             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
   4583             PARAM(stride_c), PARAM(batch_count));
   4584 
   4585   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
   4586                const DeviceMemory<float> &, int, int64,
   4587                const DeviceMemory<float> &, int, int64, float,
   4588                DeviceMemory<float> *, int, int64, int>
   4589       impl;
   4590   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
   4591               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
   4592               c, ldc, stride_c, batch_count);
   4593 }
   4594 
   4595 Stream &Stream::ThenBlasGemmStridedBatched(
   4596     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4597     uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
   4598     int64 stride_a, const DeviceMemory<double> &b, int ldb, int64 stride_b,
   4599     double beta, DeviceMemory<double> *c, int ldc, int64 stride_c,
   4600     int batch_count) {
   4601   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4602             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
   4603             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
   4604             PARAM(stride_c), PARAM(batch_count));
   4605 
   4606   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64, double,
   4607                const DeviceMemory<double> &, int, int64,
   4608                const DeviceMemory<double> &, int, int64, double,
   4609                DeviceMemory<double> *, int, int64, int>
   4610       impl;
   4611   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
   4612               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
   4613               c, ldc, stride_c, batch_count);
   4614 }
   4615 
   4616 Stream &Stream::ThenBlasGemmStridedBatched(
   4617     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4618     uint64 k, std::complex<float> alpha,
   4619     const DeviceMemory<std::complex<float>> &a, int lda, int64 stride_a,
   4620     const DeviceMemory<std::complex<float>> &b, int ldb, int64 stride_b,
   4621     std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
   4622     int64 stride_c, int batch_count) {
   4623   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4624             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
   4625             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
   4626             PARAM(stride_c), PARAM(batch_count));
   4627 
   4628   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   4629                std::complex<float>, const DeviceMemory<std::complex<float>> &,
   4630                int, int64, const DeviceMemory<std::complex<float>> &, int,
   4631                int64, std::complex<float>, DeviceMemory<std::complex<float>> *,
   4632                int, int64, int>
   4633       impl;
   4634   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
   4635               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
   4636               c, ldc, stride_c, batch_count);
   4637 }
   4638 
   4639 Stream &Stream::ThenBlasGemmStridedBatched(
   4640     blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
   4641     uint64 k, std::complex<double> alpha,
   4642     const DeviceMemory<std::complex<double>> &a, int lda, int64 stride_a,
   4643     const DeviceMemory<std::complex<double>> &b, int ldb, int64 stride_b,
   4644     std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
   4645     int64 stride_c, int batch_count) {
   4646   VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
   4647             PARAM(alpha), PARAM(a), PARAM(lda), PARAM(stride_a), PARAM(b),
   4648             PARAM(ldb), PARAM(stride_b), PARAM(beta), PARAM(c), PARAM(ldc),
   4649             PARAM(stride_c), PARAM(batch_count));
   4650 
   4651   ThenBlasImpl<blas::Transpose, blas::Transpose, uint64, uint64, uint64,
   4652                std::complex<double>, const DeviceMemory<std::complex<double>> &,
   4653                int, int64, const DeviceMemory<std::complex<double>> &, int,
   4654                int64, std::complex<double>,
   4655                DeviceMemory<std::complex<double>> *, int, int64, int>
   4656       impl;
   4657   return impl(this, &blas::BlasSupport::DoBlasGemmStridedBatched, transa,
   4658               transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta,
   4659               c, ldc, stride_c, batch_count);
   4660 }
   4661 
   4662 Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64 seed_bytes) {
   4663   VLOG_CALL(PARAM(seed), PARAM(seed_bytes));
   4664 
   4665   if (ok()) {
   4666     if (rng::RngSupport *rng = parent_->AsRng()) {
   4667       CheckError(rng->SetSeed(this, seed, seed_bytes));
   4668     } else {
   4669       SetError();
   4670       LOG(INFO) << DebugStreamPointers() << " unable to initialize RNG";
   4671     }
   4672   } else {
   4673     LOG(INFO) << DebugStreamPointers()
   4674               << " did not set RNG seed: " << static_cast<const void *>(seed)
   4675               << "; bytes: " << seed_bytes;
   4676   }
   4677   return *this;
   4678 }
   4679 
   4680 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<float> *values) {
   4681   VLOG_CALL(PARAM(values));
   4682 
   4683   if (ok()) {
   4684     if (rng::RngSupport *rng = parent_->AsRng()) {
   4685       CheckError(rng->DoPopulateRandUniform(this, values));
   4686     } else {
   4687       SetError();
   4688       LOG(INFO) << DebugStreamPointers()
   4689                 << " attempting to perform RNG operation using StreamExecutor"
   4690                    " without RNG support.";
   4691     }
   4692   }
   4693   return *this;
   4694 }
   4695 
   4696 Stream &Stream::ThenPopulateRandGaussian(float mean, float sd,
   4697                                          DeviceMemory<float> *values) {
   4698   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
   4699 
   4700   if (ok()) {
   4701     if (rng::RngSupport *rng = parent_->AsRng()) {
   4702       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
   4703     } else {
   4704       SetError();
   4705       LOG(INFO) << DebugStreamPointers()
   4706                 << " attempting to perform RNG operation using StreamExecutor"
   4707                    " without RNG support.";
   4708     }
   4709   }
   4710   return *this;
   4711 }
   4712 
   4713 Stream &Stream::ThenPopulateRandGaussian(double mean, double sd,
   4714                                          DeviceMemory<double> *values) {
   4715   VLOG_CALL(PARAM(mean), PARAM(sd), PARAM(values));
   4716 
   4717   if (ok()) {
   4718     if (rng::RngSupport *rng = parent_->AsRng()) {
   4719       CheckError(rng->DoPopulateRandGaussian(this, mean, sd, values));
   4720     } else {
   4721       SetError();
   4722       LOG(INFO) << DebugStreamPointers()
   4723                 << " attempting to perform RNG operation using StreamExecutor"
   4724                    " without RNG support.";
   4725     }
   4726   }
   4727   return *this;
   4728 }
   4729 
   4730 Stream &Stream::ThenPopulateRandUniform(DeviceMemory<double> *values) {
   4731   VLOG_CALL(PARAM(values));
   4732 
   4733   if (ok()) {
   4734     if (rng::RngSupport *rng = parent_->AsRng()) {
   4735       CheckError(rng->DoPopulateRandUniform(this, values));
   4736     } else {
   4737       SetError();
   4738       LOG(INFO) << DebugStreamPointers()
   4739                 << " attempting to perform RNG operation using StreamExecutor"
   4740                    " without RNG support.";
   4741     }
   4742   }
   4743   return *this;
   4744 }
   4745 
   4746 Stream &Stream::ThenPopulateRandUniform(
   4747     DeviceMemory<std::complex<float>> *values) {
   4748   VLOG_CALL(PARAM(values));
   4749 
   4750   if (ok()) {
   4751     if (rng::RngSupport *rng = parent_->AsRng()) {
   4752       CheckError(rng->DoPopulateRandUniform(this, values));
   4753     } else {
   4754       SetError();
   4755       LOG(INFO) << DebugStreamPointers()
   4756                 << " attempting to perform RNG operation using StreamExecutor"
   4757                    " without RNG support.";
   4758     }
   4759   }
   4760   return *this;
   4761 }
   4762 
   4763 Stream &Stream::ThenPopulateRandUniform(
   4764     DeviceMemory<std::complex<double>> *values) {
   4765   VLOG_CALL(PARAM(values));
   4766 
   4767   if (ok()) {
   4768     if (rng::RngSupport *rng = parent_->AsRng()) {
   4769       CheckError(rng->DoPopulateRandUniform(this, values));
   4770     } else {
   4771       SetError();
   4772       LOG(INFO) << DebugStreamPointers()
   4773                 << " attempting to perform RNG operation using StreamExecutor"
   4774                    " without RNG support.";
   4775     }
   4776   }
   4777   return *this;
   4778 }
   4779 
   4780 Stream &Stream::ThenMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src,
   4781                            uint64 size) {
   4782   VLOG_CALL(PARAM(host_dst), PARAM(gpu_src), PARAM(size));
   4783 
   4784   if (ok()) {
   4785     CheckError(parent_->Memcpy(this, host_dst, gpu_src, size));
   4786   } else {
   4787     LOG(INFO) << DebugStreamPointers()
   4788               << " did not memcpy device-to-host; source: " << gpu_src.opaque();
   4789   }
   4790   return *this;
   4791 }
   4792 
   4793 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src,
   4794                            uint64 size) {
   4795   VLOG_CALL(PARAM(gpu_dst), PARAM(host_src), PARAM(size));
   4796 
   4797   if (ok()) {
   4798     CheckError(parent_->Memcpy(this, gpu_dst, host_src, size));
   4799   } else {
   4800     LOG(INFO) << DebugStreamPointers()
   4801               << " did not memcpy host-to-device; source: " << host_src;
   4802   }
   4803   return *this;
   4804 }
   4805 
   4806 Stream &Stream::ThenMemcpy(DeviceMemoryBase *gpu_dst,
   4807                            const DeviceMemoryBase &gpu_src, uint64 size) {
   4808   VLOG_CALL(PARAM(gpu_dst), PARAM(gpu_src), PARAM(size));
   4809 
   4810   if (ok()) {
   4811     CheckError(parent_->MemcpyDeviceToDevice(this, gpu_dst, gpu_src, size));
   4812   } else {
   4813     LOG(INFO) << DebugStreamPointers()
   4814               << " did not memcpy gpu-to-gpu; source: " << &gpu_src;
   4815   }
   4816   return *this;
   4817 }
   4818 
   4819 Stream &Stream::ThenMemZero(DeviceMemoryBase *location, uint64 size) {
   4820   VLOG_CALL(PARAM(location), PARAM(size));
   4821 
   4822   if (ok()) {
   4823     CheckError(parent_->MemZero(this, location, size));
   4824   } else {
   4825     LOG(INFO) << DebugStreamPointers()
   4826               << " did not memzero GPU location; source: " << location;
   4827   }
   4828   return *this;
   4829 }
   4830 
   4831 Stream &Stream::ThenMemset32(DeviceMemoryBase *location, uint32 pattern,
   4832                              uint64 size) {
   4833   VLOG_CALL(PARAM(location), PARAM(pattern), PARAM(size));
   4834 
   4835   if (ok()) {
   4836     CheckError(parent_->Memset32(this, location, pattern, size));
   4837   } else {
   4838     LOG(INFO) << DebugStreamPointers()
   4839               << " did not memset GPU location; source: " << location
   4840               << "; size: " << size << "; pattern: " << std::hex << pattern;
   4841   }
   4842   return *this;
   4843 }
   4844 
   4845 Stream &Stream::ThenRnnForward(
   4846     const dnn::RnnDescriptor &rnn_desc,
   4847     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4848     const DeviceMemory<Eigen::half> &input_data,
   4849     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4850     const DeviceMemory<Eigen::half> &input_h_data,
   4851     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4852     const DeviceMemory<Eigen::half> &input_c_data,
   4853     const DeviceMemory<Eigen::half> &params,
   4854     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4855     DeviceMemory<Eigen::half> *output_data,
   4856     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4857     DeviceMemory<Eigen::half> *output_h_data,
   4858     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4859     DeviceMemory<Eigen::half> *output_c_data, bool is_training,
   4860     ScratchAllocator *reserve_space_allocator,
   4861     ScratchAllocator *workspace_allocator,
   4862     dnn::ProfileResult *output_profile_result) {
   4863   // TODO(zhengxq): add VLOG PARAM calls.
   4864   if (ok()) {
   4865     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4866       auto status = dnn->DoRnnForward(
   4867           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4868           input_c_desc, input_c_data, params, output_desc, output_data,
   4869           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4870           is_training, reserve_space_allocator, workspace_allocator,
   4871           output_profile_result);
   4872       if (!status && !output_profile_result) {
   4873         SetError();
   4874       }
   4875     } else {
   4876       SetErrorAndLogNoDnnSupport();
   4877     }
   4878   }
   4879   return *this;
   4880 }
   4881 
   4882 Stream &Stream::ThenRnnForward(
   4883     const dnn::RnnDescriptor &rnn_desc,
   4884     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4885     const DeviceMemory<float> &input_data,
   4886     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4887     const DeviceMemory<float> &input_h_data,
   4888     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4889     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
   4890     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4891     DeviceMemory<float> *output_data,
   4892     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4893     DeviceMemory<float> *output_h_data,
   4894     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4895     DeviceMemory<float> *output_c_data, bool is_training,
   4896     ScratchAllocator *reserve_space_allocator,
   4897     ScratchAllocator *workspace_allocator,
   4898     dnn::ProfileResult *output_profile_result) {
   4899   // TODO(zhengxq): add VLOG PARAM calls.
   4900   if (ok()) {
   4901     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4902       auto status = dnn->DoRnnForward(
   4903           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4904           input_c_desc, input_c_data, params, output_desc, output_data,
   4905           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4906           is_training, reserve_space_allocator, workspace_allocator,
   4907           output_profile_result);
   4908       if (!status && !output_profile_result) {
   4909         SetError();
   4910       }
   4911     } else {
   4912       SetErrorAndLogNoDnnSupport();
   4913     }
   4914   }
   4915   return *this;
   4916 }
   4917 
   4918 Stream &Stream::ThenRnnForward(
   4919     const dnn::RnnDescriptor &rnn_desc,
   4920     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4921     const DeviceMemory<double> &input_data,
   4922     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4923     const DeviceMemory<double> &input_h_data,
   4924     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4925     const DeviceMemory<double> &input_c_data,
   4926     const DeviceMemory<double> &params,
   4927     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4928     DeviceMemory<double> *output_data,
   4929     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4930     DeviceMemory<double> *output_h_data,
   4931     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4932     DeviceMemory<double> *output_c_data, bool is_training,
   4933     ScratchAllocator *reserve_space_allocator,
   4934     ScratchAllocator *workspace_allocator,
   4935     dnn::ProfileResult *output_profile_result) {
   4936   // TODO(zhengxq): add VLOG PARAM calls.
   4937   if (ok()) {
   4938     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4939       auto status = dnn->DoRnnForward(
   4940           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4941           input_c_desc, input_c_data, params, output_desc, output_data,
   4942           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4943           is_training, reserve_space_allocator, workspace_allocator,
   4944           output_profile_result);
   4945       if (!status && !output_profile_result) {
   4946         SetError();
   4947       }
   4948     } else {
   4949       SetErrorAndLogNoDnnSupport();
   4950     }
   4951   }
   4952   return *this;
   4953 }
   4954 
   4955 Stream &Stream::ThenRnnBackward(
   4956     const dnn::RnnDescriptor &rnn_desc,
   4957     const dnn::RnnSequenceTensorDescriptor &input_desc,
   4958     const DeviceMemory<Eigen::half> &input_data,
   4959     const dnn::RnnStateTensorDescriptor &input_h_desc,
   4960     const DeviceMemory<Eigen::half> &input_h_data,
   4961     const dnn::RnnStateTensorDescriptor &input_c_desc,
   4962     const DeviceMemory<Eigen::half> &input_c_data,
   4963     const DeviceMemory<Eigen::half> &params,
   4964     const dnn::RnnSequenceTensorDescriptor &output_desc,
   4965     const DeviceMemory<Eigen::half> &output_data,
   4966     const dnn::RnnStateTensorDescriptor &output_h_desc,
   4967     const DeviceMemory<Eigen::half> &output_h_data,
   4968     const dnn::RnnStateTensorDescriptor &output_c_desc,
   4969     const DeviceMemory<Eigen::half> &output_c_data,
   4970     const DeviceMemory<Eigen::half> &output_backprop_data,
   4971     const DeviceMemory<Eigen::half> &output_h_backprop_data,
   4972     const DeviceMemory<Eigen::half> &output_c_backprop_data,
   4973     DeviceMemory<Eigen::half> *input_backprop_data,
   4974     DeviceMemory<Eigen::half> *input_h_backprop_data,
   4975     DeviceMemory<Eigen::half> *input_c_backprop_data,
   4976     DeviceMemory<Eigen::half> *params_backprop_data,
   4977     DeviceMemory<uint8> *reserve_space_data,
   4978     ScratchAllocator *workspace_allocator,
   4979     dnn::ProfileResult *output_profile_result) {
   4980   // TODO(zhengxq): add VLOG PARAM calls.
   4981   if (ok()) {
   4982     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   4983       auto status = dnn->DoRnnBackward(
   4984           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   4985           input_c_desc, input_c_data, params, output_desc, output_data,
   4986           output_h_desc, output_h_data, output_c_desc, output_c_data,
   4987           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
   4988           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
   4989           params_backprop_data, reserve_space_data, workspace_allocator,
   4990           output_profile_result);
   4991       if (!status && !output_profile_result) {
   4992         SetError();
   4993       }
   4994     } else {
   4995       SetError();
   4996       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
   4997     }
   4998   }
   4999   return *this;
   5000 }
   5001 
   5002 Stream &Stream::ThenRnnBackward(
   5003     const dnn::RnnDescriptor &rnn_desc,
   5004     const dnn::RnnSequenceTensorDescriptor &input_desc,
   5005     const DeviceMemory<float> &input_data,
   5006     const dnn::RnnStateTensorDescriptor &input_h_desc,
   5007     const DeviceMemory<float> &input_h_data,
   5008     const dnn::RnnStateTensorDescriptor &input_c_desc,
   5009     const DeviceMemory<float> &input_c_data, const DeviceMemory<float> &params,
   5010     const dnn::RnnSequenceTensorDescriptor &output_desc,
   5011     const DeviceMemory<float> &output_data,
   5012     const dnn::RnnStateTensorDescriptor &output_h_desc,
   5013     const DeviceMemory<float> &output_h_data,
   5014     const dnn::RnnStateTensorDescriptor &output_c_desc,
   5015     const DeviceMemory<float> &output_c_data,
   5016     const DeviceMemory<float> &output_backprop_data,
   5017     const DeviceMemory<float> &output_h_backprop_data,
   5018     const DeviceMemory<float> &output_c_backprop_data,
   5019     DeviceMemory<float> *input_backprop_data,
   5020     DeviceMemory<float> *input_h_backprop_data,
   5021     DeviceMemory<float> *input_c_backprop_data,
   5022     DeviceMemory<float> *params_backprop_data,
   5023     DeviceMemory<uint8> *reserve_space_data,
   5024     ScratchAllocator *workspace_allocator,
   5025     dnn::ProfileResult *output_profile_result) {
   5026   // TODO(zhengxq): add VLOG PARAM calls.
   5027   if (ok()) {
   5028     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   5029       auto status = dnn->DoRnnBackward(
   5030           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   5031           input_c_desc, input_c_data, params, output_desc, output_data,
   5032           output_h_desc, output_h_data, output_c_desc, output_c_data,
   5033           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
   5034           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
   5035           params_backprop_data, reserve_space_data, workspace_allocator,
   5036           output_profile_result);
   5037       if (!status && !output_profile_result) {
   5038         SetError();
   5039       }
   5040     } else {
   5041       SetError();
   5042       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
   5043     }
   5044   }
   5045   return *this;
   5046 }
   5047 
   5048 Stream &Stream::ThenRnnBackward(
   5049     const dnn::RnnDescriptor &rnn_desc,
   5050     const dnn::RnnSequenceTensorDescriptor &input_desc,
   5051     const DeviceMemory<double> &input_data,
   5052     const dnn::RnnStateTensorDescriptor &input_h_desc,
   5053     const DeviceMemory<double> &input_h_data,
   5054     const dnn::RnnStateTensorDescriptor &input_c_desc,
   5055     const DeviceMemory<double> &input_c_data,
   5056     const DeviceMemory<double> &params,
   5057     const dnn::RnnSequenceTensorDescriptor &output_desc,
   5058     const DeviceMemory<double> &output_data,
   5059     const dnn::RnnStateTensorDescriptor &output_h_desc,
   5060     const DeviceMemory<double> &output_h_data,
   5061     const dnn::RnnStateTensorDescriptor &output_c_desc,
   5062     const DeviceMemory<double> &output_c_data,
   5063     const DeviceMemory<double> &output_backprop_data,
   5064     const DeviceMemory<double> &output_h_backprop_data,
   5065     const DeviceMemory<double> &output_c_backprop_data,
   5066     DeviceMemory<double> *input_backprop_data,
   5067     DeviceMemory<double> *input_h_backprop_data,
   5068     DeviceMemory<double> *input_c_backprop_data,
   5069     DeviceMemory<double> *params_backprop_data,
   5070     DeviceMemory<uint8> *reserve_space_data,
   5071     ScratchAllocator *workspace_allocator,
   5072     dnn::ProfileResult *output_profile_result) {
   5073   // TODO(zhengxq): add VLOG PARAM calls.
   5074   if (ok()) {
   5075     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   5076       auto status = dnn->DoRnnBackward(
   5077           this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
   5078           input_c_desc, input_c_data, params, output_desc, output_data,
   5079           output_h_desc, output_h_data, output_c_desc, output_c_data,
   5080           output_backprop_data, output_h_backprop_data, output_c_backprop_data,
   5081           input_backprop_data, input_h_backprop_data, input_c_backprop_data,
   5082           params_backprop_data, reserve_space_data, workspace_allocator,
   5083           output_profile_result);
   5084       if (!status && !output_profile_result) {
   5085         SetError();
   5086       }
   5087     } else {
   5088       SetError();
   5089       LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
   5090     }
   5091   }
   5092   return *this;
   5093 }
   5094 
   5095 Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
   5096                                     dnn::DataType input_type,
   5097                                     const DeviceMemoryBase &input_data,
   5098                                     const dnn::BatchDescriptor &output_desc,
   5099                                     dnn::DataType output_type, float scale,
   5100                                     DeviceMemoryBase *output_data) {
   5101   VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
   5102             PARAM(output_desc), PARAM(output_type), PARAM(scale),
   5103             PARAM(output_data));
   5104   if (ok()) {
   5105     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
   5106       CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
   5107                                         input_data, output_desc, output_type,
   5108                                         scale, output_data));
   5109     } else {
   5110       SetErrorAndLogNoDnnSupport();
   5111     }
   5112   }
   5113   return *this;
   5114 }
   5115 
   5116 Stream &Stream::ThenDoHostCallback(std::function<void()> callback) {
   5117   VLOG_CALL(PARAM(callback));
   5118 
   5119   if (!ok()) {
   5120     LOG(INFO) << DebugStreamPointers()
   5121               << " was in error state before adding host callback";
   5122   }
   5123   CheckError(parent_->HostCallback(this, std::move(callback)));
   5124   return *this;
   5125 }
   5126 
   5127 Stream &Stream::ThenDoHostCallbackWithStatus(
   5128     std::function<port::Status()> callback) {
   5129   VLOG_CALL(PARAM(callback));
   5130 
   5131   if (!ok()) {
   5132     LOG(INFO) << DebugStreamPointers()
   5133               << " was in error state before adding host callback";
   5134   }
   5135   CheckError(parent_->HostCallback(this, std::move(callback)));
   5136   return *this;
   5137 }
   5138 
   5139 Stream &Stream::ThenFft(fft::Plan *plan,
   5140                         const DeviceMemory<std::complex<float>> &input,
   5141                         DeviceMemory<std::complex<float>> *output) {
   5142   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   5143 
   5144   if (ok()) {
   5145     if (fft::FftSupport *fft = parent_->AsFft()) {
   5146       CheckError(fft->DoFft(this, plan, input, output));
   5147     } else {
   5148       SetError();
   5149       LOG(INFO) << DebugStreamPointers()
   5150                 << " attempting to perform FFT operation using StreamExecutor"
   5151                    " without FFT support";
   5152     }
   5153   }
   5154   return *this;
   5155 }
   5156 
   5157 Stream &Stream::ThenFft(fft::Plan *plan,
   5158                         const DeviceMemory<std::complex<double>> &input,
   5159                         DeviceMemory<std::complex<double>> *output) {
   5160   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   5161 
   5162   if (ok()) {
   5163     if (fft::FftSupport *fft = parent_->AsFft()) {
   5164       CheckError(fft->DoFft(this, plan, input, output));
   5165     } else {
   5166       SetError();
   5167       LOG(INFO) << DebugStreamPointers()
   5168                 << " attempting to perform FFT operation using StreamExecutor"
   5169                    " without FFT support";
   5170     }
   5171   }
   5172   return *this;
   5173 }
   5174 
   5175 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<float> &input,
   5176                         DeviceMemory<std::complex<float>> *output) {
   5177   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   5178 
   5179   if (ok()) {
   5180     if (fft::FftSupport *fft = parent_->AsFft()) {
   5181       CheckError(fft->DoFft(this, plan, input, output));
   5182     } else {
   5183       SetError();
   5184       LOG(INFO) << DebugStreamPointers()
   5185                 << " attempting to perform FFT operation using StreamExecutor"
   5186                    " without FFT support";
   5187     }
   5188   }
   5189   return *this;
   5190 }
   5191 
   5192 Stream &Stream::ThenFft(fft::Plan *plan, const DeviceMemory<double> &input,
   5193                         DeviceMemory<std::complex<double>> *output) {
   5194   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   5195 
   5196   if (ok()) {
   5197     if (fft::FftSupport *fft = parent_->AsFft()) {
   5198       CheckError(fft->DoFft(this, plan, input, output));
   5199     } else {
   5200       SetError();
   5201       LOG(INFO) << DebugStreamPointers()
   5202                 << " attempting to perform FFT operation using StreamExecutor"
   5203                    " without FFT support";
   5204     }
   5205   }
   5206   return *this;
   5207 }
   5208 
   5209 Stream &Stream::ThenFft(fft::Plan *plan,
   5210                         const DeviceMemory<std::complex<float>> &input,
   5211                         DeviceMemory<float> *output) {
   5212   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   5213 
   5214   if (ok()) {
   5215     if (fft::FftSupport *fft = parent_->AsFft()) {
   5216       CheckError(fft->DoFft(this, plan, input, output));
   5217     } else {
   5218       SetError();
   5219       LOG(INFO) << DebugStreamPointers()
   5220                 << " attempting to perform FFT operation using StreamExecutor"
   5221                    " without FFT support";
   5222     }
   5223   }
   5224   return *this;
   5225 }
   5226 
   5227 Stream &Stream::ThenFft(fft::Plan *plan,
   5228                         const DeviceMemory<std::complex<double>> &input,
   5229                         DeviceMemory<double> *output) {
   5230   VLOG_CALL(PARAM(plan), PARAM(input), PARAM(output));
   5231 
   5232   if (ok()) {
   5233     if (fft::FftSupport *fft = parent_->AsFft()) {
   5234       CheckError(fft->DoFft(this, plan, input, output));
   5235     } else {
   5236       SetError();
   5237       LOG(INFO) << DebugStreamPointers()
   5238                 << " attempting to perform FFT operation using StreamExecutor"
   5239                    " without FFT support";
   5240     }
   5241   }
   5242   return *this;
   5243 }
   5244 
   5245 // It looks confusing, but all this is doing is inserting a callback at the
   5246 // present point in the stream to then enqueue a task on the host executor.
   5247 Stream &Stream::ThenEnqueueOnBackgroundThread(
   5248     std::function<void(StreamExecutor *)> task) {
   5249   VLOG_CALL(PARAM(task));
   5250 
   5251   StreamExecutor *stream_executor = this->parent_;
   5252   std::function<void()> bound_task = std::bind(task, stream_executor);
   5253 
   5254   return ThenDoHostCallback([stream_executor, bound_task]() {
   5255     stream_executor->EnqueueOnBackgroundThread(bound_task);
   5256   });
   5257 }
   5258 
   5259 port::Status Stream::BlockHostUntilDone() {
   5260   VLOG_CALL();
   5261 
   5262   if (!ok()) {
   5263     port::Status status = port::Status(
   5264         port::error::INTERNAL,
   5265         "stream did not block host until done; was already in an error state");
   5266     LOG(INFO) << DebugStreamPointers() << " " << status;
   5267     return status;
   5268   }
   5269 
   5270   temporary_memory_manager_.DeallocateFinalizedTemporaries();
   5271 
   5272   port::Status error = parent_->BlockHostUntilDone(this);
   5273   CheckError(error.ok());
   5274   return error;
   5275 }
   5276 
   5277 string Stream::DebugStreamPointers() const {
   5278   // Relies on the ToVlogString(const void*) overload above.
   5279   return absl::StrCat("[stream=", ToVlogString(this),
   5280                       ",impl=", ToVlogString(implementation_.get()), "]");
   5281 }
   5282 
   5283 void Stream::CheckStatus(port::Status status) {
   5284   if (status.ok()) {
   5285     return;
   5286   }
   5287   LOG(ERROR) << status;
   5288   mutex_lock lock(mu_);
   5289   ok_ = false;
   5290 }
   5291 
   5292 }  // namespace stream_executor
   5293