Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2015 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 import time
     18 from socket import *  # pylint: disable=wildcard-import
     19 
     20 import net_test
     21 import multinetwork_base
     22 import packets
     23 
     24 # TCP states. See include/net/tcp_states.h.
     25 TCP_ESTABLISHED = 1
     26 TCP_SYN_SENT = 2
     27 TCP_SYN_RECV = 3
     28 TCP_FIN_WAIT1 = 4
     29 TCP_FIN_WAIT2 = 5
     30 TCP_TIME_WAIT = 6
     31 TCP_CLOSE = 7
     32 TCP_CLOSE_WAIT = 8
     33 TCP_LAST_ACK = 9
     34 TCP_LISTEN = 10
     35 TCP_CLOSING = 11
     36 TCP_NEW_SYN_RECV = 12
     37 
     38 TCP_NOT_YET_ACCEPTED = -1
     39 
     40 
     41 class TcpBaseTest(multinetwork_base.MultiNetworkBaseTest):
     42 
     43   def tearDown(self):
     44     if hasattr(self, "s"):
     45       self.s.close()
     46     super(TcpBaseTest, self).tearDown()
     47 
     48   def OpenListenSocket(self, version, netid):
     49     family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
     50     address = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
     51     s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
     52     # We haven't configured inbound iptables marking, so bind explicitly.
     53     self.SelectInterface(s, netid, "mark")
     54     self.port = net_test.BindRandomPort(version, s)
     55     return s
     56 
     57   def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
     58     pkt = super(TcpBaseTest, self)._ReceiveAndExpectResponse(netid, packet,
     59                                                              reply, msg)
     60     self.last_packet = pkt
     61     return pkt
     62 
     63   def ReceivePacketOn(self, netid, packet):
     64     super(TcpBaseTest, self).ReceivePacketOn(netid, packet)
     65     self.last_packet = packet
     66 
     67   def ReceiveRstPacketOn(self, netid):
     68     # self.last_packet is the last packet we received. Invert direction twice.
     69     _, ack = packets.ACK(self.version, self.myaddr, self.remoteaddr,
     70                          self.last_packet)
     71     desc, rst = packets.RST(self.version, self.remoteaddr, self.myaddr,
     72                             ack)
     73     super(TcpBaseTest, self).ReceivePacketOn(netid, rst)
     74 
     75   def RstPacket(self):
     76     return packets.RST(self.version, self.myaddr, self.remoteaddr,
     77                        self.last_packet)
     78 
     79   def FinPacket(self):
     80     return packets.FIN(self.version, self.myaddr, self.remoteaddr,
     81                        self.last_packet)
     82 
     83 
     84   def IncomingConnection(self, version, end_state, netid):
     85     self.s = self.OpenListenSocket(version, netid)
     86     self.end_state = end_state
     87 
     88     remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
     89     remotesockaddr = self.remotesockaddr = self.GetRemoteSocketAddress(version)
     90 
     91     myaddr = self.myaddr = self.MyAddress(version, netid)
     92     mysockaddr = self.mysockaddr = self.MySocketAddress(version, netid)
     93 
     94     if version == 5: version = 4
     95     self.version = version
     96 
     97     if end_state == TCP_LISTEN:
     98       return
     99 
    100     desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
    101     synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
    102     msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
    103     reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
    104     if end_state == TCP_SYN_RECV:
    105       return
    106 
    107     establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
    108     self.ReceivePacketOn(netid, establishing_ack)
    109 
    110     if end_state == TCP_NOT_YET_ACCEPTED:
    111       return
    112 
    113     self.accepted, _ = self.s.accept()
    114     net_test.DisableFinWait(self.accepted)
    115 
    116     if end_state == TCP_ESTABLISHED:
    117       return
    118 
    119     desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
    120                              payload=net_test.UDP_PAYLOAD)
    121     self.accepted.send(net_test.UDP_PAYLOAD)
    122     self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
    123 
    124     desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
    125     fin = packets._GetIpLayer(version)(str(fin))
    126     ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
    127     msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
    128 
    129     # TODO: Why can't we use this?
    130     #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
    131     self.ReceivePacketOn(netid, fin)
    132     time.sleep(0.1)
    133     self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
    134     if end_state == TCP_CLOSE_WAIT:
    135       return
    136 
    137     raise ValueError("Invalid TCP state %d specified" % end_state)
    138