Home | History | Annotate | Download | only in nir
      1 #! /usr/bin/env python
      2 #
      3 # Copyright (C) 2014 Intel Corporation
      4 #
      5 # Permission is hereby granted, free of charge, to any person obtaining a
      6 # copy of this software and associated documentation files (the "Software"),
      7 # to deal in the Software without restriction, including without limitation
      8 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
      9 # and/or sell copies of the Software, and to permit persons to whom the
     10 # Software is furnished to do so, subject to the following conditions:
     11 #
     12 # The above copyright notice and this permission notice (including the next
     13 # paragraph) shall be included in all copies or substantial portions of the
     14 # Software.
     15 #
     16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
     17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
     18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
     19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
     20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
     21 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
     22 # IN THE SOFTWARE.
     23 #
     24 # Authors:
     25 #    Jason Ekstrand (jason (at] jlekstrand.net)
     26 
     27 from __future__ import print_function
     28 import ast
     29 import itertools
     30 import struct
     31 import sys
     32 import mako.template
     33 import re
     34 import traceback
     35 
     36 from nir_opcodes import opcodes
     37 
     38 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
     39 
     40 def type_bits(type_str):
     41    m = _type_re.match(type_str)
     42    assert m.group('type')
     43 
     44    if m.group('bits') is None:
     45       return 0
     46    else:
     47       return int(m.group('bits'))
     48 
     49 # Represents a set of variables, each with a unique id
     50 class VarSet(object):
     51    def __init__(self):
     52       self.names = {}
     53       self.ids = itertools.count()
     54       self.immutable = False;
     55 
     56    def __getitem__(self, name):
     57       if name not in self.names:
     58          assert not self.immutable, "Unknown replacement variable: " + name
     59          self.names[name] = self.ids.next()
     60 
     61       return self.names[name]
     62 
     63    def lock(self):
     64       self.immutable = True
     65 
     66 class Value(object):
     67    @staticmethod
     68    def create(val, name_base, varset):
     69       if isinstance(val, tuple):
     70          return Expression(val, name_base, varset)
     71       elif isinstance(val, Expression):
     72          return val
     73       elif isinstance(val, (str, unicode)):
     74          return Variable(val, name_base, varset)
     75       elif isinstance(val, (bool, int, long, float)):
     76          return Constant(val, name_base)
     77 
     78    __template = mako.template.Template("""
     79 #include "compiler/nir/nir_search_helpers.h"
     80 static const ${val.c_type} ${val.name} = {
     81    { ${val.type_enum}, ${val.bit_size} },
     82 % if isinstance(val, Constant):
     83    ${val.type()}, { ${hex(val)} /* ${val.value} */ },
     84 % elif isinstance(val, Variable):
     85    ${val.index}, /* ${val.var_name} */
     86    ${'true' if val.is_constant else 'false'},
     87    ${val.type() or 'nir_type_invalid' },
     88    ${val.cond if val.cond else 'NULL'},
     89 % elif isinstance(val, Expression):
     90    ${'true' if val.inexact else 'false'},
     91    nir_op_${val.opcode},
     92    { ${', '.join(src.c_ptr for src in val.sources)} },
     93    ${val.cond if val.cond else 'NULL'},
     94 % endif
     95 };""")
     96 
     97    def __init__(self, name, type_str):
     98       self.name = name
     99       self.type_str = type_str
    100 
    101    @property
    102    def type_enum(self):
    103       return "nir_search_value_" + self.type_str
    104 
    105    @property
    106    def c_type(self):
    107       return "nir_search_" + self.type_str
    108 
    109    @property
    110    def c_ptr(self):
    111       return "&{0}.value".format(self.name)
    112 
    113    def render(self):
    114       return self.__template.render(val=self,
    115                                     Constant=Constant,
    116                                     Variable=Variable,
    117                                     Expression=Expression)
    118 
    119 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
    120 
    121 class Constant(Value):
    122    def __init__(self, val, name):
    123       Value.__init__(self, name, "constant")
    124 
    125       if isinstance(val, (str)):
    126          m = _constant_re.match(val)
    127          self.value = ast.literal_eval(m.group('value'))
    128          self.bit_size = int(m.group('bits')) if m.group('bits') else 0
    129       else:
    130          self.value = val
    131          self.bit_size = 0
    132 
    133       if isinstance(self.value, bool):
    134          assert self.bit_size == 0 or self.bit_size == 32
    135          self.bit_size = 32
    136 
    137    def __hex__(self):
    138       if isinstance(self.value, (bool)):
    139          return 'NIR_TRUE' if self.value else 'NIR_FALSE'
    140       if isinstance(self.value, (int, long)):
    141          return hex(self.value)
    142       elif isinstance(self.value, float):
    143          return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
    144       else:
    145          assert False
    146 
    147    def type(self):
    148       if isinstance(self.value, (bool)):
    149          return "nir_type_bool32"
    150       elif isinstance(self.value, (int, long)):
    151          return "nir_type_int"
    152       elif isinstance(self.value, float):
    153          return "nir_type_float"
    154 
    155 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
    156                           r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
    157                           r"(?P<cond>\([^\)]+\))?")
    158 
    159 class Variable(Value):
    160    def __init__(self, val, name, varset):
    161       Value.__init__(self, name, "variable")
    162 
    163       m = _var_name_re.match(val)
    164       assert m and m.group('name') is not None
    165 
    166       self.var_name = m.group('name')
    167       self.is_constant = m.group('const') is not None
    168       self.cond = m.group('cond')
    169       self.required_type = m.group('type')
    170       self.bit_size = int(m.group('bits')) if m.group('bits') else 0
    171 
    172       if self.required_type == 'bool':
    173          assert self.bit_size == 0 or self.bit_size == 32
    174          self.bit_size = 32
    175 
    176       if self.required_type is not None:
    177          assert self.required_type in ('float', 'bool', 'int', 'uint')
    178 
    179       self.index = varset[self.var_name]
    180 
    181    def type(self):
    182       if self.required_type == 'bool':
    183          return "nir_type_bool32"
    184       elif self.required_type in ('int', 'uint'):
    185          return "nir_type_int"
    186       elif self.required_type == 'float':
    187          return "nir_type_float"
    188 
    189 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
    190                         r"(?P<cond>\([^\)]+\))?")
    191 
    192 class Expression(Value):
    193    def __init__(self, expr, name_base, varset):
    194       Value.__init__(self, name_base, "expression")
    195       assert isinstance(expr, tuple)
    196 
    197       m = _opcode_re.match(expr[0])
    198       assert m and m.group('opcode') is not None
    199 
    200       self.opcode = m.group('opcode')
    201       self.bit_size = int(m.group('bits')) if m.group('bits') else 0
    202       self.inexact = m.group('inexact') is not None
    203       self.cond = m.group('cond')
    204       self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
    205                        for (i, src) in enumerate(expr[1:]) ]
    206 
    207    def render(self):
    208       srcs = "\n".join(src.render() for src in self.sources)
    209       return srcs + super(Expression, self).render()
    210 
    211 class IntEquivalenceRelation(object):
    212    """A class representing an equivalence relation on integers.
    213 
    214    Each integer has a canonical form which is the maximum integer to which it
    215    is equivalent.  Two integers are equivalent precisely when they have the
    216    same canonical form.
    217 
    218    The convention of maximum is explicitly chosen to make using it in
    219    BitSizeValidator easier because it means that an actual bit_size (if any)
    220    will always be the canonical form.
    221    """
    222    def __init__(self):
    223       self._remap = {}
    224 
    225    def get_canonical(self, x):
    226       """Get the canonical integer corresponding to x."""
    227       if x in self._remap:
    228          return self.get_canonical(self._remap[x])
    229       else:
    230          return x
    231 
    232    def add_equiv(self, a, b):
    233       """Add an equivalence and return the canonical form."""
    234       c = max(self.get_canonical(a), self.get_canonical(b))
    235       if a != c:
    236          assert a < c
    237          self._remap[a] = c
    238 
    239       if b != c:
    240          assert b < c
    241          self._remap[b] = c
    242 
    243       return c
    244 
    245 class BitSizeValidator(object):
    246    """A class for validating bit sizes of expressions.
    247 
    248    NIR supports multiple bit-sizes on expressions in order to handle things
    249    such as fp64.  The source and destination of every ALU operation is
    250    assigned a type and that type may or may not specify a bit size.  Sources
    251    and destinations whose type does not specify a bit size are considered
    252    "unsized" and automatically take on the bit size of the corresponding
    253    register or SSA value.  NIR has two simple rules for bit sizes that are
    254    validated by nir_validator:
    255 
    256     1) A given SSA def or register has a single bit size that is respected by
    257        everything that reads from it or writes to it.
    258 
    259     2) The bit sizes of all unsized inputs/outputs on any given ALU
    260        instruction must match.  They need not match the sized inputs or
    261        outputs but they must match each other.
    262 
    263    In order to keep nir_algebraic relatively simple and easy-to-use,
    264    nir_search supports a type of bit-size inference based on the two rules
    265    above.  This is similar to type inference in many common programming
    266    languages.  If, for instance, you are constructing an add operation and you
    267    know the second source is 16-bit, then you know that the other source and
    268    the destination must also be 16-bit.  There are, however, cases where this
    269    inference can be ambiguous or contradictory.  Consider, for instance, the
    270    following transformation:
    271 
    272    (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
    273 
    274    This transformation can potentially cause a problem because usub_borrow is
    275    well-defined for any bit-size of integer.  However, b2i always generates a
    276    32-bit result so it could end up replacing a 64-bit expression with one
    277    that takes two 64-bit values and produces a 32-bit value.  As another
    278    example, consider this expression:
    279 
    280    (('bcsel', a, b, 0), ('iand', a, b))
    281 
    282    In this case, in the search expression a must be 32-bit but b can
    283    potentially have any bit size.  If we had a 64-bit b value, we would end up
    284    trying to and a 32-bit value with a 64-bit value which would be invalid
    285 
    286    This class solves that problem by providing a validation layer that proves
    287    that a given search-and-replace operation is 100% well-defined before we
    288    generate any code.  This ensures that bugs are caught at compile time
    289    rather than at run time.
    290 
    291    The basic operation of the validator is very similar to the bitsize_tree in
    292    nir_search only a little more subtle.  Instead of simply tracking bit
    293    sizes, it tracks "bit classes" where each class is represented by an
    294    integer.  A value of 0 means we don't know anything yet, positive values
    295    are actual bit-sizes, and negative values are used to track equivalence
    296    classes of sizes that must be the same but have yet to receive an actual
    297    size.  The first stage uses the bitsize_tree algorithm to assign bit
    298    classes to each variable.  If it ever comes across an inconsistency, it
    299    assert-fails.  Then the second stage uses that information to prove that
    300    the resulting expression can always validly be constructed.
    301    """
    302 
    303    def __init__(self, varset):
    304       self._num_classes = 0
    305       self._var_classes = [0] * len(varset.names)
    306       self._class_relation = IntEquivalenceRelation()
    307 
    308    def validate(self, search, replace):
    309       dst_class = self._propagate_bit_size_up(search)
    310       if dst_class == 0:
    311          dst_class = self._new_class()
    312       self._propagate_bit_class_down(search, dst_class)
    313 
    314       validate_dst_class = self._validate_bit_class_up(replace)
    315       assert validate_dst_class == 0 or validate_dst_class == dst_class
    316       self._validate_bit_class_down(replace, dst_class)
    317 
    318    def _new_class(self):
    319       self._num_classes += 1
    320       return -self._num_classes
    321 
    322    def _set_var_bit_class(self, var_id, bit_class):
    323       assert bit_class != 0
    324       var_class = self._var_classes[var_id]
    325       if var_class == 0:
    326          self._var_classes[var_id] = bit_class
    327       else:
    328          canon_class = self._class_relation.get_canonical(var_class)
    329          assert canon_class < 0 or canon_class == bit_class
    330          var_class = self._class_relation.add_equiv(var_class, bit_class)
    331          self._var_classes[var_id] = var_class
    332 
    333    def _get_var_bit_class(self, var_id):
    334       return self._class_relation.get_canonical(self._var_classes[var_id])
    335 
    336    def _propagate_bit_size_up(self, val):
    337       if isinstance(val, (Constant, Variable)):
    338          return val.bit_size
    339 
    340       elif isinstance(val, Expression):
    341          nir_op = opcodes[val.opcode]
    342          val.common_size = 0
    343          for i in range(nir_op.num_inputs):
    344             src_bits = self._propagate_bit_size_up(val.sources[i])
    345             if src_bits == 0:
    346                continue
    347 
    348             src_type_bits = type_bits(nir_op.input_types[i])
    349             if src_type_bits != 0:
    350                assert src_bits == src_type_bits
    351             else:
    352                assert val.common_size == 0 or src_bits == val.common_size
    353                val.common_size = src_bits
    354 
    355          dst_type_bits = type_bits(nir_op.output_type)
    356          if dst_type_bits != 0:
    357             assert val.bit_size == 0 or val.bit_size == dst_type_bits
    358             return dst_type_bits
    359          else:
    360             if val.common_size != 0:
    361                assert val.bit_size == 0 or val.bit_size == val.common_size
    362             else:
    363                val.common_size = val.bit_size
    364             return val.common_size
    365 
    366    def _propagate_bit_class_down(self, val, bit_class):
    367       if isinstance(val, Constant):
    368          assert val.bit_size == 0 or val.bit_size == bit_class
    369 
    370       elif isinstance(val, Variable):
    371          assert val.bit_size == 0 or val.bit_size == bit_class
    372          self._set_var_bit_class(val.index, bit_class)
    373 
    374       elif isinstance(val, Expression):
    375          nir_op = opcodes[val.opcode]
    376          dst_type_bits = type_bits(nir_op.output_type)
    377          if dst_type_bits != 0:
    378             assert bit_class == 0 or bit_class == dst_type_bits
    379          else:
    380             assert val.common_size == 0 or val.common_size == bit_class
    381             val.common_size = bit_class
    382 
    383          if val.common_size:
    384             common_class = val.common_size
    385          elif nir_op.num_inputs:
    386             # If we got here then we have no idea what the actual size is.
    387             # Instead, we use a generic class
    388             common_class = self._new_class()
    389 
    390          for i in range(nir_op.num_inputs):
    391             src_type_bits = type_bits(nir_op.input_types[i])
    392             if src_type_bits != 0:
    393                self._propagate_bit_class_down(val.sources[i], src_type_bits)
    394             else:
    395                self._propagate_bit_class_down(val.sources[i], common_class)
    396 
    397    def _validate_bit_class_up(self, val):
    398       if isinstance(val, Constant):
    399          return val.bit_size
    400 
    401       elif isinstance(val, Variable):
    402          var_class = self._get_var_bit_class(val.index)
    403          # By the time we get to validation, every variable should have a class
    404          assert var_class != 0
    405 
    406          # If we have an explicit size provided by the user, the variable
    407          # *must* exactly match the search.  It cannot be implicitly sized
    408          # because otherwise we could end up with a conflict at runtime.
    409          assert val.bit_size == 0 or val.bit_size == var_class
    410 
    411          return var_class
    412 
    413       elif isinstance(val, Expression):
    414          nir_op = opcodes[val.opcode]
    415          val.common_class = 0
    416          for i in range(nir_op.num_inputs):
    417             src_class = self._validate_bit_class_up(val.sources[i])
    418             if src_class == 0:
    419                continue
    420 
    421             src_type_bits = type_bits(nir_op.input_types[i])
    422             if src_type_bits != 0:
    423                assert src_class == src_type_bits
    424             else:
    425                assert val.common_class == 0 or src_class == val.common_class
    426                val.common_class = src_class
    427 
    428          dst_type_bits = type_bits(nir_op.output_type)
    429          if dst_type_bits != 0:
    430             assert val.bit_size == 0 or val.bit_size == dst_type_bits
    431             return dst_type_bits
    432          else:
    433             if val.common_class != 0:
    434                assert val.bit_size == 0 or val.bit_size == val.common_class
    435             else:
    436                val.common_class = val.bit_size
    437             return val.common_class
    438 
    439    def _validate_bit_class_down(self, val, bit_class):
    440       # At this point, everything *must* have a bit class.  Otherwise, we have
    441       # a value we don't know how to define.
    442       assert bit_class != 0
    443 
    444       if isinstance(val, Constant):
    445          assert val.bit_size == 0 or val.bit_size == bit_class
    446 
    447       elif isinstance(val, Variable):
    448          assert val.bit_size == 0 or val.bit_size == bit_class
    449 
    450       elif isinstance(val, Expression):
    451          nir_op = opcodes[val.opcode]
    452          dst_type_bits = type_bits(nir_op.output_type)
    453          if dst_type_bits != 0:
    454             assert bit_class == dst_type_bits
    455          else:
    456             assert val.common_class == 0 or val.common_class == bit_class
    457             val.common_class = bit_class
    458 
    459          for i in range(nir_op.num_inputs):
    460             src_type_bits = type_bits(nir_op.input_types[i])
    461             if src_type_bits != 0:
    462                self._validate_bit_class_down(val.sources[i], src_type_bits)
    463             else:
    464                self._validate_bit_class_down(val.sources[i], val.common_class)
    465 
    466 _optimization_ids = itertools.count()
    467 
    468 condition_list = ['true']
    469 
    470 class SearchAndReplace(object):
    471    def __init__(self, transform):
    472       self.id = _optimization_ids.next()
    473 
    474       search = transform[0]
    475       replace = transform[1]
    476       if len(transform) > 2:
    477          self.condition = transform[2]
    478       else:
    479          self.condition = 'true'
    480 
    481       if self.condition not in condition_list:
    482          condition_list.append(self.condition)
    483       self.condition_index = condition_list.index(self.condition)
    484 
    485       varset = VarSet()
    486       if isinstance(search, Expression):
    487          self.search = search
    488       else:
    489          self.search = Expression(search, "search{0}".format(self.id), varset)
    490 
    491       varset.lock()
    492 
    493       if isinstance(replace, Value):
    494          self.replace = replace
    495       else:
    496          self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
    497 
    498       BitSizeValidator(varset).validate(self.search, self.replace)
    499 
    500 _algebraic_pass_template = mako.template.Template("""
    501 #include "nir.h"
    502 #include "nir_search.h"
    503 
    504 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
    505 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
    506 
    507 struct transform {
    508    const nir_search_expression *search;
    509    const nir_search_value *replace;
    510    unsigned condition_offset;
    511 };
    512 
    513 #endif
    514 
    515 % for (opcode, xform_list) in xform_dict.iteritems():
    516 % for xform in xform_list:
    517    ${xform.search.render()}
    518    ${xform.replace.render()}
    519 % endfor
    520 
    521 static const struct transform ${pass_name}_${opcode}_xforms[] = {
    522 % for xform in xform_list:
    523    { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
    524 % endfor
    525 };
    526 % endfor
    527 
    528 static bool
    529 ${pass_name}_block(nir_block *block, const bool *condition_flags,
    530                    void *mem_ctx)
    531 {
    532    bool progress = false;
    533 
    534    nir_foreach_instr_reverse_safe(instr, block) {
    535       if (instr->type != nir_instr_type_alu)
    536          continue;
    537 
    538       nir_alu_instr *alu = nir_instr_as_alu(instr);
    539       if (!alu->dest.dest.is_ssa)
    540          continue;
    541 
    542       switch (alu->op) {
    543       % for opcode in xform_dict.keys():
    544       case nir_op_${opcode}:
    545          for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
    546             const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
    547             if (condition_flags[xform->condition_offset] &&
    548                 nir_replace_instr(alu, xform->search, xform->replace,
    549                                   mem_ctx)) {
    550                progress = true;
    551                break;
    552             }
    553          }
    554          break;
    555       % endfor
    556       default:
    557          break;
    558       }
    559    }
    560 
    561    return progress;
    562 }
    563 
    564 static bool
    565 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
    566 {
    567    void *mem_ctx = ralloc_parent(impl);
    568    bool progress = false;
    569 
    570    nir_foreach_block_reverse(block, impl) {
    571       progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
    572    }
    573 
    574    if (progress)
    575       nir_metadata_preserve(impl, nir_metadata_block_index |
    576                                   nir_metadata_dominance);
    577 
    578    return progress;
    579 }
    580 
    581 
    582 bool
    583 ${pass_name}(nir_shader *shader)
    584 {
    585    bool progress = false;
    586    bool condition_flags[${len(condition_list)}];
    587    const nir_shader_compiler_options *options = shader->options;
    588    (void) options;
    589 
    590    % for index, condition in enumerate(condition_list):
    591    condition_flags[${index}] = ${condition};
    592    % endfor
    593 
    594    nir_foreach_function(function, shader) {
    595       if (function->impl)
    596          progress |= ${pass_name}_impl(function->impl, condition_flags);
    597    }
    598 
    599    return progress;
    600 }
    601 """)
    602 
    603 class AlgebraicPass(object):
    604    def __init__(self, pass_name, transforms):
    605       self.xform_dict = {}
    606       self.pass_name = pass_name
    607 
    608       error = False
    609 
    610       for xform in transforms:
    611          if not isinstance(xform, SearchAndReplace):
    612             try:
    613                xform = SearchAndReplace(xform)
    614             except:
    615                print("Failed to parse transformation:", file=sys.stderr)
    616                print("  " + str(xform), file=sys.stderr)
    617                traceback.print_exc(file=sys.stderr)
    618                print('', file=sys.stderr)
    619                error = True
    620                continue
    621 
    622          if xform.search.opcode not in self.xform_dict:
    623             self.xform_dict[xform.search.opcode] = []
    624 
    625          self.xform_dict[xform.search.opcode].append(xform)
    626 
    627       if error:
    628          sys.exit(1)
    629 
    630    def render(self):
    631       return _algebraic_pass_template.render(pass_name=self.pass_name,
    632                                              xform_dict=self.xform_dict,
    633                                              condition_list=condition_list)
    634