Home | History | Annotate | Download | only in tensorrt
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 """Exposes the Python wrapper conversion to trt_graph."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import six as _six
     22 from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
     23 from tensorflow.core.protobuf import config_pb2
     24 from tensorflow.core.protobuf import meta_graph_pb2
     25 from tensorflow.core.protobuf import rewriter_config_pb2
     26 from tensorflow.python.client import session
     27 from tensorflow.python.eager import context
     28 from tensorflow.python.eager import function
     29 from tensorflow.python.framework import convert_to_constants
     30 from tensorflow.python.framework import dtypes
     31 from tensorflow.python.framework import func_graph
     32 from tensorflow.python.framework import graph_util
     33 from tensorflow.python.framework import importer
     34 from tensorflow.python.framework import ops
     35 from tensorflow.python.grappler import tf_optimizer
     36 from tensorflow.python.ops import array_ops
     37 from tensorflow.python.platform import tf_logging
     38 from tensorflow.python.saved_model import builder
     39 from tensorflow.python.saved_model import load
     40 from tensorflow.python.saved_model import loader
     41 from tensorflow.python.saved_model import save
     42 from tensorflow.python.saved_model import signature_constants
     43 from tensorflow.python.saved_model import tag_constants
     44 from tensorflow.python.training import saver
     45 
     46 
     47 def _to_bytes(s):
     48   """Encode s if it is a sequence of chars."""
     49   if isinstance(s, _six.text_type):
     50     return s.encode("utf-8", errors="surrogateescape")
     51   return s
     52 
     53 
     54 def _to_string(s):
     55   """Decode s if it is a sequence of bytes."""
     56   if isinstance(s, _six.binary_type):
     57     return s.decode("utf-8")
     58   return s
     59 
     60 
     61 class GraphConverter(object):
     62   """Base class for offline converters to optimize SavedModels/GraphDefs.
     63 
     64   A `GraphConverter` object encapsulates the environment to convert (optimize) a
     65   TensorFlow SavedModel or GraphDef.
     66 
     67   To create a custom GraphConverter:
     68 
     69   ```python
     70   class MyGraphConverter(GraphConverter):
     71     ...
     72 
     73     def get_rewriter_config(self, rewriter_config_template=None):
     74       my_rewriter_config = ...
     75       return my_rewriter_config
     76   ```
     77 
     78   Then to run the conversion without quantization calibration:
     79 
     80   ```python
     81   my_converter = MyGraphConverter(input_saved_model_dir="my_dir")
     82   converted_graph_def = my_converter.convert()
     83   my_converter.save(output_saved_model_dir)  # Optional
     84   ```
     85 
     86   To run the conversion with quantization calibration:
     87 
     88   ```python
     89   my_converter = MyGraphConverter(input_saved_model_dir="my_dir")
     90   my_converter.convert()
     91 
     92   # Run calibration 10 times.
     93   converted_graph_def = my_converter.calibrate(
     94       fetch_names=['output:0'],
     95       num_runs=10,
     96       feed_dict_fn=lambda: {'input:0': my_next_data()})
     97 
     98   my_converter.save(output_saved_model_dir)  # Optional
     99   ```
    100   """
    101 
    102   # TODO(laigd): clean up the parameters.
    103   def __init__(self,
    104                input_saved_model_dir=None,
    105                input_saved_model_tags=None,
    106                input_saved_model_signature_key=None,
    107                input_graph_def=None,
    108                nodes_blacklist=None,
    109                session_config=None):
    110     """Initialize the converter.
    111 
    112     Args:
    113       input_saved_model_dir: the directory to load the SavedModel which contains
    114         the input graph to transforms. Used only when input_graph_def is None.
    115       input_saved_model_tags: list of tags to load the SavedModel.
    116       input_saved_model_signature_key: the key of the signature to optimize the
    117         graph for.
    118       input_graph_def: a GraphDef object containing a model to be transformed.
    119         If set to None, the graph will be read from the SavedModel loaded from
    120         input_saved_model_dir.
    121       nodes_blacklist: list of node names to prevent the converter from
    122         touching. Only used when input_graph_def is not None.
    123       session_config: the ConfigProto used to create a Session. It's also used
    124         as a template to create a RewriterConfig for conversion. If not
    125         specified, a default ConfigProto will be used.
    126 
    127     Raises:
    128       ValueError: if the combination of the parameters is invalid.
    129     """
    130     if context.executing_eagerly():
    131       if input_graph_def or not input_saved_model_dir:
    132         raise ValueError(
    133             "TF 2.0 only supports conversion of SavedModel, please specify "
    134             "input_saved_model_dir as input.")
    135     else:
    136       if input_graph_def and input_saved_model_dir:
    137         raise ValueError(
    138             "Can only specify one of input_graph_def and input_saved_model_dir")
    139       if not input_graph_def and not input_saved_model_dir:
    140         raise ValueError("Must specify one of input_graph_def and "
    141                          "input_saved_model_dir")
    142 
    143       self._input_graph_def = input_graph_def
    144       self._nodes_blacklist = nodes_blacklist
    145 
    146     self._input_saved_model_dir = input_saved_model_dir
    147     self._converted = False
    148     self._grappler_meta_graph_def = None
    149 
    150     self._input_saved_model_tags = (
    151         input_saved_model_tags or [tag_constants.SERVING])
    152     self._input_saved_model_signature_key = (
    153         input_saved_model_signature_key or
    154         signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
    155     self._session_config = session_config or config_pb2.ConfigProto()
    156 
    157     # For calibration usage.
    158     self._calibration_graph = None
    159     self._calibration_sess = None
    160     self._calibration_data_collected = False
    161 
    162   def get_rewriter_config(self, rewriter_config_template=None):
    163     """Returns a RewriterConfig proto for TRT transformation.
    164 
    165     Args:
    166       rewriter_config_template: a template RewriterConfig proto used to create a
    167         RewriterConfig for the conversion. The implementation should not modify
    168         the template. If None, it will use a default one.
    169 
    170     Returns:
    171       A RewriterConfig proto which will be used to run the conversion using
    172       Grappler.
    173     """
    174     raise NotImplementedError("get_rewriter_config")
    175 
    176   def _run_conversion(self):
    177     """Run Grappler's OptimizeGraph() tool to convert the graph."""
    178     # Create custom ConfigProto for Grappler.
    179     grappler_session_config = config_pb2.ConfigProto()
    180     grappler_session_config.CopyFrom(self._session_config)
    181     rewriter_config = None
    182     if (grappler_session_config.HasField("graph_options") and
    183         grappler_session_config.graph_options.HasField("rewrite_options")):
    184       rewriter_config = grappler_session_config.graph_options.rewrite_options
    185     custom_rewriter_config = self.get_rewriter_config(rewriter_config)
    186     grappler_session_config.graph_options.rewrite_options.CopyFrom(
    187         custom_rewriter_config)
    188 
    189     # Run Grappler.
    190     self._converted_graph_def = tf_optimizer.OptimizeGraph(
    191         grappler_session_config,
    192         self._grappler_meta_graph_def,
    193         graph_id=b"tf_graph")
    194     self._converted = True
    195 
    196   def _add_nodes_blacklist(self):
    197     if self._nodes_blacklist:
    198       collection_def = self._grappler_meta_graph_def.collection_def["train_op"]
    199       blacklist = collection_def.node_list.value
    200       for i in self._nodes_blacklist:
    201         if isinstance(i, ops.Tensor):
    202           blacklist.append(_to_bytes(i.name))
    203         else:
    204           blacklist.append(_to_bytes(i))
    205 
    206   def _convert_graph_def(self):
    207     """Convert the input GraphDef."""
    208     graph = ops.Graph()
    209     with graph.as_default():
    210       importer.import_graph_def(self._input_graph_def, name="")
    211     self._grappler_meta_graph_def = saver.export_meta_graph(
    212         graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
    213     self._add_nodes_blacklist()
    214 
    215     self._run_conversion()
    216 
    217   def _convert_saved_model(self):
    218     """Convert the input SavedModel."""
    219     graph = ops.Graph()
    220     with session.Session(graph=graph, config=self._session_config) as sess:
    221       input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
    222                                          self._input_saved_model_dir)
    223       input_signature_def = input_meta_graph_def.signature_def[
    224           self._input_saved_model_signature_key]
    225 
    226       def _gather_names(tensor_info):
    227         """Get the node names from a TensorInfo."""
    228         return set([tensor_info[key].name.split(":")[0] for key in tensor_info])
    229 
    230       # Get input and outputs from all SignatureDef.
    231       output_node_names = _gather_names(input_signature_def.inputs).union(
    232           _gather_names(input_signature_def.outputs))
    233 
    234       # Freeze the variables in the SavedModel graph and copy the frozen
    235       # graph over.
    236       frozen_graph_def = graph_util.convert_variables_to_constants(
    237           sess, sess.graph.as_graph_def(add_shapes=True),
    238           list(output_node_names))
    239       self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
    240       self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
    241 
    242       # Copy the collections that are not variables.
    243       for key in input_meta_graph_def.collection_def:
    244         # TODO(laigd): currently we use the collection key to filter out
    245         # collections that depend on variable ops, but this may miss some
    246         # other user-defined collections. A better way would be to use
    247         # CollectionDef::NodeList for the filtering.
    248         if key not in [
    249             "variables", "local_variables", "model_variables",
    250             "trainable_variables", "train_op", "table_initializer"
    251         ]:
    252           self._grappler_meta_graph_def.collection_def[key].CopyFrom(
    253               input_meta_graph_def.collection_def[key])
    254 
    255       self._add_nodes_blacklist()
    256 
    257       # Copy other information.
    258       self._grappler_meta_graph_def.meta_info_def.CopyFrom(
    259           input_meta_graph_def.meta_info_def)
    260       self._grappler_meta_graph_def.signature_def[
    261           self._input_saved_model_signature_key].CopyFrom(input_signature_def)
    262       # TODO(laigd): maybe add back AssetFileDef.
    263 
    264     self._run_conversion()
    265 
    266   # TODO(laigd): provide a utility function to optimize a ConcreteFunction and
    267   # use it here (b/124792963).
    268   def _convert_saved_model_v2(self):
    269     """Convert the input SavedModel in 2.0 format."""
    270     self._saved_model = load.load(self._input_saved_model_dir,
    271                                   self._input_saved_model_tags)
    272     func = self._saved_model.signatures[self._input_saved_model_signature_key]
    273     frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    274     self._grappler_meta_graph_def = saver.export_meta_graph(
    275         graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)
    276 
    277     # Add a collection 'train_op' so that Grappler knows the outputs.
    278     fetch_collection = meta_graph_pb2.CollectionDef()
    279     for array in func.inputs + func.outputs:
    280       fetch_collection.node_list.value.append(array.name)
    281     self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
    282         fetch_collection)
    283 
    284     # Run TRT optimizer in Grappler to convert the graph.
    285     self._run_conversion()
    286 
    287     def _get_tensor(graph, tensors):
    288       new_tensors = []
    289       for tensor in tensors:
    290         new_tensor = graph.get_tensor_by_name(tensor.name)
    291         new_tensor.set_shape(tensor.shape)
    292         new_tensors.append(new_tensor)
    293       return new_tensors
    294 
    295     # TODO(laigd): do we need to use different name e.g. "trt_func_graph"?
    296     converted_graph = func_graph.FuncGraph(func.graph.name)
    297     with converted_graph.as_default():
    298       importer.import_graph_def(self._converted_graph_def, name="")
    299 
    300     converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs)
    301     converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs)
    302     converted_graph.structured_outputs = func.graph.structured_outputs
    303     converted_graph.structured_input_signature = (
    304         func.graph.structured_input_signature)
    305 
    306     # pylint: disable=protected-access
    307     # TODO(laigd): should we set up the signature as well?
    308     self._converted_func = function.ConcreteFunction(
    309         converted_graph, attrs=None, signature=None)
    310     self._converted_func.add_to_graph()
    311     self._converted_func._arg_keywords = func._arg_keywords
    312     self._converted_func._num_positional_args = func._num_positional_args
    313     self._converted_func._captured_inputs = func._captured_inputs
    314     self._converted_func.graph.variables = func.graph.variables
    315     # pylint: enable=protected-access
    316 
    317   def convert(self):
    318     """Run the conversion.
    319 
    320     Returns:
    321       The converted GraphDef for TF 1.x, or the converted ConcreteFunction in TF
    322       2.0+.
    323     """
    324     assert not self._converted
    325 
    326     if context.executing_eagerly():
    327       self._convert_saved_model_v2()
    328       return self._converted_func
    329     else:
    330       if self._input_graph_def:
    331         self._convert_graph_def()
    332       else:
    333         self._convert_saved_model()
    334       return self._converted_graph_def
    335 
    336   def calibrate(self,
    337                 fetch_names,
    338                 num_runs,
    339                 feed_dict_fn=None,
    340                 input_map_fn=None):
    341     """Run the calibration and return the calibrated GraphDef.
    342 
    343     Args:
    344       fetch_names: a list of output tensor name to fetch during calibration.
    345       num_runs: number of runs of the graph during calibration.
    346       feed_dict_fn: a function that returns a dictionary mapping input names (as
    347         strings) in the GraphDef to be calibrated to values (e.g. Python list,
    348         numpy arrays, etc). One and only one of `feed_dict_fn` and
    349         `input_map_fn` should be specified.
    350       input_map_fn: a function that returns a dictionary mapping input names (as
    351         strings) in the GraphDef to be calibrated to Tensor objects. The values
    352         of the named input tensors in the GraphDef to be calibrated will be
    353         re-mapped to the respective `Tensor` values during calibration. One and
    354         only one of `feed_dict_fn` and `input_map_fn` should be specified.
    355 
    356     Raises:
    357       ValueError: if the input combination is invalid.
    358       RuntimeError: if this method is called in eager mode.
    359 
    360     Returns:
    361       The GraphDef after the calibration.
    362     """
    363     assert self._converted
    364     assert not self._calibration_sess
    365 
    366     if context.executing_eagerly():
    367       raise RuntimeError("Calibration for TF 2.0 is not supported yet.")
    368 
    369     if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and
    370                                            not input_map_fn):
    371       raise ValueError(
    372           "Should specify one and only one of feed_dict_fn and input_map_fn.")
    373 
    374     self._calibration_graph = ops.Graph()
    375     with self._calibration_graph.as_default():
    376       fetches = importer.import_graph_def(
    377           self._converted_graph_def,
    378           input_map=input_map_fn() if input_map_fn else None,
    379           return_elements=fetch_names,
    380           name="")
    381     self._calibration_sess = session.Session(
    382         graph=self._calibration_graph, config=self._session_config)
    383 
    384     for _ in range(num_runs):
    385       self._calibration_sess.run(
    386           fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None)
    387 
    388     self.finalize_calibration()
    389     return self._converted_graph_def
    390 
    391   def finalize_calibration(self):
    392     """Clean up calibration resources and finalize the calibration.
    393 
    394     Implementations need to close self._calibration_sess before returning.
    395     """
    396     raise NotImplementedError("finalize_calibration")
    397 
    398   def save(self, output_saved_model_dir):
    399     """Save the converted graph as a SavedModel.
    400 
    401     Args:
    402       output_saved_model_dir: construct a SavedModel using the converted
    403         GraphDef and save it to the specified directory. This option only works
    404         when the input graph is loaded from a SavedModel, i.e. when
    405         input_saved_model_dir is specified and input_graph_def is None in
    406         __init__().
    407 
    408     Raises:
    409       ValueError: if the input to the converter is a GraphDef instead of a
    410       SavedModel.
    411     """
    412     assert self._converted
    413 
    414     if context.executing_eagerly():
    415       # Rewrite the signature map using the optimized ConcreteFunction.
    416       signatures = {
    417           key: value for key, value in self._saved_model.signatures.items()
    418       }
    419       signatures[self._input_saved_model_signature_key] = self._converted_func
    420       save.save(self._saved_model, output_saved_model_dir, signatures)
    421     else:
    422       if self._input_graph_def:
    423         raise ValueError(
    424             "Not able to save to a SavedModel since input is a GraphDef")
    425 
    426       # Write the transformed graphdef as SavedModel.
    427       saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
    428       with ops.Graph().as_default():
    429         importer.import_graph_def(self._converted_graph_def, name="")
    430         # We don't use any specific converter here.
    431         with session.Session(config=self._session_config) as sess:
    432           saved_model_builder.add_meta_graph_and_variables(
    433               sess,
    434               self._input_saved_model_tags,
    435               signature_def_map=self._grappler_meta_graph_def.signature_def)
    436       # Ignore other meta graphs from the input SavedModel.
    437       saved_model_builder.save()
    438 
    439 
    440 class TrtPrecisionMode(object):
    441   FP32 = "FP32"
    442   FP16 = "FP16"
    443   INT8 = "INT8"
    444 
    445   @staticmethod
    446   def supported_precision_modes():
    447     return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
    448 
    449 
    450 # Use a large enough number as the default max_workspace_size for TRT engines,
    451 # so it can produce reasonable performance results with the default.
    452 DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
    453 
    454 
    455 class TrtGraphConverter(GraphConverter):
    456   """A GraphConverter for TRT transformation."""
    457 
    458   _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF_TRT_Calibration"
    459 
    460   @classmethod
    461   def get_tensorrt_rewriter_config(
    462       cls,
    463       rewriter_config_template=None,
    464       max_batch_size=1,
    465       max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
    466       precision_mode=TrtPrecisionMode.FP32,
    467       minimum_segment_size=3,
    468       is_dynamic_op=False,
    469       maximum_cached_engines=1,
    470       cached_engine_batches=None,
    471       use_calibration=True,
    472       use_function_backup=True):
    473     """Returns a RewriterConfig proto for TRT transformation.
    474 
    475     Args:
    476       rewriter_config_template: a template RewriterConfig proto used to create a
    477         TRT-enabled RewriterConfig. If None, it will use a default one.
    478       max_batch_size: max size for the input batch
    479       max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
    480         engine can use at execution time. This corresponds to the
    481         'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
    482       precision_mode: one of TrtPrecisionMode.supported_precision_modes().
    483       minimum_segment_size: the minimum number of nodes required for a subgraph
    484         to be replaced by TRTEngineOp.
    485       is_dynamic_op: whether to generate dynamic TRT ops which will build the
    486         TRT network and engine at run time.
    487       maximum_cached_engines: max number of cached TRT engines in dynamic TRT
    488         ops. If the number of cached engines is already at max but none of them
    489         can serve the input, the TRTEngineOp will fall back to run the TF
    490         function based on which the TRTEngineOp is created.
    491       cached_engine_batches: a list of batch sizes used to create cached
    492         engines, only used when is_dynamic_op is True. The length of the list
    493         should be <= maximum_cached_engines, and the dynamic TRT op will use
    494         this list to determine the batch sizes of the cached engines, instead of
    495         making the decision on the fly. This is useful when we know the most
    496         common batch size(s) the application is going to generate.
    497       use_calibration: this argument is ignored if precision_mode is not INT8.
    498         If set to True, a calibration graph will be created to calibrate the
    499         missing ranges. The calibration graph must be converted to an inference
    500         graph by running calibration with calibrate(). If set to False,
    501         quantization nodes will be expected for every tensor in the graph
    502         (exlcuding those which will be fused). If a range is missing, an error
    503         will occur. Please note that accuracy may be negatively affected if
    504         there is a mismatch between which tensors TRT quantizes and which
    505         tensors were trained with fake quantization.
    506       use_function_backup: if set to True, it will create a FunctionDef for each
    507         subgraph that is converted to TRT op, and if TRT ops fail to execute at
    508         runtime, it'll invoke that function as a fallback.
    509 
    510     Returns:
    511       A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
    512 
    513     Raises:
    514       TypeError: if any of the parameters are of unexpected type.
    515       ValueError: if any of the parameters are of unexpected value.
    516     """
    517     # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
    518     # even if it cannot find TensorRT library.
    519     trt_ops.load_trt_ops()
    520     # pylint: disable=g-import-not-at-top,unused-import,line-too-long,unused-variable
    521     # Import a random symbol to trigger loading of TRT library.
    522     from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
    523     # pylint: enable=g-import-not-at-top,unused-import,line-too-long,unused-variable
    524 
    525     if rewriter_config_template is not None and not isinstance(
    526         rewriter_config_template, rewriter_config_pb2.RewriterConfig):
    527       raise TypeError(
    528           "rewriter_config_template should be a RewriterConfig proto.")
    529 
    530     rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
    531     if rewriter_config_template is None:
    532       # Layout optimizer may add Const nodes followed by Reshape nodes, thus we
    533       # need to run constant folding again.
    534       rewriter_config_with_trt.optimizers.extend(
    535           ["constfold", "layout", "constfold"])
    536       rewriter_config_with_trt.meta_optimizer_iterations = (
    537           rewriter_config_pb2.RewriterConfig.ONE)
    538     else:
    539       rewriter_config_with_trt.CopyFrom(rewriter_config_template)
    540 
    541     optimizer = rewriter_config_with_trt.custom_optimizers.add()
    542     optimizer.name = "TensorRTOptimizer"
    543     optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
    544     optimizer.parameter_map["max_batch_size"].i = max_batch_size
    545     optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
    546     optimizer.parameter_map[
    547         "max_workspace_size_bytes"].i = max_workspace_size_bytes
    548     optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
    549     optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
    550     if cached_engine_batches:
    551       optimizer.parameter_map["cached_engine_batches"].list.i.extend(
    552           cached_engine_batches)
    553     optimizer.parameter_map["use_calibration"].b = use_calibration
    554     optimizer.parameter_map["use_function_backup"].b = use_function_backup
    555     return rewriter_config_with_trt
    556 
    557   def __init__(self,
    558                input_saved_model_dir=None,
    559                input_saved_model_tags=None,
    560                input_saved_model_signature_key=None,
    561                input_graph_def=None,
    562                nodes_blacklist=None,
    563                session_config=None,
    564                max_batch_size=1,
    565                max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
    566                precision_mode=TrtPrecisionMode.FP32,
    567                minimum_segment_size=3,
    568                is_dynamic_op=False,
    569                maximum_cached_engines=1,
    570                cached_engine_batches=None,
    571                use_calibration=True,
    572                use_function_backup=True):
    573     """Initialize the converter.
    574 
    575     Args:
    576       input_saved_model_dir: the directory to load the SavedModel which contains
    577         the input graph to transforms. Used only when input_graph_def is None.
    578       input_saved_model_tags: list of tags to load the SavedModel.
    579       input_saved_model_signature_key: the key of the signature to optimize the
    580         graph for.
    581       input_graph_def: a GraphDef object containing a model to be transformed.
    582         If set to None, the graph will be read from the SavedModel loaded from
    583         input_saved_model_dir.
    584       nodes_blacklist: list of node names to prevent the converter from
    585         touching. Only used when input_graph_def is not None.
    586       session_config: the ConfigProto used to create a Session. It's also used
    587         as a template to create a TRT-enabled ConfigProto for conversion. If not
    588         specified, a default ConfigProto will be used.
    589       max_batch_size: max size for the input batch.
    590       max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
    591         engine can use at execution time. This corresponds to the
    592         'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
    593       precision_mode: one of TrtPrecisionMode.supported_precision_modes().
    594       minimum_segment_size: the minimum number of nodes required for a subgraph
    595         to be replaced by TRTEngineOp.
    596       is_dynamic_op: whether to generate dynamic TRT ops which will build the
    597         TRT network and engine at run time.
    598       maximum_cached_engines: max number of cached TRT engines in dynamic TRT
    599         ops. If the number of cached engines is already at max but none of them
    600         can serve the input, the TRTEngineOp will fall back to run the TF
    601         function based on which the TRTEngineOp is created.
    602       cached_engine_batches: a list of batch sizes used to create cached
    603         engines, only used when is_dynamic_op is True. The length of the list
    604         should be <= maximum_cached_engines, and the dynamic TRT op will use
    605         this list to determine the batch sizes of the cached engines, instead of
    606         making the decision on the fly. This is useful when we know the most
    607         common batch size(s) the application is going to generate.
    608       use_calibration: this argument is ignored if precision_mode is not INT8.
    609         If set to True, a calibration graph will be created to calibrate the
    610         missing ranges. The calibration graph must be converted to an inference
    611         graph by running calibration with calibrate(). If set to False,
    612         quantization nodes will be expected for every tensor in the graph
    613         (exlcuding those which will be fused). If a range is missing, an error
    614         will occur. Please note that accuracy may be negatively affected if
    615         there is a mismatch between which tensors TRT quantizes and which
    616         tensors were trained with fake quantization.
    617       use_function_backup: if set to True, it will create a FunctionDef for each
    618         subgraph that is converted to TRT op, and if TRT ops fail to execute at
    619         runtime, it'll invoke that function as a fallback.
    620 
    621     Raises:
    622       ValueError: if the combination of the parameters is invalid.
    623       RuntimeError: if the TensorRT library version is incompatible.
    624     """
    625     super(TrtGraphConverter, self).__init__(
    626         input_saved_model_dir=input_saved_model_dir,
    627         input_saved_model_tags=input_saved_model_tags,
    628         input_saved_model_signature_key=input_saved_model_signature_key,
    629         input_graph_def=input_graph_def,
    630         nodes_blacklist=nodes_blacklist,
    631         session_config=session_config)
    632 
    633     # TODO(laigd): move all the validations below to
    634     # get_tensorrt_rewriter_config().
    635 
    636     # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
    637     # even if it cannot find TensorRT library.
    638     trt_ops.load_trt_ops()
    639     # pylint: disable=g-import-not-at-top,line-too-long
    640     from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
    641     from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
    642     # pylint: enable=g-import-not-at-top,line-too-long
    643 
    644     # Check compatibility of TensorRT version.
    645     compiled_version = get_linked_tensorrt_version()
    646     loaded_version = get_loaded_tensorrt_version()
    647     tf_logging.info("Linked TensorRT version: %s" % str(compiled_version))
    648     tf_logging.info("Loaded TensorRT version: %s" % str(loaded_version))
    649     version_mismatch = False
    650     if loaded_version[0] < compiled_version[0]:
    651       tf_logging.error(
    652           "TensorRT version mismatch. Tensorflow was compiled against " +
    653           "TensorRT %s but library loaded from environment is TensorRT %s" %
    654           (".".join([str(x) for x in compiled_version]),
    655            ".".join([str(x) for x in loaded_version])) +
    656           ". Please make sure that correct version of TensorRT " +
    657           "is available in the system and added to ldconfig or LD_LIBRARY_PATH")
    658       raise RuntimeError("Incompatible TensorRT library version")
    659     for i in zip(loaded_version, compiled_version):
    660       if i[0] != i[1]:
    661         tf_logging.warn("TensorRT mismatch. Compiled against version " +
    662                         "%s, but loaded %s. Things may not work" %
    663                         (".".join([str(x) for x in compiled_version]),
    664                          ".".join([str(x) for x in loaded_version])))
    665         version_mismatch = True
    666         break
    667     if not version_mismatch:
    668       tf_logging.info("Running against TensorRT version %s" %
    669                       ".".join([str(x) for x in loaded_version]))
    670 
    671     # Check input arguments.
    672     supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
    673     if precision_mode not in supported_precision_modes:
    674       raise ValueError(("precision mode '{}' is not supported."
    675                         "It should be one of {}").format(
    676                             precision_mode, supported_precision_modes))
    677 
    678     if cached_engine_batches:
    679       if not isinstance(cached_engine_batches, list):
    680         raise TypeError("cached_engine_batches should be a list.")
    681       if len(cached_engine_batches) > maximum_cached_engines:
    682         raise ValueError("cached_engine_batches should not contain more than "
    683                          "maximum_cached_engines items.")
    684 
    685     self._need_calibration = (
    686         precision_mode == TrtPrecisionMode.INT8 and use_calibration)
    687     self._use_function_backup = use_function_backup
    688 
    689     # TODO(laigd): consider provide a mechanism to remove the fallback path
    690     # after calibration is done.
    691     if self._need_calibration and not use_function_backup:
    692       raise ValueError(
    693           "Calibration requires enabling fallback to TF function execution.")
    694 
    695     # TODO(laigd):
    696     # - Get rid of is_dynamic_op option, it should always be True, and it should
    697     #   accept N shapes as input.
    698     # - Verify in int8 mode that maximum_cached_engines and
    699     #   cached_engine_batches are set appropriately.
    700     # - If it fails to build the int8 engine it should return error.
    701     self._max_batch_size = max_batch_size
    702     self._max_workspace_size_bytes = max_workspace_size_bytes
    703     self._precision_mode = precision_mode
    704     self._minimum_segment_size = minimum_segment_size
    705     self._is_dynamic_op = is_dynamic_op
    706     self._maximum_cached_engines = maximum_cached_engines
    707     self._cached_engine_batches = cached_engine_batches
    708 
    709   def get_rewriter_config(self, rewriter_config_template=None):
    710     return TrtGraphConverter.get_tensorrt_rewriter_config(
    711         rewriter_config_template,
    712         max_batch_size=self._max_batch_size,
    713         max_workspace_size_bytes=self._max_workspace_size_bytes,
    714         precision_mode=self._precision_mode,
    715         minimum_segment_size=self._minimum_segment_size,
    716         is_dynamic_op=self._is_dynamic_op,
    717         maximum_cached_engines=self._maximum_cached_engines,
    718         cached_engine_batches=self._cached_engine_batches,
    719         use_calibration=self._need_calibration,
    720         use_function_backup=self._use_function_backup)
    721 
    722   def finalize_calibration(self):
    723     assert self._need_calibration
    724     assert self._converted
    725     assert not self._calibration_data_collected
    726 
    727     # Lazily load the op, since it's not available in cpu-only builds. Importing
    728     # this at top will cause tests that imports TF-TRT fail when they're built
    729     # and run without CUDA/GPU.
    730     # pylint: disable=g-import-not-at-top,line-too-long
    731     from tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops import get_serialized_resource_op
    732     # pylint: enable=g-import-not-at-top,line-too-long
    733 
    734     # TODO(laigd): a better way would be to use self._calibration_sess to list
    735     # all the devices, add one get_serialized_resource_op for each device, and
    736     # fetch each such op for every resource until its found. This can work
    737     # even when the device of the TRTEngineOp is empty or not fully specified.
    738 
    739     # Maps device name to the corresponding get_serialized_resource_op.
    740     device_to_get_resource_op_map = {}
    741 
    742     with self._calibration_graph.as_default():
    743       container_input = array_ops.placeholder(dtypes.string)
    744       resource_name_input = array_ops.placeholder(dtypes.string)
    745 
    746       for node in self._converted_graph_def.node:
    747         if node.op == "TRTEngineOp":
    748           # Adds the get_serialized_resource_op for the device if not done
    749           # before. We only add one such op for each device.
    750           # TODO(laigd): What if the device is empty?????
    751           if node.device not in device_to_get_resource_op_map:
    752             with self._calibration_graph.device(node.device):
    753               serialized_resources_output = (
    754                   get_serialized_resource_op(container_input,
    755                                              resource_name_input))
    756             device_to_get_resource_op_map[node.device] = (
    757                 serialized_resources_output)
    758 
    759           # Get the calibration resource.
    760           calibration_result = self._calibration_sess.run(
    761               device_to_get_resource_op_map[node.device],
    762               feed_dict={
    763                   container_input:
    764                       TrtGraphConverter
    765                       ._TRT_CALIBRATION_RESOURCE_CONTAINER_NAME,
    766                   resource_name_input:
    767                       node.name
    768               })
    769           node.attr["calibration_data"].s = calibration_result
    770 
    771     self._calibration_data_collected = True
    772     self._calibration_sess.close()
    773 
    774   def save(self, output_saved_model_dir):
    775     """Save the converted graph as a SavedModel."""
    776     if self._need_calibration:
    777       assert self._calibration_data_collected
    778     super(TrtGraphConverter, self).save(output_saved_model_dir)
    779 
    780 
    781 def create_inference_graph(
    782     input_graph_def,
    783     outputs,
    784     max_batch_size=1,
    785     max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
    786     precision_mode=TrtPrecisionMode.FP32,
    787     minimum_segment_size=3,
    788     is_dynamic_op=False,
    789     maximum_cached_engines=1,
    790     cached_engine_batches=None,
    791     input_saved_model_dir=None,
    792     input_saved_model_tags=None,
    793     input_saved_model_signature_key=None,
    794     output_saved_model_dir=None,
    795     session_config=None):
    796   """Python wrapper for the TRT transformation.
    797 
    798   Args:
    799     input_graph_def: a GraphDef object containing a model to be transformed. If
    800       set to None, the graph will be read from the SavedModel loaded from
    801       input_saved_model_dir.
    802     outputs: list of tensors or node names for the model outputs. Only used when
    803       input_graph_def is not None.
    804     max_batch_size: max size for the input batch.
    805     max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
    806       engine can use at execution time. This corresponds to the 'workspaceSize'
    807       parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
    808     precision_mode: one of TrtPrecisionMode.supported_precision_modes().
    809     minimum_segment_size: the minimum number of nodes required for a subgraph to
    810       be replaced by TRTEngineOp.
    811     is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
    812       network and engine at run time.
    813     maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
    814       If the number of cached engines is already at max but none of them can
    815       serve the input, the TRTEngineOp will fall back to run the TF function
    816       based on which the TRTEngineOp is created.
    817     cached_engine_batches: a list of batch sizes used to create cached engines,
    818       only used when is_dynamic_op is True. The length of the list should be <=
    819       maximum_cached_engines, and the dynamic TRT op will use this list to
    820       determine the batch sizes of the cached engines, instead of making the
    821       decision on the fly. This is useful when we know the most common batch
    822       size(s) the application is going to generate.
    823     input_saved_model_dir: the directory to load the SavedModel which contains
    824       the input graph to transforms. Used only when input_graph_def is None.
    825     input_saved_model_tags: list of tags to load the SavedModel.
    826     input_saved_model_signature_key: the key of the signature to optimize the
    827       graph for.
    828     output_saved_model_dir: if not None, construct a SavedModel using the
    829       returned GraphDef and save it to the specified directory. This option only
    830       works when the input graph is loaded from a SavedModel, i.e. when
    831       input_saved_model_dir is specified and input_graph_def is None.
    832     session_config: the ConfigProto used to create a Session. It's also used as
    833       a template to create a TRT-enabled ConfigProto for conversion. If not
    834       specified, a default ConfigProto will be used.
    835 
    836   Returns:
    837     A GraphDef transformed from input_graph_def (or the SavedModel graph def
    838     loaded from input_saved_model_dir, if input_graph_def is not present), where
    839     all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
    840     function is added for each of the subgraphs.
    841 
    842     If is_dynamic_op is True, each TRTEngineOp will contain a serialized
    843     subgraph GraphDef, which will be converted to a TRT engine at execution time
    844     and the TRT engine will be cached for future usage. A new TRT engine will be
    845     created each time when none of the cached engines match the input shapes. If
    846     it fails to execute the TRT engine or the number of cached engines reaches
    847     maximum_cached_engines, the op will fall back to call the corresponding TF
    848     function.
    849 
    850     If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
    851     engine created from the corresponding subgraph. No more engines will be
    852     created on the fly, and the op will fall back to call the corresponding TF
    853     function when it fails to execute the engine.
    854 
    855   Raises:
    856     ValueError: if the combination of the parameters is invalid.
    857   """
    858   trt_converter = TrtGraphConverter(
    859       input_saved_model_dir=input_saved_model_dir,
    860       input_saved_model_tags=input_saved_model_tags,
    861       input_saved_model_signature_key=input_saved_model_signature_key,
    862       input_graph_def=input_graph_def,
    863       nodes_blacklist=outputs,
    864       session_config=session_config,
    865       max_batch_size=max_batch_size,
    866       max_workspace_size_bytes=max_workspace_size_bytes,
    867       precision_mode=precision_mode,
    868       minimum_segment_size=minimum_segment_size,
    869       is_dynamic_op=is_dynamic_op,
    870       maximum_cached_engines=maximum_cached_engines,
    871       cached_engine_batches=cached_engine_batches,
    872       use_calibration=False)
    873   converted_graph_def = trt_converter.convert()
    874   if output_saved_model_dir:
    875     trt_converter.save(output_saved_model_dir)
    876   return converted_graph_def
    877