1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ 17 #define TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ 18 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/tensor.h" 21 22 namespace tensorflow { 23 24 class TransposeOp : public OpKernel { 25 public: 26 explicit TransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 27 28 void Compute(OpKernelContext* ctx) override; 29 30 protected: 31 virtual Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 32 gtl::ArraySlice<int32> perm, Tensor* out) = 0; 33 virtual bool IsConjugate() const { return false; } 34 }; 35 36 class TransposeCpuOp : public TransposeOp { 37 public: 38 explicit TransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} 39 40 protected: 41 Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 42 gtl::ArraySlice<int32> perm, Tensor* out) override; 43 }; 44 45 #ifdef INTEL_MKL 46 class MklTransposeCpuOp : public TransposeOp { 47 public: 48 explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} 49 50 protected: 51 Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 52 gtl::ArraySlice<int32> perm, Tensor* out) override; 53 }; 54 #endif // INTEL_MKL 55 56 class TransposeGpuOp : public TransposeOp { 57 public: 58 explicit TransposeGpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} 59 60 protected: 61 Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 62 gtl::ArraySlice<int32> perm, Tensor* out) override; 63 }; 64 65 #ifdef TENSORFLOW_USE_SYCL 66 class TransposeSyclOp : public TransposeOp { 67 public: 68 explicit TransposeSyclOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} 69 70 protected: 71 Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 72 gtl::ArraySlice<int32> perm, Tensor* out) override; 73 }; 74 #endif // TENSORFLOW_USE_SYCL 75 76 // Conjugating transpose ops. 77 class ConjugateTransposeCpuOp : public TransposeOp { 78 public: 79 explicit ConjugateTransposeCpuOp(OpKernelConstruction* ctx) 80 : TransposeOp(ctx) {} 81 82 protected: 83 Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 84 gtl::ArraySlice<int32> perm, Tensor* out) override; 85 bool IsConjugate() const override { return true; } 86 }; 87 88 #ifdef INTEL_MKL 89 class MklConjugateTransposeCpuOp : public TransposeOp { 90 public: 91 explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx) 92 : TransposeOp(ctx) {} 93 94 protected: 95 Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 96 gtl::ArraySlice<int32> perm, Tensor* out) override; 97 bool IsConjugate() const override { return true; } 98 }; 99 #endif // INTEL_MKL 100 101 class ConjugateTransposeGpuOp : public TransposeOp { 102 public: 103 explicit ConjugateTransposeGpuOp(OpKernelConstruction* ctx) 104 : TransposeOp(ctx) {} 105 106 protected: 107 Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 108 gtl::ArraySlice<int32> perm, Tensor* out) override; 109 bool IsConjugate() const override { return true; } 110 }; 111 112 #ifdef TENSORFLOW_USE_SYCL 113 class ConjugateTransposeSyclOp : public TransposeOp { 114 public: 115 explicit ConjugateTransposeSyclOp(OpKernelConstruction* ctx) 116 : TransposeOp(ctx) {} 117 118 protected: 119 Status DoTranspose(OpKernelContext* ctx, const Tensor& in, 120 gtl::ArraySlice<int32> perm, Tensor* out) override; 121 bool IsConjugate() const override { return true; } 122 }; 123 #endif // TENSORFLOW_USE_SYCL 124 125 } // namespace tensorflow 126 127 #endif // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ 128