Home | History | Annotate | Download | only in optimizers
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
     17 
     18 #include "tensorflow/core/framework/attr_value.pb.h"
     19 #include "tensorflow/core/framework/function.pb.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/tensor.pb.h"
     22 #include "tensorflow/core/framework/versions.pb.h"
     23 #include "tensorflow/core/grappler/clusters/cluster.h"
     24 #include "tensorflow/core/grappler/devices.h"
     25 #include "tensorflow/core/grappler/grappler_item.h"
     26 #include "tensorflow/core/grappler/op_types.h"
     27 #include "tensorflow/core/grappler/utils.h"
     28 #include "tensorflow/core/lib/strings/strcat.h"
     29 
     30 namespace tensorflow {
     31 namespace grappler {
     32 const char kAutoParallelPrefix[] = "AutoParallel";
     33 
     34 NodeDef* AutoParallel::AddNodeDivConst() {
     35   NodeDef* node = graph_.add_node();
     36   node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const"));
     37   node->set_op("Const");
     38 
     39   AttrValue attr_data_type;
     40   attr_data_type.set_type(DT_FLOAT);
     41   node->mutable_attr()->insert({"dtype", attr_data_type});
     42 
     43   AttrValue attr_tensor;
     44   auto tensor = attr_tensor.mutable_tensor();
     45   tensor->add_float_val(static_cast<float>(num_replicas_));
     46   tensor->set_dtype(DT_FLOAT);
     47   node->mutable_attr()->insert({"value", attr_tensor});
     48   return node;
     49 }
     50 
     51 NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a,
     52                                   const string& input_b) {
     53   NodeDef* node = graph_.add_node();
     54   node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name));
     55   node->set_op("RealDiv");
     56   node->add_input(input_a);
     57   node->add_input(input_b);
     58   AttrValue attr_type;
     59   attr_type.set_type(DT_FLOAT);
     60   node->mutable_attr()->insert({"T", attr_type});
     61   return node;
     62 }
     63 
     64 NodeDef* AutoParallel::AddNodeControl(const string& name,
     65                                       const std::set<string>& deps,
     66                                       GraphDef* graph) {
     67   NodeDef* node = graph->add_node();
     68   node->set_name(name);
     69   node->set_op("NoOp");
     70   for (const auto& dep : deps) {
     71     node->add_input(strings::StrCat("^", dep));
     72   }
     73   return node;
     74 }
     75 
     76 Status AutoParallel::Initialize(const GrapplerItem& item) {
     77   num_gpus_ = GetNumAvailableGPUs();
     78   LOG(INFO) << "Number of GPUs: " << num_gpus_;
     79   item_ = &item;
     80   graph_ = item.graph;
     81   LOG(INFO) << "Original graph size: " << graph_.node_size();
     82   if (item.fetch.empty()) {
     83     return Status(error::INVALID_ARGUMENT, "No fetch nodes provided.");
     84   }
     85 
     86   if (item.MainVariables().empty()) {
     87     return Status(error::INVALID_ARGUMENT, "No variables provided.");
     88   }
     89 
     90   for (const auto& init : item.init_ops) {
     91     VLOG(1) << "Init node: " << init;
     92   }
     93 
     94   for (const auto& fetch : item.fetch) {
     95     VLOG(1) << "Fetch node: " << fetch;
     96   }
     97 
     98   for (const auto& var : item.MainVariables()) {
     99     VLOG(2) << "Variable: " << var->name();
    100   }
    101 
    102   const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
    103                                                 "ApplyProximalGradientDescent",
    104                                                 "ApplyAdadelta",
    105                                                 "ApplyAdagrad",
    106                                                 "ApplyProximalAdagrad",
    107                                                 "ApplyAdagradDA",
    108                                                 "ApplyFtrl",
    109                                                 "ApplyMomentum",
    110                                                 "ApplyAdam",
    111                                                 "ApplyRMSProp",
    112                                                 "ApplyCenteredRMSProp"};
    113   for (int i = 0; i < graph_.node_size(); i++) {
    114     all_nodes_.insert(
    115         std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
    116     if (apply_gradients_ops.find(graph_.node(i).op()) !=
    117         apply_gradients_ops.end()) {
    118       apply_gradients_nodes_.insert(graph_.node(i).name());
    119       VLOG(2) << "Apply gradients node: " << graph_.node(i).name();
    120     }
    121   }
    122 
    123   auto div_const_node = AddNodeDivConst();
    124   all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node));
    125   std::map<string, int> gradient_pos = {{"ApplyGradientDescent", 2},
    126                                         {"ApplyProximalGradientDescent", 4},
    127                                         {"ApplyAdadelta", 6},
    128                                         {"ApplyAdagrad", 3},
    129                                         {"ApplyProximalAdagrad", 5},
    130                                         {"ApplyAdagradDA", 3},
    131                                         {"ApplyFtrl", 3},
    132                                         {"ApplyMomentum", 3},
    133                                         {"ApplyAdam", 9},
    134                                         {"ApplyRMSProp", 7},
    135                                         {"ApplyCenteredRMSProp", 8}};
    136   for (const auto& apply_gradient_node_name : apply_gradients_nodes_) {
    137     auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op();
    138     auto apply_gradients_node = all_nodes_[apply_gradient_node_name];
    139 
    140     auto div_node = AddNodeDiv(
    141         apply_gradient_node_name,
    142         apply_gradients_node->input(gradient_pos[apply_gradients_op]),
    143         div_const_node->name());
    144     all_nodes_.insert(std::make_pair(div_node->name(), div_node));
    145     *apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) =
    146         div_node->name();
    147   }
    148   LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
    149 
    150   auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch);
    151   LOG(INFO) << "Number of training nodes: " << train_nodes.size();
    152 
    153   const NodeDef* dequeue_node;
    154   for (const auto& train_node : train_nodes) {
    155     if (IsDequeueOp(*train_node)) {
    156       dequeue_node = train_node;
    157       break;
    158     }
    159   }
    160 
    161   std::vector<const NodeDef*> input_nodes;
    162   if (dequeue_node) {
    163     LOG(INFO) << "Dequeue node: " << dequeue_node->name();
    164     input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()});
    165   }
    166   LOG(INFO) << "Number of input nodes: " << input_nodes.size();
    167 
    168   std::set<string> dont_replicate_nodes;
    169   for (const auto& variable : item.MainVariables()) {
    170     dont_replicate_nodes.insert(variable->name());
    171   }
    172 
    173   for (const auto& init : item.init_ops) {
    174     dont_replicate_nodes.insert(NodeName(init));
    175   }
    176 
    177   // Don't replicate all input nodes, except the dequeue node.
    178   for (const auto& input_node : input_nodes) {
    179     if (input_node->name() != dequeue_node->name()) {
    180       dont_replicate_nodes.insert(input_node->name());
    181     }
    182   }
    183 
    184   for (const auto& node : train_nodes) {
    185     if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) {
    186       replica_nodes_.insert(node->name());
    187     }
    188   }
    189   LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size();
    190 
    191   for (const auto& node : all_nodes_) {
    192     if (replica_nodes_.find(node.first) == replica_nodes_.end()) {
    193       shared_nodes_.insert(node.first);
    194     }
    195   }
    196   LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size();
    197   return Status::OK();
    198 }
    199 
    200 bool AutoParallel::NotSharedNode(const string& name) {
    201   return shared_nodes_.find(name) == shared_nodes_.end();
    202 }
    203 
    204 void AutoParallel::AddSharedNodes(GraphDef* graph) {
    205   string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0);
    206   for (const auto& node : shared_nodes_) {
    207     auto new_node = graph->add_node();
    208     *new_node = *all_nodes_[node];
    209     for (int i = 0; i < new_node->input_size(); i++) {
    210       if (NotSharedNode(NodeName(new_node->input(i)))) {
    211         string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
    212         *new_node->mutable_input(i) = new_name;
    213       }
    214     }
    215   }
    216 }
    217 
    218 void AutoParallel::AddOneReplica(GraphDef* graph, int number) {
    219   string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number);
    220   for (const auto& node : replica_nodes_) {
    221     auto new_node = graph->add_node();
    222     *new_node = *all_nodes_[node];
    223     if (NotSharedNode(new_node->name())) {
    224       new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix));
    225       if (num_gpus_ > 0) {
    226         new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_));
    227       }
    228       for (int i = 0; i < new_node->input_size(); i++) {
    229         if (NotSharedNode(NodeName(new_node->input(i)))) {
    230           string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
    231           *new_node->mutable_input(i) = new_name;
    232         }
    233       }
    234     }
    235   }
    236 }
    237 
    238 void AutoParallel::BuildGraph(GraphDef* graph) {
    239   AddSharedNodes(graph);
    240   for (int i = 0; i < num_replicas_; i++) {
    241     AddOneReplica(graph, i);
    242   }
    243   std::set<string> fetches;
    244   for (size_t i = 0; i < item_->fetch.size(); i++) {
    245     for (int j = 0; j < num_replicas_; j++) {
    246       string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
    247       string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
    248       fetches.insert(fetch);
    249     }
    250   }
    251   string name_control =
    252       strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch");
    253   auto control = AddNodeControl(name_control, fetches, graph);
    254 
    255   for (const auto& fetch : item_->fetch) {
    256     AddNodeControl(fetch, {control->name()}, graph);
    257   }
    258   *graph->mutable_library() = item_->graph.library();
    259   *graph->mutable_versions() = item_->graph.versions();
    260   LOG(INFO) << "Parallelized graph size: " << graph->node_size();
    261 }
    262 
    263 Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item,
    264                               GraphDef* output) {
    265   TF_RETURN_IF_ERROR(Initialize(item));
    266   BuildGraph(output);
    267   return Status::OK();
    268 }
    269 
    270 void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item,
    271                             const GraphDef& optimize_output, double result) {
    272   // TODO(yaozhang): Add feedback.
    273 }
    274 
    275 }  // end namespace grappler
    276 }  // end namespace tensorflow
    277