1 # -*- coding: utf-8 -*- 2 # 3 # Copyright 2011 Sybren A. Stvel <sybren (at] stuvel.eu> 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # https://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 from rsa._compat import zip 18 19 """Common functionality shared by several modules.""" 20 21 22 class NotRelativePrimeError(ValueError): 23 def __init__(self, a, b, d, msg=None): 24 super(NotRelativePrimeError, self).__init__( 25 msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d)) 26 self.a = a 27 self.b = b 28 self.d = d 29 30 31 def bit_size(num): 32 """ 33 Number of bits needed to represent a integer excluding any prefix 34 0 bits. 35 36 Usage:: 37 38 >>> bit_size(1023) 39 10 40 >>> bit_size(1024) 41 11 42 >>> bit_size(1025) 43 11 44 45 :param num: 46 Integer value. If num is 0, returns 0. Only the absolute value of the 47 number is considered. Therefore, signed integers will be abs(num) 48 before the number's bit length is determined. 49 :returns: 50 Returns the number of bits in the integer. 51 """ 52 53 try: 54 return num.bit_length() 55 except AttributeError: 56 raise TypeError('bit_size(num) only supports integers, not %r' % type(num)) 57 58 59 def byte_size(number): 60 """ 61 Returns the number of bytes required to hold a specific long number. 62 63 The number of bytes is rounded up. 64 65 Usage:: 66 67 >>> byte_size(1 << 1023) 68 128 69 >>> byte_size((1 << 1024) - 1) 70 128 71 >>> byte_size(1 << 1024) 72 129 73 74 :param number: 75 An unsigned integer 76 :returns: 77 The number of bytes required to hold a specific long number. 78 """ 79 if number == 0: 80 return 1 81 return ceil_div(bit_size(number), 8) 82 83 84 def ceil_div(num, div): 85 """ 86 Returns the ceiling function of a division between `num` and `div`. 87 88 Usage:: 89 90 >>> ceil_div(100, 7) 91 15 92 >>> ceil_div(100, 10) 93 10 94 >>> ceil_div(1, 4) 95 1 96 97 :param num: Division's numerator, a number 98 :param div: Division's divisor, a number 99 100 :return: Rounded up result of the division between the parameters. 101 """ 102 quanta, mod = divmod(num, div) 103 if mod: 104 quanta += 1 105 return quanta 106 107 108 def extended_gcd(a, b): 109 """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb 110 """ 111 # r = gcd(a,b) i = multiplicitive inverse of a mod b 112 # or j = multiplicitive inverse of b mod a 113 # Neg return values for i or j are made positive mod b or a respectively 114 # Iterateive Version is faster and uses much less stack space 115 x = 0 116 y = 1 117 lx = 1 118 ly = 0 119 oa = a # Remember original a/b to remove 120 ob = b # negative values from return results 121 while b != 0: 122 q = a // b 123 (a, b) = (b, a % b) 124 (x, lx) = ((lx - (q * x)), x) 125 (y, ly) = ((ly - (q * y)), y) 126 if lx < 0: 127 lx += ob # If neg wrap modulo orignal b 128 if ly < 0: 129 ly += oa # If neg wrap modulo orignal a 130 return a, lx, ly # Return only positive values 131 132 133 def inverse(x, n): 134 """Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n) 135 136 >>> inverse(7, 4) 137 3 138 >>> (inverse(143, 4) * 143) % 4 139 1 140 """ 141 142 (divider, inv, _) = extended_gcd(x, n) 143 144 if divider != 1: 145 raise NotRelativePrimeError(x, n, divider) 146 147 return inv 148 149 150 def crt(a_values, modulo_values): 151 """Chinese Remainder Theorem. 152 153 Calculates x such that x = a[i] (mod m[i]) for each i. 154 155 :param a_values: the a-values of the above equation 156 :param modulo_values: the m-values of the above equation 157 :returns: x such that x = a[i] (mod m[i]) for each i 158 159 160 >>> crt([2, 3], [3, 5]) 161 8 162 163 >>> crt([2, 3, 2], [3, 5, 7]) 164 23 165 166 >>> crt([2, 3, 0], [7, 11, 15]) 167 135 168 """ 169 170 m = 1 171 x = 0 172 173 for modulo in modulo_values: 174 m *= modulo 175 176 for (m_i, a_i) in zip(modulo_values, a_values): 177 M_i = m // m_i 178 inv = inverse(M_i, m_i) 179 180 x = (x + a_i * M_i * inv) % m 181 182 return x 183 184 185 if __name__ == '__main__': 186 import doctest 187 188 doctest.testmod() 189