1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package org.apache.commons.math.stat.regression; 18 19 import org.apache.commons.math.MathRuntimeException; 20 import org.apache.commons.math.exception.util.LocalizedFormats; 21 import org.apache.commons.math.linear.RealMatrix; 22 import org.apache.commons.math.linear.Array2DRowRealMatrix; 23 import org.apache.commons.math.linear.RealVector; 24 import org.apache.commons.math.linear.ArrayRealVector; 25 import org.apache.commons.math.stat.descriptive.moment.Variance; 26 import org.apache.commons.math.util.FastMath; 27 28 /** 29 * Abstract base class for implementations of MultipleLinearRegression. 30 * @version $Revision: 1073459 $ $Date: 2011-02-22 20:18:12 +0100 (mar. 22 fvr. 2011) $ 31 * @since 2.0 32 */ 33 public abstract class AbstractMultipleLinearRegression implements 34 MultipleLinearRegression { 35 36 /** X sample data. */ 37 protected RealMatrix X; 38 39 /** Y sample data. */ 40 protected RealVector Y; 41 42 /** Whether or not the regression model includes an intercept. True means no intercept. */ 43 private boolean noIntercept = false; 44 45 /** 46 * @return true if the model has no intercept term; false otherwise 47 * @since 2.2 48 */ 49 public boolean isNoIntercept() { 50 return noIntercept; 51 } 52 53 /** 54 * @param noIntercept true means the model is to be estimated without an intercept term 55 * @since 2.2 56 */ 57 public void setNoIntercept(boolean noIntercept) { 58 this.noIntercept = noIntercept; 59 } 60 61 /** 62 * <p>Loads model x and y sample data from a flat input array, overriding any previous sample. 63 * </p> 64 * <p>Assumes that rows are concatenated with y values first in each row. For example, an input 65 * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with 66 * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two 67 * independent variables, as below: 68 * <pre> 69 * y x[0] x[1] 70 * -------------- 71 * 1 2 3 72 * 4 5 6 73 * 7 8 9 74 * </pre> 75 * </p> 76 * <p>Note that there is no need to add an initial unitary column (column of 1's) when 77 * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>, 78 * the X matrix will be created without an initial column of "1"s; otherwise this column will 79 * be added. 80 * </p> 81 * <p>Throws IllegalArgumentException if any of the following preconditions fail: 82 * <ul><li><code>data</code> cannot be null</li> 83 * <li><code>data.length = nobs * (nvars + 1)</li> 84 * <li><code>nobs > nvars</code></li></ul> 85 * </p> 86 * 87 * @param data input data array 88 * @param nobs number of observations (rows) 89 * @param nvars number of independent variables (columns, not counting y) 90 * @throws IllegalArgumentException if the preconditions are not met 91 */ 92 public void newSampleData(double[] data, int nobs, int nvars) { 93 if (data == null) { 94 throw MathRuntimeException.createIllegalArgumentException( 95 LocalizedFormats.NULL_NOT_ALLOWED); 96 } 97 if (data.length != nobs * (nvars + 1)) { 98 throw MathRuntimeException.createIllegalArgumentException( 99 LocalizedFormats.INVALID_REGRESSION_ARRAY, data.length, nobs, nvars); 100 } 101 if (nobs <= nvars) { 102 throw MathRuntimeException.createIllegalArgumentException( 103 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS); 104 } 105 double[] y = new double[nobs]; 106 final int cols = noIntercept ? nvars: nvars + 1; 107 double[][] x = new double[nobs][cols]; 108 int pointer = 0; 109 for (int i = 0; i < nobs; i++) { 110 y[i] = data[pointer++]; 111 if (!noIntercept) { 112 x[i][0] = 1.0d; 113 } 114 for (int j = noIntercept ? 0 : 1; j < cols; j++) { 115 x[i][j] = data[pointer++]; 116 } 117 } 118 this.X = new Array2DRowRealMatrix(x); 119 this.Y = new ArrayRealVector(y); 120 } 121 122 /** 123 * Loads new y sample data, overriding any previous data. 124 * 125 * @param y the array representing the y sample 126 * @throws IllegalArgumentException if y is null or empty 127 */ 128 protected void newYSampleData(double[] y) { 129 if (y == null) { 130 throw MathRuntimeException.createIllegalArgumentException( 131 LocalizedFormats.NULL_NOT_ALLOWED); 132 } 133 if (y.length == 0) { 134 throw MathRuntimeException.createIllegalArgumentException( 135 LocalizedFormats.NO_DATA); 136 } 137 this.Y = new ArrayRealVector(y); 138 } 139 140 /** 141 * <p>Loads new x sample data, overriding any previous data. 142 * </p> 143 * The input <code>x</code> array should have one row for each sample 144 * observation, with columns corresponding to independent variables. 145 * For example, if <pre> 146 * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre> 147 * then <code>setXSampleData(x) </code> results in a model with two independent 148 * variables and 3 observations: 149 * <pre> 150 * x[0] x[1] 151 * ---------- 152 * 1 2 153 * 3 4 154 * 5 6 155 * </pre> 156 * </p> 157 * <p>Note that there is no need to add an initial unitary column (column of 1's) when 158 * specifying a model including an intercept term. 159 * </p> 160 * @param x the rectangular array representing the x sample 161 * @throws IllegalArgumentException if x is null, empty or not rectangular 162 */ 163 protected void newXSampleData(double[][] x) { 164 if (x == null) { 165 throw MathRuntimeException.createIllegalArgumentException( 166 LocalizedFormats.NULL_NOT_ALLOWED); 167 } 168 if (x.length == 0) { 169 throw MathRuntimeException.createIllegalArgumentException( 170 LocalizedFormats.NO_DATA); 171 } 172 if (noIntercept) { 173 this.X = new Array2DRowRealMatrix(x, true); 174 } else { // Augment design matrix with initial unitary column 175 final int nVars = x[0].length; 176 final double[][] xAug = new double[x.length][nVars + 1]; 177 for (int i = 0; i < x.length; i++) { 178 if (x[i].length != nVars) { 179 throw MathRuntimeException.createIllegalArgumentException( 180 LocalizedFormats.DIFFERENT_ROWS_LENGTHS, 181 x[i].length, nVars); 182 } 183 xAug[i][0] = 1.0d; 184 System.arraycopy(x[i], 0, xAug[i], 1, nVars); 185 } 186 this.X = new Array2DRowRealMatrix(xAug, false); 187 } 188 } 189 190 /** 191 * Validates sample data. Checks that 192 * <ul><li>Neither x nor y is null or empty;</li> 193 * <li>The length (i.e. number of rows) of x equals the length of y</li> 194 * <li>x has at least one more row than it has columns (i.e. there is 195 * sufficient data to estimate regression coefficients for each of the 196 * columns in x plus an intercept.</li> 197 * </ul> 198 * 199 * @param x the [n,k] array representing the x data 200 * @param y the [n,1] array representing the y data 201 * @throws IllegalArgumentException if any of the checks fail 202 * 203 */ 204 protected void validateSampleData(double[][] x, double[] y) { 205 if ((x == null) || (y == null) || (x.length != y.length)) { 206 throw MathRuntimeException.createIllegalArgumentException( 207 LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE, 208 (x == null) ? 0 : x.length, 209 (y == null) ? 0 : y.length); 210 } 211 if (x.length == 0) { // Must be no y data either 212 throw MathRuntimeException.createIllegalArgumentException( 213 LocalizedFormats.NO_DATA); 214 } 215 if (x[0].length + 1 > x.length) { 216 throw MathRuntimeException.createIllegalArgumentException( 217 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS, 218 x.length, x[0].length); 219 } 220 } 221 222 /** 223 * Validates that the x data and covariance matrix have the same 224 * number of rows and that the covariance matrix is square. 225 * 226 * @param x the [n,k] array representing the x sample 227 * @param covariance the [n,n] array representing the covariance matrix 228 * @throws IllegalArgumentException if the number of rows in x is not equal 229 * to the number of rows in covariance or covariance is not square. 230 */ 231 protected void validateCovarianceData(double[][] x, double[][] covariance) { 232 if (x.length != covariance.length) { 233 throw MathRuntimeException.createIllegalArgumentException( 234 LocalizedFormats.DIMENSIONS_MISMATCH_SIMPLE, x.length, covariance.length); 235 } 236 if (covariance.length > 0 && covariance.length != covariance[0].length) { 237 throw MathRuntimeException.createIllegalArgumentException( 238 LocalizedFormats.NON_SQUARE_MATRIX, 239 covariance.length, covariance[0].length); 240 } 241 } 242 243 /** 244 * {@inheritDoc} 245 */ 246 public double[] estimateRegressionParameters() { 247 RealVector b = calculateBeta(); 248 return b.getData(); 249 } 250 251 /** 252 * {@inheritDoc} 253 */ 254 public double[] estimateResiduals() { 255 RealVector b = calculateBeta(); 256 RealVector e = Y.subtract(X.operate(b)); 257 return e.getData(); 258 } 259 260 /** 261 * {@inheritDoc} 262 */ 263 public double[][] estimateRegressionParametersVariance() { 264 return calculateBetaVariance().getData(); 265 } 266 267 /** 268 * {@inheritDoc} 269 */ 270 public double[] estimateRegressionParametersStandardErrors() { 271 double[][] betaVariance = estimateRegressionParametersVariance(); 272 double sigma = calculateErrorVariance(); 273 int length = betaVariance[0].length; 274 double[] result = new double[length]; 275 for (int i = 0; i < length; i++) { 276 result[i] = FastMath.sqrt(sigma * betaVariance[i][i]); 277 } 278 return result; 279 } 280 281 /** 282 * {@inheritDoc} 283 */ 284 public double estimateRegressandVariance() { 285 return calculateYVariance(); 286 } 287 288 /** 289 * Estimates the variance of the error. 290 * 291 * @return estimate of the error variance 292 * @since 2.2 293 */ 294 public double estimateErrorVariance() { 295 return calculateErrorVariance(); 296 297 } 298 299 /** 300 * Estimates the standard error of the regression. 301 * 302 * @return regression standard error 303 * @since 2.2 304 */ 305 public double estimateRegressionStandardError() { 306 return Math.sqrt(estimateErrorVariance()); 307 } 308 309 /** 310 * Calculates the beta of multiple linear regression in matrix notation. 311 * 312 * @return beta 313 */ 314 protected abstract RealVector calculateBeta(); 315 316 /** 317 * Calculates the beta variance of multiple linear regression in matrix 318 * notation. 319 * 320 * @return beta variance 321 */ 322 protected abstract RealMatrix calculateBetaVariance(); 323 324 325 /** 326 * Calculates the variance of the y values. 327 * 328 * @return Y variance 329 */ 330 protected double calculateYVariance() { 331 return new Variance().evaluate(Y.getData()); 332 } 333 334 /** 335 * <p>Calculates the variance of the error term.</p> 336 * Uses the formula <pre> 337 * var(u) = u · u / (n - k) 338 * </pre> 339 * where n and k are the row and column dimensions of the design 340 * matrix X. 341 * 342 * @return error variance estimate 343 * @since 2.2 344 */ 345 protected double calculateErrorVariance() { 346 RealVector residuals = calculateResiduals(); 347 return residuals.dotProduct(residuals) / 348 (X.getRowDimension() - X.getColumnDimension()); 349 } 350 351 /** 352 * Calculates the residuals of multiple linear regression in matrix 353 * notation. 354 * 355 * <pre> 356 * u = y - X * b 357 * </pre> 358 * 359 * @return The residuals [n,1] matrix 360 */ 361 protected RealVector calculateResiduals() { 362 RealVector b = calculateBeta(); 363 return Y.subtract(X.operate(b)); 364 } 365 366 } 367