Home | History | Annotate | Download | only in cts
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.net.cts;
     18 
     19 import static org.junit.Assert.assertArrayEquals;
     20 
     21 import android.content.Context;
     22 import android.net.IpSecAlgorithm;
     23 import android.net.IpSecManager;
     24 import android.net.IpSecTransform;
     25 import android.system.Os;
     26 import android.system.OsConstants;
     27 import android.test.AndroidTestCase;
     28 import android.util.Log;
     29 
     30 import java.io.FileDescriptor;
     31 import java.io.IOException;
     32 import java.net.DatagramPacket;
     33 import java.net.DatagramSocket;
     34 import java.net.Inet4Address;
     35 import java.net.Inet6Address;
     36 import java.net.InetAddress;
     37 import java.net.InetSocketAddress;
     38 import java.net.ServerSocket;
     39 import java.net.Socket;
     40 import java.net.SocketException;
     41 import java.util.Arrays;
     42 import java.util.concurrent.atomic.AtomicInteger;
     43 
     44 public class IpSecBaseTest extends AndroidTestCase {
     45 
     46     private static final String TAG = IpSecBaseTest.class.getSimpleName();
     47 
     48     protected static final String IPV4_LOOPBACK = "127.0.0.1";
     49     protected static final String IPV6_LOOPBACK = "::1";
     50     protected static final String[] LOOPBACK_ADDRS = new String[] {IPV4_LOOPBACK, IPV6_LOOPBACK};
     51     protected static final int[] DIRECTIONS =
     52             new int[] {IpSecManager.DIRECTION_IN, IpSecManager.DIRECTION_OUT};
     53 
     54     protected static final byte[] TEST_DATA = "Best test data ever!".getBytes();
     55     protected static final int DATA_BUFFER_LEN = 4096;
     56     protected static final int SOCK_TIMEOUT = 500;
     57 
     58     private static final byte[] KEY_DATA = {
     59         0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
     60         0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
     61         0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
     62         0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
     63         0x20, 0x21, 0x22, 0x23
     64     };
     65 
     66     protected static final byte[] AUTH_KEY = getKey(256);
     67     protected static final byte[] CRYPT_KEY = getKey(256);
     68 
     69     protected IpSecManager mISM;
     70 
     71     protected void setUp() throws Exception {
     72         super.setUp();
     73         mISM = (IpSecManager) getContext().getSystemService(Context.IPSEC_SERVICE);
     74     }
     75 
     76     protected static byte[] getKey(int bitLength) {
     77         return Arrays.copyOf(KEY_DATA, bitLength / 8);
     78     }
     79 
     80     protected static int getDomain(InetAddress address) {
     81         int domain;
     82         if (address instanceof Inet6Address) {
     83             domain = OsConstants.AF_INET6;
     84         } else {
     85             domain = OsConstants.AF_INET;
     86         }
     87         return domain;
     88     }
     89 
     90     protected static int getPort(FileDescriptor sock) throws Exception {
     91         return ((InetSocketAddress) Os.getsockname(sock)).getPort();
     92     }
     93 
     94     public static interface GenericSocket extends AutoCloseable {
     95         void send(byte[] data) throws Exception;
     96 
     97         byte[] receive() throws Exception;
     98 
     99         int getPort() throws Exception;
    100 
    101         void close() throws Exception;
    102 
    103         void applyTransportModeTransform(
    104                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception;
    105 
    106         void removeTransportModeTransforms(IpSecManager ism) throws Exception;
    107     }
    108 
    109     public static interface GenericTcpSocket extends GenericSocket {}
    110 
    111     public static interface GenericUdpSocket extends GenericSocket {
    112         void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception;
    113     }
    114 
    115     public abstract static class NativeSocket implements GenericSocket {
    116         public FileDescriptor mFd;
    117 
    118         public NativeSocket(FileDescriptor fd) {
    119             mFd = fd;
    120         }
    121 
    122         @Override
    123         public void send(byte[] data) throws Exception {
    124             Os.write(mFd, data, 0, data.length);
    125         }
    126 
    127         @Override
    128         public byte[] receive() throws Exception {
    129             byte[] in = new byte[DATA_BUFFER_LEN];
    130             AtomicInteger bytesRead = new AtomicInteger(-1);
    131 
    132             Thread readSockThread = new Thread(() -> {
    133                 long startTime = System.currentTimeMillis();
    134                 while (bytesRead.get() < 0 && System.currentTimeMillis() < startTime + SOCK_TIMEOUT) {
    135                     try {
    136                         bytesRead.set(Os.recvfrom(mFd, in, 0, DATA_BUFFER_LEN, 0, null));
    137                     } catch (Exception e) {
    138                         Log.e(TAG, "Error encountered reading from socket", e);
    139                     }
    140                 }
    141             });
    142 
    143             readSockThread.start();
    144             readSockThread.join(SOCK_TIMEOUT);
    145 
    146             if (bytesRead.get() < 0) {
    147                 throw new IOException("No data received from socket");
    148             }
    149 
    150             return Arrays.copyOfRange(in, 0, bytesRead.get());
    151         }
    152 
    153         @Override
    154         public int getPort() throws Exception {
    155             return IpSecBaseTest.getPort(mFd);
    156         }
    157 
    158         @Override
    159         public void close() throws Exception {
    160             Os.close(mFd);
    161         }
    162 
    163         @Override
    164         public void applyTransportModeTransform(
    165                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
    166             ism.applyTransportModeTransform(mFd, direction, transform);
    167         }
    168 
    169         @Override
    170         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
    171             ism.removeTransportModeTransforms(mFd);
    172         }
    173     }
    174 
    175     public static class NativeTcpSocket extends NativeSocket implements GenericTcpSocket {
    176         public NativeTcpSocket(FileDescriptor fd) {
    177             super(fd);
    178         }
    179     }
    180 
    181     public static class NativeUdpSocket extends NativeSocket implements GenericUdpSocket {
    182         public NativeUdpSocket(FileDescriptor fd) {
    183             super(fd);
    184         }
    185 
    186         @Override
    187         public void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception {
    188             Os.sendto(mFd, data, 0, data.length, 0, dstAddr, port);
    189         }
    190     }
    191 
    192     public static class JavaUdpSocket implements GenericUdpSocket {
    193         public final DatagramSocket mSocket;
    194 
    195         public JavaUdpSocket(InetAddress localAddr) {
    196             try {
    197                 mSocket = new DatagramSocket(0, localAddr);
    198                 mSocket.setSoTimeout(SOCK_TIMEOUT);
    199             } catch (SocketException e) {
    200                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
    201                 // could easily end up in an endless wait.
    202                 throw new RuntimeException(e);
    203             }
    204         }
    205 
    206         @Override
    207         public void send(byte[] data) throws Exception {
    208             mSocket.send(new DatagramPacket(data, data.length));
    209         }
    210 
    211         @Override
    212         public void sendTo(byte[] data, InetAddress dstAddr, int port) throws Exception {
    213             mSocket.send(new DatagramPacket(data, data.length, dstAddr, port));
    214         }
    215 
    216         @Override
    217         public int getPort() throws Exception {
    218             return mSocket.getLocalPort();
    219         }
    220 
    221         @Override
    222         public void close() throws Exception {
    223             mSocket.close();
    224         }
    225 
    226         @Override
    227         public byte[] receive() throws Exception {
    228             DatagramPacket data = new DatagramPacket(new byte[DATA_BUFFER_LEN], DATA_BUFFER_LEN);
    229             mSocket.receive(data);
    230             return Arrays.copyOfRange(data.getData(), 0, data.getLength());
    231         }
    232 
    233         @Override
    234         public void applyTransportModeTransform(
    235                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
    236             ism.applyTransportModeTransform(mSocket, direction, transform);
    237         }
    238 
    239         @Override
    240         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
    241             ism.removeTransportModeTransforms(mSocket);
    242         }
    243     }
    244 
    245     public static class JavaTcpSocket implements GenericTcpSocket {
    246         public final Socket mSocket;
    247 
    248         public JavaTcpSocket(Socket socket) {
    249             mSocket = socket;
    250             try {
    251                 mSocket.setSoTimeout(SOCK_TIMEOUT);
    252             } catch (SocketException e) {
    253                 // Fail loudly if we can't set up sockets properly. And without the timeout, we
    254                 // could easily end up in an endless wait.
    255                 throw new RuntimeException(e);
    256             }
    257         }
    258 
    259         @Override
    260         public void send(byte[] data) throws Exception {
    261             mSocket.getOutputStream().write(data);
    262         }
    263 
    264         @Override
    265         public byte[] receive() throws Exception {
    266             byte[] in = new byte[DATA_BUFFER_LEN];
    267             int bytesRead = mSocket.getInputStream().read(in);
    268             return Arrays.copyOfRange(in, 0, bytesRead);
    269         }
    270 
    271         @Override
    272         public int getPort() throws Exception {
    273             return mSocket.getLocalPort();
    274         }
    275 
    276         @Override
    277         public void close() throws Exception {
    278             mSocket.close();
    279         }
    280 
    281         @Override
    282         public void applyTransportModeTransform(
    283                 IpSecManager ism, int direction, IpSecTransform transform) throws Exception {
    284             ism.applyTransportModeTransform(mSocket, direction, transform);
    285         }
    286 
    287         @Override
    288         public void removeTransportModeTransforms(IpSecManager ism) throws Exception {
    289             ism.removeTransportModeTransforms(mSocket);
    290         }
    291     }
    292 
    293     public static class SocketPair<T> {
    294         public final T mLeftSock;
    295         public final T mRightSock;
    296 
    297         public SocketPair(T leftSock, T rightSock) {
    298             mLeftSock = leftSock;
    299             mRightSock = rightSock;
    300         }
    301     }
    302 
    303     protected static void applyTransformBidirectionally(
    304             IpSecManager ism, IpSecTransform transform, GenericSocket socket) throws Exception {
    305         for (int direction : DIRECTIONS) {
    306             socket.applyTransportModeTransform(ism, direction, transform);
    307         }
    308     }
    309 
    310     public static SocketPair<NativeUdpSocket> getNativeUdpSocketPair(
    311             InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)
    312             throws Exception {
    313         int domain = getDomain(localAddr);
    314 
    315         NativeUdpSocket leftSock = new NativeUdpSocket(
    316             Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP));
    317         NativeUdpSocket rightSock = new NativeUdpSocket(
    318             Os.socket(domain, OsConstants.SOCK_DGRAM, OsConstants.IPPROTO_UDP));
    319 
    320         for (NativeUdpSocket sock : new NativeUdpSocket[] {leftSock, rightSock}) {
    321             applyTransformBidirectionally(ism, transform, sock);
    322             Os.bind(sock.mFd, localAddr, 0);
    323         }
    324 
    325         if (connected) {
    326             Os.connect(leftSock.mFd, localAddr, rightSock.getPort());
    327             Os.connect(rightSock.mFd, localAddr, leftSock.getPort());
    328         }
    329 
    330         return new SocketPair<>(leftSock, rightSock);
    331     }
    332 
    333     public static SocketPair<NativeTcpSocket> getNativeTcpSocketPair(
    334             InetAddress localAddr, IpSecManager ism, IpSecTransform transform) throws Exception {
    335         int domain = getDomain(localAddr);
    336 
    337         NativeTcpSocket server = new NativeTcpSocket(
    338                 Os.socket(domain, OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
    339         NativeTcpSocket client = new NativeTcpSocket(
    340                 Os.socket(domain, OsConstants.SOCK_STREAM, OsConstants.IPPROTO_TCP));
    341 
    342         Os.bind(server.mFd, localAddr, 0);
    343 
    344         applyTransformBidirectionally(ism, transform, server);
    345         applyTransformBidirectionally(ism, transform, client);
    346 
    347         Os.listen(server.mFd, 10);
    348         Os.connect(client.mFd, localAddr, server.getPort());
    349         NativeTcpSocket accepted = new NativeTcpSocket(Os.accept(server.mFd, null));
    350 
    351         applyTransformBidirectionally(ism, transform, accepted);
    352         server.close();
    353 
    354         return new SocketPair<>(client, accepted);
    355     }
    356 
    357     public static SocketPair<JavaUdpSocket> getJavaUdpSocketPair(
    358             InetAddress localAddr, IpSecManager ism, IpSecTransform transform, boolean connected)
    359             throws Exception {
    360         JavaUdpSocket leftSock = new JavaUdpSocket(localAddr);
    361         JavaUdpSocket rightSock = new JavaUdpSocket(localAddr);
    362 
    363         applyTransformBidirectionally(ism, transform, leftSock);
    364         applyTransformBidirectionally(ism, transform, rightSock);
    365 
    366         if (connected) {
    367             leftSock.mSocket.connect(localAddr, rightSock.mSocket.getLocalPort());
    368             rightSock.mSocket.connect(localAddr, leftSock.mSocket.getLocalPort());
    369         }
    370 
    371         return new SocketPair<>(leftSock, rightSock);
    372     }
    373 
    374     public static SocketPair<JavaTcpSocket> getJavaTcpSocketPair(
    375             InetAddress localAddr, IpSecManager ism, IpSecTransform transform) throws Exception {
    376         JavaTcpSocket clientSock = new JavaTcpSocket(new Socket());
    377         ServerSocket serverSocket = new ServerSocket();
    378         serverSocket.bind(new InetSocketAddress(localAddr, 0));
    379 
    380         // While technically the client socket does not need to be bound, the OpenJDK implementation
    381         // of Socket only allocates an FD when bind() or connect() or other similar methods are
    382         // called. So we call bind to force the FD creation, so that we can apply a transform to it
    383         // prior to socket connect.
    384         clientSock.mSocket.bind(new InetSocketAddress(localAddr, 0));
    385 
    386         // IpSecService doesn't support serverSockets at the moment; workaround using FD
    387         FileDescriptor serverFd = serverSocket.getImpl().getFD$();
    388 
    389         applyTransformBidirectionally(ism, transform, new NativeTcpSocket(serverFd));
    390         applyTransformBidirectionally(ism, transform, clientSock);
    391 
    392         clientSock.mSocket.connect(new InetSocketAddress(localAddr, serverSocket.getLocalPort()));
    393         JavaTcpSocket acceptedSock = new JavaTcpSocket(serverSocket.accept());
    394 
    395         applyTransformBidirectionally(ism, transform, acceptedSock);
    396         serverSocket.close();
    397 
    398         return new SocketPair<>(clientSock, acceptedSock);
    399     }
    400 
    401     private void checkSocketPair(GenericSocket left, GenericSocket right) throws Exception {
    402         left.send(TEST_DATA);
    403         assertArrayEquals(TEST_DATA, right.receive());
    404 
    405         right.send(TEST_DATA);
    406         assertArrayEquals(TEST_DATA, left.receive());
    407 
    408         left.close();
    409         right.close();
    410     }
    411 
    412     private void checkUnconnectedUdpSocketPair(
    413             GenericUdpSocket left, GenericUdpSocket right, InetAddress localAddr) throws Exception {
    414         left.sendTo(TEST_DATA, localAddr, right.getPort());
    415         assertArrayEquals(TEST_DATA, right.receive());
    416 
    417         right.sendTo(TEST_DATA, localAddr, left.getPort());
    418         assertArrayEquals(TEST_DATA, left.receive());
    419 
    420         left.close();
    421         right.close();
    422     }
    423 
    424     protected static IpSecTransform buildIpSecTransform(
    425             Context mContext,
    426             IpSecManager.SecurityParameterIndex spi,
    427             IpSecManager.UdpEncapsulationSocket encapSocket,
    428             InetAddress remoteAddr)
    429             throws Exception {
    430         String localAddr = (remoteAddr instanceof Inet4Address) ? IPV4_LOOPBACK : IPV6_LOOPBACK;
    431         IpSecTransform.Builder builder =
    432                 new IpSecTransform.Builder(mContext)
    433                 .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
    434                 .setAuthentication(
    435                         new IpSecAlgorithm(
    436                                 IpSecAlgorithm.AUTH_HMAC_SHA256,
    437                                 AUTH_KEY,
    438                                 AUTH_KEY.length * 4));
    439 
    440         if (encapSocket != null) {
    441             builder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
    442         }
    443 
    444         return builder.buildTransportModeTransform(InetAddress.getByName(localAddr), spi);
    445     }
    446 
    447     private IpSecTransform buildDefaultTransform(InetAddress localAddr) throws Exception {
    448         try (IpSecManager.SecurityParameterIndex spi =
    449                 mISM.allocateSecurityParameterIndex(localAddr)) {
    450             return buildIpSecTransform(mContext, spi, null, localAddr);
    451         }
    452     }
    453 
    454     public void testJavaTcpSocketPair() throws Exception {
    455         for (String addr : LOOPBACK_ADDRS) {
    456             InetAddress local = InetAddress.getByName(addr);
    457             try (IpSecTransform transform = buildDefaultTransform(local)) {
    458                 SocketPair<JavaTcpSocket> sockets = getJavaTcpSocketPair(local, mISM, transform);
    459                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
    460             }
    461         }
    462     }
    463 
    464     public void testJavaUdpSocketPair() throws Exception {
    465         for (String addr : LOOPBACK_ADDRS) {
    466             InetAddress local = InetAddress.getByName(addr);
    467             try (IpSecTransform transform = buildDefaultTransform(local)) {
    468                 SocketPair<JavaUdpSocket> sockets =
    469                         getJavaUdpSocketPair(local, mISM, transform, true);
    470                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
    471             }
    472         }
    473     }
    474 
    475     public void testJavaUdpSocketPairUnconnected() throws Exception {
    476         for (String addr : LOOPBACK_ADDRS) {
    477             InetAddress local = InetAddress.getByName(addr);
    478             try (IpSecTransform transform = buildDefaultTransform(local)) {
    479                 SocketPair<JavaUdpSocket> sockets =
    480                         getJavaUdpSocketPair(local, mISM, transform, false);
    481                 checkUnconnectedUdpSocketPair(sockets.mLeftSock, sockets.mRightSock, local);
    482             }
    483         }
    484     }
    485 
    486     public void testNativeTcpSocketPair() throws Exception {
    487         for (String addr : LOOPBACK_ADDRS) {
    488             InetAddress local = InetAddress.getByName(addr);
    489             try (IpSecTransform transform = buildDefaultTransform(local)) {
    490                 SocketPair<NativeTcpSocket> sockets =
    491                         getNativeTcpSocketPair(local, mISM, transform);
    492                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
    493             }
    494         }
    495     }
    496 
    497     public void testNativeUdpSocketPair() throws Exception {
    498         for (String addr : LOOPBACK_ADDRS) {
    499             InetAddress local = InetAddress.getByName(addr);
    500             try (IpSecTransform transform = buildDefaultTransform(local)) {
    501                 SocketPair<NativeUdpSocket> sockets =
    502                         getNativeUdpSocketPair(local, mISM, transform, true);
    503                 checkSocketPair(sockets.mLeftSock, sockets.mRightSock);
    504             }
    505         }
    506     }
    507 
    508     public void testNativeUdpSocketPairUnconnected() throws Exception {
    509         for (String addr : LOOPBACK_ADDRS) {
    510             InetAddress local = InetAddress.getByName(addr);
    511             try (IpSecTransform transform = buildDefaultTransform(local)) {
    512                 SocketPair<NativeUdpSocket> sockets =
    513                         getNativeUdpSocketPair(local, mISM, transform, false);
    514                 checkUnconnectedUdpSocketPair(sockets.mLeftSock, sockets.mRightSock, local);
    515             }
    516         }
    517     }
    518 }
    519