Home | History | Annotate | Download | only in products
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud (at) inria.fr>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_TRIANGULAR_SOLVER_MATRIX_H
     11 #define EIGEN_TRIANGULAR_SOLVER_MATRIX_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 // if the rhs is row major, let's transpose the product
     18 template <typename Scalar, typename Index, int Side, int Mode, bool Conjugate, int TriStorageOrder>
     19 struct triangular_solve_matrix<Scalar,Index,Side,Mode,Conjugate,TriStorageOrder,RowMajor>
     20 {
     21   static void run(
     22     Index size, Index cols,
     23     const Scalar*  tri, Index triStride,
     24     Scalar* _other, Index otherStride,
     25     level3_blocking<Scalar,Scalar>& blocking)
     26   {
     27     triangular_solve_matrix<
     28       Scalar, Index, Side==OnTheLeft?OnTheRight:OnTheLeft,
     29       (Mode&UnitDiag) | ((Mode&Upper) ? Lower : Upper),
     30       NumTraits<Scalar>::IsComplex && Conjugate,
     31       TriStorageOrder==RowMajor ? ColMajor : RowMajor, ColMajor>
     32       ::run(size, cols, tri, triStride, _other, otherStride, blocking);
     33   }
     34 };
     35 
     36 /* Optimized triangular solver with multiple right hand side and the triangular matrix on the left
     37  */
     38 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
     39 struct triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>
     40 {
     41   static EIGEN_DONT_INLINE void run(
     42     Index size, Index otherSize,
     43     const Scalar* _tri, Index triStride,
     44     Scalar* _other, Index otherStride,
     45     level3_blocking<Scalar,Scalar>& blocking);
     46 };
     47 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
     48 EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStorageOrder,ColMajor>::run(
     49     Index size, Index otherSize,
     50     const Scalar* _tri, Index triStride,
     51     Scalar* _other, Index otherStride,
     52     level3_blocking<Scalar,Scalar>& blocking)
     53   {
     54     Index cols = otherSize;
     55     const_blas_data_mapper<Scalar, Index, TriStorageOrder> tri(_tri,triStride);
     56     blas_data_mapper<Scalar, Index, ColMajor> other(_other,otherStride);
     57 
     58     typedef gebp_traits<Scalar,Scalar> Traits;
     59     enum {
     60       SmallPanelWidth   = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
     61       IsLower = (Mode&Lower) == Lower
     62     };
     63 
     64     Index kc = blocking.kc();                   // cache block size along the K direction
     65     Index mc = (std::min)(size,blocking.mc());  // cache block size along the M direction
     66 
     67     std::size_t sizeA = kc*mc;
     68     std::size_t sizeB = kc*cols;
     69     std::size_t sizeW = kc*Traits::WorkSpaceFactor;
     70 
     71     ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
     72     ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
     73     ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
     74 
     75     conj_if<Conjugate> conj;
     76     gebp_kernel<Scalar, Scalar, Index, Traits::mr, Traits::nr, Conjugate, false> gebp_kernel;
     77     gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, TriStorageOrder> pack_lhs;
     78     gemm_pack_rhs<Scalar, Index, Traits::nr, ColMajor, false, true> pack_rhs;
     79 
     80     // the goal here is to subdivise the Rhs panels such that we keep some cache
     81     // coherence when accessing the rhs elements
     82     std::ptrdiff_t l1, l2;
     83     manage_caching_sizes(GetAction, &l1, &l2);
     84     Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * otherStride) : 0;
     85     subcols = std::max<Index>((subcols/Traits::nr)*Traits::nr, Traits::nr);
     86 
     87     for(Index k2=IsLower ? 0 : size;
     88         IsLower ? k2<size : k2>0;
     89         IsLower ? k2+=kc : k2-=kc)
     90     {
     91       const Index actual_kc = (std::min)(IsLower ? size-k2 : k2, kc);
     92 
     93       // We have selected and packed a big horizontal panel R1 of rhs. Let B be the packed copy of this panel,
     94       // and R2 the remaining part of rhs. The corresponding vertical panel of lhs is split into
     95       // A11 (the triangular part) and A21 the remaining rectangular part.
     96       // Then the high level algorithm is:
     97       //  - B = R1                    => general block copy (done during the next step)
     98       //  - R1 = A11^-1 B             => tricky part
     99       //  - update B from the new R1  => actually this has to be performed continuously during the above step
    100       //  - R2 -= A21 * B             => GEPP
    101 
    102       // The tricky part: compute R1 = A11^-1 B while updating B from R1
    103       // The idea is to split A11 into multiple small vertical panels.
    104       // Each panel can be split into a small triangular part T1k which is processed without optimization,
    105       // and the remaining small part T2k which is processed using gebp with appropriate block strides
    106       for(Index j2=0; j2<cols; j2+=subcols)
    107       {
    108         Index actual_cols = (std::min)(cols-j2,subcols);
    109         // for each small vertical panels [T1k^T, T2k^T]^T of lhs
    110         for (Index k1=0; k1<actual_kc; k1+=SmallPanelWidth)
    111         {
    112           Index actualPanelWidth = std::min<Index>(actual_kc-k1, SmallPanelWidth);
    113           // tr solve
    114           for (Index k=0; k<actualPanelWidth; ++k)
    115           {
    116             // TODO write a small kernel handling this (can be shared with trsv)
    117             Index i  = IsLower ? k2+k1+k : k2-k1-k-1;
    118             Index s  = IsLower ? k2+k1 : i+1;
    119             Index rs = actualPanelWidth - k - 1; // remaining size
    120 
    121             Scalar a = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(tri(i,i));
    122             for (Index j=j2; j<j2+actual_cols; ++j)
    123             {
    124               if (TriStorageOrder==RowMajor)
    125               {
    126                 Scalar b(0);
    127                 const Scalar* l = &tri(i,s);
    128                 Scalar* r = &other(s,j);
    129                 for (Index i3=0; i3<k; ++i3)
    130                   b += conj(l[i3]) * r[i3];
    131 
    132                 other(i,j) = (other(i,j) - b)*a;
    133               }
    134               else
    135               {
    136                 Index s = IsLower ? i+1 : i-rs;
    137                 Scalar b = (other(i,j) *= a);
    138                 Scalar* r = &other(s,j);
    139                 const Scalar* l = &tri(s,i);
    140                 for (Index i3=0;i3<rs;++i3)
    141                   r[i3] -= b * conj(l[i3]);
    142               }
    143             }
    144           }
    145 
    146           Index lengthTarget = actual_kc-k1-actualPanelWidth;
    147           Index startBlock   = IsLower ? k2+k1 : k2-k1-actualPanelWidth;
    148           Index blockBOffset = IsLower ? k1 : lengthTarget;
    149 
    150           // update the respective rows of B from other
    151           pack_rhs(blockB+actual_kc*j2, &other(startBlock,j2), otherStride, actualPanelWidth, actual_cols, actual_kc, blockBOffset);
    152 
    153           // GEBP
    154           if (lengthTarget>0)
    155           {
    156             Index startTarget  = IsLower ? k2+k1+actualPanelWidth : k2-actual_kc;
    157 
    158             pack_lhs(blockA, &tri(startTarget,startBlock), triStride, actualPanelWidth, lengthTarget);
    159 
    160             gebp_kernel(&other(startTarget,j2), otherStride, blockA, blockB+actual_kc*j2, lengthTarget, actualPanelWidth, actual_cols, Scalar(-1),
    161                         actualPanelWidth, actual_kc, 0, blockBOffset, blockW);
    162           }
    163         }
    164       }
    165 
    166       // R2 -= A21 * B => GEPP
    167       {
    168         Index start = IsLower ? k2+kc : 0;
    169         Index end   = IsLower ? size : k2-kc;
    170         for(Index i2=start; i2<end; i2+=mc)
    171         {
    172           const Index actual_mc = (std::min)(mc,end-i2);
    173           if (actual_mc>0)
    174           {
    175             pack_lhs(blockA, &tri(i2, IsLower ? k2 : k2-kc), triStride, actual_kc, actual_mc);
    176 
    177             gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1), -1, -1, 0, 0, blockW);
    178           }
    179         }
    180       }
    181     }
    182   }
    183 
    184 /* Optimized triangular solver with multiple left hand sides and the trinagular matrix on the right
    185  */
    186 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
    187 struct triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>
    188 {
    189   static EIGEN_DONT_INLINE void run(
    190     Index size, Index otherSize,
    191     const Scalar* _tri, Index triStride,
    192     Scalar* _other, Index otherStride,
    193     level3_blocking<Scalar,Scalar>& blocking);
    194 };
    195 template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder>
    196 EIGEN_DONT_INLINE void triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStorageOrder,ColMajor>::run(
    197     Index size, Index otherSize,
    198     const Scalar* _tri, Index triStride,
    199     Scalar* _other, Index otherStride,
    200     level3_blocking<Scalar,Scalar>& blocking)
    201   {
    202     Index rows = otherSize;
    203     const_blas_data_mapper<Scalar, Index, TriStorageOrder> rhs(_tri,triStride);
    204     blas_data_mapper<Scalar, Index, ColMajor> lhs(_other,otherStride);
    205 
    206     typedef gebp_traits<Scalar,Scalar> Traits;
    207     enum {
    208       RhsStorageOrder   = TriStorageOrder,
    209       SmallPanelWidth   = EIGEN_PLAIN_ENUM_MAX(Traits::mr,Traits::nr),
    210       IsLower = (Mode&Lower) == Lower
    211     };
    212 
    213     Index kc = blocking.kc();                   // cache block size along the K direction
    214     Index mc = (std::min)(rows,blocking.mc());  // cache block size along the M direction
    215 
    216     std::size_t sizeA = kc*mc;
    217     std::size_t sizeB = kc*size;
    218     std::size_t sizeW = kc*Traits::WorkSpaceFactor;
    219 
    220     ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
    221     ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
    222     ei_declare_aligned_stack_constructed_variable(Scalar, blockW, sizeW, blocking.blockW());
    223 
    224     conj_if<Conjugate> conj;
    225     gebp_kernel<Scalar,Scalar, Index, Traits::mr, Traits::nr, false, Conjugate> gebp_kernel;
    226     gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
    227     gemm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder,false,true> pack_rhs_panel;
    228     gemm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, ColMajor, false, true> pack_lhs_panel;
    229 
    230     for(Index k2=IsLower ? size : 0;
    231         IsLower ? k2>0 : k2<size;
    232         IsLower ? k2-=kc : k2+=kc)
    233     {
    234       const Index actual_kc = (std::min)(IsLower ? k2 : size-k2, kc);
    235       Index actual_k2 = IsLower ? k2-actual_kc : k2 ;
    236 
    237       Index startPanel = IsLower ? 0 : k2+actual_kc;
    238       Index rs = IsLower ? actual_k2 : size - actual_k2 - actual_kc;
    239       Scalar* geb = blockB+actual_kc*actual_kc;
    240 
    241       if (rs>0) pack_rhs(geb, &rhs(actual_k2,startPanel), triStride, actual_kc, rs);
    242 
    243       // triangular packing (we only pack the panels off the diagonal,
    244       // neglecting the blocks overlapping the diagonal
    245       {
    246         for (Index j2=0; j2<actual_kc; j2+=SmallPanelWidth)
    247         {
    248           Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
    249           Index actual_j2 = actual_k2 + j2;
    250           Index panelOffset = IsLower ? j2+actualPanelWidth : 0;
    251           Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2;
    252 
    253           if (panelLength>0)
    254           pack_rhs_panel(blockB+j2*actual_kc,
    255                          &rhs(actual_k2+panelOffset, actual_j2), triStride,
    256                          panelLength, actualPanelWidth,
    257                          actual_kc, panelOffset);
    258         }
    259       }
    260 
    261       for(Index i2=0; i2<rows; i2+=mc)
    262       {
    263         const Index actual_mc = (std::min)(mc,rows-i2);
    264 
    265         // triangular solver kernel
    266         {
    267           // for each small block of the diagonal (=> vertical panels of rhs)
    268           for (Index j2 = IsLower
    269                       ? (actual_kc - ((actual_kc%SmallPanelWidth) ? Index(actual_kc%SmallPanelWidth)
    270                                                                   : Index(SmallPanelWidth)))
    271                       : 0;
    272                IsLower ? j2>=0 : j2<actual_kc;
    273                IsLower ? j2-=SmallPanelWidth : j2+=SmallPanelWidth)
    274           {
    275             Index actualPanelWidth = std::min<Index>(actual_kc-j2, SmallPanelWidth);
    276             Index absolute_j2 = actual_k2 + j2;
    277             Index panelOffset = IsLower ? j2+actualPanelWidth : 0;
    278             Index panelLength = IsLower ? actual_kc - j2 - actualPanelWidth : j2;
    279 
    280             // GEBP
    281             if(panelLength>0)
    282             {
    283               gebp_kernel(&lhs(i2,absolute_j2), otherStride,
    284                           blockA, blockB+j2*actual_kc,
    285                           actual_mc, panelLength, actualPanelWidth,
    286                           Scalar(-1),
    287                           actual_kc, actual_kc, // strides
    288                           panelOffset, panelOffset, // offsets
    289                           blockW);  // workspace
    290             }
    291 
    292             // unblocked triangular solve
    293             for (Index k=0; k<actualPanelWidth; ++k)
    294             {
    295               Index j = IsLower ? absolute_j2+actualPanelWidth-k-1 : absolute_j2+k;
    296 
    297               Scalar* r = &lhs(i2,j);
    298               for (Index k3=0; k3<k; ++k3)
    299               {
    300                 Scalar b = conj(rhs(IsLower ? j+1+k3 : absolute_j2+k3,j));
    301                 Scalar* a = &lhs(i2,IsLower ? j+1+k3 : absolute_j2+k3);
    302                 for (Index i=0; i<actual_mc; ++i)
    303                   r[i] -= a[i] * b;
    304               }
    305               Scalar b = (Mode & UnitDiag) ? Scalar(1) : Scalar(1)/conj(rhs(j,j));
    306               for (Index i=0; i<actual_mc; ++i)
    307                 r[i] *= b;
    308             }
    309 
    310             // pack the just computed part of lhs to A
    311             pack_lhs_panel(blockA, _other+absolute_j2*otherStride+i2, otherStride,
    312                            actualPanelWidth, actual_mc,
    313                            actual_kc, j2);
    314           }
    315         }
    316 
    317         if (rs>0)
    318           gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
    319                       actual_mc, actual_kc, rs, Scalar(-1),
    320                       -1, -1, 0, 0, blockW);
    321       }
    322     }
    323   }
    324 
    325 } // end namespace internal
    326 
    327 } // end namespace Eigen
    328 
    329 #endif // EIGEN_TRIANGULAR_SOLVER_MATRIX_H
    330