Home | History | Annotate | Download | only in tests
      1 #!/usr/bin/env python3
      2 
      3 # pylint: disable=unused-import,import-error
      4 
      5 import sys
      6 
      7 
      8 try:
      9     from tempfile import TemporaryDirectory
     10 except ImportError:
     11     import shutil
     12     import tempfile
     13 
     14 
     15     class TemporaryDirectory(object):
     16         def __init__(self, suffix='', prefix='tmp', dir=None):
     17             # pylint: disable=redefined-builtin
     18             self.name = tempfile.mkdtemp(suffix, prefix, dir)
     19 
     20 
     21         def __del__(self):
     22             self.cleanup()
     23 
     24 
     25         def __enter__(self):
     26             return self.name
     27 
     28 
     29         def __exit__(self, exc, value, tb):
     30             self.cleanup()
     31 
     32 
     33         def cleanup(self):
     34             if self.name:
     35                 shutil.rmtree(self.name)
     36                 self.name = None
     37 
     38 
     39 if sys.version_info >= (3, 0):
     40     from os import makedirs
     41 else:
     42     import os
     43 
     44 
     45     def makedirs(path, exist_ok):
     46         if exist_ok and os.path.exists(path):
     47             return
     48         os.makedirs(path)
     49 
     50 
     51 if sys.version_info >= (3, 0):
     52     from io import StringIO
     53 else:
     54     from StringIO import StringIO
     55 
     56 
     57 try:
     58     from unittest.mock import patch
     59 except ImportError:
     60     import contextlib
     61 
     62 
     63     @contextlib.contextmanager
     64     def patch(target, mock):
     65         obj, attr = target.rsplit('.')
     66         obj = __import__(obj)
     67         original_value = getattr(obj, attr)
     68         setattr(obj, attr, mock)
     69         try:
     70             yield
     71         finally:
     72             setattr(obj, attr, original_value)
     73 
     74 
     75 if sys.version_info >= (3, 2):
     76     from unittest import TestCase
     77 else:
     78     import unittest
     79 
     80 
     81     class TestCase(unittest.TestCase):
     82         def assertRegex(self, text, expected_regex, msg=None):
     83             # pylint: disable=deprecated-method
     84             self.assertRegexpMatches(text, expected_regex, msg)
     85 
     86 
     87         def assertNotRegex(self, text, unexpected_regex, msg=None):
     88             # pylint: disable=deprecated-method
     89             self.assertNotRegexpMatches(text, unexpected_regex, msg)
     90