Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 #include "tensorflow/core/util/equal_graph_def.h"
     18 #include <unordered_map>
     19 #include <unordered_set>
     20 #include "tensorflow/core/framework/attr_value.pb.h"
     21 #include "tensorflow/core/framework/attr_value_util.h"
     22 #include "tensorflow/core/framework/graph.pb.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 #include "tensorflow/core/framework/node_def_util.h"
     25 #include "tensorflow/core/lib/hash/hash.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/platform/protobuf.h"
     29 namespace tensorflow {
     31 bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
     32                    string* diff, const EqualGraphDefOptions& options) {
     33   // Intentionally do not check that versions match so that this routine can
     34   // be used for less brittle golden file tests.
     35   return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
     36 }
     38 uint64 GraphDefHash(const GraphDef& gdef, const EqualGraphDefOptions& options) {
     39   return RepeatedNodeDefHash(gdef.node(), options);
     40 }
     42 bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
     43                           const protobuf::RepeatedPtrField<NodeDef>& expected,
     44                           string* diff, const EqualGraphDefOptions& options) {
     45   std::unordered_map<string, const NodeDef*> actual_index;
     46   for (const NodeDef& node : actual) {
     47     actual_index[node.name()] = &node;
     48   }
     50   for (const NodeDef& expected_node : expected) {
     51     auto actual_iter = actual_index.find(expected_node.name());
     52     if (actual_iter == actual_index.end()) {
     53       if (diff != nullptr) {
     54         *diff = strings::StrCat("Did not find expected node '",
     55                                 SummarizeNodeDef(expected_node), "'");
     56       }
     57       return false;
     58     }
     60     if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
     61       return false;
     62     }
     64     actual_index.erase(actual_iter);
     65   }
     67   if (!actual_index.empty()) {
     68     if (diff != nullptr) {
     69       *diff =
     70           strings::StrCat("Found unexpected node '",
     71                           SummarizeNodeDef(*actual_index.begin()->second), "'");
     72     }
     73     return false;
     74   }
     76   return true;
     77 }
     79 uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef>& ndefs,
     80                            const EqualGraphDefOptions& options) {
     81   uint64 h = 0xDECAFCAFFE;
     82   // Insert NodeDefs into map to deterministically sort by name
     83   std::map<string, const NodeDef*> nodes;
     84   for (const NodeDef& node : ndefs) {
     85     nodes[node.name()] = &node;
     86   }
     87   for (const auto& pair : nodes) {
     88     h = Hash64(pair.first.data(), pair.first.size(), h);
     89     h = Hash64Combine(NodeDefHash(*pair.second, options), h);
     90   }
     91   return h;
     92 }
     94 namespace {
     96 string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
     97   string ret;
     98   for (int i = 0; i < f.size(); ++i) {
     99     if (i > 0) strings::StrAppend(&ret, ", ");
    100     strings::StrAppend(&ret, f.Get(i));
    101   }
    102   return ret;
    103 }
    105 }  // namespace
    107 bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
    108                   const EqualGraphDefOptions& options) {
    109   if (actual.name() != expected.name()) {
    110     if (diff != nullptr) {
    111       *diff = strings::StrCat("Actual node name '", actual.name(),
    112                               "' is not expected '", expected.name(), "'");
    113     }
    114     return false;
    115   }
    117   if (actual.op() != expected.op()) {
    118     if (diff != nullptr) {
    119       *diff = strings::StrCat("Node named '", actual.name(), "' has op '",
    120                               actual.op(), "' that is not expected '",
    121                               expected.op(), "'");
    122     }
    123     return false;
    124   }
    126   if (actual.device() != expected.device()) {
    127     if (diff != nullptr) {
    128       *diff = strings::StrCat("Node named '", actual.name(), "' has device '",
    129                               actual.device(), "' that is not expected '",
    130                               expected.device(), "'");
    131     }
    132     return false;
    133   }
    135   if (actual.input_size() != expected.input_size()) {
    136     if (diff != nullptr) {
    137       *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '",
    138                               JoinStringField(actual.input()),
    139                               "' that don't match expected '",
    140                               JoinStringField(expected.input()), "'");
    141     }
    142     return false;
    143   }
    145   int first_control_input = actual.input_size();
    146   for (int i = 0; i < actual.input_size(); ++i) {
    147     if (StringPiece(actual.input(i)).starts_with("^")) {
    148       first_control_input = i;
    149       break;
    150     }
    151     // Special case for inputs: "tensor" is equivalent to "tensor:0"
    152     if (actual.input(i) != expected.input(i) &&
    153         actual.input(i) != strings::StrCat(expected.input(i), ":0") &&
    154         strings::StrCat(actual.input(i), ":0") != expected.input(i)) {
    155       if (diff != nullptr) {
    156         *diff = strings::StrCat("Node named '", actual.name(), "' has input ",
    157                                 i, " '", actual.input(i),
    158                                 "' that doesn't match expected '",
    159                                 expected.input(i), "'");
    160       }
    161       return false;
    162     }
    163   }
    165   std::unordered_set<string> actual_control;
    166   std::unordered_set<string> expected_control;
    167   for (int i = first_control_input; i < actual.input_size(); ++i) {
    168     actual_control.insert(actual.input(i));
    169     expected_control.insert(expected.input(i));
    170   }
    171   for (const auto& e : expected_control) {
    172     if (actual_control.erase(e) == 0) {
    173       if (diff != nullptr) {
    174         *diff = strings::StrCat("Node named '", actual.name(),
    175                                 "' missing expected control input '", e, "'");
    176       }
    177       return false;
    178     }
    179   }
    180   if (!actual_control.empty()) {
    181     if (diff != nullptr) {
    182       *diff = strings::StrCat("Node named '", actual.name(),
    183                               "' has unexpected control input '",
    184                               *actual_control.begin(), "'");
    185     }
    186     return false;
    187   }
    189   std::unordered_set<string> actual_attr;
    190   for (const auto& a : actual.attr()) {
    191     if (options.ignore_internal_attrs && !a.first.empty() &&
    192         a.first[0] == '_') {
    193       continue;
    194     }
    195     actual_attr.insert(a.first);
    196   }
    197   for (const auto& e : expected.attr()) {
    198     if (options.ignore_internal_attrs && !e.first.empty() &&
    199         e.first[0] == '_') {
    200       continue;
    201     }
    203     if (actual_attr.erase(e.first) == 0) {
    204       if (diff != nullptr) {
    205         *diff = strings::StrCat("Node named '", actual.name(),
    206                                 "' missing expected attr '", e.first,
    207                                 "' with value: ", SummarizeAttrValue(e.second));
    208       }
    209       return false;
    210     }
    211     auto iter = actual.attr().find(e.first);
    212     if (!AreAttrValuesEqual(e.second, iter->second)) {
    213       if (diff != nullptr) {
    214         *diff = strings::StrCat(
    215             "Node named '", actual.name(), "' has attr '", e.first,
    216             "' with value: ", SummarizeAttrValue(iter->second),
    217             " that does not match expected: ", SummarizeAttrValue(e.second));
    218       }
    219       return false;
    220     }
    221   }
    222   if (!actual_attr.empty()) {
    223     if (diff != nullptr) {
    224       *diff = strings::StrCat(
    225           "Node named '", actual.name(), "' has unexpected attr '",
    226           *actual_attr.begin(), "' with value: ",
    227           SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second));
    228     }
    229     return false;
    230   }
    232   return true;
    233 }
    235 uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) {
    236   uint64 h = Hash64(ndef.name());
    237   h = Hash64(ndef.op().data(), ndef.op().size(), h);
    238   h = Hash64(ndef.device().data(), ndef.device().size(), h);
    240   // Normal inputs. Order important.
    241   int first_control_input = ndef.input_size();
    242   for (int i = 0; i < ndef.input_size(); ++i) {
    243     if (StringPiece(ndef.input(i)).starts_with("^")) {
    244       first_control_input = i;
    245       break;
    246     }
    247     h = Hash64(ndef.input(i).data(), ndef.input(i).size(), h);
    248   }
    250   // Control inputs. Order irrelevant.
    251   std::set<string> ndef_control;
    252   for (int i = first_control_input; i < ndef.input_size(); ++i) {
    253     ndef_control.insert(ndef.input(i));
    254   }
    255   for (const string& s : ndef_control) {
    256     h = Hash64(s.data(), s.size(), h);
    257   }
    259   // Attributes
    260   std::map<string, AttrValue> ndef_attr;
    261   for (const auto& a : ndef.attr()) {
    262     if (options.ignore_internal_attrs && !a.first.empty() &&
    263         a.first[0] == '_') {
    264       continue;
    265     }
    266     ndef_attr[a.first] = a.second;
    267   }
    268   for (const auto& a : ndef_attr) {
    269     h = Hash64(a.first.data(), a.first.size(), h);
    270     h = Hash64Combine(AttrValueHash(a.second), h);
    271   }
    273   return h;
    274 }
    276 }  // namespace tensorflow