1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_ 17 #define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_ 18 19 /** 20 * Wrappers and helpers for CUDA device code. 21 * 22 * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide 23 * backwards compatibility, see go/volta-porting for details. 24 * Provides atomic operations on types that aren't natively supported. 25 */ 26 27 #if GOOGLE_CUDA 28 29 #include <algorithm> 30 #include <complex> 31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 32 #include "cuda/include/cuda.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace tensorflow { 36 37 namespace detail { 38 39 // Helper for range-based for loop using 'delta' increments. 40 // Usage: see CudaGridRange?() functions below. 41 template <typename T> 42 class CudaGridRange { 43 struct Iterator { 44 __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {} 45 __device__ T operator*() const { return index_; } 46 __device__ Iterator& operator++() { 47 index_ += delta_; 48 return *this; 49 } 50 __device__ bool operator!=(const Iterator& other) const { 51 bool greater = index_ > other.index_; 52 bool less = index_ < other.index_; 53 // Anything past an end iterator (delta_ == 0) is equal. 54 // In range-based for loops, this optimizes to 'return less'. 55 if (!other.delta_) { 56 return less; 57 } 58 if (!delta_) { 59 return greater; 60 } 61 return less || greater; 62 } 63 64 private: 65 T index_; 66 const T delta_; 67 }; 68 69 public: 70 __device__ CudaGridRange(T begin, T delta, T end) 71 : begin_(begin), delta_(delta), end_(end) {} 72 73 __device__ Iterator begin() const { return Iterator{begin_, delta_}; } 74 __device__ Iterator end() const { return Iterator{end_, 0}; } 75 76 private: 77 T begin_; 78 T delta_; 79 T end_; 80 }; 81 82 } // namespace detail 83 84 // Helper to visit indices in the range 0 <= i < count, using the x-coordinate 85 // of the global thread index. That is, each index i is visited by all threads 86 // with the same x-coordinate. 87 // Usage: for(int i : CudaGridRangeX(count)) { visit(i); } 88 template <typename T> 89 __device__ detail::CudaGridRange<T> CudaGridRangeX(T count) { 90 return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x, 91 gridDim.x * blockDim.x, count); 92 } 93 94 // Helper to visit indices in the range 0 <= i < count using the y-coordinate. 95 // Usage: for(int i : CudaGridRangeY(count)) { visit(i); } 96 template <typename T> 97 __device__ detail::CudaGridRange<T> CudaGridRangeY(T count) { 98 return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y, 99 gridDim.y * blockDim.y, count); 100 } 101 102 // Helper to visit indices in the range 0 <= i < count using the z-coordinate. 103 // Usage: for(int i : CudaGridRangeZ(count)) { visit(i); } 104 template <typename T> 105 __device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) { 106 return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z, 107 gridDim.z * blockDim.z, count); 108 } 109 110 // Mask for all 32 threads in a warp. 111 const unsigned kCudaWarpAll = 0xffffffff; 112 113 // Returns the warp lane ID of the calling thread 114 __device__ inline unsigned CudaLaneId() { 115 unsigned int lane_id; 116 asm("mov.u32 %0, %%laneid;" : "=r"(lane_id)); 117 return lane_id; 118 } 119 120 namespace detail { 121 // Returns true if mask is a valid parameter for __shfl*sync to return a well 122 // defined value, assuming the calling lane will read from src_lane as part of 123 // the shuffle operation. 124 // 125 // Specifically, returns true iff mask has the calling lane bit and the src_lane 126 // bit set, and the src_lane calls this function with the same mask value 127 // (required for the two threads to wait for each other). 128 // 129 // On Volta, for some invalid masks, this function hangs or returns false 130 // positives, because the implementation shuffles with the same mask that 131 // we are validating. Run on Pascal if you suspect that the mask is incorrect. 132 __device__ inline bool CudaValidateShuffleSyncMask(unsigned mask, 133 unsigned src_lane) { 134 unsigned src_dst_mask = 1u << CudaLaneId() | 1u << src_lane; 135 #if CUDA_VERSION >= 9000 136 unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane); 137 #else 138 unsigned src_lane_mask = __shfl(mask, src_lane); 139 #endif 140 return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask; 141 } 142 143 // Returns the actual source lane for shuffle. 144 __device__ inline unsigned CudaShuffleGetSrcLane(int src_lane, int width) { 145 int lane_id = CudaLaneId(); 146 int lane_base = lane_id & ~width + 1; 147 int lane_offset = src_lane & width - 1; 148 return lane_base + lane_offset; 149 } 150 151 // Returns the source lane for shuffle up. 152 __device__ inline unsigned CudaShuffleUpGetSrcLane(unsigned delta, int width) { 153 unsigned lane_id = CudaLaneId(); 154 if ((lane_id & width - 1) < delta) { 155 return lane_id; 156 } 157 return lane_id - delta; 158 } 159 160 // Returns the source lane for shuffle down. 161 __device__ inline unsigned CudaShuffleDownGetSrcLane(unsigned delta, 162 int width) { 163 unsigned lane_id = CudaLaneId(); 164 if ((lane_id & width - 1) + delta >= width) { 165 return lane_id; 166 } 167 return lane_id + delta; 168 } 169 170 // Returns the source lane for shuffle xor. 171 __device__ inline unsigned CudaShuffleXorGetSrcLane(int lane_mask, int width) { 172 int lane_id = CudaLaneId(); 173 int src_lane = lane_id ^ lane_mask; 174 if (src_lane > (lane_id | width - 1)) { 175 return lane_id; 176 } 177 return src_lane; 178 } 179 } // namespace detail 180 181 // For all *_sync wrappers below, it is illegal to synchronize threads from 182 // different program locations, because that is not supported before sm_70. 183 // In other words, all threads in 'mask' must call the functions in convergence. 184 // Code that requires sm_70 (and CUDA 9) may use the intrinsic directly. 185 // 186 // It is also illegal to shuffle with a mask that produces an undefined result 187 // for any of the threads. Specifically, all source threads of the shuffle 188 // must have their corresponding bit in 'mask' set. 189 190 // Wrapper for __syncwarp. No-op for CUDA 8 and earlier. 191 __device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) { 192 assert(mask & 1u << CudaLaneId()); 193 #if CUDA_VERSION >= 9000 194 __syncwarp(mask); 195 #endif 196 } 197 198 // Wrapper for __ballot_sync. All threads in 'mask' must call this function in 199 // convergence, see comment above for details. 200 __device__ inline unsigned CudaBallotSync(unsigned mask, int pred) { 201 assert(mask & 1u << CudaLaneId()); 202 #if CUDA_VERSION >= 9000 203 return __ballot_sync(mask, pred); 204 #else 205 return __ballot(pred) & mask; // Apply mask to match __ballot_sync's spec. 206 #endif 207 } 208 209 // Wrapper for __any_sync. All threads in 'mask' must call this function in 210 // convergence, see comment above for details. 211 __device__ inline int CudaAnySync(unsigned mask, int pred) { 212 assert(mask & 1u << CudaLaneId()); 213 #if CUDA_VERSION >= 9000 214 return __any_sync(mask, pred); 215 #else 216 return __any(pred); 217 #endif 218 } 219 220 // Wrapper for __all_sync. All threads in 'mask' must call this function in 221 // convergence, see comment above for details. 222 __device__ inline int CudaAllSync(unsigned mask, int pred) { 223 assert(mask & 1u << CudaLaneId()); 224 #if CUDA_VERSION >= 9000 225 return __all_sync(mask, pred); 226 #else 227 return __all(pred); 228 #endif 229 } 230 231 // Wrapper for __shfl_sync. All threads in 'mask' must call this function in 232 // convergence, see comment above for details. 233 template <typename T> 234 __device__ T CudaShuffleSync(unsigned mask, T value, int src_lane, 235 int width = warpSize) { 236 assert(!(width & width - 1)); 237 assert(detail::CudaValidateShuffleSyncMask( 238 mask, detail::CudaShuffleGetSrcLane(src_lane, width))); 239 #if CUDA_VERSION >= 9000 240 return __shfl_sync(mask, value, src_lane, width); 241 #else 242 return __shfl(value, src_lane, width); 243 #endif 244 } 245 246 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned 247 // instead of float for lo and hi (which is incorrect with ftz, for example). 248 // See b/69446944. 249 __device__ inline double CudaShuffleSync(unsigned mask, double value, 250 int src_lane, int width = warpSize) { 251 unsigned lo, hi; 252 asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); 253 hi = CudaShuffleSync(mask, hi, src_lane, width); 254 lo = CudaShuffleSync(mask, lo, src_lane, width); 255 asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); 256 return value; 257 } 258 259 // Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in 260 // convergence, see comment above for details. 261 template <typename T> 262 __device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta, 263 int width = warpSize) { 264 assert(!(width & width - 1)); 265 assert(detail::CudaValidateShuffleSyncMask( 266 mask, detail::CudaShuffleUpGetSrcLane(delta, width))); 267 #if CUDA_VERSION >= 9000 268 return __shfl_up_sync(mask, value, delta, width); 269 #else 270 return __shfl_up(value, delta, width); 271 #endif 272 } 273 274 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned 275 // instead of float for lo and hi (which is incorrect with ftz, for example). 276 // See b/69446944. 277 __device__ inline double CudaShuffleUpSync(unsigned mask, double value, 278 unsigned delta, 279 int width = warpSize) { 280 unsigned lo, hi; 281 asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); 282 hi = CudaShuffleUpSync(mask, hi, delta, width); 283 lo = CudaShuffleUpSync(mask, lo, delta, width); 284 asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); 285 return value; 286 } 287 288 // Wrapper for __shfl_down_sync. All threads in 'mask' must call this function 289 // in convergence, see comment above for details. 290 template <typename T> 291 __device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta, 292 int width = warpSize) { 293 assert(!(width & width - 1)); 294 assert(detail::CudaValidateShuffleSyncMask( 295 mask, detail::CudaShuffleDownGetSrcLane(delta, width))); 296 #if CUDA_VERSION >= 9000 297 return __shfl_down_sync(mask, value, delta, width); 298 #else 299 return __shfl_down(value, delta, width); 300 #endif 301 } 302 303 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned 304 // instead of float for lo and hi (which is incorrect with ftz, for example). 305 // See b/69446944. 306 __device__ inline double CudaShuffleDownSync(unsigned mask, double value, 307 unsigned delta, 308 int width = warpSize) { 309 unsigned lo, hi; 310 asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); 311 hi = CudaShuffleDownSync(mask, hi, delta, width); 312 lo = CudaShuffleDownSync(mask, lo, delta, width); 313 asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); 314 return value; 315 } 316 317 // Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in 318 // convergence, see comment above for details. 319 template <typename T> 320 __device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask, 321 int width = warpSize) { 322 assert(!(width & width - 1)); 323 assert(detail::CudaValidateShuffleSyncMask( 324 mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width))); 325 #if CUDA_VERSION >= 9000 326 return __shfl_xor_sync(mask, value, lane_mask, width); 327 #else 328 return __shfl_xor(value, lane_mask, width); 329 #endif 330 } 331 332 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned 333 // instead of float for lo and hi (which is incorrect with ftz, for example). 334 // See b/69446944. 335 __device__ inline double CudaShuffleXorSync(unsigned mask, double value, 336 int lane_mask, 337 int width = warpSize) { 338 unsigned lo, hi; 339 asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value)); 340 hi = CudaShuffleXorSync(mask, hi, lane_mask, width); 341 lo = CudaShuffleXorSync(mask, lo, lane_mask, width); 342 asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi)); 343 return value; 344 } 345 346 // Wrapper for __ldg. 347 template <typename T> 348 __host__ __device__ T CudaLdg(const T* address) { 349 #if __CUDA_ARCH__ >= 350 350 return __ldg(address); 351 #else 352 return *address; 353 #endif 354 } 355 356 __host__ __device__ inline bool CudaLdg(const bool* address) { 357 return CudaLdg(reinterpret_cast<const char*>(address)) != 0; 358 } 359 360 __host__ __device__ inline std::complex<float> CudaLdg( 361 const std::complex<float>* address) { 362 #if __CUDA_ARCH__ >= 350 363 float2 mem = __ldg(reinterpret_cast<const float2*>(address)); 364 return std::complex<float>(mem.x, mem.y); 365 #else 366 return *address; 367 #endif 368 } 369 370 __host__ __device__ inline std::complex<double> CudaLdg( 371 const std::complex<double>* address) { 372 #if __CUDA_ARCH__ >= 350 373 double2 mem = __ldg(reinterpret_cast<const double2*>(address)); 374 return std::complex<double>(mem.x, mem.y); 375 #else 376 return *address; 377 #endif 378 } 379 380 // Zeroes count elements starting at ptr using all threads of a 1-D grid. 381 // Note: this function does not synchronize, and therefore the memory range is 382 // not guaranteed to be zero until the next kernel launch. 383 template <typename T> 384 __global__ void SetZero(const int count, T* ptr) { 385 // Check that the grid is one dimensional and index doesn't overflow. 386 assert(blockDim.y == 1 && blockDim.z == 1); 387 assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); 388 for (int i : CudaGridRangeX(count)) { 389 ptr[i] = T(0); 390 } 391 } 392 393 // Helper to set all tensor entries to a specific value. 394 template <typename T> 395 __global__ void SetToValue(const int count, T* ptr, T value) { 396 // Check that the grid is one dimensional and index doesn't overflow. 397 assert(blockDim.y == 1 && blockDim.z == 1); 398 assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); 399 for (int i : CudaGridRangeX(count)) { 400 ptr[i] = value; 401 } 402 } 403 404 namespace detail { 405 // Helper function for atomic accumulation implemented as CAS. 406 template <typename T, typename F> 407 __device__ T CudaAtomicCasHelper(T* ptr, F accumulate) { 408 T old = *ptr; 409 T assumed; 410 do { 411 assumed = old; 412 old = atomicCAS(ptr, assumed, accumulate(assumed)); 413 } while (assumed != old); 414 return old; 415 } 416 417 // Overload for floating point (using integer comparison to handle NaN 418 // correctly). 419 template <typename F> 420 __device__ float CudaAtomicCasHelper(float* ptr, F accumulate) { 421 return __float_as_int( 422 CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) { 423 return __float_as_int(accumulate(__int_as_float(a))); 424 })); 425 } 426 template <typename F> 427 __device__ double CudaAtomicCasHelper(double* ptr, F accumulate) { 428 return __longlong_as_double(CudaAtomicCasHelper( 429 reinterpret_cast<tensorflow::uint64*>(ptr), 430 [accumulate](tensorflow::uint64 a) { 431 return __double_as_longlong(accumulate(__longlong_as_double(a))); 432 })); 433 } 434 435 // Overload of above function for half. Note that we don't have 436 // atomicCAS() for anything less than 32 bits, so we need to include the 437 // other 16 bits in the operation. 438 // 439 // This version is going to be very slow 440 // under high concurrency, since most threads will be spinning on failing 441 // their compare-and-swap tests. (The fact that we get false sharing on the 442 // neighboring fp16 makes this even worse.) If you are doing a large reduction, 443 // you are much better off with doing the intermediate steps in fp32 and then 444 // switching to fp16 as late as you can in the calculations. 445 // 446 // Note: Assumes little endian. 447 template <typename F> 448 __device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) { 449 #if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) 450 static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian"); 451 #endif 452 namespace half_impl = Eigen::half_impl; 453 intptr_t intptr = reinterpret_cast<intptr_t>(ptr); 454 assert(!(intptr & 0x1)); // should be 2-aligned. 455 if (intptr & 0x2) { 456 // The half is in the second part of the uint32 (upper 16 bits). 457 uint32* address = reinterpret_cast<uint32*>(intptr - 2); 458 uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { 459 unsigned short high = static_cast<unsigned short>(arg >> 16); 460 Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high)); 461 return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff); 462 }); 463 return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16)); 464 } else { 465 // The half is in the first part of the uint32 (lower 16 bits). 466 uint32* address = reinterpret_cast<uint32*>(intptr); 467 uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) { 468 unsigned short low = static_cast<unsigned short>(arg & 0xffff); 469 Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low)); 470 return (arg & 0xffff0000) | static_cast<uint32>(acc.x); 471 }); 472 return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff)); 473 } 474 } 475 476 template <typename From, typename To> 477 using ToTypeIfConvertible = 478 typename std::enable_if<std::is_convertible<From, To>::value, To>::type; 479 480 } // namespace detail 481 482 // CUDA provides atomic ops, but not for all types. We provide wrappers 483 // for some ops and provide implementation for all reasonable types. 484 485 template <typename T, typename U> 486 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicAdd(T* ptr, U value) { 487 return atomicAdd(ptr, value); 488 } 489 490 __device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr, 491 Eigen::half value) { 492 return detail::CudaAtomicCasHelper( 493 ptr, [value](Eigen::half a) { return a + value; }); 494 } 495 496 497 #if __CUDA_ARCH__ < 600 498 __device__ inline double CudaAtomicAdd(double* ptr, double value) { 499 return detail::CudaAtomicCasHelper(ptr, 500 [value](double a) { return a + value; }); 501 } 502 #elif __clang__ 503 // Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX. 504 // see https://reviews.llvm.org/D39638 505 __device__ inline double CudaAtomicAdd(double* ptr, double value) { 506 double result; 507 asm volatile("atom.add.f64 %0, [%1], %2;" 508 : "=d"(result) 509 : "l"(ptr), "d"(value) 510 : "memory"); 511 return result; 512 } 513 #endif 514 // CudaAtomicAdd 515 // Specializations of CudaAtomicAdd for complex types, which CudaAtomicAdd does 516 // not support. We treat a std::complex<T>* as a T* (the C++ standard section 517 // 26.4.4 allows this explicitly) and atomic add the real and imaginary 518 // components individually. The operation as a whole is not atomic, but we can 519 // safely treat the components independently for the purpose of accumulating. 520 __device__ inline std::complex<float> CudaAtomicAdd(std::complex<float>* ptr, 521 std::complex<float> value) { 522 auto ptr_scalar = reinterpret_cast<float*>(ptr); 523 return std::complex<float>(CudaAtomicAdd(ptr_scalar, value.real()), 524 CudaAtomicAdd(ptr_scalar + 1, value.imag())); 525 } 526 527 __device__ inline std::complex<double> CudaAtomicAdd( 528 std::complex<double>* ptr, std::complex<double> value) { 529 auto ptr_scalar = reinterpret_cast<double*>(ptr); 530 return std::complex<double>(CudaAtomicAdd(ptr_scalar, value.real()), 531 CudaAtomicAdd(ptr_scalar + 1, value.imag())); 532 } 533 534 // CudaAtomicSub 535 template <typename T, typename U> 536 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicSub(T* ptr, U value) { 537 return atomicSub(ptr, value); 538 } 539 540 // Specializations of substraction which add the negative value. 541 __device__ inline float CudaAtomicSub(float* ptr, float value) { 542 return CudaAtomicAdd(ptr, -value); 543 } 544 545 __device__ inline double CudaAtomicSub(double* ptr, double value) { 546 return CudaAtomicAdd(ptr, -value); 547 } 548 549 __device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr, 550 tensorflow::uint64 value) { 551 return CudaAtomicAdd(ptr, -value); 552 } 553 554 __device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr, 555 Eigen::half value) { 556 return detail::CudaAtomicCasHelper( 557 ptr, [value](Eigen::half a) { return a - value; }); 558 } 559 560 // CudaAtomicMax 561 template <typename T, typename U> 562 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMax(T* ptr, U value) { 563 return atomicMax(ptr, value); 564 } 565 566 __device__ inline float CudaAtomicMax(float* ptr, float value) { 567 return detail::CudaAtomicCasHelper( 568 ptr, [value](float a) { return max(a, value); }); 569 } 570 571 __device__ inline double CudaAtomicMax(double* ptr, double value) { 572 return detail::CudaAtomicCasHelper( 573 ptr, [value](double a) { return max(a, value); }); 574 } 575 576 __device__ inline Eigen::half CudaAtomicMax(Eigen::half* ptr, 577 Eigen::half value) { 578 return detail::CudaAtomicCasHelper( 579 ptr, [value](Eigen::half a) { return max(a, value); }); 580 } 581 582 #if __CUDA_ARCH__ < 320 583 __device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr, 584 tensorflow::uint64 value) { 585 return detail::CudaAtomicCasHelper( 586 ptr, [value](tensorflow::uint64 a) { return max(a, value); }); 587 } 588 #endif 589 590 // CudaAtomicMin 591 template <typename T, typename U> 592 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMin(T* ptr, U value) { 593 return atomicMin(ptr, value); 594 } 595 596 __device__ inline float CudaAtomicMin(float* ptr, float value) { 597 return detail::CudaAtomicCasHelper( 598 ptr, [value](float a) { return min(a, value); }); 599 } 600 601 __device__ inline double CudaAtomicMin(double* ptr, double value) { 602 return detail::CudaAtomicCasHelper( 603 ptr, [value](double a) { return min(a, value); }); 604 } 605 606 __device__ inline Eigen::half CudaAtomicMin(Eigen::half* ptr, 607 Eigen::half value) { 608 return detail::CudaAtomicCasHelper( 609 ptr, [value](Eigen::half a) { return min(a, value); }); 610 } 611 612 #if __CUDA_ARCH__ < 320 613 __device__ inline tensorflow::uint64 CudaAtomicMin(tensorflow::uint64* ptr, 614 tensorflow::uint64 value) { 615 return detail::CudaAtomicCasHelper( 616 ptr, [value](tensorflow::uint64 a) { return min(a, value); }); 617 } 618 #endif 619 620 // CudaAtomicMul 621 template <typename T, typename U> 622 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMul(T* ptr, U value) { 623 return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; }); 624 } 625 626 // CudaAtomicDiv 627 template <typename T, typename U> 628 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicDiv(T* ptr, U value) { 629 return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; }); 630 } 631 632 } // namespace tensorflow 633 634 #endif // GOOGLE_CUDA 635 #endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_ 636