1 #!/usr/bin/python3 2 3 # Copyright 2017, The Android Open Source Project 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 """NN model compiler 18 19 Compile models and examples into VTS and NDK-based CTS unit tests 20 """ 21 22 from __future__ import absolute_import 23 from __future__ import division 24 from __future__ import print_function 25 import argparse 26 from functools import reduce 27 import math 28 import os 29 import struct 30 import sys 31 import contextlib 32 33 @contextlib.contextmanager 34 def smart_open(filename=None): 35 if filename and filename != '-': 36 fh = open(filename, 'w') 37 else: 38 fh = sys.stdout 39 40 try: 41 yield fh 42 finally: 43 if fh is not sys.stdout: 44 fh.close() 45 46 class Phase(object): 47 def __init__(self): 48 self.__objects = [] 49 self.__contents = [] 50 self.__dict_of_objects = {} 51 52 def append(self, obj, x): 53 self.__objects.append(obj) 54 self.__contents.append(x) 55 self.__dict_of_objects[obj.ID()] = obj 56 57 def dump(self, filename): 58 for x in self.__contents: 59 print (" " + x + ";", file=filename) 60 61 def objects(self): 62 return self.__objects 63 64 def search(self, i): 65 return self.__dict_of_objects[i] 66 67 # Tracking objects inside a model with a not necessarily unique name and 68 # an unique number 69 class NamedObject(object): 70 __serial = 0 71 72 def __init__(self, name = "NamedObject"): 73 self.__name = name 74 self.__id = NamedObject.serial() 75 NamedObject.__serial += 1 76 77 def ID(self): 78 return self.__id 79 80 def serial(): 81 return NamedObject.__serial 82 83 def get_name(self): 84 return self.__name 85 86 def __str__(self): 87 return self.get_name() 88 89 def __hash__(self): 90 return self.__id 91 92 # Object that can be traversed during topological sorting phase 93 class Traversable(object): 94 def traversable(self): 95 return True 96 97 class Nontraversable(object): 98 def traversable(self): 99 return False 100 101 # Object that can take input from other objects 102 class Uses(object): 103 all_uses = set() 104 def __init__(self, ins = []): 105 self.ins = ins.copy() 106 Uses.all_uses.add(self) 107 for i in ins: 108 i.outs.append(self) 109 110 # Object that other objects takes its definition from 111 class Definitions(object): 112 def __init__(self, outs = []): 113 self.outs = outs.copy() 114 for o in outs: 115 o.ins.append(self) 116 117 class TypeLookup: 118 __type_lookup = { 119 "INT32": "int32_t", 120 "FLOAT32": "float", 121 "TENSOR_INT32": "int32_t", 122 "TENSOR_FLOAT32": "float", 123 "TENSOR_QUANT8_ASYMM": "uint8_t", 124 } 125 126 def get_cpptype(nnapi_type): 127 return TypeLookup.__type_lookup[nnapi_type] 128 129 def is_float(nnapi_type): 130 return TypeLookup.get_cpptype(nnapi_type) == "float" 131 132 def get_size(nnapi_type): 133 return 1 if TypeLookup.get_cpptype(nnapi_type) == "uint8_t" else 4 134 135 136 class Type(object): 137 __types = {} 138 __type_serial = 0 # types have their own numbering 139 def __init__(self, vt = None, shape = None): 140 self.__vt = vt 141 self.__shape = shape 142 if vt is None or shape is None: 143 self.__name = None 144 return 145 146 key = str(self) 147 if key not in Type.__types: 148 self.__id = Type.__type_serial 149 Type.__types[str(self)] = self 150 Type.__type_serial += 1 151 else: 152 self.__id = Type.__types[key].__id 153 self.__name = "type" + str(self.__id) 154 155 def get_shape(self): 156 return self.__shape 157 158 def get_element_type(self): 159 return self.__vt 160 161 def get_name(self): 162 return self.__name 163 164 def __str__(self): 165 return (", ".join([self.__vt, self.__shape])) 166 167 def __hash__(self): 168 return self.__id 169 170 def dump(filename): 171 for key, value in sorted(Type.__types.items()): 172 print (" OperandType " + str(value.__name) + "(Type::" + str(key) + ");", file=filename) 173 174 def get_parsed_shape(self): 175 # Parse shape 176 if (self.__shape != "" and self.__shape != "{}"): 177 left, sep, right = self.__shape.partition('{') 178 real_shape, sep, right = right.partition('}') 179 shape = [int(x) for x in real_shape.split(",")] 180 # left now looks like "0.0f, 127.5f, " 181 scale, sep, zero_point = right.rpartition(',') 182 if scale == "": 183 if zero_point == "": 184 return real_shape, "0", "0" 185 return real_shape, zero_point, "0" 186 left, sep, scale = scale.partition(',') 187 return real_shape, scale.replace("f", ""), zero_point 188 else: 189 return "", "0", "0" 190 191 def get_size(self): 192 element_size = TypeLookup.get_size(self.__vt) 193 # Parse shape 194 nr_elements = 1 195 real_shape, scale, zero_point = self.get_parsed_shape() 196 197 if (real_shape != "" and real_shape != "{}"): 198 shape = [int(x) for x in real_shape.split(",")] 199 nr_elements = reduce((lambda x, y: x*y), shape) 200 return element_size * nr_elements 201 202 # A value is a typed, named object 203 class Value(NamedObject): 204 def __init__(self, name, vt): 205 NamedObject.__init__(self, name) 206 self.type = vt 207 208 # An operand that can be fed into operations. Also, an operand is always 209 # declared before operations. 210 class Operand(Value): 211 # All operand declarations in string 212 operands = Phase() 213 214 def __init__(self, name, vt): 215 Value.__init__(self, name, vt) 216 def_string = ( 217 "auto " + self.get_name() + " = "\ 218 "model->addOperand(&" + vt.get_name() + ")") 219 Operand.operands.append(self, def_string) 220 221 # By default, produce nothing (when asked by the Topological Sort phase) 222 def Definition(self): 223 pass 224 225 def Reference(self): 226 return NamedObject.__str__(self) 227 228 # Print a set of operands in curly braces 229 def print_operands(operands): 230 return [ x.Reference() for x in operands ] 231 232 # Defined with the model or not 233 def is_weight(self): 234 return False 235 236 # A user-declared input operand 237 class Input(Operand, Definitions, Traversable): 238 # for enumerating inputs 239 __next_number = 0 240 # Holds reference to all Inputs; used by Topoligcal sort as starting nodes. 241 __inputs = set() 242 243 def __init__(self, name, vt, shape, increase_next_number=True): 244 Operand.__init__(self, name, Type(vt, shape)) 245 Definitions.__init__(self) 246 Input.__inputs.add(self) 247 self.number = Input.__next_number 248 if increase_next_number is True: 249 Input.__next_number += 1 250 251 def lifetime(self): 252 return "MODEL_INPUT" 253 254 def is_internal(self): 255 return False 256 257 def get_inputs(exclude_internal = None): 258 if exclude_internal is not None: 259 external = { x for x in Input.__inputs if not x.is_internal() } 260 return external 261 else: 262 return Input.__inputs 263 264 # A user-declared output operand 265 class Output(Operand, Uses, Nontraversable): 266 # for enumerating outputs 267 __next_number = 0 268 __outputs = [] 269 270 def __init__(self, name, vt, shape): 271 Operand.__init__(self, name, Type(vt, shape)) 272 Uses.__init__(self) 273 Output.__outputs.append(self) 274 self.number = Output.__next_number 275 Output.__next_number += 1 276 277 def lifetime(self): 278 return "MODEL_OUTPUT" 279 280 # return all unique outputs in the original order 281 def get_outputs(): 282 saw = set() 283 unique = [x for x in Output.__outputs if x not in saw and (saw.add(x) or True)] 284 return unique 285 286 # An output that we don't want to compare the results 287 class IgnoredOutput(Output): 288 __ignored = set() 289 def __init__(self, name, vt, shape): 290 Output.__init__(self, name, vt, shape) 291 IgnoredOutput.__ignored.add(self) 292 def gen_ignored(): 293 ignored_func = """ 294 bool is_ignored(int i) { 295 static std::set<int> ignore = {%s}; 296 return ignore.find(i) != ignore.end(); 297 }""" % ", ".join([str(x.number) for x in IgnoredOutput.__ignored]) 298 return ignored_func 299 300 class ModelArgument: 301 __arguments = [] 302 303 def __init__(self, arg_type, arg_name): 304 self.__arg_type = arg_type 305 self.__arg_name = arg_name 306 ModelArgument.__arguments.append(" ".join([arg_type, arg_name])) 307 308 def get_arg_type(self): 309 return self.__arg_type 310 311 def get_arg_name(self): 312 return self.__arg_name 313 314 def get_arguments(): 315 return ModelArgument.__arguments 316 317 def lifetime(self): 318 return "CONSTANT_COPY" 319 320 # Print in C float literal format 321 def pretty_print_as_float(x): 322 s = str(float(x)) 323 if s.find(".") >= 0 or s.find("e") >= 0: 324 return s + "f" 325 else: 326 return s + ".0f" 327 328 class Parameter(Input): 329 # TODO seems wrong that's an Input. 330 def __init__(self, name, vt, shape, initializer): 331 Input.__init__(self, name, vt, shape, False) 332 self.initializer = initializer 333 self.cpptype = TypeLookup.get_cpptype(vt) 334 def is_internal(self): 335 return True 336 def Definition(self): 337 init_name = self.get_name() + "_init" 338 initializer = [str(x) for x in self.initializer] 339 if self.cpptype == "float": 340 initializer = [ pretty_print_as_float(x) for x in initializer] 341 init = self.cpptype + " " + init_name + "[]" 342 init = "static " + init + " = {" + ", ".join(initializer) + "};" 343 args = [ self.get_name(), init_name, 344 "sizeof(" + self.cpptype + ") * " + str(len(self.initializer)) ] 345 stmt = "\n ".join([init, 346 "model->setOperandValue(" + ", ".join(args)+");"]) 347 return stmt 348 def is_weight(self): 349 return True 350 def lifetime(self): 351 return "CONSTANT_COPY" 352 353 class Int32Scalar(Parameter): 354 def __init__(self, name, value): 355 Parameter.__init__(self, name, "INT32", "{}", [value]) 356 357 class Float32Scalar(Parameter): 358 def __init__(self, name, value): 359 Parameter.__init__(self, name, "FLOAT32", "{}", [value]) 360 361 # A compiler-generated intermediate result from an operation 362 class IntermediateResult(Operand, Definitions, Uses, Traversable): 363 def __init__(self, src: Value): 364 tmp_name = "tmp" + str(NamedObject.serial()) 365 Operand.__init__(self, tmp_name, src.type) 366 Definitions.__init__(self) 367 Uses.__init__(self, [src]) 368 369 def lifetime(self): 370 return "TEMPORARY_VARIABLE" 371 372 # An explicitly declared intermediate result 373 class Internal(Operand, Definitions, Uses, Traversable): 374 def __init__(self, name, vt, shape): 375 Operand.__init__(self, name, Type(vt, shape)) 376 Definitions.__init__(self) 377 Uses.__init__(self) 378 379 def lifetime(self): 380 return "TEMPORARY_VARIABLE" 381 382 # An operation in a model 383 class Operation(Definitions, Uses, Traversable): 384 def __init__(self, optype, ins, outs): 385 self.type = ins[0].type 386 Definitions.__init__(self, outs) 387 Uses.__init__(self, ins) 388 self.optype = optype 389 390 def __str__(self): 391 inputs = [ str(x) for x in self.ins ] 392 return "Operation:" + self.optype + " " + ", ".join(inputs) 393 394 def Reference(self): 395 return "operation" + str(self.ID()); 396 397 def Definition(self): 398 inputs = Operand.print_operands(self.ins); 399 outputs = Operand.print_operands(self.outs); 400 return "model->addOperation(ANEURALNETWORKS_"+self.optype+", " + \ 401 "{"+", ".join(inputs)+"}, {" + ", ".join(outputs) + "});" 402 403 # Main interface 404 class Model(object): 405 def __init__(self): 406 self.__currentOp = None 407 408 # TODO turn this into generic binary operations 409 def Add(self, i1: Value, i2 = None) -> Operation: 410 ins = [i1] 411 if i2 is not None: 412 ins.append(i2) 413 if self.__currentOp is not None: 414 ir = IntermediateResult(self.__currentOp) 415 self.__currentOp = ir 416 ins.append(self.__currentOp) 417 418 op = Operation("ADD", ins, []) 419 420 self.__currentOp = op 421 return self 422 423 def Operation(self, op_name, *args): 424 ins = [i for i in args] 425 outs = [] 426 op = Operation(op_name, ins, outs) 427 self.__currentOp = op 428 return self 429 430 def RawAdd(self, i1: Value, i2: Value, o = None) -> Operation: 431 ins = [i1, i2] 432 outs = [] 433 if o is not None: 434 outs = [o] 435 op = Operation("ADD", ins, outs) 436 437 self.__currentOp = op 438 return self 439 440 # See CpuExecutor::executeOperation() for the arguments of each op 441 def AveragePool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation): 442 ins = [input, padding, stride_width, 443 stride_height, filter_width, filter_height, activation] 444 outs = [] 445 op = Operation("AVERAGE_POOL_2D", ins, outs) 446 self.__currentOp = op 447 return self 448 449 def Concatenation(self, *args): 450 ins = [i for i in args] 451 outs = [] 452 op = Operation("CONCATENATION", ins, outs) 453 self.__currentOp = op 454 return self 455 456 def Conv(self, filter, bias, input, padding, stride_width, stride_height, activation): 457 ins = [filter, bias, input, padding, stride_width, 458 stride_height, activation] 459 outs = [] 460 op = Operation("CONV_2D", ins, outs) 461 self.__currentOp = op 462 return self 463 464 def DepthWiseConv(self, filter, bias, input, padding, stride_width, stride_height, depth_multiplier, activation): 465 ins = [filter, bias, input, padding, stride_width, 466 stride_height, depth_multiplier, activation] 467 outs = [] 468 op = Operation("DEPTHWISE_CONV_2D", ins, outs) 469 self.__currentOp = op 470 return self 471 472 def FullyConnected(self, input, weights, bias, activation): 473 ins = [input, weights, bias, activation] 474 outs = [] 475 op = Operation("FULLY_CONNECTED", ins, outs) 476 self.__currentOp = op 477 return self 478 479 def Logistic(self, input): 480 ins = [input] 481 outs = [] 482 op = Operation("LOGISTIC", ins, outs) 483 self.__currentOp = op 484 return self 485 486 def L2Pool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation): 487 ins = [input, padding, stride_width, 488 stride_height, filter_width, filter_height, activation] 489 outs = [] 490 op = Operation("L2_POOL_2D", ins, outs) 491 self.__currentOp = op 492 return self 493 494 def MaxPool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation): 495 ins = [input, padding, stride_width, 496 stride_height, filter_width, filter_height, activation] 497 outs = [] 498 op = Operation("MAX_POOL_2D", ins, outs) 499 self.__currentOp = op 500 return self 501 502 def SoftMax(self, input, beta): 503 ins = [input, beta] 504 outs = [] 505 op = Operation("SOFTMAX", ins, outs) 506 self.__currentOp = op 507 return self 508 509 def Reshape(self, input, shape): 510 ins = [input, shape] 511 outs = [] 512 op = Operation("RESHAPE", ins, outs) 513 self.__currentOp = op 514 return self 515 516 def Out(self, o): 517 if (type(o) is list or type(o) is tuple): 518 for i in o: 519 self.__currentOp.outs.append(i) 520 i.ins.append(self.__currentOp) 521 else: 522 self.__currentOp.outs.append(o) 523 o.ins.append(self.__currentOp) 524 return self 525 526 def To(self, o:Value): 527 ret = Model.Out(self, o) 528 self.__currentOp = None 529 return self 530 531 class FileNames: 532 SpecFile = "" 533 534 class Example(): 535 __examples = [] 536 def __init__(self, list_of_examples): 537 Example.__examples.append(list_of_examples) 538 539 def dump_dict(d): 540 ret = [] 541 for k, v in d.items(): 542 key = str(k) 543 suffix = "f" 544 if type(k) is not int: 545 key = str(k.number) 546 if not TypeLookup.is_float(k.type.get_element_type()): 547 suffix = "" 548 init = ", ".join( 549 [str(i) + (suffix if str(i).find(".") != -1 else "") for i in v]) 550 ret.append("{%s, {%s}}" % (key, init)) 551 return ", ".join(ret) 552 553 def dump_mixed_types(d): 554 ret = [] 555 556 float32_dict = {} 557 int32_dict = {} 558 uint8_dict = {} 559 560 for k, v in d.items(): 561 ty = Operand.operands.search(k.ID()).type.get_element_type() 562 # find out type of the operand addressed by the key 563 if (ty == "TENSOR_FLOAT32"): 564 float32_dict[k] = v 565 elif (ty == "TENSOR_INT32"): 566 int32_dict[k] = v 567 elif (ty == "TENSOR_QUANT8_ASYMM"): 568 uint8_dict[k] = v 569 else: 570 print ("Unhandled type %s"%ty, file = sys.stderr) 571 assert 0 and "unsupported example type" 572 573 tuple_init = """\ 574 {{ // See tools/test_generator/include/TestHarness.h:MixedTyped 575 // int -> FLOAT32 map 576 {{{float32_dict}}}, 577 // int -> INT32 map 578 {{{int32_dict}}}, 579 // int -> QUANT8_ASYMM map 580 {{{uint8_dict}}} 581 }}""" 582 tuple_contents = { 583 'float32_dict': Example.dump_dict(float32_dict), 584 'int32_dict': Example.dump_dict(int32_dict), 585 'uint8_dict': Example.dump_dict(uint8_dict) 586 } 587 return tuple_init.format(**tuple_contents) 588 589 590 def dump(example_file): 591 if len(Example.__examples) > 0: 592 spec_file = " (from: %s)" % (FileNames.SpecFile) 593 print ('// Generated file%s. Do not edit' % (spec_file), 594 file = example_file) 595 for i, o in Example.__examples: 596 print ('// Begin of an example', file = example_file) 597 print ('{', file = example_file) 598 inputs = Example.dump_mixed_types(i) 599 outputs = Example.dump_mixed_types(o) 600 print ('//Input(s)\n%s,' % inputs , file = example_file) 601 print ('//Output(s)\n%s' % outputs, file = example_file) 602 print ('}, // End of an example', file = example_file) 603 604 def TopologicalSort(format_op): 605 start = Input.get_inputs().copy() 606 deps = { x: set(x.ins) for x in Uses.all_uses } 607 608 while len(start) > 0: 609 cur = start.pop() 610 format_op(cur) #cur.Definition() 611 distinct_outs = set(cur.outs) 612 for o in distinct_outs: 613 deps[o].remove(cur) 614 if len(deps[o]) == 0 and o.traversable(): 615 start.add(o) 616 617 class Configuration: 618 vts = False 619 620 # Take a model from command line 621 def import_source(): 622 parser = argparse.ArgumentParser() 623 parser.add_argument("spec", help="the spec file") 624 parser.add_argument( 625 "-v", 626 "--vts", 627 help="generate VTS model instead", 628 default=False, 629 action="store_true") 630 parser.add_argument( 631 "-m", "--model", help="the output model file", default="-") 632 parser.add_argument( 633 "-e", "--example", help="the output example file", default="-") 634 args = parser.parse_args() 635 636 Configuration.vts = args.vts 637 638 if os.path.exists(args.spec): 639 FileNames.SpecFile = os.path.basename(args.spec) 640 exec (open(args.spec).read()) 641 642 return (args.model, args.example) 643 644 645 # Generate operands in VTS format 646 def generate_vts_operands(): 647 # Dump operand definitions 648 op_def = """\ 649 {{ 650 .type = OperandType::{operand_type}, 651 .dimensions = {shape}, 652 .numberOfConsumers = {no_consumers}, 653 .scale = {scale}, 654 .zeroPoint = {zero_point}, 655 .lifetime = OperandLifeTime::{lifetime}, 656 .location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}}, 657 }}""" 658 offset = 0 659 op_definitions = [] 660 for o in Operand.operands.objects(): 661 ty = o.type 662 no_consumers = len(o.outs) if o.traversable() else 0 663 lifetime = o.lifetime() 664 length = ty.get_size() if o.is_weight() else 0 665 real_shape, scale, zero_point = ty.get_parsed_shape() 666 scale = float(scale) 667 zero_point = int(zero_point) 668 op = { 669 "operand_type": ty.get_element_type(), 670 "shape": "{%s}" % real_shape, 671 "no_consumers": no_consumers, 672 "scale": pretty_print_as_float(scale), 673 "zero_point": str(int(zero_point)), 674 "lifetime": lifetime, 675 "offset": offset if o.is_weight() else 0, 676 "length": length 677 } 678 offset += length 679 op_definitions.append(op_def.format(**op)) 680 681 op_vec = """\ 682 const std::vector<Operand> operands = {{ 683 {0} 684 }};""".format(",\n".join(op_definitions)) 685 return op_vec 686 687 # Generate VTS operand values 688 def generate_vts_operand_values(): 689 weights = [o for o in Operand.operands.objects() if o.is_weight()] 690 binit = [] 691 for w in weights: 692 ty = w.type.get_element_type() 693 if ty == "TENSOR_QUANT8_ASYMM": 694 binit += w.initializer 695 elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32"}: 696 fmt = "f" if (ty == "TENSOR_FLOAT32" or ty == "FLOAT32") else "i" 697 for f in w.initializer: 698 binit += [int(x) for x in struct.pack(fmt, f)] 699 else: 700 assert 0 and "Unsupported VTS operand type" 701 702 init_defs = ", ".join([str(x) for x in binit]) 703 if (init_defs != ""): 704 init_defs = "\n %s\n " % init_defs 705 byte_vec_fmt = """\ 706 std::vector<uint8_t> operandValues = {%s};""" % init_defs 707 return byte_vec_fmt 708 709 # Generate VTS operations 710 class VTSOps(object): 711 vts_ops = [] 712 def generate_vts_operation(op): 713 try: 714 opcode =op.optype 715 except AttributeError: # not an op, but things like weights 716 return 717 op_fmt = """\ 718 {{ 719 .type = OperationType::{op_code}, 720 .inputs = {{{ins}}}, 721 .outputs = {{{outs}}}, 722 }}""" 723 op_content = { 724 'op_code': op.optype, 725 'op_type': op.type.get_element_type(), 726 'ins': ", ".join([str(x.ID()) for x in op.ins]), 727 'outs': ", ".join([str(x.ID()) for x in op.outs]), 728 } 729 VTSOps.vts_ops.append(op_fmt.format(**op_content)) 730 731 def generate_vts_operations(model_file): 732 TopologicalSort(lambda x: VTSOps.generate_vts_operation(x)) 733 return ",\n".join(VTSOps.vts_ops) 734 735 def generate_vts_model(model_file): 736 model_fmt = """\ 737 // Generated code. Do not edit 738 // Create the model 739 Model createTestModel() {{ 740 {operand_decls} 741 742 const std::vector<Operation> operations = {{ 743 {operations} 744 }}; 745 746 const std::vector<uint32_t> inputIndexes = {{{input_indices}}}; 747 const std::vector<uint32_t> outputIndexes = {{{output_indices}}}; 748 {operand_values} 749 const std::vector<hidl_memory> pools = {{}}; 750 751 return {{ 752 .operands = operands, 753 .operations = operations, 754 .inputIndexes = inputIndexes, 755 .outputIndexes = outputIndexes, 756 .operandValues = operandValues, 757 .pools = pools, 758 }}; 759 }}""" 760 model = { 761 "operations": generate_vts_operations(sys.stdout), 762 "operand_decls": generate_vts_operands(), 763 "operand_values": generate_vts_operand_values(), 764 "output_indices": ", ".join([str(i.ID()) for i in Output.get_outputs()]), 765 "input_indices": ", ".join([str(i.ID()) for i in Input.get_inputs(True)]) 766 } 767 print(model_fmt.format(**model), file = model_file) 768 769 def generate_vts(model_file): 770 generate_vts_model(model_file) 771 print (IgnoredOutput.gen_ignored(), file=model_file) 772 773 def print_cts_op(model_file, op): 774 fmt = op.Definition() 775 if fmt is not None: 776 print (" %s" % fmt, file = model_file) 777 778 if __name__ == '__main__': 779 (model, example) = import_source() 780 # Boilerplate 781 args = "" 782 if len(ModelArgument.get_arguments()) > 0: 783 args = ", " + ", ".join(ModelArgument.get_arguments()) 784 785 print( 786 "Output %s model: %s" % ("VTS" if Configuration.vts else "CTS", model), 787 file=sys.stderr) 788 print ("Output example:" + example, file = sys.stderr) 789 790 if Configuration.vts: 791 with smart_open(model) as model_file: 792 generate_vts(model_file) 793 else: 794 with smart_open(model) as model_file: 795 spec_file = " (from: %s)" % (FileNames.SpecFile) 796 797 print ('// Generated file%s. Do not edit'%(spec_file), file = model_file) 798 print ("void CreateModel(Model *model" + args + ") {", file=model_file) 799 800 # Phase 0: types 801 Type.dump(model_file) 802 # Phase 1: add operands 803 print (" // Phase 1, operands", file=model_file) 804 Operand.operands.dump(model_file) 805 806 # Phase 2: operations 807 print (" // Phase 2, operations", file=model_file) 808 TopologicalSort(lambda x: print_cts_op(model_file, x)) 809 810 # Phase 3: add inputs and outputs 811 print (" // Phase 3, inputs and outputs", file=model_file) 812 inputs = Operand.print_operands(Input.get_inputs(True)); 813 outputs = Operand.print_operands(Output.get_outputs()); 814 print (" model->identifyInputsAndOutputs(\n" + 815 " {"+", ".join(inputs)+"},\n {" + ", ".join(outputs) + "});", 816 file=model_file) 817 # Boilerplate 818 print (" assert(model->isValid());", file=model_file); 819 print ("}", file=model_file) 820 print (IgnoredOutput.gen_ignored(), file=model_file) 821 822 with smart_open(example) as example_file: 823 Example.dump(example_file) 824