1 """ 2 Unit tests for refactor.py. 3 """ 4 5 from __future__ import with_statement 6 7 import sys 8 import os 9 import codecs 10 import operator 11 import StringIO 12 import tempfile 13 import shutil 14 import unittest 15 import warnings 16 17 from lib2to3 import refactor, pygram, fixer_base 18 from lib2to3.pgen2 import token 19 20 from . import support 21 22 23 TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 24 FIXER_DIR = os.path.join(TEST_DATA_DIR, "fixers") 25 26 sys.path.append(FIXER_DIR) 27 try: 28 _DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes") 29 finally: 30 sys.path.pop() 31 32 _2TO3_FIXERS = refactor.get_fixers_from_package("lib2to3.fixes") 33 34 class TestRefactoringTool(unittest.TestCase): 35 36 def setUp(self): 37 sys.path.append(FIXER_DIR) 38 39 def tearDown(self): 40 sys.path.pop() 41 42 def check_instances(self, instances, classes): 43 for inst, cls in zip(instances, classes): 44 if not isinstance(inst, cls): 45 self.fail("%s are not instances of %s" % instances, classes) 46 47 def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None): 48 return refactor.RefactoringTool(fixers, options, explicit) 49 50 def test_print_function_option(self): 51 rt = self.rt({"print_function" : True}) 52 self.assertTrue(rt.grammar is pygram.python_grammar_no_print_statement) 53 self.assertTrue(rt.driver.grammar is 54 pygram.python_grammar_no_print_statement) 55 56 def test_write_unchanged_files_option(self): 57 rt = self.rt() 58 self.assertFalse(rt.write_unchanged_files) 59 rt = self.rt({"write_unchanged_files" : True}) 60 self.assertTrue(rt.write_unchanged_files) 61 62 def test_fixer_loading_helpers(self): 63 contents = ["explicit", "first", "last", "parrot", "preorder"] 64 non_prefixed = refactor.get_all_fix_names("myfixes") 65 prefixed = refactor.get_all_fix_names("myfixes", False) 66 full_names = refactor.get_fixers_from_package("myfixes") 67 self.assertEqual(prefixed, ["fix_" + name for name in contents]) 68 self.assertEqual(non_prefixed, contents) 69 self.assertEqual(full_names, 70 ["myfixes.fix_" + name for name in contents]) 71 72 def test_detect_future_features(self): 73 run = refactor._detect_future_features 74 fs = frozenset 75 empty = fs() 76 self.assertEqual(run(""), empty) 77 self.assertEqual(run("from __future__ import print_function"), 78 fs(("print_function",))) 79 self.assertEqual(run("from __future__ import generators"), 80 fs(("generators",))) 81 self.assertEqual(run("from __future__ import generators, feature"), 82 fs(("generators", "feature"))) 83 inp = "from __future__ import generators, print_function" 84 self.assertEqual(run(inp), fs(("generators", "print_function"))) 85 inp ="from __future__ import print_function, generators" 86 self.assertEqual(run(inp), fs(("print_function", "generators"))) 87 inp = "from __future__ import (print_function,)" 88 self.assertEqual(run(inp), fs(("print_function",))) 89 inp = "from __future__ import (generators, print_function)" 90 self.assertEqual(run(inp), fs(("generators", "print_function"))) 91 inp = "from __future__ import (generators, nested_scopes)" 92 self.assertEqual(run(inp), fs(("generators", "nested_scopes"))) 93 inp = """from __future__ import generators 94 from __future__ import print_function""" 95 self.assertEqual(run(inp), fs(("generators", "print_function"))) 96 invalid = ("from", 97 "from 4", 98 "from x", 99 "from x 5", 100 "from x im", 101 "from x import", 102 "from x import 4", 103 ) 104 for inp in invalid: 105 self.assertEqual(run(inp), empty) 106 inp = "'docstring'\nfrom __future__ import print_function" 107 self.assertEqual(run(inp), fs(("print_function",))) 108 inp = "'docstring'\n'somng'\nfrom __future__ import print_function" 109 self.assertEqual(run(inp), empty) 110 inp = "# comment\nfrom __future__ import print_function" 111 self.assertEqual(run(inp), fs(("print_function",))) 112 inp = "# comment\n'doc'\nfrom __future__ import print_function" 113 self.assertEqual(run(inp), fs(("print_function",))) 114 inp = "class x: pass\nfrom __future__ import print_function" 115 self.assertEqual(run(inp), empty) 116 117 def test_get_headnode_dict(self): 118 class NoneFix(fixer_base.BaseFix): 119 pass 120 121 class FileInputFix(fixer_base.BaseFix): 122 PATTERN = "file_input< any * >" 123 124 class SimpleFix(fixer_base.BaseFix): 125 PATTERN = "'name'" 126 127 no_head = NoneFix({}, []) 128 with_head = FileInputFix({}, []) 129 simple = SimpleFix({}, []) 130 d = refactor._get_headnode_dict([no_head, with_head, simple]) 131 top_fixes = d.pop(pygram.python_symbols.file_input) 132 self.assertEqual(top_fixes, [with_head, no_head]) 133 name_fixes = d.pop(token.NAME) 134 self.assertEqual(name_fixes, [simple, no_head]) 135 for fixes in d.itervalues(): 136 self.assertEqual(fixes, [no_head]) 137 138 def test_fixer_loading(self): 139 from myfixes.fix_first import FixFirst 140 from myfixes.fix_last import FixLast 141 from myfixes.fix_parrot import FixParrot 142 from myfixes.fix_preorder import FixPreorder 143 144 rt = self.rt() 145 pre, post = rt.get_fixers() 146 147 self.check_instances(pre, [FixPreorder]) 148 self.check_instances(post, [FixFirst, FixParrot, FixLast]) 149 150 def test_naughty_fixers(self): 151 self.assertRaises(ImportError, self.rt, fixers=["not_here"]) 152 self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"]) 153 self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"]) 154 155 def test_refactor_string(self): 156 rt = self.rt() 157 input = "def parrot(): pass\n\n" 158 tree = rt.refactor_string(input, "<test>") 159 self.assertNotEqual(str(tree), input) 160 161 input = "def f(): pass\n\n" 162 tree = rt.refactor_string(input, "<test>") 163 self.assertEqual(str(tree), input) 164 165 def test_refactor_stdin(self): 166 167 class MyRT(refactor.RefactoringTool): 168 169 def print_output(self, old_text, new_text, filename, equal): 170 results.extend([old_text, new_text, filename, equal]) 171 172 results = [] 173 rt = MyRT(_DEFAULT_FIXERS) 174 save = sys.stdin 175 sys.stdin = StringIO.StringIO("def parrot(): pass\n\n") 176 try: 177 rt.refactor_stdin() 178 finally: 179 sys.stdin = save 180 expected = ["def parrot(): pass\n\n", 181 "def cheese(): pass\n\n", 182 "<stdin>", False] 183 self.assertEqual(results, expected) 184 185 def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS, 186 options=None, mock_log_debug=None, 187 actually_write=True): 188 tmpdir = tempfile.mkdtemp(prefix="2to3-test_refactor") 189 self.addCleanup(shutil.rmtree, tmpdir) 190 # make a copy of the tested file that we can write to 191 shutil.copy(test_file, tmpdir) 192 test_file = os.path.join(tmpdir, os.path.basename(test_file)) 193 os.chmod(test_file, 0o644) 194 195 def read_file(): 196 with open(test_file, "rb") as fp: 197 return fp.read() 198 199 old_contents = read_file() 200 rt = self.rt(fixers=fixers, options=options) 201 if mock_log_debug: 202 rt.log_debug = mock_log_debug 203 204 rt.refactor_file(test_file) 205 self.assertEqual(old_contents, read_file()) 206 207 if not actually_write: 208 return 209 rt.refactor_file(test_file, True) 210 new_contents = read_file() 211 self.assertNotEqual(old_contents, new_contents) 212 return new_contents 213 214 def test_refactor_file(self): 215 test_file = os.path.join(FIXER_DIR, "parrot_example.py") 216 self.check_file_refactoring(test_file, _DEFAULT_FIXERS) 217 218 def test_refactor_file_write_unchanged_file(self): 219 test_file = os.path.join(FIXER_DIR, "parrot_example.py") 220 debug_messages = [] 221 def recording_log_debug(msg, *args): 222 debug_messages.append(msg % args) 223 self.check_file_refactoring(test_file, fixers=(), 224 options={"write_unchanged_files": True}, 225 mock_log_debug=recording_log_debug, 226 actually_write=False) 227 # Testing that it logged this message when write=False was passed is 228 # sufficient to see that it did not bail early after "No changes". 229 message_regex = r"Not writing changes to .*%s%s" % ( 230 os.sep, os.path.basename(test_file)) 231 for message in debug_messages: 232 if "Not writing changes" in message: 233 self.assertRegexpMatches(message, message_regex) 234 break 235 else: 236 self.fail("%r not matched in %r" % (message_regex, debug_messages)) 237 238 def test_refactor_dir(self): 239 def check(structure, expected): 240 def mock_refactor_file(self, f, *args): 241 got.append(f) 242 save_func = refactor.RefactoringTool.refactor_file 243 refactor.RefactoringTool.refactor_file = mock_refactor_file 244 rt = self.rt() 245 got = [] 246 dir = tempfile.mkdtemp(prefix="2to3-test_refactor") 247 try: 248 os.mkdir(os.path.join(dir, "a_dir")) 249 for fn in structure: 250 open(os.path.join(dir, fn), "wb").close() 251 rt.refactor_dir(dir) 252 finally: 253 refactor.RefactoringTool.refactor_file = save_func 254 shutil.rmtree(dir) 255 self.assertEqual(got, 256 [os.path.join(dir, path) for path in expected]) 257 check([], []) 258 tree = ["nothing", 259 "hi.py", 260 ".dumb", 261 ".after.py", 262 "notpy.npy", 263 "sappy"] 264 expected = ["hi.py"] 265 check(tree, expected) 266 tree = ["hi.py", 267 os.path.join("a_dir", "stuff.py")] 268 check(tree, tree) 269 270 def test_file_encoding(self): 271 fn = os.path.join(TEST_DATA_DIR, "different_encoding.py") 272 self.check_file_refactoring(fn) 273 274 def test_bom(self): 275 fn = os.path.join(TEST_DATA_DIR, "bom.py") 276 data = self.check_file_refactoring(fn) 277 self.assertTrue(data.startswith(codecs.BOM_UTF8)) 278 279 def test_crlf_newlines(self): 280 old_sep = os.linesep 281 os.linesep = "\r\n" 282 try: 283 fn = os.path.join(TEST_DATA_DIR, "crlf.py") 284 fixes = refactor.get_fixers_from_package("lib2to3.fixes") 285 self.check_file_refactoring(fn, fixes) 286 finally: 287 os.linesep = old_sep 288 289 def test_refactor_docstring(self): 290 rt = self.rt() 291 292 doc = """ 293 >>> example() 294 42 295 """ 296 out = rt.refactor_docstring(doc, "<test>") 297 self.assertEqual(out, doc) 298 299 doc = """ 300 >>> def parrot(): 301 ... return 43 302 """ 303 out = rt.refactor_docstring(doc, "<test>") 304 self.assertNotEqual(out, doc) 305 306 def test_explicit(self): 307 from myfixes.fix_explicit import FixExplicit 308 309 rt = self.rt(fixers=["myfixes.fix_explicit"]) 310 self.assertEqual(len(rt.post_order), 0) 311 312 rt = self.rt(explicit=["myfixes.fix_explicit"]) 313 for fix in rt.post_order: 314 if isinstance(fix, FixExplicit): 315 break 316 else: 317 self.fail("explicit fixer not loaded") 318