Home | History | Annotate | Download | only in MatrixFunctions
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2011 Jitse Niesen <jitse (at) maths.leeds.ac.uk>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_MATRIX_SQUARE_ROOT
     11 #define EIGEN_MATRIX_SQUARE_ROOT
     12 
     13 namespace Eigen {
     14 
     15 /** \ingroup MatrixFunctions_Module
     16   * \brief Class for computing matrix square roots of upper quasi-triangular matrices.
     17   * \tparam  MatrixType  type of the argument of the matrix square root,
     18   *                      expected to be an instantiation of the Matrix class template.
     19   *
     20   * This class computes the square root of the upper quasi-triangular
     21   * matrix stored in the upper Hessenberg part of the matrix passed to
     22   * the constructor.
     23   *
     24   * \sa MatrixSquareRoot, MatrixSquareRootTriangular
     25   */
     26 template <typename MatrixType>
     27 class MatrixSquareRootQuasiTriangular
     28 {
     29   public:
     30 
     31     /** \brief Constructor.
     32       *
     33       * \param[in]  A  upper quasi-triangular matrix whose square root
     34       *                is to be computed.
     35       *
     36       * The class stores a reference to \p A, so it should not be
     37       * changed (or destroyed) before compute() is called.
     38       */
     39     MatrixSquareRootQuasiTriangular(const MatrixType& A)
     40       : m_A(A)
     41     {
     42       eigen_assert(A.rows() == A.cols());
     43     }
     44 
     45     /** \brief Compute the matrix square root
     46       *
     47       * \param[out] result  square root of \p A, as specified in the constructor.
     48       *
     49       * Only the upper Hessenberg part of \p result is updated, the
     50       * rest is not touched.  See MatrixBase::sqrt() for details on
     51       * how this computation is implemented.
     52       */
     53     template <typename ResultType> void compute(ResultType &result);
     54 
     55   private:
     56     typedef typename MatrixType::Index Index;
     57     typedef typename MatrixType::Scalar Scalar;
     58 
     59     void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
     60     void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
     61     void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
     62     void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
     63 				  typename MatrixType::Index i, typename MatrixType::Index j);
     64     void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
     65 				  typename MatrixType::Index i, typename MatrixType::Index j);
     66     void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
     67 				  typename MatrixType::Index i, typename MatrixType::Index j);
     68     void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
     69 				  typename MatrixType::Index i, typename MatrixType::Index j);
     70 
     71     template <typename SmallMatrixType>
     72     static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
     73 				     const SmallMatrixType& B, const SmallMatrixType& C);
     74 
     75     const MatrixType& m_A;
     76 };
     77 
     78 template <typename MatrixType>
     79 template <typename ResultType>
     80 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
     81 {
     82   result.resize(m_A.rows(), m_A.cols());
     83   computeDiagonalPartOfSqrt(result, m_A);
     84   computeOffDiagonalPartOfSqrt(result, m_A);
     85 }
     86 
     87 // pre:  T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
     88 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
     89 template <typename MatrixType>
     90 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
     91 									  const MatrixType& T)
     92 {
     93   using std::sqrt;
     94   const Index size = m_A.rows();
     95   for (Index i = 0; i < size; i++) {
     96     if (i == size - 1 || T.coeff(i+1, i) == 0) {
     97       eigen_assert(T(i,i) >= 0);
     98       sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i));
     99     }
    100     else {
    101       compute2x2diagonalBlock(sqrtT, T, i);
    102       ++i;
    103     }
    104   }
    105 }
    106 
    107 // pre:  T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
    108 // post: sqrtT is the square root of T.
    109 template <typename MatrixType>
    110 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
    111 									     const MatrixType& T)
    112 {
    113   const Index size = m_A.rows();
    114   for (Index j = 1; j < size; j++) {
    115       if (T.coeff(j, j-1) != 0)  // if T(j-1:j, j-1:j) is a 2-by-2 block
    116 	continue;
    117     for (Index i = j-1; i >= 0; i--) {
    118       if (i > 0 && T.coeff(i, i-1) != 0)  // if T(i-1:i, i-1:i) is a 2-by-2 block
    119 	continue;
    120       bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
    121       bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
    122       if (iBlockIs2x2 && jBlockIs2x2)
    123 	compute2x2offDiagonalBlock(sqrtT, T, i, j);
    124       else if (iBlockIs2x2 && !jBlockIs2x2)
    125 	compute2x1offDiagonalBlock(sqrtT, T, i, j);
    126       else if (!iBlockIs2x2 && jBlockIs2x2)
    127 	compute1x2offDiagonalBlock(sqrtT, T, i, j);
    128       else if (!iBlockIs2x2 && !jBlockIs2x2)
    129 	compute1x1offDiagonalBlock(sqrtT, T, i, j);
    130     }
    131   }
    132 }
    133 
    134 // pre:  T.block(i,i,2,2) has complex conjugate eigenvalues
    135 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
    136 template <typename MatrixType>
    137 void MatrixSquareRootQuasiTriangular<MatrixType>
    138      ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
    139 {
    140   // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
    141   //       in EigenSolver. If we expose it, we could call it directly from here.
    142   Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
    143   EigenSolver<Matrix<Scalar,2,2> > es(block);
    144   sqrtT.template block<2,2>(i,i)
    145     = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
    146 }
    147 
    148 // pre:  block structure of T is such that (i,j) is a 1x1 block,
    149 //       all blocks of sqrtT to left of and below (i,j) are correct
    150 // post: sqrtT(i,j) has the correct value
    151 template <typename MatrixType>
    152 void MatrixSquareRootQuasiTriangular<MatrixType>
    153      ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
    154 				  typename MatrixType::Index i, typename MatrixType::Index j)
    155 {
    156   Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
    157   sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
    158 }
    159 
    160 // similar to compute1x1offDiagonalBlock()
    161 template <typename MatrixType>
    162 void MatrixSquareRootQuasiTriangular<MatrixType>
    163      ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
    164 				  typename MatrixType::Index i, typename MatrixType::Index j)
    165 {
    166   Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
    167   if (j-i > 1)
    168     rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
    169   Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
    170   A += sqrtT.template block<2,2>(j,j).transpose();
    171   sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
    172 }
    173 
    174 // similar to compute1x1offDiagonalBlock()
    175 template <typename MatrixType>
    176 void MatrixSquareRootQuasiTriangular<MatrixType>
    177      ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
    178 				  typename MatrixType::Index i, typename MatrixType::Index j)
    179 {
    180   Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
    181   if (j-i > 2)
    182     rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
    183   Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
    184   A += sqrtT.template block<2,2>(i,i);
    185   sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
    186 }
    187 
    188 // similar to compute1x1offDiagonalBlock()
    189 template <typename MatrixType>
    190 void MatrixSquareRootQuasiTriangular<MatrixType>
    191      ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
    192 				  typename MatrixType::Index i, typename MatrixType::Index j)
    193 {
    194   Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
    195   Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
    196   Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
    197   if (j-i > 2)
    198     C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
    199   Matrix<Scalar,2,2> X;
    200   solveAuxiliaryEquation(X, A, B, C);
    201   sqrtT.template block<2,2>(i,j) = X;
    202 }
    203 
    204 // solves the equation A X + X B = C where all matrices are 2-by-2
    205 template <typename MatrixType>
    206 template <typename SmallMatrixType>
    207 void MatrixSquareRootQuasiTriangular<MatrixType>
    208      ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
    209 			      const SmallMatrixType& B, const SmallMatrixType& C)
    210 {
    211   EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
    212 		      EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
    213 
    214   Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
    215   coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
    216   coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
    217   coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
    218   coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
    219   coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
    220   coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
    221   coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
    222   coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
    223   coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
    224   coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
    225   coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
    226   coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
    227 
    228   Matrix<Scalar,4,1> rhs;
    229   rhs.coeffRef(0) = C.coeff(0,0);
    230   rhs.coeffRef(1) = C.coeff(0,1);
    231   rhs.coeffRef(2) = C.coeff(1,0);
    232   rhs.coeffRef(3) = C.coeff(1,1);
    233 
    234   Matrix<Scalar,4,1> result;
    235   result = coeffMatrix.fullPivLu().solve(rhs);
    236 
    237   X.coeffRef(0,0) = result.coeff(0);
    238   X.coeffRef(0,1) = result.coeff(1);
    239   X.coeffRef(1,0) = result.coeff(2);
    240   X.coeffRef(1,1) = result.coeff(3);
    241 }
    242 
    243 
    244 /** \ingroup MatrixFunctions_Module
    245   * \brief Class for computing matrix square roots of upper triangular matrices.
    246   * \tparam  MatrixType  type of the argument of the matrix square root,
    247   *                      expected to be an instantiation of the Matrix class template.
    248   *
    249   * This class computes the square root of the upper triangular matrix
    250   * stored in the upper triangular part (including the diagonal) of
    251   * the matrix passed to the constructor.
    252   *
    253   * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular
    254   */
    255 template <typename MatrixType>
    256 class MatrixSquareRootTriangular
    257 {
    258   public:
    259     MatrixSquareRootTriangular(const MatrixType& A)
    260       : m_A(A)
    261     {
    262       eigen_assert(A.rows() == A.cols());
    263     }
    264 
    265     /** \brief Compute the matrix square root
    266       *
    267       * \param[out] result  square root of \p A, as specified in the constructor.
    268       *
    269       * Only the upper triangular part (including the diagonal) of
    270       * \p result is updated, the rest is not touched.  See
    271       * MatrixBase::sqrt() for details on how this computation is
    272       * implemented.
    273       */
    274     template <typename ResultType> void compute(ResultType &result);
    275 
    276  private:
    277     const MatrixType& m_A;
    278 };
    279 
    280 template <typename MatrixType>
    281 template <typename ResultType>
    282 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
    283 {
    284   using std::sqrt;
    285 
    286   // Compute square root of m_A and store it in upper triangular part of result
    287   // This uses that the square root of triangular matrices can be computed directly.
    288   result.resize(m_A.rows(), m_A.cols());
    289   typedef typename MatrixType::Index Index;
    290   for (Index i = 0; i < m_A.rows(); i++) {
    291     result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
    292   }
    293   for (Index j = 1; j < m_A.cols(); j++) {
    294     for (Index i = j-1; i >= 0; i--) {
    295       typedef typename MatrixType::Scalar Scalar;
    296       // if i = j-1, then segment has length 0 so tmp = 0
    297       Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
    298       // denominator may be zero if original matrix is singular
    299       result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
    300     }
    301   }
    302 }
    303 
    304 
    305 /** \ingroup MatrixFunctions_Module
    306   * \brief Class for computing matrix square roots of general matrices.
    307   * \tparam  MatrixType  type of the argument of the matrix square root,
    308   *                      expected to be an instantiation of the Matrix class template.
    309   *
    310   * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt()
    311   */
    312 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
    313 class MatrixSquareRoot
    314 {
    315   public:
    316 
    317     /** \brief Constructor.
    318       *
    319       * \param[in]  A  matrix whose square root is to be computed.
    320       *
    321       * The class stores a reference to \p A, so it should not be
    322       * changed (or destroyed) before compute() is called.
    323       */
    324     MatrixSquareRoot(const MatrixType& A);
    325 
    326     /** \brief Compute the matrix square root
    327       *
    328       * \param[out] result  square root of \p A, as specified in the constructor.
    329       *
    330       * See MatrixBase::sqrt() for details on how this computation is
    331       * implemented.
    332       */
    333     template <typename ResultType> void compute(ResultType &result);
    334 };
    335 
    336 
    337 // ********** Partial specialization for real matrices **********
    338 
    339 template <typename MatrixType>
    340 class MatrixSquareRoot<MatrixType, 0>
    341 {
    342   public:
    343 
    344     MatrixSquareRoot(const MatrixType& A)
    345       : m_A(A)
    346     {
    347       eigen_assert(A.rows() == A.cols());
    348     }
    349 
    350     template <typename ResultType> void compute(ResultType &result)
    351     {
    352       // Compute Schur decomposition of m_A
    353       const RealSchur<MatrixType> schurOfA(m_A);
    354       const MatrixType& T = schurOfA.matrixT();
    355       const MatrixType& U = schurOfA.matrixU();
    356 
    357       // Compute square root of T
    358       MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
    359       MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT);
    360 
    361       // Compute square root of m_A
    362       result = U * sqrtT * U.adjoint();
    363     }
    364 
    365   private:
    366     const MatrixType& m_A;
    367 };
    368 
    369 
    370 // ********** Partial specialization for complex matrices **********
    371 
    372 template <typename MatrixType>
    373 class MatrixSquareRoot<MatrixType, 1>
    374 {
    375   public:
    376 
    377     MatrixSquareRoot(const MatrixType& A)
    378       : m_A(A)
    379     {
    380       eigen_assert(A.rows() == A.cols());
    381     }
    382 
    383     template <typename ResultType> void compute(ResultType &result)
    384     {
    385       // Compute Schur decomposition of m_A
    386       const ComplexSchur<MatrixType> schurOfA(m_A);
    387       const MatrixType& T = schurOfA.matrixT();
    388       const MatrixType& U = schurOfA.matrixU();
    389 
    390       // Compute square root of T
    391       MatrixType sqrtT;
    392       MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
    393 
    394       // Compute square root of m_A
    395       result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
    396     }
    397 
    398   private:
    399     const MatrixType& m_A;
    400 };
    401 
    402 
    403 /** \ingroup MatrixFunctions_Module
    404   *
    405   * \brief Proxy for the matrix square root of some matrix (expression).
    406   *
    407   * \tparam Derived  Type of the argument to the matrix square root.
    408   *
    409   * This class holds the argument to the matrix square root until it
    410   * is assigned or evaluated for some other reason (so the argument
    411   * should not be changed in the meantime). It is the return type of
    412   * MatrixBase::sqrt() and most of the time this is the only way it is
    413   * used.
    414   */
    415 template<typename Derived> class MatrixSquareRootReturnValue
    416 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
    417 {
    418     typedef typename Derived::Index Index;
    419   public:
    420     /** \brief Constructor.
    421       *
    422       * \param[in]  src  %Matrix (expression) forming the argument of the
    423       * matrix square root.
    424       */
    425     MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
    426 
    427     /** \brief Compute the matrix square root.
    428       *
    429       * \param[out]  result  the matrix square root of \p src in the
    430       * constructor.
    431       */
    432     template <typename ResultType>
    433     inline void evalTo(ResultType& result) const
    434     {
    435       const typename Derived::PlainObject srcEvaluated = m_src.eval();
    436       MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
    437       me.compute(result);
    438     }
    439 
    440     Index rows() const { return m_src.rows(); }
    441     Index cols() const { return m_src.cols(); }
    442 
    443   protected:
    444     const Derived& m_src;
    445   private:
    446     MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
    447 };
    448 
    449 namespace internal {
    450 template<typename Derived>
    451 struct traits<MatrixSquareRootReturnValue<Derived> >
    452 {
    453   typedef typename Derived::PlainObject ReturnType;
    454 };
    455 }
    456 
    457 template <typename Derived>
    458 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
    459 {
    460   eigen_assert(rows() == cols());
    461   return MatrixSquareRootReturnValue<Derived>(derived());
    462 }
    463 
    464 } // end namespace Eigen
    465 
    466 #endif // EIGEN_MATRIX_FUNCTION
    467