1 # Copyright (c) 2012 The Chromium OS Authors. All rights reserved. 2 # Use of this source code is governed by a BSD-style license that can be 3 # found in the LICENSE file. 4 5 import dbus 6 import logging 7 import socket 8 import time 9 import urllib2 10 11 import common 12 13 from autotest_lib.client.bin import utils 14 from autotest_lib.client.common_lib import error 15 from autotest_lib.client.cros import routing 16 17 18 class IpTablesContext(object): 19 """Context manager that manages iptables rules.""" 20 IPTABLES = '/sbin/iptables' 21 22 def __init__(self, initial_allowed_host=None): 23 self.initial_allowed_host = initial_allowed_host 24 self.rules = [] 25 26 def _IpTables(self, command): 27 # Run, log, return output 28 return utils.system_output('%s %s' % (self.IPTABLES, command), 29 retain_output=True) 30 31 def _RemoveRule(self, rule): 32 self._IpTables('-D ' + rule) 33 self.rules.remove(rule) 34 35 def AllowHost(self, host): 36 """ 37 Allows the specified host through the firewall. 38 39 @param host: Name of host to allow 40 """ 41 for proto in ['tcp', 'udp']: 42 rule = 'INPUT -s %s/32 -p %s -m %s -j ACCEPT' % (host, proto, proto) 43 output = self._IpTables('-S INPUT') 44 current = [x.rstrip() for x in output.splitlines()] 45 logging.error('current: %s', current) 46 if '-A ' + rule in current: 47 # Already have the rule 48 logging.info('Not adding redundant %s', rule) 49 continue 50 self._IpTables('-A '+ rule) 51 self.rules.append(rule) 52 53 def _CleanupRules(self): 54 for rule in self.rules: 55 self._RemoveRule(rule) 56 57 def __enter__(self): 58 if self.initial_allowed_host: 59 self.AllowHost(self.initial_allowed_host) 60 return self 61 62 def __exit__(self, exception, value, traceback): 63 self._CleanupRules() 64 return False 65 66 67 def NameServersForService(conn_mgr, service): 68 """ 69 Returns the list of name servers used by a connected service. 70 71 @param conn_mgr: Connection manager (shill) 72 @param service: Name of the connected service 73 @return: List of name servers used by |service| 74 """ 75 service_properties = service.GetProperties(utf8_strings=True) 76 device_path = service_properties['Device'] 77 device = conn_mgr.GetObjectInterface('Device', device_path) 78 if device is None: 79 logging.error('No device for service %s', service) 80 return [] 81 82 properties = device.GetProperties(utf8_strings=True) 83 84 hosts = [] 85 for path in properties['IPConfigs']: 86 ipconfig = conn_mgr.GetObjectInterface('IPConfig', path) 87 ipconfig_properties = ipconfig.GetProperties(utf8_strings=True) 88 hosts += ipconfig_properties['NameServers'] 89 90 logging.info('Name servers: %s', ', '.join(hosts)) 91 92 return hosts 93 94 95 def CheckInterfaceForDestination(host, expected_interface): 96 """ 97 Checks that routes for host go through a given interface. 98 99 The concern here is that our network setup may have gone wrong 100 and our test connections may go over some other network than 101 the one we're trying to test. So we take all the IP addresses 102 for the supplied host and make sure they go through the given 103 network interface. 104 105 @param host: Destination host 106 @param expected_interface: Expected interface name 107 @raises: error.TestFail if the routes for the given host go through 108 a different interface than the expected one. 109 110 """ 111 # addrinfo records: (family, type, proto, canonname, (addr, port)) 112 server_addresses = [record[4][0] 113 for record in socket.getaddrinfo(host, 80)] 114 115 route_found = False 116 routes = routing.NetworkRoutes() 117 for address in server_addresses: 118 route = routes.getRouteFor(address) 119 if not route: 120 continue 121 122 route_found = True 123 124 interface = route.interface 125 logging.info('interface for %s: %s', address, interface) 126 if interface != expected_interface: 127 raise error.TestFail('Target server %s uses interface %s' 128 '(%s expected).' % 129 (address, interface, expected_interface)) 130 131 if not route_found: 132 raise error.TestFail('No route found for "%s".' % host) 133 134 FETCH_URL_PATTERN_FOR_TEST = \ 135 'http://testing-chargen.appspot.com/download?size=%d' 136 137 def FetchUrl(url_pattern, bytes_to_fetch=10, fetch_timeout=10): 138 """ 139 Fetches a specified number of bytes from a URL. 140 141 @param url_pattern: URL pattern for fetching a specified number of bytes. 142 %d in the pattern is to be filled in with the number of bytes to 143 fetch. 144 @param bytes_to_fetch: Number of bytes to fetch. 145 @param fetch_timeout: Number of seconds to wait for the fetch to complete 146 before it times out. 147 @return: The time in seconds spent for fetching the specified number of 148 bytes. 149 @raises: error.TestError if one of the following happens: 150 - The fetch takes no time. 151 - The number of bytes fetched differs from the specified 152 number. 153 154 """ 155 # Limit the amount of bytes to read at a time. 156 _MAX_FETCH_READ_BYTES = 1024 * 1024 157 158 url = url_pattern % bytes_to_fetch 159 logging.info('FetchUrl %s', url) 160 start_time = time.time() 161 result = urllib2.urlopen(url, timeout=fetch_timeout) 162 bytes_fetched = 0 163 while bytes_fetched < bytes_to_fetch: 164 bytes_left = bytes_to_fetch - bytes_fetched 165 bytes_to_read = min(bytes_left, _MAX_FETCH_READ_BYTES) 166 bytes_read = len(result.read(bytes_to_read)) 167 bytes_fetched += bytes_read 168 if bytes_read != bytes_to_read: 169 raise error.TestError('FetchUrl tried to read %d bytes, but got ' 170 '%d bytes instead.' % 171 (bytes_to_read, bytes_read)) 172 fetch_time = time.time() - start_time 173 if fetch_time > fetch_timeout: 174 raise error.TestError('FetchUrl exceeded timeout.') 175 176 return fetch_time 177