Home | History | Annotate | Download | only in cts
      1 /*
      2  * Copyright (C) 2014 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.security.cts;
     18 
     19 import android.content.Context;
     20 import android.test.InstrumentationTestCase;
     21 import android.util.Log;
     22 
     23 import com.android.cts.security.R;
     24 
     25 import java.io.ByteArrayInputStream;
     26 import java.io.ByteArrayOutputStream;
     27 import java.io.EOFException;
     28 import java.io.IOException;
     29 import java.io.InputStream;
     30 import java.io.OutputStream;
     31 import java.net.ServerSocket;
     32 import java.net.Socket;
     33 import java.net.SocketAddress;
     34 import java.security.KeyFactory;
     35 import java.security.Principal;
     36 import java.security.PrivateKey;
     37 import java.security.cert.CertificateException;
     38 import java.security.cert.CertificateFactory;
     39 import java.security.cert.X509Certificate;
     40 import java.security.spec.PKCS8EncodedKeySpec;
     41 import java.util.concurrent.Callable;
     42 import java.util.concurrent.ExecutionException;
     43 import java.util.concurrent.ExecutorService;
     44 import java.util.concurrent.Executors;
     45 import java.util.concurrent.Future;
     46 import java.util.concurrent.TimeUnit;
     47 
     48 import javax.net.ServerSocketFactory;
     49 import javax.net.SocketFactory;
     50 import javax.net.ssl.KeyManager;
     51 import javax.net.ssl.SSLContext;
     52 import javax.net.ssl.SSLException;
     53 import javax.net.ssl.SSLServerSocket;
     54 import javax.net.ssl.SSLSocket;
     55 import javax.net.ssl.TrustManager;
     56 import javax.net.ssl.X509KeyManager;
     57 import javax.net.ssl.X509TrustManager;
     58 
     59 /**
     60  * Tests for the OpenSSL Heartbleed vulnerability.
     61  */
     62 public class OpenSSLHeartbleedTest extends InstrumentationTestCase {
     63 
     64     // IMPLEMENTATION NOTE: This test spawns an SSLSocket client, SSLServerSocket server, and a
     65     // Man-in-The-Middle (MiTM). The client connects to the MiTM which then connects to the server
     66     // and starts forwarding all TLS records between the client and the server. In tests that check
     67     // for the Heartbleed vulnerability, the MiTM also injects a HeartbeatRequest message into the
     68     // traffic.
     69 
     70     // IMPLEMENTATION NOTE: This test spawns several background threads that perform network I/O
     71     // on localhost. To ensure that these background threads are cleaned up at the end of the test
     72     // tearDown() kills the sockets they may be using. To aid this behavior, all Socket and
     73     // ServerSocket instances are available as fields of this class. These fields should be accessed
     74     // via setters and getters to avoid memory visibility issues due to concurrency.
     75 
     76     private static final String TAG = OpenSSLHeartbleedTest.class.getSimpleName();
     77 
     78     private SSLServerSocket mServerListeningSocket;
     79     private SSLSocket mServerSocket;
     80     private SSLSocket mClientSocket;
     81     private ServerSocket mMitmListeningSocket;
     82     private Socket mMitmServerSocket;
     83     private Socket mMitmClientSocket;
     84     private ExecutorService mExecutorService;
     85 
     86     private boolean mHeartbeatRequestWasInjected;
     87     private boolean mHeartbeatResponseWasDetetected;
     88     private int mFirstDetectedFatalAlertDescription = -1;
     89 
     90     @Override
     91     protected void tearDown() throws Exception {
     92         Log.i(TAG, "Tearing down");
     93         if (mExecutorService != null) {
     94             mExecutorService.shutdownNow();
     95         }
     96         closeQuietly(getServerListeningSocket());
     97         closeQuietly(getServerSocket());
     98         closeQuietly(getClientSocket());
     99         closeQuietly(getMitmListeningSocket());
    100         closeQuietly(getMitmServerSocket());
    101         closeQuietly(getMitmClientSocket());
    102         super.tearDown();
    103         Log.i(TAG, "Tear down completed");
    104     }
    105 
    106     /**
    107      * Tests that TLS handshake succeeds when the MiTM simply forwards all data without tampering
    108      * with it. This is to catch issues unrelated to TLS heartbeats.
    109      */
    110     public void testWithoutHeartbeats() throws Exception {
    111         handshake(false, false);
    112     }
    113 
    114     /**
    115      * Tests whether client sockets are vulnerable to Heartbleed.
    116      */
    117     public void testClientHeartbleed() throws Exception {
    118         checkHeartbleed(true);
    119     }
    120 
    121     /**
    122      * Tests whether server sockets are vulnerable to Heartbleed.
    123      */
    124     public void testServerHeartbleed() throws Exception {
    125         checkHeartbleed(false);
    126     }
    127 
    128     /**
    129      * Tests for Heartbleed.
    130      *
    131      * @param client {@code true} to test the client, {@code false} to test the server.
    132      */
    133     private void checkHeartbleed(boolean client) throws Exception {
    134         // IMPLEMENTATION NOTE: The MiTM is forwarding all TLS records between the client and the
    135         // server unmodified. Additionally, the MiTM transmits a malformed HeartbeatRequest to
    136         // server (if "client" argument is false) right after client's ClientKeyExchange or to
    137         // client (if "client" argument is true) right after server's ServerHello. The peer is
    138         // expected to either ignore the HeartbeatRequest (if heartbeats are supported) or to abort
    139         // the handshake with unexpected_message alert (if heartbeats are not supported).
    140         try {
    141             handshake(true, client);
    142         } catch (ExecutionException e) {
    143             assertFalse(
    144                     "SSLSocket is vulnerable to Heartbleed in " + ((client) ? "client" : "server")
    145                             + " mode",
    146                     wasHeartbeatResponseDetected());
    147             if (e.getCause() instanceof SSLException) {
    148                 // TLS handshake or data exchange failed. Check whether the error was caused by
    149                 // fatal alert unexpected_message
    150                 int alertDescription = getFirstDetectedFatalAlertDescription();
    151                 if (alertDescription == -1) {
    152                     fail("Handshake failed without a fatal alert");
    153                 }
    154                 assertEquals(
    155                         "First fatal alert description received from server",
    156                         AlertMessage.DESCRIPTION_UNEXPECTED_MESSAGE,
    157                         alertDescription);
    158                 return;
    159             } else {
    160                 throw e;
    161             }
    162         }
    163 
    164         // TLS handshake succeeded
    165         assertFalse(
    166                 "SSLSocket is vulnerable to Heartbleed in " + ((client) ? "client" : "server")
    167                         + " mode",
    168                 wasHeartbeatResponseDetected());
    169         assertTrue("HeartbeatRequest not injected", wasHeartbeatRequestInjected());
    170     }
    171 
    172     /**
    173      * Starts the client, server, and the MiTM. Makes the client and server perform a TLS handshake
    174      * and exchange application-level data. The MiTM injects a HeartbeatRequest message if requested
    175      * by {@code heartbeatRequestInjected}. The direction of the injected message is specified by
    176      * {@code injectedIntoClient}.
    177      */
    178     private void handshake(
    179             final boolean heartbeatRequestInjected,
    180             final boolean injectedIntoClient) throws Exception {
    181         mExecutorService = Executors.newFixedThreadPool(4);
    182         setServerListeningSocket(serverBind());
    183         final SocketAddress serverAddress = getServerListeningSocket().getLocalSocketAddress();
    184         Log.i(TAG, "Server bound to " + serverAddress);
    185 
    186         setMitmListeningSocket(mitmBind());
    187         final SocketAddress mitmAddress = getMitmListeningSocket().getLocalSocketAddress();
    188         Log.i(TAG, "MiTM bound to " + mitmAddress);
    189 
    190         // Start the MiTM daemon in the background
    191         mExecutorService.submit(new Callable<Void>() {
    192             @Override
    193             public Void call() throws Exception {
    194                 mitmAcceptAndForward(
    195                         serverAddress,
    196                         heartbeatRequestInjected,
    197                         injectedIntoClient);
    198                 return null;
    199             }
    200         });
    201         // Start the server in the background
    202         Future<Void> serverFuture = mExecutorService.submit(new Callable<Void>() {
    203             @Override
    204             public Void call() throws Exception {
    205                 serverAcceptAndHandshake();
    206                 return null;
    207             }
    208         });
    209         // Start the client in the background
    210         Future<Void> clientFuture = mExecutorService.submit(new Callable<Void>() {
    211             @Override
    212             public Void call() throws Exception {
    213                 clientConnectAndHandshake(mitmAddress);
    214                 return null;
    215             }
    216         });
    217 
    218         // Wait for both client and server to terminate, to ensure that we observe all the traffic
    219         // exchanged between them. Throw an exception if one of them failed.
    220         Log.i(TAG, "Waiting for client");
    221         // Wait for the client, but don't yet throw an exception if it failed.
    222         Exception clientException = null;
    223         try {
    224             clientFuture.get(10, TimeUnit.SECONDS);
    225         } catch (Exception e) {
    226             clientException = e;
    227         }
    228         Log.i(TAG, "Waiting for server");
    229         // Wait for the server and throw an exception if it failed.
    230         serverFuture.get(5, TimeUnit.SECONDS);
    231         // Throw an exception if the client failed.
    232         if (clientException != null) {
    233             throw clientException;
    234         }
    235         Log.i(TAG, "Handshake completed and application data exchanged");
    236     }
    237 
    238     private void clientConnectAndHandshake(SocketAddress serverAddress) throws Exception {
    239         SSLContext sslContext = SSLContext.getInstance("TLS");
    240         sslContext.init(
    241                 null,
    242                 new TrustManager[] {new TrustAllX509TrustManager()},
    243                 null);
    244         SSLSocket socket = (SSLSocket) sslContext.getSocketFactory().createSocket();
    245         setClientSocket(socket);
    246         try {
    247             Log.i(TAG, "Client connecting to " + serverAddress);
    248             socket.connect(serverAddress);
    249             Log.i(TAG, "Client connected to server from " + socket.getLocalSocketAddress());
    250             // Ensure a TLS handshake is performed and an exception is thrown if it fails.
    251             socket.getOutputStream().write("client".getBytes());
    252             socket.getOutputStream().flush();
    253             Log.i(TAG, "Client sent request. Reading response");
    254             int b = socket.getInputStream().read();
    255             Log.i(TAG, "Client read response: " + b);
    256         } catch (Exception e) {
    257             Log.w(TAG, "Client failed", e);
    258             throw e;
    259           } finally {
    260             socket.close();
    261         }
    262     }
    263 
    264     public SSLServerSocket serverBind() throws Exception {
    265         // Load the server's private key and cert chain
    266         KeyFactory keyFactory = KeyFactory.getInstance("RSA");
    267         PrivateKey privateKey = keyFactory.generatePrivate(new PKCS8EncodedKeySpec(
    268                 readResource(
    269                         getInstrumentation().getContext(), R.raw.openssl_heartbleed_test_key)));
    270         CertificateFactory certFactory = CertificateFactory.getInstance("X.509");
    271         X509Certificate[] certChain =  new X509Certificate[] {
    272                 (X509Certificate) certFactory.generateCertificate(
    273                         new ByteArrayInputStream(readResource(
    274                                 getInstrumentation().getContext(),
    275                                 R.raw.openssl_heartbleed_test_cert)))
    276         };
    277 
    278         // Initialize TLS context to use the private key and cert chain for server sockets
    279         SSLContext sslContext = SSLContext.getInstance("TLS");
    280         sslContext.init(
    281                 new KeyManager[] {new HardcodedCertX509KeyManager(privateKey, certChain)},
    282                 null,
    283                 null);
    284 
    285         Log.i(TAG, "Server binding to local port");
    286         return (SSLServerSocket) sslContext.getServerSocketFactory().createServerSocket(0);
    287     }
    288 
    289     private void serverAcceptAndHandshake() throws Exception {
    290         SSLSocket socket = null;
    291         SSLServerSocket serverSocket = getServerListeningSocket();
    292         try {
    293             Log.i(TAG, "Server listening for incoming connection");
    294             socket = (SSLSocket) serverSocket.accept();
    295             setServerSocket(socket);
    296             Log.i(TAG, "Server accepted connection from " + socket.getRemoteSocketAddress());
    297             // Ensure a TLS handshake is performed and an exception is thrown if it fails.
    298             socket.getOutputStream().write("server".getBytes());
    299             socket.getOutputStream().flush();
    300             Log.i(TAG, "Server sent reply. Reading response");
    301             int b = socket.getInputStream().read();
    302             Log.i(TAG, "Server read response: " + b);
    303         } catch (Exception e) {
    304           Log.w(TAG, "Server failed", e);
    305           throw e;
    306         } finally {
    307             if (socket != null) {
    308                 socket.close();
    309             }
    310         }
    311     }
    312 
    313     private ServerSocket mitmBind() throws Exception {
    314         Log.i(TAG, "MiTM binding to local port");
    315         return ServerSocketFactory.getDefault().createServerSocket(0);
    316     }
    317 
    318     /**
    319      * Accepts the connection on the MiTM listening socket, forwards the TLS records between the
    320      * client and the server, and, if requested, injects a {@code HeartbeatRequest}.
    321      *
    322      * @param injectHeartbeat whether to inject a {@code HeartbeatRequest} message.
    323      * @param injectIntoClient when {@code injectHeartbeat} is {@code true}, whether to inject the
    324      *        {@code HeartbeatRequest} message into client or into server.
    325      */
    326     private void mitmAcceptAndForward(
    327             SocketAddress serverAddress,
    328             final boolean injectHeartbeat,
    329             final boolean injectIntoClient) throws Exception {
    330         Socket clientSocket = null;
    331         Socket serverSocket = null;
    332         ServerSocket listeningSocket = getMitmListeningSocket();
    333         try {
    334             Log.i(TAG, "MiTM waiting for incoming connection");
    335             clientSocket = listeningSocket.accept();
    336             setMitmClientSocket(clientSocket);
    337             Log.i(TAG, "MiTM accepted connection from " + clientSocket.getRemoteSocketAddress());
    338             serverSocket = SocketFactory.getDefault().createSocket();
    339             setMitmServerSocket(serverSocket);
    340             Log.i(TAG, "MiTM connecting to server " + serverAddress);
    341             serverSocket.connect(serverAddress, 10000);
    342             Log.i(TAG, "MiTM connected to server from " + serverSocket.getLocalSocketAddress());
    343             final InputStream serverInputStream = serverSocket.getInputStream();
    344             final OutputStream clientOutputStream = clientSocket.getOutputStream();
    345             Future<Void> serverToClientTask = mExecutorService.submit(new Callable<Void>() {
    346                 @Override
    347                 public Void call() throws Exception {
    348                     // Inject HeatbeatRequest after ServerHello, if requested
    349                     forwardTlsRecords(
    350                             "MiTM S->C",
    351                             serverInputStream,
    352                             clientOutputStream,
    353                             (injectHeartbeat && injectIntoClient)
    354                                     ? HandshakeMessage.TYPE_SERVER_HELLO : -1);
    355                     return null;
    356                 }
    357             });
    358             // Inject HeatbeatRequest after ClientKeyExchange, if requested
    359             forwardTlsRecords(
    360                     "MiTM C->S",
    361                     clientSocket.getInputStream(),
    362                     serverSocket.getOutputStream(),
    363                     (injectHeartbeat && !injectIntoClient)
    364                             ? HandshakeMessage.TYPE_CLIENT_KEY_EXCHANGE : -1);
    365             serverToClientTask.get(10, TimeUnit.SECONDS);
    366         } catch (Exception e) {
    367             Log.w(TAG, "MiTM failed", e);
    368             throw e;
    369           } finally {
    370             closeQuietly(clientSocket);
    371             closeQuietly(serverSocket);
    372         }
    373     }
    374 
    375     /**
    376      * Forwards TLS records from the provided {@code InputStream} to the provided
    377      * {@code OutputStream}. If requested, injects a {@code HeartbeatMessage}.
    378      */
    379     private void forwardTlsRecords(
    380             String logPrefix,
    381             InputStream in,
    382             OutputStream out,
    383             int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest) throws Exception {
    384         Log.i(TAG, logPrefix + ": record forwarding started");
    385         boolean interestingRecordsLogged =
    386                 handshakeMessageTypeAfterWhichToInjectHeartbeatRequest == -1;
    387         try {
    388             TlsRecordReader reader = new TlsRecordReader(in);
    389             byte[] recordBytes;
    390             // Fragments contained in records may be encrypted after a certain point in the
    391             // handshake. Once they are encrypted, this MiTM cannot inspect their plaintext which.
    392             boolean fragmentEncryptionMayBeEnabled = false;
    393             while ((recordBytes = reader.readRecord()) != null) {
    394                 TlsRecord record = TlsRecord.parse(recordBytes);
    395                 forwardTlsRecord(logPrefix,
    396                         recordBytes,
    397                         record,
    398                         fragmentEncryptionMayBeEnabled,
    399                         out,
    400                         interestingRecordsLogged,
    401                         handshakeMessageTypeAfterWhichToInjectHeartbeatRequest);
    402                 if (record.protocol == TlsProtocols.CHANGE_CIPHER_SPEC) {
    403                     fragmentEncryptionMayBeEnabled = true;
    404                 }
    405             }
    406         } catch (Exception e) {
    407             Log.w(TAG, logPrefix + ": failed", e);
    408             throw e;
    409         } finally {
    410             Log.d(TAG, logPrefix + ": record forwarding finished");
    411         }
    412     }
    413 
    414     private void forwardTlsRecord(
    415             String logPrefix,
    416             byte[] recordBytes,
    417             TlsRecord record,
    418             boolean fragmentEncryptionMayBeEnabled,
    419             OutputStream out,
    420             boolean interestingRecordsLogged,
    421             int handshakeMessageTypeAfterWhichToInjectHeartbeatRequest) throws IOException {
    422         // Save information about the records if its of interest to this test
    423         if (interestingRecordsLogged) {
    424             switch (record.protocol) {
    425                 case TlsProtocols.ALERT:
    426                     if (!fragmentEncryptionMayBeEnabled) {
    427                         AlertMessage alert = AlertMessage.tryParse(record);
    428                         if ((alert != null) && (alert.level == AlertMessage.LEVEL_FATAL)) {
    429                             setFatalAlertDetected(alert.description);
    430                         }
    431                     }
    432                     break;
    433                 case TlsProtocols.HEARTBEAT:
    434                     // When TLS records are encrypted, we cannot determine whether a
    435                     // heartbeat is a HeartbeatResponse. In our setup, the client and the
    436                     // server are not expected to sent HeartbeatRequests. Thus, we err on
    437                     // the side of caution and assume that any heartbeat message sent by
    438                     // client or server is a HeartbeatResponse.
    439                     Log.e(TAG, logPrefix
    440                             + ": heartbeat response detected -- vulnerable to Heartbleed");
    441                     setHeartbeatResponseWasDetected();
    442                     break;
    443             }
    444         }
    445 
    446         Log.i(TAG, logPrefix + ": Forwarding TLS record. "
    447                 + getRecordInfo(record, fragmentEncryptionMayBeEnabled));
    448         out.write(recordBytes);
    449         out.flush();
    450 
    451         // Inject HeartbeatRequest, if necessary, after the specified handshake message type
    452         if (handshakeMessageTypeAfterWhichToInjectHeartbeatRequest != -1) {
    453             if ((!fragmentEncryptionMayBeEnabled) && (isHandshakeMessageType(
    454                     record, handshakeMessageTypeAfterWhichToInjectHeartbeatRequest))) {
    455                 // The Heartbeat Request message below is malformed because its declared
    456                 // length of payload one byte larger than the actual payload. The peer is
    457                 // supposed to reject such messages.
    458                 byte[] payload = "arbitrary".getBytes("US-ASCII");
    459                 byte[] heartbeatRequestRecordBytes = createHeartbeatRequestRecord(
    460                         record.versionMajor,
    461                         record.versionMinor,
    462                         payload.length + 1,
    463                         payload);
    464                 Log.i(TAG, logPrefix + ": Injecting malformed HeartbeatRequest: "
    465                         + getRecordInfo(
    466                                 TlsRecord.parse(heartbeatRequestRecordBytes), false));
    467                 setHeartbeatRequestWasInjected();
    468                 out.write(heartbeatRequestRecordBytes);
    469                 out.flush();
    470             }
    471         }
    472     }
    473 
    474     private static String getRecordInfo(TlsRecord record, boolean mayBeEncrypted) {
    475         StringBuilder result = new StringBuilder();
    476         result.append(getProtocolName(record.protocol))
    477                 .append(", ")
    478                 .append(getFragmentInfo(record, mayBeEncrypted));
    479         return result.toString();
    480     }
    481 
    482     private static String getProtocolName(int protocol) {
    483         switch (protocol) {
    484             case TlsProtocols.ALERT:
    485                 return "alert";
    486             case TlsProtocols.APPLICATION_DATA:
    487                 return "application data";
    488             case TlsProtocols.CHANGE_CIPHER_SPEC:
    489                 return "change cipher spec";
    490             case TlsProtocols.HANDSHAKE:
    491                 return "handshake";
    492             case TlsProtocols.HEARTBEAT:
    493                 return "heatbeat";
    494             default:
    495                 return String.valueOf(protocol);
    496         }
    497     }
    498 
    499     private static String getFragmentInfo(TlsRecord record, boolean mayBeEncrypted) {
    500         StringBuilder result = new StringBuilder();
    501         if (mayBeEncrypted) {
    502             result.append("encrypted?");
    503         } else {
    504             switch (record.protocol) {
    505                 case TlsProtocols.ALERT:
    506                     result.append("level: " + ((record.fragment.length > 0)
    507                             ? String.valueOf(record.fragment[0] & 0xff) : "n/a")
    508                     + ", description: "
    509                     + ((record.fragment.length > 1)
    510                             ? String.valueOf(record.fragment[1] & 0xff) : "n/a"));
    511                     break;
    512                 case TlsProtocols.APPLICATION_DATA:
    513                     break;
    514                 case TlsProtocols.CHANGE_CIPHER_SPEC:
    515                     result.append("payload: " + ((record.fragment.length > 0)
    516                             ? String.valueOf(record.fragment[0] & 0xff) : "n/a"));
    517                     break;
    518                 case TlsProtocols.HANDSHAKE:
    519                     result.append("type: " + ((record.fragment.length > 0)
    520                             ? String.valueOf(record.fragment[0] & 0xff) : "n/a"));
    521                     break;
    522                 case TlsProtocols.HEARTBEAT:
    523                     result.append("type: " + ((record.fragment.length > 0)
    524                             ? String.valueOf(record.fragment[0] & 0xff) : "n/a")
    525                             + ", payload length: "
    526                             + ((record.fragment.length >= 3)
    527                                     ? String.valueOf(
    528                                             getUnsignedShortBigEndian(record.fragment, 1))
    529                                     : "n/a"));
    530                     break;
    531             }
    532         }
    533         result.append(", ").append("fragment length: " + record.fragment.length);
    534         return result.toString();
    535     }
    536 
    537     private synchronized void setServerListeningSocket(SSLServerSocket socket) {
    538         mServerListeningSocket = socket;
    539     }
    540 
    541     private synchronized SSLServerSocket getServerListeningSocket() {
    542         return mServerListeningSocket;
    543     }
    544 
    545     private synchronized void setServerSocket(SSLSocket socket) {
    546         mServerSocket = socket;
    547     }
    548 
    549     private synchronized SSLSocket getServerSocket() {
    550         return mServerSocket;
    551     }
    552 
    553     private synchronized void setClientSocket(SSLSocket socket) {
    554         mClientSocket = socket;
    555     }
    556 
    557     private synchronized SSLSocket getClientSocket() {
    558         return mClientSocket;
    559     }
    560 
    561     private synchronized void setMitmListeningSocket(ServerSocket socket) {
    562         mMitmListeningSocket = socket;
    563     }
    564 
    565     private synchronized ServerSocket getMitmListeningSocket() {
    566         return mMitmListeningSocket;
    567     }
    568 
    569     private synchronized void setMitmServerSocket(Socket socket) {
    570         mMitmServerSocket = socket;
    571     }
    572 
    573     private synchronized Socket getMitmServerSocket() {
    574         return mMitmServerSocket;
    575     }
    576 
    577     private synchronized void setMitmClientSocket(Socket socket) {
    578         mMitmClientSocket = socket;
    579     }
    580 
    581     private synchronized Socket getMitmClientSocket() {
    582         return mMitmClientSocket;
    583     }
    584 
    585     private synchronized void setHeartbeatRequestWasInjected() {
    586         mHeartbeatRequestWasInjected = true;
    587     }
    588 
    589     private synchronized boolean wasHeartbeatRequestInjected() {
    590         return mHeartbeatRequestWasInjected;
    591     }
    592 
    593     private synchronized void setHeartbeatResponseWasDetected() {
    594         mHeartbeatResponseWasDetetected = true;
    595     }
    596 
    597     private synchronized boolean wasHeartbeatResponseDetected() {
    598         return mHeartbeatResponseWasDetetected;
    599     }
    600 
    601     private synchronized void setFatalAlertDetected(int description) {
    602         if (mFirstDetectedFatalAlertDescription == -1) {
    603             mFirstDetectedFatalAlertDescription = description;
    604         }
    605     }
    606 
    607     private synchronized int getFirstDetectedFatalAlertDescription() {
    608         return mFirstDetectedFatalAlertDescription;
    609     }
    610 
    611     public static abstract class TlsProtocols {
    612         public static final int CHANGE_CIPHER_SPEC = 20;
    613         public static final int ALERT = 21;
    614         public static final int HANDSHAKE = 22;
    615         public static final int APPLICATION_DATA = 23;
    616         public static final int HEARTBEAT = 24;
    617         private TlsProtocols() {}
    618     }
    619 
    620     public static class TlsRecord {
    621         public int protocol;
    622         public int versionMajor;
    623         public int versionMinor;
    624         public byte[] fragment;
    625 
    626         public static TlsRecord parse(byte[] record) throws IOException {
    627             TlsRecord result = new TlsRecord();
    628             if (record.length < TlsRecordReader.RECORD_HEADER_LENGTH) {
    629                 throw new IOException("Record too short: " + record.length);
    630             }
    631             result.protocol = record[0] & 0xff;
    632             result.versionMajor = record[1] & 0xff;
    633             result.versionMinor = record[2] & 0xff;
    634             int fragmentLength = getUnsignedShortBigEndian(record, 3);
    635             int actualFragmentLength = record.length - TlsRecordReader.RECORD_HEADER_LENGTH;
    636             if (fragmentLength != actualFragmentLength) {
    637                 throw new IOException("Fragment length mismatch. Expected: " + fragmentLength
    638                         + ", actual: " + actualFragmentLength);
    639             }
    640             result.fragment = new byte[fragmentLength];
    641             System.arraycopy(
    642                     record, TlsRecordReader.RECORD_HEADER_LENGTH,
    643                     result.fragment, 0,
    644                     fragmentLength);
    645             return result;
    646         }
    647 
    648         public static byte[] unparse(TlsRecord record) {
    649             byte[] result = new byte[TlsRecordReader.RECORD_HEADER_LENGTH + record.fragment.length];
    650             result[0] = (byte) record.protocol;
    651             result[1] = (byte) record.versionMajor;
    652             result[2] = (byte) record.versionMinor;
    653             putUnsignedShortBigEndian(result, 3, record.fragment.length);
    654             System.arraycopy(
    655                     record.fragment, 0,
    656                     result, TlsRecordReader.RECORD_HEADER_LENGTH,
    657                     record.fragment.length);
    658             return result;
    659         }
    660     }
    661 
    662     public static final boolean isHandshakeMessageType(TlsRecord record, int type) {
    663         HandshakeMessage handshake = HandshakeMessage.tryParse(record);
    664         if (handshake == null) {
    665             return false;
    666         }
    667         return handshake.type == type;
    668     }
    669 
    670     public static class HandshakeMessage {
    671         public static final int TYPE_SERVER_HELLO = 2;
    672         public static final int TYPE_CERTIFICATE = 11;
    673         public static final int TYPE_CLIENT_KEY_EXCHANGE = 16;
    674 
    675         public int type;
    676 
    677         /**
    678          * Parses the provided TLS record as a handshake message.
    679          *
    680          * @return alert message or {@code null} if the record does not contain a handshake message.
    681          */
    682         public static HandshakeMessage tryParse(TlsRecord record) {
    683             if (record.protocol != TlsProtocols.HANDSHAKE) {
    684                 return null;
    685             }
    686             if (record.fragment.length < 1) {
    687                 return null;
    688             }
    689             HandshakeMessage result = new HandshakeMessage();
    690             result.type = record.fragment[0] & 0xff;
    691             return result;
    692         }
    693     }
    694 
    695     public static class AlertMessage {
    696         public static final int LEVEL_FATAL = 2;
    697         public static final int DESCRIPTION_UNEXPECTED_MESSAGE = 10;
    698 
    699         public int level;
    700         public int description;
    701 
    702         /**
    703          * Parses the provided TLS record as an alert message.
    704          *
    705          * @return alert message or {@code null} if the record does not contain an alert message.
    706          */
    707         public static AlertMessage tryParse(TlsRecord record) {
    708             if (record.protocol != TlsProtocols.ALERT) {
    709                 return null;
    710             }
    711             if (record.fragment.length < 2) {
    712                 return null;
    713             }
    714             AlertMessage result = new AlertMessage();
    715             result.level = record.fragment[0] & 0xff;
    716             result.description = record.fragment[1] & 0xff;
    717             return result;
    718         }
    719     }
    720 
    721     private static abstract class HeartbeatProtocol {
    722         private HeartbeatProtocol() {}
    723 
    724         private static final int MESSAGE_TYPE_REQUEST = 1;
    725         @SuppressWarnings("unused")
    726         private static final int MESSAGE_TYPE_RESPONSE = 2;
    727 
    728         private static final int MESSAGE_HEADER_LENGTH = 3;
    729         private static final int MESSAGE_PADDING_LENGTH = 16;
    730     }
    731 
    732     private static byte[] createHeartbeatRequestRecord(
    733             int versionMajor, int versionMinor,
    734             int declaredPayloadLength, byte[] payload) {
    735 
    736         byte[] fragment = new byte[HeartbeatProtocol.MESSAGE_HEADER_LENGTH
    737                 + payload.length + HeartbeatProtocol.MESSAGE_PADDING_LENGTH];
    738         fragment[0] = HeartbeatProtocol.MESSAGE_TYPE_REQUEST;
    739         putUnsignedShortBigEndian(fragment, 1, declaredPayloadLength); // payload_length
    740         TlsRecord record = new TlsRecord();
    741         record.protocol = TlsProtocols.HEARTBEAT;
    742         record.versionMajor = versionMajor;
    743         record.versionMinor = versionMinor;
    744         record.fragment = fragment;
    745         return TlsRecord.unparse(record);
    746     }
    747 
    748     /**
    749      * Reader of TLS records.
    750      */
    751     public static class TlsRecordReader {
    752         private static final int MAX_RECORD_LENGTH = 16384;
    753         public static final int RECORD_HEADER_LENGTH = 5;
    754 
    755         private final InputStream in;
    756         private final byte[] buffer;
    757         private int firstBufferedByteOffset;
    758         private int bufferedByteCount;
    759 
    760         public TlsRecordReader(InputStream in) {
    761             this.in = in;
    762             buffer = new byte[MAX_RECORD_LENGTH];
    763         }
    764 
    765         /**
    766          * Reads the next TLS record.
    767          *
    768          * @return TLS record or {@code null} if EOF was encountered before any bytes of a record
    769          *         could be read.
    770          */
    771         public byte[] readRecord() throws IOException {
    772             // Ensure that a TLS record header (or more) is in the buffer.
    773             if (bufferedByteCount < RECORD_HEADER_LENGTH) {
    774                 boolean eofPermittedInstead = (bufferedByteCount == 0);
    775                 boolean eofEncounteredInstead =
    776                         !readAtLeast(RECORD_HEADER_LENGTH, eofPermittedInstead);
    777                 if (eofEncounteredInstead) {
    778                     // End of stream reached exactly before a TLS record start.
    779                     return null;
    780                 }
    781             }
    782 
    783             // TLS record header (or more) is in the buffer.
    784             // Ensure that the rest of the record is in the buffer.
    785             int fragmentLength = getUnsignedShortBigEndian(buffer, firstBufferedByteOffset + 3);
    786             int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
    787             if (recordLength > MAX_RECORD_LENGTH) {
    788                 throw new IOException("TLS record too long: " + recordLength);
    789             }
    790             if (bufferedByteCount < recordLength) {
    791                 readAtLeast(recordLength - bufferedByteCount, false);
    792             }
    793 
    794             // TLS record (or more) is in the buffer.
    795             byte[] record = new byte[recordLength];
    796             System.arraycopy(buffer, firstBufferedByteOffset, record, 0, recordLength);
    797             firstBufferedByteOffset += recordLength;
    798             bufferedByteCount -= recordLength;
    799             return record;
    800         }
    801 
    802         /**
    803          * Reads at least the specified number of bytes from the underlying {@code InputStream} into
    804          * the {@code buffer}.
    805          *
    806          * <p>Bytes buffered but not yet returned to the client in the {@code buffer} are relocated
    807          * to the start of the buffer to make space if necessary.
    808          *
    809          * @param eofPermittedInstead {@code true} if it's permitted for an EOF to be encountered
    810          *        without any bytes having been read.
    811          *
    812          * @return {@code true} if the requested number of bytes (or more) has been read,
    813          *         {@code false} if {@code eofPermittedInstead} was {@code true} and EOF was
    814          *         encountered when no bytes have yet been read.
    815          */
    816         private boolean readAtLeast(int size, boolean eofPermittedInstead) throws IOException {
    817             ensureRemainingBufferCapacityAtLeast(size);
    818             boolean firstAttempt = true;
    819             while (size > 0) {
    820                 int chunkSize = in.read(
    821                         buffer,
    822                         firstBufferedByteOffset + bufferedByteCount,
    823                         buffer.length - (firstBufferedByteOffset + bufferedByteCount));
    824                 if (chunkSize == -1) {
    825                     if ((firstAttempt) && (eofPermittedInstead)) {
    826                         return false;
    827                     } else {
    828                         throw new EOFException("Premature EOF");
    829                     }
    830                 }
    831                 firstAttempt = false;
    832                 bufferedByteCount += chunkSize;
    833                 size -= chunkSize;
    834             }
    835             return true;
    836         }
    837 
    838         /**
    839          * Ensures that there is enough capacity in the buffer to store the specified number of
    840          * bytes at the {@code firstBufferedByteOffset + bufferedByteCount} offset.
    841          */
    842         private void ensureRemainingBufferCapacityAtLeast(int size) throws IOException {
    843             int bufferCapacityRemaining =
    844                     buffer.length - (firstBufferedByteOffset + bufferedByteCount);
    845             if (bufferCapacityRemaining >= size) {
    846                 return;
    847             }
    848             // Insufficient capacity at the end of the buffer.
    849             if (firstBufferedByteOffset > 0) {
    850                 // Some of the bytes at the start of the buffer have already been returned to the
    851                 // client of this reader. Check if moving the remaining buffered bytes to the start
    852                 // of the buffer will make enough space at the end of the buffer.
    853                 bufferCapacityRemaining += firstBufferedByteOffset;
    854                 if (bufferCapacityRemaining >= size) {
    855                     System.arraycopy(buffer, firstBufferedByteOffset, buffer, 0, bufferedByteCount);
    856                     firstBufferedByteOffset = 0;
    857                     return;
    858                 }
    859             }
    860 
    861             throw new IOException("Insuffucient remaining capacity in the buffer. Requested: "
    862                     + size + ", remaining: " + bufferCapacityRemaining);
    863         }
    864     }
    865 
    866     private static int getUnsignedShortBigEndian(byte[] buf, int offset) {
    867         return ((buf[offset] & 0xff) << 8) | (buf[offset + 1] & 0xff);
    868     }
    869 
    870     private static void putUnsignedShortBigEndian(byte[] buf, int offset, int value) {
    871         buf[offset] = (byte) ((value >>> 8) & 0xff);
    872         buf[offset + 1] = (byte) (value & 0xff);
    873     }
    874 
    875     // IMPLEMENTATION NOTE: We can't implement just one closeQueietly(Closeable) because on some
    876     // older Android platforms Socket did not implement these interfaces. To make this patch easy to
    877     // apply to these older platforms, we declare all the variants of closeQuietly that are needed
    878     // without relying on the Closeable interface.
    879 
    880     private static void closeQuietly(InputStream in) {
    881         if (in != null) {
    882             try {
    883                 in.close();
    884             } catch (IOException ignored) {}
    885         }
    886     }
    887 
    888     public static void closeQuietly(ServerSocket socket) {
    889         if (socket != null) {
    890             try {
    891                 socket.close();
    892             } catch (IOException ignored) {}
    893         }
    894     }
    895 
    896     public static void closeQuietly(Socket socket) {
    897         if (socket != null) {
    898             try {
    899                 socket.close();
    900             } catch (IOException ignored) {}
    901         }
    902     }
    903 
    904     public static byte[] readResource(Context context, int resId) throws IOException {
    905         ByteArrayOutputStream result = new ByteArrayOutputStream();
    906         InputStream in = null;
    907         byte[] buf = new byte[16 * 1024];
    908         try {
    909             in = context.getResources().openRawResource(resId);
    910             int chunkSize;
    911             while ((chunkSize = in.read(buf)) != -1) {
    912                 result.write(buf, 0, chunkSize);
    913             }
    914             return result.toByteArray();
    915         } finally {
    916             closeQuietly(in);
    917         }
    918     }
    919 
    920     /**
    921      * {@link X509TrustManager} which trusts all certificate chains.
    922      */
    923     public static class TrustAllX509TrustManager implements X509TrustManager {
    924         @Override
    925         public void checkClientTrusted(X509Certificate[] chain, String authType)
    926                 throws CertificateException {
    927         }
    928 
    929         @Override
    930         public void checkServerTrusted(X509Certificate[] chain, String authType)
    931                 throws CertificateException {
    932         }
    933 
    934         @Override
    935         public X509Certificate[] getAcceptedIssuers() {
    936             return new X509Certificate[0];
    937         }
    938     }
    939 
    940     /**
    941      * {@link X509KeyManager} which uses the provided private key and cert chain for all sockets.
    942      */
    943     public static class HardcodedCertX509KeyManager implements X509KeyManager {
    944 
    945         private final PrivateKey mPrivateKey;
    946         private final X509Certificate[] mCertChain;
    947 
    948         HardcodedCertX509KeyManager(PrivateKey privateKey, X509Certificate[] certChain) {
    949             mPrivateKey = privateKey;
    950             mCertChain = certChain;
    951         }
    952 
    953         @Override
    954         public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) {
    955             return null;
    956         }
    957 
    958         @Override
    959         public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) {
    960             return "singleton";
    961         }
    962 
    963         @Override
    964         public X509Certificate[] getCertificateChain(String alias) {
    965             return mCertChain;
    966         }
    967 
    968         @Override
    969         public String[] getClientAliases(String keyType, Principal[] issuers) {
    970             return null;
    971         }
    972 
    973         @Override
    974         public PrivateKey getPrivateKey(String alias) {
    975             return mPrivateKey;
    976         }
    977 
    978         @Override
    979         public String[] getServerAliases(String keyType, Principal[] issuers) {
    980             return new String[] {"singleton"};
    981         }
    982     }
    983 }
    984