Home | History | Annotate | Download | only in fitting
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package org.apache.commons.math.optimization.fitting;
     19 
     20 import java.util.ArrayList;
     21 import java.util.List;
     22 
     23 import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction;
     24 import org.apache.commons.math.analysis.MultivariateMatrixFunction;
     25 import org.apache.commons.math.FunctionEvaluationException;
     26 import org.apache.commons.math.optimization.DifferentiableMultivariateVectorialOptimizer;
     27 import org.apache.commons.math.optimization.OptimizationException;
     28 import org.apache.commons.math.optimization.VectorialPointValuePair;
     29 
     30 /** Fitter for parametric univariate real functions y = f(x).
     31  * <p>When a univariate real function y = f(x) does depend on some
     32  * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
     33  * this class can be used to find these parameters. It does this
     34  * by <em>fitting</em> the curve so it remains very close to a set of
     35  * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
     36  * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
     37  * is done by finding the parameters values that minimizes the objective
     38  * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
     39  * really a least squares problem.</p>
     40  * @version $Revision: 1073158 $ $Date: 2011-02-21 22:46:52 +0100 (lun. 21 fvr. 2011) $
     41  * @since 2.0
     42  */
     43 public class CurveFitter {
     44 
     45     /** Optimizer to use for the fitting. */
     46     private final DifferentiableMultivariateVectorialOptimizer optimizer;
     47 
     48     /** Observed points. */
     49     private final List<WeightedObservedPoint> observations;
     50 
     51     /** Simple constructor.
     52      * @param optimizer optimizer to use for the fitting
     53      */
     54     public CurveFitter(final DifferentiableMultivariateVectorialOptimizer optimizer) {
     55         this.optimizer = optimizer;
     56         observations = new ArrayList<WeightedObservedPoint>();
     57     }
     58 
     59     /** Add an observed (x,y) point to the sample with unit weight.
     60      * <p>Calling this method is equivalent to call
     61      * <code>addObservedPoint(1.0, x, y)</code>.</p>
     62      * @param x abscissa of the point
     63      * @param y observed value of the point at x, after fitting we should
     64      * have f(x) as close as possible to this value
     65      * @see #addObservedPoint(double, double, double)
     66      * @see #addObservedPoint(WeightedObservedPoint)
     67      * @see #getObservations()
     68      */
     69     public void addObservedPoint(double x, double y) {
     70         addObservedPoint(1.0, x, y);
     71     }
     72 
     73     /** Add an observed weighted (x,y) point to the sample.
     74      * @param weight weight of the observed point in the fit
     75      * @param x abscissa of the point
     76      * @param y observed value of the point at x, after fitting we should
     77      * have f(x) as close as possible to this value
     78      * @see #addObservedPoint(double, double)
     79      * @see #addObservedPoint(WeightedObservedPoint)
     80      * @see #getObservations()
     81      */
     82     public void addObservedPoint(double weight, double x, double y) {
     83         observations.add(new WeightedObservedPoint(weight, x, y));
     84     }
     85 
     86     /** Add an observed weighted (x,y) point to the sample.
     87      * @param observed observed point to add
     88      * @see #addObservedPoint(double, double)
     89      * @see #addObservedPoint(double, double, double)
     90      * @see #getObservations()
     91      */
     92     public void addObservedPoint(WeightedObservedPoint observed) {
     93         observations.add(observed);
     94     }
     95 
     96     /** Get the observed points.
     97      * @return observed points
     98      * @see #addObservedPoint(double, double)
     99      * @see #addObservedPoint(double, double, double)
    100      * @see #addObservedPoint(WeightedObservedPoint)
    101      */
    102     public WeightedObservedPoint[] getObservations() {
    103         return observations.toArray(new WeightedObservedPoint[observations.size()]);
    104     }
    105 
    106     /**
    107      * Remove all observations.
    108      */
    109     public void clearObservations() {
    110         observations.clear();
    111     }
    112 
    113     /** Fit a curve.
    114      * <p>This method compute the coefficients of the curve that best
    115      * fit the sample of observed points previously given through calls
    116      * to the {@link #addObservedPoint(WeightedObservedPoint)
    117      * addObservedPoint} method.</p>
    118      * @param f parametric function to fit
    119      * @param initialGuess first guess of the function parameters
    120      * @return fitted parameters
    121      * @exception FunctionEvaluationException if the objective function throws one during the search
    122      * @exception OptimizationException if the algorithm failed to converge
    123      * @exception IllegalArgumentException if the start point dimension is wrong
    124      */
    125     public double[] fit(final ParametricRealFunction f,
    126                         final double[] initialGuess)
    127         throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
    128 
    129         // prepare least squares problem
    130         double[] target  = new double[observations.size()];
    131         double[] weights = new double[observations.size()];
    132         int i = 0;
    133         for (WeightedObservedPoint point : observations) {
    134             target[i]  = point.getY();
    135             weights[i] = point.getWeight();
    136             ++i;
    137         }
    138 
    139         // perform the fit
    140         VectorialPointValuePair optimum =
    141             optimizer.optimize(new TheoreticalValuesFunction(f), target, weights, initialGuess);
    142 
    143         // extract the coefficients
    144         return optimum.getPointRef();
    145 
    146     }
    147 
    148     /** Vectorial function computing function theoretical values. */
    149     private class TheoreticalValuesFunction
    150         implements DifferentiableMultivariateVectorialFunction {
    151 
    152         /** Function to fit. */
    153         private final ParametricRealFunction f;
    154 
    155         /** Simple constructor.
    156          * @param f function to fit.
    157          */
    158         public TheoreticalValuesFunction(final ParametricRealFunction f) {
    159             this.f = f;
    160         }
    161 
    162         /** {@inheritDoc} */
    163         public MultivariateMatrixFunction jacobian() {
    164             return new MultivariateMatrixFunction() {
    165                 public double[][] value(double[] point)
    166                     throws FunctionEvaluationException, IllegalArgumentException {
    167 
    168                     final double[][] jacobian = new double[observations.size()][];
    169 
    170                     int i = 0;
    171                     for (WeightedObservedPoint observed : observations) {
    172                         jacobian[i++] = f.gradient(observed.getX(), point);
    173                     }
    174 
    175                     return jacobian;
    176 
    177                 }
    178             };
    179         }
    180 
    181         /** {@inheritDoc} */
    182         public double[] value(double[] point) throws FunctionEvaluationException, IllegalArgumentException {
    183 
    184             // compute the residuals
    185             final double[] values = new double[observations.size()];
    186             int i = 0;
    187             for (WeightedObservedPoint observed : observations) {
    188                 values[i++] = f.value(observed.getX(), point);
    189             }
    190 
    191             return values;
    192 
    193         }
    194 
    195     }
    196 
    197 }
    198