Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/grappler_item.h"
     17 
     18 #include <unordered_map>
     19 #include <unordered_set>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/attr_value.pb.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 #include "tensorflow/core/grappler/op_types.h"
     25 #include "tensorflow/core/grappler/utils.h"
     26 
     27 namespace tensorflow {
     28 namespace grappler {
     29 
     30 GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef) {
     31   id = other.id;
     32   feed = other.feed;
     33   fetch = other.fetch;
     34   init_ops = other.init_ops;
     35   expected_init_time = other.expected_init_time;
     36   save_op = other.save_op;
     37   restore_op = other.restore_op;
     38   save_restore_loc_tensor = other.save_restore_loc_tensor;
     39   queue_runners = other.queue_runners;
     40   graph.Swap(&graphDef);
     41 }
     42 
     43 std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
     44   return ComputeTransitiveFanin(graph, fetch);
     45 }
     46 
     47 std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
     48   std::vector<string> enqueue_ops;
     49   for (const auto& queue_runner : queue_runners) {
     50     for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
     51       enqueue_ops.push_back(enqueue_op);
     52     }
     53   }
     54   return ComputeTransitiveFanin(graph, enqueue_ops);
     55 }
     56 
     57 std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
     58   return ComputeTransitiveFanin(graph, init_ops);
     59 }
     60 
     61 std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
     62   std::vector<const NodeDef*> fanin = ComputeTransitiveFanin(graph, init_ops);
     63   std::vector<const NodeDef*> vars;
     64   for (const NodeDef* node : fanin) {
     65     if (IsVariable(*node)) {
     66       vars.push_back(node);
     67     }
     68   }
     69   return vars;
     70 }
     71 
     72 std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
     73   std::unordered_set<string> result;
     74   for (const string& f : fetch) {
     75     VLOG(1) << "Add fetch " << f;
     76     result.insert(NodeName(f));
     77   }
     78   for (const auto& f : feed) {
     79     VLOG(1) << "Add feed " << f.first;
     80     result.insert(NodeName(f.first));
     81   }
     82   for (const auto& node : init_ops) {
     83     result.insert(NodeName(node));
     84   }
     85   if (!save_op.empty()) {
     86     result.insert(NodeName(save_op));
     87   }
     88   if (!restore_op.empty()) {
     89     result.insert(NodeName(restore_op));
     90   }
     91   if (!save_restore_loc_tensor.empty()) {
     92     result.insert(NodeName(save_restore_loc_tensor));
     93   }
     94 
     95   for (const auto& queue_runner : queue_runners) {
     96     for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
     97       result.insert(NodeName(enqueue_op));
     98     }
     99     if (!queue_runner.close_op_name().empty()) {
    100       result.insert(NodeName(queue_runner.close_op_name()));
    101     }
    102     if (!queue_runner.cancel_op_name().empty()) {
    103       result.insert(NodeName(queue_runner.cancel_op_name()));
    104     }
    105   }
    106   return result;
    107 }
    108 
    109 std::vector<const NodeDef*> ComputeTransitiveFanin(
    110     const GraphDef& graph, const std::vector<string>& terminal_nodes) {
    111   bool ill_formed = false;
    112   std::vector<const NodeDef*> result =
    113       ComputeTransitiveFanin(graph, terminal_nodes, &ill_formed);
    114   CHECK(!ill_formed);
    115   return result;
    116 }
    117 
    118 std::vector<const NodeDef*> ComputeTransitiveFanin(
    119     const GraphDef& graph, const std::vector<string>& terminal_nodes,
    120     bool* ill_formed) {
    121   *ill_formed = false;
    122   std::unordered_map<string, const NodeDef*> name_to_node;
    123   std::unordered_map<string, const NodeDef*> name_to_send;
    124   for (const auto& node : graph.node()) {
    125     name_to_node[node.name()] = &node;
    126     if (node.op() == "_Send") {
    127       const auto& attr = node.attr();
    128       name_to_send[attr.at("tensor_name").s()] = &node;
    129     }
    130   }
    131 
    132   std::vector<const NodeDef*> queue;
    133   for (const string& root : terminal_nodes) {
    134     const NodeDef* node = name_to_node[NodeName(root)];
    135     if (!node) {
    136       *ill_formed = true;
    137       VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
    138       return {};
    139     }
    140     queue.push_back(node);
    141   }
    142 
    143   std::vector<const NodeDef*> result;
    144   std::unordered_set<const NodeDef*> visited;
    145 
    146   while (!queue.empty()) {
    147     const NodeDef* node = queue.back();
    148     queue.pop_back();
    149     if (!visited.insert(node).second) {
    150       // The node has already been visited.
    151       continue;
    152     }
    153     result.push_back(node);
    154     for (const string& input : node->input()) {
    155       const NodeDef* in = name_to_node[NodeName(input)];
    156       if (!in) {
    157         VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
    158         *ill_formed = true;
    159         return {};
    160       }
    161       queue.push_back(in);
    162     }
    163     if (node->op() == "_Recv") {
    164       const auto& attr = node->attr();
    165       const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
    166       if (send) {
    167         queue.push_back(send);
    168       }
    169       // Subgraph after partitioning may have either _Send or _Recv, not both.
    170       // So, we do not set ill_formed for missing _Send.
    171     }
    172   }
    173   return result;
    174 }
    175 
    176 }  // end namespace grappler
    177 }  // end namespace tensorflow
    178