Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/array_ops.cc.
     17 
     18 #ifdef INTEL_MKL
     19 #define EIGEN_USE_THREADS
     20 
     21 #include "mkl_trans.h"
     22 #include "tensorflow/core/kernels/transpose_functor.h"
     23 #include "tensorflow/core/kernels/transpose_op.h"
     24 
     25 namespace tensorflow {
     26 
     27 // output = TransposeOp(T<any> input, T<int32> perm) takes a tensor
     28 // of type T and rank N, and a permutation of 0, 1, ..., N-1. It
     29 // shuffles the dimensions of the input tensor according to permutation.
     30 //
     31 // Specifically, the returned tensor output meets the following condition:
     32 // 1) output.dims() == input.dims();
     33 // 2) output.dim_size(i) == input.dim_size(perm[i]);
     34 // 3) output.tensor<T, N>(i_0, i_1, ..., i_N-1) ==
     35 //      input.tensor<T, N>(j_0, j_1, ..., j_N-1),
     36 //    where i_s == j_{perm[s]}
     37 //
     38 // REQUIRES: perm is a vector of int32.
     39 // REQUIRES: input.dims() == perm.size().
     40 // REQUIRES: perm is a permutation.
     41 
     42 namespace {
     43 template <typename T>
     44 Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);
     45 
     46 // Documentation here: https://software.intel.com/en-us/node/520863
     47 // Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols,
     48 //              alpha (for scaling), array, dist_bet_adjacent_cols/rows
     49 //              (source), array, dist_bet_adjacent_cols/rows (dest))
     50 
     51 #define INSTANTIATE(T, PREFIX)                                                \
     52   template <>                                                                 \
     53   Status MKLTranspose2D<T>(const char trans, const Tensor& in, Tensor* out) { \
     54     mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1,     \
     55                            in.flat<T>().data(), in.dim_size(1),               \
     56                            out->flat<T>().data(), in.dim_size(0));            \
     57     return Status::OK();                                                      \
     58   }
     59 
     60 INSTANTIATE(float, s)
     61 INSTANTIATE(double, d)
     62 
     63 #undef INSTANTIATE
     64 
     65 template <>
     66 Status MKLTranspose2D<complex64>(const char trans, const Tensor& in, Tensor* out) {
     67     const MKL_Complex8 alpha = { 1.0f, 0.0f };
     68     mkl_comatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha,
     69                   reinterpret_cast<const MKL_Complex8*>(in.flat<complex64>().data()),
     70                   in.dim_size(1),
     71                   reinterpret_cast<MKL_Complex8*>(const_cast<complex64*>(out->flat<complex64>().data())),
     72                   in.dim_size(0));
     73     return Status::OK();
     74 }
     75 
     76 template <>
     77 Status MKLTranspose2D<complex128>(const char trans, const Tensor& in, Tensor* out) {
     78     const MKL_Complex16 alpha = { 1.0, 0.0 };
     79     mkl_zomatcopy('R', trans, in.dim_size(0), in.dim_size(1), alpha,
     80                   reinterpret_cast<const MKL_Complex16*>(in.flat<complex128>().data()),
     81                   in.dim_size(1),
     82                   reinterpret_cast<MKL_Complex16*>(const_cast<complex128*>(out->flat<complex128>().data())),
     83                   in.dim_size(0));
     84 	return Status::OK();
     85 }
     86 
     87 static const char kMKLTranspose = 'T';
     88 static const char kMKLConjugateTranspose = 'C';
     89 
     90 }  // namespace
     91 
     92 Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
     93                                       gtl::ArraySlice<int32> perm,
     94                                       Tensor* out) {
     95   if (in.dims() == 2) {
     96     if (perm[0] == 0 && perm[1] == 1) {
     97       return Status::OK();
     98     }
     99     switch (in.dtype()) {
    100       case DT_FLOAT:
    101         return MKLTranspose2D<float>(kMKLTranspose, in, out);
    102       case DT_DOUBLE:
    103         return MKLTranspose2D<double>(kMKLTranspose, in, out);
    104       case DT_COMPLEX64:
    105         return MKLTranspose2D<complex64>(kMKLTranspose, in, out);
    106       case DT_COMPLEX128:
    107         return MKLTranspose2D<complex128>(kMKLTranspose, in, out);
    108       default:
    109         break;
    110     }
    111   }
    112   // Fallback to eigen if transpose parameters not supported by MKL
    113   typedef Eigen::ThreadPoolDevice CPUDevice;
    114   return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
    115                                    out);
    116 }
    117 
    118 Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
    119                                                const Tensor& in,
    120                                                gtl::ArraySlice<int32> perm,
    121                                                Tensor* out) {
    122   if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) {
    123     // TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
    124     // for any transpose that can be reduced to swapping the last two
    125     // dimensions in a rank-3 tensor. We can even run each outer dimension in
    126     // a separate thread.
    127     switch (in.dtype()) {
    128       case DT_FLOAT:
    129         return MKLTranspose2D<float>(kMKLTranspose, in, out);
    130       case DT_DOUBLE:
    131         return MKLTranspose2D<double>(kMKLTranspose, in, out);
    132       case DT_COMPLEX64:
    133         return MKLTranspose2D<complex64>(kMKLConjugateTranspose, in, out);
    134       case DT_COMPLEX128:
    135         return MKLTranspose2D<complex128>(kMKLConjugateTranspose, in, out);
    136       default:
    137         break;
    138     }
    139   }
    140   // Fallback to eigen if transpose parameters not supported by MKL
    141   typedef Eigen::ThreadPoolDevice CPUDevice;
    142   return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
    143                                             perm, out);
    144 }
    145 
    146 }  // namespace tensorflow
    147 
    148 #endif  // INTEL_MKL
    149