Home | History | Annotate | Download | only in patches
      1 diff --git a/third_party/tlslite/tlslite/TLSConnection.py b/third_party/tlslite/tlslite/TLSConnection.py
      2 index f8811a9..e882e2c 100644
      3 --- a/third_party/tlslite/tlslite/TLSConnection.py
      4 +++ b/third_party/tlslite/tlslite/TLSConnection.py
      5 @@ -611,6 +611,8 @@ class TLSConnection(TLSRecordLayer):
      6                                     settings.cipherImplementations)
      7  
      8              #Exchange ChangeCipherSpec and Finished messages
      9 +            for result in self._getChangeCipherSpec():
     10 +                yield result
     11              for result in self._getFinished():
     12                  yield result
     13              for result in self._sendFinished():
     14 @@ -920,6 +922,8 @@ class TLSConnection(TLSRecordLayer):
     15              #Exchange ChangeCipherSpec and Finished messages
     16              for result in self._sendFinished():
     17                  yield result
     18 +            for result in self._getChangeCipherSpec():
     19 +                yield result
     20              for result in self._getFinished():
     21                  yield result
     22  
     23 @@ -1089,6 +1093,7 @@ class TLSConnection(TLSRecordLayer):
     24          clientCertChain = None
     25          serverCertChain = None #We may set certChain to this later
     26          postFinishedError = None
     27 +        doingChannelID = False
     28  
     29          #Tentatively set version to most-desirable version, so if an error
     30          #occurs parsing the ClientHello, this is what we'll use for the
     31 @@ -1208,6 +1213,8 @@ class TLSConnection(TLSRecordLayer):
     32                  serverHello.create(self.version, serverRandom,
     33                                     session.sessionID, session.cipherSuite,
     34                                     certificateType)
     35 +                serverHello.channel_id = clientHello.channel_id
     36 +                doingChannelID = clientHello.channel_id
     37                  for result in self._sendMsg(serverHello):
     38                      yield result
     39  
     40 @@ -1221,6 +1228,11 @@ class TLSConnection(TLSRecordLayer):
     41                  #Exchange ChangeCipherSpec and Finished messages
     42                  for result in self._sendFinished():
     43                      yield result
     44 +                for result in self._getChangeCipherSpec():
     45 +                    yield result
     46 +                if doingChannelID:
     47 +                    for result in self._getEncryptedExtensions():
     48 +                        yield result
     49                  for result in self._getFinished():
     50                      yield result
     51  
     52 @@ -1399,8 +1411,12 @@ class TLSConnection(TLSRecordLayer):
     53              #Send ServerHello, Certificate[, CertificateRequest],
     54              #ServerHelloDone
     55              msgs = []
     56 -            msgs.append(ServerHello().create(self.version, serverRandom,
     57 -                        sessionID, cipherSuite, certificateType))
     58 +            serverHello = ServerHello().create(
     59 +                    self.version, serverRandom,
     60 +                    sessionID, cipherSuite, certificateType)
     61 +            serverHello.channel_id = clientHello.channel_id
     62 +            doingChannelID = clientHello.channel_id
     63 +            msgs.append(serverHello)
     64              msgs.append(Certificate(certificateType).create(serverCertChain))
     65              if reqCert and reqCAs:
     66                  msgs.append(CertificateRequest().create([], reqCAs))
     67 @@ -1528,6 +1544,11 @@ class TLSConnection(TLSRecordLayer):
     68                                 settings.cipherImplementations)
     69  
     70          #Exchange ChangeCipherSpec and Finished messages
     71 +        for result in self._getChangeCipherSpec():
     72 +            yield result
     73 +        if doingChannelID:
     74 +            for result in self._getEncryptedExtensions():
     75 +                yield result
     76          for result in self._getFinished():
     77              yield result
     78  
     79 diff --git a/third_party/tlslite/tlslite/TLSRecordLayer.py b/third_party/tlslite/tlslite/TLSRecordLayer.py
     80 index 1bbd09d..933b95a 100644
     81 --- a/third_party/tlslite/tlslite/TLSRecordLayer.py
     82 +++ b/third_party/tlslite/tlslite/TLSRecordLayer.py
     83 @@ -714,6 +714,8 @@ class TLSRecordLayer:
     84                                              self.version).parse(p)
     85                  elif subType == HandshakeType.finished:
     86                      yield Finished(self.version).parse(p)
     87 +                elif subType == HandshakeType.encrypted_extensions:
     88 +                    yield EncryptedExtensions().parse(p)
     89                  else:
     90                      raise AssertionError()
     91  
     92 @@ -1067,7 +1069,7 @@ class TLSRecordLayer:
     93          for result in self._sendMsg(finished):
     94              yield result
     95  
     96 -    def _getFinished(self):
     97 +    def _getChangeCipherSpec(self):
     98          #Get and check ChangeCipherSpec
     99          for result in self._getMsg(ContentType.change_cipher_spec):
    100              if result in (0,1):
    101 @@ -1082,6 +1084,15 @@ class TLSRecordLayer:
    102          #Switch to pending read state
    103          self._changeReadState()
    104  
    105 +    def _getEncryptedExtensions(self):
    106 +        for result in self._getMsg(ContentType.handshake,
    107 +                                   HandshakeType.encrypted_extensions):
    108 +            if result in (0,1):
    109 +                yield result
    110 +        encrypted_extensions = result
    111 +        self.channel_id = encrypted_extensions.channel_id_key
    112 +
    113 +    def _getFinished(self):
    114          #Calculate verification data
    115          verifyData = self._calcFinished(False)
    116  
    117 diff --git a/third_party/tlslite/tlslite/constants.py b/third_party/tlslite/tlslite/constants.py
    118 index 04302c0..e357dd0 100644
    119 --- a/third_party/tlslite/tlslite/constants.py
    120 +++ b/third_party/tlslite/tlslite/constants.py
    121 @@ -22,6 +22,7 @@ class HandshakeType:
    122      certificate_verify = 15
    123      client_key_exchange = 16
    124      finished = 20
    125 +    encrypted_extensions = 203
    126  
    127  class ContentType:
    128      change_cipher_spec = 20
    129 @@ -30,6 +31,9 @@ class ContentType:
    130      application_data = 23
    131      all = (20,21,22,23)
    132  
    133 +class ExtensionType:
    134 +    channel_id = 30031
    135 +
    136  class AlertLevel:
    137      warning = 1
    138      fatal = 2
    139 diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py
    140 index dc6ed32..fa4d817 100644
    141 --- a/third_party/tlslite/tlslite/messages.py
    142 +++ b/third_party/tlslite/tlslite/messages.py
    143 @@ -130,6 +130,7 @@ class ClientHello(HandshakeMsg):
    144          self.certificate_types = [CertificateType.x509]
    145          self.compression_methods = []   # a list of 8-bit values
    146          self.srp_username = None        # a string
    147 +        self.channel_id = False
    148  
    149      def create(self, version, random, session_id, cipher_suites,
    150                 certificate_types=None, srp_username=None):
    151 @@ -174,6 +175,8 @@ class ClientHello(HandshakeMsg):
    152                          self.srp_username = bytesToString(p.getVarBytes(1))
    153                      elif extType == 7:
    154                          self.certificate_types = p.getVarList(1, 1)
    155 +                    elif extType == ExtensionType.channel_id:
    156 +                        self.channel_id = True
    157                      else:
    158                          p.getFixBytes(extLength)
    159                      soFar += 4 + extLength
    160 @@ -220,6 +223,7 @@ class ServerHello(HandshakeMsg):
    161          self.cipher_suite = 0
    162          self.certificate_type = CertificateType.x509
    163          self.compression_method = 0
    164 +        self.channel_id = False
    165  
    166      def create(self, version, random, session_id, cipher_suite,
    167                 certificate_type):
    168 @@ -266,6 +270,9 @@ class ServerHello(HandshakeMsg):
    169                  CertificateType.x509:
    170              extLength += 5
    171  
    172 +        if self.channel_id:
    173 +            extLength += 4
    174 +
    175          if extLength != 0:
    176              w.add(extLength, 2)
    177  
    178 @@ -275,6 +282,10 @@ class ServerHello(HandshakeMsg):
    179              w.add(1, 2)
    180              w.add(self.certificate_type, 1)
    181  
    182 +        if self.channel_id:
    183 +            w.add(ExtensionType.channel_id, 2)
    184 +            w.add(0, 2)
    185 +
    186          return HandshakeMsg.postWrite(self, w, trial)
    187  
    188  class Certificate(HandshakeMsg):
    189 @@ -567,6 +578,28 @@ class Finished(HandshakeMsg):
    190          w.addFixSeq(self.verify_data, 1)
    191          return HandshakeMsg.postWrite(self, w, trial)
    192  
    193 +class EncryptedExtensions(HandshakeMsg):
    194 +    def __init__(self):
    195 +        self.channel_id_key = None
    196 +        self.channel_id_proof = None
    197 +
    198 +    def parse(self, p):
    199 +        p.startLengthCheck(3)
    200 +        soFar = 0
    201 +        while soFar != p.lengthCheck:
    202 +            extType = p.get(2)
    203 +            extLength = p.get(2)
    204 +            if extType == ExtensionType.channel_id:
    205 +                if extLength != 32*4:
    206 +                    raise SyntaxError()
    207 +                self.channel_id_key = p.getFixBytes(64)
    208 +                self.channel_id_proof = p.getFixBytes(64)
    209 +            else:
    210 +                p.getFixBytes(extLength)
    211 +            soFar += 4 + extLength
    212 +        p.stopLengthCheck()
    213 +        return self
    214 +
    215  class ApplicationData(Msg):
    216      def __init__(self):
    217          self.contentType = ContentType.application_data
    218