Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA implementations of Random ops
     17 // TODO(misard,phawkins): handle random number generator seeds/states correctly.
     18 // TODO(misard,phawkins): add tests.
     19 
     20 #include "tensorflow/compiler/tf2xla/shape_util.h"
     21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     24 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 
     29 namespace tensorflow {
     30 namespace {
     31 
     32 class RandomUniformOp : public XlaOpKernel {
     33  public:
     34   explicit RandomUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     35 
     36   void Compile(XlaOpKernelContext* ctx) override {
     37     TensorShape shape;
     38     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
     39 
     40     const DataType dtype = output_type(0);
     41     xla::Shape xla_shape;
     42     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
     43 
     44     xla::ComputationBuilder* b = ctx->builder();
     45     xla::ComputationDataHandle result = b->RngUniform(
     46         XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape);
     47 
     48     ctx->SetOutput(0, result);
     49   }
     50 
     51  private:
     52   TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp);
     53 };
     54 
     55 REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"),
     56                 RandomUniformOp);
     57 
     58 class RandomUniformIntOp : public XlaOpKernel {
     59  public:
     60   explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     61 
     62   void Compile(XlaOpKernelContext* ctx) override {
     63     TensorShape shape;
     64     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
     65     xla::Shape xla_shape;
     66     OP_REQUIRES_OK(ctx,
     67                    TensorShapeToXLAShape(input_type(1), shape, &xla_shape));
     68 
     69     const TensorShape minval_shape = ctx->InputShape(1);
     70     const TensorShape maxval_shape = ctx->InputShape(2);
     71     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
     72                 errors::InvalidArgument("minval must be 0-D, got shape ",
     73                                         minval_shape.DebugString()));
     74     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
     75                 errors::InvalidArgument("maxval must be 0-D, got shape ",
     76                                         maxval_shape.DebugString()));
     77 
     78     auto minval = ctx->Input(1);
     79     auto maxval = ctx->Input(2);
     80     ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape));
     81   }
     82 
     83  private:
     84   TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp);
     85 };
     86 
     87 REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"),
     88                 RandomUniformIntOp);
     89 
     90 class RandomStandardNormalOp : public XlaOpKernel {
     91  public:
     92   explicit RandomStandardNormalOp(OpKernelConstruction* ctx)
     93       : XlaOpKernel(ctx) {}
     94 
     95   void Compile(XlaOpKernelContext* ctx) override {
     96     const DataType dtype = output_type(0);
     97 
     98     TensorShape shape;
     99     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
    100     xla::Shape xla_shape;
    101     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
    102 
    103     xla::ComputationBuilder* b = ctx->builder();
    104 
    105     // Normal distribution with a mean of 0 and a standard deviation of 1:
    106     xla::ComputationDataHandle result = b->RngNormal(
    107         XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape);
    108 
    109     ctx->SetOutput(0, result);
    110   }
    111 
    112  private:
    113   TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp);
    114 };
    115 
    116 REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"),
    117                 RandomStandardNormalOp);
    118 
    119 class TruncatedNormalOp : public XlaOpKernel {
    120  public:
    121   explicit TruncatedNormalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    122 
    123   void Compile(XlaOpKernelContext* ctx) override {
    124     const DataType dtype = output_type(0);
    125 
    126     TensorShape shape;
    127     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
    128     xla::Shape xla_shape;
    129     OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
    130     xla::Shape xla_element_shape =
    131         xla::ShapeUtil::MakeShape(xla_shape.element_type(), {});
    132 
    133     xla::ComputationBuilder* b = ctx->builder();
    134     xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype);
    135     xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype);
    136     xla::ComputationDataHandle candidate =
    137         b->RngNormal(mean, stddev, xla_shape);
    138 
    139     auto two_sd = [dtype](bool negate, xla::ComputationBuilder* b) {
    140       return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0);
    141     };
    142     auto out_of_range_mask = [two_sd](xla::ComputationDataHandle candidate,
    143                                       xla::ComputationBuilder* b) {
    144       xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b));
    145       xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b));
    146       return b->Or(too_large, too_small);
    147     };
    148 
    149     // The algorithm we're using is roughly:
    150     //
    151     // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) {
    152     //   out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd
    153     //   candidate = select(out_of_range_mask, rng_normal(), candidate)
    154     // }
    155     std::unique_ptr<xla::ComputationBuilder> test_builder =
    156         b->CreateSubBuilder("truncated_normal_test");
    157     {
    158       auto* b = test_builder.get();
    159       xla::ComputationDataHandle candidate =
    160           b->Parameter(0, xla_shape, "candidate");
    161       xla::ComputationDataHandle oor_mask = out_of_range_mask(candidate, b);
    162       OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status());
    163     }
    164 
    165     std::unique_ptr<xla::ComputationBuilder> body_builder =
    166         b->CreateSubBuilder("truncated_normal_body");
    167     {
    168       auto* b = body_builder.get();
    169       xla::ComputationDataHandle candidate =
    170           b->Parameter(0, xla_shape, "candidate");
    171       xla::ComputationDataHandle to_resample = out_of_range_mask(candidate, b);
    172       xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype);
    173       xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype);
    174       b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate);
    175     }
    176 
    177     xla::StatusOr<xla::Computation> test_computation = test_builder->Build();
    178     OP_REQUIRES_OK(ctx, test_computation.status());
    179     xla::StatusOr<xla::Computation> body_computation = body_builder->Build();
    180     OP_REQUIRES_OK(ctx, body_computation.status());
    181     xla::ComputationDataHandle result =
    182         b->While(test_computation.ValueOrDie(), body_computation.ValueOrDie(),
    183                  candidate);
    184 
    185     ctx->SetOutput(0, result);
    186   }
    187 };
    188 
    189 REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"),
    190                 TruncatedNormalOp);
    191 
    192 }  // anonymous namespace
    193 }  // namespace tensorflow
    194