1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/tf2xla.h" 17 18 #include <map> 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/compiler/tf2xla/dump_graph.h" 26 #include "tensorflow/compiler/tf2xla/shape_util.h" 27 #include "tensorflow/compiler/tf2xla/tf2xla_util.h" 28 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 30 #include "tensorflow/core/common_runtime/function.h" 31 #include "tensorflow/core/framework/function.h" 32 #include "tensorflow/core/framework/graph.pb.h" 33 #include "tensorflow/core/framework/graph_def_util.h" 34 #include "tensorflow/core/framework/op.h" 35 #include "tensorflow/core/framework/tensor_shape.h" 36 #include "tensorflow/core/framework/versions.pb.h" 37 #include "tensorflow/core/graph/algorithm.h" 38 #include "tensorflow/core/graph/graph.h" 39 #include "tensorflow/core/graph/graph_constructor.h" 40 #include "tensorflow/core/graph/node_builder.h" 41 #include "tensorflow/core/lib/core/errors.h" 42 #include "tensorflow/core/lib/strings/str_util.h" 43 #include "tensorflow/core/lib/strings/strcat.h" 44 #include "tensorflow/core/platform/logging.h" 45 #include "tensorflow/core/platform/types.h" 46 47 namespace tensorflow { 48 49 const char* const kArgOp = "_Arg"; 50 const char* const kRetvalOp = "_Retval"; 51 const char* const kFeedIdAttr = "_feed_id"; 52 const char* const kFetchIdAttr = "_fetch_id"; 53 const char* const kShapeAttr = "_shape"; 54 const char* const kDebugNameAttr = "_debug_name"; 55 56 namespace { 57 58 typedef std::unordered_map<string, Node*> NodeMap; 59 60 // Each feed id identifies the positional output of some node, which may consist 61 // of multiple edges. AddPlaceholdersForFeeds has already replaced each fed 62 // tensor with a placeholder. For each feed tensor, replaces all edges so they 63 // point from a new _Arg node instead. 64 Status AddArgNodes(Graph* graph, const NodeMap& node_map, 65 const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds, 66 const std::unordered_map<string, string>& feed_remapping) { 67 for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { 68 const tf2xla::Feed& feed = feeds[arg_index]; 69 // All feeds have been replaced by placeholders. 70 const int output_index = 0; 71 72 const string key = TensorIdToString(feed.id()); 73 const auto remap_it = feed_remapping.find(key); 74 auto node_it = node_map.find(remap_it->second); 75 if (node_it == node_map.end()) { 76 // Strip off the aot_feed_#/ prefix. 77 StringPiece name(remap_it->second); 78 const auto index = name.find('/'); 79 if (index > 0) name.remove_prefix(index + 1); 80 return errors::InvalidArgument( 81 "Node is fed but not needed for fetching: ", name); 82 } 83 const Node* feed_node = node_it->second; 84 85 // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a 86 // "_shape" attr if we can determine it. That way the graph will be 87 // initialized with whatever shapes we can infer, while the user can still 88 // explicitly specify or override them. 89 Node* arg_node = nullptr; 90 TF_RETURN_IF_ERROR( 91 NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) 92 .Attr("T", BaseType(feed_node->output_type(output_index))) 93 .Attr("index", arg_index) 94 .Attr(kFeedIdAttr, TensorIdToString(feed.id())) 95 .Attr(kShapeAttr, TensorShape(feed.shape())) 96 .Attr(kDebugNameAttr, feed.name()) 97 .Finalize(graph, &arg_node)); 98 99 // Collects out-edges from the feed node that have a matching edge index; 100 // these will be replaced with edges from the arg node instead. 101 // 102 // We must collect the edges first and process them in a second pass, since 103 // removing the edge from the graph invalidates feed_node->out_edges. 104 std::vector<const Edge*> feed_edges; 105 for (const Edge* edge : feed_node->out_edges()) { 106 if (edge->src_output() == output_index) { 107 feed_edges.push_back(edge); 108 } 109 } 110 for (const Edge* edge : feed_edges) { 111 graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input()); 112 graph->RemoveEdge(edge); 113 } 114 } 115 return Status::OK(); 116 } 117 118 // Each fetch id identifies the positional output of some node. For each fetch 119 // node, adds a new _Retval node instead, and adds the node to `retval_nodes`. 120 Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, 121 const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches, 122 std::unordered_set<const Node*>* retval_nodes) { 123 for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) { 124 const tf2xla::TensorId& id = fetches[ret_index].id(); 125 auto it = node_map.find(id.node_name()); 126 if (it == node_map.end()) { 127 return errors::NotFound("Can't find fetch id: ", TensorIdToString(id)); 128 } 129 Node* fetch_node = it->second; 130 if (id.output_index() >= fetch_node->num_outputs()) { 131 return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id), 132 ", output index should be < ", 133 fetch_node->num_outputs()); 134 } 135 // Connects fetch_node -> retval_node. 136 Node* retval_node = nullptr; 137 TF_RETURN_IF_ERROR( 138 NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp) 139 .Input(fetch_node, id.output_index()) 140 .Attr("T", BaseType(fetch_node->output_type(id.output_index()))) 141 .Attr("index", ret_index) 142 .Attr(kFetchIdAttr, TensorIdToString(id)) 143 .Finalize(graph, &retval_node)); 144 retval_nodes->insert(retval_node); 145 } 146 return Status::OK(); 147 } 148 149 // RewriteAndPruneGraph identifies input and output edges (named by the feed and 150 // fetch ids respectively), and rewrites the edges so that inputs flow from _Arg 151 // nodes, and outputs flow to _Retval nodes. This allows the symbolic graph 152 // execution to know the input and output args for the generated function. 153 Status RewriteAndPruneGraph( 154 Graph* graph, const tf2xla::Config& config, 155 const std::unordered_map<string, string>& feed_remapping) { 156 NodeMap node_map; 157 for (Node* n : graph->nodes()) { 158 node_map[n->name()] = n; 159 } 160 TF_RETURN_IF_ERROR( 161 AddArgNodes(graph, node_map, config.feed(), feed_remapping)); 162 std::unordered_set<const Node*> retval_nodes; 163 TF_RETURN_IF_ERROR( 164 AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); 165 VLOG(2) << "Post rewrite: " 166 << dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph); 167 PruneForReverseReachability(graph, retval_nodes); 168 FixupSourceAndSinkEdges(graph); 169 VLOG(2) << "Post prune: " 170 << dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph); 171 // Sanity-check, to make sure the feeds and fetches still exist post-pruning. 172 std::set<string> missing_feeds, missing_fetches; 173 for (const tf2xla::Feed& feed : config.feed()) { 174 missing_feeds.insert(TensorIdToString(feed.id())); 175 } 176 for (const tf2xla::Fetch& fetch : config.fetch()) { 177 missing_fetches.insert(TensorIdToString(fetch.id())); 178 } 179 for (const Node* n : graph->op_nodes()) { 180 if (n->type_string() == kArgOp) { 181 string feed_id; 182 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); 183 if (missing_feeds.erase(feed_id) == 0) { 184 return errors::Aborted(kArgOp, 185 " node found with unknown feed id: ", feed_id); 186 } 187 } else if (n->type_string() == kRetvalOp) { 188 string fetch_id; 189 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); 190 if (missing_fetches.erase(fetch_id) == 0) { 191 return errors::Aborted(kRetvalOp, 192 " node found with unknown fetch id: ", fetch_id); 193 } 194 } 195 } 196 if (!missing_feeds.empty() || !missing_fetches.empty()) { 197 return errors::Aborted( 198 "Post graph-pruning", 199 ", missing feeds: ", str_util::Join(missing_feeds, ", "), 200 ", missing fetches: ", str_util::Join(missing_fetches, ", ")); 201 } 202 return Status::OK(); 203 } 204 205 // CollectArgNodes collects _Arg nodes from the graph, and performs basic 206 // sanity-checking to ensure the index and type attributes of each node are 207 // initialized correctly. 208 Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) { 209 std::map<int, Node*> indexed_arg_nodes; 210 for (Node* n : graph.nodes()) { 211 if (n->type_string() == kArgOp) { 212 int index; 213 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 214 auto insert_result = indexed_arg_nodes.insert({index, n}); 215 if (!insert_result.second) { 216 const Node* dup = insert_result.first->second; 217 return errors::InvalidArgument( 218 "Multiple ", kArgOp, " nodes with index ", index, ", ", 219 n->DebugString(), " and ", dup->DebugString()); 220 } 221 } 222 } 223 arg_nodes->clear(); 224 for (const auto& index_node : indexed_arg_nodes) { 225 if (index_node.first != arg_nodes->size()) { 226 return errors::InvalidArgument("Expected ", kArgOp, " node with index ", 227 arg_nodes->size(), ", but got index ", 228 index_node.first); 229 } 230 arg_nodes->push_back(index_node.second); 231 } 232 return Status::OK(); 233 } 234 235 // Fills in xla_args from the corresponding _Arg nodes in the graph. 236 Status CreateXlaArgs(const Graph& graph, 237 std::vector<XlaCompiler::Argument>* xla_args) { 238 std::vector<Node*> arg_nodes; 239 TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes)); 240 for (const Node* node : arg_nodes) { 241 XlaCompiler::Argument arg; 242 arg.kind = XlaCompiler::Argument::kParameter; 243 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); 244 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); 245 TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); 246 xla_args->push_back(arg); 247 } 248 return Status::OK(); 249 } 250 251 // Converts the TensorFlow graph into an XLA computation, by executing the 252 // graph symbolically, with each op building up the XLA HLO. 253 Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client, 254 xla::Computation* computation) { 255 XlaOpRegistry::RegisterCompilationKernels(); 256 for (Node* node : graph->nodes()) { 257 node->set_assigned_device_name( 258 strings::StrCat("/device:", DEVICE_CPU_XLA_JIT)); 259 } 260 std::vector<XlaCompiler::Argument> xla_args; 261 TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); 262 263 // Compile the graph into an XLA computation. 264 XlaCompiler::Options compiler_options; 265 compiler_options.client = client; 266 DeviceType device_type(DEVICE_CPU_XLA_JIT); 267 compiler_options.device_type = &device_type; 268 compiler_options.flib_def = &graph->flib_def(); 269 compiler_options.graph_def_version = graph->versions().producer(); 270 compiler_options.allow_cpu_custom_calls = true; 271 XlaCompiler compiler(compiler_options); 272 273 XlaCompiler::CompilationResult result; 274 TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), 275 "tfcompile", std::move(graph), 276 xla_args, &result)); 277 *computation = std::move(*result.computation); 278 279 int num_const_results = 0; 280 for (int i = 0; i < result.outputs.size(); ++i) { 281 // Ending up with const results (i.e. output args) is an error, since it 282 // means that one or more fetches that the user specified will be dropped 283 // from the generated function. It's most likely a configuration error, 284 // since the user shouldn't be asking for output args that end up as consts. 285 // 286 // TODO(toddw): Provide a way for the user to access const output args, 287 // e.g. perhaps hard-coded into the header, or somehow copied into the 288 // output buffers. 289 if (result.outputs[i].is_constant) { 290 ++num_const_results; 291 LOG(ERROR) << "ConstRetVal index:" << i 292 << " value:" << result.outputs[i].constant_value.DebugString(); 293 } 294 } 295 if (num_const_results > 0) { 296 return errors::Unimplemented( 297 "Conversion from TensorFlow graph to XLA resulted in ", 298 num_const_results, 299 " constant results. The configuration of " 300 "the output args (i.e. fetch ids) is probably wrong."); 301 } 302 return Status::OK(); 303 } 304 305 // InitGraph creates a graph based on the graph_def, that may then be converted 306 // to an xla::Computation via ConvertGraphToXla. 307 // 308 // The graph is rewritten with _Arg and _Retval nodes, representing the inputs 309 // and outputs of the function that will be compiled. Each feed id causes a new 310 // _Arg node to be created, where we first collect all existing edges pointing 311 // from the named node's output index, and then rewrite them to point from that 312 // _Arg node instead. Each fetch id causes a new _Retval node to be created, 313 // with a new edge pointing from the named node's output index to that _Retval 314 // node. 315 Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, 316 std::unique_ptr<Graph>* graph) { 317 TF_RETURN_IF_ERROR(ValidateConfig(config)); 318 319 FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); 320 std::unique_ptr<Graph> g(new Graph(flib_def)); 321 322 // Replace references to fed tensors with references to newly added 323 // placeholders. 324 GraphDef first_copy_def = graph_def; 325 326 // Maps from name:port of a feed to the name:port of the placeholder to use. 327 std::unordered_map<string, string> feed_remapping; 328 TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(), 329 &feed_remapping, &first_copy_def)); 330 331 // Prune the GraphDef first so that unknown ops that we aren't compiling get 332 // filtered out. 333 GraphDef second_copy_def; 334 TF_RETURN_IF_ERROR( 335 PruneGraphDefInto(config, first_copy_def, &second_copy_def)); 336 337 TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( 338 &second_copy_def, *g->op_registry(), /*node_offset=*/0)); 339 340 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), 341 second_copy_def, g.get())); 342 TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); 343 *graph = std::move(g); 344 return Status::OK(); 345 } 346 347 } // namespace 348 349 Status ConvertGraphDefToXla(const GraphDef& graph_def, 350 const tf2xla::Config& config, xla::Client* client, 351 xla::Computation* computation) { 352 std::unique_ptr<Graph> graph; 353 TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); 354 TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); 355 return Status::OK(); 356 } 357 358 } // namespace tensorflow 359