Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <memory>
     17 #include <vector>
     18 
     19 #include "tensorflow/core/framework/attr_value.pb.h"
     20 #include "tensorflow/core/framework/function.h"
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/framework/op_def.pb.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/grappler/utils.h"
     25 #include "tensorflow/core/lib/strings/numbers.h"
     26 #include "tensorflow/core/lib/strings/scanner.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/platform/notification.h"
     29 
     30 namespace tensorflow {
     31 namespace grappler {
     32 
     33 NodeMap::NodeMap(GraphDef* graph) {
     34   CHECK(graph != nullptr);
     35   for (int i = 0; i < graph->node_size(); i++) {
     36     NodeDef* node = graph->mutable_node(i);
     37     const string& node_name = node->name();
     38     auto rslt = nodes_.emplace(node_name, node);
     39     // Check that the graph doesn't contain multiple nodes with the same name.
     40     if (!rslt.second) {
     41       LOG(WARNING) << "Duplicated node in the graph: " << node_name;
     42     }
     43     for (const auto& input : node->input()) {
     44       outputs_[NodeName(input)].insert(nodes_[node_name]);
     45     }
     46   }
     47 }
     48 
     49 void NodeMap::RemoveNode(const string& name) {
     50   nodes_.erase(NodeName(name));
     51   outputs_.erase(NodeName(name));
     52 }
     53 
     54 NodeDef* NodeMap::GetNode(const string& name) const {
     55   const string node_name = NodeName(name);
     56   auto it = nodes_.find(node_name);
     57   if (it == nodes_.end()) {
     58     return nullptr;
     59   }
     60   return it->second;
     61 }
     62 
     63 bool NodeMap::NodeExists(const string& name) const {
     64   const string node_name = NodeName(name);
     65   return nodes_.find(node_name) != nodes_.end();
     66 }
     67 
     68 const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const {
     69   auto it = outputs_.find(node_name);
     70   if (it == outputs_.end()) {
     71     return empty_set_;
     72   }
     73   return it->second;
     74 }
     75 
     76 void NodeMap::AddNode(const string& node_name, NodeDef* node) {
     77   auto ret = nodes_.emplace(node_name, CHECK_NOTNULL(node));
     78   CHECK(ret.second) << "Pair (" << node_name << "," << node
     79                     << ") is not inserted because the same key already exists.";
     80 }
     81 
     82 void NodeMap::AddOutput(const string& node_name, const string& output_name) {
     83   auto output_node = nodes_[NodeName(output_name)];
     84   CHECK(output_node) << "Output node " << output_name
     85                      << " is missing in NodeMap.";
     86   outputs_[node_name].insert(output_node);
     87 }
     88 
     89 void NodeMap::RemoveOutput(const string& node_name, const string& output_name) {
     90   outputs_[node_name].erase(nodes_[NodeName(output_name)]);
     91 }
     92 
     93 void NodeMap::UpdateInput(const string& node_name, const string& old_input_name,
     94                           const string& new_input_name) {
     95   RemoveOutput(NodeName(old_input_name), node_name);
     96   AddOutput(NodeName(new_input_name), node_name);
     97 }
     98 
     99 void NodeMap::RemoveInputs(const string& node_name) {
    100   auto node = nodes_[node_name];
    101   for (const auto& input : node->input()) {
    102     RemoveOutput(NodeName(input), node->name());
    103   }
    104 }
    105 
    106 void NodeMap::RemoveOutputs(const string& node_name) {
    107   outputs_.erase(node_name);
    108 }
    109 
    110 void NodeMap::UpdateOutput(const string& node_name,
    111                            const string& old_output_name,
    112                            const string& new_output_name) {
    113   std::set<NodeDef*>& outputs = outputs_[node_name];
    114   outputs.erase(nodes_[NodeName(old_output_name)]);
    115   outputs.insert(nodes_[NodeName(new_output_name)]);
    116 }
    117 
    118 bool IsSameInput(const string& name1, const string& name2) {
    119   if (name1 == name2) {
    120     return true;
    121   }
    122   int position1;
    123   string node1 = ParseNodeName(name1, &position1);
    124   int position2;
    125   string node2 = ParseNodeName(name2, &position2);
    126   return (position1 == position2) && (node1 == node2);
    127 }
    128 
    129 string ParseNodeName(const string& name, int* position) {
    130   // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
    131   // to get a node name.
    132   strings::Scanner scan(name);
    133   scan.ZeroOrOneLiteral("^")
    134       .RestartCapture()
    135       .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
    136       .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
    137   StringPiece capture;
    138   StringPiece remaining;
    139   if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
    140     *position = 0;
    141     return "";
    142   } else {
    143     if (name[0] == '^') {
    144       *position = -1;
    145     } else if (remaining.empty()) {
    146       *position = 0;
    147     } else {
    148       // Skip the first ':' character.
    149       CHECK(strings::safe_strto32(remaining.substr(1), position));
    150     }
    151     return capture.ToString();
    152   }
    153 }
    154 
    155 bool IsControlInput(const string& name) {
    156   return !name.empty() && name[0] == '^';
    157 }
    158 
    159 string NodeName(const string& name) {
    160   int position;
    161   return ParseNodeName(name, &position);
    162 }
    163 
    164 int NodePosition(const string& name) {
    165   int position;
    166   ParseNodeName(name, &position);
    167   return position;
    168 }
    169 
    170 string AddPrefixToNodeName(const string& name, const string& prefix,
    171                            const string& delimiter) {
    172   if (!name.empty()) {
    173     if (name[0] == '^') {
    174       return strings::StrCat("^", prefix, delimiter, name.substr(1));
    175     }
    176   }
    177   return strings::StrCat(prefix, delimiter, name);
    178 }
    179 
    180 string AddPrefixToNodeName(const string& name, const string& prefix) {
    181   return AddPrefixToNodeName(name, prefix, "/");
    182 }
    183 
    184 bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms,
    185                         thread::ThreadPool* const thread_pool) {
    186   if (timeout_in_ms <= 0) {
    187     fn();
    188     return true;
    189   }
    190   auto done = std::make_shared<Notification>();
    191   thread_pool->Schedule([done, fn]() {
    192     fn();
    193     done->Notify();
    194   });
    195   const bool notified =
    196       WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000);
    197   return notified;
    198 }
    199 
    200 string AsControlDependency(const NodeDef& node) {
    201   return strings::StrCat("^", node.name());
    202 }
    203 
    204 string AsControlDependency(const string& node_name) {
    205   CHECK(!node_name.empty());
    206   return (!node_name.empty() && node_name[0] == '^')
    207              ? node_name
    208              : strings::StrCat("^", node_name);
    209 }
    210 
    211 int NumOutputs(const NodeDef& node, GraphDef* graph) {
    212   int num_outputs = 0;
    213   const OpDef* op_def = nullptr;
    214   auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
    215   if (status.ok()) {
    216     for (const auto& output : op_def->output_arg()) {
    217       if (!output.type_list_attr().empty()) {
    218         num_outputs +=
    219             node.attr().at(output.type_list_attr()).list().type_size();
    220       } else if (!output.number_attr().empty()) {
    221         num_outputs += node.attr().at(output.number_attr()).i();
    222       } else {
    223         num_outputs++;
    224       }
    225     }
    226   } else {
    227     FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library());
    228     auto status = fdef.LookUpOpDef(node.op(), &op_def);
    229     if (status.ok()) {
    230       num_outputs = op_def->output_arg_size();
    231     }
    232   }
    233   return num_outputs;
    234 }
    235 
    236 int NumNonControlInputs(const NodeDef& node) {
    237   int num_inputs = node.input_size();
    238   for (const string& input : node.input()) {
    239     if (IsControlInput(input)) {
    240       --num_inputs;
    241     }
    242   }
    243   return num_inputs;
    244 }
    245 
    246 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
    247   int num_outputs = 0;
    248   for (const NodeDef* output : node_map.GetOutputs(node.name())) {
    249     for (const string& node_as_input : output->input()) {
    250       if (IsControlInput(node_as_input)) {
    251         break;
    252       }
    253       if (NodeName(node_as_input) == node.name()) {
    254         ++num_outputs;
    255       }
    256     }
    257   }
    258   return num_outputs;
    259 }
    260 
    261 // Returns the data type in attribute `attr_name` of `node`. If that attribute
    262 // doesn't exist, returns DT_INVALID.
    263 DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {
    264   if (!node.attr().count(attr_name)) {
    265     return DT_INVALID;
    266   }
    267   const auto& attr = node.attr().at(attr_name);
    268   if (attr.value_case() != AttrValue::kType) {
    269     return DT_INVALID;
    270   }
    271   return attr.type();
    272 }
    273 
    274 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
    275                         bool follow_control_input,
    276                         const std::function<bool(const NodeDef&)>& pred_fn) {
    277   const NodeDef* current = &source;
    278   const NodeDef* next = current;
    279   while (next == &source || (next != nullptr && pred_fn(*next))) {
    280     current = next;
    281     if (current->input_size() == 0 ||
    282         (!follow_control_input && IsControlInput(current->input(0)))) {
    283       break;
    284     }
    285     next = node_map.GetNode(current->input(0));
    286     if (next == nullptr) {
    287       LOG(ERROR) << "Node not found: " << current->input(0);
    288     }
    289   }
    290   return const_cast<NodeDef*>(current);
    291 }
    292 
    293 // Every permutation is a product of one or more cycles. Iterate over the cycles
    294 // in the permutation, and convert each of those into a product of
    295 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation
    296 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
    297                          bool invert_permutation) {
    298   CHECK_EQ(graph->node_size(), permutation->size());
    299   std::vector<int> inv_perm(permutation->size(), 0);
    300   if (invert_permutation) {
    301     for (size_t n = 0; n < permutation->size(); ++n) {
    302       inv_perm[(*permutation)[n]] = n;
    303     }
    304     permutation->swap(inv_perm);
    305   }
    306   for (std::size_t n = 0; n + 1 < permutation->size(); ++n) {
    307     while (n != (*permutation)[n]) {
    308       std::size_t r = (*permutation)[n];
    309       graph->mutable_node()->SwapElements(n, r);
    310       std::swap((*permutation)[n], (*permutation)[r]);
    311     }
    312   }
    313 }
    314 
    315 void DedupControlInputs(NodeDef* node) {
    316   std::unordered_set<string> inputs;
    317   int pos = 0;
    318   while (pos < node->input_size()) {
    319     const string& input = node->input(pos);
    320     if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) {
    321       node->mutable_input()->SwapElements(pos, node->input_size() - 1);
    322       node->mutable_input()->RemoveLast();
    323     } else {
    324       ++pos;
    325     }
    326   }
    327 }
    328 
    329 namespace {
    330 template <typename T>
    331 inline void STLSortAndRemoveDuplicates(T* v) {
    332   std::sort(v->begin(), v->end());
    333   v->erase(std::unique(v->begin(), v->end()), v->end());
    334 }
    335 }  // namespace
    336 
    337 Status SimpleGraphView::Initialize(const GraphDef& graph, bool dedup_inputs,
    338                                    bool dedup_outputs) {
    339   const int num_nodes = graph.node_size();
    340   inputs_.clear();
    341   inputs_.resize(num_nodes);
    342   outputs_.clear();
    343   outputs_.resize(num_nodes);
    344   name_to_index_.clear();
    345   name_to_index_.reserve(num_nodes);
    346   index_to_name_.clear();
    347   index_to_name_.reserve(num_nodes);
    348 
    349   // Build map from name to index and vice versa.
    350   for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
    351     const NodeDef& node = graph.node(node_idx);
    352     name_to_index_.emplace(node.name(), node_idx);
    353     index_to_name_.push_back(node.name());
    354   }
    355 
    356   // Build forward and reverse adjacency lists.
    357   for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
    358     const NodeDef& node = graph.node(node_idx);
    359     inputs_[node_idx].reserve(node.input_size());
    360     for (const string& input : node.input()) {
    361       auto it = name_to_index_.find(NodeName(input));
    362       if (it == name_to_index_.end()) {
    363         return errors::InvalidArgument("Non-existent input ", input,
    364                                        " for node ", node.name());
    365       }
    366       const int input_idx = it->second;
    367       inputs_[node_idx].push_back(input_idx);
    368       outputs_[input_idx].push_back(node_idx);
    369     }
    370     if (dedup_inputs) {
    371       // Dedup the input list while it's still hot in cache.
    372       STLSortAndRemoveDuplicates(&inputs_[node_idx]);
    373     }
    374   }
    375 
    376   // Dedup outputs.
    377   if (dedup_outputs) {
    378     for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
    379       STLSortAndRemoveDuplicates(&outputs_[node_idx]);
    380     }
    381   }
    382   return Status::OK();
    383 }
    384 
    385 string SimpleGraphView::PrintToString() const {
    386   string str;
    387   for (int i = 0; i < num_nodes(); ++i) {
    388     strings::StrAppend(&str, "Node ", i, "'", node_name(i), "'\n", "Inputs: [");
    389     for (int input : inputs(i)) {
    390       strings::StrAppend(&str, input, " '", node_name(input), "', ");
    391     }
    392     strings::StrAppend(&str, "]\n", "Outputs: [");
    393     for (int j = 0; j < outputs(i).size(); ++j) {
    394       const int output = outputs(i)[j];
    395       if (j > 0) {
    396         strings::StrAppend(&str, ", ");
    397       }
    398       strings::StrAppend(&str, output, " '", node_name(output), "'");
    399     }
    400     strings::StrAppend(&str, "]\n");
    401   }
    402   return str;
    403 }
    404 
    405 }  // end namespace grappler
    406 }  // end namespace tensorflow
    407