Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Define tflite op hints (intrinsic operations).
     16 
     17 This essentially allows defining a TensorFlow API for tflite operations in
     18 Python with hints on how they are represented in TensorFlow Lite. This basically
     19 is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution
     20 graph and is useful for LSTMs and other complicated TensorFlow constructions
     21 that are difficult to pattern match in TOCO, but are represented by a single
     22 accelerated tflite op.
     23 
     24 Example:
     25   def tflite_cool_activation(input):
     26     # A cool activation function.
     27     custom = tf.lite.OpHint("cool_activation")
     28     input, = custom.add_inputs(input)
     29     output = tf.sigmoid(input) * input
     30     output, = custom.add_outputs(output)
     31     return output
     32 
     33   image = tf.placeholder(tf.float32, (1, 16, 16, 1))
     34   output = tf.identity(tflite_cool_activation(image))
     35 
     36   session = tf.Session()
     37 
     38   graphdef_to_convert = tf.lite.convert_op_hints_to_stubs(session)
     39   tflite_graph = tf.lite.toco_convert(graphdef_to_convert, [image], [output])
     40   with open("/tmp/graph.fb", "wb") as fp:
     41     fp.write(tflite_graph)
     42 
     43 How does it work?:
     44 
     45 OpHint is a helper that you use when defining a vanilla python function.
     46 It allows you to wrap arguments with tf.identities with some custom attributes.
     47 These attributes allow you to find the original block of ops that was created.
     48 For example, if you use cool_activation above you essentially get:
     49 
     50 a_input = tf.identity()
     51 result = tf.multiply(tf.sigmoid(a_input), a_input)
     52 output = tf.identity()
     53 
     54 a_input, output are identities that have parameters representing
     55 what argument they are, what the name of the function they should turn into
     56 in tf lite as well as a guid that uniquely identifies a particular invocation.
     57 
     58 Once you have built your whole tensorflow graph, you can run it and train it
     59 as usual, but after you have done that, you need to convert the graph into
     60 a form that replaces these subgraphs wrapped in identities to stub ops. These
     61 ops don't actually exist in the normal TensorFlow runtime, but will be
     62 understood by toco later.
     63 """
     64 
     65 # TODO(aselle): Make this use generic graph transformations.
     66 # TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name.
     67 
     68 from __future__ import absolute_import
     69 from __future__ import division
     70 from __future__ import print_function
     71 
     72 import collections as _collections
     73 import copy as _copy
     74 import json as _json
     75 import uuid as _uuid
     76 import six as _six
     77 
     78 from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
     79 from tensorflow.core.framework import graph_pb2 as _graph_pb2
     80 from tensorflow.core.framework import node_def_pb2 as _node_def_pb2
     81 from tensorflow.python.framework import ops as _ops
     82 # TODO(aselle): publicize these apis if we continue to use these.
     83 from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
     84 from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
     85 from tensorflow.python.ops import array_ops as _array_ops
     86 from tensorflow.python.util import compat as _compat
     87 from tensorflow.python.util.all_util import remove_undocumented
     88 from tensorflow.python.util.tf_export import tf_export as _tf_export
     89 
     90 
     91 @_tf_export("lite.OpHint")
     92 class OpHint(object):
     93   """A class that helps build tflite function invocations.
     94 
     95   It allows you to take a bunch of TensorFlow ops and annotate the construction
     96   such that toco knows how to convert it to tflite. This embeds a pseudo
     97   function in a TensorFlow graph. This allows embedding high-level API usage
     98   information in a lower level TensorFlow implementation so that an alternative
     99   implementation can be substituted later.
    100 
    101   Essentially, any "input" into this pseudo op is fed into an identity, and
    102   attributes are added to that input before being used by the constituent ops
    103   that make up the pseudo op. A similar process is done to any output that
    104   is to be exported from the current op.
    105 
    106   """
    107   # TODO(aselle): When TensorFlow functions functionality works for arbitrary
    108   # constructs, this mechanism can be retired and changed to use python defun's.
    109 
    110   # Attr constants that are used for representation in the GraphDef. These
    111   # will be used on every Identity op that is involved in a total OpHint.
    112 
    113   # Name of the OpHint function (cosmetic).
    114   FUNCTION_NAME_ATTR = "_tflite_function_name"
    115   # UUID of the function (each OpHint gets a new uuid).
    116   FUNCTION_UUID_ATTR = "_tflite_function_uuid"
    117   # The index index of the input (or nothing if it is an output).
    118   FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
    119   # The output index of the output (or nothing if it is an input).
    120   FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
    121   # An index that orders aggregate arguments. Aggregate arguments are ones
    122   # that are separate but will be fused horizontally. For example a static LSTM
    123   # has a lstm cell for each time step. Each one has a separate opHint, but a
    124   # fused SequentialLSTM will treat this as a single tensor.
    125   FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index"
    126   # The way in which multiple parts of the aggregate argument will be joined
    127   # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST,
    128   # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK.
    129   FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate"
    130   # On fused OpHint stub, the order of inputs that the final LSTM call will
    131   # have. What this means is that the TensorFlow order might be
    132   # "foo", "bar", "stuff" and you might want the TF lite op order to be
    133   # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
    134   # attribute to [2, 0, 1, -1].
    135   TFLITE_INPUT_INDICES = "_tflite_input_indices"
    136   # OpHint level.
    137   FUNCTION_LEVEL_ATTR = "_tflite_ophint_level"
    138   # Ophint internal mapping, this is for high level Ophint only.
    139   # This basically contains three kinds of mapping:
    140   #   1) How parental ophinted inputs map to the first child ophinted inputs;
    141   #   2) How internal children nodes are connected;
    142   #   3) How parental ophinted outputs map to the last child ophinted outputs.
    143   CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping"
    144 
    145   # Types of aggregations
    146   #  stack: stacks all ophints with matching tags. i.e. for a static rnn.
    147   #   specifically, this is good for an input or output to a static rnn cell.
    148   AGGREGATE_STACK = "stack"
    149   # first: only takes the first output (one with lowest sort index)
    150   # of matching tags. This is good for the input state to an RNN.
    151   AGGREGATE_FIRST = "first"
    152   # aggregation last takes only the last tag (one with highest sort index).
    153   # This is good for an output value on the last stack item of a
    154   # static rnn.
    155   AGGREGATE_LAST = "last"
    156 
    157   class OpHintArgumentTracker(object):
    158     """Conceptually tracks indices of arguments of "OpHint functions".
    159 
    160     The inputs and arguments of these functions both use an instance
    161     of the class so they can have independent numbering.
    162     """
    163 
    164     def __init__(self,
    165                  function_name,
    166                  unique_function_id,
    167                  node_name_prefix,
    168                  attr_name,
    169                  level=1,
    170                  children_inputs_mappings=None):
    171       """Initialize ophint argument.
    172 
    173       Args:
    174         function_name: Name of the function that this tracks arguments for.
    175         unique_function_id: UUID of function that this tracks arguments for.
    176         node_name_prefix: How identities that are created are named.
    177         attr_name: Name of attribute to use to store the index for this hint.
    178           i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
    179         level: Hierarchical level of the Ophint node, a number.
    180         children_inputs_mappings: Inputs/Outputs mapping for children hints.
    181       """
    182 
    183       # The global index is the argument index of the op. This is in contrast
    184       # to the sort index which is the sequence number of a particular instance
    185       # of a given global index. For example, you may have called add hint
    186       # twice with the tag "foo". Then the global index will be 0 for both
    187       # and the sort index will be 0 for the first added and 1 for the second.
    188       self._function_name = function_name
    189       self._unique_function_id = unique_function_id
    190       self._next_global_index = 0  # The absolute global index
    191       self._used_global_indices = set()
    192       self._tag_to_global_index = {}  # The argument index a given tag maps to
    193       self._tag_to_next_sort_index = {}  # The current index for each tag
    194       self._node_name_prefix = node_name_prefix
    195       self._attr_name = attr_name
    196       self._level = level
    197       self._children_inputs_mappings = children_inputs_mappings
    198 
    199     def _get_new_global_index(self, index_override):
    200       """Return the next unused argument index in order or use an override.
    201 
    202       Args:
    203         index_override: An index to use instead of the next available or None
    204           to use the next available.
    205 
    206       Returns:
    207         A valid global_index to use for the next hint argument.
    208 
    209       Raises:
    210         ValueError: If the index_override is already used by another hint.
    211       """
    212       if index_override is None:
    213         global_index = self._next_global_index
    214       else:
    215         if index_override in self._used_global_indices:
    216           raise ValueError("Index %d was already used by another call to add")
    217         global_index = index_override
    218       # Make next_global_index valid
    219       self._used_global_indices.add(global_index)
    220       while self._next_global_index in self._used_global_indices:
    221         self._next_global_index += 1
    222       return global_index
    223 
    224     def add(self, arg, tag=None, name=None, aggregate=None,
    225             index_override=None):
    226       """Return a wrapped tensor of an input tensor as an argument.
    227 
    228       Args:
    229         arg: A TensorFlow tensor that should be considered an argument.
    230         tag: String tag to identify arguments that should be packed.
    231         name: Name of argument. This is included in the Identity hint op names.
    232         aggregate: Strategy to aggregate.
    233         Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
    234           and OpHint.AGGREGATE_STACK.
    235           Note, aggregate is only valid if tag is specified.
    236         index_override: Specify what input/output index should this be in the
    237           final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the
    238           final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
    239           the default call order based ordering.
    240 
    241       Returns:
    242         A tensor representing the wrapped argument.
    243 
    244       Raises:
    245         ValueError: When indices are not consistent.
    246       """
    247 
    248       # Find the appropriate index
    249       if tag is None:
    250         if aggregate is not None:
    251           raise ValueError("You must specify `tag` if using aggregate.")
    252         global_index = self._get_new_global_index(index_override)
    253         sort_index = None
    254       else:
    255         if aggregate is None:
    256           raise ValueError("You must specify `aggregate` if using tag.")
    257         if tag not in self._tag_to_global_index:
    258           self._tag_to_global_index[tag] = (
    259               self._get_new_global_index(index_override))
    260           self._tag_to_next_sort_index[tag] = 0
    261         elif (index_override and
    262               index_override != self._tag_to_global_index[tag]):
    263           raise ValueError(
    264               "Tag %r was called with two indices %r and %r" %
    265               (tag, index_override, self._tag_to_global_index[tag]))
    266         global_index = self._tag_to_global_index[tag]
    267         sort_index = self._tag_to_next_sort_index[tag]
    268         self._tag_to_next_sort_index[tag] += 1
    269 
    270       uuid = self._unique_function_id
    271       name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
    272                                     uuid, global_index, sort_index, name)
    273 
    274       identity_op = _array_ops.identity(arg, name=name)
    275 
    276       # pylint: disable=protected-access
    277       identity_op.op._set_attr(
    278           OpHint.FUNCTION_NAME_ATTR,
    279           _attr_value_pb2.AttrValue(
    280               s=_compat.as_bytes(self._function_name)))
    281       identity_op.op._set_attr(
    282           OpHint.FUNCTION_UUID_ATTR,
    283           _attr_value_pb2.AttrValue(
    284               s=_compat.as_bytes(self._unique_function_id)))
    285       identity_op.op._set_attr(
    286           self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
    287       identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR,
    288                                _attr_value_pb2.AttrValue(i=self._level))
    289       if self._children_inputs_mappings:
    290         identity_op.op._set_attr(
    291             OpHint.CHILDREN_INPUTS_MAPPINGS,
    292             _attr_value_pb2.AttrValue(
    293                 s=_compat.as_bytes(_json.dumps(
    294                     self._children_inputs_mappings))))
    295 
    296       if sort_index is not None:
    297         identity_op.op._set_attr(
    298             OpHint.FUNCTION_SORT_INDEX_ATTR,
    299             _attr_value_pb2.AttrValue(i=sort_index))
    300       if aggregate is not None:
    301         identity_op.op._set_attr(
    302             OpHint.FUNCTION_AGGREGATE_ATTR,
    303             _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
    304       # pylint: enable=protected-access
    305       return identity_op
    306 
    307   def __init__(self,
    308                function_name,
    309                level=1,
    310                children_inputs_mappings=None,
    311                **kwargs):
    312     """Create a OpHint.
    313 
    314     Args:
    315       function_name: Name of the function (the custom op name in tflite)
    316       level: OpHint level.
    317       children_inputs_mappings: Children OpHint inputs/outputs mapping.
    318         children_inputs_mappings should like below:
    319         "parent_first_child_input":
    320             [{"parent_input_index": num, "child_input_index": num}, ...]
    321         "parent_last_child_output":
    322             [{"parent_output_index": num, "child_output_index": num}, ...]
    323         "internal_children_input_output":
    324             [{"child_input_index": num, "child_output_index": num}, ...]
    325       **kwargs: Keyword arguments of any constant attributes for the function.
    326     """
    327     self._function_name = function_name
    328     self._level = level
    329     if self._level == 1:
    330       assert children_inputs_mappings is None
    331     else:
    332       assert isinstance(children_inputs_mappings, dict)
    333     self._children_inputs_mappings = children_inputs_mappings
    334     if self._children_inputs_mappings is not None:
    335       self._validate_children_inputs_mappings(self._children_inputs_mappings)
    336     self._unique_function_id = _uuid.uuid1().hex  # TODO(aselle): Unique enough?
    337     self._attrs_to_store_later = kwargs
    338     self._stored_attrs = False
    339     self._inputs = OpHint.OpHintArgumentTracker(
    340         self._function_name, self._unique_function_id, "InputHint",
    341         OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings)
    342     self._outputs = OpHint.OpHintArgumentTracker(
    343         self._function_name, self._unique_function_id, "OutputHint",
    344         OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level,
    345         self._children_inputs_mappings)
    346 
    347   def _validate_children_inputs_mappings(self, children_inputs_mappings):
    348     """Validate children inputs mappings is in the right format.
    349 
    350     Args:
    351       children_inputs_mappings: the Children ophint inputs/outputs mapping.
    352     """
    353     assert isinstance(children_inputs_mappings, dict)
    354     assert "parent_first_child_input" in children_inputs_mappings
    355     assert "parent_last_child_output" in children_inputs_mappings
    356     assert "internal_children_input_output" in children_inputs_mappings
    357 
    358     # validate parent_first_child_input.
    359 
    360     def assert_dictlist_has_keys(dictlist, keys):
    361       for dikt in dictlist:
    362         assert isinstance(dikt, dict)
    363         for key in keys:
    364           assert key in dikt
    365 
    366     assert_dictlist_has_keys(
    367         children_inputs_mappings["parent_first_child_input"],
    368         ["parent_ophint_input_index", "first_child_ophint_input_index"])
    369     assert_dictlist_has_keys(
    370         children_inputs_mappings["parent_last_child_output"],
    371         ["parent_output_index", "child_output_index"])
    372     assert_dictlist_has_keys(
    373         children_inputs_mappings["internal_children_input_output"],
    374         ["child_input_index", "child_output_index"])
    375 
    376   def _setattr(self, dest_op, name, value):
    377     tensor_value = _ops.convert_to_tensor(value)
    378     # pylint: disable=protected-access
    379     dest_op.op._set_attr(name, _attr_value_pb2.AttrValue(
    380         tensor=tensor_value.op.node_def.attr["value"].tensor))
    381     # pylint: enable=protected-access
    382 
    383   def add_input(self, *args, **kwargs):
    384     """Add a wrapped input argument to the hint.
    385 
    386     Args:
    387       *args: The input tensor.
    388       **kwargs:
    389         "name" label
    390         "tag" a tag to group multiple arguments that will be aggregated. I.e.
    391           a string like 'cool_input'. Basically multiple inputs can be added
    392           to the same hint for parallel operations that will eventually be
    393           combined. An example would be static_rnn which creates multiple copies
    394           of state or inputs.
    395         "aggregate" aggregation strategy that is valid only for tag non None.
    396           Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
    397           and OpHint.AGGREGATE_STACK.
    398         "index_override" The global index to use. This corresponds to the
    399           argument order in the final stub that will be generated.
    400     Returns:
    401       The wrapped input tensor.
    402     """
    403     return self._inputs.add(*args, **kwargs)
    404 
    405   def add_output(self, *args, **kwargs):
    406     """Add a wrapped output argument to the hint.
    407 
    408     Args:
    409       *args: The output tensor.
    410       **kwargs:
    411         "name" label
    412         "tag" a tag to group multiple arguments that will be aggregated. I.e.
    413           a string like 'cool_input'. Basically multiple inputs can be added
    414           to the same hint for parallel operations that will eventually be
    415           combined. An example would be static_rnn which creates multiple copies
    416           of state or inputs.
    417         "aggregate" aggregation strategy that is valid only for tag non None.
    418           Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
    419           and OpHint.AGGREGATE_STACK.
    420         "index_override" The global index to use. This corresponds to the
    421           argument order in the final stub that will be generated.
    422     Returns:
    423       The wrapped output tensor.
    424     """
    425     return self._outputs.add(*args, **kwargs)
    426 
    427   def add_inputs(self, *args, **kwargs):
    428     """Add a sequence of inputs to the function invocation.
    429 
    430     Args:
    431       *args: List of inputs to be converted (should be Tf.Tensor).
    432       **kwargs: This allows 'names' which should be a list of names.
    433     Returns:
    434       Wrapped inputs (identity standins that have additional metadata). These
    435       are also are also tf.Tensor's.
    436     """
    437     if "names" in kwargs:
    438       return [
    439           self._inputs.add(arg, name=name)
    440           for arg, name in zip(args, kwargs["names"])
    441       ]
    442     else:
    443       return [self._inputs.add(arg) for arg in args]
    444 
    445   def add_outputs(self, *args, **kwargs):
    446     """Add a sequence of outputs to the function invocation.
    447 
    448     Args:
    449       *args: List of outputs to be converted (should be tf.Tensor).
    450       **kwargs: See
    451     Returns:
    452       Wrapped outputs (identity standins that have additional metadata). These
    453       are also tf.Tensor's.
    454     """
    455     if "names" in kwargs:
    456       return [
    457           self._outputs.add(arg, name=name)
    458           for arg, name in zip(args, kwargs["names"])
    459       ]
    460     else:
    461       return [self._outputs.add(arg) for arg in args]
    462 
    463 
    464 class _LiteOperand(object):
    465   """Abstract operand for a tflite hint function._dynamic_rnn_loop.
    466 
    467   This is a base class that handles representing arguments to an OpHint.
    468   It also is able to serialize operands to the stubbed graph_def.
    469   Child classes are responsible for being able to
    470   store information about the hint identity operators. They are also responsible
    471   for knowing how to serialize to output graphdefs.
    472 
    473   Typically this will be implemented by holding one or more identity nodes
    474   that were previously discovered as hints.
    475   """
    476 
    477   def aggregate_and_return_name_for_input(self, out_graphdef):
    478     """This adds the node(s) to out_graphdef and returns the input node name.
    479 
    480     Args:
    481       out_graphdef: A graphdef that is ready to have this input added.
    482 
    483     Returns:
    484       The output that the stub should use as an input for this operand.
    485 
    486     Raises:
    487       RuntimeError: if the method is not implemented.
    488     """
    489     del out_graphdef
    490     raise RuntimeError("Unimplemented abstract method.")
    491 
    492   def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
    493                                            out_graphdef):
    494     """Add node(s) to graph representing output operands and returns type.
    495 
    496     Args:
    497       fused_op_name: name of the fused op stub name.
    498       output_index: Output index that we are currently processing from stub.
    499       out_graphdef: The destination graphdef we are currently building up.
    500 
    501     Returns:
    502       The datatype of this identity.
    503 
    504     Raises:
    505       RuntimeError: if the method is not implemented.
    506     """
    507     del fused_op_name, output_index, out_graphdef
    508     raise RuntimeError("Unimplemented abstract method.")
    509 
    510 
    511 class _LiteSingleOperand(_LiteOperand):
    512   """A simple operand that is non-aggregated (i.e. most hints)."""
    513 
    514   def __init__(self, node):
    515     _LiteOperand.__init__(self)
    516     self.node = node
    517     self.name = _tensor_name_base(node.name)
    518 
    519   def flatten(self):
    520     return [self.name]
    521 
    522   def aggregate_and_return_name_for_input(self, out_graphdef):
    523     return self.name
    524 
    525   def aggregate_and_return_name_for_output(self, fused_op_name, index,
    526                                            out_graphdef):
    527     output_node = _copy.deepcopy(self.node)
    528     del output_node.input[:]
    529     output_node.input.append(_tensorflow_output_name(fused_op_name, index))
    530     out_graphdef.node.extend([output_node])
    531     return self.node.attr["type"].i
    532 
    533   def __str__(self):
    534     return str(self.name)
    535 
    536 
    537 class _LiteAggregateOperand(_LiteOperand):
    538   """An operand for a tflite hint function that is aggregated from many.
    539 
    540   For example, an LSTM is a grid of operators that are all related. Inputs
    541   going into them may need to be fused, so they should all be tracked as
    542   related arguments.
    543   """
    544 
    545   def __init__(self, aggregation):
    546     _LiteOperand.__init__(self)
    547     self.aggregation = aggregation
    548     self.names = {}
    549     self.nodes = {}
    550     self.flattened = None
    551 
    552   def add(self, sort, node):
    553     self.names[sort] = _tensor_name_base(node.name)
    554     self.nodes[sort] = node
    555 
    556   def flatten_nodes(self):
    557     """Return a list of all the node protos in aggregation sorted order."""
    558     if not self.flattened:
    559       self.flattened = [None] * len(self.nodes)
    560       for idx, node in _six.iteritems(self.nodes):
    561         self.flattened[idx] = node
    562       for n in self.nodes:
    563         if n is None:
    564           raise RuntimeError("Aggregate was missing argument.")
    565       if self.aggregation == OpHint.AGGREGATE_FIRST:
    566         self.flattened = self.flattened[:1]
    567       elif self.aggregation == OpHint.AGGREGATE_LAST:
    568         self.flattened = self.flattened[-1:]
    569       elif self.aggregation == OpHint.AGGREGATE_STACK:
    570         pass
    571       else:
    572         raise ValueError(
    573             "Invalid aggregation type %r specified" % self.aggregation)
    574     return self.flattened
    575 
    576   def flatten(self):
    577     """Return a list of all node names in aggregation sorted sorter."""
    578     return [_tensor_name_base(x.name) for x in self.flatten_nodes()]
    579 
    580   def aggregate_and_return_name_for_input(self, out_graphdef):
    581     """This adds the nodes to out_graphdef and returns an aggregated output.
    582 
    583     In particular, if you have 4 inputs to a hint stub, this will be the
    584     node that you can use as an output. I.e. you have 4 timesteps from a
    585     static rnn, then a fused UnidriecitonalLSTM will expect 1 input with
    586     all 4 time steps. So here we make a pack and return the output name of
    587     that pack.
    588 
    589     Args:
    590       out_graphdef: A graphdef that is ready to have this input added.
    591 
    592     Returns:
    593       The name of a pack that aggregates this node.
    594     """
    595     flattened = self.flatten_nodes()
    596     if len(flattened) == 1:
    597       return _tensor_name_base(flattened[0].name)
    598     else:
    599       new_node = _node_def_pb2.NodeDef()
    600       new_node.op = "Pack"
    601       new_node.name = "OpHintStack-%s" % flattened[0].name
    602       new_node.attr["N"].i = len(flattened)
    603       new_node.attr["T"].type = flattened[0].attr["T"].type
    604       for discrete in flattened:
    605         new_node.input.append(_tensor_name_base(discrete.name))
    606       out_graphdef.node.extend([new_node])
    607       return new_node.name
    608 
    609   def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
    610                                            out_graphdef):
    611     """This adds to `out_graphdef` all the unaggregated outputs.
    612 
    613     I.e. we are outputting from a fused stub, but we need to make it compatible
    614     with the unfused original graph so we insert an unpack. Ideally in a later
    615     stage the unpack -> pack sequences will be removed.
    616 
    617     Args:
    618       fused_op_name: The name of the stub we are in the process of fusing.
    619       output_index: The output output_index this object represents.
    620       out_graphdef: The graphdef we are in the process of buildings
    621 
    622     Returns:
    623       The type of the aggregated output (so we can finish building the stub
    624       op).
    625     """
    626     flattened = self.flatten_nodes()
    627     if len(flattened) == 1:
    628       temp_op = _LiteSingleOperand(flattened[0])
    629       return temp_op.aggregate_and_return_name_for_output(
    630           fused_op_name, output_index, out_graphdef)
    631     else:
    632       stack_node = _node_def_pb2.NodeDef()
    633       stack_node.op = "Unpack"
    634       stack_node.name = "OpHintUnstack-%s" % flattened[0].name
    635       stack_node.attr["num"].i = len(flattened)
    636       output_type = flattened[0].attr["T"].type
    637       stack_node.attr["T"].type = output_type
    638       stack_node.input.append(_tensorflow_output_name(
    639           fused_op_name, output_index))
    640       out_graphdef.node.extend([stack_node])
    641 
    642       for idx, discrete in enumerate(flattened):
    643         output_node = _copy.deepcopy(discrete)
    644         del output_node.input[:]
    645         output_node.input.append(_tensorflow_output_name(stack_node.name, idx))
    646         out_graphdef.node.extend([output_node])
    647 
    648       return output_type
    649 
    650   def __str__(self):
    651     s = "\t\t\tAGGREGATE %s\n" % self.aggregation
    652     for sort, val in self.names.iteritems():
    653       s += "\t\t\t%d: %s\n" % (sort, val)
    654     return s
    655 
    656 
    657 class _LiteFuncCall(object):
    658   """Represent a TensorFlow Lite custom function.
    659 
    660   This is uses to accumulate found hints in the graphdef into a single
    661   conceptual unit.
    662 
    663   Attributes:
    664     inputs: inputs to the op (hash from index # to argument)
    665     outputs: outputs to the op (hash from index # to argument)
    666     function_name: the tflite custom op name to use
    667     uuid: a unique call id for this particular call  (i.e.
    668       multiple function calls would have the same function_name but different
    669       uuids.
    670     params: A param name to key value for op constant data. I.e. for
    671       axis on a reduction, strides on a convolution, etc.
    672     level: Level of the OpHint.
    673     children_inputs_mappings: If the Ophint has children, children inputs
    674       mappings indicate how their inputs & outputs are mapped.
    675   """
    676 
    677   def __init__(self):
    678     self.inputs = {}
    679     self.outputs = {}
    680     self.function_name = None
    681     self.uuid = None
    682     self.params = {}
    683     self.level = -1
    684     self.children_inputs_mappings = {}
    685 
    686   def flattened_inputs_and_outputs(self):
    687     """Return a list of inputs and outputs in a flattened format.
    688 
    689     Returns:
    690       Tuple of (inputs, outputs). where input and output i a list of names.
    691     """
    692     def _flatten(input_or_output_dict):
    693       flattened_items = []
    694       for item in input_or_output_dict.values():
    695         flattened_items.extend(item.flatten())
    696       return flattened_items
    697 
    698     return _flatten(self.inputs), _flatten(self.outputs)
    699 
    700   def __str__(self):
    701     def format_args(items):
    702       s = ""
    703       for idx, item in items.iteritems():
    704         s += ("\t\t%d:\n" % idx) + str(item)
    705       return s
    706 
    707     inputs_str = "\tInputs\n" + format_args(self.inputs)
    708     outputs_str = "\tOutputs\n" + format_args(self.outputs)
    709 
    710     return (
    711         "tflite function %s call %s level %d "
    712         "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" %
    713         (self.function_name, self.uuid, self.level, inputs_str, outputs_str))
    714 
    715 
    716 def _find_all_hints_in_nodes(nodes):
    717   """Look at the all the input nodes and return a list of LiteFuncCall objs.
    718 
    719   Args:
    720     nodes: A TensorFlow graph_def to look for LiteFuncCalls.
    721 
    722   Returns:
    723     a list of `LifeFuncCall` objects in the form
    724 
    725   """
    726   func_calls = _collections.defaultdict(_LiteFuncCall)
    727 
    728   for node in nodes:
    729     attr = node.attr
    730     # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
    731     uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
    732     if (OpHint.FUNCTION_UUID_ATTR not in attr
    733         or not attr[OpHint.FUNCTION_UUID_ATTR].s):
    734       continue
    735 
    736     # Start building function
    737     call_def = func_calls[uuid]
    738     call_def.uuid = uuid
    739     call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
    740     call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
    741     # Get sorting and aggregation information
    742 
    743     sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
    744             if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
    745     if sort == -1: sort = None
    746     aggregation = None
    747     if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
    748       aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
    749 
    750     if OpHint.CHILDREN_INPUTS_MAPPINGS in attr:
    751       call_def.children_inputs_mappings = _json.loads(
    752           _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s))
    753 
    754     # Add the input or output
    755     def put_operand(stuff, index, sort, operand, aggregation):
    756       """Add a given index into the function structure."""
    757       if sort is None:
    758         stuff[index] = _LiteSingleOperand(operand)
    759       else:
    760         if index not in stuff:
    761           stuff[index] = _LiteAggregateOperand(aggregation)
    762         stuff[index].add(sort, operand)
    763 
    764     if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
    765       put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
    766                   sort, node, aggregation)
    767     if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
    768       put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
    769                   sort, node, aggregation)
    770 
    771     # Remember attributes
    772     for a in attr:
    773       if a.startswith("_tflite_attr_"):
    774         call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
    775 
    776   return func_calls
    777 
    778 
    779 def _extract_topology_sequence_mapping(nodes):
    780   return dict(
    781       (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes))
    782 
    783 
    784 def _find_children_hints_in_while_loop(function_def, nodes_mapping):
    785   """Find children hints and all nodes inside the while loop.
    786 
    787   Args:
    788     function_def: Function def of the while loop.
    789     nodes_mapping: While loop input_arg : real node name.
    790 
    791   Returns:
    792     Ordered children hints and all re-mapped nodes inside the while loop.
    793   """
    794   new_nodes = []
    795 
    796   # Make nodes inside function def inputs point to the real nodes.
    797   for node in function_def.node_def:
    798     for i in range(len(node.input)):
    799       if node.input[i] in nodes_mapping:
    800         node.input[i] = nodes_mapping[node.input[i]]
    801     new_nodes.append(_copy.deepcopy(node))
    802   name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def)
    803   children_hints = _find_all_hints_in_nodes(new_nodes)
    804   children_hints_q = []
    805   # Ordered by the outputs.
    806   for hint in _six.itervalues(children_hints):
    807     _, output_names = hint.flattened_inputs_and_outputs()
    808     seq = name_to_seq_num[output_names[0]]
    809     for output_name in output_names:
    810       seq = min(seq, name_to_seq_num[output_name])
    811     children_hints_q.append((seq, hint))
    812   children_hints_q.sort(key=lambda tup: tup[0])
    813   ordered_children_hints = [x[1] for x in children_hints_q]
    814   return ordered_children_hints, new_nodes
    815 
    816 
    817 def _find_children_hints(call, graph_def):
    818   """Find all children hints.
    819 
    820   For a given OpHint, we find all children hints inside it, we also copy all the
    821   nodes inside function defs (if applicable) to the original graph_def, they are
    822   returned in a list as well.
    823 
    824   Args:
    825     call: Parent OpHint that contains children ophints.
    826     graph_def: Original graph def.
    827 
    828   Returns:
    829     Ordered children hints inside the parent ophint; new graph def that contains
    830     nodes inside function defs (if applicable); nodes inside function defs.
    831   """
    832   name_to_input_name, _, _ = _extract_graph_summary(graph_def)
    833   input_names, output_names = call.flattened_inputs_and_outputs()
    834 
    835   reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
    836   reachable_by_output = _bfs_for_reachable_nodes(output_names,
    837                                                  name_to_input_name)
    838   output_nodes_set = set(output_names)
    839   children_hints = []
    840   out = _graph_pb2.GraphDef()
    841   out.library.CopyFrom(graph_def.library)
    842   out.versions.CopyFrom(graph_def.versions)
    843   function_def_nodes = set()
    844   for node in graph_def.node:
    845     out.node.extend([_copy.deepcopy(node)])
    846     n = _tensor_name_base(node.name)
    847     if n in reachable_by_output:
    848       if n not in reachable_by_input and n not in output_nodes_set:
    849         # special handle for while loop function def.
    850         if node.op == "While":
    851           body_name = node.attr["body"].func.name
    852           inputs_outside_loop = node.input
    853           for function_def in graph_def.library.function:
    854             if function_def.signature.name == body_name:
    855               function_inputs = function_def.signature.input_arg
    856               assert len(inputs_outside_loop) == len(function_inputs)
    857               nodes_mapping = {}
    858               for i in range(len(function_inputs)):
    859                 nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
    860               # TODO(b/123050804): Consider use grappler.
    861               (children_hints_in_loop,
    862                new_nodes) = _find_children_hints_in_while_loop(
    863                    function_def, nodes_mapping)
    864               function_def_nodes.update([x.name for x in new_nodes])
    865               children_hints.extend(children_hints_in_loop)
    866               out.node.extend(new_nodes)
    867 
    868   return children_hints, out, function_def_nodes
    869 
    870 
    871 def _tensor_name_base(full_tensor_name):
    872   """Removes the device assignment code from a tensor.
    873 
    874   e.g. _tensor_name_base("foo:3") => "foo"
    875 
    876   Args:
    877     full_tensor_name: A tensor name that is annotated with a device placement
    878       (this is what tensor flow introspection gives).
    879   Returns:
    880     A name without any device assignment.
    881   """
    882   if full_tensor_name.startswith("^"):
    883     return full_tensor_name[1:]
    884   return full_tensor_name.split(":")[0]
    885 
    886 
    887 def _tensorflow_output_name(tensor_name, output_index):
    888   return tensor_name if output_index == 0 else "%s:%d" % (tensor_name,
    889                                                           output_index)
    890 
    891 
    892 # TODO(aselle): This should be converted to grappler in the future.
    893 def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
    894                            name_to_input_name):
    895   """Checks to make sure node only connects to predecessor graph through inputs.
    896 
    897   Args:
    898     n: Node to check
    899     reachable_by_input: Nodes that are reachable by all inputs of subgraph
    900     input_nodes_set: The set of nodes that are "inputs".
    901     name_to_input_name: Maps from name to the list of inputs.
    902 
    903   Raises:
    904     TypeError: If the given node uses items past inputs directly.
    905   """
    906   next_to_visit = [n]
    907   visited = set()
    908   while next_to_visit:
    909     current_node = next_to_visit.pop()
    910     visited.add(current_node)
    911     if (current_node in reachable_by_input
    912         and current_node not in input_nodes_set):
    913       raise TypeError(
    914           "Node %s uses input %s not in input_nodes." % (n, current_node))
    915     if current_node not in input_nodes_set:
    916       next_to_visit += [
    917           input_node for input_node in name_to_input_name[current_node]
    918           if input_node not in visited
    919       ]
    920 
    921 
    922 # TODO(aselle): This should be converted to grappler in the future.
    923 def _convert_single_op_hint_to_stub(call,
    924                                     graph_def,
    925                                     function_def_nodes=None,
    926                                     is_last_run=True):
    927   """Given a graph_def, converts `call` into a stub and returns a new graph_def.
    928 
    929   Args:
    930     call: A single function call to be converted.
    931     graph_def: A graph_def to use as input (that has call obviously).
    932     function_def_nodes: Nodes inside the function def those are not connected to
    933       the graph.
    934     is_last_run: Whether it is the last run for a given pass (for OpHint has
    935       children).
    936 
    937   Returns:
    938     A new transformed graph-def that has call as a stub (single op).
    939 
    940   Note: after this process, the graph_def can no longer be loaded into
    941       the tensorflow runtime, so all future manipulations are done in graph_def
    942       level.
    943   """
    944   if function_def_nodes is None:
    945     function_def_nodes = set()
    946   name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
    947       graph_def)
    948   input_names, output_names = call.flattened_inputs_and_outputs()
    949 
    950   reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
    951   reachable_by_output = _bfs_for_reachable_nodes(output_names,
    952                                                  name_to_input_name)
    953   output_nodes_set = set(output_names)
    954   nodes_after_fuse = []
    955   nodes_deleted_by_fuse = set()
    956   # Classify each node. We want to keep everything reachable by input, but
    957   # we don't know if things that are not reachable by output or input (things
    958   # after fusing).
    959   for node in graph_def.node:
    960     n = _tensor_name_base(node.name)
    961     if n in reachable_by_output:
    962       if n not in reachable_by_input and n not in output_nodes_set:
    963         nodes_deleted_by_fuse.add(n)
    964     elif n not in reachable_by_input and n not in function_def_nodes:
    965       # n is a node that after all the fusings, so keep it.
    966       nodes_after_fuse.append(n)
    967     else:
    968       # In the last run, n is a node that is randomly in the graph but not
    969       # connected to the chain of dependencies, we will delete n, otherwise
    970       # we keep them.
    971       if not is_last_run:
    972         nodes_after_fuse.append(n)
    973 
    974   # Make a new graphdef with all the pre-input and input nodes
    975   out = _graph_pb2.GraphDef()
    976   reachable_by_input_sorted = sorted(
    977       list(reachable_by_input), key=lambda n: name_to_seq_num[n])
    978   for node in reachable_by_input_sorted:
    979     out.node.extend([_copy.deepcopy(name_to_node[node])])
    980 
    981   # Create any stacks to aggregate arguments into to a single input
    982   # i.e. for static_rnn's.
    983   # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
    984   sorted_input_indices = list(call.inputs.keys())
    985   sorted_input_indices.sort()
    986   sorted_output_indices = list(call.outputs.keys())
    987   sorted_output_indices.sort()
    988   new_node = _node_def_pb2.NodeDef()
    989   # Delegate to each operand to produce the proper new input for this stub node.
    990   # In particular, an aggregate input will now be a Pack of some previously
    991   # non-fused things.
    992   for input_index in sorted_input_indices:
    993     inputs = call.inputs[input_index]
    994     input_name = inputs.aggregate_and_return_name_for_input(out)
    995     new_node.input.append(input_name)
    996   new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
    997 
    998   # Create the function
    999   new_node.op = call.function_name
   1000   new_node.name = call.uuid
   1001   out.node.extend([new_node])
   1002 
   1003   # Now call each output argument to give them a chance to make the proper
   1004   # output type and add it to our new_node.
   1005   output_dtypes = []
   1006   for output_index in sorted_output_indices:
   1007     output = call.outputs[output_index]
   1008     output_dtype = (
   1009         output.aggregate_and_return_name_for_output(new_node.name, output_index,
   1010                                                     out))
   1011     output_dtypes.append(output_dtype)
   1012   new_node.attr["_output_types"].list.type[:] = output_dtypes
   1013   # TODO(aselle): what is right here?
   1014   new_node.attr["_output_quantized"].b = False
   1015 
   1016   # Add post output nodes that do not depend on the outputs
   1017   for n in nodes_after_fuse:
   1018     should_keep = True
   1019     for input_name in name_to_input_name[n]:
   1020       if input_name in nodes_deleted_by_fuse:
   1021         should_keep = False
   1022     if should_keep:
   1023       out.node.extend([_copy.deepcopy(name_to_node[n])])
   1024 
   1025   # Misc. graph_def data that needs copying.
   1026   out.library.CopyFrom(graph_def.library)
   1027   out.versions.CopyFrom(graph_def.versions)
   1028 
   1029   return out
   1030 
   1031 
   1032 # TODO(aselle): This should be converted to grappler in the future.
   1033 def _remove_one_redundant_stack_unstack(in_graph_def):
   1034   """Removes a stack->unstack pattern from in_graph_def in a returned graph.
   1035 
   1036   Args:
   1037     in_graph_def: Graph def to use as input.
   1038   Returns:
   1039     Simplified tuple (graph_def, changed_something) where changed_something
   1040     is true if anything was done.
   1041   """
   1042   name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
   1043       in_graph_def)
   1044   del name_to_seq_num
   1045 
   1046   # TODO(aselle): Make this not hardcoded.
   1047   do_generic_pack_unpack = True
   1048 
   1049   out = _graph_pb2.GraphDef()
   1050   out.library.CopyFrom(in_graph_def.library)
   1051   out.versions.CopyFrom(in_graph_def.versions)
   1052   for n in in_graph_def.node:
   1053     node_name = _tensor_name_base(n.name)
   1054     if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
   1055       continue
   1056     next_to_visit = [node_name]
   1057     visited = set()
   1058 
   1059     unpack_nodes = set()
   1060     pack_node = node_name
   1061 
   1062     # Find a pattern of unstack connected to a stack (with identities
   1063     # in between.
   1064     matches_pattern = True
   1065     is_hint_created_stack = False
   1066     while next_to_visit:
   1067       current_node_name = next_to_visit[0]
   1068       visited.add(current_node_name)
   1069       del next_to_visit[0]
   1070       node = name_to_node[current_node_name]
   1071       is_op_hint_stack = node.name.startswith("OpHintStack")
   1072       is_op_hint_unstack = node.name.startswith("OpHintUnstack")
   1073       if (node.op == "Identity" or is_op_hint_stack
   1074           or (do_generic_pack_unpack and node.op == "Pack")):
   1075         is_hint_created_stack |= is_op_hint_stack
   1076         next_to_visit += [
   1077             input_node for input_node in name_to_input_name[current_node_name]
   1078             if input_node not in visited
   1079         ]
   1080       elif (is_op_hint_unstack
   1081             or (do_generic_pack_unpack and node.op == "Unpack")):
   1082         unpack_nodes.add(node.name)
   1083         is_hint_created_stack &= is_op_hint_unstack
   1084       else:
   1085         matches_pattern = False
   1086         break
   1087       visited.add(node.name)
   1088 
   1089     if matches_pattern and len(unpack_nodes) == 1:
   1090       pack_node = node_name
   1091 
   1092       # Check to see if anyone depends on the intermediate identity or the
   1093       # Unstacked form
   1094       no_external_dependency = True
   1095       for other_n in in_graph_def.node:
   1096         if other_n.name in visited: continue
   1097         for input_tensor in name_to_input_name[other_n.name]:
   1098           input_op = _tensor_name_base(input_tensor)
   1099           if input_op in visited and input_op != pack_node:
   1100             no_external_dependency = False
   1101       # Proceed with the substitution if the stack/unstack pair was created
   1102       # through hints, or that it was not, but nobody is consuming things
   1103       # between the stack and unstack.
   1104       if is_hint_created_stack or no_external_dependency:
   1105         end = unpack_nodes.pop()
   1106         end_input = name_to_node[end].input[0]
   1107         # All nodes that depend on the final stack need to be redone to use
   1108         for other_n in in_graph_def.node:
   1109           node_name = _tensor_name_base(other_n.name)
   1110           if node_name not in visited:
   1111             new_node = _copy.deepcopy(other_n)
   1112             new_node.input[:] = [
   1113                 (end_input if stripped == pack_node else
   1114                  non_stripped) for stripped, non_stripped in zip(
   1115                      name_to_input_name[node_name], new_node.input[:])
   1116             ]
   1117             out.node.extend([new_node])
   1118         return out, True
   1119   return in_graph_def, False
   1120 
   1121 
   1122 def _remove_redundant_stack_unstack(graph_def):
   1123   curr = graph_def
   1124   del graph_def
   1125   changed_stuff = True
   1126   while changed_stuff:
   1127     curr, changed_stuff = _remove_one_redundant_stack_unstack(curr)
   1128   return curr
   1129 
   1130 
   1131 def _get_correct_mapping(original_index, nodes):
   1132   # Special handle for the index is -1 case.
   1133   # If it is -1, return the last index.
   1134   if original_index == -1:
   1135     node_indices = nodes.keys()
   1136     node_indices = sorted(node_indices)
   1137     return node_indices[-1]
   1138   else:
   1139     return original_index
   1140   return original_index
   1141 
   1142 
   1143 def _convert_op_hints_to_stubs_helper(
   1144     graph_def, write_callback=lambda sess, graph_def: None):
   1145   """Converts a graph_def to a new graph_def where all op hints are stubbed.
   1146 
   1147   Args:
   1148     graph_def: A graph def that we should convert.
   1149     write_callback: A function pointer that can be used to write intermediate
   1150       steps of graph transformation (optional).
   1151   Returns:
   1152     A new stubbed graph_def.
   1153   """
   1154   hints = _find_all_hints_in_nodes(graph_def.node)
   1155 
   1156   hints_q = []
   1157   for hint in _six.itervalues(hints):
   1158     hints_q.append((hint.level, hint.uuid))
   1159 
   1160   hints_q.sort(key=lambda tup: tup[0])
   1161   for i in range(len(hints_q) - 1, -1, -1):
   1162     level, hint_uuid = hints_q[i]
   1163 
   1164   curr_graph_def = graph_def
   1165   del graph_def  # prevent using graph_def again (common source of error)
   1166   for i in range(len(hints_q) - 1, -1, -1):
   1167     level, hint_uuid = hints_q[i]
   1168     if level >= 2:
   1169       children_hints, curr_graph_def, function_def_nodes = _find_children_hints(
   1170           hints[hint_uuid], curr_graph_def)
   1171       # pylint: disable=superfluous-parens
   1172       assert (len(children_hints) > 0)  #  pylint: disable=g-explicit-length-test
   1173       # pylint: enable=superfluous-parens
   1174 
   1175       # Re-wire the children hints inputs/outputs, so latter child's inputs
   1176       # connect to previous child node's outputs.
   1177       children_inputs_mappings = hints[hint_uuid].children_inputs_mappings
   1178       for j in range(len(children_hints)):
   1179         child_hint = children_hints[j]
   1180         if j == 0:
   1181           for mapping in children_inputs_mappings["parent_first_child_input"]:
   1182             parent_input_index = _get_correct_mapping(
   1183                 mapping["parent_ophint_input_index"], hints[hint_uuid].inputs)
   1184             child_input_index = _get_correct_mapping(
   1185                 mapping["first_child_ophint_input_index"], child_hint.inputs)
   1186             child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[
   1187                 parent_input_index]
   1188         else:
   1189           for mapping in children_inputs_mappings[
   1190               "internal_children_input_output"]:
   1191             input_index = _get_correct_mapping(mapping["child_input_index"],
   1192                                                child_hint.inputs)
   1193             output_index = _get_correct_mapping(mapping["child_output_index"],
   1194                                                 children_hints[j - 1].outputs)
   1195             child_hint.inputs[input_index] = children_hints[
   1196                 j - 1].outputs[output_index]
   1197         if j == len(children_hints) - 1:
   1198           for mapping in children_inputs_mappings["parent_last_child_output"]:
   1199             parent_output_index = _get_correct_mapping(
   1200                 mapping["parent_output_index"], hints[hint_uuid].outputs)
   1201             child_output_index = _get_correct_mapping(
   1202                 mapping["child_output_index"], child_hint.outputs)
   1203             child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[
   1204                 parent_output_index]
   1205 
   1206       for j in range(len(children_hints)):
   1207         child_hint = children_hints[j]
   1208         curr_graph_def = _convert_single_op_hint_to_stub(
   1209             child_hint, curr_graph_def, function_def_nodes,
   1210             j == len(children_hints) - 1)
   1211     else:
   1212       curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid],
   1213                                                        curr_graph_def)
   1214       write_callback(curr_graph_def, "initial")
   1215   # The stubbing process can create stacks/unstacks in the case of LSTMs
   1216   # remove them.
   1217   curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
   1218   return curr_graph_def
   1219 
   1220 
   1221 def find_all_hinted_output_nodes(session=None, graph_def=None):
   1222   """Find all Ophints output nodes in the graph.
   1223 
   1224   This is used to get all the output nodes those are ophinted, it is important
   1225   for operation like convert_variables_to_constants keep all ophints structure.
   1226   Note: only one of session or graph_def should be used, not both.
   1227 
   1228   Args:
   1229     session: A TensorFlow session that contains the graph to convert.
   1230     graph_def: A graph def that we should convert.
   1231 
   1232   Returns:
   1233     A list of OpHints output nodes.
   1234   Raises:
   1235     ValueError: If both session and graph_def are provided.
   1236   """
   1237   if session is not None and graph_def is not None:
   1238     raise ValueError("Provide only one of session and graph_def.")
   1239   hinted_outputs_nodes = []
   1240   if session is not None:
   1241     hints = _find_all_hints_in_nodes(session.graph_def.node)
   1242   elif graph_def is not None:
   1243     hints = _find_all_hints_in_nodes(graph_def.node)
   1244   for hint in _six.itervalues(hints):
   1245     _, ouput_nodes = hint.flattened_inputs_and_outputs()
   1246     hinted_outputs_nodes.extend(ouput_nodes)
   1247   return hinted_outputs_nodes
   1248 
   1249 
   1250 @_tf_export("lite.experimental.convert_op_hints_to_stubs")
   1251 def convert_op_hints_to_stubs(session=None,
   1252                               graph_def=None,
   1253                               write_callback=lambda graph_def, comments: None):
   1254   """Converts a graphdef with LiteOp hints into stub operations.
   1255 
   1256   This is used to prepare for toco conversion of complex intrinsic usages.
   1257   Note: only one of session or graph_def should be used, not both.
   1258 
   1259   Args:
   1260     session: A TensorFlow session that contains the graph to convert.
   1261     graph_def: A graph def that we should convert.
   1262     write_callback: A function pointer that can be used to write intermediate
   1263       steps of graph transformation (optional).
   1264   Returns:
   1265     A new graphdef with all ops contained in OpHints being replaced by
   1266     a single op call with the right parameters.
   1267   Raises:
   1268     ValueError: If both session and graph_def are provided.
   1269   """
   1270 
   1271   if session is not None and graph_def is not None:
   1272     raise ValueError("Provide only one of session and graph_def.")
   1273 
   1274   if session is not None:
   1275     return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback)
   1276   elif graph_def is not None:
   1277     return _convert_op_hints_to_stubs_helper(graph_def, write_callback)
   1278   else:
   1279     raise ValueError("Must specify session or graph_def as input.")
   1280 
   1281 
   1282 _allowed_symbols = [
   1283     "OpHint", "convert_op_hints_to_stubs", "convert_op_hints_to_stubs_new",
   1284     "find_all_hinted_output_nodes"
   1285 ]
   1286 remove_undocumented(__name__, _allowed_symbols)
   1287