1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h" 17 18 #include <algorithm> 19 #include <set> 20 #include <unordered_set> 21 22 #include "tensorflow/core/platform/logging.h" 23 #include "tensorflow/core/platform/macros.h" 24 #include "tensorflow/core/platform/types.h" 25 26 using ::tensorflow::protobuf::Descriptor; 27 using ::tensorflow::protobuf::EnumDescriptor; 28 using ::tensorflow::protobuf::FieldDescriptor; 29 using ::tensorflow::protobuf::FieldOptions; 30 using ::tensorflow::protobuf::FileDescriptor; 31 32 namespace tensorflow { 33 34 namespace { 35 36 template <typename... Args> 37 string StrCat(const Args&... args) { 38 std::ostringstream s; 39 std::vector<int>{((s << args), 0)...}; 40 return s.str(); 41 } 42 43 template <typename... Args> 44 string StrAppend(string* to_append, const Args&... args) { 45 *to_append += StrCat(args...); 46 return *to_append; 47 } 48 49 // Class used to generate the code for proto text functions. One of these should 50 // be created for each FileDescriptor whose code should be generated. 51 // 52 // This class has a notion of the current output Section. The Print, Nested, 53 // and Unnest functions apply their operations to the current output section, 54 // which can be toggled with SetOutput. 55 // 56 // Note that on the generated code, various pieces are not optimized - for 57 // example: map input and output, Cord input and output, comparisons against 58 // the field names (it's a loop over all names), and tracking of has_seen. 59 class Generator { 60 public: 61 explicit Generator(const string& tf_header_prefix) 62 : tf_header_prefix_(tf_header_prefix), 63 header_(&code_.header), 64 header_impl_(&code_.header_impl), 65 cc_(&code_.cc) {} 66 67 void Generate(const FileDescriptor& fd); 68 69 // The generated code; valid after Generate has been called. 70 ProtoTextFunctionCode code() const { return code_; } 71 72 private: 73 struct Section { 74 explicit Section(string* str) : str(str) {} 75 string* str; 76 string indent; 77 }; 78 79 // Switches the currently active section to <section>. 80 Generator& SetOutput(Section* section) { 81 cur_ = section; 82 return *this; 83 } 84 85 // Increases indent level. Returns <*this>, to allow chaining. 86 Generator& Nest() { 87 StrAppend(&cur_->indent, " "); 88 return *this; 89 } 90 91 // Decreases indent level. Returns <*this>, to allow chaining. 92 Generator& Unnest() { 93 cur_->indent = cur_->indent.substr(0, cur_->indent.size() - 2); 94 return *this; 95 } 96 97 // Appends the concatenated args, with a trailing newline. Returns <*this>, to 98 // allow chaining. 99 template <typename... Args> 100 Generator& Print(Args... args) { 101 StrAppend(cur_->str, cur_->indent, args..., "\n"); 102 return *this; 103 } 104 105 // Appends the print code for a single field's value. 106 // If <omit_default> is true, then the emitted code will not print zero-valued 107 // values. 108 // <field_expr> is code that when emitted yields the field's value. 109 void AppendFieldValueAppend(const FieldDescriptor& field, 110 const bool omit_default, 111 const string& field_expr); 112 113 // Appends the print code for as single field. 114 void AppendFieldAppend(const FieldDescriptor& field); 115 116 // Appends the print code for a message. May change which section is currently 117 // active. 118 void AppendDebugStringFunctions(const Descriptor& md); 119 120 // Appends the print and parse functions for an enum. May change which 121 // section is currently active. 122 void AppendEnumFunctions(const EnumDescriptor& enum_d); 123 124 // Appends the parse functions for a message. May change which section is 125 // currently active. 126 void AppendParseMessageFunction(const Descriptor& md); 127 128 // Appends all functions for a message and its nested message and enum types. 129 // May change which section is currently active. 130 void AppendMessageFunctions(const Descriptor& md); 131 132 // Appends lines to open or close namespace declarations. 133 void AddNamespaceToCurrentSection(const string& package, bool open); 134 135 // Appends the given headers as sorted #include lines. 136 void AddHeadersToCurrentSection(const std::vector<string>& headers); 137 138 // When adding #includes for tensorflow headers, prefix them with this. 139 const string tf_header_prefix_; 140 ProtoTextFunctionCode code_; 141 Section* cur_ = nullptr; 142 Section header_; 143 Section header_impl_; 144 Section cc_; 145 146 std::unordered_set<string> map_append_signatures_included_; 147 148 TF_DISALLOW_COPY_AND_ASSIGN(Generator); 149 }; 150 151 // Returns the prefix needed to reference objects defined in <fd>. E.g. 152 // "::tensorflow::test". 153 string GetPackageReferencePrefix(const FileDescriptor* fd) { 154 string result = "::"; 155 const string& package = fd->package(); 156 for (size_t i = 0; i < package.size(); ++i) { 157 if (package[i] == '.') { 158 result += "::"; 159 } else { 160 result += package[i]; 161 } 162 } 163 result += "::"; 164 return result; 165 } 166 167 // Returns the name of the class generated by proto to represent <d>. 168 string GetClassName(const Descriptor& d) { 169 if (d.containing_type() == nullptr) return d.name(); 170 return StrCat(GetClassName(*d.containing_type()), "_", d.name()); 171 } 172 173 // Returns the name of the class generated by proto to represent <ed>. 174 string GetClassName(const EnumDescriptor& ed) { 175 if (ed.containing_type() == nullptr) return ed.name(); 176 return StrCat(GetClassName(*ed.containing_type()), "_", ed.name()); 177 } 178 179 // Returns the qualified name that refers to the class generated by proto to 180 // represent <d>. 181 string GetQualifiedName(const Descriptor& d) { 182 return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d)); 183 } 184 185 // Returns the qualified name that refers to the class generated by proto to 186 // represent <ed>. 187 string GetQualifiedName(const EnumDescriptor& d) { 188 return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d)); 189 } 190 191 // Returns the qualified name that refers to the generated 192 // AppendProtoDebugString function for <d>. 193 string GetQualifiedAppendFn(const Descriptor& d) { 194 return StrCat(GetPackageReferencePrefix(d.file()), 195 "internal::AppendProtoDebugString"); 196 } 197 198 // Returns the name of the generated function that returns an enum value's 199 // string value. 200 string GetEnumNameFn(const EnumDescriptor& enum_d) { 201 return StrCat("EnumName_", GetClassName(enum_d)); 202 } 203 204 // Returns the qualified name of the function returned by GetEnumNameFn(). 205 string GetQualifiedEnumNameFn(const EnumDescriptor& enum_d) { 206 return StrCat(GetPackageReferencePrefix(enum_d.file()), 207 GetEnumNameFn(enum_d)); 208 } 209 210 // Returns the name of a generated header file, either the public api (if impl 211 // is false) or the internal implementation header (if impl is true). 212 string GetProtoTextHeaderName(const FileDescriptor& fd, bool impl) { 213 const int dot_index = fd.name().find_last_of('.'); 214 return fd.name().substr(0, dot_index) + 215 (impl ? ".pb_text-impl.h" : ".pb_text.h"); 216 } 217 218 // Returns the name of the header generated by the proto library for <fd>. 219 string GetProtoHeaderName(const FileDescriptor& fd) { 220 const int dot_index = fd.name().find_last_of('.'); 221 return fd.name().substr(0, dot_index) + ".pb.h"; 222 } 223 224 // Returns the C++ class name for the given proto field. 225 string GetCppClass(const FieldDescriptor& d) { 226 string cpp_class = d.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE 227 ? GetQualifiedName(*d.message_type()) 228 : d.cpp_type_name(); 229 230 // In open-source TensorFlow, the definition of int64 varies across 231 // platforms. The following line, which is manipulated during internal- 232 // external sync'ing, takes care of the variability. 233 if (cpp_class == "int64") { 234 cpp_class = kProtobufInt64Typename; 235 } 236 237 return cpp_class; 238 } 239 240 // Returns the string that can be used for a header guard for the generated 241 // headers for <fd>, either for the public api (if impl is false) or the 242 // internal implementation header (if impl is true). 243 string GetHeaderGuard(const FileDescriptor& fd, bool impl) { 244 string s = fd.name(); 245 std::replace(s.begin(), s.end(), '/', '_'); 246 std::replace(s.begin(), s.end(), '.', '_'); 247 return s + (impl ? "_IMPL_H_" : "_H_"); 248 } 249 250 void Generator::AppendFieldValueAppend(const FieldDescriptor& field, 251 const bool omit_default, 252 const string& field_expr) { 253 SetOutput(&cc_); 254 switch (field.cpp_type()) { 255 case FieldDescriptor::CPPTYPE_INT32: 256 case FieldDescriptor::CPPTYPE_INT64: 257 case FieldDescriptor::CPPTYPE_UINT32: 258 case FieldDescriptor::CPPTYPE_UINT64: 259 case FieldDescriptor::CPPTYPE_DOUBLE: 260 case FieldDescriptor::CPPTYPE_FLOAT: 261 Print("o->", omit_default ? "AppendNumericIfNotZero" : "AppendNumeric", 262 "(\"", field.name(), "\", ", field_expr, ");"); 263 break; 264 case FieldDescriptor::CPPTYPE_BOOL: 265 Print("o->", omit_default ? "AppendBoolIfTrue" : "AppendBool", "(\"", 266 field.name(), "\", ", field_expr, ");"); 267 break; 268 case FieldDescriptor::CPPTYPE_STRING: { 269 const auto ctype = field.options().ctype(); 270 CHECK(ctype == FieldOptions::CORD || ctype == FieldOptions::STRING) 271 << "Unsupported ctype " << ctype; 272 273 Print("o->", omit_default ? "AppendStringIfNotEmpty" : "AppendString", 274 "(\"", field.name(), "\", ProtobufStringToString(", field_expr, 275 "));"); 276 break; 277 } 278 case FieldDescriptor::CPPTYPE_ENUM: 279 if (omit_default) { 280 Print("if (", field_expr, " != 0) {").Nest(); 281 } 282 Print("o->AppendEnumName(\"", field.name(), "\", ", 283 GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, "));"); 284 if (omit_default) { 285 Unnest().Print("}"); 286 } 287 break; 288 case FieldDescriptor::CPPTYPE_MESSAGE: 289 CHECK(!field.message_type()->options().map_entry()); 290 if (omit_default) { 291 Print("if (msg.has_", field.name(), "()) {").Nest(); 292 } 293 Print("o->OpenNestedMessage(\"", field.name(), "\");"); 294 Print(GetQualifiedAppendFn(*field.message_type()), "(o, ", field_expr, 295 ");"); 296 Print("o->CloseNestedMessage();"); 297 if (omit_default) { 298 Unnest().Print("}"); 299 } 300 break; 301 } 302 } 303 304 void Generator::AppendFieldAppend(const FieldDescriptor& field) { 305 const string& name = field.name(); 306 307 if (field.is_map()) { 308 Print("{").Nest(); 309 const auto& key_type = *field.message_type()->FindFieldByName("key"); 310 const auto& value_type = *field.message_type()->FindFieldByName("value"); 311 312 Print("std::vector<", key_type.cpp_type_name(), "> keys;"); 313 Print("for (const auto& e : msg.", name, "()) keys.push_back(e.first);"); 314 Print("std::stable_sort(keys.begin(), keys.end());"); 315 Print("for (const auto& key : keys) {").Nest(); 316 Print("o->OpenNestedMessage(\"", name, "\");"); 317 AppendFieldValueAppend(key_type, false /* omit_default */, "key"); 318 AppendFieldValueAppend(value_type, false /* omit_default */, 319 StrCat("msg.", name, "().at(key)")); 320 Print("o->CloseNestedMessage();"); 321 Unnest().Print("}"); 322 323 Unnest().Print("}"); 324 } else if (field.is_repeated()) { 325 Print("for (int i = 0; i < msg.", name, "_size(); ++i) {"); 326 Nest(); 327 AppendFieldValueAppend(field, false /* omit_default */, 328 "msg." + name + "(i)"); 329 Unnest().Print("}"); 330 } else { 331 const auto* oneof = field.containing_oneof(); 332 if (oneof != nullptr) { 333 string camel_name = field.camelcase_name(); 334 camel_name[0] = toupper(camel_name[0]); 335 Print("if (msg.", oneof->name(), "_case() == ", 336 GetQualifiedName(*oneof->containing_type()), "::k", camel_name, 337 ") {"); 338 Nest(); 339 AppendFieldValueAppend(field, false /* omit_default */, 340 "msg." + name + "()"); 341 Unnest(); 342 Print("}"); 343 } else { 344 AppendFieldValueAppend(field, true /* omit_default */, 345 "msg." + name + "()"); 346 } 347 } 348 } 349 350 void Generator::AppendEnumFunctions(const EnumDescriptor& enum_d) { 351 const string sig = StrCat("const char* ", GetEnumNameFn(enum_d), "(\n ", 352 GetQualifiedName(enum_d), " value)"); 353 SetOutput(&header_); 354 Print().Print("// Enum text output for ", string(enum_d.full_name())); 355 Print(sig, ";"); 356 357 SetOutput(&cc_); 358 Print().Print(sig, " {"); 359 Nest().Print("switch (value) {").Nest(); 360 for (int i = 0; i < enum_d.value_count(); ++i) { 361 const auto& value = *enum_d.value(i); 362 Print("case ", value.number(), ": return \"", value.name(), "\";"); 363 } 364 Print("default: return \"\";"); 365 Unnest().Print("}"); 366 Unnest().Print("}"); 367 } 368 369 void Generator::AppendParseMessageFunction(const Descriptor& md) { 370 const bool map_append = (md.options().map_entry()); 371 string sig; 372 if (!map_append) { 373 sig = StrCat("bool ProtoParseFromString(\n const string& s,\n ", 374 GetQualifiedName(md), "* msg)"); 375 SetOutput(&header_).Print(sig, "\n TF_MUST_USE_RESULT;"); 376 377 SetOutput(&cc_); 378 Print().Print(sig, " {").Nest(); 379 Print("msg->Clear();"); 380 Print("Scanner scanner(s);"); 381 Print("if (!internal::ProtoParseFromScanner(", 382 "&scanner, false, false, msg)) return false;"); 383 Print("scanner.Eos();"); 384 Print("return scanner.GetResult();"); 385 Unnest().Print("}"); 386 } 387 388 // Parse from scanner - the real work here. 389 sig = StrCat("bool ProtoParseFromScanner(", 390 "\n ::tensorflow::strings::Scanner* scanner, bool nested, " 391 "bool close_curly,\n "); 392 const FieldDescriptor* key_type = nullptr; 393 const FieldDescriptor* value_type = nullptr; 394 if (map_append) { 395 key_type = md.FindFieldByName("key"); 396 value_type = md.FindFieldByName("value"); 397 StrAppend(&sig, "::tensorflow::protobuf::Map<", GetCppClass(*key_type), 398 ", ", GetCppClass(*value_type), ">* map)"); 399 } else { 400 StrAppend(&sig, GetQualifiedName(md), "* msg)"); 401 } 402 403 if (!map_append_signatures_included_.insert(sig).second) { 404 // signature for function to append to a map of this type has 405 // already been defined in this .cc file. Don't define it again. 406 return; 407 } 408 409 if (!map_append) { 410 SetOutput(&header_impl_).Print(sig, ";"); 411 } 412 413 SetOutput(&cc_); 414 Print().Print("namespace internal {"); 415 if (map_append) { 416 Print("namespace {"); 417 } 418 Print().Print(sig, " {").Nest(); 419 if (map_append) { 420 Print(GetCppClass(*key_type), " map_key;"); 421 Print("bool set_map_key = false;"); 422 Print(GetCppClass(*value_type), " map_value;"); 423 Print("bool set_map_value = false;"); 424 } 425 Print("std::vector<bool> has_seen(", md.field_count(), ", false);"); 426 Print("while(true) {").Nest(); 427 Print("ProtoSpaceAndComments(scanner);"); 428 429 // Emit success case 430 Print("if (nested && (scanner->Peek() == (close_curly ? '}' : '>'))) {") 431 .Nest(); 432 Print("scanner->One(Scanner::ALL);"); 433 Print("ProtoSpaceAndComments(scanner);"); 434 if (map_append) { 435 Print("if (!set_map_key || !set_map_value) return false;"); 436 Print("(*map)[map_key] = map_value;"); 437 } 438 Print("return true;"); 439 Unnest().Print("}"); 440 441 Print("if (!nested && scanner->empty()) { return true; }"); 442 Print("scanner->RestartCapture()"); 443 Print(" .Many(Scanner::LETTER_DIGIT_UNDERSCORE)"); 444 Print(" .StopCapture();"); 445 Print("StringPiece identifier;"); 446 Print("if (!scanner->GetResult(nullptr, &identifier)) return false;"); 447 Print("bool parsed_colon = false;"); 448 Print("(void)parsed_colon;"); // Avoid "set but not used" compiler warning 449 Print("ProtoSpaceAndComments(scanner);"); 450 Print("if (scanner->Peek() == ':') {"); 451 Nest().Print("parsed_colon = true;"); 452 Print("scanner->One(Scanner::ALL);"); 453 Print("ProtoSpaceAndComments(scanner);"); 454 Unnest().Print("}"); 455 for (int i = 0; i < md.field_count(); ++i) { 456 const FieldDescriptor* field = md.field(i); 457 const string& field_name = field->name(); 458 string mutable_value_expr; 459 string set_value_prefix; 460 if (map_append) { 461 mutable_value_expr = StrCat("&map_", field_name); 462 set_value_prefix = StrCat("map_", field_name, " = "); 463 } else if (field->is_repeated()) { 464 if (field->is_map()) { 465 mutable_value_expr = StrCat("msg->mutable_", field_name, "()"); 466 set_value_prefix = 467 "UNREACHABLE"; // generator will never use this value. 468 } else { 469 mutable_value_expr = StrCat("msg->add_", field_name, "()"); 470 set_value_prefix = StrCat("msg->add_", field_name); 471 } 472 } else { 473 mutable_value_expr = StrCat("msg->mutable_", field_name, "()"); 474 set_value_prefix = StrCat("msg->set_", field_name); 475 } 476 477 Print(i == 0 ? "" : "else ", "if (identifier == \"", field_name, "\") {"); 478 Nest(); 479 480 if (field->is_repeated()) { 481 CHECK(!map_append); 482 483 // Check to see if this is an array assignment, like a: [1, 2, 3] 484 Print("const bool is_list = (scanner->Peek() == '[');"); 485 Print("do {"); 486 // [ or , // skip 487 Nest().Print("if (is_list) {"); 488 Nest().Print("scanner->One(Scanner::ALL);"); 489 Print("ProtoSpaceAndComments(scanner);"); 490 Unnest().Print("}"); 491 } else if (field->containing_oneof() != nullptr) { 492 CHECK(!map_append); 493 494 // Detect duplicate oneof value. 495 const string oneof_name = field->containing_oneof()->name(); 496 Print("if (msg->", oneof_name, "_case() != 0) return false;"); 497 } 498 499 if (!field->is_repeated() && !map_append) { 500 // Detect duplicate nested repeated message. 501 Print("if (has_seen[", i, "]) return false;"); 502 Print("has_seen[", i, "] = true;"); 503 } 504 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { 505 Print("const char open_char = scanner->Peek();"); 506 Print("if (open_char != '{' && open_char != '<') return false;"); 507 Print("scanner->One(Scanner::ALL);"); 508 Print("ProtoSpaceAndComments(scanner);"); 509 if (field->is_map()) { 510 Print("if (!ProtoParseFromScanner("); 511 } else { 512 Print("if (!", GetPackageReferencePrefix(field->message_type()->file()), 513 "internal::ProtoParseFromScanner("); 514 } 515 Print(" scanner, true, open_char == '{', ", mutable_value_expr, 516 ")) return false;"); 517 } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) { 518 Print("string str_value;"); 519 Print( 520 "if (!parsed_colon || " 521 "!::tensorflow::strings::ProtoParseStringLiteralFromScanner("); 522 Print(" scanner, &str_value)) return false;"); 523 Print("SetProtobufStringSwapAllowed(&str_value, ", mutable_value_expr, 524 ");"); 525 } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { 526 Print("StringPiece value;"); 527 Print( 528 "if (!parsed_colon || " 529 "!scanner->RestartCapture().Many(" 530 "Scanner::LETTER_DIGIT_DASH_UNDERSCORE)." 531 "GetResult(nullptr, &value)) return false;"); 532 const auto* enum_d = field->enum_type(); 533 string value_prefix; 534 if (enum_d->containing_type() == nullptr) { 535 value_prefix = GetPackageReferencePrefix(enum_d->file()); 536 } else { 537 value_prefix = StrCat(GetQualifiedName(*enum_d), "_"); 538 } 539 540 for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) { 541 const auto* value_d = enum_d->value(enum_i); 542 const string& value_name = value_d->name(); 543 string condition = StrCat("value == \"", value_name, 544 "\" || value == \"", value_d->number(), "\""); 545 if (value_d->number() == 0) { 546 StrAppend(&condition, " || value == \"-0\""); 547 } 548 549 Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {"); 550 Nest(); 551 Print(set_value_prefix, "(", value_prefix, value_name, ");"); 552 Unnest(); 553 } 554 Print("} else {").Nest().Print("return false;").Unnest().Print("}"); 555 } else { 556 Print(field->cpp_type_name(), " value;"); 557 switch (field->cpp_type()) { 558 case FieldDescriptor::CPPTYPE_INT32: 559 case FieldDescriptor::CPPTYPE_INT64: 560 case FieldDescriptor::CPPTYPE_UINT32: 561 case FieldDescriptor::CPPTYPE_UINT64: 562 case FieldDescriptor::CPPTYPE_DOUBLE: 563 case FieldDescriptor::CPPTYPE_FLOAT: 564 Print( 565 "if (!parsed_colon || " 566 "!::tensorflow::strings::ProtoParseNumericFromScanner(", 567 "scanner, &value)) return false;"); 568 break; 569 case FieldDescriptor::CPPTYPE_BOOL: 570 Print( 571 "if (!parsed_colon || " 572 "!::tensorflow::strings::ProtoParseBoolFromScanner(", 573 "scanner, &value)) return false;"); 574 break; 575 default: 576 LOG(FATAL) << "handled earlier"; 577 } 578 Print(set_value_prefix, "(value);"); 579 } 580 581 if (field->is_repeated()) { 582 Unnest().Print("} while (is_list && scanner->Peek() == ',');"); 583 Print( 584 "if (is_list && " 585 "!scanner->OneLiteral(\"]\").GetResult()) return false;"); 586 } 587 if (map_append) { 588 Print("set_map_", field_name, " = true;"); 589 } 590 Unnest().Print("}"); 591 } 592 Unnest().Print("}"); 593 Unnest().Print("}"); 594 Unnest().Print(); 595 if (map_append) { 596 Print("} // namespace"); 597 } 598 Print("} // namespace internal"); 599 } 600 601 void Generator::AppendDebugStringFunctions(const Descriptor& md) { 602 SetOutput(&header_impl_).Print(); 603 SetOutput(&header_).Print().Print("// Message-text conversion for ", 604 string(md.full_name())); 605 606 // Append the two debug string functions for <md>. 607 for (int short_pass = 0; short_pass < 2; ++short_pass) { 608 const bool short_debug = (short_pass == 1); 609 610 // Make the Get functions. 611 const string sig = StrCat( 612 "string ", short_debug ? "ProtoShortDebugString" : "ProtoDebugString", 613 "(\n const ", GetQualifiedName(md), "& msg)"); 614 SetOutput(&header_).Print(sig, ";"); 615 616 SetOutput(&cc_); 617 Print().Print(sig, " {").Nest(); 618 Print("string s;"); 619 Print("::tensorflow::strings::ProtoTextOutput o(&s, ", 620 short_debug ? "true" : "false", ");"); 621 Print("internal::AppendProtoDebugString(&o, msg);"); 622 Print("o.CloseTopMessage();"); 623 Print("return s;"); 624 Unnest().Print("}"); 625 } 626 627 // Make the Append function. 628 const string sig = 629 StrCat("void AppendProtoDebugString(\n", 630 " ::tensorflow::strings::ProtoTextOutput* o,\n const ", 631 GetQualifiedName(md), "& msg)"); 632 SetOutput(&header_impl_).Print(sig, ";"); 633 SetOutput(&cc_); 634 Print().Print("namespace internal {").Print(); 635 Print(sig, " {").Nest(); 636 std::vector<const FieldDescriptor*> fields; 637 fields.reserve(md.field_count()); 638 for (int i = 0; i < md.field_count(); ++i) { 639 fields.push_back(md.field(i)); 640 } 641 std::sort(fields.begin(), fields.end(), 642 [](const FieldDescriptor* left, const FieldDescriptor* right) { 643 return left->number() < right->number(); 644 }); 645 646 for (const FieldDescriptor* field : fields) { 647 SetOutput(&cc_); 648 AppendFieldAppend(*field); 649 } 650 Unnest().Print("}").Print().Print("} // namespace internal"); 651 } 652 653 void Generator::AppendMessageFunctions(const Descriptor& md) { 654 if (md.options().map_entry()) { 655 // The 'map entry' Message is not a user-visible message type. Only its 656 // parse function is created (and that actually parsed the whole Map, not 657 // just the map entry). Printing of a map is done in the code generated for 658 // the containing message. 659 AppendParseMessageFunction(md); 660 return; 661 } 662 663 // Recurse before adding the main message function, so that internal 664 // map_append functions are available before they are needed. 665 for (int i = 0; i < md.enum_type_count(); ++i) { 666 AppendEnumFunctions(*md.enum_type(i)); 667 } 668 for (int i = 0; i < md.nested_type_count(); ++i) { 669 AppendMessageFunctions(*md.nested_type(i)); 670 } 671 672 AppendDebugStringFunctions(md); 673 AppendParseMessageFunction(md); 674 } 675 676 void Generator::AddNamespaceToCurrentSection(const string& package, bool open) { 677 Print(); 678 std::vector<string> parts = {""}; 679 for (size_t i = 0; i < package.size(); ++i) { 680 if (package[i] == '.') { 681 parts.resize(parts.size() + 1); 682 } else { 683 parts.back() += package[i]; 684 } 685 } 686 if (open) { 687 for (const auto& p : parts) { 688 Print("namespace ", p, " {"); 689 } 690 } else { 691 for (auto it = parts.rbegin(); it != parts.rend(); ++it) { 692 Print("} // namespace ", *it); 693 } 694 } 695 } 696 697 void Generator::AddHeadersToCurrentSection(const std::vector<string>& headers) { 698 std::vector<string> sorted = headers; 699 std::sort(sorted.begin(), sorted.end()); 700 for (const auto& h : sorted) { 701 Print("#include \"", h, "\""); 702 } 703 } 704 705 // Adds to <all_fd> and <all_d> with all descriptors recursively 706 // reachable from the given descriptor. 707 void GetAllFileDescriptorsFromFile(const FileDescriptor* fd, 708 std::set<const FileDescriptor*>* all_fd, 709 std::set<const Descriptor*>* all_d); 710 711 // Adds to <all_fd> and <all_d> with all descriptors recursively 712 // reachable from the given descriptor. 713 void GetAllFileDescriptorsFromMessage(const Descriptor* d, 714 std::set<const FileDescriptor*>* all_fd, 715 std::set<const Descriptor*>* all_d) { 716 if (!all_d->insert(d).second) return; 717 GetAllFileDescriptorsFromFile(d->file(), all_fd, all_d); 718 for (int i = 0; i < d->field_count(); ++i) { 719 auto* f = d->field(i); 720 switch (f->cpp_type()) { 721 case FieldDescriptor::CPPTYPE_INT32: 722 case FieldDescriptor::CPPTYPE_INT64: 723 case FieldDescriptor::CPPTYPE_UINT32: 724 case FieldDescriptor::CPPTYPE_UINT64: 725 case FieldDescriptor::CPPTYPE_DOUBLE: 726 case FieldDescriptor::CPPTYPE_FLOAT: 727 case FieldDescriptor::CPPTYPE_BOOL: 728 case FieldDescriptor::CPPTYPE_STRING: 729 break; 730 case FieldDescriptor::CPPTYPE_MESSAGE: 731 GetAllFileDescriptorsFromMessage(f->message_type(), all_fd, all_d); 732 break; 733 case FieldDescriptor::CPPTYPE_ENUM: 734 GetAllFileDescriptorsFromFile(f->enum_type()->file(), all_fd, all_d); 735 break; 736 } 737 } 738 for (int i = 0; i < d->nested_type_count(); ++i) { 739 GetAllFileDescriptorsFromMessage(d->nested_type(i), all_fd, all_d); 740 } 741 } 742 743 void GetAllFileDescriptorsFromFile(const FileDescriptor* fd, 744 std::set<const FileDescriptor*>* all_fd, 745 std::set<const Descriptor*>* all_d) { 746 if (!all_fd->insert(fd).second) return; 747 for (int i = 0; i < fd->message_type_count(); ++i) { 748 GetAllFileDescriptorsFromMessage(fd->message_type(i), all_fd, all_d); 749 } 750 } 751 752 void Generator::Generate(const FileDescriptor& fd) { 753 // This does not emit code with proper proto2 semantics (e.g. it doesn't check 754 // 'has' fields on non-messages), so check that only proto3 is passed. 755 CHECK_EQ(fd.syntax(), FileDescriptor::SYNTAX_PROTO3) << fd.name(); 756 757 const string package = fd.package(); 758 std::set<const FileDescriptor*> all_fd; 759 std::set<const Descriptor*> all_d; 760 GetAllFileDescriptorsFromFile(&fd, &all_fd, &all_d); 761 762 std::vector<string> headers; 763 764 // Add header to header file. 765 SetOutput(&header_); 766 Print("// GENERATED FILE - DO NOT MODIFY"); 767 Print("#ifndef ", GetHeaderGuard(fd, false /* impl */)); 768 Print("#define ", GetHeaderGuard(fd, false /* impl */)); 769 Print(); 770 headers = { 771 GetProtoHeaderName(fd), 772 StrCat(tf_header_prefix_, "tensorflow/core/platform/macros.h"), 773 StrCat(tf_header_prefix_, "tensorflow/core/platform/protobuf.h"), 774 StrCat(tf_header_prefix_, "tensorflow/core/platform/types.h"), 775 }; 776 for (const auto& h : headers) { 777 Print("#include \"", h, "\""); 778 } 779 AddNamespaceToCurrentSection(package, true /* is_open */); 780 781 // Add header to impl file. 782 SetOutput(&header_impl_); 783 Print("// GENERATED FILE - DO NOT MODIFY"); 784 Print("#ifndef ", GetHeaderGuard(fd, true /* impl */)); 785 Print("#define ", GetHeaderGuard(fd, true /* impl */)); 786 Print(); 787 headers = { 788 GetProtoTextHeaderName(fd, false /* impl */), 789 StrCat(tf_header_prefix_, 790 "tensorflow/core/lib/strings/proto_text_util.h"), 791 StrCat(tf_header_prefix_, "tensorflow/core/lib/strings/scanner.h"), 792 }; 793 for (const FileDescriptor* d : all_fd) { 794 if (d != &fd) { 795 headers.push_back(GetProtoTextHeaderName(*d, true /* impl */)); 796 } 797 headers.push_back(GetProtoHeaderName(*d)); 798 } 799 AddHeadersToCurrentSection(headers); 800 AddNamespaceToCurrentSection(package, true /* is_open */); 801 SetOutput(&header_impl_).Print().Print("namespace internal {"); 802 803 // Add header to cc file. 804 SetOutput(&cc_); 805 Print("// GENERATED FILE - DO NOT MODIFY"); 806 headers = {GetProtoTextHeaderName(fd, true /* impl */)}; 807 AddHeadersToCurrentSection(headers); 808 Print(); 809 Print("using ::tensorflow::strings::Scanner;"); 810 Print("using ::tensorflow::strings::StrCat;"); 811 AddNamespaceToCurrentSection(package, true /* is_open */); 812 813 // Add declarations and definitions. 814 for (int i = 0; i < fd.enum_type_count(); ++i) { 815 AppendEnumFunctions(*fd.enum_type(i)); 816 } 817 for (int i = 0; i < fd.message_type_count(); ++i) { 818 AppendMessageFunctions(*fd.message_type(i)); 819 } 820 821 // Add footer to header file. 822 SetOutput(&header_); 823 AddNamespaceToCurrentSection(package, false /* is_open */); 824 Print().Print("#endif // ", GetHeaderGuard(fd, false /* impl */)); 825 826 // Add footer to header impl file. 827 SetOutput(&header_impl_).Print().Print("} // namespace internal"); 828 AddNamespaceToCurrentSection(package, false /* is_open */); 829 Print().Print("#endif // ", GetHeaderGuard(fd, true /* impl */)); 830 831 // Add footer to cc file. 832 SetOutput(&cc_); 833 AddNamespaceToCurrentSection(package, false /* is_open */); 834 } 835 836 } // namespace 837 838 ProtoTextFunctionCode GetProtoTextFunctionCode(const FileDescriptor& fd, 839 const string& tf_header_prefix) { 840 Generator gen(tf_header_prefix); 841 gen.Generate(fd); 842 return gen.code(); 843 } 844 845 } // namespace tensorflow 846