Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """TensorFlow Lite tooling helper functionality.
     16 
     17 EXPERIMENTAL: APIs here are unstable and likely to change without notice.
     18 
     19 @@toco_convert
     20 @@toco_convert_protos
     21 @@OpHint
     22 @@convert_op_hints_to_stubs
     23 
     24 """
     25 from __future__ import absolute_import
     26 from __future__ import division
     27 from __future__ import print_function
     28 import os
     29 import subprocess
     30 import tempfile
     31 
     32 # pylint: disable=unused-import
     33 from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs
     34 from tensorflow.contrib.lite.python.op_hint import OpHint
     35 # pylint: enable=unused-import
     36 from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
     37 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
     38 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
     39 from tensorflow.python.framework import dtypes as _dtypes
     40 from tensorflow.python.platform import resource_loader as _resource_loader
     41 from tensorflow.python.util.all_util import remove_undocumented
     42 from tensorflow.python.util.lazy_loader import LazyLoader
     43 
     44 # Lazy load since some of the performance benchmark skylark rules
     45 # break dependencies.
     46 _toco_python = LazyLoader(
     47     "tensorflow_wrap_toco", globals(),
     48     "tensorflow.contrib.lite.toco.python."
     49     "tensorflow_wrap_toco")
     50 del LazyLoader
     51 
     52 # Enum types from the protobuf promoted to the API
     53 FLOAT = _types_pb2.FLOAT
     54 INT32 = _types_pb2.INT32
     55 INT64 = _types_pb2.INT64
     56 STRING = _types_pb2.STRING
     57 QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8
     58 TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
     59 TFLITE = _toco_flags_pb2.TFLITE
     60 GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
     61 
     62 # Currently the default mode of operation is to shell to another python process
     63 # to protect against crashes. However, it breaks some dependent targets because
     64 # it forces us to depend on an external py_binary. The experimental API doesn't
     65 # have that drawback.
     66 EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False
     67 
     68 # Find the toco_from_protos binary using the resource loader if using from
     69 # bazel, otherwise we are in a pip where console_scripts already has
     70 # the toco_from_protos tool.
     71 if EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
     72   _toco_from_proto_bin = ""
     73 else:
     74   _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
     75       "../toco/python/toco_from_protos")
     76 
     77 if _toco_from_proto_bin and not os.path.exists(_toco_from_proto_bin):
     78   _toco_from_proto_bin = "toco_from_protos"
     79 
     80 
     81 def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
     82   """Convert `input_data_str` according to model and toco parameters.
     83 
     84   Unless you know what you are doing consider using
     85   the more friendly @{tf.contrib.lite.toco_convert}}.
     86 
     87   Args:
     88     model_flags_str: Serialized proto describing model properties, see
     89       `toco/model_flags.proto`.
     90     toco_flags_str: Serialized proto describing conversion properties, see
     91       `toco/toco_flags.proto`.
     92     input_data_str: Input data in serialized form (e.g. a graphdef is common)
     93   Returns:
     94     Converted model in serialized form (e.g. a TFLITE model is common).
     95   Raises:
     96     RuntimeError: When conversion fails, an exception is raised with the error
     97       message embedded.
     98   """
     99   # TODO(aselle): When toco does not use fatal errors for failure, we can
    100   # switch this on.
    101   if not _toco_from_proto_bin:
    102     return _toco_python.TocoConvert(
    103         model_flags_str, toco_flags_str, input_data_str)
    104 
    105   with tempfile.NamedTemporaryFile() as fp_toco, \
    106            tempfile.NamedTemporaryFile() as fp_model, \
    107            tempfile.NamedTemporaryFile() as fp_input, \
    108            tempfile.NamedTemporaryFile() as fp_output:
    109     fp_model.write(model_flags_str)
    110     fp_toco.write(toco_flags_str)
    111     fp_input.write(input_data_str)
    112     fp_model.flush()
    113     fp_toco.flush()
    114     fp_input.flush()
    115 
    116     cmd = [
    117         _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name,
    118         fp_output.name
    119     ]
    120     cmdline = " ".join(cmd)
    121     proc = subprocess.Popen(
    122         cmdline,
    123         shell=True,
    124         stdout=subprocess.PIPE,
    125         stderr=subprocess.STDOUT,
    126         close_fds=True)
    127     stdout, stderr = proc.communicate()
    128     exitcode = proc.returncode
    129     if exitcode == 0:
    130       stuff = fp_output.read()
    131       return stuff
    132     else:
    133       raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" %
    134                          (stdout, stderr))
    135 
    136 
    137 def _tensor_name(x):
    138   return x.name.split(":")[0]
    139 
    140 
    141 def toco_convert(input_data,
    142                  input_tensors,
    143                  output_tensors,
    144                  inference_type=FLOAT,
    145                  input_format=TENSORFLOW_GRAPHDEF,
    146                  output_format=TFLITE,
    147                  quantized_input_stats=None,
    148                  drop_control_dependency=True):
    149   """Convert a model using TOCO from `input_format` to `output_format`.
    150 
    151   Typically this is to convert from TensorFlow GraphDef to TFLite, in which
    152   case the default `input_format` and `output_format` are sufficient.
    153 
    154   Args:
    155     input_data: Input data (i.e. often `sess.graph_def`).
    156     input_tensors: List of input tensors. Type and shape are computed using
    157       `foo.get_shape()` and `foo.dtype`.
    158     output_tensors: List of output tensors (only .name is used from this).
    159     inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
    160     input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
    161     output_format: Type of data to write (currently must be TFLITE or
    162       GRAPHVIZ_DOT)
    163     quantized_input_stats: For each member of input_tensors the mean and
    164       std deviation of training data. Only needed if `inference_type` is
    165       `QUANTIZED_UINT8`.
    166     drop_control_dependency: Drops control dependencies silently. This is due
    167       to tf lite not supporting control dependencies.
    168 
    169   Returns:
    170     The converted data. For example if tflite was the destination, then
    171     this will be a tflite flatbuffer in a bytes array.
    172 
    173   Raises:
    174     ValueError: If the input tensor type is unknown
    175     RuntimeError: If TOCO fails to convert (in which case the runtime error's
    176       error text will contain the TOCO error log)
    177   """
    178   toco = _toco_flags_pb2.TocoFlags()
    179   toco.input_format = input_format
    180   toco.output_format = output_format
    181   toco.drop_control_dependency = drop_control_dependency
    182   model = _model_flags_pb2.ModelFlags()
    183   toco.inference_type = inference_type
    184   for idx, input_tensor in enumerate(input_tensors):
    185     if input_tensor.dtype == _dtypes.float32:
    186       tflite_input_type = FLOAT
    187     elif input_tensor.dtype == _dtypes.int32:
    188       tflite_input_type = INT32
    189     elif input_tensor.dtype == _dtypes.int64:
    190       tflite_input_type = INT64
    191     # TODO(aselle): Insert strings when they are available
    192     else:
    193       raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
    194                                                          input_tensor.dtype))
    195 
    196     input_array = model.input_arrays.add()
    197 
    198     if inference_type == QUANTIZED_UINT8:
    199       if tflite_input_type == FLOAT:
    200         tflite_input_type = QUANTIZED_UINT8
    201       input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
    202 
    203     input_array.name = _tensor_name(input_tensor)
    204     input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
    205     toco.inference_input_type = tflite_input_type
    206 
    207   for output_tensor in output_tensors:
    208     model.output_arrays.append(_tensor_name(output_tensor))
    209 
    210   data = toco_convert_protos(model.SerializeToString(),
    211                              toco.SerializeToString(),
    212                              input_data.SerializeToString())
    213   return data
    214 
    215 
    216 _allowed_symbols = [
    217     "FLOAT",
    218     "INT32",
    219     "INT64",
    220     "STRING",
    221     "QUANTIZED_UINT8",
    222     "TENSORFLOW_GRAPHDEF",
    223     "TFLITE",
    224     "GRAPHVIZ_DOT",
    225     "EXPERIMENTAL_USE_TOCO_API_DIRECTLY",
    226 ]
    227 remove_undocumented(__name__, _allowed_symbols)
    228