Home | History | Annotate | Download | only in web-page-replay
      1 #!/usr/bin/env python
      2 # Copyright 2012 Google Inc. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 
     16 """Retrieve web resources over http."""
     17 
     18 import copy
     19 import datetime
     20 import httplib
     21 import logging
     22 import random
     23 import ssl
     24 import StringIO
     25 
     26 import httparchive
     27 import platformsettings
     28 import script_injector
     29 
     30 
     31 # PIL isn't always available, but we still want to be able to run without
     32 # the image scrambling functionality in this case.
     33 try:
     34   import Image
     35 except ImportError:
     36   Image = None
     37 
     38 TIMER = platformsettings.timer
     39 
     40 
     41 class HttpClientException(Exception):
     42   """Base class for all exceptions in httpclient."""
     43   pass
     44 
     45 
     46 def _InjectScripts(response, injector):
     47   """Injects script generated by |injector| immediately after <head> or <html>.
     48 
     49   Copies |response| if it is modified.
     50 
     51   Args:
     52     response: an ArchivedHttpResponse
     53     injector: function which generates JavaScript string
     54       based on recording time (e.g. "Math.random = function(){...}")
     55   Returns:
     56     an ArchivedHttpResponse
     57   """
     58   if type(response) == tuple:
     59     logging.warn('tuple response: %s', response)
     60   content_type = response.get_header('content-type')
     61   if content_type and content_type.startswith('text/html'):
     62     text_chunks = response.get_data_as_chunks()
     63     text_chunks, just_injected = script_injector.InjectScript(
     64         text_chunks, 'text/html', injector(response.request_time))
     65     if just_injected:
     66       response = copy.deepcopy(response)
     67       response.set_data_from_chunks(text_chunks)
     68   return response
     69 
     70 
     71 def _ScrambleImages(response):
     72   """If the |response| is an image, attempt to scramble it.
     73 
     74   Copies |response| if it is modified.
     75 
     76   Args:
     77     response: an ArchivedHttpResponse
     78   Returns:
     79     an ArchivedHttpResponse
     80   """
     81 
     82   assert Image, '--scramble_images requires the PIL module to be installed.'
     83 
     84   content_type = response.get_header('content-type')
     85   if content_type and content_type.startswith('image/'):
     86     try:
     87       image_data = response.response_data[0]
     88       image_data.decode(encoding='base64')
     89       im = Image.open(StringIO.StringIO(image_data))
     90 
     91       pixel_data = list(im.getdata())
     92       random.shuffle(pixel_data)
     93 
     94       scrambled_image = im.copy()
     95       scrambled_image.putdata(pixel_data)
     96 
     97       output_image_io = StringIO.StringIO()
     98       scrambled_image.save(output_image_io, im.format)
     99       output_image_data = output_image_io.getvalue()
    100       output_image_data.encode(encoding='base64')
    101 
    102       response = copy.deepcopy(response)
    103       response.set_data(output_image_data)
    104     except Exception:
    105       pass
    106 
    107   return response
    108 
    109 
    110 class DetailedHTTPResponse(httplib.HTTPResponse):
    111   """Preserve details relevant to replaying responses.
    112 
    113   WARNING: This code uses attributes and methods of HTTPResponse
    114   that are not part of the public interface.
    115   """
    116 
    117   def read_chunks(self):
    118     """Return the response body content and timing data.
    119 
    120     The returned chunks have the chunk size and CRLFs stripped off.
    121     If the response was compressed, the returned data is still compressed.
    122 
    123     Returns:
    124       (chunks, delays)
    125         chunks:
    126           [response_body]                  # non-chunked responses
    127           [chunk_1, chunk_2, ...]          # chunked responses
    128         delays:
    129           [0]                              # non-chunked responses
    130           [chunk_1_first_byte_delay, ...]  # chunked responses
    131 
    132       The delay for the first body item should be recorded by the caller.
    133     """
    134     buf = []
    135     chunks = []
    136     delays = []
    137     if not self.chunked:
    138       chunks.append(self.read())
    139       delays.append(0)
    140     else:
    141       start = TIMER()
    142       try:
    143         while True:
    144           line = self.fp.readline()
    145           chunk_size = self._read_chunk_size(line)
    146           if chunk_size is None:
    147             raise httplib.IncompleteRead(''.join(chunks))
    148           if chunk_size == 0:
    149             break
    150           delays.append(TIMER() - start)
    151           chunks.append(self._safe_read(chunk_size))
    152           self._safe_read(2)  # skip the CRLF at the end of the chunk
    153           start = TIMER()
    154 
    155         # Ignore any trailers.
    156         while True:
    157           line = self.fp.readline()
    158           if not line or line == '\r\n':
    159             break
    160       finally:
    161         self.close()
    162     return chunks, delays
    163 
    164   @classmethod
    165   def _read_chunk_size(cls, line):
    166     chunk_extensions_pos = line.find(';')
    167     if chunk_extensions_pos != -1:
    168       line = line[:chunk_extensions_pos]  # strip chunk-extensions
    169     try:
    170       chunk_size = int(line, 16)
    171     except ValueError:
    172       return None
    173     return chunk_size
    174 
    175 
    176 class DetailedHTTPConnection(httplib.HTTPConnection):
    177   """Preserve details relevant to replaying connections."""
    178   response_class = DetailedHTTPResponse
    179 
    180 
    181 class DetailedHTTPSResponse(DetailedHTTPResponse):
    182   """Preserve details relevant to replaying SSL responses."""
    183   pass
    184 
    185 
    186 class DetailedHTTPSConnection(httplib.HTTPSConnection):
    187   """Preserve details relevant to replaying SSL connections."""
    188   response_class = DetailedHTTPSResponse
    189 
    190   def __init__(self, host, port):
    191     # https://www.python.org/dev/peps/pep-0476/#opting-out
    192     if hasattr(ssl, '_create_unverified_context'):
    193       httplib.HTTPSConnection.__init__(
    194           self, host=host, port=port, context=ssl._create_unverified_context())
    195     else:
    196       httplib.HTTPSConnection.__init__(self, host=host, port=port)
    197 
    198 
    199 class RealHttpFetch(object):
    200 
    201   def __init__(self, real_dns_lookup):
    202     """Initialize RealHttpFetch.
    203 
    204     Args:
    205       real_dns_lookup: a function that resolves a host to an IP. RealHttpFetch
    206         will resolve host name to the IP before making fetching request if this
    207         is not None.
    208     """
    209     self._real_dns_lookup = real_dns_lookup
    210 
    211   @staticmethod
    212   def _GetHeaderNameValue(header):
    213     """Parse the header line and return a name/value tuple.
    214 
    215     Args:
    216       header: a string for a header such as "Content-Length: 314".
    217     Returns:
    218       A tuple (header_name, header_value) on success or None if the header
    219       is not in expected format. header_name is in lowercase.
    220     """
    221     i = header.find(':')
    222     if i > 0:
    223       return (header[:i].lower(), header[i+1:].strip())
    224     return None
    225 
    226   @staticmethod
    227   def _ToTuples(headers):
    228     """Parse headers and save them to a list of tuples.
    229 
    230     This method takes HttpResponse.msg.headers as input and convert it
    231     to a list of (header_name, header_value) tuples.
    232     HttpResponse.msg.headers is a list of strings where each string
    233     represents either a header or a continuation line of a header.
    234     1. a normal header consists of two parts which are separated by colon :
    235        "header_name:header_value..."
    236     2. a continuation line is a string starting with whitespace
    237        "[whitespace]continued_header_value..."
    238     If a header is not in good shape or an unexpected continuation line is
    239     seen, it will be ignored.
    240 
    241     Should avoid using response.getheaders() directly
    242     because response.getheaders() can't handle multiple headers
    243     with the same name properly. Instead, parse the
    244     response.msg.headers using this method to get all headers.
    245 
    246     Args:
    247       headers: an instance of HttpResponse.msg.headers.
    248     Returns:
    249       A list of tuples which looks like:
    250       [(header_name, header_value), (header_name2, header_value2)...]
    251     """
    252     all_headers = []
    253     for line in headers:
    254       if line[0] in '\t ':
    255         if not all_headers:
    256           logging.warning(
    257               'Unexpected response header continuation line [%s]', line)
    258           continue
    259         name, value = all_headers.pop()
    260         value += '\n ' + line.strip()
    261       else:
    262         name_value = RealHttpFetch._GetHeaderNameValue(line)
    263         if not name_value:
    264           logging.warning(
    265               'Response header in wrong format [%s]', line)
    266           continue
    267         name, value = name_value  # pylint: disable=unpacking-non-sequence
    268       all_headers.append((name, value))
    269     return all_headers
    270 
    271   @staticmethod
    272   def _get_request_host_port(request):
    273     host_parts = request.host.split(':')
    274     host = host_parts[0]
    275     port = int(host_parts[1]) if len(host_parts) == 2 else None
    276     return host, port
    277 
    278   @staticmethod
    279   def _get_system_proxy(is_ssl):
    280     return platformsettings.get_system_proxy(is_ssl)
    281 
    282   def _get_connection(self, request_host, request_port, is_ssl):
    283     """Return a detailed connection object for host/port pair.
    284 
    285     If a system proxy is defined (see platformsettings.py), it will be used.
    286 
    287     Args:
    288       request_host: a host string (e.g. "www.example.com").
    289       request_port: a port integer (e.g. 8080) or None (for the default port).
    290       is_ssl: True if HTTPS connection is needed.
    291     Returns:
    292       A DetailedHTTPSConnection or DetailedHTTPConnection instance.
    293     """
    294     connection_host = request_host
    295     connection_port = request_port
    296     system_proxy = self._get_system_proxy(is_ssl)
    297     if system_proxy:
    298       connection_host = system_proxy.host
    299       connection_port = system_proxy.port
    300 
    301     # Use an IP address because WPR may override DNS settings.
    302     if self._real_dns_lookup:
    303       connection_ip = self._real_dns_lookup(connection_host)
    304       if not connection_ip:
    305         logging.critical(
    306             'Unable to find IP for host name: %s', connection_host)
    307         return None
    308       connection_host = connection_ip
    309 
    310     if is_ssl:
    311       connection = DetailedHTTPSConnection(connection_host, connection_port)
    312       if system_proxy:
    313         connection.set_tunnel(request_host, request_port)
    314     else:
    315       connection = DetailedHTTPConnection(connection_host, connection_port)
    316     return connection
    317 
    318   def __call__(self, request):
    319     """Fetch an HTTP request.
    320 
    321     Args:
    322       request: an ArchivedHttpRequest
    323     Returns:
    324       an ArchivedHttpResponse
    325     """
    326     logging.debug('RealHttpFetch: %s %s', request.host, request.full_path)
    327     request_host, request_port = self._get_request_host_port(request)
    328     retries = 3
    329     while True:
    330       try:
    331         request_time = datetime.datetime.utcnow()
    332         connection = self._get_connection(
    333             request_host, request_port, request.is_ssl)
    334         connect_start = TIMER()
    335         connection.connect()
    336         connect_delay = int((TIMER() - connect_start) * 1000)
    337         start = TIMER()
    338         connection.request(
    339             request.command,
    340             request.full_path,
    341             request.request_body,
    342             request.headers)
    343         response = connection.getresponse()
    344         headers_delay = int((TIMER() - start) * 1000)
    345 
    346         chunks, chunk_delays = response.read_chunks()
    347         delays = {
    348             'connect': connect_delay,
    349             'headers': headers_delay,
    350             'data': chunk_delays
    351             }
    352         archived_http_response = httparchive.ArchivedHttpResponse(
    353             response.version,
    354             response.status,
    355             response.reason,
    356             RealHttpFetch._ToTuples(response.msg.headers),
    357             chunks,
    358             delays,
    359             request_time)
    360         return archived_http_response
    361       except Exception, e:
    362         if retries:
    363           retries -= 1
    364           logging.warning('Retrying fetch %s: %s', request, repr(e))
    365           continue
    366         logging.critical('Could not fetch %s: %s', request, repr(e))
    367         return None
    368 
    369 
    370 class RecordHttpArchiveFetch(object):
    371   """Make real HTTP fetches and save responses in the given HttpArchive."""
    372 
    373   def __init__(self, http_archive, injector):
    374     """Initialize RecordHttpArchiveFetch.
    375 
    376     Args:
    377       http_archive: an instance of a HttpArchive
    378       injector: script injector to inject scripts in all pages
    379     """
    380     self.http_archive = http_archive
    381     # Do not resolve host name to IP when recording to avoid SSL3 handshake
    382     # failure.
    383     # See https://github.com/chromium/web-page-replay/issues/73 for details.
    384     self.real_http_fetch = RealHttpFetch(real_dns_lookup=None)
    385     self.injector = injector
    386 
    387   def __call__(self, request):
    388     """Fetch the request and return the response.
    389 
    390     Args:
    391       request: an ArchivedHttpRequest.
    392     Returns:
    393       an ArchivedHttpResponse
    394     """
    395     # If request is already in the archive, return the archived response.
    396     if request in self.http_archive:
    397       logging.debug('Repeated request found: %s', request)
    398       response = self.http_archive[request]
    399     else:
    400       response = self.real_http_fetch(request)
    401       if response is None:
    402         return None
    403       self.http_archive[request] = response
    404     if self.injector:
    405       response = _InjectScripts(response, self.injector)
    406     logging.debug('Recorded: %s', request)
    407     return response
    408 
    409 
    410 class ReplayHttpArchiveFetch(object):
    411   """Serve responses from the given HttpArchive."""
    412 
    413   def __init__(self, http_archive, real_dns_lookup, injector,
    414                use_diff_on_unknown_requests=False,
    415                use_closest_match=False, scramble_images=False):
    416     """Initialize ReplayHttpArchiveFetch.
    417 
    418     Args:
    419       http_archive: an instance of a HttpArchive
    420       real_dns_lookup: a function that resolves a host to an IP.
    421       injector: script injector to inject scripts in all pages
    422       use_diff_on_unknown_requests: If True, log unknown requests
    423         with a diff to requests that look similar.
    424       use_closest_match: If True, on replay mode, serve the closest match
    425         in the archive instead of giving a 404.
    426     """
    427     self.http_archive = http_archive
    428     self.injector = injector
    429     self.use_diff_on_unknown_requests = use_diff_on_unknown_requests
    430     self.use_closest_match = use_closest_match
    431     self.scramble_images = scramble_images
    432     self.real_http_fetch = RealHttpFetch(real_dns_lookup)
    433 
    434   def __call__(self, request):
    435     """Fetch the request and return the response.
    436 
    437     Args:
    438       request: an instance of an ArchivedHttpRequest.
    439     Returns:
    440       Instance of ArchivedHttpResponse (if found) or None
    441     """
    442     if request.host.startswith('127.0.0.1:'):
    443       return self.real_http_fetch(request)
    444 
    445     response = self.http_archive.get(request)
    446 
    447     if self.use_closest_match and not response:
    448       closest_request = self.http_archive.find_closest_request(
    449           request, use_path=True)
    450       if closest_request:
    451         response = self.http_archive.get(closest_request)
    452         if response:
    453           logging.info('Request not found: %s\nUsing closest match: %s',
    454                        request, closest_request)
    455 
    456     if not response:
    457       reason = str(request)
    458       if self.use_diff_on_unknown_requests:
    459         diff = self.http_archive.diff(request)
    460         if diff:
    461           reason += (
    462               "\nNearest request diff "
    463               "('-' for archived request, '+' for current request):\n%s" % diff)
    464       logging.warning('Could not replay: %s', reason)
    465     else:
    466       if self.injector:
    467         response = _InjectScripts(response, self.injector)
    468       if self.scramble_images:
    469         response = _ScrambleImages(response)
    470     return response
    471 
    472 
    473 class ControllableHttpArchiveFetch(object):
    474   """Controllable fetch function that can swap between record and replay."""
    475 
    476   def __init__(self, http_archive, real_dns_lookup,
    477                injector, use_diff_on_unknown_requests,
    478                use_record_mode, use_closest_match, scramble_images):
    479     """Initialize HttpArchiveFetch.
    480 
    481     Args:
    482       http_archive: an instance of a HttpArchive
    483       real_dns_lookup: a function that resolves a host to an IP.
    484       injector: function to inject scripts in all pages.
    485         takes recording time as datetime.datetime object.
    486       use_diff_on_unknown_requests: If True, log unknown requests
    487         with a diff to requests that look similar.
    488       use_record_mode: If True, start in server in record mode.
    489       use_closest_match: If True, on replay mode, serve the closest match
    490         in the archive instead of giving a 404.
    491     """
    492     self.http_archive = http_archive
    493     self.record_fetch = RecordHttpArchiveFetch(http_archive, injector)
    494     self.replay_fetch = ReplayHttpArchiveFetch(
    495         http_archive, real_dns_lookup, injector,
    496         use_diff_on_unknown_requests, use_closest_match, scramble_images)
    497     if use_record_mode:
    498       self.SetRecordMode()
    499     else:
    500       self.SetReplayMode()
    501 
    502   def SetRecordMode(self):
    503     self.fetch = self.record_fetch
    504     self.is_record_mode = True
    505 
    506   def SetReplayMode(self):
    507     self.fetch = self.replay_fetch
    508     self.is_record_mode = False
    509 
    510   def __call__(self, *args, **kwargs):
    511     """Forward calls to Replay/Record fetch functions depending on mode."""
    512     return self.fetch(*args, **kwargs)
    513