Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
     17 #define TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
     18 
     19 #include <numeric>
     20 #include <string>
     21 #include <vector>
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_types.h"
     24 #include "tensorflow/core/platform/logging.h"
     25 
     26 namespace tensorflow {
     27 // Transpose tensor 'in' into tensor 'out' according to dimension
     28 // permutation 'perm'.
     29 //
     30 // REQUIRES: in.dtype() == out->dtype()
     31 // REQUIRES: in.dims() == out->dims()
     32 // REQUIRES: in.dims() == perm.size()
     33 // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
     34 template <typename Device>
     35 Status DoTranspose(const Device& device, const Tensor& in,
     36                    const gtl::ArraySlice<int32> perm, Tensor* out);
     37 
     38 // Conjugate and transpose tensor 'in' into tensor 'out' according to dimension
     39 // permutation 'perm'.
     40 //
     41 // REQUIRES: in.dtype() == out->dtype()
     42 // REQUIRES: in.dims() == out->dims()
     43 // REQUIRES: in.dims() == perm.size()
     44 // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i)
     45 template <typename Device>
     46 Status DoConjugateTranspose(const Device& device, const Tensor& in,
     47                             const gtl::ArraySlice<int32> perm, Tensor* out);
     48 
     49 // Convenience versions of DoTranspose that only swap the last (inner) two
     50 // dimensions.
     51 template <typename Device>
     52 Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out);
     53 
     54 // Convenience versions of DoConjugateTranspose that only swap the last (inner)
     55 // two dimensions.
     56 template <typename Device>
     57 Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in,
     58                                   Tensor* out);
     59 
     60 // Primary device specific functor to be specialized for each device and type.
     61 template <typename Device, typename T, bool conjugate = false>
     62 struct Transpose {
     63   static void run(const Device& d, const Tensor& in,
     64                   const gtl::ArraySlice<int32> perm, Tensor* out);
     65 };
     66 
     67 // Implementation details.
     68 namespace internal {
     69 
     70 typedef gtl::InlinedVector<int64, 8> TransposeDimsVec;
     71 typedef gtl::InlinedVector<int32, 8> TransposePermsVec;
     72 
     73 // Helper function that takes a tensor shape, a permutation, combines the
     74 // neighboring shapes if their indices in the permutation are consecutive.
     75 // The function outputs the combined shape and new permutation.
     76 // Example: Tensor shape {2, 3, 4, 5, 120} and permutation {0, 4, 1, 2, 3} will
     77 // produce new shape {2, 60, 120} and new permutation {0, 2, 1}.
     78 inline void ReduceTransposeDimensions(const TensorShape& shape,
     79                                       gtl::ArraySlice<int32> perm,
     80                                       TransposePermsVec* new_perm,
     81                                       TransposeDimsVec* new_dims) {
     82   CHECK_EQ(shape.dims(), perm.size());
     83   if (shape.dims() == 1) {
     84     // If input dimension is already 1, no need to reduce dimension.
     85     new_perm->resize(1);
     86     (*new_perm)[0] = perm[0];
     87     (*new_dims)[0] = shape.dim_size(0);
     88     return;
     89   }
     90   TransposePermsVec new_dim_position(shape.dims(), -1);
     91   TransposeDimsVec combined_dims(shape.dims(), 0);
     92   int cur_head = perm[0];
     93   new_dim_position[cur_head] = 0;
     94   combined_dims[0] = shape.dim_size(cur_head);
     95   int dim_idx = 0;
     96   for (int perm_idx = 1; perm_idx < shape.dims(); ++perm_idx) {
     97     // If two indices in permutation are consecutive numbers, combine their
     98     // dimensions.
     99     if (cur_head + 1 == perm[perm_idx]) {
    100       cur_head = perm[perm_idx];
    101       combined_dims[dim_idx] *= shape.dim_size(cur_head);
    102     } else {
    103       // Else start a new dimension.
    104       cur_head = perm[perm_idx];
    105       dim_idx++;
    106       new_dim_position[cur_head] = dim_idx;
    107       combined_dims[dim_idx] = shape.dim_size(cur_head);
    108     }
    109   }
    110   // Compact the new permutations and dimension sizes.
    111   new_perm->resize(dim_idx + 1);
    112   new_dims->resize(dim_idx + 1);
    113   dim_idx = 0;
    114   for (int i = 0; i < new_dim_position.size(); ++i) {
    115     if (new_dim_position[i] >= 0) {
    116       int new_perm_idx = new_dim_position[i];
    117       (*new_perm)[dim_idx] = new_perm_idx;
    118       (*new_dims)[dim_idx] = combined_dims[new_perm_idx];
    119       dim_idx++;
    120     }
    121   }
    122 }
    123 
    124 // If all non-singleton dimensions remain in ascending order, the shuffled
    125 // singletons can be transposed by a reshape, saving a memory allocation & copy.
    126 // |permutation| must be a permutation of {0, .., input_shape.dims() - 1}.
    127 // That is, for all i, 0 <= perm[i] < input_shape.dims().
    128 // In practice, this is checked in TransposeOp::Compute prior to calling this
    129 // function, and the function sits here to facilitate unit testing.
    130 inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape,
    131                                         const std::vector<int32>& permutation) {
    132   int last_nonsingleton_perm_dim = -1;
    133   for (int perm_dim : permutation) {
    134     if (input_shape.dim_size(perm_dim) == 1) {
    135       continue;
    136     }
    137     if (perm_dim < last_nonsingleton_perm_dim) {
    138       return false;
    139     }
    140     last_nonsingleton_perm_dim = perm_dim;
    141   }
    142   return true;
    143 }
    144 
    145 // Uses Eigen to transpose.
    146 template <typename Device, typename T, int NDIMS>
    147 void TransposeUsingEigen(const Device& d, const Tensor& in,
    148                          const gtl::ArraySlice<int32> perm, bool conjugate,
    149                          Tensor* out) {
    150   Eigen::array<int, NDIMS> p;
    151   for (int i = 0; i < NDIMS; ++i) p[i] = perm[i];
    152   auto x = typename TTypes<T, NDIMS>::ConstTensor(
    153       reinterpret_cast<const T*>(in.tensor_data().data()),
    154       in.shape().AsEigenDSizes<NDIMS>());
    155   auto y = typename TTypes<T, NDIMS>::Tensor(
    156       reinterpret_cast<T*>(const_cast<char*>(out->tensor_data().data())),
    157       out->shape().AsEigenDSizes<NDIMS>());
    158   if (conjugate) {
    159     y.device(d) = x.conjugate().shuffle(p);
    160   } else {
    161     y.device(d) = x.shuffle(p);
    162   }
    163 }
    164 
    165 template <typename Device>
    166 Status DoTransposeImpl(const Device& d, const Tensor& in,
    167                        const gtl::ArraySlice<int32> perm, bool conjugate,
    168                        Tensor* out) {
    169   CHECK_GE(in.dims(), 2);
    170   CHECK_EQ(in.dims(), out->dims());
    171   CHECK_EQ(in.dims(), perm.size());
    172   CHECK_EQ(in.dtype(), out->dtype());
    173   switch (in.dtype()) {
    174     case DT_BOOL:
    175     case DT_INT8:
    176     case DT_QINT8:
    177     case DT_QUINT8:
    178     case DT_UINT8:
    179       Transpose<Device, uint8>::run(d, in, perm, out);
    180       break;
    181 
    182     case DT_BFLOAT16:
    183     case DT_HALF:
    184     case DT_INT16:
    185     case DT_QINT16:
    186     case DT_QUINT16:
    187     case DT_UINT16:
    188       Transpose<Device, uint16>::run(d, in, perm, out);
    189       break;
    190 
    191     case DT_FLOAT:
    192     case DT_INT32:
    193     case DT_QINT32:
    194       Transpose<Device, uint32>::run(d, in, perm, out);
    195       break;
    196 
    197     case DT_DOUBLE:
    198     case DT_INT64:
    199       Transpose<Device, uint64>::run(d, in, perm, out);
    200       break;
    201 
    202     case DT_COMPLEX64:
    203       if (conjugate) {
    204 #if defined(__ANDROID__) and !defined(__clang__)
    205         // Workaround for GCC compiler bug in Android toolchain.
    206         return errors::Unimplemented(
    207             "Conjugate transpose of complex64 not supported for GCC on "
    208             "Android.");
    209 #else
    210         Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out);
    211 #endif
    212       } else {
    213         Transpose<Device, uint64>::run(d, in, perm, out);
    214       }
    215       break;
    216 
    217     case DT_COMPLEX128:
    218       if (conjugate) {
    219         Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm,
    220                                                                out);
    221       } else {
    222         Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm,
    223                                                                 out);
    224       }
    225       break;
    226 
    227     case DT_STRING:
    228       Transpose<Device, string>::run(d, in, perm, out);
    229       break;
    230 
    231     default:
    232       return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype());
    233   }
    234   return Status::OK();
    235 }
    236 
    237 template <typename Device>
    238 inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in,
    239                                     bool conjugate, Tensor* out) {
    240   const int ndims = in.dims();
    241   if (ndims == 0) return Status::OK();
    242   TransposePermsVec perm(ndims);
    243   std::iota(perm.begin(), perm.end(), 0);
    244   std::swap(perm[ndims - 2], perm[ndims - 1]);
    245   return DoTransposeImpl(device, in, perm, conjugate, out);
    246 }
    247 
    248 #ifdef TENSORFLOW_USE_SYCL
    249 // For SYCL lets always go through Eigen
    250 template <typename Device, typename T>
    251 void TransposeSYCL(const Device& d, const Tensor& in,
    252                    const gtl::ArraySlice<int32> perm, bool conjugate,
    253                    Tensor* out);
    254 #endif  // TENSORFLOW_USE_SYCL
    255 
    256 }  // namespace internal
    257 }  // namespace tensorflow
    258 
    259 #endif  // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_
    260