Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific sequence and range Ops.
     17 
     18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/types.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 
     31 template <typename T>
     32 Status GetValue(int index, XlaOpKernelContext* ctx, T* value) {
     33   xla::Literal literal;
     34   TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
     35   *value = literal.Get<T>({});
     36   return Status::OK();
     37 }
     38 
     39 Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
     40   xla::Literal literal;
     41   TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
     42   switch (literal.shape().element_type()) {
     43     case xla::S32:
     44       *value = literal.Get<int32>({});
     45       break;
     46     case xla::S64:
     47       *value = literal.Get<int64>({});
     48       break;
     49     default:
     50       return errors::InvalidArgument("Invalid argument type for argument",
     51                                      index);
     52   }
     53   return Status::OK();
     54 }
     55 
     56 // The type-specific part of the implementation of Range.
     57 template <typename T>
     58 Status CreateRangeTensor(const xla::Literal& start_literal,
     59                          const xla::Literal& limit_literal,
     60                          const xla::Literal& delta_literal, Tensor* output) {
     61   T start = start_literal.Get<T>({});
     62   T limit = limit_literal.Get<T>({});
     63   T delta = delta_literal.Get<T>({});
     64 
     65   if (delta == 0) {
     66     return errors::InvalidArgument("Requires delta != 0: ", delta);
     67   }
     68   if (delta > 0) {
     69     if (start > limit) {
     70       return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
     71                                      start, "/", limit);
     72     }
     73   } else {
     74     if (start < limit) {
     75       return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
     76                                      start, "/", limit);
     77     }
     78   }
     79   int64 size =
     80       (std::is_integral<T>::value
     81            ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
     82            : std::ceil(std::abs((limit - start) / delta)));
     83 
     84   *output = Tensor(DataTypeToEnum<T>::v(), TensorShape({size}));
     85   auto flat = output->flat<T>();
     86   T val = start;
     87   for (int64 i = 0; i < size; ++i) {
     88     flat(i) = val;
     89     val += delta;
     90   }
     91   return Status::OK();
     92 }
     93 
     94 class RangeOp : public XlaOpKernel {
     95  public:
     96   explicit RangeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     97 
     98   void Compile(XlaOpKernelContext* ctx) override {
     99     const TensorShape start_in_shape = ctx->InputShape(0);
    100     const TensorShape limit_in_shape = ctx->InputShape(1);
    101     const TensorShape delta_in_shape = ctx->InputShape(2);
    102     OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape),
    103                 errors::InvalidArgument("start must be a scalar, not shape ",
    104                                         start_in_shape.DebugString()));
    105     OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape),
    106                 errors::InvalidArgument("limit must be a scalar, not shape ",
    107                                         limit_in_shape.DebugString()));
    108     OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape),
    109                 errors::InvalidArgument("delta must be a scalar, not shape ",
    110                                         delta_in_shape.DebugString()));
    111     xla::Literal start, limit, delta;
    112     OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start));
    113     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
    114     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta));
    115 
    116     DataType type = input_type(0);
    117     Tensor output;
    118     Status status;
    119     switch (type) {
    120       case DT_INT32:
    121         status = CreateRangeTensor<int32>(start, limit, delta, &output);
    122         break;
    123       case DT_INT64:
    124         status = CreateRangeTensor<int64>(start, limit, delta, &output);
    125         break;
    126       case DT_FLOAT:
    127         status = CreateRangeTensor<float>(start, limit, delta, &output);
    128         break;
    129       case DT_DOUBLE:
    130         status = CreateRangeTensor<double>(start, limit, delta, &output);
    131         break;
    132       default:
    133         status = errors::InvalidArgument("Invalid type for Range ",
    134                                          DataTypeString(type));
    135     }
    136     OP_REQUIRES_OK(ctx, status);
    137     ctx->SetConstantOutput(0, output);
    138   }
    139 };
    140 
    141 REGISTER_XLA_OP(Name("Range")
    142                     .CompileTimeConstInput("start")
    143                     .CompileTimeConstInput("limit")
    144                     .CompileTimeConstInput("delta"),
    145                 RangeOp);
    146 
    147 class LinSpaceOp : public XlaOpKernel {
    148  public:
    149   explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    150 
    151   void Compile(XlaOpKernelContext* ctx) override {
    152     const TensorShape start_in_shape = ctx->InputShape(0);
    153     const TensorShape stop_in_shape = ctx->InputShape(1);
    154     const TensorShape num_in_shape = ctx->InputShape(2);
    155     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
    156                 errors::InvalidArgument("start must be a scalar, not shape ",
    157                                         start_in_shape.DebugString()));
    158     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_in_shape),
    159                 errors::InvalidArgument("stop must be a scalar, not shape ",
    160                                         stop_in_shape.DebugString()));
    161     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_in_shape),
    162                 errors::InvalidArgument("num must be a scalar, not shape ",
    163                                         num_in_shape.DebugString()));
    164 
    165     DataType type = ctx->input_type(0);
    166 
    167     int64 num;
    168     OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num));
    169     OP_REQUIRES(ctx, num > 0,
    170                 errors::InvalidArgument("Requires num > 0: ", num));
    171     Tensor out_constant(type, TensorShape({num}));
    172 
    173     switch (type) {
    174       case DT_FLOAT: {
    175         float start, stop;
    176         OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start));
    177         OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop));
    178         auto flat = out_constant.flat<float>();
    179         if (num == 1) {
    180           flat(0) = start;
    181         } else {
    182           const float step = (stop - start) / (num - 1);
    183           for (int64 i = 0; i < num; ++i) {
    184             flat(i) = start + step * i;
    185           }
    186         }
    187         break;
    188       }
    189       case DT_DOUBLE: {
    190         double start, stop;
    191         OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start));
    192         OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop));
    193         auto flat = out_constant.flat<double>();
    194         if (num == 1) {
    195           flat(0) = start;
    196         } else {
    197           const double step = (stop - start) / (num - 1);
    198           for (int64 i = 0; i < num; ++i) {
    199             flat(i) = start + step * i;
    200           }
    201         }
    202         break;
    203       }
    204 
    205       default:
    206         ctx->SetStatus(errors::InvalidArgument("Invalid argument type ",
    207                                                DataTypeString(type)));
    208         return;
    209     }
    210     ctx->SetConstantOutput(0, out_constant);
    211   }
    212 };
    213 
    214 REGISTER_XLA_OP(Name("LinSpace")
    215                     .CompileTimeConstInput("start")
    216                     .CompileTimeConstInput("stop")
    217                     .CompileTimeConstInput("num"),
    218                 LinSpaceOp);
    219 
    220 }  // namespace
    221 }  // namespace tensorflow
    222