Home | History | Annotate | Download | only in webpagereplay
      1 #!/usr/bin/env python
      2 # Copyright 2012 Google Inc. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 
     16 """Retrieve web resources over http."""
     17 
     18 import copy
     19 import httplib
     20 import logging
     21 import random
     22 import ssl
     23 import StringIO
     24 
     25 import httparchive
     26 import platformsettings
     27 import script_injector
     28 
     29 
     30 # PIL isn't always available, but we still want to be able to run without
     31 # the image scrambling functionality in this case.
     32 try:
     33   import Image
     34 except ImportError:
     35   Image = None
     36 
     37 TIMER = platformsettings.timer
     38 
     39 
     40 class HttpClientException(Exception):
     41   """Base class for all exceptions in httpclient."""
     42   pass
     43 
     44 
     45 def _InjectScripts(response, inject_script):
     46   """Injects |inject_script| immediately after <head> or <html>.
     47 
     48   Copies |response| if it is modified.
     49 
     50   Args:
     51     response: an ArchivedHttpResponse
     52     inject_script: JavaScript string (e.g. "Math.random = function(){...}")
     53   Returns:
     54     an ArchivedHttpResponse
     55   """
     56   if type(response) == tuple:
     57     logging.warn('tuple response: %s', response)
     58   content_type = response.get_header('content-type')
     59   if content_type and content_type.startswith('text/html'):
     60     text = response.get_data_as_text()
     61     text, already_injected = script_injector.InjectScript(
     62         text, 'text/html', inject_script)
     63     if not already_injected:
     64       response = copy.deepcopy(response)
     65       response.set_data(text)
     66   return response
     67 
     68 
     69 def _ScrambleImages(response):
     70   """If the |response| is an image, attempt to scramble it.
     71 
     72   Copies |response| if it is modified.
     73 
     74   Args:
     75     response: an ArchivedHttpResponse
     76   Returns:
     77     an ArchivedHttpResponse
     78   """
     79 
     80   assert Image, '--scramble_images requires the PIL module to be installed.'
     81 
     82   content_type = response.get_header('content-type')
     83   if content_type and content_type.startswith('image/'):
     84     try:
     85       image_data = response.response_data[0]
     86       image_data.decode(encoding='base64')
     87       im = Image.open(StringIO.StringIO(image_data))
     88 
     89       pixel_data = list(im.getdata())
     90       random.shuffle(pixel_data)
     91 
     92       scrambled_image = im.copy()
     93       scrambled_image.putdata(pixel_data)
     94 
     95       output_image_io = StringIO.StringIO()
     96       scrambled_image.save(output_image_io, im.format)
     97       output_image_data = output_image_io.getvalue()
     98       output_image_data.encode(encoding='base64')
     99 
    100       response = copy.deepcopy(response)
    101       response.set_data(output_image_data)
    102     except Exception:
    103       pass
    104 
    105   return response
    106 
    107 
    108 class DetailedHTTPResponse(httplib.HTTPResponse):
    109   """Preserve details relevant to replaying responses.
    110 
    111   WARNING: This code uses attributes and methods of HTTPResponse
    112   that are not part of the public interface.
    113   """
    114 
    115   def read_chunks(self):
    116     """Return the response body content and timing data.
    117 
    118     The returned chunks have the chunk size and CRLFs stripped off.
    119     If the response was compressed, the returned data is still compressed.
    120 
    121     Returns:
    122       (chunks, delays)
    123         chunks:
    124           [response_body]                  # non-chunked responses
    125           [chunk_1, chunk_2, ...]          # chunked responses
    126         delays:
    127           [0]                              # non-chunked responses
    128           [chunk_1_first_byte_delay, ...]  # chunked responses
    129 
    130       The delay for the first body item should be recorded by the caller.
    131     """
    132     buf = []
    133     chunks = []
    134     delays = []
    135     if not self.chunked:
    136       chunks.append(self.read())
    137       delays.append(0)
    138     else:
    139       start = TIMER()
    140       try:
    141         while True:
    142           line = self.fp.readline()
    143           chunk_size = self._read_chunk_size(line)
    144           if chunk_size is None:
    145             raise httplib.IncompleteRead(''.join(chunks))
    146           if chunk_size == 0:
    147             break
    148           delays.append(TIMER() - start)
    149           chunks.append(self._safe_read(chunk_size))
    150           self._safe_read(2)  # skip the CRLF at the end of the chunk
    151           start = TIMER()
    152 
    153         # Ignore any trailers.
    154         while True:
    155           line = self.fp.readline()
    156           if not line or line == '\r\n':
    157             break
    158       finally:
    159         self.close()
    160     return chunks, delays
    161 
    162   @classmethod
    163   def _read_chunk_size(cls, line):
    164     chunk_extensions_pos = line.find(';')
    165     if chunk_extensions_pos != -1:
    166       line = line[:chunk_extensions_pos]  # strip chunk-extensions
    167     try:
    168       chunk_size = int(line, 16)
    169     except ValueError:
    170       return None
    171     return chunk_size
    172 
    173 
    174 class DetailedHTTPConnection(httplib.HTTPConnection):
    175   """Preserve details relevant to replaying connections."""
    176   response_class = DetailedHTTPResponse
    177 
    178 
    179 class DetailedHTTPSResponse(DetailedHTTPResponse):
    180   """Preserve details relevant to replaying SSL responses."""
    181   pass
    182 
    183 
    184 class DetailedHTTPSConnection(httplib.HTTPSConnection):
    185   """Preserve details relevant to replaying SSL connections."""
    186   response_class = DetailedHTTPSResponse
    187 
    188   def __init__(self, host, port):
    189     # https://www.python.org/dev/peps/pep-0476/#opting-out
    190     if hasattr(ssl, '_create_unverified_context'):
    191       httplib.HTTPSConnection.__init__(
    192           self, host=host, port=port, context=ssl._create_unverified_context())
    193     else:
    194       httplib.HTTPSConnection.__init__(self, host=host, port=port)
    195 
    196 
    197 class RealHttpFetch(object):
    198 
    199   def __init__(self, real_dns_lookup):
    200     """Initialize RealHttpFetch.
    201 
    202     Args:
    203       real_dns_lookup: a function that resolves a host to an IP.
    204     """
    205     self._real_dns_lookup = real_dns_lookup
    206 
    207   @staticmethod
    208   def _GetHeaderNameValue(header):
    209     """Parse the header line and return a name/value tuple.
    210 
    211     Args:
    212       header: a string for a header such as "Content-Length: 314".
    213     Returns:
    214       A tuple (header_name, header_value) on success or None if the header
    215       is not in expected format. header_name is in lowercase.
    216     """
    217     i = header.find(':')
    218     if i > 0:
    219       return (header[:i].lower(), header[i+1:].strip())
    220     return None
    221 
    222   @staticmethod
    223   def _ToTuples(headers):
    224     """Parse headers and save them to a list of tuples.
    225 
    226     This method takes HttpResponse.msg.headers as input and convert it
    227     to a list of (header_name, header_value) tuples.
    228     HttpResponse.msg.headers is a list of strings where each string
    229     represents either a header or a continuation line of a header.
    230     1. a normal header consists of two parts which are separated by colon :
    231        "header_name:header_value..."
    232     2. a continuation line is a string starting with whitespace
    233        "[whitespace]continued_header_value..."
    234     If a header is not in good shape or an unexpected continuation line is
    235     seen, it will be ignored.
    236 
    237     Should avoid using response.getheaders() directly
    238     because response.getheaders() can't handle multiple headers
    239     with the same name properly. Instead, parse the
    240     response.msg.headers using this method to get all headers.
    241 
    242     Args:
    243       headers: an instance of HttpResponse.msg.headers.
    244     Returns:
    245       A list of tuples which looks like:
    246       [(header_name, header_value), (header_name2, header_value2)...]
    247     """
    248     all_headers = []
    249     for line in headers:
    250       if line[0] in '\t ':
    251         if not all_headers:
    252           logging.warning(
    253               'Unexpected response header continuation line [%s]', line)
    254           continue
    255         name, value = all_headers.pop()
    256         value += '\n ' + line.strip()
    257       else:
    258         name_value = RealHttpFetch._GetHeaderNameValue(line)
    259         if not name_value:
    260           logging.warning(
    261               'Response header in wrong format [%s]', line)
    262           continue
    263         name, value = name_value  # pylint: disable=unpacking-non-sequence
    264       all_headers.append((name, value))
    265     return all_headers
    266 
    267   @staticmethod
    268   def _get_request_host_port(request):
    269     host_parts = request.host.split(':')
    270     host = host_parts[0]
    271     port = int(host_parts[1]) if len(host_parts) == 2 else None
    272     return host, port
    273 
    274   @staticmethod
    275   def _get_system_proxy(is_ssl):
    276     return platformsettings.get_system_proxy(is_ssl)
    277 
    278   def _get_connection(self, request_host, request_port, is_ssl):
    279     """Return a detailed connection object for host/port pair.
    280 
    281     If a system proxy is defined (see platformsettings.py), it will be used.
    282 
    283     Args:
    284       request_host: a host string (e.g. "www.example.com").
    285       request_port: a port integer (e.g. 8080) or None (for the default port).
    286       is_ssl: True if HTTPS connection is needed.
    287     Returns:
    288       A DetailedHTTPSConnection or DetailedHTTPConnection instance.
    289     """
    290     connection_host = request_host
    291     connection_port = request_port
    292     system_proxy = self._get_system_proxy(is_ssl)
    293     if system_proxy:
    294       connection_host = system_proxy.host
    295       connection_port = system_proxy.port
    296 
    297     # Use an IP address because WPR may override DNS settings.
    298     connection_ip = self._real_dns_lookup(connection_host)
    299     if not connection_ip:
    300       logging.critical('Unable to find host ip for name: %s', connection_host)
    301       return None
    302 
    303     if is_ssl:
    304       connection = DetailedHTTPSConnection(connection_ip, connection_port)
    305       if system_proxy:
    306         connection.set_tunnel(request_host, request_port)
    307     else:
    308       connection = DetailedHTTPConnection(connection_ip, connection_port)
    309     return connection
    310 
    311   def __call__(self, request):
    312     """Fetch an HTTP request.
    313 
    314     Args:
    315       request: an ArchivedHttpRequest
    316     Returns:
    317       an ArchivedHttpResponse
    318     """
    319     logging.debug('RealHttpFetch: %s %s', request.host, request.full_path)
    320     request_host, request_port = self._get_request_host_port(request)
    321     retries = 3
    322     while True:
    323       try:
    324         connection = self._get_connection(
    325             request_host, request_port, request.is_ssl)
    326         connect_start = TIMER()
    327         connection.connect()
    328         connect_delay = int((TIMER() - connect_start) * 1000)
    329         start = TIMER()
    330         connection.request(
    331             request.command,
    332             request.full_path,
    333             request.request_body,
    334             request.headers)
    335         response = connection.getresponse()
    336         headers_delay = int((TIMER() - start) * 1000)
    337 
    338         chunks, chunk_delays = response.read_chunks()
    339         delays = {
    340             'connect': connect_delay,
    341             'headers': headers_delay,
    342             'data': chunk_delays
    343             }
    344         archived_http_response = httparchive.ArchivedHttpResponse(
    345             response.version,
    346             response.status,
    347             response.reason,
    348             RealHttpFetch._ToTuples(response.msg.headers),
    349             chunks,
    350             delays)
    351         return archived_http_response
    352       except Exception, e:
    353         if retries:
    354           retries -= 1
    355           logging.warning('Retrying fetch %s: %s', request, repr(e))
    356           continue
    357         logging.critical('Could not fetch %s: %s', request, repr(e))
    358         return None
    359 
    360 
    361 class RecordHttpArchiveFetch(object):
    362   """Make real HTTP fetches and save responses in the given HttpArchive."""
    363 
    364   def __init__(self, http_archive, real_dns_lookup, inject_script):
    365     """Initialize RecordHttpArchiveFetch.
    366 
    367     Args:
    368       http_archive: an instance of a HttpArchive
    369       real_dns_lookup: a function that resolves a host to an IP.
    370       inject_script: script string to inject in all pages
    371     """
    372     self.http_archive = http_archive
    373     self.real_http_fetch = RealHttpFetch(real_dns_lookup)
    374     self.inject_script = inject_script
    375 
    376   def __call__(self, request):
    377     """Fetch the request and return the response.
    378 
    379     Args:
    380       request: an ArchivedHttpRequest.
    381     Returns:
    382       an ArchivedHttpResponse
    383     """
    384     # If request is already in the archive, return the archived response.
    385     if request in self.http_archive:
    386       logging.debug('Repeated request found: %s', request)
    387       response = self.http_archive[request]
    388     else:
    389       response = self.real_http_fetch(request)
    390       if response is None:
    391         return None
    392       self.http_archive[request] = response
    393     if self.inject_script:
    394       response = _InjectScripts(response, self.inject_script)
    395     logging.debug('Recorded: %s', request)
    396     return response
    397 
    398 
    399 class ReplayHttpArchiveFetch(object):
    400   """Serve responses from the given HttpArchive."""
    401 
    402   def __init__(self, http_archive, real_dns_lookup, inject_script,
    403                use_diff_on_unknown_requests=False,
    404                use_closest_match=False, scramble_images=False):
    405     """Initialize ReplayHttpArchiveFetch.
    406 
    407     Args:
    408       http_archive: an instance of a HttpArchive
    409       real_dns_lookup: a function that resolves a host to an IP.
    410       inject_script: script string to inject in all pages
    411       use_diff_on_unknown_requests: If True, log unknown requests
    412         with a diff to requests that look similar.
    413       use_closest_match: If True, on replay mode, serve the closest match
    414         in the archive instead of giving a 404.
    415     """
    416     self.http_archive = http_archive
    417     self.inject_script = inject_script
    418     self.use_diff_on_unknown_requests = use_diff_on_unknown_requests
    419     self.use_closest_match = use_closest_match
    420     self.scramble_images = scramble_images
    421     self.real_http_fetch = RealHttpFetch(real_dns_lookup)
    422 
    423   def __call__(self, request):
    424     """Fetch the request and return the response.
    425 
    426     Args:
    427       request: an instance of an ArchivedHttpRequest.
    428     Returns:
    429       Instance of ArchivedHttpResponse (if found) or None
    430     """
    431     if request.host.startswith('127.0.0.1:'):
    432       return self.real_http_fetch(request)
    433 
    434     response = self.http_archive.get(request)
    435 
    436     if self.use_closest_match and not response:
    437       closest_request = self.http_archive.find_closest_request(
    438           request, use_path=True)
    439       if closest_request:
    440         response = self.http_archive.get(closest_request)
    441         if response:
    442           logging.info('Request not found: %s\nUsing closest match: %s',
    443                        request, closest_request)
    444 
    445     if not response:
    446       reason = str(request)
    447       if self.use_diff_on_unknown_requests:
    448         diff = self.http_archive.diff(request)
    449         if diff:
    450           reason += (
    451               "\nNearest request diff "
    452               "('-' for archived request, '+' for current request):\n%s" % diff)
    453       logging.warning('Could not replay: %s', reason)
    454     else:
    455       if self.inject_script:
    456         response = _InjectScripts(response, self.inject_script)
    457       if self.scramble_images:
    458         response = _ScrambleImages(response)
    459     return response
    460 
    461 
    462 class ControllableHttpArchiveFetch(object):
    463   """Controllable fetch function that can swap between record and replay."""
    464 
    465   def __init__(self, http_archive, real_dns_lookup,
    466                inject_script, use_diff_on_unknown_requests,
    467                use_record_mode, use_closest_match, scramble_images):
    468     """Initialize HttpArchiveFetch.
    469 
    470     Args:
    471       http_archive: an instance of a HttpArchive
    472       real_dns_lookup: a function that resolves a host to an IP.
    473       inject_script: script string to inject in all pages.
    474       use_diff_on_unknown_requests: If True, log unknown requests
    475         with a diff to requests that look similar.
    476       use_record_mode: If True, start in server in record mode.
    477       use_closest_match: If True, on replay mode, serve the closest match
    478         in the archive instead of giving a 404.
    479     """
    480     self.http_archive = http_archive
    481     self.record_fetch = RecordHttpArchiveFetch(
    482         http_archive, real_dns_lookup, inject_script)
    483     self.replay_fetch = ReplayHttpArchiveFetch(
    484         http_archive, real_dns_lookup, inject_script,
    485         use_diff_on_unknown_requests, use_closest_match, scramble_images)
    486     if use_record_mode:
    487       self.SetRecordMode()
    488     else:
    489       self.SetReplayMode()
    490 
    491   def SetRecordMode(self):
    492     self.fetch = self.record_fetch
    493     self.is_record_mode = True
    494 
    495   def SetReplayMode(self):
    496     self.fetch = self.replay_fetch
    497     self.is_record_mode = False
    498 
    499   def __call__(self, *args, **kwargs):
    500     """Forward calls to Replay/Record fetch functions depending on mode."""
    501     return self.fetch(*args, **kwargs)
    502