1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package org.apache.commons.math.analysis.interpolation; 18 19 import java.io.Serializable; 20 import java.util.Arrays; 21 22 import org.apache.commons.math.MathException; 23 import org.apache.commons.math.analysis.polynomials.PolynomialSplineFunction; 24 import org.apache.commons.math.exception.util.Localizable; 25 import org.apache.commons.math.exception.util.LocalizedFormats; 26 import org.apache.commons.math.util.FastMath; 27 28 /** 29 * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression"> 30 * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of 31 * real univariate functions. 32 * <p/> 33 * For reference, see 34 * <a href="http://www.math.tau.ac.il/~yekutiel/MA seminar/Cleveland 1979.pdf"> 35 * William S. Cleveland - Robust Locally Weighted Regression and Smoothing 36 * Scatterplots</a> 37 * <p/> 38 * This class implements both the loess method and serves as an interpolation 39 * adapter to it, allowing to build a spline on the obtained loess fit. 40 * 41 * @version $Revision: 990655 $ $Date: 2010-08-29 23:49:40 +0200 (dim. 29 aot 2010) $ 42 * @since 2.0 43 */ 44 public class LoessInterpolator 45 implements UnivariateRealInterpolator, Serializable { 46 47 /** Default value of the bandwidth parameter. */ 48 public static final double DEFAULT_BANDWIDTH = 0.3; 49 50 /** Default value of the number of robustness iterations. */ 51 public static final int DEFAULT_ROBUSTNESS_ITERS = 2; 52 53 /** 54 * Default value for accuracy. 55 * @since 2.1 56 */ 57 public static final double DEFAULT_ACCURACY = 1e-12; 58 59 /** serializable version identifier. */ 60 private static final long serialVersionUID = 5204927143605193821L; 61 62 /** 63 * The bandwidth parameter: when computing the loess fit at 64 * a particular point, this fraction of source points closest 65 * to the current point is taken into account for computing 66 * a least-squares regression. 67 * <p/> 68 * A sensible value is usually 0.25 to 0.5. 69 */ 70 private final double bandwidth; 71 72 /** 73 * The number of robustness iterations parameter: this many 74 * robustness iterations are done. 75 * <p/> 76 * A sensible value is usually 0 (just the initial fit without any 77 * robustness iterations) to 4. 78 */ 79 private final int robustnessIters; 80 81 /** 82 * If the median residual at a certain robustness iteration 83 * is less than this amount, no more iterations are done. 84 */ 85 private final double accuracy; 86 87 /** 88 * Constructs a new {@link LoessInterpolator} 89 * with a bandwidth of {@link #DEFAULT_BANDWIDTH}, 90 * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations 91 * and an accuracy of {#link #DEFAULT_ACCURACY}. 92 * See {@link #LoessInterpolator(double, int, double)} for an explanation of 93 * the parameters. 94 */ 95 public LoessInterpolator() { 96 this.bandwidth = DEFAULT_BANDWIDTH; 97 this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS; 98 this.accuracy = DEFAULT_ACCURACY; 99 } 100 101 /** 102 * Constructs a new {@link LoessInterpolator} 103 * with given bandwidth and number of robustness iterations. 104 * <p> 105 * Calling this constructor is equivalent to calling {link {@link 106 * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth, 107 * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)} 108 * </p> 109 * 110 * @param bandwidth when computing the loess fit at 111 * a particular point, this fraction of source points closest 112 * to the current point is taken into account for computing 113 * a least-squares regression.</br> 114 * A sensible value is usually 0.25 to 0.5, the default value is 115 * {@link #DEFAULT_BANDWIDTH}. 116 * @param robustnessIters This many robustness iterations are done.</br> 117 * A sensible value is usually 0 (just the initial fit without any 118 * robustness iterations) to 4, the default value is 119 * {@link #DEFAULT_ROBUSTNESS_ITERS}. 120 * @throws MathException if bandwidth does not lie in the interval [0,1] 121 * or if robustnessIters is negative. 122 * @see #LoessInterpolator(double, int, double) 123 */ 124 public LoessInterpolator(double bandwidth, int robustnessIters) throws MathException { 125 this(bandwidth, robustnessIters, DEFAULT_ACCURACY); 126 } 127 128 /** 129 * Constructs a new {@link LoessInterpolator} 130 * with given bandwidth, number of robustness iterations and accuracy. 131 * 132 * @param bandwidth when computing the loess fit at 133 * a particular point, this fraction of source points closest 134 * to the current point is taken into account for computing 135 * a least-squares regression.</br> 136 * A sensible value is usually 0.25 to 0.5, the default value is 137 * {@link #DEFAULT_BANDWIDTH}. 138 * @param robustnessIters This many robustness iterations are done.</br> 139 * A sensible value is usually 0 (just the initial fit without any 140 * robustness iterations) to 4, the default value is 141 * {@link #DEFAULT_ROBUSTNESS_ITERS}. 142 * @param accuracy If the median residual at a certain robustness iteration 143 * is less than this amount, no more iterations are done. 144 * @throws MathException if bandwidth does not lie in the interval [0,1] 145 * or if robustnessIters is negative. 146 * @see #LoessInterpolator(double, int) 147 * @since 2.1 148 */ 149 public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy) throws MathException { 150 if (bandwidth < 0 || bandwidth > 1) { 151 throw new MathException(LocalizedFormats.BANDWIDTH_OUT_OF_INTERVAL, 152 bandwidth); 153 } 154 this.bandwidth = bandwidth; 155 if (robustnessIters < 0) { 156 throw new MathException(LocalizedFormats.NEGATIVE_ROBUSTNESS_ITERATIONS, robustnessIters); 157 } 158 this.robustnessIters = robustnessIters; 159 this.accuracy = accuracy; 160 } 161 162 /** 163 * Compute an interpolating function by performing a loess fit 164 * on the data at the original abscissae and then building a cubic spline 165 * with a 166 * {@link org.apache.commons.math.analysis.interpolation.SplineInterpolator} 167 * on the resulting fit. 168 * 169 * @param xval the arguments for the interpolation points 170 * @param yval the values for the interpolation points 171 * @return A cubic spline built upon a loess fit to the data at the original abscissae 172 * @throws MathException if some of the following conditions are false: 173 * <ul> 174 * <li> Arguments and values are of the same size that is greater than zero</li> 175 * <li> The arguments are in a strictly increasing order</li> 176 * <li> All arguments and values are finite real numbers</li> 177 * </ul> 178 */ 179 public final PolynomialSplineFunction interpolate( 180 final double[] xval, final double[] yval) throws MathException { 181 return new SplineInterpolator().interpolate(xval, smooth(xval, yval)); 182 } 183 184 /** 185 * Compute a weighted loess fit on the data at the original abscissae. 186 * 187 * @param xval the arguments for the interpolation points 188 * @param yval the values for the interpolation points 189 * @param weights point weights: coefficients by which the robustness weight of a point is multiplied 190 * @return values of the loess fit at corresponding original abscissae 191 * @throws MathException if some of the following conditions are false: 192 * <ul> 193 * <li> Arguments and values are of the same size that is greater than zero</li> 194 * <li> The arguments are in a strictly increasing order</li> 195 * <li> All arguments and values are finite real numbers</li> 196 * </ul> 197 * @since 2.1 198 */ 199 public final double[] smooth(final double[] xval, final double[] yval, final double[] weights) 200 throws MathException { 201 if (xval.length != yval.length) { 202 throw new MathException(LocalizedFormats.MISMATCHED_LOESS_ABSCISSA_ORDINATE_ARRAYS, 203 xval.length, yval.length); 204 } 205 206 final int n = xval.length; 207 208 if (n == 0) { 209 throw new MathException(LocalizedFormats.LOESS_EXPECTS_AT_LEAST_ONE_POINT); 210 } 211 212 checkAllFiniteReal(xval, LocalizedFormats.NON_REAL_FINITE_ABSCISSA); 213 checkAllFiniteReal(yval, LocalizedFormats.NON_REAL_FINITE_ORDINATE); 214 checkAllFiniteReal(weights, LocalizedFormats.NON_REAL_FINITE_WEIGHT); 215 216 checkStrictlyIncreasing(xval); 217 218 if (n == 1) { 219 return new double[]{yval[0]}; 220 } 221 222 if (n == 2) { 223 return new double[]{yval[0], yval[1]}; 224 } 225 226 int bandwidthInPoints = (int) (bandwidth * n); 227 228 if (bandwidthInPoints < 2) { 229 throw new MathException(LocalizedFormats.TOO_SMALL_BANDWIDTH, 230 n, 2.0 / n, bandwidth); 231 } 232 233 final double[] res = new double[n]; 234 235 final double[] residuals = new double[n]; 236 final double[] sortedResiduals = new double[n]; 237 238 final double[] robustnessWeights = new double[n]; 239 240 // Do an initial fit and 'robustnessIters' robustness iterations. 241 // This is equivalent to doing 'robustnessIters+1' robustness iterations 242 // starting with all robustness weights set to 1. 243 Arrays.fill(robustnessWeights, 1); 244 245 for (int iter = 0; iter <= robustnessIters; ++iter) { 246 final int[] bandwidthInterval = {0, bandwidthInPoints - 1}; 247 // At each x, compute a local weighted linear regression 248 for (int i = 0; i < n; ++i) { 249 final double x = xval[i]; 250 251 // Find out the interval of source points on which 252 // a regression is to be made. 253 if (i > 0) { 254 updateBandwidthInterval(xval, weights, i, bandwidthInterval); 255 } 256 257 final int ileft = bandwidthInterval[0]; 258 final int iright = bandwidthInterval[1]; 259 260 // Compute the point of the bandwidth interval that is 261 // farthest from x 262 final int edge; 263 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) { 264 edge = ileft; 265 } else { 266 edge = iright; 267 } 268 269 // Compute a least-squares linear fit weighted by 270 // the product of robustness weights and the tricube 271 // weight function. 272 // See http://en.wikipedia.org/wiki/Linear_regression 273 // (section "Univariate linear case") 274 // and http://en.wikipedia.org/wiki/Weighted_least_squares 275 // (section "Weighted least squares") 276 double sumWeights = 0; 277 double sumX = 0; 278 double sumXSquared = 0; 279 double sumY = 0; 280 double sumXY = 0; 281 double denom = FastMath.abs(1.0 / (xval[edge] - x)); 282 for (int k = ileft; k <= iright; ++k) { 283 final double xk = xval[k]; 284 final double yk = yval[k]; 285 final double dist = (k < i) ? x - xk : xk - x; 286 final double w = tricube(dist * denom) * robustnessWeights[k] * weights[k]; 287 final double xkw = xk * w; 288 sumWeights += w; 289 sumX += xkw; 290 sumXSquared += xk * xkw; 291 sumY += yk * w; 292 sumXY += yk * xkw; 293 } 294 295 final double meanX = sumX / sumWeights; 296 final double meanY = sumY / sumWeights; 297 final double meanXY = sumXY / sumWeights; 298 final double meanXSquared = sumXSquared / sumWeights; 299 300 final double beta; 301 if (FastMath.sqrt(FastMath.abs(meanXSquared - meanX * meanX)) < accuracy) { 302 beta = 0; 303 } else { 304 beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX); 305 } 306 307 final double alpha = meanY - beta * meanX; 308 309 res[i] = beta * x + alpha; 310 residuals[i] = FastMath.abs(yval[i] - res[i]); 311 } 312 313 // No need to recompute the robustness weights at the last 314 // iteration, they won't be needed anymore 315 if (iter == robustnessIters) { 316 break; 317 } 318 319 // Recompute the robustness weights. 320 321 // Find the median residual. 322 // An arraycopy and a sort are completely tractable here, 323 // because the preceding loop is a lot more expensive 324 System.arraycopy(residuals, 0, sortedResiduals, 0, n); 325 Arrays.sort(sortedResiduals); 326 final double medianResidual = sortedResiduals[n / 2]; 327 328 if (FastMath.abs(medianResidual) < accuracy) { 329 break; 330 } 331 332 for (int i = 0; i < n; ++i) { 333 final double arg = residuals[i] / (6 * medianResidual); 334 if (arg >= 1) { 335 robustnessWeights[i] = 0; 336 } else { 337 final double w = 1 - arg * arg; 338 robustnessWeights[i] = w * w; 339 } 340 } 341 } 342 343 return res; 344 } 345 346 /** 347 * Compute a loess fit on the data at the original abscissae. 348 * 349 * @param xval the arguments for the interpolation points 350 * @param yval the values for the interpolation points 351 * @return values of the loess fit at corresponding original abscissae 352 * @throws MathException if some of the following conditions are false: 353 * <ul> 354 * <li> Arguments and values are of the same size that is greater than zero</li> 355 * <li> The arguments are in a strictly increasing order</li> 356 * <li> All arguments and values are finite real numbers</li> 357 * </ul> 358 */ 359 public final double[] smooth(final double[] xval, final double[] yval) 360 throws MathException { 361 if (xval.length != yval.length) { 362 throw new MathException(LocalizedFormats.MISMATCHED_LOESS_ABSCISSA_ORDINATE_ARRAYS, 363 xval.length, yval.length); 364 } 365 366 final double[] unitWeights = new double[xval.length]; 367 Arrays.fill(unitWeights, 1.0); 368 369 return smooth(xval, yval, unitWeights); 370 } 371 372 /** 373 * Given an index interval into xval that embraces a certain number of 374 * points closest to xval[i-1], update the interval so that it embraces 375 * the same number of points closest to xval[i], ignoring zero weights. 376 * 377 * @param xval arguments array 378 * @param weights weights array 379 * @param i the index around which the new interval should be computed 380 * @param bandwidthInterval a two-element array {left, right} such that: <p/> 381 * <tt>(left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])</tt> 382 * <p/> and also <p/> 383 * <tt>(right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])</tt>. 384 * The array will be updated. 385 */ 386 private static void updateBandwidthInterval(final double[] xval, final double[] weights, 387 final int i, 388 final int[] bandwidthInterval) { 389 final int left = bandwidthInterval[0]; 390 final int right = bandwidthInterval[1]; 391 392 // The right edge should be adjusted if the next point to the right 393 // is closer to xval[i] than the leftmost point of the current interval 394 int nextRight = nextNonzero(weights, right); 395 if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) { 396 int nextLeft = nextNonzero(weights, bandwidthInterval[0]); 397 bandwidthInterval[0] = nextLeft; 398 bandwidthInterval[1] = nextRight; 399 } 400 } 401 402 /** 403 * Returns the smallest index j such that j > i && (j==weights.length || weights[j] != 0) 404 * @param weights weights array 405 * @param i the index from which to start search; must be < weights.length 406 * @return the smallest index j such that j > i && (j==weights.length || weights[j] != 0) 407 */ 408 private static int nextNonzero(final double[] weights, final int i) { 409 int j = i + 1; 410 while(j < weights.length && weights[j] == 0) { 411 j++; 412 } 413 return j; 414 } 415 416 /** 417 * Compute the 418 * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a> 419 * weight function 420 * 421 * @param x the argument 422 * @return (1-|x|^3)^3 423 */ 424 private static double tricube(final double x) { 425 final double tmp = 1 - x * x * x; 426 return tmp * tmp * tmp; 427 } 428 429 /** 430 * Check that all elements of an array are finite real numbers. 431 * 432 * @param values the values array 433 * @param pattern pattern of the error message 434 * @throws MathException if one of the values is not a finite real number 435 */ 436 private static void checkAllFiniteReal(final double[] values, final Localizable pattern) 437 throws MathException { 438 for (int i = 0; i < values.length; i++) { 439 final double x = values[i]; 440 if (Double.isInfinite(x) || Double.isNaN(x)) { 441 throw new MathException(pattern, i, x); 442 } 443 } 444 } 445 446 /** 447 * Check that elements of the abscissae array are in a strictly 448 * increasing order. 449 * 450 * @param xval the abscissae array 451 * @throws MathException if the abscissae array 452 * is not in a strictly increasing order 453 */ 454 private static void checkStrictlyIncreasing(final double[] xval) 455 throws MathException { 456 for (int i = 0; i < xval.length; ++i) { 457 if (i >= 1 && xval[i - 1] >= xval[i]) { 458 throw new MathException(LocalizedFormats.OUT_OF_ORDER_ABSCISSA_ARRAY, 459 i - 1, xval[i - 1], i, xval[i]); 460 } 461 } 462 } 463 } 464