Home | History | Annotate | Download | only in rsa
      1 # -*- coding: utf-8 -*-
      2 #
      3 #  Copyright 2011 Sybren A. Stvel <sybren (at] stuvel.eu>
      4 #
      5 #  Licensed under the Apache License, Version 2.0 (the "License");
      6 #  you may not use this file except in compliance with the License.
      7 #  You may obtain a copy of the License at
      8 #
      9 #      https://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 #  Unless required by applicable law or agreed to in writing, software
     12 #  distributed under the License is distributed on an "AS IS" BASIS,
     13 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 #  See the License for the specific language governing permissions and
     15 #  limitations under the License.
     16 
     17 from rsa._compat import zip
     18 
     19 """Common functionality shared by several modules."""
     20 
     21 
     22 class NotRelativePrimeError(ValueError):
     23     def __init__(self, a, b, d, msg=None):
     24         super(NotRelativePrimeError, self).__init__(
     25             msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
     26         self.a = a
     27         self.b = b
     28         self.d = d
     29 
     30 
     31 def bit_size(num):
     32     """
     33     Number of bits needed to represent a integer excluding any prefix
     34     0 bits.
     35 
     36     Usage::
     37 
     38         >>> bit_size(1023)
     39         10
     40         >>> bit_size(1024)
     41         11
     42         >>> bit_size(1025)
     43         11
     44 
     45     :param num:
     46         Integer value. If num is 0, returns 0. Only the absolute value of the
     47         number is considered. Therefore, signed integers will be abs(num)
     48         before the number's bit length is determined.
     49     :returns:
     50         Returns the number of bits in the integer.
     51     """
     52 
     53     try:
     54         return num.bit_length()
     55     except AttributeError:
     56         raise TypeError('bit_size(num) only supports integers, not %r' % type(num))
     57 
     58 
     59 def byte_size(number):
     60     """
     61     Returns the number of bytes required to hold a specific long number.
     62 
     63     The number of bytes is rounded up.
     64 
     65     Usage::
     66 
     67         >>> byte_size(1 << 1023)
     68         128
     69         >>> byte_size((1 << 1024) - 1)
     70         128
     71         >>> byte_size(1 << 1024)
     72         129
     73 
     74     :param number:
     75         An unsigned integer
     76     :returns:
     77         The number of bytes required to hold a specific long number.
     78     """
     79     if number == 0:
     80         return 1
     81     return ceil_div(bit_size(number), 8)
     82 
     83 
     84 def ceil_div(num, div):
     85     """
     86     Returns the ceiling function of a division between `num` and `div`.
     87 
     88     Usage::
     89 
     90         >>> ceil_div(100, 7)
     91         15
     92         >>> ceil_div(100, 10)
     93         10
     94         >>> ceil_div(1, 4)
     95         1
     96 
     97     :param num: Division's numerator, a number
     98     :param div: Division's divisor, a number
     99 
    100     :return: Rounded up result of the division between the parameters.
    101     """
    102     quanta, mod = divmod(num, div)
    103     if mod:
    104         quanta += 1
    105     return quanta
    106 
    107 
    108 def extended_gcd(a, b):
    109     """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
    110     """
    111     # r = gcd(a,b) i = multiplicitive inverse of a mod b
    112     #      or      j = multiplicitive inverse of b mod a
    113     # Neg return values for i or j are made positive mod b or a respectively
    114     # Iterateive Version is faster and uses much less stack space
    115     x = 0
    116     y = 1
    117     lx = 1
    118     ly = 0
    119     oa = a  # Remember original a/b to remove
    120     ob = b  # negative values from return results
    121     while b != 0:
    122         q = a // b
    123         (a, b) = (b, a % b)
    124         (x, lx) = ((lx - (q * x)), x)
    125         (y, ly) = ((ly - (q * y)), y)
    126     if lx < 0:
    127         lx += ob  # If neg wrap modulo orignal b
    128     if ly < 0:
    129         ly += oa  # If neg wrap modulo orignal a
    130     return a, lx, ly  # Return only positive values
    131 
    132 
    133 def inverse(x, n):
    134     """Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n)
    135 
    136     >>> inverse(7, 4)
    137     3
    138     >>> (inverse(143, 4) * 143) % 4
    139     1
    140     """
    141 
    142     (divider, inv, _) = extended_gcd(x, n)
    143 
    144     if divider != 1:
    145         raise NotRelativePrimeError(x, n, divider)
    146 
    147     return inv
    148 
    149 
    150 def crt(a_values, modulo_values):
    151     """Chinese Remainder Theorem.
    152 
    153     Calculates x such that x = a[i] (mod m[i]) for each i.
    154 
    155     :param a_values: the a-values of the above equation
    156     :param modulo_values: the m-values of the above equation
    157     :returns: x such that x = a[i] (mod m[i]) for each i
    158 
    159 
    160     >>> crt([2, 3], [3, 5])
    161     8
    162 
    163     >>> crt([2, 3, 2], [3, 5, 7])
    164     23
    165 
    166     >>> crt([2, 3, 0], [7, 11, 15])
    167     135
    168     """
    169 
    170     m = 1
    171     x = 0
    172 
    173     for modulo in modulo_values:
    174         m *= modulo
    175 
    176     for (m_i, a_i) in zip(modulo_values, a_values):
    177         M_i = m // m_i
    178         inv = inverse(M_i, m_i)
    179 
    180         x = (x + a_i * M_i * inv) % m
    181 
    182     return x
    183 
    184 
    185 if __name__ == '__main__':
    186     import doctest
    187 
    188     doctest.testmod()
    189