Home | History | Annotate | Download | only in big
      1 // Copyright 2010 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // This file implements multi-precision rational numbers.
      6 
      7 package big
      8 
      9 import (
     10 	"fmt"
     11 	"math"
     12 )
     13 
     14 // A Rat represents a quotient a/b of arbitrary precision.
     15 // The zero value for a Rat represents the value 0.
     16 type Rat struct {
     17 	// To make zero values for Rat work w/o initialization,
     18 	// a zero value of b (len(b) == 0) acts like b == 1.
     19 	// a.neg determines the sign of the Rat, b.neg is ignored.
     20 	a, b Int
     21 }
     22 
     23 // NewRat creates a new Rat with numerator a and denominator b.
     24 func NewRat(a, b int64) *Rat {
     25 	return new(Rat).SetFrac64(a, b)
     26 }
     27 
     28 // SetFloat64 sets z to exactly f and returns z.
     29 // If f is not finite, SetFloat returns nil.
     30 func (z *Rat) SetFloat64(f float64) *Rat {
     31 	const expMask = 1<<11 - 1
     32 	bits := math.Float64bits(f)
     33 	mantissa := bits & (1<<52 - 1)
     34 	exp := int((bits >> 52) & expMask)
     35 	switch exp {
     36 	case expMask: // non-finite
     37 		return nil
     38 	case 0: // denormal
     39 		exp -= 1022
     40 	default: // normal
     41 		mantissa |= 1 << 52
     42 		exp -= 1023
     43 	}
     44 
     45 	shift := 52 - exp
     46 
     47 	// Optimization (?): partially pre-normalise.
     48 	for mantissa&1 == 0 && shift > 0 {
     49 		mantissa >>= 1
     50 		shift--
     51 	}
     52 
     53 	z.a.SetUint64(mantissa)
     54 	z.a.neg = f < 0
     55 	z.b.Set(intOne)
     56 	if shift > 0 {
     57 		z.b.Lsh(&z.b, uint(shift))
     58 	} else {
     59 		z.a.Lsh(&z.a, uint(-shift))
     60 	}
     61 	return z.norm()
     62 }
     63 
     64 // quotToFloat32 returns the non-negative float32 value
     65 // nearest to the quotient a/b, using round-to-even in
     66 // halfway cases. It does not mutate its arguments.
     67 // Preconditions: b is non-zero; a and b have no common factors.
     68 func quotToFloat32(a, b nat) (f float32, exact bool) {
     69 	const (
     70 		// float size in bits
     71 		Fsize = 32
     72 
     73 		// mantissa
     74 		Msize  = 23
     75 		Msize1 = Msize + 1 // incl. implicit 1
     76 		Msize2 = Msize1 + 1
     77 
     78 		// exponent
     79 		Esize = Fsize - Msize1
     80 		Ebias = 1<<(Esize-1) - 1
     81 		Emin  = 1 - Ebias
     82 		Emax  = Ebias
     83 	)
     84 
     85 	// TODO(adonovan): specialize common degenerate cases: 1.0, integers.
     86 	alen := a.bitLen()
     87 	if alen == 0 {
     88 		return 0, true
     89 	}
     90 	blen := b.bitLen()
     91 	if blen == 0 {
     92 		panic("division by zero")
     93 	}
     94 
     95 	// 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
     96 	// (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
     97 	// This is 2 or 3 more than the float32 mantissa field width of Msize:
     98 	// - the optional extra bit is shifted away in step 3 below.
     99 	// - the high-order 1 is omitted in "normal" representation;
    100 	// - the low-order 1 will be used during rounding then discarded.
    101 	exp := alen - blen
    102 	var a2, b2 nat
    103 	a2 = a2.set(a)
    104 	b2 = b2.set(b)
    105 	if shift := Msize2 - exp; shift > 0 {
    106 		a2 = a2.shl(a2, uint(shift))
    107 	} else if shift < 0 {
    108 		b2 = b2.shl(b2, uint(-shift))
    109 	}
    110 
    111 	// 2. Compute quotient and remainder (q, r).  NB: due to the
    112 	// extra shift, the low-order bit of q is logically the
    113 	// high-order bit of r.
    114 	var q nat
    115 	q, r := q.div(a2, a2, b2) // (recycle a2)
    116 	mantissa := low32(q)
    117 	haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
    118 
    119 	// 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
    120 	// (in effect---we accomplish this incrementally).
    121 	if mantissa>>Msize2 == 1 {
    122 		if mantissa&1 == 1 {
    123 			haveRem = true
    124 		}
    125 		mantissa >>= 1
    126 		exp++
    127 	}
    128 	if mantissa>>Msize1 != 1 {
    129 		panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
    130 	}
    131 
    132 	// 4. Rounding.
    133 	if Emin-Msize <= exp && exp <= Emin {
    134 		// Denormal case; lose 'shift' bits of precision.
    135 		shift := uint(Emin - (exp - 1)) // [1..Esize1)
    136 		lostbits := mantissa & (1<<shift - 1)
    137 		haveRem = haveRem || lostbits != 0
    138 		mantissa >>= shift
    139 		exp = 2 - Ebias // == exp + shift
    140 	}
    141 	// Round q using round-half-to-even.
    142 	exact = !haveRem
    143 	if mantissa&1 != 0 {
    144 		exact = false
    145 		if haveRem || mantissa&2 != 0 {
    146 			if mantissa++; mantissa >= 1<<Msize2 {
    147 				// Complete rollover 11...1 => 100...0, so shift is safe
    148 				mantissa >>= 1
    149 				exp++
    150 			}
    151 		}
    152 	}
    153 	mantissa >>= 1 // discard rounding bit.  Mantissa now scaled by 1<<Msize1.
    154 
    155 	f = float32(math.Ldexp(float64(mantissa), exp-Msize1))
    156 	if math.IsInf(float64(f), 0) {
    157 		exact = false
    158 	}
    159 	return
    160 }
    161 
    162 // quotToFloat64 returns the non-negative float64 value
    163 // nearest to the quotient a/b, using round-to-even in
    164 // halfway cases. It does not mutate its arguments.
    165 // Preconditions: b is non-zero; a and b have no common factors.
    166 func quotToFloat64(a, b nat) (f float64, exact bool) {
    167 	const (
    168 		// float size in bits
    169 		Fsize = 64
    170 
    171 		// mantissa
    172 		Msize  = 52
    173 		Msize1 = Msize + 1 // incl. implicit 1
    174 		Msize2 = Msize1 + 1
    175 
    176 		// exponent
    177 		Esize = Fsize - Msize1
    178 		Ebias = 1<<(Esize-1) - 1
    179 		Emin  = 1 - Ebias
    180 		Emax  = Ebias
    181 	)
    182 
    183 	// TODO(adonovan): specialize common degenerate cases: 1.0, integers.
    184 	alen := a.bitLen()
    185 	if alen == 0 {
    186 		return 0, true
    187 	}
    188 	blen := b.bitLen()
    189 	if blen == 0 {
    190 		panic("division by zero")
    191 	}
    192 
    193 	// 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
    194 	// (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
    195 	// This is 2 or 3 more than the float64 mantissa field width of Msize:
    196 	// - the optional extra bit is shifted away in step 3 below.
    197 	// - the high-order 1 is omitted in "normal" representation;
    198 	// - the low-order 1 will be used during rounding then discarded.
    199 	exp := alen - blen
    200 	var a2, b2 nat
    201 	a2 = a2.set(a)
    202 	b2 = b2.set(b)
    203 	if shift := Msize2 - exp; shift > 0 {
    204 		a2 = a2.shl(a2, uint(shift))
    205 	} else if shift < 0 {
    206 		b2 = b2.shl(b2, uint(-shift))
    207 	}
    208 
    209 	// 2. Compute quotient and remainder (q, r).  NB: due to the
    210 	// extra shift, the low-order bit of q is logically the
    211 	// high-order bit of r.
    212 	var q nat
    213 	q, r := q.div(a2, a2, b2) // (recycle a2)
    214 	mantissa := low64(q)
    215 	haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
    216 
    217 	// 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
    218 	// (in effect---we accomplish this incrementally).
    219 	if mantissa>>Msize2 == 1 {
    220 		if mantissa&1 == 1 {
    221 			haveRem = true
    222 		}
    223 		mantissa >>= 1
    224 		exp++
    225 	}
    226 	if mantissa>>Msize1 != 1 {
    227 		panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
    228 	}
    229 
    230 	// 4. Rounding.
    231 	if Emin-Msize <= exp && exp <= Emin {
    232 		// Denormal case; lose 'shift' bits of precision.
    233 		shift := uint(Emin - (exp - 1)) // [1..Esize1)
    234 		lostbits := mantissa & (1<<shift - 1)
    235 		haveRem = haveRem || lostbits != 0
    236 		mantissa >>= shift
    237 		exp = 2 - Ebias // == exp + shift
    238 	}
    239 	// Round q using round-half-to-even.
    240 	exact = !haveRem
    241 	if mantissa&1 != 0 {
    242 		exact = false
    243 		if haveRem || mantissa&2 != 0 {
    244 			if mantissa++; mantissa >= 1<<Msize2 {
    245 				// Complete rollover 11...1 => 100...0, so shift is safe
    246 				mantissa >>= 1
    247 				exp++
    248 			}
    249 		}
    250 	}
    251 	mantissa >>= 1 // discard rounding bit.  Mantissa now scaled by 1<<Msize1.
    252 
    253 	f = math.Ldexp(float64(mantissa), exp-Msize1)
    254 	if math.IsInf(f, 0) {
    255 		exact = false
    256 	}
    257 	return
    258 }
    259 
    260 // Float32 returns the nearest float32 value for x and a bool indicating
    261 // whether f represents x exactly. If the magnitude of x is too large to
    262 // be represented by a float32, f is an infinity and exact is false.
    263 // The sign of f always matches the sign of x, even if f == 0.
    264 func (x *Rat) Float32() (f float32, exact bool) {
    265 	b := x.b.abs
    266 	if len(b) == 0 {
    267 		b = b.set(natOne) // materialize denominator
    268 	}
    269 	f, exact = quotToFloat32(x.a.abs, b)
    270 	if x.a.neg {
    271 		f = -f
    272 	}
    273 	return
    274 }
    275 
    276 // Float64 returns the nearest float64 value for x and a bool indicating
    277 // whether f represents x exactly. If the magnitude of x is too large to
    278 // be represented by a float64, f is an infinity and exact is false.
    279 // The sign of f always matches the sign of x, even if f == 0.
    280 func (x *Rat) Float64() (f float64, exact bool) {
    281 	b := x.b.abs
    282 	if len(b) == 0 {
    283 		b = b.set(natOne) // materialize denominator
    284 	}
    285 	f, exact = quotToFloat64(x.a.abs, b)
    286 	if x.a.neg {
    287 		f = -f
    288 	}
    289 	return
    290 }
    291 
    292 // SetFrac sets z to a/b and returns z.
    293 func (z *Rat) SetFrac(a, b *Int) *Rat {
    294 	z.a.neg = a.neg != b.neg
    295 	babs := b.abs
    296 	if len(babs) == 0 {
    297 		panic("division by zero")
    298 	}
    299 	if &z.a == b || alias(z.a.abs, babs) {
    300 		babs = nat(nil).set(babs) // make a copy
    301 	}
    302 	z.a.abs = z.a.abs.set(a.abs)
    303 	z.b.abs = z.b.abs.set(babs)
    304 	return z.norm()
    305 }
    306 
    307 // SetFrac64 sets z to a/b and returns z.
    308 func (z *Rat) SetFrac64(a, b int64) *Rat {
    309 	z.a.SetInt64(a)
    310 	if b == 0 {
    311 		panic("division by zero")
    312 	}
    313 	if b < 0 {
    314 		b = -b
    315 		z.a.neg = !z.a.neg
    316 	}
    317 	z.b.abs = z.b.abs.setUint64(uint64(b))
    318 	return z.norm()
    319 }
    320 
    321 // SetInt sets z to x (by making a copy of x) and returns z.
    322 func (z *Rat) SetInt(x *Int) *Rat {
    323 	z.a.Set(x)
    324 	z.b.abs = z.b.abs[:0]
    325 	return z
    326 }
    327 
    328 // SetInt64 sets z to x and returns z.
    329 func (z *Rat) SetInt64(x int64) *Rat {
    330 	z.a.SetInt64(x)
    331 	z.b.abs = z.b.abs[:0]
    332 	return z
    333 }
    334 
    335 // Set sets z to x (by making a copy of x) and returns z.
    336 func (z *Rat) Set(x *Rat) *Rat {
    337 	if z != x {
    338 		z.a.Set(&x.a)
    339 		z.b.Set(&x.b)
    340 	}
    341 	return z
    342 }
    343 
    344 // Abs sets z to |x| (the absolute value of x) and returns z.
    345 func (z *Rat) Abs(x *Rat) *Rat {
    346 	z.Set(x)
    347 	z.a.neg = false
    348 	return z
    349 }
    350 
    351 // Neg sets z to -x and returns z.
    352 func (z *Rat) Neg(x *Rat) *Rat {
    353 	z.Set(x)
    354 	z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign
    355 	return z
    356 }
    357 
    358 // Inv sets z to 1/x and returns z.
    359 func (z *Rat) Inv(x *Rat) *Rat {
    360 	if len(x.a.abs) == 0 {
    361 		panic("division by zero")
    362 	}
    363 	z.Set(x)
    364 	a := z.b.abs
    365 	if len(a) == 0 {
    366 		a = a.set(natOne) // materialize numerator
    367 	}
    368 	b := z.a.abs
    369 	if b.cmp(natOne) == 0 {
    370 		b = b[:0] // normalize denominator
    371 	}
    372 	z.a.abs, z.b.abs = a, b // sign doesn't change
    373 	return z
    374 }
    375 
    376 // Sign returns:
    377 //
    378 //	-1 if x <  0
    379 //	 0 if x == 0
    380 //	+1 if x >  0
    381 //
    382 func (x *Rat) Sign() int {
    383 	return x.a.Sign()
    384 }
    385 
    386 // IsInt reports whether the denominator of x is 1.
    387 func (x *Rat) IsInt() bool {
    388 	return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0
    389 }
    390 
    391 // Num returns the numerator of x; it may be <= 0.
    392 // The result is a reference to x's numerator; it
    393 // may change if a new value is assigned to x, and vice versa.
    394 // The sign of the numerator corresponds to the sign of x.
    395 func (x *Rat) Num() *Int {
    396 	return &x.a
    397 }
    398 
    399 // Denom returns the denominator of x; it is always > 0.
    400 // The result is a reference to x's denominator; it
    401 // may change if a new value is assigned to x, and vice versa.
    402 func (x *Rat) Denom() *Int {
    403 	x.b.neg = false // the result is always >= 0
    404 	if len(x.b.abs) == 0 {
    405 		x.b.abs = x.b.abs.set(natOne) // materialize denominator
    406 	}
    407 	return &x.b
    408 }
    409 
    410 func (z *Rat) norm() *Rat {
    411 	switch {
    412 	case len(z.a.abs) == 0:
    413 		// z == 0 - normalize sign and denominator
    414 		z.a.neg = false
    415 		z.b.abs = z.b.abs[:0]
    416 	case len(z.b.abs) == 0:
    417 		// z is normalized int - nothing to do
    418 	case z.b.abs.cmp(natOne) == 0:
    419 		// z is int - normalize denominator
    420 		z.b.abs = z.b.abs[:0]
    421 	default:
    422 		neg := z.a.neg
    423 		z.a.neg = false
    424 		z.b.neg = false
    425 		if f := NewInt(0).binaryGCD(&z.a, &z.b); f.Cmp(intOne) != 0 {
    426 			z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs)
    427 			z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs)
    428 			if z.b.abs.cmp(natOne) == 0 {
    429 				// z is int - normalize denominator
    430 				z.b.abs = z.b.abs[:0]
    431 			}
    432 		}
    433 		z.a.neg = neg
    434 	}
    435 	return z
    436 }
    437 
    438 // mulDenom sets z to the denominator product x*y (by taking into
    439 // account that 0 values for x or y must be interpreted as 1) and
    440 // returns z.
    441 func mulDenom(z, x, y nat) nat {
    442 	switch {
    443 	case len(x) == 0:
    444 		return z.set(y)
    445 	case len(y) == 0:
    446 		return z.set(x)
    447 	}
    448 	return z.mul(x, y)
    449 }
    450 
    451 // scaleDenom computes x*f.
    452 // If f == 0 (zero value of denominator), the result is (a copy of) x.
    453 func scaleDenom(x *Int, f nat) *Int {
    454 	var z Int
    455 	if len(f) == 0 {
    456 		return z.Set(x)
    457 	}
    458 	z.abs = z.abs.mul(x.abs, f)
    459 	z.neg = x.neg
    460 	return &z
    461 }
    462 
    463 // Cmp compares x and y and returns:
    464 //
    465 //   -1 if x <  y
    466 //    0 if x == y
    467 //   +1 if x >  y
    468 //
    469 func (x *Rat) Cmp(y *Rat) int {
    470 	return scaleDenom(&x.a, y.b.abs).Cmp(scaleDenom(&y.a, x.b.abs))
    471 }
    472 
    473 // Add sets z to the sum x+y and returns z.
    474 func (z *Rat) Add(x, y *Rat) *Rat {
    475 	a1 := scaleDenom(&x.a, y.b.abs)
    476 	a2 := scaleDenom(&y.a, x.b.abs)
    477 	z.a.Add(a1, a2)
    478 	z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
    479 	return z.norm()
    480 }
    481 
    482 // Sub sets z to the difference x-y and returns z.
    483 func (z *Rat) Sub(x, y *Rat) *Rat {
    484 	a1 := scaleDenom(&x.a, y.b.abs)
    485 	a2 := scaleDenom(&y.a, x.b.abs)
    486 	z.a.Sub(a1, a2)
    487 	z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
    488 	return z.norm()
    489 }
    490 
    491 // Mul sets z to the product x*y and returns z.
    492 func (z *Rat) Mul(x, y *Rat) *Rat {
    493 	z.a.Mul(&x.a, &y.a)
    494 	z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
    495 	return z.norm()
    496 }
    497 
    498 // Quo sets z to the quotient x/y and returns z.
    499 // If y == 0, a division-by-zero run-time panic occurs.
    500 func (z *Rat) Quo(x, y *Rat) *Rat {
    501 	if len(y.a.abs) == 0 {
    502 		panic("division by zero")
    503 	}
    504 	a := scaleDenom(&x.a, y.b.abs)
    505 	b := scaleDenom(&y.a, x.b.abs)
    506 	z.a.abs = a.abs
    507 	z.b.abs = b.abs
    508 	z.a.neg = a.neg != b.neg
    509 	return z.norm()
    510 }
    511