Home | History | Annotate | Download | only in fixes
      1 """Fixer for operator functions.
      2 
      3 operator.isCallable(obj)       -> hasattr(obj, '__call__')
      4 operator.sequenceIncludes(obj) -> operator.contains(obj)
      5 operator.isSequenceType(obj)   -> isinstance(obj, collections.Sequence)
      6 operator.isMappingType(obj)    -> isinstance(obj, collections.Mapping)
      7 operator.isNumberType(obj)     -> isinstance(obj, numbers.Number)
      8 operator.repeat(obj, n)        -> operator.mul(obj, n)
      9 operator.irepeat(obj, n)       -> operator.imul(obj, n)
     10 """
     11 
     12 # Local imports
     13 from lib2to3 import fixer_base
     14 from lib2to3.fixer_util import Call, Name, String, touch_import
     15 
     16 
     17 def invocation(s):
     18     def dec(f):
     19         f.invocation = s
     20         return f
     21     return dec
     22 
     23 
     24 class FixOperator(fixer_base.BaseFix):
     25     BM_compatible = True
     26     order = "pre"
     27 
     28     methods = """
     29               method=('isCallable'|'sequenceIncludes'
     30                      |'isSequenceType'|'isMappingType'|'isNumberType'
     31                      |'repeat'|'irepeat')
     32               """
     33     obj = "'(' obj=any ')'"
     34     PATTERN = """
     35               power< module='operator'
     36                 trailer< '.' %(methods)s > trailer< %(obj)s > >
     37               |
     38               power< %(methods)s trailer< %(obj)s > >
     39               """ % dict(methods=methods, obj=obj)
     40 
     41     def transform(self, node, results):
     42         method = self._check_method(node, results)
     43         if method is not None:
     44             return method(node, results)
     45 
     46     @invocation("operator.contains(%s)")
     47     def _sequenceIncludes(self, node, results):
     48         return self._handle_rename(node, results, u"contains")
     49 
     50     @invocation("hasattr(%s, '__call__')")
     51     def _isCallable(self, node, results):
     52         obj = results["obj"]
     53         args = [obj.clone(), String(u", "), String(u"'__call__'")]
     54         return Call(Name(u"hasattr"), args, prefix=node.prefix)
     55 
     56     @invocation("operator.mul(%s)")
     57     def _repeat(self, node, results):
     58         return self._handle_rename(node, results, u"mul")
     59 
     60     @invocation("operator.imul(%s)")
     61     def _irepeat(self, node, results):
     62         return self._handle_rename(node, results, u"imul")
     63 
     64     @invocation("isinstance(%s, collections.Sequence)")
     65     def _isSequenceType(self, node, results):
     66         return self._handle_type2abc(node, results, u"collections", u"Sequence")
     67 
     68     @invocation("isinstance(%s, collections.Mapping)")
     69     def _isMappingType(self, node, results):
     70         return self._handle_type2abc(node, results, u"collections", u"Mapping")
     71 
     72     @invocation("isinstance(%s, numbers.Number)")
     73     def _isNumberType(self, node, results):
     74         return self._handle_type2abc(node, results, u"numbers", u"Number")
     75 
     76     def _handle_rename(self, node, results, name):
     77         method = results["method"][0]
     78         method.value = name
     79         method.changed()
     80 
     81     def _handle_type2abc(self, node, results, module, abc):
     82         touch_import(None, module, node)
     83         obj = results["obj"]
     84         args = [obj.clone(), String(u", " + u".".join([module, abc]))]
     85         return Call(Name(u"isinstance"), args, prefix=node.prefix)
     86 
     87     def _check_method(self, node, results):
     88         method = getattr(self, "_" + results["method"][0].value.encode("ascii"))
     89         if callable(method):
     90             if "module" in results:
     91                 return method
     92             else:
     93                 sub = (unicode(results["obj"]),)
     94                 invocation_str = unicode(method.invocation) % sub
     95                 self.warning(node, u"You should use '%s' here." % invocation_str)
     96         return None
     97