Home | History | Annotate | Download | only in tls
      1 ## This file is part of Scapy
      2 ## Copyright (C) 2017 Maxence Tury
      3 ## This program is published under a GPLv2 license
      4 
      5 """
      6 TLS 1.3 key exchange logic.
      7 """
      8 
      9 import math
     10 
     11 from scapy.config import conf, crypto_validator
     12 from scapy.error import log_runtime, warning
     13 from scapy.fields import *
     14 from scapy.packet import Packet, Raw, Padding
     15 from scapy.layers.tls.cert import PubKeyRSA, PrivKeyRSA
     16 from scapy.layers.tls.session import _GenericTLSSessionInheritance
     17 from scapy.layers.tls.basefields import _tls_version, _TLSClientVersionField
     18 from scapy.layers.tls.extensions import TLS_Ext_Unknown, _tls_ext
     19 from scapy.layers.tls.crypto.pkcs1 import pkcs_i2osp, pkcs_os2ip
     20 from scapy.layers.tls.crypto.groups import (_tls_named_ffdh_groups,
     21                                             _tls_named_curves, _ffdh_groups,
     22                                             _tls_named_groups)
     23 
     24 if conf.crypto_valid:
     25     from cryptography.hazmat.backends import default_backend
     26     from cryptography.hazmat.primitives.asymmetric import dh, ec
     27 if conf.crypto_valid_advanced:
     28     from cryptography.hazmat.primitives.asymmetric import x25519
     29 
     30 
     31 class KeyShareEntry(Packet):
     32     """
     33     When building from scratch, we create a DH private key, and when
     34     dissecting, we create a DH public key. Default group is secp256r1.
     35     """
     36     __slots__ = ["privkey", "pubkey"]
     37     name = "Key Share Entry"
     38     fields_desc = [ShortEnumField("group", None, _tls_named_groups),
     39                    FieldLenField("kxlen", None, length_of="key_exchange"),
     40                    StrLenField("key_exchange", "",
     41                                length_from=lambda pkt: pkt.kxlen) ]
     42 
     43     def __init__(self, *args, **kargs):
     44         self.privkey = None
     45         self.pubkey = None
     46         super(KeyShareEntry, self).__init__(*args, **kargs)
     47 
     48     def do_build(self):
     49         """
     50         We need this hack, else 'self' would be replaced by __iter__.next().
     51         """
     52         tmp = self.explicit
     53         self.explicit = True
     54         b = super(KeyShareEntry, self).do_build()
     55         self.explicit = tmp
     56         return b
     57 
     58     @crypto_validator
     59     def create_privkey(self):
     60         """
     61         This is called by post_build() for key creation.
     62         """
     63         if self.group in _tls_named_ffdh_groups:
     64             params = _ffdh_groups[_tls_named_ffdh_groups[self.group]][0]
     65             privkey = params.generate_private_key()
     66             self.privkey = privkey
     67             pubkey = privkey.public_key()
     68             self.key_exchange = pubkey.public_numbers().y
     69         elif self.group in _tls_named_curves:
     70             if _tls_named_curves[self.group] == "x25519":
     71                 if conf.crypto_valid_advanced:
     72                     privkey = x25519.X25519PrivateKey.generate()
     73                     self.privkey = privkey
     74                     pubkey = privkey.public_key()
     75                     self.key_exchange = pubkey.public_bytes()
     76             elif _tls_named_curves[self.group] != "x448":
     77                 curve = ec._CURVE_TYPES[_tls_named_curves[self.group]]()
     78                 privkey = ec.generate_private_key(curve, default_backend())
     79                 self.privkey = privkey
     80                 pubkey = privkey.public_key()
     81                 self.key_exchange = pubkey.public_numbers().encode_point()
     82 
     83     def post_build(self, pkt, pay):
     84         if self.group is None:
     85             self.group = 23     # secp256r1
     86 
     87         if not self.key_exchange:
     88             try:
     89                 self.create_privkey()
     90             except ImportError:
     91                 pass
     92 
     93         if self.kxlen is None:
     94             self.kxlen = len(self.key_exchange)
     95 
     96         group = struct.pack("!H", self.group)
     97         kxlen = struct.pack("!H", self.kxlen)
     98         return group + kxlen + self.key_exchange + pay
     99 
    100     @crypto_validator
    101     def register_pubkey(self):
    102         if self.group in _tls_named_ffdh_groups:
    103             params = _ffdh_groups[_tls_named_ffdh_groups[self.group]][0]
    104             pn = params.parameter_numbers()
    105             public_numbers = dh.DHPublicNumbers(self.key_exchange, pn)
    106             self.pubkey = public_numbers.public_key(default_backend())
    107         elif self.group in _tls_named_curves:
    108             if _tls_named_curves[self.group] == "x25519":
    109                 if conf.crypto_valid_advanced:
    110                     import_point = x25519.X25519PublicKey.from_public_bytes
    111                     self.pubkey = import_point(self.key_exchange)
    112             elif _tls_named_curves[self.group] != "x448":
    113                 curve = ec._CURVE_TYPES[_tls_named_curves[self.group]]()
    114                 import_point = ec.EllipticCurvePublicNumbers.from_encoded_point
    115                 public_numbers = import_point(curve, self.key_exchange)
    116                 self.pubkey = public_numbers.public_key(default_backend())
    117 
    118     def post_dissection(self, r):
    119         try:
    120             self.register_pubkey()
    121         except ImportError:
    122             pass
    123 
    124     def extract_padding(self, s):
    125         return "", s
    126 
    127 
    128 class TLS_Ext_KeyShare_CH(TLS_Ext_Unknown):
    129     name = "TLS Extension - Key Share (for ClientHello)"
    130     fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
    131                    ShortField("len", None),
    132                    FieldLenField("client_shares_len", None,
    133                                  length_of="client_shares"),
    134                    PacketListField("client_shares", [], KeyShareEntry,
    135                             length_from=lambda pkt: pkt.client_shares_len) ]
    136 
    137     def post_build(self, pkt, pay):
    138         if not self.tls_session.frozen:
    139             privshares = self.tls_session.tls13_client_privshares
    140             for kse in self.client_shares:
    141                 if kse.privkey:
    142                     if _tls_named_curves[kse.group] in privshares:
    143                         pkt_info = pkt.firstlayer().summary()
    144                         log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)
    145                         break
    146                     privshares[_tls_named_groups[kse.group]] = kse.privkey
    147         return super(TLS_Ext_KeyShare_CH, self).post_build(pkt, pay)
    148 
    149     def post_dissection(self, r):
    150         if not self.tls_session.frozen:
    151             for kse in self.client_shares:
    152                 if kse.pubkey:
    153                     pubshares = self.tls_session.tls13_client_pubshares
    154                     if _tls_named_curves[kse.group] in pubshares:
    155                         pkt_info = r.firstlayer().summary()
    156                         log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)
    157                         break
    158                     pubshares[_tls_named_curves[kse.group]] = kse.pubkey
    159         return super(TLS_Ext_KeyShare_CH, self).post_dissection(r)
    160 
    161 
    162 class TLS_Ext_KeyShare_HRR(TLS_Ext_Unknown):
    163     name = "TLS Extension - Key Share (for HelloRetryRequest)"
    164     fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
    165                    ShortField("len", None),
    166                    ShortEnumField("selected_group", None, _tls_named_groups) ]
    167 
    168 
    169 class TLS_Ext_KeyShare_SH(TLS_Ext_Unknown):
    170     name = "TLS Extension - Key Share (for ServerHello)"
    171     fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
    172                    ShortField("len", None),
    173                    PacketField("server_share", None, KeyShareEntry) ]
    174 
    175     def post_build(self, pkt, pay):
    176         if not self.tls_session.frozen and self.server_share.privkey:
    177             # if there is a privkey, we assume the crypto library is ok
    178             privshare = self.tls_session.tls13_server_privshare
    179             if len(privshare) > 0:
    180                 pkt_info = pkt.firstlayer().summary()
    181                 log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)
    182             group_name = _tls_named_groups[self.server_share.group]
    183             privshare[group_name] = self.server_share.privkey
    184 
    185             if group_name in self.tls_session.tls13_client_pubshares:
    186                 privkey = self.server_share.privkey
    187                 pubkey = self.tls_session.tls13_client_pubshares[group_name]
    188                 if group_name in six.itervalues(_tls_named_ffdh_groups):
    189                     pms = privkey.exchange(pubkey)
    190                 elif group_name in six.itervalues(_tls_named_curves):
    191                     if group_name == "x25519":
    192                         pms = privkey.exchange(pubkey)
    193                     else:
    194                         pms = privkey.exchange(ec.ECDH(), pubkey)
    195                 self.tls_session.tls13_dhe_secret = pms
    196         return super(TLS_Ext_KeyShare_SH, self).post_build(pkt, pay)
    197 
    198     def post_dissection(self, r):
    199         if not self.tls_session.frozen and self.server_share.pubkey:
    200             # if there is a pubkey, we assume the crypto library is ok
    201             pubshare = self.tls_session.tls13_server_pubshare
    202             if len(pubshare) > 0:
    203                 pkt_info = r.firstlayer().summary()
    204                 log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)
    205             group_name = _tls_named_groups[self.server_share.group]
    206             pubshare[group_name] = self.server_share.pubkey
    207 
    208             if group_name in self.tls_session.tls13_client_privshares:
    209                 pubkey = self.server_share.pubkey
    210                 privkey = self.tls_session.tls13_client_privshares[group_name]
    211                 if group_name in six.itervalues(_tls_named_ffdh_groups):
    212                     pms = privkey.exchange(pubkey)
    213                 elif group_name in six.itervalues(_tls_named_curves):
    214                     if group_name == "x25519":
    215                         pms = privkey.exchange(pubkey)
    216                     else:
    217                         pms = privkey.exchange(ec.ECDH(), pubkey)
    218                 self.tls_session.tls13_dhe_secret = pms
    219         return super(TLS_Ext_KeyShare_SH, self).post_dissection(r)
    220 
    221 
    222 _tls_ext_keyshare_cls  = { 1: TLS_Ext_KeyShare_CH,
    223                            2: TLS_Ext_KeyShare_SH,
    224                            6: TLS_Ext_KeyShare_HRR }
    225 
    226 
    227 class Ticket(Packet):
    228     name = "Recommended Ticket Construction (from RFC 5077)"
    229     fields_desc = [ StrFixedLenField("key_name", None, 16),
    230                     StrFixedLenField("iv", None, 16),
    231                     FieldLenField("encstatelen", None, length_of="encstate"),
    232                     StrLenField("encstate", "",
    233                                 length_from=lambda pkt: pkt.encstatelen),
    234                     StrFixedLenField("mac", None, 32) ]
    235 
    236 class TicketField(PacketField):
    237     __slots__ = ["length_from"]
    238     def __init__(self, name, default, length_from=None, **kargs):
    239         self.length_from = length_from
    240         PacketField.__init__(self, name, default, Ticket, **kargs)
    241 
    242     def m2i(self, pkt, m):
    243         l = self.length_from(pkt)
    244         tbd, rem = m[:l], m[l:]
    245         return self.cls(tbd)/Padding(rem)
    246 
    247 class PSKIdentity(Packet):
    248     name = "PSK Identity"
    249     fields_desc = [FieldLenField("identity_len", None,
    250                                  length_of="identity"),
    251                    TicketField("identity", "",
    252                                length_from=lambda pkt: pkt.identity_len),
    253                    IntField("obfuscated_ticket_age", 0) ]
    254 
    255 class PSKBinderEntry(Packet):
    256     name = "PSK Binder Entry"
    257     fields_desc = [FieldLenField("binder_len", None, fmt="B",
    258                                  length_of="binder"),
    259                    StrLenField("binder", "",
    260                                length_from=lambda pkt: pkt.binder_len) ]
    261 
    262 class TLS_Ext_PreSharedKey_CH(TLS_Ext_Unknown):
    263     #XXX define post_build and post_dissection methods
    264     name = "TLS Extension - Pre Shared Key (for ClientHello)"
    265     fields_desc = [ShortEnumField("type", 0x28, _tls_ext),
    266                    ShortField("len", None),
    267                    FieldLenField("identities_len", None,
    268                                  length_of="identities"),
    269                    PacketListField("identities", [], PSKIdentity,
    270                             length_from=lambda pkt: pkt.identities_len),
    271                    FieldLenField("binders_len", None,
    272                                  length_of="binders"),
    273                    PacketListField("binders", [], PSKBinderEntry,
    274                             length_from=lambda pkt: pkt.binders_len) ]
    275 
    276 
    277 class TLS_Ext_PreSharedKey_SH(TLS_Ext_Unknown):
    278     name = "TLS Extension - Pre Shared Key (for ServerHello)"
    279     fields_desc = [ShortEnumField("type", 0x29, _tls_ext),
    280                    ShortField("len", None),
    281                    ShortField("selected_identity", None) ]
    282 
    283 
    284 _tls_ext_presharedkey_cls  = { 1: TLS_Ext_PreSharedKey_CH,
    285                                2: TLS_Ext_PreSharedKey_SH }
    286 
    287