1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" 17 18 #include <map> 19 #include <set> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" 25 #include "tensorflow/contrib/tensorrt/segment/segment.h" 26 #include "tensorflow/core/graph/algorithm.h" 27 #include "tensorflow/core/graph/graph.h" 28 #include "tensorflow/core/graph/graph_constructor.h" 29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h" 30 #include "tensorflow/core/grappler/costs/graph_properties.h" 31 #include "tensorflow/core/grappler/devices.h" 32 #include "tensorflow/core/grappler/grappler_item.h" 33 #include "tensorflow/core/grappler/optimizers/constant_folding.h" 34 #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" 35 #include "tensorflow/core/grappler/utils.h" 36 #include "tensorflow/core/lib/core/errors.h" 37 #include "tensorflow/core/lib/core/status.h" 38 #include "tensorflow/core/platform/logging.h" 39 #include "tensorflow/core/platform/types.h" 40 #include "tensorflow/core/protobuf/device_properties.pb.h" 41 42 #if GOOGLE_CUDA 43 #if GOOGLE_TENSORRT 44 #include "tensorrt/include/NvInfer.h" 45 46 namespace tensorflow { 47 namespace tensorrt { 48 namespace convert { 49 namespace { 50 51 static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) { 52 // LINT.IfChange 53 // TODO(jie): Segmentation shouldn't associated with op name. 54 // Split it into a registration for each kernel. 55 static const std::set<string> candidate_ops = { 56 "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu", 57 "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean" 58 }; 59 // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) 60 return candidate_ops.count(node_def.op()); 61 } 62 63 void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, 64 const std::set<int>& subgraph_node_ids, 65 tensorflow::EdgeSet* incoming_edges) { 66 for (int node_id : subgraph_node_ids) { 67 const tensorflow::Node* node = graph.FindNodeId(node_id); 68 for (const tensorflow::Edge* edge : node->in_edges()) { 69 if (!subgraph_node_ids.count(edge->src()->id()) && 70 !edge->src()->IsSource()) { 71 incoming_edges->insert(edge); 72 } 73 } 74 } 75 } 76 77 void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, 78 const std::set<int>& subgraph_node_ids, 79 tensorflow::EdgeSet* outgoing_edges) { 80 for (int node_id : subgraph_node_ids) { 81 const tensorflow::Node* node = graph.FindNodeId(node_id); 82 for (const tensorflow::Edge* edge : node->out_edges()) { 83 if (!subgraph_node_ids.count(edge->dst()->id()) && 84 !edge->dst()->IsSink()) { 85 outgoing_edges->insert(edge); 86 } 87 } 88 } 89 } 90 91 std::pair<string, int> ParseTensorName(string name, int default_idx = 0) { 92 int idx = default_idx; 93 size_t sep = name.find_last_of(':'); 94 if (sep != string::npos) { 95 name = name.substr(0, sep); 96 idx = std::stoi(name.substr(sep + 1)); 97 } 98 return std::make_pair(name, idx); 99 } 100 101 std::unordered_map<string, std::vector<int>> BuildTensorNameMap( 102 const std::vector<string>& tensor_names) { 103 std::unordered_map<string, std::vector<int>> result; 104 for (string const& tensor_name : tensor_names) { 105 string node_name; 106 int index; 107 std::tie(node_name, index) = ParseTensorName(tensor_name); 108 result[node_name].push_back(index); 109 } 110 return result; 111 } 112 113 tensorflow::Status ConvertSubGraphToTensorRT( 114 const std::vector<string>& output_names, 115 const std::set<int>& subgraph_node_ids, 116 size_t max_batch_size, // Max batch size that engine will be created for 117 // Max amount of memory that engine will be allowed to consume, in bytes 118 size_t max_workspace_size_bytes, 119 const tensorflow::grappler::GraphProperties& graph_properties, 120 tensorflow::Graph* graph) { 121 tensorflow::EdgeSet subgraph_incoming_edges; 122 GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges); 123 124 std::vector<std::pair<int, int>> subgraph_inputs; 125 126 // Collect inputs by looking for incoming edges 127 for (const tensorflow::Edge* edge : subgraph_incoming_edges) { 128 subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); 129 } 130 std::set<std::pair<int, int>> subgraph_outputs_set; 131 // Collect outputs referenced from output_names 132 auto output_name_to_index_map = BuildTensorNameMap(output_names); 133 for (int node_id : subgraph_node_ids) { 134 tensorflow::Node* node = graph->FindNodeId(node_id); 135 if (output_name_to_index_map.count(node->name())) { 136 for (int index : output_name_to_index_map.at(node->name())) { 137 subgraph_outputs_set.insert({node_id, index}); 138 } 139 } 140 } 141 // Collect outputs referenced from outgoing edges 142 tensorflow::EdgeSet subgraph_outgoing_edges; 143 GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges); 144 for (const tensorflow::Edge* edge : subgraph_outgoing_edges) { 145 subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); 146 } 147 // Impose an ordering on the outputs 148 std::vector<std::pair<int, int>> subgraph_outputs( 149 subgraph_outputs_set.begin(), subgraph_outputs_set.end()); 150 // Build TensorRT node and add it to the graph 151 tensorflow::NodeDef trt_node_def; 152 TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef( 153 *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs, 154 max_batch_size, max_workspace_size_bytes, graph_properties, 155 &trt_node_def)); 156 tensorflow::Status status; 157 tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status); 158 TF_RETURN_IF_ERROR(status); 159 160 // Re-map outgoing edges to use the new TRT node instead of the orig subgraph 161 std::map<std::pair<int, int>, int> subgraph_edge_to_output_map; 162 for (size_t i = 0; i < subgraph_outputs.size(); ++i) { 163 subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i}); 164 } 165 TF_RETURN_IF_ERROR(status); 166 for (const tensorflow::Edge* edge : subgraph_outgoing_edges) { 167 std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()}; 168 int new_src_output = subgraph_edge_to_output_map.at(old_src); 169 TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(), 170 edge->dst_input())); 171 } 172 // Remove the original subgraph 173 for (int node_id : subgraph_node_ids) { 174 tensorflow::Node* node = graph->FindNodeId(node_id); 175 // Don't remove the input placeholders 176 if (node->type_string() == "Placeholder") { 177 continue; 178 } 179 graph->RemoveNode(node); 180 } 181 return tensorflow::Status::OK(); 182 } 183 184 tensorflow::Status BuildNodeMap( 185 const tensorflow::Graph& graph, 186 std::unordered_map<string, tensorflow::Node*>* node_map) { 187 for (auto* node : graph.op_nodes()) { 188 if (!node_map->insert({node->name(), node}).second) { 189 return tensorflow::errors::AlreadyExists( 190 "Node name is not unique in graph: " + node->name()); 191 } 192 } 193 return tensorflow::Status::OK(); 194 } 195 196 } // namespace 197 198 tensorflow::Status ConvertGraphDefToTensorRT( 199 const tensorflow::GraphDef& graph_def, 200 const std::vector<string>& output_names, size_t max_batch_size, 201 size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) { 202 // Optimization pass 203 tensorflow::grappler::GrapplerItem item; 204 item.fetch = output_names; 205 tensorflow::GraphDef gdef; 206 207 // Layout optimization 208 item.graph = graph_def; 209 tensorflow::grappler::LayoutOptimizer optimizer; 210 tensorflow::grappler::Cluster* cluster; 211 212 // Virtual cluster 213 tensorflow::DeviceProperties device_properties; 214 device_properties.set_type("GPU"); 215 device_properties.mutable_environment()->insert({"architecture", "6"}); 216 cluster = 217 new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); 218 219 TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef)); 220 221 // Constant folding 222 item.graph = gdef; 223 tensorflow::grappler::ConstantFolding fold(nullptr); 224 TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef)); 225 226 // AJ refactoring shape inference through grappler/GraphProperties. 227 tensorflow::grappler::GraphProperties static_graph_properties(item); 228 TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false)); 229 230 // Build full graph 231 tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), 232 gdef.library()); 233 tensorflow::Graph graph(flib); 234 TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( 235 tensorflow::GraphConstructorOptions(), gdef, &graph)); 236 237 // Segment the graph into subgraphs that can be converted to TensorRT 238 tensorflow::tensorrt::segment::SegmentOptions segment_options; 239 240 // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) 241 for (auto node : output_names) { 242 segment_options.exclude_node_list.insert(node); 243 } 244 245 // TODO(sami): this should be passed as a knob!!!! 246 segment_options.minimum_segment_size = 2; 247 tensorflow::tensorrt::segment::SegmentNodesVector segments; 248 TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( 249 gdef, IsTensorRTCandidate, segment_options, &segments)); 250 if (segments.size() > 1) { 251 VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); 252 } 253 std::unordered_map<string, tensorflow::Node*> node_map; 254 TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); 255 for (const std::set<string>& subgraph_node_names : segments) { 256 std::set<int> subgraph_node_ids; 257 for (const string& node_name : subgraph_node_names) { 258 subgraph_node_ids.insert(node_map.at(node_name)->id()); 259 } 260 TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT( 261 output_names, subgraph_node_ids, max_batch_size, 262 max_workspace_size_bytes, static_graph_properties, &graph)); 263 } 264 graph.ToGraphDef(new_graph_def); 265 return tensorflow::Status::OK(); 266 } 267 268 } // namespace convert 269 } // namespace tensorrt 270 } // namespace tensorflow 271 272 #endif // GOOGLE_TENSORRT 273 #endif // GOOGLE_CUDA 274