Home | History | Annotate | Download | only in cros
      1 # Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import dbus
      6 import logging
      7 import socket
      8 import time
      9 import urllib2
     10 
     11 import common
     12 
     13 from autotest_lib.client.bin import utils
     14 from autotest_lib.client.common_lib import error
     15 from autotest_lib.client.cros import routing
     16 
     17 
     18 class IpTablesContext(object):
     19     """Context manager that manages iptables rules."""
     20     IPTABLES = '/sbin/iptables'
     21 
     22     def __init__(self, initial_allowed_host=None):
     23         self.initial_allowed_host = initial_allowed_host
     24         self.rules = []
     25 
     26     def _IpTables(self, command):
     27         # Run, log, return output
     28         return utils.system_output('%s %s' % (self.IPTABLES, command),
     29                                    retain_output=True)
     30 
     31     def _RemoveRule(self, rule):
     32         self._IpTables('-D ' + rule)
     33         self.rules.remove(rule)
     34 
     35     def AllowHost(self, host):
     36         """
     37         Allows the specified host through the firewall.
     38 
     39         @param host: Name of host to allow
     40         """
     41         for proto in ['tcp', 'udp']:
     42             rule = 'INPUT -s %s/32 -p %s -m %s -j ACCEPT' % (host, proto, proto)
     43             output = self._IpTables('-S INPUT')
     44             current = [x.rstrip() for x in output.splitlines()]
     45             logging.error('current: %s', current)
     46             if '-A ' + rule in current:
     47                 # Already have the rule
     48                 logging.info('Not adding redundant %s', rule)
     49                 continue
     50             self._IpTables('-A '+ rule)
     51             self.rules.append(rule)
     52 
     53     def _CleanupRules(self):
     54         for rule in self.rules:
     55             self._RemoveRule(rule)
     56 
     57     def __enter__(self):
     58         if self.initial_allowed_host:
     59             self.AllowHost(self.initial_allowed_host)
     60         return self
     61 
     62     def __exit__(self, exception, value, traceback):
     63         self._CleanupRules()
     64         return False
     65 
     66 
     67 def NameServersForService(conn_mgr, service):
     68     """
     69     Returns the list of name servers used by a connected service.
     70 
     71     @param conn_mgr: Connection manager (shill)
     72     @param service: Name of the connected service
     73     @return: List of name servers used by |service|
     74     """
     75     service_properties = service.GetProperties(utf8_strings=True)
     76     device_path = service_properties['Device']
     77     device = conn_mgr.GetObjectInterface('Device', device_path)
     78     if device is None:
     79         logging.error('No device for service %s', service)
     80         return []
     81 
     82     properties = device.GetProperties(utf8_strings=True)
     83 
     84     hosts = []
     85     for path in properties['IPConfigs']:
     86         ipconfig = conn_mgr.GetObjectInterface('IPConfig', path)
     87         ipconfig_properties = ipconfig.GetProperties(utf8_strings=True)
     88         hosts += ipconfig_properties['NameServers']
     89 
     90     logging.info('Name servers: %s', ', '.join(hosts))
     91 
     92     return hosts
     93 
     94 
     95 def CheckInterfaceForDestination(host, expected_interface):
     96     """
     97     Checks that routes for host go through a given interface.
     98 
     99     The concern here is that our network setup may have gone wrong
    100     and our test connections may go over some other network than
    101     the one we're trying to test.  So we take all the IP addresses
    102     for the supplied host and make sure they go through the given
    103     network interface.
    104 
    105     @param host: Destination host
    106     @param expected_interface: Expected interface name
    107     @raises: error.TestFail if the routes for the given host go through
    108             a different interface than the expected one.
    109 
    110     """
    111     # addrinfo records: (family, type, proto, canonname, (addr, port))
    112     server_addresses = [record[4][0]
    113                         for record in socket.getaddrinfo(host, 80)]
    114 
    115     route_found = False
    116     routes = routing.NetworkRoutes()
    117     for address in server_addresses:
    118         route = routes.getRouteFor(address)
    119         if not route:
    120             continue
    121 
    122         route_found = True
    123 
    124         interface = route.interface
    125         logging.info('interface for %s: %s', address, interface)
    126         if interface != expected_interface:
    127             raise error.TestFail('Target server %s uses interface %s'
    128                                  '(%s expected).' %
    129                                  (address, interface, expected_interface))
    130 
    131     if not route_found:
    132         raise error.TestFail('No route found for "%s".' % host)
    133 
    134 FETCH_URL_PATTERN_FOR_TEST = \
    135     'http://testing-chargen.appspot.com/download?size=%d'
    136 
    137 def FetchUrl(url_pattern, bytes_to_fetch=10, fetch_timeout=10):
    138     """
    139     Fetches a specified number of bytes from a URL.
    140 
    141     @param url_pattern: URL pattern for fetching a specified number of bytes.
    142             %d in the pattern is to be filled in with the number of bytes to
    143             fetch.
    144     @param bytes_to_fetch: Number of bytes to fetch.
    145     @param fetch_timeout: Number of seconds to wait for the fetch to complete
    146             before it times out.
    147     @return: The time in seconds spent for fetching the specified number of
    148             bytes.
    149     @raises: error.TestError if one of the following happens:
    150             - The fetch takes no time.
    151             - The number of bytes fetched differs from the specified
    152               number.
    153 
    154     """
    155     # Limit the amount of bytes to read at a time.
    156     _MAX_FETCH_READ_BYTES = 1024 * 1024
    157 
    158     url = url_pattern % bytes_to_fetch
    159     logging.info('FetchUrl %s', url)
    160     start_time = time.time()
    161     result = urllib2.urlopen(url, timeout=fetch_timeout)
    162     bytes_fetched = 0
    163     while bytes_fetched < bytes_to_fetch:
    164         bytes_left = bytes_to_fetch - bytes_fetched
    165         bytes_to_read = min(bytes_left, _MAX_FETCH_READ_BYTES)
    166         bytes_read = len(result.read(bytes_to_read))
    167         bytes_fetched += bytes_read
    168         if bytes_read != bytes_to_read:
    169             raise error.TestError('FetchUrl tried to read %d bytes, but got '
    170                                   '%d bytes instead.' %
    171                                   (bytes_to_read, bytes_read))
    172         fetch_time = time.time() - start_time
    173         if fetch_time > fetch_timeout:
    174             raise error.TestError('FetchUrl exceeded timeout.')
    175 
    176     return fetch_time
    177