Home | History | Annotate | Download | only in optimizers
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
     17 #include <unordered_map>
     18 #include <unordered_set>
     19 #include "tensorflow/core/framework/function.pb.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/node_def_util.h"
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_def.pb.h"
     24 #include "tensorflow/core/grappler/grappler_item.h"
     25 #include "tensorflow/core/grappler/op_types.h"
     26 #include "tensorflow/core/grappler/utils.h"
     27 
     28 namespace tensorflow {
     29 namespace grappler {
     30 
     31 GraphRewriter::GraphRewriter(const GrapplerItem& item) {
     32   OpRegistryInterface* op_registry = OpRegistry::Global();
     33   for (auto& node : item.graph.node()) {
     34     NodeInfo* info = new NodeInfo();
     35     info->def = &node;
     36 
     37     const OpRegistrationData* op_reg_data = nullptr;
     38     Status s = op_registry->LookUp(node.op(), &op_reg_data);
     39     // TODO(bsteiner): make this not a best-effort lookup and evaluation?
     40     if (s.ok()) {
     41       DataTypeVector inputs;
     42       s = InOutTypesForNode(node, op_reg_data->op_def, &inputs, &info->outputs);
     43       if (!s.ok()) {
     44         info->outputs.clear();
     45       }
     46     }
     47 
     48     nodes_[node.name()].reset(info);
     49   }
     50 
     51   std::unordered_set<string> function_names;
     52   for (const auto& function : item.graph.library().function()) {
     53     function_names.insert(function.signature().name());
     54   }
     55 
     56   for (auto& node : item.graph.node()) {
     57     RecordConnectivity(node, function_names);
     58   }
     59 }
     60 
     61 void GraphRewriter::ForwardInputs(
     62     const NodeDef& original_node,
     63     const std::unordered_set<const NodeDef*>& nodes_to_delete,
     64     NodeDef* new_node) {
     65   ForwardInputsInternal(original_node, nodes_to_delete, false, new_node);
     66   if (!new_node->name().empty()) {
     67     optimized_nodes_[new_node->name()] = new_node;
     68   }
     69   // Reorder inputs such that control inputs come after regular inputs.
     70   int pos = 0;
     71   for (int i = 0; i < new_node->input_size(); ++i) {
     72     if (!IsControlInput(new_node->input(i))) {
     73       new_node->mutable_input()->SwapElements(pos, i);
     74       ++pos;
     75     }
     76   }
     77   DedupControlInputs(new_node);
     78 }
     79 
     80 bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const {
     81   return control_dependency_drivers_.find(&node) !=
     82          control_dependency_drivers_.end();
     83 }
     84 
     85 bool GraphRewriter::FeedsMerge(const NodeDef& node) const {
     86   return merge_feeders_.find(&node) != merge_feeders_.end();
     87 }
     88 
     89 bool GraphRewriter::IsDrivenByControlDependency(const NodeDef& node) const {
     90   for (const auto& input : node.input()) {
     91     CHECK(!input.empty());
     92     if (input[0] == '^') {
     93       return true;
     94     }
     95   }
     96   return false;
     97 }
     98 
     99 bool GraphRewriter::IsConnectedToFunction(const NodeDef& node) const {
    100   return function_neighbors_.find(&node) != function_neighbors_.end();
    101 }
    102 
    103 bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const {
    104   return cross_device_receivers_.find(&node) != cross_device_receivers_.end();
    105 }
    106 
    107 bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const {
    108   return ref_receivers_.find(&node) != ref_receivers_.end();
    109 }
    110 
    111 bool GraphRewriter::IsDrivenBySwitch(const NodeDef& node) const {
    112   return switch_receivers_.find(&node) != switch_receivers_.end();
    113 }
    114 
    115 bool GraphRewriter::RemovalIncreasesEdgeCount(const NodeDef& node) const {
    116   const int in_degree = node.input_size();
    117   auto itr = nodes_.find(node.name());
    118   if (itr == nodes_.end()) {
    119     return true;
    120   }
    121   const int out_degree = itr->second->out_degree;
    122   return in_degree * out_degree > in_degree + out_degree;
    123 }
    124 
    125 void GraphRewriter::RecordConnectivity(
    126     const NodeDef& node, const std::unordered_set<string>& function_names) {
    127   const bool is_function =
    128       function_names.find(node.op()) != function_names.end();
    129 
    130   bool ref_receiver = false;
    131   bool switch_receiver = false;
    132   for (const auto& input : node.input()) {
    133     int position = 0;
    134     string input_node_name = ParseNodeName(input, &position);
    135     auto itr = nodes_.find(input_node_name);
    136     if (itr == nodes_.end()) {
    137       continue;
    138     }
    139 
    140     NodeInfo* fanin_info = itr->second.get();
    141     const NodeDef* fanin = fanin_info->def;
    142     if (IsMerge(node)) {
    143       merge_feeders_.insert(fanin);
    144     }
    145     // Update out_degree of fanin.
    146     ++fanin_info->out_degree;
    147     if (position < 0) {
    148       // This is a control edge
    149       control_dependency_drivers_.insert(fanin);
    150     } else {
    151       // This is a regular edge
    152       if (function_names.find(fanin->op()) != function_names.end()) {
    153         function_neighbors_.insert(&node);
    154       }
    155       if (is_function) {
    156         function_neighbors_.insert(fanin);
    157       }
    158       if (IsSwitch(*fanin)) {
    159         switch_receiver = true;
    160       }
    161       if (position < fanin_info->outputs.size() &&
    162           IsRefType(fanin_info->outputs[position])) {
    163         ref_receiver = true;
    164       }
    165     }
    166     if (fanin->device() != node.device()) {
    167       cross_device_receivers_.insert(&node);
    168     }
    169   }
    170 
    171   if (ref_receiver) {
    172     ref_receivers_.insert(&node);
    173   }
    174   if (switch_receiver) {
    175     switch_receivers_.insert(&node);
    176   }
    177 }
    178 
    179 void GraphRewriter::ForwardInputsInternal(
    180     const NodeDef& node,
    181     const std::unordered_set<const NodeDef*>& nodes_to_delete,
    182     bool add_as_control, NodeDef* new_node) {
    183   // To speed things up, use the optimized version of the node if
    184   // available.
    185   auto itr = optimized_nodes_.find(node.name());
    186   if (itr != optimized_nodes_.end()) {
    187     for (const string& input : itr->second->input()) {
    188       *new_node->add_input() =
    189           add_as_control ? AsControlDependency(NodeName(input)) : input;
    190     }
    191     return;
    192   }
    193   for (const auto& input : node.input()) {
    194     const string input_node_name = NodeName(input);
    195     auto itr = nodes_.find(input_node_name);
    196     if (itr == nodes_.end()) {
    197       // Invalid input, preserve it as is.
    198       *new_node->add_input() =
    199           add_as_control ? AsControlDependency(NodeName(input)) : input;
    200       continue;
    201     }
    202     const NodeDef* input_node = itr->second->def;
    203     if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) {
    204       ForwardInputsInternal(*input_node, nodes_to_delete,
    205                             add_as_control || IsControlInput(input), new_node);
    206     } else {
    207       *new_node->add_input() =
    208           add_as_control ? AsControlDependency(NodeName(input)) : input;
    209     }
    210   }
    211 }
    212 
    213 }  // end namespace grappler
    214 }  // end namespace tensorflow
    215