1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_kernel.h" 17 #include "tensorflow/core/framework/register_types.h" 18 #include "tensorflow/core/framework/tensor.h" 19 #include "tensorflow/core/framework/tensor_shape.h" 20 #include "tensorflow/core/kernels/bounds_check.h" 21 #include "tensorflow/core/kernels/random_op.h" 22 #include "tensorflow/core/lib/random/random_distributions.h" 23 #include "tensorflow/core/platform/logging.h" 24 25 namespace tensorflow { 26 27 using CPUDevice = Eigen::ThreadPoolDevice; 28 using GPUDevice = Eigen::GpuDevice; 29 30 namespace { 31 32 class StatelessRandomOpBase : public OpKernel { 33 public: 34 explicit StatelessRandomOpBase(OpKernelConstruction* context) 35 : OpKernel(context) {} 36 37 void Compute(OpKernelContext* context) override { 38 // Sanitize input 39 const Tensor& shape_t = context->input(0); 40 const Tensor& seed_t = context->input(1); 41 TensorShape shape; 42 OP_REQUIRES_OK(context, MakeShape(shape_t, &shape)); 43 OP_REQUIRES(context, seed_t.dims() == 1 && seed_t.dim_size(0) == 2, 44 errors::InvalidArgument("seed must have shape [2], not ", 45 seed_t.shape().DebugString())); 46 47 // Allocate output 48 Tensor* output; 49 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output)); 50 if (shape.num_elements() == 0) return; 51 52 // Grab the two seeds 53 uint64 seed0; 54 uint64 seed1; 55 if (context->input_dtype(1) == DT_INT32) { 56 const auto seed = seed_t.flat<int32>(); 57 seed0 = internal::SubtleMustCopy(seed(0)); 58 seed1 = internal::SubtleMustCopy(seed(1)); 59 } else { 60 CHECK_EQ(DT_INT64, context->input_dtype(1)); 61 const auto seed = seed_t.flat<int64>(); 62 seed0 = internal::SubtleMustCopy(seed(0)); 63 seed1 = internal::SubtleMustCopy(seed(1)); 64 } 65 66 // Scramble the seeds so that the user doesn't need to worry about which 67 // part of the seed needs to be strong. 68 random::PhiloxRandom::Key key; 69 random::PhiloxRandom::ResultType counter; 70 key[0] = 0x3ec8f720; 71 key[1] = 0x02461e29; 72 counter[0] = static_cast<uint32>(seed0); 73 counter[1] = static_cast<uint32>(seed0 >> 32); 74 counter[2] = static_cast<uint32>(seed1); 75 counter[3] = static_cast<uint32>(seed1 >> 32); 76 const auto mix = random::PhiloxRandom(counter, key)(); 77 key[0] = mix[0]; 78 key[1] = mix[1]; 79 counter[0] = counter[1] = 0; 80 counter[2] = mix[2]; 81 counter[3] = mix[3]; 82 83 // Fill in the random numbers 84 Fill(context, random::PhiloxRandom(counter, key), output); 85 } 86 87 // The part of Compute that depends on device, type, and distribution 88 virtual void Fill(OpKernelContext* context, random::PhiloxRandom random, 89 Tensor* output) = 0; 90 }; 91 92 template <typename Device, class Distribution> 93 class StatelessRandomOp : public StatelessRandomOpBase { 94 public: 95 using StatelessRandomOpBase::StatelessRandomOpBase; 96 97 void Fill(OpKernelContext* context, random::PhiloxRandom random, 98 Tensor* output) override { 99 typedef typename Distribution::ResultElementType T; 100 auto flat = output->flat<T>(); 101 // Reuse the compute kernels from the stateful random ops 102 functor::FillPhiloxRandom<Device, Distribution>()( 103 context, context->eigen_device<Device>(), random, flat.data(), 104 flat.size(), Distribution()); 105 } 106 }; 107 108 } // namespace 109 110 #define REGISTER(TYPE) \ 111 REGISTER_KERNEL_BUILDER( \ 112 Name("StatelessRandomUniform") \ 113 .Device(DEVICE_CPU) \ 114 .HostMemory("shape") \ 115 .TypeConstraint<TYPE>("dtype"), \ 116 StatelessRandomOp<CPUDevice, random::UniformDistribution< \ 117 random::PhiloxRandom, TYPE> >); \ 118 REGISTER_KERNEL_BUILDER( \ 119 Name("StatelessRandomNormal") \ 120 .Device(DEVICE_CPU) \ 121 .HostMemory("shape") \ 122 .TypeConstraint<TYPE>("dtype"), \ 123 StatelessRandomOp<CPUDevice, random::NormalDistribution< \ 124 random::PhiloxRandom, TYPE> >); \ 125 REGISTER_KERNEL_BUILDER( \ 126 Name("StatelessTruncatedNormal") \ 127 .Device(DEVICE_CPU) \ 128 .HostMemory("shape") \ 129 .TypeConstraint<TYPE>("dtype"), \ 130 StatelessRandomOp< \ 131 CPUDevice, \ 132 random::TruncatedNormalDistribution< \ 133 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); 134 135 TF_CALL_half(REGISTER); 136 TF_CALL_float(REGISTER); 137 TF_CALL_double(REGISTER); 138 139 #undef REGISTER 140 141 #if GOOGLE_CUDA 142 143 #define REGISTER(TYPE) \ 144 REGISTER_KERNEL_BUILDER( \ 145 Name("StatelessRandomUniform") \ 146 .Device(DEVICE_GPU) \ 147 .HostMemory("shape") \ 148 .HostMemory("seed") \ 149 .TypeConstraint<TYPE>("dtype"), \ 150 StatelessRandomOp<GPUDevice, random::UniformDistribution< \ 151 random::PhiloxRandom, TYPE> >); \ 152 REGISTER_KERNEL_BUILDER( \ 153 Name("StatelessRandomNormal") \ 154 .Device(DEVICE_GPU) \ 155 .HostMemory("shape") \ 156 .HostMemory("seed") \ 157 .TypeConstraint<TYPE>("dtype"), \ 158 StatelessRandomOp<GPUDevice, random::NormalDistribution< \ 159 random::PhiloxRandom, TYPE> >); \ 160 REGISTER_KERNEL_BUILDER( \ 161 Name("StatelessTruncatedNormal") \ 162 .Device(DEVICE_GPU) \ 163 .HostMemory("shape") \ 164 .HostMemory("seed") \ 165 .TypeConstraint<TYPE>("dtype"), \ 166 StatelessRandomOp< \ 167 GPUDevice, \ 168 random::TruncatedNormalDistribution< \ 169 random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); 170 171 TF_CALL_half(REGISTER); 172 TF_CALL_float(REGISTER); 173 TF_CALL_double(REGISTER); 174 175 #undef REGISTER 176 177 #endif // GOOGLE_CUDA 178 179 } // namespace tensorflow 180