1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "tensorflow/core/kernels/parameterized_truncated_normal_op.h" 21 22 #include <assert.h> 23 #include <stdio.h> 24 #include <cmath> 25 26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/lib/random/philox_random.h" 29 #include "tensorflow/core/lib/random/random_distributions.h" 30 #include "tensorflow/core/util/cuda_kernel_helper.h" 31 32 #if defined(_MSC_VER) && !defined(__clang__) 33 // msvc does not support unroll. One could try the loop pragma but we need to 34 // take a closer look if this generates better code in this case. For now let 35 // the compiler take care of it. 36 #define UNROLL 37 #else 38 #define UNROLL _Pragma("unroll") 39 #endif 40 41 namespace tensorflow { 42 43 class OpKernelContext; 44 45 namespace functor { 46 47 typedef Eigen::GpuDevice GPUDevice; 48 49 template <typename T> 50 __global__ void __launch_bounds__(1024) 51 TruncatedNormalKernel(random::PhiloxRandom gen, T* data, int64 num_batches, 52 int64 samples_per_batch, int64 num_elements, 53 const T* means, bool single_mean, const T* stddevs, 54 bool single_stddev, const T* minvals, 55 bool single_minval, const T* maxvals, 56 bool single_maxval, int64 kMaxIterations) { 57 const int32 max_samples_per_item = 2 * kMaxIterations; 58 // Initial offset as given by CUDA_1D_KERNEL_LOOP. 59 const int32 initial_offset = blockIdx.x * blockDim.x + threadIdx.x; 60 gen.Skip(max_samples_per_item * initial_offset); 61 typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform; 62 Uniform dist; 63 const int kDistSize = Uniform::kResultElementCount; 64 const T quietNaN = Eigen::NumTraits<T>::quiet_NaN(); 65 66 // We skip the total number of threads to get to the next element. To produce 67 // deterministic results between devices, each element in the output array 68 // skips max_samples_per_item in the generator. Then after generating this 69 // item, we need to skip the samples for one element for every thread to get 70 // to the next element that we actually process. 71 const int32 samples_between_processed_elements = 72 max_samples_per_item * (gridDim.x * blockDim.x); 73 74 CUDA_1D_KERNEL_LOOP(offset, num_elements) { 75 // Track how many more samples we need to skip before we process the next 76 // element. 77 int32 remaining_samples = samples_between_processed_elements; 78 79 const int64 batch_id = offset / samples_per_batch; 80 T mean = means[single_mean ? 0 : batch_id]; 81 const T input_stddev = stddevs[single_stddev ? 0 : batch_id]; 82 T minval = minvals[single_minval ? 0 : batch_id]; 83 T maxval = maxvals[single_maxval ? 0 : batch_id]; 84 85 // Flip the distribution if we can make the lower bound positive. 86 T stddev; 87 if (Eigen::numext::isinf(minval) || maxval < mean) { 88 // Reverse all calculations. normMin and normMax will be flipped. 89 // std::swap is a host function (not available in CUDA). 90 T temp = minval; 91 minval = maxval; 92 maxval = temp; 93 stddev = -input_stddev; 94 } else { 95 stddev = input_stddev; 96 } 97 98 // Calculate normalized samples, then scale them. 99 const T normMin = (minval - mean) / stddev; 100 const T normMax = (maxval - mean) / stddev; 101 102 // Determine the method to use. 103 const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4)); 104 const T cutoff = 105 T(2) * 106 Eigen::numext::exp(T(0.5) + (normMin * (normMin - sqrtFactor)) / T(4)) / 107 (normMin + sqrtFactor); 108 const T diff = normMax - normMin; 109 const T two = T(2.0); 110 111 // Validate the normalized min and max, because the originals may have been 112 // flipped already. 113 if (!(input_stddev > T(0) && normMin < normMax && 114 (Eigen::numext::isfinite(normMin) || 115 Eigen::numext::isfinite(normMax)))) { 116 data[offset] = quietNaN; 117 } else if (diff < cutoff) { 118 // Sample from a uniform distribution on [normMin, normMax]. 119 120 // Vectorized intermediate calculations for uniform rejection sampling. 121 // We always generate at most 4 samples. 122 Eigen::array<T, 4> z; 123 Eigen::array<T, 4> g; 124 125 const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin; 126 127 int numIterations = 0; 128 while (numIterations < kMaxIterations) { 129 const auto rand = dist(&gen); 130 remaining_samples -= gen.kResultElementCount; 131 UNROLL for (int i = 0; i < kDistSize; i++) { 132 z[i] = rand[i] * diff + normMin; 133 } 134 UNROLL for (int i = 0; i < kDistSize; i++) { 135 g[i] = (plusFactor - z[i] * z[i]) / two; 136 } 137 138 const auto u = dist(&gen); 139 remaining_samples -= gen.kResultElementCount; 140 UNROLL for (int i = 0; i < kDistSize; i++) { 141 if (u[i] <= Eigen::numext::exp(g[i]) || 142 numIterations + 1 >= kMaxIterations) { 143 // Accept the sample z. 144 // If we run out of iterations, just use the current uniform 145 // sample. Emperically, the probability of accepting each sample 146 // is at least 50% for typical inputs, so we will always accept 147 // by 100 iterations. 148 // This introduces a slight inaccuracy when at least one bound 149 // is large, minval is negative and maxval is positive. 150 data[offset] = z[i] * stddev + mean; 151 // Break out of the nested loop by updating numIterations. 152 numIterations = kMaxIterations; 153 break; 154 } else { 155 numIterations++; 156 } 157 } 158 } 159 } else { 160 // Sample from an exponential distribution with alpha maximizing 161 // acceptance probability, offset by normMin from the origin. 162 // Accept only if less than normMax. 163 const T alpha = 164 (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / T(2); 165 int numIterations = 0; 166 while (numIterations < kMaxIterations) { 167 auto rand = dist(&gen); 168 remaining_samples -= gen.kResultElementCount; 169 UNROLL for (int i = 0; i < kDistSize; i += 2) { 170 const T z = -Eigen::numext::log(rand[i]) / alpha + normMin; 171 const T x = normMin < alpha ? alpha - z : normMin - alpha; 172 const T g = Eigen::numext::exp(-x * x / two); 173 const T u = rand[i + 1]; 174 if ((u <= g && z < normMax) || numIterations + 1 >= kMaxIterations) { 175 data[offset] = z * stddev + mean; 176 // Break out of the nested loop by updating numIterations. 177 numIterations = kMaxIterations; 178 break; 179 } else { 180 numIterations++; 181 } 182 } 183 } 184 } 185 186 gen.Skip(remaining_samples); 187 } 188 } 189 190 // Partial specialization for GPU 191 template <typename T> 192 struct TruncatedNormalFunctor<GPUDevice, T> { 193 static const int kMaxIterations = 100; 194 195 void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches, 196 int64 samples_per_batch, int64 num_elements, 197 typename TTypes<T>::ConstFlat means, 198 typename TTypes<T>::ConstFlat stddevs, 199 typename TTypes<T>::ConstFlat minvals, 200 typename TTypes<T>::ConstFlat maxvals, 201 const random::PhiloxRandom& gen, 202 typename TTypes<T>::Flat output) { 203 const auto config = GetCudaLaunchConfig(num_elements, d); 204 205 TruncatedNormalKernel<T> 206 <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( 207 gen, output.data(), num_batches, samples_per_batch, num_elements, 208 means.data(), means.dimension(0) == 1, stddevs.data(), 209 stddevs.dimension(0) == 1, minvals.data(), 210 minvals.dimension(0) == 1, maxvals.data(), 211 maxvals.dimension(0) == 1, kMaxIterations); 212 }; 213 }; 214 215 // Explicit instantiation of the GPU distributions functors 216 template struct TruncatedNormalFunctor<GPUDevice, Eigen::half>; 217 template struct TruncatedNormalFunctor<GPUDevice, float>; 218 template struct TruncatedNormalFunctor<GPUDevice, double>; 219 220 } // namespace functor 221 } // namespace tensorflow 222 223 #endif // GOOGLE_CUDA 224