Home | History | Annotate | Download | only in regression
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package org.apache.commons.math.stat.regression;
     19 import java.io.Serializable;
     20 
     21 import org.apache.commons.math.MathException;
     22 import org.apache.commons.math.MathRuntimeException;
     23 import org.apache.commons.math.distribution.TDistribution;
     24 import org.apache.commons.math.distribution.TDistributionImpl;
     25 import org.apache.commons.math.exception.util.LocalizedFormats;
     26 import org.apache.commons.math.util.FastMath;
     27 
     28 /**
     29  * Estimates an ordinary least squares regression model
     30  * with one independent variable.
     31  * <p>
     32  * <code> y = intercept + slope * x  </code></p>
     33  * <p>
     34  * Standard errors for <code>intercept</code> and <code>slope</code> are
     35  * available as well as ANOVA, r-square and Pearson's r statistics.</p>
     36  * <p>
     37  * Observations (x,y pairs) can be added to the model one at a time or they
     38  * can be provided in a 2-dimensional array.  The observations are not stored
     39  * in memory, so there is no limit to the number of observations that can be
     40  * added to the model.</p>
     41  * <p>
     42  * <strong>Usage Notes</strong>: <ul>
     43  * <li> When there are fewer than two observations in the model, or when
     44  * there is no variation in the x values (i.e. all x values are the same)
     45  * all statistics return <code>NaN</code>. At least two observations with
     46  * different x coordinates are requred to estimate a bivariate regression
     47  * model.
     48  * </li>
     49  * <li> getters for the statistics always compute values based on the current
     50  * set of observations -- i.e., you can get statistics, then add more data
     51  * and get updated statistics without using a new instance.  There is no
     52  * "compute" method that updates all statistics.  Each of the getters performs
     53  * the necessary computations to return the requested statistic.</li>
     54  * </ul></p>
     55  *
     56  * @version $Revision: 1042336 $ $Date: 2010-12-05 13:40:48 +0100 (dim. 05 dc. 2010) $
     57  */
     58 public class SimpleRegression implements Serializable {
     59 
     60     /** Serializable version identifier */
     61     private static final long serialVersionUID = -3004689053607543335L;
     62 
     63     /** the distribution used to compute inference statistics. */
     64     private TDistribution distribution;
     65 
     66     /** sum of x values */
     67     private double sumX = 0d;
     68 
     69     /** total variation in x (sum of squared deviations from xbar) */
     70     private double sumXX = 0d;
     71 
     72     /** sum of y values */
     73     private double sumY = 0d;
     74 
     75     /** total variation in y (sum of squared deviations from ybar) */
     76     private double sumYY = 0d;
     77 
     78     /** sum of products */
     79     private double sumXY = 0d;
     80 
     81     /** number of observations */
     82     private long n = 0;
     83 
     84     /** mean of accumulated x values, used in updating formulas */
     85     private double xbar = 0;
     86 
     87     /** mean of accumulated y values, used in updating formulas */
     88     private double ybar = 0;
     89 
     90     // ---------------------Public methods--------------------------------------
     91 
     92     /**
     93      * Create an empty SimpleRegression instance
     94      */
     95     public SimpleRegression() {
     96         this(new TDistributionImpl(1.0));
     97     }
     98 
     99     /**
    100      * Create an empty SimpleRegression using the given distribution object to
    101      * compute inference statistics.
    102      * @param t the distribution used to compute inference statistics.
    103      * @since 1.2
    104      * @deprecated in 2.2 (to be removed in 3.0). Please use the {@link
    105      * #SimpleRegression(int) other constructor} instead.
    106      */
    107     @Deprecated
    108     public SimpleRegression(TDistribution t) {
    109         super();
    110         setDistribution(t);
    111     }
    112 
    113     /**
    114      * Create an empty SimpleRegression.
    115      *
    116      * @param degrees Number of degrees of freedom of the distribution
    117      * used to compute inference statistics.
    118      * @since 2.2
    119      */
    120     public SimpleRegression(int degrees) {
    121         setDistribution(new TDistributionImpl(degrees));
    122     }
    123 
    124     /**
    125      * Adds the observation (x,y) to the regression data set.
    126      * <p>
    127      * Uses updating formulas for means and sums of squares defined in
    128      * "Algorithms for Computing the Sample Variance: Analysis and
    129      * Recommendations", Chan, T.F., Golub, G.H., and LeVeque, R.J.
    130      * 1983, American Statistician, vol. 37, pp. 242-247, referenced in
    131      * Weisberg, S. "Applied Linear Regression". 2nd Ed. 1985.</p>
    132      *
    133      *
    134      * @param x independent variable value
    135      * @param y dependent variable value
    136      */
    137     public void addData(double x, double y) {
    138         if (n == 0) {
    139             xbar = x;
    140             ybar = y;
    141         } else {
    142             double dx = x - xbar;
    143             double dy = y - ybar;
    144             sumXX += dx * dx * (double) n / (n + 1d);
    145             sumYY += dy * dy * (double) n / (n + 1d);
    146             sumXY += dx * dy * (double) n / (n + 1d);
    147             xbar += dx / (n + 1.0);
    148             ybar += dy / (n + 1.0);
    149         }
    150         sumX += x;
    151         sumY += y;
    152         n++;
    153 
    154         if (n > 2) {
    155             distribution.setDegreesOfFreedom(n - 2);
    156         }
    157     }
    158 
    159 
    160     /**
    161      * Removes the observation (x,y) from the regression data set.
    162      * <p>
    163      * Mirrors the addData method.  This method permits the use of
    164      * SimpleRegression instances in streaming mode where the regression
    165      * is applied to a sliding "window" of observations, however the caller is
    166      * responsible for maintaining the set of observations in the window.</p>
    167      *
    168      * The method has no effect if there are no points of data (i.e. n=0)
    169      *
    170      * @param x independent variable value
    171      * @param y dependent variable value
    172      */
    173     public void removeData(double x, double y) {
    174         if (n > 0) {
    175             double dx = x - xbar;
    176             double dy = y - ybar;
    177             sumXX -= dx * dx * (double) n / (n - 1d);
    178             sumYY -= dy * dy * (double) n / (n - 1d);
    179             sumXY -= dx * dy * (double) n / (n - 1d);
    180             xbar -= dx / (n - 1.0);
    181             ybar -= dy / (n - 1.0);
    182             sumX -= x;
    183             sumY -= y;
    184             n--;
    185 
    186             if (n > 2) {
    187                 distribution.setDegreesOfFreedom(n - 2);
    188             }
    189         }
    190     }
    191 
    192     /**
    193      * Adds the observations represented by the elements in
    194      * <code>data</code>.
    195      * <p>
    196      * <code>(data[0][0],data[0][1])</code> will be the first observation, then
    197      * <code>(data[1][0],data[1][1])</code>, etc.</p>
    198      * <p>
    199      * This method does not replace data that has already been added.  The
    200      * observations represented by <code>data</code> are added to the existing
    201      * dataset.</p>
    202      * <p>
    203      * To replace all data, use <code>clear()</code> before adding the new
    204      * data.</p>
    205      *
    206      * @param data array of observations to be added
    207      */
    208     public void addData(double[][] data) {
    209         for (int i = 0; i < data.length; i++) {
    210             addData(data[i][0], data[i][1]);
    211         }
    212     }
    213 
    214 
    215     /**
    216      * Removes observations represented by the elements in <code>data</code>.
    217       * <p>
    218      * If the array is larger than the current n, only the first n elements are
    219      * processed.  This method permits the use of SimpleRegression instances in
    220      * streaming mode where the regression is applied to a sliding "window" of
    221      * observations, however the caller is responsible for maintaining the set
    222      * of observations in the window.</p>
    223      * <p>
    224      * To remove all data, use <code>clear()</code>.</p>
    225      *
    226      * @param data array of observations to be removed
    227      */
    228     public void removeData(double[][] data) {
    229         for (int i = 0; i < data.length && n > 0; i++) {
    230             removeData(data[i][0], data[i][1]);
    231         }
    232     }
    233 
    234     /**
    235      * Clears all data from the model.
    236      */
    237     public void clear() {
    238         sumX = 0d;
    239         sumXX = 0d;
    240         sumY = 0d;
    241         sumYY = 0d;
    242         sumXY = 0d;
    243         n = 0;
    244     }
    245 
    246     /**
    247      * Returns the number of observations that have been added to the model.
    248      *
    249      * @return n number of observations that have been added.
    250      */
    251     public long getN() {
    252         return n;
    253     }
    254 
    255     /**
    256      * Returns the "predicted" <code>y</code> value associated with the
    257      * supplied <code>x</code> value,  based on the data that has been
    258      * added to the model when this method is activated.
    259      * <p>
    260      * <code> predict(x) = intercept + slope * x </code></p>
    261      * <p>
    262      * <strong>Preconditions</strong>: <ul>
    263      * <li>At least two observations (with at least two different x values)
    264      * must have been added before invoking this method. If this method is
    265      * invoked before a model can be estimated, <code>Double,NaN</code> is
    266      * returned.
    267      * </li></ul></p>
    268      *
    269      * @param x input <code>x</code> value
    270      * @return predicted <code>y</code> value
    271      */
    272     public double predict(double x) {
    273         double b1 = getSlope();
    274         return getIntercept(b1) + b1 * x;
    275     }
    276 
    277     /**
    278      * Returns the intercept of the estimated regression line.
    279      * <p>
    280      * The least squares estimate of the intercept is computed using the
    281      * <a href="http://www.xycoon.com/estimation4.htm">normal equations</a>.
    282      * The intercept is sometimes denoted b0.</p>
    283      * <p>
    284      * <strong>Preconditions</strong>: <ul>
    285      * <li>At least two observations (with at least two different x values)
    286      * must have been added before invoking this method. If this method is
    287      * invoked before a model can be estimated, <code>Double,NaN</code> is
    288      * returned.
    289      * </li></ul></p>
    290      *
    291      * @return the intercept of the regression line
    292      */
    293     public double getIntercept() {
    294         return getIntercept(getSlope());
    295     }
    296 
    297     /**
    298     * Returns the slope of the estimated regression line.
    299     * <p>
    300     * The least squares estimate of the slope is computed using the
    301     * <a href="http://www.xycoon.com/estimation4.htm">normal equations</a>.
    302     * The slope is sometimes denoted b1.</p>
    303     * <p>
    304     * <strong>Preconditions</strong>: <ul>
    305     * <li>At least two observations (with at least two different x values)
    306     * must have been added before invoking this method. If this method is
    307     * invoked before a model can be estimated, <code>Double.NaN</code> is
    308     * returned.
    309     * </li></ul></p>
    310     *
    311     * @return the slope of the regression line
    312     */
    313     public double getSlope() {
    314         if (n < 2) {
    315             return Double.NaN; //not enough data
    316         }
    317         if (FastMath.abs(sumXX) < 10 * Double.MIN_VALUE) {
    318             return Double.NaN; //not enough variation in x
    319         }
    320         return sumXY / sumXX;
    321     }
    322 
    323     /**
    324      * Returns the <a href="http://www.xycoon.com/SumOfSquares.htm">
    325      * sum of squared errors</a> (SSE) associated with the regression
    326      * model.
    327      * <p>
    328      * The sum is computed using the computational formula</p>
    329      * <p>
    330      * <code>SSE = SYY - (SXY * SXY / SXX)</code></p>
    331      * <p>
    332      * where <code>SYY</code> is the sum of the squared deviations of the y
    333      * values about their mean, <code>SXX</code> is similarly defined and
    334      * <code>SXY</code> is the sum of the products of x and y mean deviations.
    335      * </p><p>
    336      * The sums are accumulated using the updating algorithm referenced in
    337      * {@link #addData}.</p>
    338      * <p>
    339      * The return value is constrained to be non-negative - i.e., if due to
    340      * rounding errors the computational formula returns a negative result,
    341      * 0 is returned.</p>
    342      * <p>
    343      * <strong>Preconditions</strong>: <ul>
    344      * <li>At least two observations (with at least two different x values)
    345      * must have been added before invoking this method. If this method is
    346      * invoked before a model can be estimated, <code>Double,NaN</code> is
    347      * returned.
    348      * </li></ul></p>
    349      *
    350      * @return sum of squared errors associated with the regression model
    351      */
    352     public double getSumSquaredErrors() {
    353         return FastMath.max(0d, sumYY - sumXY * sumXY / sumXX);
    354     }
    355 
    356     /**
    357      * Returns the sum of squared deviations of the y values about their mean.
    358      * <p>
    359      * This is defined as SSTO
    360      * <a href="http://www.xycoon.com/SumOfSquares.htm">here</a>.</p>
    361      * <p>
    362      * If <code>n < 2</code>, this returns <code>Double.NaN</code>.</p>
    363      *
    364      * @return sum of squared deviations of y values
    365      */
    366     public double getTotalSumSquares() {
    367         if (n < 2) {
    368             return Double.NaN;
    369         }
    370         return sumYY;
    371     }
    372 
    373     /**
    374      * Returns the sum of squared deviations of the x values about their mean.
    375      *
    376      * If <code>n < 2</code>, this returns <code>Double.NaN</code>.</p>
    377      *
    378      * @return sum of squared deviations of x values
    379      */
    380     public double getXSumSquares() {
    381         if (n < 2) {
    382             return Double.NaN;
    383         }
    384         return sumXX;
    385     }
    386 
    387     /**
    388      * Returns the sum of crossproducts, x<sub>i</sub>*y<sub>i</sub>.
    389      *
    390      * @return sum of cross products
    391      */
    392     public double getSumOfCrossProducts() {
    393         return sumXY;
    394     }
    395 
    396     /**
    397      * Returns the sum of squared deviations of the predicted y values about
    398      * their mean (which equals the mean of y).
    399      * <p>
    400      * This is usually abbreviated SSR or SSM.  It is defined as SSM
    401      * <a href="http://www.xycoon.com/SumOfSquares.htm">here</a></p>
    402      * <p>
    403      * <strong>Preconditions</strong>: <ul>
    404      * <li>At least two observations (with at least two different x values)
    405      * must have been added before invoking this method. If this method is
    406      * invoked before a model can be estimated, <code>Double.NaN</code> is
    407      * returned.
    408      * </li></ul></p>
    409      *
    410      * @return sum of squared deviations of predicted y values
    411      */
    412     public double getRegressionSumSquares() {
    413         return getRegressionSumSquares(getSlope());
    414     }
    415 
    416     /**
    417      * Returns the sum of squared errors divided by the degrees of freedom,
    418      * usually abbreviated MSE.
    419      * <p>
    420      * If there are fewer than <strong>three</strong> data pairs in the model,
    421      * or if there is no variation in <code>x</code>, this returns
    422      * <code>Double.NaN</code>.</p>
    423      *
    424      * @return sum of squared deviations of y values
    425      */
    426     public double getMeanSquareError() {
    427         if (n < 3) {
    428             return Double.NaN;
    429         }
    430         return getSumSquaredErrors() / (n - 2);
    431     }
    432 
    433     /**
    434      * Returns <a href="http://mathworld.wolfram.com/CorrelationCoefficient.html">
    435      * Pearson's product moment correlation coefficient</a>,
    436      * usually denoted r.
    437      * <p>
    438      * <strong>Preconditions</strong>: <ul>
    439      * <li>At least two observations (with at least two different x values)
    440      * must have been added before invoking this method. If this method is
    441      * invoked before a model can be estimated, <code>Double,NaN</code> is
    442      * returned.
    443      * </li></ul></p>
    444      *
    445      * @return Pearson's r
    446      */
    447     public double getR() {
    448         double b1 = getSlope();
    449         double result = FastMath.sqrt(getRSquare());
    450         if (b1 < 0) {
    451             result = -result;
    452         }
    453         return result;
    454     }
    455 
    456     /**
    457      * Returns the <a href="http://www.xycoon.com/coefficient1.htm">
    458      * coefficient of determination</a>,
    459      * usually denoted r-square.
    460      * <p>
    461      * <strong>Preconditions</strong>: <ul>
    462      * <li>At least two observations (with at least two different x values)
    463      * must have been added before invoking this method. If this method is
    464      * invoked before a model can be estimated, <code>Double,NaN</code> is
    465      * returned.
    466      * </li></ul></p>
    467      *
    468      * @return r-square
    469      */
    470     public double getRSquare() {
    471         double ssto = getTotalSumSquares();
    472         return (ssto - getSumSquaredErrors()) / ssto;
    473     }
    474 
    475     /**
    476      * Returns the <a href="http://www.xycoon.com/standarderrorb0.htm">
    477      * standard error of the intercept estimate</a>,
    478      * usually denoted s(b0).
    479      * <p>
    480      * If there are fewer that <strong>three</strong> observations in the
    481      * model, or if there is no variation in x, this returns
    482      * <code>Double.NaN</code>.</p>
    483      *
    484      * @return standard error associated with intercept estimate
    485      */
    486     public double getInterceptStdErr() {
    487         return FastMath.sqrt(
    488             getMeanSquareError() * ((1d / (double) n) + (xbar * xbar) / sumXX));
    489     }
    490 
    491     /**
    492      * Returns the <a href="http://www.xycoon.com/standerrorb(1).htm">standard
    493      * error of the slope estimate</a>,
    494      * usually denoted s(b1).
    495      * <p>
    496      * If there are fewer that <strong>three</strong> data pairs in the model,
    497      * or if there is no variation in x, this returns <code>Double.NaN</code>.
    498      * </p>
    499      *
    500      * @return standard error associated with slope estimate
    501      */
    502     public double getSlopeStdErr() {
    503         return FastMath.sqrt(getMeanSquareError() / sumXX);
    504     }
    505 
    506     /**
    507      * Returns the half-width of a 95% confidence interval for the slope
    508      * estimate.
    509      * <p>
    510      * The 95% confidence interval is</p>
    511      * <p>
    512      * <code>(getSlope() - getSlopeConfidenceInterval(),
    513      * getSlope() + getSlopeConfidenceInterval())</code></p>
    514      * <p>
    515      * If there are fewer that <strong>three</strong> observations in the
    516      * model, or if there is no variation in x, this returns
    517      * <code>Double.NaN</code>.</p>
    518      * <p>
    519      * <strong>Usage Note</strong>:<br>
    520      * The validity of this statistic depends on the assumption that the
    521      * observations included in the model are drawn from a
    522      * <a href="http://mathworld.wolfram.com/BivariateNormalDistribution.html">
    523      * Bivariate Normal Distribution</a>.</p>
    524      *
    525      * @return half-width of 95% confidence interval for the slope estimate
    526      * @throws MathException if the confidence interval can not be computed.
    527      */
    528     public double getSlopeConfidenceInterval() throws MathException {
    529         return getSlopeConfidenceInterval(0.05d);
    530     }
    531 
    532     /**
    533      * Returns the half-width of a (100-100*alpha)% confidence interval for
    534      * the slope estimate.
    535      * <p>
    536      * The (100-100*alpha)% confidence interval is </p>
    537      * <p>
    538      * <code>(getSlope() - getSlopeConfidenceInterval(),
    539      * getSlope() + getSlopeConfidenceInterval())</code></p>
    540      * <p>
    541      * To request, for example, a 99% confidence interval, use
    542      * <code>alpha = .01</code></p>
    543      * <p>
    544      * <strong>Usage Note</strong>:<br>
    545      * The validity of this statistic depends on the assumption that the
    546      * observations included in the model are drawn from a
    547      * <a href="http://mathworld.wolfram.com/BivariateNormalDistribution.html">
    548      * Bivariate Normal Distribution</a>.</p>
    549      * <p>
    550      * <strong> Preconditions:</strong><ul>
    551      * <li>If there are fewer that <strong>three</strong> observations in the
    552      * model, or if there is no variation in x, this returns
    553      * <code>Double.NaN</code>.
    554      * </li>
    555      * <li><code>(0 < alpha < 1)</code>; otherwise an
    556      * <code>IllegalArgumentException</code> is thrown.
    557      * </li></ul></p>
    558      *
    559      * @param alpha the desired significance level
    560      * @return half-width of 95% confidence interval for the slope estimate
    561      * @throws MathException if the confidence interval can not be computed.
    562      */
    563     public double getSlopeConfidenceInterval(double alpha)
    564         throws MathException {
    565         if (alpha >= 1 || alpha <= 0) {
    566             throw MathRuntimeException.createIllegalArgumentException(
    567                   LocalizedFormats.OUT_OF_BOUND_SIGNIFICANCE_LEVEL,
    568                   alpha, 0.0, 1.0);
    569         }
    570         return getSlopeStdErr() *
    571             distribution.inverseCumulativeProbability(1d - alpha / 2d);
    572     }
    573 
    574     /**
    575      * Returns the significance level of the slope (equiv) correlation.
    576      * <p>
    577      * Specifically, the returned value is the smallest <code>alpha</code>
    578      * such that the slope confidence interval with significance level
    579      * equal to <code>alpha</code> does not include <code>0</code>.
    580      * On regression output, this is often denoted <code>Prob(|t| > 0)</code>
    581      * </p><p>
    582      * <strong>Usage Note</strong>:<br>
    583      * The validity of this statistic depends on the assumption that the
    584      * observations included in the model are drawn from a
    585      * <a href="http://mathworld.wolfram.com/BivariateNormalDistribution.html">
    586      * Bivariate Normal Distribution</a>.</p>
    587      * <p>
    588      * If there are fewer that <strong>three</strong> observations in the
    589      * model, or if there is no variation in x, this returns
    590      * <code>Double.NaN</code>.</p>
    591      *
    592      * @return significance level for slope/correlation
    593      * @throws MathException if the significance level can not be computed.
    594      */
    595     public double getSignificance() throws MathException {
    596         return 2d * (1.0 - distribution.cumulativeProbability(
    597                     FastMath.abs(getSlope()) / getSlopeStdErr()));
    598     }
    599 
    600     // ---------------------Private methods-----------------------------------
    601 
    602     /**
    603     * Returns the intercept of the estimated regression line, given the slope.
    604     * <p>
    605     * Will return <code>NaN</code> if slope is <code>NaN</code>.</p>
    606     *
    607     * @param slope current slope
    608     * @return the intercept of the regression line
    609     */
    610     private double getIntercept(double slope) {
    611         return (sumY - slope * sumX) / n;
    612     }
    613 
    614     /**
    615      * Computes SSR from b1.
    616      *
    617      * @param slope regression slope estimate
    618      * @return sum of squared deviations of predicted y values
    619      */
    620     private double getRegressionSumSquares(double slope) {
    621         return slope * slope * sumXX;
    622     }
    623 
    624     /**
    625      * Modify the distribution used to compute inference statistics.
    626      * @param value the new distribution
    627      * @since 1.2
    628      * @deprecated in 2.2 (to be removed in 3.0).
    629      */
    630     @Deprecated
    631     public void setDistribution(TDistribution value) {
    632         distribution = value;
    633 
    634         // modify degrees of freedom
    635         if (n > 2) {
    636             distribution.setDegreesOfFreedom(n - 2);
    637         }
    638     }
    639 }
    640