Home | History | Annotate | Download | only in cros
      1 # Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import dbus
      6 import logging
      7 import socket
      8 import time
      9 import urllib2
     10 
     11 import common
     12 
     13 # Import 'flimflam_test_path' first in order to import 'routing'.
     14 # Disable warning about flimflam_test_path not being used since it is used
     15 # to find routing but not explicitly used as a module.
     16 # pylint: disable-msg=W0611
     17 import flimflam_test_path
     18 import routing
     19 
     20 from autotest_lib.client.bin import utils
     21 from autotest_lib.client.common_lib import error
     22 from autotest_lib.client.cros.cellular import mm
     23 
     24 
     25 def _Bug24628WorkaroundEnable(modem):
     26     """Enables a modem.  Try again if a SerialResponseTimeout is received."""
     27     # http://code.google.com/p/chromium-os/issues/detail?id=24628
     28     tries = 5
     29     while tries > 0:
     30         try:
     31             modem.Enable(True)
     32             return
     33         except dbus.exceptions.DBusException, e:
     34             logging.error('Enable failed: %s', e)
     35             tries -= 1
     36             if tries > 0:
     37                 logging.error('_Bug24628WorkaroundEnable:  sleeping')
     38                 time.sleep(6)
     39                 logging.error('_Bug24628WorkaroundEnable:  retrying')
     40             else:
     41                 raise
     42 
     43 
     44 # TODO(rochberg):  Move modem-specific functions to cellular/cell_utils
     45 def ResetAllModems(conn_mgr):
     46     """
     47     Disables/Enables cycle all modems to ensure valid starting state.
     48 
     49     @param conn_mgr: Connection manager (shill)
     50     """
     51     service = conn_mgr.FindCellularService()
     52     if not service:
     53         conn_mgr.EnableTechnology('cellular')
     54         service = conn_mgr.FindCellularService()
     55 
     56     logging.info('ResetAllModems: found service %s', service)
     57 
     58     try:
     59         if service:
     60             service.SetProperty('AutoConnect', False),
     61     except dbus.exceptions.DBusException, e:
     62         # The service object may disappear, we can safely ignore it.
     63         if e._dbus_error_name != 'org.freedesktop.DBus.Error.UnknownMethod':
     64             raise
     65 
     66     for manager, path in mm.EnumerateDevices():
     67         modem = manager.GetModem(path)
     68         version = modem.GetVersion()
     69         # Icera modems behave weirdly if we cancel the operation while the
     70         # modem is connecting or disconnecting. Work around the issue by waiting
     71         # until the connect/disconnect operation completes.
     72         # TODO(benchan): Remove this workaround once the issue is addressed
     73         # on the modem side.
     74         utils.poll_for_condition(
     75             lambda: not modem.IsConnectingOrDisconnecting(),
     76             exception=utils.TimeoutError('Timed out waiting for modem to ' +
     77                                          'finish connecting/disconnecting'),
     78             sleep_interval=1,
     79             timeout=30)
     80         modem.Enable(False)
     81         # Although we disable at the ModemManager level, we need to wait for
     82         # shill to process the disable to ensure the modem is in a stable state
     83         # before continuing else we may end up trying to enable a modem that
     84         # is still in the process of disabling.
     85         cm_device = conn_mgr.FindElementByPropertySubstring('Device',
     86                                                             'DBus.Object',
     87                                                             path)
     88         utils.poll_for_condition(
     89             lambda: not cm_device.GetProperties()['Powered'],
     90             exception=utils.TimeoutError(
     91                 'Timed out waiting for shill device disable'),
     92             sleep_interval=1,
     93             timeout=30)
     94         assert modem.IsDisabled()
     95 
     96         if 'Y3300XXKB1' in version:
     97             _Bug24628WorkaroundEnable(modem)
     98         else:
     99             modem.Enable(True)
    100             # Wait for shill to process the enable for the same reason as
    101             # above (during disable).
    102             utils.poll_for_condition(
    103                 lambda: cm_device.GetProperties()['Powered'],
    104                 exception=utils.TimeoutError(
    105                     'Timed out waiting for shill device enable'),
    106                 sleep_interval=1,
    107                 timeout=30)
    108             assert modem.IsEnabled()
    109 
    110 
    111 class IpTablesContext(object):
    112     """Context manager that manages iptables rules."""
    113     IPTABLES = '/sbin/iptables'
    114 
    115     def __init__(self, initial_allowed_host=None):
    116         self.initial_allowed_host = initial_allowed_host
    117         self.rules = []
    118 
    119     def _IpTables(self, command):
    120         # Run, log, return output
    121         return utils.system_output('%s %s' % (self.IPTABLES, command),
    122                                    retain_output=True)
    123 
    124     def _RemoveRule(self, rule):
    125         self._IpTables('-D ' + rule)
    126         self.rules.remove(rule)
    127 
    128     def AllowHost(self, host):
    129         """
    130         Allows the specified host through the firewall.
    131 
    132         @param host: Name of host to allow
    133         """
    134         for proto in ['tcp', 'udp']:
    135             rule = 'INPUT -s %s/32 -p %s -m %s -j ACCEPT' % (host, proto, proto)
    136             output = self._IpTables('-S INPUT')
    137             current = [x.rstrip() for x in output.splitlines()]
    138             logging.error('current: %s', current)
    139             if '-A ' + rule in current:
    140                 # Already have the rule
    141                 logging.info('Not adding redundant %s', rule)
    142                 continue
    143             self._IpTables('-A '+ rule)
    144             self.rules.append(rule)
    145 
    146     def _CleanupRules(self):
    147         for rule in self.rules:
    148             self._RemoveRule(rule)
    149 
    150     def __enter__(self):
    151         if self.initial_allowed_host:
    152             self.AllowHost(self.initial_allowed_host)
    153         return self
    154 
    155     def __exit__(self, exception, value, traceback):
    156         self._CleanupRules()
    157         return False
    158 
    159 
    160 def NameServersForService(conn_mgr, service):
    161     """
    162     Returns the list of name servers used by a connected service.
    163 
    164     @param conn_mgr: Connection manager (shill)
    165     @param service: Name of the connected service
    166     @return: List of name servers used by |service|
    167     """
    168     service_properties = service.GetProperties(utf8_strings=True)
    169     device_path = service_properties['Device']
    170     device = conn_mgr.GetObjectInterface('Device', device_path)
    171     if device is None:
    172         logging.error('No device for service %s', service)
    173         return []
    174 
    175     properties = device.GetProperties(utf8_strings=True)
    176 
    177     hosts = []
    178     for path in properties['IPConfigs']:
    179         ipconfig = conn_mgr.GetObjectInterface('IPConfig', path)
    180         ipconfig_properties = ipconfig.GetProperties(utf8_strings=True)
    181         hosts += ipconfig_properties['NameServers']
    182 
    183     logging.info('Name servers: %s', ', '.join(hosts))
    184 
    185     return hosts
    186 
    187 
    188 def CheckInterfaceForDestination(host, expected_interface):
    189     """
    190     Checks that routes for host go through a given interface.
    191 
    192     The concern here is that our network setup may have gone wrong
    193     and our test connections may go over some other network than
    194     the one we're trying to test.  So we take all the IP addresses
    195     for the supplied host and make sure they go through the given
    196     network interface.
    197 
    198     @param host: Destination host
    199     @param expected_interface: Expected interface name
    200     @raises: error.TestFail if the routes for the given host go through
    201             a different interface than the expected one.
    202 
    203     """
    204     # socket.getaddrinfo() returns a list of tuples in one of the following
    205     # forms:
    206     #
    207     # For IPv4 address:
    208     #   (family, type, proto, canonname, (address, port))
    209     # For IPv6 address:
    210     #   (family, type, proto, canonname, (address, port, flow_info, scope_id))
    211     #
    212     # As routing.NetworkRoutes currently supports only IPv4 routes / addresses,
    213     # we filter out any IPv6 address reported by socket.getaddrinfo().
    214     #
    215     # TODO(benchan): Fix this limitation after porting routes.NetworkRoutes to
    216     # support both IPv4 and IPv6 (crbug.com/742046).
    217     server_addresses = [record[4][0]
    218                         for record in socket.getaddrinfo(host, 80)
    219                         if len(record[4][0]) == 2]
    220 
    221     routes = routing.NetworkRoutes()
    222     for address in server_addresses:
    223         interface = routes.getRouteFor(address).interface
    224         logging.info('interface for %s: %s', address, interface)
    225         if interface != expected_interface:
    226             raise error.TestFail('Target server %s uses interface %s'
    227                                  '(%s expected).' %
    228                                  (address, interface, expected_interface))
    229 
    230 
    231 FETCH_URL_PATTERN_FOR_TEST = \
    232     'http://testing-chargen.appspot.com/download?size=%d'
    233 
    234 def FetchUrl(url_pattern, bytes_to_fetch=10, fetch_timeout=10):
    235     """
    236     Fetches a specified number of bytes from a URL.
    237 
    238     @param url_pattern: URL pattern for fetching a specified number of bytes.
    239             %d in the pattern is to be filled in with the number of bytes to
    240             fetch.
    241     @param bytes_to_fetch: Number of bytes to fetch.
    242     @param fetch_timeout: Number of seconds to wait for the fetch to complete
    243             before it times out.
    244     @return: The time in seconds spent for fetching the specified number of
    245             bytes.
    246     @raises: error.TestError if one of the following happens:
    247             - The fetch takes no time.
    248             - The number of bytes fetched differs from the specified
    249               number.
    250 
    251     """
    252     # Limit the amount of bytes to read at a time.
    253     _MAX_FETCH_READ_BYTES = 1024 * 1024
    254 
    255     url = url_pattern % bytes_to_fetch
    256     logging.info('FetchUrl %s', url)
    257     start_time = time.time()
    258     result = urllib2.urlopen(url, timeout=fetch_timeout)
    259     bytes_fetched = 0
    260     while bytes_fetched < bytes_to_fetch:
    261         bytes_left = bytes_to_fetch - bytes_fetched
    262         bytes_to_read = min(bytes_left, _MAX_FETCH_READ_BYTES)
    263         bytes_read = len(result.read(bytes_to_read))
    264         bytes_fetched += bytes_read
    265         if bytes_read != bytes_to_read:
    266             raise error.TestError('FetchUrl tried to read %d bytes, but got '
    267                                   '%d bytes instead.' %
    268                                   (bytes_to_read, bytes_read))
    269         fetch_time = time.time() - start_time
    270         if fetch_time > fetch_timeout:
    271             raise error.TestError('FetchUrl exceeded timeout.')
    272 
    273     return fetch_time
    274