Home | History | Annotate | Download | only in fixes
      1 """Fixer for operator functions.
      2 
      3 operator.isCallable(obj)       -> callable(obj)
      4 operator.sequenceIncludes(obj) -> operator.contains(obj)
      5 operator.isSequenceType(obj)   -> isinstance(obj, collections.abc.Sequence)
      6 operator.isMappingType(obj)    -> isinstance(obj, collections.abc.Mapping)
      7 operator.isNumberType(obj)     -> isinstance(obj, numbers.Number)
      8 operator.repeat(obj, n)        -> operator.mul(obj, n)
      9 operator.irepeat(obj, n)       -> operator.imul(obj, n)
     10 """
     11 
     12 import collections.abc
     13 
     14 # Local imports
     15 from lib2to3 import fixer_base
     16 from lib2to3.fixer_util import Call, Name, String, touch_import
     17 
     18 
     19 def invocation(s):
     20     def dec(f):
     21         f.invocation = s
     22         return f
     23     return dec
     24 
     25 
     26 class FixOperator(fixer_base.BaseFix):
     27     BM_compatible = True
     28     order = "pre"
     29 
     30     methods = """
     31               method=('isCallable'|'sequenceIncludes'
     32                      |'isSequenceType'|'isMappingType'|'isNumberType'
     33                      |'repeat'|'irepeat')
     34               """
     35     obj = "'(' obj=any ')'"
     36     PATTERN = """
     37               power< module='operator'
     38                 trailer< '.' %(methods)s > trailer< %(obj)s > >
     39               |
     40               power< %(methods)s trailer< %(obj)s > >
     41               """ % dict(methods=methods, obj=obj)
     42 
     43     def transform(self, node, results):
     44         method = self._check_method(node, results)
     45         if method is not None:
     46             return method(node, results)
     47 
     48     @invocation("operator.contains(%s)")
     49     def _sequenceIncludes(self, node, results):
     50         return self._handle_rename(node, results, "contains")
     51 
     52     @invocation("callable(%s)")
     53     def _isCallable(self, node, results):
     54         obj = results["obj"]
     55         return Call(Name("callable"), [obj.clone()], prefix=node.prefix)
     56 
     57     @invocation("operator.mul(%s)")
     58     def _repeat(self, node, results):
     59         return self._handle_rename(node, results, "mul")
     60 
     61     @invocation("operator.imul(%s)")
     62     def _irepeat(self, node, results):
     63         return self._handle_rename(node, results, "imul")
     64 
     65     @invocation("isinstance(%s, collections.abc.Sequence)")
     66     def _isSequenceType(self, node, results):
     67         return self._handle_type2abc(node, results, "collections.abc", "Sequence")
     68 
     69     @invocation("isinstance(%s, collections.abc.Mapping)")
     70     def _isMappingType(self, node, results):
     71         return self._handle_type2abc(node, results, "collections.abc", "Mapping")
     72 
     73     @invocation("isinstance(%s, numbers.Number)")
     74     def _isNumberType(self, node, results):
     75         return self._handle_type2abc(node, results, "numbers", "Number")
     76 
     77     def _handle_rename(self, node, results, name):
     78         method = results["method"][0]
     79         method.value = name
     80         method.changed()
     81 
     82     def _handle_type2abc(self, node, results, module, abc):
     83         touch_import(None, module, node)
     84         obj = results["obj"]
     85         args = [obj.clone(), String(", " + ".".join([module, abc]))]
     86         return Call(Name("isinstance"), args, prefix=node.prefix)
     87 
     88     def _check_method(self, node, results):
     89         method = getattr(self, "_" + results["method"][0].value)
     90         if isinstance(method, collections.abc.Callable):
     91             if "module" in results:
     92                 return method
     93             else:
     94                 sub = (str(results["obj"]),)
     95                 invocation_str = method.invocation % sub
     96                 self.warning(node, "You should use '%s' here." % invocation_str)
     97         return None
     98