1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" 17 18 #include <functional> 19 #include <memory> 20 #include <numeric> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "tensorflow/compiler/jit/graph_to_functiondef.h" 26 #include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" 27 #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" 28 #include "tensorflow/compiler/tf2xla/const_analysis.h" 29 #include "tensorflow/compiler/tf2xla/dump_graph.h" 30 #include "tensorflow/compiler/xla/status_macros.h" 31 #include "tensorflow/core/common_runtime/function.h" 32 #include "tensorflow/core/common_runtime/optimization_registry.h" 33 #include "tensorflow/core/common_runtime/shape_refiner.h" 34 #include "tensorflow/core/framework/function.h" 35 #include "tensorflow/core/framework/graph_def_util.h" 36 #include "tensorflow/core/framework/node_def_builder.h" 37 #include "tensorflow/core/framework/node_def_util.h" 38 #include "tensorflow/core/graph/algorithm.h" 39 #include "tensorflow/core/graph/graph.h" 40 #include "tensorflow/core/graph/graph_def_builder.h" 41 #include "tensorflow/core/graph/tensor_id.h" 42 #include "tensorflow/core/lib/gtl/flatset.h" 43 #include "tensorflow/core/lib/gtl/map_util.h" 44 #include "tensorflow/core/lib/hash/hash.h" 45 #include "tensorflow/core/lib/strings/str_util.h" 46 #include "tensorflow/core/lib/strings/strcat.h" 47 #include "tensorflow/core/public/session_options.h" 48 #include "tensorflow/core/public/version.h" 49 #include "tensorflow/core/util/device_name_utils.h" 50 51 namespace tensorflow { 52 53 const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel"; 54 const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs"; 55 const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; 56 57 namespace { 58 59 bool AreAllParentsConst(const Node& n, 60 const gtl::FlatSet<const Node*>& runtime_const_nodes) { 61 if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") { 62 // If the current node is itself a cast-to-const, no need 63 // to look at the incoming edges. 64 return true; 65 } 66 67 bool all_parents_const = true; 68 bool atleast_one_non_control_edge = false; 69 for (const Edge* in : n.in_edges()) { 70 atleast_one_non_control_edge = 71 atleast_one_non_control_edge || !in->IsControlEdge(); 72 if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) { 73 all_parents_const = false; 74 break; 75 } 76 } 77 return all_parents_const && atleast_one_non_control_edge; 78 } 79 80 void MarkGuaranteedConstants( 81 const Graph& graph, 82 const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) { 83 gtl::FlatSet<const Node*> guaranteed_const_nodes; 84 std::vector<const Node*> srcs; 85 srcs.reserve(src_arg_pairs.size()); 86 for (const auto& src_arg : src_arg_pairs) { 87 srcs.push_back(src_arg.first); 88 } 89 ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, 90 /*leave=*/[&guaranteed_const_nodes](const Node* n) { 91 // TODO(vinuraja): Doesn't work in the presence of loops. 92 if (AreAllParentsConst(*n, guaranteed_const_nodes)) { 93 guaranteed_const_nodes.insert(n); 94 } 95 }); 96 97 for (auto& src_arg : src_arg_pairs) { 98 if (guaranteed_const_nodes.count(src_arg.first) != 0) { 99 VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString(); 100 src_arg.second->AddAttr("_is_guaranteed_constant", true); 101 } 102 } 103 } 104 105 // A node/slot pair. 106 // TODO(phawkins): is there a common definition of this? 107 struct NodeSlot { 108 NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {} 109 NodeSlot(const Node* node, int slot) 110 : node(node), slot(slot), dtype(DT_INVALID) {} 111 NodeSlot(const Node* node, int slot, DataType dtype) 112 : node(node), slot(slot), dtype(dtype) {} 113 114 const Node* node; 115 int slot; 116 117 // Optional: used to record the destination type of a source NodeSlot in case 118 // the source output is a Ref type that is cast to a Tensor at the 119 // destination. 120 DataType dtype; 121 122 bool operator==(const NodeSlot& other) const { 123 return node == other.node && slot == other.slot && dtype == other.dtype; 124 } 125 126 // Leave dtype out of the hash since there are never two NodeSlots with the 127 // same node and slot and different dtypes. 128 struct Hasher { 129 uint64 operator()(NodeSlot const& s) const { 130 return Hash64Combine(std::hash<const Node*>()(s.node), 131 std::hash<int>()(s.slot)); 132 } 133 }; 134 135 struct PairHasher { 136 uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const { 137 return Hash64Combine(Hasher()(s.first), Hasher()(s.second)); 138 } 139 }; 140 }; 141 142 // TODO(phawkins) add a canonical copy of these operator names and refactor 143 // everything to use it. 144 static const char* const kArgOp = "_Arg"; 145 static const char* const kRetValOp = "_Retval"; 146 static const char* const kHostComputeOp = "_XlaHostCompute"; 147 static const char* const kSendFromHostOp = "_XlaSendFromHost"; 148 static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; 149 150 class Encapsulator { 151 public: 152 Encapsulator(string group_attribute, string outside_compilation_attribute, 153 Graph const* graph_in) 154 : group_attribute_(std::move(group_attribute)), 155 outside_compilation_attribute_( 156 std::move(outside_compilation_attribute)), 157 graph_in_(graph_in) {} 158 159 // Find subgraphs marked with 'group_attribute', and build a new 160 // subgraph, one for each value of 'group_attribute'. 161 Status SplitIntoSubgraphs(); 162 163 // Build a FunctionDef for each subgraph, and add it 'library'. The values of 164 // the 'group_attribute' annotations become the function names. 165 // If 'reuse_existing_functions' is set, use an existing function with the 166 // same name, if any. 167 // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before 168 // function conversion. 169 Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn, 170 bool reuse_existing_functions, 171 FunctionLibraryDefinition* library); 172 173 // Write a copy of the input graph to 'graph_out', where the subgraphs are 174 // replaced with calls to the new functions. 175 Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, 176 FunctionLibraryDefinition* library); 177 178 private: 179 // A subgraph of the input, all marked with a common 'group_attribute' 180 // value. A subgraph may contain multiple `outside_compilation' clusters. 181 // 182 // In the following simple example, A, B, ..., E are nodes in the original 183 // graph. The group attributes and outside_compilation attributes g and oc are 184 // each shown as either 0 or empty. 185 // 186 // A --> B --> C --> D --> E 187 // g: g:0 g:0 g:0 g: 188 // oc: oc: oc:0 oc: oc: 189 // 190 // The example is rewritten to two graphs; one on the host and one to be 191 // compiled. The host graph is as follows. RAH is a RecvAtHost node receiving 192 // input from the compiled cluster, and SFH is a SendFromHost node sending 193 // input back to the compiled cluster. Dotted edges are control edges. A 194 // 'sequencing' node S is inserted, and both RAH and SFH are connected via S 195 // to E (and in general all nodes that depend on nodes in the compiled 196 // cluster) to ensure that they are not pruned. 197 // 198 // A --> Call --> E 199 // ^ 200 // . 201 // ........> S 202 // .... ^ 203 // .. . 204 // RAH --> C --> SFH 205 // 206 // The compiled cluster is as follows. HC is a HostCompute node which is the 207 // source of a channel to the RAH node above and the destination of a channel 208 // from the SFH node above. 209 // 210 // Arg --> B --> HC --> D --> Retval 211 // 212 // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is 213 // at most one RAH and SFH in each outside_compilation cluster. This design is 214 // preferred over adding separate Arg/Retval nodes for each transmitted value 215 // because it allows optimizations to the host code that would like to limit 216 // communication between host and device and, e.g., raise only one interrupt 217 // per channel rather than one per transmitted value. 218 // 219 // The shapes of the outputs from the HC node in general cannot be determined 220 // until the shapes of its inputs are known at compile time, since e.g., 221 // above, the shape of C's outputs aren't known until the shape of its inputs 222 // are known. If the shapes of the HC's outputs can be determined during the 223 // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal 224 // graph is stored in the shape_inference_graph attr. This graph can be used 225 // when compiling the HC Op to determined the shape of the SFH inputs given 226 // the shapes of any ancestor RAH outputs. If it can be determined that the 227 // shape of the SFH inputs will not be inferrable even once the shapes of the 228 // RAH outputs are known, an error is returned by the rewriter. 229 class Subgraph { 230 public: 231 // Creates a graph to build the subgraph in, if it doesn't already exist, 232 // using the same op registry and versions as graph_in. 233 Node* MakeNodeImage(const Graph* graph_in, Node* node); 234 235 // Returns the graph the subgraph is being built in. 236 Graph* GetGraph() const; 237 238 // Builds a FunctionDef, and adds it to 'library'. The value of the 239 // 'group_attribute' annotations becomes the function name. If 240 // 'reuse_existing_functions' is set, use an existing function with the same 241 // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the 242 // subgraph before function conversion. 243 Status BuildFunctionDef(const string& name_in, 244 const RewriteSubgraphFn& rewrite_subgraph_fn, 245 bool reuse_existing_functions, 246 FunctionLibraryDefinition* library); 247 248 // Adds the function call node to graph_out. 249 Status AddFunctionCallNode( 250 const std::unordered_map<const Node*, Node*>& node_images, 251 bool parallel_checking, Graph* graph_out); 252 253 // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. 254 Status AddOutsideCompilationHostIONodes( 255 const string& subgraph_name, 256 const std::unordered_map<const Node*, Node*>& node_images, 257 Graph* graph_out); 258 259 // Returns the names of all the outside_compilation subgraphs in this 260 // Subgraph. 261 void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const; 262 263 // Returns the Node that inputs to the function should be wired up to. 264 Node* GetCallNodeForInputs() const; 265 266 // Returns the Node that outputs to the function should be wired up to. 267 Node* GetCallNodeForOutputs() const; 268 269 // Returns the index of the arg that the dst of edge should connect to. 270 int GetArgIndexForEdge(const Edge* edge) const; 271 272 // Returns the index of the result that the src of edge should connect to. 273 int GetResultIndexForEdge(const Edge* edge) const; 274 275 // Returns the RecvAtHost node for an outside_compilation subgraph. 276 Node* GetRecvAtHostNode( 277 const string& outside_compilation_subgraph_name) const; 278 279 // Returns the output slot for the RecvAtHost node that corresponds to the 280 // source of edge in an outside_compilation subgraph. 281 int GetRecvAtHostSlot(const string& outside_compilation_subgraph_name, 282 const Edge* edge) const; 283 284 // Returns the SendFromHost node for an outside_compilation subgraph. 285 Node* GetSendFromHostNode( 286 const string& outside_compilation_subgraph_name) const; 287 288 // Returns the input slot for the SendFromHost node that corresponds to the 289 // destination of edge in an outside_compilation subgraph. 290 int GetSendFromHostSlot(const string& outside_compilation_subgraph_name, 291 const Edge* edge) const; 292 293 // Creates an _Arg node for the src node of edge, and add its index to 294 // args_by_src_, if none exists yet. Also adds its index to args_by_dst_, 295 // and adds the edge within the subgraph from the _Arg node to the image of 296 // the dst node. 297 Status RecordArg(const Edge* edge, 298 const std::unordered_map<const Node*, Node*>& node_images, 299 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs); 300 301 // Creates a _Retval node for the src node of edge, and add it to results_, 302 // if none exists yet. If a new _Retval node is created, also adds the edge 303 // within the subgraph from the src to the _Retval node. 304 Status RecordResult( 305 const Edge* edge, 306 const std::unordered_map<const Node*, Node*>& node_images); 307 308 // Creates an outside_compilation subgraph for outside_compilation_id if 309 // none exists yet. Creates an entry for the src node of edge in the list of 310 // inputs for the outside_compilation subgraph, if none exists yet. 311 void RecordOutsideCompilationInputOrControl( 312 const string& outside_compilation_id, const Edge* edge); 313 314 // Creates an outside_compilation subgraph for outside_compilation_id if 315 // none exists yet. Creates an entry for the src node of edge in the list of 316 // outputs by src for the outside_compilation subgraph, if none exists 317 // yet. Creates an entry for the dst node of edge in the list of outputs by 318 // dst for the outside_compilation subgraph. 319 void RecordOutsideCompilationOutputOrControl( 320 const string& outside_compilation_id, const Edge* edge); 321 322 // Adds the HostCompute nodes for each outside_compilation subgraph. 323 Status AddHostComputes( 324 const string& subgraph_name, 325 const std::unordered_map<const Node*, Node*>& node_images); 326 327 // Creates the sequencer node if it doesn't exist, adding it to graph_out. 328 Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); 329 330 // If there is a sequencer node, adds a control edge from the sequencer to 331 // all the downstream nodes of call_node_outputs. 332 void ConnectSequencerToOutputs(Graph* graph_out); 333 334 Status AddShapeInferenceInfo( 335 const string& outside_compilation_subgraph_name, 336 const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph); 337 338 Status ReplaceFunctionDef(FunctionLibraryDefinition* library); 339 340 private: 341 struct OutsideCompilationSubgraph { 342 // Map from source (producer node/slot) tensors in the original graph to 343 // input index (slot number in the HostCompute/RecvAtHost nodes that will 344 // be created) for the outside_compilation subgraph. 345 std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs; 346 347 // Set of nodes in the original graph that are the source of control edges 348 // that cross from the containing compiled subgraph into the 349 // outside_compilation subgraph. These are recorded by 350 // RecordOutsideCompilationInputOrControl while walking all the subgraph 351 // edges, and lifted control edges within the subgraph are added by 352 // AddSendsToOutsideCompilation once the _HostCompute node has been 353 // created. The matching control edge from _RecvAtHost to the 354 // destination is added by CopyEdgeToOutputGraph. 355 std::unordered_set<const Node*> control_inputs; 356 357 // Maps from source (producer node/slot) and destination (consumer 358 // node/slot) tensors in the original graph to output index (slot number 359 // in the SendFromHost/HostCompute nodes that will be created) for the 360 // outside_compilation subgraph. 361 std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src; 362 std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst; 363 364 // Set of nodes in the original graph that are the destination of control 365 // edges that cross from the outside_compilation subgraph into the 366 // containing compiled subgraph. These are recorded by 367 // RecordOutsideCompilationOutputOrControl while walking all the subgraph 368 // edges, and lifted control edges within the subgraph are added by 369 // AddRecvsFromToOutsideCompilation once the _HostCompute node has been 370 // created. The matching control edge from the source to _SendFromHost to 371 // the destination is added by CopyEdgeToOutputGraph. 372 std::unordered_set<const Node*> control_outputs; 373 374 // Name of the _HostCompute node in the subgraph. 375 string host_compute_name; 376 377 // _RecvAtHost node in the output graph. Not owned. 378 Node* recv_at_host = nullptr; 379 380 // _SendFromHost node in the output graph. Not owned. 381 Node* send_from_host = nullptr; 382 }; 383 384 // Builds a ParallelCheck op that compares the output of the original 385 // subgraph with the encapsulated subgraph. 386 Status BuildParallelCheckOp( 387 const std::unordered_map<const Node*, Node*>& node_images, 388 Graph* graph_out); 389 390 // Builds a _RecvAtHost node producing all the inputs of an 391 // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. 392 Status AddRecvAtHostNode(const string& subgraph_name, 393 const string& oc_subgraph_name, 394 OutsideCompilationSubgraph* oc_subgraph, 395 Graph* graph_out); 396 397 // Builds a _SendFromHost node consuming all the outputs of an 398 // outside_compilation subgraph and stores it in oc_subgraph.send_from_host. 399 Status AddSendFromHostNode( 400 const std::unordered_map<const Node*, Node*>& node_images, 401 const string& subgraph_name, const string& oc_subgraph_name, 402 OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); 403 404 // The subgraph extracted from the input graph, suitable for being turned 405 // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are 406 // returned by _Retval nodes. 407 std::unique_ptr<Graph> graph_; 408 409 // Which device are these nodes on? Used to assign a device to the call 410 // node. 411 string device_; 412 413 // NodeDef for the function call node. 414 NodeDef call_node_def_; 415 416 // Function call node(s) in the output graph. Not owned. 417 // If parallel_checking is enabled, 'call_node_inputs' is the function call 418 // node to which inputs should be fed, and 'call_node_outputs' is the 419 // parallel check op from which outputs should be read. If parallel checking 420 // is disabled, both point to the function call node. 421 Node* call_node_inputs_; 422 Node* call_node_outputs_; 423 424 // Maps from source (producer node/slot) and destination 425 // (consumer node/slot) tensors in the input graph to _Arg numbers in 426 // the subgraph. The source map is one-to-one, whereas the dest map may be 427 // many-to-one. 428 std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src_; 429 std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst_; 430 431 // The _Arg nodes in the subgraph, in order by argument number. 432 std::vector<Node*> args_; 433 434 // Map from source tensor in the input graph to result #. 435 std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results_; 436 437 // The outside_compilation clusters in this subgraph. 438 std::unordered_map<string, OutsideCompilationSubgraph> 439 outside_compilation_subgraphs_; 440 441 // NoOp node in the output graph that is sequenced after the call node and 442 // used to prevent host-side outside_compilation sends and recvs from being 443 // pruned. 444 Node* sequencer_ = nullptr; 445 }; 446 447 // Returns the key attribute and outside_compilation attribute associated 448 // with a node in attr, and outside_compilation_attr, respectively. Sets 449 // either result to the empty string if the respective attribute is not 450 // found. Returns error status if there is an outside_compilation attribute 451 // and no key attribute, 452 Status GetFunctionNameAttr(Node const* node, string* attr, 453 string* outside_compilation_attr) const; 454 455 // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to 456 // subgraphs for data edges that cross subgraph boundaries. 457 Status CopySubgraphEdges( 458 const std::unordered_map<const Node*, Node*>& node_images, 459 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs); 460 461 // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes, 462 // or nodes marked outside_compilation. 463 Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images); 464 465 // Copies all nodes that aren't in a compiled subgraph to the output graph. 466 Status CopyNodesToOutputGraph( 467 bool parallel_checking, Graph* graph_out, 468 std::unordered_map<const Node*, Node*>* node_images); 469 470 // Adds function call nodes for each compiled subgraph. 471 Status AddFunctionCallNodes( 472 const std::unordered_map<const Node*, Node*>& node_images, 473 bool parallel_checking, Graph* graph_out); 474 475 // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all 476 // outside_compilation subgraphs. 477 Status AddOutsideCompilationHostIONodes( 478 const std::unordered_map<const Node*, Node*>& node_images, 479 Graph* graph_out); 480 481 // Finds the image of an edge source in the output graph. If the edge crosses 482 // a subgraph boundary it is the output of a call node, otherwise it is a node 483 // in the output graph. 484 Status FindOutputImageOfEdgeSrc( 485 const string& src_func_id, const string& src_outside_compilation_id, 486 const string& dst_func_id, const string& dst_outside_compilation_id, 487 const std::unordered_map<const Node*, Node*>& node_images, 488 const Node* original_src_node, Node** src_image); 489 490 // Finds an edge source slot in the output graph. If the edge crosses a 491 // subgraph boundary it is a slot on the output of a call node or a 492 // _RecvAtHost node, otherwise it is a slot on a node in the output graph. 493 int FindOutputSlotOfEdgeSrc(const string& src_func_id, 494 const string& src_outside_compilation_id, 495 const string& dst_func_id, 496 const string& dst_outside_compilation_id, 497 const Edge* edge); 498 499 // Finds the image of an edge destination in the output graph. If the edge 500 // crosses a subgraph boundary it is the input of a call node or a 501 // _SendFromHost node, otherwise it is a node in the output graph. 502 Status FindOutputImageOfEdgeDst( 503 const string& src_func_id, const string& src_outside_compilation_id, 504 const string& dst_func_id, const string& dst_outside_compilation_id, 505 const std::unordered_map<const Node*, Node*>& node_images, 506 const Node* original_dst_node, Node** dst_image); 507 508 // Finds an edge destination slot in the output graph. If the edge crosses a 509 // subgraph boundary it is a slot on the input of a call node or a 510 // _SendFromHost node, otherwise it is a slot on a node in the output graph. 511 int FindOutputSlotOfEdgeDst(const string& src_func_id, 512 const string& src_outside_compilation_id, 513 const string& dst_func_id, 514 const string& dst_outside_compilation_id, 515 const Edge* edge); 516 517 // Copies a single edge to the output graph. The edge is either entirely 518 // within the output graph, or crosses into or out of a compiled subgraph. 519 Status CopyEdgeToOutputGraph( 520 const Edge* edge, const string& src_func_id, 521 const string& src_outside_compilation_id, const string& dst_func_id, 522 const string& dst_outside_compilation_id, 523 const std::unordered_map<const Node*, Node*>& node_images, 524 bool parallel_checking, Graph* graph_out, 525 std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>* 526 edges_added); 527 528 // Adds all edges to the output graph. 529 Status AddEdgesToOutputGraph( 530 const std::unordered_map<const Node*, Node*>& node_images, 531 bool parallel_checking, Graph* graph_out); 532 533 // Constructs a minimal shape inference graph that can be used to determine 534 // the shape of send_node at the time that the subgraph is compiled. 535 // recv_at_host_nodes contains the names of all the recv_at_host nodes that 536 // send_node might depend on. These recv_at_host nodes have shapes that are 537 // not known during the rewrite pass, but will be known at compile time. 538 // 539 // If the shapes of all the inputs to send_node can be determined during the 540 // rewrite pass, on exit graphdef_out is empty and the shapes are returned in 541 // static_shape_out. Otherwise graphdef_out contains a graph that can be used 542 // for shape inference at compile time, where all the source nodes of the 543 // graph are either constants with known shapes, or nodes named in 544 // recv_at_host_nodes. 545 // 546 // A non-OK status is returned if neither of the above conditions can be 547 // satisfied, e.g., because send_node depends on a node that doesn't have a 548 // registered shape inference function. 549 Status DoStaticShapeInferenceForOutsideCompilationSend( 550 const Graph& graph_in, const ShapeRefiner& shape_refiner, 551 const std::unordered_set<string>& recv_at_host_nodes, Node* send_node, 552 FunctionLibraryDefinition* library, 553 std::vector<TensorShapeProto>* static_shape_out, 554 std::unique_ptr<GraphDef>* graphdef_out); 555 556 // Makes a copy of graph containing only nodes that are ancestors of at least 557 // one node in send_from_host_nodes and store it in pruned_graph. On exit 558 // nodes_images contains a mapping from nodes in graph to nodes in 559 // pruned_graph. All functions in the copied graph are inlined. 560 Status MakePrunedGraphCopyAndInline( 561 const Graph& graph, const std::vector<Node*>& sink_nodes, 562 std::unique_ptr<Graph>* pruned_graph, 563 std::unordered_map<const Node*, Node*>* node_images, 564 FunctionLibraryDefinition* library); 565 566 // Makes a copy of graph containing only nodes that are ancestors of a 567 // send_from_host node in an outside_compilation subgraph, and store it in 568 // pruned_graph. Also perform shape inference on the pruned graph, using 569 // shape_refiner. On exit node_images contains a mapping from nodes in graph 570 // to nodes in pruned_graph. 571 Status MakeGraphForOutsideCompilationSends( 572 const Graph& graph, std::unique_ptr<Graph>* pruned_graph, 573 ShapeRefiner* shape_refiner, 574 std::unordered_map<const Node*, Node*>* node_images, 575 FunctionLibraryDefinition* library); 576 577 // Performs static shape inference, as far as possible, for the send_from_host 578 // nodes in each outside_compilation subgraph. Where it is not possible to 579 // determine the shape statically, stores a serialized GraphDef in the 580 // HostCompute 'shape_inference_graph' attr, to be used at compile time for 581 // final inference. If the shapes are known statically they are stored in the 582 // HostCompute 'shapes' attr. 583 Status GetShapeInfoForOutsideCompilationSends( 584 Graph* graph_out, FunctionLibraryDefinition* library); 585 586 const string group_attribute_; 587 const string outside_compilation_attribute_; 588 const Graph* graph_in_; 589 590 std::unordered_map<string, Subgraph> subgraphs_; 591 592 TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); 593 }; 594 595 Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { 596 return call_node_inputs_; 597 } 598 599 Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { 600 return call_node_outputs_; 601 } 602 603 int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { 604 return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); 605 } 606 607 int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { 608 return results_.at(NodeSlot(edge->src(), edge->src_output())); 609 } 610 611 Node* Encapsulator::Subgraph::GetRecvAtHostNode( 612 const string& outside_compilation_subgraph_name) const { 613 return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) 614 .recv_at_host; 615 } 616 617 int Encapsulator::Subgraph::GetRecvAtHostSlot( 618 const string& outside_compilation_subgraph_name, const Edge* edge) const { 619 return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) 620 .inputs.at(NodeSlot(edge->src(), edge->src_output())); 621 } 622 623 Node* Encapsulator::Subgraph::GetSendFromHostNode( 624 const string& outside_compilation_subgraph_name) const { 625 return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) 626 .send_from_host; 627 } 628 629 int Encapsulator::Subgraph::GetSendFromHostSlot( 630 const string& outside_compilation_subgraph_name, const Edge* edge) const { 631 return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) 632 .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); 633 } 634 635 Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { 636 if (!graph_) { 637 graph_.reset(new Graph(graph_in->op_registry())); 638 graph_->set_versions(graph_in->versions()); 639 } 640 641 if (device_.empty()) { 642 device_ = node->assigned_device_name().empty() 643 ? node->requested_device() 644 : node->assigned_device_name(); 645 } 646 647 return graph_->CopyNode(node); 648 } 649 650 Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } 651 652 Status Encapsulator::Subgraph::RecordArg( 653 const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images, 654 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) { 655 Node* src_node = edge->src(); 656 int src_slot = edge->src_output(); 657 std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter; 658 bool inserted; 659 std::tie(iter, inserted) = 660 args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); 661 int arg_index = iter->second; 662 if (inserted) { 663 NodeDef arg_def; 664 NodeDefBuilder builder( 665 strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); 666 DataType dtype = edge->dst()->input_type(edge->dst_input()); 667 builder.Attr("T", dtype); 668 builder.Attr("index", arg_index); 669 Status s = builder.Finalize(&arg_def); 670 if (!s.ok()) return s; 671 672 Node* arg = graph_->AddNode(arg_def, &s); 673 if (!s.ok()) return s; 674 675 src_arg_pairs->push_back({src_node, arg}); 676 args_.push_back(arg); 677 } 678 Node* dst_node = edge->dst(); 679 Node* dst_image = node_images.at(dst_node); 680 int dst_slot = edge->dst_input(); 681 args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index; 682 graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); 683 return Status::OK(); 684 } 685 686 Status Encapsulator::Subgraph::RecordResult( 687 const Edge* edge, 688 const std::unordered_map<const Node*, Node*>& node_images) { 689 Node* src_node = edge->src(); 690 Node* src_image = node_images.at(src_node); 691 int src_slot = edge->src_output(); 692 std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter; 693 bool inserted; 694 std::tie(iter, inserted) = 695 results_.emplace(NodeSlot(src_node, src_slot), results_.size()); 696 int ret_index = iter->second; 697 if (inserted) { 698 NodeDef ret_def; 699 NodeDefBuilder builder( 700 strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); 701 DataType dtype = src_node->output_type(src_slot); 702 builder.Attr("T", dtype); 703 builder.Attr("index", ret_index); 704 builder.Input(src_image->name(), src_slot, dtype); 705 Status s = builder.Finalize(&ret_def); 706 if (!s.ok()) return s; 707 Node* ret = graph_->AddNode(ret_def, &s); 708 if (!s.ok()) return s; 709 710 graph_->AddEdge(src_image, src_slot, ret, 0); 711 } 712 return Status::OK(); 713 } 714 715 void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( 716 const string& outside_compilation_id, const Edge* edge) { 717 auto iter = outside_compilation_subgraphs_ 718 .emplace(outside_compilation_id, OutsideCompilationSubgraph()) 719 .first; 720 OutsideCompilationSubgraph& outside_subgraph = iter->second; 721 if (edge->IsControlEdge()) { 722 outside_subgraph.control_inputs.insert(edge->src()); 723 } else { 724 int input_index = outside_subgraph.inputs.size(); 725 outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()), 726 input_index); 727 } 728 } 729 730 void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( 731 const string& outside_compilation_id, const Edge* edge) { 732 auto subgraph_iter = 733 outside_compilation_subgraphs_ 734 .emplace(outside_compilation_id, OutsideCompilationSubgraph()) 735 .first; 736 OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second; 737 if (edge->IsControlEdge()) { 738 outside_subgraph.control_outputs.insert(edge->dst()); 739 } else { 740 DataType dtype = edge->dst()->input_type(edge->dst_input()); 741 auto output_iter = 742 outside_subgraph.outputs_by_src 743 .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), 744 outside_subgraph.outputs_by_src.size()) 745 .first; 746 int output_index = output_iter->second; 747 outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = 748 output_index; 749 } 750 } 751 752 Status Encapsulator::Subgraph::AddHostComputes( 753 const string& subgraph_name, 754 const std::unordered_map<const Node*, Node*>& node_images) { 755 for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { 756 const string& oc_subgraph_name = oc_subgraph_iter.first; 757 OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; 758 if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() || 759 !oc_subgraph.outputs_by_src.empty() || 760 !oc_subgraph.control_outputs.empty()) { 761 // Build a _HostCompute node. 762 std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size()); 763 std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID); 764 std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(), 765 DT_INVALID); 766 767 for (const auto& input_src : oc_subgraph.inputs) { 768 const Node* src_node = input_src.first.node; 769 Node* src_image = node_images.at(src_node); 770 int src_slot = input_src.first.slot; 771 int input_index = input_src.second; 772 773 DataType dtype = src_node->output_type(src_slot); 774 inputs[input_index].Reset(src_image->name(), src_slot, dtype); 775 input_dtypes[input_index] = dtype; 776 } 777 778 for (const auto& output : oc_subgraph.outputs_by_src) { 779 DataType dtype = output.first.dtype; 780 int output_index = output.second; 781 output_dtypes[output_index] = dtype; 782 } 783 784 NodeDef host_compute_def; 785 NodeDefBuilder builder(strings::StrCat("outside_compilation_", 786 oc_subgraph_name, "_host_compute"), 787 kHostComputeOp); 788 builder.Input(inputs); 789 builder.Attr("Tinputs", input_dtypes); 790 builder.Attr("Toutputs", output_dtypes); 791 builder.Attr("key", 792 strings::StrCat("host_compute_channel_", subgraph_name, "_", 793 oc_subgraph_name)); 794 Status s = builder.Finalize(&host_compute_def); 795 if (!s.ok()) return s; 796 797 Node* host_compute = graph_->AddNode(host_compute_def, &s); 798 if (!s.ok()) return s; 799 oc_subgraph.host_compute_name = host_compute->name(); 800 801 // Connect the _HostCompute node to its producers in the subgraph. 802 for (auto& input_src : oc_subgraph.inputs) { 803 const Node* src_node = input_src.first.node; 804 Node* src_image = node_images.at(src_node); 805 int src_slot = input_src.first.slot; 806 int input_index = input_src.second; 807 graph_->AddEdge(src_image, src_slot, host_compute, input_index); 808 } 809 810 // Connect the _HostCompute node to its control edge producers in the 811 // subgraph. 812 for (const auto& src_node : oc_subgraph.control_inputs) { 813 Node* src_image = node_images.at(src_node); 814 graph_->AddControlEdge(src_image, host_compute); 815 } 816 817 // Connect the consumers in the subgraph to the _HostCompute node. 818 for (const auto& output : oc_subgraph.outputs_by_dst) { 819 const Node* dst_node = output.first.node; 820 Node* dst_image = node_images.at(dst_node); 821 int dst_slot = output.first.slot; 822 int output_index = output.second; 823 824 graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); 825 } 826 827 // Connect the control edge consumers in the subgraph to the _HostCompute 828 // node. 829 for (const auto& dst_node : oc_subgraph.control_outputs) { 830 Node* dst_image = node_images.at(dst_node); 831 graph_->AddControlEdge(host_compute, dst_image); 832 } 833 } 834 } 835 836 return Status::OK(); 837 } 838 839 Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, 840 Graph* graph_out) { 841 if (sequencer_ == nullptr) { 842 NodeDef seq_def; 843 NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), 844 "NoOp"); 845 Status s = builder.Finalize(&seq_def); 846 if (!s.ok()) return s; 847 848 sequencer_ = graph_out->AddNode(seq_def, &s); 849 if (!s.ok()) return s; 850 sequencer_->set_assigned_device_name(device_); 851 } 852 return Status::OK(); 853 } 854 855 void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) { 856 if (sequencer_ != nullptr) { 857 std::unordered_set<Node*> output_dependencies; 858 for (Node* node : call_node_outputs_->out_nodes()) { 859 output_dependencies.insert(node); 860 } 861 for (Node* node : output_dependencies) { 862 graph_out->AddControlEdge(sequencer_, node); 863 } 864 } 865 } 866 867 Status Encapsulator::Subgraph::BuildFunctionDef( 868 const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, 869 bool reuse_existing_functions, FunctionLibraryDefinition* library) { 870 // name_in is copied here because name may be modified below if 871 // rewrite_subgraph_fn is true. 872 string name = name_in; 873 call_node_def_.set_op(name); 874 call_node_def_.set_name(name); 875 call_node_def_.set_device(device_); 876 877 if (rewrite_subgraph_fn) { 878 // Initialize the input and output permutations to the identity. 879 std::vector<int> input_permutation(args_by_src_.size()); 880 std::iota(input_permutation.begin(), input_permutation.end(), 0); 881 std::vector<int> output_permutation(results_.size()); 882 std::iota(output_permutation.begin(), output_permutation.end(), 0); 883 884 TF_RETURN_IF_ERROR(rewrite_subgraph_fn( 885 &graph_, &input_permutation, &output_permutation, &call_node_def_)); 886 887 // Apply the input/output permutations to the 'args_by_...' and 'results_' 888 // mappings, so when we build edges in BuildOutputGraph() we 889 // connect them to the right input/output positions. 890 if (input_permutation.size() != args_by_src_.size()) { 891 return errors::InvalidArgument("Input permutation has incorrect size."); 892 } 893 if (output_permutation.size() != results_.size()) { 894 return errors::InvalidArgument("Output permutation has incorrect size."); 895 } 896 for (auto& arg : args_by_src_) { 897 arg.second = input_permutation[arg.second]; 898 } 899 for (auto& arg : args_by_dst_) { 900 arg.second = input_permutation[arg.second]; 901 } 902 for (auto& result : results_) { 903 result.second = output_permutation[result.second]; 904 } 905 906 name = call_node_def_.op(); 907 } 908 909 FunctionDef fdef; 910 TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); 911 912 if (VLOG_IS_ON(1)) { 913 VLOG(2) << "Build function def " << name; 914 dump_graph::DumpGraphToFile( 915 strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); 916 dump_graph::DumpFunctionDefToFile( 917 strings::StrCat("encapsulate_fdef_", name), fdef); 918 } 919 920 if (!reuse_existing_functions || library->Find(name) == nullptr) { 921 TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); 922 } 923 return Status::OK(); 924 } 925 926 Status Encapsulator::Subgraph::AddShapeInferenceInfo( 927 const string& outside_compilation_subgraph_name, 928 const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) { 929 OutsideCompilationSubgraph& oc_subgraph = 930 outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); 931 932 Node* host_compute = nullptr; 933 for (Node* n : graph_->nodes()) { 934 if (n->name() == oc_subgraph.host_compute_name) { 935 host_compute = n; 936 break; 937 } 938 } 939 if (host_compute == nullptr) { 940 return errors::InvalidArgument( 941 "After rewriting subgraph ", outside_compilation_subgraph_name, 942 " there is no HostCompute Op for outside compilation subgraph ", 943 oc_subgraph.host_compute_name); 944 } 945 946 if (inference_graph == nullptr) { 947 host_compute->AddAttr("shape_inference_graph", ""); 948 host_compute->AddAttr("shapes", shapes); 949 } else { 950 string serialized_graph; 951 if (!inference_graph->SerializeToString(&serialized_graph)) { 952 return errors::Internal( 953 "Failed to serialize graph for outside compilation subgraph ", 954 oc_subgraph.host_compute_name); 955 } 956 host_compute->AddAttr("shape_inference_graph", serialized_graph); 957 host_compute->AddAttr("shapes", std::vector<TensorShapeProto>()); 958 } 959 return Status::OK(); 960 } 961 962 Status Encapsulator::Subgraph::ReplaceFunctionDef( 963 FunctionLibraryDefinition* library) { 964 const string& name = call_node_def_.name(); 965 966 FunctionDef fdef; 967 TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); 968 969 if (VLOG_IS_ON(1)) { 970 VLOG(2) << "Replace function def " << name; 971 dump_graph::DumpGraphToFile( 972 strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, 973 library); 974 dump_graph::DumpFunctionDefToFile( 975 strings::StrCat("replace_encapsulate_fdef_", name), fdef); 976 } 977 978 TF_RETURN_IF_ERROR(library->RemoveFunction(name)); 979 TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); 980 return Status::OK(); 981 } 982 983 Status Encapsulator::Subgraph::BuildParallelCheckOp( 984 const std::unordered_map<const Node*, Node*>& node_images, 985 Graph* graph_out) { 986 // Build an index mapping output positions to node/slot pairs in the 987 // original graph. 988 std::vector<NodeSlot> results_by_num(results_.size()); 989 for (const auto& entry : results_) { 990 results_by_num[entry.second] = entry.first; 991 } 992 993 // Build a parallel check NodeDef. 994 int num_results = results_by_num.size(); 995 std::vector<DataType> result_dtypes(num_results); 996 std::vector<NodeDefBuilder::NodeOut> expected_outputs(num_results); 997 std::vector<NodeDefBuilder::NodeOut> actual_outputs(num_results); 998 for (int i = 0; i < num_results; ++i) { 999 const NodeSlot& node_slot = results_by_num[i]; 1000 result_dtypes[i] = node_slot.node->output_type(node_slot.slot); 1001 expected_outputs[i] = 1002 NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), 1003 node_slot.slot, result_dtypes[i]); 1004 actual_outputs[i] = 1005 NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); 1006 } 1007 // Assign the parallel check op to a CPU on the same task as the cluster it is 1008 // checking. 1009 string device, dummy; 1010 if (!DeviceNameUtils::SplitDeviceName( 1011 call_node_inputs_->assigned_device_name(), &device, &dummy)) { 1012 return errors::InvalidArgument("Could not parse device name"); 1013 } 1014 strings::StrAppend(&device, "/cpu:0"); 1015 1016 NodeDef check_def; 1017 TF_RETURN_IF_ERROR( 1018 NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), 1019 "_parallel_check")), 1020 "ParallelCheck") 1021 .Device(device) 1022 .Attr("T", result_dtypes) 1023 .Input(expected_outputs) 1024 .Input(actual_outputs) 1025 .Finalize(&check_def)); 1026 1027 Status s; 1028 Node* check_op = graph_out->AddNode(check_def, &s); 1029 if (!s.ok()) return s; 1030 check_op->set_assigned_device_name(device); 1031 1032 // TODO(phawkins): it seems redundant to call AddEdge as well as 1033 // pass Inputs to the NodeDefBuilder, but I have been unable to find a 1034 // way to avoid it. 1035 for (int i = 0; i < num_results; ++i) { 1036 const NodeSlot& node_slot = results_by_num[i]; 1037 graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, 1038 i); 1039 graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); 1040 } 1041 1042 call_node_outputs_ = check_op; 1043 return Status::OK(); 1044 } 1045 1046 Status Encapsulator::Subgraph::AddFunctionCallNode( 1047 const std::unordered_map<const Node*, Node*>& node_images, 1048 bool parallel_checking, Graph* graph_out) { 1049 Status s; 1050 call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); 1051 if (!s.ok()) return s; 1052 1053 // Copy the assigned device and the key_annotation over. 1054 call_node_inputs_->set_assigned_device_name(device_); 1055 call_node_outputs_ = call_node_inputs_; 1056 1057 if (parallel_checking) { 1058 TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); 1059 } 1060 return Status::OK(); 1061 } 1062 1063 Status Encapsulator::Subgraph::AddRecvAtHostNode( 1064 const string& subgraph_name, const string& oc_subgraph_name, 1065 OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { 1066 std::vector<DataType> dtypes(oc_subgraph->inputs.size(), DT_INVALID); 1067 1068 for (const auto& input : oc_subgraph->inputs) { 1069 const Node* src_node = input.first.node; 1070 int src_slot = input.first.slot; 1071 int input_index = input.second; 1072 1073 DataType dtype = src_node->output_type(src_slot); 1074 dtypes[input_index] = dtype; 1075 } 1076 1077 NodeDef recv_def; 1078 NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, 1079 "_", oc_subgraph_name, "_recv"), 1080 kRecvAtHostOp); 1081 builder.Attr("Toutputs", dtypes); 1082 builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, 1083 "_", oc_subgraph_name)); 1084 Status s = builder.Finalize(&recv_def); 1085 if (!s.ok()) return s; 1086 1087 oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); 1088 if (!s.ok()) return s; 1089 oc_subgraph->recv_at_host->set_assigned_device_name(device_); 1090 1091 // Add a control dependency forcing the RecvAtHost to run before the subgraph 1092 // completes. This has no effect on execution order but prevents the 1093 // RecvAtHost being pruned. 1094 TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); 1095 graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_); 1096 1097 return Status::OK(); 1098 } 1099 1100 Status Encapsulator::Subgraph::AddSendFromHostNode( 1101 const std::unordered_map<const Node*, Node*>& node_images, 1102 const string& subgraph_name, const string& oc_subgraph_name, 1103 OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { 1104 std::vector<DataType> dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID); 1105 std::vector<NodeDefBuilder::NodeOut> inputs( 1106 oc_subgraph->outputs_by_src.size()); 1107 1108 for (const auto& output : oc_subgraph->outputs_by_src) { 1109 const Node* src_node = output.first.node; 1110 Node* src_image = node_images.at(src_node); 1111 int src_slot = output.first.slot; 1112 int output_index = output.second; 1113 1114 DataType dtype = src_node->output_type(src_slot); 1115 dtypes[output_index] = dtype; 1116 inputs[output_index].Reset(src_image->name(), src_slot, dtype); 1117 } 1118 1119 NodeDef send_def; 1120 NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, 1121 "_", oc_subgraph_name, "_send"), 1122 kSendFromHostOp); 1123 builder.Attr("Tinputs", dtypes); 1124 builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, 1125 "_", oc_subgraph_name)); 1126 builder.Input(inputs); 1127 Status s = builder.Finalize(&send_def); 1128 if (!s.ok()) return s; 1129 1130 oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); 1131 if (!s.ok()) return s; 1132 oc_subgraph->send_from_host->set_assigned_device_name(device_); 1133 1134 // Add a control dependency forcing the SendFromHost to run before the 1135 // subgraph completes. This has no effect on execution order but prevents the 1136 // RecvAtHost being pruned. 1137 TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); 1138 graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_); 1139 1140 return Status::OK(); 1141 } 1142 1143 Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( 1144 const string& subgraph_name, 1145 const std::unordered_map<const Node*, Node*>& node_images, 1146 Graph* graph_out) { 1147 for (auto& outside_compilation_subgraph_entry : 1148 outside_compilation_subgraphs_) { 1149 const string& oc_name = outside_compilation_subgraph_entry.first; 1150 OutsideCompilationSubgraph& oc_subgraph = 1151 outside_compilation_subgraph_entry.second; 1152 1153 if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { 1154 TF_RETURN_IF_ERROR( 1155 AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out)); 1156 } 1157 1158 if (!oc_subgraph.outputs_by_src.empty() || 1159 !oc_subgraph.control_outputs.empty()) { 1160 TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name, 1161 oc_name, &oc_subgraph, graph_out)); 1162 } 1163 } 1164 return Status::OK(); 1165 } 1166 1167 void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( 1168 std::vector<string>* names) const { 1169 for (auto& entry : outside_compilation_subgraphs_) { 1170 names->push_back(entry.first); 1171 } 1172 } 1173 1174 Status Encapsulator::GetFunctionNameAttr( 1175 Node const* node, string* attr, string* outside_compilation_attr) const { 1176 Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); 1177 if (s.code() == error::Code::NOT_FOUND) { 1178 // Return empty attr if there's no group_attribute. 1179 attr->clear(); 1180 } else { 1181 TF_RETURN_IF_ERROR(s); 1182 } 1183 bool has_group_attr = s.ok(); 1184 s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, 1185 outside_compilation_attr); 1186 if (s.code() == error::Code::NOT_FOUND) { 1187 // Return empty attr if there's no outside_compilation attribute. 1188 outside_compilation_attr->clear(); 1189 } else { 1190 TF_RETURN_IF_ERROR(s); 1191 if (!has_group_attr) { 1192 return errors::InvalidArgument( 1193 "Node ", node->name(), " has ", outside_compilation_attribute_, 1194 " attribute but no ", group_attribute_, " attribute."); 1195 } 1196 } 1197 return Status::OK(); 1198 } 1199 1200 bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { 1201 return !func_id.empty() && outside_compilation_id.empty(); 1202 } 1203 1204 Status Encapsulator::CopySubgraphNodes( 1205 std::unordered_map<const Node*, Node*>* node_images) { 1206 for (Node* node : graph_in_->op_nodes()) { 1207 string func_id; 1208 string outside_compilation_id; 1209 TF_RETURN_IF_ERROR( 1210 GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); 1211 if (!IsInSubgraph(func_id, outside_compilation_id)) continue; 1212 1213 Subgraph& subgraph = subgraphs_[func_id]; 1214 Node* image = subgraph.MakeNodeImage(graph_in_, node); 1215 image->ClearAttr(group_attribute_); 1216 (*node_images)[node] = image; 1217 } 1218 return Status::OK(); 1219 } 1220 1221 Status Encapsulator::CopySubgraphEdges( 1222 const std::unordered_map<const Node*, Node*>& node_images, 1223 std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) { 1224 for (const Edge* edge : graph_in_->edges()) { 1225 string src_func_id; 1226 string src_outside_compilation_id; 1227 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, 1228 &src_outside_compilation_id)); 1229 string dst_func_id; 1230 string dst_outside_compilation_id; 1231 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, 1232 &dst_outside_compilation_id)); 1233 Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); 1234 Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); 1235 1236 // Copy edges that are local to a subgraph. 1237 if (IsInSubgraph(src_func_id, src_outside_compilation_id) && 1238 IsInSubgraph(dst_func_id, dst_outside_compilation_id) && 1239 src_func_id == dst_func_id) { 1240 Graph* g = subgraphs_[src_func_id].GetGraph(); 1241 if (edge->IsControlEdge()) { 1242 g->AddControlEdge(src_image, dst_image); 1243 } else { 1244 g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); 1245 } 1246 continue; 1247 } 1248 1249 // Record 'src' as an output of its subgraph, if applicable. 1250 if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { 1251 if (!edge->IsControlEdge()) { 1252 DataType dtype = edge->src()->output_type(edge->src_output()); 1253 if (IsRefType(dtype)) { 1254 return errors::InvalidArgument( 1255 "Ref Tensors (e.g., Variables) are not supported as results: " 1256 "tensor ", 1257 edge->src()->name(), ":", edge->src_output()); 1258 } 1259 } 1260 1261 Subgraph& src_subgraph = subgraphs_[src_func_id]; 1262 if (src_func_id == dst_func_id) { 1263 // src is in the subgraph and dst is outside_compilation in the same 1264 // subgraph. 1265 src_subgraph.RecordOutsideCompilationInputOrControl( 1266 dst_outside_compilation_id, edge); 1267 } else { 1268 // Ignore control edges leaving the subgraph. We will lift them onto the 1269 // enclosing call operators in BuildOutputGraph(). 1270 if (!edge->IsControlEdge()) { 1271 TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); 1272 } 1273 } 1274 } 1275 1276 // Record 'dst' as an input of its subgraph, if applicable. 1277 if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { 1278 // Look at the type of the destination not the source, since Ref output 1279 // Tensors can be automatically cast to non-Ref Tensors at the 1280 // destination. 1281 if (!edge->IsControlEdge()) { 1282 DataType dtype = edge->dst()->input_type(edge->dst_input()); 1283 if (IsRefType(dtype)) { 1284 return errors::InvalidArgument( 1285 "Ref Tensors (e.g., Variables) are not supported as args: " 1286 "tensor ", 1287 edge->src()->name(), ":", edge->src_output()); 1288 } 1289 } 1290 1291 Subgraph& dst_subgraph = subgraphs_[dst_func_id]; 1292 if (src_func_id == dst_func_id) { 1293 // dst is in the subgraph and src is outside_compilation in the same 1294 // subgraph. 1295 dst_subgraph.RecordOutsideCompilationOutputOrControl( 1296 src_outside_compilation_id, edge); 1297 } else { 1298 // Ignore control edges entering the subgraph. We will lift them onto 1299 // the enclosing call operators in BuildOutputGraph(). 1300 if (!edge->IsControlEdge()) { 1301 TF_RETURN_IF_ERROR( 1302 dst_subgraph.RecordArg(edge, node_images, src_arg_pairs)); 1303 } 1304 } 1305 } 1306 } 1307 return Status::OK(); 1308 } 1309 1310 Status Encapsulator::SplitIntoSubgraphs() { 1311 Status s; 1312 1313 // Map from input graph nodes to subgraph nodes. 1314 std::unordered_map<const Node*, Node*> node_images; 1315 1316 // Each entry of src_arg_pairs is a pair whose first element is a node in the 1317 // original graph that has an output edge in the subgraph, and whose second 1318 // element is the arg node in the subgraph that it sends to. The vector will 1319 // be filled in below in AddArgs. 1320 std::vector<std::pair<const Node*, Node*>> src_arg_pairs; 1321 1322 TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images)); 1323 TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs)); 1324 1325 // For each subgraph, add the nodes that deal with inputs and outputs its 1326 // nested outside_compilation subgraphs. These could not be added earlier 1327 // during CopySubgraphEdges since we need to discover all the types of the 1328 // inputs and outputs for an outside_compilation subgraph before creating a 1329 // single input and output node for it. 1330 for (auto& entry : subgraphs_) { 1331 Subgraph& subgraph = entry.second; 1332 TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images)); 1333 } 1334 1335 MarkGuaranteedConstants(*graph_in_, src_arg_pairs); 1336 1337 for (auto& entry : subgraphs_) { 1338 Subgraph& subgraph = entry.second; 1339 FixupSourceAndSinkEdges(subgraph.GetGraph()); 1340 } 1341 1342 return s; 1343 } 1344 1345 Status Encapsulator::BuildFunctionDefs( 1346 const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, 1347 FunctionLibraryDefinition* library) { 1348 for (auto& subgraph_entry : subgraphs_) { 1349 string name = subgraph_entry.first; 1350 Subgraph& subgraph = subgraph_entry.second; 1351 TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef( 1352 name, rewrite_subgraph_fn, reuse_existing_functions, library)); 1353 } 1354 return Status::OK(); 1355 } 1356 1357 Status Encapsulator::CopyNodesToOutputGraph( 1358 bool parallel_checking, Graph* graph_out, 1359 std::unordered_map<const Node*, Node*>* node_images) { 1360 for (Node* node : graph_in_->op_nodes()) { 1361 string func_id; 1362 string outside_compilation_id; 1363 TF_RETURN_IF_ERROR( 1364 GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); 1365 1366 // Don't copy nodes that going to be encapsulated, unless parallel checking 1367 // is enabled. 1368 if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) 1369 continue; 1370 1371 Node* image = graph_out->CopyNode(node); 1372 if (!outside_compilation_id.empty()) { 1373 if (parallel_checking) { 1374 return errors::InvalidArgument( 1375 "Parallel checking is not supported when outside_compilation " 1376 "clusters are present."); 1377 } 1378 image->ClearAttr(group_attribute_); 1379 image->ClearAttr(outside_compilation_attribute_); 1380 } 1381 (*node_images)[node] = image; 1382 } 1383 (*node_images)[graph_in_->source_node()] = graph_out->source_node(); 1384 (*node_images)[graph_in_->sink_node()] = graph_out->sink_node(); 1385 return Status::OK(); 1386 } 1387 1388 Status Encapsulator::AddFunctionCallNodes( 1389 const std::unordered_map<const Node*, Node*>& node_images, 1390 bool parallel_checking, Graph* graph_out) { 1391 for (auto& subgraph_entry : subgraphs_) { 1392 TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( 1393 node_images, parallel_checking, graph_out)); 1394 } 1395 return Status::OK(); 1396 } 1397 1398 Status Encapsulator::AddOutsideCompilationHostIONodes( 1399 const std::unordered_map<const Node*, Node*>& node_images, 1400 Graph* graph_out) { 1401 for (auto& subgraph_entry : subgraphs_) { 1402 const string& subgraph_name = subgraph_entry.first; 1403 Subgraph& subgraph = subgraph_entry.second; 1404 TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes( 1405 subgraph_name, node_images, graph_out)); 1406 } 1407 return Status::OK(); 1408 } 1409 1410 Status Encapsulator::FindOutputImageOfEdgeSrc( 1411 const string& src_func_id, const string& src_outside_compilation_id, 1412 const string& dst_func_id, const string& dst_outside_compilation_id, 1413 const std::unordered_map<const Node*, Node*>& node_images, 1414 const Node* original_src_node, Node** src_image) { 1415 if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { 1416 if (dst_func_id == src_func_id) { 1417 // The edge is from a subgraph to an outside_compilation cluster in the 1418 // same subgraph so use the appropriate _RecvAtHost node in the output 1419 // graph. 1420 TF_RET_CHECK(!dst_outside_compilation_id.empty()); 1421 *src_image = subgraphs_.at(src_func_id) 1422 .GetRecvAtHostNode(dst_outside_compilation_id); 1423 } else { 1424 // The edge is from a subgraph to a regular node in the output graph so 1425 // use the subgraph's call node output. 1426 *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); 1427 } 1428 } else { 1429 // The source of the edge is in the output graph so use the node image in 1430 // the output graph. 1431 *src_image = node_images.at(original_src_node); 1432 } 1433 return Status::OK(); 1434 } 1435 1436 int Encapsulator::FindOutputSlotOfEdgeSrc( 1437 const string& src_func_id, const string& src_outside_compilation_id, 1438 const string& dst_func_id, const string& dst_outside_compilation_id, 1439 const Edge* edge) { 1440 if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { 1441 const Subgraph& src_subgraph = subgraphs_.at(src_func_id); 1442 if (src_func_id == dst_func_id) { 1443 // 'src' is in a subgraph and 'dst' is outside_compilation in the same 1444 // subgraph. Use the corresponding _RecvAtHost output instead. 1445 return src_subgraph.GetRecvAtHostSlot(dst_outside_compilation_id, edge); 1446 } else { 1447 // 'src' is in a subgraph and 'dst' is a regular node in the output 1448 // graph. Use the corresponding call output instead. 1449 return src_subgraph.GetResultIndexForEdge(edge); 1450 } 1451 } else { 1452 // The source of the edge is in the output graph so use the regular edge 1453 // slot. 1454 return edge->src_output(); 1455 } 1456 } 1457 1458 Status Encapsulator::FindOutputImageOfEdgeDst( 1459 const string& src_func_id, const string& src_outside_compilation_id, 1460 const string& dst_func_id, const string& dst_outside_compilation_id, 1461 const std::unordered_map<const Node*, Node*>& node_images, 1462 const Node* original_dst_node, Node** dst_image) { 1463 if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { 1464 if (src_func_id == dst_func_id) { 1465 // The edge is to a subgraph from an outside_compilation cluster in the 1466 // same subgraph so use the appropriate _SendFromHost node in the output 1467 // graph. 1468 TF_RET_CHECK(!src_outside_compilation_id.empty()); 1469 *dst_image = subgraphs_.at(dst_func_id) 1470 .GetSendFromHostNode(src_outside_compilation_id); 1471 } else { 1472 // The edge is to a subgraph from a regular node in the output graph so 1473 // use the subgraph's call node input. 1474 *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); 1475 } 1476 } else { 1477 // The destination of the edge is in the output graph so use the node image 1478 // in the output graph. 1479 *dst_image = node_images.at(original_dst_node); 1480 } 1481 return Status::OK(); 1482 } 1483 1484 int Encapsulator::FindOutputSlotOfEdgeDst( 1485 const string& src_func_id, const string& src_outside_compilation_id, 1486 const string& dst_func_id, const string& dst_outside_compilation_id, 1487 const Edge* edge) { 1488 if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { 1489 const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); 1490 if (dst_func_id == src_func_id) { 1491 // 'dst' is in a subgraph and 'src' is outside_compilation in the same 1492 // subgraph. Use the corresponding _SendFromHost input instead. 1493 return dst_subgraph.GetSendFromHostSlot(src_outside_compilation_id, edge); 1494 } else { 1495 // 'dst' is in a subgraph and 'src' is a regular node in the output 1496 // graph. Use the corresponding call input instead. 1497 return dst_subgraph.GetArgIndexForEdge(edge); 1498 } 1499 } else { 1500 // The destination of the edge is in the output graph so use the regular 1501 // edge slot. 1502 return edge->dst_input(); 1503 } 1504 } 1505 1506 Status Encapsulator::CopyEdgeToOutputGraph( 1507 const Edge* edge, const string& src_func_id, 1508 const string& src_outside_compilation_id, const string& dst_func_id, 1509 const string& dst_outside_compilation_id, 1510 const std::unordered_map<const Node*, Node*>& node_images, 1511 bool parallel_checking, Graph* graph_out, 1512 std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>* 1513 edges_added) { 1514 Node* src_image; 1515 TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( 1516 src_func_id, src_outside_compilation_id, dst_func_id, 1517 dst_outside_compilation_id, node_images, edge->src(), &src_image)); 1518 Node* dst_image; 1519 TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst( 1520 src_func_id, src_outside_compilation_id, dst_func_id, 1521 dst_outside_compilation_id, node_images, edge->dst(), &dst_image)); 1522 1523 // If this is a control edge then copy it and return. Lift control edges onto 1524 // the enclosing call operator. 1525 if (edge->IsControlEdge()) { 1526 // Add the control edge, if we have not already added it, using the images 1527 // determined above (potentially call operators or RecvAtHost/SendFromHost). 1528 if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) 1529 .second) { 1530 graph_out->AddControlEdge(src_image, dst_image); 1531 } 1532 1533 // If parallel checking is enabled, also add a control edge to the 1534 // corresponding parallel check op. 1535 if (parallel_checking) { 1536 graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); 1537 } 1538 return Status::OK(); 1539 } 1540 1541 int src_output = 1542 FindOutputSlotOfEdgeSrc(src_func_id, src_outside_compilation_id, 1543 dst_func_id, dst_outside_compilation_id, edge); 1544 1545 int dst_input = 1546 FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, 1547 dst_func_id, dst_outside_compilation_id, edge); 1548 1549 if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && 1550 parallel_checking) { 1551 // If we are parallel checking, also feed the tensor as an input to the 1552 // corresponding parallel check subgraph. 1553 graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), 1554 edge->dst_input()); 1555 } 1556 1557 // Add the edge, if we have not already added it. 1558 if (edges_added 1559 ->emplace(NodeSlot(src_image, src_output), 1560 NodeSlot(dst_image, dst_input)) 1561 .second) { 1562 graph_out->AddEdge(src_image, src_output, dst_image, dst_input); 1563 } 1564 return Status::OK(); 1565 } 1566 1567 Status Encapsulator::AddEdgesToOutputGraph( 1568 const std::unordered_map<const Node*, Node*>& node_images, 1569 bool parallel_checking, Graph* graph_out) { 1570 // Set of edges already added to the output graph, represented as (src, dst) 1571 // pairs. We use the set to deduplicate edges; multiple edges in the input 1572 // graph may map to one edge in the output graph. 1573 std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher> 1574 edges_added; 1575 1576 for (const Edge* edge : graph_in_->edges()) { 1577 string src_func_id; 1578 string src_outside_compilation_id; 1579 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, 1580 &src_outside_compilation_id)); 1581 string dst_func_id; 1582 string dst_outside_compilation_id; 1583 TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, 1584 &dst_outside_compilation_id)); 1585 1586 // Ignore edges that are strictly contained within one subgraph, unless 1587 // we are constructing parallel check graphs. 1588 if (IsInSubgraph(src_func_id, src_outside_compilation_id) && 1589 IsInSubgraph(dst_func_id, dst_outside_compilation_id) && 1590 src_func_id == dst_func_id) { 1591 if (parallel_checking) { 1592 Node* src_image = node_images.at(edge->src()); 1593 Node* dst_image = node_images.at(edge->dst()); 1594 if (edge->IsControlEdge()) { 1595 graph_out->AddControlEdge(src_image, dst_image); 1596 } else { 1597 graph_out->AddEdge(src_image, edge->src_output(), dst_image, 1598 edge->dst_input()); 1599 } 1600 } 1601 continue; 1602 } 1603 1604 // We have an edge that crosses a cluster boundary or is entirely within the 1605 // unclustered graph. 1606 TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( 1607 edge, src_func_id, src_outside_compilation_id, dst_func_id, 1608 dst_outside_compilation_id, node_images, parallel_checking, graph_out, 1609 &edges_added)); 1610 } 1611 1612 for (auto& subgraph_entry : subgraphs_) { 1613 Subgraph& subgraph = subgraph_entry.second; 1614 subgraph.ConnectSequencerToOutputs(graph_out); 1615 } 1616 1617 return Status::OK(); 1618 } 1619 1620 namespace { 1621 1622 // Adds a dummy Const node to graph_out. The "constant" has the type of 1623 // data_type and the shape indicated in 'shape'. The dummy node is not a valid 1624 // Const node because it does not have any value defined, but this doesn't 1625 // matter because it will only be used subsequently for shape inference. (It 1626 // would be possible to add a switch statement over data_type to create a value 1627 // for the constant, but that would entail maintaining the logic as new types 1628 // are added, and is not necessary.) 1629 Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, 1630 Graph* graph_out) { 1631 TensorProto dummy_proto; 1632 dummy_proto.set_dtype(data_type); 1633 *dummy_proto.mutable_tensor_shape() = shape; 1634 // Don't set any value field in the proto, since it is only going to be used 1635 // for shape inference. 1636 1637 GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); 1638 NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const", 1639 options.op_registry()); 1640 node_builder.Attr("dtype", data_type).Attr("value", dummy_proto); 1641 return options.FinalizeBuilder(&node_builder); 1642 } 1643 1644 // Adds a copy of node_in to graph_out and adds the mapping to 1645 // copied_node_images. 1646 Status CopyShapeInferenceNodeToGraph( 1647 Node* node_in, const Node* send_node, 1648 const std::unordered_map<Node*, Node*>& dummy_node_images, 1649 FunctionLibraryDefinition* library, 1650 std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) { 1651 // Once all the ancestor nodes have been added to graph_out, add this node 1652 // and connect it to its ancestors. 1653 Node* node_out = graph_out->CopyNode(node_in); 1654 (*copied_node_images)[node_in] = node_out; 1655 // Don't bother to build the shape inference graph if there's a node with no 1656 // shape inference function, since it would just result in an error later at 1657 // compile time. 1658 const OpRegistrationData* op_reg_data; 1659 TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data)); 1660 if (op_reg_data->shape_inference_fn == nullptr) { 1661 return errors::InvalidArgument( 1662 "Shape inference is not possible for outside_compilation " 1663 "SendFromHost node ", 1664 send_node->name(), " because it depends on node ", node_in->name(), 1665 " which does not have a shape inference function registered."); 1666 } 1667 // Add all the edges to the newly copied node. 1668 for (const Edge* in_edge : node_in->in_edges()) { 1669 if (!in_edge->IsControlEdge()) { 1670 Node* src = in_edge->src(); 1671 const auto iter = dummy_node_images.find(src); 1672 if (iter == dummy_node_images.end()) { 1673 // The src is a copied node so use the original output port. 1674 graph_out->AddEdge((*copied_node_images)[in_edge->src()], 1675 in_edge->src_output(), node_out, 1676 in_edge->dst_input()); 1677 } else { 1678 // The src is a dummy node so use output port 0. 1679 graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input()); 1680 } 1681 } 1682 } 1683 return Status::OK(); 1684 } 1685 1686 } // namespace 1687 1688 Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( 1689 const Graph& graph_in, const ShapeRefiner& shape_refiner, 1690 const std::unordered_set<string>& recv_at_host_nodes, Node* send_node, 1691 FunctionLibraryDefinition* library, 1692 std::vector<TensorShapeProto>* static_shape_out, 1693 std::unique_ptr<GraphDef>* graphdef_out) { 1694 // Maps from nodes in graph_in to nodes in graph_out. 1695 // 1696 // When an edge has fully defined shape the source node in graph_in is 1697 // replaced in graph_out by a dummy constant node. The mapping from nodes 1698 // in graph_in to dummy nodes is stored in dummy_node_images. 1699 // 1700 // When a node in graph_in has at least one ancestor that doesn't have fully 1701 // defined shape, it is copied into graph_out. The mapping from nodes in 1702 // graph_in to copied nodes is stored in copied_node_images. 1703 // 1704 // The two types of node are treated differently because, when adding edges to 1705 // graph_out, an output from a dummy node always uses port 0, whereas an 1706 // output from a copied node uses the same port that was used in graph_in. 1707 std::unordered_map<Node*, Node*> dummy_node_images; 1708 std::unordered_map<Node*, Node*> copied_node_images; 1709 1710 std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry())); 1711 graph_out->set_versions(graph_in.versions()); 1712 static_shape_out->resize(send_node->num_inputs()); 1713 1714 // We don't use the standard ReverseDFS because we want to cut off traversal 1715 // whenever we find an output with fully defined shape. 1716 // TODO(misard) make this work properly in the presence of control flow. 1717 struct Work { 1718 Node* node; 1719 bool leave; // Are we entering or leaving node? 1720 }; 1721 std::vector<Work> stack({{send_node, false}}); 1722 std::vector<bool> visited(graph_in.num_node_ids(), false); 1723 while (!stack.empty()) { 1724 Work w = stack.back(); 1725 stack.pop_back(); 1726 Node* n = w.node; 1727 1728 if (w.leave) { 1729 TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( 1730 n, send_node, dummy_node_images, library, &copied_node_images, 1731 graph_out.get())); 1732 } else { 1733 if (visited[n->id()]) continue; 1734 visited[n->id()] = true; 1735 1736 // Arrange to revisit when all done with all inputs. 1737 stack.push_back(Work{n, true}); 1738 1739 bool has_parent_with_unknown_shape = false; 1740 for (const Edge* in_edge : n->in_edges()) { 1741 if (!in_edge->IsControlEdge()) { 1742 Node* src_node = in_edge->src(); 1743 int src_port = in_edge->src_output(); 1744 shape_inference::InferenceContext* context = 1745 shape_refiner.GetContext(src_node); 1746 shape_inference::ShapeHandle shape = context->output(src_port); 1747 if (context->FullyDefined(shape)) { 1748 // This ancestor has known shape, so instead of adding it to the 1749 // stack, add a dummy node with that shape to graph_out and 1750 // continue. 1751 TensorShapeProto proto; 1752 context->ShapeHandleToProto(shape, &proto); 1753 dummy_node_images[src_node] = AddDummyShapedNode( 1754 src_node->output_type(src_port), proto, graph_out.get()); 1755 if (n == send_node) { 1756 (*static_shape_out)[in_edge->dst_input()] = proto; 1757 } 1758 } else { 1759 if (!visited[src_node->id()]) { 1760 has_parent_with_unknown_shape = true; 1761 stack.push_back({src_node, false}); 1762 } 1763 } 1764 } 1765 } 1766 if (!has_parent_with_unknown_shape) { 1767 if (n == send_node) { 1768 // The shapes of all the inputs to send_node are statically known. We 1769 // won't have to do any inference at compile time so return now: the 1770 // shapes were stored in static_shape_out above. 1771 graphdef_out->reset(); 1772 return Status::OK(); 1773 } else { 1774 // Any shape that is being processed is either the original send node 1775 // or has at least one output with statically-unknown shape. If the 1776 // latter and it doesn't have any inputs with statically-unknown 1777 // shape, then check that it is of the recv nodes that we can fill in 1778 // the shape of at run-time later. If it isn't one of those, then we 1779 // won't have any additional knowledge at compile time, so we already 1780 // know we won't be able to do shape inference and we can return an 1781 // error now. 1782 if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) { 1783 return errors::InvalidArgument( 1784 "Shape inference is not possible for outside_compilation " 1785 "SendFromHost node ", 1786 send_node->name(), " because shape of node ", n->name(), 1787 " will not be known at compilation time."); 1788 } 1789 } 1790 } 1791 } 1792 } 1793 1794 graphdef_out->reset(new GraphDef()); 1795 graph_out->ToGraphDef(graphdef_out->get()); 1796 1797 return Status::OK(); 1798 } 1799 1800 Status Encapsulator::MakePrunedGraphCopyAndInline( 1801 const Graph& graph, const std::vector<Node*>& sink_nodes, 1802 std::unique_ptr<Graph>* pruned_graph, 1803 std::unordered_map<const Node*, Node*>* node_images, 1804 FunctionLibraryDefinition* library) { 1805 // First copy all ancestor nodes of sink_nodes into a new graph. 1806 pruned_graph->reset(new Graph(library)); 1807 (*pruned_graph)->set_versions(graph.versions()); 1808 ReverseDFSFrom(graph, sink_nodes, 1809 /*enter=*/nullptr, 1810 /*leave=*/[&](Node* n) { 1811 if (!n->IsSource()) { 1812 Node* copied = (*pruned_graph)->CopyNode(n); 1813 node_images->emplace(n, copied); 1814 } 1815 }); 1816 1817 // Add all the edges between copied nodes. 1818 for (auto entry : *node_images) { 1819 const Node* orig = entry.first; 1820 Node* image = entry.second; 1821 for (const Edge* out_edge : orig->out_edges()) { 1822 auto iter = node_images->find(out_edge->dst()); 1823 if (iter != node_images->end()) { 1824 // The source and destination are both in the copied graph. 1825 (*pruned_graph) 1826 ->AddEdge(image, out_edge->src_output(), iter->second, 1827 out_edge->dst_input()); 1828 } 1829 } 1830 } 1831 1832 // Find all the function call nodes, and inline them. 1833 std::vector<Node*> function_nodes; 1834 for (auto node : (*pruned_graph)->nodes()) { 1835 const OpRegistrationData* op_reg_data; 1836 TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data)); 1837 if (op_reg_data->is_function_op) { 1838 function_nodes.push_back(node); 1839 } 1840 } 1841 for (auto node : function_nodes) { 1842 VLOG(2) << "Inlining function " << node->name(); 1843 const FunctionDef* fdef = library->Find(node->type_string()); 1844 if (fdef == nullptr) { 1845 return errors::Internal("Failed to find function ", node->type_string(), 1846 " in function library."); 1847 } 1848 FunctionBody* fbody = nullptr; 1849 TF_RETURN_IF_ERROR( 1850 FunctionDefToBodyHelper(*fdef, node->attrs(), library, 1851 [library](const string& op, const OpDef** sig) { 1852 return library->LookUpOpDef(op, sig); 1853 }, 1854 &fbody)); 1855 InlineFunctionBody(*library, pruned_graph->get(), node, fbody); 1856 delete fbody; 1857 } 1858 1859 return Status::OK(); 1860 } 1861 1862 Status Encapsulator::MakeGraphForOutsideCompilationSends( 1863 const Graph& graph, std::unique_ptr<Graph>* pruned_graph, 1864 ShapeRefiner* shape_refiner, 1865 std::unordered_map<const Node*, Node*>* node_images, 1866 FunctionLibraryDefinition* library) { 1867 // Find all the send_from_host nodes in all subgraphs, to use as roots for the 1868 // pruning. 1869 std::vector<Node*> send_from_host_nodes; 1870 for (auto& subgraph_entry : subgraphs_) { 1871 Subgraph& subgraph = subgraph_entry.second; 1872 std::vector<string> outside_compilation_names; 1873 subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); 1874 for (const auto& name : outside_compilation_names) { 1875 Node* send_node = subgraph.GetSendFromHostNode(name); 1876 if (send_node != nullptr) { 1877 send_from_host_nodes.push_back(send_node); 1878 } 1879 } 1880 } 1881 1882 // Make a copy of all the graph nodes needed to evaluate the send_from_host 1883 // nodes, inlining any functions as needed. 1884 TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( 1885 graph, send_from_host_nodes, pruned_graph, node_images, library)); 1886 1887 // Perform shape inference on the pruned graph. 1888 shape_refiner->set_require_shape_inference_fns(false); 1889 FixupSourceAndSinkEdges(pruned_graph->get()); 1890 std::vector<Node*> post_order; 1891 GetReversePostOrder(*(*pruned_graph), &post_order); 1892 for (auto node : post_order) { 1893 // Ignore the status returned by the shape_refiner. At this point we want 1894 // the best effort shapes, even if no shape function is registered for a 1895 // node. 1896 Status status = shape_refiner->AddNode(node); 1897 if (!status.ok()) { 1898 VLOG(1) << "Shape inference failed for node: " << status; 1899 } 1900 } 1901 1902 return Status::OK(); 1903 } 1904 1905 Status Encapsulator::GetShapeInfoForOutsideCompilationSends( 1906 Graph* graph_out, FunctionLibraryDefinition* library) { 1907 std::unique_ptr<Graph> pruned_graph; 1908 ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); 1909 std::unordered_map<const Node*, Node*> node_images; 1910 TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( 1911 *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); 1912 1913 for (auto& subgraph_entry : subgraphs_) { 1914 Subgraph& subgraph = subgraph_entry.second; 1915 // Find all the recv_at_host nodes in this subgraph. 1916 std::vector<string> outside_compilation_names; 1917 subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); 1918 std::unordered_set<string> recv_at_host_names; 1919 for (const auto& name : outside_compilation_names) { 1920 Node* recv_node = subgraph.GetRecvAtHostNode(name); 1921 if (recv_node != nullptr) { 1922 recv_at_host_names.insert(recv_node->name()); 1923 } 1924 } 1925 // For each send_from_host node, do as much shape inference as possible 1926 // without knowing the shape of the recv_at_host nodes, and store the 1927 // result, along with enough information to complete the job at compile time 1928 // once the recv_at_host shapes are known. 1929 for (const auto& name : outside_compilation_names) { 1930 Node* send_node = subgraph.GetSendFromHostNode(name); 1931 std::vector<TensorShapeProto> static_shape; 1932 std::unique_ptr<GraphDef> graphdef; 1933 if (send_node != nullptr) { 1934 TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( 1935 *pruned_graph, shape_refiner, recv_at_host_names, 1936 node_images[send_node], library, &static_shape, &graphdef)); 1937 if (graphdef == nullptr) { 1938 VLOG(2) << "Send node " << send_node->name() << " shapes"; 1939 for (int i = 0; i < static_shape.size(); ++i) { 1940 VLOG(2) << static_shape[i].DebugString(); 1941 } 1942 } else { 1943 VLOG(2) << "Send node " << send_node->name() << " graph\n" 1944 << graphdef->DebugString(); 1945 } 1946 } 1947 TF_RETURN_IF_ERROR( 1948 subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); 1949 } 1950 if (!outside_compilation_names.empty()) { 1951 TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); 1952 } 1953 } 1954 1955 return Status::OK(); 1956 } 1957 1958 Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, 1959 FunctionLibraryDefinition* library) { 1960 // Map from nodes in the input graph to nodes in the output graph. 1961 std::unordered_map<const Node*, Node*> node_images; 1962 1963 TF_RETURN_IF_ERROR( 1964 CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); 1965 TF_RETURN_IF_ERROR( 1966 AddFunctionCallNodes(node_images, parallel_checking, graph_out)); 1967 TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); 1968 TF_RETURN_IF_ERROR( 1969 AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); 1970 1971 TF_RETURN_IF_ERROR( 1972 GetShapeInfoForOutsideCompilationSends(graph_out, library)); 1973 1974 return Status::OK(); 1975 } 1976 1977 } // anonymous namespace 1978 1979 Status EncapsulateSubgraphsInFunctions( 1980 string group_attribute, string outside_compilation_attribute, 1981 const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, 1982 bool parallel_checking, bool reuse_existing_functions, 1983 std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) { 1984 Status s; 1985 1986 Encapsulator encapsulator(std::move(group_attribute), 1987 std::move(outside_compilation_attribute), 1988 &graph_in); 1989 TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); 1990 1991 TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( 1992 rewrite_subgraph_fn, reuse_existing_functions, library)); 1993 1994 std::unique_ptr<Graph> out(new Graph(library)); 1995 out->set_versions(graph_in.versions()); 1996 TF_RETURN_IF_ERROR( 1997 encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); 1998 1999 *graph_out = std::move(out); 2000 return Status::OK(); 2001 } 2002 2003 // Finds the types of the _Arg nodes, indexed by position. 2004 static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { 2005 for (Node* n : graph.op_nodes()) { 2006 if (n->type_string() == kArgOp) { 2007 int index; 2008 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 2009 if (index < 0 || index >= types->size()) { 2010 return errors::InvalidArgument("Invalid argument number"); 2011 } 2012 (*types)[index] = n->output_type(0); 2013 } 2014 } 2015 return Status::OK(); 2016 } 2017 2018 // Renumber the indices of _Arg nodes in a graph, according to 2019 // 'permutation' that maps old indices to new indices. 2020 static Status RenumberArguments(Graph* graph, 2021 const std::vector<int>& permutation) { 2022 for (Node* n : graph->op_nodes()) { 2023 if (n->type_string() == kArgOp) { 2024 int index; 2025 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); 2026 if (index < 0 || index >= permutation.size()) { 2027 return errors::InvalidArgument("Invalid argument number"); 2028 } 2029 n->AddAttr("index", permutation[index]); 2030 } 2031 } 2032 return Status::OK(); 2033 } 2034 2035 Status EncapsulateSubgraphsPass::Run( 2036 const GraphOptimizationPassOptions& options) { 2037 VLOG(1) << "EncapsulateSubgraphsPass::Run"; 2038 legacy_flags::EncapsulateSubgraphsPassFlags* flags = 2039 legacy_flags::GetEncapsulateSubgraphsPassFlags(); 2040 if (VLOG_IS_ON(1)) { 2041 dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, 2042 options.flib_def); 2043 } 2044 2045 std::unique_ptr<Graph> graph_out; 2046 FunctionLibraryDefinition* const library = options.flib_def; 2047 2048 OptimizerOptions opts; 2049 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( 2050 new ProcessFunctionLibraryRuntime(nullptr, options.session_options->env, 2051 TF_GRAPH_DEF_VERSION, library, opts)); 2052 FunctionLibraryRuntime* flr = 2053 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); 2054 2055 auto rewrite_subgraph = [flr](std::unique_ptr<Graph>* subgraph, 2056 std::vector<int>* input_permutation, 2057 std::vector<int>* output_permutation, 2058 NodeDef* node) { 2059 // Optimize the subgraph. 2060 OptimizeGraph(flr, subgraph); 2061 2062 const int num_args = input_permutation->size(); 2063 std::vector<bool> const_args(num_args); 2064 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); 2065 2066 DataTypeVector arg_types(num_args); 2067 TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); 2068 2069 // Compute a permutation of the arguments such that the constant arguments 2070 // are first. 2071 const int num_consts = 2072 std::count(const_args.begin(), const_args.end(), true); 2073 2074 const int num_resources = 2075 std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); 2076 const int num_nonconsts = num_args - num_resources - num_consts; 2077 if (num_nonconsts < 0) { 2078 return errors::Internal("num_nonconsts should be >= 0, was ", 2079 num_nonconsts); 2080 } 2081 2082 int const_pos = 0; 2083 int arg_pos = num_consts; 2084 int resource_pos = num_consts + num_nonconsts; 2085 for (int i = 0; i < num_args; ++i) { 2086 if (const_args[i]) { 2087 if (arg_types[i] == DT_RESOURCE) { 2088 return errors::Internal( 2089 "Resource arguments cannot be constant (argument ", i, ")"); 2090 } 2091 (*input_permutation)[i] = const_pos; 2092 ++const_pos; 2093 } else if (arg_types[i] == DT_RESOURCE) { 2094 (*input_permutation)[i] = resource_pos; 2095 ++resource_pos; 2096 } else { 2097 (*input_permutation)[i] = arg_pos; 2098 ++arg_pos; 2099 } 2100 } 2101 2102 // Renumber argument nodes in the graph. 2103 TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation)); 2104 2105 // TODO(phawkins): add a forward is-constant analysis, similarly split 2106 // outputs into host-memory constants and device-memory non-constants. 2107 2108 AddNodeAttr(kXlaCompiledKernelAttr, true, node); 2109 AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); 2110 AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); 2111 return Status::OK(); 2112 }; 2113 2114 TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( 2115 kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, 2116 rewrite_subgraph, flags->tf_xla_parallel_checking, 2117 /*reuse_existing_functions=*/false, &graph_out, library)); 2118 2119 if (VLOG_IS_ON(1)) { 2120 dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, 2121 options.flib_def); 2122 } 2123 2124 *options.graph = std::move(graph_out); 2125 return Status::OK(); 2126 } 2127 2128 bool IsXlaCompiledKernel(const Node& node) { 2129 bool is_compiled = false; 2130 bool has_compilation_attr = 2131 GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() && 2132 is_compiled; 2133 return has_compilation_attr ? is_compiled : false; 2134 } 2135 2136 } // namespace tensorflow 2137