1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Suite of datatypes to represent data-parallel kernel objects (code entities). 17 // Kernel is the untyped variant, whereas TypedKernel takes a type signature 18 // to do some template-based helper generation and give compile-time type 19 // checking for kernel launch parameters. 20 // 21 // Users typically don't see KernelBase, they see typed kernels, analogous to a 22 // typed function pointer. TypedKernels express their argument types via 23 // template parameters like so: 24 // 25 // TypedKernel<DeviceMemory<int>*, int> 26 // 27 // Which expresses a data parallel kernel signature for: 28 // 29 // void(int*, int); 30 // 31 // And for a const memory region: 32 // 33 // TypedKernel<const DeviceMemory<int>&, int> 34 // 35 // Corresponds to a data parallel kernel signature for: 36 // 37 // void(const int*, int) 38 // 39 // Note that kernels always have a void return type, so results typically must 40 // be memcpy'ied from device memory to the host. 41 // 42 // Also note that a scalar integer residing in device memory and an array of 43 // integers residing in device memory have the same signature: DeviceMemory<T>. 44 // However, in the future, checks may be added for additional safety that arrays 45 // of minimum sizes are passed when those minimum sizes are contractually 46 // expected by the kernel. 47 // 48 // For user-defined types whose definitions are appropriately shared between the 49 // host code doing the launching and the kernel code being launched, the user 50 // defined types are similarly permitted to be expressed as residing in device 51 // memory: 52 // 53 // TypedKernel<DeviceMemory<MyUserDefinedStructure>> 54 // 55 // And, when the alignment and padding are agreed upon, POD types will also be 56 // able to be passed by value; for example, it is a common idiom to specify a 57 // bunch of options simultaneously with a structure: 58 // 59 // TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>> 60 // 61 // Which corresponds to a data parallel kernel signature like: 62 // 63 // void(MyOptionsStructurePassedByValue value, float *result); 64 // 65 // Users typically won't need to type out the TypedKernel signature in full, it 66 // will be typedef'd by automatically generated code; for example, see 67 // perftools::gputools::executor_sample::VecReduceAddKernel. 68 69 #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 70 #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 71 72 #include <array> 73 #include <memory> 74 #include <tuple> 75 #include <type_traits> 76 #include <vector> 77 78 #include "tensorflow/stream_executor/device_memory.h" 79 #include "tensorflow/stream_executor/kernel_cache_config.h" 80 #include "tensorflow/stream_executor/lib/array_slice.h" 81 #include "tensorflow/stream_executor/lib/inlined_vector.h" 82 #include "tensorflow/stream_executor/lib/stringpiece.h" 83 #include "tensorflow/stream_executor/platform/port.h" 84 85 namespace perftools { 86 namespace gputools { 87 88 class DeviceMemoryBase; 89 template <typename ElemT> 90 class DeviceMemory; 91 class StreamExecutor; 92 93 namespace internal { 94 class KernelInterface; 95 } // namespace internal 96 97 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as 98 // registers allocated, shared memory used, etc. 99 // Not all platforms support reporting of all information, so each accessor 100 // returns false if the associated field is not populated in the underlying 101 // platform. 102 class KernelMetadata { 103 public: 104 KernelMetadata() 105 : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {} 106 107 // Returns the number of registers used per thread executing this kernel. 108 bool registers_per_thread(int *registers_per_thread) const; 109 110 // Sets the number of registers used per thread executing this kernel. 111 void set_registers_per_thread(int registers_per_thread); 112 113 // Returns the amount of [static] shared memory used per block executing this 114 // kernel. Note that dynamic shared memory allocations are not (and can not) 115 // be reported here (since they're not specified until kernel launch time). 116 bool shared_memory_bytes(int *shared_memory_bytes) const; 117 118 // Sets the amount of [static] shared memory used per block executing this 119 // kernel. 120 void set_shared_memory_bytes(int shared_memory_bytes); 121 122 private: 123 // Holds the value returned by registers_per_thread above. 124 bool has_registers_per_thread_; 125 int registers_per_thread_; 126 127 // Holds the value returned by shared_memory_bytes above. 128 bool has_shared_memory_bytes_; 129 int64 shared_memory_bytes_; 130 }; 131 132 // A data-parallel kernel (code entity) for launching via the StreamExecutor, 133 // analogous to a void* device function pointer. See TypedKernel for the typed 134 // variant. 135 // 136 // Thread-compatible. 137 class KernelBase { 138 public: 139 KernelBase(KernelBase &&from); 140 141 // Constructs an "empty" (not-yet-loaded) kernel instance. 142 // 143 // parent is the StreamExecutor that will be responsible for loading the 144 // implementation of this kernel. It must not be null. 145 explicit KernelBase(StreamExecutor *parent); 146 147 // Test-only constructor that can take a mock KernelInterface implementation. 148 KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation); 149 150 // Releases resources associated with the kernel instance (i.e. 151 // platform-specific implementation). 152 ~KernelBase(); 153 154 // Returns the number of parameters that this kernel accepts. (Arity refers to 155 // nullary, unary, ...). 156 unsigned Arity() const; 157 158 // Returns the StreamExecutor that represents the platform this kernel 159 // executes upon. 160 StreamExecutor *parent() const { return parent_; } 161 162 // Returns a const pointer to the (opaque) platform-dependent implementation. 163 const internal::KernelInterface *implementation() const { 164 return implementation_.get(); 165 } 166 167 // Returns a non-const pointer to the (opaque) platform-dependent 168 // implementation. 169 internal::KernelInterface *implementation() { return implementation_.get(); } 170 171 void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; } 172 173 const KernelMetadata &metadata() const { return metadata_; } 174 175 // Sets the preferred cache configuration for a kernel. This is just a 176 // suggestion to the runtime, and may not be honored during execution. 177 void SetPreferredCacheConfig(KernelCacheConfig config); 178 179 // Gets the preferred cache configuration for a kernel. 180 KernelCacheConfig GetPreferredCacheConfig() const; 181 182 void set_name(port::StringPiece name); 183 const string &name() const { return name_; } 184 const string &demangled_name() const { return demangled_name_; } 185 186 private: 187 // The StreamExecutor that loads this kernel object. 188 StreamExecutor *parent_; 189 190 // Implementation delegated to for platform-specific functionality. 191 std::unique_ptr<internal::KernelInterface> implementation_; 192 193 string name_; 194 string demangled_name_; 195 196 KernelMetadata metadata_; 197 198 SE_DISALLOW_COPY_AND_ASSIGN(KernelBase); 199 }; 200 201 // Whether T is a DeviceMemory-family pointer. 202 template <typename T> 203 struct IsDeviceMemoryPointer { 204 static constexpr bool value = false; 205 }; 206 207 template <typename U> 208 struct IsDeviceMemoryPointer<DeviceMemory<U> *> { 209 static constexpr bool value = true; 210 }; 211 212 template <> 213 struct IsDeviceMemoryPointer<DeviceMemoryBase *> { 214 static constexpr bool value = true; 215 }; 216 217 // Whether T is a DeviceMemory-family value-like thing (which includes a 218 // reference). This trait is useful because we pack values in the same manner as 219 // references. 220 template <typename T> 221 struct IsDeviceMemoryValueLike { 222 static constexpr bool value = false; 223 }; 224 225 template <typename U> 226 struct IsDeviceMemoryValueLike<DeviceMemory<U> &> { 227 static constexpr bool value = true; 228 }; 229 230 // We need to treat SharedDeviceMemory types differently than other DeviceMemory 231 // types (since they maintain no allocations), hence these specializations. 232 template <typename U> 233 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> { 234 static constexpr bool value = false; 235 }; 236 237 template <> 238 struct IsDeviceMemoryValueLike<DeviceMemoryBase &> { 239 static constexpr bool value = true; 240 }; 241 242 template <typename U> 243 struct IsDeviceMemoryValueLike<DeviceMemory<U>> { 244 static constexpr bool value = true; 245 }; 246 247 template <typename U> 248 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> { 249 static constexpr bool value = false; 250 }; 251 252 template <> 253 struct IsDeviceMemoryValueLike<DeviceMemoryBase> { 254 static constexpr bool value = true; 255 }; 256 257 template <typename U> 258 struct IsSharedDeviceMemory { 259 static constexpr bool value = false; 260 }; 261 262 template <typename U> 263 struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> { 264 static constexpr bool value = true; 265 }; 266 267 template <typename U> 268 struct IsSharedDeviceMemory<SharedDeviceMemory<U>> { 269 static constexpr bool value = true; 270 }; 271 272 // Basic data about a kernel argument. 273 struct KernelArg { 274 bool is_shared; 275 const void *address; 276 size_t size; 277 }; 278 279 // An iterator for traversing all the arguments of a KernelArgsArray. 280 class KernelArgIterator { 281 public: 282 KernelArgIterator(int number_of_argument_addresses, 283 int number_of_shared_memory_arguments, 284 const void *const *arg_addresses_data, 285 const size_t *arg_sizes_data, 286 const size_t *shmem_bytes_data, 287 const size_t *shmem_indices_data) 288 : arg_index_(0), 289 number_of_arguments_(number_of_argument_addresses + 290 number_of_shared_memory_arguments), 291 arg_address_iter_(arg_addresses_data), 292 arg_size_iter_(arg_sizes_data), 293 shmem_bytes_iter_(shmem_bytes_data), 294 shmem_indices_iter_(shmem_indices_data), 295 shmem_indices_end_(shmem_indices_data + 296 number_of_shared_memory_arguments) {} 297 298 // Returns true if another argument is present in the iterator. 299 bool has_next() { return arg_index_ < number_of_arguments_; } 300 301 // Returns the next argument in the iterator. 302 // 303 // Returns a default-constructed KernelArg if there is no next argument. 304 KernelArg next() { 305 KernelArg result = {}; 306 if (!has_next()) { 307 return result; 308 } else if ((shmem_indices_iter_ != shmem_indices_end_) && 309 (arg_index_ == *shmem_indices_iter_)) { 310 result.is_shared = true; 311 result.address = nullptr; 312 result.size = *shmem_bytes_iter_; 313 ++shmem_indices_iter_; 314 ++shmem_bytes_iter_; 315 } else { 316 result.is_shared = false; 317 result.address = *arg_address_iter_; 318 result.size = *arg_size_iter_; 319 ++arg_address_iter_; 320 ++arg_size_iter_; 321 } 322 ++arg_index_; 323 return result; 324 } 325 326 private: 327 size_t arg_index_; 328 size_t number_of_arguments_; 329 const void *const *arg_address_iter_; 330 const size_t *arg_size_iter_; 331 const size_t *shmem_bytes_iter_; 332 const size_t *shmem_indices_iter_; 333 const size_t *const shmem_indices_end_; 334 }; 335 336 // Base class for KernelArgsArray. 337 // 338 // Supports all the getter methods that do not depend on the compile-time number 339 // of arguments template parameter. 340 // 341 // This class exists as a way to pass kernel arguments to 342 // StreamExecutorInterface::Launch. That Launch method is virtual, so it can't 343 // be templated to accept any KernelArgsArray type, therefore a reference to 344 // this base type is passed instead. 345 // 346 // Performance is not a concern here because each of these methods will be 347 // called at most once per kernel launch. Past performance concerns with 348 // KernelArgsArray have been in reference to the argument packing routines which 349 // are called once per kernel argument. Those packing routines are now handled 350 // by the templated KernelArgsArray subclass of this class where they can take 351 // advantage of compile-time knowledge of the number of arguments in order to be 352 // very efficient. 353 class KernelArgsArrayBase { 354 public: 355 virtual ~KernelArgsArrayBase() = default; 356 357 // Gets the number of arguments added so far, including shared memory 358 // arguments. 359 virtual size_t number_of_arguments() const = 0; 360 361 // Gets the total number of shared memory bytes added so far. 362 virtual uint64 number_of_shared_bytes() const = 0; 363 364 // Gets the list of argument addresses. 365 virtual port::ArraySlice<const void *> argument_addresses() const = 0; 366 367 // Gets an iterator to the arguments in the array. 368 virtual KernelArgIterator arg_iterator() const = 0; 369 }; 370 371 // A list of arguments for a kernel call. 372 // 373 // The template parameter kNumArgs is the maximum number of arguments which can 374 // be stored in the list. 375 // 376 // Contains a list of addresses for non-shared-memory arguments and a list of 377 // sizes for shared-memory arguments. Since the shared-memory arguments may be 378 // interspersed with the non-shared-memory arguments, it also stores a list of 379 // the indices at which the shared-memory arguments appeared. 380 // 381 // For example, if the argument address list contains {a, b, c, d, e}, the 382 // shared-memory arguments list contains the sizes of {A, B, C}, and the 383 // shared-memory indices list contains {0, 3, 5}, then the original list of 384 // arguments was {A, a, b, B, c, C, d, e}. 385 // 386 // This way of storing the arguments makes CUDA kernel calls efficient because 387 // they only require the argument address list and the total number of shared 388 // bytes, but it also makes it possible for OpenCL kernel calls because they 389 // depend on the location of each shared-memory argument and its size. 390 // 391 // Note that the code for adding arguments has been identified as a performance 392 // hotspot in some real-world applications so this structure has been optimized 393 // for the performance of argument adding. 394 template <size_t kNumArgs> 395 class KernelArgsArray : public KernelArgsArrayBase { 396 public: 397 explicit KernelArgsArray() 398 : total_shared_memory_bytes_(0), 399 number_of_argument_addresses_(0), 400 number_of_shared_memory_arguments_(0) {} 401 402 // Adds an argument to the list. 403 // 404 // Note that the address of the argument is stored, so the input must not go 405 // out of scope before the instance of this class that calls this method does. 406 template <typename T> 407 void add_argument(const T &arg) { 408 argument_addresses_[number_of_argument_addresses_] = 409 static_cast<const void *>(&arg); 410 argument_sizes_[number_of_argument_addresses_] = sizeof(arg); 411 ++number_of_argument_addresses_; 412 } 413 414 // Adds a device memory argument to the list. 415 void add_device_memory_argument(const DeviceMemoryBase &arg) { 416 const void **copy_ptr = 417 &device_memory_opaque_pointers_[number_of_argument_addresses_]; 418 *copy_ptr = arg.opaque(); 419 argument_addresses_[number_of_argument_addresses_] = copy_ptr; 420 argument_sizes_[number_of_argument_addresses_] = sizeof(void *); 421 ++number_of_argument_addresses_; 422 } 423 424 // Adds a shared memory argument to the list. 425 // 426 // The only significant information about a shared argument is its size, so 427 // that is the only parameter in this function. 428 void add_shared_bytes(size_t number_of_bytes) { 429 shared_memory_indices_[number_of_shared_memory_arguments_] = 430 number_of_argument_addresses_ + number_of_shared_memory_arguments_; 431 shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes; 432 ++number_of_shared_memory_arguments_; 433 total_shared_memory_bytes_ += number_of_bytes; 434 } 435 436 // Gets the number of arguments added so far, including shared memory 437 // arguments. 438 size_t number_of_arguments() const override { 439 return number_of_argument_addresses_ + number_of_shared_memory_arguments_; 440 } 441 442 // Gets the total number of shared memory bytes added so far. 443 uint64 number_of_shared_bytes() const override { 444 return total_shared_memory_bytes_; 445 } 446 447 // Gets the list of argument addresses. 448 port::ArraySlice<const void *> argument_addresses() const override { 449 return port::ArraySlice<const void *>(argument_addresses_.data(), 450 number_of_argument_addresses_); 451 } 452 453 // Gets an iterator to the arguments in the array. 454 KernelArgIterator arg_iterator() const override { 455 return KernelArgIterator( 456 number_of_argument_addresses_, number_of_shared_memory_arguments_, 457 argument_addresses_.data(), argument_sizes_.data(), 458 shared_memory_bytes_.data(), shared_memory_indices_.data()); 459 } 460 461 private: 462 // A place to store copies of opaque pointers from device memory arguments. 463 std::array<const void *, kNumArgs> device_memory_opaque_pointers_; 464 465 // Addresses for non-shared-memory arguments. 466 std::array<const void *, kNumArgs> argument_addresses_; 467 468 // Sizes for non-shared-memory arguments. 469 std::array<size_t, kNumArgs> argument_sizes_; 470 471 // Size in bytes for each shared memory argument. 472 std::array<size_t, kNumArgs> shared_memory_bytes_; 473 474 // Indices in the arguments array for shared memory arguments. 475 std::array<size_t, kNumArgs> shared_memory_indices_; 476 477 // Total of all shared memory sizes. 478 size_t total_shared_memory_bytes_; 479 480 // Number of significant entries in argument_addresses_ and argument_sizes_. 481 size_t number_of_argument_addresses_; 482 483 // Number of significant entries in shared_memory_bytes_ and 484 // shared_memory_indices_. 485 size_t number_of_shared_memory_arguments_; 486 }; 487 488 // Typed variant of KernelBase, like a typed device function pointer. See the 489 // file comment for details and example usage. 490 // 491 // This class contains template metaprogramming magic to type check the 492 // parameters passed to a kernel launch are acceptable, and subsequently pack 493 // them into a form which can be used by the StreamExecutorInterface 494 // implementation. (i.e. CUDA and OpenCL both bind void*s with associated 495 // sizes as kernel arguments.) 496 // 497 // Thread-compatible. 498 template <typename... Params> 499 class TypedKernel : public KernelBase { 500 public: 501 static constexpr size_t kNumberOfParameters = sizeof...(Params); 502 503 // Delegates to KernelBase::KernelBase(), see that constructor. 504 explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {} 505 506 // Test-only constructor that can take a mock KernelInterface implementation. 507 // Takes ownership of implementation, it should not be null. 508 TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation) 509 : KernelBase(parent, implementation) {} 510 511 private: 512 // Stream needs access to the specific parameter-packing functionality that 513 // the TypedKernel provides for its corresponding type signature (and no other 514 // type signatures). 515 friend class Stream; 516 517 // This is the main entry point into the magic. Packs the parameters (which 518 // must type check against the class template) into the args and sizes 519 // arrays. 520 // 521 // Const refs are taken as parameters on all of the handlers to avoid 522 // implicit type promotion of integers. 523 // 524 // WARNING: as a performance optimization this method may store pointers to 525 // some of the input parameters in the kernel args structure, so any params 526 // passed into this method must live at least as long as the kernel args 527 // structure. 528 void PackParams(KernelArgsArray<kNumberOfParameters> *args, 529 Params &... params) const { 530 PackOneParam(args, params...); 531 } 532 533 template <typename T, typename... RestOfParams> 534 void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg, 535 const RestOfParams &... rest) const { 536 PackOneParam(args, arg); 537 PackOneParam(args, rest...); 538 } 539 540 // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array. 541 // The enable_if<> is for excluding DeviceMemoryBase args, which have a 542 // separate implementation below. 543 template <typename T> 544 void PackOneParam( 545 KernelArgsArray<kNumberOfParameters> *args, const T &arg, 546 typename std::enable_if<!IsDeviceMemoryValueLike<T>::value && 547 !IsDeviceMemoryPointer<T>::value && 548 !IsSharedDeviceMemory<T>::value>::type * = 549 nullptr) const { 550 static_assert(!std::is_pointer<T>::value, 551 "cannot pass raw pointer to the device"); 552 static_assert(!std::is_convertible<T, DeviceMemoryBase>::value, 553 "cannot pass device memory as a normal value"); 554 args->add_argument(arg); 555 } 556 557 // DeviceMemoryBase family reference override. 558 template <typename T> 559 void PackOneParam( 560 KernelArgsArray<kNumberOfParameters> *args, const T &arg, 561 typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * = 562 nullptr) const { 563 args->add_device_memory_argument(arg); 564 } 565 566 // DeviceMemoryBase family pointer override. 567 template <typename T> 568 void PackOneParam( 569 KernelArgsArray<kNumberOfParameters> *args, T arg, 570 typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * = 571 nullptr) const { 572 DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg); 573 args->add_device_memory_argument(*ptr); 574 } 575 576 // Dynamic shared device memory has a size, but no associated allocation on 577 // the host; internally, the device will allocate storage. 578 template <typename T> 579 void PackOneParam( 580 KernelArgsArray<kNumberOfParameters> *args, T arg, 581 typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * = 582 nullptr) const { 583 args->add_shared_bytes(arg.size()); 584 } 585 586 // Base case for variadic template expansion - nothing to do! 587 void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {} 588 589 SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel); 590 }; 591 592 // Template metaprogramming helper type that helps us produce better error 593 // messages at compile time when the are mismatches between the parameter 594 // type list and the argument type list. 595 template <typename ParamTuple, typename ArgTuple> 596 struct KernelInvocationChecker { 597 // Whether the parameter tuple and argument tuple match in length. 598 static constexpr bool kLengthMatches = 599 std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value; 600 601 // The (matching) length of the parameters and arguments type lists. 602 static constexpr int kTupleLength = 603 static_cast<int>(std::tuple_size<ArgTuple>::value); 604 605 // Helper trait to say whether the parameter wants a DeviceMemory-reference 606 // compatible type. This is for inexact type matches, so that it doesn't have 607 // to be precisely a const DeviceMemory<T>&, but can also be a value that 608 // represents the same. 609 template <typename ParamType, typename ArgType> 610 struct IsCompatibleDeviceMemoryRef { 611 static constexpr bool value = false; 612 }; 613 614 // See type trait definition above. 615 template <typename U> 616 struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> { 617 static constexpr bool value = true; 618 }; 619 620 // See type trait definition above. 621 template <typename U> 622 struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &, 623 SharedDeviceMemory<U>> { 624 static constexpr bool value = true; 625 }; 626 627 // Returns whether ParamT and ArgT are compatible for data parallel kernel 628 // parameter packing without any assert functionality. 629 template <typename ParamT, typename ArgT> 630 static constexpr bool CompatibleNoAssert() { 631 return std::is_same<typename std::remove_const<ParamT>::type, 632 ArgT>::value || 633 IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value; 634 } 635 636 // Checks whether ParamT and ArgT are compatible for data parallel kernel 637 // parameter packing. kArgumentNumber is unused, it just for error display. 638 // 639 // NOTE: if you encounter an error here, you can see the mismatch by looking 640 // at the end of the last error message, which will be of the form: 641 // 642 // ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &, 643 // perftools::gputools::DeviceMemory<AnotherThing>, true, 644 // 0>' 645 // requested here 646 // 647 // This means that the 0th argument you passed to the kernel invocation should 648 // have been DeviceMemory<OneThing> but was observed to be 649 // DeviceMemory<AnotherThing>. 650 template <typename ParamT, typename ArgT, bool kShouldStaticAssert, 651 int kArgumentNumber> 652 static constexpr bool Compatible() { 653 static_assert( 654 kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true, 655 "parameter type (LHS) is not compatible with argument type (RHS)"); 656 return CompatibleNoAssert<ParamT, ArgT>(); 657 } 658 659 // Checks the parameter/argument match at kArgumentNumber for an out of bounds 660 // argument number. 661 // 662 // This is the base case: we've run out of argument to check, so we're all 663 // good. 664 template <int kArgumentNumber, bool kShouldStaticAssert> 665 static constexpr bool CheckParam( 666 typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) { 667 return true; 668 } 669 670 // Checks the parameter/argument match at kArgumentNumber. 671 // kShouldStaticAssert determines whether to assert out on a mismatch, or just 672 // yield the constexpr boolean value. 673 template <int kArgumentNumber, bool kShouldStaticAssert> 674 static constexpr bool CheckParam( 675 typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) { 676 typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type 677 ParamT; 678 typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT; 679 return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() && 680 CheckParam<kArgumentNumber - 1, kShouldStaticAssert>(); 681 } 682 683 // Checks the parameters/arguments for match, but doesn't static assert out. 684 // This is useful for testing/inspecting whether a set of parameters match in 685 // things like tests. 686 static constexpr bool CheckAllNoStaticAssert() { 687 return kLengthMatches && CheckParam<kTupleLength - 1, false>(); 688 } 689 690 // Checks the parameters and static asserts out with a helpful error message 691 // (and useful template parameters in the instantiation stack) if there is an 692 // error. 693 static constexpr bool CheckAllStaticAssert() { 694 static_assert(kLengthMatches, 695 "argument length mismatched against typed kernel parameters"); 696 return kLengthMatches && CheckParam<kTupleLength - 1, true>(); 697 } 698 }; 699 700 // This is a convenience type for checking whether a typed kernel matches 701 // against a type list. 702 template <typename KernelT, typename... Params> 703 struct KernelParamsOk { 704 static constexpr bool kResult = false; 705 }; 706 707 // See above. 708 template <typename... Params, typename... Args> 709 struct KernelParamsOk<TypedKernel<Params...>, Args...> { 710 static constexpr bool kResult = KernelInvocationChecker< 711 std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert(); 712 }; 713 714 } // namespace gputools 715 } // namespace perftools 716 717 #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ 718