Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // XLA-specific Transpose Op. This is very different to the Eigen
     17 // version in third_party/tensorflow because XLA's reshape neatly
     18 // handles all transposes, while Eigen needs a restricted DoTranspose
     19 // helper.
     20 
     21 #include "tensorflow/core/kernels/transpose_op.h"
     22 #include "tensorflow/compiler/tf2xla/type_util.h"
     23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     26 #include "tensorflow/core/framework/kernel_def_builder.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/kernels/bounds_check.h"
     29 
     30 namespace tensorflow {
     31 namespace {
     32 
     33 class TransposeOp : public XlaOpKernel {
     34  public:
     35   explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
     36 
     37   void Compile(XlaOpKernelContext* ctx) override {
     38     const TensorShape input_shape = ctx->InputShape(0);
     39     const TensorShape perm_tensor_shape = ctx->InputShape(1);
     40 
     41     // Preliminary validation of sizes.
     42     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
     43                 errors::InvalidArgument("perm must be a vector, not ",
     44                                         perm_tensor_shape.DebugString()));
     45 
     46     const int dims = input_shape.dims();
     47     OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(),
     48                 errors::InvalidArgument("transpose expects a vector of size ",
     49                                         input_shape.dims(),
     50                                         ". But input(1) is a vector of size ",
     51                                         perm_tensor_shape.num_elements()));
     52 
     53     xla::Literal literal;
     54     OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
     55 
     56     std::vector<int32> perm(dims);
     57     std::copy(literal.data<int32>().begin(), literal.data<int32>().end(),
     58               perm.begin());
     59 
     60     std::vector<int64> transposed_order;
     61     // Check whether permutation is a permutation of integers of [0 .. dims).
     62     gtl::InlinedVector<bool, 8> bits(dims);
     63     bool is_identity = true;
     64     for (int i = 0; i < dims; ++i) {
     65       const int32 d = perm[i];
     66       OP_REQUIRES(
     67           ctx, 0 <= d && d < dims,
     68           errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
     69       bits[d] = true;
     70       transposed_order.push_back(d);
     71       if (d != i) {
     72         is_identity = false;
     73       }
     74     }
     75     for (int i = 0; i < dims; ++i) {
     76       OP_REQUIRES(
     77           ctx, bits[i],
     78           errors::InvalidArgument(i, " is missing from 'perm' argument."));
     79     }
     80 
     81     // 0-D, 1-D, and identity transposes do nothing.
     82     if (dims <= 1 || is_identity) {
     83       ctx->SetOutput(0, ctx->Input(0));
     84       return;
     85     }
     86 
     87     ctx->SetOutput(0,
     88                    ctx->builder()->Transpose(ctx->Input(0), transposed_order));
     89   }
     90 };
     91 
     92 REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp);
     93 
     94 // InvertPermutation frequently forms part of the gradient of Transpose.
     95 //
     96 // inv = InvertPermutationOp(T<int32> p) takes a permutation of
     97 // integers 0, 1, ..., n - 1 and returns the inverted
     98 // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
     99 //
    100 // REQUIRES: input is a vector of int32.
    101 // REQUIRES: input is a permutation of 0, 1, ..., n-1.
    102 
    103 class InvertPermutationOp : public XlaOpKernel {
    104  public:
    105   explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
    106 
    107   void Compile(XlaOpKernelContext* ctx) override {
    108     OP_REQUIRES(ctx,
    109                 FastBoundsCheck(ctx->InputShape(0).num_elements(),
    110                                 std::numeric_limits<int32>::max()),
    111                 errors::InvalidArgument("permutation of nonnegative int32s "
    112                                         "must have <= int32 max elements"));
    113 
    114     std::vector<int64> perm;
    115     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm));
    116 
    117     int size = perm.size();
    118 
    119     std::vector<int32> output(size);
    120     std::fill_n(output.data(), size, -1);
    121     for (int i = 0; i < size; ++i) {
    122       const int64 d = perm[i];
    123       OP_REQUIRES(ctx, FastBoundsCheck(d, size),
    124                   errors::InvalidArgument(d, " is not between 0 and ", size));
    125       OP_REQUIRES(ctx, output[d] == -1,
    126                   errors::InvalidArgument(d, " is duplicated in the input."));
    127       output[d] = i;
    128     }
    129 
    130     ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output));
    131   }
    132 };
    133 
    134 REGISTER_XLA_OP(Name("InvertPermutation")
    135                     .TypeConstraint("T", DT_INT32)
    136                     .CompileTimeConstInput("x"),
    137                 InvertPermutationOp);
    138 
    139 }  // namespace
    140 }  // namespace tensorflow
    141