1 #!/usr/bin/python3 2 3 # Copyright 2017, The Android Open Source Project 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 """Slicing the input Model file 17 18 Invoked by ml/nn/runtime/test/specs/slicing.sh; this Python code is 19 not intended to be invoked directly by the users. See that script for 20 details on how to use the slicing tool is used. 21 22 This script does the following work: 23 24 Perform a topological sort similar to the test generator, except that: 25 * It would stop at the N-th operation it encounters, and 26 * Rename the output of the N-th operation to a model output, and 27 * Name that as the output of the model. 28 * Also only inputs and weights used by the submodel would be emitted. 29 30 """ 31 32 from __future__ import absolute_import 33 from __future__ import division 34 from __future__ import print_function 35 import argparse 36 from functools import reduce 37 import math 38 import os 39 import struct 40 import sys 41 import contextlib 42 import test_generator 43 import pprint 44 # Stuff from test generator 45 from test_generator import Configuration 46 from test_generator import Example 47 from test_generator import Float32Scalar 48 from test_generator import Float32Vector 49 from test_generator import IgnoredOutput 50 from test_generator import Input 51 from test_generator import Int32Scalar 52 from test_generator import Int32Vector 53 from test_generator import Internal 54 from test_generator import Model 55 from test_generator import Output 56 from test_generator import Parameter 57 from test_generator import SmartOpen 58 59 60 # Take a model from command line 61 def import_source(): 62 parser = argparse.ArgumentParser() 63 parser.add_argument("spec", help="the spec file") 64 parser.add_argument( 65 "-n", "--number", 66 help="number of operations in the sliced model. Default = 1", 67 default=1) 68 parser.add_argument( 69 "-m", "--model", help="the output model file", default="-") 70 parser.add_argument( 71 "-e", "--example", help="the output example file", default="-") 72 args = parser.parse_args() 73 74 if os.path.exists(args.spec): 75 test_generator.FileNames.specFile = os.path.basename(args.spec) 76 exec (open(args.spec).read()) 77 else: 78 print("cannot find file %s" % args.spec) 79 sys.exit(1) 80 81 return (args.model, args.example, args.number) 82 83 84 # Slice till the Nth op the topological sort finds 85 # the output of that op becomes the output of the model 86 class slicing: 87 88 def __init__(self, threshold): 89 self.__nr_op_seen = 0 90 self.__threshold = threshold 91 self.__last_outs = [] 92 self.__all_formatted_ops = [] 93 self.__referenced_operands = set() 94 95 def format_as_py_op(self, op): 96 fmt = op.PyDefinition() 97 if fmt is not None: 98 self.__nr_op_seen += 1 99 if self.__nr_op_seen > self.__threshold: 100 return False 101 self.__last_outs = op.outs 102 for o in op.ins: 103 self.__referenced_operands.add(o) 104 for o in op.outs: 105 self.__referenced_operands.add(o) 106 self.__all_formatted_ops.append("model = model.%s" % fmt) 107 return True 108 109 def dump(self, model_file): 110 for x in self.__all_formatted_ops: 111 print(x, file=model_file) 112 113 def dump_example(self, example_file): 114 override = {} 115 # Make alias for the output variable 116 for lo in self.__last_outs: 117 override[str(lo)] = lo.type.GetNumberOfElements() 118 alias_def = """\ 119 # Alias for the output variable {operand_name} 120 aliased_output{number} = {operand_name} 121 """ 122 op = { 123 'operand_name': str(lo), 124 'number': 0 # only support one output as of now 125 } 126 print (alias_def.format(**op), file=example_file) 127 Example.py_dump(example_file, override, self.__referenced_operands) 128 129 def format_operands(self, model): 130 # Dump operand definitions 131 op_definitions = [] 132 for o in model.operands: 133 if o not in self.__referenced_operands: 134 continue 135 ty = o.type 136 op_def = """{op_name} = {operand}("{op_name}", "{element_type}", "{shape}" """ 137 if isinstance(o, test_generator.Parameter): 138 op_def += """, {initializer})""" 139 init = o.value 140 py_operand_name = "Parameter" 141 else: 142 op_def += ")" 143 init = [] 144 py_operand_name = "IgnoredOutput" if o in set( 145 self.__last_outs) else o.__class__.__name__ 146 147 op = { 148 "element_type": ty.type, 149 "shape": ty.GetRawShape(), 150 "op_name": str(o), 151 "operand": py_operand_name, 152 "initializer": init 153 } 154 op_definitions.append(op_def.format(**op)) 155 return "\n".join(op_definitions) 156 157 158 if __name__ == "__main__": 159 (model, example, number) = import_source() 160 s = slicing(int(number)) 161 162 with SmartOpen(model) as model_file: 163 spec_file = " (from: %s)" % (test_generator.FileNames.specFile) 164 print("# Generated file%s. Do not edit" % (spec_file), file=model_file) 165 print("model = Model()", file=model_file) 166 # slicing tool only support one single model per spec file 167 model = Model.models[0].Compile() 168 for op in model.operations: 169 s.format_as_py_op(op) 170 print(s.format_operands(model), file=model_file) 171 s.dump(model_file) 172 with SmartOpen(example) as example_file: 173 s.dump_example(example_file) 174