Home | History | Annotate | Download | only in regression
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 package org.apache.commons.math.stat.regression;
     18 
     19 import org.apache.commons.math.linear.Array2DRowRealMatrix;
     20 import org.apache.commons.math.linear.LUDecompositionImpl;
     21 import org.apache.commons.math.linear.QRDecomposition;
     22 import org.apache.commons.math.linear.QRDecompositionImpl;
     23 import org.apache.commons.math.linear.RealMatrix;
     24 import org.apache.commons.math.linear.RealVector;
     25 import org.apache.commons.math.stat.StatUtils;
     26 import org.apache.commons.math.stat.descriptive.moment.SecondMoment;
     27 
     28 /**
     29  * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
     30  * multiple linear regression model.</p>
     31  *
     32  * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
     33  * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
     34  *
     35  * <p>To solve the normal equations, this implementation uses QR decomposition
     36  * of the <code>X</code> matrix. (See {@link QRDecompositionImpl} for details on the
     37  * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
     38  * has rows corresponding to sample observations and columns corresponding to independent
     39  * variables.  When the model is estimated using an intercept term (i.e. when
     40  * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
     41  * matrix includes an initial column identically equal to 1.  We solve the normal equations
     42  * as follows:
     43  * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
     44  * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
     45  * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
     46  * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
     47  * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
     48  * R b = Q<sup>T</sup> y </code></pre></p>
     49  *
     50  * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
     51  *
     52  * @version $Revision: 1073464 $ $Date: 2011-02-22 20:35:02 +0100 (mar. 22 fvr. 2011) $
     53  * @since 2.0
     54  */
     55 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
     56 
     57     /** Cached QR decomposition of X matrix */
     58     private QRDecomposition qr = null;
     59 
     60     /**
     61      * Loads model x and y sample data, overriding any previous sample.
     62      *
     63      * Computes and caches QR decomposition of the X matrix.
     64      * @param y the [n,1] array representing the y sample
     65      * @param x the [n,k] array representing the x sample
     66      * @throws IllegalArgumentException if the x and y array data are not
     67      *             compatible for the regression
     68      */
     69     public void newSampleData(double[] y, double[][] x) {
     70         validateSampleData(x, y);
     71         newYSampleData(y);
     72         newXSampleData(x);
     73     }
     74 
     75     /**
     76      * {@inheritDoc}
     77      * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
     78      */
     79     @Override
     80     public void newSampleData(double[] data, int nobs, int nvars) {
     81         super.newSampleData(data, nobs, nvars);
     82         qr = new QRDecompositionImpl(X);
     83     }
     84 
     85     /**
     86      * <p>Compute the "hat" matrix.
     87      * </p>
     88      * <p>The hat matrix is defined in terms of the design matrix X
     89      *  by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
     90      * </p>
     91      * <p>The implementation here uses the QR decomposition to compute the
     92      * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
     93      * p-dimensional identity matrix augmented by 0's.  This computational
     94      * formula is from "The Hat Matrix in Regression and ANOVA",
     95      * David C. Hoaglin and Roy E. Welsch,
     96      * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
     97      *
     98      * @return the hat matrix
     99      */
    100     public RealMatrix calculateHat() {
    101         // Create augmented identity matrix
    102         RealMatrix Q = qr.getQ();
    103         final int p = qr.getR().getColumnDimension();
    104         final int n = Q.getColumnDimension();
    105         Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
    106         double[][] augIData = augI.getDataRef();
    107         for (int i = 0; i < n; i++) {
    108             for (int j =0; j < n; j++) {
    109                 if (i == j && i < p) {
    110                     augIData[i][j] = 1d;
    111                 } else {
    112                     augIData[i][j] = 0d;
    113                 }
    114             }
    115         }
    116 
    117         // Compute and return Hat matrix
    118         return Q.multiply(augI).multiply(Q.transpose());
    119     }
    120 
    121     /**
    122      * <p>Returns the sum of squared deviations of Y from its mean.</p>
    123      *
    124      * <p>If the model has no intercept term, <code>0</code> is used for the
    125      * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
    126      *
    127      * <p>The value returned by this method is the SSTO value used in
    128      * the {@link #calculateRSquared() R-squared} computation.</p>
    129      *
    130      * @return SSTO - the total sum of squares
    131      * @see #isNoIntercept()
    132      * @since 2.2
    133      */
    134     public double calculateTotalSumOfSquares() {
    135         if (isNoIntercept()) {
    136             return StatUtils.sumSq(Y.getData());
    137         } else {
    138             return new SecondMoment().evaluate(Y.getData());
    139         }
    140     }
    141 
    142     /**
    143      * Returns the sum of squared residuals.
    144      *
    145      * @return residual sum of squares
    146      * @since 2.2
    147      */
    148     public double calculateResidualSumOfSquares() {
    149         final RealVector residuals = calculateResiduals();
    150         return residuals.dotProduct(residuals);
    151     }
    152 
    153     /**
    154      * Returns the R-Squared statistic, defined by the formula <pre>
    155      * R<sup>2</sup> = 1 - SSR / SSTO
    156      * </pre>
    157      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
    158      * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
    159      *
    160      * @return R-square statistic
    161      * @since 2.2
    162      */
    163     public double calculateRSquared() {
    164         return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
    165     }
    166 
    167     /**
    168      * <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
    169      * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
    170      * </pre>
    171      * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
    172      * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
    173      * of observations and p is the number of parameters estimated (including the intercept).</p>
    174      *
    175      * <p>If the regression is estimated without an intercept term, what is returned is <pre>
    176      * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
    177      * </pre></p>
    178      *
    179      * @return adjusted R-Squared statistic
    180      * @see #isNoIntercept()
    181      * @since 2.2
    182      */
    183     public double calculateAdjustedRSquared() {
    184         final double n = X.getRowDimension();
    185         if (isNoIntercept()) {
    186             return 1 - (1 - calculateRSquared()) * (n / (n - X.getColumnDimension()));
    187         } else {
    188             return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
    189                 (calculateTotalSumOfSquares() * (n - X.getColumnDimension()));
    190         }
    191     }
    192 
    193     /**
    194      * {@inheritDoc}
    195      * <p>This implementation computes and caches the QR decomposition of the X matrix
    196      * once it is successfully loaded.</p>
    197      */
    198     @Override
    199     protected void newXSampleData(double[][] x) {
    200         super.newXSampleData(x);
    201         qr = new QRDecompositionImpl(X);
    202     }
    203 
    204     /**
    205      * Calculates the regression coefficients using OLS.
    206      *
    207      * @return beta
    208      */
    209     @Override
    210     protected RealVector calculateBeta() {
    211         return qr.getSolver().solve(Y);
    212     }
    213 
    214     /**
    215      * <p>Calculates the variance-covariance matrix of the regression parameters.
    216      * </p>
    217      * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
    218      * </p>
    219      * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
    220      * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
    221      * R included, where p = the length of the beta vector.</p>
    222      *
    223      * @return The beta variance-covariance matrix
    224      */
    225     @Override
    226     protected RealMatrix calculateBetaVariance() {
    227         int p = X.getColumnDimension();
    228         RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
    229         RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse();
    230         return Rinv.multiply(Rinv.transpose());
    231     }
    232 
    233 }
    234