1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2011 Jitse Niesen <jitse (at) maths.leeds.ac.uk> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_MATRIX_SQUARE_ROOT 11 #define EIGEN_MATRIX_SQUARE_ROOT 12 13 namespace Eigen { 14 15 /** \ingroup MatrixFunctions_Module 16 * \brief Class for computing matrix square roots of upper quasi-triangular matrices. 17 * \tparam MatrixType type of the argument of the matrix square root, 18 * expected to be an instantiation of the Matrix class template. 19 * 20 * This class computes the square root of the upper quasi-triangular 21 * matrix stored in the upper Hessenberg part of the matrix passed to 22 * the constructor. 23 * 24 * \sa MatrixSquareRoot, MatrixSquareRootTriangular 25 */ 26 template <typename MatrixType> 27 class MatrixSquareRootQuasiTriangular 28 { 29 public: 30 31 /** \brief Constructor. 32 * 33 * \param[in] A upper quasi-triangular matrix whose square root 34 * is to be computed. 35 * 36 * The class stores a reference to \p A, so it should not be 37 * changed (or destroyed) before compute() is called. 38 */ 39 MatrixSquareRootQuasiTriangular(const MatrixType& A) 40 : m_A(A) 41 { 42 eigen_assert(A.rows() == A.cols()); 43 } 44 45 /** \brief Compute the matrix square root 46 * 47 * \param[out] result square root of \p A, as specified in the constructor. 48 * 49 * Only the upper Hessenberg part of \p result is updated, the 50 * rest is not touched. See MatrixBase::sqrt() for details on 51 * how this computation is implemented. 52 */ 53 template <typename ResultType> void compute(ResultType &result); 54 55 private: 56 typedef typename MatrixType::Index Index; 57 typedef typename MatrixType::Scalar Scalar; 58 59 void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); 60 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); 61 void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); 62 void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 63 typename MatrixType::Index i, typename MatrixType::Index j); 64 void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 65 typename MatrixType::Index i, typename MatrixType::Index j); 66 void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 67 typename MatrixType::Index i, typename MatrixType::Index j); 68 void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 69 typename MatrixType::Index i, typename MatrixType::Index j); 70 71 template <typename SmallMatrixType> 72 static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 73 const SmallMatrixType& B, const SmallMatrixType& C); 74 75 const MatrixType& m_A; 76 }; 77 78 template <typename MatrixType> 79 template <typename ResultType> 80 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result) 81 { 82 // Compute Schur decomposition of m_A 83 const RealSchur<MatrixType> schurOfA(m_A); 84 const MatrixType& T = schurOfA.matrixT(); 85 const MatrixType& U = schurOfA.matrixU(); 86 87 // Compute square root of T 88 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); 89 computeDiagonalPartOfSqrt(sqrtT, T); 90 computeOffDiagonalPartOfSqrt(sqrtT, T); 91 92 // Compute square root of m_A 93 result = U * sqrtT * U.adjoint(); 94 } 95 96 // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size 97 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T 98 template <typename MatrixType> 99 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, 100 const MatrixType& T) 101 { 102 const Index size = m_A.rows(); 103 for (Index i = 0; i < size; i++) { 104 if (i == size - 1 || T.coeff(i+1, i) == 0) { 105 eigen_assert(T(i,i) > 0); 106 sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i)); 107 } 108 else { 109 compute2x2diagonalBlock(sqrtT, T, i); 110 ++i; 111 } 112 } 113 } 114 115 // pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. 116 // post: sqrtT is the square root of T. 117 template <typename MatrixType> 118 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, 119 const MatrixType& T) 120 { 121 const Index size = m_A.rows(); 122 for (Index j = 1; j < size; j++) { 123 if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block 124 continue; 125 for (Index i = j-1; i >= 0; i--) { 126 if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block 127 continue; 128 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0); 129 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0); 130 if (iBlockIs2x2 && jBlockIs2x2) 131 compute2x2offDiagonalBlock(sqrtT, T, i, j); 132 else if (iBlockIs2x2 && !jBlockIs2x2) 133 compute2x1offDiagonalBlock(sqrtT, T, i, j); 134 else if (!iBlockIs2x2 && jBlockIs2x2) 135 compute1x2offDiagonalBlock(sqrtT, T, i, j); 136 else if (!iBlockIs2x2 && !jBlockIs2x2) 137 compute1x1offDiagonalBlock(sqrtT, T, i, j); 138 } 139 } 140 } 141 142 // pre: T.block(i,i,2,2) has complex conjugate eigenvalues 143 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) 144 template <typename MatrixType> 145 void MatrixSquareRootQuasiTriangular<MatrixType> 146 ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i) 147 { 148 // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere 149 // in EigenSolver. If we expose it, we could call it directly from here. 150 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i); 151 EigenSolver<Matrix<Scalar,2,2> > es(block); 152 sqrtT.template block<2,2>(i,i) 153 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real(); 154 } 155 156 // pre: block structure of T is such that (i,j) is a 1x1 block, 157 // all blocks of sqrtT to left of and below (i,j) are correct 158 // post: sqrtT(i,j) has the correct value 159 template <typename MatrixType> 160 void MatrixSquareRootQuasiTriangular<MatrixType> 161 ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 162 typename MatrixType::Index i, typename MatrixType::Index j) 163 { 164 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); 165 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j)); 166 } 167 168 // similar to compute1x1offDiagonalBlock() 169 template <typename MatrixType> 170 void MatrixSquareRootQuasiTriangular<MatrixType> 171 ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 172 typename MatrixType::Index i, typename MatrixType::Index j) 173 { 174 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j); 175 if (j-i > 1) 176 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2); 177 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity(); 178 A += sqrtT.template block<2,2>(j,j).transpose(); 179 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose()); 180 } 181 182 // similar to compute1x1offDiagonalBlock() 183 template <typename MatrixType> 184 void MatrixSquareRootQuasiTriangular<MatrixType> 185 ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 186 typename MatrixType::Index i, typename MatrixType::Index j) 187 { 188 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j); 189 if (j-i > 2) 190 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1); 191 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity(); 192 A += sqrtT.template block<2,2>(i,i); 193 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs); 194 } 195 196 // similar to compute1x1offDiagonalBlock() 197 template <typename MatrixType> 198 void MatrixSquareRootQuasiTriangular<MatrixType> 199 ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 200 typename MatrixType::Index i, typename MatrixType::Index j) 201 { 202 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i); 203 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j); 204 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j); 205 if (j-i > 2) 206 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2); 207 Matrix<Scalar,2,2> X; 208 solveAuxiliaryEquation(X, A, B, C); 209 sqrtT.template block<2,2>(i,j) = X; 210 } 211 212 // solves the equation A X + X B = C where all matrices are 2-by-2 213 template <typename MatrixType> 214 template <typename SmallMatrixType> 215 void MatrixSquareRootQuasiTriangular<MatrixType> 216 ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 217 const SmallMatrixType& B, const SmallMatrixType& C) 218 { 219 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value), 220 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); 221 222 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero(); 223 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0); 224 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1); 225 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0); 226 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1); 227 coeffMatrix.coeffRef(0,1) = B.coeff(1,0); 228 coeffMatrix.coeffRef(0,2) = A.coeff(0,1); 229 coeffMatrix.coeffRef(1,0) = B.coeff(0,1); 230 coeffMatrix.coeffRef(1,3) = A.coeff(0,1); 231 coeffMatrix.coeffRef(2,0) = A.coeff(1,0); 232 coeffMatrix.coeffRef(2,3) = B.coeff(1,0); 233 coeffMatrix.coeffRef(3,1) = A.coeff(1,0); 234 coeffMatrix.coeffRef(3,2) = B.coeff(0,1); 235 236 Matrix<Scalar,4,1> rhs; 237 rhs.coeffRef(0) = C.coeff(0,0); 238 rhs.coeffRef(1) = C.coeff(0,1); 239 rhs.coeffRef(2) = C.coeff(1,0); 240 rhs.coeffRef(3) = C.coeff(1,1); 241 242 Matrix<Scalar,4,1> result; 243 result = coeffMatrix.fullPivLu().solve(rhs); 244 245 X.coeffRef(0,0) = result.coeff(0); 246 X.coeffRef(0,1) = result.coeff(1); 247 X.coeffRef(1,0) = result.coeff(2); 248 X.coeffRef(1,1) = result.coeff(3); 249 } 250 251 252 /** \ingroup MatrixFunctions_Module 253 * \brief Class for computing matrix square roots of upper triangular matrices. 254 * \tparam MatrixType type of the argument of the matrix square root, 255 * expected to be an instantiation of the Matrix class template. 256 * 257 * This class computes the square root of the upper triangular matrix 258 * stored in the upper triangular part (including the diagonal) of 259 * the matrix passed to the constructor. 260 * 261 * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular 262 */ 263 template <typename MatrixType> 264 class MatrixSquareRootTriangular 265 { 266 public: 267 MatrixSquareRootTriangular(const MatrixType& A) 268 : m_A(A) 269 { 270 eigen_assert(A.rows() == A.cols()); 271 } 272 273 /** \brief Compute the matrix square root 274 * 275 * \param[out] result square root of \p A, as specified in the constructor. 276 * 277 * Only the upper triangular part (including the diagonal) of 278 * \p result is updated, the rest is not touched. See 279 * MatrixBase::sqrt() for details on how this computation is 280 * implemented. 281 */ 282 template <typename ResultType> void compute(ResultType &result); 283 284 private: 285 const MatrixType& m_A; 286 }; 287 288 template <typename MatrixType> 289 template <typename ResultType> 290 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) 291 { 292 // Compute Schur decomposition of m_A 293 const ComplexSchur<MatrixType> schurOfA(m_A); 294 const MatrixType& T = schurOfA.matrixT(); 295 const MatrixType& U = schurOfA.matrixU(); 296 297 // Compute square root of T and store it in upper triangular part of result 298 // This uses that the square root of triangular matrices can be computed directly. 299 result.resize(m_A.rows(), m_A.cols()); 300 typedef typename MatrixType::Index Index; 301 for (Index i = 0; i < m_A.rows(); i++) { 302 result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i)); 303 } 304 for (Index j = 1; j < m_A.cols(); j++) { 305 for (Index i = j-1; i >= 0; i--) { 306 typedef typename MatrixType::Scalar Scalar; 307 // if i = j-1, then segment has length 0 so tmp = 0 308 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); 309 // denominator may be zero if original matrix is singular 310 result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); 311 } 312 } 313 314 // Compute square root of m_A as U * result * U.adjoint() 315 MatrixType tmp; 316 tmp.noalias() = U * result.template triangularView<Upper>(); 317 result.noalias() = tmp * U.adjoint(); 318 } 319 320 321 /** \ingroup MatrixFunctions_Module 322 * \brief Class for computing matrix square roots of general matrices. 323 * \tparam MatrixType type of the argument of the matrix square root, 324 * expected to be an instantiation of the Matrix class template. 325 * 326 * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt() 327 */ 328 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex> 329 class MatrixSquareRoot 330 { 331 public: 332 333 /** \brief Constructor. 334 * 335 * \param[in] A matrix whose square root is to be computed. 336 * 337 * The class stores a reference to \p A, so it should not be 338 * changed (or destroyed) before compute() is called. 339 */ 340 MatrixSquareRoot(const MatrixType& A); 341 342 /** \brief Compute the matrix square root 343 * 344 * \param[out] result square root of \p A, as specified in the constructor. 345 * 346 * See MatrixBase::sqrt() for details on how this computation is 347 * implemented. 348 */ 349 template <typename ResultType> void compute(ResultType &result); 350 }; 351 352 353 // ********** Partial specialization for real matrices ********** 354 355 template <typename MatrixType> 356 class MatrixSquareRoot<MatrixType, 0> 357 { 358 public: 359 360 MatrixSquareRoot(const MatrixType& A) 361 : m_A(A) 362 { 363 eigen_assert(A.rows() == A.cols()); 364 } 365 366 template <typename ResultType> void compute(ResultType &result) 367 { 368 // Compute Schur decomposition of m_A 369 const RealSchur<MatrixType> schurOfA(m_A); 370 const MatrixType& T = schurOfA.matrixT(); 371 const MatrixType& U = schurOfA.matrixU(); 372 373 // Compute square root of T 374 MatrixSquareRootQuasiTriangular<MatrixType> tmp(T); 375 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); 376 tmp.compute(sqrtT); 377 378 // Compute square root of m_A 379 result = U * sqrtT * U.adjoint(); 380 } 381 382 private: 383 const MatrixType& m_A; 384 }; 385 386 387 // ********** Partial specialization for complex matrices ********** 388 389 template <typename MatrixType> 390 class MatrixSquareRoot<MatrixType, 1> 391 { 392 public: 393 394 MatrixSquareRoot(const MatrixType& A) 395 : m_A(A) 396 { 397 eigen_assert(A.rows() == A.cols()); 398 } 399 400 template <typename ResultType> void compute(ResultType &result) 401 { 402 // Compute Schur decomposition of m_A 403 const ComplexSchur<MatrixType> schurOfA(m_A); 404 const MatrixType& T = schurOfA.matrixT(); 405 const MatrixType& U = schurOfA.matrixU(); 406 407 // Compute square root of T 408 MatrixSquareRootTriangular<MatrixType> tmp(T); 409 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); 410 tmp.compute(sqrtT); 411 412 // Compute square root of m_A 413 result = U * sqrtT * U.adjoint(); 414 } 415 416 private: 417 const MatrixType& m_A; 418 }; 419 420 421 /** \ingroup MatrixFunctions_Module 422 * 423 * \brief Proxy for the matrix square root of some matrix (expression). 424 * 425 * \tparam Derived Type of the argument to the matrix square root. 426 * 427 * This class holds the argument to the matrix square root until it 428 * is assigned or evaluated for some other reason (so the argument 429 * should not be changed in the meantime). It is the return type of 430 * MatrixBase::sqrt() and most of the time this is the only way it is 431 * used. 432 */ 433 template<typename Derived> class MatrixSquareRootReturnValue 434 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> > 435 { 436 typedef typename Derived::Index Index; 437 public: 438 /** \brief Constructor. 439 * 440 * \param[in] src %Matrix (expression) forming the argument of the 441 * matrix square root. 442 */ 443 MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { } 444 445 /** \brief Compute the matrix square root. 446 * 447 * \param[out] result the matrix square root of \p src in the 448 * constructor. 449 */ 450 template <typename ResultType> 451 inline void evalTo(ResultType& result) const 452 { 453 const typename Derived::PlainObject srcEvaluated = m_src.eval(); 454 MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated); 455 me.compute(result); 456 } 457 458 Index rows() const { return m_src.rows(); } 459 Index cols() const { return m_src.cols(); } 460 461 protected: 462 const Derived& m_src; 463 private: 464 MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&); 465 }; 466 467 namespace internal { 468 template<typename Derived> 469 struct traits<MatrixSquareRootReturnValue<Derived> > 470 { 471 typedef typename Derived::PlainObject ReturnType; 472 }; 473 } 474 475 template <typename Derived> 476 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const 477 { 478 eigen_assert(rows() == cols()); 479 return MatrixSquareRootReturnValue<Derived>(derived()); 480 } 481 482 } // end namespace Eigen 483 484 #endif // EIGEN_MATRIX_FUNCTION 485