1 #!/usr/bin/python3 2 3 # Copyright 2017, The Android Open Source Project 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 """VTS testcase generator 17 18 Implements VTS test backend. Shares most logic with the CTS test 19 generator. Invoked by ml/nn/runtime/test/specs/generate_vts_tests.sh; 20 See that script for details on how this script is used. 21 22 """ 23 24 from __future__ import absolute_import 25 from __future__ import division 26 from __future__ import print_function 27 import argparse 28 from functools import reduce 29 import math 30 import numpy as np 31 import os 32 import re 33 import struct 34 import sys 35 import contextlib 36 import pprint 37 38 # Stuff from test generator 39 import test_generator as tg 40 from test_generator import ActivationConverter 41 from test_generator import BoolScalar 42 from test_generator import Configuration 43 from test_generator import DataTypeConverter 44 from test_generator import DataLayoutConverter 45 from test_generator import Example 46 from test_generator import Float16Scalar 47 from test_generator import Float32Scalar 48 from test_generator import Float32Vector 49 from test_generator import IgnoredOutput 50 from test_generator import Input 51 from test_generator import Int32Scalar 52 from test_generator import Int32Vector 53 from test_generator import Internal 54 from test_generator import Model 55 from test_generator import Operand 56 from test_generator import Output 57 from test_generator import Parameter 58 from test_generator import ParameterAsInputConverter 59 from test_generator import RelaxedModeConverter 60 from test_generator import SmartOpen 61 from test_generator import SymmPerChannelQuantParams 62 63 # Dumping methods that shared with CTS generator 64 from cts_generator import DumpCtsExample 65 from cts_generator import DumpCtsIsIgnored 66 67 # Take a model from command line 68 def ParseCmdLine(): 69 parser = argparse.ArgumentParser() 70 parser.add_argument("spec", help="the spec file") 71 parser.add_argument( 72 "-m", "--model", help="the output model file", default="-") 73 parser.add_argument( 74 "-e", "--example", help="the output example file", default="-") 75 parser.add_argument( 76 "-t", "--test", help="the output test file", default="-") 77 args = parser.parse_args() 78 tg.FileNames.InitializeFileLists( 79 args.spec, args.model, args.example, args.test) 80 81 # Generate operands in VTS format 82 def generate_vts_operands(model): 83 # Dump operand definitions 84 op_def = """\ 85 {{ 86 .type = OperandType::{operand_type}, 87 .dimensions = {shape}, 88 .numberOfConsumers = {no_consumers}, 89 .scale = {scale}, 90 .zeroPoint = {zero_point}, 91 .lifetime = OperandLifeTime::{lifetime}, 92 .location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}},{extraParams} 93 }}""" 94 offset = 0 95 op_definitions = [] 96 extra_params_definitions = [] 97 for index, o in enumerate(model.operands): 98 length = o.type.GetByteSize() if isinstance(o, Parameter) else 0 99 add_extra_params = o.type.extraParams is not None and not o.type.extraParams.hide 100 op = { 101 "operand_type": o.type.type, 102 "shape": o.type.GetDimensionsString(), 103 "no_consumers": len(o.outs), 104 "scale": tg.PrettyPrintAsFloat(o.type.scale), 105 "zero_point": str(int(o.type.zeroPoint)), 106 "lifetime": o.lifetime, 107 "offset": offset if isinstance(o, Parameter) else 0, 108 "length": length, 109 "extraParams": "" if not add_extra_params else "\n .extraParams = std::move(extraParams%d)," % (index,), 110 } 111 offset += length 112 op_definitions.append(op_def.format(**op)) 113 114 extra_params_def = """\ 115 Operand::ExtraParams extraParams{index}; 116 extraParams{index}.{setMethodName}({param}); 117 """ 118 119 if add_extra_params: 120 ep = o.type.extraParams 121 op = { 122 "index": index, 123 "setMethodName": ep.GetVtsSetter(), 124 "param": ep.GetVtsConstructor(), 125 } 126 extra_params_definitions.append(extra_params_def.format(**op)) 127 128 op_vec = """{0}\ 129 const std::vector<Operand> operands = {{ 130 {1} 131 }};""".format(",\n".join(extra_params_definitions), ",\n".join(op_definitions)) 132 return op_vec 133 134 # Generate VTS operand values 135 def generate_vts_operand_values(operands): 136 weights = [o for o in operands if isinstance(o, Parameter)] 137 binit = [] 138 for w in weights: 139 ty = w.type.type 140 if ty == "TENSOR_QUANT8_ASYMM": 141 binit += w.value 142 elif ty == "TENSOR_QUANT8_SYMM_PER_CHANNEL" or ty == "TENSOR_QUANT8_SYMM": 143 binit += [struct.pack("b", value)[0] for value in w.value] 144 elif ty == "BOOL" or ty == "TENSOR_BOOL8": 145 binit += [1 if x else 0 for x in w.value] 146 elif ty == "TENSOR_FLOAT16" or ty == "FLOAT16": 147 for f in w.value: 148 # The pack format for float16 is not available until Python 3.6. 149 binit += [int(x) for x in np.float16(f).tostring()] 150 elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32", "TENSOR_QUANT16_ASYMM"}: 151 if ty in ["TENSOR_FLOAT32", "FLOAT32"]: 152 fmt = "f" 153 elif ty in ["TENSOR_INT32", "INT32"]: 154 fmt = "i" 155 elif ty == "TENSOR_QUANT16_ASYMM": 156 fmt = "H" 157 for f in w.value: 158 binit += [int(x) for x in struct.pack(fmt, f)] 159 else: 160 assert 0 and "Unsupported VTS operand type" 161 162 init_defs = ", ".join([str(x) for x in binit]) 163 if (init_defs != ""): 164 init_defs = "\n %s\n " % init_defs 165 byte_vec_fmt = """{%s}""" % init_defs 166 return byte_vec_fmt 167 168 # Generate VTS operations 169 def generate_vts_operation(op, model): 170 op_fmt = """\ 171 {{ 172 .type = OperationType::{op_code}, 173 .inputs = {{{ins}}}, 174 .outputs = {{{outs}}}, 175 }}""" 176 op_content = { 177 'op_code': op.optype, 178 'ins': tg.GetJointStr(model.GetIndexOfOperands(op.ins)), 179 'outs': tg.GetJointStr(model.GetIndexOfOperands(op.outs)) 180 } 181 return op_fmt.format(**op_content) 182 183 def generate_vts_operations(model): 184 vts_ops = [generate_vts_operation(op, model) for op in model.operations] 185 return ",\n".join(vts_ops) 186 187 def generate_vts_model(model, model_file): 188 operand_values_fmt = "" 189 if Configuration.useSHM(): 190 # Boilerplate code for passing weights in shared memory 191 operand_values_fmt = """\ 192 std::vector<uint8_t> operandValues = {{}}; 193 const uint8_t data[] = {operand_values}; 194 195 // Allocate segment of android shared memory, wrapped in hidl_memory. 196 // This object will be automatically freed when sharedMemory is destroyed. 197 hidl_memory sharedMemory = allocateSharedMemory(sizeof(data)); 198 199 // Mmap ashmem into usable address and hold it within the mappedMemory object. 200 // MappedMemory will automatically munmap the memory when it is destroyed. 201 sp<IMemory> mappedMemory = mapMemory(sharedMemory); 202 203 if (mappedMemory != nullptr) {{ 204 // Retrieve the mmapped pointer. 205 uint8_t* mappedPointer = 206 static_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer())); 207 208 if (mappedPointer != nullptr) {{ 209 // Acquire the write lock for the shared memory segment, upload the data, 210 // and release the lock. 211 mappedMemory->update(); 212 std::copy(data, data + sizeof(data), mappedPointer); 213 mappedMemory->commit(); 214 }} 215 }} 216 217 const std::vector<hidl_memory> pools = {{sharedMemory}}; 218 """ 219 else: 220 # Passing weights via operandValues 221 operand_values_fmt = """\ 222 std::vector<uint8_t> operandValues = {operand_values}; 223 const std::vector<hidl_memory> pools = {{}}; 224 """ 225 226 operand_values_val = { 227 'operand_values': generate_vts_operand_values(model.operands) 228 } 229 operand_values = operand_values_fmt.format(**operand_values_val) 230 # operand_values = operand_values_fmt 231 model_fmt = """\ 232 // Create the model 233 Model {create_test_model_name}() {{ 234 {operand_decls} 235 236 const std::vector<Operation> operations = {{ 237 {operations} 238 }}; 239 240 const std::vector<uint32_t> inputIndexes = {{{input_indices}}}; 241 const std::vector<uint32_t> outputIndexes = {{{output_indices}}}; 242 {operand_values} 243 return {{ 244 .operands = operands, 245 .operations = operations, 246 .inputIndexes = inputIndexes, 247 .outputIndexes = outputIndexes, 248 .operandValues = operandValues, 249 .pools = pools,{relaxed_field} 250 }}; 251 }} 252 """ 253 model_dict = { 254 "create_test_model_name": str(model.createTestFunctionName), 255 "operations": generate_vts_operations(model), 256 "operand_decls": generate_vts_operands(model), 257 "operand_values": operand_values, 258 "output_indices": tg.GetJointStr(model.GetOutputsIndex()), 259 "input_indices": tg.GetJointStr(model.GetInputsIndex()), 260 "relaxed_field": 261 "\n .relaxComputationFloat32toFloat16 = true," if (model.isRelaxed) else "" 262 } 263 print(model_fmt.format(**model_dict), file = model_file) 264 265 def generate_vts(model, model_file): 266 assert model.compiled 267 generate_vts_model(model, model_file) 268 DumpCtsIsIgnored(model, model_file) 269 270 def generate_vts_test(example, test_file): 271 testTemplate = """\ 272 TEST_F({test_case_name}, {test_name}) {{ 273 generated_tests::Execute(device, 274 {namespace}::{create_model_name}, 275 {namespace}::{is_ignored_name}, 276 {namespace}::get_{examples_name}(){test_dynamic_output_shape});\n}} 277 278 TEST_F(ValidationTest, {test_name}) {{ 279 const Model model = {namespace}::{create_model_name}(); 280 const std::vector<Request> requests = createRequests({namespace}::get_{examples_name}()); 281 validateEverything(model, requests); 282 }}\n 283 """ 284 if example.model.hasDynamicOutputShape: 285 print("#ifdef NN_TEST_DYNAMIC_OUTPUT_SHAPE", file=test_fd) 286 print(testTemplate.format( 287 test_case_name="DynamicOutputShapeTest" if example.model.hasDynamicOutputShape \ 288 else "NeuralnetworksHidlTest", 289 test_name=str(example.testName), 290 namespace=tg.FileNames.specName, 291 create_model_name=str(example.model.createTestFunctionName), 292 is_ignored_name=str(example.model.isIgnoredFunctionName), 293 examples_name=str(example.examplesName), 294 test_dynamic_output_shape=", true" if example.model.hasDynamicOutputShape else "" 295 ), file=test_fd) 296 if example.model.hasDynamicOutputShape: 297 print("#endif", file=test_fd) 298 299 def InitializeFiles(model_fd, example_fd, test_fd): 300 fileHeader = "// clang-format off\n// Generated file (from: {spec_file}). Do not edit" 301 testFileHeader = """\ 302 // Generated from: {spec_file}. 303 namespace {spec_name} {{ 304 // Generated {spec_name} test 305 #include "{example_file}" 306 // Generated model constructor 307 #include "{model_file}" 308 }} // namespace {spec_name}\n""" 309 # This regex is to remove prefix and get relative path for #include 310 pathRegex = r".*frameworks/ml/nn/(runtime/test/generated/)?" 311 specFileBase = os.path.basename(tg.FileNames.specFile) 312 print(fileHeader.format(spec_file=specFileBase), file=model_fd) 313 print(fileHeader.format(spec_file=specFileBase), file=example_fd) 314 print(testFileHeader.format( 315 spec_file=specFileBase, 316 model_file=re.sub(pathRegex, "", tg.FileNames.modelFile), 317 example_file=re.sub(pathRegex, "", tg.FileNames.exampleFile), 318 spec_name=tg.FileNames.specName), file=test_fd) 319 320 if __name__ == "__main__": 321 ParseCmdLine() 322 while tg.FileNames.NextFile(): 323 print("Generating test(s) from spec: %s" % tg.FileNames.specFile, file=sys.stderr) 324 exec (open(tg.FileNames.specFile, "r").read()) 325 print("Output VTS model: %s" % tg.FileNames.modelFile, file=sys.stderr) 326 print("Output example:" + tg.FileNames.exampleFile, file=sys.stderr) 327 with SmartOpen(tg.FileNames.modelFile) as model_fd, \ 328 SmartOpen(tg.FileNames.exampleFile) as example_fd, \ 329 SmartOpen(tg.FileNames.testFile, mode="a") as test_fd: 330 InitializeFiles(model_fd, example_fd, test_fd) 331 Example.DumpAllExamples( 332 DumpModel=generate_vts, model_fd=model_fd, 333 DumpExample=DumpCtsExample, example_fd=example_fd, 334 DumpTest=generate_vts_test, test_fd=test_fd) 335