Home | History | Annotate | Download | only in cuda
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
     17 
     18 #if defined(__APPLE__)
     19 #include <mach-o/dyld.h>
     20 #endif
     21 #if defined(PLATFORM_WINDOWS)
     22 #include <windows.h>
     23 #define PATH_MAX MAX_PATH
     24 #else
     25 #include <unistd.h>
     26 #endif
     27 #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
     28 #include "tensorflow/stream_executor/cuda/cuda_driver.h"
     29 #include "tensorflow/stream_executor/cuda/cuda_event.h"
     30 #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
     31 #include "tensorflow/stream_executor/cuda/cuda_stream.h"
     32 #include "tensorflow/stream_executor/cuda/cuda_timer.h"
     33 #include "tensorflow/stream_executor/kernel_cache_config.h"
     34 #include "tensorflow/stream_executor/lib/casts.h"
     35 #include "tensorflow/stream_executor/lib/env.h"
     36 #include "tensorflow/stream_executor/lib/error.h"
     37 #include "tensorflow/stream_executor/lib/initialize.h"
     38 #include "tensorflow/stream_executor/lib/mathutil.h"
     39 #include "tensorflow/stream_executor/lib/path.h"
     40 #include "tensorflow/stream_executor/lib/process_state.h"
     41 #include "tensorflow/stream_executor/lib/ptr_util.h"
     42 #include "tensorflow/stream_executor/lib/statusor.h"
     43 #include "tensorflow/stream_executor/lib/str_util.h"
     44 #include "tensorflow/stream_executor/lib/strcat.h"
     45 #include "tensorflow/stream_executor/lib/stringprintf.h"
     46 #include "tensorflow/stream_executor/platform.h"
     47 #include "tensorflow/stream_executor/platform/logging.h"
     48 #include "tensorflow/stream_executor/platform/port.h"
     49 #include "tensorflow/stream_executor/plugin_registry.h"
     50 #include "tensorflow/stream_executor/stream.h"
     51 #include "tensorflow/stream_executor/stream_executor_internal.h"
     52 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
     53 #include "tensorflow/stream_executor/timer.h"
     54 #include "tensorflow/stream_executor/lib/numbers.h"
     55 
     56 #ifdef PLATFORMS_GPUS_CUDA_DYNAMIC_LIBCUDA_DYNAMIC_LIBCUDA_H_
     57 #error \
     58     "No driver calls in this file, wrap driver functionality in cuda_driver.cc."
     59 #endif
     60 
     61 #ifdef __CUDA_RUNTIME_H__
     62 #error \
     63     "CUDA runtime being included into CUDA GPU executor; should be driver only."
     64 #endif
     65 
     66 extern bool FLAGS_check_gpu_leaks;
     67 bool FLAGS_prefer_cubin_to_ptx = true;
     68 
     69 namespace perftools {
     70 namespace gputools {
     71 namespace cuda {
     72 
     73 // Hook that can be used to CUBIN-ate PTX before it is loaded into the driver.
     74 // It has been observed that loading both PTX and cubins into the driver library
     75 // can cause it to crash, but loading only CUBINs avoids those crashes;
     76 // therefore, it's useful to have this hook to hack in uniform CUBIN-ation of
     77 // PTX code.
     78 //
     79 // As this is an implementation-detail workaround, the usage is to declare this
     80 // variable with extern linkage and populate it from another translation unit.
     81 std::function<string(const string &)> g_cubinate;
     82 
     83 static CUDAEvent *AsCUDAEvent(Event *event) {
     84   DCHECK(event != nullptr);
     85   return static_cast<CUDAEvent *>(event->implementation());
     86 }
     87 
     88 
     89 // Given a platform-independent timer datatype, returns the internal CUDA
     90 // platform implementation pointer.
     91 static CUDATimer *AsCUDATimer(Timer *timer) {
     92   DCHECK(timer != nullptr);
     93   return static_cast<CUDATimer *>(timer->implementation());
     94 }
     95 
     96 // Given const GPU memory, returns a libcuda device pointer datatype, suitable
     97 // for passing directly to libcuda APIs.
     98 //
     99 // N.B. we must lose constness in order to pass a suitable type to the existing
    100 // libcuda APIs, so the caller should take care to only pass the result of const
    101 // GPU memory conversions to libcuda functions which will honor constness.
    102 static CUdeviceptr AsCudaDevicePtr(const DeviceMemoryBase &gpu_mem) {
    103   return reinterpret_cast<CUdeviceptr>(gpu_mem.opaque());
    104 }
    105 
    106 // See description on const version above.
    107 static CUdeviceptr AsCudaDevicePtr(DeviceMemoryBase *gpu_mem) {
    108   return AsCudaDevicePtr(*gpu_mem);
    109 }
    110 
    111 CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec) {
    112   CHECK(cuda_exec != nullptr);
    113   return cuda_exec->cuda_context();
    114 }
    115 
    116 CUDAExecutor *ExtractCudaExecutor(StreamExecutor *stream_exec) {
    117   return static_cast<CUDAExecutor *>(stream_exec->implementation());
    118 }
    119 
    120 CUDAExecutor::~CUDAExecutor() {
    121   CHECK(kernel_to_gpu_binary_.empty()) << "CUDAExecutor has live kernels.";
    122   CHECK(gpu_binary_to_module_.empty()) << "CUDAExecutor has loaded modules.";
    123   if (context_ != nullptr) {
    124     CUDADriver::DestroyContext(context_);
    125   }
    126 }
    127 
    128 port::Status CUDAExecutor::Init(int device_ordinal,
    129                                 DeviceOptions device_options) {
    130   device_ordinal_ = device_ordinal;
    131 
    132   auto status = CUDADriver::Init();
    133   if (!status.ok()) {
    134     return status;
    135   }
    136 
    137   status = CUDADriver::GetDevice(device_ordinal_, &device_);
    138   if (!status.ok()) {
    139     return status;
    140   }
    141 
    142   status = CUDADriver::CreateContext(device_, device_options, &context_);
    143   if (!status.ok()) {
    144     return status;
    145   }
    146 
    147   return CUDADriver::GetComputeCapability(&cc_major_, &cc_minor_, device_);
    148 }
    149 
    150 bool CUDAExecutor::FindOnDiskForComputeCapability(
    151     port::StringPiece filename, port::StringPiece canonical_suffix,
    152     string *found_filename) const {
    153   if (cc_major_ == 0 && cc_minor_ == 0) {
    154     return false;
    155   }
    156 
    157   string cc_specific =
    158       port::StrCat(filename, ".cc", cc_major_, cc_minor_, canonical_suffix);
    159   if (port::FileExists(cc_specific).ok()) {
    160     VLOG(2) << "found compute-capability-specific file, using that: "
    161             << cc_specific;
    162     *found_filename = cc_specific;
    163     return true;
    164   }
    165 
    166   VLOG(2) << "could not find compute-capability specific file at: "
    167           << cc_specific;
    168   if (port::FileExists(filename.ToString()).ok()) {
    169     *found_filename = filename.ToString();
    170     return true;
    171   }
    172 
    173   return false;
    174 }
    175 
    176 // Returns the path to the running executable.
    177 // N.B. Derived from //knowledge/smalltalk/background_kb.cc
    178 // Arg: strip_exe: if true, remove the name of the executable itself from the
    179 //                 returned string. Example: calling this from /usr/bin/foo
    180 //                 would return /usr/bin.
    181 static string GetBinaryDir(bool strip_exe) {
    182   char exe_path[PATH_MAX] = {0};
    183 #if defined(__APPLE__)
    184     uint32_t buffer_size = 0U;
    185     _NSGetExecutablePath(nullptr, &buffer_size);
    186     char unresolved_path[buffer_size];
    187     _NSGetExecutablePath(unresolved_path, &buffer_size);
    188     CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1);
    189 #else
    190 #if defined(PLATFORM_WINDOWS)
    191   HMODULE hModule = GetModuleHandle(NULL);
    192   GetModuleFileName(hModule, exe_path, MAX_PATH);
    193 #else
    194   CHECK_ERR(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1));
    195 #endif
    196 #endif
    197   // Make sure it's null-terminated:
    198   exe_path[sizeof(exe_path) - 1] = 0;
    199 
    200   if (strip_exe) {
    201     // The exe is the last component of the path, so remove one component.
    202     string ret = exe_path;
    203     std::vector<string> components = port::Split(exe_path, '/');
    204     components.pop_back();
    205     return port::Join(components, "/");
    206   }
    207   return exe_path;
    208 }
    209 
    210 bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
    211                              KernelBase *kernel) {
    212   CUDAKernel *cuda_kernel = AsCUDAKernel(kernel);
    213   CUmodule module;
    214   const string *kernelname;
    215 
    216   VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name();
    217 
    218   if (spec.has_cuda_cubin_in_memory()) {
    219     kernelname = &spec.cuda_cubin_in_memory().kernelname();
    220     const char *cubin = spec.cuda_cubin_in_memory().bytes();
    221     mutex_lock lock{in_memory_modules_mu_};
    222     uint64_t module_refcount;
    223     std::tie(module, module_refcount) = gpu_binary_to_module_[cubin];
    224 
    225     if (module == nullptr) {
    226       auto load_status = CUDADriver::LoadCubin(context_, cubin, &module);
    227       if (!load_status.ok()) {
    228         LOG(ERROR) << "failed to load CUBIN: " << load_status;
    229         return false;
    230       }
    231       module_refcount = 1;
    232       VLOG(3) << "Loaded CUBIN " << static_cast<const void *>(cubin)
    233               << " as module " << module;
    234     } else {
    235       ++module_refcount;
    236       VLOG(3) << "CUBIN " << static_cast<const void *>(cubin)
    237               << " is already loaded as module " << module;
    238     }
    239     kernel_to_gpu_binary_[kernel] = cubin;
    240     gpu_binary_to_module_[cubin] = {module, module_refcount};
    241   } else if (spec.has_cuda_ptx_in_memory()) {
    242     kernelname = &spec.cuda_ptx_in_memory().kernelname();
    243 
    244     if (cc_major_ == 0 && cc_minor_ == 0) {
    245       return false;
    246     }
    247 
    248     const char *ptx = spec.cuda_ptx_in_memory().text(cc_major_, cc_minor_);
    249     if (ptx == nullptr) {
    250       ptx = spec.cuda_ptx_in_memory().default_text();
    251     }
    252     if (ptx == nullptr) {
    253       LOG(FATAL) << "loader spec has no ptx for kernel " << *kernelname;
    254       return false;
    255     }
    256 
    257     mutex_lock lock{in_memory_modules_mu_};
    258     uint64_t module_refcount;
    259     std::tie(module, module_refcount) = gpu_binary_to_module_[ptx];
    260 
    261     if (module == nullptr) {
    262       if (!CUDADriver::LoadPtx(context_, ptx, &module)) {
    263         LOG(ERROR) << "failed to load PTX for kernel " << *kernelname;
    264         return false;
    265       }
    266       VLOG(3) << "Loaded PTX " << static_cast<const void *>(ptx)
    267               << " as module " << module;
    268       module_refcount = 1;
    269     } else {
    270       ++module_refcount;
    271       VLOG(3) << "PTX " << static_cast<const void *>(ptx)
    272               << " is already loaded as module " << module;
    273     }
    274     kernel_to_gpu_binary_[kernel] = ptx;
    275     gpu_binary_to_module_[ptx] = {module, module_refcount};
    276   } else {
    277     LOG(WARNING) << "no method of loading CUDA kernel provided";
    278     return false;
    279   }
    280   VLOG(2) << "getting function " << *kernelname << " from module " << module;
    281   if (!CUDADriver::GetModuleFunction(context_, module, kernelname->c_str(),
    282                                      cuda_kernel->cuda_function_ptr())) {
    283     return false;
    284   }
    285 
    286   // We have to trust the kernel loader spec arity because there doesn't appear
    287   // to be a way to reflect on the number of expected arguments w/the CUDA API.
    288   cuda_kernel->set_arity(spec.arity());
    289 
    290   KernelMetadata kernel_metadata;
    291   if (!GetKernelMetadata(cuda_kernel, &kernel_metadata)) {
    292     LOG(WARNING) << "unable to get metadata for kernel " << *kernelname;
    293   }
    294   kernel->set_metadata(kernel_metadata);
    295   kernel->set_name(*kernelname);
    296   return true;
    297 }
    298 
    299 void CUDAExecutor::UnloadKernel(const KernelBase *kernel) {
    300   VLOG(3) << "Unloading kernel " << kernel << " : " << kernel->name();
    301 
    302   mutex_lock lock{in_memory_modules_mu_};
    303   auto gpu_binary_it = kernel_to_gpu_binary_.find(kernel);
    304   if (kernel_to_gpu_binary_.end() == gpu_binary_it) {
    305     VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
    306             << " has never been loaded.";
    307     return;  // We've never seen this kernel.
    308   }
    309   VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
    310           << " has loaded GPU code " << gpu_binary_it->second;
    311   auto module_it = gpu_binary_to_module_.find(gpu_binary_it->second);
    312   if (gpu_binary_to_module_.end() == module_it) {
    313     VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
    314             << " has no loaded CUDA module.";
    315     return;  // This kernel never loaded any modules
    316   }
    317   auto &module = module_it->second.first;
    318   auto &refcount = module_it->second.second;
    319   VLOG(3) << "Kernel " << kernel << " : " << kernel->name()
    320           << " has loaded GPU code " << gpu_binary_it->second
    321           << " into CUDA module " << module << " with refcount " << refcount;
    322   if (--refcount == 0) {
    323     VLOG(3) << "Unloading CUDA module " << module;
    324     CUDADriver::UnloadModule(context_, module);
    325     gpu_binary_to_module_.erase(module_it);
    326   }
    327   kernel_to_gpu_binary_.erase(gpu_binary_it);
    328 }
    329 
    330 bool CUDAExecutor::GetKernelMetadata(CUDAKernel *cuda_kernel,
    331                                      KernelMetadata *kernel_metadata) {
    332   int value;
    333   if (!CUDADriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_NUM_REGS,
    334                                     *cuda_kernel->cuda_function_ptr(),
    335                                     &value)) {
    336     return false;
    337   }
    338   kernel_metadata->set_registers_per_thread(value);
    339 
    340   if (!CUDADriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES,
    341                                     *cuda_kernel->cuda_function_ptr(),
    342                                     &value)) {
    343     return false;
    344   }
    345   kernel_metadata->set_shared_memory_bytes(value);
    346 
    347   return true;
    348 }
    349 
    350 bool CUDAExecutor::Launch(Stream *stream, const ThreadDim &thread_dims,
    351                           const BlockDim &block_dims, const KernelBase &kernel,
    352                           const KernelArgsArrayBase &args) {
    353   CHECK_EQ(kernel.Arity(), args.number_of_arguments());
    354   CUstream custream = AsCUDAStreamValue(stream);
    355   const CUDAKernel *cuda_kernel = AsCUDAKernel(&kernel);
    356   CUfunction cufunc = cuda_kernel->AsCUDAFunctionValue();
    357 
    358   // Only perform/print the occupancy check once.  Even just checking to see
    359   // whether we've done an occupancy check on this kernel before isn't free
    360   // (because we have to synchronize), so we only do this at -v 2+.
    361   if (VLOG_IS_ON(2)) {
    362     mutex_lock lock(launched_kernels_mu_);
    363     if (!launched_kernels_.count(cufunc)) {
    364       VlogOccupancyInfo(kernel, thread_dims, block_dims);
    365       // TODO(rspringer): Remove elements from launched_kernels_...if we ever
    366       // expose a kernel/module deallocation method.
    367       launched_kernels_.insert(cufunc);
    368     }
    369   }
    370 
    371   if (cuda_kernel->GetPreferredCacheConfig() !=
    372       KernelCacheConfig::kNoPreference) {
    373     CUDADriver::FuncSetCacheConfig(cufunc, cuda_kernel->GetCUDACacheConfig());
    374   }
    375 
    376   void **kernel_params = const_cast<void **>(args.argument_addresses().data());
    377 
    378   if (!CUDADriver::LaunchKernel(context_, cufunc, block_dims.x, block_dims.y,
    379                                 block_dims.z, thread_dims.x, thread_dims.y,
    380                                 thread_dims.z, args.number_of_shared_bytes(),
    381                                 custream, kernel_params,
    382                                 nullptr /* = extra */)) {
    383     LOG(ERROR) << "failed to launch CUDA kernel " << kernel.name() << " with "
    384                << args.number_of_arguments()
    385                << " args; thread dim: " << thread_dims.ToString()
    386                << "; block dim: " << block_dims.ToString();
    387     return false;
    388   }
    389 
    390   return true;
    391 }
    392 
    393 // This is a non-essential operation; if there's a failure, proceed without
    394 // logging an error. It's nearly certain that in case of failures, we'd never
    395 // get here in the first place; these are very low-impact routines.
    396 void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel,
    397                                      const ThreadDim &thread_dims,
    398                                      const BlockDim &block_dims) {
    399   VLOG(2) << "Computing kernel occupancy for kernel "
    400           << kernel.demangled_name();
    401   VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y
    402           << ", " << thread_dims.z << ")";
    403 
    404   int regs_per_thread;
    405   if (!kernel.metadata().registers_per_thread(&regs_per_thread)) {
    406     return;
    407   }
    408 
    409   int smem_per_block;
    410   if (!kernel.metadata().shared_memory_bytes(&smem_per_block)) {
    411     return;
    412   }
    413 
    414   const DeviceDescription &device_description =
    415       kernel.parent()->GetDeviceDescription();
    416 
    417   uint64 blocks_per_sm = CalculateOccupancy(
    418       device_description, regs_per_thread, smem_per_block, thread_dims);
    419   VLOG(2) << "Resident blocks per SM is " << blocks_per_sm;
    420 
    421   // To increase occupancy, there must be a sufficient number of blocks
    422   // available to spread across the sm's at this new improved occupancy level.
    423   int multiprocessor_count = device_description.core_count();
    424   int block_count = block_dims.x * block_dims.y * block_dims.z;
    425   int available_blocks_per_sm =
    426       port::MathUtil::CeilOfRatio(block_count, multiprocessor_count);
    427   if (available_blocks_per_sm <= static_cast<int64>(blocks_per_sm)) {
    428     VLOG(2) << "Occupancy is limited by number of blocks available per sm.";
    429     return;
    430   }
    431 
    432   uint64 improved_regs_per_thread = CalculateRegisterLimitForTargetOccupancy(
    433       device_description, smem_per_block, thread_dims, blocks_per_sm + 1);
    434   if (improved_regs_per_thread != 0) {
    435     VLOG(2) << "Reducing register usage from " << regs_per_thread
    436             << " to " << improved_regs_per_thread
    437             << " could increase resident blocks per SM by one.";
    438   } else {
    439     VLOG(2) << "Resident blocks per SM cannot be increased by reducing "
    440         "register usage.";
    441   }
    442 }
    443 
    444 void *CUDAExecutor::Allocate(uint64 size) {
    445   return CUDADriver::DeviceAllocate(context_, size);
    446 }
    447 
    448 void *CUDAExecutor::AllocateSubBuffer(DeviceMemoryBase *mem,
    449                                       uint64 offset_bytes, uint64 size_bytes) {
    450   // offset and size are in bytes, so char* works as the pointer type.
    451   return reinterpret_cast<char *>(mem->opaque()) + offset_bytes;
    452 }
    453 
    454 void CUDAExecutor::Deallocate(DeviceMemoryBase *mem) {
    455   // CUDA "sub-buffers" are just pointer + offset, so no dealloc is necessary.
    456   if (!mem->is_sub_buffer()) {
    457     CUDADriver::DeviceDeallocate(context_, mem->opaque());
    458   }
    459 }
    460 
    461 bool CUDAExecutor::HostMemoryRegister(void *location, uint64 size) {
    462   if (location == nullptr || size == 0) {
    463     LOG(WARNING) << "attempting to register null or zero-sized memory: "
    464                  << location << "; size " << size;
    465   }
    466   VLOG(2) << "registering " << location << " size " << size;
    467   return CUDADriver::HostRegister(context_, location, size);
    468 }
    469 
    470 bool CUDAExecutor::HostMemoryUnregister(void *location) {
    471   VLOG(2) << "unregistering " << location;
    472   return CUDADriver::HostUnregister(context_, location);
    473 }
    474 
    475 bool CUDAExecutor::SynchronizeAllActivity() {
    476   return CUDADriver::SynchronizeContext(context_);
    477 }
    478 
    479 bool CUDAExecutor::SynchronousMemZero(DeviceMemoryBase *location, uint64 size) {
    480   if (reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 &&
    481       size % 4 == 0) {
    482     return CUDADriver::SynchronousMemsetUint32(
    483         context_, AsCudaDevicePtr(location), 0x0, size / 4);
    484   }
    485   return CUDADriver::SynchronousMemsetUint8(context_, AsCudaDevicePtr(location),
    486                                             0x0, size);
    487 }
    488 
    489 bool CUDAExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
    490                                      uint64 size) {
    491   if (reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 &&
    492       size % 4 == 0) {
    493     // cudaMemset reinterprets "value" as a uint8.
    494     uint8 byte_value = static_cast<uint8>(value);
    495     uint32 pattern = (byte_value << 24) | (byte_value << 16) |
    496                      (byte_value << 8) | byte_value;
    497     return CUDADriver::SynchronousMemsetUint32(
    498         context_, AsCudaDevicePtr(location), pattern, size / 4);
    499   }
    500   return CUDADriver::SynchronousMemsetUint8(context_, AsCudaDevicePtr(location),
    501                                             value, size);
    502 }
    503 
    504 port::Status CUDAExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
    505                                              const void *host_src,
    506                                              uint64 size) {
    507   return CUDADriver::SynchronousMemcpyH2D(context_, AsCudaDevicePtr(gpu_dst),
    508                                           host_src, size);
    509 }
    510 
    511 port::Status CUDAExecutor::SynchronousMemcpy(void *host_dst,
    512                                              const DeviceMemoryBase &gpu_src,
    513                                              uint64 size) {
    514   return CUDADriver::SynchronousMemcpyD2H(context_, host_dst,
    515                                           AsCudaDevicePtr(gpu_src), size);
    516 }
    517 
    518 port::Status CUDAExecutor::SynchronousMemcpyDeviceToDevice(
    519     DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64 size) {
    520   return CUDADriver::SynchronousMemcpyD2D(context_, AsCudaDevicePtr(gpu_dst),
    521                                           AsCudaDevicePtr(gpu_src), size);
    522 }
    523 
    524 bool CUDAExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
    525                            uint64 size) {
    526   if (reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 &&
    527       size % 4 == 0) {
    528     return Memset32(stream, location, 0x0, size);
    529   } else {
    530     return Memset(stream, location, 0x0, size);
    531   }
    532 }
    533 
    534 bool CUDAExecutor::Memset(Stream *stream, DeviceMemoryBase *location,
    535                            uint8 pattern, uint64 size) {
    536   VLOG(2) << "enqueueing memset8 operation onto stream " << stream
    537           << " at location " << location << " with size " << size
    538           << " and pattern " << std::hex << pattern;
    539   return CUDADriver::AsynchronousMemsetUint8(
    540       context_, AsCudaDevicePtr(location), pattern, size,
    541       AsCUDAStreamValue(stream));
    542 }
    543 
    544 bool CUDAExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
    545                             uint32 pattern, uint64 size) {
    546   VLOG(2) << "enqueueing memset32 operation onto stream " << stream
    547           << " at location " << location << " with size " << size
    548           << " and pattern " << std::hex << pattern;
    549   CHECK(reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 &&
    550         size % 4 == 0);
    551   return CUDADriver::AsynchronousMemsetUint32(
    552       context_, AsCudaDevicePtr(location), pattern, size / 4,
    553       AsCUDAStreamValue(stream));
    554 }
    555 
    556 bool CUDAExecutor::Memcpy(Stream *stream, void *host_dst,
    557                           const DeviceMemoryBase &gpu_src, uint64 size) {
    558   return CUDADriver::AsynchronousMemcpyD2H(context_, host_dst,
    559                                            AsCudaDevicePtr(gpu_src), size,
    560                                            AsCUDAStreamValue(stream));
    561 }
    562 
    563 bool CUDAExecutor::Memcpy(Stream *stream, DeviceMemoryBase *gpu_dst,
    564                           const void *host_src, uint64 size) {
    565   return CUDADriver::AsynchronousMemcpyH2D(context_, AsCudaDevicePtr(gpu_dst),
    566                                            host_src, size,
    567                                            AsCUDAStreamValue(stream));
    568 }
    569 
    570 bool CUDAExecutor::MemcpyDeviceToDevice(Stream *stream,
    571                                         DeviceMemoryBase *gpu_dst,
    572                                         const DeviceMemoryBase &gpu_src,
    573                                         uint64 size) {
    574   return CUDADriver::AsynchronousMemcpyD2D(context_, AsCudaDevicePtr(gpu_dst),
    575                                            AsCudaDevicePtr(gpu_src), size,
    576                                            AsCUDAStreamValue(stream));
    577 }
    578 
    579 bool CUDAExecutor::HostCallback(Stream *stream,
    580                                 std::function<void()> callback) {
    581   auto callback_ptr = new std::function<void()>(callback);
    582   return CUDADriver::AddStreamCallback(context_, AsCUDAStreamValue(stream),
    583                                        InternalHostCallback, callback_ptr);
    584 }
    585 
    586 /* static */ void CUDAExecutor::InternalHostCallback(CUstream stream,
    587                                                      CUresult status,
    588                                                      void *data) {
    589   std::function<void()> *callback =
    590       reinterpret_cast<std::function<void()> *>(data);
    591   (*callback)();
    592   delete callback;
    593 }
    594 
    595 port::Status CUDAExecutor::AllocateEvent(Event *event) {
    596   return AsCUDAEvent(event)->Init();
    597 }
    598 
    599 port::Status CUDAExecutor::DeallocateEvent(Event *event) {
    600   return AsCUDAEvent(event)->Destroy();
    601 }
    602 
    603 port::Status CUDAExecutor::RecordEvent(Stream *stream, Event *event) {
    604   return AsCUDAEvent(event)->Record(AsCUDAStream(stream));
    605 }
    606 
    607 port::Status CUDAExecutor::WaitForEvent(Stream *stream, Event *event) {
    608   if (CUDADriver::WaitStreamOnEvent(context_,
    609                                     AsCUDAStream(stream)->cuda_stream(),
    610                                     AsCUDAEvent(event)->cuda_event())) {
    611     return port::Status::OK();
    612   } else {
    613     return port::Status{
    614         port::error::INTERNAL,
    615         port::Printf("error recording waiting for CUDA event on stream %p",
    616                      stream)};
    617   }
    618 }
    619 
    620 Event::Status CUDAExecutor::PollForEventStatus(Event *event) {
    621   return AsCUDAEvent(event)->PollForStatus();
    622 }
    623 
    624 bool CUDAExecutor::AllocateStream(Stream *stream) {
    625   return AsCUDAStream(stream)->Init();
    626 }
    627 
    628 void CUDAExecutor::DeallocateStream(Stream *stream) {
    629   CUDAStream *cuda_stream = AsCUDAStream(stream);
    630   if (!cuda_stream->IsIdle()) {
    631     LOG(ERROR) << "Deallocating stream with pending work";
    632   }
    633   cuda_stream->Destroy();
    634 }
    635 
    636 bool CUDAExecutor::AllocateTimer(Timer *timer) {
    637   return AsCUDATimer(timer)->Init();
    638 }
    639 
    640 void CUDAExecutor::DeallocateTimer(Timer *timer) {
    641   AsCUDATimer(timer)->Destroy();
    642 }
    643 
    644 bool CUDAExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
    645   CUevent other_completed_event = *AsCUDAStream(other)->completed_event();
    646   bool ok = CUDADriver::RecordEvent(context_, other_completed_event,
    647                                     AsCUDAStreamValue(other))
    648       .ok();
    649   if (!ok) {
    650     LOG(ERROR) << "failed to record completion event; "
    651                   "therefore, failed to create inter-stream dependency";
    652     return false;
    653   }
    654 
    655   return CUDADriver::WaitStreamOnEvent(context_, AsCUDAStreamValue(dependent),
    656                                        other_completed_event);
    657 }
    658 
    659 bool CUDAExecutor::StartTimer(Stream *stream, Timer *timer) {
    660   return AsCUDATimer(timer)->Start(AsCUDAStream(stream));
    661 }
    662 
    663 bool CUDAExecutor::StopTimer(Stream *stream, Timer *timer) {
    664   return AsCUDATimer(timer)->Stop(AsCUDAStream(stream));
    665 }
    666 
    667 port::Status CUDAExecutor::BlockHostUntilDone(Stream *stream) {
    668   return CUDADriver::SynchronizeStream(context_, AsCUDAStreamValue(stream));
    669 }
    670 
    671 blas::BlasSupport *CUDAExecutor::CreateBlas() {
    672   PluginRegistry *registry = PluginRegistry::Instance();
    673   port::StatusOr<PluginRegistry::BlasFactory> status =
    674       registry->GetFactory<PluginRegistry::BlasFactory>(kCudaPlatformId,
    675                                                         plugin_config_.blas());
    676   if (!status.ok()) {
    677     LOG(ERROR) << "Unable to retrieve BLAS factory: "
    678                << status.status().error_message();
    679     return nullptr;
    680   }
    681 
    682   return status.ValueOrDie()(this);
    683 }
    684 
    685 dnn::DnnSupport *CUDAExecutor::CreateDnn() {
    686   PluginRegistry *registry = PluginRegistry::Instance();
    687   port::StatusOr<PluginRegistry::DnnFactory> status =
    688       registry->GetFactory<PluginRegistry::DnnFactory>(kCudaPlatformId,
    689                                                        plugin_config_.dnn());
    690   if (!status.ok()) {
    691     LOG(ERROR) << "Unable to retrieve DNN factory: "
    692                << status.status().error_message();
    693     return nullptr;
    694   }
    695 
    696   return status.ValueOrDie()(this);
    697 }
    698 
    699 fft::FftSupport *CUDAExecutor::CreateFft() {
    700   PluginRegistry *registry = PluginRegistry::Instance();
    701   port::StatusOr<PluginRegistry::FftFactory> status =
    702       registry->GetFactory<PluginRegistry::FftFactory>(kCudaPlatformId,
    703                                                        plugin_config_.fft());
    704   if (!status.ok()) {
    705     LOG(ERROR) << "Unable to retrieve FFT factory: "
    706                << status.status().error_message();
    707     return nullptr;
    708   }
    709 
    710   return status.ValueOrDie()(this);
    711 }
    712 
    713 rng::RngSupport *CUDAExecutor::CreateRng() {
    714   PluginRegistry *registry = PluginRegistry::Instance();
    715   port::StatusOr<PluginRegistry::RngFactory> status =
    716       registry->GetFactory<PluginRegistry::RngFactory>(kCudaPlatformId,
    717                                                        plugin_config_.rng());
    718   if (!status.ok()) {
    719     LOG(ERROR) << "Unable to retrieve RNG factory: "
    720                << status.status().error_message();
    721     return nullptr;
    722   }
    723 
    724   return status.ValueOrDie()(this);
    725 }
    726 
    727 // TODO(rspringer): Remove in b/18544742.
    728 bool CUDAExecutor::SupportsDnn() const {
    729   return true;
    730 }
    731 
    732 bool CUDAExecutor::CanEnablePeerAccessTo(StreamExecutorInterface *other) {
    733   CUDAExecutor *cuda_other = static_cast<CUDAExecutor *>(other);
    734   return CUDADriver::CanEnablePeerAccess(context_, cuda_other->context_);
    735 }
    736 
    737 port::Status CUDAExecutor::EnablePeerAccessTo(StreamExecutorInterface *other) {
    738   CUDAExecutor *cuda_other = static_cast<CUDAExecutor *>(other);
    739   return CUDADriver::EnablePeerAccess(context_, cuda_other->context_);
    740 }
    741 
    742 SharedMemoryConfig CUDAExecutor::GetDeviceSharedMemoryConfig() {
    743   port::StatusOr<CUsharedconfig> cuda_config =
    744       CUDADriver::ContextGetSharedMemConfig(context_);
    745   if (!cuda_config.ok()) {
    746     // Don't log; the failed call will log necessary output.
    747     return SharedMemoryConfig::kDefault;
    748   }
    749 
    750   switch (cuda_config.ValueOrDie()) {
    751     case CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE:
    752       return SharedMemoryConfig::kDefault;
    753     case CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE:
    754       return SharedMemoryConfig::kFourByte;
    755     case CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE:
    756       return SharedMemoryConfig::kEightByte;
    757     default:
    758       LOG(FATAL) << "Invalid shared memory configuration returned: "
    759                  << cuda_config.ValueOrDie();
    760   }
    761 }
    762 
    763 port::Status CUDAExecutor::SetDeviceSharedMemoryConfig(
    764     SharedMemoryConfig config) {
    765   CUsharedconfig cuda_config;
    766   switch (config) {
    767     case SharedMemoryConfig::kDefault:
    768       cuda_config = CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE;
    769       break;
    770     case SharedMemoryConfig::kFourByte:
    771       cuda_config = CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE;
    772       break;
    773     case SharedMemoryConfig::kEightByte:
    774       cuda_config = CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE;
    775       break;
    776     default:
    777       LOG(FATAL) << "Invalid shared memory configuration specified: "
    778                  << static_cast<int>(config);
    779   }
    780   return CUDADriver::ContextSetSharedMemConfig(context_, cuda_config);
    781 }
    782 
    783 bool CUDAExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const {
    784   return CUDADriver::GetDeviceMemoryInfo(context_, free, total);
    785 }
    786 
    787 bool CUDAExecutor::GetSymbol(const string& symbol_name, void **mem,
    788                              size_t *bytes) {
    789   {  // give limited scope to mutex_lock
    790     mutex_lock lock{in_memory_modules_mu_};
    791     for (auto &it : gpu_binary_to_module_) {
    792       CUmodule module = it.second.first;
    793       CHECK(module != nullptr);
    794       if (CUDADriver::GetModuleSymbol(context_, module, symbol_name.c_str(),
    795                                       reinterpret_cast<CUdeviceptr *>(mem),
    796                                       bytes)) {
    797         return true;
    798       }
    799     }
    800   }
    801 
    802   LOG(INFO) << "Falied to find symbol in any modules: " << symbol_name;
    803   return false;
    804 }
    805 
    806 bool CUDAExecutor::FillBlockDimLimit(BlockDim *block_dim_limit) const {
    807   // The BlockDim name is a mismatch against these GRID_DIM_* queries because
    808   // we use BlockDims to express the dimensions of blocks within a grid
    809   // (as opposed to ThreadDim which expresses the dimensions of threads
    810   // within a block).
    811   int x, y, z;
    812   if (!CUDADriver::GetGridLimits(&x, &y, &z, device_)) {
    813     return false;
    814   }
    815 
    816   block_dim_limit->x = x;
    817   block_dim_limit->y = y;
    818   block_dim_limit->z = z;
    819   return true;
    820 }
    821 
    822 bool CUDAExecutor::SupportsBlas() const { return true; }
    823 
    824 bool CUDAExecutor::SupportsFft() const { return true; }
    825 
    826 bool CUDAExecutor::SupportsRng() const { return true; }
    827 
    828 std::unique_ptr<internal::EventInterface>
    829 CUDAExecutor::CreateEventImplementation() {
    830   return std::unique_ptr<internal::EventInterface>(new CUDAEvent(this));
    831 }
    832 
    833 std::unique_ptr<internal::KernelInterface>
    834 CUDAExecutor::CreateKernelImplementation() {
    835   return std::unique_ptr<internal::KernelInterface>(new CUDAKernel());
    836 }
    837 
    838 std::unique_ptr<internal::StreamInterface>
    839 CUDAExecutor::GetStreamImplementation() {
    840   return std::unique_ptr<internal::StreamInterface>(new CUDAStream(this));
    841 }
    842 
    843 std::unique_ptr<internal::TimerInterface>
    844 CUDAExecutor::GetTimerImplementation() {
    845   return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this));
    846 }
    847 
    848 void *CUDAExecutor::CudaContextHack() { return context_; }
    849 
    850 CudaContext* CUDAExecutor::cuda_context() { return context_; }
    851 
    852 // Attempts to read the NUMA node corresponding to the GPU device's PCI bus out
    853 // of SysFS. Returns -1 if it cannot.
    854 //
    855 // For anything more complicated/prod-focused than this, you'll likely want to
    856 // turn to gsys' topology modeling.
    857 static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) {
    858 #if defined(__APPLE__)
    859   LOG(INFO) << "OS X does not support NUMA - returning NUMA node zero";
    860   return 0;
    861 #elif defined(PLATFORM_WINDOWS)
    862   // Windows support for NUMA is not currently implemented. Return node 0.
    863   return 0;
    864 #elif defined(__aarch64__)
    865   LOG(INFO) << "ARM64 does not support NUMA - returning NUMA node zero";
    866   return 0;
    867 #else
    868   VLOG(2) << "trying to read NUMA node for device ordinal: " << device_ordinal;
    869   static const int kUnknownNumaNode = -1;
    870 
    871   if (pci_bus_id.empty()) {
    872     LOG(INFO) << "no PCI bus ID for device ordinal: " << device_ordinal;
    873     return kUnknownNumaNode;
    874   }
    875 
    876   string filename =
    877       port::Printf("/sys/bus/pci/devices/%s/numa_node", pci_bus_id.c_str());
    878 
    879   // We have to use fopen/fread here so that the device properties can be
    880   // populated before InitGoogle procedure has been completed (at which point we
    881   // could use the file::* utilities).
    882   FILE *file = fopen(filename.c_str(), "r");
    883   if (file == nullptr) {
    884     LOG(ERROR) << "could not open file to read NUMA node: " << filename
    885                << "\nYour kernel may have been built without NUMA support.";
    886     return kUnknownNumaNode;
    887   }
    888 
    889   string content;
    890   char buf[32];
    891   size_t did_read = fread(buf, sizeof(buf[0]), sizeof(buf) - 1, file);
    892   buf[did_read] = '\0';
    893   content = buf;
    894 
    895   int32 value;
    896   if (port::safe_strto32(content, &value)) {
    897     if (value < 0) {  // See http://b/18228951 for details on this path.
    898       LOG(INFO) << "successful NUMA node read from SysFS had negative value ("
    899                 << value << "), but there must be at least one NUMA node"
    900                             ", so returning NUMA node zero";
    901       fclose(file);
    902       return 0;
    903     }
    904     fclose(file);
    905     return value;
    906   }
    907 
    908   LOG(WARNING)
    909       << "could not convert SysFS file contents to integral NUMA node value: "
    910       << content;
    911 
    912   fclose(file);
    913   return kUnknownNumaNode;
    914 #endif
    915 }
    916 
    917 // Set of compute capability specific device parameters that cannot be
    918 // queried from the driver API.  These values instead are baked into a
    919 // lookup table indexed by compute capability version.
    920 struct UnqueryableDeviceParams {
    921   int cc_major;
    922   int cc_minor;
    923   uint64 blocks_per_core_limit;
    924   uint64 registers_per_core_limit;
    925   uint64 registers_per_thread_limit;
    926   uint64 warp_alloc_granularity;
    927   uint64 register_alloc_granularity;
    928   uint64 shared_memory_alloc_granularity;
    929 };
    930 
    931 // http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
    932 // https://developer.download.nvidia.com/compute/cuda/CUDA_Occupancy_calculator.xls
    933 static const UnqueryableDeviceParams kAllUnqueryableDeviceParams[] = {
    934     {
    935         2, 0,       // compute capability (2.0)
    936         8,          // blocks_per_core_limit
    937         32 * 1024,  // registers_per_core_limit
    938         63,         // registers_per_thread_limit
    939         2,          // warp_alloc_granularity
    940         64,         // register_alloc_granularity
    941         128,        // shared_memory_alloc_granularity
    942     },
    943     {
    944         2, 1,       // compute capability (2.1)
    945         8,          // blocks_per_core_limit
    946         32 * 1024,  // registers_per_core_limit
    947         63,         // registers_per_thread_limit
    948         2,          // warp_alloc_granularity
    949         64,         // register_alloc_granularity
    950         128,        // shared_memory_alloc_granularity
    951     },
    952     {
    953         3, 0,       // compute capability (3.0)
    954         16,         // blocks_per_core_limit
    955         64 * 1024,  // registers_per_core_limit
    956         63,         // registers_per_thread_limit
    957         4,          // warp_alloc_granularity
    958         256,        // register_alloc_granularity
    959         256,        // shared_memory_alloc_granularity
    960     },
    961     {
    962         3, 2,       // compute capability (3.2)
    963         16,         // blocks_per_core_limit
    964         64 * 1024,  // registers_per_core_limit
    965         255,        // registers_per_thread_limit
    966         4,          // warp_alloc_granularity
    967         256,        // register_alloc_granularity
    968         256,        // shared_memory_alloc_granularity
    969     },
    970     {
    971         3, 5,       // compute capability (3.5)
    972         16,         // blocks_per_core_limit
    973         64 * 1024,  // registers_per_core_limit
    974         255,        // registers_per_thread_limit
    975         4,          // warp_alloc_granularity
    976         256,        // register_alloc_granularity
    977         256,        // shared_memory_alloc_granularity
    978     },
    979     {
    980         3, 7,        // compute capability (3.7)
    981         16,          // blocks_per_core_limit
    982         128 * 1024,  // registers_per_core_limit
    983         255,         // registers_per_thread_limit
    984         4,           // warp_alloc_granularity
    985         256,         // register_alloc_granularity
    986         256,         // shared_memory_alloc_granularity
    987     },
    988     {
    989         5, 0,       // compute capability (5.0)
    990         32,         // blocks_per_core_limit
    991         64 * 1024,  // registers_per_core_limit
    992         255,        // registers_per_thread_limit
    993         4,          // warp_alloc_granularity
    994         256,        // register_alloc_granularity
    995         256,        // shared_memory_alloc_granularity
    996     },
    997     {
    998         5, 2,       // compute capability (5.2)
    999         32,         // blocks_per_core_limit
   1000         64 * 1024,  // registers_per_core_limit
   1001         255,        // registers_per_thread_limit
   1002         4,          // warp_alloc_granularity
   1003         256,        // register_alloc_granularity
   1004         256,        // shared_memory_alloc_granularity
   1005     },
   1006     {
   1007         5, 3,       // compute capability (5.3)
   1008         32,         // blocks_per_core_limit
   1009         64 * 1024,  // registers_per_core_limit
   1010         255,        // registers_per_thread_limit
   1011         4,          // warp_alloc_granularity
   1012         256,        // register_alloc_granularity
   1013         256,        // shared_memory_alloc_granularity
   1014     },
   1015     {
   1016         6, 0,       // compute capability (6.0)
   1017         32,         // blocks_per_core_limit
   1018         64 * 1024,  // registers_per_core_limit
   1019         255,        // registers_per_thread_limit
   1020         2,          // warp_alloc_granularity
   1021         256,        // register_alloc_granularity
   1022         256,        // shared_memory_alloc_granularity
   1023     },
   1024     {
   1025         6, 1,       // compute capability (6.1)
   1026         32,         // blocks_per_core_limit
   1027         64 * 1024,  // registers_per_core_limit
   1028         255,        // registers_per_thread_limit
   1029         4,          // warp_alloc_granularity
   1030         256,        // register_alloc_granularity
   1031         256,        // shared_memory_alloc_granularity
   1032     },
   1033     {
   1034         6, 2,       // compute capability (6.2)
   1035         32,         // blocks_per_core_limit
   1036         64 * 1024,  // registers_per_core_limit
   1037         255,        // registers_per_thread_limit
   1038         4,          // warp_alloc_granularity
   1039         256,        // register_alloc_granularity
   1040         256,        // shared_memory_alloc_granularity
   1041     },
   1042     // TODO(jlebar): Confirm the alloc granularity values for sm_70.  These are
   1043     // not published in the spreadsheet linked above.  Currently we guess that
   1044     // they're the same as sm_60.
   1045     {
   1046         7, 0,       // compute capability (7.0)
   1047         32,         // blocks_per_core_limit
   1048         64 * 1024,  // registers_per_core_limit
   1049         255,        // registers_per_thread_limit
   1050         2,          // warp_alloc_granularity
   1051         256,        // register_alloc_granularity
   1052         256,        // shared_memory_alloc_granularity
   1053     },
   1054 };
   1055 
   1056 DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const {
   1057   internal::DeviceDescriptionBuilder builder;
   1058 
   1059   {
   1060     int driver_version = 0;
   1061     (void)CUDADriver::GetDriverVersion(&driver_version);
   1062     string augmented_driver_version = port::Printf(
   1063         "%d (%s)", driver_version,
   1064         DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
   1065     builder.set_driver_version(augmented_driver_version);
   1066   }
   1067 
   1068   {
   1069     string pci_bus_id = CUDADriver::GetPCIBusID(device_);
   1070 
   1071     // Lower the hex characters to match sysfs.
   1072     pci_bus_id = port::Lowercase(pci_bus_id);
   1073     builder.set_pci_bus_id(pci_bus_id);
   1074 
   1075     // Read the NUMA node corresponding to the PCI bus ID out of sysfs.
   1076     int numa_node = TryToReadNumaNode(pci_bus_id, device_ordinal_);
   1077     builder.set_numa_node(numa_node);
   1078   }
   1079 
   1080   CUdevprop prop;
   1081   if (CUDADriver::GetDeviceProperties(&prop, device_ordinal_)) {
   1082     builder.set_threads_per_block_limit(prop.maxThreadsPerBlock);
   1083 
   1084     ThreadDim thread_dim_limit;
   1085     thread_dim_limit.x = prop.maxThreadsDim[0];
   1086     thread_dim_limit.y = prop.maxThreadsDim[1];
   1087     thread_dim_limit.z = prop.maxThreadsDim[2];
   1088     builder.set_thread_dim_limit(thread_dim_limit);
   1089 
   1090     float clock_rate_ghz = static_cast<float>(prop.clockRate) / 1e6;
   1091     builder.set_clock_rate_ghz(clock_rate_ghz);
   1092   }
   1093 
   1094   {
   1095     bool ecc_enabled = false;
   1096     (void)CUDADriver::IsEccEnabled(device_, &ecc_enabled);
   1097     builder.set_ecc_enabled(ecc_enabled);
   1098   }
   1099 
   1100   {
   1101     uint64 device_memory_size = -1;
   1102     (void)CUDADriver::GetDeviceTotalMemory(device_, &device_memory_size);
   1103     builder.set_device_memory_size(device_memory_size);
   1104   }
   1105 
   1106   {
   1107     BlockDim block_dim_limit;
   1108     FillBlockDimLimit(&block_dim_limit);
   1109     builder.set_block_dim_limit(block_dim_limit);
   1110   }
   1111 
   1112   {
   1113     string device_name;
   1114     (void)CUDADriver::GetDeviceName(device_, &device_name);
   1115     builder.set_name(device_name);
   1116   }
   1117 
   1118   for (size_t i = 0; i < ARRAYSIZE(kAllUnqueryableDeviceParams); i++) {
   1119     const auto &params = kAllUnqueryableDeviceParams[i];
   1120     if (params.cc_major == cc_major_ && params.cc_minor == cc_minor_) {
   1121       builder.set_blocks_per_core_limit(params.blocks_per_core_limit);
   1122       builder.set_registers_per_core_limit(params.registers_per_core_limit);
   1123       builder.set_registers_per_thread_limit(params.registers_per_thread_limit);
   1124       builder.set_warp_alloc_granularity(params.warp_alloc_granularity);
   1125       builder.set_register_alloc_granularity(params.register_alloc_granularity);
   1126       builder.set_shared_memory_alloc_granularity(
   1127           params.shared_memory_alloc_granularity);
   1128     }
   1129   }
   1130 
   1131   builder.set_platform_version(
   1132       port::StrCat("Compute Capability ", cc_major_, ".", cc_minor_));
   1133 
   1134   // TODO(leary) should be a way to query this from the driver, but this is
   1135   // unlikely to change for us any time soon.
   1136   builder.set_device_address_bits(64);
   1137 
   1138   builder.set_device_vendor("NVIDIA Corporation");
   1139   builder.set_cuda_compute_capability(cc_major_, cc_minor_);
   1140   builder.set_shared_memory_per_core(
   1141       CUDADriver::GetMaxSharedMemoryPerCore(device_).ValueOrDie());
   1142   builder.set_shared_memory_per_block(
   1143       CUDADriver::GetMaxSharedMemoryPerBlock(device_).ValueOrDie());
   1144   builder.set_core_count(
   1145       CUDADriver::GetMultiprocessorCount(device_).ValueOrDie());
   1146   builder.set_threads_per_core_limit(
   1147       CUDADriver::GetMaxThreadsPerMultiprocessor(device_).ValueOrDie());
   1148   builder.set_registers_per_block_limit(
   1149       CUDADriver::GetMaxRegistersPerBlock(device_).ValueOrDie());
   1150   builder.set_threads_per_warp(
   1151       CUDADriver::GetThreadsPerWarp(device_).ValueOrDie());
   1152 
   1153   auto built = builder.Build();
   1154   return built.release();
   1155 }
   1156 
   1157 }  // namespace cuda
   1158 
   1159 namespace gpu = ::perftools::gputools;
   1160 
   1161 void initialize_cuda_gpu_executor() {
   1162   *gpu::internal::MakeCUDAExecutorImplementation() = [](
   1163       const gpu::PluginConfig &config) {
   1164     return new gpu::cuda::CUDAExecutor{config};
   1165   };
   1166 }
   1167 
   1168 }  // namespace gputools
   1169 }  // namespace perftools
   1170 
   1171 REGISTER_MODULE_INITIALIZER(
   1172     cuda_gpu_executor, {perftools::gputools::initialize_cuda_gpu_executor();});
   1173