Home | History | Annotate | Download | only in math
      1 /*
      2  * Copyright (C) 2011 The Guava Authors
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  * http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.google.common.math;
     18 
     19 import static com.google.common.base.Preconditions.checkArgument;
     20 import static com.google.common.base.Preconditions.checkNotNull;
     21 import static com.google.common.math.MathPreconditions.checkNonNegative;
     22 import static com.google.common.math.MathPreconditions.checkPositive;
     23 import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
     24 import static java.math.RoundingMode.CEILING;
     25 import static java.math.RoundingMode.FLOOR;
     26 import static java.math.RoundingMode.HALF_EVEN;
     27 
     28 import com.google.common.annotations.GwtCompatible;
     29 import com.google.common.annotations.GwtIncompatible;
     30 import com.google.common.annotations.VisibleForTesting;
     31 
     32 import java.math.BigDecimal;
     33 import java.math.BigInteger;
     34 import java.math.RoundingMode;
     35 import java.util.ArrayList;
     36 import java.util.List;
     37 
     38 /**
     39  * A class for arithmetic on values of type {@code BigInteger}.
     40  *
     41  * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
     42  * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
     43  *
     44  * <p>Similar functionality for {@code int} and for {@code long} can be found in
     45  * {@link IntMath} and {@link LongMath} respectively.
     46  *
     47  * @author Louis Wasserman
     48  * @since 11.0
     49  */
     50 @GwtCompatible(emulated = true)
     51 public final class BigIntegerMath {
     52   /**
     53    * Returns {@code true} if {@code x} represents a power of two.
     54    */
     55   public static boolean isPowerOfTwo(BigInteger x) {
     56     checkNotNull(x);
     57     return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
     58   }
     59 
     60   /**
     61    * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
     62    *
     63    * @throws IllegalArgumentException if {@code x <= 0}
     64    * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
     65    *         is not a power of two
     66    */
     67   @SuppressWarnings("fallthrough")
     68   // TODO(kevinb): remove after this warning is disabled globally
     69   public static int log2(BigInteger x, RoundingMode mode) {
     70     checkPositive("x", checkNotNull(x));
     71     int logFloor = x.bitLength() - 1;
     72     switch (mode) {
     73       case UNNECESSARY:
     74         checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
     75       case DOWN:
     76       case FLOOR:
     77         return logFloor;
     78 
     79       case UP:
     80       case CEILING:
     81         return isPowerOfTwo(x) ? logFloor : logFloor + 1;
     82 
     83       case HALF_DOWN:
     84       case HALF_UP:
     85       case HALF_EVEN:
     86         if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
     87           BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
     88               SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
     89           if (x.compareTo(halfPower) <= 0) {
     90             return logFloor;
     91           } else {
     92             return logFloor + 1;
     93           }
     94         }
     95         /*
     96          * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
     97          *
     98          * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
     99          * logFloor + 1).
    100          */
    101         BigInteger x2 = x.pow(2);
    102         int logX2Floor = x2.bitLength() - 1;
    103         return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
    104 
    105       default:
    106         throw new AssertionError();
    107     }
    108   }
    109 
    110   /*
    111    * The maximum number of bits in a square root for which we'll precompute an explicit half power
    112    * of two. This can be any value, but higher values incur more class load time and linearly
    113    * increasing memory consumption.
    114    */
    115   @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
    116 
    117   @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
    118       new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
    119 
    120   /**
    121    * Returns the base-10 logarithm of {@code x}, rounded according to the specified rounding mode.
    122    *
    123    * @throws IllegalArgumentException if {@code x <= 0}
    124    * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
    125    *         is not a power of ten
    126    */
    127   @GwtIncompatible("TODO")
    128   @SuppressWarnings("fallthrough")
    129   public static int log10(BigInteger x, RoundingMode mode) {
    130     checkPositive("x", x);
    131     if (fitsInLong(x)) {
    132       return LongMath.log10(x.longValue(), mode);
    133     }
    134 
    135     int approxLog10 = (int) (log2(x, FLOOR) * LN_2 / LN_10);
    136     BigInteger approxPow = BigInteger.TEN.pow(approxLog10);
    137     int approxCmp = approxPow.compareTo(x);
    138 
    139     /*
    140      * We adjust approxLog10 and approxPow until they're equal to floor(log10(x)) and
    141      * 10^floor(log10(x)).
    142      */
    143 
    144     if (approxCmp > 0) {
    145       /*
    146        * The code is written so that even completely incorrect approximations will still yield the
    147        * correct answer eventually, but in practice this branch should almost never be entered,
    148        * and even then the loop should not run more than once.
    149        */
    150       do {
    151         approxLog10--;
    152         approxPow = approxPow.divide(BigInteger.TEN);
    153         approxCmp = approxPow.compareTo(x);
    154       } while (approxCmp > 0);
    155     } else {
    156       BigInteger nextPow = BigInteger.TEN.multiply(approxPow);
    157       int nextCmp = nextPow.compareTo(x);
    158       while (nextCmp <= 0) {
    159         approxLog10++;
    160         approxPow = nextPow;
    161         approxCmp = nextCmp;
    162         nextPow = BigInteger.TEN.multiply(approxPow);
    163         nextCmp = nextPow.compareTo(x);
    164       }
    165     }
    166 
    167     int floorLog = approxLog10;
    168     BigInteger floorPow = approxPow;
    169     int floorCmp = approxCmp;
    170 
    171     switch (mode) {
    172       case UNNECESSARY:
    173         checkRoundingUnnecessary(floorCmp == 0);
    174         // fall through
    175       case FLOOR:
    176       case DOWN:
    177         return floorLog;
    178 
    179       case CEILING:
    180       case UP:
    181         return floorPow.equals(x) ? floorLog : floorLog + 1;
    182 
    183       case HALF_DOWN:
    184       case HALF_UP:
    185       case HALF_EVEN:
    186         // Since sqrt(10) is irrational, log10(x) - floorLog can never be exactly 0.5
    187         BigInteger x2 = x.pow(2);
    188         BigInteger halfPowerSquared = floorPow.pow(2).multiply(BigInteger.TEN);
    189         return (x2.compareTo(halfPowerSquared) <= 0) ? floorLog : floorLog + 1;
    190       default:
    191         throw new AssertionError();
    192     }
    193   }
    194 
    195   private static final double LN_10 = Math.log(10);
    196   private static final double LN_2 = Math.log(2);
    197 
    198   /**
    199    * Returns the square root of {@code x}, rounded with the specified rounding mode.
    200    *
    201    * @throws IllegalArgumentException if {@code x < 0}
    202    * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and
    203    *         {@code sqrt(x)} is not an integer
    204    */
    205   @GwtIncompatible("TODO")
    206   @SuppressWarnings("fallthrough")
    207   public static BigInteger sqrt(BigInteger x, RoundingMode mode) {
    208     checkNonNegative("x", x);
    209     if (fitsInLong(x)) {
    210       return BigInteger.valueOf(LongMath.sqrt(x.longValue(), mode));
    211     }
    212     BigInteger sqrtFloor = sqrtFloor(x);
    213     switch (mode) {
    214       case UNNECESSARY:
    215         checkRoundingUnnecessary(sqrtFloor.pow(2).equals(x)); // fall through
    216       case FLOOR:
    217       case DOWN:
    218         return sqrtFloor;
    219       case CEILING:
    220       case UP:
    221         int sqrtFloorInt = sqrtFloor.intValue();
    222         boolean sqrtFloorIsExact =
    223             (sqrtFloorInt * sqrtFloorInt == x.intValue()) // fast check mod 2^32
    224             && sqrtFloor.pow(2).equals(x); // slow exact check
    225         return sqrtFloorIsExact ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
    226       case HALF_DOWN:
    227       case HALF_UP:
    228       case HALF_EVEN:
    229         BigInteger halfSquare = sqrtFloor.pow(2).add(sqrtFloor);
    230         /*
    231          * We wish to test whether or not x <= (sqrtFloor + 0.5)^2 = halfSquare + 0.25. Since both
    232          * x and halfSquare are integers, this is equivalent to testing whether or not x <=
    233          * halfSquare.
    234          */
    235         return (halfSquare.compareTo(x) >= 0) ? sqrtFloor : sqrtFloor.add(BigInteger.ONE);
    236       default:
    237         throw new AssertionError();
    238     }
    239   }
    240 
    241   @GwtIncompatible("TODO")
    242   private static BigInteger sqrtFloor(BigInteger x) {
    243     /*
    244      * Adapted from Hacker's Delight, Figure 11-1.
    245      *
    246      * Using DoubleUtils.bigToDouble, getting a double approximation of x is extremely fast, and
    247      * then we can get a double approximation of the square root. Then, we iteratively improve this
    248      * guess with an application of Newton's method, which sets guess := (guess + (x / guess)) / 2.
    249      * This iteration has the following two properties:
    250      *
    251      * a) every iteration (except potentially the first) has guess >= floor(sqrt(x)). This is
    252      * because guess' is the arithmetic mean of guess and x / guess, sqrt(x) is the geometric mean,
    253      * and the arithmetic mean is always higher than the geometric mean.
    254      *
    255      * b) this iteration converges to floor(sqrt(x)). In fact, the number of correct digits doubles
    256      * with each iteration, so this algorithm takes O(log(digits)) iterations.
    257      *
    258      * We start out with a double-precision approximation, which may be higher or lower than the
    259      * true value. Therefore, we perform at least one Newton iteration to get a guess that's
    260      * definitely >= floor(sqrt(x)), and then continue the iteration until we reach a fixed point.
    261      */
    262     BigInteger sqrt0;
    263     int log2 = log2(x, FLOOR);
    264     if (log2 < Double.MAX_EXPONENT) {
    265       sqrt0 = sqrtApproxWithDoubles(x);
    266     } else {
    267       int shift = (log2 - DoubleUtils.SIGNIFICAND_BITS) & ~1; // even!
    268       /*
    269        * We have that x / 2^shift < 2^54. Our initial approximation to sqrtFloor(x) will be
    270        * 2^(shift/2) * sqrtApproxWithDoubles(x / 2^shift).
    271        */
    272       sqrt0 = sqrtApproxWithDoubles(x.shiftRight(shift)).shiftLeft(shift >> 1);
    273     }
    274     BigInteger sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
    275     if (sqrt0.equals(sqrt1)) {
    276       return sqrt0;
    277     }
    278     do {
    279       sqrt0 = sqrt1;
    280       sqrt1 = sqrt0.add(x.divide(sqrt0)).shiftRight(1);
    281     } while (sqrt1.compareTo(sqrt0) < 0);
    282     return sqrt0;
    283   }
    284 
    285   @GwtIncompatible("TODO")
    286   private static BigInteger sqrtApproxWithDoubles(BigInteger x) {
    287     return DoubleMath.roundToBigInteger(Math.sqrt(DoubleUtils.bigToDouble(x)), HALF_EVEN);
    288   }
    289 
    290   /**
    291    * Returns the result of dividing {@code p} by {@code q}, rounding using the specified
    292    * {@code RoundingMode}.
    293    *
    294    * @throws ArithmeticException if {@code q == 0}, or if {@code mode == UNNECESSARY} and {@code a}
    295    *         is not an integer multiple of {@code b}
    296    */
    297   @GwtIncompatible("TODO")
    298   public static BigInteger divide(BigInteger p, BigInteger q, RoundingMode mode) {
    299     BigDecimal pDec = new BigDecimal(p);
    300     BigDecimal qDec = new BigDecimal(q);
    301     return pDec.divide(qDec, 0, mode).toBigIntegerExact();
    302   }
    303 
    304   /**
    305    * Returns {@code n!}, that is, the product of the first {@code n} positive
    306    * integers, or {@code 1} if {@code n == 0}.
    307    *
    308    * <p><b>Warning:</b> the result takes <i>O(n log n)</i> space, so use cautiously.
    309    *
    310    * <p>This uses an efficient binary recursive algorithm to compute the factorial
    311    * with balanced multiplies.  It also removes all the 2s from the intermediate
    312    * products (shifting them back in at the end).
    313    *
    314    * @throws IllegalArgumentException if {@code n < 0}
    315    */
    316   public static BigInteger factorial(int n) {
    317     checkNonNegative("n", n);
    318 
    319     // If the factorial is small enough, just use LongMath to do it.
    320     if (n < LongMath.factorials.length) {
    321       return BigInteger.valueOf(LongMath.factorials[n]);
    322     }
    323 
    324     // Pre-allocate space for our list of intermediate BigIntegers.
    325     int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
    326     ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
    327 
    328     // Start from the pre-computed maximum long factorial.
    329     int startingNumber = LongMath.factorials.length;
    330     long product = LongMath.factorials[startingNumber - 1];
    331     // Strip off 2s from this value.
    332     int shift = Long.numberOfTrailingZeros(product);
    333     product >>= shift;
    334 
    335     // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
    336     int productBits = LongMath.log2(product, FLOOR) + 1;
    337     int bits = LongMath.log2(startingNumber, FLOOR) + 1;
    338     // Check for the next power of two boundary, to save us a CLZ operation.
    339     int nextPowerOfTwo = 1 << (bits - 1);
    340 
    341     // Iteratively multiply the longs as big as they can go.
    342     for (long num = startingNumber; num <= n; num++) {
    343       // Check to see if the floor(log2(num)) + 1 has changed.
    344       if ((num & nextPowerOfTwo) != 0) {
    345         nextPowerOfTwo <<= 1;
    346         bits++;
    347       }
    348       // Get rid of the 2s in num.
    349       int tz = Long.numberOfTrailingZeros(num);
    350       long normalizedNum = num >> tz;
    351       shift += tz;
    352       // Adjust floor(log2(num)) + 1.
    353       int normalizedBits = bits - tz;
    354       // If it won't fit in a long, then we store off the intermediate product.
    355       if (normalizedBits + productBits >= Long.SIZE) {
    356         bignums.add(BigInteger.valueOf(product));
    357         product = 1;
    358         productBits = 0;
    359       }
    360       product *= normalizedNum;
    361       productBits = LongMath.log2(product, FLOOR) + 1;
    362     }
    363     // Check for leftovers.
    364     if (product > 1) {
    365       bignums.add(BigInteger.valueOf(product));
    366     }
    367     // Efficiently multiply all the intermediate products together.
    368     return listProduct(bignums).shiftLeft(shift);
    369   }
    370 
    371   static BigInteger listProduct(List<BigInteger> nums) {
    372     return listProduct(nums, 0, nums.size());
    373   }
    374 
    375   static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
    376     switch (end - start) {
    377       case 0:
    378         return BigInteger.ONE;
    379       case 1:
    380         return nums.get(start);
    381       case 2:
    382         return nums.get(start).multiply(nums.get(start + 1));
    383       case 3:
    384         return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
    385       default:
    386         // Otherwise, split the list in half and recursively do this.
    387         int m = (end + start) >>> 1;
    388         return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
    389     }
    390   }
    391 
    392  /**
    393    * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
    394    * {@code k}, that is, {@code n! / (k! (n - k)!)}.
    395    *
    396    * <p><b>Warning:</b> the result can take as much as <i>O(k log n)</i> space.
    397    *
    398    * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
    399    */
    400   public static BigInteger binomial(int n, int k) {
    401     checkNonNegative("n", n);
    402     checkNonNegative("k", k);
    403     checkArgument(k <= n, "k (%s) > n (%s)", k, n);
    404     if (k > (n >> 1)) {
    405       k = n - k;
    406     }
    407     if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
    408       return BigInteger.valueOf(LongMath.binomial(n, k));
    409     }
    410 
    411     BigInteger accum = BigInteger.ONE;
    412 
    413     long numeratorAccum = n;
    414     long denominatorAccum = 1;
    415 
    416     int bits = LongMath.log2(n, RoundingMode.CEILING);
    417 
    418     int numeratorBits = bits;
    419 
    420     for (int i = 1; i < k; i++) {
    421       int p = n - i;
    422       int q = i + 1;
    423 
    424       // log2(p) >= bits - 1, because p >= n/2
    425 
    426       if (numeratorBits + bits >= Long.SIZE - 1) {
    427         // The numerator is as big as it can get without risking overflow.
    428         // Multiply numeratorAccum / denominatorAccum into accum.
    429         accum = accum
    430             .multiply(BigInteger.valueOf(numeratorAccum))
    431             .divide(BigInteger.valueOf(denominatorAccum));
    432         numeratorAccum = p;
    433         denominatorAccum = q;
    434         numeratorBits = bits;
    435       } else {
    436         // We can definitely multiply into the long accumulators without overflowing them.
    437         numeratorAccum *= p;
    438         denominatorAccum *= q;
    439         numeratorBits += bits;
    440       }
    441     }
    442     return accum
    443         .multiply(BigInteger.valueOf(numeratorAccum))
    444         .divide(BigInteger.valueOf(denominatorAccum));
    445   }
    446 
    447   // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
    448   @GwtIncompatible("TODO")
    449   static boolean fitsInLong(BigInteger x) {
    450     return x.bitLength() <= Long.SIZE - 1;
    451   }
    452 
    453   private BigIntegerMath() {}
    454 }
    455