Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/random_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/multinomial_op.h"
     21 
     22 #include <algorithm>
     23 #include <cmath>
     24 #include <memory>
     25 
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/register_types.h"
     29 #include "tensorflow/core/framework/tensor.h"
     30 #include "tensorflow/core/framework/tensor_shape.h"
     31 #include "tensorflow/core/lib/random/random_distributions.h"
     32 #include "tensorflow/core/lib/random/simple_philox.h"
     33 #include "tensorflow/core/util/guarded_philox_random.h"
     34 #include "tensorflow/core/util/work_sharder.h"
     35 
     36 namespace tensorflow {
     37 
     38 typedef Eigen::ThreadPoolDevice CPUDevice;
     39 typedef Eigen::GpuDevice GPUDevice;
     40 
     41 namespace functor {
     42 
     43 template <typename Device, typename T, typename OutputType>
     44 struct MultinomialFunctor {
     45   void operator()(OpKernelContext* ctx, const Device& d,
     46                   typename TTypes<T>::ConstMatrix logits,
     47                   typename TTypes<float>::Flat noises,
     48                   typename TTypes<float>::Flat scores,
     49                   typename TTypes<float>::Flat scratch, int batch_size,
     50                   int num_classes, int num_samples,
     51                   const random::PhiloxRandom& gen,
     52                   typename TTypes<OutputType>::Matrix output);
     53 };
     54 
     55 template <typename T, typename OutputType>
     56 struct MultinomialFunctor<CPUDevice, T, OutputType> {
     57   void operator()(OpKernelContext* ctx, const CPUDevice& d,
     58                   typename TTypes<T>::ConstMatrix logits,
     59                   typename TTypes<float>::Flat /* noises */,
     60                   typename TTypes<float>::Flat /* scores */,
     61                   typename TTypes<float>::Flat /* scratch */, int batch_size,
     62                   int num_classes, int num_samples,
     63                   const random::PhiloxRandom& gen,
     64                   typename TTypes<OutputType>::Matrix output) {
     65     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
     66 
     67     // The implementation only parallelizes by batch.
     68     //
     69     // This takes O(BatchSize * NumSamples * log(NumClasses) + NumClasses) CPU
     70     // time.
     71     auto DoWork = [ctx, num_samples, num_classes, &gen, &output, &logits](
     72                       int64 start_row, int64 limit_row) {
     73       // Capturing "gen" by-value would only make a copy for the _shared_
     74       // lambda.  Since we want to let each worker have its own copy, we pass
     75       // "gen" by reference and explicitly do a copy assignment here.
     76       random::PhiloxRandom gen_copy = gen;
     77       // Skip takes units of 128 bytes.  +3 is so rounding doesn't lead to
     78       // us using the same state in different batches.
     79       gen_copy.Skip(start_row * (num_samples + 3) / 4);
     80       random::SimplePhilox simple_philox(&gen_copy);
     81 
     82       Tensor cdf_tensor;
     83       OP_REQUIRES_OK(ctx,
     84                      ctx->allocate_temp(DT_DOUBLE, TensorShape({num_classes}),
     85                                         &cdf_tensor));
     86       auto cdf = cdf_tensor.flat<double>();
     87       for (int64 b = start_row; b < limit_row; ++b) {
     88         const auto* logits_row = &logits(b, 0);
     89 
     90         // Takes an along-class maximum (for numerical stability).
     91         T max = std::numeric_limits<T>::lowest();
     92         for (int64 j = 0; j < num_classes; ++j) {
     93           if (Eigen::numext::isfinite(logits_row[j])) {
     94             max = std::max(max, logits_row[j]);
     95           }
     96         }
     97         const double max_logit = static_cast<double>(max);
     98 
     99         // Precompute cumulative probability distribution across classes.
    100         // Note: This isn't normalized.
    101         cdf = (logits.template chip<0>(b).template cast<double>() - max_logit)
    102                   .exp();
    103         double running_total = 0;
    104         for (int64 j = 0; j < num_classes; ++j) {
    105           if (Eigen::numext::isfinite(logits_row[j])) {
    106             running_total += cdf(j);
    107           }
    108           cdf(j) = running_total;
    109         }
    110         // Generate each sample.
    111         const double* cdf_begin = cdf.data();
    112         const double* cdf_end = cdf.data() + num_classes;
    113         for (int64 j = 0; j < num_samples; ++j) {
    114           const double to_find = simple_philox.RandDouble() * running_total;
    115           auto found_iter = std::upper_bound(cdf_begin, cdf_end, to_find);
    116           output(b, j) = std::distance(cdf_begin, found_iter);
    117         }
    118       }
    119     };
    120     // Incredibly rough estimate of clock cycles for DoWork();
    121     const int64 cost =
    122         50 * (num_samples * std::log(num_classes) / std::log(2) + num_classes);
    123     Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost,
    124           DoWork);
    125   }
    126 };
    127 
    128 }  // namespace functor
    129 
    130 // Samples from a multinomial distribution.
    131 template <typename Device, typename T, typename OutputType>
    132 class MultinomialOp : public OpKernel {
    133  public:
    134   explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {
    135     OP_REQUIRES_OK(context, generator_.Init(context));
    136   }
    137 
    138   void Compute(OpKernelContext* ctx) override {
    139     const Tensor& logits_t = ctx->input(0);
    140     const Tensor& num_samples_t = ctx->input(1);
    141 
    142     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_t.shape()),
    143                 errors::InvalidArgument("logits should be a matrix, got shape ",
    144                                         logits_t.shape().DebugString()));
    145     OP_REQUIRES(
    146         ctx, TensorShapeUtils::IsScalar(num_samples_t.shape()),
    147         errors::InvalidArgument("num_samples should be a scalar, got shape ",
    148                                 num_samples_t.shape().DebugString()));
    149 
    150     const int num_samples = num_samples_t.scalar<int>()();
    151     OP_REQUIRES(ctx, num_samples >= 0,
    152                 errors::InvalidArgument(
    153                     "num_samples should be nonnegative, got ", num_samples));
    154 
    155     for (int i = 0; i < 2; i++) {
    156       const int64 dim = logits_t.dim_size(i);
    157       OP_REQUIRES(ctx, static_cast<int>(dim) == dim,
    158                   errors::InvalidArgument(
    159                       "logits.shape = ", logits_t.shape().DebugString(),
    160                       " too large for int"));
    161     }
    162     const int batch_size = static_cast<int>(logits_t.dim_size(0));
    163     const int num_classes = static_cast<int>(logits_t.dim_size(1));
    164     OP_REQUIRES(ctx, num_classes > 0,
    165                 errors::InvalidArgument("num_classes should be positive, got ",
    166                                         num_classes));
    167 
    168     Tensor* samples_t;
    169     OP_REQUIRES_OK(
    170         ctx, ctx->allocate_output(0, TensorShape({batch_size, num_samples}),
    171                                   &samples_t));
    172 
    173     // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU.
    174     if (samples_t->NumElements() > 0) {
    175       Tensor noises, scores, scratch;  // Scratch space only used for GPU.
    176       if (std::is_same<Device, GPUDevice>::value) {
    177         OP_REQUIRES_OK(
    178             ctx,
    179             ctx->allocate_temp(
    180                 DT_FLOAT, TensorShape({batch_size, num_samples, num_classes}),
    181                 &noises));
    182         OP_REQUIRES_OK(
    183             ctx,
    184             ctx->allocate_temp(
    185                 DT_FLOAT, TensorShape({batch_size, num_samples, num_classes}),
    186                 &scores));
    187         OP_REQUIRES_OK(
    188             ctx,
    189             ctx->allocate_temp(DT_FLOAT, TensorShape({batch_size, num_samples}),
    190                                &scratch));
    191       }
    192 
    193       int num_samples_ceil_4 = (num_samples + 3) / 4 * 4;
    194       // CPU generates doubles = 2 samples per number.
    195       if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2;
    196       auto rng =
    197           generator_.ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
    198       functor::MultinomialFunctor<Device, T, OutputType>()(
    199           ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(),
    200           noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(),
    201           batch_size, num_classes, num_samples, rng,
    202           samples_t->matrix<OutputType>());
    203     }
    204   }
    205 
    206  private:
    207   GuardedPhiloxRandom generator_;
    208 
    209   TF_DISALLOW_COPY_AND_ASSIGN(MultinomialOp);
    210 };
    211 
    212 #define REGISTER(TYPE)                                                   \
    213   REGISTER_KERNEL_BUILDER(Name("Multinomial")                            \
    214                               .Device(DEVICE_CPU)                        \
    215                               .TypeConstraint<TYPE>("T")                 \
    216                               .TypeConstraint("output_dtype", DT_INT32), \
    217                           MultinomialOp<CPUDevice, TYPE, int32>);        \
    218   REGISTER_KERNEL_BUILDER(Name("Multinomial")                            \
    219                               .Device(DEVICE_CPU)                        \
    220                               .TypeConstraint<TYPE>("T")                 \
    221                               .TypeConstraint("output_dtype", DT_INT64), \
    222                           MultinomialOp<CPUDevice, TYPE, int64>);
    223 
    224 TF_CALL_half(REGISTER);
    225 TF_CALL_float(REGISTER);
    226 TF_CALL_double(REGISTER);
    227 #undef REGISTER
    228 
    229 #if GOOGLE_CUDA
    230 #define REGISTER(TYPE)                                                   \
    231   REGISTER_KERNEL_BUILDER(Name("Multinomial")                            \
    232                               .Device(DEVICE_GPU)                        \
    233                               .HostMemory("num_samples")                 \
    234                               .TypeConstraint<TYPE>("T")                 \
    235                               .TypeConstraint("output_dtype", DT_INT32), \
    236                           MultinomialOp<GPUDevice, TYPE, int32>)         \
    237   REGISTER_KERNEL_BUILDER(Name("Multinomial")                            \
    238                               .Device(DEVICE_GPU)                        \
    239                               .HostMemory("num_samples")                 \
    240                               .TypeConstraint<TYPE>("T")                 \
    241                               .TypeConstraint("output_dtype", DT_INT64), \
    242                           MultinomialOp<GPUDevice, TYPE, int64>)
    243 
    244 TF_CALL_half(REGISTER);
    245 TF_CALL_float(REGISTER);
    246 TF_CALL_double(REGISTER);
    247 #undef REGISTER
    248 
    249 #endif  // GOOGLE_CUDA
    250 
    251 }  // end namespace tensorflow
    252