Home | History | Annotate | Download | only in Core
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_SOLVE_H
     11 #define EIGEN_SOLVE_H
     12 
     13 namespace Eigen {
     14 
     15 template<typename Decomposition, typename RhsType, typename StorageKind> class SolveImpl;
     16 
     17 /** \class Solve
     18   * \ingroup Core_Module
     19   *
     20   * \brief Pseudo expression representing a solving operation
     21   *
     22   * \tparam Decomposition the type of the matrix or decomposion object
     23   * \tparam Rhstype the type of the right-hand side
     24   *
     25   * This class represents an expression of A.solve(B)
     26   * and most of the time this is the only way it is used.
     27   *
     28   */
     29 namespace internal {
     30 
     31 // this solve_traits class permits to determine the evaluation type with respect to storage kind (Dense vs Sparse)
     32 template<typename Decomposition, typename RhsType,typename StorageKind> struct solve_traits;
     33 
     34 template<typename Decomposition, typename RhsType>
     35 struct solve_traits<Decomposition,RhsType,Dense>
     36 {
     37   typedef Matrix<typename RhsType::Scalar,
     38                  Decomposition::ColsAtCompileTime,
     39                  RhsType::ColsAtCompileTime,
     40                  RhsType::PlainObject::Options,
     41                  Decomposition::MaxColsAtCompileTime,
     42                  RhsType::MaxColsAtCompileTime> PlainObject;
     43 };
     44 
     45 template<typename Decomposition, typename RhsType>
     46 struct traits<Solve<Decomposition, RhsType> >
     47   : traits<typename solve_traits<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>::PlainObject>
     48 {
     49   typedef typename solve_traits<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>::PlainObject PlainObject;
     50   typedef typename promote_index_type<typename Decomposition::StorageIndex, typename RhsType::StorageIndex>::type StorageIndex;
     51   typedef traits<PlainObject> BaseTraits;
     52   enum {
     53     Flags = BaseTraits::Flags & RowMajorBit,
     54     CoeffReadCost = HugeCost
     55   };
     56 };
     57 
     58 }
     59 
     60 
     61 template<typename Decomposition, typename RhsType>
     62 class Solve : public SolveImpl<Decomposition,RhsType,typename internal::traits<RhsType>::StorageKind>
     63 {
     64 public:
     65   typedef typename internal::traits<Solve>::PlainObject PlainObject;
     66   typedef typename internal::traits<Solve>::StorageIndex StorageIndex;
     67 
     68   Solve(const Decomposition &dec, const RhsType &rhs)
     69     : m_dec(dec), m_rhs(rhs)
     70   {}
     71 
     72   EIGEN_DEVICE_FUNC Index rows() const { return m_dec.cols(); }
     73   EIGEN_DEVICE_FUNC Index cols() const { return m_rhs.cols(); }
     74 
     75   EIGEN_DEVICE_FUNC const Decomposition& dec() const { return m_dec; }
     76   EIGEN_DEVICE_FUNC const RhsType&       rhs() const { return m_rhs; }
     77 
     78 protected:
     79   const Decomposition &m_dec;
     80   const RhsType       &m_rhs;
     81 };
     82 
     83 
     84 // Specialization of the Solve expression for dense results
     85 template<typename Decomposition, typename RhsType>
     86 class SolveImpl<Decomposition,RhsType,Dense>
     87   : public MatrixBase<Solve<Decomposition,RhsType> >
     88 {
     89   typedef Solve<Decomposition,RhsType> Derived;
     90 
     91 public:
     92 
     93   typedef MatrixBase<Solve<Decomposition,RhsType> > Base;
     94   EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
     95 
     96 private:
     97 
     98   Scalar coeff(Index row, Index col) const;
     99   Scalar coeff(Index i) const;
    100 };
    101 
    102 // Generic API dispatcher
    103 template<typename Decomposition, typename RhsType, typename StorageKind>
    104 class SolveImpl : public internal::generic_xpr_base<Solve<Decomposition,RhsType>, MatrixXpr, StorageKind>::type
    105 {
    106   public:
    107     typedef typename internal::generic_xpr_base<Solve<Decomposition,RhsType>, MatrixXpr, StorageKind>::type Base;
    108 };
    109 
    110 namespace internal {
    111 
    112 // Evaluator of Solve -> eval into a temporary
    113 template<typename Decomposition, typename RhsType>
    114 struct evaluator<Solve<Decomposition,RhsType> >
    115   : public evaluator<typename Solve<Decomposition,RhsType>::PlainObject>
    116 {
    117   typedef Solve<Decomposition,RhsType> SolveType;
    118   typedef typename SolveType::PlainObject PlainObject;
    119   typedef evaluator<PlainObject> Base;
    120 
    121   enum { Flags = Base::Flags | EvalBeforeNestingBit };
    122 
    123   EIGEN_DEVICE_FUNC explicit evaluator(const SolveType& solve)
    124     : m_result(solve.rows(), solve.cols())
    125   {
    126     ::new (static_cast<Base*>(this)) Base(m_result);
    127     solve.dec()._solve_impl(solve.rhs(), m_result);
    128   }
    129 
    130 protected:
    131   PlainObject m_result;
    132 };
    133 
    134 // Specialization for "dst = dec.solve(rhs)"
    135 // NOTE we need to specialize it for Dense2Dense to avoid ambiguous specialization error and a Sparse2Sparse specialization must exist somewhere
    136 template<typename DstXprType, typename DecType, typename RhsType, typename Scalar>
    137 struct Assignment<DstXprType, Solve<DecType,RhsType>, internal::assign_op<Scalar,Scalar>, Dense2Dense>
    138 {
    139   typedef Solve<DecType,RhsType> SrcXprType;
    140   static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
    141   {
    142     Index dstRows = src.rows();
    143     Index dstCols = src.cols();
    144     if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
    145       dst.resize(dstRows, dstCols);
    146 
    147     src.dec()._solve_impl(src.rhs(), dst);
    148   }
    149 };
    150 
    151 // Specialization for "dst = dec.transpose().solve(rhs)"
    152 template<typename DstXprType, typename DecType, typename RhsType, typename Scalar>
    153 struct Assignment<DstXprType, Solve<Transpose<const DecType>,RhsType>, internal::assign_op<Scalar,Scalar>, Dense2Dense>
    154 {
    155   typedef Solve<Transpose<const DecType>,RhsType> SrcXprType;
    156   static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
    157   {
    158     Index dstRows = src.rows();
    159     Index dstCols = src.cols();
    160     if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
    161       dst.resize(dstRows, dstCols);
    162 
    163     src.dec().nestedExpression().template _solve_impl_transposed<false>(src.rhs(), dst);
    164   }
    165 };
    166 
    167 // Specialization for "dst = dec.adjoint().solve(rhs)"
    168 template<typename DstXprType, typename DecType, typename RhsType, typename Scalar>
    169 struct Assignment<DstXprType, Solve<CwiseUnaryOp<internal::scalar_conjugate_op<typename DecType::Scalar>, const Transpose<const DecType> >,RhsType>,
    170                   internal::assign_op<Scalar,Scalar>, Dense2Dense>
    171 {
    172   typedef Solve<CwiseUnaryOp<internal::scalar_conjugate_op<typename DecType::Scalar>, const Transpose<const DecType> >,RhsType> SrcXprType;
    173   static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
    174   {
    175     Index dstRows = src.rows();
    176     Index dstCols = src.cols();
    177     if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
    178       dst.resize(dstRows, dstCols);
    179 
    180     src.dec().nestedExpression().nestedExpression().template _solve_impl_transposed<true>(src.rhs(), dst);
    181   }
    182 };
    183 
    184 } // end namepsace internal
    185 
    186 } // end namespace Eigen
    187 
    188 #endif // EIGEN_SOLVE_H
    189