Home | History | Annotate | Download | only in securegcm
      1 /* Copyright 2018 Google LLC
      2  *
      3  * Licensed under the Apache License, Version 2.0 (the "License");
      4  * you may not use this file except in compliance with the License.
      5  * You may obtain a copy of the License at
      6  *
      7  *     https://www.apache.org/licenses/LICENSE-2.0
      8  *
      9  * Unless required by applicable law or agreed to in writing, software
     10  * distributed under the License is distributed on an "AS IS" BASIS,
     11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12  * See the License for the specific language governing permissions and
     13  * limitations under the License.
     14  */
     15 package com.google.security.cryptauth.lib.securegcm;
     16 
     17 import com.google.protobuf.ByteString;
     18 import com.google.protobuf.InvalidProtocolBufferException;
     19 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2Alert;
     20 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2ClientFinished;
     21 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2ClientInit;
     22 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2ClientInit.CipherCommitment;
     23 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2Message;
     24 import com.google.security.cryptauth.lib.securegcm.UkeyProto.Ukey2ServerInit;
     25 import com.google.security.cryptauth.lib.securemessage.CryptoOps;
     26 import com.google.security.cryptauth.lib.securemessage.PublicKeyProtoUtil;
     27 import com.google.security.cryptauth.lib.securemessage.SecureMessageProto.GenericPublicKey;
     28 import java.io.ByteArrayOutputStream;
     29 import java.io.IOException;
     30 import java.io.UnsupportedEncodingException;
     31 import java.security.InvalidKeyException;
     32 import java.security.KeyPair;
     33 import java.security.MessageDigest;
     34 import java.security.NoSuchAlgorithmException;
     35 import java.security.PublicKey;
     36 import java.security.SecureRandom;
     37 import java.security.spec.InvalidKeySpecException;
     38 import java.util.Arrays;
     39 import java.util.HashMap;
     40 import java.util.List;
     41 import javax.annotation.Nullable;
     42 import javax.crypto.SecretKey;
     43 import javax.crypto.spec.SecretKeySpec;
     44 
     45 /**
     46  * Implements UKEY2 and produces a {@link D2DConnectionContext}.
     47  *
     48  * <p>Client Usage:
     49  * <code>
     50  * try {
     51  *   Ukey2Handshake client = Ukey2Handshake.forInitiator(HandshakeCipher.P256_SHA512);
     52  *   byte[] handshakeMessage;
     53  *
     54  *   // Message 1 (Client Init)
     55  *   handshakeMessage = client.getNextHandshakeMessage();
     56  *   sendMessageToServer(handshakeMessage);
     57  *
     58  *   // Message 2 (Server Init)
     59  *   handshakeMessage = receiveMessageFromServer();
     60  *   client.parseHandshakeMessage(handshakeMessage);
     61  *
     62  *   // Message 3 (Client Finish)
     63  *   handshakeMessage = client.getNextHandshakeMessage();
     64  *   sendMessageToServer(handshakeMessage);
     65  *
     66  *   // Get the auth string
     67  *   byte[] clientAuthString = client.getVerificationString(STRING_LENGTH);
     68  *   showStringToUser(clientAuthString);
     69  *
     70  *   // Using out-of-band channel, verify auth string, then call:
     71  *   client.verifyHandshake();
     72  *
     73  *   // Make a connection context
     74  *   D2DConnectionContext clientContext = client.toConnectionContext();
     75  * } catch (AlertException e) {
     76  *   log(e.getMessage);
     77  *   sendMessageToServer(e.getAlertMessageToSend());
     78  * } catch (HandshakeException e) {
     79  *   log(e);
     80  *   // terminate handshake
     81  * }
     82  * </code>
     83  *
     84  * <p>Server Usage:
     85  * <code>
     86  * try {
     87  *   Ukey2Handshake server = Ukey2Handshake.forResponder(HandshakeCipher.P256_SHA512);
     88  *   byte[] handshakeMessage;
     89  *
     90  *   // Message 1 (Client Init)
     91  *   handshakeMessage = receiveMessageFromClient();
     92  *   server.parseHandshakeMessage(handshakeMessage);
     93  *
     94  *   // Message 2 (Server Init)
     95  *   handshakeMessage = server.getNextHandshakeMessage();
     96  *   sendMessageToServer(handshakeMessage);
     97  *
     98  *   // Message 3 (Client Finish)
     99  *   handshakeMessage = receiveMessageFromClient();
    100  *   server.parseHandshakeMessage(handshakeMessage);
    101  *
    102  *   // Get the auth string
    103  *   byte[] serverAuthString = server.getVerificationString(STRING_LENGTH);
    104  *   showStringToUser(serverAuthString);
    105  *
    106  *   // Using out-of-band channel, verify auth string, then call:
    107  *   server.verifyHandshake();
    108  *
    109  *   // Make a connection context
    110  *   D2DConnectionContext serverContext = server.toConnectionContext();
    111  * } catch (AlertException e) {
    112  *   log(e.getMessage);
    113  *   sendMessageToClient(e.getAlertMessageToSend());
    114  * } catch (HandshakeException e) {
    115  *   log(e);
    116  *   // terminate handshake
    117  * }
    118  * </code>
    119  */
    120 public class Ukey2Handshake {
    121 
    122   /**
    123    * Creates a {@link Ukey2Handshake} with a particular cipher that can be used by an initiator /
    124    * client.
    125    *
    126    * @throws HandshakeException
    127    */
    128   public static Ukey2Handshake forInitiator(HandshakeCipher cipher) throws HandshakeException {
    129     return new Ukey2Handshake(InternalState.CLIENT_START, cipher);
    130   }
    131 
    132   /**
    133    * Creates a {@link Ukey2Handshake} with a particular cipher that can be used by an responder /
    134    * server.
    135    *
    136    * @throws HandshakeException
    137    */
    138   public static Ukey2Handshake forResponder(HandshakeCipher cipher) throws HandshakeException {
    139     return new Ukey2Handshake(InternalState.SERVER_START, cipher);
    140   }
    141 
    142   /**
    143    * Handshake States. Meaning of states:
    144    * <ul>
    145    * <li>IN_PROGRESS: The handshake is in progress, caller should use
    146    * {@link Ukey2Handshake#getNextHandshakeMessage()} and
    147    * {@link Ukey2Handshake#parseHandshakeMessage(byte[])} to continue the handshake.
    148    * <li>VERIFICATION_NEEDED: The handshake is complete, but pending verification of the
    149    * authentication string. Clients should use {@link Ukey2Handshake#getVerificationString(int)} to
    150    * get the verification string and use out-of-band methods to authenticate the handshake.
    151    * <li>VERIFICATION_IN_PROGRESS: The handshake is complete, verification string has been
    152    * generated, but has not been confirmed. After authenticating the handshake out-of-band, use
    153    * {@link Ukey2Handshake#verifyHandshake()} to mark the handshake as verified.
    154    * <li>FINISHED: The handshake is finished, and caller can use
    155    * {@link Ukey2Handshake#toConnectionContext()} to produce a {@link D2DConnectionContext}.
    156    * <li>ALREADY_USED: The handshake has already been used and should be discarded / garbage
    157    * collected.
    158    * <li>ERROR: The handshake produced an error and should be destroyed.
    159    * </ul>
    160    */
    161   public enum State {
    162     IN_PROGRESS,
    163     VERIFICATION_NEEDED,
    164     VERIFICATION_IN_PROGRESS,
    165     FINISHED,
    166     ALREADY_USED,
    167     ERROR,
    168   }
    169 
    170   /**
    171    * Currently implemented UKEY2 handshake ciphers. Each cipher is a tuple consisting of a key
    172    * negotiation cipher and a hash function used for a commitment. Currently the ciphers are:
    173    * <code>
    174    *   +-----------------------------------------------------+
    175    *   | Enum        | Key negotiation       | Hash function |
    176    *   +-------------+-----------------------+---------------+
    177    *   | P256_SHA512 | ECDH using NIST P-256 | SHA512        |
    178    *   +-----------------------------------------------------+
    179    * </code>
    180    *
    181    * <p>Note that these should correspond to values in device_to_device_messages.proto.
    182    */
    183   public enum HandshakeCipher {
    184     P256_SHA512(UkeyProto.Ukey2HandshakeCipher.P256_SHA512);
    185     // TODO(aczeskis): add CURVE25519_SHA512
    186 
    187     private final UkeyProto.Ukey2HandshakeCipher value;
    188 
    189     HandshakeCipher(UkeyProto.Ukey2HandshakeCipher value) {
    190       // Make sure we only accept values that are valid as per the ukey protobuf.
    191       // NOTE: Don't use switch statement on value, as that will trigger a bug. b/30682989.
    192       if (value == UkeyProto.Ukey2HandshakeCipher.P256_SHA512) {
    193           this.value = value;
    194       } else {
    195           throw new IllegalArgumentException("Unknown cipher value: " + value);
    196       }
    197     }
    198 
    199     public UkeyProto.Ukey2HandshakeCipher getValue() {
    200       return value;
    201     }
    202   }
    203 
    204   /**
    205    * If thrown, this exception contains information that should be sent on the wire. Specifically,
    206    * the {@link #getAlertMessageToSend()} method returns a <code>byte[]</code> that communicates the
    207    * error to the other party in the handshake. Meanwhile, the {@link #getMessage()} method can be
    208    * used to get a log-able error message.
    209    */
    210   public static class AlertException extends Exception {
    211     private final Ukey2Alert alertMessageToSend;
    212 
    213     public AlertException(String alertMessageToLog, Ukey2Alert alertMessageToSend) {
    214       super(alertMessageToLog);
    215       this.alertMessageToSend = alertMessageToSend;
    216     }
    217 
    218     /**
    219      * @return a message suitable for sending to other member of handshake.
    220      */
    221     public byte[] getAlertMessageToSend() {
    222       return alertMessageToSend.toByteArray();
    223     }
    224   }
    225 
    226   // Maximum version of the handshake supported by this class.
    227   public static final int VERSION = 1;
    228 
    229   // Random nonce is fixed at 32 bytes (as per go/ukey2).
    230   private static final int NONCE_LENGTH_IN_BYTES = 32;
    231 
    232   private static final String UTF_8 = "UTF-8";
    233 
    234   // Currently, we only support one next protocol.
    235   private static final String NEXT_PROTOCOL = "AES_256_CBC-HMAC_SHA256";
    236 
    237   // Clients need to store a map of message 3's (client finishes) for each commitment.
    238   private final HashMap<HandshakeCipher, byte[]> rawMessage3Map = new HashMap<>();
    239 
    240   private final HandshakeCipher handshakeCipher;
    241   private final HandshakeRole handshakeRole;
    242   private InternalState handshakeState;
    243   private final KeyPair ourKeyPair;
    244   private PublicKey theirPublicKey;
    245   private SecretKey derivedSecretKey;
    246 
    247   // Servers need to store client commitments.
    248   private byte[] theirCommitment;
    249 
    250   // We store the raw messages sent for computing the authentication strings and next key.
    251   private byte[] rawMessage1;
    252   private byte[] rawMessage2;
    253 
    254   // Enums for internal state machinery
    255   private enum InternalState {
    256     // Initiator/client state
    257     CLIENT_START,
    258     CLIENT_WAITING_FOR_SERVER_INIT,
    259     CLIENT_AFTER_SERVER_INIT,
    260 
    261     // Responder/server state
    262     SERVER_START,
    263     SERVER_AFTER_CLIENT_INIT,
    264     SERVER_WAITING_FOR_CLIENT_FINISHED,
    265 
    266     // Common completion state
    267     HANDSHAKE_VERIFICATION_NEEDED,
    268     HANDSHAKE_VERIFICATION_IN_PROGRESS,
    269     HANDSHAKE_FINISHED,
    270     HANDSHAKE_ALREADY_USED,
    271     HANDSHAKE_ERROR,
    272   }
    273 
    274   // Helps us remember our role in the handshake
    275   private enum HandshakeRole {
    276     CLIENT,
    277     SERVER
    278   }
    279 
    280   /**
    281    * Never invoked directly. Caller should use {@link #forInitiator(HandshakeCipher)} or
    282    * {@link #forResponder(HandshakeCipher)} instead.
    283    *
    284    * @throws HandshakeException if an unrecoverable error occurs and the connection should be shut
    285    * down.
    286    */
    287   private Ukey2Handshake(InternalState state, HandshakeCipher cipher) throws HandshakeException {
    288     if (cipher == null) {
    289       throwIllegalArgumentException("Invalid handshake cipher");
    290     }
    291     this.handshakeCipher = cipher;
    292 
    293     switch (state) {
    294       case CLIENT_START:
    295         handshakeRole = HandshakeRole.CLIENT;
    296         break;
    297       case SERVER_START:
    298         handshakeRole = HandshakeRole.SERVER;
    299         break;
    300       default:
    301         throwIllegalStateException("Invalid handshake state");
    302         handshakeRole = null; // unreachable, but makes compiler happy
    303     }
    304     this.handshakeState = state;
    305 
    306     this.ourKeyPair = genKeyPair(cipher);
    307   }
    308 
    309   /**
    310    * Get the next handshake message suitable for sending on the wire.
    311    *
    312    * @throws HandshakeException if an unrecoverable error occurs and the connection should be shut
    313    * down.
    314    */
    315   public byte[] getNextHandshakeMessage() throws HandshakeException {
    316     switch (handshakeState) {
    317       case CLIENT_START:
    318         rawMessage1 = makeUkey2Message(Ukey2Message.Type.CLIENT_INIT, makeClientInitMessage());
    319         handshakeState = InternalState.CLIENT_WAITING_FOR_SERVER_INIT;
    320         return rawMessage1;
    321 
    322       case SERVER_AFTER_CLIENT_INIT:
    323         rawMessage2 = makeUkey2Message(Ukey2Message.Type.SERVER_INIT, makeServerInitMessage());
    324         handshakeState = InternalState.SERVER_WAITING_FOR_CLIENT_FINISHED;
    325         return rawMessage2;
    326 
    327       case CLIENT_AFTER_SERVER_INIT:
    328         // Make sure we have a message 3 for the chosen cipher.
    329         if (!rawMessage3Map.containsKey(handshakeCipher)) {
    330           throwIllegalStateException(
    331               "Client state is CLIENT_AFTER_SERVER_INIT, and cipher is "
    332                   + handshakeCipher
    333                   + ", but no corresponding raw client finished message has been generated");
    334         }
    335         handshakeState = InternalState.HANDSHAKE_VERIFICATION_NEEDED;
    336         return rawMessage3Map.get(handshakeCipher);
    337 
    338       default:
    339         throwIllegalStateException("Cannot get next message in state: " + handshakeState);
    340         return null; // unreachable, but makes compiler happy
    341     }
    342   }
    343 
    344   /**
    345    * Returns an authentication string suitable for authenticating the handshake out-of-band. Note
    346    * that the authentication string can be short (e.g., a 6 digit visual confirmation code). Note:
    347    * this should only be called when the state returned byte {@link #getHandshakeState()} is
    348    * {@link State#VERIFICATION_NEEDED}, which means this can only be called once.
    349    *
    350    * @param byteLength length of output in bytes. Min length is 1; max length is 32.
    351    */
    352   public byte[] getVerificationString(int byteLength) throws HandshakeException {
    353     if (byteLength < 1 || byteLength > 32) {
    354       throwIllegalArgumentException("Minimum length is 1 byte, max is 32 bytes");
    355     }
    356 
    357     if (handshakeState != InternalState.HANDSHAKE_VERIFICATION_NEEDED) {
    358       throwIllegalStateException("Unexpected state: " + handshakeState);
    359     }
    360 
    361     try {
    362       derivedSecretKey =
    363           EnrollmentCryptoOps.doKeyAgreement(ourKeyPair.getPrivate(), theirPublicKey);
    364     } catch (InvalidKeyException e) {
    365       // unreachable in practice
    366       throwHandshakeException(e);
    367     }
    368 
    369     ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
    370     try {
    371       byteStream.write(rawMessage1);
    372       byteStream.write(rawMessage2);
    373     } catch (IOException e) {
    374       // unreachable in practice
    375       throwHandshakeException(e);
    376     }
    377     byte[] info = byteStream.toByteArray();
    378 
    379     byte[] salt = null;
    380 
    381     try {
    382       salt = "UKEY2 v1 auth".getBytes(UTF_8);
    383     } catch (UnsupportedEncodingException e) {
    384       // unreachable in practice
    385       throwHandshakeException(e);
    386     }
    387 
    388     byte[] authString = null;
    389     try {
    390       authString = CryptoOps.hkdf(derivedSecretKey, salt, info);
    391     } catch (InvalidKeyException | NoSuchAlgorithmException e) {
    392       // unreachable in practice
    393       throwHandshakeException(e);
    394     }
    395 
    396     handshakeState = InternalState.HANDSHAKE_VERIFICATION_IN_PROGRESS;
    397     return Arrays.copyOf(authString, byteLength);
    398   }
    399 
    400   /**
    401    * Invoked to let handshake state machine know that caller has validated the authentication
    402    * string obtained via {@link #getVerificationString(int)}; Note: this should only be called when
    403    * the state returned byte {@link #getHandshakeState()} is {@link State#VERIFICATION_IN_PROGRESS}.
    404    */
    405   public void verifyHandshake() {
    406     if (handshakeState != InternalState.HANDSHAKE_VERIFICATION_IN_PROGRESS) {
    407       throwIllegalStateException("Unexpected state: " + handshakeState);
    408     }
    409     handshakeState = InternalState.HANDSHAKE_FINISHED;
    410   }
    411 
    412   /**
    413    * Parses the given handshake message.
    414    * @throws AlertException if an error occurs that should be sent to other party.
    415    * @throws HandshakeException in an error occurs and the connection should be torn down.
    416    */
    417   public void parseHandshakeMessage(byte[] handshakeMessage)
    418       throws AlertException, HandshakeException {
    419     switch (handshakeState) {
    420       case SERVER_START:
    421         parseMessage1(handshakeMessage);
    422         handshakeState = InternalState.SERVER_AFTER_CLIENT_INIT;
    423         break;
    424 
    425       case CLIENT_WAITING_FOR_SERVER_INIT:
    426         parseMessage2(handshakeMessage);
    427         handshakeState = InternalState.CLIENT_AFTER_SERVER_INIT;
    428         break;
    429 
    430       case SERVER_WAITING_FOR_CLIENT_FINISHED:
    431         parseMessage3(handshakeMessage);
    432         handshakeState = InternalState.HANDSHAKE_VERIFICATION_NEEDED;
    433         break;
    434 
    435       default:
    436         throwIllegalStateException("Cannot parse message in state " + handshakeState);
    437     }
    438   }
    439 
    440   /**
    441    * Returns the current state of the handshake. See {@link State}.
    442    */
    443   public State getHandshakeState() {
    444     switch (handshakeState) {
    445       case CLIENT_START:
    446       case CLIENT_WAITING_FOR_SERVER_INIT:
    447       case CLIENT_AFTER_SERVER_INIT:
    448       case SERVER_START:
    449       case SERVER_WAITING_FOR_CLIENT_FINISHED:
    450       case SERVER_AFTER_CLIENT_INIT:
    451         // fallback intended -- these are all in-progress states
    452         return State.IN_PROGRESS;
    453 
    454       case HANDSHAKE_ERROR:
    455         return State.ERROR;
    456 
    457       case HANDSHAKE_VERIFICATION_NEEDED:
    458         return State.VERIFICATION_NEEDED;
    459 
    460       case HANDSHAKE_VERIFICATION_IN_PROGRESS:
    461         return State.VERIFICATION_IN_PROGRESS;
    462 
    463       case HANDSHAKE_FINISHED:
    464         return State.FINISHED;
    465 
    466       case HANDSHAKE_ALREADY_USED:
    467         return State.ALREADY_USED;
    468 
    469       default:
    470         // unreachable in practice
    471         throwIllegalStateException("Unknown state");
    472         return null; // really unreachable, but makes compiler happy
    473     }
    474   }
    475 
    476   /**
    477    * Can be called to generate a {@link D2DConnectionContext}. Note: this should only be called
    478    * when the state returned byte {@link #getHandshakeState()} is {@link State#FINISHED}.
    479    *
    480    * @throws HandshakeException
    481    */
    482   public D2DConnectionContext toConnectionContext() throws HandshakeException {
    483     switch (handshakeState) {
    484       case HANDSHAKE_ERROR:
    485         throwIllegalStateException("Cannot make context; handshake had error");
    486         return null; // makes linter happy
    487       case HANDSHAKE_ALREADY_USED:
    488         throwIllegalStateException("Cannot reuse handshake context; is has already been used");
    489         return null; // makes linter happy
    490       case HANDSHAKE_VERIFICATION_NEEDED:
    491         throwIllegalStateException("Handshake not verified, cannot create context");
    492         return null; // makes linter happy
    493       case HANDSHAKE_FINISHED:
    494         // We're done, okay to return a context
    495         break;
    496       default:
    497         // unreachable in practice
    498         throwIllegalStateException("Handshake is not complete; cannot create connection context");
    499     }
    500 
    501     if (derivedSecretKey == null) {
    502       throwIllegalStateException("Unexpected state error: derived key is null");
    503     }
    504 
    505     ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
    506     try {
    507       byteStream.write(rawMessage1);
    508       byteStream.write(rawMessage2);
    509     } catch (IOException e) {
    510       // unreachable in practice
    511       throwHandshakeException(e);
    512     }
    513     byte[] info = byteStream.toByteArray();
    514 
    515     byte[] salt = null;
    516     try {
    517       salt = "UKEY2 v1 next".getBytes(UTF_8);
    518     } catch (UnsupportedEncodingException e) {
    519       // unreachable
    520       throwHandshakeException(e);
    521     }
    522 
    523     SecretKey nextProtocolKey = null;
    524     try {
    525       nextProtocolKey = new SecretKeySpec(CryptoOps.hkdf(derivedSecretKey, salt, info), "AES");
    526     } catch (InvalidKeyException | NoSuchAlgorithmException e) {
    527       // unreachable in practice
    528       throwHandshakeException(e);
    529     }
    530 
    531     SecretKey clientKey = null;
    532     SecretKey serverKey = null;
    533     try {
    534       clientKey = D2DCryptoOps.deriveNewKeyForPurpose(nextProtocolKey, "client");
    535       serverKey = D2DCryptoOps.deriveNewKeyForPurpose(nextProtocolKey, "server");
    536     } catch (InvalidKeyException | NoSuchAlgorithmException e) {
    537       // unreachable in practice
    538       throwHandshakeException(e);
    539     }
    540 
    541     handshakeState = InternalState.HANDSHAKE_ALREADY_USED;
    542 
    543     return new D2DConnectionContextV1(
    544         handshakeRole == HandshakeRole.CLIENT ? clientKey : serverKey,
    545         handshakeRole == HandshakeRole.CLIENT ? serverKey : clientKey,
    546         0 /* initial encode sequence number */,
    547         0 /* initial decode sequence number */);
    548   }
    549 
    550   /**
    551    * Generates the byte[] encoding of a {@link Ukey2ClientInit} message.
    552    *
    553    * @throws HandshakeException
    554    */
    555   private byte[] makeClientInitMessage() throws HandshakeException {
    556     Ukey2ClientInit.Builder clientInit = Ukey2ClientInit.newBuilder();
    557     clientInit.setVersion(VERSION);
    558     clientInit.setRandom(ByteString.copyFrom(generateRandomNonce()));
    559     clientInit.setNextProtocol(NEXT_PROTOCOL);
    560 
    561     // At the moment, we only support one cipher
    562     clientInit.addCipherCommitments(generateP256SHA512Commitment());
    563 
    564     return clientInit.build().toByteArray();
    565   }
    566 
    567   /**
    568    * Generates the byte[] encoding of a {@link Ukey2ServerInit} message.
    569    */
    570   private byte[] makeServerInitMessage() {
    571     Ukey2ServerInit.Builder serverInit = Ukey2ServerInit.newBuilder();
    572     serverInit.setVersion(VERSION);
    573     serverInit.setRandom(ByteString.copyFrom(generateRandomNonce()));
    574     serverInit.setHandshakeCipher(handshakeCipher.getValue());
    575     serverInit.setPublicKey(
    576         PublicKeyProtoUtil.encodePublicKey(ourKeyPair.getPublic()).toByteString());
    577 
    578     return serverInit.build().toByteArray();
    579   }
    580 
    581   /**
    582    * Generates a keypair for the provided handshake cipher. Currently only P256_SHA512 is
    583    * supported.
    584    *
    585    * @throws HandshakeException
    586    */
    587   private KeyPair genKeyPair(HandshakeCipher cipher) throws HandshakeException {
    588     switch (cipher) {
    589       case P256_SHA512:
    590         return PublicKeyProtoUtil.generateEcP256KeyPair();
    591       default:
    592         // Should never happen
    593         throwHandshakeException("unknown cipher: " + cipher);
    594     }
    595     return null; // unreachable, but makes compiler happy
    596   }
    597 
    598   /**
    599    * Attempts to parse message 1 (which is a wrapped {@link Ukey2ClientInit}). See go/ukey2 for
    600    * details.
    601    *
    602    * @throws AlertException if an error occurs
    603    */
    604   private void parseMessage1(byte[] handshakeMessage) throws AlertException, HandshakeException {
    605     // Deserialize the protobuf; send a BAD_MESSAGE message if deserialization fails
    606     Ukey2Message message = null;
    607     try {
    608       message = Ukey2Message.parseFrom(handshakeMessage);
    609     } catch (InvalidProtocolBufferException e) {
    610       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE,
    611           "Can't parse message 1 " + e.getMessage());
    612     }
    613 
    614     // Verify that message_type == Type.CLIENT_INIT; send a BAD_MESSAGE_TYPE message if mismatch
    615     if (!message.hasMessageType() || message.getMessageType() != Ukey2Message.Type.CLIENT_INIT) {
    616       throwAlertException(
    617           Ukey2Alert.AlertType.BAD_MESSAGE_TYPE,
    618           "Expected, but did not find ClientInit message type");
    619     }
    620 
    621     // Deserialize message_data as a ClientInit message; send a BAD_MESSAGE_DATA message if
    622     // deserialization fails
    623     if (!message.hasMessageData()) {
    624       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_DATA,
    625           "Expected message data, but didn't find it");
    626     }
    627     Ukey2ClientInit clientInit = null;
    628     try {
    629       clientInit = Ukey2ClientInit.parseFrom(message.getMessageData());
    630     } catch (InvalidProtocolBufferException e) {
    631       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_DATA,
    632           "Can't parse message data into ClientInit");
    633     }
    634 
    635     // Check that version == VERSION; send BAD_VERSION message if mismatch
    636     if (!clientInit.hasVersion()) {
    637       throwAlertException(Ukey2Alert.AlertType.BAD_VERSION, "ClientInit missing version");
    638     }
    639     if (clientInit.getVersion() != VERSION) {
    640       throwAlertException(Ukey2Alert.AlertType.BAD_VERSION, "ClientInit version mismatch");
    641     }
    642 
    643     // Check that random is exactly NONCE_LENGTH_IN_BYTES bytes; send Alert.BAD_RANDOM message if
    644     // not.
    645     if (!clientInit.hasRandom()) {
    646       throwAlertException(Ukey2Alert.AlertType.BAD_RANDOM, "ClientInit missing random");
    647     }
    648     if (clientInit.getRandom().toByteArray().length != NONCE_LENGTH_IN_BYTES) {
    649       throwAlertException(Ukey2Alert.AlertType.BAD_RANDOM, "ClientInit has incorrect nonce length");
    650     }
    651 
    652     // Check to see if any of the handshake_cipher in cipher_commitment are acceptable. Servers
    653     // should select the first handshake_cipher that it finds acceptable to support clients
    654     // signaling deprecated but supported HandshakeCiphers. If no handshake_cipher is acceptable
    655     // (or there are no HandshakeCiphers in the message), the server sends a BAD_HANDSHAKE_CIPHER
    656     //  message
    657     List<Ukey2ClientInit.CipherCommitment> commitments = clientInit.getCipherCommitmentsList();
    658     if (commitments.isEmpty()) {
    659       throwAlertException(
    660           Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER, "ClientInit is missing cipher commitments");
    661     }
    662     for (Ukey2ClientInit.CipherCommitment commitment : commitments) {
    663       if (!commitment.hasHandshakeCipher()
    664           || !commitment.hasCommitment()) {
    665         throwAlertException(
    666             Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER,
    667             "ClientInit has improperly formatted cipher commitment");
    668       }
    669 
    670       // TODO(aczeskis): for now we only support one cipher, eventually support more
    671       if (commitment.getHandshakeCipher() == handshakeCipher.getValue()) {
    672         theirCommitment = commitment.getCommitment().toByteArray();
    673       }
    674     }
    675     if (theirCommitment == null) {
    676       throwAlertException(Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER,
    677           "No acceptable commitments found");
    678     }
    679 
    680     // Checks that next_protocol contains a protocol that the server supports. Send a
    681     // BAD_NEXT_PROTOCOL message if not. We currently only support one protocol
    682     if (!clientInit.hasNextProtocol() || !NEXT_PROTOCOL.equals(clientInit.getNextProtocol())) {
    683       throwAlertException(Ukey2Alert.AlertType.BAD_NEXT_PROTOCOL, "Incorrect next protocol");
    684     }
    685 
    686     // Store raw message for AUTH_STRING computation
    687     rawMessage1 = handshakeMessage;
    688   }
    689 
    690   /**
    691    * Attempts to parse message 2 (which is a wrapped {@link Ukey2ServerInit}). See go/ukey2 for
    692    * details.
    693    */
    694   private void parseMessage2(final byte[] handshakeMessage)
    695       throws AlertException, HandshakeException {
    696     // Deserialize the protobuf; send a BAD_MESSAGE message if deserialization fails
    697     Ukey2Message message = null;
    698     try {
    699       message = Ukey2Message.parseFrom(handshakeMessage);
    700     } catch (InvalidProtocolBufferException e) {
    701       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE,
    702           "Can't parse message 2 " + e.getMessage());
    703     }
    704 
    705     // Verify that message_type == Type.SERVER_INIT; send a BAD_MESSAGE_TYPE message if mismatch
    706     if (!message.hasMessageType()) {
    707       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_TYPE,
    708           "Expected, but did not find message type");
    709     }
    710     if (message.getMessageType() == Ukey2Message.Type.ALERT) {
    711       handshakeState = InternalState.HANDSHAKE_ERROR;
    712       throwHandshakeMessageFromAlertMessage(message);
    713     }
    714     if (message.getMessageType() != Ukey2Message.Type.SERVER_INIT) {
    715       throwAlertException(
    716           Ukey2Alert.AlertType.BAD_MESSAGE_TYPE,
    717           "Expected, but did not find SERVER_INIT message type");
    718     }
    719 
    720     // Deserialize message_data as a ServerInit message; send a BAD_MESSAGE_DATA message if
    721     // deserialization fails
    722     if (!message.hasMessageData()) {
    723 
    724       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_DATA,
    725           "Expected message data, but didn't find it");
    726     }
    727     Ukey2ServerInit serverInit = null;
    728     try {
    729       serverInit = Ukey2ServerInit.parseFrom(message.getMessageData());
    730     } catch (InvalidProtocolBufferException e) {
    731       throwAlertException(Ukey2Alert.AlertType.BAD_MESSAGE_DATA,
    732           "Can't parse message data into ServerInit");
    733     }
    734 
    735     // Check that version == VERSION; send BAD_VERSION message if mismatch
    736     if (!serverInit.hasVersion()) {
    737       throwAlertException(Ukey2Alert.AlertType.BAD_VERSION, "ServerInit missing version");
    738     }
    739     if (serverInit.getVersion() != VERSION) {
    740       throwAlertException(Ukey2Alert.AlertType.BAD_VERSION, "ServerInit version mismatch");
    741     }
    742 
    743     // Check that random is exactly NONCE_LENGTH_IN_BYTES bytes; send Alert.BAD_RANDOM message if
    744     // not.
    745     if (!serverInit.hasRandom()) {
    746       throwAlertException(Ukey2Alert.AlertType.BAD_RANDOM, "ServerInit missing random");
    747     }
    748     if (serverInit.getRandom().toByteArray().length != NONCE_LENGTH_IN_BYTES) {
    749       throwAlertException(Ukey2Alert.AlertType.BAD_RANDOM, "ServerInit has incorrect nonce length");
    750     }
    751 
    752     // Check that handshake_cipher matches a handshake cipher that was sent in
    753     // ClientInit.cipher_commitments. If not, send a BAD_HANDSHAKECIPHER message
    754     if (!serverInit.hasHandshakeCipher()) {
    755       throwAlertException(Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER, "No handshake cipher found");
    756     }
    757     HandshakeCipher serverCipher = null;
    758     for (HandshakeCipher cipher : HandshakeCipher.values()) {
    759       if (cipher.getValue() == serverInit.getHandshakeCipher()) {
    760         serverCipher = cipher;
    761         break;
    762       }
    763     }
    764     if (serverCipher == null || serverCipher != handshakeCipher) {
    765       throwAlertException(Ukey2Alert.AlertType.BAD_HANDSHAKE_CIPHER,
    766           "No acceptable handshake cipher found");
    767     }
    768 
    769     // Check that public_key parses into a correct public key structure. If not, send a
    770     // BAD_PUBLIC_KEY message.
    771     if (!serverInit.hasPublicKey()) {
    772       throwAlertException(Ukey2Alert.AlertType.BAD_PUBLIC_KEY, "No public key found in ServerInit");
    773     }
    774     theirPublicKey = parseP256PublicKey(serverInit.getPublicKey().toByteArray());
    775 
    776     // Store raw message for AUTH_STRING computation
    777     rawMessage2 = handshakeMessage;
    778   }
    779 
    780   /**
    781    * Attempts to parse message 3 (which is a wrapped {@link Ukey2ClientFinished}). See go/ukey2 for
    782    * details.
    783    */
    784   private void parseMessage3(final byte[] handshakeMessage) throws HandshakeException {
    785     // Deserialize the protobuf; terminate the connection if deserialization fails.
    786     Ukey2Message message = null;
    787     try {
    788       message = Ukey2Message.parseFrom(handshakeMessage);
    789     } catch (InvalidProtocolBufferException e) {
    790       throwHandshakeException("Can't parse message 3", e);
    791     }
    792 
    793     // Verify that message_type == Type.CLIENT_FINISH; terminate connection if mismatch occurs
    794     if (!message.hasMessageType()) {
    795       throw new HandshakeException("Expected, but did not find message type");
    796     }
    797     if (message.getMessageType() == Ukey2Message.Type.ALERT) {
    798       throwHandshakeMessageFromAlertMessage(message);
    799     }
    800     if (message.getMessageType() != Ukey2Message.Type.CLIENT_FINISH) {
    801       throwHandshakeException("Expected, but did not find CLIENT_FINISH message type");
    802     }
    803 
    804     // Verify that the hash of the ClientFinished matches the expected commitment from ClientInit.
    805     // Terminate the connection if the expected match fails.
    806     verifyCommitment(handshakeMessage);
    807 
    808     // Deserialize message_data as a ClientFinished message; terminate the connection if
    809     // deserialization fails.
    810     if (!message.hasMessageData()) {
    811       throwHandshakeException("Expected message data, but didn't find it");
    812     }
    813     Ukey2ClientFinished clientFinished = null;
    814     try {
    815       clientFinished = Ukey2ClientFinished.parseFrom(message.getMessageData());
    816     } catch (InvalidProtocolBufferException e) {
    817       throwHandshakeException(e);
    818     }
    819 
    820     // Check that public_key parses into a correct public key structure. If not, terminate the
    821     // connection.
    822     if (!clientFinished.hasPublicKey()) {
    823       throwHandshakeException("No public key found in ClientFinished");
    824     }
    825     try {
    826       theirPublicKey = parseP256PublicKey(clientFinished.getPublicKey().toByteArray());
    827     } catch (AlertException e) {
    828       // Wrap in a HandshakeException because error should not be sent on the wire.
    829       throwHandshakeException(e);
    830     }
    831   }
    832 
    833   private void verifyCommitment(byte[] handshakeMessage) throws HandshakeException {
    834     byte[] actualClientFinishHash = null;
    835     switch (handshakeCipher) {
    836       case P256_SHA512:
    837         actualClientFinishHash = sha512(handshakeMessage);
    838         break;
    839       default:
    840         // should be unreachable
    841         throwIllegalStateException("Unexpected handshakeCipher");
    842     }
    843 
    844     // Time constant after Java SE 6 Update 17
    845     // See http://www.oracle.com/technetwork/java/javase/6u17-141447.html
    846     if (!MessageDigest.isEqual(actualClientFinishHash, theirCommitment)) {
    847       throwHandshakeException("Commitment does not match");
    848     }
    849   }
    850 
    851   private void throwHandshakeMessageFromAlertMessage(Ukey2Message message)
    852       throws HandshakeException {
    853     if (message.hasMessageData()) {
    854       Ukey2Alert alert = null;
    855       try {
    856         alert = Ukey2Alert.parseFrom(message.getMessageData());
    857       } catch (InvalidProtocolBufferException e) {
    858         throwHandshakeException("Cannot parse alert message", e);
    859       }
    860 
    861       if (alert.hasType() && alert.hasErrorMessage()) {
    862         throwHandshakeException(
    863             "Received Alert message. Type: "
    864                 + alert.getType()
    865                 + " Error Message: "
    866                 + alert.getErrorMessage());
    867       } else if (alert.hasType()) {
    868         throwHandshakeException("Received Alert message. Type: " + alert.getType());
    869       }
    870     }
    871 
    872     throwHandshakeException("Received empty Alert Message");
    873   }
    874 
    875   /**
    876    * Parses an encoded public P256 key.
    877    */
    878   private PublicKey parseP256PublicKey(byte[] encodedPublicKey)
    879       throws AlertException, HandshakeException {
    880     try {
    881       return PublicKeyProtoUtil.parsePublicKey(GenericPublicKey.parseFrom(encodedPublicKey));
    882     } catch (InvalidProtocolBufferException | InvalidKeySpecException e) {
    883       throwAlertException(Ukey2Alert.AlertType.BAD_PUBLIC_KEY,
    884           "Cannot parse public key: " + e.getMessage());
    885       return null; // unreachable, but makes compiler happy
    886     }
    887   }
    888 
    889   /**
    890    * Generates a {@link CipherCommitment} for the P256_SHA512 cipher.
    891    */
    892   private CipherCommitment generateP256SHA512Commitment() throws HandshakeException {
    893     // Generate the corresponding finished message if it's not done yet
    894     if (!rawMessage3Map.containsKey(HandshakeCipher.P256_SHA512)) {
    895       generateP256SHA512ClientFinished(ourKeyPair);
    896     }
    897 
    898     CipherCommitment.Builder cipherCommitment = CipherCommitment.newBuilder();
    899     cipherCommitment.setHandshakeCipher(UkeyProto.Ukey2HandshakeCipher.P256_SHA512);
    900     cipherCommitment.setCommitment(
    901         ByteString.copyFrom(sha512(rawMessage3Map.get(HandshakeCipher.P256_SHA512))));
    902 
    903     return cipherCommitment.build();
    904   }
    905 
    906   /**
    907    * Generates and records a {@link Ukey2ClientFinished} message for the P256_SHA512 cipher.
    908    */
    909   private Ukey2ClientFinished generateP256SHA512ClientFinished(KeyPair p256KeyPair) {
    910     byte[] encodedKey = PublicKeyProtoUtil.encodePublicKey(p256KeyPair.getPublic()).toByteArray();
    911 
    912     Ukey2ClientFinished.Builder clientFinished = Ukey2ClientFinished.newBuilder();
    913     clientFinished.setPublicKey(ByteString.copyFrom(encodedKey));
    914 
    915     rawMessage3Map.put(
    916         HandshakeCipher.P256_SHA512,
    917         makeUkey2Message(Ukey2Message.Type.CLIENT_FINISH, clientFinished.build().toByteArray()));
    918 
    919     return clientFinished.build();
    920   }
    921 
    922   /**
    923    * Generates the serialized representation of a {@link Ukey2Message} based on the provided type
    924    * and data.
    925    */
    926   private byte[] makeUkey2Message(Ukey2Message.Type messageType, byte[] messageData) {
    927     Ukey2Message.Builder message = Ukey2Message.newBuilder();
    928 
    929     switch (messageType) {
    930       case ALERT:
    931       case CLIENT_INIT:
    932       case SERVER_INIT:
    933       case CLIENT_FINISH:
    934         // fall through intentional; valid message types
    935         break;
    936       default:
    937         throwIllegalArgumentException("Invalid message type: " + messageType);
    938     }
    939     message.setMessageType(messageType);
    940 
    941     // Alerts a blank message data field
    942     if (messageType != Ukey2Message.Type.ALERT) {
    943       if (messageData == null || messageData.length == 0) {
    944         throwIllegalArgumentException("Cannot send empty message data for non-alert messages");
    945       }
    946       message.setMessageData(ByteString.copyFrom(messageData));
    947     }
    948 
    949     return message.build().toByteArray();
    950   }
    951 
    952   /**
    953    * Returns a {@link Ukey2Alert} message of given type and having the loggable additional data if
    954    * present.
    955    */
    956   private Ukey2Alert makeAlertMessage(Ukey2Alert.AlertType alertType,
    957       @Nullable String loggableAdditionalData) throws HandshakeException {
    958     switch (alertType) {
    959       case BAD_MESSAGE:
    960       case BAD_MESSAGE_TYPE:
    961       case INCORRECT_MESSAGE:
    962       case BAD_MESSAGE_DATA:
    963       case BAD_VERSION:
    964       case BAD_RANDOM:
    965       case BAD_HANDSHAKE_CIPHER:
    966       case BAD_NEXT_PROTOCOL:
    967       case BAD_PUBLIC_KEY:
    968       case INTERNAL_ERROR:
    969         // fall through intentional; valid alert types
    970         break;
    971       default:
    972         throwHandshakeException("Unknown alert type: " + alertType);
    973     }
    974 
    975     Ukey2Alert.Builder alert = Ukey2Alert.newBuilder();
    976     alert.setType(alertType);
    977 
    978     if (loggableAdditionalData != null) {
    979       alert.setErrorMessage(loggableAdditionalData);
    980     }
    981 
    982     return alert.build();
    983   }
    984 
    985   /**
    986    * Generates a cryptoraphically random nonce of NONCE_LENGTH_IN_BYTES bytes.
    987    */
    988   private static byte[] generateRandomNonce() {
    989     SecureRandom rng = new SecureRandom();
    990     byte[] randomNonce = new byte[NONCE_LENGTH_IN_BYTES];
    991     rng.nextBytes(randomNonce);
    992     return randomNonce;
    993   }
    994 
    995   /**
    996    * Handy wrapper to do SHA512.
    997    */
    998   private byte[] sha512(byte[] input) throws HandshakeException {
    999     MessageDigest sha512;
   1000     try {
   1001       sha512 = MessageDigest.getInstance("SHA-512");
   1002       return sha512.digest(input);
   1003     } catch (NoSuchAlgorithmException e) {
   1004       throwHandshakeException("No security provider initialized yet?", e);
   1005       return null; // unreachable in practice, but makes compiler happy
   1006     }
   1007   }
   1008 
   1009   // Exception wrappers that remember to set the handshake state to ERROR
   1010 
   1011   private void throwAlertException(Ukey2Alert.AlertType alertType, String alertLogStatement)
   1012       throws AlertException, HandshakeException {
   1013     handshakeState = InternalState.HANDSHAKE_ERROR;
   1014     throw new AlertException(alertLogStatement, makeAlertMessage(alertType, alertLogStatement));
   1015   }
   1016 
   1017   private void throwHandshakeException(String logMessage) throws HandshakeException {
   1018     handshakeState = InternalState.HANDSHAKE_ERROR;
   1019     throw new HandshakeException(logMessage);
   1020   }
   1021 
   1022   private void throwHandshakeException(Exception e) throws HandshakeException {
   1023     handshakeState = InternalState.HANDSHAKE_ERROR;
   1024     throw new HandshakeException(e);
   1025   }
   1026 
   1027   private void throwHandshakeException(String logMessage, Exception e) throws HandshakeException {
   1028     handshakeState = InternalState.HANDSHAKE_ERROR;
   1029     throw new HandshakeException(logMessage, e);
   1030   }
   1031 
   1032   private void throwIllegalStateException(String logMessage) {
   1033     handshakeState = InternalState.HANDSHAKE_ERROR;
   1034     throw new IllegalStateException(logMessage);
   1035   }
   1036 
   1037   private void throwIllegalArgumentException(String logMessage) {
   1038     handshakeState = InternalState.HANDSHAKE_ERROR;
   1039     throw new IllegalArgumentException(logMessage);
   1040   }
   1041 }
   1042