1 # -*- coding: utf-8 -*- 2 # 3 # Copyright 2011 Sybren A. Stvel <sybren (at] stuvel.eu> 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # https://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 """Unittest for saving and loading keys.""" 18 19 import base64 20 import mock 21 import os.path 22 import pickle 23 import unittest 24 import warnings 25 26 from rsa._compat import range 27 import rsa.key 28 29 B64PRIV_DER = b'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt' 30 PRIVATE_DER = base64.standard_b64decode(B64PRIV_DER) 31 32 B64PUB_DER = b'MAwCBQDeKYlRAgMBAAE=' 33 PUBLIC_DER = base64.standard_b64decode(B64PUB_DER) 34 35 PRIVATE_PEM = b'''\ 36 -----BEGIN CONFUSING STUFF----- 37 Cruft before the key 38 39 -----BEGIN RSA PRIVATE KEY----- 40 Comment: something blah 41 42 ''' + B64PRIV_DER + b''' 43 -----END RSA PRIVATE KEY----- 44 45 Stuff after the key 46 -----END CONFUSING STUFF----- 47 ''' 48 49 CLEAN_PRIVATE_PEM = b'''\ 50 -----BEGIN RSA PRIVATE KEY----- 51 ''' + B64PRIV_DER + b''' 52 -----END RSA PRIVATE KEY----- 53 ''' 54 55 PUBLIC_PEM = b'''\ 56 -----BEGIN CONFUSING STUFF----- 57 Cruft before the key 58 59 -----BEGIN RSA PUBLIC KEY----- 60 Comment: something blah 61 62 ''' + B64PUB_DER + b''' 63 -----END RSA PUBLIC KEY----- 64 65 Stuff after the key 66 -----END CONFUSING STUFF----- 67 ''' 68 69 CLEAN_PUBLIC_PEM = b'''\ 70 -----BEGIN RSA PUBLIC KEY----- 71 ''' + B64PUB_DER + b''' 72 -----END RSA PUBLIC KEY----- 73 ''' 74 75 76 class DerTest(unittest.TestCase): 77 """Test saving and loading DER keys.""" 78 79 def test_load_private_key(self): 80 """Test loading private DER keys.""" 81 82 key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_DER, 'DER') 83 expected = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) 84 85 self.assertEqual(expected, key) 86 self.assertEqual(key.exp1, 55063) 87 self.assertEqual(key.exp2, 10095) 88 self.assertEqual(key.coef, 50797) 89 90 @mock.patch('pyasn1.codec.der.decoder.decode') 91 def test_load_malformed_private_key(self, der_decode): 92 """Test loading malformed private DER keys.""" 93 94 # Decode returns an invalid exp2 value. 95 der_decode.return_value = ( 96 [0, 3727264081, 65537, 3349121513, 65063, 57287, 55063, 0, 50797], 97 0, 98 ) 99 100 with warnings.catch_warnings(record=True) as w: 101 # Always print warnings 102 warnings.simplefilter('always') 103 104 # Load 3 keys 105 for _ in range(3): 106 key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_DER, 'DER') 107 108 # Check that 3 warnings were generated. 109 self.assertEqual(3, len(w)) 110 111 for warning in w: 112 self.assertTrue(issubclass(warning.category, UserWarning)) 113 self.assertIn('malformed', str(warning.message)) 114 115 # Check that we are creating the key with correct values 116 self.assertEqual(key.exp1, 55063) 117 self.assertEqual(key.exp2, 10095) 118 self.assertEqual(key.coef, 50797) 119 120 def test_save_private_key(self): 121 """Test saving private DER keys.""" 122 123 key = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) 124 der = key.save_pkcs1('DER') 125 126 self.assertIsInstance(der, bytes) 127 self.assertEqual(PRIVATE_DER, der) 128 129 def test_load_public_key(self): 130 """Test loading public DER keys.""" 131 132 key = rsa.key.PublicKey.load_pkcs1(PUBLIC_DER, 'DER') 133 expected = rsa.key.PublicKey(3727264081, 65537) 134 135 self.assertEqual(expected, key) 136 137 def test_save_public_key(self): 138 """Test saving public DER keys.""" 139 140 key = rsa.key.PublicKey(3727264081, 65537) 141 der = key.save_pkcs1('DER') 142 143 self.assertIsInstance(der, bytes) 144 self.assertEqual(PUBLIC_DER, der) 145 146 147 class PemTest(unittest.TestCase): 148 """Test saving and loading PEM keys.""" 149 150 def test_load_private_key(self): 151 """Test loading private PEM files.""" 152 153 key = rsa.key.PrivateKey.load_pkcs1(PRIVATE_PEM, 'PEM') 154 expected = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) 155 156 self.assertEqual(expected, key) 157 self.assertEqual(key.exp1, 55063) 158 self.assertEqual(key.exp2, 10095) 159 self.assertEqual(key.coef, 50797) 160 161 def test_save_private_key(self): 162 """Test saving private PEM files.""" 163 164 key = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) 165 pem = key.save_pkcs1('PEM') 166 167 self.assertIsInstance(pem, bytes) 168 self.assertEqual(CLEAN_PRIVATE_PEM, pem) 169 170 def test_load_public_key(self): 171 """Test loading public PEM files.""" 172 173 key = rsa.key.PublicKey.load_pkcs1(PUBLIC_PEM, 'PEM') 174 expected = rsa.key.PublicKey(3727264081, 65537) 175 176 self.assertEqual(expected, key) 177 178 def test_save_public_key(self): 179 """Test saving public PEM files.""" 180 181 key = rsa.key.PublicKey(3727264081, 65537) 182 pem = key.save_pkcs1('PEM') 183 184 self.assertIsInstance(pem, bytes) 185 self.assertEqual(CLEAN_PUBLIC_PEM, pem) 186 187 def test_load_from_disk(self): 188 """Test loading a PEM file from disk.""" 189 190 fname = os.path.join(os.path.dirname(__file__), 'private.pem') 191 with open(fname, mode='rb') as privatefile: 192 keydata = privatefile.read() 193 privkey = rsa.key.PrivateKey.load_pkcs1(keydata) 194 195 self.assertEqual(15945948582725241569, privkey.p) 196 self.assertEqual(14617195220284816877, privkey.q) 197 198 199 class PickleTest(unittest.TestCase): 200 """Test saving and loading keys by pickling.""" 201 202 def test_private_key(self): 203 pk = rsa.key.PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) 204 205 pickled = pickle.dumps(pk) 206 unpickled = pickle.loads(pickled) 207 self.assertEqual(pk, unpickled) 208 209 def test_public_key(self): 210 pk = rsa.key.PublicKey(3727264081, 65537) 211 212 pickled = pickle.dumps(pk) 213 unpickled = pickle.loads(pickled) 214 215 self.assertEqual(pk, unpickled) 216