Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <complex>
     19 
     20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     21 #include "tensorflow/core/framework/attr_value.pb.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/kernels/ops_util.h"
     24 #include "tensorflow/core/kernels/transpose_functor.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/gtl/array_slice.h"
     27 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     28 
     29 typedef Eigen::ThreadPoolDevice CPUDevice;
     30 
     31 namespace tensorflow {
     32 namespace {
     33 
     34 template <typename T, bool conjugate>
     35 void TransposeSimple(const CPUDevice& device, const Tensor& in,
     36                      const gtl::ArraySlice<int32> perm, Tensor* out) {
     37   const int ndims = in.dims();
     38   gtl::InlinedVector<int64, 8> in_strides = ComputeStride<int64>(in.shape());
     39   gtl::InlinedVector<int64, 8> out_strides = ComputeStride<int64>(out->shape());
     40   const T* p = reinterpret_cast<const T*>(in.tensor_data().data());
     41   T* q = reinterpret_cast<T*>(const_cast<char*>((out->tensor_data().data())));
     42   auto transpose_fn = [=, &in_strides, &out_strides, &perm](int64 begin,
     43                                                             int64 end) {
     44     for (int64 o_idx = begin; o_idx < end; ++o_idx) {
     45       int64 i_idx = 0;
     46       int64 t = o_idx;
     47       for (int i = 0; i < ndims; ++i) {
     48         const int64 ratio = t / out_strides[i];
     49         t -= ratio * out_strides[i];
     50         i_idx += ratio * in_strides[perm[i]];
     51       }
     52       if (conjugate) {
     53         q[o_idx] = Eigen::numext::conj(p[i_idx]);
     54       } else {
     55         q[o_idx] = p[i_idx];
     56       }
     57     }
     58   };
     59   double cycles_per_element =
     60       (conjugate ? 1 : 0) + ndims * (Eigen::TensorOpCost::DivCost<int64>() +
     61                                      2 * Eigen::TensorOpCost::MulCost<int64>() +
     62                                      2 * Eigen::TensorOpCost::AddCost<int64>());
     63   Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T),
     64                            /*bytes_stored=*/sizeof(T), cycles_per_element);
     65   device.parallelFor(in.NumElements(), cost, std::move(transpose_fn));
     66 }
     67 
     68 }  // namespace
     69 
     70 template <typename T, bool conjugate>
     71 struct Transpose<CPUDevice, T, conjugate> {
     72   static void run(const CPUDevice& d, const Tensor& in,
     73                   const gtl::ArraySlice<int32> perm, Tensor* out) {
     74     switch (in.dims()) {
     75       case 2:
     76         internal::TransposeUsingEigen<CPUDevice, T, 2>(d, in, perm, conjugate,
     77                                                        out);
     78         break;
     79       case 3:
     80         internal::TransposeUsingEigen<CPUDevice, T, 3>(d, in, perm, conjugate,
     81                                                        out);
     82         break;
     83       case 4:
     84         internal::TransposeUsingEigen<CPUDevice, T, 4>(d, in, perm, conjugate,
     85                                                        out);
     86         break;
     87       case 5:
     88         internal::TransposeUsingEigen<CPUDevice, T, 5>(d, in, perm, conjugate,
     89                                                        out);
     90         break;
     91       case 6:
     92         internal::TransposeUsingEigen<CPUDevice, T, 6>(d, in, perm, conjugate,
     93                                                        out);
     94         break;
     95       case 7:
     96         internal::TransposeUsingEigen<CPUDevice, T, 7>(d, in, perm, conjugate,
     97                                                        out);
     98         break;
     99       case 8:
    100         internal::TransposeUsingEigen<CPUDevice, T, 8>(d, in, perm, conjugate,
    101                                                        out);
    102         break;
    103       default:
    104         TransposeSimple<T, conjugate>(d, in, perm, out);
    105         break;
    106     }
    107   }
    108 };
    109 
    110 #define INSTANTIATE(DEVICE)                                                 \
    111   template <>                                                               \
    112   Status DoTranspose(const DEVICE& device, const Tensor& in,                \
    113                      const gtl::ArraySlice<int32> perm, Tensor* out) {      \
    114     return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/false, \
    115                                      out);                                  \
    116   }                                                                         \
    117   template <>                                                               \
    118   Status DoConjugateTranspose(const DEVICE& device, const Tensor& in,       \
    119                               const gtl::ArraySlice<int32> perm,            \
    120                               Tensor* out) {                                \
    121     return internal::DoTransposeImpl(device, in, perm, /*conjugate=*/true,  \
    122                                      out);                                  \
    123   }                                                                         \
    124   template <>                                                               \
    125   Status DoMatrixTranspose(const DEVICE& device, const Tensor& in,          \
    126                            Tensor* out) {                                   \
    127     return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/false, \
    128                                            out);                            \
    129   }                                                                         \
    130   template <>                                                               \
    131   Status DoConjugateMatrixTranspose(const DEVICE& device, const Tensor& in, \
    132                                     Tensor* out) {                          \
    133     return internal::DoMatrixTransposeImpl(device, in, /*conjugate=*/true,  \
    134                                            out);                            \
    135   }
    136 
    137 INSTANTIATE(CPUDevice)
    138 
    139 #ifdef TENSORFLOW_USE_SYCL
    140 typedef Eigen::SyclDevice SYCLDevice;
    141 
    142 namespace internal {
    143 template <typename T>
    144 void TransposeSYCL(const SYCLDevice& d, const Tensor& in,
    145                    const gtl::ArraySlice<int32> perm, bool conjugate,
    146                    Tensor* out) {
    147   switch (in.dims()) {
    148     case 1:
    149       TransposeUsingEigen<SYCLDevice, T, 1>(d, in, perm, conjugate, out);
    150       break;
    151     case 2:
    152       TransposeUsingEigen<SYCLDevice, T, 2>(d, in, perm, conjugate, out);
    153       break;
    154     case 3:
    155       TransposeUsingEigen<SYCLDevice, T, 3>(d, in, perm, conjugate, out);
    156       break;
    157     case 4:
    158       TransposeUsingEigen<SYCLDevice, T, 4>(d, in, perm, conjugate, out);
    159       break;
    160     case 5:
    161       TransposeUsingEigen<SYCLDevice, T, 5>(d, in, perm, conjugate, out);
    162       break;
    163     case 6:
    164       TransposeUsingEigen<SYCLDevice, T, 6>(d, in, perm, conjugate, out);
    165       break;
    166     case 7:
    167       TransposeUsingEigen<SYCLDevice, T, 7>(d, in, perm, conjugate, out);
    168       break;
    169     case 8:
    170       TransposeUsingEigen<SYCLDevice, T, 8>(d, in, perm, conjugate, out);
    171       break;
    172     default:
    173       LOG(FATAL) << "Unsupported TransposeUsingEigen for: " << in.dims();
    174       break;
    175   }
    176 }
    177 
    178 }  // namespace internal
    179 
    180 template <typename T, bool conjugate>
    181 struct Transpose<SYCLDevice, T, conjugate> {
    182   static void run(const SYCLDevice& d, const Tensor& in,
    183                   const gtl::ArraySlice<int32> perm, Tensor* out) {
    184     internal::TransposeSycl(d, in, perm, conjugate, out);
    185   }
    186 };
    187 
    188 template <bool conjugate>
    189 struct Transpose<SYCLDevice, string, conjugate> {
    190   static void run(const SYCLDevice& d, const Tensor& in,
    191                   const gtl::ArraySlice<int32> perm, Tensor* out) {
    192     LOG(FATAL) << "DT_STRING not supported on SYCL device.";
    193   }
    194 };
    195 
    196 // Explicit instantiation.
    197 template struct Transpose<SYCLDevice, string, false>;
    198 
    199 INSTANTIATE(SYCLDevice)
    200 #undef INSTANTIATE
    201 
    202 #endif  // TENSORFLOW_USE_SYCL
    203 
    204 }  // namespace tensorflow
    205