1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/linalg_ops.cc. 17 18 #include <cmath> 19 20 #if GOOGLE_CUDA 21 #define EIGEN_USE_GPU 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/kernels/determinant_op.h" 24 #endif 25 26 #include "third_party/eigen3/Eigen/LU" 27 #include "tensorflow/core/framework/kernel_def_builder.h" 28 #include "tensorflow/core/framework/numeric_types.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/tensor_shape.h" 31 #include "tensorflow/core/kernels/linalg_ops_common.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/types.h" 35 36 #if GOOGLE_CUDA 37 #include "tensorflow/core/kernels/cuda_solvers.h" 38 #include "tensorflow/core/kernels/fill_functor.h" 39 #endif 40 41 namespace tensorflow { 42 43 // A helper function to compute the sign and absolute value of the log of the 44 // determinant of inputs via a partially pivoted LU 45 // factorization. 46 // 47 // Returns the log of the absolute value of the determinant, and its sign in 48 // 'sign'. 49 template <class Scalar> 50 static typename Eigen::NumTraits<Scalar>::Real SLogDet( 51 const Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>& inputs, 52 Scalar* sign) { 53 using RealScalar = typename Eigen::NumTraits<Scalar>::Real; 54 RealScalar log_abs_det = 0; 55 *sign = 1; 56 // An empty matrix' determinant is defined to be 1. 57 // (https://en.wikipedia.org/wiki/Determinant) 58 if (inputs.size() > 0) { 59 // Compute the log determinant through a Partially Pivoted LU decomposition 60 using Eigen::Dynamic; 61 Eigen::PartialPivLU<Eigen::Matrix<Scalar, Dynamic, Dynamic>> lu(inputs); 62 Eigen::Matrix<Scalar, Dynamic, Dynamic> LU = lu.matrixLU(); 63 *sign = lu.permutationP().determinant(); 64 auto diag = LU.diagonal().array().eval(); 65 auto abs_diag = diag.cwiseAbs().eval(); 66 log_abs_det += abs_diag.log().sum(); 67 *sign *= (diag / abs_diag).prod(); 68 } 69 if (!Eigen::numext::isfinite(log_abs_det)) { 70 *sign = 0; 71 log_abs_det = 72 log_abs_det > 0 ? -std::log(RealScalar(0)) : std::log(RealScalar(0)); 73 } 74 return log_abs_det; 75 } 76 77 template <class Scalar> 78 class LogDeterminantOp : public LinearAlgebraOp<Scalar> { 79 public: 80 INHERIT_LINALG_TYPEDEFS(Scalar); 81 82 explicit LogDeterminantOp(OpKernelConstruction* context) : Base(context) {} 83 84 TensorShapes GetOutputMatrixShapes( 85 const TensorShapes& input_matrix_shapes) const final { 86 return TensorShapes({TensorShape({}), TensorShape({})}); 87 } 88 89 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 90 MatrixMaps* outputs) final { 91 Scalar sign; 92 const RealScalar log_abs_det = SLogDet( 93 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>(inputs[0]), 94 &sign); 95 96 outputs->at(0)(0, 0) = sign; 97 outputs->at(1)(0, 0) = log_abs_det; 98 } 99 }; 100 101 template <class Scalar> 102 class DeterminantOp : public LinearAlgebraOp<Scalar> { 103 public: 104 INHERIT_LINALG_TYPEDEFS(Scalar); 105 106 explicit DeterminantOp(OpKernelConstruction* context) : Base(context) {} 107 108 TensorShapes GetOutputMatrixShapes( 109 const TensorShapes& input_matrix_shape) const final { 110 return TensorShapes({TensorShape({})}); 111 } 112 113 void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, 114 MatrixMaps* outputs) final { 115 Scalar sign; 116 const RealScalar log_abs_det = SLogDet( 117 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>(inputs[0]), 118 &sign); 119 outputs->at(0)(0, 0) = sign * std::exp(log_abs_det); 120 } 121 }; 122 123 #if GOOGLE_CUDA 124 125 typedef Eigen::GpuDevice GPUDevice; 126 127 template <class Scalar> 128 class DeterminantOpGpu : public AsyncOpKernel { 129 public: 130 explicit DeterminantOpGpu(OpKernelConstruction* context) 131 : AsyncOpKernel(context) {} 132 133 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 134 const Tensor& input = context->input(0); 135 const int ndims = input.dims(); 136 const int64 n = input.dim_size(ndims - 1); 137 // Validate inputs. 138 OP_REQUIRES_ASYNC( 139 context, ndims >= 2, 140 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 141 done); 142 OP_REQUIRES_ASYNC( 143 context, input.dim_size(ndims - 2) == n, 144 errors::InvalidArgument("Input matrices must be square, got", 145 input.dim_size(ndims - 2), " != ", n), 146 done); 147 148 // Allocate output. 149 TensorShape out_shape; 150 for (int dim = 0; dim < ndims - 2; ++dim) { 151 out_shape.AddDim(input.dim_size(dim)); 152 } 153 out_shape.AppendShape(TensorShape({})); 154 Tensor* out; 155 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &out), 156 done); 157 158 // By definition, the determinant of an empty matrix is equal to one. 159 const GPUDevice& d = context->eigen_device<GPUDevice>(); 160 if (input.NumElements() == 0) { 161 functor::SetOneFunctor<GPUDevice, Scalar> f; 162 f(d, out->template flat<Scalar>()); 163 done(); 164 return; 165 } 166 167 // TODO(rmlarsen): Convert to absl::make_unique when available. 168 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 169 170 // Reuse the input buffer or make a copy for the factorization step, 171 // depending on whether this ops owns it exclusively. 172 Tensor input_copy; 173 OP_REQUIRES_OK_ASYNC( 174 context, 175 solver->forward_input_or_allocate_scoped_tensor( 176 {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy), 177 done); 178 if (!input.SharesBufferWith(input_copy)) { 179 d.memcpy(input_copy.flat<Scalar>().data(), input.flat<Scalar>().data(), 180 input.NumElements() * sizeof(Scalar)); 181 } 182 auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>(); 183 const int64 batch_size = input_copy_reshaped.dimension(0); 184 185 // Allocate pivots on the device. 186 Tensor pivots; 187 OP_REQUIRES_OK_ASYNC( 188 context, 189 solver->allocate_scoped_tensor(DataTypeToEnum<int>::value, 190 TensorShape{batch_size, n}, &pivots), 191 done); 192 auto pivots_mat = pivots.template matrix<int>(); 193 194 // Prepare pointer arrays for cuBlas' batch interface. 195 // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory 196 // without the ugly casting. 197 auto input_copy_ptrs = solver->GetScratchSpace<uint8>( 198 sizeof(Scalar*) * batch_size, "input_copy_ptrs", 199 /* on_host */ true); 200 auto output_reshaped = out->template flat_inner_dims<Scalar, 1>(); 201 202 // Compute the partially pivoted LU factorization(s) of the matrix/matrices. 203 std::vector<DeviceLapackInfo> dev_info; 204 if (n / batch_size <= 128) { 205 // For small matrices or large batch sizes, we use the batched interface 206 // from cuBlas. 207 const Scalar** input_copy_ptrs_base = 208 reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data()); 209 for (int batch = 0; batch < batch_size; ++batch) { 210 input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0); 211 } 212 dev_info.push_back( 213 solver->GetDeviceLapackInfo(batch_size, "getrfBatched")); 214 OP_REQUIRES_OK_ASYNC( 215 context, 216 solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(), 217 &dev_info.back(), batch_size), 218 done); 219 } else { 220 // For small batch sizes we use the non-batched interface from cuSolver, 221 // which is much faster for large matrices. 222 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf")); 223 for (int batch = 0; batch < batch_size; ++batch) { 224 OP_REQUIRES_OK_ASYNC( 225 context, 226 solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n, 227 &pivots_mat(batch, 0), &dev_info.back()(batch)), 228 done); 229 } 230 } 231 232 // Compute the determinant for each batch as (-1)^s * prod(diag(U)), 233 // where s is the order of the permutation encoded in pivots and U is the 234 // upper triangular factor of the LU factorization, which is written to 235 // input_copy by the Getrf{Batched} kernel. 236 functor::DeterminantFromPivotedLUFunctor<GPUDevice, Scalar> functor; 237 functor(d, 238 const_cast<const Tensor*>(&input_copy) 239 ->template flat_inner_dims<Scalar, 3>(), 240 pivots_mat.data(), output_reshaped, dev_info.back().mutable_data()); 241 242 // Register callback to check info after kernels finish. 243 auto info_checker = [context, done]( 244 const Status& status, 245 const std::vector<HostLapackInfo>& host_infos) { 246 if (!status.ok() && errors::IsInvalidArgument(status) && 247 !host_infos.empty()) { 248 for (int i = 0; i < host_infos[0].size(); ++i) { 249 // It is OK for a matrix to be singular (signaled by info > 0), 250 // corresponding to determinant of zero, but we do want to catch 251 // invalid arguments to Getrf{Batched}. 252 OP_REQUIRES_ASYNC( 253 context, host_infos[0](i) >= 0, 254 errors::InvalidArgument("Invalid input argument no. ", 255 host_infos[0].data()[i], 256 " for batch index ", i, "."), 257 done); 258 } 259 } 260 done(); 261 }; 262 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 263 std::move(info_checker)); 264 } 265 }; 266 267 template <class Scalar> 268 class LogDeterminantOpGpu : public AsyncOpKernel { 269 public: 270 explicit LogDeterminantOpGpu(OpKernelConstruction* context) 271 : AsyncOpKernel(context) {} 272 273 void ComputeAsync(OpKernelContext* context, DoneCallback done) final { 274 const Tensor& input = context->input(0); 275 const int ndims = input.dims(); 276 const int64 n = input.dim_size(ndims - 1); 277 // Validate inputs. 278 OP_REQUIRES_ASYNC( 279 context, ndims >= 2, 280 errors::InvalidArgument("Input must have rank >= 2, got ", ndims), 281 done); 282 OP_REQUIRES_ASYNC( 283 context, input.dim_size(ndims - 2) == n, 284 errors::InvalidArgument("Input matrices must be square, got", 285 input.dim_size(ndims - 2), " != ", n), 286 done); 287 288 // Allocate output. 289 TensorShape out_shape; 290 for (int dim = 0; dim < ndims - 2; ++dim) { 291 out_shape.AddDim(input.dim_size(dim)); 292 } 293 out_shape.AppendShape(TensorShape({})); 294 Tensor* sign; 295 OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, out_shape, &sign), 296 done); 297 Tensor* log_abs_det; 298 OP_REQUIRES_OK_ASYNC( 299 context, context->allocate_output(1, out_shape, &log_abs_det), done); 300 301 // By definition, the determinant of an empty matrix is equal to one. 302 const GPUDevice& d = context->eigen_device<GPUDevice>(); 303 if (input.NumElements() == 0) { 304 functor::SetOneFunctor<GPUDevice, Scalar> one_func; 305 one_func(d, sign->template flat<Scalar>()); 306 functor::SetZeroFunctor<GPUDevice, Scalar> zero_func; 307 zero_func(d, log_abs_det->template flat<Scalar>()); 308 done(); 309 return; 310 } 311 312 // TODO(rmlarsen): Convert to absl::make_unique when available. 313 std::unique_ptr<CudaSolver> solver(new CudaSolver(context)); 314 315 // Reuse the input buffer or make a copy for the factorization step, 316 // depending on whether this ops owns it exclusively. 317 Tensor input_copy; 318 OP_REQUIRES_OK_ASYNC( 319 context, 320 solver->forward_input_or_allocate_scoped_tensor( 321 {0}, DataTypeToEnum<Scalar>::value, input.shape(), &input_copy), 322 done); 323 if (!input.SharesBufferWith(input_copy)) { 324 d.memcpy(input_copy.flat<Scalar>().data(), input.flat<Scalar>().data(), 325 input.NumElements() * sizeof(Scalar)); 326 } 327 auto input_copy_reshaped = input_copy.template flat_inner_dims<Scalar, 3>(); 328 const int64 batch_size = input_copy_reshaped.dimension(0); 329 330 // Allocate pivots on the device. 331 Tensor pivots; 332 OP_REQUIRES_OK_ASYNC( 333 context, 334 solver->allocate_scoped_tensor(DataTypeToEnum<int>::value, 335 TensorShape{batch_size, n}, &pivots), 336 done); 337 auto pivots_mat = pivots.template matrix<int>(); 338 339 // Prepare pointer arrays for cuBlas' batch interface. 340 // TODO(rmlarsen): Find a way to encode pointer arrays in pinned host memory 341 // without the ugly casting. 342 auto input_copy_ptrs = solver->GetScratchSpace<uint8>( 343 sizeof(Scalar*) * batch_size, "input_copy_ptrs", 344 /* on_host */ true); 345 346 // Compute the partially pivoted LU factorization(s) of the matrix/matrices. 347 std::vector<DeviceLapackInfo> dev_info; 348 if (n / batch_size <= 128) { 349 // For small matrices or large batch sizes, we use the batched interface 350 // from cuBlas. 351 const Scalar** input_copy_ptrs_base = 352 reinterpret_cast<const Scalar**>(input_copy_ptrs.mutable_data()); 353 for (int batch = 0; batch < batch_size; ++batch) { 354 input_copy_ptrs_base[batch] = &input_copy_reshaped(batch, 0, 0); 355 } 356 dev_info.push_back( 357 solver->GetDeviceLapackInfo(batch_size, "getrfBatched")); 358 OP_REQUIRES_OK_ASYNC( 359 context, 360 solver->GetrfBatched(n, input_copy_ptrs_base, n, pivots_mat.data(), 361 &dev_info.back(), batch_size), 362 done); 363 } else { 364 // For large matrices or small batch sizes we use the non-batched 365 // interface from cuSolver, which is much faster for large matrices. 366 dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "getrf")); 367 for (int batch = 0; batch < batch_size; ++batch) { 368 OP_REQUIRES_OK_ASYNC( 369 context, 370 solver->Getrf(n, n, &input_copy_reshaped(batch, 0, 0), n, 371 &pivots_mat(batch, 0), &dev_info.back()(batch)), 372 done); 373 } 374 } 375 376 auto input_copy_reshaped_const = 377 const_cast<const Tensor*>(&input_copy) 378 ->template flat_inner_dims<Scalar, 3>(); 379 auto sign_reshaped = sign->flat<Scalar>(); 380 auto log_abs_det_reshaped = log_abs_det->flat<Scalar>(); 381 // Compute the determinant for each batch as (-1)^s * prod(diag(U)), 382 // where s is the order of the permutation encoded in pivots and U is the 383 // upper triangular factor of the LU factorization, which is written to 384 // input_copy by the Getrf{Batched} kernel. 385 functor::LogDeterminantFromPivotedLUFunctor<GPUDevice, Scalar> functor; 386 functor(d, input_copy_reshaped_const, pivots_mat.data(), sign_reshaped, 387 log_abs_det_reshaped); 388 389 // Register callback to check info after kernels finish. 390 auto info_checker = [context, done]( 391 const Status& status, 392 const std::vector<HostLapackInfo>& host_infos) { 393 if (!status.ok() && errors::IsInvalidArgument(status) && 394 !host_infos.empty()) { 395 for (int i = 0; i < host_infos[0].size(); ++i) { 396 // It is OK for a matrix to be singular (signaled by info > 0), 397 // corresponding to determinant of zero, but we do want to catch 398 // invalid arguments to Getrf{Batched}. 399 OP_REQUIRES_ASYNC( 400 context, host_infos[0](i) >= 0, 401 errors::InvalidArgument("Invalid input argument no. ", 402 host_infos[0].data()[i], 403 " for batch index ", i, "."), 404 done); 405 } 406 } 407 done(); 408 }; 409 CudaSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, 410 std::move(info_checker)); 411 } 412 }; 413 414 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<float>), float); 415 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<double>), double); 416 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<complex64>), 417 complex64); 418 REGISTER_LINALG_OP_GPU("MatrixDeterminant", (DeterminantOpGpu<complex128>), 419 complex128); 420 421 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<float>), 422 float); 423 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<double>), 424 double); 425 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", (LogDeterminantOpGpu<complex64>), 426 complex64); 427 REGISTER_LINALG_OP_GPU("LogMatrixDeterminant", 428 (LogDeterminantOpGpu<complex128>), complex128); 429 #endif // GOOGLE_CUDA 430 431 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<float>), float); 432 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<double>), double); 433 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex64>), complex64); 434 REGISTER_LINALG_OP("MatrixDeterminant", (DeterminantOp<complex128>), 435 complex128); 436 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<float>), float); 437 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<double>), double); 438 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex64>), 439 complex64); 440 REGISTER_LINALG_OP("BatchMatrixDeterminant", (DeterminantOp<complex128>), 441 complex128); 442 443 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<float>), float); 444 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<double>), double); 445 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<complex64>), 446 complex64); 447 REGISTER_LINALG_OP("LogMatrixDeterminant", (LogDeterminantOp<complex128>), 448 complex128); 449 } // namespace tensorflow 450