Home | History | Annotate | Download | only in cts
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.security.net.config.cts;
     18 
     19 import android.security.net.config.cts.CtsNetSecConfigDownloadManagerTestCases.R;
     20 
     21 import android.app.DownloadManager;
     22 import android.content.BroadcastReceiver;
     23 import android.content.Context;
     24 import android.content.Intent;
     25 import android.content.IntentFilter;
     26 import android.database.Cursor;
     27 import android.net.Uri;
     28 import android.os.SystemClock;
     29 import android.test.AndroidTestCase;
     30 import android.text.format.DateUtils;
     31 import android.util.Log;
     32 
     33 import java.io.ByteArrayOutputStream;
     34 import java.io.InputStream;
     35 import java.net.Socket;
     36 import java.net.ServerSocket;
     37 import java.security.KeyFactory;
     38 import java.security.KeyStore;
     39 import java.security.PrivateKey;
     40 import java.security.Provider;
     41 import java.security.Security;
     42 import java.security.Signature;
     43 import java.security.cert.Certificate;
     44 import java.security.cert.CertificateFactory;
     45 import java.security.cert.X509Certificate;
     46 import java.security.spec.PKCS8EncodedKeySpec;
     47 import java.util.Collection;
     48 import java.util.HashSet;
     49 import java.util.concurrent.Callable;
     50 import java.util.concurrent.ExecutionException;
     51 import java.util.concurrent.FutureTask;
     52 import java.util.concurrent.TimeUnit;
     53 import java.util.concurrent.TimeoutException;
     54 
     55 import javax.net.ssl.KeyManagerFactory;
     56 import javax.net.ssl.SSLContext;
     57 import javax.net.ssl.SSLServerSocket;
     58 import javax.net.ssl.TrustManager;
     59 import javax.net.ssl.TrustManagerFactory;
     60 import javax.net.ssl.X509TrustManager;
     61 
     62 public class DownloadManagerTest extends AndroidTestCase {
     63 
     64     private static final String HTTP_RESPONSE =
     65             "HTTP/1.0 200 OK\r\nContent-Type: text/plain\r\nContent-length: 5\r\n\r\nhello";
     66     private static final long TIMEOUT = 3 * DateUtils.SECOND_IN_MILLIS;
     67 
     68     public void testConfigTrustedCaAccepted() throws Exception {
     69         SSLServerSocket serverSocket = bindTLSServer(R.raw.valid_chain, R.raw.test_key);
     70         runDownloadManagerTest(serverSocket, true);
     71     }
     72 
     73     public void testUntrustedCaRejected() throws Exception {
     74         try {
     75             SSLServerSocket serverSocket = bindTLSServer(R.raw.invalid_chain, R.raw.test_key);
     76             runDownloadManagerTest(serverSocket, true);
     77             fail("Invalid CA should be rejected");
     78         } catch (Exception expected) {
     79         }
     80     }
     81 
     82     public void testPerDomainCleartextAccepted() throws Exception {
     83         ServerSocket serverSocket = new ServerSocket();
     84         serverSocket.bind(null);
     85         runDownloadManagerTest(serverSocket, false);
     86     }
     87 
     88     private void runDownloadManagerTest(ServerSocket serverSocket, boolean https) throws Exception {
     89         DownloadManager dm =
     90                 (DownloadManager) getContext().getSystemService(Context.DOWNLOAD_SERVICE);
     91         DownloadCompleteReceiver receiver = new DownloadCompleteReceiver();
     92         FutureTask<Void> serverFuture = new FutureTask<Void>(new Callable() {
     93             @Override
     94             public Void call() throws Exception {
     95                 runServer(serverSocket);
     96                 return null;
     97             }
     98         });
     99         try {
    100             IntentFilter filter = new IntentFilter(DownloadManager.ACTION_DOWNLOAD_COMPLETE);
    101             getContext().registerReceiver(receiver, filter);
    102             new Thread(serverFuture).start();
    103             String host = (https ? "https" : "http") + "://localhost";
    104             Uri destination = Uri.parse(host + ":" + serverSocket.getLocalPort());
    105             long id = dm.enqueue(new DownloadManager.Request(destination));
    106             try {
    107                 serverFuture.get(TIMEOUT, TimeUnit.MILLISECONDS);
    108                 // Check that the download was successful.
    109                 receiver.waitForDownloadComplete(TIMEOUT, id);
    110                 assertSuccessfulDownload(id);
    111             } catch (InterruptedException e) {
    112                 // Wrap InterruptedException since otherwise it gets eaten by AndroidTest
    113                 throw new RuntimeException(e);
    114             } finally {
    115                 dm.remove(id);
    116             }
    117         } finally {
    118             getContext().unregisterReceiver(receiver);
    119             serverFuture.cancel(true);
    120             try {
    121                 serverSocket.close();
    122             } catch (Exception ignored) {}
    123         }
    124     }
    125 
    126     private void runServer(ServerSocket server) throws Exception {
    127         Socket s = server.accept();
    128         s.getOutputStream().write(HTTP_RESPONSE.getBytes());
    129         s.getOutputStream().flush();
    130         s.close();
    131     }
    132 
    133     private SSLServerSocket bindTLSServer(int chainResId, int keyResId) throws Exception {
    134         // Load certificate chain.
    135         CertificateFactory fact = CertificateFactory.getInstance("X.509");
    136         Collection<? extends Certificate> certs;
    137         try (InputStream is = getContext().getResources().openRawResource(chainResId)) {
    138             certs = fact.generateCertificates(is);
    139         }
    140         X509Certificate[] chain = new X509Certificate[certs.size()];
    141         int i = 0;
    142         for (Certificate cert : certs) {
    143             chain[i++] = (X509Certificate) cert;
    144         }
    145 
    146         // Load private key for the leaf.
    147         PrivateKey key;
    148         try (InputStream is = getContext().getResources().openRawResource(keyResId)) {
    149             ByteArrayOutputStream keyout = new ByteArrayOutputStream();
    150             byte[] buffer = new byte[4096];
    151             int chunk_size;
    152             while ((chunk_size = is.read(buffer)) != -1) {
    153                 keyout.write(buffer, 0, chunk_size);
    154             }
    155             is.close();
    156             byte[] keyBytes = keyout.toByteArray();
    157             key = KeyFactory.getInstance("RSA")
    158                     .generatePrivate(new PKCS8EncodedKeySpec(keyBytes));
    159         }
    160 
    161         // Create KeyStore based on the private key/chain.
    162         KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
    163         ks.load(null);
    164         ks.setKeyEntry("name", key, null, chain);
    165 
    166         // Create SSLContext.
    167         TrustManagerFactory tmf = TrustManagerFactory.getInstance("PKIX");
    168         tmf.init(ks);
    169         KeyManagerFactory kmf =
    170                 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
    171         kmf.init(ks, null);
    172         SSLContext context = SSLContext.getInstance("TLS");
    173         context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
    174 
    175         SSLServerSocket s = (SSLServerSocket) context.getServerSocketFactory().createServerSocket();
    176         s.bind(null);
    177         return s;
    178     }
    179 
    180     private void assertSuccessfulDownload(long id) throws Exception {
    181         Cursor cursor = null;
    182         DownloadManager dm =
    183                 (DownloadManager) getContext().getSystemService(Context.DOWNLOAD_SERVICE);
    184         try {
    185             cursor = dm.query(new DownloadManager.Query().setFilterById(id));
    186             assertTrue(cursor.moveToNext());
    187             assertEquals(DownloadManager.STATUS_SUCCESSFUL, cursor.getInt(
    188                     cursor.getColumnIndex(DownloadManager.COLUMN_STATUS)));
    189         } finally {
    190             if (cursor != null) {
    191                 cursor.close();
    192             }
    193         }
    194     }
    195 
    196     private static final class DownloadCompleteReceiver extends BroadcastReceiver {
    197         private HashSet<Long> mCompletedDownloads = new HashSet<>();
    198 
    199         public DownloadCompleteReceiver() {
    200         }
    201 
    202         @Override
    203         public void onReceive(Context context, Intent intent) {
    204             synchronized(mCompletedDownloads) {
    205                 mCompletedDownloads.add(intent.getLongExtra(DownloadManager.EXTRA_DOWNLOAD_ID, -1));
    206                 mCompletedDownloads.notifyAll();
    207             }
    208         }
    209 
    210         public void waitForDownloadComplete(long timeout, long id)
    211                 throws TimeoutException, InterruptedException  {
    212             long deadline = SystemClock.elapsedRealtime() + timeout;
    213             do {
    214                 synchronized (mCompletedDownloads) {
    215                     long millisTillTimeout = deadline - SystemClock.elapsedRealtime();
    216                     if (millisTillTimeout > 0) {
    217                         mCompletedDownloads.wait(millisTillTimeout);
    218                     }
    219                     if (mCompletedDownloads.contains(id)) {
    220                         return;
    221                     }
    222                 }
    223             } while (SystemClock.elapsedRealtime() < deadline);
    224 
    225             throw new TimeoutException("Timed out waiting for download complete");
    226         }
    227     }
    228 
    229 
    230 }
    231