Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 %include <std_shared_ptr.i>
     17 %typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) {
     18   char* c_string;
     19   Py_ssize_t py_size;
     20   if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {
     21     // Python has raised an error (likely TypeError or UnicodeEncodeError).
     22     SWIG_fail;
     23   }
     24 
     25   if (!temp.ParseFromString(string(c_string, py_size))) {
     26     PyErr_SetString(
     27         PyExc_TypeError,
     28         "The MetaGraphDef could not be parsed as a valid protocol buffer");
     29     SWIG_fail;
     30   }
     31   $1 = &temp;
     32 }
     33 
     34 // Wrap the item into an object that swig can manipulate. This ensures it will call the object
     35 // destructor upon garbage collection instead of leaking memory.
     36 struct GItem {
     37   std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
     38 };
     39 
     40 
     41 %{
     42 #include <unordered_set>
     43 #include <map>
     44 #include "tensorflow/c/tf_status_helper.h"
     45 #include "tensorflow/core/framework/node_def_util.h"
     46 #include "tensorflow/core/grappler/utils.h"
     47 #include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
     48 #include "tensorflow/core/grappler/grappler_item_builder.h"
     49 #include "tensorflow/core/grappler/costs/graph_properties.h"
     50 #include "tensorflow/core/grappler/utils/topological_sort.h"
     51 #include "tensorflow/core/lib/core/error_codes.pb.h"
     52 #include "tensorflow/core/lib/core/status.h"
     53 #include "tensorflow/core/protobuf/meta_graph.pb.h"
     54 #include "tensorflow/core/lib/strings/strcat.h"
     55 
     56 // Provide the implementation fo the GItem struct here.
     57 struct GItem {
     58   GItem() {}
     59   GItem(tensorflow::grappler::GrapplerItem* item) : item_(item) {}
     60 
     61   tensorflow::grappler::GrapplerItem* operator->() const {
     62     return item_.get();
     63   }
     64   const tensorflow::grappler::GrapplerItem& operator*() const {
     65     return *item_.get();
     66   }
     67   bool is_none() const {
     68     return item_.get() == nullptr;
     69   }
     70   std::shared_ptr<tensorflow::grappler::GrapplerItem> item_;
     71 };
     72 
     73 static GItem TF_NewItem(
     74     const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
     75     bool ignore_user_placement, TF_Status* out_status) {
     76   if (meta_graph.collection_def().count("train_op") == 0) {
     77     tensorflow::Set_TF_Status_from_Status(
     78         out_status,
     79         tensorflow::errors::InvalidArgument("train_op not specified in the metagraph"));
     80     return nullptr;
     81   }
     82 
     83   tensorflow::grappler::ItemConfig cfg;
     84   cfg.ignore_user_placement = ignore_user_placement;
     85   cfg.ignore_colocation = ignore_colocation;
     86   std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
     87       tensorflow::grappler::GrapplerItemFromMetaGraphDef("item", meta_graph, cfg);
     88   if (!item) {
     89     tensorflow::Set_TF_Status_from_Status(
     90         out_status,
     91         tensorflow::errors::InvalidArgument("Invalid metagraph"));
     92     return nullptr;
     93   }
     94   tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
     95   return GItem(item.release());
     96 }
     97 
     98 static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
     99                                                    TF_Status* status) {
    100   if (item.is_none()) {
    101     Py_RETURN_NONE;
    102   }
    103 
    104   std::vector<const tensorflow::NodeDef*> main_ops = item->MainOpsFanin();
    105   std::vector<const tensorflow::NodeDef*> enqueue_ops = item->EnqueueOpsFanin();
    106   std::unordered_set<string> op_names;
    107   for (auto op : main_ops) {
    108     op_names.insert(op->name());
    109   }
    110   for (auto op : enqueue_ops) {
    111     op_names.insert(op->name());
    112   }
    113 
    114   std::vector<string> ops;
    115   if (sort_topologically) {
    116     tensorflow::GraphDef subgraph;
    117     for (const tensorflow::NodeDef& node : item->graph.node()) {
    118       if (op_names.find(node.name()) != op_names.end()) {
    119         *subgraph.add_node() = node;
    120       }
    121     }
    122     tensorflow::Status s = tensorflow::grappler::TopologicalSort(&subgraph);
    123     tensorflow::Set_TF_Status_from_Status(status, s);
    124     for (const tensorflow::NodeDef& node : subgraph.node()) {
    125       ops.push_back(node.name());
    126     }
    127   }
    128   else {
    129     for (const auto& op_name : op_names) {
    130       ops.push_back(op_name);
    131     }
    132   }
    133 
    134   PyGILState_STATE gstate = PyGILState_Ensure();
    135   PyObject* result = PyList_New(ops.size());
    136   for (int i = 0; i < ops.size(); ++i) {
    137     PyList_SetItem(result, i, PyString_FromString(ops[i].c_str()));
    138   }
    139   PyGILState_Release(gstate);
    140   return result;
    141 }
    142 
    143 static PyObject* TF_GetOpProperties(GItem item) {
    144   if (item.is_none()) {
    145     Py_RETURN_NONE;
    146   }
    147   tensorflow::grappler::GraphProperties properties(*item);
    148   tensorflow::Status status = properties.InferStatically(false);
    149   if (!status.ok()) {
    150     Py_RETURN_NONE;
    151   }
    152 
    153   PyGILState_STATE gstate = PyGILState_Ensure();
    154   PyObject* props = PyDict_New();
    155   for (const auto& node : item->graph.node()) {
    156     const string& node_name = node.name();
    157     const std::vector<tensorflow::OpInfo::TensorProperties>& output_props =
    158         properties.GetOutputProperties(node_name);
    159 
    160     PyObject* prop = PyList_New(output_props.size());
    161     for (int i = 0; i < output_props.size(); ++i) {
    162       string output_prop_str = output_props[i].SerializeAsString();
    163       PyObject* output_prop = PyBytes_FromStringAndSize(output_prop_str.data(),
    164                                                         output_prop_str.size());
    165       PyList_SetItem(prop, i, output_prop);
    166     }
    167     CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop));
    168   }
    169   PyGILState_Release(gstate);
    170   return props;
    171 }
    172 
    173 class ColocationGroups {
    174 public:
    175   void Group(const string& x, const string& y) {
    176     Rep* x_root = Find(x);
    177     Rep* y_root = Find(y);
    178 
    179     // x and y are already in the same set
    180     if (x_root == y_root) {
    181       return;
    182     }
    183     // x and y are not in same set, so we merge them
    184     // Use the occasion to strengthen what we know about the handle by merging the
    185     // information about the 2 subsets.
    186     if (x_root->rank < y_root->rank) {
    187       x_root->parent = y_root;
    188     } else if (x_root->rank > y_root->rank) {
    189       y_root->parent = x_root;
    190     } else {
    191       // Arbitrarily make one root the new parent
    192       y_root->parent = x_root;
    193       x_root->rank = x_root->rank + 1;
    194     }
    195   }
    196 
    197   void ExtractGroups(std::vector<std::vector<string>>* groups) {
    198     groups->reserve(nodes_.size());
    199     std::unordered_map<const Rep*, int> group_ids;
    200     for (const auto& rep : nodes_) {
    201       Rep* r = Find(rep.first);
    202       auto it = group_ids.find(r);
    203       std::vector<string>* g;
    204       if (it == group_ids.end()) {
    205         int id = group_ids.size();
    206         group_ids[r] = id;
    207         groups->resize(id+1);
    208         g = &groups->back();
    209       } else {
    210         int id = it->second;
    211         g = &((*groups)[id]);
    212       }
    213       g->push_back(rep.first);
    214     }
    215   }
    216 
    217 private:
    218   struct Rep {
    219     // Parent in the tree used to encode the set.
    220     Rep* parent;
    221     // Rank in the tree, used to figure out how to compress the path to the root
    222     // of the tree.
    223     int rank;
    224     // The node.
    225     string value;
    226   };
    227 
    228   Rep* Find(const string& n) {
    229     auto it = nodes_.find(n);
    230     if (it == nodes_.end()) {
    231       // This is the first time we process this handle, create an entry for it.
    232       Rep* node = new Rep;
    233       node->parent = node;
    234       node->rank = 0;
    235       node->value = n;
    236       nodes_[n] = node;
    237       return node;
    238     }
    239     // Return the representative for the set, which is the root of the tree. Apply
    240     // path compression to speedup future queries.
    241     Rep* node = it->second;
    242     Rep* root = node->parent;
    243     while (root != root->parent) {
    244       root = root->parent;
    245     }
    246     while (node->parent != root) {
    247       Rep* next = node->parent;
    248       node->parent = root;
    249       node = next;
    250     }
    251     return root;
    252   }
    253 
    254   std::unordered_map<string, Rep*> nodes_;
    255 };
    256 
    257 static PyObject* TF_GetColocationGroups(GItem item) {
    258   if (item.is_none()) {
    259     Py_RETURN_NONE;
    260   }
    261   ColocationGroups groupings;
    262   tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
    263   for (const auto& node : item->graph.node()) {
    264     const tensorflow::OpDef* op_def;
    265     tensorflow::Status s = registry->LookUpOpDef(node.op(), &op_def);
    266     if (!s.ok()) {
    267       continue;
    268     }
    269     tensorflow::NameRangeMap inputs;
    270     tensorflow::NameRangeMap outputs;
    271     s = tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs);
    272     if (!s.ok()) {
    273       continue;
    274     }
    275     for (const auto& arg : op_def->input_arg()) {
    276       if (!arg.is_ref()) {
    277         continue;
    278       }
    279       const auto& range = inputs[arg.name()];
    280       for (int i = range.first; i < range.second; ++i) {
    281         string input = tensorflow::grappler::NodeName(node.input(i));
    282         groupings.Group(node.name(), input);
    283       }
    284     }
    285   }
    286 
    287   std::vector<std::vector<string>> groups;
    288   groupings.ExtractGroups(&groups);
    289 
    290   PyGILState_STATE gstate = PyGILState_Ensure();
    291   PyObject* result = PyList_New(groups.size());
    292   for (int i = 0; i < groups.size(); ++i) {
    293     const std::vector<string>& group = groups[i];
    294     PyObject* g = PyTuple_New(group.size());
    295     for (int j = 0; j < group.size(); ++j) {
    296       const string& node_name = group[j];
    297       PyTuple_SetItem(g, j, PyString_FromString(node_name.c_str()));
    298     }
    299     PyList_SetItem(result, i, g);
    300   }
    301   PyGILState_Release(gstate);
    302   return result;
    303 }
    304 
    305 %}
    306 
    307 
    308 // Wrap these functions.
    309 static GItem TF_NewItem(
    310     const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
    311     bool ignore_user_placement, TF_Status* out_status);
    312 static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
    313                                          TF_Status* status);
    314 static PyObject* TF_GetOpProperties(GItem item);
    315 static PyObject* TF_GetColocationGroups(GItem item);
    316