Home | History | Annotate | Download | only in math
      1 /*
      2  * Copyright (C) 2013 The Guava Authors
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  * http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.google.common.math;
     18 
     19 import com.google.caliper.BeforeExperiment;
     20 import com.google.caliper.Benchmark;
     21 import com.google.caliper.Param;
     22 import com.google.caliper.api.SkipThisScenarioException;
     23 import com.google.common.primitives.Doubles;
     24 
     25 import java.util.Random;
     26 
     27 /**
     28  * Benchmarks for various algorithms for computing the mean and/or variance.
     29  *
     30  * @author Louis Wasserman
     31  */
     32 public class StatsBenchmark {
     33 
     34   enum MeanAlgorithm {
     35     SIMPLE {
     36       @Override
     37       double mean(double[] values) {
     38         double sum = 0.0;
     39         for (double value : values) {
     40           sum += value;
     41         }
     42         return sum / values.length;
     43       }
     44     },
     45     KAHAN {
     46       @Override
     47       double mean(double[] values) {
     48         double sum = 0.0;
     49         double c = 0.0;
     50         for (double value : values) {
     51           double y = value - c;
     52           double t = sum + y;
     53           c = (t - sum) - y;
     54           sum = t;
     55         }
     56         return sum / values.length;
     57       }
     58     },
     59     KNUTH {
     60       @Override
     61       double mean(double[] values) {
     62         double mean = values[0];
     63         for (int i = 1; i < values.length; i++) {
     64           mean = mean + (values[i] - mean) / (i + 1);
     65         }
     66         return mean;
     67       }
     68     };
     69 
     70     abstract double mean(double[] values);
     71   }
     72 
     73   static class MeanAndVariance {
     74     private final double mean;
     75     private final double variance;
     76 
     77     MeanAndVariance(double mean, double variance) {
     78       this.mean = mean;
     79       this.variance = variance;
     80     }
     81 
     82     @Override
     83     public int hashCode() {
     84       return Doubles.hashCode(mean) * 31 + Doubles.hashCode(variance);
     85     }
     86   }
     87 
     88   enum VarianceAlgorithm {
     89     DO_NOT_COMPUTE {
     90       @Override
     91       MeanAndVariance variance(double[] values, MeanAlgorithm meanAlgorithm) {
     92         return new MeanAndVariance(meanAlgorithm.mean(values), 0.0);
     93       }
     94     },
     95     SIMPLE {
     96       @Override
     97       MeanAndVariance variance(double[] values, MeanAlgorithm meanAlgorithm) {
     98         double mean = meanAlgorithm.mean(values);
     99         double sumOfSquaresOfDeltas = 0.0;
    100         for (double value : values) {
    101           double delta = value - mean;
    102           sumOfSquaresOfDeltas += delta * delta;
    103         }
    104         return new MeanAndVariance(mean, sumOfSquaresOfDeltas / values.length);
    105       }
    106     },
    107     KAHAN {
    108       @Override
    109       MeanAndVariance variance(double[] values, MeanAlgorithm meanAlgorithm) {
    110         double mean = meanAlgorithm.mean(values);
    111         double sumOfSquaresOfDeltas = 0.0;
    112         double c = 0.0;
    113         for (double value : values) {
    114           double delta = value - mean;
    115           double deltaSquared = delta * delta;
    116           double y = deltaSquared - c;
    117           double t = sumOfSquaresOfDeltas + deltaSquared;
    118           c = (t - sumOfSquaresOfDeltas) - y;
    119           sumOfSquaresOfDeltas = t;
    120         }
    121         return new MeanAndVariance(mean, sumOfSquaresOfDeltas / values.length);
    122       }
    123     },
    124     KNUTH {
    125       @Override
    126       MeanAndVariance variance(double[] values, MeanAlgorithm meanAlgorithm) {
    127         if (meanAlgorithm != MeanAlgorithm.KNUTH) {
    128           throw new SkipThisScenarioException();
    129         }
    130         double mean = values[0];
    131         double s = 0.0;
    132         for (int i = 1; i < values.length; i++) {
    133           double nextMean = mean + (values[i] - mean) / (i + 1);
    134           s += (values[i] - mean) * (values[i] - nextMean);
    135           mean = nextMean;
    136         }
    137         return new MeanAndVariance(mean, s / values.length);
    138       }
    139     };
    140 
    141     abstract MeanAndVariance variance(double[] values, MeanAlgorithm meanAlgorithm);
    142   }
    143 
    144   @Param({"100", "10000"})
    145   int n;
    146 
    147   @Param
    148   MeanAlgorithm meanAlgorithm;
    149   @Param
    150   VarianceAlgorithm varianceAlgorithm;
    151 
    152   private double[][] values = new double[0x100][];
    153 
    154   @BeforeExperiment
    155   void setUp() {
    156     Random rng = new Random();
    157     for (int i = 0; i < 0x100; i++) {
    158       values[i] = new double[n];
    159       for (int j = 0; j < n; j++) {
    160         values[i][j] = rng.nextDouble();
    161       }
    162     }
    163   }
    164 
    165   @Benchmark int meanAndVariance(int reps) {
    166     int tmp = 0;
    167     for (int i = 0; i < reps; i++) {
    168       tmp += varianceAlgorithm.variance(values[i & 0xFF], meanAlgorithm).hashCode();
    169     }
    170     return tmp;
    171   }
    172 }
    173