1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" 17 18 #include <algorithm> 19 #include <iterator> 20 #include <map> 21 #include <string> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 27 #include "tensorflow/core/common_runtime/constant_folding.h" 28 #include "tensorflow/core/common_runtime/shape_refiner.h" 29 #include "tensorflow/core/graph/graph_constructor.h" 30 #include "tensorflow/core/graph/node_builder.h" 31 #include "tensorflow/core/graph/subgraph.h" 32 #include "tensorflow/core/lib/core/stringpiece.h" 33 #include "tensorflow/core/lib/strings/numbers.h" 34 #include "tensorflow/core/platform/init_main.h" 35 #include "tensorflow/core/public/session.h" 36 #include "tensorflow/core/util/command_line_flags.h" 37 #include "tensorflow/tools/graph_transforms/transform_utils.h" 38 39 namespace tensorflow { 40 namespace graph_transforms { 41 namespace { 42 using StringPieceSet = std::unordered_set<StringPiece, StringPieceHasher>; 43 template <typename T> 44 using StringPieceMap = std::unordered_map<StringPiece, T, StringPieceHasher>; 45 } // namespace 46 47 Status ReplaceSendRecvs(const GraphDef& original_graph_def, 48 const GraphDef& rewritten_graph_def, 49 const std::vector<string>& inputs, 50 const std::vector<string>& outputs, 51 GraphDef* output_graph_def) { 52 // recv_node_names serves as a string storage for recv node names. 53 std::vector<string> recv_node_names(inputs.size()); 54 StringPieceMap<TensorId> recv_node_map; 55 StringPieceSet input_nodes; 56 for (int i = 0; i < inputs.size(); ++i) { 57 // RewriteGraphForExecution adds a recv node for each input edge. We assume 58 // here that adding such recv node did not fail. For example, the original 59 // graph did not already have a node with the name for the new added recv 60 // node. 61 TensorId id = ParseTensorName(inputs[i]); 62 input_nodes.insert(id.first); 63 string& recv_node_name = recv_node_names[i]; 64 recv_node_name = strings::StrCat("_recv_", id.first, "_", id.second); 65 recv_node_map.emplace(recv_node_name, id); 66 } 67 68 StringPieceMap<const NodeDef*> original_map; 69 for (const NodeDef& node : original_graph_def.node()) { 70 original_map.emplace(node.name(), &node); 71 } 72 73 for (const NodeDef& node : rewritten_graph_def.node()) { 74 if ((node.op() == "_Send") || (node.op() == "_Recv")) { 75 // If the op is a Send or Recv that wasn't in the original, skip it. 76 if (original_map.count(node.name()) == 0) { 77 continue; 78 } 79 } 80 81 NodeDef* new_node = output_graph_def->add_node(); 82 new_node->MergeFrom(node); 83 for (int i = 0; i < new_node->input_size(); ++i) { 84 string& input = *new_node->mutable_input(i); 85 TensorId id = ParseTensorName(input); 86 const auto iter = recv_node_map.find(id.first); 87 if (iter != recv_node_map.end()) { 88 // The node being substituted is a Recv node, and it has only one 89 // output. If this input is not a control input, then replace the input 90 // with the mapped value. Otherwise, replace the node name only. 91 if (id.second != Graph::kControlSlot) { 92 CHECK_EQ(id.second, 0); 93 input = iter->second.ToString(); 94 } else { 95 id.first = iter->second.first; 96 input = id.ToString(); 97 } 98 } 99 } 100 101 // RewriteGraphForExecution() did not remove this input node. Remove this 102 // node name from input_nodes so that a duplicate does not get added to the 103 // output_graph_def. 104 auto iter = input_nodes.find(new_node->name()); 105 if (iter != input_nodes.end()) { 106 input_nodes.erase(iter); 107 } 108 } 109 110 // Some input nodes are removed in rewrite_graph_def. Add those nodes to 111 // output_graph_def. 112 for (StringPiece name : input_nodes) { 113 const NodeDef& removed_node = *CHECK_NOTNULL(original_map[name]); 114 output_graph_def->add_node()->MergeFrom(removed_node); 115 } 116 117 return Status::OK(); 118 } 119 120 Status RemoveUnusedNodes(const GraphDef& input_graph_def, 121 const TransformFuncContext& context, 122 GraphDef* output_graph_def) { 123 StringPieceMap<const NodeDef*> node_map; 124 for (const NodeDef& node : input_graph_def.node()) { 125 node_map.emplace(node.name(), &node); 126 } 127 128 std::unordered_set<TensorId, TensorId::Hasher> input_names; 129 for (const string& input : context.input_names) { 130 input_names.insert(ParseTensorName(input)); 131 } 132 StringPieceSet used_nodes; 133 StringPieceSet current_nodes; 134 for (const string& name : context.output_names) { 135 TensorId id = ParseTensorName(name); 136 used_nodes.insert(id.first); 137 current_nodes.insert(id.first); 138 } 139 while (!current_nodes.empty()) { 140 StringPieceSet next_nodes; 141 for (StringPiece node_name : current_nodes) { 142 if (node_map.count(node_name) == 0) { 143 LOG(ERROR) << "Bad graph structure, no node named '" << node_name 144 << "' found for input lookup"; 145 return errors::InvalidArgument("Bad graph structure, no node named '", 146 node_name, "' found for input lookup"); 147 } 148 const NodeDef& node = *(node_map[node_name]); 149 for (const string& input : node.input()) { 150 TensorId id = ParseTensorName(input); 151 if (input_names.count(id) > 0) { 152 continue; 153 } 154 if (used_nodes.insert(id.first).second) { 155 next_nodes.insert(id.first); 156 } 157 } 158 } 159 current_nodes.swap(next_nodes); 160 } 161 for (const TensorId& id : input_names) { 162 used_nodes.insert(id.first); 163 } 164 FilterGraphDef( 165 input_graph_def, 166 [&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; }, 167 output_graph_def); 168 169 return Status::OK(); 170 } 171 172 // Converts a shape inference handle to a PartialTensorShape. 173 Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle, 174 shape_inference::InferenceContext* context, 175 PartialTensorShape* shape) { 176 // The default is already unknown. 177 if (!context->RankKnown(handle)) return Status::OK(); 178 179 std::vector<int64> dims(context->Rank(handle)); 180 for (int32 i = 0; i < dims.size(); ++i) { 181 dims[i] = context->Value(context->Dim(handle, i)); 182 } 183 return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); 184 } 185 186 // Converts any sub-graphs that can be resolved into constant expressions into 187 // single Const ops. 188 Status FoldConstants(const GraphDef& input_graph_def, 189 const TransformFuncContext& context, 190 GraphDef* output_graph_def) { 191 Graph input_graph(OpRegistry::Global()); 192 TF_RETURN_IF_ERROR(input_graph.AddFunctionLibrary(input_graph_def.library())); 193 194 ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry()); 195 shape_refiner.set_require_shape_inference_fns(false); 196 shape_refiner.set_disable_constant_propagation(false); 197 shape_refiner.set_function_library_for_shape_inference( 198 &input_graph.flib_def()); 199 200 bool clear_output_shapes; 201 TF_RETURN_IF_ERROR(context.GetOneBoolParameter("clear_output_shapes", true, 202 &clear_output_shapes)); 203 if (clear_output_shapes) { 204 // Some older GraphDefs have saved _output_shapes attributes which are out 205 // of date and cause import errors, so clean them up first. 206 GraphDef cleaned_graph_def; 207 RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def); 208 209 TF_RETURN_IF_ERROR( 210 ImportGraphDef({}, cleaned_graph_def, &input_graph, &shape_refiner)); 211 } else { 212 TF_RETURN_IF_ERROR( 213 ImportGraphDef({}, input_graph_def, &input_graph, &shape_refiner)); 214 } 215 216 // Sorted array of input names as lookup table. 217 std::vector<TensorId> input_names; 218 input_names.reserve(context.input_names.size()); 219 std::transform(context.input_names.begin(), context.input_names.end(), 220 std::back_inserter(input_names), 221 [](const string& name) { return ParseTensorName(name); }); 222 223 const auto compare = [](TensorId lhs, TensorId rhs) { 224 return lhs.first < rhs.first; 225 }; 226 227 std::sort(input_names.begin(), input_names.end(), compare); 228 229 // Set statically inferred shapes. 230 std::unordered_map<string, std::vector<PartialTensorShape>> shape_map; 231 for (const Node* const node : input_graph.nodes()) { 232 auto ctx = shape_refiner.GetContext(node); 233 if (ctx == nullptr) { 234 continue; 235 } 236 237 std::vector<PartialTensorShape>& partial_shapes = shape_map[node->name()]; 238 if (ctx->num_outputs() <= 0) continue; 239 partial_shapes.resize(ctx->num_outputs()); 240 241 // Check all outputs. 242 for (const Edge* out_edge : node->out_edges()) { 243 if (out_edge->IsControlEdge()) continue; 244 245 const int output_idx = out_edge->src_output(); 246 TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(ctx->output(output_idx), ctx, 247 &partial_shapes[output_idx])); 248 } 249 250 // RewriteGraphForExecution() will add a Recv node for each input. Shape 251 // refiner does not include shape information of these Recv nodes. Therefore 252 // we add entries for Recv nodes here. 253 const auto pair = std::equal_range(input_names.begin(), input_names.end(), 254 TensorId{node->name(), 0}, compare); 255 for (auto it = pair.first; it != pair.second; ++it) { 256 const string recv_name = 257 strings::StrCat("_recv_", it->first, "_", it->second); 258 auto& recv_partial_shapes = shape_map[recv_name]; 259 // For whatever reason (for example, name collision) if the map entry was 260 // already there, then do nothing. 261 if (recv_partial_shapes.empty()) { 262 recv_partial_shapes.push_back(partial_shapes[it->second]); 263 } 264 } 265 } 266 267 subgraph::RewriteGraphMetadata unused_metadata; 268 TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( 269 &input_graph, context.input_names, context.output_names, {}, {}, 270 false /* use_function_convention */, &unused_metadata)); 271 272 ConstantFoldingOptions cf_opts; 273 cf_opts.shape_map = &shape_map; 274 275 // Exclude specified nodes from constant folding. 276 if (context.params.count("exclude_op") > 0) { 277 const auto& excluded_nodes = context.params.at("exclude_op"); 278 const std::set<string> excluded_nodes_set(excluded_nodes.begin(), 279 excluded_nodes.end()); 280 cf_opts.consider = [excluded_nodes_set](const Node* n) { 281 return excluded_nodes_set.find(n->op_def().name()) == 282 excluded_nodes_set.end(); 283 }; 284 } 285 286 // Constant folding. 287 bool was_mutated; 288 TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr, 289 &input_graph, &was_mutated)); 290 GraphDef folded_graph_def; 291 input_graph.ToGraphDef(&folded_graph_def); 292 GraphDef send_recvs_replaced; 293 TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def, 294 context.input_names, context.output_names, 295 &send_recvs_replaced)); 296 TF_RETURN_IF_ERROR( 297 RemoveUnusedNodes(send_recvs_replaced, context, output_graph_def)); 298 return Status::OK(); 299 } 300 301 REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants); 302 303 } // namespace graph_transforms 304 } // namespace tensorflow 305