1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cmath> 17 18 #include "tensorflow/compiler/tf2xla/shape_util.h" 19 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/lib/core/casts.h" 27 #include "tensorflow/core/lib/math/math_util.h" 28 29 namespace tensorflow { 30 namespace { 31 32 // Rotates a 32-bit integer 'v' left by 'distance' bits. 33 xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, 34 const xla::ComputationDataHandle& v, 35 int distance) { 36 return builder->Or( 37 builder->ShiftLeft(v, builder->ConstantR0<int>(distance)), 38 builder->ShiftRightLogical(v, builder->ConstantR0<int>(32 - distance))); 39 } 40 41 // TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than 42 // building XOR out of other bitwise operators. 43 xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder, 44 const xla::ComputationDataHandle& x, 45 const xla::ComputationDataHandle& y) { 46 return builder->Or(builder->And(x, builder->Not(y)), 47 builder->And(builder->Not(x), y)); 48 } 49 50 using ThreeFry2x32State = std::array<xla::ComputationDataHandle, 2>; 51 52 // Implements the ThreeFry counter-based PRNG algorithm. 53 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. 54 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf 55 ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, 56 ThreeFry2x32State input, ThreeFry2x32State key) { 57 // Rotation distances specified by the Threefry2x32 algorithm. 58 constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24}; 59 ThreeFry2x32State x; 60 61 std::array<xla::ComputationDataHandle, 3> ks; 62 // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. 63 ks[2] = builder->ConstantR0<int32>(0x1BD11BDA); 64 for (int i = 0; i < 2; ++i) { 65 ks[i] = key[i]; 66 x[i] = input[i]; 67 ks[2] = BitwiseXor(builder, ks[2], key[i]); 68 } 69 70 x[0] = builder->Add(x[0], ks[0]); 71 x[1] = builder->Add(x[1], ks[1]); 72 73 // Performs a single round of the Threefry2x32 algorithm, with a rotation 74 // amount 'rotation'. 75 auto round = [builder](ThreeFry2x32State v, int rotation) { 76 v[0] = builder->Add(v[0], v[1]); 77 v[1] = RotateLeftS32(builder, v[1], rotation); 78 v[1] = BitwiseXor(builder, v[0], v[1]); 79 return v; 80 }; 81 82 // There are no known statistical flaws with 13 rounds of Threefry2x32. 83 // We are conservative and use 20 rounds. 84 x = round(x, rotations[0]); 85 x = round(x, rotations[1]); 86 x = round(x, rotations[2]); 87 x = round(x, rotations[3]); 88 x[0] = builder->Add(x[0], ks[1]); 89 x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(1)); 90 91 x = round(x, rotations[4]); 92 x = round(x, rotations[5]); 93 x = round(x, rotations[6]); 94 x = round(x, rotations[7]); 95 x[0] = builder->Add(x[0], ks[2]); 96 x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(2)); 97 98 x = round(x, rotations[0]); 99 x = round(x, rotations[1]); 100 x = round(x, rotations[2]); 101 x = round(x, rotations[3]); 102 x[0] = builder->Add(x[0], ks[0]); 103 x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0<int32>(3)); 104 105 x = round(x, rotations[4]); 106 x = round(x, rotations[5]); 107 x = round(x, rotations[6]); 108 x = round(x, rotations[7]); 109 x[0] = builder->Add(x[0], ks[1]); 110 x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(4)); 111 112 x = round(x, rotations[0]); 113 x = round(x, rotations[1]); 114 x = round(x, rotations[2]); 115 x = round(x, rotations[3]); 116 x[0] = builder->Add(x[0], ks[2]); 117 x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(5)); 118 119 return x; 120 } 121 122 // Returns a tensor of 'shape' random values uniformly distributed in the range 123 // [minval, maxval) 124 xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, 125 const xla::ComputationDataHandle& seed, 126 const TensorShape& shape, 127 double minval, double maxval) { 128 // Split the seed into two 32-bit scalars to form a key. 129 auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); 130 auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); 131 ThreeFry2x32State key = {seed0, seed1}; 132 const int64 size = shape.num_elements(); 133 134 const int64 half_size = MathUtil::CeilOfRatio<int64>(size, 2); 135 const bool size_is_odd = (half_size * 2 != size); 136 137 // Fill the generator inputs with unique counter values. 138 ThreeFry2x32State inputs; 139 TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); 140 inputs[1] = builder->Add(inputs[0], builder->ConstantR0<int32>(half_size)); 141 ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); 142 143 if (size_is_odd) { 144 outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1}); 145 } 146 147 auto bits = 148 builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes()); 149 150 // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit 151 // forces the random bits into the mantissa. 152 constexpr int kFloatBits = 32; 153 constexpr int kMantissaBits = 23; 154 bits = builder->Or( 155 builder->ShiftRightLogical( 156 bits, builder->ConstantR0<int32>(kFloatBits - kMantissaBits)), 157 builder->ConstantR0<int32>(bit_cast<int32>(1.0f))); 158 auto floats = builder->BitcastConvertType(bits, xla::F32); 159 160 // We have a floating point number in the range [1.0, 2.0). 161 // Subtract 1.0f to shift to the range [0.0, 1.0) 162 floats = builder->Sub(floats, builder->ConstantR0<float>(1.0f)); 163 // Multiply and add to shift to the range [minval, maxval). 164 floats = builder->Mul(floats, builder->ConstantR0<float>(maxval - minval)); 165 floats = builder->Add(floats, builder->ConstantR0<float>(minval)); 166 return floats; 167 } 168 169 // Approximation for the inverse error function from 170 // Giles, M., "Approximating the erfinv function". 171 // The approximation has the form: 172 // w = -log((1 - x) * (1 + x)) 173 // if ( w < 5 ) { 174 // w = w - 2.5 175 // p = sum_{i=1}^n lq[i]*w^i 176 // } else { 177 // w = sqrt(w) - 3 178 // p = sum_{i=1}^n gq[i]*w^i 179 // } 180 // return p*x 181 xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b, 182 const xla::ComputationDataHandle& x, 183 const TensorShape& shape) { 184 constexpr int kDegree = 9; 185 constexpr std::array<float, 9> w_less_than_5_constants = { 186 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, 187 -4.39150654e-06f, 0.00021858087f, -0.00125372503f, 188 -0.00417768164f, 0.246640727f, 1.50140941f}; 189 constexpr std::array<float, 9> w_greater_than_5_constants = { 190 -0.000200214257f, 0.000100950558f, 0.00134934322f, 191 -0.00367342844f, 0.00573950773f, -0.0076224613f, 192 0.00943887047f, 1.00167406f, 2.83297682f}; 193 194 auto one = b->ConstantR0<float>(1.0); 195 auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); 196 197 auto lt = b->Lt(w, b->ConstantR0<float>(5.0)); 198 auto coefficient = [&](int i) { 199 return b->Select( 200 lt, 201 b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]), 202 shape.dim_sizes()), 203 b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]), 204 shape.dim_sizes())); 205 }; 206 w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)), 207 b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f))); 208 auto p = coefficient(0); 209 for (int i = 1; i < kDegree; ++i) { 210 p = b->Add(coefficient(i), b->Mul(p, w)); 211 } 212 return b->Mul(p, x); 213 } 214 215 } // namespace 216 217 class StatelessRandomUniformOp : public XlaOpKernel { 218 public: 219 explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) 220 : XlaOpKernel(ctx) {} 221 222 void Compile(XlaOpKernelContext* ctx) override { 223 xla::ComputationBuilder* builder = ctx->builder(); 224 225 TensorShape shape; 226 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 227 228 TensorShape seed_shape = ctx->InputShape(1); 229 OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, 230 errors::InvalidArgument("seed must have shape [2], not ", 231 seed_shape.DebugString())); 232 xla::ComputationDataHandle seed = ctx->Input(1); 233 ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); 234 } 235 236 private: 237 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp); 238 }; 239 240 // TODO(phawkins): generalize to non-float, non-int32 seed types. 241 REGISTER_XLA_OP(Name("StatelessRandomUniform") 242 .TypeConstraint("dtype", DT_FLOAT) 243 .TypeConstraint("Tseed", DT_INT32), 244 StatelessRandomUniformOp); 245 246 class StatelessRandomNormalOp : public XlaOpKernel { 247 public: 248 explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) 249 : XlaOpKernel(ctx) {} 250 251 void Compile(XlaOpKernelContext* ctx) override { 252 TensorShape shape; 253 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 254 255 TensorShape seed_shape = ctx->InputShape(1); 256 OP_REQUIRES(ctx, seed_shape == TensorShape({2}), 257 errors::InvalidArgument("seed must have shape [2], not ", 258 seed_shape.DebugString())); 259 xla::ComputationDataHandle seed = ctx->Input(1); 260 xla::ComputationBuilder* builder = ctx->builder(); 261 auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); 262 // Convert uniform distribution to normal distribution by computing 263 // sqrt(2) * erfinv(x) 264 auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)), 265 ErfInvF32(builder, uniform, shape)); 266 ctx->SetOutput(0, normal); 267 } 268 269 private: 270 TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp); 271 }; 272 273 // TODO(phawkins): generalize to non-float, non-int32 seed types. 274 REGISTER_XLA_OP(Name("StatelessRandomNormal") 275 .TypeConstraint("dtype", DT_FLOAT) 276 .TypeConstraint("Tseed", DT_INT32), 277 StatelessRandomNormalOp); 278 279 } // namespace tensorflow 280