1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA implementations of Random ops 17 // TODO(misard,phawkins): handle random number generator seeds/states correctly. 18 // TODO(misard,phawkins): add tests. 19 20 #include "tensorflow/compiler/tf2xla/shape_util.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 29 namespace tensorflow { 30 namespace { 31 32 class RandomUniformOp : public XlaOpKernel { 33 public: 34 explicit RandomUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 35 36 void Compile(XlaOpKernelContext* ctx) override { 37 TensorShape shape; 38 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 39 40 const DataType dtype = output_type(0); 41 xla::Shape xla_shape; 42 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); 43 44 xla::ComputationBuilder* b = ctx->builder(); 45 xla::ComputationDataHandle result = b->RngUniform( 46 XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); 47 48 ctx->SetOutput(0, result); 49 } 50 51 private: 52 TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); 53 }; 54 55 REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), 56 RandomUniformOp); 57 58 class RandomUniformIntOp : public XlaOpKernel { 59 public: 60 explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 61 62 void Compile(XlaOpKernelContext* ctx) override { 63 TensorShape shape; 64 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 65 xla::Shape xla_shape; 66 OP_REQUIRES_OK(ctx, 67 TensorShapeToXLAShape(input_type(1), shape, &xla_shape)); 68 69 const TensorShape minval_shape = ctx->InputShape(1); 70 const TensorShape maxval_shape = ctx->InputShape(2); 71 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape), 72 errors::InvalidArgument("minval must be 0-D, got shape ", 73 minval_shape.DebugString())); 74 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape), 75 errors::InvalidArgument("maxval must be 0-D, got shape ", 76 maxval_shape.DebugString())); 77 78 auto minval = ctx->Input(1); 79 auto maxval = ctx->Input(2); 80 ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape)); 81 } 82 83 private: 84 TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); 85 }; 86 87 REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"), 88 RandomUniformIntOp); 89 90 class RandomStandardNormalOp : public XlaOpKernel { 91 public: 92 explicit RandomStandardNormalOp(OpKernelConstruction* ctx) 93 : XlaOpKernel(ctx) {} 94 95 void Compile(XlaOpKernelContext* ctx) override { 96 const DataType dtype = output_type(0); 97 98 TensorShape shape; 99 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 100 xla::Shape xla_shape; 101 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); 102 103 xla::ComputationBuilder* b = ctx->builder(); 104 105 // Normal distribution with a mean of 0 and a standard deviation of 1: 106 xla::ComputationDataHandle result = b->RngNormal( 107 XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); 108 109 ctx->SetOutput(0, result); 110 } 111 112 private: 113 TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); 114 }; 115 116 REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"), 117 RandomStandardNormalOp); 118 119 class TruncatedNormalOp : public XlaOpKernel { 120 public: 121 explicit TruncatedNormalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 122 123 void Compile(XlaOpKernelContext* ctx) override { 124 const DataType dtype = output_type(0); 125 126 TensorShape shape; 127 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); 128 xla::Shape xla_shape; 129 OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); 130 xla::Shape xla_element_shape = 131 xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); 132 133 xla::ComputationBuilder* b = ctx->builder(); 134 xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); 135 xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); 136 xla::ComputationDataHandle candidate = 137 b->RngNormal(mean, stddev, xla_shape); 138 139 auto two_sd = [dtype](bool negate, xla::ComputationBuilder* b) { 140 return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); 141 }; 142 auto out_of_range_mask = [two_sd](xla::ComputationDataHandle candidate, 143 xla::ComputationBuilder* b) { 144 xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b)); 145 xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b)); 146 return b->Or(too_large, too_small); 147 }; 148 149 // The algorithm we're using is roughly: 150 // 151 // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) { 152 // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd 153 // candidate = select(out_of_range_mask, rng_normal(), candidate) 154 // } 155 std::unique_ptr<xla::ComputationBuilder> test_builder = 156 b->CreateSubBuilder("truncated_normal_test"); 157 { 158 auto* b = test_builder.get(); 159 xla::ComputationDataHandle candidate = 160 b->Parameter(0, xla_shape, "candidate"); 161 xla::ComputationDataHandle oor_mask = out_of_range_mask(candidate, b); 162 OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); 163 } 164 165 std::unique_ptr<xla::ComputationBuilder> body_builder = 166 b->CreateSubBuilder("truncated_normal_body"); 167 { 168 auto* b = body_builder.get(); 169 xla::ComputationDataHandle candidate = 170 b->Parameter(0, xla_shape, "candidate"); 171 xla::ComputationDataHandle to_resample = out_of_range_mask(candidate, b); 172 xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); 173 xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); 174 b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); 175 } 176 177 xla::StatusOr<xla::Computation> test_computation = test_builder->Build(); 178 OP_REQUIRES_OK(ctx, test_computation.status()); 179 xla::StatusOr<xla::Computation> body_computation = body_builder->Build(); 180 OP_REQUIRES_OK(ctx, body_computation.status()); 181 xla::ComputationDataHandle result = 182 b->While(test_computation.ValueOrDie(), body_computation.ValueOrDie(), 183 candidate); 184 185 ctx->SetOutput(0, result); 186 } 187 }; 188 189 REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), 190 TruncatedNormalOp); 191 192 } // anonymous namespace 193 } // namespace tensorflow 194