Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/grappler_item.h"
     17 
     18 #include <unordered_map>
     19 #include <unordered_set>
     20 #include <vector>
     21 
     22 #include "absl/container/flat_hash_set.h"
     23 #include "absl/strings/str_join.h"
     24 #include "tensorflow/core/framework/attr_value.pb.h"
     25 #include "tensorflow/core/framework/node_def.pb.h"
     26 #include "tensorflow/core/grappler/op_types.h"
     27 #include "tensorflow/core/grappler/utils.h"
     28 #include "tensorflow/core/util/device_name_utils.h"
     29 
     30 namespace tensorflow {
     31 namespace grappler {
     32 
     33 GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
     34   GrapplerItem item;
     35   item.id = id;
     36   item.feed = feed;
     37   item.fetch = fetch;
     38   item.init_ops = init_ops;
     39   item.keep_ops = keep_ops;
     40   item.expected_init_time = expected_init_time;
     41   item.save_op = save_op;
     42   item.restore_op = restore_op;
     43   item.save_restore_loc_tensor = save_restore_loc_tensor;
     44   item.queue_runners = queue_runners;
     45   item.devices_ = devices_;
     46   item.optimization_options_ = optimization_options_;
     47   item.graph.Swap(&graph_def);
     48   return item;
     49 }
     50 
     51 std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
     52   return ComputeTransitiveFanin(graph, fetch);
     53 }
     54 
     55 std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const {
     56   std::vector<string> enqueue_ops;
     57   for (const auto& queue_runner : queue_runners) {
     58     for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
     59       enqueue_ops.push_back(enqueue_op);
     60     }
     61   }
     62   return ComputeTransitiveFanin(graph, enqueue_ops);
     63 }
     64 
     65 std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
     66   return ComputeTransitiveFanin(graph, init_ops);
     67 }
     68 
     69 std::vector<const NodeDef*> GrapplerItem::MainVariables() const {
     70   std::vector<const NodeDef*> fanin = ComputeTransitiveFanin(graph, init_ops);
     71   std::vector<const NodeDef*> vars;
     72   for (const NodeDef* node : fanin) {
     73     if (IsVariable(*node)) {
     74       vars.push_back(node);
     75     }
     76   }
     77   return vars;
     78 }
     79 
     80 std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
     81   std::unordered_set<string> result;
     82   for (const string& f : fetch) {
     83     VLOG(1) << "Add fetch " << f;
     84     result.insert(NodeName(f));
     85   }
     86   for (const auto& f : feed) {
     87     VLOG(1) << "Add feed " << f.first;
     88     result.insert(NodeName(f.first));
     89   }
     90   for (const auto& node : init_ops) {
     91     result.insert(NodeName(node));
     92   }
     93   for (const auto& node : keep_ops) {
     94     result.insert(NodeName(node));
     95   }
     96   if (!save_op.empty()) {
     97     result.insert(NodeName(save_op));
     98   }
     99   if (!restore_op.empty()) {
    100     result.insert(NodeName(restore_op));
    101   }
    102   if (!save_restore_loc_tensor.empty()) {
    103     result.insert(NodeName(save_restore_loc_tensor));
    104   }
    105 
    106   for (const auto& queue_runner : queue_runners) {
    107     for (const string& enqueue_op : queue_runner.enqueue_op_name()) {
    108       result.insert(NodeName(enqueue_op));
    109     }
    110     if (!queue_runner.close_op_name().empty()) {
    111       result.insert(NodeName(queue_runner.close_op_name()));
    112     }
    113     if (!queue_runner.cancel_op_name().empty()) {
    114       result.insert(NodeName(queue_runner.cancel_op_name()));
    115     }
    116   }
    117 
    118   // Tensorflow functions do not prune stateful or dataset-output ops from
    119   // the function body (see PruneFunctionBody in common_runtime/function.cc).
    120   if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
    121     FunctionLibraryDefinition fn_library(OpRegistry::Global(), graph.library());
    122     for (const NodeDef& node : graph.node()) {
    123       if (IsStateful(node, &fn_library) || IsDataset(node)) {
    124         result.insert(node.name());
    125       }
    126     }
    127   }
    128 
    129   return result;
    130 }
    131 
    132 const std::unordered_set<string>& GrapplerItem::devices() const {
    133   return devices_;
    134 }
    135 
    136 Status GrapplerItem::AddDevice(const string& device) {
    137   DeviceNameUtils::ParsedName name;
    138 
    139   if (!DeviceNameUtils::ParseFullName(device, &name)) {
    140     return errors::InvalidArgument("Invalid device name: device=", device);
    141 
    142   } else if (!name.has_job || !name.has_replica || !name.has_task ||
    143              !name.has_type || !name.has_id) {
    144     return errors::InvalidArgument("Not a fully defined device name: device=",
    145                                    device);
    146   }
    147 
    148   devices_.insert(DeviceNameUtils::ParsedNameToString(name));
    149   return Status::OK();
    150 }
    151 
    152 Status GrapplerItem::AddDevices(const GrapplerItem& other) {
    153   std::vector<absl::string_view> invalid_devices;
    154   for (const string& device : other.devices()) {
    155     Status added = AddDevice(device);
    156     if (!added.ok()) invalid_devices.emplace_back(device);
    157   }
    158   return invalid_devices.empty()
    159              ? Status::OK()
    160              : errors::InvalidArgument("Skipped invalid devices: [",
    161                                        absl::StrJoin(invalid_devices, ", "),
    162                                        "]");
    163 }
    164 
    165 Status GrapplerItem::InferDevicesFromGraph() {
    166   absl::flat_hash_set<absl::string_view> invalid_devices;
    167   for (const NodeDef& node : graph.node()) {
    168     Status added = AddDevice(node.device());
    169     if (!added.ok()) invalid_devices.insert(node.device());
    170   }
    171   VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
    172   return invalid_devices.empty()
    173              ? Status::OK()
    174              : errors::InvalidArgument("Skipped invalid devices: [",
    175                                        absl::StrJoin(invalid_devices, ", "),
    176                                        "]");
    177 }
    178 
    179 void GrapplerItem::ClearDevices() { devices_.clear(); }
    180 
    181 const GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options()
    182     const {
    183   return optimization_options_;
    184 }
    185 
    186 GrapplerItem::OptimizationOptions& GrapplerItem::optimization_options() {
    187   return optimization_options_;
    188 }
    189 
    190 std::vector<const NodeDef*> ComputeTransitiveFanin(
    191     const GraphDef& graph, const std::vector<string>& terminal_nodes) {
    192   bool ill_formed = false;
    193   std::vector<const NodeDef*> result =
    194       ComputeTransitiveFanin(graph, terminal_nodes, &ill_formed);
    195   CHECK(!ill_formed);
    196   return result;
    197 }
    198 
    199 std::vector<const NodeDef*> ComputeTransitiveFanin(
    200     const GraphDef& graph, const std::vector<string>& terminal_nodes,
    201     bool* ill_formed) {
    202   *ill_formed = false;
    203   std::unordered_map<string, const NodeDef*> name_to_node;
    204   std::unordered_map<string, const NodeDef*> name_to_send;
    205   for (const auto& node : graph.node()) {
    206     name_to_node[node.name()] = &node;
    207     if (node.op() == "_Send") {
    208       const auto& attr = node.attr();
    209       name_to_send[attr.at("tensor_name").s()] = &node;
    210     }
    211   }
    212 
    213   std::vector<const NodeDef*> queue;
    214   for (const string& root : terminal_nodes) {
    215     const NodeDef* node = name_to_node[NodeName(root)];
    216     if (!node) {
    217       *ill_formed = true;
    218       VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
    219       return {};
    220     }
    221     queue.push_back(node);
    222   }
    223 
    224   std::vector<const NodeDef*> result;
    225   std::unordered_set<const NodeDef*> visited;
    226 
    227   while (!queue.empty()) {
    228     const NodeDef* node = queue.back();
    229     queue.pop_back();
    230     if (!visited.insert(node).second) {
    231       // The node has already been visited.
    232       continue;
    233     }
    234     result.push_back(node);
    235     for (const string& input : node->input()) {
    236       const NodeDef* in = name_to_node[NodeName(input)];
    237       if (!in) {
    238         VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
    239         *ill_formed = true;
    240         return {};
    241       }
    242       queue.push_back(in);
    243     }
    244     if (node->op() == "_Recv") {
    245       const auto& attr = node->attr();
    246       const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
    247       if (send) {
    248         queue.push_back(send);
    249       }
    250       // Subgraph after partitioning may have either _Send or _Recv, not both.
    251       // So, we do not set ill_formed for missing _Send.
    252     }
    253   }
    254   return result;
    255 }
    256 
    257 }  // end namespace grappler
    258 }  // end namespace tensorflow
    259