Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // This module implements a common subexpression elimination pass.  We
     17 // process the nodes in the graph in reverse postorder
     18 // (i.e. inputs before their downstream dependencies).  The rough algorithm is
     19 // as follows:
     20 //
     21 // std::unordered_map<size_t, Node*> available
     22 // for each node n in forward topological order:
     23 //   h = NodeHash(n)
     24 //   if available[h] exists and Equivalent(available(h), h)
     25 //     redirect downstream uses of outputs of n to available[h]
     26 //     remove n from graph
     27 //   else
     28 //     if available[h] does not exist
     29 //       available[h] = n
     30 //
     31 // This is similar to the global value number algorithm describe in this
     32 // paper:
     33 //
     34 // "Global code motion/global value numbering", Cliff Click, PLDI '95
     35 // Proceedings of the ACM SIGPLAN 1995 conference on Programming
     36 // language design and implementation, Pages 246-257
     37 //      http://dl.acm.org/citation.cfm?id=207154
     38 
     39 #include "tensorflow/core/graph/optimizer_cse.h"
     40 
     41 #include <unordered_map>
     42 #include <utility>
     43 #include <vector>
     44 
     45 #include "tensorflow/core/framework/node_def.pb.h"
     46 #include "tensorflow/core/graph/algorithm.h"
     47 #include "tensorflow/core/lib/gtl/map_util.h"
     48 #include "tensorflow/core/lib/hash/hash.h"
     49 #include "tensorflow/core/platform/logging.h"
     50 
     51 namespace tensorflow {
     52 
     53 class OptimizerCSE {
     54  public:
     55   explicit OptimizerCSE(Graph* g) : g_(g) {}
     56 
     57   bool Optimize(const std::function<bool(const Node*)>& consider_fn);
     58 
     59  private:
     60   static size_t NodeHash(const Node* n);
     61   static bool Equivalent(const Node* a, const Node* b,
     62                          AttrSlice::Scratch* scratch);
     63 
     64   Graph* g_;
     65 };
     66 
     67 static void FillInputs(const Node* n,
     68                        gtl::InlinedVector<Node*, 4>* control_edges,
     69                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
     70   DCHECK_EQ(in->size(), n->num_inputs());
     71   control_edges->clear();
     72   for (const Edge* e : n->in_edges()) {
     73     if (e->IsControlEdge()) {
     74       control_edges->push_back(e->src());
     75     } else {
     76       (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
     77     }
     78   }
     79   std::sort(control_edges->begin(), control_edges->end());
     80   if (n->op_def().is_commutative()) {
     81     // For commutative inputs, we sort the input by the input Node*
     82     // to get a canonical ordering (so that add(a,b) and add(b, a) will
     83     // hash to the same value if is_commutative is true for 'add').
     84     std::sort(in->begin(), in->end());
     85   }
     86 }
     87 
     88 static size_t kIllegalNodeHash = 0;
     89 
     90 size_t OptimizerCSE::NodeHash(const Node* n) {
     91   const DataTypeVector& out = n->output_types();
     92   string str_to_hash = strings::StrCat(n->type_string(), out.size());
     93   for (DataType dt : out) {
     94     strings::StrAppend(&str_to_hash, dt);
     95   }
     96 
     97   const int N_in = n->num_inputs();
     98   strings::StrAppend(&str_to_hash, N_in);
     99   gtl::InlinedVector<Node*, 4> control_edges;
    100   gtl::InlinedVector<std::pair<Node*, int>, 4> in(N_in);
    101   FillInputs(n, &control_edges, &in);
    102   for (const auto& edge : in) {
    103     strings::StrAppend(&str_to_hash, edge.first->id(), edge.second);
    104   }
    105 
    106   size_t h = Hash64(str_to_hash);
    107 
    108 #if !defined(__ANDROID__)
    109   // Hash the attrs.  For example, this makes sure different constants
    110   // end up in different hash buckets.
    111   string tmp;
    112   for (const auto& attr : n->attrs()) {
    113     tmp = attr.first;
    114     attr.second.AppendToString(&tmp);
    115     // Add hashes of attrs, so the order of attrs doesn't matter.
    116     h += Hash32(tmp.data(), tmp.size(), 0x87341245);
    117   }
    118 #endif
    119 
    120   if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1;
    121   return h;
    122 }
    123 
    124 static bool HasRefInput(const Node* n) {
    125   for (auto dt : n->input_types()) {
    126     if (IsRefType(dt)) return true;
    127   }
    128   return false;
    129 }
    130 
    131 bool OptimizerCSE::Equivalent(const Node* a, const Node* b,
    132                               AttrSlice::Scratch* scratch) {
    133   // Different op names are different
    134   if (a->type_string() != b->type_string()) return false;
    135 
    136   // Never consider stateful nodes (such as non-const inputs) equivalent.
    137   if (a->op_def().is_stateful()) return false;
    138 
    139   // For now, we consider any node that takes a ref input to not be
    140   // equivalent to any other node.
    141   if (HasRefInput(a) || HasRefInput(b)) return false;
    142 
    143   // Compare attrs.  Note that equal attrs implies equal input and
    144   // output types.
    145   if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false;
    146 
    147   // Compare input sources
    148   if (a->num_inputs() != b->num_inputs()) return false;
    149   const int N_in = a->num_inputs();
    150   gtl::InlinedVector<Node*, 4> a_control_edges;
    151   gtl::InlinedVector<Node*, 4> b_control_edges;
    152   gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
    153   gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
    154   FillInputs(a, &a_control_edges, &a_in);
    155   FillInputs(b, &b_control_edges, &b_in);
    156   if (a_in != b_in) return false;
    157   if (a_control_edges != b_control_edges) return false;
    158 
    159   return true;
    160 }
    161 
    162 bool OptimizerCSE::Optimize(
    163     const std::function<bool(const Node*)>& consider_fn) {
    164   // This very simple implementation works if the whole graph is one
    165   // giant basic block (because we just traverse nodes in a
    166   // topological order). This simple implementation works well
    167   // with control flow/loops/etc. But we need to be careful about
    168   // control flow if we want to add more sophisticated CSE optimizations.
    169 
    170   // TODO(jeff): We need to handle Update nodes specially, but dealing
    171   // with more general control flow will also solve this issue, and for
    172   // now, our updates are almost always the most downstream nodes in
    173   // the graph.
    174   std::vector<Node*> order;
    175   GetReversePostOrder(*g_, &order);
    176 
    177   // Our value is just a single Node*, meaning we keep just a single
    178   // candidate for a given node hash value.  This may cause us to
    179   // (rarely) lose some optimization opportunities if there are
    180   // hash collisions, but it allows us to avoid having the value
    181   // be a set<Node*> (or equivalent).
    182   std::unordered_map<size_t, Node*> available;
    183 
    184   // Scratch space for Equivalent calls.  Allocated here and passed in to
    185   // Equivalent to avoid allocation inside the loop below.
    186   bool changed = false;
    187   AttrSlice::Scratch scratch;
    188   for (Node* n : order) {
    189     if (!n->IsOp()) continue;
    190 
    191     // Don't prune placeholder nodes.
    192     if (n->type_string() == "Placeholder" ||
    193         n->type_string() == "PlaceholderV2" ||
    194         n->type_string() == "PlaceholderWithDefault") {
    195       continue;
    196     }
    197 
    198     // See if we should consider this node at all
    199     if (consider_fn != nullptr && !consider_fn(n)) continue;
    200 
    201     size_t h = NodeHash(n);
    202     Node** candidate = &available[h];
    203     if (*candidate == nullptr) {
    204       // No existing match: insert "n" into the hash table under "h"
    205       *candidate = n;
    206     } else if (Equivalent(*candidate, n, &scratch)) {
    207       VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and "
    208               << n->name();
    209       // *candidate and n are equivalent.  Therefore, we can replace
    210       // n with *candidate by fixing up outgoing edges from "n" to instead
    211       // come from "*candidate", and then delete n from the graph
    212       for (const Edge* e : n->out_edges()) {
    213         g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input());
    214       }
    215 
    216       g_->RemoveNode(n);
    217       changed = true;
    218     }
    219   }
    220   return changed;
    221 }
    222 
    223 bool OptimizeCSE(Graph* g,
    224                  const std::function<bool(const Node*)>& consider_fn) {
    225   OptimizerCSE opt(g);
    226   return opt.Optimize(consider_fn);
    227 }
    228 
    229 }  // namespace tensorflow
    230