Home | History | Annotate | Download | only in scripts
      1 #!/usr/bin/env python
      2 # Copyright (C) 2013 Google Inc. All rights reserved.
      3 #
      4 # Redistribution and use in source and binary forms, with or without
      5 # modification, are permitted provided that the following conditions are
      6 # met:
      7 #
      8 #     * Redistributions of source code must retain the above copyright
      9 # notice, this list of conditions and the following disclaimer.
     10 #     * Redistributions in binary form must reproduce the above
     11 # copyright notice, this list of conditions and the following disclaimer
     12 # in the documentation and/or other materials provided with the
     13 # distribution.
     14 #     * Neither the name of Google Inc. nor the names of its
     15 # contributors may be used to endorse or promote products derived from
     16 # this software without specific prior written permission.
     17 #
     18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     29 
     30 import io
     31 import itertools
     32 import re
     33 import sys
     34 
     35 
     36 class BadInput(Exception):
     37     """Unsupported input has been found."""
     38 
     39 
     40 class SwitchCase(object):
     41     """Represents a CASE block."""
     42     def __init__(self, identifier, block):
     43         self.identifier = identifier
     44         self.block = block
     45 
     46 
     47 class Optimizer(object):
     48     """Generates optimized identifier matching code."""
     49     def __init__(self, output_file, array_variable, length_variable):
     50         self.output_file = output_file
     51         self.array_variable = array_variable
     52         self.length_variable = length_variable
     53 
     54     def inspect(self, cases):
     55         lengths = list(set([len(c.identifier) for c in cases]))
     56         lengths.sort()
     57 
     58         def response(length):
     59             self.inspect_array([c for c in cases if len(c.identifier) == length], range(length))
     60         self.write_selection(self.length_variable, lengths, str, response)
     61 
     62     def score(self, alternatives):
     63         return -sum([len(list(count)) ** 2 for _, count in itertools.groupby(sorted(alternatives))])
     64 
     65     def choose_selection_pos(self, cases, pending):
     66         candidates = [pos for pos in pending if all(alternative.isalpha() for alternative in [c.identifier[pos] for c in cases])]
     67         if not candidates:
     68             raise BadInput('Case-insensitive switching on non-alphabetic characters not yet implemented')
     69         return sorted(candidates, key=lambda pos: self.score([c.identifier[pos] for c in cases]))[0]
     70 
     71     def inspect_array(self, cases, pending):
     72         assert len(cases) >= 1
     73         if pending:
     74             common = [pos for pos in pending
     75                       if len(set([c.identifier[pos] for c in cases])) == 1]
     76             if common:
     77                 identifier = cases[0].identifier
     78                 for index in xrange(len(common)):
     79                     if index == 0:
     80                         self.output_file.write(u'if (LIKELY(')
     81                     else:
     82                         self.output_file.write(u' && ')
     83                     pos = common[index]
     84                     if identifier[pos].isalpha():
     85                         self.output_file.write("(%s[%d] | 0x20) == '%s'" %
     86                                                (self.array_variable, pos, identifier[pos]))
     87                     else:
     88                         self.output_file.write("%s[%d] == '%s'" %
     89                                                (self.array_variable, pos, identifier[pos]))
     90                 self.output_file.write(u')) {\n')
     91                 next_pending = list(set(pending) - set(common))
     92                 next_pending.sort()
     93                 self.inspect_array(cases, next_pending)
     94                 self.output_file.write(u'}\n')
     95             else:
     96                 pos = self.choose_selection_pos(cases, pending)
     97                 next_pending = filter(lambda p: p != pos, pending)
     98 
     99                 alternatives = list(set([c.identifier[pos] for c in cases]))
    100                 alternatives.sort()
    101 
    102                 def literal(alternative):
    103                     if isinstance(alternative, int):
    104                         return str(alternative)
    105                     else:
    106                         return "'%s'" % alternative
    107 
    108                 def response(alternative):
    109                     self.inspect_array([c for c in cases if c.identifier[pos] == alternative],
    110                                        next_pending)
    111 
    112                 expression = '(%s[%d] | 0x20)' % (self.array_variable, pos)
    113                 self.write_selection(expression, alternatives, literal, response)
    114         else:
    115             assert len(cases) == 1
    116             for block_line in cases[0].block:
    117                 self.output_file.write(block_line)
    118 
    119     def write_selection(self, expression, alternatives, literal, response):
    120         if len(alternatives) == 1:
    121             self.output_file.write(u'if (LIKELY(%s == %s)) {\n' % (expression, literal(alternatives[0])))
    122             response(alternatives[0])
    123             self.output_file.write(u'}\n')
    124         elif len(alternatives) == 2:
    125             self.output_file.write(u'if (%s == %s) {\n' % (expression, literal(alternatives[0])))
    126             response(alternatives[0])
    127             self.output_file.write(u'} else if (LIKELY(%s == %s)) {\n' % (expression, literal(alternatives[1])))
    128             response(alternatives[1])
    129             self.output_file.write(u'}\n')
    130         else:
    131             self.output_file.write('switch (%s) {\n' % expression)
    132             for alternative in alternatives:
    133                 self.output_file.write(u'case %s: {\n' % literal(alternative))
    134                 response(alternative)
    135                 self.output_file.write(u'} break;\n')
    136             self.output_file.write(u'}\n')
    137 
    138 
    139 class LineProcessor(object):
    140     def process_line(self, line):
    141         pass
    142 
    143 
    144 class MainLineProcessor(LineProcessor):
    145     """Processes the contents of an input file."""
    146     SWITCH_PATTERN = re.compile(r'\s*SWITCH\s*\((\w*),\s*(\w*)\) \{$')
    147 
    148     def __init__(self, output_file):
    149         self.output_file = output_file
    150 
    151     def process_line(self, line):
    152         match_switch = MainLineProcessor.SWITCH_PATTERN.match(line)
    153         if match_switch:
    154             array_variable = match_switch.group(1)
    155             length_variable = match_switch.group(2)
    156             return SwitchLineProcessor(self, self.output_file, array_variable, length_variable)
    157         else:
    158             self.output_file.write(line)
    159             return self
    160 
    161 
    162 class SwitchLineProcessor(LineProcessor):
    163     """Processes the contents of a SWITCH block."""
    164     CASE_PATTERN = re.compile(r'\s*CASE\s*\(\"([a-z0-9_\-\(]*)\"\) \{$')
    165     CLOSE_BRACE_PATTERN = re.compile(r'\s*\}$')
    166     EMPTY_PATTERN = re.compile(r'\s*$')
    167 
    168     def __init__(self, parent, output_file, array_variable, length_variable):
    169         self.parent = parent
    170         self.output_file = output_file
    171         self.array_variable = array_variable
    172         self.length_variable = length_variable
    173         self.cases = []
    174 
    175     def process_line(self, line):
    176         match_case = SwitchLineProcessor.CASE_PATTERN.match(line)
    177         match_close_brace = SwitchLineProcessor.CLOSE_BRACE_PATTERN.match(line)
    178         match_empty = SwitchLineProcessor.EMPTY_PATTERN.match(line)
    179         if match_case:
    180             identifier = match_case.group(1)
    181             return CaseLineProcessor(self, self.output_file, identifier)
    182         elif match_close_brace:
    183             Optimizer(self.output_file, self.array_variable, self.length_variable).inspect(self.cases)
    184             return self.parent
    185         elif match_empty:
    186             return self
    187         else:
    188             raise BadInput('Invalid line within SWITCH: %s' % line)
    189 
    190     def add_case(self, latest_case):
    191         if latest_case.identifier in [c.identifier for c in self.cases]:
    192             raise BadInput('Repeated case: %s' % latest_case.identifier)
    193         self.cases.append(latest_case)
    194 
    195 
    196 class CaseLineProcessor(LineProcessor):
    197     """Processes the contents of a CASE block."""
    198     CLOSE_BRACE_PATTERN = re.compile(r'\s*\}$')
    199     BREAK_PATTERN = re.compile(r'break;')
    200 
    201     def __init__(self, parent, output_file, identifier):
    202         self.parent = parent
    203         self.output_file = output_file
    204         self.identifier = identifier
    205         self.block = []
    206 
    207     def process_line(self, line):
    208         match_close_brace = CaseLineProcessor.CLOSE_BRACE_PATTERN.match(line)
    209         match_break = CaseLineProcessor.BREAK_PATTERN.search(line)
    210         if match_close_brace:
    211             self.parent.add_case(SwitchCase(self.identifier, self.block))
    212             return self.parent
    213         elif match_break:
    214             raise BadInput('break within CASE not supported: %s' % line)
    215         else:
    216             self.block.append(line)
    217             return self
    218 
    219 
    220 def process_file(input_name, output_name):
    221     """Transforms input file into legal C++ source code."""
    222     with io.open(input_name, 'r', -1, 'utf-8') as input_file:
    223         with io.open(output_name, 'w', -1, 'utf-8') as output_file:
    224             processor = MainLineProcessor(output_file)
    225             input_lines = input_file.readlines()
    226             for line in input_lines:
    227                 processor = processor.process_line(line)
    228 
    229 
    230 if __name__ == '__main__':
    231         process_file(sys.argv[1], sys.argv[2])
    232