Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
     17 #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/common_runtime/build_graph_options.h"
     25 #include "tensorflow/core/common_runtime/device.h"
     26 #include "tensorflow/core/common_runtime/device_set.h"
     27 #include "tensorflow/core/framework/graph.pb.h"
     28 #include "tensorflow/core/graph/costmodel.h"
     29 #include "tensorflow/core/graph/graph.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace tensorflow {
     35 struct SessionOptions;
     36 
     37 namespace subgraph {
     38 struct RewriteGraphMetadata;
     39 }
     40 
     41 struct GraphExecutionStateOptions {
     42   const DeviceSet* device_set = nullptr;
     43   const SessionOptions* session_options = nullptr;
     44   // A map from node name to device name, representing the unchangeable
     45   // placement of stateful nodes.
     46   std::unordered_map<string, string> stateful_placements;
     47 };
     48 
     49 // A ClientGraph is simply a sub-graph of the full graph as induced by
     50 // BuildGraphOptions.
     51 struct ClientGraph {
     52   explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
     53                        DataTypeVector feed_types, DataTypeVector fetch_types)
     54       : flib_def(std::move(flib)),
     55         graph(flib_def.get()),
     56         feed_types(std::move(feed_types)),
     57         fetch_types(std::move(fetch_types)) {}
     58   // Each client-graph gets its own function library since optimization passes
     59   // post rewrite for execution might want to introduce new functions.
     60   std::unique_ptr<FunctionLibraryDefinition> flib_def;
     61   Graph graph;
     62   DataTypeVector feed_types;
     63   DataTypeVector fetch_types;
     64 };
     65 
     66 // GraphExecutionState is responsible for generating an
     67 // executable ClientGraph from the original GraphDef that specifies
     68 // the complete graph and from BuildGraphOptions which specifies
     69 // input/output nodes.
     70 //
     71 // An executable Graph differs from a GraphDef by being Placed,
     72 // meaning that each Node is assigned to a single Device in the
     73 // available set.
     74 //
     75 // When GraphExecutionState is first constructed it instantiates
     76 // a full Graph from the provided GraphDef, and places it, using only
     77 // the static device assignments from the GraphDef.  Nodes without are
     78 // currently placed in a very naive way.  Since stateful Nodes cannot
     79 // be moved after initial placement, it is important that stateful
     80 // Nodes get sensible initial device assignments in the graph
     81 // definition.
     82 //
     83 // Subsequently, GraphExecutionState generates a SimpleClientGraph on
     84 // demand, which is a sub-graph of the latest placement of the full
     85 // Graph.  MasterSession uses such a ClientGraph to execute one or
     86 // more similar client requests.
     87 //
     88 // GraphExecutionState is thread-safe.
     89 
     90 class GraphExecutionState {
     91  public:
     92   virtual ~GraphExecutionState();
     93 
     94   // Creates a new `GraphExecutionState` for the given
     95   // `graph_def`, which represents the entire graph for a session.
     96   //
     97   // N.B. This method uses `GraphDef::Swap()` and leaves `graph_def`
     98   // in an undefined state. If it is necessary to use `*graph_def`
     99   // after this call, make an explicit copy of the graph before
    100   // calling this method.
    101   static Status MakeForBaseGraph(
    102       GraphDef* graph_def, const GraphExecutionStateOptions& options,
    103       std::unique_ptr<GraphExecutionState>* out_state);
    104 
    105   // Creates a new `GraphExecutionState` and `SimpleClientGraph`
    106   // for the subgraph of `original_graph_def` defined by
    107   // `subgraph_options`.
    108   static Status MakeForPrunedGraph(
    109       const FunctionDefLibrary& func_def_lib,
    110       const GraphExecutionStateOptions& options,
    111       const GraphDef& original_graph_def,
    112       const BuildGraphOptions& subgraph_options,
    113       std::unique_ptr<GraphExecutionState>* out_state,
    114       std::unique_ptr<ClientGraph>* out_client_graph);
    115 
    116   // Creates a new GraphExecutionState representing the
    117   // concatenation of this graph, and the graph defined by
    118   // "extension_def". The same name may not be used to define a node
    119   // in both this graph and "extension_def".
    120   //
    121   // If successful, returns OK and the caller takes ownership of "*out".
    122   // Otherwise returns an error and does not modify "*out".
    123   //
    124   // After calling `old_state->Extend()`, `old_state` may no longer be
    125   // used.
    126   //
    127   // NOTE(mrry): This method respects the placement of stateful nodes in
    128   // in *this, but currently does not transfer any other placement
    129   // or cost model information to the new graph.
    130   Status Extend(const GraphDef& extension_def,
    131                 std::unique_ptr<GraphExecutionState>* out) const;
    132 
    133   // Builds a ClientGraph (a sub-graph of the full graph as induced by
    134   // the Node set specified in "options").  If successful, returns OK
    135   // and the caller takes the ownership of "*out". Otherwise, returns
    136   // an error.
    137   Status BuildGraph(const BuildGraphOptions& options,
    138                     std::unique_ptr<ClientGraph>* out);
    139 
    140   // The graph returned by BuildGraph may contain only the pruned
    141   // graph, whereas some clients may want access to the full graph.
    142   const Graph* full_graph() { return graph_; }
    143 
    144   // Returns the node with the given name, or null if it does not exist.
    145   const Node* get_node_by_name(const string& name) const {
    146     NodeNameToCostIdMap::const_iterator iter =
    147         node_name_to_cost_id_map_.find(name);
    148     if (iter != node_name_to_cost_id_map_.end()) {
    149       return graph_->FindNodeId(iter->second);
    150     } else {
    151       return nullptr;
    152     }
    153   }
    154 
    155   // Returns a reference to the current graph_def.  Use must
    156   // not extend beyond lifetime of GrahExecutionState object.
    157   const GraphDef& original_graph_def() { return original_graph_def_; }
    158 
    159   // Returns the map of stateful placements as a map of
    160   // node name to placement string.
    161   std::unordered_map<string, string> GetStatefulPlacements() const {
    162     return stateful_placements_;
    163   }
    164 
    165  private:
    166   GraphExecutionState(GraphDef* graph_def,
    167                       const GraphExecutionStateOptions& options);
    168 
    169   Status InitBaseGraph(const BuildGraphOptions& options);
    170 
    171   // Map of placed stateful nodes, i.e. nodes for which is_stateful()
    172   // is true, such as "params" and "queue" nodes.  Once placed these
    173   // nodes can not be moved to a different device.  Maps node names to
    174   // device names.
    175   std::unordered_map<string, string> stateful_placements_;  // Immutable after
    176                                                             // ctor.
    177   void SaveStatefulNodes(Graph* graph);
    178   void RestoreStatefulNodes(Graph* graph);
    179 
    180   Status OptimizeGraph(const BuildGraphOptions& options,
    181                        std::unique_ptr<Graph>* optimized_graph);
    182 
    183   GraphDef original_graph_def_;            // Immutable after ctor.
    184   const DeviceSet* device_set_;            // Not owned
    185   const SessionOptions* session_options_;  // Not owned
    186 
    187   // Map from name to Node for the full graph in placed_.
    188   NodeNameToCostIdMap node_name_to_cost_id_map_;
    189 
    190   // 'flib_def_' is initialized from the initial graph def's library,
    191   // and may be updated by a graph optimization pass.
    192   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
    193 
    194   // `rewrite_metadata_` is only set for GraphExecutionState
    195   // objects created by `MakeForPrunedGraph()`.
    196   std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
    197 
    198   // The dataflow graph owned by this object.
    199   Graph* graph_;
    200 
    201   TF_DISALLOW_COPY_AND_ASSIGN(GraphExecutionState);
    202 };
    203 
    204 }  // namespace tensorflow
    205 
    206 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_
    207