Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
     17 #define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
     18 
     19 #include <functional>
     20 #include <string>
     21 #include <unordered_map>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/function.h"
     25 #include "tensorflow/core/framework/graph.pb.h"
     26 #include "tensorflow/core/graph/costmodel.h"
     27 #include "tensorflow/core/graph/graph.h"
     28 
     29 namespace tensorflow {
     30 
     31 struct PartitionOptions {
     32   // A function that returns a location for the execution of a given
     33   // Node.
     34   typedef std::function<string(const Node*)> NodeToLocFunc;
     35   NodeToLocFunc node_to_loc = nullptr;
     36 
     37   // A function that returns a unique graph node name with the given
     38   // prefix.
     39   typedef std::function<string(const string&)> NewNameFunc;
     40   NewNameFunc new_name = nullptr;
     41 
     42   // A function that returns the incarnation of a device given the
     43   // device's fullname. If not found, GetIncarnationFunc should return
     44   // kIllegalIncarnation.
     45   static const uint64 kIllegalIncarnation = 0;
     46   typedef std::function<uint64(const string&)> GetIncarnationFunc;
     47   GetIncarnationFunc get_incarnation = nullptr;
     48 
     49   // If specified, flib_def defines a function library that should be
     50   // partitioned and replicated into each resulting partition graphs.
     51   const FunctionLibraryDefinition* flib_def = nullptr;
     52 
     53   // True if all the control flow "code" has already been added. The
     54   // control flow code needs to be added when we still have the entire
     55   // graph before any partitioning. So this flag should be false for
     56   // the first partitioning but true for all subsequent partitioning.
     57   //
     58   // TODO(yuanbyu): We could also make the addition of the control
     59   // flow code incremental based on 'node_to_loc'. This makes the
     60   // communication a broadcast tree, which could be more efficient when
     61   // the number of participating devices is large.
     62   bool control_flow_added = false;
     63 
     64   // A function that returns the data type into which the tensor
     65   // should be cast before sent over the wire.
     66   typedef std::function<DataType(const Edge*)> ShouldCastFunc;
     67   ShouldCastFunc should_cast = nullptr;
     68 
     69   // Schedule the execution of the recvs based on their start times
     70   // computed by some scheduling algorithm. The recvs are divided into
     71   // epochs based on their start times. A recv is enabled only when
     72   // execution reaches its epoch - N for some predefined N.
     73   bool scheduling_for_recvs = false;
     74   // The start time for each node in the graph computed by some scheduling
     75   // algorithm. If 'need_to_record_start_times' is true, we record them
     76   // in the graph as a node attribute.
     77   bool need_to_record_start_times = false;
     78   std::vector<Microseconds> start_times;
     79 };
     80 
     81 // Partition "input" graph into a set of graphs, one per location.
     82 // The location for node n is derived by calling opts.node_to_loc(n).
     83 // New nodes added by Partition use "opts.new_name(old_name)" to
     84 // generate node names.
     85 //
     86 // Stores the partitions in *partitions.
     87 Status Partition(const PartitionOptions& opts, Graph* input,
     88                  std::unordered_map<string, GraphDef>* partitions);
     89 
     90 // Add control edges to the partitions to control the ordering
     91 // and timing of the recv nodes based on the start times calculated
     92 // using some scheduling algorithm.
     93 Status AddControlEdges(const PartitionOptions& opts,
     94                        std::unordered_map<string, GraphDef>* partitions);
     95 
     96 }  // namespace tensorflow
     97 
     98 #endif  // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
     99