Home | History | Annotate | Download | only in optimization
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package org.apache.commons.math.optimization;
     19 
     20 import java.util.Arrays;
     21 import java.util.Comparator;
     22 
     23 import org.apache.commons.math.MathRuntimeException;
     24 import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction;
     25 import org.apache.commons.math.FunctionEvaluationException;
     26 import org.apache.commons.math.exception.util.LocalizedFormats;
     27 import org.apache.commons.math.random.RandomVectorGenerator;
     28 
     29 /**
     30  * Special implementation of the {@link DifferentiableMultivariateVectorialOptimizer} interface adding
     31  * multi-start features to an existing optimizer.
     32  * <p>
     33  * This class wraps a classical optimizer to use it several times in
     34  * turn with different starting points in order to avoid being trapped
     35  * into a local extremum when looking for a global one.
     36  * </p>
     37  * @version $Revision: 1073158 $ $Date: 2011-02-21 22:46:52 +0100 (lun. 21 fvr. 2011) $
     38  * @since 2.0
     39  */
     40 public class MultiStartDifferentiableMultivariateVectorialOptimizer
     41     implements DifferentiableMultivariateVectorialOptimizer {
     42 
     43     /** Serializable version identifier. */
     44     private static final long serialVersionUID = 9206382258980561530L;
     45 
     46     /** Underlying classical optimizer. */
     47     private final DifferentiableMultivariateVectorialOptimizer optimizer;
     48 
     49     /** Maximal number of iterations allowed. */
     50     private int maxIterations;
     51 
     52     /** Number of iterations already performed for all starts. */
     53     private int totalIterations;
     54 
     55     /** Maximal number of evaluations allowed. */
     56     private int maxEvaluations;
     57 
     58     /** Number of evaluations already performed for all starts. */
     59     private int totalEvaluations;
     60 
     61     /** Number of jacobian evaluations already performed for all starts. */
     62     private int totalJacobianEvaluations;
     63 
     64     /** Number of starts to go. */
     65     private int starts;
     66 
     67     /** Random generator for multi-start. */
     68     private RandomVectorGenerator generator;
     69 
     70     /** Found optima. */
     71     private VectorialPointValuePair[] optima;
     72 
     73     /**
     74      * Create a multi-start optimizer from a single-start optimizer
     75      * @param optimizer single-start optimizer to wrap
     76      * @param starts number of starts to perform (including the
     77      * first one), multi-start is disabled if value is less than or
     78      * equal to 1
     79      * @param generator random vector generator to use for restarts
     80      */
     81     public MultiStartDifferentiableMultivariateVectorialOptimizer(
     82                 final DifferentiableMultivariateVectorialOptimizer optimizer,
     83                 final int starts,
     84                 final RandomVectorGenerator generator) {
     85         this.optimizer                = optimizer;
     86         this.totalIterations          = 0;
     87         this.totalEvaluations         = 0;
     88         this.totalJacobianEvaluations = 0;
     89         this.starts                   = starts;
     90         this.generator                = generator;
     91         this.optima                   = null;
     92         setMaxIterations(Integer.MAX_VALUE);
     93         setMaxEvaluations(Integer.MAX_VALUE);
     94     }
     95 
     96     /** Get all the optima found during the last call to {@link
     97      * #optimize(DifferentiableMultivariateVectorialFunction,
     98      * double[], double[], double[]) optimize}.
     99      * <p>The optimizer stores all the optima found during a set of
    100      * restarts. The {@link #optimize(DifferentiableMultivariateVectorialFunction,
    101      * double[], double[], double[]) optimize} method returns the
    102      * best point only. This method returns all the points found at the
    103      * end of each starts, including the best one already returned by the {@link
    104      * #optimize(DifferentiableMultivariateVectorialFunction, double[],
    105      * double[], double[]) optimize} method.
    106      * </p>
    107      * <p>
    108      * The returned array as one element for each start as specified
    109      * in the constructor. It is ordered with the results from the
    110      * runs that did converge first, sorted from best to worst
    111      * objective value (i.e in ascending order if minimizing and in
    112      * descending order if maximizing), followed by and null elements
    113      * corresponding to the runs that did not converge. This means all
    114      * elements will be null if the {@link #optimize(DifferentiableMultivariateVectorialFunction,
    115      * double[], double[], double[]) optimize} method did throw a {@link
    116      * org.apache.commons.math.ConvergenceException ConvergenceException}).
    117      * This also means that if the first element is non null, it is the best
    118      * point found across all starts.</p>
    119      * @return array containing the optima
    120      * @exception IllegalStateException if {@link #optimize(DifferentiableMultivariateVectorialFunction,
    121      * double[], double[], double[]) optimize} has not been called
    122      */
    123     public VectorialPointValuePair[] getOptima() throws IllegalStateException {
    124         if (optima == null) {
    125             throw MathRuntimeException.createIllegalStateException(LocalizedFormats.NO_OPTIMUM_COMPUTED_YET);
    126         }
    127         return optima.clone();
    128     }
    129 
    130     /** {@inheritDoc} */
    131     public void setMaxIterations(int maxIterations) {
    132         this.maxIterations = maxIterations;
    133     }
    134 
    135     /** {@inheritDoc} */
    136     public int getMaxIterations() {
    137         return maxIterations;
    138     }
    139 
    140     /** {@inheritDoc} */
    141     public int getIterations() {
    142         return totalIterations;
    143     }
    144 
    145     /** {@inheritDoc} */
    146     public void setMaxEvaluations(int maxEvaluations) {
    147         this.maxEvaluations = maxEvaluations;
    148     }
    149 
    150     /** {@inheritDoc} */
    151     public int getMaxEvaluations() {
    152         return maxEvaluations;
    153     }
    154 
    155     /** {@inheritDoc} */
    156     public int getEvaluations() {
    157         return totalEvaluations;
    158     }
    159 
    160     /** {@inheritDoc} */
    161     public int getJacobianEvaluations() {
    162         return totalJacobianEvaluations;
    163     }
    164 
    165     /** {@inheritDoc} */
    166     public void setConvergenceChecker(VectorialConvergenceChecker checker) {
    167         optimizer.setConvergenceChecker(checker);
    168     }
    169 
    170     /** {@inheritDoc} */
    171     public VectorialConvergenceChecker getConvergenceChecker() {
    172         return optimizer.getConvergenceChecker();
    173     }
    174 
    175     /** {@inheritDoc} */
    176     public VectorialPointValuePair optimize(final DifferentiableMultivariateVectorialFunction f,
    177                                             final double[] target, final double[] weights,
    178                                             final double[] startPoint)
    179         throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
    180 
    181         optima                   = new VectorialPointValuePair[starts];
    182         totalIterations          = 0;
    183         totalEvaluations         = 0;
    184         totalJacobianEvaluations = 0;
    185 
    186         // multi-start loop
    187         for (int i = 0; i < starts; ++i) {
    188 
    189             try {
    190                 optimizer.setMaxIterations(maxIterations - totalIterations);
    191                 optimizer.setMaxEvaluations(maxEvaluations - totalEvaluations);
    192                 optima[i] = optimizer.optimize(f, target, weights,
    193                                                (i == 0) ? startPoint : generator.nextVector());
    194             } catch (FunctionEvaluationException fee) {
    195                 optima[i] = null;
    196             } catch (OptimizationException oe) {
    197                 optima[i] = null;
    198             }
    199 
    200             totalIterations          += optimizer.getIterations();
    201             totalEvaluations         += optimizer.getEvaluations();
    202             totalJacobianEvaluations += optimizer.getJacobianEvaluations();
    203 
    204         }
    205 
    206         // sort the optima from best to worst, followed by null elements
    207         Arrays.sort(optima, new Comparator<VectorialPointValuePair>() {
    208             public int compare(final VectorialPointValuePair o1, final VectorialPointValuePair o2) {
    209                 if (o1 == null) {
    210                     return (o2 == null) ? 0 : +1;
    211                 } else if (o2 == null) {
    212                     return -1;
    213                 }
    214                 return Double.compare(weightedResidual(o1), weightedResidual(o2));
    215             }
    216             private double weightedResidual(final VectorialPointValuePair pv) {
    217                 final double[] value = pv.getValueRef();
    218                 double sum = 0;
    219                 for (int i = 0; i < value.length; ++i) {
    220                     final double ri = value[i] - target[i];
    221                     sum += weights[i] * ri * ri;
    222                 }
    223                 return sum;
    224             }
    225         });
    226 
    227         if (optima[0] == null) {
    228             throw new OptimizationException(
    229                     LocalizedFormats.NO_CONVERGENCE_WITH_ANY_START_POINT,
    230                     starts);
    231         }
    232 
    233         // return the found point given the best objective function value
    234         return optima[0];
    235 
    236     }
    237 
    238 }
    239