1 /* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package org.conscrypt; 18 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertFalse; 21 22 import java.io.FileNotFoundException; 23 import java.io.IOException; 24 import java.io.InputStream; 25 import java.lang.reflect.Method; 26 import java.net.ServerSocket; 27 import java.nio.ByteBuffer; 28 import java.nio.charset.Charset; 29 import java.security.NoSuchAlgorithmException; 30 import java.security.Provider; 31 import java.security.Security; 32 import javax.net.ssl.SSLContext; 33 import javax.net.ssl.SSLEngine; 34 import javax.net.ssl.SSLEngineResult; 35 import javax.net.ssl.SSLException; 36 import javax.net.ssl.SSLServerSocketFactory; 37 import javax.net.ssl.SSLSocketFactory; 38 import libcore.io.Streams; 39 import libcore.java.security.TestKeyStore; 40 41 /** 42 * Utility methods to support testing. 43 */ 44 public final class TestUtils { 45 static final Charset UTF_8 = Charset.forName("UTF-8"); 46 47 private static final Provider JDK_PROVIDER = getDefaultTlsProvider(); 48 private static final byte[] CHARS = 49 "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".getBytes(UTF_8); 50 private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0); 51 52 public static final String PROTOCOL_TLS_V1_2 = "TLSv1.2"; 53 public static final String PROVIDER_PROPERTY = "SSLContext.TLSv1.2"; 54 public static final String LOCALHOST = "localhost"; 55 56 static final String TEST_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; 57 58 private TestUtils() {} 59 60 private static Provider getDefaultTlsProvider() { 61 for (Provider p : Security.getProviders()) { 62 if (p.get(PROVIDER_PROPERTY) != null) { 63 return p; 64 } 65 } 66 throw new RuntimeException("Unable to find a default provider for " + PROVIDER_PROPERTY); 67 } 68 69 static Provider getJdkProvider() { 70 return JDK_PROVIDER; 71 } 72 73 public static Provider getConscryptProvider() { 74 try { 75 return (Provider) conscryptClass("OpenSSLProvider") 76 .getConstructor() 77 .newInstance(); 78 } catch (Exception e) { 79 throw new RuntimeException(e); 80 } 81 } 82 83 public static void installConscryptAsDefaultProvider() { 84 final Provider conscryptProvider = getConscryptProvider(); 85 synchronized (getConscryptProvider()) { 86 Provider[] providers = Security.getProviders(); 87 if (providers.length == 0 || !providers[0].equals(conscryptProvider)) { 88 Security.insertProviderAt(conscryptProvider, 1); 89 return; 90 } 91 } 92 } 93 94 public static InputStream openTestFile(String name) throws FileNotFoundException { 95 InputStream is = TestUtils.class.getResourceAsStream("/" + name); 96 if (is == null) { 97 throw new FileNotFoundException(name); 98 } 99 return is; 100 } 101 102 public static byte[] readTestFile(String name) throws IOException { 103 return Streams.readFully(openTestFile(name)); 104 } 105 106 /** 107 * Looks up the conscrypt class for the given simple name (i.e. no package prefix). 108 */ 109 public static Class<?> conscryptClass(String simpleName) throws ClassNotFoundException { 110 ClassNotFoundException ex = null; 111 for (String packageName : new String[]{"com.android.org.conscrypt", "org.conscrypt"}) { 112 String name = packageName + "." + simpleName; 113 try { 114 return Class.forName(name); 115 } catch (ClassNotFoundException e) { 116 ex = e; 117 } 118 } 119 throw ex; 120 } 121 122 /** 123 * Returns an array containing only {@link #PROTOCOL_TLS_V1_2}. 124 */ 125 public static String[] getProtocols() { 126 return new String[] {PROTOCOL_TLS_V1_2}; 127 } 128 129 public static SSLSocketFactory getJdkSocketFactory() { 130 return getSocketFactory(JDK_PROVIDER); 131 } 132 133 public static SSLServerSocketFactory getJdkServerSocketFactory() { 134 return getServerSocketFactory(JDK_PROVIDER); 135 } 136 137 static SSLSocketFactory setUseEngineSocket(SSLSocketFactory conscryptFactory, boolean useEngineSocket) { 138 try { 139 Class<?> clazz = conscryptClass("Conscrypt$SocketFactories"); 140 Method method = clazz.getMethod("setUseEngineSocket", SSLSocketFactory.class, boolean.class); 141 method.invoke(null, conscryptFactory, useEngineSocket); 142 return conscryptFactory; 143 } catch (Exception e) { 144 throw new RuntimeException(e); 145 } 146 } 147 148 static SSLServerSocketFactory setUseEngineSocket(SSLServerSocketFactory conscryptFactory, boolean useEngineSocket) { 149 try { 150 Class<?> clazz = conscryptClass("Conscrypt$ServerSocketFactories"); 151 Method method = clazz.getMethod("setUseEngineSocket", SSLServerSocketFactory.class, boolean.class); 152 method.invoke(null, conscryptFactory, useEngineSocket); 153 return conscryptFactory; 154 } catch (Exception e) { 155 throw new RuntimeException(e); 156 } 157 } 158 159 public static SSLSocketFactory getConscryptSocketFactory(boolean useEngineSocket) { 160 return setUseEngineSocket(getSocketFactory(getConscryptProvider()), useEngineSocket); 161 } 162 163 public static SSLServerSocketFactory getConscryptServerSocketFactory(boolean useEngineSocket) { 164 return setUseEngineSocket(getServerSocketFactory(getConscryptProvider()), useEngineSocket); 165 } 166 167 private static SSLSocketFactory getSocketFactory(Provider provider) { 168 SSLContext clientContext = initClientSslContext(newContext(provider)); 169 return clientContext.getSocketFactory(); 170 } 171 172 private static SSLServerSocketFactory getServerSocketFactory(Provider provider) { 173 SSLContext serverContext = initServerSslContext(newContext(provider)); 174 return serverContext.getServerSocketFactory(); 175 } 176 177 private static SSLContext newContext(Provider provider) { 178 try { 179 return SSLContext.getInstance("TLS", provider); 180 } catch (NoSuchAlgorithmException e) { 181 throw new RuntimeException(e); 182 } 183 } 184 185 /** 186 * Picks a port that is not used right at this moment. 187 * Warning: Not thread safe. May see "BindException: Address already in use: bind" if using the 188 * returned port to create a new server socket when other threads/processes are concurrently 189 * creating new sockets without a specific port. 190 */ 191 public static int pickUnusedPort() { 192 try { 193 ServerSocket serverSocket = new ServerSocket(0); 194 int port = serverSocket.getLocalPort(); 195 serverSocket.close(); 196 return port; 197 } catch (IOException e) { 198 throw new RuntimeException(e); 199 } 200 } 201 202 /** 203 * Creates a text message of the given length. 204 */ 205 public static byte[] newTextMessage(int length) { 206 byte[] msg = new byte[length]; 207 for (int msgIndex = 0; msgIndex < length;) { 208 int remaining = length - msgIndex; 209 int numChars = Math.min(remaining, CHARS.length); 210 System.arraycopy(CHARS, 0, msg, msgIndex, numChars); 211 msgIndex += numChars; 212 } 213 return msg; 214 } 215 216 /** 217 * Initializes the given engine with the cipher and client mode. 218 */ 219 static SSLEngine initEngine(SSLEngine engine, String cipher, boolean client) { 220 engine.setEnabledProtocols(getProtocols()); 221 engine.setEnabledCipherSuites(new String[] {cipher}); 222 engine.setUseClientMode(client); 223 return engine; 224 } 225 226 static SSLContext newClientSslContext(Provider provider) { 227 SSLContext context = newContext(provider); 228 return initClientSslContext(context); 229 } 230 231 static SSLContext newServerSslContext(Provider provider) { 232 SSLContext context = newContext(provider); 233 return initServerSslContext(context); 234 } 235 236 /** 237 * Initializes the given client-side {@code context} with a default cert. 238 */ 239 public static SSLContext initClientSslContext(SSLContext context) { 240 return initSslContext(context, TestKeyStore.getClient()); 241 } 242 243 /** 244 * Initializes the given server-side {@code context} with the given cert chain and private key. 245 */ 246 public static SSLContext initServerSslContext(SSLContext context) { 247 return initSslContext(context, TestKeyStore.getServer()); 248 } 249 250 /** 251 * Initializes the given {@code context} from the {@code keyStore}. 252 */ 253 static SSLContext initSslContext(SSLContext context, TestKeyStore keyStore) { 254 try { 255 context.init(keyStore.keyManagers, keyStore.trustManagers, null); 256 return context; 257 } catch (Exception e) { 258 throw new RuntimeException(e); 259 } 260 } 261 262 /** 263 * Performs the intial TLS handshake between the two {@link SSLEngine} instances. 264 */ 265 public static void doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine, 266 ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer, 267 ByteBuffer serverPacketBuffer) throws SSLException { 268 clientEngine.beginHandshake(); 269 serverEngine.beginHandshake(); 270 271 SSLEngineResult clientResult; 272 SSLEngineResult serverResult; 273 274 boolean clientHandshakeFinished = false; 275 boolean serverHandshakeFinished = false; 276 277 do { 278 int cTOsPos = clientPacketBuffer.position(); 279 int sTOcPos = serverPacketBuffer.position(); 280 281 clientResult = clientEngine.wrap(EMPTY_BUFFER, clientPacketBuffer); 282 runDelegatedTasks(clientResult, clientEngine); 283 serverResult = serverEngine.wrap(EMPTY_BUFFER, serverPacketBuffer); 284 runDelegatedTasks(serverResult, serverEngine); 285 286 // Verify that the consumed and produced number match what is in the buffers now. 287 assertEquals(0, clientResult.bytesConsumed()); 288 assertEquals(0, serverResult.bytesConsumed()); 289 assertEquals(clientPacketBuffer.position() - cTOsPos, clientResult.bytesProduced()); 290 assertEquals(serverPacketBuffer.position() - sTOcPos, serverResult.bytesProduced()); 291 292 clientPacketBuffer.flip(); 293 serverPacketBuffer.flip(); 294 295 // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED 296 if (isHandshakeFinished(clientResult)) { 297 assertFalse(clientHandshakeFinished); 298 clientHandshakeFinished = true; 299 } 300 if (isHandshakeFinished(serverResult)) { 301 assertFalse(serverHandshakeFinished); 302 serverHandshakeFinished = true; 303 } 304 305 cTOsPos = clientPacketBuffer.position(); 306 sTOcPos = serverPacketBuffer.position(); 307 308 int clientAppReadBufferPos = clientAppBuffer.position(); 309 int serverAppReadBufferPos = serverAppBuffer.position(); 310 311 clientResult = clientEngine.unwrap(serverPacketBuffer, clientAppBuffer); 312 runDelegatedTasks(clientResult, clientEngine); 313 serverResult = serverEngine.unwrap(clientPacketBuffer, serverAppBuffer); 314 runDelegatedTasks(serverResult, serverEngine); 315 316 // Verify that the consumed and produced number match what is in the buffers now. 317 assertEquals(serverPacketBuffer.position() - sTOcPos, clientResult.bytesConsumed()); 318 assertEquals(clientPacketBuffer.position() - cTOsPos, serverResult.bytesConsumed()); 319 assertEquals(clientAppBuffer.position() - clientAppReadBufferPos, 320 clientResult.bytesProduced()); 321 assertEquals(serverAppBuffer.position() - serverAppReadBufferPos, 322 serverResult.bytesProduced()); 323 324 clientPacketBuffer.compact(); 325 serverPacketBuffer.compact(); 326 327 // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED 328 if (isHandshakeFinished(clientResult)) { 329 assertFalse(clientHandshakeFinished); 330 clientHandshakeFinished = true; 331 } 332 if (isHandshakeFinished(serverResult)) { 333 assertFalse(serverHandshakeFinished); 334 serverHandshakeFinished = true; 335 } 336 } while (!clientHandshakeFinished || !serverHandshakeFinished); 337 } 338 339 private static boolean isHandshakeFinished(SSLEngineResult result) { 340 return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED; 341 } 342 343 private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) { 344 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) { 345 for (;;) { 346 Runnable task = engine.getDelegatedTask(); 347 if (task == null) { 348 break; 349 } 350 task.run(); 351 } 352 } 353 } 354 } 355