Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_
     17 #define TENSORFLOW_KERNELS_SHAPE_OPS_H_
     18 
     19 #include <limits>
     20 #include <unordered_set>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/tensor_shape.h"
     26 #include "tensorflow/core/framework/variant_op_registry.h"
     27 #include "tensorflow/core/kernels/bounds_check.h"
     28 
     29 namespace tensorflow {
     30 
     31 namespace shape_op_helpers {
     32 inline Status GetRegularOrVariantShape(OpKernelContext* ctx, int input_index,
     33                                        TensorShape* shape) {
     34   const Tensor& inp = ctx->input(input_index);
     35   if (ctx->input_dtype(0) == DT_VARIANT) {
     36     if (inp.dims() != 0) {
     37       return errors::InvalidArgument(
     38           "Shape of non-unary Variant not supported.");
     39     }
     40     TF_RETURN_IF_ERROR(GetUnaryVariantShape(inp, shape));
     41   } else {
     42     *shape = inp.shape();
     43   }
     44   return Status::OK();
     45 }
     46 }  // namespace shape_op_helpers
     47 
     48 template <typename OutType>
     49 class ShapeOp : public OpKernel {
     50  public:
     51   explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     52 
     53   void Compute(OpKernelContext* ctx) override {
     54     TensorShape shape;
     55     OP_REQUIRES_OK(ctx,
     56                    shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
     57     const int rank = shape.dims();
     58     Tensor* out = nullptr;
     59     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
     60     auto vec = out->vec<OutType>();
     61     for (int i = 0; i < rank; ++i) {
     62       int64 dim_size = shape.dim_size(i);
     63       if (out->dtype() == DT_INT32) {
     64         OP_REQUIRES(
     65             ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
     66             errors::InvalidArgument("Shape output type is 32-bit ", " but dim ",
     67                                     i, " is ", dim_size));
     68       }
     69       vec(i) = static_cast<OutType>(dim_size);
     70     }
     71   }
     72 
     73   bool IsExpensive() override { return false; }
     74 };
     75 
     76 template <typename OutType>
     77 class ShapeNOp : public OpKernel {
     78  public:
     79   explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     80 
     81   void Compute(OpKernelContext* ctx) override {
     82     for (int i = 0; i < ctx->num_inputs(); ++i) {
     83       TensorShape shape;
     84       OP_REQUIRES_OK(
     85           ctx, shape_op_helpers::GetRegularOrVariantShape(ctx, i, &shape));
     86       const int dims = shape.dims();
     87       Tensor* out = nullptr;
     88       OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
     89       auto vec = out->vec<OutType>();
     90 
     91       for (int j = 0; j < dims; ++j) {
     92         int64 dim_size = shape.dim_size(j);
     93         if (out->dtype() == DT_INT32) {
     94           OP_REQUIRES(
     95               ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
     96               errors::InvalidArgument("ShapeN output type is 32-bit but shape ",
     97                                       i, " dim ", j, " is ", dim_size));
     98         }
     99         vec(j) = static_cast<OutType>(dim_size);
    100       }
    101     }
    102   }
    103 
    104   bool IsExpensive() override { return false; }
    105 };
    106 
    107 class RankOp : public OpKernel {
    108  public:
    109   explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    110 
    111   void Compute(OpKernelContext* ctx) override {
    112     TensorShape shape;
    113     OP_REQUIRES_OK(ctx,
    114                    shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
    115     const int rank = shape.dims();
    116     Tensor* out = nullptr;
    117     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
    118     out->scalar<int32>()() = rank;
    119   }
    120 
    121   bool IsExpensive() override { return false; }
    122 };
    123 
    124 template <typename OutType>
    125 class SizeOp : public OpKernel {
    126  public:
    127   explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    128 
    129   void Compute(OpKernelContext* ctx) override {
    130     TensorShape shape;
    131     OP_REQUIRES_OK(ctx,
    132                    shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
    133     const int64 size = shape.num_elements();
    134     Tensor* out = nullptr;
    135     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
    136     if (out->dtype() == DT_INT32) {
    137       OP_REQUIRES(
    138           ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
    139           errors::InvalidArgument("Number of elements was larger than "
    140                                   "representable by 32-bit output type"));
    141     }
    142     out->scalar<OutType>()() = static_cast<OutType>(size);
    143   }
    144 
    145   bool IsExpensive() override { return false; }
    146 };
    147 
    148 template <typename Tdim>
    149 class ExpandDimsOp : public OpKernel {
    150  public:
    151   explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    152 
    153   void Compute(OpKernelContext* ctx) override {
    154     OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
    155                 errors::InvalidArgument("ExpandDims on Variant not supported"));
    156 
    157     Tdim dim = ctx->input(1).flat<Tdim>()(0);
    158     OP_REQUIRES(
    159         ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
    160         errors::InvalidArgument("Tried to expand dim index ", dim,
    161                                 " for tensor with ", ctx->input(0).dims(),
    162                                 " dimensions."));
    163 
    164     auto existing_dims = ctx->input(0).shape().dim_sizes();
    165     // Safe - # elements in tensor dims bounded.
    166     const int existing_dims_size = static_cast<int>(existing_dims.size());
    167     std::vector<int64> new_shape(existing_dims_size);
    168     for (size_t i = 0; i < new_shape.size(); ++i) {
    169       new_shape[i] = existing_dims[i];
    170     }
    171 
    172     // We emulate numpy's interpretation of the dim axis when
    173     // -input.dims() >= dim <= input.dims().
    174     if (dim < 0) {
    175       dim += existing_dims.size() + 1;
    176     }
    177 
    178     // Clamp to the end if needed.
    179     dim = std::min<Tdim>(dim, existing_dims_size);
    180     new_shape.emplace(new_shape.begin() + dim, 1);
    181     const TensorShape output_shape(new_shape);
    182 
    183     Tensor* output = nullptr;
    184     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
    185     if (!output->CopyFrom(ctx->input(0), output_shape)) {
    186       // This should never happen, since the sizes of the input and output
    187       // should always be the same (we only expand the dimension with 1).
    188       ctx->SetStatus(
    189           errors::Internal("Could not expand dimension with input shape ",
    190                            ctx->input(0).shape().DebugString(),
    191                            " and output shape ", output_shape.DebugString()));
    192     }
    193   }
    194 
    195   bool IsExpensive() override { return false; }
    196 };
    197 
    198 class SqueezeOp : public OpKernel {
    199  public:
    200   explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    201     std::vector<int32> squeeze_dims;
    202     OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
    203     squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
    204   }
    205 
    206   void Compute(OpKernelContext* ctx) override {
    207     OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
    208                 errors::InvalidArgument("Squeeze on Variant not supported"));
    209 
    210     auto existing_dims = ctx->input(0).shape().dim_sizes();
    211     const int existing_dims_size = static_cast<int>(existing_dims.size());
    212     std::vector<int64> new_shape;
    213 
    214     std::unordered_set<int32> wrapped_squeeze_dims;
    215     wrapped_squeeze_dims.reserve(squeeze_dims_.size());
    216     // Validate squeeze dims against the input.
    217     for (int32 dim : squeeze_dims_) {
    218       OP_REQUIRES(
    219           ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
    220           errors::InvalidArgument("Tried to squeeze dim index ", dim,
    221                                   " for tensor with ", ctx->input(0).dims(),
    222                                   " dimensions."));
    223       // If dim is < 0, we wrap around (-1 means the last element).
    224       if (dim < 0) {
    225         dim = existing_dims_size + dim;
    226       }
    227 
    228       wrapped_squeeze_dims.insert(dim);
    229     }
    230 
    231     for (int i = 0; i < existing_dims_size; ++i) {
    232       auto existing_dim = existing_dims[i];
    233 
    234       // If squeeze_set is non-empty, only squeeze those dimensions.
    235       if (!wrapped_squeeze_dims.empty()) {
    236         if (wrapped_squeeze_dims.count(i) > 0) {
    237           OP_REQUIRES(ctx, existing_dim == 1,
    238                       errors::InvalidArgument(
    239                           "Tried to explicitly squeeze "
    240                           "dimension ",
    241                           i, " but dimension was not 1: ", existing_dim));
    242         } else {
    243           // This dimension is not being squeezed.
    244           new_shape.push_back(existing_dim);
    245         }
    246       } else {
    247         // Copy over all non-1-length dimensions.
    248         if (existing_dim != 1) {
    249           new_shape.push_back(existing_dim);
    250         }
    251       }
    252     }
    253 
    254     const TensorShape output_shape(new_shape);
    255     Tensor* output = nullptr;
    256     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
    257     if (!output->CopyFrom(ctx->input(0), output_shape)) {
    258       // This should never happen, since the sizes of the input and
    259       // output should always be the same.
    260       ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
    261                                       ctx->input(0).shape().DebugString(),
    262                                       " and output shape ",
    263                                       output_shape.DebugString()));
    264     }
    265   }
    266 
    267   bool IsExpensive() override { return false; }
    268 
    269  private:
    270   std::unordered_set<int32> squeeze_dims_;
    271 };
    272 
    273 }  // namespace tensorflow
    274 
    275 #endif  // TENSORFLOW_KERNELS_SHAPE_OPS_H_
    276