1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ 17 #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ 18 19 #include <atomic> 20 #include <memory> 21 #include <set> 22 #include <tuple> 23 #include <vector> 24 25 #include "tensorflow/stream_executor/lib/status.h" 26 #include "tensorflow/stream_executor/lib/statusor.h" 27 #include "tensorflow/stream_executor/lib/strcat.h" 28 #include "tensorflow/stream_executor/lib/threadpool.h" 29 #include "tensorflow/stream_executor/platform.h" 30 #include "tensorflow/stream_executor/platform/logging.h" 31 #include "tensorflow/stream_executor/platform/mutex.h" 32 #include "tensorflow/stream_executor/platform/port.h" 33 #include "tensorflow/stream_executor/platform/thread_annotations.h" 34 #include "tensorflow/stream_executor/rng.h" 35 #include "tensorflow/stream_executor/shared_memory_config.h" 36 #include "tensorflow/stream_executor/stream.h" 37 #include "tensorflow/stream_executor/stream_executor_internal.h" 38 #include "tensorflow/stream_executor/trace_listener.h" 39 40 namespace perftools { 41 namespace gputools { 42 43 // Structure used for device memory leak checking. 44 struct AllocRecord { 45 // The requested allocation size of the buffer. 46 uint64 bytes; 47 48 // Holds a representation of the stack at the time the associated buffer was 49 // allocated. Produced in a form described in 50 // //util/symbolize/symbolized_stacktrace.h. 51 string stack_trace; 52 }; 53 54 // Forward declaration of private friend class. 55 template <typename BeginCallT, typename CompleteCallT, 56 typename ReturnT, typename... BeginArgsT> 57 class ScopedTracer; 58 59 // A StreamExecutor manages a single device, in terms of executing work (kernel 60 // launches) and memory management (allocation/deallocation, memory copies to 61 // and from the device). It is conceptually the "handle" for a device -- Stream 62 // objects, which are used to enqueue work to run on the 63 // coprocessor have a StreamExecutor instance as their "parent" object. 64 // 65 // StreamExecutor objects have an underlying platform that is specified up 66 // front; 67 // e.g. either it is a CUDA or OpenCL executor. 68 // 69 // Thread-safe after initialization. 70 // StreamExecutor interface should not be invoked from a signal handler. 71 class StreamExecutor { 72 public: 73 explicit StreamExecutor(PlatformKind kind, 74 const PluginConfig &plugin_config = PluginConfig()); 75 76 StreamExecutor( 77 const Platform *platform, 78 std::unique_ptr<internal::StreamExecutorInterface> implementation); 79 80 ~StreamExecutor(); 81 82 port::Status Init(); 83 port::Status Init(int device_ordinal, DeviceOptions device_options); 84 85 // DEPRECATED: Do not use; use platform() instead. 86 // Returns the platform that this StreamExecutor is acting upon. 87 PlatformKind platform_kind() const { return platform_kind_; } 88 89 // Returns a reference to the platform that created this executor. 90 const Platform *platform() const { return platform_; } 91 92 // Retrieves (loads) a kernel for the platform this StreamExecutor is acting 93 // upon, if one exists. 94 // 95 // Parameters: 96 // spec: The MultiKernelLoaderSpec is usually generated as a compile-time 97 // constant into an appropriate namespace. For example, see 98 // perftools::gputools::executor_sample::kKernelLoaderSpecs, from which a 99 // MultiKernelLoaderSpec is selected. 100 // kernel: Outparam that the kernel is loaded into. A given Kernel 101 // instantiation should not be loaded into more than once. 102 // 103 // If an error occurs, or there is no kernel available for the StreamExecutor 104 // platform, false is returned. 105 bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel); 106 107 // Releases any state associated with the previously loaded kernel. 108 void UnloadKernel(const KernelBase *kernel); 109 110 // Synchronously allocates an array on the device of type T with element_count 111 // elements. 112 template <typename T> 113 DeviceMemory<T> AllocateArray(uint64 element_count); 114 115 // As AllocateArray(), but returns a ScopedDeviceMemory<T>. 116 template <typename T> 117 ScopedDeviceMemory<T> AllocateOwnedArray(uint64 element_count) { 118 return ScopedDeviceMemory<T>(this, AllocateArray<T>(element_count)); 119 } 120 121 // Convenience wrapper that allocates space for a single element of type T in 122 // device memory. 123 template <typename T> 124 DeviceMemory<T> AllocateScalar() { 125 return AllocateArray<T>(1); 126 } 127 128 // As AllocateScalar(), but returns a ScopedDeviceMemory<T>. 129 template <typename T> 130 ScopedDeviceMemory<T> AllocateOwnedScalar() { 131 return AllocateOwnedArray<T>(1); 132 } 133 134 // Synchronously allocates a scalar of type T on the device that is (POD) 135 // zero-byte initialized. 136 template <typename T> 137 DeviceMemory<T> AllocateZeroed(); 138 139 // As AllocateZeroed(), but returns a ScopedDeviceMemory<T>. 140 template <typename T> 141 ScopedDeviceMemory<T> AllocateOwnedZeroed() { 142 return ScopedDeviceMemory<T>(this, AllocateZeroed<T>()); 143 } 144 145 // Allocate a memory region inside another allocated memory region. 146 // Offset and size are specified in terms of T elements. 147 // Warning: Do not free a parent buffer before its sub-buffers; this may cause 148 // use-after-free issues (the specific behavior is not consistent across 149 // platforms). 150 // - Note: OpenCL uses refcounting to manage buffer lifetimes, so use of a 151 // sub-buffer after parent deallocation is expected to be safe. This will 152 // render your code non-platform-portable, however. 153 template <typename T> 154 DeviceMemory<T> AllocateSubBuffer(DeviceMemory<T> *parent, 155 uint64 element_offset, 156 uint64 element_count); 157 158 // As AllocateSubBuffer(), but returns a ScopedDeviceMemory<T>. 159 template <typename T> 160 ScopedDeviceMemory<T> AllocateOwnedSubBuffer(DeviceMemory<T> *parent, 161 uint64 element_offset, 162 uint64 element_count) { 163 return ScopedDeviceMemory<T>( 164 this, AllocateSubBuffer<T>(parent, element_offset, element_count)); 165 } 166 167 // Finds a symbol and returns device memory allocated to the symbol. The 168 // symbol is searched in any kernels that were previously loaded through 169 // GetKernel() before the GetSymbol() call. The user has to make sure that the 170 // type of symbol and T match. 171 // - Note: symbol_name should include its namespace as well. For example, 172 // pass "nms0::symbol" if referring to nms0::symbol. 173 template <typename T> 174 port::StatusOr<DeviceMemory<T>> GetSymbol(const string &symbol_name); 175 176 // Deallocate the DeviceMemory previously allocated via this interface. 177 // Deallocation of a nullptr-representative value is permitted. 178 // 179 // Resets the internal contents of mem to be null-representative, but this 180 // null-out effect should not be relied upon in client code. 181 void Deallocate(DeviceMemoryBase *mem); 182 183 // Retrieves a mapping of active opaque device memory pointer to a string 184 // representation of the [allocating thread's] stack at the time the pointer 185 // was allocated. Useful for tracking device memory leaks. 186 // 187 // Note: this will only be populated if --check_device_leaks flag is 188 // activated. 189 void GetMemAllocs(std::map<void *, AllocRecord> *records_out); 190 191 // Allocates a region of host memory and registers it with the platform API. 192 // Memory allocated in this manner (or allocated and registered with 193 // HostMemoryRegister() is required for use in asynchronous memcpy operations, 194 // such as Stream::ThenMemcpy. 195 void *HostMemoryAllocate(uint64 bytes); 196 197 // Deallocates a region of host memory allocated by HostMemoryAllocate(). 198 void HostMemoryDeallocate(void *location); 199 200 // Registers a region of host memory with the platform API. Registered memory 201 // (or memory allocated with HostMemoryAllocate) is required for use with 202 // asynchronous memcpy operations, such as Stream::ThenMemcpy. This method 203 // is used to register memory allocated outside the StreamExecutor; 204 // HostMemoryAllocate implicitly registers its allocations and 205 // HostMemoryDeallocate implicitly deregisters on deallocation. 206 bool HostMemoryRegister(void *location, uint64 size) SE_MUST_USE_RESULT; 207 208 // Unregisters a region of host memory registered with HostMemoryRegister. 209 // This should be done before deallocating the region with delete[]/free/etc. 210 bool HostMemoryUnregister(void *location) SE_MUST_USE_RESULT; 211 212 // Synchronizes all activity occurring in the StreamExecutor's context (most 213 // likely a whole device). 214 bool SynchronizeAllActivity() SE_MUST_USE_RESULT; 215 216 // Blocks the caller while "size" bytes are zeroed out (in POD fashion) at the 217 // given location in device memory. 218 bool SynchronousMemZero(DeviceMemoryBase *location, 219 uint64 size) SE_MUST_USE_RESULT; 220 221 // Blocks the caller while "size" bytes are initialized to "value" (in POD 222 // fashion) at the given location in device memory. 223 bool SynchronousMemSet(DeviceMemoryBase *location, int value, 224 uint64 size) SE_MUST_USE_RESULT; 225 226 // [deprecated] Blocks the caller while a data segment of the given size is 227 // copied from the host source to the device destination. 228 // 229 // Deprecation: prefer explicit H2D below, to avoid error-prone API usage. 230 bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src, 231 uint64 size) SE_MUST_USE_RESULT; 232 233 // [deprecated] Blocks the caller while a data segment of the given size is 234 // copied from the device source to the host destination. 235 // 236 // Deprecation: prefer explicit D2H below, to avoid error-prone API usage. 237 bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src, 238 uint64 size) SE_MUST_USE_RESULT; 239 240 // Same as SynchronousMemcpy(DeviceMemoryBase*, ...) above. 241 port::Status SynchronousMemcpyH2D(const void *host_src, int64 size, 242 DeviceMemoryBase *device_dst); 243 244 // Alternative interface for memcpying from host to device that takes an 245 // array slice. Checks that the destination size can accommodate the host 246 // slice size. 247 template <class T> 248 port::Status SynchronousMemcpyH2D(port::ArraySlice<T> host_src, 249 DeviceMemoryBase *device_dst) { 250 auto host_size = host_src.size() * sizeof(T); 251 CHECK(device_dst->size() == 0 || device_dst->size() >= host_size); 252 return SynchronousMemcpyH2D(host_src.begin(), host_size, device_dst); 253 } 254 255 // Same as SynchronousMemcpy(void*, ...) above. 256 port::Status SynchronousMemcpyD2H(const DeviceMemoryBase &device_src, 257 int64 size, void *host_dst); 258 259 // Alternative interface for memcpying from device to host that takes an 260 // array slice. Checks that the destination size can accommodate the host 261 // slice size. 262 template <typename T> 263 port::Status SynchronousMemcpyD2H(const DeviceMemory<T> &device_src, 264 port::MutableArraySlice<T> host_dst) { 265 auto host_size = host_dst.size() * sizeof(T); 266 CHECK(device_src.size() == 0 || host_size >= device_src.size()); 267 return SynchronousMemcpyD2H(device_src, host_size, host_dst.begin()); 268 } 269 270 // Blocks the caller while a data segment of the given size is copied from the 271 // device source to the device destination. 272 bool SynchronousMemcpy(DeviceMemoryBase *device_dst, 273 const DeviceMemoryBase &device_src, 274 uint64 size) SE_MUST_USE_RESULT; 275 276 // Enqueues an operation onto stream to zero out size bytes at the given 277 // device memory location. Neither stream nor location may be null. Returns 278 // whether the operation was successfully enqueued onto the stream. 279 bool MemZero(Stream *stream, DeviceMemoryBase *location, 280 uint64 size) SE_MUST_USE_RESULT; 281 282 // Enqueues an operation onto stream to set 32-bit patterns starting at 283 // location, for byte count given by size. size must be 32-bit quantified 284 // (i.e. evently divisible by 4). Returns whether the operation was 285 // successfully enqueued onto the stream. 286 bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern, 287 uint64 size) SE_MUST_USE_RESULT; 288 289 // Enables peer access from this StreamExecutor to memory 290 // allocated by other, such that launched device code, memcpies, etc may 291 // access it directly. 292 // 293 // Both this StreamExecutor and other must be backed by the same platform (as 294 // in 295 // CUDA vs OpenCL) implementation. 296 port::Status EnablePeerAccessTo(StreamExecutor *other); 297 298 // Returns whether it's possible to enable peer access from this 299 // StreamExecutor 300 // to memory allocated by another. 301 // 302 // Even when this returns true, EnablePeerAccessTo may fail for other reasons; 303 // this is more an up-front test as to whether it's expressly forbidden. 304 bool CanEnablePeerAccessTo(StreamExecutor *other); 305 306 // Gets the preferred shared memory configuration for the device to which this 307 // executor is bound. 308 SharedMemoryConfig GetDeviceSharedMemoryConfig(); 309 310 // Sets the preferred shared memory configuration for the device to which this 311 // executor is bound. 312 port::Status SetDeviceSharedMemoryConfig(SharedMemoryConfig config); 313 314 // Obtains metadata about the underlying device. 315 // The value is cached on first use. 316 const DeviceDescription &GetDeviceDescription() const; 317 318 // If implemented, returns device specific measurement of load 319 // (e.g. pending requests). 320 int64 GetDeviceLoad() const; 321 322 // Returns the underlying device memory usage information, if it is available. 323 // If it is not available (false is returned), free/total may not be 324 // initialized. 325 // 326 // Note: "Free" reflects the amount of free memory on the underlying device, 327 // so allocations via other StreamExecutors that have the same underlying 328 // device 329 // will be reflected in "free". 330 bool DeviceMemoryUsage(int64 *free, int64 *total) const; 331 332 // The device count reported by this StreamExecutor's platform. 333 // Note: on OpenCL we implicitly select platform zero at the moment. 334 int PlatformDeviceCount() const; 335 336 // Returns whether the StreamExecutor supports BLAS routines for the platform 337 // that underlies this interface. 338 bool SupportsBlas() const; 339 340 // Returns whether the StreamExecutor supports FFT routines for the platform 341 // that underlies this interface. 342 bool SupportsFft() const; 343 344 // Returns whether the StreamExecutor supports RNG routines for the platform 345 // that underlies this interface. 346 bool SupportsRng() const; 347 348 // Returns whether the StreamExecutor support neural net routines for the 349 // platform that underlies this interface. 350 bool SupportsDnn() const; 351 352 // Get the list of supported algorithms for the forward convolution opeartion. 353 bool GetConvolveAlgorithms(bool with_winograd_nonfused, 354 std::vector<dnn::AlgorithmDesc> *out_algorithms); 355 356 // Get the list of supported algorithms for the backward convolution on data. 357 bool GetConvolveBackwardDataAlgorithms( 358 bool with_winograd_nonfused, 359 std::vector<dnn::AlgorithmDesc> *out_algorithms); 360 361 // Get the list of supported algorithms for the backward convolution on the 362 // filter. 363 bool GetConvolveBackwardFilterAlgorithms( 364 bool with_winograd_nonfused, 365 std::vector<dnn::AlgorithmDesc> *out_algorithms); 366 367 // Get the list of supported algorithms for BLAS gemm. 368 bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms); 369 370 // Create an RNN descriptor based on model shapes and configurations. 371 // The caller retains the ownership of the descriptor. 372 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor( 373 int num_layers, int hidden_size, int input_size, 374 dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, 375 dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, 376 uint64 seed, ScratchAllocator *state_allocator); 377 378 // Create a RNN sequence descriptor that specifies either the input or output 379 // sequence. The caller retains the ownership of the returned descriptor. 380 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>> 381 createRnnSequenceTensorDescriptor(int seq_length, int batch_size, 382 int data_size, dnn::DataType data_type); 383 384 // Create an RNN state descriptor that specifies the input or hidden state. 385 // The caller retains the ownership of the returned descriptor. 386 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>> 387 createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, 388 dnn::DataType data_type); 389 390 // Returns the device ordinal that this StreamExecutor was initialized with. 391 // Meaningless before initialization. 392 int device_ordinal() const { return device_ordinal_; } 393 394 // Returns a borrowed pointer to the underlying StreamExecutor implementation. 395 internal::StreamExecutorInterface *implementation(); 396 397 // Warning: use Stream::ThenLaunch instead, this method is not for general 398 // consumption. However, this is the only way to launch a kernel for which 399 // the type signature is only known at runtime; say, if an application 400 // supports loading/launching kernels with arbitrary type signatures. 401 // In this case, the application is expected to know how to do parameter 402 // packing that obeys the contract of the underlying platform implementation. 403 // 404 // Launches a data parallel kernel with the given thread/block 405 // dimensionality and already-packed args/sizes to pass to the underlying 406 // platform driver. 407 // 408 // This is called by Stream::Launch() to delegate to the platform's launch 409 // implementation in StreamExecutorInterface::Launch(). 410 bool Launch(Stream *stream, const ThreadDim &thread_dims, 411 const BlockDim &block_dims, const KernelBase &kernel, 412 const KernelArgsArrayBase &args); 413 414 // Gets-or-creates (creates with memoization) a FftSupport datatype that can 415 // be used to execute FFT routines on the current platform. 416 // 417 // Ownership and user-facing is the same as AsBlas() below. 418 // 419 // Returns null if there was an error initializing the FFT support for the 420 // underlying platform. 421 fft::FftSupport *AsFft(); 422 423 // Gets-or-creates (creates with memoization) a DnnSupport datatype that can 424 // be used for neural network routines on the current platform. 425 // 426 // Ownership and user-facing is the same as AsBlas() below. 427 // 428 // Returns null if there was an error initializing the DNN support for the 429 // underlying platform. 430 dnn::DnnSupport *AsDnn(); 431 432 // Turns StreamExecutor operation tracing on or off. 433 void EnableTracing(bool enable); 434 435 // Registers a trace listener to receive callbacks for only a single 436 // StreamExecutor instance. 437 // To register a listener for all executors for a given platform, see 438 // Platform::RegisterTraceListener(). 439 // Does not take ownership of listener. 440 void RegisterTraceListener(TraceListener* listener); 441 442 // Removes a TraceListener from this StreamExecutor instance. 443 // Returns false (and logs) in cases where the argument listener was not 444 // previously registered. 445 bool UnregisterTraceListener(TraceListener* listener); 446 447 private: 448 template <typename BeginCallT, typename CompleteCallT, 449 typename ReturnT, typename... BeginArgsT> 450 friend class ScopedTracer; 451 friend class Event; 452 friend class Stream; 453 friend class Timer; 454 template <typename... Params> 455 friend class TypedKernel; 456 template <typename... Args> 457 friend struct ThenBlasImpl; 458 459 // Gets-or-creates (creates with memoization) a BlasSupport datatype that can 460 // be used to execute BLAS routines on the current platform. This is typically 461 // not user-facing, as users will use the Stream::ThenBlas* family of routines 462 // to entrain BLAS operations. See blas.h for additional details. 463 // 464 // Ownership is not transferred to the caller -- ownership is retained by this 465 // object for memoization. This BLAS interface is also only expected to be 466 // used by a Stream for entraining calls to BLAS functionality. 467 // 468 // Returns null if there was an error initializing the BLAS support for the 469 // underlying platform. 470 blas::BlasSupport *AsBlas(); 471 472 // Gets-or-creates (creates with memoization) an RngSupport datatype that can 473 // be used for random-number-generation routines on the current platform. 474 // 475 // Ownership and user-facing is the same as AsBlas() above. 476 // 477 // Returns null if there was an error initializing the RNG support for the 478 // underlying platform. 479 rng::RngSupport *AsRng(); 480 481 // Causes the host code to synchronously wait for operations entrained onto 482 // stream to complete. Effectively a join on the asynchronous device 483 // operations enqueued on the stream before this program point. 484 port::Status BlockHostUntilDone(Stream *stream); 485 486 // Synchronously allocates size bytes on the underlying platform and returns 487 // an opaque void* representing that allocation. In the case of failure, 488 // nullptr is returned. 489 void *Allocate(uint64 size); 490 491 // Finds and retrieves device memory for the symbol on the underlying 492 // platform. 493 bool GetSymbol(const string& symbol_name, void **mem, size_t *bytes); 494 495 // Entrains a memcpy operation onto stream, with a host destination location 496 // host_dst and a device memory source, with target size size. 497 bool Memcpy(Stream *stream, void *host_dst, 498 const DeviceMemoryBase &device_src, uint64 size); 499 500 // Entrains a memcpy operation onto stream, with a device destination location 501 // and a host memory source, with target size size. 502 bool Memcpy(Stream *stream, DeviceMemoryBase *device_dst, 503 const void *host_src, uint64 size); 504 505 // Entrains a memcpy operation onto stream, with a device destination location 506 // and a device source location, with target size size. Peer access should 507 // have been enabled between the StreamExecutors owning the device memory 508 // regions. 509 bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst, 510 const DeviceMemoryBase &device_src, uint64 size); 511 512 // Entrains on a stream a user-specified function to be run on the host. 513 // See Stream::ThenDoHostCallback for full details. 514 bool HostCallback(Stream *stream, std::function<void()> callback); 515 516 // Performs platform-specific allocation and initialization of an event. 517 port::Status AllocateEvent(Event *event); 518 519 // Performs platform-specific deallocation and cleanup of an event. 520 port::Status DeallocateEvent(Event *event); 521 522 // Inserts the specified event at the end of the specified stream. 523 port::Status RecordEvent(Stream *stream, Event *event); 524 525 // Wait for the specified event at the end of the specified stream. 526 port::Status WaitForEvent(Stream *stream, Event *event); 527 528 // Requests the current status of the event from the underlying platform. 529 Event::Status PollForEventStatus(Event *event); 530 531 // Allocates stream resources on the underlying platform for subject and 532 // initializes its internals. 533 bool AllocateStream(Stream *subject); 534 535 // Deallocates stream resources on the underlying platform. 536 void DeallocateStream(Stream *subject); 537 538 // Causes dependent to not begin execution until other has finished its 539 // last-enqueued work. 540 bool CreateStreamDependency(Stream *dependent, Stream *other); 541 542 // Allocates timer resources on the underlying platform for subject and 543 // initializes its internals. 544 bool AllocateTimer(Timer *subject); 545 546 // Deallocates timer resources on the underlying platform. 547 void DeallocateTimer(Timer *subject); 548 549 // Records a start event for an interval timer. 550 bool StartTimer(Stream *stream, Timer *timer); 551 552 // Records a stop event for an interval timer. 553 bool StopTimer(Stream *stream, Timer *timer); 554 555 // Allocates a new metadata object, appropriately populated, on the heap, with 556 // ownership transfer to caller. 557 DeviceDescription *PopulateDeviceDescription() const; 558 559 // Adds a task to the port::ThreadPool work queue. These tasks must be 560 // fire-and-forget and have no external data or timing dependencies; their 561 // execution order and completion time have no guarantees. 562 // For an example of an appropriate task, see HostBlas::DoBlasGemmInternal; 563 // there, temporary internal buffers are freed using this method. 564 void EnqueueOnBackgroundThread(std::function<void()> task); 565 566 // Adds an AllocRecord for 'opaque' of size 'bytes' to the record map, for 567 // leak checking. NULL buffer pointers and buffer sizes of 0 will not be 568 // tracked. 569 void CreateAllocRecord(void *opaque, uint64 size); 570 571 // Removes the AllocRecord keyed by 'opaque' from the record map. NULL 572 // pointers will not be erased (as they're not tracked, per above). 573 void EraseAllocRecord(void *opaque); 574 575 // Calls the relevant TraceListener routine to begin tracing for the specified 576 // asynchronous method. 577 template <typename TraceCallT, typename... ArgsT> 578 void SubmitTrace(TraceCallT trace_call, ArgsT&&... args); 579 580 // Reader/writer lock for class-static StreamExecutor members. 581 static mutex static_mu_; 582 583 // Reader/writer lock for mutable data structures on this StreamExecutor. 584 // 585 // Mutable so that caching functions (like DeviceDescription, AsBlas, etc.) 586 // can acquire the lock on their first (mutating) call as well. 587 mutable mutex mu_; 588 589 // Reference to the platform that created this executor. 590 const Platform *platform_; 591 592 // Pointer to the platform-specific-interface implementation. This is 593 // delegated to by the interface routines in pointer-to-implementation 594 // fashion. 595 std::unique_ptr<internal::StreamExecutorInterface> implementation_; 596 597 // A mapping of pointer (to device memory) to string representation of the 598 // stack (of the allocating thread) at the time at which the pointer was 599 // allocated. 600 std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_); 601 602 // Memoized BLAS support object -- we only want to create this once when asked 603 // for a BLAS interface. 604 std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_); 605 606 // Memoized DNN support object -- we only want to create this once when asked 607 // for an DNN interface. 608 std::unique_ptr<dnn::DnnSupport> dnn_ GUARDED_BY(mu_); 609 610 // Memoized FFT support object -- we only want to create this once when asked 611 // for a FFT interface. 612 std::unique_ptr<fft::FftSupport> fft_; 613 614 // Memoized RNG support object -- we only want to create this once when asked 615 // for an RNG interface. 616 std::unique_ptr<rng::RngSupport> rng_ GUARDED_BY(mu_); 617 618 // Slot to cache the owned DeviceDescription for the underlying device 619 // once it has been quieried from DeviceDescription(). 620 mutable std::unique_ptr<DeviceDescription> device_description_ 621 GUARDED_BY(mu_); 622 623 // The kind of the underlying platform that is being targeted, as passed 624 // during construction. 625 // 626 // Immutable post-initialization. 627 PlatformKind platform_kind_; 628 629 // The device ordinal that this object was initialized with. 630 // 631 // Immutable post-initialization. 632 int device_ordinal_; 633 634 // Executor for handling host callback work that cannot be performed 635 // by a host callback thread - for example, cleanup after a host BLAS routine 636 // (which may make device API calls). This work cannot block the host 637 // callback thread, will be completed asynchronously, and should be treated 638 // as fire-and-forget. Assume no ordering guarantees WRT the tasks enqueued 639 // here. 640 // 641 // Immutable post-initialization. Object is thread-safe. 642 std::unique_ptr<port::ThreadPool> background_threads_; 643 644 // Counter for the current number of live streams. This is used to check 645 // for accidentally-outstanding streams at StreamExecutor teardown time, as 646 // well 647 // as to indicate leaks (via a large outstanding count being logged) in the 648 // case we can't allocate more streams. 649 std::atomic_int_fast32_t live_stream_count_; 650 651 // Only one worker thread is needed; little work will be done by the 652 // executor. 653 static const int kNumBackgroundThreads = 1; 654 655 // Indicates if StreamExecutor operation tracing should be performed. 656 bool tracing_enabled_; 657 658 // The set of TraceListeners registered for this StreamExecutor. 659 std::set<TraceListener*> listeners_ GUARDED_BY(mu_); 660 661 SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor); 662 }; 663 664 //////////// 665 // Inlines 666 667 template <typename T> 668 inline DeviceMemory<T> StreamExecutor::AllocateArray(uint64 element_count) { 669 uint64 bytes = sizeof(T) * element_count; 670 void *opaque = Allocate(bytes); 671 return DeviceMemory<T>::MakeFromByteSize(opaque, bytes); 672 } 673 674 template <typename T> 675 inline port::StatusOr<DeviceMemory<T>> StreamExecutor::GetSymbol( 676 const string &symbol_name) { 677 // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to 678 // be nullptr/0 for consistency with DeviceMemory semantics. 679 void *opaque = nullptr; 680 size_t bytes = 0; 681 if (GetSymbol(symbol_name, &opaque, &bytes)) { 682 CHECK_EQ(bytes % sizeof(T), 0); 683 return DeviceMemory<T>::MakeFromByteSize(opaque, bytes); 684 } 685 return port::Status( 686 port::error::NOT_FOUND, 687 port::StrCat("Check if kernel using the symbol is loaded: ", 688 symbol_name)); 689 } 690 691 template <typename ElemT> 692 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory() 693 : wrapped_(DeviceMemoryBase()), parent_(nullptr) {} 694 695 template <typename ElemT> 696 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory(StreamExecutor *parent, 697 DeviceMemoryBase value) 698 : wrapped_(value), parent_(parent) {} 699 700 template <typename ElemT> 701 ScopedDeviceMemory<ElemT>::ScopedDeviceMemory( 702 StreamExecutor *parent, std::initializer_list<ElemT> values) 703 : ScopedDeviceMemory(parent, parent->AllocateArray<ElemT>(values.size())) { 704 if (ptr() != nullptr) { 705 std::vector<ElemT> local(values); 706 if (!parent->SynchronousMemcpy(ptr(), const_cast<const ElemT *>(&local[0]), 707 ptr()->size())) { 708 Reset(nullptr); 709 } 710 } 711 } 712 713 template <typename ElemT> 714 ScopedDeviceMemory<ElemT>::~ScopedDeviceMemory() { 715 if (wrapped_ == nullptr) return; 716 DCHECK(parent_ != nullptr); 717 parent_->Deallocate(&wrapped_); 718 } 719 720 template <typename ElemT> 721 void ScopedDeviceMemory<ElemT>::Reset(DeviceMemory<ElemT> updated) { 722 if (wrapped_ != nullptr) { 723 DCHECK(parent_ != nullptr); 724 parent_->Deallocate(&wrapped_); 725 } 726 wrapped_ = updated; 727 } 728 729 template <typename ElemT> 730 void ScopedDeviceMemory<ElemT>::Reset(std::nullptr_t) { 731 if (wrapped_ != nullptr) { 732 DCHECK(parent_ != nullptr); 733 parent_->Deallocate(&wrapped_); 734 } 735 wrapped_ = DeviceMemory<ElemT>{}; 736 } 737 738 template <typename T> 739 DeviceMemory<T> StreamExecutor::AllocateZeroed() { 740 void *opaque = Allocate(sizeof(T)); 741 if (opaque == nullptr) { 742 return DeviceMemory<T>{}; 743 } 744 745 DeviceMemory<T> result = DeviceMemory<T>::MakeFromByteSize(opaque, sizeof(T)); 746 bool ok = SynchronousMemZero(&result, sizeof(T)); 747 if (!ok) { 748 Deallocate(&result); 749 return DeviceMemory<T>{}; 750 } 751 752 return result; 753 } 754 755 template <typename T> 756 DeviceMemory<T> StreamExecutor::AllocateSubBuffer(DeviceMemory<T> *parent, 757 uint64 element_offset, 758 uint64 element_count) { 759 if (element_offset + element_count > parent->ElementCount()) { 760 LOG(ERROR) << "requested sub-buffer allocation (offset + size) is greater " 761 << "than parent allocation size: (" << element_offset << " + " 762 << element_count << ") vs. (" << parent->ElementCount() << ")"; 763 return DeviceMemory<T>{}; 764 } 765 766 void *opaque = implementation_->AllocateSubBuffer( 767 parent, sizeof(T) * element_offset, sizeof(T) * element_count); 768 if (opaque == nullptr) { 769 return DeviceMemory<T>{}; 770 } 771 CreateAllocRecord(opaque, sizeof(T) * element_count); 772 return DeviceMemory<T>(DeviceMemoryBase(opaque, sizeof(T) * element_count, 773 true /* = is_sub_buffer */)); 774 } 775 776 template <typename... Params, typename... Args> 777 inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, 778 const TypedKernel<Params...> &kernel, 779 Args... args) { 780 KernelInvocationChecker<std::tuple<Params...>, 781 std::tuple<Args...>>::CheckAllStaticAssert(); 782 if (ok()) { 783 // This is the core that allows type-safe kernel launching. 784 // Since the platforms take kernel arguments as tuples of (void *, size), 785 // we pack the variadic parameters passed as ...args into the desired 786 // tuple form and pass that packed form to the StreamExecutor::Launch() 787 // implementation. 788 KernelArgsArray<sizeof...(args)> kernel_args; 789 kernel.PackParams(&kernel_args, args...); 790 DCHECK(parent_ != nullptr); 791 bool ok = 792 parent_->Launch(this, thread_dims, block_dims, kernel, kernel_args); 793 if (!ok) { 794 SetError(); 795 LOG(WARNING) << "parent failed to launch kernel: " << &kernel; 796 } 797 } 798 return *this; 799 } 800 801 } // namespace gputools 802 } // namespace perftools 803 804 #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ 805