1 #! /usr/bin/env python 2 # 3 # Copyright (C) 2014 Intel Corporation 4 # 5 # Permission is hereby granted, free of charge, to any person obtaining a 6 # copy of this software and associated documentation files (the "Software"), 7 # to deal in the Software without restriction, including without limitation 8 # the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 # and/or sell copies of the Software, and to permit persons to whom the 10 # Software is furnished to do so, subject to the following conditions: 11 # 12 # The above copyright notice and this permission notice (including the next 13 # paragraph) shall be included in all copies or substantial portions of the 14 # Software. 15 # 16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 22 # IN THE SOFTWARE. 23 # 24 # Authors: 25 # Jason Ekstrand (jason (at] jlekstrand.net) 26 27 from __future__ import print_function 28 import ast 29 import itertools 30 import struct 31 import sys 32 import mako.template 33 import re 34 import traceback 35 36 from nir_opcodes import opcodes 37 38 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 39 40 def type_bits(type_str): 41 m = _type_re.match(type_str) 42 assert m.group('type') 43 44 if m.group('bits') is None: 45 return 0 46 else: 47 return int(m.group('bits')) 48 49 # Represents a set of variables, each with a unique id 50 class VarSet(object): 51 def __init__(self): 52 self.names = {} 53 self.ids = itertools.count() 54 self.immutable = False; 55 56 def __getitem__(self, name): 57 if name not in self.names: 58 assert not self.immutable, "Unknown replacement variable: " + name 59 self.names[name] = self.ids.next() 60 61 return self.names[name] 62 63 def lock(self): 64 self.immutable = True 65 66 class Value(object): 67 @staticmethod 68 def create(val, name_base, varset): 69 if isinstance(val, tuple): 70 return Expression(val, name_base, varset) 71 elif isinstance(val, Expression): 72 return val 73 elif isinstance(val, (str, unicode)): 74 return Variable(val, name_base, varset) 75 elif isinstance(val, (bool, int, long, float)): 76 return Constant(val, name_base) 77 78 __template = mako.template.Template(""" 79 #include "compiler/nir/nir_search_helpers.h" 80 static const ${val.c_type} ${val.name} = { 81 { ${val.type_enum}, ${val.bit_size} }, 82 % if isinstance(val, Constant): 83 ${val.type()}, { ${hex(val)} /* ${val.value} */ }, 84 % elif isinstance(val, Variable): 85 ${val.index}, /* ${val.var_name} */ 86 ${'true' if val.is_constant else 'false'}, 87 ${val.type() or 'nir_type_invalid' }, 88 ${val.cond if val.cond else 'NULL'}, 89 % elif isinstance(val, Expression): 90 ${'true' if val.inexact else 'false'}, 91 nir_op_${val.opcode}, 92 { ${', '.join(src.c_ptr for src in val.sources)} }, 93 ${val.cond if val.cond else 'NULL'}, 94 % endif 95 };""") 96 97 def __init__(self, name, type_str): 98 self.name = name 99 self.type_str = type_str 100 101 @property 102 def type_enum(self): 103 return "nir_search_value_" + self.type_str 104 105 @property 106 def c_type(self): 107 return "nir_search_" + self.type_str 108 109 @property 110 def c_ptr(self): 111 return "&{0}.value".format(self.name) 112 113 def render(self): 114 return self.__template.render(val=self, 115 Constant=Constant, 116 Variable=Variable, 117 Expression=Expression) 118 119 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 120 121 class Constant(Value): 122 def __init__(self, val, name): 123 Value.__init__(self, name, "constant") 124 125 if isinstance(val, (str)): 126 m = _constant_re.match(val) 127 self.value = ast.literal_eval(m.group('value')) 128 self.bit_size = int(m.group('bits')) if m.group('bits') else 0 129 else: 130 self.value = val 131 self.bit_size = 0 132 133 if isinstance(self.value, bool): 134 assert self.bit_size == 0 or self.bit_size == 32 135 self.bit_size = 32 136 137 def __hex__(self): 138 if isinstance(self.value, (bool)): 139 return 'NIR_TRUE' if self.value else 'NIR_FALSE' 140 if isinstance(self.value, (int, long)): 141 return hex(self.value) 142 elif isinstance(self.value, float): 143 return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) 144 else: 145 assert False 146 147 def type(self): 148 if isinstance(self.value, (bool)): 149 return "nir_type_bool32" 150 elif isinstance(self.value, (int, long)): 151 return "nir_type_int" 152 elif isinstance(self.value, float): 153 return "nir_type_float" 154 155 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 156 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 157 r"(?P<cond>\([^\)]+\))?") 158 159 class Variable(Value): 160 def __init__(self, val, name, varset): 161 Value.__init__(self, name, "variable") 162 163 m = _var_name_re.match(val) 164 assert m and m.group('name') is not None 165 166 self.var_name = m.group('name') 167 self.is_constant = m.group('const') is not None 168 self.cond = m.group('cond') 169 self.required_type = m.group('type') 170 self.bit_size = int(m.group('bits')) if m.group('bits') else 0 171 172 if self.required_type == 'bool': 173 assert self.bit_size == 0 or self.bit_size == 32 174 self.bit_size = 32 175 176 if self.required_type is not None: 177 assert self.required_type in ('float', 'bool', 'int', 'uint') 178 179 self.index = varset[self.var_name] 180 181 def type(self): 182 if self.required_type == 'bool': 183 return "nir_type_bool32" 184 elif self.required_type in ('int', 'uint'): 185 return "nir_type_int" 186 elif self.required_type == 'float': 187 return "nir_type_float" 188 189 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 190 r"(?P<cond>\([^\)]+\))?") 191 192 class Expression(Value): 193 def __init__(self, expr, name_base, varset): 194 Value.__init__(self, name_base, "expression") 195 assert isinstance(expr, tuple) 196 197 m = _opcode_re.match(expr[0]) 198 assert m and m.group('opcode') is not None 199 200 self.opcode = m.group('opcode') 201 self.bit_size = int(m.group('bits')) if m.group('bits') else 0 202 self.inexact = m.group('inexact') is not None 203 self.cond = m.group('cond') 204 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) 205 for (i, src) in enumerate(expr[1:]) ] 206 207 def render(self): 208 srcs = "\n".join(src.render() for src in self.sources) 209 return srcs + super(Expression, self).render() 210 211 class IntEquivalenceRelation(object): 212 """A class representing an equivalence relation on integers. 213 214 Each integer has a canonical form which is the maximum integer to which it 215 is equivalent. Two integers are equivalent precisely when they have the 216 same canonical form. 217 218 The convention of maximum is explicitly chosen to make using it in 219 BitSizeValidator easier because it means that an actual bit_size (if any) 220 will always be the canonical form. 221 """ 222 def __init__(self): 223 self._remap = {} 224 225 def get_canonical(self, x): 226 """Get the canonical integer corresponding to x.""" 227 if x in self._remap: 228 return self.get_canonical(self._remap[x]) 229 else: 230 return x 231 232 def add_equiv(self, a, b): 233 """Add an equivalence and return the canonical form.""" 234 c = max(self.get_canonical(a), self.get_canonical(b)) 235 if a != c: 236 assert a < c 237 self._remap[a] = c 238 239 if b != c: 240 assert b < c 241 self._remap[b] = c 242 243 return c 244 245 class BitSizeValidator(object): 246 """A class for validating bit sizes of expressions. 247 248 NIR supports multiple bit-sizes on expressions in order to handle things 249 such as fp64. The source and destination of every ALU operation is 250 assigned a type and that type may or may not specify a bit size. Sources 251 and destinations whose type does not specify a bit size are considered 252 "unsized" and automatically take on the bit size of the corresponding 253 register or SSA value. NIR has two simple rules for bit sizes that are 254 validated by nir_validator: 255 256 1) A given SSA def or register has a single bit size that is respected by 257 everything that reads from it or writes to it. 258 259 2) The bit sizes of all unsized inputs/outputs on any given ALU 260 instruction must match. They need not match the sized inputs or 261 outputs but they must match each other. 262 263 In order to keep nir_algebraic relatively simple and easy-to-use, 264 nir_search supports a type of bit-size inference based on the two rules 265 above. This is similar to type inference in many common programming 266 languages. If, for instance, you are constructing an add operation and you 267 know the second source is 16-bit, then you know that the other source and 268 the destination must also be 16-bit. There are, however, cases where this 269 inference can be ambiguous or contradictory. Consider, for instance, the 270 following transformation: 271 272 (('usub_borrow', a, b), ('b2i', ('ult', a, b))) 273 274 This transformation can potentially cause a problem because usub_borrow is 275 well-defined for any bit-size of integer. However, b2i always generates a 276 32-bit result so it could end up replacing a 64-bit expression with one 277 that takes two 64-bit values and produces a 32-bit value. As another 278 example, consider this expression: 279 280 (('bcsel', a, b, 0), ('iand', a, b)) 281 282 In this case, in the search expression a must be 32-bit but b can 283 potentially have any bit size. If we had a 64-bit b value, we would end up 284 trying to and a 32-bit value with a 64-bit value which would be invalid 285 286 This class solves that problem by providing a validation layer that proves 287 that a given search-and-replace operation is 100% well-defined before we 288 generate any code. This ensures that bugs are caught at compile time 289 rather than at run time. 290 291 The basic operation of the validator is very similar to the bitsize_tree in 292 nir_search only a little more subtle. Instead of simply tracking bit 293 sizes, it tracks "bit classes" where each class is represented by an 294 integer. A value of 0 means we don't know anything yet, positive values 295 are actual bit-sizes, and negative values are used to track equivalence 296 classes of sizes that must be the same but have yet to receive an actual 297 size. The first stage uses the bitsize_tree algorithm to assign bit 298 classes to each variable. If it ever comes across an inconsistency, it 299 assert-fails. Then the second stage uses that information to prove that 300 the resulting expression can always validly be constructed. 301 """ 302 303 def __init__(self, varset): 304 self._num_classes = 0 305 self._var_classes = [0] * len(varset.names) 306 self._class_relation = IntEquivalenceRelation() 307 308 def validate(self, search, replace): 309 dst_class = self._propagate_bit_size_up(search) 310 if dst_class == 0: 311 dst_class = self._new_class() 312 self._propagate_bit_class_down(search, dst_class) 313 314 validate_dst_class = self._validate_bit_class_up(replace) 315 assert validate_dst_class == 0 or validate_dst_class == dst_class 316 self._validate_bit_class_down(replace, dst_class) 317 318 def _new_class(self): 319 self._num_classes += 1 320 return -self._num_classes 321 322 def _set_var_bit_class(self, var_id, bit_class): 323 assert bit_class != 0 324 var_class = self._var_classes[var_id] 325 if var_class == 0: 326 self._var_classes[var_id] = bit_class 327 else: 328 canon_class = self._class_relation.get_canonical(var_class) 329 assert canon_class < 0 or canon_class == bit_class 330 var_class = self._class_relation.add_equiv(var_class, bit_class) 331 self._var_classes[var_id] = var_class 332 333 def _get_var_bit_class(self, var_id): 334 return self._class_relation.get_canonical(self._var_classes[var_id]) 335 336 def _propagate_bit_size_up(self, val): 337 if isinstance(val, (Constant, Variable)): 338 return val.bit_size 339 340 elif isinstance(val, Expression): 341 nir_op = opcodes[val.opcode] 342 val.common_size = 0 343 for i in range(nir_op.num_inputs): 344 src_bits = self._propagate_bit_size_up(val.sources[i]) 345 if src_bits == 0: 346 continue 347 348 src_type_bits = type_bits(nir_op.input_types[i]) 349 if src_type_bits != 0: 350 assert src_bits == src_type_bits 351 else: 352 assert val.common_size == 0 or src_bits == val.common_size 353 val.common_size = src_bits 354 355 dst_type_bits = type_bits(nir_op.output_type) 356 if dst_type_bits != 0: 357 assert val.bit_size == 0 or val.bit_size == dst_type_bits 358 return dst_type_bits 359 else: 360 if val.common_size != 0: 361 assert val.bit_size == 0 or val.bit_size == val.common_size 362 else: 363 val.common_size = val.bit_size 364 return val.common_size 365 366 def _propagate_bit_class_down(self, val, bit_class): 367 if isinstance(val, Constant): 368 assert val.bit_size == 0 or val.bit_size == bit_class 369 370 elif isinstance(val, Variable): 371 assert val.bit_size == 0 or val.bit_size == bit_class 372 self._set_var_bit_class(val.index, bit_class) 373 374 elif isinstance(val, Expression): 375 nir_op = opcodes[val.opcode] 376 dst_type_bits = type_bits(nir_op.output_type) 377 if dst_type_bits != 0: 378 assert bit_class == 0 or bit_class == dst_type_bits 379 else: 380 assert val.common_size == 0 or val.common_size == bit_class 381 val.common_size = bit_class 382 383 if val.common_size: 384 common_class = val.common_size 385 elif nir_op.num_inputs: 386 # If we got here then we have no idea what the actual size is. 387 # Instead, we use a generic class 388 common_class = self._new_class() 389 390 for i in range(nir_op.num_inputs): 391 src_type_bits = type_bits(nir_op.input_types[i]) 392 if src_type_bits != 0: 393 self._propagate_bit_class_down(val.sources[i], src_type_bits) 394 else: 395 self._propagate_bit_class_down(val.sources[i], common_class) 396 397 def _validate_bit_class_up(self, val): 398 if isinstance(val, Constant): 399 return val.bit_size 400 401 elif isinstance(val, Variable): 402 var_class = self._get_var_bit_class(val.index) 403 # By the time we get to validation, every variable should have a class 404 assert var_class != 0 405 406 # If we have an explicit size provided by the user, the variable 407 # *must* exactly match the search. It cannot be implicitly sized 408 # because otherwise we could end up with a conflict at runtime. 409 assert val.bit_size == 0 or val.bit_size == var_class 410 411 return var_class 412 413 elif isinstance(val, Expression): 414 nir_op = opcodes[val.opcode] 415 val.common_class = 0 416 for i in range(nir_op.num_inputs): 417 src_class = self._validate_bit_class_up(val.sources[i]) 418 if src_class == 0: 419 continue 420 421 src_type_bits = type_bits(nir_op.input_types[i]) 422 if src_type_bits != 0: 423 assert src_class == src_type_bits 424 else: 425 assert val.common_class == 0 or src_class == val.common_class 426 val.common_class = src_class 427 428 dst_type_bits = type_bits(nir_op.output_type) 429 if dst_type_bits != 0: 430 assert val.bit_size == 0 or val.bit_size == dst_type_bits 431 return dst_type_bits 432 else: 433 if val.common_class != 0: 434 assert val.bit_size == 0 or val.bit_size == val.common_class 435 else: 436 val.common_class = val.bit_size 437 return val.common_class 438 439 def _validate_bit_class_down(self, val, bit_class): 440 # At this point, everything *must* have a bit class. Otherwise, we have 441 # a value we don't know how to define. 442 assert bit_class != 0 443 444 if isinstance(val, Constant): 445 assert val.bit_size == 0 or val.bit_size == bit_class 446 447 elif isinstance(val, Variable): 448 assert val.bit_size == 0 or val.bit_size == bit_class 449 450 elif isinstance(val, Expression): 451 nir_op = opcodes[val.opcode] 452 dst_type_bits = type_bits(nir_op.output_type) 453 if dst_type_bits != 0: 454 assert bit_class == dst_type_bits 455 else: 456 assert val.common_class == 0 or val.common_class == bit_class 457 val.common_class = bit_class 458 459 for i in range(nir_op.num_inputs): 460 src_type_bits = type_bits(nir_op.input_types[i]) 461 if src_type_bits != 0: 462 self._validate_bit_class_down(val.sources[i], src_type_bits) 463 else: 464 self._validate_bit_class_down(val.sources[i], val.common_class) 465 466 _optimization_ids = itertools.count() 467 468 condition_list = ['true'] 469 470 class SearchAndReplace(object): 471 def __init__(self, transform): 472 self.id = _optimization_ids.next() 473 474 search = transform[0] 475 replace = transform[1] 476 if len(transform) > 2: 477 self.condition = transform[2] 478 else: 479 self.condition = 'true' 480 481 if self.condition not in condition_list: 482 condition_list.append(self.condition) 483 self.condition_index = condition_list.index(self.condition) 484 485 varset = VarSet() 486 if isinstance(search, Expression): 487 self.search = search 488 else: 489 self.search = Expression(search, "search{0}".format(self.id), varset) 490 491 varset.lock() 492 493 if isinstance(replace, Value): 494 self.replace = replace 495 else: 496 self.replace = Value.create(replace, "replace{0}".format(self.id), varset) 497 498 BitSizeValidator(varset).validate(self.search, self.replace) 499 500 _algebraic_pass_template = mako.template.Template(""" 501 #include "nir.h" 502 #include "nir_search.h" 503 504 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS 505 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS 506 507 struct transform { 508 const nir_search_expression *search; 509 const nir_search_value *replace; 510 unsigned condition_offset; 511 }; 512 513 #endif 514 515 % for (opcode, xform_list) in xform_dict.iteritems(): 516 % for xform in xform_list: 517 ${xform.search.render()} 518 ${xform.replace.render()} 519 % endfor 520 521 static const struct transform ${pass_name}_${opcode}_xforms[] = { 522 % for xform in xform_list: 523 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} }, 524 % endfor 525 }; 526 % endfor 527 528 static bool 529 ${pass_name}_block(nir_block *block, const bool *condition_flags, 530 void *mem_ctx) 531 { 532 bool progress = false; 533 534 nir_foreach_instr_reverse_safe(instr, block) { 535 if (instr->type != nir_instr_type_alu) 536 continue; 537 538 nir_alu_instr *alu = nir_instr_as_alu(instr); 539 if (!alu->dest.dest.is_ssa) 540 continue; 541 542 switch (alu->op) { 543 % for opcode in xform_dict.keys(): 544 case nir_op_${opcode}: 545 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) { 546 const struct transform *xform = &${pass_name}_${opcode}_xforms[i]; 547 if (condition_flags[xform->condition_offset] && 548 nir_replace_instr(alu, xform->search, xform->replace, 549 mem_ctx)) { 550 progress = true; 551 break; 552 } 553 } 554 break; 555 % endfor 556 default: 557 break; 558 } 559 } 560 561 return progress; 562 } 563 564 static bool 565 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags) 566 { 567 void *mem_ctx = ralloc_parent(impl); 568 bool progress = false; 569 570 nir_foreach_block_reverse(block, impl) { 571 progress |= ${pass_name}_block(block, condition_flags, mem_ctx); 572 } 573 574 if (progress) 575 nir_metadata_preserve(impl, nir_metadata_block_index | 576 nir_metadata_dominance); 577 578 return progress; 579 } 580 581 582 bool 583 ${pass_name}(nir_shader *shader) 584 { 585 bool progress = false; 586 bool condition_flags[${len(condition_list)}]; 587 const nir_shader_compiler_options *options = shader->options; 588 (void) options; 589 590 % for index, condition in enumerate(condition_list): 591 condition_flags[${index}] = ${condition}; 592 % endfor 593 594 nir_foreach_function(function, shader) { 595 if (function->impl) 596 progress |= ${pass_name}_impl(function->impl, condition_flags); 597 } 598 599 return progress; 600 } 601 """) 602 603 class AlgebraicPass(object): 604 def __init__(self, pass_name, transforms): 605 self.xform_dict = {} 606 self.pass_name = pass_name 607 608 error = False 609 610 for xform in transforms: 611 if not isinstance(xform, SearchAndReplace): 612 try: 613 xform = SearchAndReplace(xform) 614 except: 615 print("Failed to parse transformation:", file=sys.stderr) 616 print(" " + str(xform), file=sys.stderr) 617 traceback.print_exc(file=sys.stderr) 618 print('', file=sys.stderr) 619 error = True 620 continue 621 622 if xform.search.opcode not in self.xform_dict: 623 self.xform_dict[xform.search.opcode] = [] 624 625 self.xform_dict[xform.search.opcode].append(xform) 626 627 if error: 628 sys.exit(1) 629 630 def render(self): 631 return _algebraic_pass_template.render(pass_name=self.pass_name, 632 xform_dict=self.xform_dict, 633 condition_list=condition_list) 634