1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 %include <std_shared_ptr.i> 17 %typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { 18 char* c_string; 19 Py_ssize_t py_size; 20 if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { 21 // Python has raised an error (likely TypeError or UnicodeEncodeError). 22 SWIG_fail; 23 } 24 25 if (!temp.ParseFromString(string(c_string, py_size))) { 26 PyErr_SetString( 27 PyExc_TypeError, 28 "The MetaGraphDef could not be parsed as a valid protocol buffer"); 29 SWIG_fail; 30 } 31 $1 = &temp; 32 } 33 34 // Wrap the item into an object that swig can manipulate. This ensures it will call the object 35 // destructor upon garbage collection instead of leaking memory. 36 struct GItem { 37 std::shared_ptr<tensorflow::grappler::GrapplerItem> item_; 38 }; 39 40 41 %{ 42 #include <unordered_set> 43 #include <map> 44 #include "tensorflow/c/tf_status_helper.h" 45 #include "tensorflow/core/framework/node_def_util.h" 46 #include "tensorflow/core/grappler/utils.h" 47 #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" 48 #include "tensorflow/core/grappler/grappler_item_builder.h" 49 #include "tensorflow/core/grappler/costs/graph_properties.h" 50 #include "tensorflow/core/grappler/utils/topological_sort.h" 51 #include "tensorflow/core/lib/core/error_codes.pb.h" 52 #include "tensorflow/core/lib/core/status.h" 53 #include "tensorflow/core/protobuf/meta_graph.pb.h" 54 #include "tensorflow/core/lib/strings/strcat.h" 55 56 // Provide the implementation fo the GItem struct here. 57 struct GItem { 58 GItem() {} 59 GItem(tensorflow::grappler::GrapplerItem* item) : item_(item) {} 60 61 tensorflow::grappler::GrapplerItem* operator->() const { 62 return item_.get(); 63 } 64 const tensorflow::grappler::GrapplerItem& operator*() const { 65 return *item_.get(); 66 } 67 bool is_none() const { 68 return item_.get() == nullptr; 69 } 70 std::shared_ptr<tensorflow::grappler::GrapplerItem> item_; 71 }; 72 73 static GItem TF_NewItem( 74 const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation, 75 bool ignore_user_placement, TF_Status* out_status) { 76 if (meta_graph.collection_def().count("train_op") == 0) { 77 tensorflow::Set_TF_Status_from_Status( 78 out_status, 79 tensorflow::errors::InvalidArgument("train_op not specified in the metagraph")); 80 return nullptr; 81 } 82 83 tensorflow::grappler::ItemConfig cfg; 84 cfg.ignore_user_placement = ignore_user_placement; 85 cfg.ignore_colocation = ignore_colocation; 86 std::unique_ptr<tensorflow::grappler::GrapplerItem> item = 87 tensorflow::grappler::GrapplerItemFromMetaGraphDef("item", meta_graph, cfg); 88 if (!item) { 89 tensorflow::Set_TF_Status_from_Status( 90 out_status, 91 tensorflow::errors::InvalidArgument("Invalid metagraph")); 92 return nullptr; 93 } 94 tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK()); 95 return GItem(item.release()); 96 } 97 98 static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically, 99 TF_Status* status) { 100 if (item.is_none()) { 101 Py_RETURN_NONE; 102 } 103 104 std::vector<const tensorflow::NodeDef*> main_ops = item->MainOpsFanin(); 105 std::vector<const tensorflow::NodeDef*> enqueue_ops = item->EnqueueOpsFanin(); 106 std::unordered_set<string> op_names; 107 for (auto op : main_ops) { 108 op_names.insert(op->name()); 109 } 110 for (auto op : enqueue_ops) { 111 op_names.insert(op->name()); 112 } 113 114 std::vector<string> ops; 115 if (sort_topologically) { 116 tensorflow::GraphDef subgraph; 117 for (const tensorflow::NodeDef& node : item->graph.node()) { 118 if (op_names.find(node.name()) != op_names.end()) { 119 *subgraph.add_node() = node; 120 } 121 } 122 tensorflow::Status s = tensorflow::grappler::TopologicalSort(&subgraph); 123 tensorflow::Set_TF_Status_from_Status(status, s); 124 for (const tensorflow::NodeDef& node : subgraph.node()) { 125 ops.push_back(node.name()); 126 } 127 } 128 else { 129 for (const auto& op_name : op_names) { 130 ops.push_back(op_name); 131 } 132 } 133 134 PyGILState_STATE gstate = PyGILState_Ensure(); 135 PyObject* result = PyList_New(ops.size()); 136 for (int i = 0; i < ops.size(); ++i) { 137 PyList_SetItem(result, i, PyString_FromString(ops[i].c_str())); 138 } 139 PyGILState_Release(gstate); 140 return result; 141 } 142 143 static PyObject* TF_GetOpProperties(GItem item) { 144 if (item.is_none()) { 145 Py_RETURN_NONE; 146 } 147 tensorflow::grappler::GraphProperties properties(*item); 148 tensorflow::Status status = properties.InferStatically(false); 149 if (!status.ok()) { 150 Py_RETURN_NONE; 151 } 152 153 PyGILState_STATE gstate = PyGILState_Ensure(); 154 PyObject* props = PyDict_New(); 155 for (const auto& node : item->graph.node()) { 156 const string& node_name = node.name(); 157 const std::vector<tensorflow::OpInfo::TensorProperties>& output_props = 158 properties.GetOutputProperties(node_name); 159 160 PyObject* prop = PyList_New(output_props.size()); 161 for (int i = 0; i < output_props.size(); ++i) { 162 string output_prop_str = output_props[i].SerializeAsString(); 163 PyObject* output_prop = PyBytes_FromStringAndSize(output_prop_str.data(), 164 output_prop_str.size()); 165 PyList_SetItem(prop, i, output_prop); 166 } 167 CHECK_EQ(0, PyDict_SetItem(props, PyString_FromString(node_name.c_str()), prop)); 168 } 169 PyGILState_Release(gstate); 170 return props; 171 } 172 173 class ColocationGroups { 174 public: 175 void Group(const string& x, const string& y) { 176 Rep* x_root = Find(x); 177 Rep* y_root = Find(y); 178 179 // x and y are already in the same set 180 if (x_root == y_root) { 181 return; 182 } 183 // x and y are not in same set, so we merge them 184 // Use the occasion to strengthen what we know about the handle by merging the 185 // information about the 2 subsets. 186 if (x_root->rank < y_root->rank) { 187 x_root->parent = y_root; 188 } else if (x_root->rank > y_root->rank) { 189 y_root->parent = x_root; 190 } else { 191 // Arbitrarily make one root the new parent 192 y_root->parent = x_root; 193 x_root->rank = x_root->rank + 1; 194 } 195 } 196 197 void ExtractGroups(std::vector<std::vector<string>>* groups) { 198 groups->reserve(nodes_.size()); 199 std::unordered_map<const Rep*, int> group_ids; 200 for (const auto& rep : nodes_) { 201 Rep* r = Find(rep.first); 202 auto it = group_ids.find(r); 203 std::vector<string>* g; 204 if (it == group_ids.end()) { 205 int id = group_ids.size(); 206 group_ids[r] = id; 207 groups->resize(id+1); 208 g = &groups->back(); 209 } else { 210 int id = it->second; 211 g = &((*groups)[id]); 212 } 213 g->push_back(rep.first); 214 } 215 } 216 217 private: 218 struct Rep { 219 // Parent in the tree used to encode the set. 220 Rep* parent; 221 // Rank in the tree, used to figure out how to compress the path to the root 222 // of the tree. 223 int rank; 224 // The node. 225 string value; 226 }; 227 228 Rep* Find(const string& n) { 229 auto it = nodes_.find(n); 230 if (it == nodes_.end()) { 231 // This is the first time we process this handle, create an entry for it. 232 Rep* node = new Rep; 233 node->parent = node; 234 node->rank = 0; 235 node->value = n; 236 nodes_[n] = node; 237 return node; 238 } 239 // Return the representative for the set, which is the root of the tree. Apply 240 // path compression to speedup future queries. 241 Rep* node = it->second; 242 Rep* root = node->parent; 243 while (root != root->parent) { 244 root = root->parent; 245 } 246 while (node->parent != root) { 247 Rep* next = node->parent; 248 node->parent = root; 249 node = next; 250 } 251 return root; 252 } 253 254 std::unordered_map<string, Rep*> nodes_; 255 }; 256 257 static PyObject* TF_GetColocationGroups(GItem item) { 258 if (item.is_none()) { 259 Py_RETURN_NONE; 260 } 261 ColocationGroups groupings; 262 tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global(); 263 for (const auto& node : item->graph.node()) { 264 const tensorflow::OpDef* op_def; 265 tensorflow::Status s = registry->LookUpOpDef(node.op(), &op_def); 266 if (!s.ok()) { 267 continue; 268 } 269 tensorflow::NameRangeMap inputs; 270 tensorflow::NameRangeMap outputs; 271 s = tensorflow::NameRangesForNode(node, *op_def, &inputs, &outputs); 272 if (!s.ok()) { 273 continue; 274 } 275 for (const auto& arg : op_def->input_arg()) { 276 if (!arg.is_ref()) { 277 continue; 278 } 279 const auto& range = inputs[arg.name()]; 280 for (int i = range.first; i < range.second; ++i) { 281 string input = tensorflow::grappler::NodeName(node.input(i)); 282 groupings.Group(node.name(), input); 283 } 284 } 285 } 286 287 std::vector<std::vector<string>> groups; 288 groupings.ExtractGroups(&groups); 289 290 PyGILState_STATE gstate = PyGILState_Ensure(); 291 PyObject* result = PyList_New(groups.size()); 292 for (int i = 0; i < groups.size(); ++i) { 293 const std::vector<string>& group = groups[i]; 294 PyObject* g = PyTuple_New(group.size()); 295 for (int j = 0; j < group.size(); ++j) { 296 const string& node_name = group[j]; 297 PyTuple_SetItem(g, j, PyString_FromString(node_name.c_str())); 298 } 299 PyList_SetItem(result, i, g); 300 } 301 PyGILState_Release(gstate); 302 return result; 303 } 304 305 %} 306 307 308 // Wrap these functions. 309 static GItem TF_NewItem( 310 const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation, 311 bool ignore_user_placement, TF_Status* out_status); 312 static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically, 313 TF_Status* status); 314 static PyObject* TF_GetOpProperties(GItem item); 315 static PyObject* TF_GetColocationGroups(GItem item); 316