Home | History | Annotate | Download | only in layers
      1 #############################################################################
      2 ## ipsec.py --- IPsec support for Scapy                                    ##
      3 ##                                                                         ##
      4 ## Copyright (C) 2014  6WIND                                               ##
      5 ##                                                                         ##
      6 ## This program is free software; you can redistribute it and/or modify it ##
      7 ## under the terms of the GNU General Public License version 2 as          ##
      8 ## published by the Free Software Foundation.                              ##
      9 ##                                                                         ##
     10 ## This program is distributed in the hope that it will be useful, but     ##
     11 ## WITHOUT ANY WARRANTY; without even the implied warranty of              ##
     12 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU       ##
     13 ## General Public License for more details.                                ##
     14 #############################################################################
     15 """
     16 IPsec layer
     17 ===========
     18 
     19 Example of use:
     20 
     21 >>> sa = SecurityAssociation(ESP, spi=0xdeadbeef, crypt_algo='AES-CBC',
     22 ...                          crypt_key='sixteenbytes key')
     23 >>> p = IP(src='1.1.1.1', dst='2.2.2.2')
     24 >>> p /= TCP(sport=45012, dport=80)
     25 >>> p /= Raw('testdata')
     26 >>> p = IP(raw(p))
     27 >>> p
     28 <IP  version=4L ihl=5L tos=0x0 len=48 id=1 flags= frag=0L ttl=64 proto=tcp chksum=0x74c2 src=1.1.1.1 dst=2.2.2.2 options=[] |<TCP  sport=45012 dport=http seq=0 ack=0 dataofs=5L reserved=0L flags=S window=8192 chksum=0x1914 urgptr=0 options=[] |<Raw  load='testdata' |>>>
     29 >>>
     30 >>> e = sa.encrypt(p)
     31 >>> e
     32 <IP  version=4L ihl=5L tos=0x0 len=76 id=1 flags= frag=0L ttl=64 proto=esp chksum=0x747a src=1.1.1.1 dst=2.2.2.2 |<ESP  spi=0xdeadbeef seq=1 data=b'\xf8\xdb\x1e\x83[T\xab\\\xd2\x1b\xed\xd1\xe5\xc8Y\xc2\xa5d\x92\xc1\x05\x17\xa6\x92\x831\xe6\xc1]\x9a\xd6K}W\x8bFfd\xa5B*+\xde\xc8\x89\xbf{\xa9' |>>
     33 >>>
     34 >>> d = sa.decrypt(e)
     35 >>> d
     36 <IP  version=4L ihl=5L tos=0x0 len=48 id=1 flags= frag=0L ttl=64 proto=tcp chksum=0x74c2 src=1.1.1.1 dst=2.2.2.2 |<TCP  sport=45012 dport=http seq=0 ack=0 dataofs=5L reserved=0L flags=S window=8192 chksum=0x1914 urgptr=0 options=[] |<Raw  load='testdata' |>>>
     37 >>>
     38 >>> d == p
     39 True
     40 """
     41 
     42 from __future__ import absolute_import
     43 from fractions import gcd
     44 import os
     45 import socket
     46 import struct
     47 
     48 from scapy.config import conf, crypto_validator
     49 from scapy.compat import orb, raw
     50 from scapy.data import IP_PROTOS
     51 from scapy.compat import *
     52 from scapy.error import log_loading
     53 from scapy.fields import ByteEnumField, ByteField, IntField, PacketField, \
     54     ShortField, StrField, XIntField, XStrField, XStrLenField
     55 from scapy.packet import Packet, bind_layers, Raw
     56 from scapy.layers.inet import IP, UDP
     57 import scapy.modules.six as six
     58 from scapy.modules.six.moves import range
     59 from scapy.layers.inet6 import IPv6, IPv6ExtHdrHopByHop, IPv6ExtHdrDestOpt, \
     60     IPv6ExtHdrRouting
     61 
     62 
     63 #------------------------------------------------------------------------------
     64 class AH(Packet):
     65     """
     66     Authentication Header
     67 
     68     See https://tools.ietf.org/rfc/rfc4302.txt
     69     """
     70 
     71     name = 'AH'
     72 
     73     def __get_icv_len(self):
     74         """
     75         Compute the size of the ICV based on the payloadlen field.
     76         Padding size is included as it can only be known from the authentication
     77         algorithm provided by the Security Association.
     78         """
     79         # payloadlen = length of AH in 32-bit words (4-byte units), minus "2"
     80         # payloadlen = 3 32-bit word fixed fields + ICV + padding - 2
     81         # ICV = (payloadlen + 2 - 3 - padding) in 32-bit words
     82         return (self.payloadlen - 1) * 4
     83 
     84     fields_desc = [
     85         ByteEnumField('nh', None, IP_PROTOS),
     86         ByteField('payloadlen', None),
     87         ShortField('reserved', None),
     88         XIntField('spi', 0x0),
     89         IntField('seq', 0),
     90         XStrLenField('icv', None, length_from=__get_icv_len),
     91         # Padding len can only be known with the SecurityAssociation.auth_algo
     92         XStrLenField('padding', None, length_from=lambda x: 0),
     93     ]
     94 
     95     overload_fields = {
     96         IP: {'proto': socket.IPPROTO_AH},
     97         IPv6: {'nh': socket.IPPROTO_AH},
     98         IPv6ExtHdrHopByHop: {'nh': socket.IPPROTO_AH},
     99         IPv6ExtHdrDestOpt: {'nh': socket.IPPROTO_AH},
    100         IPv6ExtHdrRouting: {'nh': socket.IPPROTO_AH},
    101     }
    102 
    103 bind_layers(IP, AH, proto=socket.IPPROTO_AH)
    104 bind_layers(IPv6, AH, nh=socket.IPPROTO_AH)
    105 bind_layers(AH, IP, nh=socket.IPPROTO_IP)
    106 bind_layers(AH, IPv6, nh=socket.IPPROTO_IPV6)
    107 
    108 #------------------------------------------------------------------------------
    109 class ESP(Packet):
    110     """
    111     Encapsulated Security Payload
    112 
    113     See https://tools.ietf.org/rfc/rfc4303.txt
    114     """
    115     name = 'ESP'
    116 
    117     fields_desc = [
    118         XIntField('spi', 0x0),
    119         IntField('seq', 0),
    120         XStrField('data', None),
    121     ]
    122 
    123     overload_fields = {
    124         IP: {'proto': socket.IPPROTO_ESP},
    125         IPv6: {'nh': socket.IPPROTO_ESP},
    126         IPv6ExtHdrHopByHop: {'nh': socket.IPPROTO_ESP},
    127         IPv6ExtHdrDestOpt: {'nh': socket.IPPROTO_ESP},
    128         IPv6ExtHdrRouting: {'nh': socket.IPPROTO_ESP},
    129     }
    130 
    131 bind_layers(IP, ESP, proto=socket.IPPROTO_ESP)
    132 bind_layers(IPv6, ESP, nh=socket.IPPROTO_ESP)
    133 bind_layers(UDP, ESP, dport=4500)  # NAT-Traversal encapsulation
    134 bind_layers(UDP, ESP, sport=4500)  # NAT-Traversal encapsulation
    135 
    136 #------------------------------------------------------------------------------
    137 class _ESPPlain(Packet):
    138     """
    139     Internal class to represent unencrypted ESP packets.
    140     """
    141     name = 'ESP'
    142 
    143     fields_desc = [
    144         XIntField('spi', 0x0),
    145         IntField('seq', 0),
    146 
    147         StrField('iv', ''),
    148         PacketField('data', '', Raw),
    149         StrField('padding', ''),
    150 
    151         ByteField('padlen', 0),
    152         ByteEnumField('nh', 0, IP_PROTOS),
    153         StrField('icv', ''),
    154     ]
    155 
    156     def data_for_encryption(self):
    157         return raw(self.data) + self.padding + struct.pack("BB", self.padlen, self.nh)
    158 
    159 #------------------------------------------------------------------------------
    160 if conf.crypto_valid:
    161     from cryptography.exceptions import InvalidTag
    162     from cryptography.hazmat.backends import default_backend
    163     from cryptography.hazmat.primitives.ciphers import (
    164         Cipher,
    165         algorithms,
    166         modes,
    167     )
    168 else:
    169     log_loading.info("Can't import python-cryptography v1.7+. "
    170                      "Disabled IPsec encryption/authentication.")
    171     InvalidTag = default_backend = None
    172     Cipher = algorithms = modes = None
    173 
    174 #------------------------------------------------------------------------------
    175 def _lcm(a, b):
    176     """
    177     Least Common Multiple between 2 integers.
    178     """
    179     if a == 0 or b == 0:
    180         return 0
    181     else:
    182         return abs(a * b) // gcd(a, b)
    183 
    184 class CryptAlgo(object):
    185     """
    186     IPsec encryption algorithm
    187     """
    188 
    189     def __init__(self, name, cipher, mode, block_size=None, iv_size=None,
    190                  key_size=None, icv_size=None, salt_size=None, format_mode_iv=None):
    191         """
    192         @param name: the name of this encryption algorithm
    193         @param cipher: a Cipher module
    194         @param mode: the mode used with the cipher module
    195         @param block_size: the length a block for this algo. Defaults to the
    196                            `block_size` of the cipher.
    197         @param iv_size: the length of the initialization vector of this algo.
    198                         Defaults to the `block_size` of the cipher.
    199         @param key_size: an integer or list/tuple of integers. If specified,
    200                          force the secret keys length to one of the values.
    201                          Defaults to the `key_size` of the cipher.
    202         @param icv_size: the length of the Integrity Check Value of this algo.
    203                          Used by Combined Mode Algorithms e.g. GCM
    204         @param salt_size: the length of the salt to use as the IV prefix.
    205                           Usually used by Counter modes e.g. CTR
    206         @param format_mode_iv: function to format the Initialization Vector
    207                                e.g. handle the salt value
    208                                Default is the random buffer from `generate_iv`
    209         """
    210         self.name = name
    211         self.cipher = cipher
    212         self.mode = mode
    213         self.icv_size = icv_size
    214 
    215         if modes and self.mode is not None:
    216             self.is_aead = issubclass(self.mode,
    217                                       modes.ModeWithAuthenticationTag)
    218         else:
    219             self.is_aead = False
    220 
    221         if block_size is not None:
    222             self.block_size = block_size
    223         elif cipher is not None:
    224             self.block_size = cipher.block_size // 8
    225         else:
    226             self.block_size = 1
    227 
    228         if iv_size is None:
    229             self.iv_size = self.block_size
    230         else:
    231             self.iv_size = iv_size
    232 
    233         if key_size is not None:
    234             self.key_size = key_size
    235         elif cipher is not None:
    236             self.key_size = tuple(i // 8 for i in cipher.key_sizes)
    237         else:
    238             self.key_size = None
    239 
    240         if salt_size is None:
    241             self.salt_size = 0
    242         else:
    243             self.salt_size = salt_size
    244 
    245         if format_mode_iv is None:
    246             self._format_mode_iv = lambda iv, **kw: iv
    247         else:
    248             self._format_mode_iv = format_mode_iv
    249 
    250     def check_key(self, key):
    251         """
    252         Check that the key length is valid.
    253 
    254         @param key:    a byte string
    255         """
    256         if self.key_size and not (len(key) == self.key_size or len(key) in self.key_size):
    257             raise TypeError('invalid key size %s, must be %s' %
    258                             (len(key), self.key_size))
    259 
    260     def generate_iv(self):
    261         """
    262         Generate a random initialization vector.
    263         """
    264         # XXX: Handle counter modes with real counters? RFCs allow the use of
    265         # XXX: random bytes for counters, so it is not wrong to do it that way
    266         return os.urandom(self.iv_size)
    267 
    268     @crypto_validator
    269     def new_cipher(self, key, mode_iv, digest=None):
    270         """
    271         @param key:     the secret key, a byte string
    272         @param mode_iv: the initialization vector or nonce, a byte string.
    273                         Formatted by `format_mode_iv`.
    274         @param digest:  also known as tag or icv. A byte string containing the
    275                         digest of the encrypted data. Only use this during
    276                         decryption!
    277 
    278         @return:    an initialized cipher object for this algo
    279         """
    280         if self.is_aead and digest is not None:
    281             # With AEAD, the mode needs the digest during decryption.
    282             return Cipher(
    283                 self.cipher(key),
    284                 self.mode(mode_iv, digest, len(digest)),
    285                 default_backend(),
    286             )
    287         else:
    288             return Cipher(
    289                 self.cipher(key),
    290                 self.mode(mode_iv),
    291                 default_backend(),
    292             )
    293 
    294     def pad(self, esp):
    295         """
    296         Add the correct amount of padding so that the data to encrypt is
    297         exactly a multiple of the algorithm's block size.
    298 
    299         Also, make sure that the total ESP packet length is a multiple of 4
    300         bytes.
    301 
    302         @param esp:    an unencrypted _ESPPlain packet
    303 
    304         @return:    an unencrypted _ESPPlain packet with valid padding
    305         """
    306         # 2 extra bytes for padlen and nh
    307         data_len = len(esp.data) + 2
    308 
    309         # according to the RFC4303, section 2.4. Padding (for Encryption)
    310         # the size of the ESP payload must be a multiple of 32 bits
    311         align = _lcm(self.block_size, 4)
    312 
    313         # pad for block size
    314         esp.padlen = -data_len % align
    315 
    316         # Still according to the RFC, the default value for padding *MUST* be an
    317         # array of bytes starting from 1 to padlen
    318         # TODO: Handle padding function according to the encryption algo
    319         esp.padding = struct.pack("B" * esp.padlen, *range(1, esp.padlen + 1))
    320 
    321         # If the following test fails, it means that this algo does not comply
    322         # with the RFC
    323         payload_len = len(esp.iv) + len(esp.data) + len(esp.padding) + 2
    324         if payload_len % 4 != 0:
    325             raise ValueError('The size of the ESP data is not aligned to 32 bits after padding.')
    326 
    327         return esp
    328 
    329     def encrypt(self, sa, esp, key):
    330         """
    331         Encrypt an ESP packet
    332 
    333         @param sa:   the SecurityAssociation associated with the ESP packet.
    334         @param esp:  an unencrypted _ESPPlain packet with valid padding
    335         @param key:  the secret key used for encryption
    336 
    337         @return:    a valid ESP packet encrypted with this algorithm
    338         """
    339         data = esp.data_for_encryption()
    340 
    341         if self.cipher:
    342             mode_iv = self._format_mode_iv(algo=self, sa=sa, iv=esp.iv)
    343             cipher = self.new_cipher(key, mode_iv)
    344             encryptor = cipher.encryptor()
    345 
    346             if self.is_aead:
    347                 aad = struct.pack('!LL', esp.spi, esp.seq)
    348                 encryptor.authenticate_additional_data(aad)
    349                 data = encryptor.update(data) + encryptor.finalize()
    350                 data += encryptor.tag[:self.icv_size]
    351             else:
    352                 data = encryptor.update(data) + encryptor.finalize()
    353 
    354         return ESP(spi=esp.spi, seq=esp.seq, data=esp.iv + data)
    355 
    356     def decrypt(self, sa, esp, key, icv_size=None):
    357         """
    358         Decrypt an ESP packet
    359 
    360         @param sa:         the SecurityAssociation associated with the ESP packet.
    361         @param esp:        an encrypted ESP packet
    362         @param key:        the secret key used for encryption
    363         @param icv_size:   the length of the icv used for integrity check
    364 
    365         @return:    a valid ESP packet encrypted with this algorithm
    366         @raise IPSecIntegrityError: if the integrity check fails with an AEAD
    367                                     algorithm
    368         """
    369         if icv_size is None:
    370             icv_size = self.icv_size if self.is_aead else 0
    371 
    372         iv = esp.data[:self.iv_size]
    373         data = esp.data[self.iv_size:len(esp.data) - icv_size]
    374         icv = esp.data[len(esp.data) - icv_size:]
    375 
    376         if self.cipher:
    377             mode_iv = self._format_mode_iv(sa=sa, iv=iv)
    378             cipher = self.new_cipher(key, mode_iv, icv)
    379             decryptor = cipher.decryptor()
    380 
    381             if self.is_aead:
    382                 # Tag value check is done during the finalize method
    383                 decryptor.authenticate_additional_data(
    384                     struct.pack('!LL', esp.spi, esp.seq)
    385                 )
    386 
    387             try:
    388                 data = decryptor.update(data) + decryptor.finalize()
    389             except InvalidTag as err:
    390                 raise IPSecIntegrityError(err)
    391 
    392         # extract padlen and nh
    393         padlen = orb(data[-2])
    394         nh = orb(data[-1])
    395 
    396         # then use padlen to determine data and padding
    397         data = data[:len(data) - padlen - 2]
    398         padding = data[len(data) - padlen - 2: len(data) - 2]
    399 
    400         return _ESPPlain(spi=esp.spi,
    401                         seq=esp.seq,
    402                         iv=iv,
    403                         data=data,
    404                         padding=padding,
    405                         padlen=padlen,
    406                         nh=nh,
    407                         icv=icv)
    408 
    409 #------------------------------------------------------------------------------
    410 # The names of the encryption algorithms are the same than in scapy.contrib.ikev2
    411 # see http://www.iana.org/assignments/ikev2-parameters/ikev2-parameters.xhtml
    412 
    413 CRYPT_ALGOS = {
    414     'NULL': CryptAlgo('NULL', cipher=None, mode=None, iv_size=0),
    415 }
    416 
    417 if algorithms:
    418     CRYPT_ALGOS['AES-CBC'] = CryptAlgo('AES-CBC',
    419                                        cipher=algorithms.AES,
    420                                        mode=modes.CBC)
    421     _aes_ctr_format_mode_iv = lambda sa, iv, **kw: sa.crypt_salt + iv + b'\x00\x00\x00\x01'
    422     CRYPT_ALGOS['AES-CTR'] = CryptAlgo('AES-CTR',
    423                                        cipher=algorithms.AES,
    424                                        mode=modes.CTR,
    425                                        iv_size=8,
    426                                        salt_size=4,
    427                                        format_mode_iv=_aes_ctr_format_mode_iv)
    428     _salt_format_mode_iv = lambda sa, iv, **kw: sa.crypt_salt + iv
    429     CRYPT_ALGOS['AES-GCM'] = CryptAlgo('AES-GCM',
    430                                        cipher=algorithms.AES,
    431                                        mode=modes.GCM,
    432                                        salt_size=4,
    433                                        iv_size=8,
    434                                        icv_size=16,
    435                                        format_mode_iv=_salt_format_mode_iv)
    436     if hasattr(modes, 'CCM'):
    437         CRYPT_ALGOS['AES-CCM'] = CryptAlgo('AES-CCM',
    438                                            cipher=algorithms.AES,
    439                                            mode=modes.CCM,
    440                                            iv_size=8,
    441                                            salt_size=3,
    442                                            icv_size=16,
    443                                            format_mode_iv=_salt_format_mode_iv)
    444     # XXX: Flagged as weak by 'cryptography'. Kept for backward compatibility
    445     CRYPT_ALGOS['Blowfish'] = CryptAlgo('Blowfish',
    446                                         cipher=algorithms.Blowfish,
    447                                         mode=modes.CBC)
    448     # XXX: RFC7321 states that DES *MUST NOT* be implemented.
    449     # XXX: Keep for backward compatibility?
    450     # Using a TripleDES cipher algorithm for DES is done by using the same 64
    451     # bits key 3 times (done by cryptography when given a 64 bits key)
    452     CRYPT_ALGOS['DES'] = CryptAlgo('DES',
    453                                    cipher=algorithms.TripleDES,
    454                                    mode=modes.CBC,
    455                                    key_size=(8,))
    456     CRYPT_ALGOS['3DES'] = CryptAlgo('3DES',
    457                                     cipher=algorithms.TripleDES,
    458                                     mode=modes.CBC)
    459     CRYPT_ALGOS['CAST'] = CryptAlgo('CAST',
    460                                     cipher=algorithms.CAST5,
    461                                     mode=modes.CBC)
    462 
    463 #------------------------------------------------------------------------------
    464 if conf.crypto_valid:
    465     from cryptography.hazmat.primitives.hmac import HMAC
    466     from cryptography.hazmat.primitives.cmac import CMAC
    467     from cryptography.hazmat.primitives import hashes
    468 else:
    469     # no error if cryptography is not available but authentication won't be supported
    470     HMAC = CMAC = hashes = None
    471 
    472 #------------------------------------------------------------------------------
    473 class IPSecIntegrityError(Exception):
    474     """
    475     Error risen when the integrity check fails.
    476     """
    477     pass
    478 
    479 class AuthAlgo(object):
    480     """
    481     IPsec integrity algorithm
    482     """
    483 
    484     def __init__(self, name, mac, digestmod, icv_size, key_size=None):
    485         """
    486         @param name: the name of this integrity algorithm
    487         @param mac: a Message Authentication Code module
    488         @param digestmod: a Hash or Cipher module
    489         @param icv_size: the length of the integrity check value of this algo
    490         @param key_size: an integer or list/tuple of integers. If specified,
    491                          force the secret keys length to one of the values.
    492                          Defaults to the `key_size` of the cipher.
    493         """
    494         self.name = name
    495         self.mac = mac
    496         self.digestmod = digestmod
    497         self.icv_size = icv_size
    498         self.key_size = key_size
    499 
    500     def check_key(self, key):
    501         """
    502         Check that the key length is valid.
    503 
    504         @param key:    a byte string
    505         """
    506         if self.key_size and len(key) not in self.key_size:
    507             raise TypeError('invalid key size %s, must be one of %s' %
    508                             (len(key), self.key_size))
    509 
    510     @crypto_validator
    511     def new_mac(self, key):
    512         """
    513         @param key:    a byte string
    514         @return:       an initialized mac object for this algo
    515         """
    516         if self.mac is CMAC:
    517             return self.mac(self.digestmod(key), default_backend())
    518         else:
    519             return self.mac(key, self.digestmod(), default_backend())
    520 
    521     def sign(self, pkt, key):
    522         """
    523         Sign an IPsec (ESP or AH) packet with this algo.
    524 
    525         @param pkt:    a packet that contains a valid encrypted ESP or AH layer
    526         @param key:    the authentication key, a byte string
    527 
    528         @return: the signed packet
    529         """
    530         if not self.mac:
    531             return pkt
    532 
    533         mac = self.new_mac(key)
    534 
    535         if pkt.haslayer(ESP):
    536             mac.update(raw(pkt[ESP]))
    537             pkt[ESP].data += mac.finalize()[:self.icv_size]
    538 
    539         elif pkt.haslayer(AH):
    540             clone = zero_mutable_fields(pkt.copy(), sending=True)
    541             mac.update(raw(clone))
    542             pkt[AH].icv = mac.finalize()[:self.icv_size]
    543 
    544         return pkt
    545 
    546     def verify(self, pkt, key):
    547         """
    548         Check that the integrity check value (icv) of a packet is valid.
    549 
    550         @param pkt:    a packet that contains a valid encrypted ESP or AH layer
    551         @param key:    the authentication key, a byte string
    552 
    553         @raise IPSecIntegrityError: if the integrity check fails
    554         """
    555         if not self.mac or self.icv_size == 0:
    556             return
    557 
    558         mac = self.new_mac(key)
    559 
    560         pkt_icv = 'not found'
    561         computed_icv = 'not computed'
    562 
    563         if isinstance(pkt, ESP):
    564             pkt_icv = pkt.data[len(pkt.data) - self.icv_size:]
    565             clone = pkt.copy()
    566             clone.data = clone.data[:len(clone.data) - self.icv_size]
    567 
    568         elif pkt.haslayer(AH):
    569             if len(pkt[AH].icv) != self.icv_size:
    570                 # Fill padding since we know the actual icv_size
    571                 pkt[AH].padding = pkt[AH].icv[self.icv_size:]
    572                 pkt[AH].icv = pkt[AH].icv[:self.icv_size]
    573             pkt_icv = pkt[AH].icv
    574             clone = zero_mutable_fields(pkt.copy(), sending=False)
    575 
    576         mac.update(raw(clone))
    577         computed_icv = mac.finalize()[:self.icv_size]
    578 
    579         # XXX: Cannot use mac.verify because the ICV can be truncated
    580         if pkt_icv != computed_icv:
    581             raise IPSecIntegrityError('pkt_icv=%r, computed_icv=%r' %
    582                                       (pkt_icv, computed_icv))
    583 
    584 #------------------------------------------------------------------------------
    585 # The names of the integrity algorithms are the same than in scapy.contrib.ikev2
    586 # see http://www.iana.org/assignments/ikev2-parameters/ikev2-parameters.xhtml
    587 
    588 AUTH_ALGOS = {
    589     'NULL': AuthAlgo('NULL', mac=None, digestmod=None, icv_size=0),
    590 }
    591 
    592 if HMAC and hashes:
    593     # XXX: NIST has deprecated SHA1 but is required by RFC7321
    594     AUTH_ALGOS['HMAC-SHA1-96'] = AuthAlgo('HMAC-SHA1-96',
    595                                           mac=HMAC,
    596                                           digestmod=hashes.SHA1,
    597                                           icv_size=12)
    598     AUTH_ALGOS['SHA2-256-128'] = AuthAlgo('SHA2-256-128',
    599                                           mac=HMAC,
    600                                           digestmod=hashes.SHA256,
    601                                           icv_size=16)
    602     AUTH_ALGOS['SHA2-384-192'] = AuthAlgo('SHA2-384-192',
    603                                           mac=HMAC,
    604                                           digestmod=hashes.SHA384,
    605                                           icv_size=24)
    606     AUTH_ALGOS['SHA2-512-256'] = AuthAlgo('SHA2-512-256',
    607                                           mac=HMAC,
    608                                           digestmod=hashes.SHA512,
    609                                           icv_size=32)
    610     # XXX:Flagged as deprecated by 'cryptography'. Kept for backward compat
    611     AUTH_ALGOS['HMAC-MD5-96'] = AuthAlgo('HMAC-MD5-96',
    612                                          mac=HMAC,
    613                                          digestmod=hashes.MD5,
    614                                          icv_size=12)
    615 if CMAC and algorithms:
    616     AUTH_ALGOS['AES-CMAC-96'] = AuthAlgo('AES-CMAC-96',
    617                                       mac=CMAC,
    618                                       digestmod=algorithms.AES,
    619                                       icv_size=12,
    620                                       key_size=(16,))
    621 
    622 #------------------------------------------------------------------------------
    623 def split_for_transport(orig_pkt, transport_proto):
    624     """
    625     Split an IP(v6) packet in the correct location to insert an ESP or AH
    626     header.
    627 
    628     @param orig_pkt: the packet to split. Must be an IP or IPv6 packet
    629     @param transport_proto: the IPsec protocol number that will be inserted
    630                             at the split position.
    631     @return: a tuple (header, nh, payload) where nh is the protocol number of
    632              payload.
    633     """
    634     # force resolution of default fields to avoid padding errors
    635     header = orig_pkt.__class__(raw(orig_pkt))
    636     next_hdr = header.payload
    637     nh = None
    638 
    639     if header.version == 4:
    640         nh = header.proto
    641         header.proto = transport_proto
    642         header.remove_payload()
    643         del header.chksum
    644         del header.len
    645 
    646         return header, nh, next_hdr
    647     else:
    648         found_rt_hdr = False
    649         prev = header
    650 
    651         # Since the RFC 4302 is vague about where the ESP/AH headers should be
    652         # inserted in IPv6, I chose to follow the linux implementation.
    653         while isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrRouting, IPv6ExtHdrDestOpt)):
    654             if isinstance(next_hdr, IPv6ExtHdrHopByHop):
    655                 pass
    656             if isinstance(next_hdr, IPv6ExtHdrRouting):
    657                 found_rt_hdr = True
    658             elif isinstance(next_hdr, IPv6ExtHdrDestOpt) and found_rt_hdr:
    659                 break
    660 
    661             prev = next_hdr
    662             next_hdr = next_hdr.payload
    663 
    664         nh = prev.nh
    665         prev.nh = transport_proto
    666         prev.remove_payload()
    667         del header.plen
    668 
    669         return header, nh, next_hdr
    670 
    671 #------------------------------------------------------------------------------
    672 # see RFC 4302 - Appendix A. Mutability of IP Options/Extension Headers
    673 IMMUTABLE_IPV4_OPTIONS = (
    674     0, # End Of List
    675     1, # No OPeration
    676     2, # Security
    677     5, # Extended Security
    678     6, # Commercial Security
    679     20, # Router Alert
    680     21, # Sender Directed Multi-Destination Delivery
    681 )
    682 def zero_mutable_fields(pkt, sending=False):
    683     """
    684     When using AH, all "mutable" fields must be "zeroed" before calculating
    685     the ICV. See RFC 4302, Section 3.3.3.1. Handling Mutable Fields.
    686 
    687     @param pkt: an IP(v6) packet containing an AH layer.
    688                 NOTE: The packet will be modified
    689     @param sending: if true, ipv6 routing headers will not be reordered
    690     """
    691 
    692     if pkt.haslayer(AH):
    693         pkt[AH].icv = b"\x00" * len(pkt[AH].icv)
    694     else:
    695         raise TypeError('no AH layer found')
    696 
    697     if pkt.version == 4:
    698         # the tos field has been replaced by DSCP and ECN
    699         # Routers may rewrite the DS field as needed to provide a
    700         # desired local or end-to-end service
    701         pkt.tos = 0
    702         # an intermediate router might set the DF bit, even if the source
    703         # did not select it.
    704         pkt.flags = 0
    705         # changed en route as a normal course of processing by routers
    706         pkt.ttl = 0
    707         # will change if any of these other fields change
    708         pkt.chksum = 0
    709 
    710         immutable_opts = []
    711         for opt in pkt.options:
    712             if opt.option in IMMUTABLE_IPV4_OPTIONS:
    713                 immutable_opts.append(opt)
    714             else:
    715                 immutable_opts.append(Raw(b"\x00" * len(opt)))
    716         pkt.options = immutable_opts
    717 
    718     else:
    719         # holds DSCP and ECN
    720         pkt.tc = 0
    721         # The flow label described in AHv1 was mutable, and in RFC 2460 [DH98]
    722         # was potentially mutable. To retain compatibility with existing AH
    723         # implementations, the flow label is not included in the ICV in AHv2.
    724         pkt.fl = 0
    725         # same as ttl
    726         pkt.hlim = 0
    727 
    728         next_hdr = pkt.payload
    729 
    730         while isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrRouting, IPv6ExtHdrDestOpt)):
    731             if isinstance(next_hdr, (IPv6ExtHdrHopByHop, IPv6ExtHdrDestOpt)):
    732                 for opt in next_hdr.options:
    733                     if opt.otype & 0x20:
    734                         # option data can change en-route and must be zeroed
    735                         opt.optdata = b"\x00" * opt.optlen
    736             elif isinstance(next_hdr, IPv6ExtHdrRouting) and sending:
    737                 # The sender must order the field so that it appears as it
    738                 # will at the receiver, prior to performing the ICV computation.
    739                 next_hdr.segleft = 0
    740                 if next_hdr.addresses:
    741                     final = next_hdr.addresses.pop()
    742                     next_hdr.addresses.insert(0, pkt.dst)
    743                     pkt.dst = final
    744             else:
    745                 break
    746 
    747             next_hdr = next_hdr.payload
    748 
    749     return pkt
    750 
    751 #------------------------------------------------------------------------------
    752 class SecurityAssociation(object):
    753     """
    754     This class is responsible of "encryption" and "decryption" of IPsec packets.
    755     """
    756 
    757     SUPPORTED_PROTOS = (IP, IPv6)
    758 
    759     def __init__(self, proto, spi, seq_num=1, crypt_algo=None, crypt_key=None,
    760                  auth_algo=None, auth_key=None, tunnel_header=None, nat_t_header=None):
    761         """
    762         @param proto: the IPsec proto to use (ESP or AH)
    763         @param spi: the Security Parameters Index of this SA
    764         @param seq_num: the initial value for the sequence number on encrypted
    765                         packets
    766         @param crypt_algo: the encryption algorithm name (only used with ESP)
    767         @param crypt_key: the encryption key (only used with ESP)
    768         @param auth_algo: the integrity algorithm name
    769         @param auth_key: the integrity key
    770         @param tunnel_header: an instance of a IP(v6) header that will be used
    771                               to encapsulate the encrypted packets.
    772         @param nat_t_header: an instance of a UDP header that will be used
    773                              for NAT-Traversal.
    774         """
    775 
    776         if proto not in (ESP, AH, ESP.name, AH.name):
    777             raise ValueError("proto must be either ESP or AH")
    778         if isinstance(proto, six.string_types):
    779             self.proto = eval(proto)
    780         else:
    781             self.proto = proto
    782 
    783         self.spi = spi
    784         self.seq_num = seq_num
    785 
    786         if crypt_algo:
    787             if crypt_algo not in CRYPT_ALGOS:
    788                 raise TypeError('unsupported encryption algo %r, try %r' %
    789                                 (crypt_algo, list(CRYPT_ALGOS)))
    790             self.crypt_algo = CRYPT_ALGOS[crypt_algo]
    791 
    792             if crypt_key:
    793                 salt_size = self.crypt_algo.salt_size
    794                 self.crypt_key = crypt_key[:len(crypt_key) - salt_size]
    795                 self.crypt_salt = crypt_key[len(crypt_key) - salt_size:]
    796             else:
    797                 self.crypt_key = None
    798                 self.crypt_salt = None
    799 
    800         else:
    801             self.crypt_algo = CRYPT_ALGOS['NULL']
    802             self.crypt_key = None
    803 
    804         if auth_algo:
    805             if auth_algo not in AUTH_ALGOS:
    806                 raise TypeError('unsupported integrity algo %r, try %r' %
    807                                 (auth_algo, list(AUTH_ALGOS)))
    808             self.auth_algo = AUTH_ALGOS[auth_algo]
    809             self.auth_key = auth_key
    810         else:
    811             self.auth_algo = AUTH_ALGOS['NULL']
    812             self.auth_key = None
    813 
    814         if tunnel_header and not isinstance(tunnel_header, (IP, IPv6)):
    815             raise TypeError('tunnel_header must be %s or %s' % (IP.name, IPv6.name))
    816         self.tunnel_header = tunnel_header
    817 
    818         if nat_t_header:
    819             if proto is not ESP:
    820                 raise TypeError('nat_t_header is only allowed with ESP')
    821             if not isinstance(nat_t_header, UDP):
    822                 raise TypeError('nat_t_header must be %s' % UDP.name)
    823         self.nat_t_header = nat_t_header
    824 
    825     def check_spi(self, pkt):
    826         if pkt.spi != self.spi:
    827             raise TypeError('packet spi=0x%x does not match the SA spi=0x%x' %
    828                             (pkt.spi, self.spi))
    829 
    830     def _encrypt_esp(self, pkt, seq_num=None, iv=None):
    831 
    832         if iv is None:
    833             iv = self.crypt_algo.generate_iv()
    834         else:
    835             if len(iv) != self.crypt_algo.iv_size:
    836                 raise TypeError('iv length must be %s' % self.crypt_algo.iv_size)
    837 
    838         esp = _ESPPlain(spi=self.spi, seq=seq_num or self.seq_num, iv=iv)
    839 
    840         if self.tunnel_header:
    841             tunnel = self.tunnel_header.copy()
    842 
    843             if tunnel.version == 4:
    844                 del tunnel.proto
    845                 del tunnel.len
    846                 del tunnel.chksum
    847             else:
    848                 del tunnel.nh
    849                 del tunnel.plen
    850 
    851             pkt = tunnel.__class__(raw(tunnel / pkt))
    852 
    853         ip_header, nh, payload = split_for_transport(pkt, socket.IPPROTO_ESP)
    854         esp.data = payload
    855         esp.nh = nh
    856 
    857         esp = self.crypt_algo.pad(esp)
    858         esp = self.crypt_algo.encrypt(self, esp, self.crypt_key)
    859 
    860         self.auth_algo.sign(esp, self.auth_key)
    861 
    862         if self.nat_t_header:
    863             nat_t_header = self.nat_t_header.copy()
    864             nat_t_header.chksum = 0
    865             del nat_t_header.len
    866             if ip_header.version == 4:
    867                 del ip_header.proto
    868             else:
    869                 del ip_header.nh
    870             ip_header /= nat_t_header
    871 
    872         if ip_header.version == 4:
    873             ip_header.len = len(ip_header) + len(esp)
    874             del ip_header.chksum
    875             ip_header = ip_header.__class__(raw(ip_header))
    876         else:
    877             ip_header.plen = len(ip_header.payload) + len(esp)
    878 
    879         # sequence number must always change, unless specified by the user
    880         if seq_num is None:
    881             self.seq_num += 1
    882 
    883         return ip_header / esp
    884 
    885     def _encrypt_ah(self, pkt, seq_num=None):
    886 
    887         ah = AH(spi=self.spi, seq=seq_num or self.seq_num,
    888                 icv = b"\x00" * self.auth_algo.icv_size)
    889 
    890         if self.tunnel_header:
    891             tunnel = self.tunnel_header.copy()
    892 
    893             if tunnel.version == 4:
    894                 del tunnel.proto
    895                 del tunnel.len
    896                 del tunnel.chksum
    897             else:
    898                 del tunnel.nh
    899                 del tunnel.plen
    900 
    901             pkt = tunnel.__class__(raw(tunnel / pkt))
    902 
    903         ip_header, nh, payload = split_for_transport(pkt, socket.IPPROTO_AH)
    904         ah.nh = nh
    905 
    906         if ip_header.version == 6 and len(ah) % 8 != 0:
    907             # For IPv6, the total length of the header must be a multiple of
    908             # 8-octet units.
    909             ah.padding = b"\x00" * (-len(ah) % 8)
    910         elif len(ah) % 4 != 0:
    911             # For IPv4, the total length of the header must be a multiple of
    912             # 4-octet units.
    913             ah.padding = b"\x00" * (-len(ah) % 4)
    914 
    915         # RFC 4302 - Section 2.2. Payload Length
    916         # This 8-bit field specifies the length of AH in 32-bit words (4-byte
    917         # units), minus "2".
    918         ah.payloadlen = len(ah) // 4 - 2
    919 
    920         if ip_header.version == 4:
    921             ip_header.len = len(ip_header) + len(ah) + len(payload)
    922             del ip_header.chksum
    923             ip_header = ip_header.__class__(raw(ip_header))
    924         else:
    925             ip_header.plen = len(ip_header.payload) + len(ah) + len(payload)
    926 
    927         signed_pkt = self.auth_algo.sign(ip_header / ah / payload, self.auth_key)
    928 
    929         # sequence number must always change, unless specified by the user
    930         if seq_num is None:
    931             self.seq_num += 1
    932 
    933         return signed_pkt
    934 
    935     def encrypt(self, pkt, seq_num=None, iv=None):
    936         """
    937         Encrypt (and encapsulate) an IP(v6) packet with ESP or AH according
    938         to this SecurityAssociation.
    939 
    940         @param pkt:     the packet to encrypt
    941         @param seq_num: if specified, use this sequence number instead of the
    942                         generated one
    943         @param iv:      if specified, use this initialization vector for
    944                         encryption instead of a random one.
    945 
    946         @return: the encrypted/encapsulated packet
    947         """
    948         if not isinstance(pkt, self.SUPPORTED_PROTOS):
    949             raise TypeError('cannot encrypt %s, supported protos are %s'
    950                             % (pkt.__class__, self.SUPPORTED_PROTOS))
    951         if self.proto is ESP:
    952             return self._encrypt_esp(pkt, seq_num=seq_num, iv=iv)
    953         else:
    954             return self._encrypt_ah(pkt, seq_num=seq_num)
    955 
    956     def _decrypt_esp(self, pkt, verify=True):
    957 
    958         encrypted = pkt[ESP]
    959 
    960         if verify:
    961             self.check_spi(pkt)
    962             self.auth_algo.verify(encrypted, self.auth_key)
    963 
    964         esp = self.crypt_algo.decrypt(self, encrypted, self.crypt_key,
    965                                       self.crypt_algo.icv_size or
    966                                       self.auth_algo.icv_size)
    967 
    968         if self.tunnel_header:
    969             # drop the tunnel header and return the payload untouched
    970 
    971             pkt.remove_payload()
    972             if pkt.version == 4:
    973                 pkt.proto = esp.nh
    974             else:
    975                 pkt.nh = esp.nh
    976             cls = pkt.guess_payload_class(esp.data)
    977 
    978             return cls(esp.data)
    979         else:
    980             ip_header = pkt
    981 
    982             if ip_header.version == 4:
    983                 ip_header.proto = esp.nh
    984                 del ip_header.chksum
    985                 ip_header.remove_payload()
    986                 ip_header.len = len(ip_header) + len(esp.data)
    987                 # recompute checksum
    988                 ip_header = ip_header.__class__(raw(ip_header))
    989             else:
    990                 encrypted.underlayer.nh = esp.nh
    991                 encrypted.underlayer.remove_payload()
    992                 ip_header.plen = len(ip_header.payload) + len(esp.data)
    993 
    994             cls = ip_header.guess_payload_class(esp.data)
    995 
    996             # reassemble the ip_header with the ESP payload
    997             return ip_header / cls(esp.data)
    998 
    999     def _decrypt_ah(self, pkt, verify=True):
   1000 
   1001         if verify:
   1002             self.check_spi(pkt)
   1003             self.auth_algo.verify(pkt, self.auth_key)
   1004 
   1005         ah = pkt[AH]
   1006         payload = ah.payload
   1007         payload.remove_underlayer(None)  # useless argument...
   1008 
   1009         if self.tunnel_header:
   1010             return payload
   1011         else:
   1012             ip_header = pkt
   1013 
   1014             if ip_header.version == 4:
   1015                 ip_header.proto = ah.nh
   1016                 del ip_header.chksum
   1017                 ip_header.remove_payload()
   1018                 ip_header.len = len(ip_header) + len(payload)
   1019                 # recompute checksum
   1020                 ip_header = ip_header.__class__(raw(ip_header))
   1021             else:
   1022                 ah.underlayer.nh = ah.nh
   1023                 ah.underlayer.remove_payload()
   1024                 ip_header.plen = len(ip_header.payload) + len(payload)
   1025 
   1026             # reassemble the ip_header with the AH payload
   1027             return ip_header / payload
   1028 
   1029     def decrypt(self, pkt, verify=True):
   1030         """
   1031         Decrypt (and decapsulate) an IP(v6) packet containing ESP or AH.
   1032 
   1033         @param pkt:     the packet to decrypt
   1034         @param verify:  if False, do not perform the integrity check
   1035 
   1036         @return: the decrypted/decapsulated packet
   1037         @raise IPSecIntegrityError: if the integrity check fails
   1038         """
   1039         if not isinstance(pkt, self.SUPPORTED_PROTOS):
   1040             raise TypeError('cannot decrypt %s, supported protos are %s'
   1041                             % (pkt.__class__, self.SUPPORTED_PROTOS))
   1042 
   1043         if self.proto is ESP and pkt.haslayer(ESP):
   1044             return self._decrypt_esp(pkt, verify=verify)
   1045         elif self.proto is AH and pkt.haslayer(AH):
   1046             return self._decrypt_ah(pkt, verify=verify)
   1047         else:
   1048             raise TypeError('%s has no %s layer' % (pkt, self.proto.name))
   1049