1 /* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package org.conscrypt; 18 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertFalse; 21 22 import java.io.FileNotFoundException; 23 import java.io.IOException; 24 import java.io.InputStream; 25 import java.lang.reflect.Method; 26 import java.net.InetAddress; 27 import java.net.ServerSocket; 28 import java.net.UnknownHostException; 29 import java.nio.ByteBuffer; 30 import java.nio.charset.Charset; 31 import java.security.NoSuchAlgorithmException; 32 import java.security.Provider; 33 import java.security.Security; 34 import java.util.ArrayList; 35 import java.util.Arrays; 36 import java.util.Iterator; 37 import java.util.LinkedHashSet; 38 import java.util.List; 39 import java.util.Set; 40 import javax.net.ssl.SSLContext; 41 import javax.net.ssl.SSLEngine; 42 import javax.net.ssl.SSLEngineResult; 43 import javax.net.ssl.SSLException; 44 import javax.net.ssl.SSLParameters; 45 import javax.net.ssl.SSLServerSocketFactory; 46 import javax.net.ssl.SSLSocketFactory; 47 import libcore.io.Streams; 48 import org.bouncycastle.jce.provider.BouncyCastleProvider; 49 import org.conscrypt.java.security.TestKeyStore; 50 import org.junit.Assume; 51 52 /** 53 * Utility methods to support testing. 54 */ 55 public final class TestUtils { 56 public static final Charset UTF_8 = Charset.forName("UTF-8"); 57 private static final String PROTOCOL_TLS_V1_2 = "TLSv1.2"; 58 private static final String PROTOCOL_TLS_V1_1 = "TLSv1.1"; 59 private static final String PROTOCOL_TLS_V1 = "TLSv1"; 60 private static final String[] DESIRED_PROTOCOLS = 61 new String[] {PROTOCOL_TLS_V1_2, PROTOCOL_TLS_V1_1, /* For Java 6 */ PROTOCOL_TLS_V1}; 62 private static final Provider JDK_PROVIDER = getDefaultTlsProvider(); 63 private static final byte[] CHARS = 64 "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".getBytes(UTF_8); 65 private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0); 66 private static final String[] PROTOCOLS = getProtocolsInternal(); 67 68 static final String TEST_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; 69 70 private TestUtils() {} 71 72 private static Provider getDefaultTlsProvider() { 73 for (String protocol : DESIRED_PROTOCOLS) { 74 for (Provider p : Security.getProviders()) { 75 if (hasProtocol(p, protocol)) { 76 return p; 77 } 78 } 79 } 80 // For Java 1.6 testing 81 return new BouncyCastleProvider(); 82 } 83 84 private static boolean hasProtocol(Provider p, String protocol) { 85 return p.get("SSLContext." + protocol) != null; 86 } 87 88 static Provider getJdkProvider() { 89 return JDK_PROVIDER; 90 } 91 92 private static void assumeClassAvailable(String classname) { 93 boolean available = false; 94 try { 95 Class.forName(classname); 96 available = true; 97 } catch (ClassNotFoundException ignore) { 98 // Ignored 99 } 100 Assume.assumeTrue("Skipping test: " + classname + " unavailable", available); 101 } 102 103 public static void assumeSNIHostnameAvailable() { 104 assumeClassAvailable("javax.net.ssl.SNIHostName"); 105 } 106 107 public static void assumeSetEndpointIdentificationAlgorithmAvailable() { 108 boolean supported = false; 109 try { 110 SSLParameters.class.getMethod("setEndpointIdentificationAlgorithm", String.class); 111 supported = true; 112 } catch (NoSuchMethodException ignore) { 113 // Ignored 114 } 115 Assume.assumeTrue("Skipping test: " 116 + "SSLParameters.setEndpointIdentificationAlgorithm unavailable", supported); 117 } 118 119 public static void assumeAEADAvailable() { 120 assumeClassAvailable("javax.crypto.AEADBadTagException"); 121 } 122 123 private static boolean isAndroid() { 124 try { 125 Class.forName("android.app.Application", false, ClassLoader.getSystemClassLoader()); 126 return true; 127 } catch (Throwable ignored) { 128 // Failed to load the class uniquely available in Android. 129 return false; 130 } 131 } 132 133 public static void assumeAndroid() { 134 Assume.assumeTrue(isAndroid()); 135 } 136 137 public static void assumeAllowsUnsignedCrypto() { 138 // The Oracle JRE disallows loading crypto providers from unsigned jars 139 Assume.assumeTrue(isAndroid() 140 || !System.getProperty("java.vm.name").contains("HotSpot")); 141 } 142 143 public static InetAddress getLoopbackAddress() { 144 try { 145 Method method = InetAddress.class.getMethod("getLoopbackAddress"); 146 return (InetAddress) method.invoke(null); 147 } catch (Exception ignore) { 148 // Ignored. 149 } 150 try { 151 return InetAddress.getLocalHost(); 152 } catch (UnknownHostException e) { 153 throw new RuntimeException(e); 154 } 155 } 156 157 public static Provider getConscryptProvider() { 158 try { 159 return (Provider) conscryptClass("OpenSSLProvider").getConstructor().newInstance(); 160 } catch (Exception e) { 161 throw new RuntimeException(e); 162 } 163 } 164 165 public static synchronized void installConscryptAsDefaultProvider() { 166 final Provider conscryptProvider = getConscryptProvider(); 167 Provider[] providers = Security.getProviders(); 168 if (providers.length == 0 || !providers[0].equals(conscryptProvider)) { 169 Security.insertProviderAt(conscryptProvider, 1); 170 } 171 } 172 173 public static InputStream openTestFile(String name) throws FileNotFoundException { 174 InputStream is = TestUtils.class.getResourceAsStream("/" + name); 175 if (is == null) { 176 throw new FileNotFoundException(name); 177 } 178 return is; 179 } 180 181 public static byte[] readTestFile(String name) throws IOException { 182 return Streams.readFully(openTestFile(name)); 183 } 184 185 /** 186 * Looks up the conscrypt class for the given simple name (i.e. no package prefix). 187 */ 188 public static Class<?> conscryptClass(String simpleName) throws ClassNotFoundException { 189 ClassNotFoundException ex = null; 190 for (String packageName : new String[] {"org.conscrypt", "com.android.org.conscrypt"}) { 191 String name = packageName + "." + simpleName; 192 try { 193 return Class.forName(name); 194 } catch (ClassNotFoundException e) { 195 ex = e; 196 } 197 } 198 throw ex; 199 } 200 201 /** 202 * Returns an array containing only {@link #PROTOCOL_TLS_V1_2}. 203 */ 204 public static String[] getProtocols() { 205 return PROTOCOLS; 206 } 207 208 private static String[] getProtocolsInternal() { 209 List<String> protocols = new ArrayList<String>(); 210 for (String protocol : DESIRED_PROTOCOLS) { 211 if (hasProtocol(getJdkProvider(), protocol)) { 212 protocols.add(protocol); 213 } 214 } 215 return protocols.toArray(new String[protocols.size()]); 216 } 217 218 public static SSLSocketFactory getJdkSocketFactory() { 219 return getSocketFactory(JDK_PROVIDER); 220 } 221 222 public static SSLServerSocketFactory getJdkServerSocketFactory() { 223 return getServerSocketFactory(JDK_PROVIDER); 224 } 225 226 static SSLSocketFactory setUseEngineSocket( 227 SSLSocketFactory conscryptFactory, boolean useEngineSocket) { 228 try { 229 Class<?> clazz = conscryptClass("Conscrypt"); 230 Method method = 231 clazz.getMethod("setUseEngineSocket", SSLSocketFactory.class, boolean.class); 232 method.invoke(null, conscryptFactory, useEngineSocket); 233 return conscryptFactory; 234 } catch (Exception e) { 235 throw new RuntimeException(e); 236 } 237 } 238 239 static SSLServerSocketFactory setUseEngineSocket( 240 SSLServerSocketFactory conscryptFactory, boolean useEngineSocket) { 241 try { 242 Class<?> clazz = conscryptClass("Conscrypt"); 243 Method method = clazz.getMethod( 244 "setUseEngineSocket", SSLServerSocketFactory.class, boolean.class); 245 method.invoke(null, conscryptFactory, useEngineSocket); 246 return conscryptFactory; 247 } catch (Exception e) { 248 throw new RuntimeException(e); 249 } 250 } 251 252 public static SSLSocketFactory getConscryptSocketFactory(boolean useEngineSocket) { 253 return setUseEngineSocket(getSocketFactory(getConscryptProvider()), useEngineSocket); 254 } 255 256 public static SSLServerSocketFactory getConscryptServerSocketFactory(boolean useEngineSocket) { 257 return setUseEngineSocket(getServerSocketFactory(getConscryptProvider()), useEngineSocket); 258 } 259 260 private static SSLSocketFactory getSocketFactory(Provider provider) { 261 SSLContext clientContext = initClientSslContext(newContext(provider)); 262 return clientContext.getSocketFactory(); 263 } 264 265 private static SSLServerSocketFactory getServerSocketFactory(Provider provider) { 266 SSLContext serverContext = initServerSslContext(newContext(provider)); 267 return serverContext.getServerSocketFactory(); 268 } 269 270 static SSLContext newContext(Provider provider) { 271 try { 272 return SSLContext.getInstance("TLS", provider); 273 } catch (NoSuchAlgorithmException e) { 274 throw new RuntimeException(e); 275 } 276 } 277 278 static String[] getCommonCipherSuites() { 279 SSLContext jdkContext = 280 TestUtils.initSslContext(newContext(getJdkProvider()), TestKeyStore.getClient()); 281 SSLContext conscryptContext = TestUtils.initSslContext( 282 newContext(getConscryptProvider()), TestKeyStore.getClient()); 283 Set<String> supported = new LinkedHashSet<String>(); 284 supported.addAll(supportedCiphers(jdkContext)); 285 supported.retainAll(supportedCiphers(conscryptContext)); 286 filterCiphers(supported); 287 288 return supported.toArray(new String[supported.size()]); 289 } 290 291 private static List<String> supportedCiphers(SSLContext ctx) { 292 return Arrays.asList(ctx.getDefaultSSLParameters().getCipherSuites()); 293 } 294 295 private static void filterCiphers(Iterable<String> ciphers) { 296 // Filter all non-TLS ciphers. 297 Iterator<String> iter = ciphers.iterator(); 298 while (iter.hasNext()) { 299 String cipher = iter.next(); 300 if (cipher.startsWith("SSL_") || cipher.startsWith("TLS_EMPTY") 301 || cipher.contains("_RC4_")) { 302 iter.remove(); 303 } 304 } 305 } 306 307 /** 308 * Picks a port that is not used right at this moment. 309 * Warning: Not thread safe. May see "BindException: Address already in use: bind" if using the 310 * returned port to create a new server socket when other threads/processes are concurrently 311 * creating new sockets without a specific port. 312 */ 313 public static int pickUnusedPort() { 314 try { 315 ServerSocket serverSocket = new ServerSocket(0); 316 int port = serverSocket.getLocalPort(); 317 serverSocket.close(); 318 return port; 319 } catch (IOException e) { 320 throw new RuntimeException(e); 321 } 322 } 323 324 /** 325 * Creates a text message of the given length. 326 */ 327 public static byte[] newTextMessage(int length) { 328 byte[] msg = new byte[length]; 329 for (int msgIndex = 0; msgIndex < length;) { 330 int remaining = length - msgIndex; 331 int numChars = Math.min(remaining, CHARS.length); 332 System.arraycopy(CHARS, 0, msg, msgIndex, numChars); 333 msgIndex += numChars; 334 } 335 return msg; 336 } 337 338 static SSLContext newClientSslContext(Provider provider) { 339 SSLContext context = newContext(provider); 340 return initClientSslContext(context); 341 } 342 343 static SSLContext newServerSslContext(Provider provider) { 344 SSLContext context = newContext(provider); 345 return initServerSslContext(context); 346 } 347 348 /** 349 * Initializes the given client-side {@code context} with a default cert. 350 */ 351 public static SSLContext initClientSslContext(SSLContext context) { 352 return initSslContext(context, TestKeyStore.getClient()); 353 } 354 355 /** 356 * Initializes the given server-side {@code context} with the given cert chain and private key. 357 */ 358 public static SSLContext initServerSslContext(SSLContext context) { 359 return initSslContext(context, TestKeyStore.getServer()); 360 } 361 362 /** 363 * Initializes the given {@code context} from the {@code keyStore}. 364 */ 365 static SSLContext initSslContext(SSLContext context, TestKeyStore keyStore) { 366 try { 367 context.init(keyStore.keyManagers, keyStore.trustManagers, null); 368 return context; 369 } catch (Exception e) { 370 throw new RuntimeException(e); 371 } 372 } 373 374 /** 375 * Performs the intial TLS handshake between the two {@link SSLEngine} instances. 376 */ 377 public static void doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine, 378 ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer, 379 ByteBuffer serverPacketBuffer, boolean beginHandshake) throws SSLException { 380 if (beginHandshake) { 381 clientEngine.beginHandshake(); 382 serverEngine.beginHandshake(); 383 } 384 385 SSLEngineResult clientResult; 386 SSLEngineResult serverResult; 387 388 boolean clientHandshakeFinished = false; 389 boolean serverHandshakeFinished = false; 390 391 do { 392 int cTOsPos = clientPacketBuffer.position(); 393 int sTOcPos = serverPacketBuffer.position(); 394 395 clientResult = clientEngine.wrap(EMPTY_BUFFER, clientPacketBuffer); 396 runDelegatedTasks(clientResult, clientEngine); 397 serverResult = serverEngine.wrap(EMPTY_BUFFER, serverPacketBuffer); 398 runDelegatedTasks(serverResult, serverEngine); 399 400 // Verify that the consumed and produced number match what is in the buffers now. 401 assertEquals(0, clientResult.bytesConsumed()); 402 assertEquals(0, serverResult.bytesConsumed()); 403 assertEquals(clientPacketBuffer.position() - cTOsPos, clientResult.bytesProduced()); 404 assertEquals(serverPacketBuffer.position() - sTOcPos, serverResult.bytesProduced()); 405 406 clientPacketBuffer.flip(); 407 serverPacketBuffer.flip(); 408 409 // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED 410 if (isHandshakeFinished(clientResult)) { 411 assertFalse(clientHandshakeFinished); 412 clientHandshakeFinished = true; 413 } 414 if (isHandshakeFinished(serverResult)) { 415 assertFalse(serverHandshakeFinished); 416 serverHandshakeFinished = true; 417 } 418 419 cTOsPos = clientPacketBuffer.position(); 420 sTOcPos = serverPacketBuffer.position(); 421 422 int clientAppReadBufferPos = clientAppBuffer.position(); 423 int serverAppReadBufferPos = serverAppBuffer.position(); 424 425 clientResult = clientEngine.unwrap(serverPacketBuffer, clientAppBuffer); 426 runDelegatedTasks(clientResult, clientEngine); 427 serverResult = serverEngine.unwrap(clientPacketBuffer, serverAppBuffer); 428 runDelegatedTasks(serverResult, serverEngine); 429 430 // Verify that the consumed and produced number match what is in the buffers now. 431 assertEquals(serverPacketBuffer.position() - sTOcPos, clientResult.bytesConsumed()); 432 assertEquals(clientPacketBuffer.position() - cTOsPos, serverResult.bytesConsumed()); 433 assertEquals(clientAppBuffer.position() - clientAppReadBufferPos, 434 clientResult.bytesProduced()); 435 assertEquals(serverAppBuffer.position() - serverAppReadBufferPos, 436 serverResult.bytesProduced()); 437 438 clientPacketBuffer.compact(); 439 serverPacketBuffer.compact(); 440 441 // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED 442 if (isHandshakeFinished(clientResult)) { 443 assertFalse(clientHandshakeFinished); 444 clientHandshakeFinished = true; 445 } 446 if (isHandshakeFinished(serverResult)) { 447 assertFalse(serverHandshakeFinished); 448 serverHandshakeFinished = true; 449 } 450 } while (!clientHandshakeFinished || !serverHandshakeFinished); 451 } 452 453 private static boolean isHandshakeFinished(SSLEngineResult result) { 454 return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED; 455 } 456 457 private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) { 458 if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) { 459 for (;;) { 460 Runnable task = engine.getDelegatedTask(); 461 if (task == null) { 462 break; 463 } 464 task.run(); 465 } 466 } 467 } 468 469 /** 470 * Decodes the provided hexadecimal string into a byte array. Odd-length inputs 471 * are not allowed. 472 * 473 * Throws an {@code IllegalArgumentException} if the input is malformed. 474 */ 475 public static byte[] decodeHex(String encoded) throws IllegalArgumentException { 476 return decodeHex(encoded.toCharArray()); 477 } 478 479 /** 480 * Decodes the provided hexadecimal string into a byte array. If {@code allowSingleChar} 481 * is {@code true} odd-length inputs are allowed and the first character is interpreted 482 * as the lower bits of the first result byte. 483 * 484 * Throws an {@code IllegalArgumentException} if the input is malformed. 485 */ 486 public static byte[] decodeHex(String encoded, boolean allowSingleChar) throws IllegalArgumentException { 487 return decodeHex(encoded.toCharArray(), allowSingleChar); 488 } 489 490 /** 491 * Decodes the provided hexadecimal string into a byte array. Odd-length inputs 492 * are not allowed. 493 * 494 * Throws an {@code IllegalArgumentException} if the input is malformed. 495 */ 496 public static byte[] decodeHex(char[] encoded) throws IllegalArgumentException { 497 return decodeHex(encoded, false); 498 } 499 500 /** 501 * Decodes the provided hexadecimal string into a byte array. If {@code allowSingleChar} 502 * is {@code true} odd-length inputs are allowed and the first character is interpreted 503 * as the lower bits of the first result byte. 504 * 505 * Throws an {@code IllegalArgumentException} if the input is malformed. 506 */ 507 public static byte[] decodeHex(char[] encoded, boolean allowSingleChar) throws IllegalArgumentException { 508 int resultLengthBytes = (encoded.length + 1) / 2; 509 byte[] result = new byte[resultLengthBytes]; 510 511 int resultOffset = 0; 512 int i = 0; 513 if (allowSingleChar) { 514 if ((encoded.length % 2) != 0) { 515 // Odd number of digits -- the first digit is the lower 4 bits of the first result byte. 516 result[resultOffset++] = (byte) toDigit(encoded, i); 517 i++; 518 } 519 } else { 520 if ((encoded.length % 2) != 0) { 521 throw new IllegalArgumentException("Invalid input length: " + encoded.length); 522 } 523 } 524 525 for (int len = encoded.length; i < len; i += 2) { 526 result[resultOffset++] = (byte) ((toDigit(encoded, i) << 4) | toDigit(encoded, i + 1)); 527 } 528 529 return result; 530 } 531 532 533 private static int toDigit(char[] str, int offset) throws IllegalArgumentException { 534 // NOTE: that this isn't really a code point in the traditional sense, since we're 535 // just rejecting surrogate pairs outright. 536 int pseudoCodePoint = str[offset]; 537 538 if ('0' <= pseudoCodePoint && pseudoCodePoint <= '9') { 539 return pseudoCodePoint - '0'; 540 } else if ('a' <= pseudoCodePoint && pseudoCodePoint <= 'f') { 541 return 10 + (pseudoCodePoint - 'a'); 542 } else if ('A' <= pseudoCodePoint && pseudoCodePoint <= 'F') { 543 return 10 + (pseudoCodePoint - 'A'); 544 } 545 546 throw new IllegalArgumentException("Illegal char: " + str[offset] + 547 " at offset " + offset); 548 } 549 } 550