Home | History | Annotate | Download | only in cts
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.security.net.config.cts;
     18 
     19 import android.security.net.config.cts.CtsNetSecConfigDownloadManagerTestCases.R;
     20 
     21 import android.app.DownloadManager;
     22 import android.content.BroadcastReceiver;
     23 import android.content.Context;
     24 import android.content.Intent;
     25 import android.content.IntentFilter;
     26 import android.database.Cursor;
     27 import android.net.Uri;
     28 import android.os.SystemClock;
     29 import android.test.AndroidTestCase;
     30 import android.text.format.DateUtils;
     31 import android.util.Log;
     32 
     33 import java.io.ByteArrayOutputStream;
     34 import java.io.InputStream;
     35 import java.net.Socket;
     36 import java.security.KeyFactory;
     37 import java.security.KeyStore;
     38 import java.security.PrivateKey;
     39 import java.security.Provider;
     40 import java.security.Security;
     41 import java.security.Signature;
     42 import java.security.cert.Certificate;
     43 import java.security.cert.CertificateFactory;
     44 import java.security.cert.X509Certificate;
     45 import java.security.spec.PKCS8EncodedKeySpec;
     46 import java.util.Collection;
     47 import java.util.HashSet;
     48 import java.util.concurrent.Callable;
     49 import java.util.concurrent.ExecutionException;
     50 import java.util.concurrent.FutureTask;
     51 import java.util.concurrent.TimeUnit;
     52 import java.util.concurrent.TimeoutException;
     53 
     54 import javax.net.ssl.KeyManagerFactory;
     55 import javax.net.ssl.SSLContext;
     56 import javax.net.ssl.SSLServerSocket;
     57 import javax.net.ssl.TrustManager;
     58 import javax.net.ssl.TrustManagerFactory;
     59 import javax.net.ssl.X509TrustManager;
     60 
     61 public class DownloadManagerTest extends AndroidTestCase {
     62 
     63     private static final String HTTP_RESPONSE =
     64             "HTTP/1.0 200 OK\r\nContent-Type: text/plain\r\nContent-length: 5\r\n\r\nhello";
     65     private static final long TIMEOUT = 3 * DateUtils.SECOND_IN_MILLIS;
     66 
     67     public void testConfigTrustedCaAccepted() throws Exception {
     68         runDownloadManagerTest(R.raw.valid_chain, R.raw.test_key);
     69     }
     70 
     71     public void testUntrustedCaRejected() throws Exception {
     72         try {
     73             runDownloadManagerTest(R.raw.invalid_chain, R.raw.test_key);
     74             fail("Invalid CA should be rejected");
     75         } catch (Exception expected) {
     76         }
     77     }
     78 
     79     private void runDownloadManagerTest(int chainResId, int keyResId) throws Exception {
     80         DownloadManager dm =
     81                 (DownloadManager) getContext().getSystemService(Context.DOWNLOAD_SERVICE);
     82         DownloadCompleteReceiver receiver = new DownloadCompleteReceiver();
     83         final SSLServerSocket serverSocket = bindTLSServer(chainResId, keyResId);
     84         FutureTask<Void> serverFuture = new FutureTask<Void>(new Callable() {
     85             @Override
     86             public Void call() throws Exception {
     87                 runServer(serverSocket);
     88                 return null;
     89             }
     90         });
     91         try {
     92             IntentFilter filter = new IntentFilter(DownloadManager.ACTION_DOWNLOAD_COMPLETE);
     93             getContext().registerReceiver(receiver, filter);
     94             new Thread(serverFuture).start();
     95             Uri destination = Uri.parse("https://localhost:" + serverSocket.getLocalPort());
     96             long id = dm.enqueue(new DownloadManager.Request(destination));
     97             try {
     98                 serverFuture.get(TIMEOUT, TimeUnit.MILLISECONDS);
     99                 // Check that the download was successful.
    100                 receiver.waitForDownloadComplete(TIMEOUT, id);
    101                 assertSuccessfulDownload(id);
    102             } catch (InterruptedException e) {
    103                 // Wrap InterruptedException since otherwise it gets eaten by AndroidTest
    104                 throw new RuntimeException(e);
    105             } finally {
    106                 dm.remove(id);
    107             }
    108         } finally {
    109             getContext().unregisterReceiver(receiver);
    110             serverFuture.cancel(true);
    111             try {
    112                 serverSocket.close();
    113             } catch (Exception ignored) {}
    114         }
    115     }
    116 
    117     private void runServer(SSLServerSocket server) throws Exception {
    118         Socket s = server.accept();
    119         s.getOutputStream().write(HTTP_RESPONSE.getBytes());
    120         s.getOutputStream().flush();
    121         s.close();
    122     }
    123 
    124     private SSLServerSocket bindTLSServer(int chainResId, int keyResId) throws Exception {
    125         // Load certificate chain.
    126         CertificateFactory fact = CertificateFactory.getInstance("X.509");
    127         Collection<? extends Certificate> certs;
    128         try (InputStream is = getContext().getResources().openRawResource(chainResId)) {
    129             certs = fact.generateCertificates(is);
    130         }
    131         X509Certificate[] chain = new X509Certificate[certs.size()];
    132         int i = 0;
    133         for (Certificate cert : certs) {
    134             chain[i++] = (X509Certificate) cert;
    135         }
    136 
    137         // Load private key for the leaf.
    138         PrivateKey key;
    139         try (InputStream is = getContext().getResources().openRawResource(keyResId)) {
    140             ByteArrayOutputStream keyout = new ByteArrayOutputStream();
    141             byte[] buffer = new byte[4096];
    142             int chunk_size;
    143             while ((chunk_size = is.read(buffer)) != -1) {
    144                 keyout.write(buffer, 0, chunk_size);
    145             }
    146             is.close();
    147             byte[] keyBytes = keyout.toByteArray();
    148             key = KeyFactory.getInstance("RSA")
    149                     .generatePrivate(new PKCS8EncodedKeySpec(keyBytes));
    150         }
    151 
    152         // Create KeyStore based on the private key/chain.
    153         KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
    154         ks.load(null);
    155         ks.setKeyEntry("name", key, null, chain);
    156 
    157         // Create SSLContext.
    158         TrustManagerFactory tmf = TrustManagerFactory.getInstance("PKIX");
    159         tmf.init(ks);
    160         KeyManagerFactory kmf =
    161                 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
    162         kmf.init(ks, null);
    163         SSLContext context = SSLContext.getInstance("TLS");
    164         context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
    165 
    166         SSLServerSocket s = (SSLServerSocket) context.getServerSocketFactory().createServerSocket();
    167         s.bind(null);
    168         return s;
    169     }
    170 
    171     private void assertSuccessfulDownload(long id) throws Exception {
    172         Cursor cursor = null;
    173         DownloadManager dm =
    174                 (DownloadManager) getContext().getSystemService(Context.DOWNLOAD_SERVICE);
    175         try {
    176             cursor = dm.query(new DownloadManager.Query().setFilterById(id));
    177             assertTrue(cursor.moveToNext());
    178             assertEquals(DownloadManager.STATUS_SUCCESSFUL, cursor.getInt(
    179                     cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)));
    180         } finally {
    181             if (cursor != null) {
    182                 cursor.close();
    183             }
    184         }
    185     }
    186 
    187     private static final class DownloadCompleteReceiver extends BroadcastReceiver {
    188         private HashSet<Long> mCompletedDownloads = new HashSet<>();
    189 
    190         public DownloadCompleteReceiver() {
    191         }
    192 
    193         @Override
    194         public void onReceive(Context context, Intent intent) {
    195             synchronized(mCompletedDownloads) {
    196                 mCompletedDownloads.add(intent.getLongExtra(DownloadManager.EXTRA_DOWNLOAD_ID, -1));
    197                 mCompletedDownloads.notifyAll();
    198             }
    199         }
    200 
    201         public void waitForDownloadComplete(long timeout, long id)
    202                 throws TimeoutException, InterruptedException  {
    203             long deadline = SystemClock.elapsedRealtime() + timeout;
    204             do {
    205                 synchronized (mCompletedDownloads) {
    206                     long millisTillTimeout = deadline - SystemClock.elapsedRealtime();
    207                     if (millisTillTimeout > 0) {
    208                         mCompletedDownloads.wait(millisTillTimeout);
    209                     }
    210                     if (mCompletedDownloads.contains(id)) {
    211                         return;
    212                     }
    213                 }
    214             } while (SystemClock.elapsedRealtime() < deadline);
    215 
    216             throw new TimeoutException("Timed out waiting for download complete");
    217         }
    218     }
    219 
    220 
    221 }
    222