Home | History | Annotate | Download | only in util
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
     17 #define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
     18 
     19 /**
     20  * Wrappers and helpers for CUDA device code.
     21  *
     22  * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
     23  * backwards compatibility, see go/volta-porting for details.
     24  * Provides atomic operations on types that aren't natively supported.
     25  */
     26 
     27 #if GOOGLE_CUDA
     28 
     29 #include <algorithm>
     30 #include <complex>
     31 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     32 #include "cuda/include/cuda.h"
     33 #include "tensorflow/core/platform/types.h"
     34 
     35 namespace tensorflow {
     36 
     37 namespace detail {
     38 
     39 // Helper for range-based for loop using 'delta' increments.
     40 // Usage: see CudaGridRange?() functions below.
     41 template <typename T>
     42 class CudaGridRange {
     43   struct Iterator {
     44     __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
     45     __device__ T operator*() const { return index_; }
     46     __device__ Iterator& operator++() {
     47       index_ += delta_;
     48       return *this;
     49     }
     50     __device__ bool operator!=(const Iterator& other) const {
     51       bool greater = index_ > other.index_;
     52       bool less = index_ < other.index_;
     53       // Anything past an end iterator (delta_ == 0) is equal.
     54       // In range-based for loops, this optimizes to 'return less'.
     55       if (!other.delta_) {
     56         return less;
     57       }
     58       if (!delta_) {
     59         return greater;
     60       }
     61       return less || greater;
     62     }
     63 
     64    private:
     65     T index_;
     66     const T delta_;
     67   };
     68 
     69  public:
     70   __device__ CudaGridRange(T begin, T delta, T end)
     71       : begin_(begin), delta_(delta), end_(end) {}
     72 
     73   __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
     74   __device__ Iterator end() const { return Iterator{end_, 0}; }
     75 
     76  private:
     77   T begin_;
     78   T delta_;
     79   T end_;
     80 };
     81 
     82 }  // namespace detail
     83 
     84 // Helper to visit indices in the range 0 <= i < count, using the x-coordinate
     85 // of the global thread index. That is, each index i is visited by all threads
     86 // with the same x-coordinate.
     87 // Usage: for(int i : CudaGridRangeX(count)) { visit(i); }
     88 template <typename T>
     89 __device__ detail::CudaGridRange<T> CudaGridRangeX(T count) {
     90   return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x,
     91                                   gridDim.x * blockDim.x, count);
     92 }
     93 
     94 // Helper to visit indices in the range 0 <= i < count using the y-coordinate.
     95 // Usage: for(int i : CudaGridRangeY(count)) { visit(i); }
     96 template <typename T>
     97 __device__ detail::CudaGridRange<T> CudaGridRangeY(T count) {
     98   return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y,
     99                                   gridDim.y * blockDim.y, count);
    100 }
    101 
    102 // Helper to visit indices in the range 0 <= i < count using the z-coordinate.
    103 // Usage: for(int i : CudaGridRangeZ(count)) { visit(i); }
    104 template <typename T>
    105 __device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) {
    106   return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z,
    107                                   gridDim.z * blockDim.z, count);
    108 }
    109 
    110 // Mask for all 32 threads in a warp.
    111 const unsigned kCudaWarpAll = 0xffffffff;
    112 
    113 // Returns the warp lane ID of the calling thread
    114 __device__ inline unsigned CudaLaneId() {
    115   unsigned int lane_id;
    116   asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
    117   return lane_id;
    118 }
    119 
    120 namespace detail {
    121 // Returns true if mask is a valid parameter for __shfl*sync to return a well
    122 // defined value, assuming the calling lane will read from src_lane as part of
    123 // the shuffle operation.
    124 //
    125 // Specifically, returns true iff mask has the calling lane bit and the src_lane
    126 // bit set, and the src_lane calls this function with the same mask value
    127 // (required for the two threads to wait for each other).
    128 //
    129 // On Volta, for some invalid masks, this function hangs or returns false
    130 // positives, because the implementation shuffles with the same mask that
    131 // we are validating. Run on Pascal if you suspect that the mask is incorrect.
    132 __device__ inline bool CudaValidateShuffleSyncMask(unsigned mask,
    133                                                    unsigned src_lane) {
    134   unsigned src_dst_mask = 1u << CudaLaneId() | 1u << src_lane;
    135 #if CUDA_VERSION >= 9000
    136   unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane);
    137 #else
    138   unsigned src_lane_mask = __shfl(mask, src_lane);
    139 #endif
    140   return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask;
    141 }
    142 
    143 // Returns the actual source lane for shuffle.
    144 __device__ inline unsigned CudaShuffleGetSrcLane(int src_lane, int width) {
    145   int lane_id = CudaLaneId();
    146   int lane_base = lane_id & ~width + 1;
    147   int lane_offset = src_lane & width - 1;
    148   return lane_base + lane_offset;
    149 }
    150 
    151 // Returns the source lane for shuffle up.
    152 __device__ inline unsigned CudaShuffleUpGetSrcLane(unsigned delta, int width) {
    153   unsigned lane_id = CudaLaneId();
    154   if ((lane_id & width - 1) < delta) {
    155     return lane_id;
    156   }
    157   return lane_id - delta;
    158 }
    159 
    160 // Returns the source lane for shuffle down.
    161 __device__ inline unsigned CudaShuffleDownGetSrcLane(unsigned delta,
    162                                                      int width) {
    163   unsigned lane_id = CudaLaneId();
    164   if ((lane_id & width - 1) + delta >= width) {
    165     return lane_id;
    166   }
    167   return lane_id + delta;
    168 }
    169 
    170 // Returns the source lane for shuffle xor.
    171 __device__ inline unsigned CudaShuffleXorGetSrcLane(int lane_mask, int width) {
    172   int lane_id = CudaLaneId();
    173   int src_lane = lane_id ^ lane_mask;
    174   if (src_lane > (lane_id | width - 1)) {
    175     return lane_id;
    176   }
    177   return src_lane;
    178 }
    179 }  // namespace detail
    180 
    181 // For all *_sync wrappers below, it is illegal to synchronize threads from
    182 // different program locations, because that is not supported before sm_70.
    183 // In other words, all threads in 'mask' must call the functions in convergence.
    184 // Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
    185 //
    186 // It is also illegal to shuffle with a mask that produces an undefined result
    187 // for any of the threads. Specifically, all source threads of the shuffle
    188 // must have their corresponding bit in 'mask' set.
    189 
    190 // Wrapper for __syncwarp. No-op for CUDA 8 and earlier.
    191 __device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) {
    192   assert(mask & 1u << CudaLaneId());
    193 #if CUDA_VERSION >= 9000
    194   __syncwarp(mask);
    195 #endif
    196 }
    197 
    198 // Wrapper for __ballot_sync. All threads in 'mask' must call this function in
    199 // convergence, see comment above for details.
    200 __device__ inline unsigned CudaBallotSync(unsigned mask, int pred) {
    201   assert(mask & 1u << CudaLaneId());
    202 #if CUDA_VERSION >= 9000
    203   return __ballot_sync(mask, pred);
    204 #else
    205   return __ballot(pred) & mask;  // Apply mask to match __ballot_sync's spec.
    206 #endif
    207 }
    208 
    209 // Wrapper for __any_sync. All threads in 'mask' must call this function in
    210 // convergence, see comment above for details.
    211 __device__ inline int CudaAnySync(unsigned mask, int pred) {
    212   assert(mask & 1u << CudaLaneId());
    213 #if CUDA_VERSION >= 9000
    214   return __any_sync(mask, pred);
    215 #else
    216   return __any(pred);
    217 #endif
    218 }
    219 
    220 // Wrapper for __all_sync. All threads in 'mask' must call this function in
    221 // convergence, see comment above for details.
    222 __device__ inline int CudaAllSync(unsigned mask, int pred) {
    223   assert(mask & 1u << CudaLaneId());
    224 #if CUDA_VERSION >= 9000
    225   return __all_sync(mask, pred);
    226 #else
    227   return __all(pred);
    228 #endif
    229 }
    230 
    231 // Wrapper for __shfl_sync. All threads in 'mask' must call this function in
    232 // convergence, see comment above for details.
    233 template <typename T>
    234 __device__ T CudaShuffleSync(unsigned mask, T value, int src_lane,
    235                              int width = warpSize) {
    236   assert(!(width & width - 1));
    237   assert(detail::CudaValidateShuffleSyncMask(
    238       mask, detail::CudaShuffleGetSrcLane(src_lane, width)));
    239 #if CUDA_VERSION >= 9000
    240   return __shfl_sync(mask, value, src_lane, width);
    241 #else
    242   return __shfl(value, src_lane, width);
    243 #endif
    244 }
    245 
    246 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
    247 // instead of float for lo and hi (which is incorrect with ftz, for example).
    248 // See b/69446944.
    249 __device__ inline double CudaShuffleSync(unsigned mask, double value,
    250                                          int src_lane, int width = warpSize) {
    251   unsigned lo, hi;
    252   asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
    253   hi = CudaShuffleSync(mask, hi, src_lane, width);
    254   lo = CudaShuffleSync(mask, lo, src_lane, width);
    255   asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
    256   return value;
    257 }
    258 
    259 // Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in
    260 // convergence, see comment above for details.
    261 template <typename T>
    262 __device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta,
    263                                       int width = warpSize) {
    264   assert(!(width & width - 1));
    265   assert(detail::CudaValidateShuffleSyncMask(
    266       mask, detail::CudaShuffleUpGetSrcLane(delta, width)));
    267 #if CUDA_VERSION >= 9000
    268   return __shfl_up_sync(mask, value, delta, width);
    269 #else
    270   return __shfl_up(value, delta, width);
    271 #endif
    272 }
    273 
    274 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
    275 // instead of float for lo and hi (which is incorrect with ftz, for example).
    276 // See b/69446944.
    277 __device__ inline double CudaShuffleUpSync(unsigned mask, double value,
    278                                            unsigned delta,
    279                                            int width = warpSize) {
    280   unsigned lo, hi;
    281   asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
    282   hi = CudaShuffleUpSync(mask, hi, delta, width);
    283   lo = CudaShuffleUpSync(mask, lo, delta, width);
    284   asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
    285   return value;
    286 }
    287 
    288 // Wrapper for __shfl_down_sync. All threads in 'mask' must call this function
    289 // in convergence, see comment above for details.
    290 template <typename T>
    291 __device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta,
    292                                         int width = warpSize) {
    293   assert(!(width & width - 1));
    294   assert(detail::CudaValidateShuffleSyncMask(
    295       mask, detail::CudaShuffleDownGetSrcLane(delta, width)));
    296 #if CUDA_VERSION >= 9000
    297   return __shfl_down_sync(mask, value, delta, width);
    298 #else
    299   return __shfl_down(value, delta, width);
    300 #endif
    301 }
    302 
    303 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
    304 // instead of float for lo and hi (which is incorrect with ftz, for example).
    305 // See b/69446944.
    306 __device__ inline double CudaShuffleDownSync(unsigned mask, double value,
    307                                              unsigned delta,
    308                                              int width = warpSize) {
    309   unsigned lo, hi;
    310   asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
    311   hi = CudaShuffleDownSync(mask, hi, delta, width);
    312   lo = CudaShuffleDownSync(mask, lo, delta, width);
    313   asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
    314   return value;
    315 }
    316 
    317 // Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in
    318 // convergence, see comment above for details.
    319 template <typename T>
    320 __device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask,
    321                                 int width = warpSize) {
    322   assert(!(width & width - 1));
    323   assert(detail::CudaValidateShuffleSyncMask(
    324       mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width)));
    325 #if CUDA_VERSION >= 9000
    326   return __shfl_xor_sync(mask, value, lane_mask, width);
    327 #else
    328   return __shfl_xor(value, lane_mask, width);
    329 #endif
    330 }
    331 
    332 // Variant of the (undocumented) version from the CUDA SDK, but using unsigned
    333 // instead of float for lo and hi (which is incorrect with ftz, for example).
    334 // See b/69446944.
    335 __device__ inline double CudaShuffleXorSync(unsigned mask, double value,
    336                                             int lane_mask,
    337                                             int width = warpSize) {
    338   unsigned lo, hi;
    339   asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
    340   hi = CudaShuffleXorSync(mask, hi, lane_mask, width);
    341   lo = CudaShuffleXorSync(mask, lo, lane_mask, width);
    342   asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
    343   return value;
    344 }
    345 
    346 // Wrapper for __ldg.
    347 template <typename T>
    348 __host__ __device__ T CudaLdg(const T* address) {
    349 #if __CUDA_ARCH__ >= 350
    350   return __ldg(address);
    351 #else
    352   return *address;
    353 #endif
    354 }
    355 
    356 __host__ __device__ inline bool CudaLdg(const bool* address) {
    357   return CudaLdg(reinterpret_cast<const char*>(address)) != 0;
    358 }
    359 
    360 __host__ __device__ inline std::complex<float> CudaLdg(
    361     const std::complex<float>* address) {
    362 #if __CUDA_ARCH__ >= 350
    363   float2 mem = __ldg(reinterpret_cast<const float2*>(address));
    364   return std::complex<float>(mem.x, mem.y);
    365 #else
    366   return *address;
    367 #endif
    368 }
    369 
    370 __host__ __device__ inline std::complex<double> CudaLdg(
    371     const std::complex<double>* address) {
    372 #if __CUDA_ARCH__ >= 350
    373   double2 mem = __ldg(reinterpret_cast<const double2*>(address));
    374   return std::complex<double>(mem.x, mem.y);
    375 #else
    376   return *address;
    377 #endif
    378 }
    379 
    380 // Zeroes count elements starting at ptr using all threads of a 1-D grid.
    381 // Note: this function does not synchronize, and therefore the memory range is
    382 // not guaranteed to be zero until the next kernel launch.
    383 template <typename T>
    384 __global__ void SetZero(const int count, T* ptr) {
    385   // Check that the grid is one dimensional and index doesn't overflow.
    386   assert(blockDim.y == 1 && blockDim.z == 1);
    387   assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
    388   for (int i : CudaGridRangeX(count)) {
    389     ptr[i] = T(0);
    390   }
    391 }
    392 
    393 // Helper to set all tensor entries to a specific value.
    394 template <typename T>
    395 __global__ void SetToValue(const int count, T* ptr, T value) {
    396   // Check that the grid is one dimensional and index doesn't overflow.
    397   assert(blockDim.y == 1 && blockDim.z == 1);
    398   assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
    399   for (int i : CudaGridRangeX(count)) {
    400     ptr[i] = value;
    401   }
    402 }
    403 
    404 namespace detail {
    405 // Helper function for atomic accumulation implemented as CAS.
    406 template <typename T, typename F>
    407 __device__ T CudaAtomicCasHelper(T* ptr, F accumulate) {
    408   T old = *ptr;
    409   T assumed;
    410   do {
    411     assumed = old;
    412     old = atomicCAS(ptr, assumed, accumulate(assumed));
    413   } while (assumed != old);
    414   return old;
    415 }
    416 
    417 // Overload for floating point (using integer comparison to handle NaN
    418 // correctly).
    419 template <typename F>
    420 __device__ float CudaAtomicCasHelper(float* ptr, F accumulate) {
    421   return __float_as_int(
    422       CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
    423         return __float_as_int(accumulate(__int_as_float(a)));
    424       }));
    425 }
    426 template <typename F>
    427 __device__ double CudaAtomicCasHelper(double* ptr, F accumulate) {
    428   return __longlong_as_double(CudaAtomicCasHelper(
    429       reinterpret_cast<tensorflow::uint64*>(ptr),
    430       [accumulate](tensorflow::uint64 a) {
    431         return __double_as_longlong(accumulate(__longlong_as_double(a)));
    432       }));
    433 }
    434 
    435 // Overload of above function for half. Note that we don't have
    436 // atomicCAS() for anything less than 32 bits, so we need to include the
    437 // other 16 bits in the operation.
    438 //
    439 // This version is going to be very slow
    440 // under high concurrency, since most threads will be spinning on failing
    441 // their compare-and-swap tests. (The fact that we get false sharing on the
    442 // neighboring fp16 makes this even worse.) If you are doing a large reduction,
    443 // you are much better off with doing the intermediate steps in fp32 and then
    444 // switching to fp16 as late as you can in the calculations.
    445 //
    446 // Note: Assumes little endian.
    447 template <typename F>
    448 __device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) {
    449 #if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
    450   static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
    451 #endif
    452   namespace half_impl = Eigen::half_impl;
    453   intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
    454   assert(!(intptr & 0x1));  // should be 2-aligned.
    455   if (intptr & 0x2) {
    456     // The half is in the second part of the uint32 (upper 16 bits).
    457     uint32* address = reinterpret_cast<uint32*>(intptr - 2);
    458     uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
    459       unsigned short high = static_cast<unsigned short>(arg >> 16);
    460       Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high));
    461       return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff);
    462     });
    463     return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16));
    464   } else {
    465     // The half is in the first part of the uint32 (lower 16 bits).
    466     uint32* address = reinterpret_cast<uint32*>(intptr);
    467     uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
    468       unsigned short low = static_cast<unsigned short>(arg & 0xffff);
    469       Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low));
    470       return (arg & 0xffff0000) | static_cast<uint32>(acc.x);
    471     });
    472     return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff));
    473   }
    474 }
    475 
    476 template <typename From, typename To>
    477 using ToTypeIfConvertible =
    478     typename std::enable_if<std::is_convertible<From, To>::value, To>::type;
    479 
    480 }  // namespace detail
    481 
    482 // CUDA provides atomic ops, but not for all types.  We provide wrappers
    483 // for some ops and provide implementation for all reasonable types.
    484 
    485 template <typename T, typename U>
    486 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicAdd(T* ptr, U value) {
    487   return atomicAdd(ptr, value);
    488 }
    489 
    490 __device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr,
    491                                             Eigen::half value) {
    492   return detail::CudaAtomicCasHelper(
    493       ptr, [value](Eigen::half a) { return a + value; });
    494 }
    495 
    496 
    497 #if __CUDA_ARCH__ < 600
    498 __device__ inline double CudaAtomicAdd(double* ptr, double value) {
    499   return detail::CudaAtomicCasHelper(ptr,
    500                                      [value](double a) { return a + value; });
    501 }
    502 #elif __clang__
    503 // Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX.
    504 // see https://reviews.llvm.org/D39638
    505 __device__ inline double CudaAtomicAdd(double* ptr, double value) {
    506   double result;
    507   asm volatile("atom.add.f64 %0, [%1], %2;"
    508                : "=d"(result)
    509                : "l"(ptr), "d"(value)
    510                : "memory");
    511   return result;
    512 }
    513 #endif
    514 // CudaAtomicAdd
    515 // Specializations of CudaAtomicAdd for complex types, which CudaAtomicAdd does
    516 // not support. We treat a std::complex<T>* as a T* (the C++ standard section
    517 // 26.4.4 allows this explicitly) and atomic add the real and imaginary
    518 // components individually. The operation as a whole is not atomic, but we can
    519 // safely treat the components independently for the purpose of accumulating.
    520 __device__ inline std::complex<float> CudaAtomicAdd(std::complex<float>* ptr,
    521                                                     std::complex<float> value) {
    522   auto ptr_scalar = reinterpret_cast<float*>(ptr);
    523   return std::complex<float>(CudaAtomicAdd(ptr_scalar, value.real()),
    524                              CudaAtomicAdd(ptr_scalar + 1, value.imag()));
    525 }
    526 
    527 __device__ inline std::complex<double> CudaAtomicAdd(
    528     std::complex<double>* ptr, std::complex<double> value) {
    529   auto ptr_scalar = reinterpret_cast<double*>(ptr);
    530   return std::complex<double>(CudaAtomicAdd(ptr_scalar, value.real()),
    531                               CudaAtomicAdd(ptr_scalar + 1, value.imag()));
    532 }
    533 
    534 // CudaAtomicSub
    535 template <typename T, typename U>
    536 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicSub(T* ptr, U value) {
    537   return atomicSub(ptr, value);
    538 }
    539 
    540 // Specializations of substraction which add the negative value.
    541 __device__ inline float CudaAtomicSub(float* ptr, float value) {
    542   return CudaAtomicAdd(ptr, -value);
    543 }
    544 
    545 __device__ inline double CudaAtomicSub(double* ptr, double value) {
    546   return CudaAtomicAdd(ptr, -value);
    547 }
    548 
    549 __device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr,
    550                                                    tensorflow::uint64 value) {
    551   return CudaAtomicAdd(ptr, -value);
    552 }
    553 
    554 __device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr,
    555                                             Eigen::half value) {
    556   return detail::CudaAtomicCasHelper(
    557       ptr, [value](Eigen::half a) { return a - value; });
    558 }
    559 
    560 // CudaAtomicMax
    561 template <typename T, typename U>
    562 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMax(T* ptr, U value) {
    563   return atomicMax(ptr, value);
    564 }
    565 
    566 __device__ inline float CudaAtomicMax(float* ptr, float value) {
    567   return detail::CudaAtomicCasHelper(
    568       ptr, [value](float a) { return max(a, value); });
    569 }
    570 
    571 __device__ inline double CudaAtomicMax(double* ptr, double value) {
    572   return detail::CudaAtomicCasHelper(
    573       ptr, [value](double a) { return max(a, value); });
    574 }
    575 
    576 __device__ inline Eigen::half CudaAtomicMax(Eigen::half* ptr,
    577                                             Eigen::half value) {
    578   return detail::CudaAtomicCasHelper(
    579       ptr, [value](Eigen::half a) { return max(a, value); });
    580 }
    581 
    582 #if __CUDA_ARCH__ < 320
    583 __device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr,
    584                                                    tensorflow::uint64 value) {
    585   return detail::CudaAtomicCasHelper(
    586       ptr, [value](tensorflow::uint64 a) { return max(a, value); });
    587 }
    588 #endif
    589 
    590 // CudaAtomicMin
    591 template <typename T, typename U>
    592 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMin(T* ptr, U value) {
    593   return atomicMin(ptr, value);
    594 }
    595 
    596 __device__ inline float CudaAtomicMin(float* ptr, float value) {
    597   return detail::CudaAtomicCasHelper(
    598       ptr, [value](float a) { return min(a, value); });
    599 }
    600 
    601 __device__ inline double CudaAtomicMin(double* ptr, double value) {
    602   return detail::CudaAtomicCasHelper(
    603       ptr, [value](double a) { return min(a, value); });
    604 }
    605 
    606 __device__ inline Eigen::half CudaAtomicMin(Eigen::half* ptr,
    607                                             Eigen::half value) {
    608   return detail::CudaAtomicCasHelper(
    609       ptr, [value](Eigen::half a) { return min(a, value); });
    610 }
    611 
    612 #if __CUDA_ARCH__ < 320
    613 __device__ inline tensorflow::uint64 CudaAtomicMin(tensorflow::uint64* ptr,
    614                                                    tensorflow::uint64 value) {
    615   return detail::CudaAtomicCasHelper(
    616       ptr, [value](tensorflow::uint64 a) { return min(a, value); });
    617 }
    618 #endif
    619 
    620 // CudaAtomicMul
    621 template <typename T, typename U>
    622 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMul(T* ptr, U value) {
    623   return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; });
    624 }
    625 
    626 // CudaAtomicDiv
    627 template <typename T, typename U>
    628 __device__ detail::ToTypeIfConvertible<U, T> CudaAtomicDiv(T* ptr, U value) {
    629   return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; });
    630 }
    631 
    632 }  // namespace tensorflow
    633 
    634 #endif  // GOOGLE_CUDA
    635 #endif  // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
    636