1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" 17 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" 18 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" 19 #include "tensorflow/compiler/tf2xla/shape_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/computation_builder.h" 23 #include "tensorflow/compiler/xla/literal_util.h" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/kernels/bounds_check.h" 27 #include "tensorflow/core/kernels/no_op.h" 28 29 namespace tensorflow { 30 namespace { 31 32 class VarIsInitializedOp : public XlaOpKernel { 33 public: 34 explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 35 void Compile(XlaOpKernelContext* ctx) override { 36 XlaResource* variable; 37 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); 38 ctx->SetOutput(0, 39 ctx->builder()->ConstantR0<bool>(variable->initialized())); 40 } 41 }; 42 REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); 43 44 class ReadVariableOp : public XlaOpKernel { 45 public: 46 explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 47 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 48 } 49 50 void Compile(XlaOpKernelContext* ctx) override { 51 xla::ComputationDataHandle handle; 52 OP_REQUIRES_OK( 53 ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); 54 ctx->SetOutput(0, handle); 55 } 56 57 private: 58 DataType dtype_; 59 }; 60 REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); 61 62 class AssignVariableOp : public XlaOpKernel { 63 public: 64 explicit AssignVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 65 void Compile(XlaOpKernelContext* ctx) override { 66 OP_REQUIRES_OK(ctx, 67 ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); 68 } 69 }; 70 REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp); 71 72 class AssignAddVariableOp : public XlaOpKernel { 73 public: 74 explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 75 void Compile(XlaOpKernelContext* ctx) override { 76 DataType type = ctx->input_type(1); 77 xla::ComputationDataHandle handle; 78 OP_REQUIRES_OK(ctx, 79 ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); 80 handle = ctx->builder()->Add(handle, ctx->Input(1)); 81 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); 82 } 83 }; 84 REGISTER_XLA_OP( 85 Name("AssignAddVariableOp").TypeConstraint("dtype", kNumericTypes), 86 AssignAddVariableOp); 87 88 class AssignSubVariableOp : public XlaOpKernel { 89 public: 90 explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 91 void Compile(XlaOpKernelContext* ctx) override { 92 DataType type = ctx->input_type(1); 93 xla::ComputationDataHandle handle; 94 OP_REQUIRES_OK(ctx, 95 ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); 96 handle = ctx->builder()->Sub(handle, ctx->Input(1)); 97 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); 98 } 99 }; 100 REGISTER_XLA_OP( 101 Name("AssignSubVariableOp").TypeConstraint("dtype", kNumericTypes), 102 AssignSubVariableOp); 103 104 class ResourceGatherOp : public XlaOpKernel { 105 public: 106 explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 107 void Compile(XlaOpKernelContext* ctx) override { 108 xla::ComputationBuilder* builder = ctx->builder(); 109 110 DataType type = ctx->expected_output_dtype(0); 111 112 TensorShape resource_shape; 113 xla::ComputationDataHandle resource_handle; 114 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, 115 &resource_handle)); 116 117 auto indices = ctx->Input(1); 118 auto indices_shape = ctx->InputShape(1); 119 DataType index_type = ctx->input_type(1); 120 xla::ComputationDataHandle gather; 121 OP_REQUIRES_OK( 122 ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, 123 /*axis=*/0, /*indices_are_nd=*/false, type, index_type, 124 builder, &gather)); 125 ctx->SetOutput(0, gather); 126 } 127 }; 128 REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), 129 ResourceGatherOp); 130 131 class VariableShapeOp : public XlaOpKernel { 132 public: 133 explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 134 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); 135 } 136 137 void Compile(XlaOpKernelContext* ctx) override { 138 DataType variable_dtype; 139 TensorShape shape; 140 OP_REQUIRES_OK(ctx, 141 ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); 142 Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); 143 OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); 144 ctx->SetConstantOutput(0, shape_constant); 145 } 146 147 private: 148 DataType out_dtype_; 149 }; 150 151 REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); 152 } // namespace 153 } // namespace tensorflow 154