1 // Ceres Solver - A fast non-linear least squares minimizer 2 // Copyright 2010, 2011, 2012 Google Inc. All rights reserved. 3 // http://code.google.com/p/ceres-solver/ 4 // 5 // Redistribution and use in source and binary forms, with or without 6 // modification, are permitted provided that the following conditions are met: 7 // 8 // * Redistributions of source code must retain the above copyright notice, 9 // this list of conditions and the following disclaimer. 10 // * Redistributions in binary form must reproduce the above copyright notice, 11 // this list of conditions and the following disclaimer in the documentation 12 // and/or other materials provided with the distribution. 13 // * Neither the name of Google Inc. nor the names of its contributors may be 14 // used to endorse or promote products derived from this software without 15 // specific prior written permission. 16 // 17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 // POSSIBILITY OF SUCH DAMAGE. 28 // 29 // Author: sameeragarwal (at) google.com (Sameer Agarwal) 30 31 #include "ceres/implicit_schur_complement.h" 32 33 #include "Eigen/Dense" 34 #include "ceres/block_sparse_matrix.h" 35 #include "ceres/block_structure.h" 36 #include "ceres/internal/eigen.h" 37 #include "ceres/internal/scoped_ptr.h" 38 #include "ceres/linear_solver.h" 39 #include "ceres/types.h" 40 #include "glog/logging.h" 41 42 namespace ceres { 43 namespace internal { 44 45 ImplicitSchurComplement::ImplicitSchurComplement( 46 const LinearSolver::Options& options) 47 : options_(options), 48 D_(NULL), 49 b_(NULL) { 50 } 51 52 ImplicitSchurComplement::~ImplicitSchurComplement() { 53 } 54 55 void ImplicitSchurComplement::Init(const BlockSparseMatrix& A, 56 const double* D, 57 const double* b) { 58 // Since initialization is reasonably heavy, perhaps we can save on 59 // constructing a new object everytime. 60 if (A_ == NULL) { 61 A_.reset(PartitionedMatrixViewBase::Create(options_, A)); 62 } 63 64 D_ = D; 65 b_ = b; 66 67 // Initialize temporary storage and compute the block diagonals of 68 // E'E and F'E. 69 if (block_diagonal_EtE_inverse_ == NULL) { 70 block_diagonal_EtE_inverse_.reset(A_->CreateBlockDiagonalEtE()); 71 if (options_.preconditioner_type == JACOBI) { 72 block_diagonal_FtF_inverse_.reset(A_->CreateBlockDiagonalFtF()); 73 } 74 rhs_.resize(A_->num_cols_f()); 75 rhs_.setZero(); 76 tmp_rows_.resize(A_->num_rows()); 77 tmp_e_cols_.resize(A_->num_cols_e()); 78 tmp_e_cols_2_.resize(A_->num_cols_e()); 79 tmp_f_cols_.resize(A_->num_cols_f()); 80 } else { 81 A_->UpdateBlockDiagonalEtE(block_diagonal_EtE_inverse_.get()); 82 if (options_.preconditioner_type == JACOBI) { 83 A_->UpdateBlockDiagonalFtF(block_diagonal_FtF_inverse_.get()); 84 } 85 } 86 87 // The block diagonals of the augmented linear system contain 88 // contributions from the diagonal D if it is non-null. Add that to 89 // the block diagonals and invert them. 90 AddDiagonalAndInvert(D_, block_diagonal_EtE_inverse_.get()); 91 if (options_.preconditioner_type == JACOBI) { 92 AddDiagonalAndInvert((D_ == NULL) ? NULL : D_ + A_->num_cols_e(), 93 block_diagonal_FtF_inverse_.get()); 94 } 95 96 // Compute the RHS of the Schur complement system. 97 UpdateRhs(); 98 } 99 100 // Evaluate the product 101 // 102 // Sx = [F'F - F'E (E'E)^-1 E'F]x 103 // 104 // By breaking it down into individual matrix vector products 105 // involving the matrices E and F. This is implemented using a 106 // PartitionedMatrixView of the input matrix A. 107 void ImplicitSchurComplement::RightMultiply(const double* x, double* y) const { 108 // y1 = F x 109 tmp_rows_.setZero(); 110 A_->RightMultiplyF(x, tmp_rows_.data()); 111 112 // y2 = E' y1 113 tmp_e_cols_.setZero(); 114 A_->LeftMultiplyE(tmp_rows_.data(), tmp_e_cols_.data()); 115 116 // y3 = -(E'E)^-1 y2 117 tmp_e_cols_2_.setZero(); 118 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), 119 tmp_e_cols_2_.data()); 120 tmp_e_cols_2_ *= -1.0; 121 122 // y1 = y1 + E y3 123 A_->RightMultiplyE(tmp_e_cols_2_.data(), tmp_rows_.data()); 124 125 // y5 = D * x 126 if (D_ != NULL) { 127 ConstVectorRef Dref(D_ + A_->num_cols_e(), num_cols()); 128 VectorRef(y, num_cols()) = 129 (Dref.array().square() * 130 ConstVectorRef(x, num_cols()).array()).matrix(); 131 } else { 132 VectorRef(y, num_cols()).setZero(); 133 } 134 135 // y = y5 + F' y1 136 A_->LeftMultiplyF(tmp_rows_.data(), y); 137 } 138 139 // Given a block diagonal matrix and an optional array of diagonal 140 // entries D, add them to the diagonal of the matrix and compute the 141 // inverse of each diagonal block. 142 void ImplicitSchurComplement::AddDiagonalAndInvert( 143 const double* D, 144 BlockSparseMatrix* block_diagonal) { 145 const CompressedRowBlockStructure* block_diagonal_structure = 146 block_diagonal->block_structure(); 147 for (int r = 0; r < block_diagonal_structure->rows.size(); ++r) { 148 const int row_block_pos = block_diagonal_structure->rows[r].block.position; 149 const int row_block_size = block_diagonal_structure->rows[r].block.size; 150 const Cell& cell = block_diagonal_structure->rows[r].cells[0]; 151 MatrixRef m(block_diagonal->mutable_values() + cell.position, 152 row_block_size, row_block_size); 153 154 if (D != NULL) { 155 ConstVectorRef d(D + row_block_pos, row_block_size); 156 m += d.array().square().matrix().asDiagonal(); 157 } 158 159 m = m 160 .selfadjointView<Eigen::Upper>() 161 .llt() 162 .solve(Matrix::Identity(row_block_size, row_block_size)); 163 } 164 } 165 166 // Similar to RightMultiply, use the block structure of the matrix A 167 // to compute y = (E'E)^-1 (E'b - E'F x). 168 void ImplicitSchurComplement::BackSubstitute(const double* x, double* y) { 169 const int num_cols_e = A_->num_cols_e(); 170 const int num_cols_f = A_->num_cols_f(); 171 const int num_cols = A_->num_cols(); 172 const int num_rows = A_->num_rows(); 173 174 // y1 = F x 175 tmp_rows_.setZero(); 176 A_->RightMultiplyF(x, tmp_rows_.data()); 177 178 // y2 = b - y1 179 tmp_rows_ = ConstVectorRef(b_, num_rows) - tmp_rows_; 180 181 // y3 = E' y2 182 tmp_e_cols_.setZero(); 183 A_->LeftMultiplyE(tmp_rows_.data(), tmp_e_cols_.data()); 184 185 // y = (E'E)^-1 y3 186 VectorRef(y, num_cols).setZero(); 187 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), y); 188 189 // The full solution vector y has two blocks. The first block of 190 // variables corresponds to the eliminated variables, which we just 191 // computed via back substitution. The second block of variables 192 // corresponds to the Schur complement system, so we just copy those 193 // values from the solution to the Schur complement. 194 VectorRef(y + num_cols_e, num_cols_f) = ConstVectorRef(x, num_cols_f); 195 } 196 197 // Compute the RHS of the Schur complement system. 198 // 199 // rhs = F'b - F'E (E'E)^-1 E'b 200 // 201 // Like BackSubstitute, we use the block structure of A to implement 202 // this using a series of matrix vector products. 203 void ImplicitSchurComplement::UpdateRhs() { 204 // y1 = E'b 205 tmp_e_cols_.setZero(); 206 A_->LeftMultiplyE(b_, tmp_e_cols_.data()); 207 208 // y2 = (E'E)^-1 y1 209 Vector y2 = Vector::Zero(A_->num_cols_e()); 210 block_diagonal_EtE_inverse_->RightMultiply(tmp_e_cols_.data(), y2.data()); 211 212 // y3 = E y2 213 tmp_rows_.setZero(); 214 A_->RightMultiplyE(y2.data(), tmp_rows_.data()); 215 216 // y3 = b - y3 217 tmp_rows_ = ConstVectorRef(b_, A_->num_rows()) - tmp_rows_; 218 219 // rhs = F' y3 220 rhs_.setZero(); 221 A_->LeftMultiplyF(tmp_rows_.data(), rhs_.data()); 222 } 223 224 } // namespace internal 225 } // namespace ceres 226