Home | History | Annotate | Download | only in test_generator
      1 #!/usr/bin/python3
      2 
      3 # Copyright 2017, The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 """Slicing the input Model file
     17 
     18 Invoked by ml/nn/runtime/test/specs/slicing.sh; this Python code is
     19 not intended to be invoked directly by the users. See that script for
     20 details on how to use the slicing tool is used.
     21 
     22 This script does the following work:
     23 
     24 Perform a topological sort similar to the test generator, except that:
     25 * It would stop at the N-th operation it encounters, and
     26 * Rename the output of the N-th operation to a model output, and
     27 * Name that as the output of the model.
     28 * Also only inputs and weights used by the submodel would be emitted.
     29 
     30 """
     31 
     32 from __future__ import absolute_import
     33 from __future__ import division
     34 from __future__ import print_function
     35 import argparse
     36 from functools import reduce
     37 import math
     38 import os
     39 import struct
     40 import sys
     41 import contextlib
     42 import test_generator
     43 import pprint
     44 # Stuff from test generator
     45 from test_generator import Configuration
     46 from test_generator import Example
     47 from test_generator import Float32Scalar
     48 from test_generator import Float32Vector
     49 from test_generator import IgnoredOutput
     50 from test_generator import Input
     51 from test_generator import Int32Scalar
     52 from test_generator import Int32Vector
     53 from test_generator import Internal
     54 from test_generator import Model
     55 from test_generator import Output
     56 from test_generator import Parameter
     57 from test_generator import SmartOpen
     58 
     59 
     60 # Take a model from command line
     61 def import_source():
     62   parser = argparse.ArgumentParser()
     63   parser.add_argument("spec", help="the spec file")
     64   parser.add_argument(
     65       "-n", "--number",
     66       help="number of operations in the sliced model. Default = 1",
     67       default=1)
     68   parser.add_argument(
     69       "-m", "--model", help="the output model file", default="-")
     70   parser.add_argument(
     71       "-e", "--example", help="the output example file", default="-")
     72   args = parser.parse_args()
     73 
     74   if os.path.exists(args.spec):
     75     test_generator.FileNames.specFile = os.path.basename(args.spec)
     76     exec (open(args.spec).read())
     77   else:
     78     print("cannot find file %s" % args.spec)
     79     sys.exit(1)
     80 
     81   return (args.model, args.example, args.number)
     82 
     83 
     84 # Slice till the Nth op the topological sort finds
     85 # the output of that op becomes the output of the model
     86 class slicing:
     87 
     88   def __init__(self, threshold):
     89     self.__nr_op_seen = 0
     90     self.__threshold = threshold
     91     self.__last_outs = []
     92     self.__all_formatted_ops = []
     93     self.__referenced_operands = set()
     94 
     95   def format_as_py_op(self, op):
     96     fmt = op.PyDefinition()
     97     if fmt is not None:
     98       self.__nr_op_seen += 1
     99       if self.__nr_op_seen > self.__threshold:
    100         return False
    101       self.__last_outs = op.outs
    102       for o in op.ins:
    103         self.__referenced_operands.add(o)
    104       for o in op.outs:
    105         self.__referenced_operands.add(o)
    106       self.__all_formatted_ops.append("model = model.%s" % fmt)
    107       return True
    108 
    109   def dump(self, model_file):
    110     for x in self.__all_formatted_ops:
    111       print(x, file=model_file)
    112 
    113   def dump_example(self, example_file):
    114     override = {}
    115     # Make alias for the output variable
    116     for lo in self.__last_outs:
    117       override[str(lo)] = lo.type.GetNumberOfElements()
    118       alias_def = """\
    119 # Alias for the output variable {operand_name}
    120 aliased_output{number} = {operand_name}
    121 """
    122       op = {
    123           'operand_name': str(lo),
    124           'number': 0 # only support one output as of now
    125       }
    126       print (alias_def.format(**op), file=example_file)
    127     Example.py_dump(example_file, override, self.__referenced_operands)
    128 
    129   def format_operands(self, model):
    130     # Dump operand definitions
    131     op_definitions = []
    132     for o in model.operands:
    133       if o not in self.__referenced_operands:
    134         continue
    135       ty = o.type
    136       op_def = """{op_name} = {operand}("{op_name}", "{element_type}", "{shape}" """
    137       if isinstance(o, test_generator.Parameter):
    138         op_def += """, {initializer})"""
    139         init = o.value
    140         py_operand_name = "Parameter"
    141       else:
    142         op_def += ")"
    143         init = []
    144         py_operand_name = "IgnoredOutput" if o in set(
    145             self.__last_outs) else o.__class__.__name__
    146 
    147       op = {
    148           "element_type": ty.type,
    149           "shape": ty.GetRawShape(),
    150           "op_name": str(o),
    151           "operand": py_operand_name,
    152           "initializer": init
    153       }
    154       op_definitions.append(op_def.format(**op))
    155     return "\n".join(op_definitions)
    156 
    157 
    158 if __name__ == "__main__":
    159   (model, example, number) = import_source()
    160   s = slicing(int(number))
    161 
    162   with SmartOpen(model) as model_file:
    163     spec_file = " (from: %s)" % (test_generator.FileNames.specFile)
    164     print("# Generated file%s. Do not edit" % (spec_file), file=model_file)
    165     print("model = Model()", file=model_file)
    166     # slicing tool only support one single model per spec file
    167     model = Model.models[0].Compile()
    168     for op in model.operations:
    169         s.format_as_py_op(op)
    170     print(s.format_operands(model), file=model_file)
    171     s.dump(model_file)
    172   with SmartOpen(example) as example_file:
    173     s.dump_example(example_file)
    174