Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/equal_graph_def.h"
     17 
     18 #include <unordered_map>
     19 #include <unordered_set>
     20 #include "tensorflow/core/framework/attr_value.pb.h"
     21 #include "tensorflow/core/framework/attr_value_util.h"
     22 #include "tensorflow/core/framework/graph.pb.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 #include "tensorflow/core/framework/node_def_util.h"
     25 #include "tensorflow/core/lib/hash/hash.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/platform/protobuf.h"
     28 
     29 namespace tensorflow {
     30 
     31 bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
     32                    string* diff, const EqualGraphDefOptions& options) {
     33   // Intentionally do not check that versions match so that this routine can
     34   // be used for less brittle golden file tests.
     35   return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
     36 }
     37 
     38 uint64 GraphDefHash(const GraphDef& gdef, const EqualGraphDefOptions& options) {
     39   return RepeatedNodeDefHash(gdef.node(), options);
     40 }
     41 
     42 bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
     43                           const protobuf::RepeatedPtrField<NodeDef>& expected,
     44                           string* diff, const EqualGraphDefOptions& options) {
     45   std::unordered_map<string, const NodeDef*> actual_index;
     46   for (const NodeDef& node : actual) {
     47     actual_index[node.name()] = &node;
     48   }
     49 
     50   for (const NodeDef& expected_node : expected) {
     51     auto actual_iter = actual_index.find(expected_node.name());
     52     if (actual_iter == actual_index.end()) {
     53       if (diff != nullptr) {
     54         *diff = strings::StrCat("Did not find expected node '",
     55                                 SummarizeNodeDef(expected_node), "'");
     56       }
     57       return false;
     58     }
     59 
     60     if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
     61       return false;
     62     }
     63 
     64     actual_index.erase(actual_iter);
     65   }
     66 
     67   if (!actual_index.empty()) {
     68     if (diff != nullptr) {
     69       *diff =
     70           strings::StrCat("Found unexpected node '",
     71                           SummarizeNodeDef(*actual_index.begin()->second), "'");
     72     }
     73     return false;
     74   }
     75 
     76   return true;
     77 }
     78 
     79 uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef>& ndefs,
     80                            const EqualGraphDefOptions& options) {
     81   uint64 h = 0xDECAFCAFFE;
     82   // Insert NodeDefs into map to deterministically sort by name
     83   std::map<string, const NodeDef*> nodes;
     84   for (const NodeDef& node : ndefs) {
     85     nodes[node.name()] = &node;
     86   }
     87   for (const auto& pair : nodes) {
     88     h = Hash64(pair.first.data(), pair.first.size(), h);
     89     h = Hash64Combine(NodeDefHash(*pair.second, options), h);
     90   }
     91   return h;
     92 }
     93 
     94 namespace {
     95 
     96 string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
     97   string ret;
     98   for (int i = 0; i < f.size(); ++i) {
     99     if (i > 0) strings::StrAppend(&ret, ", ");
    100     strings::StrAppend(&ret, f.Get(i));
    101   }
    102   return ret;
    103 }
    104 
    105 }  // namespace
    106 
    107 bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
    108                   const EqualGraphDefOptions& options) {
    109   if (actual.name() != expected.name()) {
    110     if (diff != nullptr) {
    111       *diff = strings::StrCat("Actual node name '", actual.name(),
    112                               "' is not expected '", expected.name(), "'");
    113     }
    114     return false;
    115   }
    116 
    117   if (actual.op() != expected.op()) {
    118     if (diff != nullptr) {
    119       *diff = strings::StrCat("Node named '", actual.name(), "' has op '",
    120                               actual.op(), "' that is not expected '",
    121                               expected.op(), "'");
    122     }
    123     return false;
    124   }
    125 
    126   if (actual.device() != expected.device()) {
    127     if (diff != nullptr) {
    128       *diff = strings::StrCat("Node named '", actual.name(), "' has device '",
    129                               actual.device(), "' that is not expected '",
    130                               expected.device(), "'");
    131     }
    132     return false;
    133   }
    134 
    135   if (actual.input_size() != expected.input_size()) {
    136     if (diff != nullptr) {
    137       *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '",
    138                               JoinStringField(actual.input()),
    139                               "' that don't match expected '",
    140                               JoinStringField(expected.input()), "'");
    141     }
    142     return false;
    143   }
    144 
    145   int first_control_input = actual.input_size();
    146   for (int i = 0; i < actual.input_size(); ++i) {
    147     if (StringPiece(actual.input(i)).starts_with("^")) {
    148       first_control_input = i;
    149       break;
    150     }
    151     // Special case for inputs: "tensor" is equivalent to "tensor:0"
    152     if (actual.input(i) != expected.input(i) &&
    153         actual.input(i) != strings::StrCat(expected.input(i), ":0") &&
    154         strings::StrCat(actual.input(i), ":0") != expected.input(i)) {
    155       if (diff != nullptr) {
    156         *diff = strings::StrCat("Node named '", actual.name(), "' has input ",
    157                                 i, " '", actual.input(i),
    158                                 "' that doesn't match expected '",
    159                                 expected.input(i), "'");
    160       }
    161       return false;
    162     }
    163   }
    164 
    165   std::unordered_set<string> actual_control;
    166   std::unordered_set<string> expected_control;
    167   for (int i = first_control_input; i < actual.input_size(); ++i) {
    168     actual_control.insert(actual.input(i));
    169     expected_control.insert(expected.input(i));
    170   }
    171   for (const auto& e : expected_control) {
    172     if (actual_control.erase(e) == 0) {
    173       if (diff != nullptr) {
    174         *diff = strings::StrCat("Node named '", actual.name(),
    175                                 "' missing expected control input '", e, "'");
    176       }
    177       return false;
    178     }
    179   }
    180   if (!actual_control.empty()) {
    181     if (diff != nullptr) {
    182       *diff = strings::StrCat("Node named '", actual.name(),
    183                               "' has unexpected control input '",
    184                               *actual_control.begin(), "'");
    185     }
    186     return false;
    187   }
    188 
    189   std::unordered_set<string> actual_attr;
    190   for (const auto& a : actual.attr()) {
    191     if (options.ignore_internal_attrs && !a.first.empty() &&
    192         a.first[0] == '_') {
    193       continue;
    194     }
    195     actual_attr.insert(a.first);
    196   }
    197   for (const auto& e : expected.attr()) {
    198     if (options.ignore_internal_attrs && !e.first.empty() &&
    199         e.first[0] == '_') {
    200       continue;
    201     }
    202 
    203     if (actual_attr.erase(e.first) == 0) {
    204       if (diff != nullptr) {
    205         *diff = strings::StrCat("Node named '", actual.name(),
    206                                 "' missing expected attr '", e.first,
    207                                 "' with value: ", SummarizeAttrValue(e.second));
    208       }
    209       return false;
    210     }
    211     auto iter = actual.attr().find(e.first);
    212     if (!AreAttrValuesEqual(e.second, iter->second)) {
    213       if (diff != nullptr) {
    214         *diff = strings::StrCat(
    215             "Node named '", actual.name(), "' has attr '", e.first,
    216             "' with value: ", SummarizeAttrValue(iter->second),
    217             " that does not match expected: ", SummarizeAttrValue(e.second));
    218       }
    219       return false;
    220     }
    221   }
    222   if (!actual_attr.empty()) {
    223     if (diff != nullptr) {
    224       *diff = strings::StrCat(
    225           "Node named '", actual.name(), "' has unexpected attr '",
    226           *actual_attr.begin(), "' with value: ",
    227           SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second));
    228     }
    229     return false;
    230   }
    231 
    232   return true;
    233 }
    234 
    235 uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) {
    236   uint64 h = Hash64(ndef.name());
    237   h = Hash64(ndef.op().data(), ndef.op().size(), h);
    238   h = Hash64(ndef.device().data(), ndef.device().size(), h);
    239 
    240   // Normal inputs. Order important.
    241   int first_control_input = ndef.input_size();
    242   for (int i = 0; i < ndef.input_size(); ++i) {
    243     if (StringPiece(ndef.input(i)).starts_with("^")) {
    244       first_control_input = i;
    245       break;
    246     }
    247     h = Hash64(ndef.input(i).data(), ndef.input(i).size(), h);
    248   }
    249 
    250   // Control inputs. Order irrelevant.
    251   std::set<string> ndef_control;
    252   for (int i = first_control_input; i < ndef.input_size(); ++i) {
    253     ndef_control.insert(ndef.input(i));
    254   }
    255   for (const string& s : ndef_control) {
    256     h = Hash64(s.data(), s.size(), h);
    257   }
    258 
    259   // Attributes
    260   std::map<string, AttrValue> ndef_attr;
    261   for (const auto& a : ndef.attr()) {
    262     if (options.ignore_internal_attrs && !a.first.empty() &&
    263         a.first[0] == '_') {
    264       continue;
    265     }
    266     ndef_attr[a.first] = a.second;
    267   }
    268   for (const auto& a : ndef_attr) {
    269     h = Hash64(a.first.data(), a.first.size(), h);
    270     h = Hash64Combine(AttrValueHash(a.second), h);
    271   }
    272 
    273   return h;
    274 }
    275 
    276 }  // namespace tensorflow
    277