Home | History | Annotate | Download | only in python
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 """Exposes the Python wrapper conversion to trt_graph."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 # pylint: disable=unused-import,line-too-long
     22 import six as _six
     23 from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
     24 from tensorflow.core.framework import graph_pb2
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import errors_impl as _impl
     27 from tensorflow.python.framework import ops
     28 
     29 
     30 # TODO(skama): get outputs from session when implemented as c++
     31 # optimization pass
     32 def create_inference_graph(input_graph_def,
     33                            outputs,
     34                            max_batch_size=1,
     35                            max_workspace_size_bytes=2 << 20):
     36   """Python wrapper for the TRT transormation.
     37 
     38 
     39   Args:
     40     input_graph_def: GraphDef object containing a model to be transformed.
     41     outputs: List of tensors or node names for the model outputs.
     42     max_batch_size: max size for the input batch
     43     max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
     44 
     45   Returns:
     46     New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
     47 
     48   Raises:
     49     RuntimeError: if the returned status message is malformed.
     50   """
     51 
     52   def py2bytes(inp):
     53     return inp
     54 
     55   def py3bytes(inp):
     56     return inp.encode("utf-8", errors="surrogateescape")
     57 
     58   def py2string(inp):
     59     return inp
     60 
     61   def py3string(inp):
     62     return inp.decode("utf-8")
     63 
     64   if _six.PY2:
     65     to_bytes = py2bytes
     66     to_string = py2string
     67   else:
     68     to_bytes = py3bytes
     69     to_string = py3string
     70 
     71   out_names = []
     72   for i in outputs:
     73     if isinstance(i, ops.Tensor):
     74       out_names.append(to_bytes(i.name))
     75     else:
     76       out_names.append(to_bytes(i))
     77 
     78   input_graph_def_str = input_graph_def.SerializeToString()
     79 
     80   # TODO(sami): Fix this when we can return status from C++ library
     81   # There is a problem with the TF internal library setup that doesn't
     82   # allow us to return a status object from C++.  Thus we return a
     83   # pair or strings where first one is encoded status and the second
     84   # one is the transformed graphs protobuf string.
     85   out = trt_convert(input_graph_def_str, out_names, max_batch_size,
     86                     max_workspace_size_bytes)
     87   status = to_string(out[0])
     88   output_graph_def_string = out[1]
     89   del input_graph_def_str  # Save some memory
     90   if len(status) < 2:
     91     raise _impl.UnknownError(None, None, status)
     92   if status[:2] != "OK":
     93     msg = status.split(";")
     94     if len(msg) == 1:
     95       raise RuntimeError("Status message is malformed {}".format(status))
     96     # pylint: disable=protected-access
     97     raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
     98                                          int(msg[0]))
     99     # pylint: enable=protected-access
    100   output_graph_def = graph_pb2.GraphDef()
    101   output_graph_def.ParseFromString(output_graph_def_string)
    102   del output_graph_def_string  # Save some memory
    103   return output_graph_def
    104