Home | History | Annotate | Download | only in regression
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 package org.apache.commons.math.stat.regression;
     18 
     19 import org.apache.commons.math.MathRuntimeException;
     20 import org.apache.commons.math.exception.util.LocalizedFormats;
     21 import org.apache.commons.math.linear.RealMatrix;
     22 import org.apache.commons.math.linear.Array2DRowRealMatrix;
     23 import org.apache.commons.math.linear.RealVector;
     24 import org.apache.commons.math.linear.ArrayRealVector;
     25 import org.apache.commons.math.stat.descriptive.moment.Variance;
     26 import org.apache.commons.math.util.FastMath;
     27 
     28 /**
     29  * Abstract base class for implementations of MultipleLinearRegression.
     30  * @version $Revision: 1073459 $ $Date: 2011-02-22 20:18:12 +0100 (mar. 22 fvr. 2011) $
     31  * @since 2.0
     32  */
     33 public abstract class AbstractMultipleLinearRegression implements
     34         MultipleLinearRegression {
     35 
     36     /** X sample data. */
     37     protected RealMatrix X;
     38 
     39     /** Y sample data. */
     40     protected RealVector Y;
     41 
     42     /** Whether or not the regression model includes an intercept.  True means no intercept. */
     43     private boolean noIntercept = false;
     44 
     45     /**
     46      * @return true if the model has no intercept term; false otherwise
     47      * @since 2.2
     48      */
     49     public boolean isNoIntercept() {
     50         return noIntercept;
     51     }
     52 
     53     /**
     54      * @param noIntercept true means the model is to be estimated without an intercept term
     55      * @since 2.2
     56      */
     57     public void setNoIntercept(boolean noIntercept) {
     58         this.noIntercept = noIntercept;
     59     }
     60 
     61     /**
     62      * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
     63      * </p>
     64      * <p>Assumes that rows are concatenated with y values first in each row.  For example, an input
     65      * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
     66      * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
     67      * independent variables, as below:
     68      * <pre>
     69      *   y   x[0]  x[1]
     70      *   --------------
     71      *   1     2     3
     72      *   4     5     6
     73      *   7     8     9
     74      * </pre>
     75      * </p>
     76      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
     77      * specifying a model including an intercept term.  If {@link #isNoIntercept()} is <code>true</code>,
     78      * the X matrix will be created without an initial column of "1"s; otherwise this column will
     79      * be added.
     80      * </p>
     81      * <p>Throws IllegalArgumentException if any of the following preconditions fail:
     82      * <ul><li><code>data</code> cannot be null</li>
     83      * <li><code>data.length = nobs * (nvars + 1)</li>
     84      * <li><code>nobs > nvars</code></li></ul>
     85      * </p>
     86      *
     87      * @param data input data array
     88      * @param nobs number of observations (rows)
     89      * @param nvars number of independent variables (columns, not counting y)
     90      * @throws IllegalArgumentException if the preconditions are not met
     91      */
     92     public void newSampleData(double[] data, int nobs, int nvars) {
     93         if (data == null) {
     94             throw MathRuntimeException.createIllegalArgumentException(
     95                     LocalizedFormats.NULL_NOT_ALLOWED);
     96         }
     97         if (data.length != nobs * (nvars + 1)) {
     98             throw MathRuntimeException.createIllegalArgumentException(
     99                     LocalizedFormats.INVALID_REGRESSION_ARRAY, data.length, nobs, nvars);
    100         }
    101         if (nobs <= nvars) {
    102             throw MathRuntimeException.createIllegalArgumentException(
    103                     LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS);
    104         }
    105         double[] y = new double[nobs];
    106         final int cols = noIntercept ? nvars: nvars + 1;
    107         double[][] x = new double[nobs][cols];
    108         int pointer = 0;
    109         for (int i = 0; i < nobs; i++) {
    110             y[i] = data[pointer++];
    111             if (!noIntercept) {
    112                 x[i][0] = 1.0d;
    113             }
    114             for (int j = noIntercept ? 0 : 1; j < cols; j++) {
    115                 x[i][j] = data[pointer++];
    116             }
    117         }
    118         this.X = new Array2DRowRealMatrix(x);
    119         this.Y = new ArrayRealVector(y);
    120     }
    121 
    122     /**
    123      * Loads new y sample data, overriding any previous data.
    124      *
    125      * @param y the array representing the y sample
    126      * @throws IllegalArgumentException if y is null or empty
    127      */
    128     protected void newYSampleData(double[] y) {
    129         if (y == null) {
    130             throw MathRuntimeException.createIllegalArgumentException(
    131                     LocalizedFormats.NULL_NOT_ALLOWED);
    132         }
    133         if (y.length == 0) {
    134             throw MathRuntimeException.createIllegalArgumentException(
    135                     LocalizedFormats.NO_DATA);
    136         }
    137         this.Y = new ArrayRealVector(y);
    138     }
    139 
    140     /**
    141      * <p>Loads new x sample data, overriding any previous data.
    142      * </p>
    143      * The input <code>x</code> array should have one row for each sample
    144      * observation, with columns corresponding to independent variables.
    145      * For example, if <pre>
    146      * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
    147      * then <code>setXSampleData(x) </code> results in a model with two independent
    148      * variables and 3 observations:
    149      * <pre>
    150      *   x[0]  x[1]
    151      *   ----------
    152      *     1    2
    153      *     3    4
    154      *     5    6
    155      * </pre>
    156      * </p>
    157      * <p>Note that there is no need to add an initial unitary column (column of 1's) when
    158      * specifying a model including an intercept term.
    159      * </p>
    160      * @param x the rectangular array representing the x sample
    161      * @throws IllegalArgumentException if x is null, empty or not rectangular
    162      */
    163     protected void newXSampleData(double[][] x) {
    164         if (x == null) {
    165             throw MathRuntimeException.createIllegalArgumentException(
    166                     LocalizedFormats.NULL_NOT_ALLOWED);
    167         }
    168         if (x.length == 0) {
    169             throw MathRuntimeException.createIllegalArgumentException(
    170                     LocalizedFormats.NO_DATA);
    171         }
    172         if (noIntercept) {
    173             this.X = new Array2DRowRealMatrix(x, true);
    174         } else { // Augment design matrix with initial unitary column
    175             final int nVars = x[0].length;
    176             final double[][] xAug = new double[x.length][nVars + 1];
    177             for (int i = 0; i < x.length; i++) {
    178                 if (x[i].length != nVars) {
    179                     throw MathRuntimeException.createIllegalArgumentException(
    180                             LocalizedFormats.DIFFERENT_ROWS_LENGTHS,
    181                             x[i].length, nVars);
    182                 }
    183                 xAug[i][0] = 1.0d;
    184                 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
    185             }
    186             this.X = new Array2DRowRealMatrix(xAug, false);
    187         }
    188     }
    189 
    190     /**
    191      * Validates sample data.  Checks that
    192      * <ul><li>Neither x nor y is null or empty;</li>
    193      * <li>The length (i.e. number of rows) of x equals the length of y</li>
    194      * <li>x has at least one more row than it has columns (i.e. there is
    195      * sufficient data to estimate regression coefficients for each of the
    196      * columns in x plus an intercept.</li>
    197      * </ul>
    198      *
    199      * @param x the [n,k] array representing the x data
    200      * @param y the [n,1] array representing the y data
    201      * @throws IllegalArgumentException if any of the checks fail
    202      *
    203      */
    204     protected void validateSampleData(double[][] x, double[] y) {
    205         if ((x == null) || (y == null) || (x.length != y.length)) {
    206             throw MathRuntimeException.createIllegalArgumentException(
    207                   LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE,
    208                   (x == null) ? 0 : x.length,
    209                   (y == null) ? 0 : y.length);
    210         }
    211         if (x.length == 0) {  // Must be no y data either
    212             throw MathRuntimeException.createIllegalArgumentException(
    213                     LocalizedFormats.NO_DATA);
    214         }
    215         if (x[0].length + 1 > x.length) {
    216             throw MathRuntimeException.createIllegalArgumentException(
    217                   LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
    218                   x.length, x[0].length);
    219         }
    220     }
    221 
    222     /**
    223      * Validates that the x data and covariance matrix have the same
    224      * number of rows and that the covariance matrix is square.
    225      *
    226      * @param x the [n,k] array representing the x sample
    227      * @param covariance the [n,n] array representing the covariance matrix
    228      * @throws IllegalArgumentException if the number of rows in x is not equal
    229      * to the number of rows in covariance or covariance is not square.
    230      */
    231     protected void validateCovarianceData(double[][] x, double[][] covariance) {
    232         if (x.length != covariance.length) {
    233             throw MathRuntimeException.createIllegalArgumentException(
    234                  LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE, x.length, covariance.length);
    235         }
    236         if (covariance.length > 0 && covariance.length != covariance[0].length) {
    237             throw MathRuntimeException.createIllegalArgumentException(
    238                   LocalizedFormats.NON_SQUARE_MATRIX,
    239                   covariance.length, covariance[0].length);
    240         }
    241     }
    242 
    243     /**
    244      * {@inheritDoc}
    245      */
    246     public double[] estimateRegressionParameters() {
    247         RealVector b = calculateBeta();
    248         return b.getData();
    249     }
    250 
    251     /**
    252      * {@inheritDoc}
    253      */
    254     public double[] estimateResiduals() {
    255         RealVector b = calculateBeta();
    256         RealVector e = Y.subtract(X.operate(b));
    257         return e.getData();
    258     }
    259 
    260     /**
    261      * {@inheritDoc}
    262      */
    263     public double[][] estimateRegressionParametersVariance() {
    264         return calculateBetaVariance().getData();
    265     }
    266 
    267     /**
    268      * {@inheritDoc}
    269      */
    270     public double[] estimateRegressionParametersStandardErrors() {
    271         double[][] betaVariance = estimateRegressionParametersVariance();
    272         double sigma = calculateErrorVariance();
    273         int length = betaVariance[0].length;
    274         double[] result = new double[length];
    275         for (int i = 0; i < length; i++) {
    276             result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
    277         }
    278         return result;
    279     }
    280 
    281     /**
    282      * {@inheritDoc}
    283      */
    284     public double estimateRegressandVariance() {
    285         return calculateYVariance();
    286     }
    287 
    288     /**
    289      * Estimates the variance of the error.
    290      *
    291      * @return estimate of the error variance
    292      * @since 2.2
    293      */
    294     public double estimateErrorVariance() {
    295         return calculateErrorVariance();
    296 
    297     }
    298 
    299     /**
    300      * Estimates the standard error of the regression.
    301      *
    302      * @return regression standard error
    303      * @since 2.2
    304      */
    305     public double estimateRegressionStandardError() {
    306         return Math.sqrt(estimateErrorVariance());
    307     }
    308 
    309     /**
    310      * Calculates the beta of multiple linear regression in matrix notation.
    311      *
    312      * @return beta
    313      */
    314     protected abstract RealVector calculateBeta();
    315 
    316     /**
    317      * Calculates the beta variance of multiple linear regression in matrix
    318      * notation.
    319      *
    320      * @return beta variance
    321      */
    322     protected abstract RealMatrix calculateBetaVariance();
    323 
    324 
    325     /**
    326      * Calculates the variance of the y values.
    327      *
    328      * @return Y variance
    329      */
    330     protected double calculateYVariance() {
    331         return new Variance().evaluate(Y.getData());
    332     }
    333 
    334     /**
    335      * <p>Calculates the variance of the error term.</p>
    336      * Uses the formula <pre>
    337      * var(u) = u &middot; u / (n - k)
    338      * </pre>
    339      * where n and k are the row and column dimensions of the design
    340      * matrix X.
    341      *
    342      * @return error variance estimate
    343      * @since 2.2
    344      */
    345     protected double calculateErrorVariance() {
    346         RealVector residuals = calculateResiduals();
    347         return residuals.dotProduct(residuals) /
    348                (X.getRowDimension() - X.getColumnDimension());
    349     }
    350 
    351     /**
    352      * Calculates the residuals of multiple linear regression in matrix
    353      * notation.
    354      *
    355      * <pre>
    356      * u = y - X * b
    357      * </pre>
    358      *
    359      * @return The residuals [n,1] matrix
    360      */
    361     protected RealVector calculateResiduals() {
    362         RealVector b = calculateBeta();
    363         return Y.subtract(X.operate(b));
    364     }
    365 
    366 }
    367