1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/random_ops.cc. 17 // NOTE: If the algorithm is changed, please run the test 18 // .../python/kernel_tests:parameterized_truncated_normal_op_test 19 // commenting out the "tf.set_random_seed(seed)" lines, and using the 20 // "--runs-per-test=1000" flag. This tests the statistical correctness of the 21 // op results. 22 23 #define EIGEN_USE_THREADS 24 25 #include "tensorflow/core/kernels/parameterized_truncated_normal_op.h" 26 27 #include <algorithm> 28 #include <cmath> 29 #include <memory> 30 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/register_types.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/framework/tensor_shape.h" 35 #include "tensorflow/core/lib/random/random_distributions.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/util/guarded_philox_random.h" 38 #include "tensorflow/core/util/work_sharder.h" 39 40 namespace tensorflow { 41 42 typedef Eigen::ThreadPoolDevice CPUDevice; 43 typedef Eigen::GpuDevice GPUDevice; 44 45 namespace functor { 46 using random::PhiloxRandom; 47 48 template <typename T> 49 struct TruncatedNormalFunctor<CPUDevice, T> { 50 static const int kMaxIterations = 100; 51 52 void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches, 53 int64 samples_per_batch, int64 num_elements, 54 typename TTypes<T>::ConstFlat means, 55 typename TTypes<T>::ConstFlat stddevs, 56 typename TTypes<T>::ConstFlat minvals, 57 typename TTypes<T>::ConstFlat maxvals, 58 const random::PhiloxRandom& gen, 59 typename TTypes<T>::Flat output) { 60 auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); 61 62 auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs, 63 &minvals, &maxvals, &gen, 64 &output](int start_batch, int limit_batch) { 65 // Capturing "gen" by-value would only make a copy for the _shared_ 66 // lambda. Since we want to let each worker have its own copy, we pass 67 // "gen" by reference and explicitly do a copy assignment here. 68 random::PhiloxRandom gen_copy = gen; 69 // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to 70 // us using the same state in different batches. 71 // The sample from each iteration uses 2 random numbers. 72 gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) / 73 4); 74 typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform; 75 Uniform dist; 76 77 // Vectorized intermediate calculations for uniform rejection sampling. 78 // We always generate at most 4 samples. 79 Eigen::array<T, 4> z; 80 Eigen::array<T, 4> g; 81 82 for (int64 b = start_batch; b < limit_batch; ++b) { 83 // We are passed a flat array for each of the parameter tensors. 84 // The input is either a scalar broadcasted to all batches or a vector 85 // with length num_batches, but the scalar becomes an array of length 1. 86 T mean = means((means.dimension(0) == 1) ? 0 : b); 87 T stddev = stddevs((stddevs.dimension(0) == 1) ? 0 : b); 88 T minval = minvals((minvals.dimension(0) == 1) ? 0 : b); 89 T maxval = maxvals((maxvals.dimension(0) == 1) ? 0 : b); 90 91 // The last batch can be short, if we adjusted num_batches and 92 // samples_per_batch. 93 const int64 limit_sample = 94 std::min((b + 1) * samples_per_batch, num_elements); 95 int64 sample = b * samples_per_batch; 96 97 // On GPU, this check will just fill samples with NAN if it fails. 98 OP_REQUIRES(ctx, 99 stddev > T(0) && minval < maxval && 100 (Eigen::numext::isfinite(minval) || 101 Eigen::numext::isfinite(maxval)), 102 errors::InvalidArgument("Invalid parameters")); 103 104 int numIterations = 0; 105 106 // If possible, make one-sided bound be the lower bound, or make both 107 // bounds positive. Otherwise, the bounds are on either side of the 108 // mean. 109 if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) { 110 // Reverse all calculations. normMin and normMax will be flipped. 111 std::swap(minval, maxval); 112 stddev = -stddev; 113 } 114 115 // Calculate normalized samples, then convert them. 116 const T normMin = (minval - mean) / stddev; 117 const T normMax = (maxval - mean) / stddev; 118 119 // Determine the method to use. 120 const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4)); 121 const T cutoff = 122 T(2) * 123 Eigen::numext::exp(T(0.5) + 124 (normMin * (normMin - sqrtFactor)) / T(4)) / 125 (normMin + sqrtFactor); 126 const T diff = normMax - normMin; 127 if (diff < cutoff) { 128 // Sample from a uniform distribution on [normMin, normMax]. 129 130 const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin; 131 132 while (sample < limit_sample) { 133 const auto rand = dist(&gen_copy); 134 const int size = rand.size(); 135 // NOTE(ringwalt): These loops seem to only generate packed AVX 136 // instructions for float32. 137 for (int i = 0; i < size; i++) { 138 z[i] = rand[i] * diff + normMin; 139 } 140 for (int i = 0; i < size; i++) { 141 g[i] = (plusFactor - z[i] * z[i]) / T(2.0); 142 } 143 144 const auto u = dist(&gen_copy); 145 for (int i = 0; i < size; i++) { 146 if (u[i] <= Eigen::numext::exp(g[i]) || 147 numIterations + 1 >= kMaxIterations) { 148 // Accept the sample z. 149 // If we run out of iterations, just use the current uniform 150 // sample. Emperically, the probability of accepting each sample 151 // is at least 50% for typical inputs, so we will always accept 152 // by 100 iterations. 153 // This introduces a slight inaccuracy when at least one bound 154 // is large, minval is negative and maxval is positive. 155 output(sample) = z[i] * stddev + mean; 156 sample++; 157 if (sample >= limit_sample) { 158 break; 159 } 160 numIterations = 0; 161 } else { 162 numIterations++; 163 } 164 } 165 } 166 } else { 167 // Sample from an exponential distribution with alpha maximizing 168 // acceptance probability, offset by normMin from the origin. 169 // Accept only if less than normMax. 170 const T alpha = 171 (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) / 172 T(2); 173 while (sample < limit_sample) { 174 auto rand = dist(&gen_copy); 175 const int size = rand.size(); 176 int i = 0; 177 while (i < size) { 178 const T z = -Eigen::numext::log(rand[i]) / alpha + normMin; 179 i++; 180 const T x = normMin < alpha ? alpha - z : normMin - alpha; 181 const T g = Eigen::numext::exp(-x * x / T(2.0)); 182 const T u = rand[i]; 183 i++; 184 if ((u <= g && z < normMax) || 185 numIterations + 1 >= kMaxIterations) { 186 output(sample) = z * stddev + mean; 187 sample++; 188 if (sample >= limit_sample) { 189 break; 190 } 191 numIterations = 0; 192 } else { 193 numIterations++; 194 } 195 } 196 } 197 } 198 } 199 }; 200 // The cost of the initial calculations for the batch. 201 const int64 batchInitCost = 202 // normMin, normMax 203 (Eigen::TensorOpCost::AddCost<T>() + 204 Eigen::TensorOpCost::MulCost<T>()) * 205 2 206 // sqrtFactor 207 + Eigen::TensorOpCost::AddCost<T>() + 208 Eigen::TensorOpCost::MulCost<T>() + 209 Eigen::internal::functor_traits< 210 Eigen::internal::scalar_sqrt_op<T>>::Cost 211 // cutoff 212 + Eigen::TensorOpCost::MulCost<T>() * 4 + 213 Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost 214 // diff 215 + Eigen::TensorOpCost::AddCost<T>(); 216 const int64 uniformSampleCost = 217 random::PhiloxRandom::kElementCost + 218 random::UniformDistribution<random::PhiloxRandom, T>::kElementCost; 219 // The cost of a single uniform sampling round. 220 const int64 uniformRejectionSamplingCost = 221 uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() + 222 Eigen::TensorOpCost::AddCost<T>() + 223 Eigen::TensorOpCost::MulCost<T>() * 2 + 224 Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost + 225 Eigen::internal::functor_traits< 226 Eigen::internal::scalar_exp_op<T>>::Cost + 227 Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>(); 228 // Estimate the cost for an entire batch. 229 // Assume we use uniform sampling, and accept the 2nd sample on average. 230 const int64 batchCost = 231 batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch; 232 Shard(worker_threads.num_threads, worker_threads.workers, num_batches, 233 batchCost, DoWork); 234 } 235 }; 236 237 } // namespace functor 238 239 namespace { 240 241 // Samples from a truncated normal distribution, using the given parameters. 242 template <typename Device, typename T> 243 class ParameterizedTruncatedNormalOp : public OpKernel { 244 // Reshape batches so each batch is this size if possible. 245 static const int32 kDesiredBatchSize = 100; 246 247 public: 248 explicit ParameterizedTruncatedNormalOp(OpKernelConstruction* context) 249 : OpKernel(context) { 250 OP_REQUIRES_OK(context, generator_.Init(context)); 251 } 252 253 void Compute(OpKernelContext* ctx) override { 254 const Tensor& shape_tensor = ctx->input(0); 255 const Tensor& means_tensor = ctx->input(1); 256 const Tensor& stddevs_tensor = ctx->input(2); 257 const Tensor& minvals_tensor = ctx->input(3); 258 const Tensor& maxvals_tensor = ctx->input(4); 259 260 OP_REQUIRES( 261 ctx, TensorShapeUtils::IsVector(shape_tensor.shape()), 262 errors::InvalidArgument("Input shape should be a vector, got shape: ", 263 shape_tensor.shape().DebugString())); 264 int32 num_batches = shape_tensor.flat<int32>()(0); 265 266 int32 samples_per_batch = 1; 267 const int32 num_dims = shape_tensor.dim_size(0); 268 for (int32 i = 1; i < num_dims; i++) { 269 samples_per_batch *= shape_tensor.flat<int32>()(i); 270 } 271 const int32 num_elements = num_batches * samples_per_batch; 272 273 // Allocate the output before fudging num_batches and samples_per_batch. 274 auto shape_vec = shape_tensor.flat<int32>(); 275 TensorShape tensor_shape; 276 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( 277 shape_vec.data(), shape_vec.size(), &tensor_shape)); 278 Tensor* samples_tensor; 279 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor)); 280 281 // Parameters must be 0-d or 1-d. 282 OP_REQUIRES(ctx, means_tensor.dims() <= 1, 283 errors::InvalidArgument( 284 "Input means should be a scalar or vector, got shape: ", 285 means_tensor.shape().DebugString())); 286 OP_REQUIRES(ctx, stddevs_tensor.dims() <= 1, 287 errors::InvalidArgument( 288 "Input stddevs should be a scalar or vector, got shape: ", 289 stddevs_tensor.shape().DebugString())); 290 OP_REQUIRES(ctx, minvals_tensor.dims() <= 1, 291 errors::InvalidArgument( 292 "Input minvals should be a scalar or vector, got shape: ", 293 minvals_tensor.shape().DebugString())); 294 OP_REQUIRES(ctx, maxvals_tensor.dims() <= 1, 295 errors::InvalidArgument( 296 "Input maxvals should be a scalar or vector, got shape: ", 297 maxvals_tensor.shape().DebugString())); 298 299 if ((means_tensor.dims() == 0 || means_tensor.dim_size(0) == 1) && 300 (stddevs_tensor.dims() == 0 || stddevs_tensor.dim_size(0) == 1) && 301 minvals_tensor.dims() == 0 && maxvals_tensor.dims() == 0) { 302 // All batches have the same parameters, so we can update the batch size 303 // to a reasonable value to improve parallelism (ensure enough batches, 304 // and no very small batches which have high overhead). 305 int32 size = num_batches * samples_per_batch; 306 int32 adjusted_samples = kDesiredBatchSize; 307 // Ensure adjusted_batches * adjusted_samples >= size. 308 int32 adjusted_batches = Eigen::divup(size, adjusted_samples); 309 num_batches = adjusted_batches; 310 samples_per_batch = adjusted_samples; 311 } else { 312 // Parameters must be broadcastable to the shape [num_batches]. 313 OP_REQUIRES( 314 ctx, 315 TensorShapeUtils::IsScalar(means_tensor.shape()) || 316 means_tensor.dim_size(0) == 1 || 317 means_tensor.dim_size(0) == num_batches, 318 errors::InvalidArgument( 319 "Input means should have length 1 or shape[0], got shape: ", 320 means_tensor.shape().DebugString())); 321 OP_REQUIRES( 322 ctx, 323 TensorShapeUtils::IsScalar(stddevs_tensor.shape()) || 324 stddevs_tensor.dim_size(0) == 1 || 325 stddevs_tensor.dim_size(0) == num_batches, 326 errors::InvalidArgument( 327 "Input stddevs should have length 1 or shape[0], got shape: ", 328 stddevs_tensor.shape().DebugString())); 329 OP_REQUIRES( 330 ctx, 331 TensorShapeUtils::IsScalar(minvals_tensor.shape()) || 332 minvals_tensor.dim_size(0) == 1 || 333 minvals_tensor.dim_size(0) == num_batches, 334 errors::InvalidArgument( 335 "Input minvals should have length 1 or shape[0], got shape: ", 336 minvals_tensor.shape().DebugString())); 337 OP_REQUIRES( 338 ctx, 339 TensorShapeUtils::IsScalar(maxvals_tensor.shape()) || 340 maxvals_tensor.dim_size(0) == 1 || 341 maxvals_tensor.dim_size(0) == num_batches, 342 errors::InvalidArgument( 343 "Input maxvals should have length 1 or shape[0], got shape: ", 344 maxvals_tensor.shape().DebugString())); 345 } 346 347 auto truncFunctor = functor::TruncatedNormalFunctor<Device, T>(); 348 // Each worker has the fudge factor for samples_per_batch, so use it here. 349 random::PhiloxRandom rng = generator_.ReserveSamples128( 350 num_batches * 2 * truncFunctor.kMaxIterations * 351 (samples_per_batch + 3) / 4); 352 truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches, 353 samples_per_batch, num_elements, means_tensor.flat<T>(), 354 stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(), 355 maxvals_tensor.flat<T>(), rng, samples_tensor->flat<T>()); 356 } 357 358 private: 359 GuardedPhiloxRandom generator_; 360 361 TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp); 362 }; 363 364 } // namespace 365 366 #define REGISTER(TYPE) \ 367 REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \ 368 .Device(DEVICE_CPU) \ 369 .TypeConstraint<TYPE>("dtype"), \ 370 ParameterizedTruncatedNormalOp<CPUDevice, TYPE>) 371 372 TF_CALL_half(REGISTER); 373 TF_CALL_float(REGISTER); 374 TF_CALL_double(REGISTER); 375 376 #undef REGISTER 377 378 #if GOOGLE_CUDA 379 380 #define REGISTER(TYPE) \ 381 REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \ 382 .Device(DEVICE_GPU) \ 383 .HostMemory("shape") \ 384 .TypeConstraint<TYPE>("dtype"), \ 385 ParameterizedTruncatedNormalOp<GPUDevice, TYPE>) 386 387 TF_CALL_half(REGISTER); 388 TF_CALL_float(REGISTER); 389 TF_CALL_double(REGISTER); 390 391 #undef REGISTER 392 393 #endif // GOOGLE_CUDA 394 395 } // end namespace tensorflow 396