Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
     17 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
     18 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
     19 #include "tensorflow/compiler/tf2xla/shape_util.h"
     20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     22 #include "tensorflow/compiler/xla/client/computation_builder.h"
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/core/framework/kernel_def_builder.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/kernels/bounds_check.h"
     27 #include "tensorflow/core/kernels/no_op.h"
     28 
     29 namespace tensorflow {
     30 namespace {
     31 
     32 class VarIsInitializedOp : public XlaOpKernel {
     33  public:
     34   explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     35   void Compile(XlaOpKernelContext* ctx) override {
     36     XlaResource* variable;
     37     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
     38     ctx->SetOutput(0,
     39                    ctx->builder()->ConstantR0<bool>(variable->initialized()));
     40   }
     41 };
     42 REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
     43 
     44 class ReadVariableOp : public XlaOpKernel {
     45  public:
     46   explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     47     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
     48   }
     49 
     50   void Compile(XlaOpKernelContext* ctx) override {
     51     xla::ComputationDataHandle handle;
     52     OP_REQUIRES_OK(
     53         ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
     54     ctx->SetOutput(0, handle);
     55   }
     56 
     57  private:
     58   DataType dtype_;
     59 };
     60 REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
     61 
     62 class AssignVariableOp : public XlaOpKernel {
     63  public:
     64   explicit AssignVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     65   void Compile(XlaOpKernelContext* ctx) override {
     66     OP_REQUIRES_OK(ctx,
     67                    ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1)));
     68   }
     69 };
     70 REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp);
     71 
     72 class AssignAddVariableOp : public XlaOpKernel {
     73  public:
     74   explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     75   void Compile(XlaOpKernelContext* ctx) override {
     76     DataType type = ctx->input_type(1);
     77     xla::ComputationDataHandle handle;
     78     OP_REQUIRES_OK(ctx,
     79                    ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
     80     handle = ctx->builder()->Add(handle, ctx->Input(1));
     81     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
     82   }
     83 };
     84 REGISTER_XLA_OP(
     85     Name("AssignAddVariableOp").TypeConstraint("dtype", kNumericTypes),
     86     AssignAddVariableOp);
     87 
     88 class AssignSubVariableOp : public XlaOpKernel {
     89  public:
     90   explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     91   void Compile(XlaOpKernelContext* ctx) override {
     92     DataType type = ctx->input_type(1);
     93     xla::ComputationDataHandle handle;
     94     OP_REQUIRES_OK(ctx,
     95                    ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
     96     handle = ctx->builder()->Sub(handle, ctx->Input(1));
     97     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
     98   }
     99 };
    100 REGISTER_XLA_OP(
    101     Name("AssignSubVariableOp").TypeConstraint("dtype", kNumericTypes),
    102     AssignSubVariableOp);
    103 
    104 class ResourceGatherOp : public XlaOpKernel {
    105  public:
    106   explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    107   void Compile(XlaOpKernelContext* ctx) override {
    108     xla::ComputationBuilder* builder = ctx->builder();
    109 
    110     DataType type = ctx->expected_output_dtype(0);
    111 
    112     TensorShape resource_shape;
    113     xla::ComputationDataHandle resource_handle;
    114     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
    115                                                &resource_handle));
    116 
    117     auto indices = ctx->Input(1);
    118     auto indices_shape = ctx->InputShape(1);
    119     DataType index_type = ctx->input_type(1);
    120     xla::ComputationDataHandle gather;
    121     OP_REQUIRES_OK(
    122         ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape,
    123                        /*axis=*/0, /*indices_are_nd=*/false, type, index_type,
    124                        builder, &gather));
    125     ctx->SetOutput(0, gather);
    126   }
    127 };
    128 REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes),
    129                 ResourceGatherOp);
    130 
    131 class VariableShapeOp : public XlaOpKernel {
    132  public:
    133   explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    134     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
    135   }
    136 
    137   void Compile(XlaOpKernelContext* ctx) override {
    138     DataType variable_dtype;
    139     TensorShape shape;
    140     OP_REQUIRES_OK(ctx,
    141                    ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape));
    142     Tensor shape_constant(out_dtype_, TensorShape({shape.dims()}));
    143     OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant));
    144     ctx->SetConstantOutput(0, shape_constant);
    145   }
    146 
    147  private:
    148   DataType out_dtype_;
    149 };
    150 
    151 REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp);
    152 }  // namespace
    153 }  // namespace tensorflow
    154