Home | History | Annotate | Download | only in direct
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package org.apache.commons.math.optimization.direct;
     19 
     20 import java.util.Comparator;
     21 
     22 import org.apache.commons.math.FunctionEvaluationException;
     23 import org.apache.commons.math.optimization.OptimizationException;
     24 import org.apache.commons.math.optimization.RealPointValuePair;
     25 
     26 /**
     27  * This class implements the Nelder-Mead direct search method.
     28  *
     29  * @version $Revision: 1070725 $ $Date: 2011-02-15 02:31:12 +0100 (mar. 15 fvr. 2011) $
     30  * @see MultiDirectional
     31  * @since 1.2
     32  */
     33 public class NelderMead extends DirectSearchOptimizer {
     34 
     35     /** Reflection coefficient. */
     36     private final double rho;
     37 
     38     /** Expansion coefficient. */
     39     private final double khi;
     40 
     41     /** Contraction coefficient. */
     42     private final double gamma;
     43 
     44     /** Shrinkage coefficient. */
     45     private final double sigma;
     46 
     47     /** Build a Nelder-Mead optimizer with default coefficients.
     48      * <p>The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
     49      * for both gamma and sigma.</p>
     50      */
     51     public NelderMead() {
     52         this.rho   = 1.0;
     53         this.khi   = 2.0;
     54         this.gamma = 0.5;
     55         this.sigma = 0.5;
     56     }
     57 
     58     /** Build a Nelder-Mead optimizer with specified coefficients.
     59      * @param rho reflection coefficient
     60      * @param khi expansion coefficient
     61      * @param gamma contraction coefficient
     62      * @param sigma shrinkage coefficient
     63      */
     64     public NelderMead(final double rho, final double khi,
     65                       final double gamma, final double sigma) {
     66         this.rho   = rho;
     67         this.khi   = khi;
     68         this.gamma = gamma;
     69         this.sigma = sigma;
     70     }
     71 
     72     /** {@inheritDoc} */
     73     @Override
     74     protected void iterateSimplex(final Comparator<RealPointValuePair> comparator)
     75         throws FunctionEvaluationException, OptimizationException {
     76 
     77         incrementIterationsCounter();
     78 
     79         // the simplex has n+1 point if dimension is n
     80         final int n = simplex.length - 1;
     81 
     82         // interesting values
     83         final RealPointValuePair best       = simplex[0];
     84         final RealPointValuePair secondBest = simplex[n-1];
     85         final RealPointValuePair worst      = simplex[n];
     86         final double[] xWorst = worst.getPointRef();
     87 
     88         // compute the centroid of the best vertices
     89         // (dismissing the worst point at index n)
     90         final double[] centroid = new double[n];
     91         for (int i = 0; i < n; ++i) {
     92             final double[] x = simplex[i].getPointRef();
     93             for (int j = 0; j < n; ++j) {
     94                 centroid[j] += x[j];
     95             }
     96         }
     97         final double scaling = 1.0 / n;
     98         for (int j = 0; j < n; ++j) {
     99             centroid[j] *= scaling;
    100         }
    101 
    102         // compute the reflection point
    103         final double[] xR = new double[n];
    104         for (int j = 0; j < n; ++j) {
    105             xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
    106         }
    107         final RealPointValuePair reflected = new RealPointValuePair(xR, evaluate(xR), false);
    108 
    109         if ((comparator.compare(best, reflected) <= 0) &&
    110             (comparator.compare(reflected, secondBest) < 0)) {
    111 
    112             // accept the reflected point
    113             replaceWorstPoint(reflected, comparator);
    114 
    115         } else if (comparator.compare(reflected, best) < 0) {
    116 
    117             // compute the expansion point
    118             final double[] xE = new double[n];
    119             for (int j = 0; j < n; ++j) {
    120                 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
    121             }
    122             final RealPointValuePair expanded = new RealPointValuePair(xE, evaluate(xE), false);
    123 
    124             if (comparator.compare(expanded, reflected) < 0) {
    125                 // accept the expansion point
    126                 replaceWorstPoint(expanded, comparator);
    127             } else {
    128                 // accept the reflected point
    129                 replaceWorstPoint(reflected, comparator);
    130             }
    131 
    132         } else {
    133 
    134             if (comparator.compare(reflected, worst) < 0) {
    135 
    136                 // perform an outside contraction
    137                 final double[] xC = new double[n];
    138                 for (int j = 0; j < n; ++j) {
    139                     xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
    140                 }
    141                 final RealPointValuePair outContracted = new RealPointValuePair(xC, evaluate(xC), false);
    142 
    143                 if (comparator.compare(outContracted, reflected) <= 0) {
    144                     // accept the contraction point
    145                     replaceWorstPoint(outContracted, comparator);
    146                     return;
    147                 }
    148 
    149             } else {
    150 
    151                 // perform an inside contraction
    152                 final double[] xC = new double[n];
    153                 for (int j = 0; j < n; ++j) {
    154                     xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
    155                 }
    156                 final RealPointValuePair inContracted = new RealPointValuePair(xC, evaluate(xC), false);
    157 
    158                 if (comparator.compare(inContracted, worst) < 0) {
    159                     // accept the contraction point
    160                     replaceWorstPoint(inContracted, comparator);
    161                     return;
    162                 }
    163 
    164             }
    165 
    166             // perform a shrink
    167             final double[] xSmallest = simplex[0].getPointRef();
    168             for (int i = 1; i < simplex.length; ++i) {
    169                 final double[] x = simplex[i].getPoint();
    170                 for (int j = 0; j < n; ++j) {
    171                     x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
    172                 }
    173                 simplex[i] = new RealPointValuePair(x, Double.NaN, false);
    174             }
    175             evaluateSimplex(comparator);
    176 
    177         }
    178 
    179     }
    180 
    181 }
    182