1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #endif // GOOGLE_CUDA 23 24 #include "tensorflow/core/kernels/matrix_set_diag_op.h" 25 26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/framework/register_types.h" 29 #include "tensorflow/core/framework/tensor.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/framework/tensor_types.h" 32 #include "tensorflow/core/framework/types.h" 33 #include "tensorflow/core/lib/core/threadpool.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/macros.h" 36 37 namespace tensorflow { 38 39 typedef Eigen::ThreadPoolDevice CPUDevice; 40 typedef Eigen::GpuDevice GPUDevice; 41 42 template <typename Device, typename T> 43 class MatrixSetDiagOp : public OpKernel { 44 public: 45 explicit MatrixSetDiagOp(OpKernelConstruction* context) : OpKernel(context) {} 46 47 void Compute(OpKernelContext* context) override { 48 const Tensor& input = context->input(0); 49 const Tensor& diag = context->input(1); 50 51 const TensorShape& input_shape = input.shape(); 52 const TensorShape& diag_shape = diag.shape(); 53 const int rank = input_shape.dims(); 54 55 // Preliminary validation of sizes. 56 OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), 57 errors::InvalidArgument( 58 "input must be at least 2-dim, received shape: ", 59 input.shape().DebugString())); 60 61 // Check to make sure the last dimension of diag is equal to the smaller of 62 // the last two dimensions of input. 63 const int64 min_dim = std::min(input_shape.dim_size(rank - 1), 64 input_shape.dim_size(rank - 2)); 65 TensorShape expected_diag_shape = input_shape; 66 expected_diag_shape.RemoveLastDims(2); 67 expected_diag_shape.AddDim(min_dim); 68 OP_REQUIRES(context, expected_diag_shape == diag_shape, 69 errors::InvalidArgument( 70 "must have diagonal.shape == input.shape[:-2] + " 71 "min(input.shape[-2:]), but received input shape: ", 72 input_shape.DebugString(), 73 " and diagonal shape: ", diag_shape.DebugString())); 74 75 if (input.NumElements() == 0) { 76 // This is a no-op. 77 context->set_output(0, input); 78 return; 79 } 80 81 auto input_reshaped = input.flat_inner_dims<T, 3>(); 82 auto diag_reshaped = diag.flat_inner_dims<T, 2>(); 83 Tensor* output = nullptr; 84 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 85 {0}, 0, input_shape, &output)); 86 auto output_reshaped = output->flat_inner_dims<T, 3>(); 87 functor::MatrixSetDiag<Device, T>::Compute( 88 context, context->eigen_device<Device>(), input_reshaped, diag_reshaped, 89 output_reshaped); 90 } 91 92 private: 93 TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); 94 }; 95 96 #define REGISTER_MATRIX_SET_DIAG(type) \ 97 REGISTER_KERNEL_BUILDER( \ 98 Name("MatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 99 MatrixSetDiagOp<CPUDevice, type>); 100 TF_CALL_POD_TYPES(REGISTER_MATRIX_SET_DIAG); 101 #undef REGISTER_MATRIX_SET_DIAG 102 103 // Registration of the deprecated kernel. 104 // Delete after 10mar2017. 105 #define REGISTER_BATCH_MATRIX_SET_DIAG(type) \ 106 REGISTER_KERNEL_BUILDER( \ 107 Name("BatchMatrixSetDiag").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 108 MatrixSetDiagOp<CPUDevice, type>); 109 TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_SET_DIAG); 110 #undef REGISTER_BATCH_MATRIX_SET_DIAG 111 112 namespace functor { 113 114 // Implementation of the functor specialization for CPU. 115 template <typename T> 116 struct MatrixSetDiag<CPUDevice, T> { 117 static void Compute(OpKernelContext* context, const CPUDevice& device, 118 typename TTypes<T, 3>::ConstTensor input, 119 typename TTypes<T, 2>::ConstTensor diag, 120 typename TTypes<T, 3>::Tensor output) { 121 if (input.data() != output.data()) { 122 output.device(device) = input; 123 } 124 auto compute_shard = [&output, &diag](int64 begin, int64 end) { 125 for (int64 batch = begin; batch < end; ++batch) { 126 for (int64 col = 0; col < diag.dimension(1); ++col) { 127 output(batch, col, col) = diag(batch, col); 128 } 129 } 130 }; 131 auto thread_pool = 132 context->device()->tensorflow_cpu_worker_threads()->workers; 133 int64 cost_per_batch = 10 * output.dimension(1); // Heuristic. 134 thread_pool->ParallelFor(output.dimension(0), cost_per_batch, 135 std::move(compute_shard)); 136 } 137 }; 138 139 } // namespace functor 140 141 #if GOOGLE_CUDA 142 143 // Forward declarations of the functor specializations for GPU. 144 namespace functor { 145 #define DECLARE_GPU_SPEC(T) \ 146 template <> \ 147 void MatrixSetDiag<GPUDevice, T>::Compute( \ 148 OpKernelContext* context, const GPUDevice& d, \ 149 typename TTypes<T, 3>::ConstTensor input, \ 150 typename TTypes<T, 2>::ConstTensor diag, \ 151 typename TTypes<T, 3>::Tensor output); \ 152 extern template struct MatrixSetDiag<GPUDevice, T>; 153 154 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); 155 TF_CALL_bool(DECLARE_GPU_SPEC); 156 TF_CALL_complex64(DECLARE_GPU_SPEC); 157 TF_CALL_complex128(DECLARE_GPU_SPEC); 158 159 } // namespace functor 160 161 // Registration of the GPU implementations. 162 #define REGISTER_MATRIX_SET_DIAG_GPU(type) \ 163 REGISTER_KERNEL_BUILDER( \ 164 Name("MatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 165 MatrixSetDiagOp<GPUDevice, type>); 166 TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU); 167 TF_CALL_bool(REGISTER_MATRIX_SET_DIAG_GPU); 168 TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU); 169 TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU); 170 #undef REGISTER_MATRIX_SET_DIAG_GPU 171 172 // Registration of the deprecated kernel. 173 // Delete after 10mar2017. 174 #define REGISTER_BATCH_MATRIX_SET_DIAG_GPU(type) \ 175 REGISTER_KERNEL_BUILDER( \ 176 Name("BatchMatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 177 MatrixSetDiagOp<GPUDevice, type>); 178 TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_SET_DIAG_GPU); 179 #undef REGISTER_BATCH_MATRIX_SET_DIAG_GPU 180 181 #endif // GOOGLE_CUDA 182 183 } // namespace tensorflow 184