1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Transpose Op. This is very different to the Eigen 17 // version in third_party/tensorflow because XLA's reshape neatly 18 // handles all transposes, while Eigen needs a restricted DoTranspose 19 // helper. 20 21 #include "tensorflow/core/kernels/transpose_op.h" 22 #include "tensorflow/compiler/tf2xla/type_util.h" 23 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 26 #include "tensorflow/core/framework/kernel_def_builder.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/kernels/bounds_check.h" 29 30 namespace tensorflow { 31 namespace { 32 33 class TransposeOp : public XlaOpKernel { 34 public: 35 explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 36 37 void Compile(XlaOpKernelContext* ctx) override { 38 const TensorShape input_shape = ctx->InputShape(0); 39 const TensorShape perm_tensor_shape = ctx->InputShape(1); 40 41 // Preliminary validation of sizes. 42 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape), 43 errors::InvalidArgument("perm must be a vector, not ", 44 perm_tensor_shape.DebugString())); 45 46 const int dims = input_shape.dims(); 47 OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(), 48 errors::InvalidArgument("transpose expects a vector of size ", 49 input_shape.dims(), 50 ". But input(1) is a vector of size ", 51 perm_tensor_shape.num_elements())); 52 53 xla::Literal literal; 54 OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); 55 56 std::vector<int32> perm(dims); 57 std::copy(literal.data<int32>().begin(), literal.data<int32>().end(), 58 perm.begin()); 59 60 std::vector<int64> transposed_order; 61 // Check whether permutation is a permutation of integers of [0 .. dims). 62 gtl::InlinedVector<bool, 8> bits(dims); 63 bool is_identity = true; 64 for (int i = 0; i < dims; ++i) { 65 const int32 d = perm[i]; 66 OP_REQUIRES( 67 ctx, 0 <= d && d < dims, 68 errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")")); 69 bits[d] = true; 70 transposed_order.push_back(d); 71 if (d != i) { 72 is_identity = false; 73 } 74 } 75 for (int i = 0; i < dims; ++i) { 76 OP_REQUIRES( 77 ctx, bits[i], 78 errors::InvalidArgument(i, " is missing from 'perm' argument.")); 79 } 80 81 // 0-D, 1-D, and identity transposes do nothing. 82 if (dims <= 1 || is_identity) { 83 ctx->SetOutput(0, ctx->Input(0)); 84 return; 85 } 86 87 ctx->SetOutput(0, 88 ctx->builder()->Transpose(ctx->Input(0), transposed_order)); 89 } 90 }; 91 92 REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); 93 94 // InvertPermutation frequently forms part of the gradient of Transpose. 95 // 96 // inv = InvertPermutationOp(T<int32> p) takes a permutation of 97 // integers 0, 1, ..., n - 1 and returns the inverted 98 // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). 99 // 100 // REQUIRES: input is a vector of int32. 101 // REQUIRES: input is a permutation of 0, 1, ..., n-1. 102 103 class InvertPermutationOp : public XlaOpKernel { 104 public: 105 explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 106 107 void Compile(XlaOpKernelContext* ctx) override { 108 OP_REQUIRES(ctx, 109 FastBoundsCheck(ctx->InputShape(0).num_elements(), 110 std::numeric_limits<int32>::max()), 111 errors::InvalidArgument("permutation of nonnegative int32s " 112 "must have <= int32 max elements")); 113 114 std::vector<int64> perm; 115 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm)); 116 117 int size = perm.size(); 118 119 std::vector<int32> output(size); 120 std::fill_n(output.data(), size, -1); 121 for (int i = 0; i < size; ++i) { 122 const int64 d = perm[i]; 123 OP_REQUIRES(ctx, FastBoundsCheck(d, size), 124 errors::InvalidArgument(d, " is not between 0 and ", size)); 125 OP_REQUIRES(ctx, output[d] == -1, 126 errors::InvalidArgument(d, " is duplicated in the input.")); 127 output[d] = i; 128 } 129 130 ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output)); 131 } 132 }; 133 134 REGISTER_XLA_OP(Name("InvertPermutation") 135 .TypeConstraint("T", DT_INT32) 136 .CompileTimeConstInput("x"), 137 InvertPermutationOp); 138 139 } // namespace 140 } // namespace tensorflow 141