Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op_kernel.h"
     17 #include "tensorflow/core/framework/register_types.h"
     18 #include "tensorflow/core/framework/tensor.h"
     19 #include "tensorflow/core/framework/tensor_shape.h"
     20 #include "tensorflow/core/kernels/bounds_check.h"
     21 #include "tensorflow/core/kernels/random_op.h"
     22 #include "tensorflow/core/lib/random/random_distributions.h"
     23 #include "tensorflow/core/platform/logging.h"
     24 
     25 namespace tensorflow {
     26 
     27 using CPUDevice = Eigen::ThreadPoolDevice;
     28 using GPUDevice = Eigen::GpuDevice;
     29 
     30 namespace {
     31 
     32 class StatelessRandomOpBase : public OpKernel {
     33  public:
     34   explicit StatelessRandomOpBase(OpKernelConstruction* context)
     35       : OpKernel(context) {}
     36 
     37   void Compute(OpKernelContext* context) override {
     38     // Sanitize input
     39     const Tensor& shape_t = context->input(0);
     40     const Tensor& seed_t = context->input(1);
     41     TensorShape shape;
     42     OP_REQUIRES_OK(context, MakeShape(shape_t, &shape));
     43     OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
     44                 errors::InvalidArgument("seed must have shape [2], not ",
     45                                         seed_t.shape().DebugString()));
     46 
     47     // Allocate output
     48     Tensor* output;
     49     OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
     50     if (shape.num_elements() == 0) return;
     51 
     52     // Grab the two seeds
     53     uint64 seed0;
     54     uint64 seed1;
     55     if (context->input_dtype(1) == DT_INT32) {
     56       const auto seed = seed_t.flat<int32>();
     57       seed0 = internal::SubtleMustCopy(seed(0));
     58       seed1 = internal::SubtleMustCopy(seed(1));
     59     } else {
     60       CHECK_EQ(DT_INT64, context->input_dtype(1));
     61       const auto seed = seed_t.flat<int64>();
     62       seed0 = internal::SubtleMustCopy(seed(0));
     63       seed1 = internal::SubtleMustCopy(seed(1));
     64     }
     65 
     66     // Scramble the seeds so that the user doesn't need to worry about which
     67     // part of the seed needs to be strong.
     68     random::PhiloxRandom::Key key;
     69     random::PhiloxRandom::ResultType counter;
     70     key[0] = 0x3ec8f720;
     71     key[1] = 0x02461e29;
     72     counter[0] = static_cast<uint32>(seed0);
     73     counter[1] = static_cast<uint32>(seed0 >> 32);
     74     counter[2] = static_cast<uint32>(seed1);
     75     counter[3] = static_cast<uint32>(seed1 >> 32);
     76     const auto mix = random::PhiloxRandom(counter, key)();
     77     key[0] = mix[0];
     78     key[1] = mix[1];
     79     counter[0] = counter[1] = 0;
     80     counter[2] = mix[2];
     81     counter[3] = mix[3];
     82 
     83     // Fill in the random numbers
     84     Fill(context, random::PhiloxRandom(counter, key), output);
     85   }
     86 
     87   // The part of Compute that depends on device, type, and distribution
     88   virtual void Fill(OpKernelContext* context, random::PhiloxRandom random,
     89                     Tensor* output) = 0;
     90 };
     91 
     92 template <typename Device, class Distribution>
     93 class StatelessRandomOp : public StatelessRandomOpBase {
     94  public:
     95   using StatelessRandomOpBase::StatelessRandomOpBase;
     96 
     97   void Fill(OpKernelContext* context, random::PhiloxRandom random,
     98             Tensor* output) override {
     99     typedef typename Distribution::ResultElementType T;
    100     auto flat = output->flat<T>();
    101     // Reuse the compute kernels from the stateful random ops
    102     functor::FillPhiloxRandom<Device, Distribution>()(
    103         context, context->eigen_device<Device>(), random, flat.data(),
    104         flat.size(), Distribution());
    105   }
    106 };
    107 
    108 }  // namespace
    109 
    110 #define REGISTER(TYPE)                                                 \
    111   REGISTER_KERNEL_BUILDER(                                             \
    112       Name("StatelessRandomUniform")                                   \
    113           .Device(DEVICE_CPU)                                          \
    114           .HostMemory("shape")                                         \
    115           .TypeConstraint<TYPE>("dtype"),                              \
    116       StatelessRandomOp<CPUDevice, random::UniformDistribution<        \
    117                                        random::PhiloxRandom, TYPE> >); \
    118   REGISTER_KERNEL_BUILDER(                                             \
    119       Name("StatelessRandomNormal")                                    \
    120           .Device(DEVICE_CPU)                                          \
    121           .HostMemory("shape")                                         \
    122           .TypeConstraint<TYPE>("dtype"),                              \
    123       StatelessRandomOp<CPUDevice, random::NormalDistribution<         \
    124                                        random::PhiloxRandom, TYPE> >); \
    125   REGISTER_KERNEL_BUILDER(                                             \
    126       Name("StatelessTruncatedNormal")                                 \
    127           .Device(DEVICE_CPU)                                          \
    128           .HostMemory("shape")                                         \
    129           .TypeConstraint<TYPE>("dtype"),                              \
    130       StatelessRandomOp<                                               \
    131           CPUDevice,                                                   \
    132           random::TruncatedNormalDistribution<                         \
    133               random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
    134 
    135 TF_CALL_half(REGISTER);
    136 TF_CALL_float(REGISTER);
    137 TF_CALL_double(REGISTER);
    138 
    139 #undef REGISTER
    140 
    141 #if GOOGLE_CUDA
    142 
    143 #define REGISTER(TYPE)                                                 \
    144   REGISTER_KERNEL_BUILDER(                                             \
    145       Name("StatelessRandomUniform")                                   \
    146           .Device(DEVICE_GPU)                                          \
    147           .HostMemory("shape")                                         \
    148           .HostMemory("seed")                                          \
    149           .TypeConstraint<TYPE>("dtype"),                              \
    150       StatelessRandomOp<GPUDevice, random::UniformDistribution<        \
    151                                        random::PhiloxRandom, TYPE> >); \
    152   REGISTER_KERNEL_BUILDER(                                             \
    153       Name("StatelessRandomNormal")                                    \
    154           .Device(DEVICE_GPU)                                          \
    155           .HostMemory("shape")                                         \
    156           .HostMemory("seed")                                          \
    157           .TypeConstraint<TYPE>("dtype"),                              \
    158       StatelessRandomOp<GPUDevice, random::NormalDistribution<         \
    159                                        random::PhiloxRandom, TYPE> >); \
    160   REGISTER_KERNEL_BUILDER(                                             \
    161       Name("StatelessTruncatedNormal")                                 \
    162           .Device(DEVICE_GPU)                                          \
    163           .HostMemory("shape")                                         \
    164           .HostMemory("seed")                                          \
    165           .TypeConstraint<TYPE>("dtype"),                              \
    166       StatelessRandomOp<                                               \
    167           GPUDevice,                                                   \
    168           random::TruncatedNormalDistribution<                         \
    169               random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
    170 
    171 TF_CALL_half(REGISTER);
    172 TF_CALL_float(REGISTER);
    173 TF_CALL_double(REGISTER);
    174 
    175 #undef REGISTER
    176 
    177 #endif  // GOOGLE_CUDA
    178 
    179 }  // namespace tensorflow
    180