1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/util/equal_graph_def.h" 17 18 #include <unordered_map> 19 #include <unordered_set> 20 #include "tensorflow/core/framework/attr_value.pb.h" 21 #include "tensorflow/core/framework/attr_value_util.h" 22 #include "tensorflow/core/framework/graph.pb.h" 23 #include "tensorflow/core/framework/node_def.pb.h" 24 #include "tensorflow/core/framework/node_def_util.h" 25 #include "tensorflow/core/lib/hash/hash.h" 26 #include "tensorflow/core/lib/strings/strcat.h" 27 #include "tensorflow/core/platform/protobuf.h" 28 29 namespace tensorflow { 30 31 bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, 32 string* diff, const EqualGraphDefOptions& options) { 33 // Intentionally do not check that versions match so that this routine can 34 // be used for less brittle golden file tests. 35 return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options); 36 } 37 38 uint64 GraphDefHash(const GraphDef& gdef, const EqualGraphDefOptions& options) { 39 return RepeatedNodeDefHash(gdef.node(), options); 40 } 41 42 bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual, 43 const protobuf::RepeatedPtrField<NodeDef>& expected, 44 string* diff, const EqualGraphDefOptions& options) { 45 std::unordered_map<string, const NodeDef*> actual_index; 46 for (const NodeDef& node : actual) { 47 actual_index[node.name()] = &node; 48 } 49 50 for (const NodeDef& expected_node : expected) { 51 auto actual_iter = actual_index.find(expected_node.name()); 52 if (actual_iter == actual_index.end()) { 53 if (diff != nullptr) { 54 *diff = strings::StrCat("Did not find expected node '", 55 SummarizeNodeDef(expected_node), "'"); 56 } 57 return false; 58 } 59 60 if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) { 61 return false; 62 } 63 64 actual_index.erase(actual_iter); 65 } 66 67 if (!actual_index.empty()) { 68 if (diff != nullptr) { 69 *diff = 70 strings::StrCat("Found unexpected node '", 71 SummarizeNodeDef(*actual_index.begin()->second), "'"); 72 } 73 return false; 74 } 75 76 return true; 77 } 78 79 uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef>& ndefs, 80 const EqualGraphDefOptions& options) { 81 uint64 h = 0xDECAFCAFFE; 82 // Insert NodeDefs into map to deterministically sort by name 83 std::map<string, const NodeDef*> nodes; 84 for (const NodeDef& node : ndefs) { 85 nodes[node.name()] = &node; 86 } 87 for (const auto& pair : nodes) { 88 h = Hash64(pair.first.data(), pair.first.size(), h); 89 h = Hash64Combine(NodeDefHash(*pair.second, options), h); 90 } 91 return h; 92 } 93 94 namespace { 95 96 string JoinStringField(const protobuf::RepeatedPtrField<string>& f) { 97 string ret; 98 for (int i = 0; i < f.size(); ++i) { 99 if (i > 0) strings::StrAppend(&ret, ", "); 100 strings::StrAppend(&ret, f.Get(i)); 101 } 102 return ret; 103 } 104 105 } // namespace 106 107 bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff, 108 const EqualGraphDefOptions& options) { 109 if (actual.name() != expected.name()) { 110 if (diff != nullptr) { 111 *diff = strings::StrCat("Actual node name '", actual.name(), 112 "' is not expected '", expected.name(), "'"); 113 } 114 return false; 115 } 116 117 if (actual.op() != expected.op()) { 118 if (diff != nullptr) { 119 *diff = strings::StrCat("Node named '", actual.name(), "' has op '", 120 actual.op(), "' that is not expected '", 121 expected.op(), "'"); 122 } 123 return false; 124 } 125 126 if (actual.device() != expected.device()) { 127 if (diff != nullptr) { 128 *diff = strings::StrCat("Node named '", actual.name(), "' has device '", 129 actual.device(), "' that is not expected '", 130 expected.device(), "'"); 131 } 132 return false; 133 } 134 135 if (actual.input_size() != expected.input_size()) { 136 if (diff != nullptr) { 137 *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '", 138 JoinStringField(actual.input()), 139 "' that don't match expected '", 140 JoinStringField(expected.input()), "'"); 141 } 142 return false; 143 } 144 145 int first_control_input = actual.input_size(); 146 for (int i = 0; i < actual.input_size(); ++i) { 147 if (StringPiece(actual.input(i)).starts_with("^")) { 148 first_control_input = i; 149 break; 150 } 151 // Special case for inputs: "tensor" is equivalent to "tensor:0" 152 if (actual.input(i) != expected.input(i) && 153 actual.input(i) != strings::StrCat(expected.input(i), ":0") && 154 strings::StrCat(actual.input(i), ":0") != expected.input(i)) { 155 if (diff != nullptr) { 156 *diff = strings::StrCat("Node named '", actual.name(), "' has input ", 157 i, " '", actual.input(i), 158 "' that doesn't match expected '", 159 expected.input(i), "'"); 160 } 161 return false; 162 } 163 } 164 165 std::unordered_set<string> actual_control; 166 std::unordered_set<string> expected_control; 167 for (int i = first_control_input; i < actual.input_size(); ++i) { 168 actual_control.insert(actual.input(i)); 169 expected_control.insert(expected.input(i)); 170 } 171 for (const auto& e : expected_control) { 172 if (actual_control.erase(e) == 0) { 173 if (diff != nullptr) { 174 *diff = strings::StrCat("Node named '", actual.name(), 175 "' missing expected control input '", e, "'"); 176 } 177 return false; 178 } 179 } 180 if (!actual_control.empty()) { 181 if (diff != nullptr) { 182 *diff = strings::StrCat("Node named '", actual.name(), 183 "' has unexpected control input '", 184 *actual_control.begin(), "'"); 185 } 186 return false; 187 } 188 189 std::unordered_set<string> actual_attr; 190 for (const auto& a : actual.attr()) { 191 if (options.ignore_internal_attrs && !a.first.empty() && 192 a.first[0] == '_') { 193 continue; 194 } 195 actual_attr.insert(a.first); 196 } 197 for (const auto& e : expected.attr()) { 198 if (options.ignore_internal_attrs && !e.first.empty() && 199 e.first[0] == '_') { 200 continue; 201 } 202 203 if (actual_attr.erase(e.first) == 0) { 204 if (diff != nullptr) { 205 *diff = strings::StrCat("Node named '", actual.name(), 206 "' missing expected attr '", e.first, 207 "' with value: ", SummarizeAttrValue(e.second)); 208 } 209 return false; 210 } 211 auto iter = actual.attr().find(e.first); 212 if (!AreAttrValuesEqual(e.second, iter->second)) { 213 if (diff != nullptr) { 214 *diff = strings::StrCat( 215 "Node named '", actual.name(), "' has attr '", e.first, 216 "' with value: ", SummarizeAttrValue(iter->second), 217 " that does not match expected: ", SummarizeAttrValue(e.second)); 218 } 219 return false; 220 } 221 } 222 if (!actual_attr.empty()) { 223 if (diff != nullptr) { 224 *diff = strings::StrCat( 225 "Node named '", actual.name(), "' has unexpected attr '", 226 *actual_attr.begin(), "' with value: ", 227 SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second)); 228 } 229 return false; 230 } 231 232 return true; 233 } 234 235 uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) { 236 uint64 h = Hash64(ndef.name()); 237 h = Hash64(ndef.op().data(), ndef.op().size(), h); 238 h = Hash64(ndef.device().data(), ndef.device().size(), h); 239 240 // Normal inputs. Order important. 241 int first_control_input = ndef.input_size(); 242 for (int i = 0; i < ndef.input_size(); ++i) { 243 if (StringPiece(ndef.input(i)).starts_with("^")) { 244 first_control_input = i; 245 break; 246 } 247 h = Hash64(ndef.input(i).data(), ndef.input(i).size(), h); 248 } 249 250 // Control inputs. Order irrelevant. 251 std::set<string> ndef_control; 252 for (int i = first_control_input; i < ndef.input_size(); ++i) { 253 ndef_control.insert(ndef.input(i)); 254 } 255 for (const string& s : ndef_control) { 256 h = Hash64(s.data(), s.size(), h); 257 } 258 259 // Attributes 260 std::map<string, AttrValue> ndef_attr; 261 for (const auto& a : ndef.attr()) { 262 if (options.ignore_internal_attrs && !a.first.empty() && 263 a.first[0] == '_') { 264 continue; 265 } 266 ndef_attr[a.first] = a.second; 267 } 268 for (const auto& a : ndef_attr) { 269 h = Hash64(a.first.data(), a.first.size(), h); 270 h = Hash64Combine(AttrValueHash(a.second), h); 271 } 272 273 return h; 274 } 275 276 } // namespace tensorflow 277