Home | History | Annotate | Download | only in safe_browsing
      1 #!/usr/bin/env python
      2 # Copyright 2013 The Chromium Authors. All rights reserved.
      3 # Use of this source code is governed by a BSD-style license that can be
      4 # found in the LICENSE file.
      5 
      6 """Testserver for the two phase upload protocol."""
      7 
      8 import base64
      9 import BaseHTTPServer
     10 import hashlib
     11 import os
     12 import sys
     13 import urlparse
     14 
     15 BASE_DIR = os.path.dirname(os.path.abspath(__file__))
     16 
     17 sys.path.append(os.path.join(BASE_DIR, '..', '..', '..', 'net',
     18                              'tools', 'testserver'))
     19 import testserver_base
     20 
     21 
     22 class RequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
     23   def ReadRequestBody(self):
     24     """This function reads the body of the current HTTP request, handling
     25     both plain and chunked transfer encoded requests."""
     26 
     27     if self.headers.getheader('transfer-encoding') == 'chunked':
     28       return ''
     29 
     30     length = int(self.headers.getheader('content-length'))
     31     return self.rfile.read(length)
     32 
     33   def do_GET(self):
     34     print 'GET', self.path
     35     self.send_error(400, 'GET not supported')
     36 
     37   def do_POST(self):
     38     request_body = self.ReadRequestBody()
     39     print 'POST', repr(self.path), repr(request_body)
     40 
     41     kStartHeader = 'x-goog-resumable'
     42     if kStartHeader not in self.headers:
     43       self.send_error(400, 'Missing header: ' + kStartHeader)
     44       return
     45     if self.headers.get(kStartHeader) != 'start':
     46       self.send_error(400, 'Invalid %s header value: %s' % (
     47           kStartHeader, self.headers.get(kStartHeader)))
     48       return
     49 
     50     metadata_hash = hashlib.sha1(request_body).hexdigest()
     51     _, _, url_path, _, query, _ = urlparse.urlparse(self.path)
     52     query_args = urlparse.parse_qs(query)
     53 
     54     if query_args.get('p1close'):
     55       self.close_connection = 1
     56       return
     57 
     58     put_url = 'http://%s:%d/put?%s,%s,%s' % (self.server.server_address[0],
     59                                           self.server.server_port,
     60                                           url_path,
     61                                           metadata_hash,
     62                                           base64.urlsafe_b64encode(query))
     63     self.send_response(int(query_args.get('p1code', [201])[0]))
     64     self.send_header('Location', put_url)
     65     self.end_headers()
     66 
     67   def do_PUT(self):
     68     _, _, url_path, _, query, _ = urlparse.urlparse(self.path)
     69     if url_path != '/put':
     70       self.send_error(400, 'invalid path on 2nd phase: ' + url_path)
     71       return
     72 
     73     initial_path, metadata_hash, config_query_b64 = query.split(',', 2)
     74     config_query = urlparse.parse_qs(base64.urlsafe_b64decode(config_query_b64))
     75 
     76     request_body = self.ReadRequestBody()
     77     print 'PUT', repr(self.path), len(request_body), 'bytes'
     78 
     79     if config_query.get('p2close'):
     80       self.close_connection = 1
     81       return
     82 
     83     self.send_response(int(config_query.get('p2code', [200])[0]))
     84     self.end_headers()
     85     self.wfile.write('%s\n%s\n%s\n' % (
     86         initial_path,
     87         metadata_hash,
     88         hashlib.sha1(request_body).hexdigest()))
     89 
     90 
     91 class ServerRunner(testserver_base.TestServerRunner):
     92   """TestServerRunner for safebrowsing_test_server.py."""
     93 
     94   def create_server(self, server_data):
     95     server = BaseHTTPServer.HTTPServer((self.options.host, self.options.port),
     96                                        RequestHandler)
     97     print 'server started on port %d...' % server.server_port
     98     server_data['port'] = server.server_port
     99 
    100     return server
    101 
    102 
    103 if __name__ == '__main__':
    104   sys.exit(ServerRunner().main())
    105