Home | History | Annotate | Download | only in general
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package org.apache.commons.math.optimization.general;
     19 
     20 import org.apache.commons.math.ConvergenceException;
     21 import org.apache.commons.math.FunctionEvaluationException;
     22 import org.apache.commons.math.analysis.UnivariateRealFunction;
     23 import org.apache.commons.math.analysis.solvers.BrentSolver;
     24 import org.apache.commons.math.analysis.solvers.UnivariateRealSolver;
     25 import org.apache.commons.math.exception.util.LocalizedFormats;
     26 import org.apache.commons.math.optimization.GoalType;
     27 import org.apache.commons.math.optimization.OptimizationException;
     28 import org.apache.commons.math.optimization.RealPointValuePair;
     29 import org.apache.commons.math.util.FastMath;
     30 
     31 /**
     32  * Non-linear conjugate gradient optimizer.
     33  * <p>
     34  * This class supports both the Fletcher-Reeves and the Polak-Ribi&egrave;re
     35  * update formulas for the conjugate search directions. It also supports
     36  * optional preconditioning.
     37  * </p>
     38  *
     39  * @version $Revision: 1070725 $ $Date: 2011-02-15 02:31:12 +0100 (mar. 15 fvr. 2011) $
     40  * @since 2.0
     41  *
     42  */
     43 
     44 public class NonLinearConjugateGradientOptimizer
     45     extends AbstractScalarDifferentiableOptimizer {
     46 
     47     /** Update formula for the beta parameter. */
     48     private final ConjugateGradientFormula updateFormula;
     49 
     50     /** Preconditioner (may be null). */
     51     private Preconditioner preconditioner;
     52 
     53     /** solver to use in the line search (may be null). */
     54     private UnivariateRealSolver solver;
     55 
     56     /** Initial step used to bracket the optimum in line search. */
     57     private double initialStep;
     58 
     59     /** Simple constructor with default settings.
     60      * <p>The convergence check is set to a {@link
     61      * org.apache.commons.math.optimization.SimpleVectorialValueChecker}
     62      * and the maximal number of iterations is set to
     63      * {@link AbstractScalarDifferentiableOptimizer#DEFAULT_MAX_ITERATIONS}.
     64      * @param updateFormula formula to use for updating the &beta; parameter,
     65      * must be one of {@link ConjugateGradientFormula#FLETCHER_REEVES} or {@link
     66      * ConjugateGradientFormula#POLAK_RIBIERE}
     67      */
     68     public NonLinearConjugateGradientOptimizer(final ConjugateGradientFormula updateFormula) {
     69         this.updateFormula = updateFormula;
     70         preconditioner     = null;
     71         solver             = null;
     72         initialStep        = 1.0;
     73     }
     74 
     75     /**
     76      * Set the preconditioner.
     77      * @param preconditioner preconditioner to use for next optimization,
     78      * may be null to remove an already registered preconditioner
     79      */
     80     public void setPreconditioner(final Preconditioner preconditioner) {
     81         this.preconditioner = preconditioner;
     82     }
     83 
     84     /**
     85      * Set the solver to use during line search.
     86      * @param lineSearchSolver solver to use during line search, may be null
     87      * to remove an already registered solver and fall back to the
     88      * default {@link BrentSolver Brent solver}.
     89      */
     90     public void setLineSearchSolver(final UnivariateRealSolver lineSearchSolver) {
     91         this.solver = lineSearchSolver;
     92     }
     93 
     94     /**
     95      * Set the initial step used to bracket the optimum in line search.
     96      * <p>
     97      * The initial step is a factor with respect to the search direction,
     98      * which itself is roughly related to the gradient of the function
     99      * </p>
    100      * @param initialStep initial step used to bracket the optimum in line search,
    101      * if a non-positive value is used, the initial step is reset to its
    102      * default value of 1.0
    103      */
    104     public void setInitialStep(final double initialStep) {
    105         if (initialStep <= 0) {
    106             this.initialStep = 1.0;
    107         } else {
    108             this.initialStep = initialStep;
    109         }
    110     }
    111 
    112     /** {@inheritDoc} */
    113     @Override
    114     protected RealPointValuePair doOptimize()
    115         throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
    116         try {
    117 
    118             // initialization
    119             if (preconditioner == null) {
    120                 preconditioner = new IdentityPreconditioner();
    121             }
    122             if (solver == null) {
    123                 solver = new BrentSolver();
    124             }
    125             final int n = point.length;
    126             double[] r = computeObjectiveGradient(point);
    127             if (goal == GoalType.MINIMIZE) {
    128                 for (int i = 0; i < n; ++i) {
    129                     r[i] = -r[i];
    130                 }
    131             }
    132 
    133             // initial search direction
    134             double[] steepestDescent = preconditioner.precondition(point, r);
    135             double[] searchDirection = steepestDescent.clone();
    136 
    137             double delta = 0;
    138             for (int i = 0; i < n; ++i) {
    139                 delta += r[i] * searchDirection[i];
    140             }
    141 
    142             RealPointValuePair current = null;
    143             while (true) {
    144 
    145                 final double objective = computeObjectiveValue(point);
    146                 RealPointValuePair previous = current;
    147                 current = new RealPointValuePair(point, objective);
    148                 if (previous != null) {
    149                     if (checker.converged(getIterations(), previous, current)) {
    150                         // we have found an optimum
    151                         return current;
    152                     }
    153                 }
    154 
    155                 incrementIterationsCounter();
    156 
    157                 double dTd = 0;
    158                 for (final double di : searchDirection) {
    159                     dTd += di * di;
    160                 }
    161 
    162                 // find the optimal step in the search direction
    163                 final UnivariateRealFunction lsf = new LineSearchFunction(searchDirection);
    164                 final double step = solver.solve(lsf, 0, findUpperBound(lsf, 0, initialStep));
    165 
    166                 // validate new point
    167                 for (int i = 0; i < point.length; ++i) {
    168                     point[i] += step * searchDirection[i];
    169                 }
    170                 r = computeObjectiveGradient(point);
    171                 if (goal == GoalType.MINIMIZE) {
    172                     for (int i = 0; i < n; ++i) {
    173                         r[i] = -r[i];
    174                     }
    175                 }
    176 
    177                 // compute beta
    178                 final double deltaOld = delta;
    179                 final double[] newSteepestDescent = preconditioner.precondition(point, r);
    180                 delta = 0;
    181                 for (int i = 0; i < n; ++i) {
    182                     delta += r[i] * newSteepestDescent[i];
    183                 }
    184 
    185                 final double beta;
    186                 if (updateFormula == ConjugateGradientFormula.FLETCHER_REEVES) {
    187                     beta = delta / deltaOld;
    188                 } else {
    189                     double deltaMid = 0;
    190                     for (int i = 0; i < r.length; ++i) {
    191                         deltaMid += r[i] * steepestDescent[i];
    192                     }
    193                     beta = (delta - deltaMid) / deltaOld;
    194                 }
    195                 steepestDescent = newSteepestDescent;
    196 
    197                 // compute conjugate search direction
    198                 if ((getIterations() % n == 0) || (beta < 0)) {
    199                     // break conjugation: reset search direction
    200                     searchDirection = steepestDescent.clone();
    201                 } else {
    202                     // compute new conjugate search direction
    203                     for (int i = 0; i < n; ++i) {
    204                         searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
    205                     }
    206                 }
    207 
    208             }
    209 
    210         } catch (ConvergenceException ce) {
    211             throw new OptimizationException(ce);
    212         }
    213     }
    214 
    215     /**
    216      * Find the upper bound b ensuring bracketing of a root between a and b
    217      * @param f function whose root must be bracketed
    218      * @param a lower bound of the interval
    219      * @param h initial step to try
    220      * @return b such that f(a) and f(b) have opposite signs
    221      * @exception FunctionEvaluationException if the function cannot be computed
    222      * @exception OptimizationException if no bracket can be found
    223      */
    224     private double findUpperBound(final UnivariateRealFunction f,
    225                                   final double a, final double h)
    226         throws FunctionEvaluationException, OptimizationException {
    227         final double yA = f.value(a);
    228         double yB = yA;
    229         for (double step = h; step < Double.MAX_VALUE; step *= FastMath.max(2, yA / yB)) {
    230             final double b = a + step;
    231             yB = f.value(b);
    232             if (yA * yB <= 0) {
    233                 return b;
    234             }
    235         }
    236         throw new OptimizationException(LocalizedFormats.UNABLE_TO_BRACKET_OPTIMUM_IN_LINE_SEARCH);
    237     }
    238 
    239     /** Default identity preconditioner. */
    240     private static class IdentityPreconditioner implements Preconditioner {
    241 
    242         /** {@inheritDoc} */
    243         public double[] precondition(double[] variables, double[] r) {
    244             return r.clone();
    245         }
    246 
    247     }
    248 
    249     /** Internal class for line search.
    250      * <p>
    251      * The function represented by this class is the dot product of
    252      * the objective function gradient and the search direction. Its
    253      * value is zero when the gradient is orthogonal to the search
    254      * direction, i.e. when the objective function value is a local
    255      * extremum along the search direction.
    256      * </p>
    257      */
    258     private class LineSearchFunction implements UnivariateRealFunction {
    259         /** Search direction. */
    260         private final double[] searchDirection;
    261 
    262         /** Simple constructor.
    263          * @param searchDirection search direction
    264          */
    265         public LineSearchFunction(final double[] searchDirection) {
    266             this.searchDirection = searchDirection;
    267         }
    268 
    269         /** {@inheritDoc} */
    270         public double value(double x) throws FunctionEvaluationException {
    271 
    272             // current point in the search direction
    273             final double[] shiftedPoint = point.clone();
    274             for (int i = 0; i < shiftedPoint.length; ++i) {
    275                 shiftedPoint[i] += x * searchDirection[i];
    276             }
    277 
    278             // gradient of the objective function
    279             final double[] gradient;
    280             gradient = computeObjectiveGradient(shiftedPoint);
    281 
    282             // dot product with the search direction
    283             double dotProduct = 0;
    284             for (int i = 0; i < gradient.length; ++i) {
    285                 dotProduct += gradient[i] * searchDirection[i];
    286             }
    287 
    288             return dotProduct;
    289 
    290         }
    291 
    292     }
    293 
    294 }
    295