Home | History | Annotate | Download | only in test
      1 """Mock socket module used by the smtpd and smtplib tests.
      2 """
      3 
      4 # imported for _GLOBAL_DEFAULT_TIMEOUT
      5 import socket as socket_module
      6 
      7 # Mock socket module
      8 _defaulttimeout = None
      9 _reply_data = None
     10 
     11 # This is used to queue up data to be read through socket.makefile, typically
     12 # *before* the socket object is even created. It is intended to handle a single
     13 # line which the socket will feed on recv() or makefile().
     14 def reply_with(line):
     15     global _reply_data
     16     _reply_data = line
     17 
     18 
     19 class MockFile:
     20     """Mock file object returned by MockSocket.makefile().
     21     """
     22     def __init__(self, lines):
     23         self.lines = lines
     24     def readline(self, limit=-1):
     25         result = self.lines.pop(0) + b'\r\n'
     26         if limit >= 0:
     27             # Re-insert the line, removing the \r\n we added.
     28             self.lines.insert(0, result[limit:-2])
     29             result = result[:limit]
     30         return result
     31     def close(self):
     32         pass
     33 
     34 
     35 class MockSocket:
     36     """Mock socket object used by smtpd and smtplib tests.
     37     """
     38     def __init__(self, family=None):
     39         global _reply_data
     40         self.family = family
     41         self.output = []
     42         self.lines = []
     43         if _reply_data:
     44             self.lines.append(_reply_data)
     45             _reply_data = None
     46         self.conn = None
     47         self.timeout = None
     48 
     49     def queue_recv(self, line):
     50         self.lines.append(line)
     51 
     52     def recv(self, bufsize, flags=None):
     53         data = self.lines.pop(0) + b'\r\n'
     54         return data
     55 
     56     def fileno(self):
     57         return 0
     58 
     59     def settimeout(self, timeout):
     60         if timeout is None:
     61             self.timeout = _defaulttimeout
     62         else:
     63             self.timeout = timeout
     64 
     65     def gettimeout(self):
     66         return self.timeout
     67 
     68     def setsockopt(self, level, optname, value):
     69         pass
     70 
     71     def getsockopt(self, level, optname, buflen=None):
     72         return 0
     73 
     74     def bind(self, address):
     75         pass
     76 
     77     def accept(self):
     78         self.conn = MockSocket()
     79         return self.conn, 'c'
     80 
     81     def getsockname(self):
     82         return ('0.0.0.0', 0)
     83 
     84     def setblocking(self, flag):
     85         pass
     86 
     87     def listen(self, backlog):
     88         pass
     89 
     90     def makefile(self, mode='r', bufsize=-1):
     91         handle = MockFile(self.lines)
     92         return handle
     93 
     94     def sendall(self, buffer, flags=None):
     95         self.last = data
     96         self.output.append(data)
     97         return len(data)
     98 
     99     def send(self, data, flags=None):
    100         self.last = data
    101         self.output.append(data)
    102         return len(data)
    103 
    104     def getpeername(self):
    105         return ('peer-address', 'peer-port')
    106 
    107     def close(self):
    108         pass
    109 
    110 
    111 def socket(family=None, type=None, proto=None):
    112     return MockSocket(family)
    113 
    114 def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT,
    115                       source_address=None):
    116     try:
    117         int_port = int(address[1])
    118     except ValueError:
    119         raise error
    120     ms = MockSocket()
    121     if timeout is socket_module._GLOBAL_DEFAULT_TIMEOUT:
    122         timeout = getdefaulttimeout()
    123     ms.settimeout(timeout)
    124     return ms
    125 
    126 
    127 def setdefaulttimeout(timeout):
    128     global _defaulttimeout
    129     _defaulttimeout = timeout
    130 
    131 
    132 def getdefaulttimeout():
    133     return _defaulttimeout
    134 
    135 
    136 def getfqdn():
    137     return ""
    138 
    139 
    140 def gethostname():
    141     pass
    142 
    143 
    144 def gethostbyname(name):
    145     return ""
    146 
    147 def getaddrinfo(*args, **kw):
    148     return socket_module.getaddrinfo(*args, **kw)
    149 
    150 gaierror = socket_module.gaierror
    151 error = socket_module.error
    152 
    153 
    154 # Constants
    155 AF_INET = socket_module.AF_INET
    156 AF_INET6 = socket_module.AF_INET6
    157 SOCK_STREAM = socket_module.SOCK_STREAM
    158 SOL_SOCKET = None
    159 SO_REUSEADDR = None
    160