1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific sequence and range Ops. 17 18 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/types.h" 27 28 namespace tensorflow { 29 namespace { 30 31 template <typename T> 32 Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { 33 xla::Literal literal; 34 TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); 35 *value = literal.Get<T>({}); 36 return Status::OK(); 37 } 38 39 Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { 40 xla::Literal literal; 41 TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); 42 switch (literal.shape().element_type()) { 43 case xla::S32: 44 *value = literal.Get<int32>({}); 45 break; 46 case xla::S64: 47 *value = literal.Get<int64>({}); 48 break; 49 default: 50 return errors::InvalidArgument("Invalid argument type for argument", 51 index); 52 } 53 return Status::OK(); 54 } 55 56 // The type-specific part of the implementation of Range. 57 template <typename T> 58 Status CreateRangeTensor(const xla::Literal& start_literal, 59 const xla::Literal& limit_literal, 60 const xla::Literal& delta_literal, Tensor* output) { 61 T start = start_literal.Get<T>({}); 62 T limit = limit_literal.Get<T>({}); 63 T delta = delta_literal.Get<T>({}); 64 65 if (delta == 0) { 66 return errors::InvalidArgument("Requires delta != 0: ", delta); 67 } 68 if (delta > 0) { 69 if (start > limit) { 70 return errors::InvalidArgument("Requires start <= limit when delta > 0: ", 71 start, "/", limit); 72 } 73 } else { 74 if (start < limit) { 75 return errors::InvalidArgument("Requires start >= limit when delta < 0: ", 76 start, "/", limit); 77 } 78 } 79 int64 size = 80 (std::is_integral<T>::value 81 ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) 82 : std::ceil(std::abs((limit - start) / delta))); 83 84 *output = Tensor(DataTypeToEnum<T>::v(), TensorShape({size})); 85 auto flat = output->flat<T>(); 86 T val = start; 87 for (int64 i = 0; i < size; ++i) { 88 flat(i) = val; 89 val += delta; 90 } 91 return Status::OK(); 92 } 93 94 class RangeOp : public XlaOpKernel { 95 public: 96 explicit RangeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 97 98 void Compile(XlaOpKernelContext* ctx) override { 99 const TensorShape start_in_shape = ctx->InputShape(0); 100 const TensorShape limit_in_shape = ctx->InputShape(1); 101 const TensorShape delta_in_shape = ctx->InputShape(2); 102 OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape), 103 errors::InvalidArgument("start must be a scalar, not shape ", 104 start_in_shape.DebugString())); 105 OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape), 106 errors::InvalidArgument("limit must be a scalar, not shape ", 107 limit_in_shape.DebugString())); 108 OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape), 109 errors::InvalidArgument("delta must be a scalar, not shape ", 110 delta_in_shape.DebugString())); 111 xla::Literal start, limit, delta; 112 OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start)); 113 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit)); 114 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta)); 115 116 DataType type = input_type(0); 117 Tensor output; 118 Status status; 119 switch (type) { 120 case DT_INT32: 121 status = CreateRangeTensor<int32>(start, limit, delta, &output); 122 break; 123 case DT_INT64: 124 status = CreateRangeTensor<int64>(start, limit, delta, &output); 125 break; 126 case DT_FLOAT: 127 status = CreateRangeTensor<float>(start, limit, delta, &output); 128 break; 129 case DT_DOUBLE: 130 status = CreateRangeTensor<double>(start, limit, delta, &output); 131 break; 132 default: 133 status = errors::InvalidArgument("Invalid type for Range ", 134 DataTypeString(type)); 135 } 136 OP_REQUIRES_OK(ctx, status); 137 ctx->SetConstantOutput(0, output); 138 } 139 }; 140 141 REGISTER_XLA_OP(Name("Range") 142 .CompileTimeConstInput("start") 143 .CompileTimeConstInput("limit") 144 .CompileTimeConstInput("delta"), 145 RangeOp); 146 147 class LinSpaceOp : public XlaOpKernel { 148 public: 149 explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 150 151 void Compile(XlaOpKernelContext* ctx) override { 152 const TensorShape start_in_shape = ctx->InputShape(0); 153 const TensorShape stop_in_shape = ctx->InputShape(1); 154 const TensorShape num_in_shape = ctx->InputShape(2); 155 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape), 156 errors::InvalidArgument("start must be a scalar, not shape ", 157 start_in_shape.DebugString())); 158 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_in_shape), 159 errors::InvalidArgument("stop must be a scalar, not shape ", 160 stop_in_shape.DebugString())); 161 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_in_shape), 162 errors::InvalidArgument("num must be a scalar, not shape ", 163 num_in_shape.DebugString())); 164 165 DataType type = ctx->input_type(0); 166 167 int64 num; 168 OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num)); 169 OP_REQUIRES(ctx, num > 0, 170 errors::InvalidArgument("Requires num > 0: ", num)); 171 Tensor out_constant(type, TensorShape({num})); 172 173 switch (type) { 174 case DT_FLOAT: { 175 float start, stop; 176 OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); 177 OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); 178 auto flat = out_constant.flat<float>(); 179 if (num == 1) { 180 flat(0) = start; 181 } else { 182 const float step = (stop - start) / (num - 1); 183 for (int64 i = 0; i < num; ++i) { 184 flat(i) = start + step * i; 185 } 186 } 187 break; 188 } 189 case DT_DOUBLE: { 190 double start, stop; 191 OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start)); 192 OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop)); 193 auto flat = out_constant.flat<double>(); 194 if (num == 1) { 195 flat(0) = start; 196 } else { 197 const double step = (stop - start) / (num - 1); 198 for (int64 i = 0; i < num; ++i) { 199 flat(i) = start + step * i; 200 } 201 } 202 break; 203 } 204 205 default: 206 ctx->SetStatus(errors::InvalidArgument("Invalid argument type ", 207 DataTypeString(type))); 208 return; 209 } 210 ctx->SetConstantOutput(0, out_constant); 211 } 212 }; 213 214 REGISTER_XLA_OP(Name("LinSpace") 215 .CompileTimeConstInput("start") 216 .CompileTimeConstInput("stop") 217 .CompileTimeConstInput("num"), 218 LinSpaceOp); 219 220 } // namespace 221 } // namespace tensorflow 222