1 //===================================================== 2 // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud (at) inria.fr> 3 //===================================================== 4 // 5 // This program is free software; you can redistribute it and/or 6 // modify it under the terms of the GNU General Public License 7 // as published by the Free Software Foundation; either version 2 8 // of the License, or (at your option) any later version. 9 // 10 // This program is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU General Public License for more details. 14 // You should have received a copy of the GNU General Public License 15 // along with this program; if not, write to the Free Software 16 // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 17 // 18 #ifndef EIGEN3_INTERFACE_HH 19 #define EIGEN3_INTERFACE_HH 20 21 #include <Eigen/Eigen> 22 #include <vector> 23 #include "btl.hh" 24 25 using namespace Eigen; 26 27 template<class real, int SIZE=Dynamic> 28 class eigen3_interface 29 { 30 31 public : 32 33 enum {IsFixedSize = (SIZE!=Dynamic)}; 34 35 typedef real real_type; 36 37 typedef std::vector<real> stl_vector; 38 typedef std::vector<stl_vector> stl_matrix; 39 40 typedef Eigen::Matrix<real,SIZE,SIZE> gene_matrix; 41 typedef Eigen::Matrix<real,SIZE,1> gene_vector; 42 43 static inline std::string name( void ) 44 { 45 return EIGEN_MAKESTRING(BTL_PREFIX); 46 } 47 48 static void free_matrix(gene_matrix & /*A*/, int /*N*/) {} 49 50 static void free_vector(gene_vector & /*B*/) {} 51 52 static BTL_DONT_INLINE void matrix_from_stl(gene_matrix & A, stl_matrix & A_stl){ 53 A.resize(A_stl[0].size(), A_stl.size()); 54 55 for (unsigned int j=0; j<A_stl.size() ; j++){ 56 for (unsigned int i=0; i<A_stl[j].size() ; i++){ 57 A.coeffRef(i,j) = A_stl[j][i]; 58 } 59 } 60 } 61 62 static BTL_DONT_INLINE void vector_from_stl(gene_vector & B, stl_vector & B_stl){ 63 B.resize(B_stl.size(),1); 64 65 for (unsigned int i=0; i<B_stl.size() ; i++){ 66 B.coeffRef(i) = B_stl[i]; 67 } 68 } 69 70 static BTL_DONT_INLINE void vector_to_stl(gene_vector & B, stl_vector & B_stl){ 71 for (unsigned int i=0; i<B_stl.size() ; i++){ 72 B_stl[i] = B.coeff(i); 73 } 74 } 75 76 static BTL_DONT_INLINE void matrix_to_stl(gene_matrix & A, stl_matrix & A_stl){ 77 int N=A_stl.size(); 78 79 for (int j=0;j<N;j++){ 80 A_stl[j].resize(N); 81 for (int i=0;i<N;i++){ 82 A_stl[j][i] = A.coeff(i,j); 83 } 84 } 85 } 86 87 static inline void matrix_matrix_product(const gene_matrix & A, const gene_matrix & B, gene_matrix & X, int /*N*/){ 88 X.noalias() = A*B; 89 } 90 91 static inline void transposed_matrix_matrix_product(const gene_matrix & A, const gene_matrix & B, gene_matrix & X, int /*N*/){ 92 X.noalias() = A.transpose()*B.transpose(); 93 } 94 95 // static inline void ata_product(const gene_matrix & A, gene_matrix & X, int /*N*/){ 96 // X.noalias() = A.transpose()*A; 97 // } 98 99 static inline void aat_product(const gene_matrix & A, gene_matrix & X, int /*N*/){ 100 X.template triangularView<Lower>().setZero(); 101 X.template selfadjointView<Lower>().rankUpdate(A); 102 } 103 104 static inline void matrix_vector_product(const gene_matrix & A, const gene_vector & B, gene_vector & X, int /*N*/){ 105 X.noalias() = A*B; 106 } 107 108 static inline void symv(const gene_matrix & A, const gene_vector & B, gene_vector & X, int /*N*/){ 109 X.noalias() = (A.template selfadjointView<Lower>() * B); 110 // internal::product_selfadjoint_vector<real,0,LowerTriangularBit,false,false>(N,A.data(),N, B.data(), 1, X.data(), 1); 111 } 112 113 template<typename Dest, typename Src> static void triassign(Dest& dst, const Src& src) 114 { 115 typedef typename Dest::Scalar Scalar; 116 typedef typename internal::packet_traits<Scalar>::type Packet; 117 const int PacketSize = sizeof(Packet)/sizeof(Scalar); 118 int size = dst.cols(); 119 for(int j=0; j<size; j+=1) 120 { 121 // const int alignedEnd = alignedStart + ((innerSize-alignedStart) & ~packetAlignedMask); 122 Scalar* A0 = dst.data() + j*dst.stride(); 123 int starti = j; 124 int alignedEnd = starti; 125 int alignedStart = (starti) + internal::first_aligned(&A0[starti], size-starti); 126 alignedEnd = alignedStart + ((size-alignedStart)/(2*PacketSize))*(PacketSize*2); 127 128 // do the non-vectorizable part of the assignment 129 for (int index = starti; index<alignedStart ; ++index) 130 { 131 if(Dest::Flags&RowMajorBit) 132 dst.copyCoeff(j, index, src); 133 else 134 dst.copyCoeff(index, j, src); 135 } 136 137 // do the vectorizable part of the assignment 138 for (int index = alignedStart; index<alignedEnd; index+=PacketSize) 139 { 140 if(Dest::Flags&RowMajorBit) 141 dst.template copyPacket<Src, Aligned, Unaligned>(j, index, src); 142 else 143 dst.template copyPacket<Src, Aligned, Unaligned>(index, j, src); 144 } 145 146 // do the non-vectorizable part of the assignment 147 for (int index = alignedEnd; index<size; ++index) 148 { 149 if(Dest::Flags&RowMajorBit) 150 dst.copyCoeff(j, index, src); 151 else 152 dst.copyCoeff(index, j, src); 153 } 154 //dst.col(j).tail(N-j) = src.col(j).tail(N-j); 155 } 156 } 157 158 static EIGEN_DONT_INLINE void syr2(gene_matrix & A, gene_vector & X, gene_vector & Y, int N){ 159 // internal::product_selfadjoint_rank2_update<real,0,LowerTriangularBit>(N,A.data(),N, X.data(), 1, Y.data(), 1, -1); 160 for(int j=0; j<N; ++j) 161 A.col(j).tail(N-j) += X[j] * Y.tail(N-j) + Y[j] * X.tail(N-j); 162 } 163 164 static EIGEN_DONT_INLINE void ger(gene_matrix & A, gene_vector & X, gene_vector & Y, int N){ 165 for(int j=0; j<N; ++j) 166 A.col(j) += X * Y[j]; 167 } 168 169 static EIGEN_DONT_INLINE void rot(gene_vector & A, gene_vector & B, real c, real s, int /*N*/){ 170 internal::apply_rotation_in_the_plane(A, B, JacobiRotation<real>(c,s)); 171 } 172 173 static inline void atv_product(gene_matrix & A, gene_vector & B, gene_vector & X, int /*N*/){ 174 X.noalias() = (A.transpose()*B); 175 } 176 177 static inline void axpy(real coef, const gene_vector & X, gene_vector & Y, int /*N*/){ 178 Y += coef * X; 179 } 180 181 static inline void axpby(real a, const gene_vector & X, real b, gene_vector & Y, int /*N*/){ 182 Y = a*X + b*Y; 183 } 184 185 static EIGEN_DONT_INLINE void copy_matrix(const gene_matrix & source, gene_matrix & cible, int /*N*/){ 186 cible = source; 187 } 188 189 static EIGEN_DONT_INLINE void copy_vector(const gene_vector & source, gene_vector & cible, int /*N*/){ 190 cible = source; 191 } 192 193 static inline void trisolve_lower(const gene_matrix & L, const gene_vector& B, gene_vector& X, int /*N*/){ 194 X = L.template triangularView<Lower>().solve(B); 195 } 196 197 static inline void trisolve_lower_matrix(const gene_matrix & L, const gene_matrix& B, gene_matrix& X, int /*N*/){ 198 X = L.template triangularView<Upper>().solve(B); 199 } 200 201 static inline void trmm(const gene_matrix & L, const gene_matrix& B, gene_matrix& X, int /*N*/){ 202 X.noalias() = L.template triangularView<Lower>() * B; 203 } 204 205 static inline void cholesky(const gene_matrix & X, gene_matrix & C, int /*N*/){ 206 C = X; 207 internal::llt_inplace<real,Lower>::blocked(C); 208 //C = X.llt().matrixL(); 209 // C = X; 210 // Cholesky<gene_matrix>::computeInPlace(C); 211 // Cholesky<gene_matrix>::computeInPlaceBlock(C); 212 } 213 214 static inline void lu_decomp(const gene_matrix & X, gene_matrix & C, int /*N*/){ 215 C = X.fullPivLu().matrixLU(); 216 } 217 218 static inline void partial_lu_decomp(const gene_matrix & X, gene_matrix & C, int N){ 219 Matrix<DenseIndex,1,Dynamic> piv(N); 220 DenseIndex nb; 221 C = X; 222 internal::partial_lu_inplace(C,piv,nb); 223 // C = X.partialPivLu().matrixLU(); 224 } 225 226 static inline void tridiagonalization(const gene_matrix & X, gene_matrix & C, int N){ 227 typename Tridiagonalization<gene_matrix>::CoeffVectorType aux(N-1); 228 C = X; 229 internal::tridiagonalization_inplace(C, aux); 230 } 231 232 static inline void hessenberg(const gene_matrix & X, gene_matrix & C, int /*N*/){ 233 C = HessenbergDecomposition<gene_matrix>(X).packedMatrix(); 234 } 235 236 237 238 }; 239 240 #endif 241