Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
     17 
     18 #include <algorithm>
     19 #include <iterator>
     20 #include <map>
     21 #include <string>
     22 #include <unordered_map>
     23 #include <unordered_set>
     24 #include <utility>
     25 #include <vector>
     26 
     27 #include "tensorflow/core/common_runtime/constant_folding.h"
     28 #include "tensorflow/core/common_runtime/shape_refiner.h"
     29 #include "tensorflow/core/graph/graph_constructor.h"
     30 #include "tensorflow/core/graph/node_builder.h"
     31 #include "tensorflow/core/graph/subgraph.h"
     32 #include "tensorflow/core/lib/core/stringpiece.h"
     33 #include "tensorflow/core/lib/strings/numbers.h"
     34 #include "tensorflow/core/platform/init_main.h"
     35 #include "tensorflow/core/public/session.h"
     36 #include "tensorflow/core/util/command_line_flags.h"
     37 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     38 
     39 namespace tensorflow {
     40 namespace graph_transforms {
     41 namespace {
     42 using StringPieceSet = std::unordered_set<StringPiece, StringPieceHasher>;
     43 template <typename T>
     44 using StringPieceMap = std::unordered_map<StringPiece, T, StringPieceHasher>;
     45 }  // namespace
     46 
     47 Status ReplaceSendRecvs(const GraphDef& original_graph_def,
     48                         const GraphDef& rewritten_graph_def,
     49                         const std::vector<string>& inputs,
     50                         const std::vector<string>& outputs,
     51                         GraphDef* output_graph_def) {
     52   // recv_node_names serves as a string storage for recv node names.
     53   std::vector<string> recv_node_names(inputs.size());
     54   StringPieceMap<TensorId> recv_node_map;
     55   StringPieceSet input_nodes;
     56   for (int i = 0; i < inputs.size(); ++i) {
     57     // RewriteGraphForExecution adds a recv node for each input edge. We assume
     58     // here that adding such recv node did not fail. For example, the original
     59     // graph did not already have a node with the name for the new added recv
     60     // node.
     61     TensorId id = ParseTensorName(inputs[i]);
     62     input_nodes.insert(id.first);
     63     string& recv_node_name = recv_node_names[i];
     64     recv_node_name = strings::StrCat("_recv_", id.first, "_", id.second);
     65     recv_node_map.emplace(recv_node_name, id);
     66   }
     67 
     68   StringPieceMap<const NodeDef*> original_map;
     69   for (const NodeDef& node : original_graph_def.node()) {
     70     original_map.emplace(node.name(), &node);
     71   }
     72 
     73   for (const NodeDef& node : rewritten_graph_def.node()) {
     74     if ((node.op() == "_Send") || (node.op() == "_Recv")) {
     75       // If the op is a Send or Recv that wasn't in the original, skip it.
     76       if (original_map.count(node.name()) == 0) {
     77         continue;
     78       }
     79     }
     80 
     81     NodeDef* new_node = output_graph_def->add_node();
     82     new_node->MergeFrom(node);
     83     for (int i = 0; i < new_node->input_size(); ++i) {
     84       string& input = *new_node->mutable_input(i);
     85       TensorId id = ParseTensorName(input);
     86       const auto iter = recv_node_map.find(id.first);
     87       if (iter != recv_node_map.end()) {
     88         // The node being substituted is a Recv node, and it has only one
     89         // output. If this input is not a control input, then replace the input
     90         // with the mapped value. Otherwise, replace the node name only.
     91         if (id.second != Graph::kControlSlot) {
     92           CHECK_EQ(id.second, 0);
     93           input = iter->second.ToString();
     94         } else {
     95           id.first = iter->second.first;
     96           input = id.ToString();
     97         }
     98       }
     99     }
    100 
    101     // RewriteGraphForExecution() did not remove this input node. Remove this
    102     // node name from input_nodes so that a duplicate does not get added to the
    103     // output_graph_def.
    104     auto iter = input_nodes.find(new_node->name());
    105     if (iter != input_nodes.end()) {
    106       input_nodes.erase(iter);
    107     }
    108   }
    109 
    110   // Some input nodes are removed in rewrite_graph_def. Add those nodes to
    111   // output_graph_def.
    112   for (StringPiece name : input_nodes) {
    113     const NodeDef& removed_node = *CHECK_NOTNULL(original_map[name]);
    114     output_graph_def->add_node()->MergeFrom(removed_node);
    115   }
    116 
    117   return Status::OK();
    118 }
    119 
    120 Status RemoveUnusedNodes(const GraphDef& input_graph_def,
    121                          const TransformFuncContext& context,
    122                          GraphDef* output_graph_def) {
    123   StringPieceMap<const NodeDef*> node_map;
    124   for (const NodeDef& node : input_graph_def.node()) {
    125     node_map.emplace(node.name(), &node);
    126   }
    127 
    128   std::unordered_set<TensorId, TensorId::Hasher> input_names;
    129   for (const string& input : context.input_names) {
    130     input_names.insert(ParseTensorName(input));
    131   }
    132   StringPieceSet used_nodes;
    133   StringPieceSet current_nodes;
    134   for (const string& name : context.output_names) {
    135     TensorId id = ParseTensorName(name);
    136     used_nodes.insert(id.first);
    137     current_nodes.insert(id.first);
    138   }
    139   while (!current_nodes.empty()) {
    140     StringPieceSet next_nodes;
    141     for (StringPiece node_name : current_nodes) {
    142       if (node_map.count(node_name) == 0) {
    143         LOG(ERROR) << "Bad graph structure, no node named '" << node_name
    144                    << "' found for input lookup";
    145         return errors::InvalidArgument("Bad graph structure, no node named '",
    146                                        node_name, "' found for input lookup");
    147       }
    148       const NodeDef& node = *(node_map[node_name]);
    149       for (const string& input : node.input()) {
    150         TensorId id = ParseTensorName(input);
    151         if (input_names.count(id) > 0) {
    152           continue;
    153         }
    154         if (used_nodes.insert(id.first).second) {
    155           next_nodes.insert(id.first);
    156         }
    157       }
    158     }
    159     current_nodes.swap(next_nodes);
    160   }
    161   for (const TensorId& id : input_names) {
    162     used_nodes.insert(id.first);
    163   }
    164   FilterGraphDef(
    165       input_graph_def,
    166       [&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; },
    167       output_graph_def);
    168 
    169   return Status::OK();
    170 }
    171 
    172 // Converts a shape inference handle to a PartialTensorShape.
    173 Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle,
    174                                 shape_inference::InferenceContext* context,
    175                                 PartialTensorShape* shape) {
    176   // The default is already unknown.
    177   if (!context->RankKnown(handle)) return Status::OK();
    178 
    179   std::vector<int64> dims(context->Rank(handle));
    180   for (int32 i = 0; i < dims.size(); ++i) {
    181     dims[i] = context->Value(context->Dim(handle, i));
    182   }
    183   return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
    184 }
    185 
    186 // Converts any sub-graphs that can be resolved into constant expressions into
    187 // single Const ops.
    188 Status FoldConstants(const GraphDef& input_graph_def,
    189                      const TransformFuncContext& context,
    190                      GraphDef* output_graph_def) {
    191   Graph input_graph(OpRegistry::Global());
    192   TF_RETURN_IF_ERROR(input_graph.AddFunctionLibrary(input_graph_def.library()));
    193 
    194   ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry());
    195   shape_refiner.set_require_shape_inference_fns(false);
    196   shape_refiner.set_disable_constant_propagation(false);
    197   shape_refiner.set_function_library_for_shape_inference(
    198       &input_graph.flib_def());
    199 
    200   bool clear_output_shapes;
    201   TF_RETURN_IF_ERROR(context.GetOneBoolParameter("clear_output_shapes", true,
    202                                                  &clear_output_shapes));
    203   if (clear_output_shapes) {
    204     // Some older GraphDefs have saved _output_shapes attributes which are out
    205     // of date and cause import errors, so clean them up first.
    206     GraphDef cleaned_graph_def;
    207     RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def);
    208 
    209     TF_RETURN_IF_ERROR(
    210         ImportGraphDef({}, cleaned_graph_def, &input_graph, &shape_refiner));
    211   } else {
    212     TF_RETURN_IF_ERROR(
    213         ImportGraphDef({}, input_graph_def, &input_graph, &shape_refiner));
    214   }
    215 
    216   // Sorted array of input names as lookup table.
    217   std::vector<TensorId> input_names;
    218   input_names.reserve(context.input_names.size());
    219   std::transform(context.input_names.begin(), context.input_names.end(),
    220                  std::back_inserter(input_names),
    221                  [](const string& name) { return ParseTensorName(name); });
    222 
    223   const auto compare = [](TensorId lhs, TensorId rhs) {
    224     return lhs.first < rhs.first;
    225   };
    226 
    227   std::sort(input_names.begin(), input_names.end(), compare);
    228 
    229   // Set statically inferred shapes.
    230   std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
    231   for (const Node* const node : input_graph.nodes()) {
    232     auto ctx = shape_refiner.GetContext(node);
    233     if (ctx == nullptr) {
    234       continue;
    235     }
    236 
    237     std::vector<PartialTensorShape>& partial_shapes = shape_map[node->name()];
    238     if (ctx->num_outputs() <= 0) continue;
    239     partial_shapes.resize(ctx->num_outputs());
    240 
    241     // Check all outputs.
    242     for (const Edge* out_edge : node->out_edges()) {
    243       if (out_edge->IsControlEdge()) continue;
    244 
    245       const int output_idx = out_edge->src_output();
    246       TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(ctx->output(output_idx), ctx,
    247                                                   &partial_shapes[output_idx]));
    248     }
    249 
    250     // RewriteGraphForExecution() will add a Recv node for each input. Shape
    251     // refiner does not include shape information of these Recv nodes. Therefore
    252     // we add entries for Recv nodes here.
    253     const auto pair = std::equal_range(input_names.begin(), input_names.end(),
    254                                        TensorId{node->name(), 0}, compare);
    255     for (auto it = pair.first; it != pair.second; ++it) {
    256       const string recv_name =
    257           strings::StrCat("_recv_", it->first, "_", it->second);
    258       auto& recv_partial_shapes = shape_map[recv_name];
    259       // For whatever reason (for example, name collision) if the map entry was
    260       // already there, then do nothing.
    261       if (recv_partial_shapes.empty()) {
    262         recv_partial_shapes.push_back(partial_shapes[it->second]);
    263       }
    264     }
    265   }
    266 
    267   subgraph::RewriteGraphMetadata unused_metadata;
    268   TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
    269       &input_graph, context.input_names, context.output_names, {}, {},
    270       false /* use_function_convention */, &unused_metadata));
    271 
    272   ConstantFoldingOptions cf_opts;
    273   cf_opts.shape_map = &shape_map;
    274 
    275   // Exclude specified nodes from constant folding.
    276   if (context.params.count("exclude_op") > 0) {
    277     const auto& excluded_nodes = context.params.at("exclude_op");
    278     const std::set<string> excluded_nodes_set(excluded_nodes.begin(),
    279                                               excluded_nodes.end());
    280     cf_opts.consider = [excluded_nodes_set](const Node* n) {
    281       return excluded_nodes_set.find(n->op_def().name()) ==
    282              excluded_nodes_set.end();
    283     };
    284   }
    285 
    286   // Constant folding.
    287   bool was_mutated;
    288   TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr,
    289                                   &input_graph, &was_mutated));
    290   GraphDef folded_graph_def;
    291   input_graph.ToGraphDef(&folded_graph_def);
    292   GraphDef send_recvs_replaced;
    293   TF_RETURN_IF_ERROR(ReplaceSendRecvs(input_graph_def, folded_graph_def,
    294                                       context.input_names, context.output_names,
    295                                       &send_recvs_replaced));
    296   TF_RETURN_IF_ERROR(
    297       RemoveUnusedNodes(send_recvs_replaced, context, output_graph_def));
    298   return Status::OK();
    299 }
    300 
    301 REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
    302 
    303 }  // namespace graph_transforms
    304 }  // namespace tensorflow
    305