Home | History | Annotate | Download | only in MatrixFunctions
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2011 Jitse Niesen <jitse (at) maths.leeds.ac.uk>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_MATRIX_SQUARE_ROOT
     11 #define EIGEN_MATRIX_SQUARE_ROOT
     12 
     13 namespace Eigen {
     14 
     15 /** \ingroup MatrixFunctions_Module
     16   * \brief Class for computing matrix square roots of upper quasi-triangular matrices.
     17   * \tparam  MatrixType  type of the argument of the matrix square root,
     18   *                      expected to be an instantiation of the Matrix class template.
     19   *
     20   * This class computes the square root of the upper quasi-triangular
     21   * matrix stored in the upper Hessenberg part of the matrix passed to
     22   * the constructor.
     23   *
     24   * \sa MatrixSquareRoot, MatrixSquareRootTriangular
     25   */
     26 template <typename MatrixType>
     27 class MatrixSquareRootQuasiTriangular
     28 {
     29   public:
     30 
     31     /** \brief Constructor.
     32       *
     33       * \param[in]  A  upper quasi-triangular matrix whose square root
     34       *                is to be computed.
     35       *
     36       * The class stores a reference to \p A, so it should not be
     37       * changed (or destroyed) before compute() is called.
     38       */
     39     MatrixSquareRootQuasiTriangular(const MatrixType& A)
     40       : m_A(A)
     41     {
     42       eigen_assert(A.rows() == A.cols());
     43     }
     44 
     45     /** \brief Compute the matrix square root
     46       *
     47       * \param[out] result  square root of \p A, as specified in the constructor.
     48       *
     49       * Only the upper Hessenberg part of \p result is updated, the
     50       * rest is not touched.  See MatrixBase::sqrt() for details on
     51       * how this computation is implemented.
     52       */
     53     template <typename ResultType> void compute(ResultType &result);
     54 
     55   private:
     56     typedef typename MatrixType::Index Index;
     57     typedef typename MatrixType::Scalar Scalar;
     58 
     59     void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
     60     void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
     61     void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
     62     void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
     63   				  typename MatrixType::Index i, typename MatrixType::Index j);
     64     void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
     65   				  typename MatrixType::Index i, typename MatrixType::Index j);
     66     void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
     67   				  typename MatrixType::Index i, typename MatrixType::Index j);
     68     void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
     69   				  typename MatrixType::Index i, typename MatrixType::Index j);
     70 
     71     template <typename SmallMatrixType>
     72     static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
     73   				     const SmallMatrixType& B, const SmallMatrixType& C);
     74 
     75     const MatrixType& m_A;
     76 };
     77 
     78 template <typename MatrixType>
     79 template <typename ResultType>
     80 void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
     81 {
     82   // Compute Schur decomposition of m_A
     83   const RealSchur<MatrixType> schurOfA(m_A);
     84   const MatrixType& T = schurOfA.matrixT();
     85   const MatrixType& U = schurOfA.matrixU();
     86 
     87   // Compute square root of T
     88   MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
     89   computeDiagonalPartOfSqrt(sqrtT, T);
     90   computeOffDiagonalPartOfSqrt(sqrtT, T);
     91 
     92   // Compute square root of m_A
     93   result = U * sqrtT * U.adjoint();
     94 }
     95 
     96 // pre:  T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
     97 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
     98 template <typename MatrixType>
     99 void MatrixSquareRootQuasiTriangular<MatrixType>::computeDiagonalPartOfSqrt(MatrixType& sqrtT,
    100 									  const MatrixType& T)
    101 {
    102   const Index size = m_A.rows();
    103   for (Index i = 0; i < size; i++) {
    104     if (i == size - 1 || T.coeff(i+1, i) == 0) {
    105       eigen_assert(T(i,i) > 0);
    106       sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
    107     }
    108     else {
    109       compute2x2diagonalBlock(sqrtT, T, i);
    110       ++i;
    111     }
    112   }
    113 }
    114 
    115 // pre:  T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
    116 // post: sqrtT is the square root of T.
    117 template <typename MatrixType>
    118 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
    119 									     const MatrixType& T)
    120 {
    121   const Index size = m_A.rows();
    122   for (Index j = 1; j < size; j++) {
    123       if (T.coeff(j, j-1) != 0)  // if T(j-1:j, j-1:j) is a 2-by-2 block
    124 	continue;
    125     for (Index i = j-1; i >= 0; i--) {
    126       if (i > 0 && T.coeff(i, i-1) != 0)  // if T(i-1:i, i-1:i) is a 2-by-2 block
    127 	continue;
    128       bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
    129       bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
    130       if (iBlockIs2x2 && jBlockIs2x2)
    131 	compute2x2offDiagonalBlock(sqrtT, T, i, j);
    132       else if (iBlockIs2x2 && !jBlockIs2x2)
    133 	compute2x1offDiagonalBlock(sqrtT, T, i, j);
    134       else if (!iBlockIs2x2 && jBlockIs2x2)
    135 	compute1x2offDiagonalBlock(sqrtT, T, i, j);
    136       else if (!iBlockIs2x2 && !jBlockIs2x2)
    137 	compute1x1offDiagonalBlock(sqrtT, T, i, j);
    138     }
    139   }
    140 }
    141 
    142 // pre:  T.block(i,i,2,2) has complex conjugate eigenvalues
    143 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
    144 template <typename MatrixType>
    145 void MatrixSquareRootQuasiTriangular<MatrixType>
    146      ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
    147 {
    148   // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
    149   //       in EigenSolver. If we expose it, we could call it directly from here.
    150   Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
    151   EigenSolver<Matrix<Scalar,2,2> > es(block);
    152   sqrtT.template block<2,2>(i,i)
    153     = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
    154 }
    155 
    156 // pre:  block structure of T is such that (i,j) is a 1x1 block,
    157 //       all blocks of sqrtT to left of and below (i,j) are correct
    158 // post: sqrtT(i,j) has the correct value
    159 template <typename MatrixType>
    160 void MatrixSquareRootQuasiTriangular<MatrixType>
    161      ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
    162 				  typename MatrixType::Index i, typename MatrixType::Index j)
    163 {
    164   Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
    165   sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
    166 }
    167 
    168 // similar to compute1x1offDiagonalBlock()
    169 template <typename MatrixType>
    170 void MatrixSquareRootQuasiTriangular<MatrixType>
    171      ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
    172 				  typename MatrixType::Index i, typename MatrixType::Index j)
    173 {
    174   Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
    175   if (j-i > 1)
    176     rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
    177   Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
    178   A += sqrtT.template block<2,2>(j,j).transpose();
    179   sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
    180 }
    181 
    182 // similar to compute1x1offDiagonalBlock()
    183 template <typename MatrixType>
    184 void MatrixSquareRootQuasiTriangular<MatrixType>
    185      ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
    186 				  typename MatrixType::Index i, typename MatrixType::Index j)
    187 {
    188   Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
    189   if (j-i > 2)
    190     rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
    191   Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
    192   A += sqrtT.template block<2,2>(i,i);
    193   sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
    194 }
    195 
    196 // similar to compute1x1offDiagonalBlock()
    197 template <typename MatrixType>
    198 void MatrixSquareRootQuasiTriangular<MatrixType>
    199      ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
    200 				  typename MatrixType::Index i, typename MatrixType::Index j)
    201 {
    202   Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
    203   Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
    204   Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
    205   if (j-i > 2)
    206     C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
    207   Matrix<Scalar,2,2> X;
    208   solveAuxiliaryEquation(X, A, B, C);
    209   sqrtT.template block<2,2>(i,j) = X;
    210 }
    211 
    212 // solves the equation A X + X B = C where all matrices are 2-by-2
    213 template <typename MatrixType>
    214 template <typename SmallMatrixType>
    215 void MatrixSquareRootQuasiTriangular<MatrixType>
    216      ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
    217 			      const SmallMatrixType& B, const SmallMatrixType& C)
    218 {
    219   EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
    220 		      EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
    221 
    222   Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
    223   coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
    224   coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
    225   coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
    226   coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
    227   coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
    228   coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
    229   coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
    230   coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
    231   coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
    232   coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
    233   coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
    234   coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
    235 
    236   Matrix<Scalar,4,1> rhs;
    237   rhs.coeffRef(0) = C.coeff(0,0);
    238   rhs.coeffRef(1) = C.coeff(0,1);
    239   rhs.coeffRef(2) = C.coeff(1,0);
    240   rhs.coeffRef(3) = C.coeff(1,1);
    241 
    242   Matrix<Scalar,4,1> result;
    243   result = coeffMatrix.fullPivLu().solve(rhs);
    244 
    245   X.coeffRef(0,0) = result.coeff(0);
    246   X.coeffRef(0,1) = result.coeff(1);
    247   X.coeffRef(1,0) = result.coeff(2);
    248   X.coeffRef(1,1) = result.coeff(3);
    249 }
    250 
    251 
    252 /** \ingroup MatrixFunctions_Module
    253   * \brief Class for computing matrix square roots of upper triangular matrices.
    254   * \tparam  MatrixType  type of the argument of the matrix square root,
    255   *                      expected to be an instantiation of the Matrix class template.
    256   *
    257   * This class computes the square root of the upper triangular matrix
    258   * stored in the upper triangular part (including the diagonal) of
    259   * the matrix passed to the constructor.
    260   *
    261   * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular
    262   */
    263 template <typename MatrixType>
    264 class MatrixSquareRootTriangular
    265 {
    266   public:
    267     MatrixSquareRootTriangular(const MatrixType& A)
    268       : m_A(A)
    269     {
    270       eigen_assert(A.rows() == A.cols());
    271     }
    272 
    273     /** \brief Compute the matrix square root
    274       *
    275       * \param[out] result  square root of \p A, as specified in the constructor.
    276       *
    277       * Only the upper triangular part (including the diagonal) of
    278       * \p result is updated, the rest is not touched.  See
    279       * MatrixBase::sqrt() for details on how this computation is
    280       * implemented.
    281       */
    282     template <typename ResultType> void compute(ResultType &result);
    283 
    284  private:
    285     const MatrixType& m_A;
    286 };
    287 
    288 template <typename MatrixType>
    289 template <typename ResultType>
    290 void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
    291 {
    292   // Compute Schur decomposition of m_A
    293   const ComplexSchur<MatrixType> schurOfA(m_A);
    294   const MatrixType& T = schurOfA.matrixT();
    295   const MatrixType& U = schurOfA.matrixU();
    296 
    297   // Compute square root of T and store it in upper triangular part of result
    298   // This uses that the square root of triangular matrices can be computed directly.
    299   result.resize(m_A.rows(), m_A.cols());
    300   typedef typename MatrixType::Index Index;
    301   for (Index i = 0; i < m_A.rows(); i++) {
    302     result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
    303   }
    304   for (Index j = 1; j < m_A.cols(); j++) {
    305     for (Index i = j-1; i >= 0; i--) {
    306       typedef typename MatrixType::Scalar Scalar;
    307       // if i = j-1, then segment has length 0 so tmp = 0
    308       Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
    309       // denominator may be zero if original matrix is singular
    310       result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
    311     }
    312   }
    313 
    314   // Compute square root of m_A as U * result * U.adjoint()
    315   MatrixType tmp;
    316   tmp.noalias() = U * result.template triangularView<Upper>();
    317   result.noalias() = tmp * U.adjoint();
    318 }
    319 
    320 
    321 /** \ingroup MatrixFunctions_Module
    322   * \brief Class for computing matrix square roots of general matrices.
    323   * \tparam  MatrixType  type of the argument of the matrix square root,
    324   *                      expected to be an instantiation of the Matrix class template.
    325   *
    326   * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt()
    327   */
    328 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
    329 class MatrixSquareRoot
    330 {
    331   public:
    332 
    333     /** \brief Constructor.
    334       *
    335       * \param[in]  A  matrix whose square root is to be computed.
    336       *
    337       * The class stores a reference to \p A, so it should not be
    338       * changed (or destroyed) before compute() is called.
    339       */
    340     MatrixSquareRoot(const MatrixType& A);
    341 
    342     /** \brief Compute the matrix square root
    343       *
    344       * \param[out] result  square root of \p A, as specified in the constructor.
    345       *
    346       * See MatrixBase::sqrt() for details on how this computation is
    347       * implemented.
    348       */
    349     template <typename ResultType> void compute(ResultType &result);
    350 };
    351 
    352 
    353 // ********** Partial specialization for real matrices **********
    354 
    355 template <typename MatrixType>
    356 class MatrixSquareRoot<MatrixType, 0>
    357 {
    358   public:
    359 
    360     MatrixSquareRoot(const MatrixType& A)
    361       : m_A(A)
    362     {
    363       eigen_assert(A.rows() == A.cols());
    364     }
    365 
    366     template <typename ResultType> void compute(ResultType &result)
    367     {
    368       // Compute Schur decomposition of m_A
    369       const RealSchur<MatrixType> schurOfA(m_A);
    370       const MatrixType& T = schurOfA.matrixT();
    371       const MatrixType& U = schurOfA.matrixU();
    372 
    373       // Compute square root of T
    374       MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
    375       MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
    376       tmp.compute(sqrtT);
    377 
    378       // Compute square root of m_A
    379       result = U * sqrtT * U.adjoint();
    380     }
    381 
    382   private:
    383     const MatrixType& m_A;
    384 };
    385 
    386 
    387 // ********** Partial specialization for complex matrices **********
    388 
    389 template <typename MatrixType>
    390 class MatrixSquareRoot<MatrixType, 1>
    391 {
    392   public:
    393 
    394     MatrixSquareRoot(const MatrixType& A)
    395       : m_A(A)
    396     {
    397       eigen_assert(A.rows() == A.cols());
    398     }
    399 
    400     template <typename ResultType> void compute(ResultType &result)
    401     {
    402       // Compute Schur decomposition of m_A
    403       const ComplexSchur<MatrixType> schurOfA(m_A);
    404       const MatrixType& T = schurOfA.matrixT();
    405       const MatrixType& U = schurOfA.matrixU();
    406 
    407       // Compute square root of T
    408       MatrixSquareRootTriangular<MatrixType> tmp(T);
    409       MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
    410       tmp.compute(sqrtT);
    411 
    412       // Compute square root of m_A
    413       result = U * sqrtT * U.adjoint();
    414     }
    415 
    416   private:
    417     const MatrixType& m_A;
    418 };
    419 
    420 
    421 /** \ingroup MatrixFunctions_Module
    422   *
    423   * \brief Proxy for the matrix square root of some matrix (expression).
    424   *
    425   * \tparam Derived  Type of the argument to the matrix square root.
    426   *
    427   * This class holds the argument to the matrix square root until it
    428   * is assigned or evaluated for some other reason (so the argument
    429   * should not be changed in the meantime). It is the return type of
    430   * MatrixBase::sqrt() and most of the time this is the only way it is
    431   * used.
    432   */
    433 template<typename Derived> class MatrixSquareRootReturnValue
    434 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
    435 {
    436     typedef typename Derived::Index Index;
    437   public:
    438     /** \brief Constructor.
    439       *
    440       * \param[in]  src  %Matrix (expression) forming the argument of the
    441       * matrix square root.
    442       */
    443     MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
    444 
    445     /** \brief Compute the matrix square root.
    446       *
    447       * \param[out]  result  the matrix square root of \p src in the
    448       * constructor.
    449       */
    450     template <typename ResultType>
    451     inline void evalTo(ResultType& result) const
    452     {
    453       const typename Derived::PlainObject srcEvaluated = m_src.eval();
    454       MatrixSquareRoot<typename Derived::PlainObject> me(srcEvaluated);
    455       me.compute(result);
    456     }
    457 
    458     Index rows() const { return m_src.rows(); }
    459     Index cols() const { return m_src.cols(); }
    460 
    461   protected:
    462     const Derived& m_src;
    463   private:
    464     MatrixSquareRootReturnValue& operator=(const MatrixSquareRootReturnValue&);
    465 };
    466 
    467 namespace internal {
    468 template<typename Derived>
    469 struct traits<MatrixSquareRootReturnValue<Derived> >
    470 {
    471   typedef typename Derived::PlainObject ReturnType;
    472 };
    473 }
    474 
    475 template <typename Derived>
    476 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
    477 {
    478   eigen_assert(rows() == cols());
    479   return MatrixSquareRootReturnValue<Derived>(derived());
    480 }
    481 
    482 } // end namespace Eigen
    483 
    484 #endif // EIGEN_MATRIX_FUNCTION
    485