Home | History | Annotate | Download | only in test_email
      1 import os
      2 import unittest
      3 import collections
      4 import email
      5 from email.message import Message
      6 from email._policybase import compat32
      7 from test.support import load_package_tests
      8 from test.test_email import __file__ as landmark
      9 
     10 # Load all tests in package
     11 def load_tests(*args):
     12     return load_package_tests(os.path.dirname(__file__), *args)
     13 
     14 
     15 # helper code used by a number of test modules.
     16 
     17 def openfile(filename, *args, **kws):
     18     path = os.path.join(os.path.dirname(landmark), 'data', filename)
     19     return open(path, *args, **kws)
     20 
     21 
     22 # Base test class
     23 class TestEmailBase(unittest.TestCase):
     24 
     25     maxDiff = None
     26     # Currently the default policy is compat32.  By setting that as the default
     27     # here we make minimal changes in the test_email tests compared to their
     28     # pre-3.3 state.
     29     policy = compat32
     30     # Likewise, the default message object is Message.
     31     message = Message
     32 
     33     def __init__(self, *args, **kw):
     34         super().__init__(*args, **kw)
     35         self.addTypeEqualityFunc(bytes, self.assertBytesEqual)
     36 
     37     # Backward compatibility to minimize test_email test changes.
     38     ndiffAssertEqual = unittest.TestCase.assertEqual
     39 
     40     def _msgobj(self, filename):
     41         with openfile(filename) as fp:
     42             return email.message_from_file(fp, policy=self.policy)
     43 
     44     def _str_msg(self, string, message=None, policy=None):
     45         if policy is None:
     46             policy = self.policy
     47         if message is None:
     48             message = self.message
     49         return email.message_from_string(string, message, policy=policy)
     50 
     51     def _bytes_msg(self, bytestring, message=None, policy=None):
     52         if policy is None:
     53             policy = self.policy
     54         if message is None:
     55             message = self.message
     56         return email.message_from_bytes(bytestring, message, policy=policy)
     57 
     58     def _make_message(self):
     59         return self.message(policy=self.policy)
     60 
     61     def _bytes_repr(self, b):
     62         return [repr(x) for x in b.splitlines(keepends=True)]
     63 
     64     def assertBytesEqual(self, first, second, msg):
     65         """Our byte strings are really encoded strings; improve diff output"""
     66         self.assertEqual(self._bytes_repr(first), self._bytes_repr(second))
     67 
     68     def assertDefectsEqual(self, actual, expected):
     69         self.assertEqual(len(actual), len(expected), actual)
     70         for i in range(len(actual)):
     71             self.assertIsInstance(actual[i], expected[i],
     72                                     'item {}'.format(i))
     73 
     74 
     75 def parameterize(cls):
     76     """A test method parameterization class decorator.
     77 
     78     Parameters are specified as the value of a class attribute that ends with
     79     the string '_params'.  Call the portion before '_params' the prefix.  Then
     80     a method to be parameterized must have the same prefix, the string
     81     '_as_', and an arbitrary suffix.
     82 
     83     The value of the _params attribute may be either a dictionary or a list.
     84     The values in the dictionary and the elements of the list may either be
     85     single values, or a list.  If single values, they are turned into single
     86     element tuples.  However derived, the resulting sequence is passed via
     87     *args to the parameterized test function.
     88 
     89     In a _params dictionary, the keys become part of the name of the generated
     90     tests.  In a _params list, the values in the list are converted into a
     91     string by joining the string values of the elements of the tuple by '_' and
     92     converting any blanks into '_'s, and this become part of the name.
     93     The  full name of a generated test is a 'test_' prefix, the portion of the
     94     test function name after the  '_as_' separator, plus an '_', plus the name
     95     derived as explained above.
     96 
     97     For example, if we have:
     98 
     99         count_params = range(2)
    100 
    101         def count_as_foo_arg(self, foo):
    102             self.assertEqual(foo+1, myfunc(foo))
    103 
    104     we will get parameterized test methods named:
    105         test_foo_arg_0
    106         test_foo_arg_1
    107         test_foo_arg_2
    108 
    109     Or we could have:
    110 
    111         example_params = {'foo': ('bar', 1), 'bing': ('bang', 2)}
    112 
    113         def example_as_myfunc_input(self, name, count):
    114             self.assertEqual(name+str(count), myfunc(name, count))
    115 
    116     and get:
    117         test_myfunc_input_foo
    118         test_myfunc_input_bing
    119 
    120     Note: if and only if the generated test name is a valid identifier can it
    121     be used to select the test individually from the unittest command line.
    122 
    123     The values in the params dict can be a single value, a tuple, or a
    124     dict.  If a single value of a tuple, it is passed to the test function
    125     as positional arguments.  If a dict, it is a passed via **kw.
    126 
    127     """
    128     paramdicts = {}
    129     testers = collections.defaultdict(list)
    130     for name, attr in cls.__dict__.items():
    131         if name.endswith('_params'):
    132             if not hasattr(attr, 'keys'):
    133                 d = {}
    134                 for x in attr:
    135                     if not hasattr(x, '__iter__'):
    136                         x = (x,)
    137                     n = '_'.join(str(v) for v in x).replace(' ', '_')
    138                     d[n] = x
    139                 attr = d
    140             paramdicts[name[:-7] + '_as_'] = attr
    141         if '_as_' in name:
    142             testers[name.split('_as_')[0] + '_as_'].append(name)
    143     testfuncs = {}
    144     for name in paramdicts:
    145         if name not in testers:
    146             raise ValueError("No tester found for {}".format(name))
    147     for name in testers:
    148         if name not in paramdicts:
    149             raise ValueError("No params found for {}".format(name))
    150     for name, attr in cls.__dict__.items():
    151         for paramsname, paramsdict in paramdicts.items():
    152             if name.startswith(paramsname):
    153                 testnameroot = 'test_' + name[len(paramsname):]
    154                 for paramname, params in paramsdict.items():
    155                     if hasattr(params, 'keys'):
    156                         test = (lambda self, name=name, params=params:
    157                                     getattr(self, name)(**params))
    158                     else:
    159                         test = (lambda self, name=name, params=params:
    160                                         getattr(self, name)(*params))
    161                     testname = testnameroot + '_' + paramname
    162                     test.__name__ = testname
    163                     testfuncs[testname] = test
    164     for key, value in testfuncs.items():
    165         setattr(cls, key, value)
    166     return cls
    167