Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2017 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 # pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
     18 from errno import *  # pylint: disable=wildcard-import
     19 import os
     20 import itertools
     21 from scapy import all as scapy
     22 from socket import *  # pylint: disable=wildcard-import
     23 import subprocess
     24 import threading
     25 import unittest
     26 
     27 import multinetwork_base
     28 import net_test
     29 from tun_twister import TapTwister
     30 import xfrm
     31 import xfrm_base
     32 
     33 # List of encryption algorithms for use in ParamTests.
     34 CRYPT_ALGOS = [
     35     xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 128)),
     36     xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 192)),
     37     xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 256)),
     38 ]
     39 
     40 # List of auth algorithms for use in ParamTests.
     41 AUTH_ALGOS = [
     42     # RFC 4868 specifies that the only supported truncation length is half the
     43     # hash size.
     44     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 96)),
     45     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 96)),
     46     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 128)),
     47     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 192)),
     48     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 256)),
     49     # Test larger truncation lengths for good measure.
     50     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 128)),
     51     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 160)),
     52     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 256)),
     53     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 384)),
     54     xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 512)),
     55 ]
     56 
     57 # List of aead algorithms for use in ParamTests.
     58 AEAD_ALGOS = [
     59     # RFC 4106 specifies that key length must be 128, 192 or 256 bits,
     60     #   with an additional 4 bytes (32 bits) of salt. The salt must be unique
     61     #   for each new SA using the same key.
     62     # RFC 4106 specifies that ICV length must be 8, 12, or 16 bytes
     63     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32,  8*8)),
     64     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 12*8)),
     65     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 16*8)),
     66     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32,  8*8)),
     67     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 12*8)),
     68     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 16*8)),
     69     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32,  8*8)),
     70     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 12*8)),
     71     xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 16*8)),
     72 ]
     73 
     74 def InjectTests():
     75     XfrmAlgorithmTest.InjectTests()
     76 
     77 class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest):
     78   @classmethod
     79   def InjectTests(cls):
     80     """Inject parameterized test cases into this class.
     81 
     82     Because a library for parameterized testing is not availble in
     83     net_test.rootfs.20150203, this does a minimal parameterization.
     84 
     85     This finds methods named like "ParamTestFoo" and replaces them with several
     86     "testFoo(*)" methods taking different parameter dicts. A set of test
     87     parameters is generated from every combination of encryption,
     88     authentication, IP version, and TCP/UDP.
     89 
     90     The benefit of this approach is that an individually failing tests have a
     91     clearly separated stack trace, and one failed test doesn't prevent the rest
     92     from running.
     93     """
     94     param_test_names = [
     95         name for name in dir(cls) if name.startswith("ParamTest")
     96     ]
     97     VERSIONS = (4, 6)
     98     TYPES = (SOCK_DGRAM, SOCK_STREAM)
     99 
    100     # Tests all combinations of auth & crypt. Mutually exclusive with aead.
    101     for crypt, auth, version, proto, name in itertools.product(
    102         CRYPT_ALGOS, AUTH_ALGOS, VERSIONS, TYPES, param_test_names):
    103       XfrmAlgorithmTest.InjectSingleTest(name, version, proto, crypt=crypt, auth=auth)
    104 
    105     # Tests all combinations of aead. Mutually exclusive with auth/crypt.
    106     for aead, version, proto, name in itertools.product(
    107         AEAD_ALGOS, VERSIONS, TYPES, param_test_names):
    108       XfrmAlgorithmTest.InjectSingleTest(name, version, proto, aead=aead)
    109 
    110   @classmethod
    111   def InjectSingleTest(cls, name, version, proto, crypt=None, auth=None, aead=None):
    112     func = getattr(cls, name)
    113 
    114     def TestClosure(self):
    115       func(self, {"crypt": crypt, "auth": auth, "aead": aead,
    116           "version": version, "proto": proto})
    117 
    118     # Produce a unique and readable name for each test. e.g.
    119     #     testSocketPolicySimple_cbc-aes_256_hmac-sha512_512_256_IPv6_UDP
    120     param_string = ""
    121     if crypt is not None:
    122       param_string += "%s_%d_" % (crypt.name, crypt.key_len)
    123 
    124     if auth is not None:
    125       param_string += "%s_%d_%d_" % (auth.name, auth.key_len,
    126           auth.trunc_len)
    127 
    128     if aead is not None:
    129       param_string += "%s_%d_%d_" % (aead.name, aead.key_len,
    130           aead.icv_len)
    131 
    132     param_string += "%s_%s" % ("IPv4" if version == 4 else "IPv6",
    133         "UDP" if proto == SOCK_DGRAM else "TCP")
    134     new_name = "%s_%s" % (func.__name__.replace("ParamTest", "test"),
    135                           param_string)
    136     new_name = new_name.replace("(", "-").replace(")", "")  # remove parens
    137     setattr(cls, new_name, TestClosure)
    138 
    139   def ParamTestSocketPolicySimple(self, params):
    140     """Test two-way traffic using transport mode and socket policies."""
    141 
    142     def AssertEncrypted(packet):
    143       # This gives a free pass to ICMP and ICMPv6 packets, which show up
    144       # nondeterministically in tests.
    145       self.assertEquals(None,
    146                         packet.getlayer(scapy.UDP),
    147                         "UDP packet sent in the clear")
    148       self.assertEquals(None,
    149                         packet.getlayer(scapy.TCP),
    150                         "TCP packet sent in the clear")
    151 
    152     # We create a pair of sockets, "left" and "right", that will talk to each
    153     # other using transport mode ESP. Because of TapTwister, both sockets
    154     # perceive each other as owning "remote_addr".
    155     netid = self.RandomNetid()
    156     family = net_test.GetAddressFamily(params["version"])
    157     local_addr = self.MyAddress(params["version"], netid)
    158     remote_addr = self.GetRemoteSocketAddress(params["version"])
    159     crypt_left = (xfrm.XfrmAlgo((
    160         params["crypt"].name,
    161         params["crypt"].key_len)),
    162         os.urandom(params["crypt"].key_len / 8)) if params["crypt"] else None
    163     crypt_right = (xfrm.XfrmAlgo((
    164         params["crypt"].name,
    165         params["crypt"].key_len)),
    166         os.urandom(params["crypt"].key_len / 8)) if params["crypt"] else None
    167     auth_left = (xfrm.XfrmAlgoAuth((
    168         params["auth"].name,
    169         params["auth"].key_len,
    170         params["auth"].trunc_len)),
    171         os.urandom(params["auth"].key_len / 8)) if params["auth"] else None
    172     auth_right = (xfrm.XfrmAlgoAuth((
    173         params["auth"].name,
    174         params["auth"].key_len,
    175         params["auth"].trunc_len)),
    176         os.urandom(params["auth"].key_len / 8)) if params["auth"] else None
    177     aead_left = (xfrm.XfrmAlgoAead((
    178         params["aead"].name,
    179         params["aead"].key_len,
    180         params["aead"].icv_len)),
    181         os.urandom(params["aead"].key_len / 8)) if params["aead"] else None
    182     aead_right = (xfrm.XfrmAlgoAead((
    183         params["aead"].name,
    184         params["aead"].key_len,
    185         params["aead"].icv_len)),
    186         os.urandom(params["aead"].key_len / 8)) if params["aead"] else None
    187     spi_left = 0xbeefface
    188     spi_right = 0xcafed00d
    189     req_ids = [100, 200, 300, 400]  # Used to match templates and SAs.
    190 
    191     # Left outbound SA
    192     self.xfrm.AddSaInfo(
    193         src=local_addr,
    194         dst=remote_addr,
    195         spi=spi_right,
    196         mode=xfrm.XFRM_MODE_TRANSPORT,
    197         reqid=req_ids[0],
    198         encryption=crypt_right,
    199         auth_trunc=auth_right,
    200         aead=aead_right,
    201         encap=None,
    202         mark=None,
    203         output_mark=None)
    204     # Right inbound SA
    205     self.xfrm.AddSaInfo(
    206         src=remote_addr,
    207         dst=local_addr,
    208         spi=spi_right,
    209         mode=xfrm.XFRM_MODE_TRANSPORT,
    210         reqid=req_ids[1],
    211         encryption=crypt_right,
    212         auth_trunc=auth_right,
    213         aead=aead_right,
    214         encap=None,
    215         mark=None,
    216         output_mark=None)
    217     # Right outbound SA
    218     self.xfrm.AddSaInfo(
    219         src=local_addr,
    220         dst=remote_addr,
    221         spi=spi_left,
    222         mode=xfrm.XFRM_MODE_TRANSPORT,
    223         reqid=req_ids[2],
    224         encryption=crypt_left,
    225         auth_trunc=auth_left,
    226         aead=aead_left,
    227         encap=None,
    228         mark=None,
    229         output_mark=None)
    230     # Left inbound SA
    231     self.xfrm.AddSaInfo(
    232         src=remote_addr,
    233         dst=local_addr,
    234         spi=spi_left,
    235         mode=xfrm.XFRM_MODE_TRANSPORT,
    236         reqid=req_ids[3],
    237         encryption=crypt_left,
    238         auth_trunc=auth_left,
    239         aead=aead_left,
    240         encap=None,
    241         mark=None,
    242         output_mark=None)
    243 
    244     # Make two sockets.
    245     sock_left = socket(family, params["proto"], 0)
    246     sock_left.settimeout(2.0)
    247     sock_left.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
    248     self.SelectInterface(sock_left, netid, "mark")
    249     sock_right = socket(family, params["proto"], 0)
    250     sock_right.settimeout(2.0)
    251     sock_right.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
    252     self.SelectInterface(sock_right, netid, "mark")
    253 
    254     # For UDP, set SO_LINGER to 0, to prevent TCP sockets from hanging around
    255     # in a TIME_WAIT state.
    256     if params["proto"] == SOCK_STREAM:
    257         net_test.DisableFinWait(sock_left)
    258         net_test.DisableFinWait(sock_right)
    259 
    260     # Apply the left outbound socket policy.
    261     xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_OUT,
    262                                 spi_right, req_ids[0], None)
    263     # Apply right inbound socket policy.
    264     xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_IN,
    265                                 spi_right, req_ids[1], None)
    266     # Apply right outbound socket policy.
    267     xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_OUT,
    268                                 spi_left, req_ids[2], None)
    269     # Apply left inbound socket policy.
    270     xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_IN,
    271                                 spi_left, req_ids[3], None)
    272 
    273     server_ready = threading.Event()
    274     server_error = None  # Save exceptions thrown by the server.
    275 
    276     def TcpServer(sock, client_port):
    277       try:
    278         sock.listen(1)
    279         server_ready.set()
    280         accepted, peer = sock.accept()
    281         self.assertEquals(remote_addr, peer[0])
    282         self.assertEquals(client_port, peer[1])
    283         data = accepted.recv(2048)
    284         self.assertEquals("hello request", data)
    285         accepted.send("hello response")
    286       except Exception as e:
    287         server_error = e
    288       finally:
    289         sock.close()
    290 
    291     def UdpServer(sock, client_port):
    292       try:
    293         server_ready.set()
    294         data, peer = sock.recvfrom(2048)
    295         self.assertEquals(remote_addr, peer[0])
    296         self.assertEquals(client_port, peer[1])
    297         self.assertEquals("hello request", data)
    298         sock.sendto("hello response", peer)
    299       except Exception as e:
    300         server_error = e
    301       finally:
    302         sock.close()
    303 
    304     # Server and client need to know each other's port numbers in advance.
    305     wildcard_addr = net_test.GetWildcardAddress(params["version"])
    306     sock_left.bind((wildcard_addr, 0))
    307     sock_right.bind((wildcard_addr, 0))
    308     left_port = sock_left.getsockname()[1]
    309     right_port = sock_right.getsockname()[1]
    310 
    311     # Start the appropriate server type on sock_right.
    312     target = TcpServer if params["proto"] == SOCK_STREAM else UdpServer
    313     server = threading.Thread(
    314         target=target,
    315         args=(sock_right, left_port),
    316         name="SocketServer")
    317     server.start()
    318     # Wait for server to be ready before attempting to connect. TCP retries
    319     # hide this problem, but UDP will fail outright if the server socket has
    320     # not bound when we send.
    321     self.assertTrue(server_ready.wait(2.0), "Timed out waiting for server thread")
    322 
    323     with TapTwister(fd=self.tuns[netid].fileno(), validator=AssertEncrypted):
    324       sock_left.connect((remote_addr, right_port))
    325       sock_left.send("hello request")
    326       data = sock_left.recv(2048)
    327       self.assertEquals("hello response", data)
    328       sock_left.close()
    329       server.join()
    330     if server_error:
    331       raise server_error
    332 
    333 
    334 if __name__ == "__main__":
    335   XfrmAlgorithmTest.InjectTests()
    336   unittest.main()
    337