Home | History | Annotate | Download | only in linear
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package org.apache.commons.math.linear;
     19 
     20 import java.lang.reflect.Array;
     21 
     22 import org.apache.commons.math.Field;
     23 import org.apache.commons.math.FieldElement;
     24 import org.apache.commons.math.MathRuntimeException;
     25 import org.apache.commons.math.exception.util.LocalizedFormats;
     26 
     27 /**
     28  * Calculates the LUP-decomposition of a square matrix.
     29  * <p>The LUP-decomposition of a matrix A consists of three matrices
     30  * L, U and P that satisfy: PA = LU, L is lower triangular, and U is
     31  * upper triangular and P is a permutation matrix. All matrices are
     32  * m&times;m.</p>
     33  * <p>Since {@link FieldElement field elements} do not provide an ordering
     34  * operator, the permutation matrix is computed here only in order to avoid
     35  * a zero pivot element, no attempt is done to get the largest pivot element.</p>
     36  *
     37  * @param <T> the type of the field elements
     38  * @version $Revision: 983921 $ $Date: 2010-08-10 12:46:06 +0200 (mar. 10 aot 2010) $
     39  * @since 2.0
     40  */
     41 public class FieldLUDecompositionImpl<T extends FieldElement<T>> implements FieldLUDecomposition<T> {
     42 
     43     /** Field to which the elements belong. */
     44     private final Field<T> field;
     45 
     46     /** Entries of LU decomposition. */
     47     private T lu[][];
     48 
     49     /** Pivot permutation associated with LU decomposition */
     50     private int[] pivot;
     51 
     52     /** Parity of the permutation associated with the LU decomposition */
     53     private boolean even;
     54 
     55     /** Singularity indicator. */
     56     private boolean singular;
     57 
     58     /** Cached value of L. */
     59     private FieldMatrix<T> cachedL;
     60 
     61     /** Cached value of U. */
     62     private FieldMatrix<T> cachedU;
     63 
     64     /** Cached value of P. */
     65     private FieldMatrix<T> cachedP;
     66 
     67     /**
     68      * Calculates the LU-decomposition of the given matrix.
     69      * @param matrix The matrix to decompose.
     70      * @exception NonSquareMatrixException if matrix is not square
     71      */
     72     public FieldLUDecompositionImpl(FieldMatrix<T> matrix)
     73         throws NonSquareMatrixException {
     74 
     75         if (!matrix.isSquare()) {
     76             throw new NonSquareMatrixException(matrix.getRowDimension(), matrix.getColumnDimension());
     77         }
     78 
     79         final int m = matrix.getColumnDimension();
     80         field = matrix.getField();
     81         lu = matrix.getData();
     82         pivot = new int[m];
     83         cachedL = null;
     84         cachedU = null;
     85         cachedP = null;
     86 
     87         // Initialize permutation array and parity
     88         for (int row = 0; row < m; row++) {
     89             pivot[row] = row;
     90         }
     91         even     = true;
     92         singular = false;
     93 
     94         // Loop over columns
     95         for (int col = 0; col < m; col++) {
     96 
     97             T sum = field.getZero();
     98 
     99             // upper
    100             for (int row = 0; row < col; row++) {
    101                 final T[] luRow = lu[row];
    102                 sum = luRow[col];
    103                 for (int i = 0; i < row; i++) {
    104                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
    105                 }
    106                 luRow[col] = sum;
    107             }
    108 
    109             // lower
    110             int nonZero = col; // permutation row
    111             for (int row = col; row < m; row++) {
    112                 final T[] luRow = lu[row];
    113                 sum = luRow[col];
    114                 for (int i = 0; i < col; i++) {
    115                     sum = sum.subtract(luRow[i].multiply(lu[i][col]));
    116                 }
    117                 luRow[col] = sum;
    118 
    119                 if (lu[nonZero][col].equals(field.getZero())) {
    120                     // try to select a better permutation choice
    121                     ++nonZero;
    122                 }
    123             }
    124 
    125             // Singularity check
    126             if (nonZero >= m) {
    127                 singular = true;
    128                 return;
    129             }
    130 
    131             // Pivot if necessary
    132             if (nonZero != col) {
    133                 T tmp = field.getZero();
    134                 for (int i = 0; i < m; i++) {
    135                     tmp = lu[nonZero][i];
    136                     lu[nonZero][i] = lu[col][i];
    137                     lu[col][i] = tmp;
    138                 }
    139                 int temp = pivot[nonZero];
    140                 pivot[nonZero] = pivot[col];
    141                 pivot[col] = temp;
    142                 even = !even;
    143             }
    144 
    145             // Divide the lower elements by the "winning" diagonal elt.
    146             final T luDiag = lu[col][col];
    147             for (int row = col + 1; row < m; row++) {
    148                 final T[] luRow = lu[row];
    149                 luRow[col] = luRow[col].divide(luDiag);
    150             }
    151         }
    152 
    153     }
    154 
    155     /** {@inheritDoc} */
    156     public FieldMatrix<T> getL() {
    157         if ((cachedL == null) && !singular) {
    158             final int m = pivot.length;
    159             cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
    160             for (int i = 0; i < m; ++i) {
    161                 final T[] luI = lu[i];
    162                 for (int j = 0; j < i; ++j) {
    163                     cachedL.setEntry(i, j, luI[j]);
    164                 }
    165                 cachedL.setEntry(i, i, field.getOne());
    166             }
    167         }
    168         return cachedL;
    169     }
    170 
    171     /** {@inheritDoc} */
    172     public FieldMatrix<T> getU() {
    173         if ((cachedU == null) && !singular) {
    174             final int m = pivot.length;
    175             cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
    176             for (int i = 0; i < m; ++i) {
    177                 final T[] luI = lu[i];
    178                 for (int j = i; j < m; ++j) {
    179                     cachedU.setEntry(i, j, luI[j]);
    180                 }
    181             }
    182         }
    183         return cachedU;
    184     }
    185 
    186     /** {@inheritDoc} */
    187     public FieldMatrix<T> getP() {
    188         if ((cachedP == null) && !singular) {
    189             final int m = pivot.length;
    190             cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
    191             for (int i = 0; i < m; ++i) {
    192                 cachedP.setEntry(i, pivot[i], field.getOne());
    193             }
    194         }
    195         return cachedP;
    196     }
    197 
    198     /** {@inheritDoc} */
    199     public int[] getPivot() {
    200         return pivot.clone();
    201     }
    202 
    203     /** {@inheritDoc} */
    204     public T getDeterminant() {
    205         if (singular) {
    206             return field.getZero();
    207         } else {
    208             final int m = pivot.length;
    209             T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
    210             for (int i = 0; i < m; i++) {
    211                 determinant = determinant.multiply(lu[i][i]);
    212             }
    213             return determinant;
    214         }
    215     }
    216 
    217     /** {@inheritDoc} */
    218     public FieldDecompositionSolver<T> getSolver() {
    219         return new Solver<T>(field, lu, pivot, singular);
    220     }
    221 
    222     /** Specialized solver. */
    223     private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
    224 
    225         /** Serializable version identifier. */
    226         private static final long serialVersionUID = -6353105415121373022L;
    227 
    228         /** Field to which the elements belong. */
    229         private final Field<T> field;
    230 
    231         /** Entries of LU decomposition. */
    232         private final T lu[][];
    233 
    234         /** Pivot permutation associated with LU decomposition. */
    235         private final int[] pivot;
    236 
    237         /** Singularity indicator. */
    238         private final boolean singular;
    239 
    240         /**
    241          * Build a solver from decomposed matrix.
    242          * @param field field to which the matrix elements belong
    243          * @param lu entries of LU decomposition
    244          * @param pivot pivot permutation associated with LU decomposition
    245          * @param singular singularity indicator
    246          */
    247         private Solver(final Field<T> field, final T[][] lu,
    248                        final int[] pivot, final boolean singular) {
    249             this.field    = field;
    250             this.lu       = lu;
    251             this.pivot    = pivot;
    252             this.singular = singular;
    253         }
    254 
    255         /** {@inheritDoc} */
    256         public boolean isNonSingular() {
    257             return !singular;
    258         }
    259 
    260         /** {@inheritDoc} */
    261         public T[] solve(T[] b)
    262             throws IllegalArgumentException, InvalidMatrixException {
    263 
    264             final int m = pivot.length;
    265             if (b.length != m) {
    266                 throw MathRuntimeException.createIllegalArgumentException(
    267                         LocalizedFormats.VECTOR_LENGTH_MISMATCH,
    268                         b.length, m);
    269             }
    270             if (singular) {
    271                 throw new SingularMatrixException();
    272             }
    273 
    274             @SuppressWarnings("unchecked") // field is of type T
    275             final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
    276 
    277             // Apply permutations to b
    278             for (int row = 0; row < m; row++) {
    279                 bp[row] = b[pivot[row]];
    280             }
    281 
    282             // Solve LY = b
    283             for (int col = 0; col < m; col++) {
    284                 final T bpCol = bp[col];
    285                 for (int i = col + 1; i < m; i++) {
    286                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
    287                 }
    288             }
    289 
    290             // Solve UX = Y
    291             for (int col = m - 1; col >= 0; col--) {
    292                 bp[col] = bp[col].divide(lu[col][col]);
    293                 final T bpCol = bp[col];
    294                 for (int i = 0; i < col; i++) {
    295                     bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
    296                 }
    297             }
    298 
    299             return bp;
    300 
    301         }
    302 
    303         /** {@inheritDoc} */
    304         public FieldVector<T> solve(FieldVector<T> b)
    305             throws IllegalArgumentException, InvalidMatrixException {
    306             try {
    307                 return solve((ArrayFieldVector<T>) b);
    308             } catch (ClassCastException cce) {
    309 
    310                 final int m = pivot.length;
    311                 if (b.getDimension() != m) {
    312                     throw MathRuntimeException.createIllegalArgumentException(
    313                             LocalizedFormats.VECTOR_LENGTH_MISMATCH,
    314                             b.getDimension(), m);
    315                 }
    316                 if (singular) {
    317                     throw new SingularMatrixException();
    318                 }
    319 
    320                 @SuppressWarnings("unchecked") // field is of type T
    321                 final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
    322 
    323                 // Apply permutations to b
    324                 for (int row = 0; row < m; row++) {
    325                     bp[row] = b.getEntry(pivot[row]);
    326                 }
    327 
    328                 // Solve LY = b
    329                 for (int col = 0; col < m; col++) {
    330                     final T bpCol = bp[col];
    331                     for (int i = col + 1; i < m; i++) {
    332                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
    333                     }
    334                 }
    335 
    336                 // Solve UX = Y
    337                 for (int col = m - 1; col >= 0; col--) {
    338                     bp[col] = bp[col].divide(lu[col][col]);
    339                     final T bpCol = bp[col];
    340                     for (int i = 0; i < col; i++) {
    341                         bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
    342                     }
    343                 }
    344 
    345                 return new ArrayFieldVector<T>(bp, false);
    346 
    347             }
    348         }
    349 
    350         /** Solve the linear equation A &times; X = B.
    351          * <p>The A matrix is implicit here. It is </p>
    352          * @param b right-hand side of the equation A &times; X = B
    353          * @return a vector X such that A &times; X = B
    354          * @exception IllegalArgumentException if matrices dimensions don't match
    355          * @exception InvalidMatrixException if decomposed matrix is singular
    356          */
    357         public ArrayFieldVector<T> solve(ArrayFieldVector<T> b)
    358             throws IllegalArgumentException, InvalidMatrixException {
    359             return new ArrayFieldVector<T>(solve(b.getDataRef()), false);
    360         }
    361 
    362         /** {@inheritDoc} */
    363         public FieldMatrix<T> solve(FieldMatrix<T> b)
    364             throws IllegalArgumentException, InvalidMatrixException {
    365 
    366             final int m = pivot.length;
    367             if (b.getRowDimension() != m) {
    368                 throw MathRuntimeException.createIllegalArgumentException(
    369                         LocalizedFormats.DIMENSIONS_MISMATCH_2x2,
    370                         b.getRowDimension(), b.getColumnDimension(), m, "n");
    371             }
    372             if (singular) {
    373                 throw new SingularMatrixException();
    374             }
    375 
    376             final int nColB = b.getColumnDimension();
    377 
    378             // Apply permutations to b
    379             @SuppressWarnings("unchecked") // field is of type T
    380             final T[][] bp = (T[][]) Array.newInstance(field.getZero().getClass(), new int[] { m, nColB });
    381             for (int row = 0; row < m; row++) {
    382                 final T[] bpRow = bp[row];
    383                 final int pRow = pivot[row];
    384                 for (int col = 0; col < nColB; col++) {
    385                     bpRow[col] = b.getEntry(pRow, col);
    386                 }
    387             }
    388 
    389             // Solve LY = b
    390             for (int col = 0; col < m; col++) {
    391                 final T[] bpCol = bp[col];
    392                 for (int i = col + 1; i < m; i++) {
    393                     final T[] bpI = bp[i];
    394                     final T luICol = lu[i][col];
    395                     for (int j = 0; j < nColB; j++) {
    396                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
    397                     }
    398                 }
    399             }
    400 
    401             // Solve UX = Y
    402             for (int col = m - 1; col >= 0; col--) {
    403                 final T[] bpCol = bp[col];
    404                 final T luDiag = lu[col][col];
    405                 for (int j = 0; j < nColB; j++) {
    406                     bpCol[j] = bpCol[j].divide(luDiag);
    407                 }
    408                 for (int i = 0; i < col; i++) {
    409                     final T[] bpI = bp[i];
    410                     final T luICol = lu[i][col];
    411                     for (int j = 0; j < nColB; j++) {
    412                         bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
    413                     }
    414                 }
    415             }
    416 
    417             return new Array2DRowFieldMatrix<T>(bp, false);
    418 
    419         }
    420 
    421         /** {@inheritDoc} */
    422         public FieldMatrix<T> getInverse() throws InvalidMatrixException {
    423             final int m = pivot.length;
    424             final T one = field.getOne();
    425             FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
    426             for (int i = 0; i < m; ++i) {
    427                 identity.setEntry(i, i, one);
    428             }
    429             return solve(identity);
    430         }
    431 
    432     }
    433 
    434 }
    435