1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2011 Jitse Niesen <jitse (at) maths.leeds.ac.uk> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_MATRIX_SQUARE_ROOT 11 #define EIGEN_MATRIX_SQUARE_ROOT 12 13 namespace Eigen { 14 15 /** \ingroup MatrixFunctions_Module 16 * \brief Class for computing matrix square roots of upper quasi-triangular matrices. 17 * \tparam MatrixType type of the argument of the matrix square root, 18 * expected to be an instantiation of the Matrix class template. 19 * 20 * This class computes the square root of the upper quasi-triangular 21 * matrix stored in the upper Hessenberg part of the matrix passed to 22 * the constructor. 23 * 24 * \sa MatrixSquareRoot, MatrixSquareRootTriangular 25 */ 26 template <typename MatrixType> 27 class MatrixSquareRootQuasiTriangular 28 { 29 public: 30 31 /** \brief Constructor. 32 * 33 * \param[in] A upper quasi-triangular matrix whose square root 34 * is to be computed. 35 * 36 * The class stores a reference to \p A, so it should not be 37 * changed (or destroyed) before compute() is called. 38 */ 39 MatrixSquareRootQuasiTriangular(const MatrixType& A) 40 : m_A(A) 41 { 42 eigen_assert(A.rows() == A.cols()); 43 } 44 45 /** \brief Compute the matrix square root 46 * 47 * \param[out] result square root of \p A, as specified in the constructor. 48 * 49 * Only the upper Hessenberg part of \p result is updated, the 50 * rest is not touched. See MatrixBase::sqrt() for details on 51 * how this computation is implemented. 52 */ 53 template <typename ResultType> void compute(ResultType &result); 54 55 private: 56 typedef typename MatrixType::Index Index; 57 typedef typename MatrixType::Scalar Scalar; 58 59 void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); 60 void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); 61 void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); 62 void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 63 typename MatrixType::Index i, typename MatrixType::Index j); 64 void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 65 typename MatrixType::Index i, typename MatrixType::Index j); 66 void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 67 typename MatrixType::Index i, typename MatrixType::Index j); 68 void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 69 typename MatrixType::Index i, typename MatrixType::Index j); 70 71 template <typename SmallMatrixType> 72 static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 73 const SmallMatrixType& B, const SmallMatrixType& C); 74 75 const MatrixType& m_A; 76 }; 77 78 template <typename MatrixType> 79 template <typename ResultType> 80 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result) 81 { 82 result.resize(m_A.rows(), m_A.cols()); 83 computeDiagonalPartOfSqrt(result, m_A); 84 computeOffDiagonalPartOfSqrt(result, m_A); 85 } 86 87 // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size 88 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T 89 template <typename MatrixType> 90 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT, 91 const MatrixType& T) 92 { 93 using std::sqrt; 94 const Index size = m_A.rows(); 95 for (Index i = 0; i < size; i++) { 96 if (i == size - 1 || T.coeff(i+1, i) == 0) { 97 eigen_assert(T(i,i) >= 0); 98 sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i)); 99 } 100 else { 101 compute2x2diagonalBlock(sqrtT, T, i); 102 ++i; 103 } 104 } 105 } 106 107 // pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. 108 // post: sqrtT is the square root of T. 109 template <typename MatrixType> 110 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, 111 const MatrixType& T) 112 { 113 const Index size = m_A.rows(); 114 for (Index j = 1; j < size; j++) { 115 if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block 116 continue; 117 for (Index i = j-1; i >= 0; i--) { 118 if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block 119 continue; 120 bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0); 121 bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0); 122 if (iBlockIs2x2 && jBlockIs2x2) 123 compute2x2offDiagonalBlock(sqrtT, T, i, j); 124 else if (iBlockIs2x2 && !jBlockIs2x2) 125 compute2x1offDiagonalBlock(sqrtT, T, i, j); 126 else if (!iBlockIs2x2 && jBlockIs2x2) 127 compute1x2offDiagonalBlock(sqrtT, T, i, j); 128 else if (!iBlockIs2x2 && !jBlockIs2x2) 129 compute1x1offDiagonalBlock(sqrtT, T, i, j); 130 } 131 } 132 } 133 134 // pre: T.block(i,i,2,2) has complex conjugate eigenvalues 135 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) 136 template <typename MatrixType> 137 void MatrixSquareRootQuasiTriangular<MatrixType> 138 ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i) 139 { 140 // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere 141 // in EigenSolver. If we expose it, we could call it directly from here. 142 Matrix<Scalar,2,2> block = T.template block<2,2>(i,i); 143 EigenSolver<Matrix<Scalar,2,2> > es(block); 144 sqrtT.template block<2,2>(i,i) 145 = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real(); 146 } 147 148 // pre: block structure of T is such that (i,j) is a 1x1 block, 149 // all blocks of sqrtT to left of and below (i,j) are correct 150 // post: sqrtT(i,j) has the correct value 151 template <typename MatrixType> 152 void MatrixSquareRootQuasiTriangular<MatrixType> 153 ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 154 typename MatrixType::Index i, typename MatrixType::Index j) 155 { 156 Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); 157 sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j)); 158 } 159 160 // similar to compute1x1offDiagonalBlock() 161 template <typename MatrixType> 162 void MatrixSquareRootQuasiTriangular<MatrixType> 163 ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 164 typename MatrixType::Index i, typename MatrixType::Index j) 165 { 166 Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j); 167 if (j-i > 1) 168 rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2); 169 Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity(); 170 A += sqrtT.template block<2,2>(j,j).transpose(); 171 sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose()); 172 } 173 174 // similar to compute1x1offDiagonalBlock() 175 template <typename MatrixType> 176 void MatrixSquareRootQuasiTriangular<MatrixType> 177 ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 178 typename MatrixType::Index i, typename MatrixType::Index j) 179 { 180 Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j); 181 if (j-i > 2) 182 rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1); 183 Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity(); 184 A += sqrtT.template block<2,2>(i,i); 185 sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs); 186 } 187 188 // similar to compute1x1offDiagonalBlock() 189 template <typename MatrixType> 190 void MatrixSquareRootQuasiTriangular<MatrixType> 191 ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, 192 typename MatrixType::Index i, typename MatrixType::Index j) 193 { 194 Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i); 195 Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j); 196 Matrix<Scalar,2,2> C = T.template block<2,2>(i,j); 197 if (j-i > 2) 198 C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2); 199 Matrix<Scalar,2,2> X; 200 solveAuxiliaryEquation(X, A, B, C); 201 sqrtT.template block<2,2>(i,j) = X; 202 } 203 204 // solves the equation A X + X B = C where all matrices are 2-by-2 205 template <typename MatrixType> 206 template <typename SmallMatrixType> 207 void MatrixSquareRootQuasiTriangular<MatrixType> 208 ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, 209 const SmallMatrixType& B, const SmallMatrixType& C) 210 { 211 EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value), 212 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); 213 214 Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero(); 215 coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0); 216 coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1); 217 coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0); 218 coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1); 219 coeffMatrix.coeffRef(0,1) = B.coeff(1,0); 220 coeffMatrix.coeffRef(0,2) = A.coeff(0,1); 221 coeffMatrix.coeffRef(1,0) = B.coeff(0,1); 222 coeffMatrix.coeffRef(1,3) = A.coeff(0,1); 223 coeffMatrix.coeffRef(2,0) = A.coeff(1,0); 224 coeffMatrix.coeffRef(2,3) = B.coeff(1,0); 225 coeffMatrix.coeffRef(3,1) = A.coeff(1,0); 226 coeffMatrix.coeffRef(3,2) = B.coeff(0,1); 227 228 Matrix<Scalar,4,1> rhs; 229 rhs.coeffRef(0) = C.coeff(0,0); 230 rhs.coeffRef(1) = C.coeff(0,1); 231 rhs.coeffRef(2) = C.coeff(1,0); 232 rhs.coeffRef(3) = C.coeff(1,1); 233 234 Matrix<Scalar,4,1> result; 235 result = coeffMatrix.fullPivLu().solve(rhs); 236 237 X.coeffRef(0,0) = result.coeff(0); 238 X.coeffRef(0,1) = result.coeff(1); 239 X.coeffRef(1,0) = result.coeff(2); 240 X.coeffRef(1,1) = result.coeff(3); 241 } 242 243 244 /** \ingroup MatrixFunctions_Module 245 * \brief Class for computing matrix square roots of upper triangular matrices. 246 * \tparam MatrixType type of the argument of the matrix square root, 247 * expected to be an instantiation of the Matrix class template. 248 * 249 * This class computes the square root of the upper triangular matrix 250 * stored in the upper triangular part (including the diagonal) of 251 * the matrix passed to the constructor. 252 * 253 * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular 254 */ 255 template <typename MatrixType> 256 class MatrixSquareRootTriangular 257 { 258 public: 259 MatrixSquareRootTriangular(const MatrixType& A) 260 : m_A(A) 261 { 262 eigen_assert(A.rows() == A.cols()); 263 } 264 265 /** \brief Compute the matrix square root 266 * 267 * \param[out] result square root of \p A, as specified in the constructor. 268 * 269 * Only the upper triangular part (including the diagonal) of 270 * \p result is updated, the rest is not touched. See 271 * MatrixBase::sqrt() for details on how this computation is 272 * implemented. 273 */ 274 template <typename ResultType> void compute(ResultType &result); 275 276 private: 277 const MatrixType& m_A; 278 }; 279 280 template <typename MatrixType> 281 template <typename ResultType> 282 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) 283 { 284 using std::sqrt; 285 286 // Compute square root of m_A and store it in upper triangular part of result 287 // This uses that the square root of triangular matrices can be computed directly. 288 result.resize(m_A.rows(), m_A.cols()); 289 typedef typename MatrixType::Index Index; 290 for (Index i = 0; i < m_A.rows(); i++) { 291 result.coeffRef(i,i) = sqrt(m_A.coeff(i,i)); 292 } 293 for (Index j = 1; j < m_A.cols(); j++) { 294 for (Index i = j-1; i >= 0; i--) { 295 typedef typename MatrixType::Scalar Scalar; 296 // if i = j-1, then segment has length 0 so tmp = 0 297 Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); 298 // denominator may be zero if original matrix is singular 299 result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); 300 } 301 } 302 } 303 304 305 /** \ingroup MatrixFunctions_Module 306 * \brief Class for computing matrix square roots of general matrices. 307 * \tparam MatrixType type of the argument of the matrix square root, 308 * expected to be an instantiation of the Matrix class template. 309 * 310 * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt() 311 */ 312 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex> 313 class MatrixSquareRoot 314 { 315 public: 316 317 /** \brief Constructor. 318 * 319 * \param[in] A matrix whose square root is to be computed. 320 * 321 * The class stores a reference to \p A, so it should not be 322 * changed (or destroyed) before compute() is called. 323 */ 324 MatrixSquareRoot(const MatrixType& A); 325 326 /** \brief Compute the matrix square root 327 * 328 * \param[out] result square root of \p A, as specified in the constructor. 329 * 330 * See MatrixBase::sqrt() for details on how this computation is 331 * implemented. 332 */ 333 template <typename ResultType> void compute(ResultType &result); 334 }; 335 336 337 // ********** Partial specialization for real matrices ********** 338 339 template <typename MatrixType> 340 class MatrixSquareRoot<MatrixType, 0> 341 { 342 public: 343 344 MatrixSquareRoot(const MatrixType& A) 345 : m_A(A) 346 { 347 eigen_assert(A.rows() == A.cols()); 348 } 349 350 template <typename ResultType> void compute(ResultType &result) 351 { 352 // Compute Schur decomposition of m_A 353 const RealSchur<MatrixType> schurOfA(m_A); 354 const MatrixType& T = schurOfA.matrixT(); 355 const MatrixType& U = schurOfA.matrixU(); 356 357 // Compute square root of T 358 MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols()); 359 MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT); 360 361 // Compute square root of m_A 362 result = U * sqrtT * U.adjoint(); 363 } 364 365 private: 366 const MatrixType& m_A; 367 }; 368 369 370 // ********** Partial specialization for complex matrices ********** 371 372 template <typename MatrixType> 373 class MatrixSquareRoot<MatrixType, 1> 374 { 375 public: 376 377 MatrixSquareRoot(const MatrixType& A) 378 : m_A(A) 379 { 380 eigen_assert(A.rows() == A.cols()); 381 } 382 383 template <typename ResultType> void compute(ResultType &result) 384 { 385 // Compute Schur decomposition of m_A 386 const ComplexSchur<MatrixType> schurOfA(m_A); 387 const MatrixType& T = schurOfA.matrixT(); 388 const MatrixType& U = schurOfA.matrixU(); 389 390 // Compute square root of T 391 MatrixType sqrtT; 392 MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); 393 394 // Compute square root of m_A 395 result = U * (sqrtT.template triangularView<Upper>() * U.adjoint()); 396 } 397 398 private: 399 const MatrixType& m_A; 400 }; 401 402 403 /** \ingroup MatrixFunctions_Module 404 * 405 * \brief Proxy for the matrix square root of some matrix (expression). 406 * 407 * \tparam Derived Type of the argument to the matrix square root. 408 * 409 * This class holds the argument to the matrix square root until it 410 * is assigned or evaluated for some other reason (so the argument 411 * should not be changed in the meantime). It is the return type of 412 * MatrixBase::sqrt() and most of the time this is the only way it is 413 * used. 414 */ 415 template<typename Derived> class MatrixSquareRootReturnValue 416 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> > 417 { 418 typedef typename Derived::Index Index; 419 public: 420 /** \brief Constructor. 421 * 422 * \param[in] src %Matrix (expression) forming the argument of the 423 * matrix square root. 424 */ 425 MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { } 426 427 /** \brief Compute the matrix square root. 428 * 429 * \param[out] result the matrix square root of \p src in the 430 * constructor. 431 */ 432 template <typename ResultType> 433 inline void evalTo(ResultType& result) const 434 { 435 const typename Derived::PlainObject srcEvaluated = m_src.eval(); 436 MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated); 437 me.compute(result); 438 } 439 440 Index rows() const { return m_src.rows(); } 441 Index cols() const { return m_src.cols(); } 442 443 protected: 444 const Derived& m_src; 445 private: 446 MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&); 447 }; 448 449 namespace internal { 450 template<typename Derived> 451 struct traits<MatrixSquareRootReturnValue<Derived> > 452 { 453 typedef typename Derived::PlainObject ReturnType; 454 }; 455 } 456 457 template <typename Derived> 458 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const 459 { 460 eigen_assert(rows() == cols()); 461 return MatrixSquareRootReturnValue<Derived>(derived()); 462 } 463 464 } // end namespace Eigen 465 466 #endif // EIGEN_MATRIX_FUNCTION 467