Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific Shape Ops.
     17 
     18 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
     19 #include "tensorflow/compiler/tf2xla/type_util.h"
     20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     23 #include "tensorflow/core/framework/kernel_def_builder.h"
     24 #include "tensorflow/core/kernels/bounds_check.h"
     25 
     26 namespace tensorflow {
     27 namespace {
     28 
     29 class ShapeOp : public XlaOpKernel {
     30  public:
     31   explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     32     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
     33   }
     34 
     35   void Compile(XlaOpKernelContext* ctx) override {
     36     const TensorShape input_shape = ctx->InputShape(0);
     37     Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
     38     OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
     39     ctx->SetConstantOutput(0, shape_constant);
     40   }
     41 
     42  private:
     43   DataType out_dtype_;
     44 };
     45 
     46 REGISTER_XLA_OP(Name("Shape"), ShapeOp);
     47 
     48 class ShapeNOp : public XlaOpKernel {
     49  public:
     50   explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     51     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
     52   }
     53 
     54   void Compile(XlaOpKernelContext* ctx) override {
     55     for (int i = 0; i < ctx->num_inputs(); ++i) {
     56       const TensorShape input_shape = ctx->InputShape(i);
     57       Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
     58       OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
     59       ctx->SetConstantOutput(i, shape_constant);
     60     }
     61   }
     62 
     63   bool IsExpensive() override { return false; }
     64 
     65  private:
     66   DataType out_dtype_;
     67 };
     68 REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp);
     69 
     70 class RankOp : public XlaOpKernel {
     71  public:
     72   explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     73 
     74   void Compile(XlaOpKernelContext* ctx) override {
     75     const TensorShape input_shape = ctx->InputShape(0);
     76     const int rank = input_shape.dims();
     77     Tensor rank_constant(DT_INT32, TensorShape({}));
     78     rank_constant.scalar<int32>()() = rank;
     79 
     80     ctx->SetConstantOutput(0, rank_constant);
     81   }
     82 };
     83 
     84 REGISTER_XLA_OP(Name("Rank"), RankOp);
     85 
     86 class SizeOp : public XlaOpKernel {
     87  public:
     88   explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     89 
     90   void Compile(XlaOpKernelContext* ctx) override {
     91     const TensorShape input_shape = ctx->InputShape(0);
     92     const int64 size = input_shape.num_elements();
     93     OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
     94                 errors::InvalidArgument("Size does not work for tensors > "
     95                                         "int32 max."));
     96     Tensor size_constant(DT_INT32, TensorShape({}));
     97     size_constant.scalar<int32>()() = static_cast<int32>(size);
     98 
     99     ctx->SetConstantOutput(0, size_constant);
    100   }
    101 };
    102 
    103 REGISTER_XLA_OP(Name("Size"), SizeOp);
    104 
    105 class ExpandDimsOp : public XlaOpKernel {
    106  public:
    107   explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    108 
    109   void Compile(XlaOpKernelContext* ctx) override {
    110     const TensorShape input_shape = ctx->InputShape(0);
    111     const TensorShape dim_shape = ctx->InputShape(1);
    112 
    113     // TODO(phawkins): the standard implementation of ExpandDimsOp seems to
    114     // accept legacy scalars, even when they should be forbidden by the graphdef
    115     // version.
    116     OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
    117                 errors::InvalidArgument(strings::StrCat(
    118                     "dim input to ExpandDims must be a scalar; got ",
    119                     dim_shape.DebugString())));
    120 
    121     xla::Literal literal;
    122     OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
    123 
    124     int dim = literal.data<int32>()[0];
    125 
    126     OP_REQUIRES(ctx,
    127                 (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
    128                 errors::InvalidArgument("Tried to expand dim index ", dim,
    129                                         " for tensor with ", input_shape.dims(),
    130                                         " dimensions."));
    131 
    132     auto existing_dims = input_shape.dim_sizes();
    133     // Safe - # elements in tensor dims bounded.
    134     const int existing_dims_size = static_cast<int>(existing_dims.size());
    135     std::vector<int64> new_shape(existing_dims_size);
    136     for (size_t i = 0; i < new_shape.size(); ++i) {
    137       new_shape[i] = existing_dims[i];
    138     }
    139 
    140     // We emulate numpy's interpretation of the dim axis when
    141     // -input.dims() >= dim <= input.dims().
    142     if (dim < 0) {
    143       dim += existing_dims.size() + 1;
    144     }
    145 
    146     // Clamp to the end if needed.
    147     dim = std::min<int32>(dim, existing_dims_size);
    148     new_shape.emplace(new_shape.begin() + dim, 1);
    149 
    150     ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
    151   }
    152 };
    153 REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp);
    154 
    155 class SqueezeOp : public XlaOpKernel {
    156  public:
    157   explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    158     std::vector<int32> squeeze_dims;
    159     OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
    160     squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
    161   }
    162 
    163   void Compile(XlaOpKernelContext* ctx) override {
    164     const TensorShape input_shape = ctx->InputShape(0);
    165     auto existing_dims = input_shape.dim_sizes();
    166     int existing_dims_size = input_shape.dims();
    167     std::vector<int64> new_shape;
    168 
    169     std::unordered_set<int32> wrapped_squeeze_dims;
    170     wrapped_squeeze_dims.reserve(squeeze_dims_.size());
    171     // Validate squeeze dims against the input.
    172     for (int32 dim : squeeze_dims_) {
    173       OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()),
    174                   errors::InvalidArgument("Tried to squeeze dim index ", dim,
    175                                           " for tensor with ",
    176                                           input_shape.dims(), " dimensions."));
    177       // If dim is < 0, we wrap around (-1 means the last element).
    178       if (dim < 0) {
    179         dim = existing_dims_size + dim;
    180       }
    181 
    182       wrapped_squeeze_dims.insert(dim);
    183     }
    184 
    185     for (int i = 0; i < existing_dims_size; ++i) {
    186       auto existing_dim = existing_dims[i];
    187 
    188       // If squeeze_set is non-empty, only squeeze those dimensions.
    189       if (!wrapped_squeeze_dims.empty()) {
    190         if (wrapped_squeeze_dims.count(i) > 0) {
    191           OP_REQUIRES(ctx, existing_dim == 1,
    192                       errors::InvalidArgument("Tried to explicitly squeeze "
    193                                               "dimension ",
    194                                               i, " but dimension was not 1: ",
    195                                               existing_dim));
    196         } else {
    197           // This dimension is not being squeezed.
    198           new_shape.push_back(existing_dim);
    199         }
    200       } else {
    201         // Copy over all non-1-length dimensions.
    202         if (existing_dim != 1) {
    203           new_shape.push_back(existing_dim);
    204         }
    205       }
    206     }
    207 
    208     ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
    209   }
    210 
    211  private:
    212   std::unordered_set<int32> squeeze_dims_;
    213 };
    214 
    215 REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp);
    216 
    217 class ZerosLikeOp : public XlaOpKernel {
    218  public:
    219   explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    220 
    221   void Compile(XlaOpKernelContext* ctx) override {
    222     const TensorShape input_shape = ctx->InputShape(0);
    223 
    224     auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
    225     ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes()));
    226   }
    227 };
    228 
    229 REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
    230 
    231 class OnesLikeOp : public XlaOpKernel {
    232  public:
    233   explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    234 
    235   void Compile(XlaOpKernelContext* ctx) override {
    236     const TensorShape input_shape = ctx->InputShape(0);
    237 
    238     auto one = XlaHelpers::One(ctx->builder(), input_type(0));
    239     ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
    240   }
    241 };
    242 
    243 REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
    244 
    245 }  // namespace
    246 }  // namespace tensorflow
    247