Home | History | Annotate | Download | only in test
      1 from __future__ import nested_scopes    # Backward compat for 2.1
      2 from unittest import TestCase
      3 from wsgiref.util import setup_testing_defaults
      4 from wsgiref.headers import Headers
      5 from wsgiref.handlers import BaseHandler, BaseCGIHandler
      6 from wsgiref import util
      7 from wsgiref.validate import validator
      8 from wsgiref.simple_server import WSGIServer, WSGIRequestHandler, demo_app
      9 from wsgiref.simple_server import make_server
     10 from StringIO import StringIO
     11 from SocketServer import BaseServer
     12 import os
     13 import re
     14 import sys
     15 
     16 from test import test_support
     17 
     18 class MockServer(WSGIServer):
     19     """Non-socket HTTP server"""
     20 
     21     def __init__(self, server_address, RequestHandlerClass):
     22         BaseServer.__init__(self, server_address, RequestHandlerClass)
     23         self.server_bind()
     24 
     25     def server_bind(self):
     26         host, port = self.server_address
     27         self.server_name = host
     28         self.server_port = port
     29         self.setup_environ()
     30 
     31 
     32 class MockHandler(WSGIRequestHandler):
     33     """Non-socket HTTP handler"""
     34     def setup(self):
     35         self.connection = self.request
     36         self.rfile, self.wfile = self.connection
     37 
     38     def finish(self):
     39         pass
     40 
     41 
     42 def hello_app(environ,start_response):
     43     start_response("200 OK", [
     44         ('Content-Type','text/plain'),
     45         ('Date','Mon, 05 Jun 2006 18:49:54 GMT')
     46     ])
     47     return ["Hello, world!"]
     48 
     49 def run_amock(app=hello_app, data="GET / HTTP/1.0\n\n"):
     50     server = make_server("", 80, app, MockServer, MockHandler)
     51     inp, out, err, olderr = StringIO(data), StringIO(), StringIO(), sys.stderr
     52     sys.stderr = err
     53 
     54     try:
     55         server.finish_request((inp,out), ("127.0.0.1",8888))
     56     finally:
     57         sys.stderr = olderr
     58 
     59     return out.getvalue(), err.getvalue()
     60 
     61 
     62 def compare_generic_iter(make_it,match):
     63     """Utility to compare a generic 2.1/2.2+ iterator with an iterable
     64 
     65     If running under Python 2.2+, this tests the iterator using iter()/next(),
     66     as well as __getitem__.  'make_it' must be a function returning a fresh
     67     iterator to be tested (since this may test the iterator twice)."""
     68 
     69     it = make_it()
     70     n = 0
     71     for item in match:
     72         if not it[n]==item: raise AssertionError
     73         n+=1
     74     try:
     75         it[n]
     76     except IndexError:
     77         pass
     78     else:
     79         raise AssertionError("Too many items from __getitem__",it)
     80 
     81     try:
     82         iter, StopIteration
     83     except NameError:
     84         pass
     85     else:
     86         # Only test iter mode under 2.2+
     87         it = make_it()
     88         if not iter(it) is it: raise AssertionError
     89         for item in match:
     90             if not it.next()==item: raise AssertionError
     91         try:
     92             it.next()
     93         except StopIteration:
     94             pass
     95         else:
     96             raise AssertionError("Too many items from .next()",it)
     97 
     98 
     99 class IntegrationTests(TestCase):
    100 
    101     def check_hello(self, out, has_length=True):
    102         self.assertEqual(out,
    103             "HTTP/1.0 200 OK\r\n"
    104             "Server: WSGIServer/0.1 Python/"+sys.version.split()[0]+"\r\n"
    105             "Content-Type: text/plain\r\n"
    106             "Date: Mon, 05 Jun 2006 18:49:54 GMT\r\n" +
    107             (has_length and  "Content-Length: 13\r\n" or "") +
    108             "\r\n"
    109             "Hello, world!"
    110         )
    111 
    112     def test_plain_hello(self):
    113         out, err = run_amock()
    114         self.check_hello(out)
    115 
    116     def test_validated_hello(self):
    117         out, err = run_amock(validator(hello_app))
    118         # the middleware doesn't support len(), so content-length isn't there
    119         self.check_hello(out, has_length=False)
    120 
    121     def test_simple_validation_error(self):
    122         def bad_app(environ,start_response):
    123             start_response("200 OK", ('Content-Type','text/plain'))
    124             return ["Hello, world!"]
    125         out, err = run_amock(validator(bad_app))
    126         self.assertTrue(out.endswith(
    127             "A server error occurred.  Please contact the administrator."
    128         ))
    129         self.assertEqual(
    130             err.splitlines()[-2],
    131             "AssertionError: Headers (('Content-Type', 'text/plain')) must"
    132             " be of type list: <type 'tuple'>"
    133         )
    134 
    135 
    136 class UtilityTests(TestCase):
    137 
    138     def checkShift(self,sn_in,pi_in,part,sn_out,pi_out):
    139         env = {'SCRIPT_NAME':sn_in,'PATH_INFO':pi_in}
    140         util.setup_testing_defaults(env)
    141         self.assertEqual(util.shift_path_info(env),part)
    142         self.assertEqual(env['PATH_INFO'],pi_out)
    143         self.assertEqual(env['SCRIPT_NAME'],sn_out)
    144         return env
    145 
    146     def checkDefault(self, key, value, alt=None):
    147         # Check defaulting when empty
    148         env = {}
    149         util.setup_testing_defaults(env)
    150         if isinstance(value, StringIO):
    151             self.assertIsInstance(env[key], StringIO)
    152         else:
    153             self.assertEqual(env[key], value)
    154 
    155         # Check existing value
    156         env = {key:alt}
    157         util.setup_testing_defaults(env)
    158         self.assertTrue(env[key] is alt)
    159 
    160     def checkCrossDefault(self,key,value,**kw):
    161         util.setup_testing_defaults(kw)
    162         self.assertEqual(kw[key],value)
    163 
    164     def checkAppURI(self,uri,**kw):
    165         util.setup_testing_defaults(kw)
    166         self.assertEqual(util.application_uri(kw),uri)
    167 
    168     def checkReqURI(self,uri,query=1,**kw):
    169         util.setup_testing_defaults(kw)
    170         self.assertEqual(util.request_uri(kw,query),uri)
    171 
    172     def checkFW(self,text,size,match):
    173 
    174         def make_it(text=text,size=size):
    175             return util.FileWrapper(StringIO(text),size)
    176 
    177         compare_generic_iter(make_it,match)
    178 
    179         it = make_it()
    180         self.assertFalse(it.filelike.closed)
    181 
    182         for item in it:
    183             pass
    184 
    185         self.assertFalse(it.filelike.closed)
    186 
    187         it.close()
    188         self.assertTrue(it.filelike.closed)
    189 
    190     def testSimpleShifts(self):
    191         self.checkShift('','/', '', '/', '')
    192         self.checkShift('','/x', 'x', '/x', '')
    193         self.checkShift('/','', None, '/', '')
    194         self.checkShift('/a','/x/y', 'x', '/a/x', '/y')
    195         self.checkShift('/a','/x/',  'x', '/a/x', '/')
    196 
    197     def testNormalizedShifts(self):
    198         self.checkShift('/a/b', '/../y', '..', '/a', '/y')
    199         self.checkShift('', '/../y', '..', '', '/y')
    200         self.checkShift('/a/b', '//y', 'y', '/a/b/y', '')
    201         self.checkShift('/a/b', '//y/', 'y', '/a/b/y', '/')
    202         self.checkShift('/a/b', '/./y', 'y', '/a/b/y', '')
    203         self.checkShift('/a/b', '/./y/', 'y', '/a/b/y', '/')
    204         self.checkShift('/a/b', '///./..//y/.//', '..', '/a', '/y/')
    205         self.checkShift('/a/b', '///', '', '/a/b/', '')
    206         self.checkShift('/a/b', '/.//', '', '/a/b/', '')
    207         self.checkShift('/a/b', '/x//', 'x', '/a/b/x', '/')
    208         self.checkShift('/a/b', '/.', None, '/a/b', '')
    209 
    210     def testDefaults(self):
    211         for key, value in [
    212             ('SERVER_NAME','127.0.0.1'),
    213             ('SERVER_PORT', '80'),
    214             ('SERVER_PROTOCOL','HTTP/1.0'),
    215             ('HTTP_HOST','127.0.0.1'),
    216             ('REQUEST_METHOD','GET'),
    217             ('SCRIPT_NAME',''),
    218             ('PATH_INFO','/'),
    219             ('wsgi.version', (1,0)),
    220             ('wsgi.run_once', 0),
    221             ('wsgi.multithread', 0),
    222             ('wsgi.multiprocess', 0),
    223             ('wsgi.input', StringIO("")),
    224             ('wsgi.errors', StringIO()),
    225             ('wsgi.url_scheme','http'),
    226         ]:
    227             self.checkDefault(key,value)
    228 
    229     def testCrossDefaults(self):
    230         self.checkCrossDefault('HTTP_HOST',"foo.bar",SERVER_NAME="foo.bar")
    231         self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="on")
    232         self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="1")
    233         self.checkCrossDefault('wsgi.url_scheme',"https",HTTPS="yes")
    234         self.checkCrossDefault('wsgi.url_scheme',"http",HTTPS="foo")
    235         self.checkCrossDefault('SERVER_PORT',"80",HTTPS="foo")
    236         self.checkCrossDefault('SERVER_PORT',"443",HTTPS="on")
    237 
    238     def testGuessScheme(self):
    239         self.assertEqual(util.guess_scheme({}), "http")
    240         self.assertEqual(util.guess_scheme({'HTTPS':"foo"}), "http")
    241         self.assertEqual(util.guess_scheme({'HTTPS':"on"}), "https")
    242         self.assertEqual(util.guess_scheme({'HTTPS':"yes"}), "https")
    243         self.assertEqual(util.guess_scheme({'HTTPS':"1"}), "https")
    244 
    245     def testAppURIs(self):
    246         self.checkAppURI("http://127.0.0.1/")
    247         self.checkAppURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
    248         self.checkAppURI("http://spam.example.com:2071/",
    249             HTTP_HOST="spam.example.com:2071", SERVER_PORT="2071")
    250         self.checkAppURI("http://spam.example.com/",
    251             SERVER_NAME="spam.example.com")
    252         self.checkAppURI("http://127.0.0.1/",
    253             HTTP_HOST="127.0.0.1", SERVER_NAME="spam.example.com")
    254         self.checkAppURI("https://127.0.0.1/", HTTPS="on")
    255         self.checkAppURI("http://127.0.0.1:8000/", SERVER_PORT="8000",
    256             HTTP_HOST=None)
    257 
    258     def testReqURIs(self):
    259         self.checkReqURI("http://127.0.0.1/")
    260         self.checkReqURI("http://127.0.0.1/spam", SCRIPT_NAME="/spam")
    261         self.checkReqURI("http://127.0.0.1/spammity/spam",
    262             SCRIPT_NAME="/spammity", PATH_INFO="/spam")
    263         self.checkReqURI("http://127.0.0.1/spammity/spam;ham",
    264             SCRIPT_NAME="/spammity", PATH_INFO="/spam;ham")
    265         self.checkReqURI("http://127.0.0.1/spammity/spam;cookie=1234,5678",
    266             SCRIPT_NAME="/spammity", PATH_INFO="/spam;cookie=1234,5678")
    267         self.checkReqURI("http://127.0.0.1/spammity/spam?say=ni",
    268             SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni")
    269         self.checkReqURI("http://127.0.0.1/spammity/spam", 0,
    270             SCRIPT_NAME="/spammity", PATH_INFO="/spam",QUERY_STRING="say=ni")
    271 
    272     def testFileWrapper(self):
    273         self.checkFW("xyz"*50, 120, ["xyz"*40,"xyz"*10])
    274 
    275     def testHopByHop(self):
    276         for hop in (
    277             "Connection Keep-Alive Proxy-Authenticate Proxy-Authorization "
    278             "TE Trailers Transfer-Encoding Upgrade"
    279         ).split():
    280             for alt in hop, hop.title(), hop.upper(), hop.lower():
    281                 self.assertTrue(util.is_hop_by_hop(alt))
    282 
    283         # Not comprehensive, just a few random header names
    284         for hop in (
    285             "Accept Cache-Control Date Pragma Trailer Via Warning"
    286         ).split():
    287             for alt in hop, hop.title(), hop.upper(), hop.lower():
    288                 self.assertFalse(util.is_hop_by_hop(alt))
    289 
    290 class HeaderTests(TestCase):
    291 
    292     def testMappingInterface(self):
    293         test = [('x','y')]
    294         self.assertEqual(len(Headers([])),0)
    295         self.assertEqual(len(Headers(test[:])),1)
    296         self.assertEqual(Headers(test[:]).keys(), ['x'])
    297         self.assertEqual(Headers(test[:]).values(), ['y'])
    298         self.assertEqual(Headers(test[:]).items(), test)
    299         self.assertFalse(Headers(test).items() is test)  # must be copy!
    300 
    301         h=Headers([])
    302         del h['foo']   # should not raise an error
    303 
    304         h['Foo'] = 'bar'
    305         for m in h.has_key, h.__contains__, h.get, h.get_all, h.__getitem__:
    306             self.assertTrue(m('foo'))
    307             self.assertTrue(m('Foo'))
    308             self.assertTrue(m('FOO'))
    309             self.assertFalse(m('bar'))
    310 
    311         self.assertEqual(h['foo'],'bar')
    312         h['foo'] = 'baz'
    313         self.assertEqual(h['FOO'],'baz')
    314         self.assertEqual(h.get_all('foo'),['baz'])
    315 
    316         self.assertEqual(h.get("foo","whee"), "baz")
    317         self.assertEqual(h.get("zoo","whee"), "whee")
    318         self.assertEqual(h.setdefault("foo","whee"), "baz")
    319         self.assertEqual(h.setdefault("zoo","whee"), "whee")
    320         self.assertEqual(h["foo"],"baz")
    321         self.assertEqual(h["zoo"],"whee")
    322 
    323     def testRequireList(self):
    324         self.assertRaises(TypeError, Headers, "foo")
    325 
    326 
    327     def testExtras(self):
    328         h = Headers([])
    329         self.assertEqual(str(h),'\r\n')
    330 
    331         h.add_header('foo','bar',baz="spam")
    332         self.assertEqual(h['foo'], 'bar; baz="spam"')
    333         self.assertEqual(str(h),'foo: bar; baz="spam"\r\n\r\n')
    334 
    335         h.add_header('Foo','bar',cheese=None)
    336         self.assertEqual(h.get_all('foo'),
    337             ['bar; baz="spam"', 'bar; cheese'])
    338 
    339         self.assertEqual(str(h),
    340             'foo: bar; baz="spam"\r\n'
    341             'Foo: bar; cheese\r\n'
    342             '\r\n'
    343         )
    344 
    345 
    346 class ErrorHandler(BaseCGIHandler):
    347     """Simple handler subclass for testing BaseHandler"""
    348 
    349     # BaseHandler records the OS environment at import time, but envvars
    350     # might have been changed later by other tests, which trips up
    351     # HandlerTests.testEnviron().
    352     os_environ = dict(os.environ.items())
    353 
    354     def __init__(self,**kw):
    355         setup_testing_defaults(kw)
    356         BaseCGIHandler.__init__(
    357             self, StringIO(''), StringIO(), StringIO(), kw,
    358             multithread=True, multiprocess=True
    359         )
    360 
    361 class TestHandler(ErrorHandler):
    362     """Simple handler subclass for testing BaseHandler, w/error passthru"""
    363 
    364     def handle_error(self):
    365         raise   # for testing, we want to see what's happening
    366 
    367 
    368 class HandlerTests(TestCase):
    369 
    370     def checkEnvironAttrs(self, handler):
    371         env = handler.environ
    372         for attr in [
    373             'version','multithread','multiprocess','run_once','file_wrapper'
    374         ]:
    375             if attr=='file_wrapper' and handler.wsgi_file_wrapper is None:
    376                 continue
    377             self.assertEqual(getattr(handler,'wsgi_'+attr),env['wsgi.'+attr])
    378 
    379     def checkOSEnviron(self,handler):
    380         empty = {}; setup_testing_defaults(empty)
    381         env = handler.environ
    382         from os import environ
    383         for k,v in environ.items():
    384             if k not in empty:
    385                 self.assertEqual(env[k],v)
    386         for k,v in empty.items():
    387             self.assertIn(k, env)
    388 
    389     def testEnviron(self):
    390         h = TestHandler(X="Y")
    391         h.setup_environ()
    392         self.checkEnvironAttrs(h)
    393         self.checkOSEnviron(h)
    394         self.assertEqual(h.environ["X"],"Y")
    395 
    396     def testCGIEnviron(self):
    397         h = BaseCGIHandler(None,None,None,{})
    398         h.setup_environ()
    399         for key in 'wsgi.url_scheme', 'wsgi.input', 'wsgi.errors':
    400             self.assertIn(key, h.environ)
    401 
    402     def testScheme(self):
    403         h=TestHandler(HTTPS="on"); h.setup_environ()
    404         self.assertEqual(h.environ['wsgi.url_scheme'],'https')
    405         h=TestHandler(); h.setup_environ()
    406         self.assertEqual(h.environ['wsgi.url_scheme'],'http')
    407 
    408     def testAbstractMethods(self):
    409         h = BaseHandler()
    410         for name in [
    411             '_flush','get_stdin','get_stderr','add_cgi_vars'
    412         ]:
    413             self.assertRaises(NotImplementedError, getattr(h,name))
    414         self.assertRaises(NotImplementedError, h._write, "test")
    415 
    416     def testContentLength(self):
    417         # Demo one reason iteration is better than write()...  ;)
    418 
    419         def trivial_app1(e,s):
    420             s('200 OK',[])
    421             return [e['wsgi.url_scheme']]
    422 
    423         def trivial_app2(e,s):
    424             s('200 OK',[])(e['wsgi.url_scheme'])
    425             return []
    426 
    427         def trivial_app4(e,s):
    428             # Simulate a response to a HEAD request
    429             s('200 OK',[('Content-Length', '12345')])
    430             return []
    431 
    432         h = TestHandler()
    433         h.run(trivial_app1)
    434         self.assertEqual(h.stdout.getvalue(),
    435             "Status: 200 OK\r\n"
    436             "Content-Length: 4\r\n"
    437             "\r\n"
    438             "http")
    439 
    440         h = TestHandler()
    441         h.run(trivial_app2)
    442         self.assertEqual(h.stdout.getvalue(),
    443             "Status: 200 OK\r\n"
    444             "\r\n"
    445             "http")
    446 
    447 
    448         h = TestHandler()
    449         h.run(trivial_app4)
    450         self.assertEqual(h.stdout.getvalue(),
    451             b'Status: 200 OK\r\n'
    452             b'Content-Length: 12345\r\n'
    453             b'\r\n')
    454 
    455     def testBasicErrorOutput(self):
    456 
    457         def non_error_app(e,s):
    458             s('200 OK',[])
    459             return []
    460 
    461         def error_app(e,s):
    462             raise AssertionError("This should be caught by handler")
    463 
    464         h = ErrorHandler()
    465         h.run(non_error_app)
    466         self.assertEqual(h.stdout.getvalue(),
    467             "Status: 200 OK\r\n"
    468             "Content-Length: 0\r\n"
    469             "\r\n")
    470         self.assertEqual(h.stderr.getvalue(),"")
    471 
    472         h = ErrorHandler()
    473         h.run(error_app)
    474         self.assertEqual(h.stdout.getvalue(),
    475             "Status: %s\r\n"
    476             "Content-Type: text/plain\r\n"
    477             "Content-Length: %d\r\n"
    478             "\r\n%s" % (h.error_status,len(h.error_body),h.error_body))
    479 
    480         self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1)
    481 
    482     def testErrorAfterOutput(self):
    483         MSG = "Some output has been sent"
    484         def error_app(e,s):
    485             s("200 OK",[])(MSG)
    486             raise AssertionError("This should be caught by handler")
    487 
    488         h = ErrorHandler()
    489         h.run(error_app)
    490         self.assertEqual(h.stdout.getvalue(),
    491             "Status: 200 OK\r\n"
    492             "\r\n"+MSG)
    493         self.assertNotEqual(h.stderr.getvalue().find("AssertionError"), -1)
    494 
    495     def testHeaderFormats(self):
    496 
    497         def non_error_app(e,s):
    498             s('200 OK',[])
    499             return []
    500 
    501         stdpat = (
    502             r"HTTP/%s 200 OK\r\n"
    503             r"Date: \w{3}, [ 0123]\d \w{3} \d{4} \d\d:\d\d:\d\d GMT\r\n"
    504             r"%s" r"Content-Length: 0\r\n" r"\r\n"
    505         )
    506         shortpat = (
    507             "Status: 200 OK\r\n" "Content-Length: 0\r\n" "\r\n"
    508         )
    509 
    510         for ssw in "FooBar/1.0", None:
    511             sw = ssw and "Server: %s\r\n" % ssw or ""
    512 
    513             for version in "1.0", "1.1":
    514                 for proto in "HTTP/0.9", "HTTP/1.0", "HTTP/1.1":
    515 
    516                     h = TestHandler(SERVER_PROTOCOL=proto)
    517                     h.origin_server = False
    518                     h.http_version = version
    519                     h.server_software = ssw
    520                     h.run(non_error_app)
    521                     self.assertEqual(shortpat,h.stdout.getvalue())
    522 
    523                     h = TestHandler(SERVER_PROTOCOL=proto)
    524                     h.origin_server = True
    525                     h.http_version = version
    526                     h.server_software = ssw
    527                     h.run(non_error_app)
    528                     if proto=="HTTP/0.9":
    529                         self.assertEqual(h.stdout.getvalue(),"")
    530                     else:
    531                         self.assertTrue(
    532                             re.match(stdpat%(version,sw), h.stdout.getvalue()),
    533                             (stdpat%(version,sw), h.stdout.getvalue())
    534                         )
    535 
    536     def testCloseOnError(self):
    537         side_effects = {'close_called': False}
    538         MSG = b"Some output has been sent"
    539         def error_app(e,s):
    540             s("200 OK",[])(MSG)
    541             class CrashyIterable(object):
    542                 def __iter__(self):
    543                     while True:
    544                         yield b'blah'
    545                         raise AssertionError("This should be caught by handler")
    546 
    547                 def close(self):
    548                     side_effects['close_called'] = True
    549             return CrashyIterable()
    550 
    551         h = ErrorHandler()
    552         h.run(error_app)
    553         self.assertEqual(side_effects['close_called'], True)
    554 
    555 
    556 def test_main():
    557     test_support.run_unittest(__name__)
    558 
    559 if __name__ == "__main__":
    560     test_main()
    561