Home | History | Annotate | Download | only in Core
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_PRODUCT_H
     11 #define EIGEN_PRODUCT_H
     12 
     13 namespace Eigen {
     14 
     15 template<typename Lhs, typename Rhs, int Option, typename StorageKind> class ProductImpl;
     16 
     17 namespace internal {
     18 
     19 template<typename Lhs, typename Rhs, int Option>
     20 struct traits<Product<Lhs, Rhs, Option> >
     21 {
     22   typedef typename remove_all<Lhs>::type LhsCleaned;
     23   typedef typename remove_all<Rhs>::type RhsCleaned;
     24   typedef traits<LhsCleaned> LhsTraits;
     25   typedef traits<RhsCleaned> RhsTraits;
     26 
     27   typedef MatrixXpr XprKind;
     28 
     29   typedef typename ScalarBinaryOpTraits<typename traits<LhsCleaned>::Scalar, typename traits<RhsCleaned>::Scalar>::ReturnType Scalar;
     30   typedef typename product_promote_storage_type<typename LhsTraits::StorageKind,
     31                                                 typename RhsTraits::StorageKind,
     32                                                 internal::product_type<Lhs,Rhs>::ret>::ret StorageKind;
     33   typedef typename promote_index_type<typename LhsTraits::StorageIndex,
     34                                       typename RhsTraits::StorageIndex>::type StorageIndex;
     35 
     36   enum {
     37     RowsAtCompileTime    = LhsTraits::RowsAtCompileTime,
     38     ColsAtCompileTime    = RhsTraits::ColsAtCompileTime,
     39     MaxRowsAtCompileTime = LhsTraits::MaxRowsAtCompileTime,
     40     MaxColsAtCompileTime = RhsTraits::MaxColsAtCompileTime,
     41 
     42     // FIXME: only needed by GeneralMatrixMatrixTriangular
     43     InnerSize = EIGEN_SIZE_MIN_PREFER_FIXED(LhsTraits::ColsAtCompileTime, RhsTraits::RowsAtCompileTime),
     44 
     45     // The storage order is somewhat arbitrary here. The correct one will be determined through the evaluator.
     46     Flags = (MaxRowsAtCompileTime==1 && MaxColsAtCompileTime!=1) ? RowMajorBit
     47           : (MaxColsAtCompileTime==1 && MaxRowsAtCompileTime!=1) ? 0
     48           : (   ((LhsTraits::Flags&NoPreferredStorageOrderBit) && (RhsTraits::Flags&RowMajorBit))
     49              || ((RhsTraits::Flags&NoPreferredStorageOrderBit) && (LhsTraits::Flags&RowMajorBit)) ) ? RowMajorBit
     50           : NoPreferredStorageOrderBit
     51   };
     52 };
     53 
     54 } // end namespace internal
     55 
     56 /** \class Product
     57   * \ingroup Core_Module
     58   *
     59   * \brief Expression of the product of two arbitrary matrices or vectors
     60   *
     61   * \tparam _Lhs the type of the left-hand side expression
     62   * \tparam _Rhs the type of the right-hand side expression
     63   *
     64   * This class represents an expression of the product of two arbitrary matrices.
     65   *
     66   * The other template parameters are:
     67   * \tparam Option     can be DefaultProduct, AliasFreeProduct, or LazyProduct
     68   *
     69   */
     70 template<typename _Lhs, typename _Rhs, int Option>
     71 class Product : public ProductImpl<_Lhs,_Rhs,Option,
     72                                    typename internal::product_promote_storage_type<typename internal::traits<_Lhs>::StorageKind,
     73                                                                                    typename internal::traits<_Rhs>::StorageKind,
     74                                                                                    internal::product_type<_Lhs,_Rhs>::ret>::ret>
     75 {
     76   public:
     77 
     78     typedef _Lhs Lhs;
     79     typedef _Rhs Rhs;
     80 
     81     typedef typename ProductImpl<
     82         Lhs, Rhs, Option,
     83         typename internal::product_promote_storage_type<typename internal::traits<Lhs>::StorageKind,
     84                                                         typename internal::traits<Rhs>::StorageKind,
     85                                                         internal::product_type<Lhs,Rhs>::ret>::ret>::Base Base;
     86     EIGEN_GENERIC_PUBLIC_INTERFACE(Product)
     87 
     88     typedef typename internal::ref_selector<Lhs>::type LhsNested;
     89     typedef typename internal::ref_selector<Rhs>::type RhsNested;
     90     typedef typename internal::remove_all<LhsNested>::type LhsNestedCleaned;
     91     typedef typename internal::remove_all<RhsNested>::type RhsNestedCleaned;
     92 
     93     EIGEN_DEVICE_FUNC Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs)
     94     {
     95       eigen_assert(lhs.cols() == rhs.rows()
     96         && "invalid matrix product"
     97         && "if you wanted a coeff-wise or a dot product use the respective explicit functions");
     98     }
     99 
    100     EIGEN_DEVICE_FUNC inline Index rows() const { return m_lhs.rows(); }
    101     EIGEN_DEVICE_FUNC inline Index cols() const { return m_rhs.cols(); }
    102 
    103     EIGEN_DEVICE_FUNC const LhsNestedCleaned& lhs() const { return m_lhs; }
    104     EIGEN_DEVICE_FUNC const RhsNestedCleaned& rhs() const { return m_rhs; }
    105 
    106   protected:
    107 
    108     LhsNested m_lhs;
    109     RhsNested m_rhs;
    110 };
    111 
    112 namespace internal {
    113 
    114 template<typename Lhs, typename Rhs, int Option, int ProductTag = internal::product_type<Lhs,Rhs>::ret>
    115 class dense_product_base
    116  : public internal::dense_xpr_base<Product<Lhs,Rhs,Option> >::type
    117 {};
    118 
    119 /** Convertion to scalar for inner-products */
    120 template<typename Lhs, typename Rhs, int Option>
    121 class dense_product_base<Lhs, Rhs, Option, InnerProduct>
    122  : public internal::dense_xpr_base<Product<Lhs,Rhs,Option> >::type
    123 {
    124   typedef Product<Lhs,Rhs,Option> ProductXpr;
    125   typedef typename internal::dense_xpr_base<ProductXpr>::type Base;
    126 public:
    127   using Base::derived;
    128   typedef typename Base::Scalar Scalar;
    129 
    130   operator const Scalar() const
    131   {
    132     return internal::evaluator<ProductXpr>(derived()).coeff(0,0);
    133   }
    134 };
    135 
    136 } // namespace internal
    137 
    138 // Generic API dispatcher
    139 template<typename Lhs, typename Rhs, int Option, typename StorageKind>
    140 class ProductImpl : public internal::generic_xpr_base<Product<Lhs,Rhs,Option>, MatrixXpr, StorageKind>::type
    141 {
    142   public:
    143     typedef typename internal::generic_xpr_base<Product<Lhs,Rhs,Option>, MatrixXpr, StorageKind>::type Base;
    144 };
    145 
    146 template<typename Lhs, typename Rhs, int Option>
    147 class ProductImpl<Lhs,Rhs,Option,Dense>
    148   : public internal::dense_product_base<Lhs,Rhs,Option>
    149 {
    150     typedef Product<Lhs, Rhs, Option> Derived;
    151 
    152   public:
    153 
    154     typedef typename internal::dense_product_base<Lhs, Rhs, Option> Base;
    155     EIGEN_DENSE_PUBLIC_INTERFACE(Derived)
    156   protected:
    157     enum {
    158       IsOneByOne = (RowsAtCompileTime == 1 || RowsAtCompileTime == Dynamic) &&
    159                    (ColsAtCompileTime == 1 || ColsAtCompileTime == Dynamic),
    160       EnableCoeff = IsOneByOne || Option==LazyProduct
    161     };
    162 
    163   public:
    164 
    165     EIGEN_DEVICE_FUNC Scalar coeff(Index row, Index col) const
    166     {
    167       EIGEN_STATIC_ASSERT(EnableCoeff, THIS_METHOD_IS_ONLY_FOR_INNER_OR_LAZY_PRODUCTS);
    168       eigen_assert( (Option==LazyProduct) || (this->rows() == 1 && this->cols() == 1) );
    169 
    170       return internal::evaluator<Derived>(derived()).coeff(row,col);
    171     }
    172 
    173     EIGEN_DEVICE_FUNC Scalar coeff(Index i) const
    174     {
    175       EIGEN_STATIC_ASSERT(EnableCoeff, THIS_METHOD_IS_ONLY_FOR_INNER_OR_LAZY_PRODUCTS);
    176       eigen_assert( (Option==LazyProduct) || (this->rows() == 1 && this->cols() == 1) );
    177 
    178       return internal::evaluator<Derived>(derived()).coeff(i);
    179     }
    180 
    181 
    182 };
    183 
    184 } // end namespace Eigen
    185 
    186 #endif // EIGEN_PRODUCT_H
    187