Home | History | Annotate | Download | only in convert
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
     17 
     18 #include <map>
     19 #include <set>
     20 #include <unordered_map>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
     25 #include "tensorflow/contrib/tensorrt/segment/segment.h"
     26 #include "tensorflow/core/graph/algorithm.h"
     27 #include "tensorflow/core/graph/graph.h"
     28 #include "tensorflow/core/graph/graph_constructor.h"
     29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
     30 #include "tensorflow/core/grappler/costs/graph_properties.h"
     31 #include "tensorflow/core/grappler/devices.h"
     32 #include "tensorflow/core/grappler/grappler_item.h"
     33 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
     34 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
     35 #include "tensorflow/core/grappler/utils.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/lib/core/status.h"
     38 #include "tensorflow/core/platform/logging.h"
     39 #include "tensorflow/core/platform/types.h"
     40 #include "tensorflow/core/protobuf/device_properties.pb.h"
     41 
     42 #if GOOGLE_CUDA
     43 #if GOOGLE_TENSORRT
     44 #include "tensorrt/include/NvInfer.h"
     45 
     46 namespace tensorflow {
     47 namespace tensorrt {
     48 namespace convert {
     49 namespace {
     50 
     51 static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
     52   // LINT.IfChange
     53   // TODO(jie): Segmentation shouldn't associated with op name.
     54   //            Split it into a registration for each kernel.
     55   static const std::set<string> candidate_ops = {
     56       "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
     57       "Add",      "Mul",   "Sub",    "Rsqrt",   "Pad"  // "Placeholder" ,"Mean"
     58   };
     59   // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
     60   return candidate_ops.count(node_def.op());
     61 }
     62 
     63 void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
     64                               const std::set<int>& subgraph_node_ids,
     65                               tensorflow::EdgeSet* incoming_edges) {
     66   for (int node_id : subgraph_node_ids) {
     67     const tensorflow::Node* node = graph.FindNodeId(node_id);
     68     for (const tensorflow::Edge* edge : node->in_edges()) {
     69       if (!subgraph_node_ids.count(edge->src()->id()) &&
     70           !edge->src()->IsSource()) {
     71         incoming_edges->insert(edge);
     72       }
     73     }
     74   }
     75 }
     76 
     77 void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
     78                               const std::set<int>& subgraph_node_ids,
     79                               tensorflow::EdgeSet* outgoing_edges) {
     80   for (int node_id : subgraph_node_ids) {
     81     const tensorflow::Node* node = graph.FindNodeId(node_id);
     82     for (const tensorflow::Edge* edge : node->out_edges()) {
     83       if (!subgraph_node_ids.count(edge->dst()->id()) &&
     84           !edge->dst()->IsSink()) {
     85         outgoing_edges->insert(edge);
     86       }
     87     }
     88   }
     89 }
     90 
     91 std::pair<string, int> ParseTensorName(string name, int default_idx = 0) {
     92   int idx = default_idx;
     93   size_t sep = name.find_last_of(':');
     94   if (sep != string::npos) {
     95     name = name.substr(0, sep);
     96     idx = std::stoi(name.substr(sep + 1));
     97   }
     98   return std::make_pair(name, idx);
     99 }
    100 
    101 std::unordered_map<string, std::vector<int>> BuildTensorNameMap(
    102     const std::vector<string>& tensor_names) {
    103   std::unordered_map<string, std::vector<int>> result;
    104   for (string const& tensor_name : tensor_names) {
    105     string node_name;
    106     int index;
    107     std::tie(node_name, index) = ParseTensorName(tensor_name);
    108     result[node_name].push_back(index);
    109   }
    110   return result;
    111 }
    112 
    113 tensorflow::Status ConvertSubGraphToTensorRT(
    114     const std::vector<string>& output_names,
    115     const std::set<int>& subgraph_node_ids,
    116     size_t max_batch_size,  // Max batch size that engine will be created for
    117     // Max amount of memory that engine will be allowed to consume, in bytes
    118     size_t max_workspace_size_bytes,
    119     const tensorflow::grappler::GraphProperties& graph_properties,
    120     tensorflow::Graph* graph) {
    121   tensorflow::EdgeSet subgraph_incoming_edges;
    122   GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
    123 
    124   std::vector<std::pair<int, int>> subgraph_inputs;
    125 
    126   // Collect inputs by looking for incoming edges
    127   for (const tensorflow::Edge* edge : subgraph_incoming_edges) {
    128     subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
    129   }
    130   std::set<std::pair<int, int>> subgraph_outputs_set;
    131   // Collect outputs referenced from output_names
    132   auto output_name_to_index_map = BuildTensorNameMap(output_names);
    133   for (int node_id : subgraph_node_ids) {
    134     tensorflow::Node* node = graph->FindNodeId(node_id);
    135     if (output_name_to_index_map.count(node->name())) {
    136       for (int index : output_name_to_index_map.at(node->name())) {
    137         subgraph_outputs_set.insert({node_id, index});
    138       }
    139     }
    140   }
    141   // Collect outputs referenced from outgoing edges
    142   tensorflow::EdgeSet subgraph_outgoing_edges;
    143   GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
    144   for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
    145     subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
    146   }
    147   // Impose an ordering on the outputs
    148   std::vector<std::pair<int, int>> subgraph_outputs(
    149       subgraph_outputs_set.begin(), subgraph_outputs_set.end());
    150   // Build TensorRT node and add it to the graph
    151   tensorflow::NodeDef trt_node_def;
    152   TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
    153       *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
    154       max_batch_size, max_workspace_size_bytes, graph_properties,
    155       &trt_node_def));
    156   tensorflow::Status status;
    157   tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
    158   TF_RETURN_IF_ERROR(status);
    159 
    160   // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
    161   std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
    162   for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
    163     subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
    164   }
    165   TF_RETURN_IF_ERROR(status);
    166   for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
    167     std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
    168     int new_src_output = subgraph_edge_to_output_map.at(old_src);
    169     TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(),
    170                                          edge->dst_input()));
    171   }
    172   // Remove the original subgraph
    173   for (int node_id : subgraph_node_ids) {
    174     tensorflow::Node* node = graph->FindNodeId(node_id);
    175     // Don't remove the input placeholders
    176     if (node->type_string() == "Placeholder") {
    177       continue;
    178     }
    179     graph->RemoveNode(node);
    180   }
    181   return tensorflow::Status::OK();
    182 }
    183 
    184 tensorflow::Status BuildNodeMap(
    185     const tensorflow::Graph& graph,
    186     std::unordered_map<string, tensorflow::Node*>* node_map) {
    187   for (auto* node : graph.op_nodes()) {
    188     if (!node_map->insert({node->name(), node}).second) {
    189       return tensorflow::errors::AlreadyExists(
    190           "Node name is not unique in graph: " + node->name());
    191     }
    192   }
    193   return tensorflow::Status::OK();
    194 }
    195 
    196 }  // namespace
    197 
    198 tensorflow::Status ConvertGraphDefToTensorRT(
    199     const tensorflow::GraphDef& graph_def,
    200     const std::vector<string>& output_names, size_t max_batch_size,
    201     size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) {
    202   // Optimization pass
    203   tensorflow::grappler::GrapplerItem item;
    204   item.fetch = output_names;
    205   tensorflow::GraphDef gdef;
    206 
    207   // Layout optimization
    208   item.graph = graph_def;
    209   tensorflow::grappler::LayoutOptimizer optimizer;
    210   tensorflow::grappler::Cluster* cluster;
    211 
    212   // Virtual cluster
    213   tensorflow::DeviceProperties device_properties;
    214   device_properties.set_type("GPU");
    215   device_properties.mutable_environment()->insert({"architecture", "6"});
    216   cluster =
    217       new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
    218 
    219   TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef));
    220 
    221   // Constant folding
    222   item.graph = gdef;
    223   tensorflow::grappler::ConstantFolding fold(nullptr);
    224   TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef));
    225 
    226   // AJ refactoring shape inference through grappler/GraphProperties.
    227   tensorflow::grappler::GraphProperties static_graph_properties(item);
    228   TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false));
    229 
    230   // Build full graph
    231   tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
    232                                              gdef.library());
    233   tensorflow::Graph graph(flib);
    234   TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
    235       tensorflow::GraphConstructorOptions(), gdef, &graph));
    236 
    237   // Segment the graph into subgraphs that can be converted to TensorRT
    238   tensorflow::tensorrt::segment::SegmentOptions segment_options;
    239 
    240   // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
    241   for (auto node : output_names) {
    242     segment_options.exclude_node_list.insert(node);
    243   }
    244 
    245   // TODO(sami): this should be passed as a knob!!!!
    246   segment_options.minimum_segment_size = 2;
    247   tensorflow::tensorrt::segment::SegmentNodesVector segments;
    248   TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
    249       gdef, IsTensorRTCandidate, segment_options, &segments));
    250   if (segments.size() > 1) {
    251     VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
    252   }
    253   std::unordered_map<string, tensorflow::Node*> node_map;
    254   TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
    255   for (const std::set<string>& subgraph_node_names : segments) {
    256     std::set<int> subgraph_node_ids;
    257     for (const string& node_name : subgraph_node_names) {
    258       subgraph_node_ids.insert(node_map.at(node_name)->id());
    259     }
    260     TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
    261         output_names, subgraph_node_ids, max_batch_size,
    262         max_workspace_size_bytes, static_graph_properties, &graph));
    263   }
    264   graph.ToGraphDef(new_graph_def);
    265   return tensorflow::Status::OK();
    266 }
    267 
    268 }  // namespace convert
    269 }  // namespace tensorrt
    270 }  // namespace tensorflow
    271 
    272 #endif  // GOOGLE_TENSORRT
    273 #endif  // GOOGLE_CUDA
    274