1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_def_builder.h" 17 18 #include <limits> 19 #include <vector> 20 #include "tensorflow/core/framework/attr_value.pb.h" 21 #include "tensorflow/core/framework/attr_value_util.h" 22 #include "tensorflow/core/framework/op_def_util.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/lib/strings/scanner.h" 27 #include "tensorflow/core/lib/strings/str_util.h" 28 #include "tensorflow/core/lib/strings/strcat.h" 29 30 using ::tensorflow::strings::Scanner; 31 32 namespace tensorflow { 33 34 namespace { 35 36 string AttrError(StringPiece orig, const string& op_name) { 37 return strings::StrCat(" from Attr(\"", orig, "\") for Op ", op_name); 38 } 39 40 bool ConsumeAttrName(StringPiece* sp, StringPiece* out) { 41 return Scanner(*sp) 42 .One(Scanner::LETTER) 43 .Any(Scanner::LETTER_DIGIT_UNDERSCORE) 44 .StopCapture() 45 .AnySpace() 46 .OneLiteral(":") 47 .AnySpace() 48 .GetResult(sp, out); 49 } 50 51 bool ConsumeListPrefix(StringPiece* sp) { 52 return Scanner(*sp) 53 .OneLiteral("list") 54 .AnySpace() 55 .OneLiteral("(") 56 .AnySpace() 57 .GetResult(sp); 58 } 59 60 bool ConsumeQuotedString(char quote_ch, StringPiece* sp, StringPiece* out) { 61 const string quote_str(1, quote_ch); 62 return Scanner(*sp) 63 .OneLiteral(quote_str.c_str()) 64 .RestartCapture() 65 .ScanEscapedUntil(quote_ch) 66 .StopCapture() 67 .OneLiteral(quote_str.c_str()) 68 .AnySpace() 69 .GetResult(sp, out); 70 } 71 72 bool ConsumeAttrType(StringPiece* sp, StringPiece* out) { 73 return Scanner(*sp) 74 .Many(Scanner::LOWERLETTER_DIGIT) 75 .StopCapture() 76 .AnySpace() 77 .GetResult(sp, out); 78 } 79 80 bool ConsumeAttrNumber(StringPiece* sp, int64* out) { 81 Scanner scan(*sp); 82 StringPiece match; 83 StringPiece remaining; 84 85 scan.AnySpace().RestartCapture(); 86 if (scan.Peek() == '-') { 87 scan.OneLiteral("-"); 88 } 89 if (!scan.Many(Scanner::DIGIT) 90 .StopCapture() 91 .AnySpace() 92 .GetResult(&remaining, &match)) { 93 return false; 94 } 95 int64 value = 0; 96 if (!strings::safe_strto64(match, &value)) { 97 return false; 98 } 99 *out = value; 100 *sp = remaining; 101 return true; 102 } 103 104 #define VERIFY(expr, ...) \ 105 do { \ 106 if (!(expr)) { \ 107 errors->push_back( \ 108 strings::StrCat(__VA_ARGS__, AttrError(orig, op_def->name()))); \ 109 return; \ 110 } \ 111 } while (false) 112 113 bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) { 114 auto capture_begin = sp->begin(); 115 if (sp->Consume("numbertype") || sp->Consume("numerictype") || 116 sp->Consume("quantizedtype") || sp->Consume("realnumbertype") || 117 sp->Consume("realnumberictype")) { 118 *out = StringPiece(capture_begin, sp->begin() - capture_begin); 119 return true; 120 } 121 return false; 122 } 123 124 bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) { 125 if (type_string == "numbertype" || type_string == "numerictype") { 126 for (DataType dt : NumberTypes()) { 127 allowed->mutable_list()->add_type(dt); 128 } 129 } else if (type_string == "quantizedtype") { 130 for (DataType dt : QuantizedTypes()) { 131 allowed->mutable_list()->add_type(dt); 132 } 133 } else if (type_string == "realnumbertype" || 134 type_string == "realnumerictype") { 135 for (DataType dt : RealNumberTypes()) { 136 allowed->mutable_list()->add_type(dt); 137 } 138 } else { 139 return false; 140 } 141 return true; 142 } 143 144 void FinalizeAttr(StringPiece spec, OpDef* op_def, 145 std::vector<string>* errors) { 146 OpDef::AttrDef* attr = op_def->add_attr(); 147 StringPiece orig(spec); 148 149 // Parse "<name>:" at the beginning. 150 StringPiece tmp_name; 151 VERIFY(ConsumeAttrName(&spec, &tmp_name), "Trouble parsing '<name>:'"); 152 attr->set_name(tmp_name.data(), tmp_name.size()); 153 154 // Read "<type>" or "list(<type>)". 155 bool is_list = ConsumeListPrefix(&spec); 156 string type; 157 StringPiece type_string; // Used if type == "type" 158 if (spec.Consume("string")) { 159 type = "string"; 160 } else if (spec.Consume("int")) { 161 type = "int"; 162 } else if (spec.Consume("float")) { 163 type = "float"; 164 } else if (spec.Consume("bool")) { 165 type = "bool"; 166 } else if (spec.Consume("type")) { 167 type = "type"; 168 } else if (spec.Consume("shape")) { 169 type = "shape"; 170 } else if (spec.Consume("tensor")) { 171 type = "tensor"; 172 } else if (spec.Consume("func")) { 173 type = "func"; 174 } else if (ConsumeCompoundAttrType(&spec, &type_string)) { 175 type = "type"; 176 AttrValue* allowed = attr->mutable_allowed_values(); 177 VERIFY(ProcessCompoundType(type_string, allowed), 178 "Expected to see a compound type, saw: ", type_string); 179 } else if (spec.Consume("{")) { 180 // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" 181 AttrValue* allowed = attr->mutable_allowed_values(); 182 str_util::RemoveLeadingWhitespace(&spec); 183 if (spec.starts_with("\"") || spec.starts_with("'")) { 184 type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" 185 while (true) { 186 StringPiece escaped_string; 187 VERIFY(ConsumeQuotedString('"', &spec, &escaped_string) || 188 ConsumeQuotedString('\'', &spec, &escaped_string), 189 "Trouble parsing allowed string at '", spec, "'"); 190 string unescaped; 191 string error; 192 VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error), 193 "Trouble unescaping \"", escaped_string, 194 "\", got error: ", error); 195 allowed->mutable_list()->add_s(unescaped); 196 if (spec.Consume(",")) { 197 str_util::RemoveLeadingWhitespace(&spec); 198 if (spec.Consume("}")) break; // Allow ending with ", }". 199 } else { 200 VERIFY(spec.Consume("}"), 201 "Expected , or } after strings in list, not: '", spec, "'"); 202 break; 203 } 204 } 205 } else { // "{ bool, numbertype, string }" 206 type = "type"; 207 while (true) { 208 VERIFY(ConsumeAttrType(&spec, &type_string), 209 "Trouble parsing type string at '", spec, "'"); 210 if (ProcessCompoundType(type_string, allowed)) { 211 // Processed a compound type. 212 } else { 213 DataType dt; 214 VERIFY(DataTypeFromString(type_string, &dt), 215 "Unrecognized type string '", type_string, "'"); 216 allowed->mutable_list()->add_type(dt); 217 } 218 if (spec.Consume(",")) { 219 str_util::RemoveLeadingWhitespace(&spec); 220 if (spec.Consume("}")) break; // Allow ending with ", }". 221 } else { 222 VERIFY(spec.Consume("}"), 223 "Expected , or } after types in list, not: '", spec, "'"); 224 break; 225 } 226 } 227 } 228 } else { // if spec.Consume("{") 229 VERIFY(false, "Trouble parsing type string at '", spec, "'"); 230 } 231 str_util::RemoveLeadingWhitespace(&spec); 232 233 // Write the type into *attr. 234 if (is_list) { 235 VERIFY(spec.Consume(")"), "Expected ) to close 'list(', not: '", spec, "'"); 236 str_util::RemoveLeadingWhitespace(&spec); 237 attr->set_type(strings::StrCat("list(", type, ")")); 238 } else { 239 attr->set_type(type); 240 } 241 242 // Read optional minimum constraint at the end. 243 if ((is_list || type == "int") && spec.Consume(">=")) { 244 int64 min_limit = -999; 245 VERIFY(ConsumeAttrNumber(&spec, &min_limit), 246 "Could not parse integer lower limit after '>=', found '", spec, 247 "' instead"); 248 attr->set_has_minimum(true); 249 attr->set_minimum(min_limit); 250 } 251 252 // Parse default value, if present. 253 if (spec.Consume("=")) { 254 str_util::RemoveLeadingWhitespace(&spec); 255 VERIFY(ParseAttrValue(attr->type(), spec, attr->mutable_default_value()), 256 "Could not parse default value '", spec, "'"); 257 } else { 258 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end"); 259 } 260 } 261 262 #undef VERIFY 263 264 string InOutError(bool is_output, StringPiece orig, const string& op_name) { 265 return strings::StrCat(" from ", is_output ? "Output" : "Input", "(\"", orig, 266 "\") for Op ", op_name); 267 } 268 269 bool ConsumeInOutName(StringPiece* sp, StringPiece* out) { 270 return Scanner(*sp) 271 .One(Scanner::LOWERLETTER) 272 .Any(Scanner::LOWERLETTER_DIGIT_UNDERSCORE) 273 .StopCapture() 274 .AnySpace() 275 .OneLiteral(":") 276 .AnySpace() 277 .GetResult(sp, out); 278 } 279 280 bool ConsumeInOutRefOpen(StringPiece* sp) { 281 return Scanner(*sp) 282 .OneLiteral("Ref") 283 .AnySpace() 284 .OneLiteral("(") 285 .AnySpace() 286 .GetResult(sp); 287 } 288 289 bool ConsumeInOutRefClose(StringPiece* sp) { 290 return Scanner(*sp).OneLiteral(")").AnySpace().GetResult(sp); 291 } 292 293 bool ConsumeInOutNameOrType(StringPiece* sp, StringPiece* out) { 294 return Scanner(*sp) 295 .One(Scanner::LETTER) 296 .Any(Scanner::LETTER_DIGIT_UNDERSCORE) 297 .StopCapture() 298 .AnySpace() 299 .GetResult(sp, out); 300 } 301 302 bool ConsumeInOutTimesType(StringPiece* sp, StringPiece* out) { 303 return Scanner(*sp) 304 .OneLiteral("*") 305 .AnySpace() 306 .RestartCapture() 307 .One(Scanner::LETTER) 308 .Any(Scanner::LETTER_DIGIT_UNDERSCORE) 309 .StopCapture() 310 .AnySpace() 311 .GetResult(sp, out); 312 } 313 314 #define VERIFY(expr, ...) \ 315 do { \ 316 if (!(expr)) { \ 317 errors->push_back(strings::StrCat( \ 318 __VA_ARGS__, InOutError(is_output, orig, op_def->name()))); \ 319 return; \ 320 } \ 321 } while (false) 322 323 void FinalizeInputOrOutput(StringPiece spec, bool is_output, OpDef* op_def, 324 std::vector<string>* errors) { 325 OpDef::ArgDef* arg = 326 is_output ? op_def->add_output_arg() : op_def->add_input_arg(); 327 328 StringPiece orig(spec); 329 330 // Parse "<name>:" at the beginning. 331 StringPiece tmp_name; 332 VERIFY(ConsumeInOutName(&spec, &tmp_name), "Trouble parsing 'name:'"); 333 arg->set_name(tmp_name.data(), tmp_name.size()); 334 335 // Detect "Ref(...)". 336 if (ConsumeInOutRefOpen(&spec)) { 337 arg->set_is_ref(true); 338 } 339 340 { // Parse "<name|type>" or "<name>*<name|type>". 341 StringPiece first, second, type_or_attr; 342 VERIFY(ConsumeInOutNameOrType(&spec, &first), 343 "Trouble parsing either a type or an attr name at '", spec, "'"); 344 if (ConsumeInOutTimesType(&spec, &second)) { 345 arg->set_number_attr(first.data(), first.size()); 346 type_or_attr = second; 347 } else { 348 type_or_attr = first; 349 } 350 DataType dt; 351 if (DataTypeFromString(type_or_attr, &dt)) { 352 arg->set_type(dt); 353 } else { 354 const OpDef::AttrDef* attr = FindAttr(type_or_attr, *op_def); 355 VERIFY(attr != nullptr, "Reference to unknown attr '", type_or_attr, "'"); 356 if (attr->type() == "type") { 357 arg->set_type_attr(type_or_attr.data(), type_or_attr.size()); 358 } else { 359 VERIFY(attr->type() == "list(type)", "Reference to attr '", 360 type_or_attr, "' with type ", attr->type(), 361 " that isn't type or list(type)"); 362 arg->set_type_list_attr(type_or_attr.data(), type_or_attr.size()); 363 } 364 } 365 } 366 367 // Closing ) for Ref(. 368 if (arg->is_ref()) { 369 VERIFY(ConsumeInOutRefClose(&spec), 370 "Did not find closing ')' for 'Ref(', instead found: '", spec, "'"); 371 } 372 373 // Should not have anything else. 374 VERIFY(spec.empty(), "Extra '", spec, "' unparsed at the end"); 375 376 // Int attrs that are the length of an input or output get a default 377 // minimum of 1. 378 if (!arg->number_attr().empty()) { 379 OpDef::AttrDef* attr = FindAttrMutable(arg->number_attr(), op_def); 380 if (attr != nullptr && !attr->has_minimum()) { 381 attr->set_has_minimum(true); 382 attr->set_minimum(1); 383 } 384 } else if (!arg->type_list_attr().empty()) { 385 // If an input or output has type specified by a list(type) attr, 386 // it gets a default minimum of 1 as well. 387 OpDef::AttrDef* attr = FindAttrMutable(arg->type_list_attr(), op_def); 388 if (attr != nullptr && attr->type() == "list(type)" && 389 !attr->has_minimum()) { 390 attr->set_has_minimum(true); 391 attr->set_minimum(1); 392 } 393 } 394 395 // If the arg's dtype is resource we should mark the op as stateful as it 396 // likely touches a resource manager. This deliberately doesn't cover inputs / 397 // outputs which resolve to resource via Attrs as those mostly operate on 398 // resource handles as an opaque type (as opposed to ops which explicitly take 399 // / produce resources). 400 if (arg->type() == DT_RESOURCE) { 401 op_def->set_is_stateful(true); 402 } 403 } 404 405 #undef VERIFY 406 407 int num_leading_spaces(StringPiece s) { 408 size_t i = 0; 409 while (i < s.size() && s[i] == ' ') { 410 ++i; 411 } 412 return i; 413 } 414 415 bool ConsumeDocNameColon(StringPiece* sp, StringPiece* out) { 416 return Scanner(*sp) 417 .One(Scanner::LETTER) 418 .Any(Scanner::LETTER_DIGIT_UNDERSCORE) 419 .StopCapture() 420 .AnySpace() 421 .OneLiteral(":") 422 .AnySpace() 423 .GetResult(sp, out); 424 } 425 426 bool IsDocNameColon(StringPiece s) { 427 return ConsumeDocNameColon(&s, nullptr /* out */); 428 } 429 430 void FinalizeDoc(const string& text, OpDef* op_def, 431 std::vector<string>* errors) { 432 std::vector<string> lines = str_util::Split(text, '\n'); 433 434 // Remove trailing spaces. 435 for (string& line : lines) { 436 str_util::StripTrailingWhitespace(&line); 437 } 438 439 // First non-blank line -> summary. 440 int l = 0; 441 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l; 442 if (static_cast<size_t>(l) < lines.size()) { 443 op_def->set_summary(lines[l]); 444 ++l; 445 } 446 while (static_cast<size_t>(l) < lines.size() && lines[l].empty()) ++l; 447 448 // Lines until we see name: -> description. 449 int start_l = l; 450 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) { 451 ++l; 452 } 453 int end_l = l; 454 // Trim trailing blank lines from the description. 455 while (start_l < end_l && lines[end_l - 1].empty()) --end_l; 456 string desc = str_util::Join( 457 gtl::ArraySlice<string>(lines.data() + start_l, end_l - start_l), "\n"); 458 if (!desc.empty()) op_def->set_description(desc); 459 460 // name: description 461 // possibly continued on the next line 462 // if so, we remove the minimum indent 463 StringPiece name; 464 std::vector<StringPiece> description; 465 while (static_cast<size_t>(l) < lines.size()) { 466 description.clear(); 467 description.push_back(lines[l]); 468 ConsumeDocNameColon(&description.back(), &name); 469 ++l; 470 while (static_cast<size_t>(l) < lines.size() && !IsDocNameColon(lines[l])) { 471 description.push_back(lines[l]); 472 ++l; 473 } 474 // Remove any trailing blank lines. 475 while (!description.empty() && description.back().empty()) { 476 description.pop_back(); 477 } 478 // Compute the minimum indent of all lines after the first. 479 int min_indent = -1; 480 for (size_t i = 1; i < description.size(); ++i) { 481 if (!description[i].empty()) { 482 int indent = num_leading_spaces(description[i]); 483 if (min_indent < 0 || indent < min_indent) min_indent = indent; 484 } 485 } 486 // Remove min_indent spaces from all lines after the first. 487 for (size_t i = 1; i < description.size(); ++i) { 488 if (!description[i].empty()) description[i].remove_prefix(min_indent); 489 } 490 // Concatenate lines into a single string. 491 const string complete(str_util::Join(description, "\n")); 492 493 // Find name. 494 bool found = false; 495 for (int i = 0; !found && i < op_def->input_arg_size(); ++i) { 496 if (op_def->input_arg(i).name() == name) { 497 op_def->mutable_input_arg(i)->set_description(complete); 498 found = true; 499 } 500 } 501 for (int i = 0; !found && i < op_def->output_arg_size(); ++i) { 502 if (op_def->output_arg(i).name() == name) { 503 op_def->mutable_output_arg(i)->set_description(complete); 504 found = true; 505 } 506 } 507 for (int i = 0; !found && i < op_def->attr_size(); ++i) { 508 if (op_def->attr(i).name() == name) { 509 op_def->mutable_attr(i)->set_description(complete); 510 found = true; 511 } 512 } 513 if (!found) { 514 errors->push_back( 515 strings::StrCat("No matching input/output/attr for name '", name, 516 "' from Doc() for Op ", op_def->name())); 517 return; 518 } 519 } 520 } 521 522 } // namespace 523 524 OpDefBuilder::OpDefBuilder(StringPiece op_name) { 525 op_def()->set_name(op_name.ToString()); // NOLINT 526 } 527 528 OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) { 529 attrs_.emplace_back(spec.data(), spec.size()); 530 return *this; 531 } 532 533 OpDefBuilder& OpDefBuilder::Input(StringPiece spec) { 534 inputs_.emplace_back(spec.data(), spec.size()); 535 return *this; 536 } 537 538 OpDefBuilder& OpDefBuilder::Output(StringPiece spec) { 539 outputs_.emplace_back(spec.data(), spec.size()); 540 return *this; 541 } 542 543 #ifndef TF_LEAN_BINARY 544 OpDefBuilder& OpDefBuilder::Doc(StringPiece text) { 545 if (!doc_.empty()) { 546 errors_.push_back( 547 strings::StrCat("Extra call to Doc() for Op ", op_def()->name())); 548 } else { 549 doc_.assign(text.data(), text.size()); 550 } 551 return *this; 552 } 553 #endif 554 555 OpDefBuilder& OpDefBuilder::SetIsCommutative() { 556 op_def()->set_is_commutative(true); 557 return *this; 558 } 559 560 OpDefBuilder& OpDefBuilder::SetIsAggregate() { 561 op_def()->set_is_aggregate(true); 562 return *this; 563 } 564 565 OpDefBuilder& OpDefBuilder::SetIsStateful() { 566 op_def()->set_is_stateful(true); 567 return *this; 568 } 569 570 OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() { 571 op_def()->set_allows_uninitialized_input(true); 572 return *this; 573 } 574 575 OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) { 576 if (op_def()->has_deprecation()) { 577 errors_.push_back( 578 strings::StrCat("Deprecated called twice for Op ", op_def()->name())); 579 } else { 580 OpDeprecation* deprecation = op_def()->mutable_deprecation(); 581 deprecation->set_version(version); 582 deprecation->set_explanation(explanation.ToString()); 583 } 584 return *this; 585 } 586 587 OpDefBuilder& OpDefBuilder::SetShapeFn( 588 Status (*fn)(shape_inference::InferenceContext*)) { 589 if (op_reg_data_.shape_inference_fn != nullptr) { 590 errors_.push_back( 591 strings::StrCat("SetShapeFn called twice for Op ", op_def()->name())); 592 } else { 593 op_reg_data_.shape_inference_fn = OpShapeInferenceFn(fn); 594 } 595 return *this; 596 } 597 598 Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { 599 std::vector<string> errors = errors_; 600 *op_reg_data = op_reg_data_; 601 602 OpDef* op_def = &op_reg_data->op_def; 603 for (StringPiece attr : attrs_) { 604 FinalizeAttr(attr, op_def, &errors); 605 } 606 for (StringPiece input : inputs_) { 607 FinalizeInputOrOutput(input, false, op_def, &errors); 608 } 609 for (StringPiece output : outputs_) { 610 FinalizeInputOrOutput(output, true, op_def, &errors); 611 } 612 FinalizeDoc(doc_, op_def, &errors); 613 614 if (errors.empty()) return Status::OK(); 615 return errors::InvalidArgument(str_util::Join(errors, "\n")); 616 } 617 618 } // namespace tensorflow 619