Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/random_ops.cc.
     17 // NOTE: If the algorithm is changed, please run the test
     18 // .../python/kernel_tests:parameterized_truncated_normal_op_test
     19 // commenting out the "tf.set_random_seed(seed)" lines, and using the
     20 // "--runs-per-test=1000" flag. This tests the statistical correctness of the
     21 // op results.
     22 
     23 #define EIGEN_USE_THREADS
     24 
     25 #include "tensorflow/core/kernels/parameterized_truncated_normal_op.h"
     26 
     27 #include <algorithm>
     28 #include <cmath>
     29 #include <memory>
     30 
     31 #include "tensorflow/core/framework/op_kernel.h"
     32 #include "tensorflow/core/framework/register_types.h"
     33 #include "tensorflow/core/framework/tensor.h"
     34 #include "tensorflow/core/framework/tensor_shape.h"
     35 #include "tensorflow/core/lib/random/random_distributions.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/util/guarded_philox_random.h"
     38 #include "tensorflow/core/util/work_sharder.h"
     39 
     40 namespace tensorflow {
     41 
     42 typedef Eigen::ThreadPoolDevice CPUDevice;
     43 typedef Eigen::GpuDevice GPUDevice;
     44 
     45 namespace functor {
     46 using random::PhiloxRandom;
     47 
     48 template <typename T>
     49 struct TruncatedNormalFunctor<CPUDevice, T> {
     50   static const int kMaxIterations = 100;
     51 
     52   void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
     53                   int64 samples_per_batch, int64 num_elements,
     54                   typename TTypes<T>::ConstFlat means,
     55                   typename TTypes<T>::ConstFlat stddevs,
     56                   typename TTypes<T>::ConstFlat minvals,
     57                   typename TTypes<T>::ConstFlat maxvals,
     58                   const random::PhiloxRandom& gen,
     59                   typename TTypes<T>::Flat output) {
     60     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
     61 
     62     auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
     63                    &minvals, &maxvals, &gen,
     64                    &output](int start_batch, int limit_batch) {
     65       // Capturing "gen" by-value would only make a copy for the _shared_
     66       // lambda.  Since we want to let each worker have its own copy, we pass
     67       // "gen" by reference and explicitly do a copy assignment here.
     68       random::PhiloxRandom gen_copy = gen;
     69       // Skip takes units of 128 bytes.  +3 is so rounding doesn't lead to
     70       // us using the same state in different batches.
     71       // The sample from each iteration uses 2 random numbers.
     72       gen_copy.Skip(start_batch * 2 * kMaxIterations * (samples_per_batch + 3) /
     73                     4);
     74       typedef random::UniformDistribution<random::PhiloxRandom, T> Uniform;
     75       Uniform dist;
     76 
     77       // Vectorized intermediate calculations for uniform rejection sampling.
     78       // We always generate at most 4 samples.
     79       Eigen::array<T, 4> z;
     80       Eigen::array<T, 4> g;
     81 
     82       for (int64 b = start_batch; b < limit_batch; ++b) {
     83         // We are passed a flat array for each of the parameter tensors.
     84         // The input is either a scalar broadcasted to all batches or a vector
     85         // with length num_batches, but the scalar becomes an array of length 1.
     86         T mean = means((means.dimension(0) == 1) ? 0 : b);
     87         T stddev = stddevs((stddevs.dimension(0) == 1) ? 0 : b);
     88         T minval = minvals((minvals.dimension(0) == 1) ? 0 : b);
     89         T maxval = maxvals((maxvals.dimension(0) == 1) ? 0 : b);
     90 
     91         // The last batch can be short, if we adjusted num_batches and
     92         // samples_per_batch.
     93         const int64 limit_sample =
     94             std::min((b + 1) * samples_per_batch, num_elements);
     95         int64 sample = b * samples_per_batch;
     96 
     97         // On GPU, this check will just fill samples with NAN if it fails.
     98         OP_REQUIRES(ctx,
     99                     stddev > T(0) && minval < maxval &&
    100                         (Eigen::numext::isfinite(minval) ||
    101                          Eigen::numext::isfinite(maxval)),
    102                     errors::InvalidArgument("Invalid parameters"));
    103 
    104         int numIterations = 0;
    105 
    106         // If possible, make one-sided bound be the lower bound, or make both
    107         // bounds positive. Otherwise, the bounds are on either side of the
    108         // mean.
    109         if ((Eigen::numext::isinf(minval) && minval < T(0)) || maxval < mean) {
    110           // Reverse all calculations. normMin and normMax will be flipped.
    111           std::swap(minval, maxval);
    112           stddev = -stddev;
    113         }
    114 
    115         // Calculate normalized samples, then convert them.
    116         const T normMin = (minval - mean) / stddev;
    117         const T normMax = (maxval - mean) / stddev;
    118 
    119         // Determine the method to use.
    120         const T sqrtFactor = Eigen::numext::sqrt((normMin * normMin) + T(4));
    121         const T cutoff =
    122             T(2) *
    123             Eigen::numext::exp(T(0.5) +
    124                                (normMin * (normMin - sqrtFactor)) / T(4)) /
    125             (normMin + sqrtFactor);
    126         const T diff = normMax - normMin;
    127         if (diff < cutoff) {
    128           // Sample from a uniform distribution on [normMin, normMax].
    129 
    130           const T plusFactor = (normMin < T(0)) ? T(0) : normMin * normMin;
    131 
    132           while (sample < limit_sample) {
    133             const auto rand = dist(&gen_copy);
    134             const int size = rand.size();
    135             // NOTE(ringwalt): These loops seem to only generate packed AVX
    136             // instructions for float32.
    137             for (int i = 0; i < size; i++) {
    138               z[i] = rand[i] * diff + normMin;
    139             }
    140             for (int i = 0; i < size; i++) {
    141               g[i] = (plusFactor - z[i] * z[i]) / T(2.0);
    142             }
    143 
    144             const auto u = dist(&gen_copy);
    145             for (int i = 0; i < size; i++) {
    146               if (u[i] <= Eigen::numext::exp(g[i]) ||
    147                   numIterations + 1 >= kMaxIterations) {
    148                 // Accept the sample z.
    149                 // If we run out of iterations, just use the current uniform
    150                 // sample. Emperically, the probability of accepting each sample
    151                 // is at least 50% for typical inputs, so we will always accept
    152                 // by 100 iterations.
    153                 // This introduces a slight inaccuracy when at least one bound
    154                 // is large, minval is negative and maxval is positive.
    155                 output(sample) = z[i] * stddev + mean;
    156                 sample++;
    157                 if (sample >= limit_sample) {
    158                   break;
    159                 }
    160                 numIterations = 0;
    161               } else {
    162                 numIterations++;
    163               }
    164             }
    165           }
    166         } else {
    167           // Sample from an exponential distribution with alpha maximizing
    168           // acceptance probability, offset by normMin from the origin.
    169           // Accept only if less than normMax.
    170           const T alpha =
    171               (normMin + Eigen::numext::sqrt((normMin * normMin) + T(4))) /
    172               T(2);
    173           while (sample < limit_sample) {
    174             auto rand = dist(&gen_copy);
    175             const int size = rand.size();
    176             int i = 0;
    177             while (i < size) {
    178               const T z = -Eigen::numext::log(rand[i]) / alpha + normMin;
    179               i++;
    180               const T x = normMin < alpha ? alpha - z : normMin - alpha;
    181               const T g = Eigen::numext::exp(-x * x / T(2.0));
    182               const T u = rand[i];
    183               i++;
    184               if ((u <= g && z < normMax) ||
    185                   numIterations + 1 >= kMaxIterations) {
    186                 output(sample) = z * stddev + mean;
    187                 sample++;
    188                 if (sample >= limit_sample) {
    189                   break;
    190                 }
    191                 numIterations = 0;
    192               } else {
    193                 numIterations++;
    194               }
    195             }
    196           }
    197         }
    198       }
    199     };
    200     // The cost of the initial calculations for the batch.
    201     const int64 batchInitCost =
    202         // normMin, normMax
    203         (Eigen::TensorOpCost::AddCost<T>() +
    204          Eigen::TensorOpCost::MulCost<T>()) *
    205             2
    206         // sqrtFactor
    207         + Eigen::TensorOpCost::AddCost<T>() +
    208         Eigen::TensorOpCost::MulCost<T>() +
    209         Eigen::internal::functor_traits<
    210             Eigen::internal::scalar_sqrt_op<T>>::Cost
    211         // cutoff
    212         + Eigen::TensorOpCost::MulCost<T>() * 4 +
    213         Eigen::internal::functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost
    214         // diff
    215         + Eigen::TensorOpCost::AddCost<T>();
    216     const int64 uniformSampleCost =
    217         random::PhiloxRandom::kElementCost +
    218         random::UniformDistribution<random::PhiloxRandom, T>::kElementCost;
    219     // The cost of a single uniform sampling round.
    220     const int64 uniformRejectionSamplingCost =
    221         uniformSampleCost + Eigen::TensorOpCost::MulCost<T>() +
    222         Eigen::TensorOpCost::AddCost<T>() +
    223         Eigen::TensorOpCost::MulCost<T>() * 2 +
    224         Eigen::TensorOpCost::AddCost<T>() + uniformSampleCost +
    225         Eigen::internal::functor_traits<
    226             Eigen::internal::scalar_exp_op<T>>::Cost +
    227         Eigen::TensorOpCost::MulCost<T>() + Eigen::TensorOpCost::AddCost<T>();
    228     // Estimate the cost for an entire batch.
    229     // Assume we use uniform sampling, and accept the 2nd sample on average.
    230     const int64 batchCost =
    231         batchInitCost + uniformRejectionSamplingCost * 2 * samples_per_batch;
    232     Shard(worker_threads.num_threads, worker_threads.workers, num_batches,
    233           batchCost, DoWork);
    234   }
    235 };
    236 
    237 }  // namespace functor
    238 
    239 namespace {
    240 
    241 // Samples from a truncated normal distribution, using the given parameters.
    242 template <typename Device, typename T>
    243 class ParameterizedTruncatedNormalOp : public OpKernel {
    244   // Reshape batches so each batch is this size if possible.
    245   static const int32 kDesiredBatchSize = 100;
    246 
    247  public:
    248   explicit ParameterizedTruncatedNormalOp(OpKernelConstruction* context)
    249       : OpKernel(context) {
    250     OP_REQUIRES_OK(context, generator_.Init(context));
    251   }
    252 
    253   void Compute(OpKernelContext* ctx) override {
    254     const Tensor& shape_tensor = ctx->input(0);
    255     const Tensor& means_tensor = ctx->input(1);
    256     const Tensor& stddevs_tensor = ctx->input(2);
    257     const Tensor& minvals_tensor = ctx->input(3);
    258     const Tensor& maxvals_tensor = ctx->input(4);
    259 
    260     OP_REQUIRES(
    261         ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
    262         errors::InvalidArgument("Input shape should be a vector, got shape: ",
    263                                 shape_tensor.shape().DebugString()));
    264     int32 num_batches = shape_tensor.flat<int32>()(0);
    265 
    266     int32 samples_per_batch = 1;
    267     const int32 num_dims = shape_tensor.dim_size(0);
    268     for (int32 i = 1; i < num_dims; i++) {
    269       samples_per_batch *= shape_tensor.flat<int32>()(i);
    270     }
    271     const int32 num_elements = num_batches * samples_per_batch;
    272 
    273     // Allocate the output before fudging num_batches and samples_per_batch.
    274     auto shape_vec = shape_tensor.flat<int32>();
    275     TensorShape tensor_shape;
    276     OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
    277                             shape_vec.data(), shape_vec.size(), &tensor_shape));
    278     Tensor* samples_tensor;
    279     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, tensor_shape, &samples_tensor));
    280 
    281     // Parameters must be 0-d or 1-d.
    282     OP_REQUIRES(ctx, means_tensor.dims() <= 1,
    283                 errors::InvalidArgument(
    284                     "Input means should be a scalar or vector, got shape: ",
    285                     means_tensor.shape().DebugString()));
    286     OP_REQUIRES(ctx, stddevs_tensor.dims() <= 1,
    287                 errors::InvalidArgument(
    288                     "Input stddevs should be a scalar or vector, got shape: ",
    289                     stddevs_tensor.shape().DebugString()));
    290     OP_REQUIRES(ctx, minvals_tensor.dims() <= 1,
    291                 errors::InvalidArgument(
    292                     "Input minvals should be a scalar or vector, got shape: ",
    293                     minvals_tensor.shape().DebugString()));
    294     OP_REQUIRES(ctx, maxvals_tensor.dims() <= 1,
    295                 errors::InvalidArgument(
    296                     "Input maxvals should be a scalar or vector, got shape: ",
    297                     maxvals_tensor.shape().DebugString()));
    298 
    299     if ((means_tensor.dims() == 0 || means_tensor.dim_size(0) == 1) &&
    300         (stddevs_tensor.dims() == 0 || stddevs_tensor.dim_size(0) == 1) &&
    301         minvals_tensor.dims() == 0 && maxvals_tensor.dims() == 0) {
    302       // All batches have the same parameters, so we can update the batch size
    303       // to a reasonable value to improve parallelism (ensure enough batches,
    304       // and no very small batches which have high overhead).
    305       int32 size = num_batches * samples_per_batch;
    306       int32 adjusted_samples = kDesiredBatchSize;
    307       // Ensure adjusted_batches * adjusted_samples >= size.
    308       int32 adjusted_batches = Eigen::divup(size, adjusted_samples);
    309       num_batches = adjusted_batches;
    310       samples_per_batch = adjusted_samples;
    311     } else {
    312       // Parameters must be broadcastable to the shape [num_batches].
    313       OP_REQUIRES(
    314           ctx,
    315           TensorShapeUtils::IsScalar(means_tensor.shape()) ||
    316               means_tensor.dim_size(0) == 1 ||
    317               means_tensor.dim_size(0) == num_batches,
    318           errors::InvalidArgument(
    319               "Input means should have length 1 or shape[0], got shape: ",
    320               means_tensor.shape().DebugString()));
    321       OP_REQUIRES(
    322           ctx,
    323           TensorShapeUtils::IsScalar(stddevs_tensor.shape()) ||
    324               stddevs_tensor.dim_size(0) == 1 ||
    325               stddevs_tensor.dim_size(0) == num_batches,
    326           errors::InvalidArgument(
    327               "Input stddevs should have length 1 or shape[0], got shape: ",
    328               stddevs_tensor.shape().DebugString()));
    329       OP_REQUIRES(
    330           ctx,
    331           TensorShapeUtils::IsScalar(minvals_tensor.shape()) ||
    332               minvals_tensor.dim_size(0) == 1 ||
    333               minvals_tensor.dim_size(0) == num_batches,
    334           errors::InvalidArgument(
    335               "Input minvals should have length 1 or shape[0], got shape: ",
    336               minvals_tensor.shape().DebugString()));
    337       OP_REQUIRES(
    338           ctx,
    339           TensorShapeUtils::IsScalar(maxvals_tensor.shape()) ||
    340               maxvals_tensor.dim_size(0) == 1 ||
    341               maxvals_tensor.dim_size(0) == num_batches,
    342           errors::InvalidArgument(
    343               "Input maxvals should have length 1 or shape[0], got shape: ",
    344               maxvals_tensor.shape().DebugString()));
    345     }
    346 
    347     auto truncFunctor = functor::TruncatedNormalFunctor<Device, T>();
    348     // Each worker has the fudge factor for samples_per_batch, so use it here.
    349     random::PhiloxRandom rng = generator_.ReserveSamples128(
    350         num_batches * 2 * truncFunctor.kMaxIterations *
    351         (samples_per_batch + 3) / 4);
    352     truncFunctor(ctx, ctx->eigen_device<Device>(), num_batches,
    353                  samples_per_batch, num_elements, means_tensor.flat<T>(),
    354                  stddevs_tensor.flat<T>(), minvals_tensor.flat<T>(),
    355                  maxvals_tensor.flat<T>(), rng, samples_tensor->flat<T>());
    356   }
    357 
    358  private:
    359   GuardedPhiloxRandom generator_;
    360 
    361   TF_DISALLOW_COPY_AND_ASSIGN(ParameterizedTruncatedNormalOp);
    362 };
    363 
    364 }  // namespace
    365 
    366 #define REGISTER(TYPE)                                         \
    367   REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
    368                               .Device(DEVICE_CPU)              \
    369                               .TypeConstraint<TYPE>("dtype"),  \
    370                           ParameterizedTruncatedNormalOp<CPUDevice, TYPE>)
    371 
    372 TF_CALL_half(REGISTER);
    373 TF_CALL_float(REGISTER);
    374 TF_CALL_double(REGISTER);
    375 
    376 #undef REGISTER
    377 
    378 #if GOOGLE_CUDA
    379 
    380 #define REGISTER(TYPE)                                         \
    381   REGISTER_KERNEL_BUILDER(Name("ParameterizedTruncatedNormal") \
    382                               .Device(DEVICE_GPU)              \
    383                               .HostMemory("shape")             \
    384                               .TypeConstraint<TYPE>("dtype"),  \
    385                           ParameterizedTruncatedNormalOp<GPUDevice, TYPE>)
    386 
    387 TF_CALL_half(REGISTER);
    388 TF_CALL_float(REGISTER);
    389 TF_CALL_double(REGISTER);
    390 
    391 #undef REGISTER
    392 
    393 #endif  // GOOGLE_CUDA
    394 
    395 }  // end namespace tensorflow
    396