1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """TensorFlow Lite tooling helper functionality. 16 17 EXPERIMENTAL: APIs here are unstable and likely to change without notice. 18 19 @@toco_convert 20 @@toco_convert_protos 21 @@OpHint 22 @@convert_op_hints_to_stubs 23 24 """ 25 from __future__ import absolute_import 26 from __future__ import division 27 from __future__ import print_function 28 import os 29 import subprocess 30 import tempfile 31 32 # pylint: disable=unused-import 33 from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs 34 from tensorflow.contrib.lite.python.op_hint import OpHint 35 # pylint: enable=unused-import 36 from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 37 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 38 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 39 from tensorflow.python.framework import dtypes as _dtypes 40 from tensorflow.python.platform import resource_loader as _resource_loader 41 from tensorflow.python.util.all_util import remove_undocumented 42 from tensorflow.python.util.lazy_loader import LazyLoader 43 44 # Lazy load since some of the performance benchmark skylark rules 45 # break dependencies. 46 _toco_python = LazyLoader( 47 "tensorflow_wrap_toco", globals(), 48 "tensorflow.contrib.lite.toco.python." 49 "tensorflow_wrap_toco") 50 del LazyLoader 51 52 # Enum types from the protobuf promoted to the API 53 FLOAT = _types_pb2.FLOAT 54 INT32 = _types_pb2.INT32 55 INT64 = _types_pb2.INT64 56 STRING = _types_pb2.STRING 57 QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8 58 TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF 59 TFLITE = _toco_flags_pb2.TFLITE 60 GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT 61 62 # Currently the default mode of operation is to shell to another python process 63 # to protect against crashes. However, it breaks some dependent targets because 64 # it forces us to depend on an external py_binary. The experimental API doesn't 65 # have that drawback. 66 EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False 67 68 # Find the toco_from_protos binary using the resource loader if using from 69 # bazel, otherwise we are in a pip where console_scripts already has 70 # the toco_from_protos tool. 71 if EXPERIMENTAL_USE_TOCO_API_DIRECTLY: 72 _toco_from_proto_bin = "" 73 else: 74 _toco_from_proto_bin = _resource_loader.get_path_to_datafile( 75 "../toco/python/toco_from_protos") 76 77 if _toco_from_proto_bin and not os.path.exists(_toco_from_proto_bin): 78 _toco_from_proto_bin = "toco_from_protos" 79 80 81 def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): 82 """Convert `input_data_str` according to model and toco parameters. 83 84 Unless you know what you are doing consider using 85 the more friendly @{tf.contrib.lite.toco_convert}}. 86 87 Args: 88 model_flags_str: Serialized proto describing model properties, see 89 `toco/model_flags.proto`. 90 toco_flags_str: Serialized proto describing conversion properties, see 91 `toco/toco_flags.proto`. 92 input_data_str: Input data in serialized form (e.g. a graphdef is common) 93 Returns: 94 Converted model in serialized form (e.g. a TFLITE model is common). 95 Raises: 96 RuntimeError: When conversion fails, an exception is raised with the error 97 message embedded. 98 """ 99 # TODO(aselle): When toco does not use fatal errors for failure, we can 100 # switch this on. 101 if not _toco_from_proto_bin: 102 return _toco_python.TocoConvert( 103 model_flags_str, toco_flags_str, input_data_str) 104 105 with tempfile.NamedTemporaryFile() as fp_toco, \ 106 tempfile.NamedTemporaryFile() as fp_model, \ 107 tempfile.NamedTemporaryFile() as fp_input, \ 108 tempfile.NamedTemporaryFile() as fp_output: 109 fp_model.write(model_flags_str) 110 fp_toco.write(toco_flags_str) 111 fp_input.write(input_data_str) 112 fp_model.flush() 113 fp_toco.flush() 114 fp_input.flush() 115 116 cmd = [ 117 _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name, 118 fp_output.name 119 ] 120 cmdline = " ".join(cmd) 121 proc = subprocess.Popen( 122 cmdline, 123 shell=True, 124 stdout=subprocess.PIPE, 125 stderr=subprocess.STDOUT, 126 close_fds=True) 127 stdout, stderr = proc.communicate() 128 exitcode = proc.returncode 129 if exitcode == 0: 130 stuff = fp_output.read() 131 return stuff 132 else: 133 raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" % 134 (stdout, stderr)) 135 136 137 def _tensor_name(x): 138 return x.name.split(":")[0] 139 140 141 def toco_convert(input_data, 142 input_tensors, 143 output_tensors, 144 inference_type=FLOAT, 145 input_format=TENSORFLOW_GRAPHDEF, 146 output_format=TFLITE, 147 quantized_input_stats=None, 148 drop_control_dependency=True): 149 """Convert a model using TOCO from `input_format` to `output_format`. 150 151 Typically this is to convert from TensorFlow GraphDef to TFLite, in which 152 case the default `input_format` and `output_format` are sufficient. 153 154 Args: 155 input_data: Input data (i.e. often `sess.graph_def`). 156 input_tensors: List of input tensors. Type and shape are computed using 157 `foo.get_shape()` and `foo.dtype`. 158 output_tensors: List of output tensors (only .name is used from this). 159 inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. 160 input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). 161 output_format: Type of data to write (currently must be TFLITE or 162 GRAPHVIZ_DOT) 163 quantized_input_stats: For each member of input_tensors the mean and 164 std deviation of training data. Only needed if `inference_type` is 165 `QUANTIZED_UINT8`. 166 drop_control_dependency: Drops control dependencies silently. This is due 167 to tf lite not supporting control dependencies. 168 169 Returns: 170 The converted data. For example if tflite was the destination, then 171 this will be a tflite flatbuffer in a bytes array. 172 173 Raises: 174 ValueError: If the input tensor type is unknown 175 RuntimeError: If TOCO fails to convert (in which case the runtime error's 176 error text will contain the TOCO error log) 177 """ 178 toco = _toco_flags_pb2.TocoFlags() 179 toco.input_format = input_format 180 toco.output_format = output_format 181 toco.drop_control_dependency = drop_control_dependency 182 model = _model_flags_pb2.ModelFlags() 183 toco.inference_type = inference_type 184 for idx, input_tensor in enumerate(input_tensors): 185 if input_tensor.dtype == _dtypes.float32: 186 tflite_input_type = FLOAT 187 elif input_tensor.dtype == _dtypes.int32: 188 tflite_input_type = INT32 189 elif input_tensor.dtype == _dtypes.int64: 190 tflite_input_type = INT64 191 # TODO(aselle): Insert strings when they are available 192 else: 193 raise ValueError("Tensors %s not known type %r" % (input_tensor.name, 194 input_tensor.dtype)) 195 196 input_array = model.input_arrays.add() 197 198 if inference_type == QUANTIZED_UINT8: 199 if tflite_input_type == FLOAT: 200 tflite_input_type = QUANTIZED_UINT8 201 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] 202 203 input_array.name = _tensor_name(input_tensor) 204 input_array.shape.dims.extend(map(int, input_tensor.get_shape())) 205 toco.inference_input_type = tflite_input_type 206 207 for output_tensor in output_tensors: 208 model.output_arrays.append(_tensor_name(output_tensor)) 209 210 data = toco_convert_protos(model.SerializeToString(), 211 toco.SerializeToString(), 212 input_data.SerializeToString()) 213 return data 214 215 216 _allowed_symbols = [ 217 "FLOAT", 218 "INT32", 219 "INT64", 220 "STRING", 221 "QUANTIZED_UINT8", 222 "TENSORFLOW_GRAPHDEF", 223 "TFLITE", 224 "GRAPHVIZ_DOT", 225 "EXPERIMENTAL_USE_TOCO_API_DIRECTLY", 226 ] 227 remove_undocumented(__name__, _allowed_symbols) 228