Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2017 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 # pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
     18 from errno import *  # pylint: disable=wildcard-import
     19 from scapy import all as scapy
     20 from socket import *  # pylint: disable=wildcard-import
     21 import struct
     22 import subprocess
     23 import threading
     24 import unittest
     25 
     26 import csocket
     27 import cstruct
     28 import multinetwork_base
     29 import net_test
     30 import packets
     31 import xfrm
     32 import xfrm_base
     33 
     34 ENCRYPTED_PAYLOAD = ("b1c74998efd6326faebe2061f00f2c750e90e76001664a80c287b150"
     35                      "59e74bf949769cc6af71e51b539e7de3a2a14cb05a231b969e035174"
     36                      "d98c5aa0cef1937db98889ec0d08fa408fecf616")
     37 
     38 TEST_ADDR1 = "2001:4860:4860::8888"
     39 TEST_ADDR2 = "2001:4860:4860::8844"
     40 
     41 # IP addresses to use for tunnel endpoints. For generality, these should be
     42 # different from the addresses we send packets to.
     43 TUNNEL_ENDPOINTS = {4: "8.8.4.4", 6: TEST_ADDR2}
     44 
     45 TEST_SPI = 0x1234
     46 TEST_SPI2 = 0x1235
     47 
     48 
     49 
     50 class XfrmFunctionalTest(xfrm_base.XfrmLazyTest):
     51 
     52   def assertIsUdpEncapEsp(self, packet, spi, seq, length):
     53     self.assertEquals(IPPROTO_UDP, packet.proto)
     54     udp_hdr = packet[scapy.UDP]
     55     self.assertEquals(4500, udp_hdr.dport)
     56     self.assertEquals(length, len(udp_hdr))
     57     esp_hdr, _ = cstruct.Read(str(udp_hdr.payload), xfrm.EspHdr)
     58     # FIXME: this file currently swaps SPI byte order manually, so SPI needs to
     59     # be double-swapped here.
     60     self.assertEquals(xfrm.EspHdr((spi, seq)), esp_hdr)
     61 
     62   def CreateNewSa(self, localAddr, remoteAddr, spi, reqId, encap_tmpl,
     63                   null_auth=False):
     64     auth_algo = (
     65         xfrm_base._ALGO_AUTH_NULL if null_auth else xfrm_base._ALGO_HMAC_SHA1)
     66     self.xfrm.AddSaInfo(localAddr, remoteAddr, spi, xfrm.XFRM_MODE_TRANSPORT,
     67                     reqId, xfrm_base._ALGO_CBC_AES_256, auth_algo, None,
     68                     encap_tmpl, None, None)
     69 
     70   def testAddSa(self):
     71     self.CreateNewSa("::", TEST_ADDR1, TEST_SPI, 3320, None)
     72     expected = (
     73         "src :: dst 2001:4860:4860::8888\n"
     74         "\tproto esp spi 0x00001234 reqid 3320 mode transport\n"
     75         "\treplay-window 4 \n"
     76         "\tauth-trunc hmac(sha1) 0x%s 96\n"
     77         "\tenc cbc(aes) 0x%s\n"
     78         "\tsel src ::/0 dst ::/0 \n" % (
     79             xfrm_base._AUTHENTICATION_KEY_128.encode("hex"),
     80             xfrm_base._ENCRYPTION_KEY_256.encode("hex")))
     81 
     82     actual = subprocess.check_output("ip xfrm state".split())
     83     # Newer versions of IP also show anti-replay context. Don't choke if it's
     84     # missing.
     85     actual = actual.replace(
     86         "\tanti-replay context: seq 0x0, oseq 0x0, bitmap 0x00000000\n", "")
     87     try:
     88       self.assertMultiLineEqual(expected, actual)
     89     finally:
     90       self.xfrm.DeleteSaInfo(TEST_ADDR1, TEST_SPI, IPPROTO_ESP)
     91 
     92   def testFlush(self):
     93     self.assertEquals(0, len(self.xfrm.DumpSaInfo()))
     94     self.CreateNewSa("::", "2000::", TEST_SPI, 1234, None)
     95     self.CreateNewSa("0.0.0.0", "192.0.2.1", TEST_SPI, 4321, None)
     96     self.assertEquals(2, len(self.xfrm.DumpSaInfo()))
     97     self.xfrm.FlushSaInfo()
     98     self.assertEquals(0, len(self.xfrm.DumpSaInfo()))
     99 
    100   def _TestSocketPolicy(self, version):
    101     # Open a UDP socket and connect it.
    102     family = net_test.GetAddressFamily(version)
    103     s = socket(family, SOCK_DGRAM, 0)
    104     netid = self.RandomNetid()
    105     self.SelectInterface(s, netid, "mark")
    106 
    107     remotesockaddr = self.GetRemoteSocketAddress(version)
    108     s.connect((remotesockaddr, 53))
    109     saddr, sport = s.getsockname()[:2]
    110     daddr, dport = s.getpeername()[:2]
    111     if version == 5:
    112       saddr = saddr.replace("::ffff:", "")
    113       daddr = daddr.replace("::ffff:", "")
    114 
    115     reqid = 0
    116 
    117     desc, pkt = packets.UDP(version, saddr, daddr, sport=sport)
    118     s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
    119     self.ExpectPacketOn(netid, "Send after socket, expected %s" % desc, pkt)
    120 
    121     # Using IPv4 XFRM on a dual-stack socket requires setting an AF_INET policy
    122     # that's written in terms of IPv4 addresses.
    123     xfrm_version = 4 if version == 5 else version
    124     xfrm_family = net_test.GetAddressFamily(xfrm_version)
    125     xfrm_base.ApplySocketPolicy(s, xfrm_family, xfrm.XFRM_POLICY_OUT,
    126                                 TEST_SPI, reqid, None)
    127 
    128     # Because the policy has level set to "require" (the default), attempting
    129     # to send a packet results in an error, because there is no SA that
    130     # matches the socket policy we set.
    131     self.assertRaisesErrno(
    132         EAGAIN,
    133         s.sendto, net_test.UDP_PAYLOAD, (remotesockaddr, 53))
    134 
    135     # Adding a matching SA causes the packet to go out encrypted. The SA's
    136     # SPI must match the one in our template, and the destination address must
    137     # match the packet's destination address (in tunnel mode, it has to match
    138     # the tunnel destination).
    139     self.CreateNewSa(
    140         net_test.GetWildcardAddress(xfrm_version),
    141         self.GetRemoteAddress(xfrm_version), TEST_SPI, reqid, None)
    142     s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
    143     expected_length = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TRANSPORT,
    144                                                 version, False,
    145                                                 net_test.UDP_PAYLOAD,
    146                                                 xfrm_base._ALGO_HMAC_SHA1,
    147                                                 xfrm_base._ALGO_CBC_AES_256)
    148     self._ExpectEspPacketOn(netid, TEST_SPI, 1, expected_length, None, None)
    149 
    150     # Sending to another destination doesn't work: again, no matching SA.
    151     remoteaddr2 = self.GetOtherRemoteSocketAddress(version)
    152     self.assertRaisesErrno(
    153         EAGAIN,
    154         s.sendto, net_test.UDP_PAYLOAD, (remoteaddr2, 53))
    155 
    156     # Sending on another socket without the policy applied results in an
    157     # unencrypted packet going out.
    158     s2 = socket(family, SOCK_DGRAM, 0)
    159     self.SelectInterface(s2, netid, "mark")
    160     s2.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
    161     pkts = self.ReadAllPacketsOn(netid)
    162     self.assertEquals(1, len(pkts))
    163     packet = pkts[0]
    164 
    165     protocol = packet.nh if version == 6 else packet.proto
    166     self.assertEquals(IPPROTO_UDP, protocol)
    167 
    168     # Deleting the SA causes the first socket to return errors again.
    169     self.xfrm.DeleteSaInfo(self.GetRemoteAddress(xfrm_version), TEST_SPI,
    170                            IPPROTO_ESP)
    171     self.assertRaisesErrno(
    172         EAGAIN,
    173         s.sendto, net_test.UDP_PAYLOAD, (remotesockaddr, 53))
    174 
    175     # Clear the socket policy and expect a cleartext packet.
    176     xfrm_base.SetPolicySockopt(s, family, None)
    177     s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
    178     self.ExpectPacketOn(netid, "Send after clear, expected %s" % desc, pkt)
    179 
    180     # Clearing the policy twice is safe.
    181     xfrm_base.SetPolicySockopt(s, family, None)
    182     s.sendto(net_test.UDP_PAYLOAD, (remotesockaddr, 53))
    183     self.ExpectPacketOn(netid, "Send after clear 2, expected %s" % desc, pkt)
    184 
    185     # Clearing if a policy was never set is safe.
    186     s = socket(AF_INET6, SOCK_DGRAM, 0)
    187     xfrm_base.SetPolicySockopt(s, family, None)
    188 
    189   def testSocketPolicyIPv4(self):
    190     self._TestSocketPolicy(4)
    191 
    192   def testSocketPolicyIPv6(self):
    193     self._TestSocketPolicy(6)
    194 
    195   def testSocketPolicyMapped(self):
    196     self._TestSocketPolicy(5)
    197 
    198   # Sets up sockets and marks to correct netid
    199   def _SetupUdpEncapSockets(self):
    200     netid = self.RandomNetid()
    201     myaddr = self.MyAddress(4, netid)
    202     remoteaddr = self.GetRemoteAddress(4)
    203 
    204     # Reserve a port on which to receive UDP encapsulated packets. Sending
    205     # packets works without this (and potentially can send packets with a source
    206     # port belonging to another application), but receiving requires the port to
    207     # be bound and the encapsulation socket option enabled.
    208     encap_sock = net_test.Socket(AF_INET, SOCK_DGRAM, 0)
    209     encap_sock.bind((myaddr, 0))
    210     encap_port = encap_sock.getsockname()[1]
    211     encap_sock.setsockopt(IPPROTO_UDP, xfrm.UDP_ENCAP, xfrm.UDP_ENCAP_ESPINUDP)
    212 
    213     # Open a socket to send traffic.
    214     s = socket(AF_INET, SOCK_DGRAM, 0)
    215     self.SelectInterface(s, netid, "mark")
    216     s.connect((remoteaddr, 53))
    217 
    218     return netid, myaddr, remoteaddr, encap_sock, encap_port, s
    219 
    220   # Sets up SAs and applies socket policy to given socket
    221   def _SetupUdpEncapSaPair(self, myaddr, remoteaddr, in_spi, out_spi,
    222                            encap_port, s, use_null_auth):
    223     in_reqid = 123
    224     out_reqid = 456
    225 
    226     # Create inbound and outbound SAs that specify UDP encapsulation.
    227     encaptmpl = xfrm.XfrmEncapTmpl((xfrm.UDP_ENCAP_ESPINUDP, htons(encap_port),
    228                                     htons(4500), 16 * "\x00"))
    229     self.CreateNewSa(myaddr, remoteaddr, out_spi, out_reqid, encaptmpl,
    230                      use_null_auth)
    231 
    232     # Add an encap template that's the mirror of the outbound one.
    233     encaptmpl.sport, encaptmpl.dport = encaptmpl.dport, encaptmpl.sport
    234     self.CreateNewSa(remoteaddr, myaddr, in_spi, in_reqid, encaptmpl,
    235                      use_null_auth)
    236 
    237     # Apply socket policies to s.
    238     xfrm_base.ApplySocketPolicy(s, AF_INET, xfrm.XFRM_POLICY_OUT, out_spi,
    239                                 out_reqid, None)
    240 
    241     # TODO: why does this work without a per-socket policy applied?
    242     # The received  packet obviously matches an SA, but don't inbound packets
    243     # need to match a policy as well? (b/71541609)
    244     xfrm_base.ApplySocketPolicy(s, AF_INET, xfrm.XFRM_POLICY_IN, in_spi,
    245                                 in_reqid, None)
    246 
    247     # Uncomment for debugging.
    248     # subprocess.call("ip xfrm state".split())
    249 
    250   # Check that packets can be sent and received.
    251   def _VerifyUdpEncapSocket(self, netid, remoteaddr, myaddr, encap_port, sock,
    252                            in_spi, out_spi, null_auth, seq_num):
    253     # Now send a packet.
    254     sock.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
    255     srcport = sock.getsockname()[1]
    256 
    257     # Expect to see an UDP encapsulated packet.
    258     pkts = self.ReadAllPacketsOn(netid)
    259     self.assertEquals(1, len(pkts))
    260     packet = pkts[0]
    261 
    262     auth_algo = (
    263         xfrm_base._ALGO_AUTH_NULL if null_auth else xfrm_base._ALGO_HMAC_SHA1)
    264     expected_len = xfrm_base.GetEspPacketLength(
    265         xfrm.XFRM_MODE_TRANSPORT, 4, True, net_test.UDP_PAYLOAD, auth_algo,
    266         xfrm_base._ALGO_CBC_AES_256)
    267     self.assertIsUdpEncapEsp(packet, out_spi, seq_num, expected_len)
    268 
    269     # Now test the receive path. Because we don't know how to decrypt packets,
    270     # we just play back the encrypted packet that kernel sent earlier. We swap
    271     # the addresses in the IP header to make the packet look like it's bound for
    272     # us, but we can't do that for the port numbers because the UDP header is
    273     # part of the integrity protected payload, which we can only replay as is.
    274     # So the source and destination ports are swapped and the packet appears to
    275     # be sent from srcport to port 53. Open another socket on that port, and
    276     # apply the inbound policy to it.
    277     twisted_socket = socket(AF_INET, SOCK_DGRAM, 0)
    278     csocket.SetSocketTimeout(twisted_socket, 100)
    279     twisted_socket.bind(("0.0.0.0", 53))
    280 
    281     # Save the payload of the packet so we can replay it back to ourselves, and
    282     # replace the SPI with our inbound SPI.
    283     payload = str(packet.payload)[8:]
    284     spi_seq = xfrm.EspHdr((in_spi, seq_num)).Pack()
    285     payload = spi_seq + payload[len(spi_seq):]
    286 
    287     sainfo = self.xfrm.FindSaInfo(in_spi)
    288     start_integrity_failures = sainfo.stats.integrity_failed
    289 
    290     # Now play back the valid packet and check that we receive it.
    291     incoming = (scapy.IP(src=remoteaddr, dst=myaddr) /
    292                 scapy.UDP(sport=4500, dport=encap_port) / payload)
    293     incoming = scapy.IP(str(incoming))
    294     self.ReceivePacketOn(netid, incoming)
    295 
    296     sainfo = self.xfrm.FindSaInfo(in_spi)
    297 
    298     # TODO: break this out into a separate test
    299     # If our SPIs are different, and we aren't using null authentication,
    300     # we expect the packet to be dropped. We also expect that the integrity
    301     # failure counter to increase, as SPIs are part of the authenticated or
    302     # integrity-verified portion of the packet.
    303     if not null_auth and in_spi != out_spi:
    304       self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096)
    305       self.assertEquals(start_integrity_failures + 1,
    306                         sainfo.stats.integrity_failed)
    307     else:
    308       data, src = twisted_socket.recvfrom(4096)
    309       self.assertEquals(net_test.UDP_PAYLOAD, data)
    310       self.assertEquals((remoteaddr, srcport), src)
    311       self.assertEquals(start_integrity_failures, sainfo.stats.integrity_failed)
    312 
    313     # Check that unencrypted packets on twisted_socket are not received.
    314     unencrypted = (
    315         scapy.IP(src=remoteaddr, dst=myaddr) / scapy.UDP(
    316             sport=srcport, dport=53) / net_test.UDP_PAYLOAD)
    317     self.assertRaisesErrno(EAGAIN, twisted_socket.recv, 4096)
    318 
    319   def _RunEncapSocketPolicyTest(self, in_spi, out_spi, use_null_auth):
    320     netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
    321         self._SetupUdpEncapSockets()
    322 
    323     self._SetupUdpEncapSaPair(myaddr, remoteaddr, in_spi, out_spi, encap_port,
    324                               s, use_null_auth)
    325 
    326     # Check that UDP encap sockets work with socket policy and given SAs
    327     self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s, in_spi,
    328                                out_spi, use_null_auth, 1)
    329 
    330   # TODO: Add tests for ESP (non-encap) sockets.
    331   def testUdpEncapSameSpisNullAuth(self):
    332     # Use the same SPI both inbound and outbound because this lets us receive
    333     # encrypted packets by simply replaying the packets the kernel sends
    334     # without having to disable authentication
    335     self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI, True)
    336 
    337   def testUdpEncapSameSpis(self):
    338     self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI, False)
    339 
    340   def testUdpEncapDifferentSpisNullAuth(self):
    341     self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI2, True)
    342 
    343   def testUdpEncapDifferentSpis(self):
    344     self._RunEncapSocketPolicyTest(TEST_SPI, TEST_SPI2, False)
    345 
    346   def testUdpEncapRekey(self):
    347     # Select the two SPIs that will be used
    348     start_spi = TEST_SPI
    349     rekey_spi = TEST_SPI2
    350 
    351     # Setup sockets
    352     netid, myaddr, remoteaddr, encap_sock, encap_port, s = \
    353         self._SetupUdpEncapSockets()
    354 
    355     # The SAs must use null authentication, since we change SPIs on the fly
    356     # Without null authentication, this would result in an ESP authentication
    357     # error since the SPI is part of the authenticated section. The packet
    358     # would then be dropped
    359     self._SetupUdpEncapSaPair(myaddr, remoteaddr, start_spi, start_spi,
    360                               encap_port, s, True)
    361 
    362     # Check that UDP encap sockets work with socket policy and given SAs
    363     self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
    364                                start_spi, start_spi, True, 1)
    365 
    366     # Rekey this socket using the make-before-break paradigm. First we create
    367     # new SAs, update the per-socket policies, and only then remove the old SAs
    368     #
    369     # This allows us to switch to the new SA without breaking the outbound path.
    370     self._SetupUdpEncapSaPair(myaddr, remoteaddr, rekey_spi, rekey_spi,
    371                               encap_port, s, True)
    372 
    373     # Check that UDP encap socket works with updated socket policy, sending
    374     # using new SA, but receiving on both old and new SAs
    375     self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
    376                                rekey_spi, rekey_spi, True, 1)
    377     self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
    378                                start_spi, rekey_spi, True, 2)
    379 
    380     # Delete old SAs
    381     self.xfrm.DeleteSaInfo(remoteaddr, start_spi, IPPROTO_ESP)
    382     self.xfrm.DeleteSaInfo(myaddr, start_spi, IPPROTO_ESP)
    383 
    384     # Check that UDP encap socket works with updated socket policy and new SAs
    385     self._VerifyUdpEncapSocket(netid, remoteaddr, myaddr, encap_port, s,
    386                                rekey_spi, rekey_spi, True, 3)
    387 
    388   def testAllocSpecificSpi(self):
    389     spi = 0xABCD
    390     new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
    391     self.assertEquals(spi, new_sa.id.spi)
    392 
    393   def testAllocSpecificSpiUnavailable(self):
    394     """Attempt to allocate the same SPI twice."""
    395     spi = 0xABCD
    396     new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
    397     self.assertEquals(spi, new_sa.id.spi)
    398     with self.assertRaisesErrno(ENOENT):
    399       new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, spi, spi)
    400 
    401   def testAllocRangeSpi(self):
    402     start, end = 0xABCD0, 0xABCDF
    403     new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end)
    404     spi = new_sa.id.spi
    405     self.assertGreaterEqual(spi, start)
    406     self.assertLessEqual(spi, end)
    407 
    408   def testAllocRangeSpiUnavailable(self):
    409     """Attempt to allocate N+1 SPIs from a range of size N."""
    410     start, end = 0xABCD0, 0xABCDF
    411     range_size = end - start + 1
    412     spis = set()
    413     # Assert that allocating SPI fails when none are available.
    414     with self.assertRaisesErrno(ENOENT):
    415       # Allocating range_size + 1 SPIs is guaranteed to fail.  Due to the way
    416       # kernel picks random SPIs, this has a high probability of failing before
    417       # reaching that limit.
    418       for i in xrange(range_size + 1):
    419         new_sa = self.xfrm.AllocSpi("::", IPPROTO_ESP, start, end)
    420         spi = new_sa.id.spi
    421         self.assertNotIn(spi, spis)
    422         spis.add(spi)
    423 
    424   def testSocketPolicyDstCacheV6(self):
    425     self._TestSocketPolicyDstCache(6)
    426 
    427   def testSocketPolicyDstCacheV4(self):
    428     self._TestSocketPolicyDstCache(4)
    429 
    430   def _TestSocketPolicyDstCache(self, version):
    431     """Test that destination cache is cleared with socket policy.
    432 
    433     This relies on the fact that connect() on a UDP socket populates the
    434     destination cache.
    435     """
    436 
    437     # Create UDP socket.
    438     family = net_test.GetAddressFamily(version)
    439     netid = self.RandomNetid()
    440     s = socket(family, SOCK_DGRAM, 0)
    441     self.SelectInterface(s, netid, "mark")
    442 
    443     # Populate the socket's destination cache.
    444     remote = self.GetRemoteAddress(version)
    445     s.connect((remote, 53))
    446 
    447     # Apply a policy to the socket. Should clear dst cache.
    448     reqid = 123
    449     xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT,
    450                                 TEST_SPI, reqid, None)
    451 
    452     # Policy with no matching SA should result in EAGAIN. If destination cache
    453     # failed to clear, then the UDP packet will be sent normally.
    454     with self.assertRaisesErrno(EAGAIN):
    455       s.send(net_test.UDP_PAYLOAD)
    456     self.ExpectNoPacketsOn(netid, "Packet not blocked by policy")
    457 
    458   def _CheckNullEncryptionTunnelMode(self, version):
    459     family = net_test.GetAddressFamily(version)
    460     netid = self.RandomNetid()
    461     local_addr = self.MyAddress(version, netid)
    462     remote_addr = self.GetRemoteAddress(version)
    463 
    464     # Borrow the address of another netId as the source address of the tunnel
    465     tun_local = self.MyAddress(version, self.RandomNetid(netid))
    466     # For generality, pick a tunnel endpoint that's not the address we
    467     # connect the socket to.
    468     tun_remote = TUNNEL_ENDPOINTS[version]
    469 
    470     # Output
    471     self.xfrm.AddSaInfo(
    472         tun_local, tun_remote, 0xABCD, xfrm.XFRM_MODE_TUNNEL, 123,
    473         xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
    474         None, None, None, netid)
    475     # Input
    476     self.xfrm.AddSaInfo(
    477         tun_remote, tun_local, 0x9876, xfrm.XFRM_MODE_TUNNEL, 456,
    478         xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
    479         None, None, None, None)
    480 
    481     sock = net_test.UDPSocket(family)
    482     self.SelectInterface(sock, netid, "mark")
    483     sock.bind((local_addr, 0))
    484     local_port = sock.getsockname()[1]
    485     remote_port = 5555
    486 
    487     xfrm_base.ApplySocketPolicy(
    488         sock, family, xfrm.XFRM_POLICY_OUT, 0xABCD, 123,
    489         (tun_local, tun_remote))
    490     xfrm_base.ApplySocketPolicy(
    491         sock, family, xfrm.XFRM_POLICY_IN, 0x9876, 456,
    492         (tun_remote, tun_local))
    493 
    494     # Create and receive an ESP packet.
    495     IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
    496     input_pkt = (IpType(src=remote_addr, dst=local_addr) /
    497                  scapy.UDP(sport=remote_port, dport=local_port) /
    498                  "input hello")
    499     input_pkt = IpType(str(input_pkt)) # Compute length, checksum.
    500     input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, 0x9876,
    501                                                 1, (tun_remote, tun_local))
    502 
    503     self.ReceivePacketOn(netid, input_pkt)
    504     msg, addr = sock.recvfrom(1024)
    505     self.assertEquals("input hello", msg)
    506     self.assertEquals((remote_addr, remote_port), addr[:2])
    507 
    508     # Send and capture a packet.
    509     sock.sendto("output hello", (remote_addr, remote_port))
    510     packets = self.ReadAllPacketsOn(netid)
    511     self.assertEquals(1, len(packets))
    512     output_pkt = packets[0]
    513     output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt)
    514     self.assertEquals(output_pkt[scapy.UDP].len, len("output_hello") + 8)
    515     self.assertEquals(remote_addr, output_pkt.dst)
    516     self.assertEquals(remote_port, output_pkt[scapy.UDP].dport)
    517     # length of the payload plus the UDP header
    518     self.assertEquals("output hello", str(output_pkt[scapy.UDP].payload))
    519     self.assertEquals(0xABCD, esp_hdr.spi)
    520 
    521   def testNullEncryptionTunnelMode(self):
    522     """Verify null encryption in tunnel mode.
    523 
    524     This test verifies both manual assembly and disassembly of UDP packets
    525     with ESP in IPsec tunnel mode.
    526     """
    527     for version in [4, 6]:
    528       self._CheckNullEncryptionTunnelMode(version)
    529 
    530   def _CheckNullEncryptionTransportMode(self, version):
    531     family = net_test.GetAddressFamily(version)
    532     netid = self.RandomNetid()
    533     local_addr = self.MyAddress(version, netid)
    534     remote_addr = self.GetRemoteAddress(version)
    535 
    536     # Output
    537     self.xfrm.AddSaInfo(
    538         local_addr, remote_addr, 0xABCD, xfrm.XFRM_MODE_TRANSPORT, 123,
    539         xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
    540         None, None, None, None)
    541     # Input
    542     self.xfrm.AddSaInfo(
    543         remote_addr, local_addr, 0x9876, xfrm.XFRM_MODE_TRANSPORT, 456,
    544         xfrm_base._ALGO_CRYPT_NULL, xfrm_base._ALGO_AUTH_NULL,
    545         None, None, None, None)
    546 
    547     sock = net_test.UDPSocket(family)
    548     self.SelectInterface(sock, netid, "mark")
    549     sock.bind((local_addr, 0))
    550     local_port = sock.getsockname()[1]
    551     remote_port = 5555
    552 
    553     xfrm_base.ApplySocketPolicy(
    554         sock, family, xfrm.XFRM_POLICY_OUT, 0xABCD, 123, None)
    555     xfrm_base.ApplySocketPolicy(
    556         sock, family, xfrm.XFRM_POLICY_IN, 0x9876, 456, None)
    557 
    558     # Create and receive an ESP packet.
    559     IpType = {4: scapy.IP, 6: scapy.IPv6}[version]
    560     input_pkt = (IpType(src=remote_addr, dst=local_addr) /
    561                  scapy.UDP(sport=remote_port, dport=local_port) /
    562                  "input hello")
    563     input_pkt = IpType(str(input_pkt)) # Compute length, checksum.
    564     input_pkt = xfrm_base.EncryptPacketWithNull(input_pkt, 0x9876, 1, None)
    565 
    566     self.ReceivePacketOn(netid, input_pkt)
    567     msg, addr = sock.recvfrom(1024)
    568     self.assertEquals("input hello", msg)
    569     self.assertEquals((remote_addr, remote_port), addr[:2])
    570 
    571     # Send and capture a packet.
    572     sock.sendto("output hello", (remote_addr, remote_port))
    573     packets = self.ReadAllPacketsOn(netid)
    574     self.assertEquals(1, len(packets))
    575     output_pkt = packets[0]
    576     output_pkt, esp_hdr = xfrm_base.DecryptPacketWithNull(output_pkt)
    577     # length of the payload plus the UDP header
    578     self.assertEquals(output_pkt[scapy.UDP].len, len("output_hello") + 8)
    579     self.assertEquals(remote_addr, output_pkt.dst)
    580     self.assertEquals(remote_port, output_pkt[scapy.UDP].dport)
    581     self.assertEquals("output hello", str(output_pkt[scapy.UDP].payload))
    582     self.assertEquals(0xABCD, esp_hdr.spi)
    583 
    584   def testNullEncryptionTransportMode(self):
    585     """Verify null encryption in transport mode.
    586 
    587     This test verifies both manual assembly and disassembly of UDP packets
    588     with ESP in IPsec transport mode.
    589     """
    590     for version in [4, 6]:
    591       self._CheckNullEncryptionTransportMode(version)
    592 
    593   def _CheckGlobalPoliciesByMark(self, version):
    594     """Tests that global policies may differ by only the mark."""
    595     family = net_test.GetAddressFamily(version)
    596     sel = xfrm.EmptySelector(family)
    597     # Pick 2 arbitrary mark values.
    598     mark1 = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
    599     mark2 = xfrm.XfrmMark(mark=0xf00d, mask=xfrm_base.MARK_MASK_ALL)
    600     # Create a global policy.
    601     policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
    602     tmpl = xfrm.UserTemplate(AF_UNSPEC, 0xfeed, 0, None)
    603     # Create the policy with the first mark.
    604     self.xfrm.AddPolicyInfo(policy, tmpl, mark1)
    605     # Create the same policy but with the second (different) mark.
    606     self.xfrm.AddPolicyInfo(policy, tmpl, mark2)
    607     # Delete the policies individually
    608     self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark1)
    609     self.xfrm.DeletePolicyInfo(sel, xfrm.XFRM_POLICY_OUT, mark2)
    610 
    611   def testGlobalPoliciesByMarkV4(self):
    612     self._CheckGlobalPoliciesByMark(4)
    613 
    614   def testGlobalPoliciesByMarkV6(self):
    615     self._CheckGlobalPoliciesByMark(6)
    616 
    617   def _CheckUpdatePolicy(self, version):
    618     """Tests that we can can update the template on a policy."""
    619     family = net_test.GetAddressFamily(version)
    620     tmpl1 = xfrm.UserTemplate(family, 0xdead, 0, None)
    621     tmpl2 = xfrm.UserTemplate(family, 0xbeef, 0, None)
    622     sel = xfrm.EmptySelector(family)
    623     policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
    624     mark = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
    625 
    626     def _CheckTemplateMatch(tmpl):
    627       """Dump the SPD and match a single template on a single policy."""
    628       dump = self.xfrm.DumpPolicyInfo()
    629       self.assertEquals(1, len(dump))
    630       _, attributes = dump[0]
    631       self.assertEquals(attributes['XFRMA_TMPL'], tmpl)
    632 
    633     # Create a new policy using update.
    634     self.xfrm.UpdatePolicyInfo(policy, tmpl1, mark)
    635     # NEWPOLICY will not update the existing policy. This checks both that
    636     # UPDPOLICY created a policy and that NEWPOLICY will not perform updates.
    637     _CheckTemplateMatch(tmpl1)
    638     with self.assertRaisesErrno(EEXIST):
    639       self.xfrm.AddPolicyInfo(policy, tmpl2, mark)
    640     # Update the policy using UPDPOLICY.
    641     self.xfrm.UpdatePolicyInfo(policy, tmpl2, mark)
    642     # There should only be one policy after update, and it should have the
    643     # updated template.
    644     _CheckTemplateMatch(tmpl2)
    645 
    646   def testUpdatePolicyV4(self):
    647     self._CheckUpdatePolicy(4)
    648 
    649   def testUpdatePolicyV6(self):
    650     self._CheckUpdatePolicy(6)
    651 
    652   def _CheckPolicyDifferByDirection(self,version):
    653     """Tests that policies can differ only by direction."""
    654     family = net_test.GetAddressFamily(version)
    655     tmpl = xfrm.UserTemplate(family, 0xdead, 0, None)
    656     sel = xfrm.EmptySelector(family)
    657     mark = xfrm.XfrmMark(mark=0xf00, mask=xfrm_base.MARK_MASK_ALL)
    658     policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_OUT, sel)
    659     self.xfrm.AddPolicyInfo(policy, tmpl, mark)
    660     policy = xfrm.UserPolicy(xfrm.XFRM_POLICY_IN, sel)
    661     self.xfrm.AddPolicyInfo(policy, tmpl, mark)
    662 
    663   def testPolicyDifferByDirectionV4(self):
    664     self._CheckPolicyDifferByDirection(4)
    665 
    666   def testPolicyDifferByDirectionV6(self):
    667     self._CheckPolicyDifferByDirection(6)
    668 
    669 class XfrmOutputMarkTest(xfrm_base.XfrmLazyTest):
    670 
    671   def _CheckTunnelModeOutputMark(self, version, tunsrc, mark, expected_netid):
    672     """Tests sending UDP packets to tunnel mode SAs with output marks.
    673 
    674     Opens a UDP socket and binds it to a random netid, then sets up tunnel mode
    675     SAs with an output_mark of mark and sets a socket policy to use the SA.
    676     Then checks that sending on those SAs sends a packet on expected_netid,
    677     or, if expected_netid is zero, checks that sending returns ENETUNREACH.
    678 
    679     Args:
    680       version: 4 or 6.
    681       tunsrc: A string, the source address of the tunnel.
    682       mark: An integer, the output_mark to set in the SA.
    683       expected_netid: An integer, the netid to expect the kernel to send the
    684           packet on. If None, expect that sendto will fail with ENETUNREACH.
    685     """
    686     # Open a UDP socket and bind it to a random netid.
    687     family = net_test.GetAddressFamily(version)
    688     s = socket(family, SOCK_DGRAM, 0)
    689     self.SelectInterface(s, self.RandomNetid(), "mark")
    690 
    691     # For generality, pick a tunnel endpoint that's not the address we
    692     # connect the socket to.
    693     tundst = TUNNEL_ENDPOINTS[version]
    694     tun_addrs = (tunsrc, tundst)
    695 
    696     # Create a tunnel mode SA and use XFRM_OUTPUT_MARK to bind it to netid.
    697     spi = TEST_SPI * mark
    698     reqid = 100 + spi
    699     self.xfrm.AddSaInfo(tunsrc, tundst, spi, xfrm.XFRM_MODE_TUNNEL, reqid,
    700                         xfrm_base._ALGO_CBC_AES_256, xfrm_base._ALGO_HMAC_SHA1,
    701                         None, None, None, mark)
    702 
    703     # Set a socket policy to use it.
    704     xfrm_base.ApplySocketPolicy(s, family, xfrm.XFRM_POLICY_OUT, spi, reqid,
    705                                 tun_addrs)
    706 
    707     # Send a packet and check that we see it on the wire.
    708     remoteaddr = self.GetRemoteAddress(version)
    709 
    710     packetlen = xfrm_base.GetEspPacketLength(xfrm.XFRM_MODE_TUNNEL, version,
    711                                              False, net_test.UDP_PAYLOAD,
    712                                              xfrm_base._ALGO_HMAC_SHA1,
    713                                              xfrm_base._ALGO_CBC_AES_256)
    714 
    715     if expected_netid is not None:
    716       s.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
    717       self._ExpectEspPacketOn(expected_netid, spi, 1, packetlen, tunsrc, tundst)
    718     else:
    719       with self.assertRaisesErrno(ENETUNREACH):
    720         s.sendto(net_test.UDP_PAYLOAD, (remoteaddr, 53))
    721 
    722   def testTunnelModeOutputMarkIPv4(self):
    723     for netid in self.NETIDS:
    724       tunsrc = self.MyAddress(4, netid)
    725       self._CheckTunnelModeOutputMark(4, tunsrc, netid, netid)
    726 
    727   def testTunnelModeOutputMarkIPv6(self):
    728     for netid in self.NETIDS:
    729       tunsrc = self.MyAddress(6, netid)
    730       self._CheckTunnelModeOutputMark(6, tunsrc, netid, netid)
    731 
    732   def testTunnelModeOutputNoMarkIPv4(self):
    733     tunsrc = self.MyAddress(4, self.RandomNetid())
    734     self._CheckTunnelModeOutputMark(4, tunsrc, 0, None)
    735 
    736   def testTunnelModeOutputNoMarkIPv6(self):
    737     tunsrc = self.MyAddress(6, self.RandomNetid())
    738     self._CheckTunnelModeOutputMark(6, tunsrc, 0, None)
    739 
    740   def testTunnelModeOutputInvalidMarkIPv4(self):
    741     tunsrc = self.MyAddress(4, self.RandomNetid())
    742     self._CheckTunnelModeOutputMark(4, tunsrc, 9999, None)
    743 
    744   def testTunnelModeOutputInvalidMarkIPv6(self):
    745     tunsrc = self.MyAddress(6, self.RandomNetid())
    746     self._CheckTunnelModeOutputMark(6, tunsrc, 9999, None)
    747 
    748   def testTunnelModeOutputMarkAttributes(self):
    749     mark = 1234567
    750     self.xfrm.AddSaInfo(TEST_ADDR1, TUNNEL_ENDPOINTS[6], 0x1234,
    751                         xfrm.XFRM_MODE_TUNNEL, 100, xfrm_base._ALGO_CBC_AES_256,
    752                         xfrm_base._ALGO_HMAC_SHA1, None, None, None, mark)
    753     dump = self.xfrm.DumpSaInfo()
    754     self.assertEquals(1, len(dump))
    755     sainfo, attributes = dump[0]
    756     self.assertEquals(mark, attributes["XFRMA_OUTPUT_MARK"])
    757 
    758   def testInvalidAlgorithms(self):
    759     key = "af442892cdcd0ef650e9c299f9a8436a".decode("hex")
    760     invalid_auth = (xfrm.XfrmAlgoAuth(("invalid(algo)", 128, 96)), key)
    761     invalid_crypt = (xfrm.XfrmAlgo(("invalid(algo)", 128)), key)
    762     with self.assertRaisesErrno(ENOSYS):
    763         self.xfrm.AddSaInfo(TEST_ADDR1, TEST_ADDR2, 0x1234,
    764             xfrm.XFRM_MODE_TRANSPORT, 0, xfrm_base._ALGO_CBC_AES_256,
    765             invalid_auth, None, None, None, 0)
    766     with self.assertRaisesErrno(ENOSYS):
    767         self.xfrm.AddSaInfo(TEST_ADDR1, TEST_ADDR2, 0x1234,
    768             xfrm.XFRM_MODE_TRANSPORT, 0, invalid_crypt,
    769             xfrm_base._ALGO_HMAC_SHA1, None, None, None, 0)
    770 
    771   def testUpdateSaAddMark(self):
    772     """Test that when an SA has no mark, it can be updated to add a mark."""
    773     for version in [4, 6]:
    774       spi = 0xABCD
    775       # Test that an SA created with ALLOCSPI can be updated with the mark.
    776       new_sa = self.xfrm.AllocSpi(net_test.GetWildcardAddress(version),
    777                                   IPPROTO_ESP, spi, spi)
    778       mark = xfrm.ExactMatchMark(0xf00d)
    779       self.xfrm.AddSaInfo(net_test.GetWildcardAddress(version),
    780                           net_test.GetWildcardAddress(version),
    781                           spi, xfrm.XFRM_MODE_TUNNEL, 0,
    782                           xfrm_base._ALGO_CBC_AES_256,
    783                           xfrm_base._ALGO_HMAC_SHA1,
    784                           None, None, mark, 0, is_update=True)
    785       dump = self.xfrm.DumpSaInfo()
    786       self.assertEquals(1, len(dump)) # check that update updated
    787       sainfo, attributes = dump[0]
    788       self.assertEquals(mark, attributes["XFRMA_MARK"])
    789       self.xfrm.DeleteSaInfo(net_test.GetWildcardAddress(version),
    790                              spi, IPPROTO_ESP, mark)
    791 
    792       # TODO: we might also need to update the mark for a VALID SA.
    793 
    794 if __name__ == "__main__":
    795   unittest.main()
    796