1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/matrix_diag_op.h" 25 26 #include <memory> 27 #include <vector> 28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/register_types.h" 31 #include "tensorflow/core/framework/tensor.h" 32 #include "tensorflow/core/framework/tensor_shape.h" 33 #include "tensorflow/core/framework/tensor_types.h" 34 #include "tensorflow/core/framework/types.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/macros.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 typedef Eigen::GpuDevice GPUDevice; 42 43 template <typename Device, typename T> 44 class MatrixDiagPartOp : public OpKernel { 45 public: 46 explicit MatrixDiagPartOp(OpKernelConstruction* context) 47 : OpKernel(context) {} 48 49 void Compute(OpKernelContext* context) override { 50 const Tensor& input = context->input(0); 51 52 const TensorShape& input_shape = input.shape(); 53 const int rank = input_shape.dims(); 54 55 // Preliminary validation of sizes. 56 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), 57 errors::InvalidArgument( 58 "input must be at least 2-dim, received shape: ", 59 input.shape().DebugString())); 60 61 TensorShape output_shape; 62 for (int i = 0; i < rank - 2; ++i) { 63 output_shape.AddDim(input_shape.dim_size(i)); 64 } 65 const int64 min_dim = std::min(input_shape.dim_size(rank - 2), 66 input_shape.dim_size(rank - 1)); 67 output_shape.AddDim(min_dim); 68 69 Tensor* output = nullptr; 70 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 71 72 auto output_reshaped = output->flat_inner_dims<T, 2>(); 73 auto input_reshaped = input.flat_inner_dims<T, 3>(); 74 75 functor::MatrixDiagPart<Device, T>::Compute( 76 context->eigen_device<Device>(), input_reshaped, output_reshaped); 77 } 78 79 private: 80 TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagPartOp); 81 }; 82 83 template <typename Device, typename T> 84 class MatrixDiagOp : public OpKernel { 85 public: 86 explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {} 87 88 void Compute(OpKernelContext* context) override { 89 const Tensor& input = context->input(0); 90 91 const TensorShape& input_shape = input.shape(); 92 const int rank = input_shape.dims(); 93 94 // Preliminary validation of sizes. 95 OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(input_shape), 96 errors::InvalidArgument( 97 "input must be at least 1-dim, received shape: ", 98 input.shape().DebugString())); 99 100 const int64 k = input_shape.dim_size(rank - 1); 101 auto input_reshaped = input.flat_inner_dims<T, 2>(); 102 103 TensorShape output_shape = input_shape; 104 output_shape.AddDim(k); 105 106 Tensor* output = nullptr; 107 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 108 109 auto output_reshaped = output->flat_inner_dims<T, 3>(); 110 111 functor::MatrixDiag<Device, T>::Compute(context->eigen_device<Device>(), 112 input_reshaped, output_reshaped); 113 } 114 115 private: 116 TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagOp); 117 }; 118 119 #define REGISTER_MATRIX_DIAG(type) \ 120 REGISTER_KERNEL_BUILDER( \ 121 Name("MatrixDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 122 MatrixDiagOp<CPUDevice, type>); \ 123 REGISTER_KERNEL_BUILDER( \ 124 Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 125 MatrixDiagPartOp<CPUDevice, type>); 126 TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG); 127 #undef REGISTER_MATRIX_DIAG 128 129 // Registration of the deprecated kernel. 130 // Delete after 10mar2017. 131 #define REGISTER_BATCH_MATRIX_DIAG(type) \ 132 REGISTER_KERNEL_BUILDER( \ 133 Name("BatchMatrixDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 134 MatrixDiagOp<CPUDevice, type>); \ 135 REGISTER_KERNEL_BUILDER(Name("BatchMatrixDiagPart") \ 136 .Device(DEVICE_CPU) \ 137 .TypeConstraint<type>("T"), \ 138 MatrixDiagPartOp<CPUDevice, type>); 139 TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG); 140 #undef REGISTER_BATCH_MATRIX_DIAG 141 142 // Implementation of the functor specialization for CPU. 143 namespace functor { 144 template <typename T> 145 struct MatrixDiag<CPUDevice, T> { 146 static void Compute(const CPUDevice& d, 147 typename TTypes<T, 2>::ConstTensor input, 148 typename TTypes<T, 3>::Tensor output) { 149 output.device(d) = output.constant(T()); 150 for (int64 r = 0; r < output.dimension(0); ++r) { 151 for (int64 d = 0; d < output.dimension(1); ++d) { 152 output(r, d, d) = input(r, d); 153 } 154 } 155 } 156 }; 157 158 template <typename T> 159 struct MatrixDiagPart<CPUDevice, T> { 160 static void Compute(const CPUDevice& d, 161 typename TTypes<T, 3>::ConstTensor input, 162 typename TTypes<T, 2>::Tensor output) { 163 for (int64 r = 0; r < output.dimension(0); ++r) { 164 for (int64 d = 0; d < output.dimension(1); ++d) { 165 output(r, d) = input(r, d, d); 166 } 167 } 168 } 169 }; 170 171 } // namespace functor 172 173 #if GOOGLE_CUDA 174 175 // Forward declarations of the functor specializations for GPU. 176 namespace functor { 177 #define DECLARE_GPU_SPEC(T) \ 178 template <> \ 179 void MatrixDiag<GPUDevice, T>::Compute( \ 180 const GPUDevice& d, typename TTypes<T, 2>::ConstTensor input, \ 181 typename TTypes<T, 3>::Tensor output); \ 182 extern template struct MatrixDiag<GPUDevice, T>; \ 183 template <> \ 184 void MatrixDiagPart<GPUDevice, T>::Compute( \ 185 const GPUDevice& d, typename TTypes<T, 3>::ConstTensor input, \ 186 typename TTypes<T, 2>::Tensor output); \ 187 extern template struct MatrixDiagPart<GPUDevice, T>; 188 189 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 190 TF_CALL_bool(DECLARE_GPU_SPEC); 191 TF_CALL_complex64(DECLARE_GPU_SPEC); 192 TF_CALL_complex128(DECLARE_GPU_SPEC); 193 194 } // namespace functor 195 196 // Registration of the GPU implementations. 197 #define REGISTER_MATRIX_DIAG_GPU(type) \ 198 REGISTER_KERNEL_BUILDER( \ 199 Name("MatrixDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 200 MatrixDiagOp<GPUDevice, type>); \ 201 REGISTER_KERNEL_BUILDER( \ 202 Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 203 MatrixDiagPartOp<GPUDevice, type>); 204 TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU); 205 TF_CALL_bool(REGISTER_MATRIX_DIAG_GPU); 206 TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU); 207 TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU); 208 #undef REGISTER_MATRIX_DIAG_GPU 209 210 // Registration of the deprecated kernel. 211 // Delete after 10mar2017. 212 #define REGISTER_BATCH_MATRIX_DIAG_GPU(type) \ 213 REGISTER_KERNEL_BUILDER( \ 214 Name("BatchMatrixDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 215 MatrixDiagOp<GPUDevice, type>); \ 216 REGISTER_KERNEL_BUILDER(Name("BatchMatrixDiagPart") \ 217 .Device(DEVICE_GPU) \ 218 .TypeConstraint<type>("T"), \ 219 MatrixDiagPartOp<GPUDevice, type>); 220 TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG_GPU); 221 #undef REGISTER_BATCH_MATRIX_DIAG_GPU 222 223 #endif // GOOGLE_CUDA 224 225 } // namespace tensorflow 226