1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================= 15 """Exposes the Python wrapper conversion to trt_graph.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 # pylint: disable=unused-import,line-too-long 22 import six as _six 23 from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert 24 from tensorflow.core.framework import graph_pb2 25 from tensorflow.python.framework import errors 26 from tensorflow.python.framework import errors_impl as _impl 27 from tensorflow.python.framework import ops 28 29 30 # TODO(skama): get outputs from session when implemented as c++ 31 # optimization pass 32 def create_inference_graph(input_graph_def, 33 outputs, 34 max_batch_size=1, 35 max_workspace_size_bytes=2 << 20): 36 """Python wrapper for the TRT transormation. 37 38 39 Args: 40 input_graph_def: GraphDef object containing a model to be transformed. 41 outputs: List of tensors or node names for the model outputs. 42 max_batch_size: max size for the input batch 43 max_workspace_size_bytes: parameter to control memory allocation (in Bytes) 44 45 Returns: 46 New GraphDef with TRTEngineOps placed in graph replacing subgraphs. 47 48 Raises: 49 RuntimeError: if the returned status message is malformed. 50 """ 51 52 def py2bytes(inp): 53 return inp 54 55 def py3bytes(inp): 56 return inp.encode("utf-8", errors="surrogateescape") 57 58 def py2string(inp): 59 return inp 60 61 def py3string(inp): 62 return inp.decode("utf-8") 63 64 if _six.PY2: 65 to_bytes = py2bytes 66 to_string = py2string 67 else: 68 to_bytes = py3bytes 69 to_string = py3string 70 71 out_names = [] 72 for i in outputs: 73 if isinstance(i, ops.Tensor): 74 out_names.append(to_bytes(i.name)) 75 else: 76 out_names.append(to_bytes(i)) 77 78 input_graph_def_str = input_graph_def.SerializeToString() 79 80 # TODO(sami): Fix this when we can return status from C++ library 81 # There is a problem with the TF internal library setup that doesn't 82 # allow us to return a status object from C++. Thus we return a 83 # pair or strings where first one is encoded status and the second 84 # one is the transformed graphs protobuf string. 85 out = trt_convert(input_graph_def_str, out_names, max_batch_size, 86 max_workspace_size_bytes) 87 status = to_string(out[0]) 88 output_graph_def_string = out[1] 89 del input_graph_def_str # Save some memory 90 if len(status) < 2: 91 raise _impl.UnknownError(None, None, status) 92 if status[:2] != "OK": 93 msg = status.split(";") 94 if len(msg) == 1: 95 raise RuntimeError("Status message is malformed {}".format(status)) 96 # pylint: disable=protected-access 97 raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), 98 int(msg[0])) 99 # pylint: enable=protected-access 100 output_graph_def = graph_pb2.GraphDef() 101 output_graph_def.ParseFromString(output_graph_def_string) 102 del output_graph_def_string # Save some memory 103 return output_graph_def 104