1 // Ceres Solver - A fast non-linear least squares minimizer 2 // Copyright 2010, 2011, 2012 Google Inc. All rights reserved. 3 // http://code.google.com/p/ceres-solver/ 4 // 5 // Redistribution and use in source and binary forms, with or without 6 // modification, are permitted provided that the following conditions are met: 7 // 8 // * Redistributions of source code must retain the above copyright notice, 9 // this list of conditions and the following disclaimer. 10 // * Redistributions in binary form must reproduce the above copyright notice, 11 // this list of conditions and the following disclaimer in the documentation 12 // and/or other materials provided with the distribution. 13 // * Neither the name of Google Inc. nor the names of its contributors may be 14 // used to endorse or promote products derived from this software without 15 // specific prior written permission. 16 // 17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 // POSSIBILITY OF SUCH DAMAGE. 28 // 29 // Author: sameeragarwal (at) google.com (Sameer Agarwal) 30 31 #include "ceres/linear_least_squares_problems.h" 32 33 #include <cstdio> 34 #include <string> 35 #include <vector> 36 #include "ceres/block_sparse_matrix.h" 37 #include "ceres/block_structure.h" 38 #include "ceres/casts.h" 39 #include "ceres/compressed_row_sparse_matrix.h" 40 #include "ceres/file.h" 41 #include "ceres/internal/scoped_ptr.h" 42 #include "ceres/matrix_proto.h" 43 #include "ceres/stringprintf.h" 44 #include "ceres/triplet_sparse_matrix.h" 45 #include "ceres/types.h" 46 #include "glog/logging.h" 47 48 namespace ceres { 49 namespace internal { 50 51 LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromId(int id) { 52 switch (id) { 53 case 0: 54 return LinearLeastSquaresProblem0(); 55 case 1: 56 return LinearLeastSquaresProblem1(); 57 case 2: 58 return LinearLeastSquaresProblem2(); 59 case 3: 60 return LinearLeastSquaresProblem3(); 61 default: 62 LOG(FATAL) << "Unknown problem id requested " << id; 63 } 64 return NULL; 65 } 66 67 #ifndef CERES_NO_PROTOCOL_BUFFERS 68 LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromFile( 69 const string& filename) { 70 LinearLeastSquaresProblemProto problem_proto; 71 { 72 string serialized_proto; 73 ReadFileToStringOrDie(filename, &serialized_proto); 74 CHECK(problem_proto.ParseFromString(serialized_proto)); 75 } 76 77 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 78 const SparseMatrixProto& A = problem_proto.a(); 79 80 if (A.has_block_matrix()) { 81 problem->A.reset(new BlockSparseMatrix(A)); 82 } else if (A.has_triplet_matrix()) { 83 problem->A.reset(new TripletSparseMatrix(A)); 84 } else { 85 problem->A.reset(new CompressedRowSparseMatrix(A)); 86 } 87 88 if (problem_proto.b_size() > 0) { 89 problem->b.reset(new double[problem_proto.b_size()]); 90 for (int i = 0; i < problem_proto.b_size(); ++i) { 91 problem->b[i] = problem_proto.b(i); 92 } 93 } 94 95 if (problem_proto.d_size() > 0) { 96 problem->D.reset(new double[problem_proto.d_size()]); 97 for (int i = 0; i < problem_proto.d_size(); ++i) { 98 problem->D[i] = problem_proto.d(i); 99 } 100 } 101 102 if (problem_proto.d_size() > 0) { 103 if (problem_proto.x_size() > 0) { 104 problem->x_D.reset(new double[problem_proto.x_size()]); 105 for (int i = 0; i < problem_proto.x_size(); ++i) { 106 problem->x_D[i] = problem_proto.x(i); 107 } 108 } 109 } else { 110 if (problem_proto.x_size() > 0) { 111 problem->x.reset(new double[problem_proto.x_size()]); 112 for (int i = 0; i < problem_proto.x_size(); ++i) { 113 problem->x[i] = problem_proto.x(i); 114 } 115 } 116 } 117 118 problem->num_eliminate_blocks = 0; 119 if (problem_proto.has_num_eliminate_blocks()) { 120 problem->num_eliminate_blocks = problem_proto.num_eliminate_blocks(); 121 } 122 123 return problem; 124 } 125 #else 126 LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromFile( 127 const string& filename) { 128 LOG(FATAL) 129 << "Loading a least squares problem from disk requires " 130 << "Ceres to be built with Protocol Buffers support."; 131 return NULL; 132 } 133 #endif // CERES_NO_PROTOCOL_BUFFERS 134 135 /* 136 A = [1 2] 137 [3 4] 138 [6 -10] 139 140 b = [ 8 141 18 142 -18] 143 144 x = [2 145 3] 146 147 D = [1 148 2] 149 150 x_D = [1.78448275; 151 2.82327586;] 152 */ 153 LinearLeastSquaresProblem* LinearLeastSquaresProblem0() { 154 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 155 156 TripletSparseMatrix* A = new TripletSparseMatrix(3, 2, 6); 157 problem->b.reset(new double[3]); 158 problem->D.reset(new double[2]); 159 160 problem->x.reset(new double[2]); 161 problem->x_D.reset(new double[2]); 162 163 int* Ai = A->mutable_rows(); 164 int* Aj = A->mutable_cols(); 165 double* Ax = A->mutable_values(); 166 167 int counter = 0; 168 for (int i = 0; i < 3; ++i) { 169 for (int j = 0; j< 2; ++j) { 170 Ai[counter]=i; 171 Aj[counter]=j; 172 ++counter; 173 } 174 }; 175 176 Ax[0] = 1.; 177 Ax[1] = 2.; 178 Ax[2] = 3.; 179 Ax[3] = 4.; 180 Ax[4] = 6; 181 Ax[5] = -10; 182 A->set_num_nonzeros(6); 183 problem->A.reset(A); 184 185 problem->b[0] = 8; 186 problem->b[1] = 18; 187 problem->b[2] = -18; 188 189 problem->x[0] = 2.0; 190 problem->x[1] = 3.0; 191 192 problem->D[0] = 1; 193 problem->D[1] = 2; 194 195 problem->x_D[0] = 1.78448275; 196 problem->x_D[1] = 2.82327586; 197 return problem; 198 } 199 200 201 /* 202 A = [1 0 | 2 0 0 203 3 0 | 0 4 0 204 0 5 | 0 0 6 205 0 7 | 8 0 0 206 0 9 | 1 0 0 207 0 0 | 1 1 1] 208 209 b = [0 210 1 211 2 212 3 213 4 214 5] 215 216 c = A'* b = [ 3 217 67 218 33 219 9 220 17] 221 222 A'A = [10 0 2 12 0 223 0 155 65 0 30 224 2 65 70 1 1 225 12 0 1 17 1 226 0 30 1 1 37] 227 228 S = [ 42.3419 -1.4000 -11.5806 229 -1.4000 2.6000 1.0000 230 11.5806 1.0000 31.1935] 231 232 r = [ 4.3032 233 5.4000 234 5.0323] 235 236 S\r = [ 0.2102 237 2.1367 238 0.1388] 239 240 A\b = [-2.3061 241 0.3172 242 0.2102 243 2.1367 244 0.1388] 245 */ 246 // The following two functions create a TripletSparseMatrix and a 247 // BlockSparseMatrix version of this problem. 248 249 // TripletSparseMatrix version. 250 LinearLeastSquaresProblem* LinearLeastSquaresProblem1() { 251 int num_rows = 6; 252 int num_cols = 5; 253 254 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 255 TripletSparseMatrix* A = new TripletSparseMatrix(num_rows, 256 num_cols, 257 num_rows * num_cols); 258 problem->b.reset(new double[num_rows]); 259 problem->D.reset(new double[num_cols]); 260 problem->num_eliminate_blocks = 2; 261 262 int* rows = A->mutable_rows(); 263 int* cols = A->mutable_cols(); 264 double* values = A->mutable_values(); 265 266 int nnz = 0; 267 268 // Row 1 269 { 270 rows[nnz] = 0; 271 cols[nnz] = 0; 272 values[nnz++] = 1; 273 274 rows[nnz] = 0; 275 cols[nnz] = 2; 276 values[nnz++] = 2; 277 } 278 279 // Row 2 280 { 281 rows[nnz] = 1; 282 cols[nnz] = 0; 283 values[nnz++] = 3; 284 285 rows[nnz] = 1; 286 cols[nnz] = 3; 287 values[nnz++] = 4; 288 } 289 290 // Row 3 291 { 292 rows[nnz] = 2; 293 cols[nnz] = 1; 294 values[nnz++] = 5; 295 296 rows[nnz] = 2; 297 cols[nnz] = 4; 298 values[nnz++] = 6; 299 } 300 301 // Row 4 302 { 303 rows[nnz] = 3; 304 cols[nnz] = 1; 305 values[nnz++] = 7; 306 307 rows[nnz] = 3; 308 cols[nnz] = 2; 309 values[nnz++] = 8; 310 } 311 312 // Row 5 313 { 314 rows[nnz] = 4; 315 cols[nnz] = 1; 316 values[nnz++] = 9; 317 318 rows[nnz] = 4; 319 cols[nnz] = 2; 320 values[nnz++] = 1; 321 } 322 323 // Row 6 324 { 325 rows[nnz] = 5; 326 cols[nnz] = 2; 327 values[nnz++] = 1; 328 329 rows[nnz] = 5; 330 cols[nnz] = 3; 331 values[nnz++] = 1; 332 333 rows[nnz] = 5; 334 cols[nnz] = 4; 335 values[nnz++] = 1; 336 } 337 338 A->set_num_nonzeros(nnz); 339 CHECK(A->IsValid()); 340 341 problem->A.reset(A); 342 343 for (int i = 0; i < num_cols; ++i) { 344 problem->D.get()[i] = 1; 345 } 346 347 for (int i = 0; i < num_rows; ++i) { 348 problem->b.get()[i] = i; 349 } 350 351 return problem; 352 } 353 354 // BlockSparseMatrix version 355 LinearLeastSquaresProblem* LinearLeastSquaresProblem2() { 356 int num_rows = 6; 357 int num_cols = 5; 358 359 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 360 361 problem->b.reset(new double[num_rows]); 362 problem->D.reset(new double[num_cols]); 363 problem->num_eliminate_blocks = 2; 364 365 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure; 366 scoped_array<double> values(new double[num_rows * num_cols]); 367 368 for (int c = 0; c < num_cols; ++c) { 369 bs->cols.push_back(Block()); 370 bs->cols.back().size = 1; 371 bs->cols.back().position = c; 372 } 373 374 int nnz = 0; 375 376 // Row 1 377 { 378 values[nnz++] = 1; 379 values[nnz++] = 2; 380 381 bs->rows.push_back(CompressedRow()); 382 CompressedRow& row = bs->rows.back(); 383 row.block.size = 1; 384 row.block.position = 0; 385 row.cells.push_back(Cell(0, 0)); 386 row.cells.push_back(Cell(2, 1)); 387 } 388 389 // Row 2 390 { 391 values[nnz++] = 3; 392 values[nnz++] = 4; 393 394 bs->rows.push_back(CompressedRow()); 395 CompressedRow& row = bs->rows.back(); 396 row.block.size = 1; 397 row.block.position = 1; 398 row.cells.push_back(Cell(0, 2)); 399 row.cells.push_back(Cell(3, 3)); 400 } 401 402 // Row 3 403 { 404 values[nnz++] = 5; 405 values[nnz++] = 6; 406 407 bs->rows.push_back(CompressedRow()); 408 CompressedRow& row = bs->rows.back(); 409 row.block.size = 1; 410 row.block.position = 2; 411 row.cells.push_back(Cell(1, 4)); 412 row.cells.push_back(Cell(4, 5)); 413 } 414 415 // Row 4 416 { 417 values[nnz++] = 7; 418 values[nnz++] = 8; 419 420 bs->rows.push_back(CompressedRow()); 421 CompressedRow& row = bs->rows.back(); 422 row.block.size = 1; 423 row.block.position = 3; 424 row.cells.push_back(Cell(1, 6)); 425 row.cells.push_back(Cell(2, 7)); 426 } 427 428 // Row 5 429 { 430 values[nnz++] = 9; 431 values[nnz++] = 1; 432 433 bs->rows.push_back(CompressedRow()); 434 CompressedRow& row = bs->rows.back(); 435 row.block.size = 1; 436 row.block.position = 4; 437 row.cells.push_back(Cell(1, 8)); 438 row.cells.push_back(Cell(2, 9)); 439 } 440 441 // Row 6 442 { 443 values[nnz++] = 1; 444 values[nnz++] = 1; 445 values[nnz++] = 1; 446 447 bs->rows.push_back(CompressedRow()); 448 CompressedRow& row = bs->rows.back(); 449 row.block.size = 1; 450 row.block.position = 5; 451 row.cells.push_back(Cell(2, 10)); 452 row.cells.push_back(Cell(3, 11)); 453 row.cells.push_back(Cell(4, 12)); 454 } 455 456 BlockSparseMatrix* A = new BlockSparseMatrix(bs); 457 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values())); 458 459 for (int i = 0; i < num_cols; ++i) { 460 problem->D.get()[i] = 1; 461 } 462 463 for (int i = 0; i < num_rows; ++i) { 464 problem->b.get()[i] = i; 465 } 466 467 problem->A.reset(A); 468 469 return problem; 470 } 471 472 473 /* 474 A = [1 0 475 3 0 476 0 5 477 0 7 478 0 9 479 0 0] 480 481 b = [0 482 1 483 2 484 3 485 4 486 5] 487 */ 488 // BlockSparseMatrix version 489 LinearLeastSquaresProblem* LinearLeastSquaresProblem3() { 490 int num_rows = 5; 491 int num_cols = 2; 492 493 LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem; 494 495 problem->b.reset(new double[num_rows]); 496 problem->D.reset(new double[num_cols]); 497 problem->num_eliminate_blocks = 2; 498 499 CompressedRowBlockStructure* bs = new CompressedRowBlockStructure; 500 scoped_array<double> values(new double[num_rows * num_cols]); 501 502 for (int c = 0; c < num_cols; ++c) { 503 bs->cols.push_back(Block()); 504 bs->cols.back().size = 1; 505 bs->cols.back().position = c; 506 } 507 508 int nnz = 0; 509 510 // Row 1 511 { 512 values[nnz++] = 1; 513 bs->rows.push_back(CompressedRow()); 514 CompressedRow& row = bs->rows.back(); 515 row.block.size = 1; 516 row.block.position = 0; 517 row.cells.push_back(Cell(0, 0)); 518 } 519 520 // Row 2 521 { 522 values[nnz++] = 3; 523 bs->rows.push_back(CompressedRow()); 524 CompressedRow& row = bs->rows.back(); 525 row.block.size = 1; 526 row.block.position = 1; 527 row.cells.push_back(Cell(0, 1)); 528 } 529 530 // Row 3 531 { 532 values[nnz++] = 5; 533 bs->rows.push_back(CompressedRow()); 534 CompressedRow& row = bs->rows.back(); 535 row.block.size = 1; 536 row.block.position = 2; 537 row.cells.push_back(Cell(1, 2)); 538 } 539 540 // Row 4 541 { 542 values[nnz++] = 7; 543 bs->rows.push_back(CompressedRow()); 544 CompressedRow& row = bs->rows.back(); 545 row.block.size = 1; 546 row.block.position = 3; 547 row.cells.push_back(Cell(1, 3)); 548 } 549 550 // Row 5 551 { 552 values[nnz++] = 9; 553 bs->rows.push_back(CompressedRow()); 554 CompressedRow& row = bs->rows.back(); 555 row.block.size = 1; 556 row.block.position = 4; 557 row.cells.push_back(Cell(1, 4)); 558 } 559 560 BlockSparseMatrix* A = new BlockSparseMatrix(bs); 561 memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values())); 562 563 for (int i = 0; i < num_cols; ++i) { 564 problem->D.get()[i] = 1; 565 } 566 567 for (int i = 0; i < num_rows; ++i) { 568 problem->b.get()[i] = i; 569 } 570 571 problem->A.reset(A); 572 573 return problem; 574 } 575 576 bool DumpLinearLeastSquaresProblemToConsole(const string& directory, 577 int iteration, 578 const SparseMatrix* A, 579 const double* D, 580 const double* b, 581 const double* x, 582 int num_eliminate_blocks) { 583 CHECK_NOTNULL(A); 584 Matrix AA; 585 A->ToDenseMatrix(&AA); 586 LOG(INFO) << "A^T: \n" << AA.transpose(); 587 588 if (D != NULL) { 589 LOG(INFO) << "A's appended diagonal:\n" 590 << ConstVectorRef(D, A->num_cols()); 591 } 592 593 if (b != NULL) { 594 LOG(INFO) << "b: \n" << ConstVectorRef(b, A->num_rows()); 595 } 596 597 if (x != NULL) { 598 LOG(INFO) << "x: \n" << ConstVectorRef(x, A->num_cols()); 599 } 600 return true; 601 }; 602 603 #ifndef CERES_NO_PROTOCOL_BUFFERS 604 bool DumpLinearLeastSquaresProblemToProtocolBuffer(const string& directory, 605 int iteration, 606 const SparseMatrix* A, 607 const double* D, 608 const double* b, 609 const double* x, 610 int num_eliminate_blocks) { 611 CHECK_NOTNULL(A); 612 LinearLeastSquaresProblemProto lsqp; 613 A->ToProto(lsqp.mutable_a()); 614 615 if (D != NULL) { 616 for (int i = 0; i < A->num_cols(); ++i) { 617 lsqp.add_d(D[i]); 618 } 619 } 620 621 if (b != NULL) { 622 for (int i = 0; i < A->num_rows(); ++i) { 623 lsqp.add_b(b[i]); 624 } 625 } 626 627 if (x != NULL) { 628 for (int i = 0; i < A->num_cols(); ++i) { 629 lsqp.add_x(x[i]); 630 } 631 } 632 633 lsqp.set_num_eliminate_blocks(num_eliminate_blocks); 634 string format_string = JoinPath(directory, 635 "lm_iteration_%03d.lsqp"); 636 string filename = 637 StringPrintf(format_string.c_str(), iteration); 638 LOG(INFO) << "Dumping least squares problem for iteration " << iteration 639 << " to disk. File: " << filename; 640 WriteStringToFileOrDie(lsqp.SerializeAsString(), filename); 641 return true; 642 } 643 #else 644 bool DumpLinearLeastSquaresProblemToProtocolBuffer(const string& directory, 645 int iteration, 646 const SparseMatrix* A, 647 const double* D, 648 const double* b, 649 const double* x, 650 int num_eliminate_blocks) { 651 LOG(ERROR) << "Dumping least squares problems is only " 652 << "supported when Ceres is compiled with " 653 << "protocol buffer support."; 654 return false; 655 } 656 #endif 657 658 void WriteArrayToFileOrDie(const string& filename, 659 const double* x, 660 const int size) { 661 CHECK_NOTNULL(x); 662 VLOG(2) << "Writing array to: " << filename; 663 FILE* fptr = fopen(filename.c_str(), "w"); 664 CHECK_NOTNULL(fptr); 665 for (int i = 0; i < size; ++i) { 666 fprintf(fptr, "%17f\n", x[i]); 667 } 668 fclose(fptr); 669 } 670 671 bool DumpLinearLeastSquaresProblemToTextFile(const string& directory, 672 int iteration, 673 const SparseMatrix* A, 674 const double* D, 675 const double* b, 676 const double* x, 677 int num_eliminate_blocks) { 678 CHECK_NOTNULL(A); 679 string format_string = JoinPath(directory, 680 "lm_iteration_%03d"); 681 string filename_prefix = 682 StringPrintf(format_string.c_str(), iteration); 683 684 LOG(INFO) << "writing to: " << filename_prefix << "*"; 685 686 string matlab_script; 687 StringAppendF(&matlab_script, 688 "function lsqp = lm_iteration_%03d()\n", iteration); 689 StringAppendF(&matlab_script, 690 "lsqp.num_rows = %d;\n", A->num_rows()); 691 StringAppendF(&matlab_script, 692 "lsqp.num_cols = %d;\n", A->num_cols()); 693 694 { 695 string filename = filename_prefix + "_A.txt"; 696 FILE* fptr = fopen(filename.c_str(), "w"); 697 CHECK_NOTNULL(fptr); 698 A->ToTextFile(fptr); 699 fclose(fptr); 700 StringAppendF(&matlab_script, 701 "tmp = load('%s', '-ascii');\n", filename.c_str()); 702 StringAppendF( 703 &matlab_script, 704 "lsqp.A = sparse(tmp(:, 1) + 1, tmp(:, 2) + 1, tmp(:, 3), %d, %d);\n", 705 A->num_rows(), 706 A->num_cols()); 707 } 708 709 710 if (D != NULL) { 711 string filename = filename_prefix + "_D.txt"; 712 WriteArrayToFileOrDie(filename, D, A->num_cols()); 713 StringAppendF(&matlab_script, 714 "lsqp.D = load('%s', '-ascii');\n", filename.c_str()); 715 } 716 717 if (b != NULL) { 718 string filename = filename_prefix + "_b.txt"; 719 WriteArrayToFileOrDie(filename, b, A->num_rows()); 720 StringAppendF(&matlab_script, 721 "lsqp.b = load('%s', '-ascii');\n", filename.c_str()); 722 } 723 724 if (x != NULL) { 725 string filename = filename_prefix + "_x.txt"; 726 WriteArrayToFileOrDie(filename, x, A->num_cols()); 727 StringAppendF(&matlab_script, 728 "lsqp.x = load('%s', '-ascii');\n", filename.c_str()); 729 } 730 731 string matlab_filename = filename_prefix + ".m"; 732 WriteStringToFileOrDie(matlab_script, matlab_filename); 733 return true; 734 } 735 736 bool DumpLinearLeastSquaresProblem(const string& directory, 737 int iteration, 738 DumpFormatType dump_format_type, 739 const SparseMatrix* A, 740 const double* D, 741 const double* b, 742 const double* x, 743 int num_eliminate_blocks) { 744 switch (dump_format_type) { 745 case (CONSOLE): 746 return DumpLinearLeastSquaresProblemToConsole(directory, 747 iteration, 748 A, D, b, x, 749 num_eliminate_blocks); 750 case (PROTOBUF): 751 return DumpLinearLeastSquaresProblemToProtocolBuffer( 752 directory, 753 iteration, 754 A, D, b, x, 755 num_eliminate_blocks); 756 case (TEXTFILE): 757 return DumpLinearLeastSquaresProblemToTextFile(directory, 758 iteration, 759 A, D, b, x, 760 num_eliminate_blocks); 761 default: 762 LOG(FATAL) << "Unknown DumpFormatType " << dump_format_type; 763 }; 764 765 return true; 766 } 767 768 } // namespace internal 769 } // namespace ceres 770