Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cmath>
     17 
     18 #include "tensorflow/compiler/tf2xla/shape_util.h"
     19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/lib/core/casts.h"
     27 #include "tensorflow/core/lib/math/math_util.h"
     28 
     29 namespace tensorflow {
     30 namespace {
     31 
     32 // Rotates a 32-bit integer 'v' left by 'distance' bits.
     33 xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder,
     34                                          const xla::ComputationDataHandle& v,
     35                                          int distance) {
     36   return builder->Or(
     37       builder->ShiftLeft(v, builder->ConstantR0<int>(distance)),
     38       builder->ShiftRightLogical(v, builder->ConstantR0<int>(32 - distance)));
     39 }
     40 
     41 // TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than
     42 // building XOR out of other bitwise operators.
     43 xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder,
     44                                       const xla::ComputationDataHandle& x,
     45                                       const xla::ComputationDataHandle& y) {
     46   return builder->Or(builder->And(x, builder->Not(y)),
     47                      builder->And(builder->Not(x), y));
     48 }
     49 
     50 using ThreeFry2x32State = std::array<xla::ComputationDataHandle, 2>;
     51 
     52 // Implements the ThreeFry counter-based PRNG algorithm.
     53 // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
     54 // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
     55 ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder,
     56                                ThreeFry2x32State input, ThreeFry2x32State key) {
     57   // Rotation distances specified by the Threefry2x32 algorithm.
     58   constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
     59   ThreeFry2x32State x;
     60 
     61   std::array<xla::ComputationDataHandle, 3> ks;
     62   // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
     63   ks[2] = builder->ConstantR0<int32>(0x1BD11BDA);
     64   for (int i = 0; i < 2; ++i) {
     65     ks[i] = key[i];
     66     x[i] = input[i];
     67     ks[2] = BitwiseXor(builder, ks[2], key[i]);
     68   }
     69 
     70   x[0] = builder->Add(x[0], ks[0]);
     71   x[1] = builder->Add(x[1], ks[1]);
     72 
     73   // Performs a single round of the Threefry2x32 algorithm, with a rotation
     74   // amount 'rotation'.
     75   auto round = [builder](ThreeFry2x32State v, int rotation) {
     76     v[0] = builder->Add(v[0], v[1]);
     77     v[1] = RotateLeftS32(builder, v[1], rotation);
     78     v[1] = BitwiseXor(builder, v[0], v[1]);
     79     return v;
     80   };
     81 
     82   // There are no known statistical flaws with 13 rounds of Threefry2x32.
     83   // We are conservative and use 20 rounds.
     84   x = round(x, rotations[0]);
     85   x = round(x, rotations[1]);
     86   x = round(x, rotations[2]);
     87   x = round(x, rotations[3]);
     88   x[0] = builder->Add(x[0], ks[1]);
     89   x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(1));
     90 
     91   x = round(x, rotations[4]);
     92   x = round(x, rotations[5]);
     93   x = round(x, rotations[6]);
     94   x = round(x, rotations[7]);
     95   x[0] = builder->Add(x[0], ks[2]);
     96   x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(2));
     97 
     98   x = round(x, rotations[0]);
     99   x = round(x, rotations[1]);
    100   x = round(x, rotations[2]);
    101   x = round(x, rotations[3]);
    102   x[0] = builder->Add(x[0], ks[0]);
    103   x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0<int32>(3));
    104 
    105   x = round(x, rotations[4]);
    106   x = round(x, rotations[5]);
    107   x = round(x, rotations[6]);
    108   x = round(x, rotations[7]);
    109   x[0] = builder->Add(x[0], ks[1]);
    110   x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(4));
    111 
    112   x = round(x, rotations[0]);
    113   x = round(x, rotations[1]);
    114   x = round(x, rotations[2]);
    115   x = round(x, rotations[3]);
    116   x[0] = builder->Add(x[0], ks[2]);
    117   x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(5));
    118 
    119   return x;
    120 }
    121 
    122 // Returns a tensor of 'shape' random values uniformly distributed in the range
    123 // [minval, maxval)
    124 xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder,
    125                                          const xla::ComputationDataHandle& seed,
    126                                          const TensorShape& shape,
    127                                          double minval, double maxval) {
    128   // Split the seed into two 32-bit scalars to form a key.
    129   auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {});
    130   auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {});
    131   ThreeFry2x32State key = {seed0, seed1};
    132   const int64 size = shape.num_elements();
    133 
    134   const int64 half_size = MathUtil::CeilOfRatio<int64>(size, 2);
    135   const bool size_is_odd = (half_size * 2 != size);
    136 
    137   // Fill the generator inputs with unique counter values.
    138   ThreeFry2x32State inputs;
    139   TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0]));
    140   inputs[1] = builder->Add(inputs[0], builder->ConstantR0<int32>(half_size));
    141   ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key);
    142 
    143   if (size_is_odd) {
    144     outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1});
    145   }
    146 
    147   auto bits =
    148       builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes());
    149 
    150   // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit
    151   // forces the random bits into the mantissa.
    152   constexpr int kFloatBits = 32;
    153   constexpr int kMantissaBits = 23;
    154   bits = builder->Or(
    155       builder->ShiftRightLogical(
    156           bits, builder->ConstantR0<int32>(kFloatBits - kMantissaBits)),
    157       builder->ConstantR0<int32>(bit_cast<int32>(1.0f)));
    158   auto floats = builder->BitcastConvertType(bits, xla::F32);
    159 
    160   // We have a floating point number in the range [1.0, 2.0).
    161   // Subtract 1.0f to shift to the range [0.0, 1.0)
    162   floats = builder->Sub(floats, builder->ConstantR0<float>(1.0f));
    163   // Multiply and add to shift to the range [minval, maxval).
    164   floats = builder->Mul(floats, builder->ConstantR0<float>(maxval - minval));
    165   floats = builder->Add(floats, builder->ConstantR0<float>(minval));
    166   return floats;
    167 }
    168 
    169 // Approximation for the inverse error function from
    170 //   Giles, M., "Approximating the erfinv function".
    171 // The approximation has the form:
    172 //   w = -log((1 - x) * (1 + x))
    173 //   if ( w < 5 ) {
    174 //     w = w - 2.5
    175 //     p = sum_{i=1}^n lq[i]*w^i
    176 //   } else {
    177 //     w = sqrt(w) - 3
    178 //     p = sum_{i=1}^n gq[i]*w^i
    179 //   }
    180 //   return p*x
    181 xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b,
    182                                      const xla::ComputationDataHandle& x,
    183                                      const TensorShape& shape) {
    184   constexpr int kDegree = 9;
    185   constexpr std::array<float, 9> w_less_than_5_constants = {
    186       2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
    187       -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
    188       -0.00417768164f,  0.246640727f,    1.50140941f};
    189   constexpr std::array<float, 9> w_greater_than_5_constants = {
    190       -0.000200214257f, 0.000100950558f, 0.00134934322f,
    191       -0.00367342844f,  0.00573950773f,  -0.0076224613f,
    192       0.00943887047f,   1.00167406f,     2.83297682f};
    193 
    194   auto one = b->ConstantR0<float>(1.0);
    195   auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x))));
    196 
    197   auto lt = b->Lt(w, b->ConstantR0<float>(5.0));
    198   auto coefficient = [&](int i) {
    199     return b->Select(
    200         lt,
    201         b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]),
    202                      shape.dim_sizes()),
    203         b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]),
    204                      shape.dim_sizes()));
    205   };
    206   w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)),
    207                 b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f)));
    208   auto p = coefficient(0);
    209   for (int i = 1; i < kDegree; ++i) {
    210     p = b->Add(coefficient(i), b->Mul(p, w));
    211   }
    212   return b->Mul(p, x);
    213 }
    214 
    215 }  // namespace
    216 
    217 class StatelessRandomUniformOp : public XlaOpKernel {
    218  public:
    219   explicit StatelessRandomUniformOp(OpKernelConstruction* ctx)
    220       : XlaOpKernel(ctx) {}
    221 
    222   void Compile(XlaOpKernelContext* ctx) override {
    223     xla::ComputationBuilder* builder = ctx->builder();
    224 
    225     TensorShape shape;
    226     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
    227 
    228     TensorShape seed_shape = ctx->InputShape(1);
    229     OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
    230                 errors::InvalidArgument("seed must have shape [2], not ",
    231                                         seed_shape.DebugString()));
    232     xla::ComputationDataHandle seed = ctx->Input(1);
    233     ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0));
    234   }
    235 
    236  private:
    237   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp);
    238 };
    239 
    240 // TODO(phawkins): generalize to non-float, non-int32 seed types.
    241 REGISTER_XLA_OP(Name("StatelessRandomUniform")
    242                     .TypeConstraint("dtype", DT_FLOAT)
    243                     .TypeConstraint("Tseed", DT_INT32),
    244                 StatelessRandomUniformOp);
    245 
    246 class StatelessRandomNormalOp : public XlaOpKernel {
    247  public:
    248   explicit StatelessRandomNormalOp(OpKernelConstruction* ctx)
    249       : XlaOpKernel(ctx) {}
    250 
    251   void Compile(XlaOpKernelContext* ctx) override {
    252     TensorShape shape;
    253     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
    254 
    255     TensorShape seed_shape = ctx->InputShape(1);
    256     OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
    257                 errors::InvalidArgument("seed must have shape [2], not ",
    258                                         seed_shape.DebugString()));
    259     xla::ComputationDataHandle seed = ctx->Input(1);
    260     xla::ComputationBuilder* builder = ctx->builder();
    261     auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0);
    262     // Convert uniform distribution to normal distribution by computing
    263     // sqrt(2) * erfinv(x)
    264     auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)),
    265                                ErfInvF32(builder, uniform, shape));
    266     ctx->SetOutput(0, normal);
    267   }
    268 
    269  private:
    270   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp);
    271 };
    272 
    273 // TODO(phawkins): generalize to non-float, non-int32 seed types.
    274 REGISTER_XLA_OP(Name("StatelessRandomNormal")
    275                     .TypeConstraint("dtype", DT_FLOAT)
    276                     .TypeConstraint("Tseed", DT_INT32),
    277                 StatelessRandomNormalOp);
    278 
    279 }  // namespace tensorflow
    280