1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_def_util.h" 17 18 #include <set> 19 #include <unordered_map> 20 #include <unordered_set> 21 #include "tensorflow/core/framework/attr_value.pb.h" 22 #include "tensorflow/core/framework/attr_value_util.h" 23 #include "tensorflow/core/framework/op_def.pb_text.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/core/stringpiece.h" 27 #include "tensorflow/core/lib/gtl/map_util.h" 28 #include "tensorflow/core/lib/hash/hash.h" 29 #include "tensorflow/core/lib/strings/scanner.h" 30 #include "tensorflow/core/lib/strings/strcat.h" 31 #include "tensorflow/core/platform/mutex.h" 32 #include "tensorflow/core/platform/protobuf.h" 33 #include "tensorflow/core/platform/types.h" 34 35 namespace tensorflow { 36 namespace { // ------ Helper functions ------ 37 38 bool HasAttrStyleType(const OpDef::ArgDef& arg) { 39 return arg.type() != DT_INVALID || !arg.type_attr().empty() || 40 !arg.type_list_attr().empty(); 41 } 42 43 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { 44 const AttrValue& allowed_values(attr.allowed_values()); 45 for (auto allowed : allowed_values.list().type()) { 46 if (dt == allowed) { 47 return Status::OK(); 48 } 49 } 50 string allowed_str; 51 for (int i = 0; i < allowed_values.list().type_size(); ++i) { 52 if (!allowed_str.empty()) { 53 strings::StrAppend(&allowed_str, ", "); 54 } 55 strings::StrAppend(&allowed_str, 56 DataTypeString(allowed_values.list().type(i))); 57 } 58 return errors::InvalidArgument( 59 "Value for attr '", attr.name(), "' of ", DataTypeString(dt), 60 " is not in the list of allowed values: ", allowed_str); 61 } 62 63 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { 64 const AttrValue& allowed_values(attr.allowed_values()); 65 for (const auto& allowed : allowed_values.list().s()) { 66 if (str == allowed) { 67 return Status::OK(); 68 } 69 } 70 string allowed_str; 71 for (const string& allowed : allowed_values.list().s()) { 72 if (!allowed_str.empty()) { 73 strings::StrAppend(&allowed_str, ", "); 74 } 75 strings::StrAppend(&allowed_str, "\"", allowed, "\""); 76 } 77 return errors::InvalidArgument( 78 "Value for attr '", attr.name(), "' of \"", str, 79 "\" is not in the list of allowed values: ", allowed_str); 80 } 81 82 } // namespace 83 84 // Requires: attr has already been validated. 85 Status ValidateAttrValue(const AttrValue& attr_value, 86 const OpDef::AttrDef& attr) { 87 // Is it a valid value? 88 TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()), 89 " for attr '", attr.name(), "'"); 90 91 // Does the value satisfy the minimum constraint in the AttrDef? 92 if (attr.has_minimum()) { 93 if (attr.type() == "int") { 94 if (attr_value.i() < attr.minimum()) { 95 return errors::InvalidArgument( 96 "Value for attr '", attr.name(), "' of ", attr_value.i(), 97 " must be at least minimum ", attr.minimum()); 98 } 99 } else { 100 int length = -1; 101 if (attr.type() == "list(string)") { 102 length = attr_value.list().s_size(); 103 } else if (attr.type() == "list(int)") { 104 length = attr_value.list().i_size(); 105 } else if (attr.type() == "list(float)") { 106 length = attr_value.list().f_size(); 107 } else if (attr.type() == "list(bool)") { 108 length = attr_value.list().b_size(); 109 } else if (attr.type() == "list(type)") { 110 length = attr_value.list().type_size(); 111 } else if (attr.type() == "list(shape)") { 112 length = attr_value.list().shape_size(); 113 } else if (attr.type() == "list(tensor)") { 114 length = attr_value.list().tensor_size(); 115 } 116 if (length < attr.minimum()) { 117 return errors::InvalidArgument( 118 "Length for attr '", attr.name(), "' of ", length, 119 " must be at least minimum ", attr.minimum()); 120 } 121 } 122 } 123 124 // Does the value satisfy the allowed_value constraint in the AttrDef? 125 if (attr.has_allowed_values()) { 126 if (attr.type() == "type") { 127 TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr)); 128 } else if (attr.type() == "list(type)") { 129 for (int dt : attr_value.list().type()) { 130 TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr)); 131 } 132 } else if (attr.type() == "string") { 133 TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr)); 134 } else if (attr.type() == "list(string)") { 135 for (const string& str : attr_value.list().s()) { 136 TF_RETURN_IF_ERROR(AllowedStringValue(str, attr)); 137 } 138 } else { 139 return errors::Unimplemented( 140 "Support for allowed_values not implemented for type ", attr.type()); 141 } 142 } 143 return Status::OK(); 144 } 145 146 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { 147 for (int i = 0; i < op_def.attr_size(); ++i) { 148 if (op_def.attr(i).name() == name) { 149 return &op_def.attr(i); 150 } 151 } 152 return nullptr; 153 } 154 155 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { 156 for (int i = 0; i < op_def->attr_size(); ++i) { 157 if (op_def->attr(i).name() == name) { 158 return op_def->mutable_attr(i); 159 } 160 } 161 return nullptr; 162 } 163 164 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) { 165 for (int i = 0; i < op_def.input_arg_size(); ++i) { 166 if (op_def.input_arg(i).name() == name) { 167 return &op_def.input_arg(i); 168 } 169 } 170 return nullptr; 171 } 172 173 #define VALIDATE(EXPR, ...) \ 174 do { \ 175 if (!(EXPR)) { \ 176 return errors::InvalidArgument( \ 177 __VA_ARGS__, "; in OpDef: ", ProtoShortDebugString(op_def)); \ 178 } \ 179 } while (false) 180 181 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, 182 bool output, std::set<string>* names) { 183 const string suffix = strings::StrCat( 184 output ? " for output '" : " for input '", arg.name(), "'"); 185 VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), 186 "Duplicate name: ", arg.name()); 187 VALIDATE(HasAttrStyleType(arg), "Missing type", suffix); 188 189 if (!arg.number_attr().empty()) { 190 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); 191 VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'", 192 suffix); 193 VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length", 194 suffix, " has type ", attr->type(), " != int"); 195 VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length", 196 suffix, " must have minimum"); 197 VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length", 198 suffix, " must have minimum >= 0"); 199 VALIDATE(arg.type_list_attr().empty(), 200 "Can't have both number_attr and type_list_attr", suffix); 201 VALIDATE((arg.type() != DT_INVALID ? 1 : 0) + 202 (!arg.type_attr().empty() ? 1 : 0) == 203 1, 204 "Exactly one of type, type_attr must be set", suffix); 205 } else { 206 const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) + 207 (!arg.type_attr().empty() ? 1 : 0) + 208 (!arg.type_list_attr().empty() ? 1 : 0); 209 VALIDATE(num_type_fields == 1, 210 "Exactly one of type, type_attr, type_list_attr must be set", 211 suffix); 212 } 213 214 if (!arg.type_attr().empty()) { 215 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); 216 VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'", 217 suffix); 218 VALIDATE(attr->type() == "type", "Attr '", attr->name(), 219 "' used as type_attr", suffix, " has type ", attr->type(), 220 " != type"); 221 } else if (!arg.type_list_attr().empty()) { 222 const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def); 223 VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'", 224 suffix); 225 VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(), 226 "' used as type_list_attr", suffix, " has type ", attr->type(), 227 " != list(type)"); 228 } else { 229 // All argument types should be non-reference types at this point. 230 // ArgDef.is_ref is set to true for reference arguments. 231 VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '", 232 DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix); 233 } 234 235 return Status::OK(); 236 } 237 238 Status ValidateOpDef(const OpDef& op_def) { 239 using ::tensorflow::strings::Scanner; 240 241 if (!StringPiece(op_def.name()).starts_with("_")) { 242 VALIDATE(Scanner(op_def.name()) 243 .One(Scanner::UPPERLETTER) 244 .Any(Scanner::LETTER_DIGIT) 245 .Eos() 246 .GetResult(), 247 "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); 248 } 249 250 std::set<string> names; // for detecting duplicate names 251 for (const auto& attr : op_def.attr()) { 252 // Validate name 253 VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), 254 "Duplicate name: ", attr.name()); 255 DataType dt; 256 VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ", 257 attr.name(), " that matches a data type"); 258 259 // Validate type 260 StringPiece type(attr.type()); 261 bool is_list = type.Consume("list("); 262 bool found = false; 263 for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape", 264 "tensor", "func"}) { 265 if (type.Consume(valid)) { 266 found = true; 267 break; 268 } 269 } 270 VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(), 271 "'"); 272 if (is_list) { 273 VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ", 274 attr.name(), "'s type ", attr.type()); 275 } 276 VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ", 277 attr.name(), "'s type ", attr.type()); 278 279 // Validate minimum 280 if (attr.has_minimum()) { 281 VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(), 282 "' has minimum for unsupported type ", attr.type()); 283 if (is_list) { 284 VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(), 285 "' with list type must have a non-negative minimum, not ", 286 attr.minimum()); 287 } 288 } else { 289 VALIDATE(attr.minimum() == 0, "Attr '", attr.name(), 290 "' with has_minimum = false but minimum ", attr.minimum(), 291 " not equal to default of 0"); 292 } 293 294 // Validate allowed_values 295 if (attr.has_allowed_values()) { 296 const string list_type = 297 is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")"); 298 TF_RETURN_WITH_CONTEXT_IF_ERROR( 299 AttrValueHasType(attr.allowed_values(), list_type), " for attr '", 300 attr.name(), "' in Op '", op_def.name(), "'"); 301 } 302 303 // Validate default_value (after we have validated the rest of the attr, 304 // so we can use ValidateAttrValue()). 305 if (attr.has_default_value()) { 306 TF_RETURN_WITH_CONTEXT_IF_ERROR( 307 ValidateAttrValue(attr.default_value(), attr), " in Op '", 308 op_def.name(), "'"); 309 } 310 } 311 312 for (const auto& arg : op_def.input_arg()) { 313 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names)); 314 } 315 316 for (const auto& arg : op_def.output_arg()) { 317 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); 318 } 319 320 return Status::OK(); 321 } 322 323 #undef VALIDATE 324 325 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { 326 if (op_def.has_deprecation()) { 327 const OpDeprecation& dep = op_def.deprecation(); 328 if (graph_def_version >= dep.version()) { 329 return errors::Unimplemented( 330 "Op ", op_def.name(), " is not available in GraphDef version ", 331 graph_def_version, ". It has been removed in version ", dep.version(), 332 ". ", dep.explanation(), "."); 333 } else { 334 // Warn only once for each op name, and do it in a threadsafe manner. 335 static mutex mu(LINKER_INITIALIZED); 336 static std::unordered_set<string> warned; 337 bool warn; 338 { 339 mutex_lock lock(mu); 340 warn = warned.insert(op_def.name()).second; 341 } 342 if (warn) { 343 LOG(WARNING) << "Op " << op_def.name() << " is deprecated." 344 << " It will cease to work in GraphDef version " 345 << dep.version() << ". " << dep.explanation() << "."; 346 } 347 } 348 } 349 return Status::OK(); 350 } 351 352 namespace { 353 354 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { 355 string ret; 356 for (const OpDef::ArgDef& arg : args) { 357 if (!ret.empty()) strings::StrAppend(&ret, ", "); 358 strings::StrAppend(&ret, arg.name(), ":"); 359 if (arg.is_ref()) strings::StrAppend(&ret, "Ref("); 360 if (!arg.number_attr().empty()) { 361 strings::StrAppend(&ret, arg.number_attr(), "*"); 362 } 363 if (arg.type() != DT_INVALID) { 364 strings::StrAppend(&ret, DataTypeString(arg.type())); 365 } else { 366 strings::StrAppend(&ret, arg.type_attr()); 367 } 368 if (arg.is_ref()) strings::StrAppend(&ret, ")"); 369 } 370 return ret; 371 } 372 373 } // namespace 374 375 string SummarizeOpDef(const OpDef& op_def) { 376 string ret = strings::StrCat("Op<name=", op_def.name()); 377 strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()), 378 " -> ", SummarizeArgs(op_def.output_arg())); 379 for (int i = 0; i < op_def.attr_size(); ++i) { 380 strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":", 381 op_def.attr(i).type()); 382 if (op_def.attr(i).has_default_value()) { 383 strings::StrAppend(&ret, ",default=", 384 SummarizeAttrValue(op_def.attr(i).default_value())); 385 } 386 if (op_def.attr(i).has_minimum()) { 387 strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum()); 388 } 389 if (op_def.attr(i).has_allowed_values()) { 390 strings::StrAppend(&ret, ",allowed=", 391 SummarizeAttrValue(op_def.attr(i).allowed_values())); 392 } 393 } 394 if (op_def.is_commutative()) { 395 strings::StrAppend(&ret, "; is_commutative=true"); 396 } 397 if (op_def.is_aggregate()) { 398 strings::StrAppend(&ret, "; is_aggregate=true"); 399 } 400 if (op_def.is_stateful()) { 401 strings::StrAppend(&ret, "; is_stateful=true"); 402 } 403 if (op_def.allows_uninitialized_input()) { 404 strings::StrAppend(&ret, "; allows_uninitialized_input=true"); 405 } 406 strings::StrAppend(&ret, ">"); 407 return ret; 408 } 409 410 namespace { 411 412 // Returns true if every element of `sub` is contained in `super`. 413 template <class T> 414 bool IsSubsetOf(const T& sub, const T& super) { 415 for (const auto& o : sub) { 416 bool found = false; 417 for (const auto& n : super) { 418 if (o == n) { 419 found = true; 420 break; 421 } 422 } 423 if (!found) return false; 424 } 425 return true; 426 } 427 428 bool MoreRestrictive(const OpDef::AttrDef& old_attr, 429 const OpDef::AttrDef& new_attr) { 430 // Anything -> no restriction : not more restrictive. 431 if (!new_attr.has_allowed_values()) return false; 432 // No restriction -> restriction : more restrictive. 433 if (!old_attr.has_allowed_values()) return true; 434 // If anything that was previously allowed is no longer allowed: 435 // more restrictive. 436 if (!IsSubsetOf(old_attr.allowed_values().list().type(), 437 new_attr.allowed_values().list().type())) { 438 return true; 439 } 440 if (!IsSubsetOf(old_attr.allowed_values().list().s(), 441 new_attr.allowed_values().list().s())) { 442 return true; 443 } 444 return false; 445 } 446 447 string AllowedStr(const OpDef::AttrDef& attr) { 448 if (!attr.has_allowed_values()) return "no restriction"; 449 return SummarizeAttrValue(attr.allowed_values()); 450 } 451 452 string DefaultAttrStr(const OpDef::AttrDef& attr) { 453 if (!attr.has_default_value()) return "no default"; 454 return SummarizeAttrValue(attr.default_value()); 455 } 456 457 bool HigherMinimum(const OpDef::AttrDef& old_attr, 458 const OpDef::AttrDef& new_attr) { 459 // Anything -> no restriction : not more restrictive. 460 if (!new_attr.has_minimum()) return false; 461 // No restriction -> restriction : more restrictive. 462 if (!old_attr.has_minimum()) return true; 463 // If anything that was previously allowed is no longer allowed: 464 // more restrictive. 465 return new_attr.minimum() > old_attr.minimum(); 466 } 467 468 string MinStr(const OpDef::AttrDef& attr) { 469 if (!attr.has_minimum()) return "no minimum"; 470 return strings::StrCat(attr.minimum()); 471 } 472 473 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap; 474 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) { 475 for (const auto& attr : op_def.attr()) { 476 (*attr_map)[attr.name()] = &attr; 477 } 478 } 479 480 // Add a comma to *s every call but the first (*add_comma should be 481 // initialized to false). 482 void AddComma(string* s, bool* add_comma) { 483 if (*add_comma) { 484 strings::StrAppend(s, ", "); 485 } else { 486 *add_comma = true; 487 } 488 } 489 490 // Will add the `name` from arg if name is true. 491 void AddName(string* s, bool name, const OpDef::ArgDef& arg) { 492 if (name) { 493 strings::StrAppend(s, arg.name(), ":"); 494 } 495 } 496 497 // Compute a signature for either inputs or outputs that will be the 498 // same for both the old and new OpDef if they are compatible. We 499 // assume that new_attrs is a superset of old_attrs, and that any attr 500 // in the difference has a default. Our strategy is to make a list of 501 // types, where the types are things like: 502 // * "int32", "float", etc., 503 // * "T" for some attr "T" in old_attrs, or 504 // * "N * type" for "N" either some attr in old_attrs. 505 // 506 // We get the types by either using the attrs in args if they are in 507 // old_attrs, or substituting the default value from new_attrs. 508 string ComputeArgSignature( 509 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, 510 const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref, 511 bool names) { 512 string s; 513 bool add_comma = false; 514 for (const OpDef::ArgDef& arg : args) { 515 if (!arg.type_list_attr().empty()) { 516 const OpDef::AttrDef* old_attr = 517 gtl::FindPtrOrNull(old_attrs, arg.type_list_attr()); 518 if (old_attr) { 519 // Both old and new have the list(type) attr, so can use it directly. 520 AddComma(&s, &add_comma); 521 AddName(&s, names, arg); 522 strings::StrAppend(&s, arg.type_list_attr()); 523 ref->push_back(arg.is_ref()); 524 } else { 525 // Missing the list(type) attr in the old, so use the default 526 // value for the attr from new instead. 527 const OpDef::AttrDef* new_attr = 528 gtl::FindPtrOrNull(new_attrs, arg.type_list_attr()); 529 const auto& type_list = new_attr->default_value().list().type(); 530 if (type_list.empty()) continue; 531 for (int i = 0; i < type_list.size(); ++i) { 532 AddComma(&s, &add_comma); 533 AddName(&s, names, arg); 534 strings::StrAppend( 535 &s, DataTypeString(static_cast<DataType>(type_list.Get(i)))); 536 ref->push_back(arg.is_ref()); 537 } 538 } 539 } else { 540 int num = 1; // How many input/outputs does this represent? 541 string type; // What is the type of this arg? 542 AddName(&type, names, arg); 543 if (!arg.number_attr().empty()) { 544 // N * type case. 545 const OpDef::AttrDef* old_attr = 546 gtl::FindPtrOrNull(old_attrs, arg.number_attr()); 547 if (old_attr) { 548 // Both old and new have the number attr, so can use it directly. 549 strings::StrAppend(&type, arg.number_attr(), " * "); 550 } else { 551 // Missing the number attr in the old, so use the default 552 // value for the attr from new instead. 553 const OpDef::AttrDef* new_attr = 554 gtl::FindPtrOrNull(new_attrs, arg.number_attr()); 555 num = new_attr->default_value().i(); 556 } 557 } 558 559 if (arg.type() != DT_INVALID) { 560 // int32, float, etc. case 561 strings::StrAppend(&type, DataTypeString(arg.type())); 562 } else { 563 const OpDef::AttrDef* old_attr = 564 gtl::FindPtrOrNull(old_attrs, arg.type_attr()); 565 if (old_attr) { 566 // Both old and new have the type attr, so can use it directly. 567 strings::StrAppend(&type, arg.type_attr()); 568 } else { 569 // Missing the type attr in the old, so use the default 570 // value for the attr from new instead. 571 const OpDef::AttrDef* new_attr = 572 gtl::FindPtrOrNull(new_attrs, arg.type_attr()); 573 strings::StrAppend(&type, 574 DataTypeString(new_attr->default_value().type())); 575 } 576 } 577 578 // Record `num` * `type` in the signature. 579 for (int i = 0; i < num; ++i) { 580 AddComma(&s, &add_comma); 581 strings::StrAppend(&s, type); 582 ref->push_back(arg.is_ref()); 583 } 584 } 585 } 586 587 return s; 588 } 589 590 } // namespace 591 592 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { 593 #define VALIDATE(CONDITION, ...) \ 594 if (!(CONDITION)) { \ 595 return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \ 596 "; old: ", SummarizeOpDef(old_op), \ 597 "; new: ", SummarizeOpDef(new_op)); \ 598 } 599 600 VALIDATE(old_op.name() == new_op.name(), "Name mismatch"); 601 602 AttrMap new_attrs, old_attrs; 603 FillAttrMap(old_op, &old_attrs); 604 FillAttrMap(new_op, &new_attrs); 605 for (const auto& old_attr : old_op.attr()) { 606 const OpDef::AttrDef* new_attr = 607 gtl::FindPtrOrNull(new_attrs, old_attr.name()); 608 VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed"); 609 VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(), 610 "' changed type '", old_attr.type(), "' -> '", new_attr->type(), 611 "'"); 612 VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(), 613 "' has a stricter set of allowed values; from ", 614 AllowedStr(old_attr), " to ", AllowedStr(*new_attr)); 615 VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(), 616 "' has a higher minimum; from ", MinStr(old_attr), " to ", 617 MinStr(*new_attr)); 618 } 619 620 for (const auto& new_attr : new_op.attr()) { 621 const OpDef::AttrDef* old_attr = 622 gtl::FindPtrOrNull(old_attrs, new_attr.name()); 623 VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '", 624 new_attr.name(), "' added without default"); 625 } 626 627 std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref; 628 const string old_in_sig = ComputeArgSignature( 629 old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */); 630 const string new_in_sig = ComputeArgSignature( 631 new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */); 632 VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig, 633 "' vs. '", new_in_sig, "'"); 634 VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen 635 "Unexpected change in input ref lists."); 636 for (int i = 0; i < old_in_ref.size(); ++i) { 637 // Allowed to remove "ref" from an input (or leave it unchanged). 638 VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i, 639 " changed from non-ref to ref"); 640 } 641 642 const string old_out_sig = 643 ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs, 644 &old_out_ref, true /* names */); 645 const string new_out_sig = 646 ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs, 647 &new_out_ref, true /* names */); 648 VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '", 649 old_out_sig, "' vs. '", new_out_sig, "'"); 650 VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen 651 "Unexpected change in output ref lists"); 652 for (int i = 0; i < old_out_ref.size(); ++i) { 653 // Allowed to add "ref" to an output (or leave it unchanged). 654 VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i, 655 " changed from ref to non-ref"); 656 } 657 658 return Status::OK(); 659 } 660 661 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, 662 const OpDef& penultimate_op, 663 const OpDef& new_op) { 664 AttrMap new_attrs, old_attrs; 665 FillAttrMap(old_op, &old_attrs); 666 FillAttrMap(new_op, &new_attrs); 667 668 for (const auto& penultimate_attr : penultimate_op.attr()) { 669 const OpDef::AttrDef* old_attr = 670 gtl::FindPtrOrNull(old_attrs, penultimate_attr.name()); 671 if (old_attr != nullptr) continue; // attr wasn't added 672 const OpDef::AttrDef* new_attr = 673 gtl::FindPtrOrNull(new_attrs, penultimate_attr.name()); 674 675 // These shouldn't happen if the op passed OpDefCompatible(). 676 if (new_attr == nullptr) { 677 return errors::InvalidArgument("Missing attr '", penultimate_attr.name(), 678 "' in op: ", SummarizeOpDef(new_op)); 679 } 680 if (!penultimate_attr.has_default_value() || 681 !new_attr->has_default_value()) { 682 return errors::InvalidArgument("Missing default for attr '", 683 penultimate_attr.name(), 684 "' in op: ", SummarizeOpDef(new_op)); 685 } 686 687 // Actually test that the attr's default value hasn't changed. 688 if (!AreAttrValuesEqual(penultimate_attr.default_value(), 689 new_attr->default_value())) { 690 return errors::InvalidArgument( 691 "Can't change default value for attr '", penultimate_attr.name(), 692 "' from ", SummarizeAttrValue(penultimate_attr.default_value()), 693 " in op: ", SummarizeOpDef(new_op)); 694 } 695 } 696 697 return Status::OK(); 698 } 699 700 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { 701 AttrMap new_attrs, old_attrs; 702 FillAttrMap(old_op, &old_attrs); 703 FillAttrMap(new_op, &new_attrs); 704 705 for (const auto& old_attr : old_op.attr()) { 706 const OpDef::AttrDef* new_attr = 707 gtl::FindPtrOrNull(new_attrs, old_attr.name()); 708 if (new_attr == nullptr) continue; 709 if (old_attr.has_default_value() != new_attr->has_default_value()) { 710 return errors::InvalidArgument( 711 "Attr '", old_attr.name(), "' has added/removed it's default; ", 712 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr)); 713 } 714 if (old_attr.has_default_value() && 715 !AreAttrValuesEqual(old_attr.default_value(), 716 new_attr->default_value())) { 717 return errors::InvalidArgument( 718 "Attr '", old_attr.name(), "' has changed it's default value; ", 719 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr)); 720 } 721 } 722 723 return Status::OK(); 724 } 725 726 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) { 727 for (int i = 0; i < op_def->input_arg_size(); ++i) { 728 op_def->mutable_input_arg(i)->clear_description(); 729 } 730 for (int i = 0; i < op_def->output_arg_size(); ++i) { 731 op_def->mutable_output_arg(i)->clear_description(); 732 } 733 for (int i = 0; i < op_def->attr_size(); ++i) { 734 op_def->mutable_attr(i)->clear_description(); 735 } 736 op_def->clear_summary(); 737 op_def->clear_description(); 738 } 739 740 void RemoveDescriptionsFromOpDef(OpDef* op_def) { 741 RemoveNonDeprecationDescriptionsFromOpDef(op_def); 742 if (op_def->has_deprecation()) { 743 op_def->mutable_deprecation()->clear_explanation(); 744 } 745 } 746 747 void RemoveDescriptionsFromOpList(OpList* op_list) { 748 for (int i = 0; i < op_list->op_size(); ++i) { 749 OpDef* op_def = op_list->mutable_op(i); 750 RemoveDescriptionsFromOpDef(op_def); 751 } 752 } 753 754 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) { 755 #ifndef TENSORFLOW_LITE_PROTOS 756 DCHECK_EQ(7, a1.GetDescriptor()->field_count()) 757 << "Please modify these equality and hash functions to reflect the " 758 "changes to the AttrDef protobuf"; 759 #endif // TENSORFLOW_LITE_PROTOS 760 761 if (a1.name() != a2.name()) return false; 762 if (a1.type() != a2.type()) return false; 763 if (a1.description() != a2.description()) return false; 764 if (a1.has_minimum() != a2.has_minimum()) return false; 765 if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false; 766 if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false; 767 if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values())) 768 return false; 769 return true; 770 } 771 772 uint64 AttrDefHash(const OpDef::AttrDef& a) { 773 uint64 h = Hash64(a.name()); 774 h = Hash64(a.type().data(), a.type().size(), h); 775 h = Hash64Combine(AttrValueHash(a.default_value()), h); 776 h = Hash64(a.description().data(), a.description().size(), h); 777 h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h); 778 h = Hash64Combine(static_cast<uint64>(a.minimum()), h); 779 h = Hash64Combine(AttrValueHash(a.allowed_values()), h); 780 return h; 781 } 782 783 bool RepeatedAttrDefEqual( 784 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1, 785 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) { 786 std::unordered_map<string, const OpDef::AttrDef*> a1_set; 787 for (const OpDef::AttrDef& def : a1) { 788 DCHECK(a1_set.find(def.name()) == a1_set.end()) 789 << "AttrDef names must be unique, but '" << def.name() 790 << "' appears more than once"; 791 a1_set[def.name()] = &def; 792 } 793 for (const OpDef::AttrDef& def : a2) { 794 auto iter = a1_set.find(def.name()); 795 if (iter == a1_set.end()) return false; 796 if (!AttrDefEqual(*iter->second, def)) return false; 797 a1_set.erase(iter); 798 } 799 if (!a1_set.empty()) return false; 800 return true; 801 } 802 803 uint64 RepeatedAttrDefHash( 804 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) { 805 // Insert AttrDefs into map to deterministically sort by name 806 std::map<string, const OpDef::AttrDef*> a_set; 807 for (const OpDef::AttrDef& def : a) { 808 a_set[def.name()] = &def; 809 } 810 // Iterate and combines hashes of keys and values 811 uint64 h = 0xDECAFCAFFE; 812 for (const auto& pair : a_set) { 813 h = Hash64(pair.first.data(), pair.first.size(), h); 814 h = Hash64Combine(AttrDefHash(*pair.second), h); 815 } 816 return h; 817 } 818 819 bool OpDefEqual(const OpDef& o1, const OpDef& o2) { 820 // attr order doesn't matter. 821 // Compare it separately here instead of serializing below. 822 if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false; 823 824 // Clear attr field, serialize, and compare serialized strings 825 OpDef o1_copy = o1; 826 OpDef o2_copy = o2; 827 o1_copy.clear_attr(); 828 o2_copy.clear_attr(); 829 string s1, s2; 830 SerializeToStringDeterministic(o1_copy, &s1); 831 SerializeToStringDeterministic(o2_copy, &s2); 832 if (s1 != s2) return false; 833 return true; 834 } 835 836 uint64 OpDefHash(const OpDef& o) { 837 uint64 h = RepeatedAttrDefHash(o.attr()); 838 OpDef o_copy = o; 839 o_copy.clear_attr(); 840 string s; 841 SerializeToStringDeterministic(o_copy, &s); 842 return Hash64(s.data(), s.size(), h); 843 } 844 845 } // namespace tensorflow 846