1 /* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.security.net.config.cts; 18 19 import android.security.net.config.cts.CtsNetSecConfigDownloadManagerTestCases.R; 20 21 import android.app.DownloadManager; 22 import android.content.BroadcastReceiver; 23 import android.content.Context; 24 import android.content.Intent; 25 import android.content.IntentFilter; 26 import android.database.Cursor; 27 import android.net.Uri; 28 import android.os.SystemClock; 29 import android.test.AndroidTestCase; 30 import android.text.format.DateUtils; 31 import android.util.Log; 32 33 import java.io.ByteArrayOutputStream; 34 import java.io.InputStream; 35 import java.net.Socket; 36 import java.security.KeyFactory; 37 import java.security.KeyStore; 38 import java.security.PrivateKey; 39 import java.security.Provider; 40 import java.security.Security; 41 import java.security.Signature; 42 import java.security.cert.Certificate; 43 import java.security.cert.CertificateFactory; 44 import java.security.cert.X509Certificate; 45 import java.security.spec.PKCS8EncodedKeySpec; 46 import java.util.Collection; 47 import java.util.HashSet; 48 import java.util.concurrent.Callable; 49 import java.util.concurrent.ExecutionException; 50 import java.util.concurrent.FutureTask; 51 import java.util.concurrent.TimeUnit; 52 import java.util.concurrent.TimeoutException; 53 54 import javax.net.ssl.KeyManagerFactory; 55 import javax.net.ssl.SSLContext; 56 import javax.net.ssl.SSLServerSocket; 57 import javax.net.ssl.TrustManager; 58 import javax.net.ssl.TrustManagerFactory; 59 import javax.net.ssl.X509TrustManager; 60 61 public class DownloadManagerTest extends AndroidTestCase { 62 63 private static final String HTTP_RESPONSE = 64 "HTTP/1.0 200 OK\r\nContent-Type: text/plain\r\nContent-length: 5\r\n\r\nhello"; 65 private static final long TIMEOUT = 3 * DateUtils.SECOND_IN_MILLIS; 66 67 public void testConfigTrustedCaAccepted() throws Exception { 68 runDownloadManagerTest(R.raw.valid_chain, R.raw.test_key); 69 } 70 71 public void testUntrustedCaRejected() throws Exception { 72 try { 73 runDownloadManagerTest(R.raw.invalid_chain, R.raw.test_key); 74 fail("Invalid CA should be rejected"); 75 } catch (Exception expected) { 76 } 77 } 78 79 private void runDownloadManagerTest(int chainResId, int keyResId) throws Exception { 80 DownloadManager dm = 81 (DownloadManager) getContext().getSystemService(Context.DOWNLOAD_SERVICE); 82 DownloadCompleteReceiver receiver = new DownloadCompleteReceiver(); 83 final SSLServerSocket serverSocket = bindTLSServer(chainResId, keyResId); 84 FutureTask<Void> serverFuture = new FutureTask<Void>(new Callable() { 85 @Override 86 public Void call() throws Exception { 87 runServer(serverSocket); 88 return null; 89 } 90 }); 91 try { 92 IntentFilter filter = new IntentFilter(DownloadManager.ACTION_DOWNLOAD_COMPLETE); 93 getContext().registerReceiver(receiver, filter); 94 new Thread(serverFuture).start(); 95 Uri destination = Uri.parse("https://localhost:" + serverSocket.getLocalPort()); 96 long id = dm.enqueue(new DownloadManager.Request(destination)); 97 try { 98 serverFuture.get(TIMEOUT, TimeUnit.MILLISECONDS); 99 // Check that the download was successful. 100 receiver.waitForDownloadComplete(TIMEOUT, id); 101 assertSuccessfulDownload(id); 102 } catch (InterruptedException e) { 103 // Wrap InterruptedException since otherwise it gets eaten by AndroidTest 104 throw new RuntimeException(e); 105 } finally { 106 dm.remove(id); 107 } 108 } finally { 109 getContext().unregisterReceiver(receiver); 110 serverFuture.cancel(true); 111 try { 112 serverSocket.close(); 113 } catch (Exception ignored) {} 114 } 115 } 116 117 private void runServer(SSLServerSocket server) throws Exception { 118 Socket s = server.accept(); 119 s.getOutputStream().write(HTTP_RESPONSE.getBytes()); 120 s.getOutputStream().flush(); 121 s.close(); 122 } 123 124 private SSLServerSocket bindTLSServer(int chainResId, int keyResId) throws Exception { 125 // Load certificate chain. 126 CertificateFactory fact = CertificateFactory.getInstance("X.509"); 127 Collection<? extends Certificate> certs; 128 try (InputStream is = getContext().getResources().openRawResource(chainResId)) { 129 certs = fact.generateCertificates(is); 130 } 131 X509Certificate[] chain = new X509Certificate[certs.size()]; 132 int i = 0; 133 for (Certificate cert : certs) { 134 chain[i++] = (X509Certificate) cert; 135 } 136 137 // Load private key for the leaf. 138 PrivateKey key; 139 try (InputStream is = getContext().getResources().openRawResource(keyResId)) { 140 ByteArrayOutputStream keyout = new ByteArrayOutputStream(); 141 byte[] buffer = new byte[4096]; 142 int chunk_size; 143 while ((chunk_size = is.read(buffer)) != -1) { 144 keyout.write(buffer, 0, chunk_size); 145 } 146 is.close(); 147 byte[] keyBytes = keyout.toByteArray(); 148 key = KeyFactory.getInstance("RSA") 149 .generatePrivate(new PKCS8EncodedKeySpec(keyBytes)); 150 } 151 152 // Create KeyStore based on the private key/chain. 153 KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); 154 ks.load(null); 155 ks.setKeyEntry("name", key, null, chain); 156 157 // Create SSLContext. 158 TrustManagerFactory tmf = TrustManagerFactory.getInstance("PKIX"); 159 tmf.init(ks); 160 KeyManagerFactory kmf = 161 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); 162 kmf.init(ks, null); 163 SSLContext context = SSLContext.getInstance("TLS"); 164 context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); 165 166 SSLServerSocket s = (SSLServerSocket) context.getServerSocketFactory().createServerSocket(); 167 s.bind(null); 168 return s; 169 } 170 171 private void assertSuccessfulDownload(long id) throws Exception { 172 Cursor cursor = null; 173 DownloadManager dm = 174 (DownloadManager) getContext().getSystemService(Context.DOWNLOAD_SERVICE); 175 try { 176 cursor = dm.query(new DownloadManager.Query().setFilterById(id)); 177 assertTrue(cursor.moveToNext()); 178 assertEquals(DownloadManager.STATUS_SUCCESSFUL, cursor.getInt( 179 cursor.getColumnIndex(DownloadManager.COLUMN_STATUS))); 180 } finally { 181 if (cursor != null) { 182 cursor.close(); 183 } 184 } 185 } 186 187 private static final class DownloadCompleteReceiver extends BroadcastReceiver { 188 private HashSet<Long> mCompletedDownloads = new HashSet<>(); 189 190 public DownloadCompleteReceiver() { 191 } 192 193 @Override 194 public void onReceive(Context context, Intent intent) { 195 synchronized(mCompletedDownloads) { 196 mCompletedDownloads.add(intent.getLongExtra(DownloadManager.EXTRA_DOWNLOAD_ID, -1)); 197 mCompletedDownloads.notifyAll(); 198 } 199 } 200 201 public void waitForDownloadComplete(long timeout, long id) 202 throws TimeoutException, InterruptedException { 203 long deadline = SystemClock.elapsedRealtime() + timeout; 204 do { 205 synchronized (mCompletedDownloads) { 206 long millisTillTimeout = deadline - SystemClock.elapsedRealtime(); 207 if (millisTillTimeout > 0) { 208 mCompletedDownloads.wait(millisTillTimeout); 209 } 210 if (mCompletedDownloads.contains(id)) { 211 return; 212 } 213 } 214 } while (SystemClock.elapsedRealtime() < deadline); 215 216 throw new TimeoutException("Timed out waiting for download complete"); 217 } 218 } 219 220 221 } 222