1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_def_util.h" 17 18 #include <set> 19 #include <unordered_map> 20 #include <unordered_set> 21 #include "tensorflow/core/framework/attr_value.pb.h" 22 #include "tensorflow/core/framework/attr_value_util.h" 23 #include "tensorflow/core/framework/op_def.pb_text.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/core/stringpiece.h" 27 #include "tensorflow/core/lib/gtl/map_util.h" 28 #include "tensorflow/core/lib/hash/hash.h" 29 #include "tensorflow/core/lib/strings/proto_serialization.h" 30 #include "tensorflow/core/lib/strings/scanner.h" 31 #include "tensorflow/core/lib/strings/str_util.h" 32 #include "tensorflow/core/lib/strings/strcat.h" 33 #include "tensorflow/core/platform/mutex.h" 34 #include "tensorflow/core/platform/protobuf.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace tensorflow { 38 namespace { // ------ Helper functions ------ 39 40 bool HasAttrStyleType(const OpDef::ArgDef& arg) { 41 return arg.type() != DT_INVALID || !arg.type_attr().empty() || 42 !arg.type_list_attr().empty(); 43 } 44 45 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { 46 const AttrValue& allowed_values(attr.allowed_values()); 47 for (auto allowed : allowed_values.list().type()) { 48 if (dt == allowed) { 49 return Status::OK(); 50 } 51 } 52 string allowed_str; 53 for (int i = 0; i < allowed_values.list().type_size(); ++i) { 54 if (!allowed_str.empty()) { 55 strings::StrAppend(&allowed_str, ", "); 56 } 57 strings::StrAppend(&allowed_str, 58 DataTypeString(allowed_values.list().type(i))); 59 } 60 return errors::InvalidArgument( 61 "Value for attr '", attr.name(), "' of ", DataTypeString(dt), 62 " is not in the list of allowed values: ", allowed_str); 63 } 64 65 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { 66 const AttrValue& allowed_values(attr.allowed_values()); 67 for (const auto& allowed : allowed_values.list().s()) { 68 if (str == allowed) { 69 return Status::OK(); 70 } 71 } 72 string allowed_str; 73 for (const string& allowed : allowed_values.list().s()) { 74 if (!allowed_str.empty()) { 75 strings::StrAppend(&allowed_str, ", "); 76 } 77 strings::StrAppend(&allowed_str, "\"", allowed, "\""); 78 } 79 return errors::InvalidArgument( 80 "Value for attr '", attr.name(), "' of \"", str, 81 "\" is not in the list of allowed values: ", allowed_str); 82 } 83 84 } // namespace 85 86 // Requires: attr has already been validated. 87 Status ValidateAttrValue(const AttrValue& attr_value, 88 const OpDef::AttrDef& attr) { 89 // Is it a valid value? 90 TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()), 91 " for attr '", attr.name(), "'"); 92 93 // Does the value satisfy the minimum constraint in the AttrDef? 94 if (attr.has_minimum()) { 95 if (attr.type() == "int") { 96 if (attr_value.i() < attr.minimum()) { 97 return errors::InvalidArgument( 98 "Value for attr '", attr.name(), "' of ", attr_value.i(), 99 " must be at least minimum ", attr.minimum()); 100 } 101 } else { 102 int length = -1; 103 if (attr.type() == "list(string)") { 104 length = attr_value.list().s_size(); 105 } else if (attr.type() == "list(int)") { 106 length = attr_value.list().i_size(); 107 } else if (attr.type() == "list(float)") { 108 length = attr_value.list().f_size(); 109 } else if (attr.type() == "list(bool)") { 110 length = attr_value.list().b_size(); 111 } else if (attr.type() == "list(type)") { 112 length = attr_value.list().type_size(); 113 } else if (attr.type() == "list(shape)") { 114 length = attr_value.list().shape_size(); 115 } else if (attr.type() == "list(tensor)") { 116 length = attr_value.list().tensor_size(); 117 } else if (attr.type() == "list(func)") { 118 length = attr_value.list().func_size(); 119 } 120 if (length < attr.minimum()) { 121 return errors::InvalidArgument( 122 "Length for attr '", attr.name(), "' of ", length, 123 " must be at least minimum ", attr.minimum()); 124 } 125 } 126 } 127 128 // Does the value satisfy the allowed_value constraint in the AttrDef? 129 if (attr.has_allowed_values()) { 130 if (attr.type() == "type") { 131 TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr)); 132 } else if (attr.type() == "list(type)") { 133 for (int dt : attr_value.list().type()) { 134 TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr)); 135 } 136 } else if (attr.type() == "string") { 137 TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr)); 138 } else if (attr.type() == "list(string)") { 139 for (const string& str : attr_value.list().s()) { 140 TF_RETURN_IF_ERROR(AllowedStringValue(str, attr)); 141 } 142 } else { 143 return errors::Unimplemented( 144 "Support for allowed_values not implemented for type ", attr.type()); 145 } 146 } 147 return Status::OK(); 148 } 149 150 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { 151 for (int i = 0; i < op_def.attr_size(); ++i) { 152 if (op_def.attr(i).name() == name) { 153 return &op_def.attr(i); 154 } 155 } 156 return nullptr; 157 } 158 159 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { 160 for (int i = 0; i < op_def->attr_size(); ++i) { 161 if (op_def->attr(i).name() == name) { 162 return op_def->mutable_attr(i); 163 } 164 } 165 return nullptr; 166 } 167 168 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { 169 for (int i = 0; i < op_def.input_arg_size(); ++i) { 170 if (op_def.input_arg(i).name() == name) { 171 return &op_def.input_arg(i); 172 } 173 } 174 return nullptr; 175 } 176 177 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { 178 for (int i = 0; i < api_def.in_arg_size(); ++i) { 179 if (api_def.in_arg(i).name() == name) { 180 return &api_def.in_arg(i); 181 } 182 } 183 return nullptr; 184 } 185 186 #define VALIDATE(EXPR, ...) \ 187 do { \ 188 if (!(EXPR)) { \ 189 return errors::InvalidArgument( \ 190 __VA_ARGS__, "; in OpDef: ", ProtoShortDebugString(op_def)); \ 191 } \ 192 } while (false) 193 194 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, 195 bool output, std::set<string>* names) { 196 const string suffix = strings::StrCat( 197 output ? " for output '" : " for input '", arg.name(), "'"); 198 VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), 199 "Duplicate name: ", arg.name()); 200 VALIDATE(HasAttrStyleType(arg), "Missing type", suffix); 201 202 if (!arg.number_attr().empty()) { 203 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); 204 VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'", 205 suffix); 206 VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length", 207 suffix, " has type ", attr->type(), " != int"); 208 VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length", 209 suffix, " must have minimum"); 210 VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length", 211 suffix, " must have minimum >= 0"); 212 VALIDATE(arg.type_list_attr().empty(), 213 "Can't have both number_attr and type_list_attr", suffix); 214 VALIDATE((arg.type() != DT_INVALID ? 1 : 0) + 215 (!arg.type_attr().empty() ? 1 : 0) == 216 1, 217 "Exactly one of type, type_attr must be set", suffix); 218 } else { 219 const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) + 220 (!arg.type_attr().empty() ? 1 : 0) + 221 (!arg.type_list_attr().empty() ? 1 : 0); 222 VALIDATE(num_type_fields == 1, 223 "Exactly one of type, type_attr, type_list_attr must be set", 224 suffix); 225 } 226 227 if (!arg.type_attr().empty()) { 228 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); 229 VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'", 230 suffix); 231 VALIDATE(attr->type() == "type", "Attr '", attr->name(), 232 "' used as type_attr", suffix, " has type ", attr->type(), 233 " != type"); 234 } else if (!arg.type_list_attr().empty()) { 235 const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def); 236 VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'", 237 suffix); 238 VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(), 239 "' used as type_list_attr", suffix, " has type ", attr->type(), 240 " != list(type)"); 241 } else { 242 // All argument types should be non-reference types at this point. 243 // ArgDef.is_ref is set to true for reference arguments. 244 VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '", 245 DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix); 246 } 247 248 return Status::OK(); 249 } 250 251 Status ValidateOpDef(const OpDef& op_def) { 252 using ::tensorflow::strings::Scanner; 253 254 if (!str_util::StartsWith(op_def.name(), "_")) { 255 VALIDATE(Scanner(op_def.name()) 256 .One(Scanner::UPPERLETTER) 257 .Any(Scanner::LETTER_DIGIT) 258 .Eos() 259 .GetResult(), 260 "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); 261 } 262 263 std::set<string> names; // for detecting duplicate names 264 for (const auto& attr : op_def.attr()) { 265 // Validate name 266 VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), 267 "Duplicate name: ", attr.name()); 268 DataType dt; 269 VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ", 270 attr.name(), " that matches a data type"); 271 272 // Validate type 273 StringPiece type(attr.type()); 274 bool is_list = str_util::ConsumePrefix(&type, "list("); 275 bool found = false; 276 for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape", 277 "tensor", "func"}) { 278 if (str_util::ConsumePrefix(&type, valid)) { 279 found = true; 280 break; 281 } 282 } 283 VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(), 284 "'"); 285 if (is_list) { 286 VALIDATE(str_util::ConsumePrefix(&type, ")"), 287 "'list(' is missing ')' in attr ", attr.name(), "'s type ", 288 attr.type()); 289 } 290 VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ", 291 attr.name(), "'s type ", attr.type()); 292 293 // Validate minimum 294 if (attr.has_minimum()) { 295 VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(), 296 "' has minimum for unsupported type ", attr.type()); 297 if (is_list) { 298 VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(), 299 "' with list type must have a non-negative minimum, not ", 300 attr.minimum()); 301 } 302 } else { 303 VALIDATE(attr.minimum() == 0, "Attr '", attr.name(), 304 "' with has_minimum = false but minimum ", attr.minimum(), 305 " not equal to default of 0"); 306 } 307 308 // Validate allowed_values 309 if (attr.has_allowed_values()) { 310 const string list_type = 311 is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")"); 312 TF_RETURN_WITH_CONTEXT_IF_ERROR( 313 AttrValueHasType(attr.allowed_values(), list_type), " for attr '", 314 attr.name(), "' in Op '", op_def.name(), "'"); 315 } 316 317 // Validate default_value (after we have validated the rest of the attr, 318 // so we can use ValidateAttrValue()). 319 if (attr.has_default_value()) { 320 TF_RETURN_WITH_CONTEXT_IF_ERROR( 321 ValidateAttrValue(attr.default_value(), attr), " in Op '", 322 op_def.name(), "'"); 323 } 324 } 325 326 for (const auto& arg : op_def.input_arg()) { 327 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names)); 328 } 329 330 for (const auto& arg : op_def.output_arg()) { 331 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); 332 } 333 334 return Status::OK(); 335 } 336 337 #undef VALIDATE 338 339 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { 340 if (op_def.has_deprecation()) { 341 const OpDeprecation& dep = op_def.deprecation(); 342 if (graph_def_version >= dep.version()) { 343 return errors::Unimplemented( 344 "Op ", op_def.name(), " is not available in GraphDef version ", 345 graph_def_version, ". It has been removed in version ", dep.version(), 346 ". ", dep.explanation(), "."); 347 } else { 348 // Warn only once for each op name, and do it in a threadsafe manner. 349 static mutex mu(LINKER_INITIALIZED); 350 static std::unordered_set<string> warned; 351 bool warn; 352 { 353 mutex_lock lock(mu); 354 warn = warned.insert(op_def.name()).second; 355 } 356 if (warn) { 357 LOG(WARNING) << "Op " << op_def.name() << " is deprecated." 358 << " It will cease to work in GraphDef version " 359 << dep.version() << ". " << dep.explanation() << "."; 360 } 361 } 362 } 363 return Status::OK(); 364 } 365 366 namespace { 367 368 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { 369 string ret; 370 for (const OpDef::ArgDef& arg : args) { 371 if (!ret.empty()) strings::StrAppend(&ret, ", "); 372 strings::StrAppend(&ret, arg.name(), ":"); 373 if (arg.is_ref()) strings::StrAppend(&ret, "Ref("); 374 if (!arg.number_attr().empty()) { 375 strings::StrAppend(&ret, arg.number_attr(), "*"); 376 } 377 if (arg.type() != DT_INVALID) { 378 strings::StrAppend(&ret, DataTypeString(arg.type())); 379 } else { 380 strings::StrAppend(&ret, arg.type_attr()); 381 } 382 if (arg.is_ref()) strings::StrAppend(&ret, ")"); 383 } 384 return ret; 385 } 386 387 } // namespace 388 389 string SummarizeOpDef(const OpDef& op_def) { 390 string ret = strings::StrCat("Op<name=", op_def.name()); 391 strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()), 392 " -> ", SummarizeArgs(op_def.output_arg())); 393 for (int i = 0; i < op_def.attr_size(); ++i) { 394 strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":", 395 op_def.attr(i).type()); 396 if (op_def.attr(i).has_default_value()) { 397 strings::StrAppend(&ret, ",default=", 398 SummarizeAttrValue(op_def.attr(i).default_value())); 399 } 400 if (op_def.attr(i).has_minimum()) { 401 strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum()); 402 } 403 if (op_def.attr(i).has_allowed_values()) { 404 strings::StrAppend(&ret, ",allowed=", 405 SummarizeAttrValue(op_def.attr(i).allowed_values())); 406 } 407 } 408 if (op_def.is_commutative()) { 409 strings::StrAppend(&ret, "; is_commutative=true"); 410 } 411 if (op_def.is_aggregate()) { 412 strings::StrAppend(&ret, "; is_aggregate=true"); 413 } 414 if (op_def.is_stateful()) { 415 strings::StrAppend(&ret, "; is_stateful=true"); 416 } 417 if (op_def.allows_uninitialized_input()) { 418 strings::StrAppend(&ret, "; allows_uninitialized_input=true"); 419 } 420 strings::StrAppend(&ret, ">"); 421 return ret; 422 } 423 424 namespace { 425 426 // Returns true if every element of `sub` is contained in `super`. 427 template <class T> 428 bool IsSubsetOf(const T& sub, const T& super) { 429 for (const auto& o : sub) { 430 bool found = false; 431 for (const auto& n : super) { 432 if (o == n) { 433 found = true; 434 break; 435 } 436 } 437 if (!found) return false; 438 } 439 return true; 440 } 441 442 bool MoreRestrictive(const OpDef::AttrDef& old_attr, 443 const OpDef::AttrDef& new_attr) { 444 // Anything -> no restriction : not more restrictive. 445 if (!new_attr.has_allowed_values()) return false; 446 // No restriction -> restriction : more restrictive. 447 if (!old_attr.has_allowed_values()) return true; 448 // If anything that was previously allowed is no longer allowed: 449 // more restrictive. 450 if (!IsSubsetOf(old_attr.allowed_values().list().type(), 451 new_attr.allowed_values().list().type())) { 452 return true; 453 } 454 if (!IsSubsetOf(old_attr.allowed_values().list().s(), 455 new_attr.allowed_values().list().s())) { 456 return true; 457 } 458 return false; 459 } 460 461 string AllowedStr(const OpDef::AttrDef& attr) { 462 if (!attr.has_allowed_values()) return "no restriction"; 463 return SummarizeAttrValue(attr.allowed_values()); 464 } 465 466 string DefaultAttrStr(const OpDef::AttrDef& attr) { 467 if (!attr.has_default_value()) return "no default"; 468 return SummarizeAttrValue(attr.default_value()); 469 } 470 471 bool HigherMinimum(const OpDef::AttrDef& old_attr, 472 const OpDef::AttrDef& new_attr) { 473 // Anything -> no restriction : not more restrictive. 474 if (!new_attr.has_minimum()) return false; 475 // No restriction -> restriction : more restrictive. 476 if (!old_attr.has_minimum()) return true; 477 // If anything that was previously allowed is no longer allowed: 478 // more restrictive. 479 return new_attr.minimum() > old_attr.minimum(); 480 } 481 482 string MinStr(const OpDef::AttrDef& attr) { 483 if (!attr.has_minimum()) return "no minimum"; 484 return strings::StrCat(attr.minimum()); 485 } 486 487 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap; 488 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) { 489 for (const auto& attr : op_def.attr()) { 490 (*attr_map)[attr.name()] = &attr; 491 } 492 } 493 494 // Add a comma to *s every call but the first (*add_comma should be 495 // initialized to false). 496 void AddComma(string* s, bool* add_comma) { 497 if (*add_comma) { 498 strings::StrAppend(s, ", "); 499 } else { 500 *add_comma = true; 501 } 502 } 503 504 // Will add the `name` from arg if name is true. 505 void AddName(string* s, bool name, const OpDef::ArgDef& arg) { 506 if (name) { 507 strings::StrAppend(s, arg.name(), ":"); 508 } 509 } 510 511 // Compute a signature for either inputs or outputs that will be the 512 // same for both the old and new OpDef if they are compatible. We 513 // assume that new_attrs is a superset of old_attrs, and that any attr 514 // in the difference has a default. Our strategy is to make a list of 515 // types, where the types are things like: 516 // * "int32", "float", etc., 517 // * "T" for some attr "T" in old_attrs, or 518 // * "N * type" for "N" either some attr in old_attrs. 519 // 520 // We get the types by either using the attrs in args if they are in 521 // old_attrs, or substituting the default value from new_attrs. 522 string ComputeArgSignature( 523 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, 524 const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref, 525 bool names) { 526 string s; 527 bool add_comma = false; 528 for (const OpDef::ArgDef& arg : args) { 529 if (!arg.type_list_attr().empty()) { 530 const OpDef::AttrDef* old_attr = 531 gtl::FindPtrOrNull(old_attrs, arg.type_list_attr()); 532 if (old_attr) { 533 // Both old and new have the list(type) attr, so can use it directly. 534 AddComma(&s, &add_comma); 535 AddName(&s, names, arg); 536 strings::StrAppend(&s, arg.type_list_attr()); 537 ref->push_back(arg.is_ref()); 538 } else { 539 // Missing the list(type) attr in the old, so use the default 540 // value for the attr from new instead. 541 const OpDef::AttrDef* new_attr = 542 gtl::FindPtrOrNull(new_attrs, arg.type_list_attr()); 543 const auto& type_list = new_attr->default_value().list().type(); 544 if (type_list.empty()) continue; 545 for (int i = 0; i < type_list.size(); ++i) { 546 AddComma(&s, &add_comma); 547 AddName(&s, names, arg); 548 strings::StrAppend( 549 &s, DataTypeString(static_cast<DataType>(type_list.Get(i)))); 550 ref->push_back(arg.is_ref()); 551 } 552 } 553 } else { 554 int num = 1; // How many input/outputs does this represent? 555 string type; // What is the type of this arg? 556 AddName(&type, names, arg); 557 if (!arg.number_attr().empty()) { 558 // N * type case. 559 const OpDef::AttrDef* old_attr = 560 gtl::FindPtrOrNull(old_attrs, arg.number_attr()); 561 if (old_attr) { 562 // Both old and new have the number attr, so can use it directly. 563 strings::StrAppend(&type, arg.number_attr(), " * "); 564 } else { 565 // Missing the number attr in the old, so use the default 566 // value for the attr from new instead. 567 const OpDef::AttrDef* new_attr = 568 gtl::FindPtrOrNull(new_attrs, arg.number_attr()); 569 num = new_attr->default_value().i(); 570 } 571 } 572 573 if (arg.type() != DT_INVALID) { 574 // int32, float, etc. case 575 strings::StrAppend(&type, DataTypeString(arg.type())); 576 } else { 577 const OpDef::AttrDef* old_attr = 578 gtl::FindPtrOrNull(old_attrs, arg.type_attr()); 579 if (old_attr) { 580 // Both old and new have the type attr, so can use it directly. 581 strings::StrAppend(&type, arg.type_attr()); 582 } else { 583 // Missing the type attr in the old, so use the default 584 // value for the attr from new instead. 585 const OpDef::AttrDef* new_attr = 586 gtl::FindPtrOrNull(new_attrs, arg.type_attr()); 587 strings::StrAppend(&type, 588 DataTypeString(new_attr->default_value().type())); 589 } 590 } 591 592 // Record `num` * `type` in the signature. 593 for (int i = 0; i < num; ++i) { 594 AddComma(&s, &add_comma); 595 strings::StrAppend(&s, type); 596 ref->push_back(arg.is_ref()); 597 } 598 } 599 } 600 601 return s; 602 } 603 604 } // namespace 605 606 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { 607 #define VALIDATE(CONDITION, ...) \ 608 if (!(CONDITION)) { \ 609 return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \ 610 "; old: ", SummarizeOpDef(old_op), \ 611 "; new: ", SummarizeOpDef(new_op)); \ 612 } 613 614 VALIDATE(old_op.name() == new_op.name(), "Name mismatch"); 615 616 AttrMap new_attrs, old_attrs; 617 FillAttrMap(old_op, &old_attrs); 618 FillAttrMap(new_op, &new_attrs); 619 for (const auto& old_attr : old_op.attr()) { 620 const OpDef::AttrDef* new_attr = 621 gtl::FindPtrOrNull(new_attrs, old_attr.name()); 622 VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed"); 623 VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(), 624 "' changed type '", old_attr.type(), "' -> '", new_attr->type(), 625 "'"); 626 VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(), 627 "' has a stricter set of allowed values; from ", 628 AllowedStr(old_attr), " to ", AllowedStr(*new_attr)); 629 VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(), 630 "' has a higher minimum; from ", MinStr(old_attr), " to ", 631 MinStr(*new_attr)); 632 } 633 634 for (const auto& new_attr : new_op.attr()) { 635 const OpDef::AttrDef* old_attr = 636 gtl::FindPtrOrNull(old_attrs, new_attr.name()); 637 VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '", 638 new_attr.name(), "' added without default"); 639 } 640 641 std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref; 642 const string old_in_sig = ComputeArgSignature( 643 old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */); 644 const string new_in_sig = ComputeArgSignature( 645 new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */); 646 VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig, 647 "' vs. '", new_in_sig, "'"); 648 VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen 649 "Unexpected change in input ref lists."); 650 for (int i = 0; i < old_in_ref.size(); ++i) { 651 // Allowed to remove "ref" from an input (or leave it unchanged). 652 VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i, 653 " changed from non-ref to ref"); 654 } 655 656 const string old_out_sig = 657 ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs, 658 &old_out_ref, true /* names */); 659 const string new_out_sig = 660 ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs, 661 &new_out_ref, true /* names */); 662 VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '", 663 old_out_sig, "' vs. '", new_out_sig, "'"); 664 VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen 665 "Unexpected change in output ref lists"); 666 for (int i = 0; i < old_out_ref.size(); ++i) { 667 // Allowed to add "ref" to an output (or leave it unchanged). 668 VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i, 669 " changed from ref to non-ref"); 670 } 671 672 return Status::OK(); 673 } 674 675 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, 676 const OpDef& penultimate_op, 677 const OpDef& new_op) { 678 AttrMap new_attrs, old_attrs; 679 FillAttrMap(old_op, &old_attrs); 680 FillAttrMap(new_op, &new_attrs); 681 682 for (const auto& penultimate_attr : penultimate_op.attr()) { 683 const OpDef::AttrDef* old_attr = 684 gtl::FindPtrOrNull(old_attrs, penultimate_attr.name()); 685 if (old_attr != nullptr) continue; // attr wasn't added 686 const OpDef::AttrDef* new_attr = 687 gtl::FindPtrOrNull(new_attrs, penultimate_attr.name()); 688 689 // These shouldn't happen if the op passed OpDefCompatible(). 690 if (new_attr == nullptr) { 691 return errors::InvalidArgument("Missing attr '", penultimate_attr.name(), 692 "' in op: ", SummarizeOpDef(new_op)); 693 } 694 if (!penultimate_attr.has_default_value() || 695 !new_attr->has_default_value()) { 696 return errors::InvalidArgument("Missing default for attr '", 697 penultimate_attr.name(), 698 "' in op: ", SummarizeOpDef(new_op)); 699 } 700 701 // Actually test that the attr's default value hasn't changed. 702 if (!AreAttrValuesEqual(penultimate_attr.default_value(), 703 new_attr->default_value())) { 704 return errors::InvalidArgument( 705 "Can't change default value for attr '", penultimate_attr.name(), 706 "' from ", SummarizeAttrValue(penultimate_attr.default_value()), 707 " in op: ", SummarizeOpDef(new_op)); 708 } 709 } 710 711 return Status::OK(); 712 } 713 714 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { 715 AttrMap new_attrs, old_attrs; 716 FillAttrMap(old_op, &old_attrs); 717 FillAttrMap(new_op, &new_attrs); 718 719 for (const auto& old_attr : old_op.attr()) { 720 const OpDef::AttrDef* new_attr = 721 gtl::FindPtrOrNull(new_attrs, old_attr.name()); 722 if (new_attr == nullptr) continue; 723 if (old_attr.has_default_value() != new_attr->has_default_value()) { 724 return errors::InvalidArgument( 725 "Attr '", old_attr.name(), "' has added/removed it's default; ", 726 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr)); 727 } 728 if (old_attr.has_default_value() && 729 !AreAttrValuesEqual(old_attr.default_value(), 730 new_attr->default_value())) { 731 return errors::InvalidArgument( 732 "Attr '", old_attr.name(), "' has changed it's default value; ", 733 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr)); 734 } 735 } 736 737 return Status::OK(); 738 } 739 740 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) { 741 for (int i = 0; i < op_def->input_arg_size(); ++i) { 742 op_def->mutable_input_arg(i)->clear_description(); 743 } 744 for (int i = 0; i < op_def->output_arg_size(); ++i) { 745 op_def->mutable_output_arg(i)->clear_description(); 746 } 747 for (int i = 0; i < op_def->attr_size(); ++i) { 748 op_def->mutable_attr(i)->clear_description(); 749 } 750 op_def->clear_summary(); 751 op_def->clear_description(); 752 } 753 754 void RemoveDescriptionsFromOpDef(OpDef* op_def) { 755 RemoveNonDeprecationDescriptionsFromOpDef(op_def); 756 if (op_def->has_deprecation()) { 757 op_def->mutable_deprecation()->clear_explanation(); 758 } 759 } 760 761 void RemoveDescriptionsFromOpList(OpList* op_list) { 762 for (int i = 0; i < op_list->op_size(); ++i) { 763 OpDef* op_def = op_list->mutable_op(i); 764 RemoveDescriptionsFromOpDef(op_def); 765 } 766 } 767 768 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) { 769 #ifndef TENSORFLOW_LITE_PROTOS 770 DCHECK_EQ(7, a1.GetDescriptor()->field_count()) 771 << "Please modify these equality and hash functions to reflect the " 772 "changes to the AttrDef protobuf"; 773 #endif // TENSORFLOW_LITE_PROTOS 774 775 if (a1.name() != a2.name()) return false; 776 if (a1.type() != a2.type()) return false; 777 if (a1.description() != a2.description()) return false; 778 if (a1.has_minimum() != a2.has_minimum()) return false; 779 if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false; 780 if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false; 781 if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values())) 782 return false; 783 return true; 784 } 785 786 uint64 AttrDefHash(const OpDef::AttrDef& a) { 787 uint64 h = Hash64(a.name()); 788 h = Hash64(a.type().data(), a.type().size(), h); 789 h = Hash64Combine(AttrValueHash(a.default_value()), h); 790 h = Hash64(a.description().data(), a.description().size(), h); 791 h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h); 792 h = Hash64Combine(static_cast<uint64>(a.minimum()), h); 793 h = Hash64Combine(AttrValueHash(a.allowed_values()), h); 794 return h; 795 } 796 797 bool RepeatedAttrDefEqual( 798 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1, 799 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) { 800 std::unordered_map<string, const OpDef::AttrDef*> a1_set; 801 for (const OpDef::AttrDef& def : a1) { 802 DCHECK(a1_set.find(def.name()) == a1_set.end()) 803 << "AttrDef names must be unique, but '" << def.name() 804 << "' appears more than once"; 805 a1_set[def.name()] = &def; 806 } 807 for (const OpDef::AttrDef& def : a2) { 808 auto iter = a1_set.find(def.name()); 809 if (iter == a1_set.end()) return false; 810 if (!AttrDefEqual(*iter->second, def)) return false; 811 a1_set.erase(iter); 812 } 813 if (!a1_set.empty()) return false; 814 return true; 815 } 816 817 uint64 RepeatedAttrDefHash( 818 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) { 819 // Insert AttrDefs into map to deterministically sort by name 820 std::map<string, const OpDef::AttrDef*> a_set; 821 for (const OpDef::AttrDef& def : a) { 822 a_set[def.name()] = &def; 823 } 824 // Iterate and combines hashes of keys and values 825 uint64 h = 0xDECAFCAFFE; 826 for (const auto& pair : a_set) { 827 h = Hash64(pair.first.data(), pair.first.size(), h); 828 h = Hash64Combine(AttrDefHash(*pair.second), h); 829 } 830 return h; 831 } 832 833 bool OpDefEqual(const OpDef& o1, const OpDef& o2) { 834 // attr order doesn't matter. 835 // Compare it separately here instead of serializing below. 836 if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false; 837 838 // `control_output` order doesn't matter. 839 std::set<string> control_output1(o1.control_output().begin(), 840 o1.control_output().end()); 841 std::set<string> control_output2(o2.control_output().begin(), 842 o2.control_output().end()); 843 if (control_output1 != control_output2) return false; 844 845 // Clear `attr` and `control_output` fields, serialize, and compare serialized 846 // strings. 847 OpDef o1_copy = o1; 848 OpDef o2_copy = o2; 849 o1_copy.clear_attr(); 850 o1_copy.clear_control_output(); 851 o2_copy.clear_attr(); 852 o2_copy.clear_control_output(); 853 854 return AreSerializedProtosEqual(o1_copy, o2_copy); 855 } 856 857 uint64 OpDefHash(const OpDef& o) { 858 uint64 h = RepeatedAttrDefHash(o.attr()); 859 860 // Compute deterministic order-independent control outputs hash. 861 std::set<string> control_output(o.control_output().begin(), 862 o.control_output().end()); 863 for (const auto& co : control_output) h = Hash64Combine(h, Hash64(co)); 864 865 OpDef o_copy = o; 866 o_copy.clear_attr(); 867 o_copy.clear_control_output(); 868 return DeterministicProtoHash64(o_copy, h); 869 } 870 871 } // namespace tensorflow 872