Home | History | Annotate | Download | only in big
      1 // Copyright 2010 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // This file implements multi-precision rational numbers.
      6 
      7 package big
      8 
      9 import (
     10 	"encoding/binary"
     11 	"errors"
     12 	"fmt"
     13 	"math"
     14 )
     15 
     16 // A Rat represents a quotient a/b of arbitrary precision.
     17 // The zero value for a Rat represents the value 0.
     18 type Rat struct {
     19 	// To make zero values for Rat work w/o initialization,
     20 	// a zero value of b (len(b) == 0) acts like b == 1.
     21 	// a.neg determines the sign of the Rat, b.neg is ignored.
     22 	a, b Int
     23 }
     24 
     25 // NewRat creates a new Rat with numerator a and denominator b.
     26 func NewRat(a, b int64) *Rat {
     27 	return new(Rat).SetFrac64(a, b)
     28 }
     29 
     30 // SetFloat64 sets z to exactly f and returns z.
     31 // If f is not finite, SetFloat returns nil.
     32 func (z *Rat) SetFloat64(f float64) *Rat {
     33 	const expMask = 1<<11 - 1
     34 	bits := math.Float64bits(f)
     35 	mantissa := bits & (1<<52 - 1)
     36 	exp := int((bits >> 52) & expMask)
     37 	switch exp {
     38 	case expMask: // non-finite
     39 		return nil
     40 	case 0: // denormal
     41 		exp -= 1022
     42 	default: // normal
     43 		mantissa |= 1 << 52
     44 		exp -= 1023
     45 	}
     46 
     47 	shift := 52 - exp
     48 
     49 	// Optimization (?): partially pre-normalise.
     50 	for mantissa&1 == 0 && shift > 0 {
     51 		mantissa >>= 1
     52 		shift--
     53 	}
     54 
     55 	z.a.SetUint64(mantissa)
     56 	z.a.neg = f < 0
     57 	z.b.Set(intOne)
     58 	if shift > 0 {
     59 		z.b.Lsh(&z.b, uint(shift))
     60 	} else {
     61 		z.a.Lsh(&z.a, uint(-shift))
     62 	}
     63 	return z.norm()
     64 }
     65 
     66 // quotToFloat32 returns the non-negative float32 value
     67 // nearest to the quotient a/b, using round-to-even in
     68 // halfway cases.  It does not mutate its arguments.
     69 // Preconditions: b is non-zero; a and b have no common factors.
     70 func quotToFloat32(a, b nat) (f float32, exact bool) {
     71 	const (
     72 		// float size in bits
     73 		Fsize = 32
     74 
     75 		// mantissa
     76 		Msize  = 23
     77 		Msize1 = Msize + 1 // incl. implicit 1
     78 		Msize2 = Msize1 + 1
     79 
     80 		// exponent
     81 		Esize = Fsize - Msize1
     82 		Ebias = 1<<(Esize-1) - 1
     83 		Emin  = 1 - Ebias
     84 		Emax  = Ebias
     85 	)
     86 
     87 	// TODO(adonovan): specialize common degenerate cases: 1.0, integers.
     88 	alen := a.bitLen()
     89 	if alen == 0 {
     90 		return 0, true
     91 	}
     92 	blen := b.bitLen()
     93 	if blen == 0 {
     94 		panic("division by zero")
     95 	}
     96 
     97 	// 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
     98 	// (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
     99 	// This is 2 or 3 more than the float32 mantissa field width of Msize:
    100 	// - the optional extra bit is shifted away in step 3 below.
    101 	// - the high-order 1 is omitted in "normal" representation;
    102 	// - the low-order 1 will be used during rounding then discarded.
    103 	exp := alen - blen
    104 	var a2, b2 nat
    105 	a2 = a2.set(a)
    106 	b2 = b2.set(b)
    107 	if shift := Msize2 - exp; shift > 0 {
    108 		a2 = a2.shl(a2, uint(shift))
    109 	} else if shift < 0 {
    110 		b2 = b2.shl(b2, uint(-shift))
    111 	}
    112 
    113 	// 2. Compute quotient and remainder (q, r).  NB: due to the
    114 	// extra shift, the low-order bit of q is logically the
    115 	// high-order bit of r.
    116 	var q nat
    117 	q, r := q.div(a2, a2, b2) // (recycle a2)
    118 	mantissa := low32(q)
    119 	haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
    120 
    121 	// 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
    122 	// (in effect---we accomplish this incrementally).
    123 	if mantissa>>Msize2 == 1 {
    124 		if mantissa&1 == 1 {
    125 			haveRem = true
    126 		}
    127 		mantissa >>= 1
    128 		exp++
    129 	}
    130 	if mantissa>>Msize1 != 1 {
    131 		panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
    132 	}
    133 
    134 	// 4. Rounding.
    135 	if Emin-Msize <= exp && exp <= Emin {
    136 		// Denormal case; lose 'shift' bits of precision.
    137 		shift := uint(Emin - (exp - 1)) // [1..Esize1)
    138 		lostbits := mantissa & (1<<shift - 1)
    139 		haveRem = haveRem || lostbits != 0
    140 		mantissa >>= shift
    141 		exp = 2 - Ebias // == exp + shift
    142 	}
    143 	// Round q using round-half-to-even.
    144 	exact = !haveRem
    145 	if mantissa&1 != 0 {
    146 		exact = false
    147 		if haveRem || mantissa&2 != 0 {
    148 			if mantissa++; mantissa >= 1<<Msize2 {
    149 				// Complete rollover 11...1 => 100...0, so shift is safe
    150 				mantissa >>= 1
    151 				exp++
    152 			}
    153 		}
    154 	}
    155 	mantissa >>= 1 // discard rounding bit.  Mantissa now scaled by 1<<Msize1.
    156 
    157 	f = float32(math.Ldexp(float64(mantissa), exp-Msize1))
    158 	if math.IsInf(float64(f), 0) {
    159 		exact = false
    160 	}
    161 	return
    162 }
    163 
    164 // quotToFloat64 returns the non-negative float64 value
    165 // nearest to the quotient a/b, using round-to-even in
    166 // halfway cases.  It does not mutate its arguments.
    167 // Preconditions: b is non-zero; a and b have no common factors.
    168 func quotToFloat64(a, b nat) (f float64, exact bool) {
    169 	const (
    170 		// float size in bits
    171 		Fsize = 64
    172 
    173 		// mantissa
    174 		Msize  = 52
    175 		Msize1 = Msize + 1 // incl. implicit 1
    176 		Msize2 = Msize1 + 1
    177 
    178 		// exponent
    179 		Esize = Fsize - Msize1
    180 		Ebias = 1<<(Esize-1) - 1
    181 		Emin  = 1 - Ebias
    182 		Emax  = Ebias
    183 	)
    184 
    185 	// TODO(adonovan): specialize common degenerate cases: 1.0, integers.
    186 	alen := a.bitLen()
    187 	if alen == 0 {
    188 		return 0, true
    189 	}
    190 	blen := b.bitLen()
    191 	if blen == 0 {
    192 		panic("division by zero")
    193 	}
    194 
    195 	// 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
    196 	// (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
    197 	// This is 2 or 3 more than the float64 mantissa field width of Msize:
    198 	// - the optional extra bit is shifted away in step 3 below.
    199 	// - the high-order 1 is omitted in "normal" representation;
    200 	// - the low-order 1 will be used during rounding then discarded.
    201 	exp := alen - blen
    202 	var a2, b2 nat
    203 	a2 = a2.set(a)
    204 	b2 = b2.set(b)
    205 	if shift := Msize2 - exp; shift > 0 {
    206 		a2 = a2.shl(a2, uint(shift))
    207 	} else if shift < 0 {
    208 		b2 = b2.shl(b2, uint(-shift))
    209 	}
    210 
    211 	// 2. Compute quotient and remainder (q, r).  NB: due to the
    212 	// extra shift, the low-order bit of q is logically the
    213 	// high-order bit of r.
    214 	var q nat
    215 	q, r := q.div(a2, a2, b2) // (recycle a2)
    216 	mantissa := low64(q)
    217 	haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
    218 
    219 	// 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
    220 	// (in effect---we accomplish this incrementally).
    221 	if mantissa>>Msize2 == 1 {
    222 		if mantissa&1 == 1 {
    223 			haveRem = true
    224 		}
    225 		mantissa >>= 1
    226 		exp++
    227 	}
    228 	if mantissa>>Msize1 != 1 {
    229 		panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
    230 	}
    231 
    232 	// 4. Rounding.
    233 	if Emin-Msize <= exp && exp <= Emin {
    234 		// Denormal case; lose 'shift' bits of precision.
    235 		shift := uint(Emin - (exp - 1)) // [1..Esize1)
    236 		lostbits := mantissa & (1<<shift - 1)
    237 		haveRem = haveRem || lostbits != 0
    238 		mantissa >>= shift
    239 		exp = 2 - Ebias // == exp + shift
    240 	}
    241 	// Round q using round-half-to-even.
    242 	exact = !haveRem
    243 	if mantissa&1 != 0 {
    244 		exact = false
    245 		if haveRem || mantissa&2 != 0 {
    246 			if mantissa++; mantissa >= 1<<Msize2 {
    247 				// Complete rollover 11...1 => 100...0, so shift is safe
    248 				mantissa >>= 1
    249 				exp++
    250 			}
    251 		}
    252 	}
    253 	mantissa >>= 1 // discard rounding bit.  Mantissa now scaled by 1<<Msize1.
    254 
    255 	f = math.Ldexp(float64(mantissa), exp-Msize1)
    256 	if math.IsInf(f, 0) {
    257 		exact = false
    258 	}
    259 	return
    260 }
    261 
    262 // Float32 returns the nearest float32 value for x and a bool indicating
    263 // whether f represents x exactly. If the magnitude of x is too large to
    264 // be represented by a float32, f is an infinity and exact is false.
    265 // The sign of f always matches the sign of x, even if f == 0.
    266 func (x *Rat) Float32() (f float32, exact bool) {
    267 	b := x.b.abs
    268 	if len(b) == 0 {
    269 		b = b.set(natOne) // materialize denominator
    270 	}
    271 	f, exact = quotToFloat32(x.a.abs, b)
    272 	if x.a.neg {
    273 		f = -f
    274 	}
    275 	return
    276 }
    277 
    278 // Float64 returns the nearest float64 value for x and a bool indicating
    279 // whether f represents x exactly. If the magnitude of x is too large to
    280 // be represented by a float64, f is an infinity and exact is false.
    281 // The sign of f always matches the sign of x, even if f == 0.
    282 func (x *Rat) Float64() (f float64, exact bool) {
    283 	b := x.b.abs
    284 	if len(b) == 0 {
    285 		b = b.set(natOne) // materialize denominator
    286 	}
    287 	f, exact = quotToFloat64(x.a.abs, b)
    288 	if x.a.neg {
    289 		f = -f
    290 	}
    291 	return
    292 }
    293 
    294 // SetFrac sets z to a/b and returns z.
    295 func (z *Rat) SetFrac(a, b *Int) *Rat {
    296 	z.a.neg = a.neg != b.neg
    297 	babs := b.abs
    298 	if len(babs) == 0 {
    299 		panic("division by zero")
    300 	}
    301 	if &z.a == b || alias(z.a.abs, babs) {
    302 		babs = nat(nil).set(babs) // make a copy
    303 	}
    304 	z.a.abs = z.a.abs.set(a.abs)
    305 	z.b.abs = z.b.abs.set(babs)
    306 	return z.norm()
    307 }
    308 
    309 // SetFrac64 sets z to a/b and returns z.
    310 func (z *Rat) SetFrac64(a, b int64) *Rat {
    311 	z.a.SetInt64(a)
    312 	if b == 0 {
    313 		panic("division by zero")
    314 	}
    315 	if b < 0 {
    316 		b = -b
    317 		z.a.neg = !z.a.neg
    318 	}
    319 	z.b.abs = z.b.abs.setUint64(uint64(b))
    320 	return z.norm()
    321 }
    322 
    323 // SetInt sets z to x (by making a copy of x) and returns z.
    324 func (z *Rat) SetInt(x *Int) *Rat {
    325 	z.a.Set(x)
    326 	z.b.abs = z.b.abs[:0]
    327 	return z
    328 }
    329 
    330 // SetInt64 sets z to x and returns z.
    331 func (z *Rat) SetInt64(x int64) *Rat {
    332 	z.a.SetInt64(x)
    333 	z.b.abs = z.b.abs[:0]
    334 	return z
    335 }
    336 
    337 // Set sets z to x (by making a copy of x) and returns z.
    338 func (z *Rat) Set(x *Rat) *Rat {
    339 	if z != x {
    340 		z.a.Set(&x.a)
    341 		z.b.Set(&x.b)
    342 	}
    343 	return z
    344 }
    345 
    346 // Abs sets z to |x| (the absolute value of x) and returns z.
    347 func (z *Rat) Abs(x *Rat) *Rat {
    348 	z.Set(x)
    349 	z.a.neg = false
    350 	return z
    351 }
    352 
    353 // Neg sets z to -x and returns z.
    354 func (z *Rat) Neg(x *Rat) *Rat {
    355 	z.Set(x)
    356 	z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign
    357 	return z
    358 }
    359 
    360 // Inv sets z to 1/x and returns z.
    361 func (z *Rat) Inv(x *Rat) *Rat {
    362 	if len(x.a.abs) == 0 {
    363 		panic("division by zero")
    364 	}
    365 	z.Set(x)
    366 	a := z.b.abs
    367 	if len(a) == 0 {
    368 		a = a.set(natOne) // materialize numerator
    369 	}
    370 	b := z.a.abs
    371 	if b.cmp(natOne) == 0 {
    372 		b = b[:0] // normalize denominator
    373 	}
    374 	z.a.abs, z.b.abs = a, b // sign doesn't change
    375 	return z
    376 }
    377 
    378 // Sign returns:
    379 //
    380 //	-1 if x <  0
    381 //	 0 if x == 0
    382 //	+1 if x >  0
    383 //
    384 func (x *Rat) Sign() int {
    385 	return x.a.Sign()
    386 }
    387 
    388 // IsInt reports whether the denominator of x is 1.
    389 func (x *Rat) IsInt() bool {
    390 	return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0
    391 }
    392 
    393 // Num returns the numerator of x; it may be <= 0.
    394 // The result is a reference to x's numerator; it
    395 // may change if a new value is assigned to x, and vice versa.
    396 // The sign of the numerator corresponds to the sign of x.
    397 func (x *Rat) Num() *Int {
    398 	return &x.a
    399 }
    400 
    401 // Denom returns the denominator of x; it is always > 0.
    402 // The result is a reference to x's denominator; it
    403 // may change if a new value is assigned to x, and vice versa.
    404 func (x *Rat) Denom() *Int {
    405 	x.b.neg = false // the result is always >= 0
    406 	if len(x.b.abs) == 0 {
    407 		x.b.abs = x.b.abs.set(natOne) // materialize denominator
    408 	}
    409 	return &x.b
    410 }
    411 
    412 func (z *Rat) norm() *Rat {
    413 	switch {
    414 	case len(z.a.abs) == 0:
    415 		// z == 0 - normalize sign and denominator
    416 		z.a.neg = false
    417 		z.b.abs = z.b.abs[:0]
    418 	case len(z.b.abs) == 0:
    419 		// z is normalized int - nothing to do
    420 	case z.b.abs.cmp(natOne) == 0:
    421 		// z is int - normalize denominator
    422 		z.b.abs = z.b.abs[:0]
    423 	default:
    424 		neg := z.a.neg
    425 		z.a.neg = false
    426 		z.b.neg = false
    427 		if f := NewInt(0).binaryGCD(&z.a, &z.b); f.Cmp(intOne) != 0 {
    428 			z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f.abs)
    429 			z.b.abs, _ = z.b.abs.div(nil, z.b.abs, f.abs)
    430 			if z.b.abs.cmp(natOne) == 0 {
    431 				// z is int - normalize denominator
    432 				z.b.abs = z.b.abs[:0]
    433 			}
    434 		}
    435 		z.a.neg = neg
    436 	}
    437 	return z
    438 }
    439 
    440 // mulDenom sets z to the denominator product x*y (by taking into
    441 // account that 0 values for x or y must be interpreted as 1) and
    442 // returns z.
    443 func mulDenom(z, x, y nat) nat {
    444 	switch {
    445 	case len(x) == 0:
    446 		return z.set(y)
    447 	case len(y) == 0:
    448 		return z.set(x)
    449 	}
    450 	return z.mul(x, y)
    451 }
    452 
    453 // scaleDenom computes x*f.
    454 // If f == 0 (zero value of denominator), the result is (a copy of) x.
    455 func scaleDenom(x *Int, f nat) *Int {
    456 	var z Int
    457 	if len(f) == 0 {
    458 		return z.Set(x)
    459 	}
    460 	z.abs = z.abs.mul(x.abs, f)
    461 	z.neg = x.neg
    462 	return &z
    463 }
    464 
    465 // Cmp compares x and y and returns:
    466 //
    467 //   -1 if x <  y
    468 //    0 if x == y
    469 //   +1 if x >  y
    470 //
    471 func (x *Rat) Cmp(y *Rat) int {
    472 	return scaleDenom(&x.a, y.b.abs).Cmp(scaleDenom(&y.a, x.b.abs))
    473 }
    474 
    475 // Add sets z to the sum x+y and returns z.
    476 func (z *Rat) Add(x, y *Rat) *Rat {
    477 	a1 := scaleDenom(&x.a, y.b.abs)
    478 	a2 := scaleDenom(&y.a, x.b.abs)
    479 	z.a.Add(a1, a2)
    480 	z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
    481 	return z.norm()
    482 }
    483 
    484 // Sub sets z to the difference x-y and returns z.
    485 func (z *Rat) Sub(x, y *Rat) *Rat {
    486 	a1 := scaleDenom(&x.a, y.b.abs)
    487 	a2 := scaleDenom(&y.a, x.b.abs)
    488 	z.a.Sub(a1, a2)
    489 	z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
    490 	return z.norm()
    491 }
    492 
    493 // Mul sets z to the product x*y and returns z.
    494 func (z *Rat) Mul(x, y *Rat) *Rat {
    495 	z.a.Mul(&x.a, &y.a)
    496 	z.b.abs = mulDenom(z.b.abs, x.b.abs, y.b.abs)
    497 	return z.norm()
    498 }
    499 
    500 // Quo sets z to the quotient x/y and returns z.
    501 // If y == 0, a division-by-zero run-time panic occurs.
    502 func (z *Rat) Quo(x, y *Rat) *Rat {
    503 	if len(y.a.abs) == 0 {
    504 		panic("division by zero")
    505 	}
    506 	a := scaleDenom(&x.a, y.b.abs)
    507 	b := scaleDenom(&y.a, x.b.abs)
    508 	z.a.abs = a.abs
    509 	z.b.abs = b.abs
    510 	z.a.neg = a.neg != b.neg
    511 	return z.norm()
    512 }
    513 
    514 // Gob codec version. Permits backward-compatible changes to the encoding.
    515 const ratGobVersion byte = 1
    516 
    517 // GobEncode implements the gob.GobEncoder interface.
    518 func (x *Rat) GobEncode() ([]byte, error) {
    519 	if x == nil {
    520 		return nil, nil
    521 	}
    522 	buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b.abs))*_S) // extra bytes for version and sign bit (1), and numerator length (4)
    523 	i := x.b.abs.bytes(buf)
    524 	j := x.a.abs.bytes(buf[:i])
    525 	n := i - j
    526 	if int(uint32(n)) != n {
    527 		// this should never happen
    528 		return nil, errors.New("Rat.GobEncode: numerator too large")
    529 	}
    530 	binary.BigEndian.PutUint32(buf[j-4:j], uint32(n))
    531 	j -= 1 + 4
    532 	b := ratGobVersion << 1 // make space for sign bit
    533 	if x.a.neg {
    534 		b |= 1
    535 	}
    536 	buf[j] = b
    537 	return buf[j:], nil
    538 }
    539 
    540 // GobDecode implements the gob.GobDecoder interface.
    541 func (z *Rat) GobDecode(buf []byte) error {
    542 	if len(buf) == 0 {
    543 		// Other side sent a nil or default value.
    544 		*z = Rat{}
    545 		return nil
    546 	}
    547 	b := buf[0]
    548 	if b>>1 != ratGobVersion {
    549 		return fmt.Errorf("Rat.GobDecode: encoding version %d not supported", b>>1)
    550 	}
    551 	const j = 1 + 4
    552 	i := j + binary.BigEndian.Uint32(buf[j-4:j])
    553 	z.a.neg = b&1 != 0
    554 	z.a.abs = z.a.abs.setBytes(buf[j:i])
    555 	z.b.abs = z.b.abs.setBytes(buf[i:])
    556 	return nil
    557 }
    558 
    559 // MarshalText implements the encoding.TextMarshaler interface.
    560 func (r *Rat) MarshalText() (text []byte, err error) {
    561 	return []byte(r.RatString()), nil
    562 }
    563 
    564 // UnmarshalText implements the encoding.TextUnmarshaler interface.
    565 func (r *Rat) UnmarshalText(text []byte) error {
    566 	if _, ok := r.SetString(string(text)); !ok {
    567 		return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Rat", text)
    568 	}
    569 	return nil
    570 }
    571