Home | History | Annotate | Download | only in test_generator
      1 #!/usr/bin/python3
      2 
      3 # Copyright 2017, The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 """VTS testcase generator
     17 
     18 Implements VTS test backend. Shares most logic with the CTS test
     19 generator. Invoked by ml/nn/runtime/test/specs/generate_vts_tests.sh;
     20 See that script for details on how this script is used.
     21 
     22 """
     23 
     24 from __future__ import absolute_import
     25 from __future__ import division
     26 from __future__ import print_function
     27 import argparse
     28 from functools import reduce
     29 import math
     30 import numpy as np
     31 import os
     32 import re
     33 import struct
     34 import sys
     35 import contextlib
     36 import pprint
     37 
     38 # Stuff from test generator
     39 import test_generator as tg
     40 from test_generator import ActivationConverter
     41 from test_generator import BoolScalar
     42 from test_generator import Configuration
     43 from test_generator import DataTypeConverter
     44 from test_generator import DataLayoutConverter
     45 from test_generator import Example
     46 from test_generator import Float16Scalar
     47 from test_generator import Float32Scalar
     48 from test_generator import Float32Vector
     49 from test_generator import IgnoredOutput
     50 from test_generator import Input
     51 from test_generator import Int32Scalar
     52 from test_generator import Int32Vector
     53 from test_generator import Internal
     54 from test_generator import Model
     55 from test_generator import Operand
     56 from test_generator import Output
     57 from test_generator import Parameter
     58 from test_generator import ParameterAsInputConverter
     59 from test_generator import RelaxedModeConverter
     60 from test_generator import SmartOpen
     61 from test_generator import SymmPerChannelQuantParams
     62 
     63 # Dumping methods that shared with CTS generator
     64 from cts_generator import DumpCtsExample
     65 from cts_generator import DumpCtsIsIgnored
     66 
     67 # Take a model from command line
     68 def ParseCmdLine():
     69     parser = argparse.ArgumentParser()
     70     parser.add_argument("spec", help="the spec file")
     71     parser.add_argument(
     72         "-m", "--model", help="the output model file", default="-")
     73     parser.add_argument(
     74         "-e", "--example", help="the output example file", default="-")
     75     parser.add_argument(
     76         "-t", "--test", help="the output test file", default="-")
     77     args = parser.parse_args()
     78     tg.FileNames.InitializeFileLists(
     79         args.spec, args.model, args.example, args.test)
     80 
     81 # Generate operands in VTS format
     82 def generate_vts_operands(model):
     83   # Dump operand definitions
     84   op_def = """\
     85         {{
     86             .type = OperandType::{operand_type},
     87             .dimensions = {shape},
     88             .numberOfConsumers = {no_consumers},
     89             .scale = {scale},
     90             .zeroPoint = {zero_point},
     91             .lifetime = OperandLifeTime::{lifetime},
     92             .location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}},{extraParams}
     93         }}"""
     94   offset = 0
     95   op_definitions = []
     96   extra_params_definitions = []
     97   for index, o in enumerate(model.operands):
     98     length = o.type.GetByteSize() if isinstance(o, Parameter) else 0
     99     add_extra_params = o.type.extraParams is not None and not o.type.extraParams.hide
    100     op = {
    101         "operand_type": o.type.type,
    102         "shape": o.type.GetDimensionsString(),
    103         "no_consumers": len(o.outs),
    104         "scale": tg.PrettyPrintAsFloat(o.type.scale),
    105         "zero_point": str(int(o.type.zeroPoint)),
    106         "lifetime": o.lifetime,
    107         "offset": offset if isinstance(o, Parameter) else 0,
    108         "length": length,
    109         "extraParams": "" if not add_extra_params else "\n            .extraParams = std::move(extraParams%d)," % (index,),
    110     }
    111     offset += length
    112     op_definitions.append(op_def.format(**op))
    113 
    114     extra_params_def = """\
    115     Operand::ExtraParams extraParams{index};
    116     extraParams{index}.{setMethodName}({param});
    117 """
    118 
    119     if add_extra_params:
    120       ep = o.type.extraParams
    121       op = {
    122           "index": index,
    123           "setMethodName": ep.GetVtsSetter(),
    124           "param": ep.GetVtsConstructor(),
    125       }
    126       extra_params_definitions.append(extra_params_def.format(**op))
    127 
    128   op_vec = """{0}\
    129     const std::vector<Operand> operands = {{
    130 {1}
    131     }};""".format(",\n".join(extra_params_definitions), ",\n".join(op_definitions))
    132   return op_vec
    133 
    134 # Generate VTS operand values
    135 def generate_vts_operand_values(operands):
    136     weights = [o for o in operands if isinstance(o, Parameter)]
    137     binit = []
    138     for w in weights:
    139         ty = w.type.type
    140         if ty == "TENSOR_QUANT8_ASYMM":
    141             binit += w.value
    142         elif ty == "TENSOR_QUANT8_SYMM_PER_CHANNEL" or ty == "TENSOR_QUANT8_SYMM":
    143             binit += [struct.pack("b", value)[0] for value in w.value]
    144         elif ty == "BOOL" or ty == "TENSOR_BOOL8":
    145             binit += [1 if x else 0 for x in w.value]
    146         elif ty == "TENSOR_FLOAT16" or ty == "FLOAT16":
    147             for f in w.value:
    148                 # The pack format for float16 is not available until Python 3.6.
    149                 binit += [int(x) for x in np.float16(f).tostring()]
    150         elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32", "TENSOR_QUANT16_ASYMM"}:
    151             if ty in ["TENSOR_FLOAT32", "FLOAT32"]:
    152                 fmt = "f"
    153             elif ty in ["TENSOR_INT32", "INT32"]:
    154                 fmt = "i"
    155             elif ty == "TENSOR_QUANT16_ASYMM":
    156                 fmt = "H"
    157             for f in w.value:
    158                 binit += [int(x) for x in struct.pack(fmt, f)]
    159         else:
    160             assert 0 and "Unsupported VTS operand type"
    161 
    162     init_defs = ", ".join([str(x) for x in binit])
    163     if (init_defs != ""):
    164         init_defs = "\n      %s\n    " % init_defs
    165     byte_vec_fmt = """{%s}""" % init_defs
    166     return byte_vec_fmt
    167 
    168 # Generate VTS operations
    169 def generate_vts_operation(op, model):
    170     op_fmt = """\
    171         {{
    172             .type = OperationType::{op_code},
    173             .inputs = {{{ins}}},
    174             .outputs = {{{outs}}},
    175         }}"""
    176     op_content = {
    177         'op_code': op.optype,
    178         'ins': tg.GetJointStr(model.GetIndexOfOperands(op.ins)),
    179         'outs': tg.GetJointStr(model.GetIndexOfOperands(op.outs))
    180     }
    181     return op_fmt.format(**op_content)
    182 
    183 def generate_vts_operations(model):
    184     vts_ops = [generate_vts_operation(op, model) for op in model.operations]
    185     return ",\n".join(vts_ops)
    186 
    187 def generate_vts_model(model, model_file):
    188   operand_values_fmt = ""
    189   if Configuration.useSHM():
    190     # Boilerplate code for passing weights in shared memory
    191     operand_values_fmt = """\
    192     std::vector<uint8_t> operandValues = {{}};
    193     const uint8_t data[] = {operand_values};
    194 
    195     // Allocate segment of android shared memory, wrapped in hidl_memory.
    196     // This object will be automatically freed when sharedMemory is destroyed.
    197     hidl_memory sharedMemory = allocateSharedMemory(sizeof(data));
    198 
    199     // Mmap ashmem into usable address and hold it within the mappedMemory object.
    200     // MappedMemory will automatically munmap the memory when it is destroyed.
    201     sp<IMemory> mappedMemory = mapMemory(sharedMemory);
    202 
    203     if (mappedMemory != nullptr) {{
    204         // Retrieve the mmapped pointer.
    205         uint8_t* mappedPointer =
    206                 static_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer()));
    207 
    208         if (mappedPointer != nullptr) {{
    209             // Acquire the write lock for the shared memory segment, upload the data,
    210             // and release the lock.
    211             mappedMemory->update();
    212             std::copy(data, data + sizeof(data), mappedPointer);
    213             mappedMemory->commit();
    214         }}
    215     }}
    216 
    217     const std::vector<hidl_memory> pools = {{sharedMemory}};
    218 """
    219   else:
    220     # Passing weights via operandValues
    221     operand_values_fmt = """\
    222     std::vector<uint8_t> operandValues = {operand_values};
    223     const std::vector<hidl_memory> pools = {{}};
    224 """
    225 
    226   operand_values_val = {
    227       'operand_values': generate_vts_operand_values(model.operands)
    228   }
    229   operand_values = operand_values_fmt.format(**operand_values_val)
    230   #  operand_values = operand_values_fmt
    231   model_fmt = """\
    232 // Create the model
    233 Model {create_test_model_name}() {{
    234 {operand_decls}
    235 
    236     const std::vector<Operation> operations = {{
    237 {operations}
    238     }};
    239 
    240     const std::vector<uint32_t> inputIndexes = {{{input_indices}}};
    241     const std::vector<uint32_t> outputIndexes = {{{output_indices}}};
    242 {operand_values}
    243     return {{
    244         .operands = operands,
    245         .operations = operations,
    246         .inputIndexes = inputIndexes,
    247         .outputIndexes = outputIndexes,
    248         .operandValues = operandValues,
    249         .pools = pools,{relaxed_field}
    250     }};
    251 }}
    252 """
    253   model_dict = {
    254       "create_test_model_name": str(model.createTestFunctionName),
    255       "operations": generate_vts_operations(model),
    256       "operand_decls": generate_vts_operands(model),
    257       "operand_values": operand_values,
    258       "output_indices": tg.GetJointStr(model.GetOutputsIndex()),
    259       "input_indices": tg.GetJointStr(model.GetInputsIndex()),
    260       "relaxed_field":
    261         "\n        .relaxComputationFloat32toFloat16 = true," if (model.isRelaxed) else ""
    262   }
    263   print(model_fmt.format(**model_dict), file = model_file)
    264 
    265 def generate_vts(model, model_file):
    266   assert model.compiled
    267   generate_vts_model(model, model_file)
    268   DumpCtsIsIgnored(model, model_file)
    269 
    270 def generate_vts_test(example, test_file):
    271     testTemplate = """\
    272 TEST_F({test_case_name}, {test_name}) {{
    273   generated_tests::Execute(device,
    274                            {namespace}::{create_model_name},
    275                            {namespace}::{is_ignored_name},
    276                            {namespace}::get_{examples_name}(){test_dynamic_output_shape});\n}}
    277 
    278 TEST_F(ValidationTest, {test_name}) {{
    279   const Model model = {namespace}::{create_model_name}();
    280   const std::vector<Request> requests = createRequests({namespace}::get_{examples_name}());
    281   validateEverything(model, requests);
    282 }}\n
    283 """
    284     if example.model.hasDynamicOutputShape:
    285         print("#ifdef NN_TEST_DYNAMIC_OUTPUT_SHAPE", file=test_fd)
    286     print(testTemplate.format(
    287             test_case_name="DynamicOutputShapeTest" if example.model.hasDynamicOutputShape \
    288                            else "NeuralnetworksHidlTest",
    289             test_name=str(example.testName),
    290             namespace=tg.FileNames.specName,
    291             create_model_name=str(example.model.createTestFunctionName),
    292             is_ignored_name=str(example.model.isIgnoredFunctionName),
    293             examples_name=str(example.examplesName),
    294             test_dynamic_output_shape=", true" if example.model.hasDynamicOutputShape else ""
    295         ), file=test_fd)
    296     if example.model.hasDynamicOutputShape:
    297         print("#endif", file=test_fd)
    298 
    299 def InitializeFiles(model_fd, example_fd, test_fd):
    300     fileHeader = "// clang-format off\n// Generated file (from: {spec_file}). Do not edit"
    301     testFileHeader = """\
    302 // Generated from: {spec_file}.
    303 namespace {spec_name} {{
    304 // Generated {spec_name} test
    305 #include "{example_file}"
    306 // Generated model constructor
    307 #include "{model_file}"
    308 }} // namespace {spec_name}\n"""
    309     # This regex is to remove prefix and get relative path for #include
    310     pathRegex = r".*frameworks/ml/nn/(runtime/test/generated/)?"
    311     specFileBase = os.path.basename(tg.FileNames.specFile)
    312     print(fileHeader.format(spec_file=specFileBase), file=model_fd)
    313     print(fileHeader.format(spec_file=specFileBase), file=example_fd)
    314     print(testFileHeader.format(
    315         spec_file=specFileBase,
    316         model_file=re.sub(pathRegex, "", tg.FileNames.modelFile),
    317         example_file=re.sub(pathRegex, "", tg.FileNames.exampleFile),
    318         spec_name=tg.FileNames.specName), file=test_fd)
    319 
    320 if __name__ == "__main__":
    321     ParseCmdLine()
    322     while tg.FileNames.NextFile():
    323         print("Generating test(s) from spec: %s" % tg.FileNames.specFile, file=sys.stderr)
    324         exec (open(tg.FileNames.specFile, "r").read())
    325         print("Output VTS model: %s" % tg.FileNames.modelFile, file=sys.stderr)
    326         print("Output example:" + tg.FileNames.exampleFile, file=sys.stderr)
    327         with SmartOpen(tg.FileNames.modelFile) as model_fd, \
    328              SmartOpen(tg.FileNames.exampleFile) as example_fd, \
    329              SmartOpen(tg.FileNames.testFile, mode="a") as test_fd:
    330             InitializeFiles(model_fd, example_fd, test_fd)
    331             Example.DumpAllExamples(
    332                 DumpModel=generate_vts, model_fd=model_fd,
    333                 DumpExample=DumpCtsExample, example_fd=example_fd,
    334                 DumpTest=generate_vts_test, test_fd=test_fd)
    335