1 // Ceres Solver - A fast non-linear least squares minimizer 2 // Copyright 2010, 2011, 2012 Google Inc. All rights reserved. 3 // http://code.google.com/p/ceres-solver/ 4 // 5 // Redistribution and use in source and binary forms, with or without 6 // modification, are permitted provided that the following conditions are met: 7 // 8 // * Redistributions of source code must retain the above copyright notice, 9 // this list of conditions and the following disclaimer. 10 // * Redistributions in binary form must reproduce the above copyright notice, 11 // this list of conditions and the following disclaimer in the documentation 12 // and/or other materials provided with the distribution. 13 // * Neither the name of Google Inc. nor the names of its contributors may be 14 // used to endorse or promote products derived from this software without 15 // specific prior written permission. 16 // 17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 // POSSIBILITY OF SUCH DAMAGE. 28 // 29 // Author: sameeragarwal (at) google.com (Sameer Agarwal) 30 31 #include "ceres/implicit_schur_complement.h" 32 33 #include "Eigen/Dense" 34 #include "ceres/block_sparse_matrix.h" 35 #include "ceres/block_structure.h" 36 #include "ceres/internal/eigen.h" 37 #include "ceres/internal/scoped_ptr.h" 38 #include "ceres/types.h" 39 #include "glog/logging.h" 40 41 namespace ceres { 42 namespace internal { 43 44 ImplicitSchurComplement::ImplicitSchurComplement(int num_eliminate_blocks, 45 bool preconditioner) 46 : num_eliminate_blocks_(num_eliminate_blocks), 47 preconditioner_(preconditioner), 48 A_(NULL), 49 D_(NULL), 50 b_(NULL), 51 block_diagonal_EtE_inverse_(NULL), 52 block_diagonal_FtF_inverse_(NULL) { 53 } 54 55 ImplicitSchurComplement::~ImplicitSchurComplement() { 56 } 57 58 void ImplicitSchurComplement::Init(const BlockSparseMatrixBase& A, 59 const double* D, 60 const double* b) { 61 // Since initialization is reasonably heavy, perhaps we can save on 62 // constructing a new object everytime. 63 if (A_ == NULL) { 64 A_.reset(new PartitionedMatrixView(A, num_eliminate_blocks_)); 65 } 66 67 D_ = D; 68 b_ = b; 69 70 // Initialize temporary storage and compute the block diagonals of 71 // E'E and F'E. 72 if (block_diagonal_EtE_inverse_ == NULL) { 73 block_diagonal_EtE_inverse_.reset(A_->CreateBlockDiagonalEtE()); 74 if (preconditioner_) { 75 block_diagonal_FtF_inverse_.reset(A_->CreateBlockDiagonalFtF()); 76 } 77 rhs_.resize(A_->num_cols_f()); 78 rhs_.setZero(); 79 tmp_rows_.resize(A_->num_rows()); 80 tmp_e_cols_.resize(A_->num_cols_e()); 81 tmp_e_cols_2_.resize(A_->num_cols_e()); 82 tmp_f_cols_.resize(A_->num_cols_f()); 83 } else { 84 A_->UpdateBlockDiagonalEtE(block_diagonal_EtE_inverse_.get()); 85 if (preconditioner_) { 86 A_->UpdateBlockDiagonalFtF(block_diagonal_FtF_inverse_.get()); 87 } 88 } 89 90 // The block diagonals of the augmented linear system contain 91 // contributions from the diagonal D if it is non-null. Add that to 92 // the block diagonals and invert them. 93 AddDiagonalAndInvert(D_, block_diagonal_EtE_inverse_.get()); 94 if (preconditioner_) { 95 AddDiagonalAndInvert((D_ == NULL) ? NULL : D_ + A_->num_cols_e(), 96 block_diagonal_FtF_inverse_.get()); 97 } 98 99 // Compute the RHS of the Schur complement system. 100 UpdateRhs(); 101 } 102 103 // Evaluate the product 104 // 105 // Sx = [F'F - F'E (E'E)^-1 E'F]x 106 // 107 // By breaking it down into individual matrix vector products 108 // involving the matrices E and F. This is implemented using a 109 // PartitionedMatrixView of the input matrix A. 110 void ImplicitSchurComplement::RightMultiply(const double* x, double* y) const { 111 // y1 = F x 112 tmp_rows_.setZero(); 113 A_->RightMultiplyF(x, tmp_rows_.data()); 114 115 // y2 = E' y1 116 tmp_e_cols_.setZero(); 117 A_->LeftMultiplyE(tmp_rows_.data(), tmp_e_cols_.data()); 118 119 // y3 = -(E'E)^-1 y2 120 tmp_e_cols_2_.setZero(); 121 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), 122 tmp_e_cols_2_.data()); 123 tmp_e_cols_2_ *= -1.0; 124 125 // y1 = y1 + E y3 126 A_->RightMultiplyE(tmp_e_cols_2_.data(), tmp_rows_.data()); 127 128 // y5 = D * x 129 if (D_ != NULL) { 130 ConstVectorRef Dref(D_ + A_->num_cols_e(), num_cols()); 131 VectorRef(y, num_cols()) = 132 (Dref.array().square() * 133 ConstVectorRef(x, num_cols()).array()).matrix(); 134 } else { 135 VectorRef(y, num_cols()).setZero(); 136 } 137 138 // y = y5 + F' y1 139 A_->LeftMultiplyF(tmp_rows_.data(), y); 140 } 141 142 // Given a block diagonal matrix and an optional array of diagonal 143 // entries D, add them to the diagonal of the matrix and compute the 144 // inverse of each diagonal block. 145 void ImplicitSchurComplement::AddDiagonalAndInvert( 146 const double* D, 147 BlockSparseMatrix* block_diagonal) { 148 const CompressedRowBlockStructure* block_diagonal_structure = 149 block_diagonal->block_structure(); 150 for (int r = 0; r < block_diagonal_structure->rows.size(); ++r) { 151 const int row_block_pos = block_diagonal_structure->rows[r].block.position; 152 const int row_block_size = block_diagonal_structure->rows[r].block.size; 153 const Cell& cell = block_diagonal_structure->rows[r].cells[0]; 154 MatrixRef m(block_diagonal->mutable_values() + cell.position, 155 row_block_size, row_block_size); 156 157 if (D != NULL) { 158 ConstVectorRef d(D + row_block_pos, row_block_size); 159 m += d.array().square().matrix().asDiagonal(); 160 } 161 162 m = m 163 .selfadjointView<Eigen::Upper>() 164 .ldlt() 165 .solve(Matrix::Identity(row_block_size, row_block_size)); 166 } 167 } 168 169 // Similar to RightMultiply, use the block structure of the matrix A 170 // to compute y = (E'E)^-1 (E'b - E'F x). 171 void ImplicitSchurComplement::BackSubstitute(const double* x, double* y) { 172 const int num_cols_e = A_->num_cols_e(); 173 const int num_cols_f = A_->num_cols_f(); 174 const int num_cols = A_->num_cols(); 175 const int num_rows = A_->num_rows(); 176 177 // y1 = F x 178 tmp_rows_.setZero(); 179 A_->RightMultiplyF(x, tmp_rows_.data()); 180 181 // y2 = b - y1 182 tmp_rows_ = ConstVectorRef(b_, num_rows) - tmp_rows_; 183 184 // y3 = E' y2 185 tmp_e_cols_.setZero(); 186 A_->LeftMultiplyE(tmp_rows_.data(), tmp_e_cols_.data()); 187 188 // y = (E'E)^-1 y3 189 VectorRef(y, num_cols).setZero(); 190 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), y); 191 192 // The full solution vector y has two blocks. The first block of 193 // variables corresponds to the eliminated variables, which we just 194 // computed via back substitution. The second block of variables 195 // corresponds to the Schur complement system, so we just copy those 196 // values from the solution to the Schur complement. 197 VectorRef(y + num_cols_e, num_cols_f) = ConstVectorRef(x, num_cols_f); 198 } 199 200 // Compute the RHS of the Schur complement system. 201 // 202 // rhs = F'b - F'E (E'E)^-1 E'b 203 // 204 // Like BackSubstitute, we use the block structure of A to implement 205 // this using a series of matrix vector products. 206 void ImplicitSchurComplement::UpdateRhs() { 207 // y1 = E'b 208 tmp_e_cols_.setZero(); 209 A_->LeftMultiplyE(b_, tmp_e_cols_.data()); 210 211 // y2 = (E'E)^-1 y1 212 Vector y2 = Vector::Zero(A_->num_cols_e()); 213 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), y2.data()); 214 215 // y3 = E y2 216 tmp_rows_.setZero(); 217 A_->RightMultiplyE(y2.data(), tmp_rows_.data()); 218 219 // y3 = b - y3 220 tmp_rows_ = ConstVectorRef(b_, A_->num_rows()) - tmp_rows_; 221 222 // rhs = F' y3 223 rhs_.setZero(); 224 A_->LeftMultiplyF(tmp_rows_.data(), rhs_.data()); 225 } 226 227 } // namespace internal 228 } // namespace ceres 229