1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/core/common_runtime/build_graph_options.h" 25 #include "tensorflow/core/common_runtime/device.h" 26 #include "tensorflow/core/common_runtime/device_set.h" 27 #include "tensorflow/core/framework/graph.pb.h" 28 #include "tensorflow/core/graph/costmodel.h" 29 #include "tensorflow/core/graph/graph.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/types.h" 33 34 namespace tensorflow { 35 struct SessionOptions; 36 37 namespace subgraph { 38 struct RewriteGraphMetadata; 39 } 40 41 struct GraphExecutionStateOptions { 42 const DeviceSet* device_set = nullptr; 43 const SessionOptions* session_options = nullptr; 44 // A map from node name to device name, representing the unchangeable 45 // placement of stateful nodes. 46 std::unordered_map<string, string> stateful_placements; 47 }; 48 49 // A ClientGraph is simply a sub-graph of the full graph as induced by 50 // BuildGraphOptions. 51 struct ClientGraph { 52 explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib, 53 DataTypeVector feed_types, DataTypeVector fetch_types) 54 : flib_def(std::move(flib)), 55 graph(flib_def.get()), 56 feed_types(std::move(feed_types)), 57 fetch_types(std::move(fetch_types)) {} 58 // Each client-graph gets its own function library since optimization passes 59 // post rewrite for execution might want to introduce new functions. 60 std::unique_ptr<FunctionLibraryDefinition> flib_def; 61 Graph graph; 62 DataTypeVector feed_types; 63 DataTypeVector fetch_types; 64 }; 65 66 // GraphExecutionState is responsible for generating an 67 // executable ClientGraph from the original GraphDef that specifies 68 // the complete graph and from BuildGraphOptions which specifies 69 // input/output nodes. 70 // 71 // An executable Graph differs from a GraphDef by being Placed, 72 // meaning that each Node is assigned to a single Device in the 73 // available set. 74 // 75 // When GraphExecutionState is first constructed it instantiates 76 // a full Graph from the provided GraphDef, and places it, using only 77 // the static device assignments from the GraphDef. Nodes without are 78 // currently placed in a very naive way. Since stateful Nodes cannot 79 // be moved after initial placement, it is important that stateful 80 // Nodes get sensible initial device assignments in the graph 81 // definition. 82 // 83 // Subsequently, GraphExecutionState generates a SimpleClientGraph on 84 // demand, which is a sub-graph of the latest placement of the full 85 // Graph. MasterSession uses such a ClientGraph to execute one or 86 // more similar client requests. 87 // 88 // GraphExecutionState is thread-safe. 89 90 class GraphExecutionState { 91 public: 92 virtual ~GraphExecutionState(); 93 94 // Creates a new `GraphExecutionState` for the given 95 // `graph_def`, which represents the entire graph for a session. 96 // 97 // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def` 98 // in an undefined state. If it is necessary to use `*graph_def` 99 // after this call, make an explicit copy of the graph before 100 // calling this method. 101 static Status MakeForBaseGraph( 102 GraphDef* graph_def, const GraphExecutionStateOptions& options, 103 std::unique_ptr<GraphExecutionState>* out_state); 104 105 // Creates a new `GraphExecutionState` and `SimpleClientGraph` 106 // for the subgraph of `original_graph_def` defined by 107 // `subgraph_options`. 108 static Status MakeForPrunedGraph( 109 const FunctionDefLibrary& func_def_lib, 110 const GraphExecutionStateOptions& options, 111 const GraphDef& original_graph_def, 112 const BuildGraphOptions& subgraph_options, 113 std::unique_ptr<GraphExecutionState>* out_state, 114 std::unique_ptr<ClientGraph>* out_client_graph); 115 116 // Creates a new GraphExecutionState representing the 117 // concatenation of this graph, and the graph defined by 118 // "extension_def". The same name may not be used to define a node 119 // in both this graph and "extension_def". 120 // 121 // If successful, returns OK and the caller takes ownership of "*out". 122 // Otherwise returns an error and does not modify "*out". 123 // 124 // After calling `old_state->Extend()`, `old_state` may no longer be 125 // used. 126 // 127 // NOTE(mrry): This method respects the placement of stateful nodes in 128 // in *this, but currently does not transfer any other placement 129 // or cost model information to the new graph. 130 Status Extend(const GraphDef& extension_def, 131 std::unique_ptr<GraphExecutionState>* out) const; 132 133 // Builds a ClientGraph (a sub-graph of the full graph as induced by 134 // the Node set specified in "options"). If successful, returns OK 135 // and the caller takes the ownership of "*out". Otherwise, returns 136 // an error. 137 Status BuildGraph(const BuildGraphOptions& options, 138 std::unique_ptr<ClientGraph>* out); 139 140 // The graph returned by BuildGraph may contain only the pruned 141 // graph, whereas some clients may want access to the full graph. 142 const Graph* full_graph() { return graph_; } 143 144 // Returns the node with the given name, or null if it does not exist. 145 const Node* get_node_by_name(const string& name) const { 146 NodeNameToCostIdMap::const_iterator iter = 147 node_name_to_cost_id_map_.find(name); 148 if (iter != node_name_to_cost_id_map_.end()) { 149 return graph_->FindNodeId(iter->second); 150 } else { 151 return nullptr; 152 } 153 } 154 155 // Returns a reference to the current graph_def. Use must 156 // not extend beyond lifetime of GrahExecutionState object. 157 const GraphDef& original_graph_def() { return original_graph_def_; } 158 159 // Returns the map of stateful placements as a map of 160 // node name to placement string. 161 std::unordered_map<string, string> GetStatefulPlacements() const { 162 return stateful_placements_; 163 } 164 165 private: 166 GraphExecutionState(GraphDef* graph_def, 167 const GraphExecutionStateOptions& options); 168 169 Status InitBaseGraph(const BuildGraphOptions& options); 170 171 // Map of placed stateful nodes, i.e. nodes for which is_stateful() 172 // is true, such as "params" and "queue" nodes. Once placed these 173 // nodes can not be moved to a different device. Maps node names to 174 // device names. 175 std::unordered_map<string, string> stateful_placements_; // Immutable after 176 // ctor. 177 void SaveStatefulNodes(Graph* graph); 178 void RestoreStatefulNodes(Graph* graph); 179 180 Status OptimizeGraph(const BuildGraphOptions& options, 181 std::unique_ptr<Graph>* optimized_graph); 182 183 GraphDef original_graph_def_; // Immutable after ctor. 184 const DeviceSet* device_set_; // Not owned 185 const SessionOptions* session_options_; // Not owned 186 187 // Map from name to Node for the full graph in placed_. 188 NodeNameToCostIdMap node_name_to_cost_id_map_; 189 190 // 'flib_def_' is initialized from the initial graph def's library, 191 // and may be updated by a graph optimization pass. 192 std::unique_ptr<FunctionLibraryDefinition> flib_def_; 193 194 // `rewrite_metadata_` is only set for GraphExecutionState 195 // objects created by `MakeForPrunedGraph()`. 196 std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_; 197 198 // The dataflow graph owned by this object. 199 Graph* graph_; 200 201 TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState); 202 }; 203 204 } // namespace tensorflow 205 206 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ 207