Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/random_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include "tensorflow/core/kernels/random_op.h"
     21 
     22 #include <algorithm>
     23 #include <cmath>
     24 #include <memory>
     25 
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/lib/hash/crc32c.h"
     31 #include "tensorflow/core/lib/random/random_distributions.h"
     32 #include "tensorflow/core/lib/random/simple_philox.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/util/guarded_philox_random.h"
     35 #include "tensorflow/core/util/work_sharder.h"
     36 
     37 #if EIGEN_COMP_GNUC && __cplusplus > 199711L
     38 #define DISABLE_FLOAT_EQUALITY_WARNING \
     39   _Pragma("GCC diagnostic push")       \
     40       _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
     41 #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
     42 #else
     43 #define DISABLE_FLOAT_EQUALITY_WARNING
     44 #define ENABLE_FLOAT_EQUALITY_WARNING
     45 #endif
     46 
     47 namespace tensorflow {
     48 
     49 typedef Eigen::ThreadPoolDevice CPUDevice;
     50 typedef Eigen::GpuDevice GPUDevice;
     51 #ifdef TENSORFLOW_USE_SYCL
     52 typedef Eigen::SyclDevice SYCLDevice;
     53 #endif  // TENSORFLOW_USE_SYCL
     54 
     55 namespace functor {
     56 using random::PhiloxRandom;
     57 using random::SingleSampleAdapter;
     58 
     59 // The default implementation of the functor, which should never be invoked
     60 // But we still need to provide implementation for now for the linker to work,
     61 // since we do not support all the distributions yet.
     62 template <typename Device, class Distribution>
     63 struct FillPhiloxRandom {
     64   typedef typename Distribution::ResultElementType T;
     65   void operator()(OpKernelContext*, const Device&, random::PhiloxRandom gen,
     66                   T* data, int64 size, Distribution dist) {
     67     LOG(FATAL) << "Default FillPhiloxRandom should not be executed.";
     68   }
     69 };
     70 
     71 // A class to fill a specified range of random groups
     72 template <class Distribution, bool VariableSamplesPerOutput>
     73 struct FillPhiloxRandomTask;
     74 
     75 // Specialization for distribution that takes a fixed number of samples for
     76 // each output.
     77 template <class Distribution>
     78 struct FillPhiloxRandomTask<Distribution, false> {
     79   typedef typename Distribution::ResultElementType T;
     80   static void Run(random::PhiloxRandom gen, T* data, int64 size,
     81                   int64 start_group, int64 limit_group, Distribution dist) {
     82     const int kGroupSize = Distribution::kResultElementCount;
     83 
     84     gen.Skip(start_group);
     85     int64 offset = start_group * kGroupSize;
     86 
     87     // First fill all the full-size groups
     88     int64 limit_group_full = std::min(limit_group, size / kGroupSize);
     89     for (int64 index = start_group; index < limit_group_full; ++index) {
     90       auto samples = dist(&gen);
     91       std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
     92       offset += kGroupSize;
     93     }
     94 
     95     // If there are any remaining elements that need to be filled, process them
     96     if (limit_group_full < limit_group) {
     97       int64 remaining_size = size - limit_group_full * kGroupSize;
     98       auto samples = dist(&gen);
     99       std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
    100     }
    101   }
    102 };
    103 
    104 // Specialization for distribution that takes a variable number of samples for
    105 // each output. This will be slower due to the generality.
    106 template <class Distribution>
    107 struct FillPhiloxRandomTask<Distribution, true> {
    108   typedef typename Distribution::ResultElementType T;
    109   static const int64 kReservedSamplesPerOutput = 256;
    110 
    111   static void Run(random::PhiloxRandom base_gen, T* data, int64 size,
    112                   int64 start_group, int64 limit_group, Distribution dist) {
    113     const int kGroupSize = Distribution::kResultElementCount;
    114 
    115     static const int kGeneratorSkipPerOutputGroup =
    116         kGroupSize * kReservedSamplesPerOutput /
    117         PhiloxRandom::kResultElementCount;
    118 
    119     int64 offset = start_group * kGroupSize;
    120 
    121     // First fill all the full-size groups
    122     int64 limit_group_full = std::min(limit_group, size / kGroupSize);
    123     int64 group_index;
    124     for (group_index = start_group; group_index < limit_group_full;
    125          ++group_index) {
    126       // Reset the generator to the beginning of the output group region
    127       // This is necessary if we want the results to be independent of order
    128       // of work
    129       PhiloxRandom gen = base_gen;
    130       gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
    131       SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
    132 
    133       auto samples = dist(&single_samples);
    134       std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
    135       offset += kGroupSize;
    136     }
    137 
    138     // If there are any remaining elements that need to be filled, process them
    139     if (limit_group_full < limit_group) {
    140       PhiloxRandom gen = base_gen;
    141       gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
    142       SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
    143 
    144       int64 remaining_size = size - limit_group_full * kGroupSize;
    145       auto samples = dist(&single_samples);
    146       std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
    147     }
    148   }
    149 };
    150 
    151 // Partial specialization for CPU to fill the entire region with randoms
    152 // It splits the work into several tasks and run them in parallel
    153 template <class Distribution>
    154 void FillPhiloxRandom<CPUDevice, Distribution>::operator()(
    155     OpKernelContext* context, const CPUDevice&, random::PhiloxRandom gen,
    156     typename Distribution::ResultElementType* data, int64 size,
    157     Distribution dist) {
    158   const int kGroupSize = Distribution::kResultElementCount;
    159 
    160   auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    161 
    162   int64 total_group_count = (size + kGroupSize - 1) / kGroupSize;
    163 
    164   const int kGroupCost =
    165       random::PhiloxRandom::kResultElementCount *
    166       (random::PhiloxRandom::kElementCost + Distribution::kElementCost);
    167   Shard(worker_threads.num_threads, worker_threads.workers, total_group_count,
    168         kGroupCost,
    169         [&gen, data, size, dist](int64 start_group, int64 limit_group) {
    170           FillPhiloxRandomTask<
    171               Distribution,
    172               Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
    173                                                             start_group,
    174                                                             limit_group, dist);
    175         });
    176 }
    177 
    178 }  // namespace functor
    179 
    180 namespace {
    181 
    182 static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
    183                                       int index, Tensor** output) {
    184   TensorShape tensor_shape;
    185   TF_RETURN_IF_ERROR(ctx->op_kernel().MakeShape(shape, &tensor_shape));
    186   return ctx->allocate_output(index, tensor_shape, output);
    187 }
    188 
    189 // For now, use the same interface as RandomOp, so we can choose either one
    190 // at the run-time.
    191 template <typename Device, class Distribution>
    192 class PhiloxRandomOp : public OpKernel {
    193  public:
    194   typedef typename Distribution::ResultElementType T;
    195   explicit PhiloxRandomOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    196     OP_REQUIRES_OK(ctx, generator_.Init(ctx));
    197   }
    198 
    199   void Compute(OpKernelContext* ctx) override {
    200     const Tensor& shape = ctx->input(0);
    201     Tensor* output;
    202     OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
    203     auto output_flat = output->flat<T>();
    204     functor::FillPhiloxRandom<Device, Distribution>()(
    205         ctx, ctx->eigen_device<Device>(),
    206         // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
    207         // it just here.
    208         generator_.ReserveRandomOutputs(output_flat.size(), 256),
    209         output_flat.data(), output_flat.size(), Distribution());
    210   }
    211 
    212  private:
    213   GuardedPhiloxRandom generator_;
    214 };
    215 
    216 template <typename Device, class IntType>
    217 class RandomUniformIntOp : public OpKernel {
    218  public:
    219   explicit RandomUniformIntOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    220     OP_REQUIRES_OK(ctx, generator_.Init(ctx));
    221   }
    222 
    223   void Compute(OpKernelContext* ctx) override {
    224     const Tensor& shape = ctx->input(0);
    225     const Tensor& minval = ctx->input(1);
    226     const Tensor& maxval = ctx->input(2);
    227     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval.shape()),
    228                 errors::InvalidArgument("minval must be 0-D, got shape ",
    229                                         minval.shape().DebugString()));
    230     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval.shape()),
    231                 errors::InvalidArgument("maxval must be 0-D, got shape ",
    232                                         maxval.shape().DebugString()));
    233 
    234     // Verify that minval < maxval
    235     IntType lo = minval.scalar<IntType>()();
    236     IntType hi = maxval.scalar<IntType>()();
    237     OP_REQUIRES(
    238         ctx, lo < hi,
    239         errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
    240 
    241     // Build distribution
    242     typedef random::UniformDistribution<random::PhiloxRandom, IntType>
    243         Distribution;
    244     Distribution dist(lo, hi);
    245 
    246     Tensor* output;
    247     OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
    248     auto output_flat = output->flat<IntType>();
    249     functor::FillPhiloxRandom<Device, Distribution>()(
    250         ctx, ctx->eigen_device<Device>(),
    251         // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change
    252         // it just here.
    253         generator_.ReserveRandomOutputs(output_flat.size(), 256),
    254         output_flat.data(), output_flat.size(), dist);
    255   }
    256 
    257  private:
    258   GuardedPhiloxRandom generator_;
    259 };
    260 
    261 // Samples from one or more gamma distributions. All internal computations are
    262 // done with double precision for numerical stability.
    263 template <typename T>
    264 class RandomGammaOp : public OpKernel {
    265  public:
    266   explicit RandomGammaOp(OpKernelConstruction* context) : OpKernel(context) {
    267     OP_REQUIRES_OK(context, generator_.Init(context));
    268   }
    269 
    270   void Compute(OpKernelContext* ctx) override {
    271     const Tensor& shape_t = ctx->input(0);
    272     const Tensor& alpha_t = ctx->input(1);
    273 
    274     OP_REQUIRES(ctx,
    275                 TensorShapeUtils::IsVector(shape_t.shape()) &&
    276                     (shape_t.dtype() == DataType::DT_INT32 ||
    277                      shape_t.dtype() == DataType::DT_INT64),
    278                 errors::InvalidArgument(
    279                     "shape must be a vector of {int32,int64}, got shape: ",
    280                     shape_t.DebugString()));
    281     TensorShape samples_shape;
    282     if (shape_t.dtype() == DataType::DT_INT32) {
    283       auto vec = shape_t.flat<int32>();
    284       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
    285                                                       &samples_shape));
    286     } else if (shape_t.dtype() == DataType::DT_INT64) {
    287       auto vec = shape_t.flat<int64>();
    288       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(vec.data(), vec.size(),
    289                                                       &samples_shape));
    290     }
    291     const int64 num_samples = samples_shape.num_elements();
    292 
    293     samples_shape.AppendShape(alpha_t.shape());
    294     // Allocate output samples.
    295     Tensor* samples_t = nullptr;
    296     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, samples_shape, &samples_t));
    297 
    298     if (num_samples == 0) return;
    299 
    300     using random::PhiloxRandom;
    301 
    302     typedef random::NormalDistribution<PhiloxRandom, double> Normal;
    303     typedef random::UniformDistribution<PhiloxRandom, double> Uniform;
    304 #define UNIFORM(X)                                    \
    305   if (uniform_remaining == 0) {                       \
    306     uniform_remaining = Uniform::kResultElementCount; \
    307     uniform_result = uniform(&gen);                   \
    308   }                                                   \
    309   uniform_remaining--;                                \
    310   double X = uniform_result[uniform_remaining]
    311 
    312     // Each attempt is 95+% successful, and requires 1-2 normal + 1 uniform
    313     static constexpr int kReservedSamplesPerOutput = 256;
    314 
    315     const auto alpha_flat = alpha_t.flat<T>().data();
    316     const int64 num_alphas = alpha_t.NumElements();
    317     OP_REQUIRES(ctx, num_alphas > 0,
    318                 errors::InvalidArgument(
    319                     "Input alpha should have non-zero element count, got: ",
    320                     num_alphas));
    321     auto samples_flat = samples_t->flat<T>().data();
    322     PhiloxRandom rng = generator_.ReserveRandomOutputs(
    323         num_samples * num_alphas, kReservedSamplesPerOutput);
    324 
    325     // We partition work first across alphas then across samples-per-alpha to
    326     // avoid a couple flops which can be done on a per-alpha basis.
    327 
    328     auto DoWork = [num_samples, num_alphas, &rng, samples_flat, alpha_flat](
    329                       int start_output, int limit_output) {
    330       using Eigen::numext::exp;
    331       using Eigen::numext::log;
    332       using Eigen::numext::pow;
    333 
    334       // Capturing "rng" by-value would only make a copy for the _shared_
    335       // lambda.  Since we want to let each worker have its own copy, we pass
    336       // "rng" by reference and explicitly do a copy assignment.
    337 
    338       Normal normal;
    339       Uniform uniform;
    340       typename Normal::ResultType norm_result;
    341       typename Uniform::ResultType uniform_result;
    342       for (int64 output_idx = start_output; output_idx < limit_output;
    343            /* output_idx incremented within inner loop below */) {
    344         int64 alpha_idx = output_idx / num_samples;
    345 
    346         // Instead of +alpha_idx for each sample, we offset the pointer once.
    347         T* const samples_alpha_offset = samples_flat + alpha_idx;
    348 
    349         // Several calculations can be done on a per-alpha basis.
    350         const double alpha = static_cast<double>(alpha_flat[alpha_idx]);
    351 
    352         DISABLE_FLOAT_EQUALITY_WARNING
    353         if (alpha == double(1.0)) {
    354           ENABLE_FLOAT_EQUALITY_WARNING
    355           // Sample from an exponential distribution.
    356           for (int64 sample_idx = output_idx % num_samples;
    357                sample_idx < num_samples && output_idx < limit_output;
    358                sample_idx++, output_idx++) {
    359             // As we want data stable regardless of sharding
    360             // (including eventually on GPU), we skip on a per-sample basis.
    361             PhiloxRandom gen = rng;
    362             gen.Skip(kReservedSamplesPerOutput * output_idx);
    363             short uniform_remaining = 0;
    364             UNIFORM(u);
    365             const double res = -log(1.0 - u);
    366             samples_alpha_offset[sample_idx * num_alphas] = static_cast<T>(res);
    367           }       // for (sample_idx)
    368         } else {  // if alpha != 1.0
    369           // Transformation-rejection from pairs of uniform and normal random
    370           // variables. http://dl.acm.org/citation.cfm?id=358414
    371           //
    372           // The algorithm has an acceptance rate of ~95% for small alpha (~1),
    373           // and higher accept rates for higher alpha, so runtime is
    374           // O(NumAlphas * NumSamples * k) with k ~ 1 / 0.95.
    375           //
    376           // For alpha<1, we add one to d=alpha-1/3, and multiply the final
    377           // result by uniform()^(1/alpha)
    378           const bool alpha_less_than_one = alpha < 1;
    379           const double d = alpha + (alpha_less_than_one ? 2.0 / 3 : -1.0 / 3);
    380           const double c = 1.0 / 3 / sqrt(d);
    381 
    382           // Compute the rest of the samples for the current alpha value.
    383           for (int64 sample_idx = output_idx % num_samples;
    384                sample_idx < num_samples && output_idx < limit_output;
    385                sample_idx++, output_idx++) {
    386             // Since each sample may use a variable number of normal/uniform
    387             // samples, and we want data stable regardless of sharding
    388             // (including eventually on GPU), we skip on a per-sample basis.
    389             PhiloxRandom gen = rng;
    390             gen.Skip(kReservedSamplesPerOutput * output_idx);
    391             short norm_remaining = 0;
    392             short uniform_remaining = 0;
    393 
    394             // Keep trying until we don't reject a sample. In practice, we will
    395             // only reject ~5% at worst, for low alpha near 1.
    396             while (true) {
    397               if (norm_remaining == 0) {
    398                 norm_remaining = Normal::kResultElementCount;
    399                 norm_result = normal(&gen);
    400               }
    401               norm_remaining--;
    402               const double x = norm_result[norm_remaining];
    403               double v = 1 + c * x;
    404               if (v <= 0) {
    405                 continue;
    406               }
    407               v = v * v * v;
    408               UNIFORM(u);
    409               // The first option in the if is a "squeeze" short-circuit to
    410               // dodge the two logs. Magic constant sourced from the paper
    411               // linked above. Upward of .91 of the area covered by the log
    412               // inequality is covered by the squeeze as well (larger coverage
    413               // for smaller values of alpha).
    414               if ((u < 1 - 0.0331 * (x * x) * (x * x)) ||
    415                   (log(u) < 0.5 * x * x + d * (1 - v + log(v)))) {
    416                 double res = d * v;
    417                 if (alpha_less_than_one) {
    418                   UNIFORM(b);
    419                   res *= pow(b, 1 / alpha);
    420                 }
    421                 samples_alpha_offset[sample_idx * num_alphas] =
    422                     static_cast<T>(res);
    423                 break;
    424               }
    425             }  // while: true
    426           }    // for: sample_idx
    427         }      // if (alpha == 1.0)
    428       }        // for: output_idx
    429     };         // DoWork
    430 #undef UNIFORM
    431     // Two calls to log only occur for ~10% of samples reaching the log line.
    432     //   2 x 100 (64-bit cycles per log) x 0.10 = ~20.
    433     // Other ops: sqrt, +, *, /, %... something like 15 of these, at 3-6 cycles
    434     // each = ~60.
    435     // All of this /0.95 due to the rejection possibility = ~85.
    436     static const int kElementCost = 85 + 2 * Normal::kElementCost +
    437                                     Uniform::kElementCost +
    438                                     3 * PhiloxRandom::kElementCost;
    439     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
    440     Shard(worker_threads.num_threads, worker_threads.workers,
    441           num_alphas * num_samples, kElementCost, DoWork);
    442   }
    443 
    444  private:
    445   GuardedPhiloxRandom generator_;
    446 
    447   TF_DISALLOW_COPY_AND_ASSIGN(RandomGammaOp);
    448 };
    449 
    450 }  // namespace
    451 
    452 #define REGISTER(TYPE)                                                         \
    453   template struct functor::FillPhiloxRandom<                                   \
    454       CPUDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>;     \
    455   template struct functor::FillPhiloxRandom<                                   \
    456       CPUDevice, random::NormalDistribution<random::PhiloxRandom, TYPE>>;      \
    457   template struct functor::FillPhiloxRandom<                                   \
    458       CPUDevice,                                                               \
    459       random::TruncatedNormalDistribution<                                     \
    460           random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>;           \
    461   REGISTER_KERNEL_BUILDER(                                                     \
    462       Name("RandomUniform")                                                    \
    463           .Device(DEVICE_CPU)                                                  \
    464           .HostMemory("shape")                                                 \
    465           .TypeConstraint<TYPE>("dtype"),                                      \
    466       PhiloxRandomOp<CPUDevice, random::UniformDistribution<                   \
    467                                     random::PhiloxRandom, TYPE>>);             \
    468   REGISTER_KERNEL_BUILDER(                                                     \
    469       Name("RandomStandardNormal")                                             \
    470           .Device(DEVICE_CPU)                                                  \
    471           .HostMemory("shape")                                                 \
    472           .TypeConstraint<TYPE>("dtype"),                                      \
    473       PhiloxRandomOp<CPUDevice,                                                \
    474                      random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
    475   REGISTER_KERNEL_BUILDER(                                                     \
    476       Name("TruncatedNormal")                                                  \
    477           .Device(DEVICE_CPU)                                                  \
    478           .HostMemory("shape")                                                 \
    479           .TypeConstraint<TYPE>("dtype"),                                      \
    480       PhiloxRandomOp<                                                          \
    481           CPUDevice,                                                           \
    482           random::TruncatedNormalDistribution<                                 \
    483               random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);      \
    484   REGISTER_KERNEL_BUILDER(                                                     \
    485       Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"),        \
    486       RandomGammaOp<TYPE>)
    487 
    488 #define REGISTER_INT(IntType)                                   \
    489   REGISTER_KERNEL_BUILDER(Name("RandomUniformInt")              \
    490                               .Device(DEVICE_CPU)               \
    491                               .HostMemory("shape")              \
    492                               .HostMemory("minval")             \
    493                               .HostMemory("maxval")             \
    494                               .TypeConstraint<IntType>("Tout"), \
    495                           RandomUniformIntOp<CPUDevice, IntType>);
    496 
    497 TF_CALL_half(REGISTER);
    498 TF_CALL_float(REGISTER);
    499 TF_CALL_double(REGISTER);
    500 TF_CALL_int32(REGISTER_INT);
    501 TF_CALL_int64(REGISTER_INT);
    502 
    503 #undef REGISTER
    504 #undef REGISTER_INT
    505 
    506 #if GOOGLE_CUDA
    507 
    508 #define REGISTER(TYPE)                                                         \
    509   REGISTER_KERNEL_BUILDER(                                                     \
    510       Name("RandomUniform")                                                    \
    511           .Device(DEVICE_GPU)                                                  \
    512           .HostMemory("shape")                                                 \
    513           .TypeConstraint<int32>("T")                                          \
    514           .TypeConstraint<TYPE>("dtype"),                                      \
    515       PhiloxRandomOp<GPUDevice, random::UniformDistribution<                   \
    516                                     random::PhiloxRandom, TYPE>>);             \
    517   REGISTER_KERNEL_BUILDER(                                                     \
    518       Name("RandomStandardNormal")                                             \
    519           .Device(DEVICE_GPU)                                                  \
    520           .HostMemory("shape")                                                 \
    521           .TypeConstraint<int32>("T")                                          \
    522           .TypeConstraint<TYPE>("dtype"),                                      \
    523       PhiloxRandomOp<GPUDevice,                                                \
    524                      random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
    525   REGISTER_KERNEL_BUILDER(                                                     \
    526       Name("TruncatedNormal")                                                  \
    527           .Device(DEVICE_GPU)                                                  \
    528           .HostMemory("shape")                                                 \
    529           .TypeConstraint<int32>("T")                                          \
    530           .TypeConstraint<TYPE>("dtype"),                                      \
    531       PhiloxRandomOp<                                                          \
    532           GPUDevice,                                                           \
    533           random::TruncatedNormalDistribution<                                 \
    534               random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
    535 
    536 #define REGISTER_INT(IntType)                                   \
    537   REGISTER_KERNEL_BUILDER(Name("RandomUniformInt")              \
    538                               .Device(DEVICE_GPU)               \
    539                               .HostMemory("shape")              \
    540                               .HostMemory("minval")             \
    541                               .HostMemory("maxval")             \
    542                               .TypeConstraint<int32>("T")       \
    543                               .TypeConstraint<IntType>("Tout"), \
    544                           RandomUniformIntOp<GPUDevice, IntType>);
    545 
    546 TF_CALL_half(REGISTER);
    547 TF_CALL_float(REGISTER);
    548 TF_CALL_double(REGISTER);
    549 TF_CALL_int32(REGISTER_INT);
    550 TF_CALL_int64(REGISTER_INT);
    551 
    552 #undef REGISTER
    553 #undef REGISTER_INT
    554 
    555 #endif  // GOOGLE_CUDA
    556 
    557 #ifdef TENSORFLOW_USE_SYCL
    558 
    559 namespace functor {
    560 
    561 using namespace cl;
    562 
    563 template <class Distribution, bool VariableSamplesPerOutput>
    564 struct FillPhiloxRandomKernel;
    565 
    566 template <class Distribution>
    567 struct FillPhiloxRandomKernel<Distribution, false> {
    568   typedef typename Distribution::ResultElementType T;
    569   using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
    570                                         sycl::access::target::global_buffer>;
    571 
    572   FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
    573                          Distribution& dist)
    574       : data_(data), gen_(gen), dist_(dist) {}
    575 
    576   void operator()(sycl::nd_item<1> item) {
    577     const size_t kGroupSize = Distribution::kResultElementCount;
    578 
    579     const size_t item_id = item.get_global(0);
    580     const size_t total_item_count = item.get_global_range();
    581     size_t offset = item_id * kGroupSize;
    582     gen_.Skip(item_id);
    583 
    584     const size_t size = data_.get_size() / sizeof(T);
    585     T* data = ConvertToActualTypeSycl(T, data_);
    586 
    587     while (offset + kGroupSize <= size) {
    588       const typename Distribution::ResultType samples = dist_(&gen_);
    589       for (size_t i = 0; i < kGroupSize; ++i) {
    590         data[offset + i] = samples[i];
    591       }
    592 
    593       offset += (total_item_count - 1) * kGroupSize;
    594       gen_.Skip(total_item_count - 1);
    595     }
    596 
    597     const typename Distribution::ResultType samples = dist_(&gen_);
    598     for (size_t i = 0; i < kGroupSize; ++i) {
    599       if (offset >= size) {
    600         return;
    601       }
    602       data[offset] = samples[i];
    603       ++offset;
    604     }
    605   }
    606 
    607  private:
    608   write_accessor data_;
    609   random::PhiloxRandom gen_;
    610   Distribution dist_;
    611 };
    612 
    613 template <class Distribution>
    614 struct FillPhiloxRandomKernel<Distribution, true> {
    615   typedef typename Distribution::ResultElementType T;
    616   using write_accessor = sycl::accessor<uint8_t, 1, sycl::access::mode::write,
    617                                         sycl::access::target::global_buffer>;
    618 
    619   FillPhiloxRandomKernel(write_accessor& data, random::PhiloxRandom& gen,
    620                          Distribution& dist)
    621       : data_(data), gen_(gen), dist_(dist) {}
    622 
    623   void operator()(sycl::nd_item<1> item) {
    624     using random::PhiloxRandom;
    625     using random::SingleSampleAdapter;
    626 
    627     const size_t kReservedSamplesPerOutput = 256;
    628     const size_t kGroupSize = Distribution::kResultElementCount;
    629     const size_t kGeneratorSkipPerOutputGroup =
    630         kGroupSize * kReservedSamplesPerOutput /
    631         PhiloxRandom::kResultElementCount;
    632 
    633     const size_t item_id = item.get_global(0);
    634     const size_t total_item_count = item.get_global_range();
    635     size_t group_index = item_id;
    636     size_t offset = group_index * kGroupSize;
    637 
    638     T* data = ConvertToActualTypeSycl(T, data_);
    639     const size_t size = data_.get_size() / sizeof(T);
    640 
    641     while (offset < size) {
    642       // Since each output takes a variable number of samples, we need to
    643       // realign the generator to the beginning for the current output group
    644       PhiloxRandom gen = gen_;
    645       gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
    646       SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
    647 
    648       const typename Distribution::ResultType samples = dist_(&single_samples);
    649 
    650       for (size_t i = 0; i < kGroupSize; ++i) {
    651         if (offset >= size) {
    652           return;
    653         }
    654         data[offset] = samples[i];
    655         ++offset;
    656       }
    657 
    658       offset += (total_item_count - 1) * kGroupSize;
    659       group_index += total_item_count;
    660     }
    661   }
    662 
    663  private:
    664   write_accessor data_;
    665   random::PhiloxRandom gen_;
    666   Distribution dist_;
    667 };
    668 
    669 template <typename T>
    670 class FillRandomKernel;
    671 // Partial specialization for SYCL to fill the entire region with randoms
    672 // It splits the work into several tasks and run them in parallel
    673 template <class Distribution>
    674 void FillPhiloxRandom<SYCLDevice, Distribution>::operator()(
    675     OpKernelContext* context, const SYCLDevice& device,
    676     random::PhiloxRandom gen, typename Distribution::ResultElementType* data,
    677     int64 size, Distribution dist) {
    678   const size_t group_size = device.maxSyclThreadsPerBlock();
    679   const size_t group_count = (size + group_size - 1) / group_size;
    680 
    681   auto buffer = device.get_sycl_buffer(data);
    682 
    683   device.sycl_queue().submit([&](sycl::handler& cgh) {
    684     auto access = buffer.template get_access<sycl::access::mode::write>(cgh);
    685 
    686     FillPhiloxRandomKernel<Distribution,
    687                            Distribution::kVariableSamplesPerOutput>
    688         task(access, gen, dist);
    689     cgh.parallel_for<class FillRandomKernel<Distribution>>(
    690         sycl::nd_range<1>(sycl::range<1>(group_count * group_size),
    691                           sycl::range<1>(group_size)),
    692         task);
    693   });
    694 }
    695 
    696 }  // namespace functor
    697 
    698 #define REGISTER(TYPE)                                                         \
    699   template struct functor::FillPhiloxRandom<                                   \
    700       SYCLDevice, random::UniformDistribution<random::PhiloxRandom, TYPE>>;    \
    701   REGISTER_KERNEL_BUILDER(                                                     \
    702       Name("RandomUniform")                                                    \
    703           .Device(DEVICE_SYCL)                                                 \
    704           .HostMemory("shape")                                                 \
    705           .TypeConstraint<TYPE>("dtype"),                                      \
    706       PhiloxRandomOp<SYCLDevice, random::UniformDistribution<                  \
    707                                      random::PhiloxRandom, TYPE>>);            \
    708   REGISTER_KERNEL_BUILDER(                                                     \
    709       Name("RandomStandardNormal")                                             \
    710           .Device(DEVICE_SYCL)                                                 \
    711           .HostMemory("shape")                                                 \
    712           .TypeConstraint<TYPE>("dtype"),                                      \
    713       PhiloxRandomOp<SYCLDevice,                                               \
    714                      random::NormalDistribution<random::PhiloxRandom, TYPE>>); \
    715   REGISTER_KERNEL_BUILDER(                                                     \
    716       Name("TruncatedNormal")                                                  \
    717           .Device(DEVICE_SYCL)                                                 \
    718           .HostMemory("shape")                                                 \
    719           .TypeConstraint<TYPE>("dtype"),                                      \
    720       PhiloxRandomOp<                                                          \
    721           SYCLDevice,                                                          \
    722           random::TruncatedNormalDistribution<                                 \
    723               random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
    724 
    725 #define REGISTER_INT(IntType)                                   \
    726   REGISTER_KERNEL_BUILDER(Name("RandomUniformInt")              \
    727                               .Device(DEVICE_SYCL)              \
    728                               .HostMemory("shape")              \
    729                               .HostMemory("minval")             \
    730                               .HostMemory("maxval")             \
    731                               .TypeConstraint<IntType>("Tout"), \
    732                           RandomUniformIntOp<SYCLDevice, IntType>);
    733 
    734 TF_CALL_float(REGISTER);
    735 TF_CALL_double(REGISTER);
    736 TF_CALL_int32(REGISTER_INT);
    737 TF_CALL_int64(REGISTER_INT);
    738 
    739 #undef REGISTER
    740 #undef REGISTER_INT
    741 
    742 #endif  // TENSORFLOW_USE_SYCL
    743 
    744 }  // end namespace tensorflow
    745