Home | History | Annotate | Download | only in linear
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package org.apache.commons.math.linear;
     19 
     20 import java.util.Arrays;
     21 
     22 import org.apache.commons.math.MathRuntimeException;
     23 import org.apache.commons.math.exception.util.LocalizedFormats;
     24 import org.apache.commons.math.util.FastMath;
     25 
     26 
     27 /**
     28  * Calculates the QR-decomposition of a matrix.
     29  * <p>The QR-decomposition of a matrix A consists of two matrices Q and R
     30  * that satisfy: A = QR, Q is orthogonal (Q<sup>T</sup>Q = I), and R is
     31  * upper triangular. If A is m&times;n, Q is m&times;m and R m&times;n.</p>
     32  * <p>This class compute the decomposition using Householder reflectors.</p>
     33  * <p>For efficiency purposes, the decomposition in packed form is transposed.
     34  * This allows inner loop to iterate inside rows, which is much more cache-efficient
     35  * in Java.</p>
     36  *
     37  * @see <a href="http://mathworld.wolfram.com/QRDecomposition.html">MathWorld</a>
     38  * @see <a href="http://en.wikipedia.org/wiki/QR_decomposition">Wikipedia</a>
     39  *
     40  * @version $Revision: 990655 $ $Date: 2010-08-29 23:49:40 +0200 (dim. 29 aot 2010) $
     41  * @since 1.2
     42  */
     43 public class QRDecompositionImpl implements QRDecomposition {
     44 
     45     /**
     46      * A packed TRANSPOSED representation of the QR decomposition.
     47      * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
     48      * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
     49      * from which an explicit form of Q can be recomputed if desired.</p>
     50      */
     51     private double[][] qrt;
     52 
     53     /** The diagonal elements of R. */
     54     private double[] rDiag;
     55 
     56     /** Cached value of Q. */
     57     private RealMatrix cachedQ;
     58 
     59     /** Cached value of QT. */
     60     private RealMatrix cachedQT;
     61 
     62     /** Cached value of R. */
     63     private RealMatrix cachedR;
     64 
     65     /** Cached value of H. */
     66     private RealMatrix cachedH;
     67 
     68     /**
     69      * Calculates the QR-decomposition of the given matrix.
     70      * @param matrix The matrix to decompose.
     71      */
     72     public QRDecompositionImpl(RealMatrix matrix) {
     73 
     74         final int m = matrix.getRowDimension();
     75         final int n = matrix.getColumnDimension();
     76         qrt = matrix.transpose().getData();
     77         rDiag = new double[FastMath.min(m, n)];
     78         cachedQ  = null;
     79         cachedQT = null;
     80         cachedR  = null;
     81         cachedH  = null;
     82 
     83         /*
     84          * The QR decomposition of a matrix A is calculated using Householder
     85          * reflectors by repeating the following operations to each minor
     86          * A(minor,minor) of A:
     87          */
     88         for (int minor = 0; minor < FastMath.min(m, n); minor++) {
     89 
     90             final double[] qrtMinor = qrt[minor];
     91 
     92             /*
     93              * Let x be the first column of the minor, and a^2 = |x|^2.
     94              * x will be in the positions qr[minor][minor] through qr[m][minor].
     95              * The first column of the transformed minor will be (a,0,0,..)'
     96              * The sign of a is chosen to be opposite to the sign of the first
     97              * component of x. Let's find a:
     98              */
     99             double xNormSqr = 0;
    100             for (int row = minor; row < m; row++) {
    101                 final double c = qrtMinor[row];
    102                 xNormSqr += c * c;
    103             }
    104             final double a = (qrtMinor[minor] > 0) ? -FastMath.sqrt(xNormSqr) : FastMath.sqrt(xNormSqr);
    105             rDiag[minor] = a;
    106 
    107             if (a != 0.0) {
    108 
    109                 /*
    110                  * Calculate the normalized reflection vector v and transform
    111                  * the first column. We know the norm of v beforehand: v = x-ae
    112                  * so |v|^2 = <x-ae,x-ae> = <x,x>-2a<x,e>+a^2<e,e> =
    113                  * a^2+a^2-2a<x,e> = 2a*(a - <x,e>).
    114                  * Here <x, e> is now qr[minor][minor].
    115                  * v = x-ae is stored in the column at qr:
    116                  */
    117                 qrtMinor[minor] -= a; // now |v|^2 = -2a*(qr[minor][minor])
    118 
    119                 /*
    120                  * Transform the rest of the columns of the minor:
    121                  * They will be transformed by the matrix H = I-2vv'/|v|^2.
    122                  * If x is a column vector of the minor, then
    123                  * Hx = (I-2vv'/|v|^2)x = x-2vv'x/|v|^2 = x - 2<x,v>/|v|^2 v.
    124                  * Therefore the transformation is easily calculated by
    125                  * subtracting the column vector (2<x,v>/|v|^2)v from x.
    126                  *
    127                  * Let 2<x,v>/|v|^2 = alpha. From above we have
    128                  * |v|^2 = -2a*(qr[minor][minor]), so
    129                  * alpha = -<x,v>/(a*qr[minor][minor])
    130                  */
    131                 for (int col = minor+1; col < n; col++) {
    132                     final double[] qrtCol = qrt[col];
    133                     double alpha = 0;
    134                     for (int row = minor; row < m; row++) {
    135                         alpha -= qrtCol[row] * qrtMinor[row];
    136                     }
    137                     alpha /= a * qrtMinor[minor];
    138 
    139                     // Subtract the column vector alpha*v from x.
    140                     for (int row = minor; row < m; row++) {
    141                         qrtCol[row] -= alpha * qrtMinor[row];
    142                     }
    143                 }
    144             }
    145         }
    146     }
    147 
    148     /** {@inheritDoc} */
    149     public RealMatrix getR() {
    150 
    151         if (cachedR == null) {
    152 
    153             // R is supposed to be m x n
    154             final int n = qrt.length;
    155             final int m = qrt[0].length;
    156             cachedR = MatrixUtils.createRealMatrix(m, n);
    157 
    158             // copy the diagonal from rDiag and the upper triangle of qr
    159             for (int row = FastMath.min(m, n) - 1; row >= 0; row--) {
    160                 cachedR.setEntry(row, row, rDiag[row]);
    161                 for (int col = row + 1; col < n; col++) {
    162                     cachedR.setEntry(row, col, qrt[col][row]);
    163                 }
    164             }
    165 
    166         }
    167 
    168         // return the cached matrix
    169         return cachedR;
    170 
    171     }
    172 
    173     /** {@inheritDoc} */
    174     public RealMatrix getQ() {
    175         if (cachedQ == null) {
    176             cachedQ = getQT().transpose();
    177         }
    178         return cachedQ;
    179     }
    180 
    181     /** {@inheritDoc} */
    182     public RealMatrix getQT() {
    183 
    184         if (cachedQT == null) {
    185 
    186             // QT is supposed to be m x m
    187             final int n = qrt.length;
    188             final int m = qrt[0].length;
    189             cachedQT = MatrixUtils.createRealMatrix(m, m);
    190 
    191             /*
    192              * Q = Q1 Q2 ... Q_m, so Q is formed by first constructing Q_m and then
    193              * applying the Householder transformations Q_(m-1),Q_(m-2),...,Q1 in
    194              * succession to the result
    195              */
    196             for (int minor = m - 1; minor >= FastMath.min(m, n); minor--) {
    197                 cachedQT.setEntry(minor, minor, 1.0);
    198             }
    199 
    200             for (int minor = FastMath.min(m, n)-1; minor >= 0; minor--){
    201                 final double[] qrtMinor = qrt[minor];
    202                 cachedQT.setEntry(minor, minor, 1.0);
    203                 if (qrtMinor[minor] != 0.0) {
    204                     for (int col = minor; col < m; col++) {
    205                         double alpha = 0;
    206                         for (int row = minor; row < m; row++) {
    207                             alpha -= cachedQT.getEntry(col, row) * qrtMinor[row];
    208                         }
    209                         alpha /= rDiag[minor] * qrtMinor[minor];
    210 
    211                         for (int row = minor; row < m; row++) {
    212                             cachedQT.addToEntry(col, row, -alpha * qrtMinor[row]);
    213                         }
    214                     }
    215                 }
    216             }
    217 
    218         }
    219 
    220         // return the cached matrix
    221         return cachedQT;
    222 
    223     }
    224 
    225     /** {@inheritDoc} */
    226     public RealMatrix getH() {
    227 
    228         if (cachedH == null) {
    229 
    230             final int n = qrt.length;
    231             final int m = qrt[0].length;
    232             cachedH = MatrixUtils.createRealMatrix(m, n);
    233             for (int i = 0; i < m; ++i) {
    234                 for (int j = 0; j < FastMath.min(i + 1, n); ++j) {
    235                     cachedH.setEntry(i, j, qrt[j][i] / -rDiag[j]);
    236                 }
    237             }
    238 
    239         }
    240 
    241         // return the cached matrix
    242         return cachedH;
    243 
    244     }
    245 
    246     /** {@inheritDoc} */
    247     public DecompositionSolver getSolver() {
    248         return new Solver(qrt, rDiag);
    249     }
    250 
    251     /** Specialized solver. */
    252     private static class Solver implements DecompositionSolver {
    253 
    254         /**
    255          * A packed TRANSPOSED representation of the QR decomposition.
    256          * <p>The elements BELOW the diagonal are the elements of the UPPER triangular
    257          * matrix R, and the rows ABOVE the diagonal are the Householder reflector vectors
    258          * from which an explicit form of Q can be recomputed if desired.</p>
    259          */
    260         private final double[][] qrt;
    261 
    262         /** The diagonal elements of R. */
    263         private final double[] rDiag;
    264 
    265         /**
    266          * Build a solver from decomposed matrix.
    267          * @param qrt packed TRANSPOSED representation of the QR decomposition
    268          * @param rDiag diagonal elements of R
    269          */
    270         private Solver(final double[][] qrt, final double[] rDiag) {
    271             this.qrt   = qrt;
    272             this.rDiag = rDiag;
    273         }
    274 
    275         /** {@inheritDoc} */
    276         public boolean isNonSingular() {
    277 
    278             for (double diag : rDiag) {
    279                 if (diag == 0) {
    280                     return false;
    281                 }
    282             }
    283             return true;
    284 
    285         }
    286 
    287         /** {@inheritDoc} */
    288         public double[] solve(double[] b)
    289         throws IllegalArgumentException, InvalidMatrixException {
    290 
    291             final int n = qrt.length;
    292             final int m = qrt[0].length;
    293             if (b.length != m) {
    294                 throw MathRuntimeException.createIllegalArgumentException(
    295                         LocalizedFormats.VECTOR_LENGTH_MISMATCH,
    296                         b.length, m);
    297             }
    298             if (!isNonSingular()) {
    299                 throw new SingularMatrixException();
    300             }
    301 
    302             final double[] x = new double[n];
    303             final double[] y = b.clone();
    304 
    305             // apply Householder transforms to solve Q.y = b
    306             for (int minor = 0; minor < FastMath.min(m, n); minor++) {
    307 
    308                 final double[] qrtMinor = qrt[minor];
    309                 double dotProduct = 0;
    310                 for (int row = minor; row < m; row++) {
    311                     dotProduct += y[row] * qrtMinor[row];
    312                 }
    313                 dotProduct /= rDiag[minor] * qrtMinor[minor];
    314 
    315                 for (int row = minor; row < m; row++) {
    316                     y[row] += dotProduct * qrtMinor[row];
    317                 }
    318 
    319             }
    320 
    321             // solve triangular system R.x = y
    322             for (int row = rDiag.length - 1; row >= 0; --row) {
    323                 y[row] /= rDiag[row];
    324                 final double yRow   = y[row];
    325                 final double[] qrtRow = qrt[row];
    326                 x[row] = yRow;
    327                 for (int i = 0; i < row; i++) {
    328                     y[i] -= yRow * qrtRow[i];
    329                 }
    330             }
    331 
    332             return x;
    333 
    334         }
    335 
    336         /** {@inheritDoc} */
    337         public RealVector solve(RealVector b)
    338         throws IllegalArgumentException, InvalidMatrixException {
    339             try {
    340                 return solve((ArrayRealVector) b);
    341             } catch (ClassCastException cce) {
    342                 return new ArrayRealVector(solve(b.getData()), false);
    343             }
    344         }
    345 
    346         /** Solve the linear equation A &times; X = B.
    347          * <p>The A matrix is implicit here. It is </p>
    348          * @param b right-hand side of the equation A &times; X = B
    349          * @return a vector X that minimizes the two norm of A &times; X - B
    350          * @throws IllegalArgumentException if matrices dimensions don't match
    351          * @throws InvalidMatrixException if decomposed matrix is singular
    352          */
    353         public ArrayRealVector solve(ArrayRealVector b)
    354         throws IllegalArgumentException, InvalidMatrixException {
    355             return new ArrayRealVector(solve(b.getDataRef()), false);
    356         }
    357 
    358         /** {@inheritDoc} */
    359         public RealMatrix solve(RealMatrix b)
    360         throws IllegalArgumentException, InvalidMatrixException {
    361 
    362             final int n = qrt.length;
    363             final int m = qrt[0].length;
    364             if (b.getRowDimension() != m) {
    365                 throw MathRuntimeException.createIllegalArgumentException(
    366                         LocalizedFormats.DIMENSIONS_MISMATCH_2x2,
    367                         b.getRowDimension(), b.getColumnDimension(), m, "n");
    368             }
    369             if (!isNonSingular()) {
    370                 throw new SingularMatrixException();
    371             }
    372 
    373             final int columns        = b.getColumnDimension();
    374             final int blockSize      = BlockRealMatrix.BLOCK_SIZE;
    375             final int cBlocks        = (columns + blockSize - 1) / blockSize;
    376             final double[][] xBlocks = BlockRealMatrix.createBlocksLayout(n, columns);
    377             final double[][] y       = new double[b.getRowDimension()][blockSize];
    378             final double[]   alpha   = new double[blockSize];
    379 
    380             for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
    381                 final int kStart = kBlock * blockSize;
    382                 final int kEnd   = FastMath.min(kStart + blockSize, columns);
    383                 final int kWidth = kEnd - kStart;
    384 
    385                 // get the right hand side vector
    386                 b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
    387 
    388                 // apply Householder transforms to solve Q.y = b
    389                 for (int minor = 0; minor < FastMath.min(m, n); minor++) {
    390                     final double[] qrtMinor = qrt[minor];
    391                     final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]);
    392 
    393                     Arrays.fill(alpha, 0, kWidth, 0.0);
    394                     for (int row = minor; row < m; ++row) {
    395                         final double   d    = qrtMinor[row];
    396                         final double[] yRow = y[row];
    397                         for (int k = 0; k < kWidth; ++k) {
    398                             alpha[k] += d * yRow[k];
    399                         }
    400                     }
    401                     for (int k = 0; k < kWidth; ++k) {
    402                         alpha[k] *= factor;
    403                     }
    404 
    405                     for (int row = minor; row < m; ++row) {
    406                         final double   d    = qrtMinor[row];
    407                         final double[] yRow = y[row];
    408                         for (int k = 0; k < kWidth; ++k) {
    409                             yRow[k] += alpha[k] * d;
    410                         }
    411                     }
    412 
    413                 }
    414 
    415                 // solve triangular system R.x = y
    416                 for (int j = rDiag.length - 1; j >= 0; --j) {
    417                     final int      jBlock = j / blockSize;
    418                     final int      jStart = jBlock * blockSize;
    419                     final double   factor = 1.0 / rDiag[j];
    420                     final double[] yJ     = y[j];
    421                     final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
    422                     int index = (j - jStart) * kWidth;
    423                     for (int k = 0; k < kWidth; ++k) {
    424                         yJ[k]          *= factor;
    425                         xBlock[index++] = yJ[k];
    426                     }
    427 
    428                     final double[] qrtJ = qrt[j];
    429                     for (int i = 0; i < j; ++i) {
    430                         final double rIJ  = qrtJ[i];
    431                         final double[] yI = y[i];
    432                         for (int k = 0; k < kWidth; ++k) {
    433                             yI[k] -= yJ[k] * rIJ;
    434                         }
    435                     }
    436 
    437                 }
    438 
    439             }
    440 
    441             return new BlockRealMatrix(n, columns, xBlocks, false);
    442 
    443         }
    444 
    445         /** {@inheritDoc} */
    446         public RealMatrix getInverse()
    447         throws InvalidMatrixException {
    448             return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
    449         }
    450 
    451     }
    452 
    453 }
    454