Home | History | Annotate | Download | only in big
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // Package big implements multi-precision arithmetic (big numbers).
      6 // The following numeric types are supported:
      7 //
      8 //   Int    signed integers
      9 //   Rat    rational numbers
     10 //   Float  floating-point numbers
     11 //
     12 // Methods are typically of the form:
     13 //
     14 //   func (z *T) Unary(x *T) *T        // z = op x
     15 //   func (z *T) Binary(x, y *T) *T    // z = x op y
     16 //   func (x *T) M() T1                // v = x.M()
     17 //
     18 // with T one of Int, Rat, or Float. For unary and binary operations, the
     19 // result is the receiver (usually named z in that case); if it is one of
     20 // the operands x or y it may be overwritten (and its memory reused).
     21 // To enable chaining of operations, the result is also returned. Methods
     22 // returning a result other than *Int, *Rat, or *Float take an operand as
     23 // the receiver (usually named x in that case).
     24 //
     25 package big
     26 
     27 // This file contains operations on unsigned multi-precision integers.
     28 // These are the building blocks for the operations on signed integers
     29 // and rationals.
     30 
     31 import "math/rand"
     32 
     33 // An unsigned integer x of the form
     34 //
     35 //   x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
     36 //
     37 // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
     38 // with the digits x[i] as the slice elements.
     39 //
     40 // A number is normalized if the slice contains no leading 0 digits.
     41 // During arithmetic operations, denormalized values may occur but are
     42 // always normalized before returning the final result. The normalized
     43 // representation of 0 is the empty or nil slice (length = 0).
     44 //
     45 type nat []Word
     46 
     47 var (
     48 	natOne = nat{1}
     49 	natTwo = nat{2}
     50 	natTen = nat{10}
     51 )
     52 
     53 func (z nat) clear() {
     54 	for i := range z {
     55 		z[i] = 0
     56 	}
     57 }
     58 
     59 func (z nat) norm() nat {
     60 	i := len(z)
     61 	for i > 0 && z[i-1] == 0 {
     62 		i--
     63 	}
     64 	return z[0:i]
     65 }
     66 
     67 func (z nat) make(n int) nat {
     68 	if n <= cap(z) {
     69 		return z[:n] // reuse z
     70 	}
     71 	// Choosing a good value for e has significant performance impact
     72 	// because it increases the chance that a value can be reused.
     73 	const e = 4 // extra capacity
     74 	return make(nat, n, n+e)
     75 }
     76 
     77 func (z nat) setWord(x Word) nat {
     78 	if x == 0 {
     79 		return z[:0]
     80 	}
     81 	z = z.make(1)
     82 	z[0] = x
     83 	return z
     84 }
     85 
     86 func (z nat) setUint64(x uint64) nat {
     87 	// single-digit values
     88 	if w := Word(x); uint64(w) == x {
     89 		return z.setWord(w)
     90 	}
     91 
     92 	// compute number of words n required to represent x
     93 	n := 0
     94 	for t := x; t > 0; t >>= _W {
     95 		n++
     96 	}
     97 
     98 	// split x into n words
     99 	z = z.make(n)
    100 	for i := range z {
    101 		z[i] = Word(x & _M)
    102 		x >>= _W
    103 	}
    104 
    105 	return z
    106 }
    107 
    108 func (z nat) set(x nat) nat {
    109 	z = z.make(len(x))
    110 	copy(z, x)
    111 	return z
    112 }
    113 
    114 func (z nat) add(x, y nat) nat {
    115 	m := len(x)
    116 	n := len(y)
    117 
    118 	switch {
    119 	case m < n:
    120 		return z.add(y, x)
    121 	case m == 0:
    122 		// n == 0 because m >= n; result is 0
    123 		return z[:0]
    124 	case n == 0:
    125 		// result is x
    126 		return z.set(x)
    127 	}
    128 	// m > 0
    129 
    130 	z = z.make(m + 1)
    131 	c := addVV(z[0:n], x, y)
    132 	if m > n {
    133 		c = addVW(z[n:m], x[n:], c)
    134 	}
    135 	z[m] = c
    136 
    137 	return z.norm()
    138 }
    139 
    140 func (z nat) sub(x, y nat) nat {
    141 	m := len(x)
    142 	n := len(y)
    143 
    144 	switch {
    145 	case m < n:
    146 		panic("underflow")
    147 	case m == 0:
    148 		// n == 0 because m >= n; result is 0
    149 		return z[:0]
    150 	case n == 0:
    151 		// result is x
    152 		return z.set(x)
    153 	}
    154 	// m > 0
    155 
    156 	z = z.make(m)
    157 	c := subVV(z[0:n], x, y)
    158 	if m > n {
    159 		c = subVW(z[n:], x[n:], c)
    160 	}
    161 	if c != 0 {
    162 		panic("underflow")
    163 	}
    164 
    165 	return z.norm()
    166 }
    167 
    168 func (x nat) cmp(y nat) (r int) {
    169 	m := len(x)
    170 	n := len(y)
    171 	if m != n || m == 0 {
    172 		switch {
    173 		case m < n:
    174 			r = -1
    175 		case m > n:
    176 			r = 1
    177 		}
    178 		return
    179 	}
    180 
    181 	i := m - 1
    182 	for i > 0 && x[i] == y[i] {
    183 		i--
    184 	}
    185 
    186 	switch {
    187 	case x[i] < y[i]:
    188 		r = -1
    189 	case x[i] > y[i]:
    190 		r = 1
    191 	}
    192 	return
    193 }
    194 
    195 func (z nat) mulAddWW(x nat, y, r Word) nat {
    196 	m := len(x)
    197 	if m == 0 || y == 0 {
    198 		return z.setWord(r) // result is r
    199 	}
    200 	// m > 0
    201 
    202 	z = z.make(m + 1)
    203 	z[m] = mulAddVWW(z[0:m], x, y, r)
    204 
    205 	return z.norm()
    206 }
    207 
    208 // basicMul multiplies x and y and leaves the result in z.
    209 // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
    210 func basicMul(z, x, y nat) {
    211 	z[0 : len(x)+len(y)].clear() // initialize z
    212 	for i, d := range y {
    213 		if d != 0 {
    214 			z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
    215 		}
    216 	}
    217 }
    218 
    219 // montgomery computes x*y*2^(-n*_W) mod m,
    220 // assuming k = -1/m mod 2^_W.
    221 // z is used for storing the result which is returned;
    222 // z must not alias x, y or m.
    223 func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
    224 	var c1, c2 Word
    225 	z = z.make(n)
    226 	z.clear()
    227 	for i := 0; i < n; i++ {
    228 		d := y[i]
    229 		c1 += addMulVVW(z, x, d)
    230 		t := z[0] * k
    231 		c2 = addMulVVW(z, m, t)
    232 
    233 		copy(z, z[1:])
    234 		z[n-1] = c1 + c2
    235 		if z[n-1] < c1 {
    236 			c1 = 1
    237 		} else {
    238 			c1 = 0
    239 		}
    240 	}
    241 	if c1 != 0 {
    242 		subVV(z, z, m)
    243 	}
    244 	return z
    245 }
    246 
    247 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
    248 // Factored out for readability - do not use outside karatsuba.
    249 func karatsubaAdd(z, x nat, n int) {
    250 	if c := addVV(z[0:n], z, x); c != 0 {
    251 		addVW(z[n:n+n>>1], z[n:], c)
    252 	}
    253 }
    254 
    255 // Like karatsubaAdd, but does subtract.
    256 func karatsubaSub(z, x nat, n int) {
    257 	if c := subVV(z[0:n], z, x); c != 0 {
    258 		subVW(z[n:n+n>>1], z[n:], c)
    259 	}
    260 }
    261 
    262 // Operands that are shorter than karatsubaThreshold are multiplied using
    263 // "grade school" multiplication; for longer operands the Karatsuba algorithm
    264 // is used.
    265 var karatsubaThreshold int = 40 // computed by calibrate.go
    266 
    267 // karatsuba multiplies x and y and leaves the result in z.
    268 // Both x and y must have the same length n and n must be a
    269 // power of 2. The result vector z must have len(z) >= 6*n.
    270 // The (non-normalized) result is placed in z[0 : 2*n].
    271 func karatsuba(z, x, y nat) {
    272 	n := len(y)
    273 
    274 	// Switch to basic multiplication if numbers are odd or small.
    275 	// (n is always even if karatsubaThreshold is even, but be
    276 	// conservative)
    277 	if n&1 != 0 || n < karatsubaThreshold || n < 2 {
    278 		basicMul(z, x, y)
    279 		return
    280 	}
    281 	// n&1 == 0 && n >= karatsubaThreshold && n >= 2
    282 
    283 	// Karatsuba multiplication is based on the observation that
    284 	// for two numbers x and y with:
    285 	//
    286 	//   x = x1*b + x0
    287 	//   y = y1*b + y0
    288 	//
    289 	// the product x*y can be obtained with 3 products z2, z1, z0
    290 	// instead of 4:
    291 	//
    292 	//   x*y = x1*y1*b*b + (x1*y0 + x0*y1)*b + x0*y0
    293 	//       =    z2*b*b +              z1*b +    z0
    294 	//
    295 	// with:
    296 	//
    297 	//   xd = x1 - x0
    298 	//   yd = y0 - y1
    299 	//
    300 	//   z1 =      xd*yd                    + z2 + z0
    301 	//      = (x1-x0)*(y0 - y1)             + z2 + z0
    302 	//      = x1*y0 - x1*y1 - x0*y0 + x0*y1 + z2 + z0
    303 	//      = x1*y0 -    z2 -    z0 + x0*y1 + z2 + z0
    304 	//      = x1*y0                 + x0*y1
    305 
    306 	// split x, y into "digits"
    307 	n2 := n >> 1              // n2 >= 1
    308 	x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
    309 	y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
    310 
    311 	// z is used for the result and temporary storage:
    312 	//
    313 	//   6*n     5*n     4*n     3*n     2*n     1*n     0*n
    314 	// z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
    315 	//
    316 	// For each recursive call of karatsuba, an unused slice of
    317 	// z is passed in that has (at least) half the length of the
    318 	// caller's z.
    319 
    320 	// compute z0 and z2 with the result "in place" in z
    321 	karatsuba(z, x0, y0)     // z0 = x0*y0
    322 	karatsuba(z[n:], x1, y1) // z2 = x1*y1
    323 
    324 	// compute xd (or the negative value if underflow occurs)
    325 	s := 1 // sign of product xd*yd
    326 	xd := z[2*n : 2*n+n2]
    327 	if subVV(xd, x1, x0) != 0 { // x1-x0
    328 		s = -s
    329 		subVV(xd, x0, x1) // x0-x1
    330 	}
    331 
    332 	// compute yd (or the negative value if underflow occurs)
    333 	yd := z[2*n+n2 : 3*n]
    334 	if subVV(yd, y0, y1) != 0 { // y0-y1
    335 		s = -s
    336 		subVV(yd, y1, y0) // y1-y0
    337 	}
    338 
    339 	// p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
    340 	// p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
    341 	p := z[n*3:]
    342 	karatsuba(p, xd, yd)
    343 
    344 	// save original z2:z0
    345 	// (ok to use upper half of z since we're done recursing)
    346 	r := z[n*4:]
    347 	copy(r, z[:n*2])
    348 
    349 	// add up all partial products
    350 	//
    351 	//   2*n     n     0
    352 	// z = [ z2  | z0  ]
    353 	//   +    [ z0  ]
    354 	//   +    [ z2  ]
    355 	//   +    [  p  ]
    356 	//
    357 	karatsubaAdd(z[n2:], r, n)
    358 	karatsubaAdd(z[n2:], r[n:], n)
    359 	if s > 0 {
    360 		karatsubaAdd(z[n2:], p, n)
    361 	} else {
    362 		karatsubaSub(z[n2:], p, n)
    363 	}
    364 }
    365 
    366 // alias reports whether x and y share the same base array.
    367 func alias(x, y nat) bool {
    368 	return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
    369 }
    370 
    371 // addAt implements z += x<<(_W*i); z must be long enough.
    372 // (we don't use nat.add because we need z to stay the same
    373 // slice, and we don't need to normalize z after each addition)
    374 func addAt(z, x nat, i int) {
    375 	if n := len(x); n > 0 {
    376 		if c := addVV(z[i:i+n], z[i:], x); c != 0 {
    377 			j := i + n
    378 			if j < len(z) {
    379 				addVW(z[j:], z[j:], c)
    380 			}
    381 		}
    382 	}
    383 }
    384 
    385 func max(x, y int) int {
    386 	if x > y {
    387 		return x
    388 	}
    389 	return y
    390 }
    391 
    392 // karatsubaLen computes an approximation to the maximum k <= n such that
    393 // k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
    394 // result is the largest number that can be divided repeatedly by 2 before
    395 // becoming about the value of karatsubaThreshold.
    396 func karatsubaLen(n int) int {
    397 	i := uint(0)
    398 	for n > karatsubaThreshold {
    399 		n >>= 1
    400 		i++
    401 	}
    402 	return n << i
    403 }
    404 
    405 func (z nat) mul(x, y nat) nat {
    406 	m := len(x)
    407 	n := len(y)
    408 
    409 	switch {
    410 	case m < n:
    411 		return z.mul(y, x)
    412 	case m == 0 || n == 0:
    413 		return z[:0]
    414 	case n == 1:
    415 		return z.mulAddWW(x, y[0], 0)
    416 	}
    417 	// m >= n > 1
    418 
    419 	// determine if z can be reused
    420 	if alias(z, x) || alias(z, y) {
    421 		z = nil // z is an alias for x or y - cannot reuse
    422 	}
    423 
    424 	// use basic multiplication if the numbers are small
    425 	if n < karatsubaThreshold {
    426 		z = z.make(m + n)
    427 		basicMul(z, x, y)
    428 		return z.norm()
    429 	}
    430 	// m >= n && n >= karatsubaThreshold && n >= 2
    431 
    432 	// determine Karatsuba length k such that
    433 	//
    434 	//   x = xh*b + x0  (0 <= x0 < b)
    435 	//   y = yh*b + y0  (0 <= y0 < b)
    436 	//   b = 1<<(_W*k)  ("base" of digits xi, yi)
    437 	//
    438 	k := karatsubaLen(n)
    439 	// k <= n
    440 
    441 	// multiply x0 and y0 via Karatsuba
    442 	x0 := x[0:k]              // x0 is not normalized
    443 	y0 := y[0:k]              // y0 is not normalized
    444 	z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
    445 	karatsuba(z, x0, y0)
    446 	z = z[0 : m+n]  // z has final length but may be incomplete
    447 	z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
    448 
    449 	// If xh != 0 or yh != 0, add the missing terms to z. For
    450 	//
    451 	//   xh = xi*b^i + ... + x2*b^2 + x1*b (0 <= xi < b)
    452 	//   yh =                         y1*b (0 <= y1 < b)
    453 	//
    454 	// the missing terms are
    455 	//
    456 	//   x0*y1*b and xi*y0*b^i, xi*y1*b^(i+1) for i > 0
    457 	//
    458 	// since all the yi for i > 1 are 0 by choice of k: If any of them
    459 	// were > 0, then yh >= b^2 and thus y >= b^2. Then k' = k*2 would
    460 	// be a larger valid threshold contradicting the assumption about k.
    461 	//
    462 	if k < n || m != n {
    463 		var t nat
    464 
    465 		// add x0*y1*b
    466 		x0 := x0.norm()
    467 		y1 := y[k:]       // y1 is normalized because y is
    468 		t = t.mul(x0, y1) // update t so we don't lose t's underlying array
    469 		addAt(z, t, k)
    470 
    471 		// add xi*y0<<i, xi*y1*b<<(i+k)
    472 		y0 := y0.norm()
    473 		for i := k; i < len(x); i += k {
    474 			xi := x[i:]
    475 			if len(xi) > k {
    476 				xi = xi[:k]
    477 			}
    478 			xi = xi.norm()
    479 			t = t.mul(xi, y0)
    480 			addAt(z, t, i)
    481 			t = t.mul(xi, y1)
    482 			addAt(z, t, i+k)
    483 		}
    484 	}
    485 
    486 	return z.norm()
    487 }
    488 
    489 // mulRange computes the product of all the unsigned integers in the
    490 // range [a, b] inclusively. If a > b (empty range), the result is 1.
    491 func (z nat) mulRange(a, b uint64) nat {
    492 	switch {
    493 	case a == 0:
    494 		// cut long ranges short (optimization)
    495 		return z.setUint64(0)
    496 	case a > b:
    497 		return z.setUint64(1)
    498 	case a == b:
    499 		return z.setUint64(a)
    500 	case a+1 == b:
    501 		return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b))
    502 	}
    503 	m := (a + b) / 2
    504 	return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
    505 }
    506 
    507 // q = (x-r)/y, with 0 <= r < y
    508 func (z nat) divW(x nat, y Word) (q nat, r Word) {
    509 	m := len(x)
    510 	switch {
    511 	case y == 0:
    512 		panic("division by zero")
    513 	case y == 1:
    514 		q = z.set(x) // result is x
    515 		return
    516 	case m == 0:
    517 		q = z[:0] // result is 0
    518 		return
    519 	}
    520 	// m > 0
    521 	z = z.make(m)
    522 	r = divWVW(z, 0, x, y)
    523 	q = z.norm()
    524 	return
    525 }
    526 
    527 func (z nat) div(z2, u, v nat) (q, r nat) {
    528 	if len(v) == 0 {
    529 		panic("division by zero")
    530 	}
    531 
    532 	if u.cmp(v) < 0 {
    533 		q = z[:0]
    534 		r = z2.set(u)
    535 		return
    536 	}
    537 
    538 	if len(v) == 1 {
    539 		var r2 Word
    540 		q, r2 = z.divW(u, v[0])
    541 		r = z2.setWord(r2)
    542 		return
    543 	}
    544 
    545 	q, r = z.divLarge(z2, u, v)
    546 	return
    547 }
    548 
    549 // q = (uIn-r)/v, with 0 <= r < y
    550 // Uses z as storage for q, and u as storage for r if possible.
    551 // See Knuth, Volume 2, section 4.3.1, Algorithm D.
    552 // Preconditions:
    553 //    len(v) >= 2
    554 //    len(uIn) >= len(v)
    555 func (z nat) divLarge(u, uIn, v nat) (q, r nat) {
    556 	n := len(v)
    557 	m := len(uIn) - n
    558 
    559 	// determine if z can be reused
    560 	// TODO(gri) should find a better solution - this if statement
    561 	//           is very costly (see e.g. time pidigits -s -n 10000)
    562 	if alias(z, uIn) || alias(z, v) {
    563 		z = nil // z is an alias for uIn or v - cannot reuse
    564 	}
    565 	q = z.make(m + 1)
    566 
    567 	qhatv := make(nat, n+1)
    568 	if alias(u, uIn) || alias(u, v) {
    569 		u = nil // u is an alias for uIn or v - cannot reuse
    570 	}
    571 	u = u.make(len(uIn) + 1)
    572 	u.clear() // TODO(gri) no need to clear if we allocated a new u
    573 
    574 	// D1.
    575 	shift := nlz(v[n-1])
    576 	if shift > 0 {
    577 		// do not modify v, it may be used by another goroutine simultaneously
    578 		v1 := make(nat, n)
    579 		shlVU(v1, v, shift)
    580 		v = v1
    581 	}
    582 	u[len(uIn)] = shlVU(u[0:len(uIn)], uIn, shift)
    583 
    584 	// D2.
    585 	for j := m; j >= 0; j-- {
    586 		// D3.
    587 		qhat := Word(_M)
    588 		if u[j+n] != v[n-1] {
    589 			var rhat Word
    590 			qhat, rhat = divWW(u[j+n], u[j+n-1], v[n-1])
    591 
    592 			// x1 | x2 = qv_{n-2}
    593 			x1, x2 := mulWW(qhat, v[n-2])
    594 			// test if qv_{n-2} > br + u_{j+n-2}
    595 			for greaterThan(x1, x2, rhat, u[j+n-2]) {
    596 				qhat--
    597 				prevRhat := rhat
    598 				rhat += v[n-1]
    599 				// v[n-1] >= 0, so this tests for overflow.
    600 				if rhat < prevRhat {
    601 					break
    602 				}
    603 				x1, x2 = mulWW(qhat, v[n-2])
    604 			}
    605 		}
    606 
    607 		// D4.
    608 		qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0)
    609 
    610 		c := subVV(u[j:j+len(qhatv)], u[j:], qhatv)
    611 		if c != 0 {
    612 			c := addVV(u[j:j+n], u[j:], v)
    613 			u[j+n] += c
    614 			qhat--
    615 		}
    616 
    617 		q[j] = qhat
    618 	}
    619 
    620 	q = q.norm()
    621 	shrVU(u, u, shift)
    622 	r = u.norm()
    623 
    624 	return q, r
    625 }
    626 
    627 // Length of x in bits. x must be normalized.
    628 func (x nat) bitLen() int {
    629 	if i := len(x) - 1; i >= 0 {
    630 		return i*_W + bitLen(x[i])
    631 	}
    632 	return 0
    633 }
    634 
    635 const deBruijn32 = 0x077CB531
    636 
    637 var deBruijn32Lookup = []byte{
    638 	0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
    639 	31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9,
    640 }
    641 
    642 const deBruijn64 = 0x03f79d71b4ca8b09
    643 
    644 var deBruijn64Lookup = []byte{
    645 	0, 1, 56, 2, 57, 49, 28, 3, 61, 58, 42, 50, 38, 29, 17, 4,
    646 	62, 47, 59, 36, 45, 43, 51, 22, 53, 39, 33, 30, 24, 18, 12, 5,
    647 	63, 55, 48, 27, 60, 41, 37, 16, 46, 35, 44, 21, 52, 32, 23, 11,
    648 	54, 26, 40, 15, 34, 20, 31, 10, 25, 14, 19, 9, 13, 8, 7, 6,
    649 }
    650 
    651 // trailingZeroBits returns the number of consecutive least significant zero
    652 // bits of x.
    653 func trailingZeroBits(x Word) uint {
    654 	// x & -x leaves only the right-most bit set in the word. Let k be the
    655 	// index of that bit. Since only a single bit is set, the value is two
    656 	// to the power of k. Multiplying by a power of two is equivalent to
    657 	// left shifting, in this case by k bits.  The de Bruijn constant is
    658 	// such that all six bit, consecutive substrings are distinct.
    659 	// Therefore, if we have a left shifted version of this constant we can
    660 	// find by how many bits it was shifted by looking at which six bit
    661 	// substring ended up at the top of the word.
    662 	// (Knuth, volume 4, section 7.3.1)
    663 	switch _W {
    664 	case 32:
    665 		return uint(deBruijn32Lookup[((x&-x)*deBruijn32)>>27])
    666 	case 64:
    667 		return uint(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
    668 	default:
    669 		panic("unknown word size")
    670 	}
    671 }
    672 
    673 // trailingZeroBits returns the number of consecutive least significant zero
    674 // bits of x.
    675 func (x nat) trailingZeroBits() uint {
    676 	if len(x) == 0 {
    677 		return 0
    678 	}
    679 	var i uint
    680 	for x[i] == 0 {
    681 		i++
    682 	}
    683 	// x[i] != 0
    684 	return i*_W + trailingZeroBits(x[i])
    685 }
    686 
    687 // z = x << s
    688 func (z nat) shl(x nat, s uint) nat {
    689 	m := len(x)
    690 	if m == 0 {
    691 		return z[:0]
    692 	}
    693 	// m > 0
    694 
    695 	n := m + int(s/_W)
    696 	z = z.make(n + 1)
    697 	z[n] = shlVU(z[n-m:n], x, s%_W)
    698 	z[0 : n-m].clear()
    699 
    700 	return z.norm()
    701 }
    702 
    703 // z = x >> s
    704 func (z nat) shr(x nat, s uint) nat {
    705 	m := len(x)
    706 	n := m - int(s/_W)
    707 	if n <= 0 {
    708 		return z[:0]
    709 	}
    710 	// n > 0
    711 
    712 	z = z.make(n)
    713 	shrVU(z, x[m-n:], s%_W)
    714 
    715 	return z.norm()
    716 }
    717 
    718 func (z nat) setBit(x nat, i uint, b uint) nat {
    719 	j := int(i / _W)
    720 	m := Word(1) << (i % _W)
    721 	n := len(x)
    722 	switch b {
    723 	case 0:
    724 		z = z.make(n)
    725 		copy(z, x)
    726 		if j >= n {
    727 			// no need to grow
    728 			return z
    729 		}
    730 		z[j] &^= m
    731 		return z.norm()
    732 	case 1:
    733 		if j >= n {
    734 			z = z.make(j + 1)
    735 			z[n:].clear()
    736 		} else {
    737 			z = z.make(n)
    738 		}
    739 		copy(z, x)
    740 		z[j] |= m
    741 		// no need to normalize
    742 		return z
    743 	}
    744 	panic("set bit is not 0 or 1")
    745 }
    746 
    747 // bit returns the value of the i'th bit, with lsb == bit 0.
    748 func (x nat) bit(i uint) uint {
    749 	j := i / _W
    750 	if j >= uint(len(x)) {
    751 		return 0
    752 	}
    753 	// 0 <= j < len(x)
    754 	return uint(x[j] >> (i % _W) & 1)
    755 }
    756 
    757 // sticky returns 1 if there's a 1 bit within the
    758 // i least significant bits, otherwise it returns 0.
    759 func (x nat) sticky(i uint) uint {
    760 	j := i / _W
    761 	if j >= uint(len(x)) {
    762 		if len(x) == 0 {
    763 			return 0
    764 		}
    765 		return 1
    766 	}
    767 	// 0 <= j < len(x)
    768 	for _, x := range x[:j] {
    769 		if x != 0 {
    770 			return 1
    771 		}
    772 	}
    773 	if x[j]<<(_W-i%_W) != 0 {
    774 		return 1
    775 	}
    776 	return 0
    777 }
    778 
    779 func (z nat) and(x, y nat) nat {
    780 	m := len(x)
    781 	n := len(y)
    782 	if m > n {
    783 		m = n
    784 	}
    785 	// m <= n
    786 
    787 	z = z.make(m)
    788 	for i := 0; i < m; i++ {
    789 		z[i] = x[i] & y[i]
    790 	}
    791 
    792 	return z.norm()
    793 }
    794 
    795 func (z nat) andNot(x, y nat) nat {
    796 	m := len(x)
    797 	n := len(y)
    798 	if n > m {
    799 		n = m
    800 	}
    801 	// m >= n
    802 
    803 	z = z.make(m)
    804 	for i := 0; i < n; i++ {
    805 		z[i] = x[i] &^ y[i]
    806 	}
    807 	copy(z[n:m], x[n:m])
    808 
    809 	return z.norm()
    810 }
    811 
    812 func (z nat) or(x, y nat) nat {
    813 	m := len(x)
    814 	n := len(y)
    815 	s := x
    816 	if m < n {
    817 		n, m = m, n
    818 		s = y
    819 	}
    820 	// m >= n
    821 
    822 	z = z.make(m)
    823 	for i := 0; i < n; i++ {
    824 		z[i] = x[i] | y[i]
    825 	}
    826 	copy(z[n:m], s[n:m])
    827 
    828 	return z.norm()
    829 }
    830 
    831 func (z nat) xor(x, y nat) nat {
    832 	m := len(x)
    833 	n := len(y)
    834 	s := x
    835 	if m < n {
    836 		n, m = m, n
    837 		s = y
    838 	}
    839 	// m >= n
    840 
    841 	z = z.make(m)
    842 	for i := 0; i < n; i++ {
    843 		z[i] = x[i] ^ y[i]
    844 	}
    845 	copy(z[n:m], s[n:m])
    846 
    847 	return z.norm()
    848 }
    849 
    850 // greaterThan reports whether (x1<<_W + x2) > (y1<<_W + y2)
    851 func greaterThan(x1, x2, y1, y2 Word) bool {
    852 	return x1 > y1 || x1 == y1 && x2 > y2
    853 }
    854 
    855 // modW returns x % d.
    856 func (x nat) modW(d Word) (r Word) {
    857 	// TODO(agl): we don't actually need to store the q value.
    858 	var q nat
    859 	q = q.make(len(x))
    860 	return divWVW(q, 0, x, d)
    861 }
    862 
    863 // random creates a random integer in [0..limit), using the space in z if
    864 // possible. n is the bit length of limit.
    865 func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
    866 	if alias(z, limit) {
    867 		z = nil // z is an alias for limit - cannot reuse
    868 	}
    869 	z = z.make(len(limit))
    870 
    871 	bitLengthOfMSW := uint(n % _W)
    872 	if bitLengthOfMSW == 0 {
    873 		bitLengthOfMSW = _W
    874 	}
    875 	mask := Word((1 << bitLengthOfMSW) - 1)
    876 
    877 	for {
    878 		switch _W {
    879 		case 32:
    880 			for i := range z {
    881 				z[i] = Word(rand.Uint32())
    882 			}
    883 		case 64:
    884 			for i := range z {
    885 				z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
    886 			}
    887 		default:
    888 			panic("unknown word size")
    889 		}
    890 		z[len(limit)-1] &= mask
    891 		if z.cmp(limit) < 0 {
    892 			break
    893 		}
    894 	}
    895 
    896 	return z.norm()
    897 }
    898 
    899 // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
    900 // otherwise it sets z to x**y. The result is the value of z.
    901 func (z nat) expNN(x, y, m nat) nat {
    902 	if alias(z, x) || alias(z, y) {
    903 		// We cannot allow in-place modification of x or y.
    904 		z = nil
    905 	}
    906 
    907 	// x**y mod 1 == 0
    908 	if len(m) == 1 && m[0] == 1 {
    909 		return z.setWord(0)
    910 	}
    911 	// m == 0 || m > 1
    912 
    913 	// x**0 == 1
    914 	if len(y) == 0 {
    915 		return z.setWord(1)
    916 	}
    917 	// y > 0
    918 
    919 	// x**1 mod m == x mod m
    920 	if len(y) == 1 && y[0] == 1 && len(m) != 0 {
    921 		_, z = z.div(z, x, m)
    922 		return z
    923 	}
    924 	// y > 1
    925 
    926 	if len(m) != 0 {
    927 		// We likely end up being as long as the modulus.
    928 		z = z.make(len(m))
    929 	}
    930 	z = z.set(x)
    931 
    932 	// If the base is non-trivial and the exponent is large, we use
    933 	// 4-bit, windowed exponentiation. This involves precomputing 14 values
    934 	// (x^2...x^15) but then reduces the number of multiply-reduces by a
    935 	// third. Even for a 32-bit exponent, this reduces the number of
    936 	// operations. Uses Montgomery method for odd moduli.
    937 	if len(x) > 1 && len(y) > 1 && len(m) > 0 {
    938 		if m[0]&1 == 1 {
    939 			return z.expNNMontgomery(x, y, m)
    940 		}
    941 		return z.expNNWindowed(x, y, m)
    942 	}
    943 
    944 	v := y[len(y)-1] // v > 0 because y is normalized and y > 0
    945 	shift := nlz(v) + 1
    946 	v <<= shift
    947 	var q nat
    948 
    949 	const mask = 1 << (_W - 1)
    950 
    951 	// We walk through the bits of the exponent one by one. Each time we
    952 	// see a bit, we square, thus doubling the power. If the bit is a one,
    953 	// we also multiply by x, thus adding one to the power.
    954 
    955 	w := _W - int(shift)
    956 	// zz and r are used to avoid allocating in mul and div as
    957 	// otherwise the arguments would alias.
    958 	var zz, r nat
    959 	for j := 0; j < w; j++ {
    960 		zz = zz.mul(z, z)
    961 		zz, z = z, zz
    962 
    963 		if v&mask != 0 {
    964 			zz = zz.mul(z, x)
    965 			zz, z = z, zz
    966 		}
    967 
    968 		if len(m) != 0 {
    969 			zz, r = zz.div(r, z, m)
    970 			zz, r, q, z = q, z, zz, r
    971 		}
    972 
    973 		v <<= 1
    974 	}
    975 
    976 	for i := len(y) - 2; i >= 0; i-- {
    977 		v = y[i]
    978 
    979 		for j := 0; j < _W; j++ {
    980 			zz = zz.mul(z, z)
    981 			zz, z = z, zz
    982 
    983 			if v&mask != 0 {
    984 				zz = zz.mul(z, x)
    985 				zz, z = z, zz
    986 			}
    987 
    988 			if len(m) != 0 {
    989 				zz, r = zz.div(r, z, m)
    990 				zz, r, q, z = q, z, zz, r
    991 			}
    992 
    993 			v <<= 1
    994 		}
    995 	}
    996 
    997 	return z.norm()
    998 }
    999 
   1000 // expNNWindowed calculates x**y mod m using a fixed, 4-bit window.
   1001 func (z nat) expNNWindowed(x, y, m nat) nat {
   1002 	// zz and r are used to avoid allocating in mul and div as otherwise
   1003 	// the arguments would alias.
   1004 	var zz, r nat
   1005 
   1006 	const n = 4
   1007 	// powers[i] contains x^i.
   1008 	var powers [1 << n]nat
   1009 	powers[0] = natOne
   1010 	powers[1] = x
   1011 	for i := 2; i < 1<<n; i += 2 {
   1012 		p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
   1013 		*p = p.mul(*p2, *p2)
   1014 		zz, r = zz.div(r, *p, m)
   1015 		*p, r = r, *p
   1016 		*p1 = p1.mul(*p, x)
   1017 		zz, r = zz.div(r, *p1, m)
   1018 		*p1, r = r, *p1
   1019 	}
   1020 
   1021 	z = z.setWord(1)
   1022 
   1023 	for i := len(y) - 1; i >= 0; i-- {
   1024 		yi := y[i]
   1025 		for j := 0; j < _W; j += n {
   1026 			if i != len(y)-1 || j != 0 {
   1027 				// Unrolled loop for significant performance
   1028 				// gain.  Use go test -bench=".*" in crypto/rsa
   1029 				// to check performance before making changes.
   1030 				zz = zz.mul(z, z)
   1031 				zz, z = z, zz
   1032 				zz, r = zz.div(r, z, m)
   1033 				z, r = r, z
   1034 
   1035 				zz = zz.mul(z, z)
   1036 				zz, z = z, zz
   1037 				zz, r = zz.div(r, z, m)
   1038 				z, r = r, z
   1039 
   1040 				zz = zz.mul(z, z)
   1041 				zz, z = z, zz
   1042 				zz, r = zz.div(r, z, m)
   1043 				z, r = r, z
   1044 
   1045 				zz = zz.mul(z, z)
   1046 				zz, z = z, zz
   1047 				zz, r = zz.div(r, z, m)
   1048 				z, r = r, z
   1049 			}
   1050 
   1051 			zz = zz.mul(z, powers[yi>>(_W-n)])
   1052 			zz, z = z, zz
   1053 			zz, r = zz.div(r, z, m)
   1054 			z, r = r, z
   1055 
   1056 			yi <<= n
   1057 		}
   1058 	}
   1059 
   1060 	return z.norm()
   1061 }
   1062 
   1063 // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
   1064 // Uses Montgomery representation.
   1065 func (z nat) expNNMontgomery(x, y, m nat) nat {
   1066 	var zz, one, rr, RR nat
   1067 
   1068 	numWords := len(m)
   1069 
   1070 	// We want the lengths of x and m to be equal.
   1071 	if len(x) > numWords {
   1072 		_, rr = rr.div(rr, x, m)
   1073 	} else if len(x) < numWords {
   1074 		rr = rr.make(numWords)
   1075 		rr.clear()
   1076 		for i := range x {
   1077 			rr[i] = x[i]
   1078 		}
   1079 	} else {
   1080 		rr = x
   1081 	}
   1082 	x = rr
   1083 
   1084 	// Ideally the precomputations would be performed outside, and reused
   1085 	// k0 = -m-1 mod 2_W. Algorithm from: Dumas, J.G. "On NewtonRaphson
   1086 	// Iteration for Multiplicative Inverses Modulo Prime Powers".
   1087 	k0 := 2 - m[0]
   1088 	t := m[0] - 1
   1089 	for i := 1; i < _W; i <<= 1 {
   1090 		t *= t
   1091 		k0 *= (t + 1)
   1092 	}
   1093 	k0 = -k0
   1094 
   1095 	// RR = 2(2*_W*len(m)) mod m
   1096 	RR = RR.setWord(1)
   1097 	zz = zz.shl(RR, uint(2*numWords*_W))
   1098 	_, RR = RR.div(RR, zz, m)
   1099 	if len(RR) < numWords {
   1100 		zz = zz.make(numWords)
   1101 		copy(zz, RR)
   1102 		RR = zz
   1103 	}
   1104 	// one = 1, with equal length to that of m
   1105 	one = one.make(numWords)
   1106 	one.clear()
   1107 	one[0] = 1
   1108 
   1109 	const n = 4
   1110 	// powers[i] contains x^i
   1111 	var powers [1 << n]nat
   1112 	powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
   1113 	powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
   1114 	for i := 2; i < 1<<n; i++ {
   1115 		powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
   1116 	}
   1117 
   1118 	// initialize z = 1 (Montgomery 1)
   1119 	z = z.make(numWords)
   1120 	copy(z, powers[0])
   1121 
   1122 	zz = zz.make(numWords)
   1123 
   1124 	// same windowed exponent, but with Montgomery multiplications
   1125 	for i := len(y) - 1; i >= 0; i-- {
   1126 		yi := y[i]
   1127 		for j := 0; j < _W; j += n {
   1128 			if i != len(y)-1 || j != 0 {
   1129 				zz = zz.montgomery(z, z, m, k0, numWords)
   1130 				z = z.montgomery(zz, zz, m, k0, numWords)
   1131 				zz = zz.montgomery(z, z, m, k0, numWords)
   1132 				z = z.montgomery(zz, zz, m, k0, numWords)
   1133 			}
   1134 			zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
   1135 			z, zz = zz, z
   1136 			yi <<= n
   1137 		}
   1138 	}
   1139 	// convert to regular number
   1140 	zz = zz.montgomery(z, one, m, k0, numWords)
   1141 	return zz.norm()
   1142 }
   1143 
   1144 // probablyPrime performs reps Miller-Rabin tests to check whether n is prime.
   1145 // If it returns true, n is prime with probability 1 - 1/4^reps.
   1146 // If it returns false, n is not prime.
   1147 func (n nat) probablyPrime(reps int) bool {
   1148 	if len(n) == 0 {
   1149 		return false
   1150 	}
   1151 
   1152 	if len(n) == 1 {
   1153 		if n[0] < 2 {
   1154 			return false
   1155 		}
   1156 
   1157 		if n[0]%2 == 0 {
   1158 			return n[0] == 2
   1159 		}
   1160 
   1161 		// We have to exclude these cases because we reject all
   1162 		// multiples of these numbers below.
   1163 		switch n[0] {
   1164 		case 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53:
   1165 			return true
   1166 		}
   1167 	}
   1168 
   1169 	if n[0]&1 == 0 {
   1170 		return false // n is even
   1171 	}
   1172 
   1173 	const primesProduct32 = 0xC0CFD797         //  {p  primes, 2 < p <= 29}
   1174 	const primesProduct64 = 0xE221F97C30E94E1D //  {p  primes, 2 < p <= 53}
   1175 
   1176 	var r Word
   1177 	switch _W {
   1178 	case 32:
   1179 		r = n.modW(primesProduct32)
   1180 	case 64:
   1181 		r = n.modW(primesProduct64 & _M)
   1182 	default:
   1183 		panic("Unknown word size")
   1184 	}
   1185 
   1186 	if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 ||
   1187 		r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 {
   1188 		return false
   1189 	}
   1190 
   1191 	if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 ||
   1192 		r%43 == 0 || r%47 == 0 || r%53 == 0) {
   1193 		return false
   1194 	}
   1195 
   1196 	nm1 := nat(nil).sub(n, natOne)
   1197 	// determine q, k such that nm1 = q << k
   1198 	k := nm1.trailingZeroBits()
   1199 	q := nat(nil).shr(nm1, k)
   1200 
   1201 	nm3 := nat(nil).sub(nm1, natTwo)
   1202 	rand := rand.New(rand.NewSource(int64(n[0])))
   1203 
   1204 	var x, y, quotient nat
   1205 	nm3Len := nm3.bitLen()
   1206 
   1207 NextRandom:
   1208 	for i := 0; i < reps; i++ {
   1209 		x = x.random(rand, nm3, nm3Len)
   1210 		x = x.add(x, natTwo)
   1211 		y = y.expNN(x, q, n)
   1212 		if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
   1213 			continue
   1214 		}
   1215 		for j := uint(1); j < k; j++ {
   1216 			y = y.mul(y, y)
   1217 			quotient, y = quotient.div(y, y, n)
   1218 			if y.cmp(nm1) == 0 {
   1219 				continue NextRandom
   1220 			}
   1221 			if y.cmp(natOne) == 0 {
   1222 				return false
   1223 			}
   1224 		}
   1225 		return false
   1226 	}
   1227 
   1228 	return true
   1229 }
   1230 
   1231 // bytes writes the value of z into buf using big-endian encoding.
   1232 // len(buf) must be >= len(z)*_S. The value of z is encoded in the
   1233 // slice buf[i:]. The number i of unused bytes at the beginning of
   1234 // buf is returned as result.
   1235 func (z nat) bytes(buf []byte) (i int) {
   1236 	i = len(buf)
   1237 	for _, d := range z {
   1238 		for j := 0; j < _S; j++ {
   1239 			i--
   1240 			buf[i] = byte(d)
   1241 			d >>= 8
   1242 		}
   1243 	}
   1244 
   1245 	for i < len(buf) && buf[i] == 0 {
   1246 		i++
   1247 	}
   1248 
   1249 	return
   1250 }
   1251 
   1252 // setBytes interprets buf as the bytes of a big-endian unsigned
   1253 // integer, sets z to that value, and returns z.
   1254 func (z nat) setBytes(buf []byte) nat {
   1255 	z = z.make((len(buf) + _S - 1) / _S)
   1256 
   1257 	k := 0
   1258 	s := uint(0)
   1259 	var d Word
   1260 	for i := len(buf); i > 0; i-- {
   1261 		d |= Word(buf[i-1]) << s
   1262 		if s += 8; s == _S*8 {
   1263 			z[k] = d
   1264 			k++
   1265 			s = 0
   1266 			d = 0
   1267 		}
   1268 	}
   1269 	if k < len(z) {
   1270 		z[k] = d
   1271 	}
   1272 
   1273 	return z.norm()
   1274 }
   1275