Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
     17 #define TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
     18 
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 
     22 namespace tensorflow {
     23 
     24 class TransposeOp : public OpKernel {
     25  public:
     26   explicit TransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
     27 
     28   void Compute(OpKernelContext* ctx) override;
     29 
     30  protected:
     31   virtual Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
     32                              gtl::ArraySlice<int32> perm, Tensor* out) = 0;
     33   virtual bool IsConjugate() const { return false; }
     34 };
     35 
     36 class TransposeCpuOp : public TransposeOp {
     37  public:
     38   explicit TransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
     39 
     40  protected:
     41   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
     42                      gtl::ArraySlice<int32> perm, Tensor* out) override;
     43 };
     44 
     45 #ifdef INTEL_MKL
     46 class MklTransposeCpuOp : public TransposeOp {
     47  public:
     48   explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
     49 
     50  protected:
     51   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
     52                      gtl::ArraySlice<int32> perm, Tensor* out) override;
     53 };
     54 #endif  // INTEL_MKL
     55 
     56 class TransposeGpuOp : public TransposeOp {
     57  public:
     58   explicit TransposeGpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
     59 
     60  protected:
     61   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
     62                      gtl::ArraySlice<int32> perm, Tensor* out) override;
     63 };
     64 
     65 #ifdef TENSORFLOW_USE_SYCL
     66 class TransposeSyclOp : public TransposeOp {
     67  public:
     68   explicit TransposeSyclOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
     69 
     70  protected:
     71   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
     72                      gtl::ArraySlice<int32> perm, Tensor* out) override;
     73 };
     74 #endif  // TENSORFLOW_USE_SYCL
     75 
     76 // Conjugating transpose ops.
     77 class ConjugateTransposeCpuOp : public TransposeOp {
     78  public:
     79   explicit ConjugateTransposeCpuOp(OpKernelConstruction* ctx)
     80       : TransposeOp(ctx) {}
     81 
     82  protected:
     83   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
     84                      gtl::ArraySlice<int32> perm, Tensor* out) override;
     85   bool IsConjugate() const override { return true; }
     86 };
     87 
     88 #ifdef INTEL_MKL
     89 class MklConjugateTransposeCpuOp : public TransposeOp {
     90  public:
     91   explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx)
     92       : TransposeOp(ctx) {}
     93 
     94  protected:
     95   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
     96                      gtl::ArraySlice<int32> perm, Tensor* out) override;
     97   bool IsConjugate() const override { return true; }
     98 };
     99 #endif  // INTEL_MKL
    100 
    101 class ConjugateTransposeGpuOp : public TransposeOp {
    102  public:
    103   explicit ConjugateTransposeGpuOp(OpKernelConstruction* ctx)
    104       : TransposeOp(ctx) {}
    105 
    106  protected:
    107   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
    108                      gtl::ArraySlice<int32> perm, Tensor* out) override;
    109   bool IsConjugate() const override { return true; }
    110 };
    111 
    112 #ifdef TENSORFLOW_USE_SYCL
    113 class ConjugateTransposeSyclOp : public TransposeOp {
    114  public:
    115   explicit ConjugateTransposeSyclOp(OpKernelConstruction* ctx)
    116       : TransposeOp(ctx) {}
    117 
    118  protected:
    119   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
    120                      gtl::ArraySlice<int32> perm, Tensor* out) override;
    121   bool IsConjugate() const override { return true; }
    122 };
    123 #endif  // TENSORFLOW_USE_SYCL
    124 
    125 }  // namespace tensorflow
    126 
    127 #endif  // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
    128