1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h" 17 18 #include <algorithm> 19 #include <set> 20 #include <unordered_set> 21 22 #include "tensorflow/core/platform/logging.h" 23 #include "tensorflow/core/platform/macros.h" 24 #include "tensorflow/core/platform/types.h" 25 26 using ::tensorflow::protobuf::Descriptor; 27 using ::tensorflow::protobuf::EnumDescriptor; 28 using ::tensorflow::protobuf::FieldDescriptor; 29 using ::tensorflow::protobuf::FieldOptions; 30 using ::tensorflow::protobuf::FileDescriptor; 31 32 namespace tensorflow { 33 34 namespace { 35 36 template <typename... Args> 37 string StrCat(const Args&... args) { 38 std::ostringstream s; 39 std::vector<int>{((s << args), 0)...}; 40 return s.str(); 41 } 42 43 template <typename... Args> 44 string StrAppend(string* to_append, const Args&... args) { 45 *to_append += StrCat(args...); 46 return *to_append; 47 } 48 49 // Class used to generate the code for proto text functions. One of these should 50 // be created for each FileDescriptor whose code should be generated. 51 // 52 // This class has a notion of the current output Section. The Print, Nested, 53 // and Unnest functions apply their operations to the current output section, 54 // which can be toggled with SetOutput. 55 // 56 // Note that on the generated code, various pieces are not optimized - for 57 // example: map input and output, Cord input and output, comparisons against 58 // the field names (it's a loop over all names), and tracking of has_seen. 59 class Generator { 60 public: 61 explicit Generator(const string& tf_header_prefix) 62 : tf_header_prefix_(tf_header_prefix), 63 header_(&code_.header), 64 header_impl_(&code_.header_impl), 65 cc_(&code_.cc) {} 66 67 void Generate(const FileDescriptor& fd); 68 69 // The generated code; valid after Generate has been called. 70 ProtoTextFunctionCode code() const { return code_; } 71 72 private: 73 struct Section { 74 explicit Section(string* str) : str(str) {} 75 string* str; 76 string indent; 77 }; 78 79 // Switches the currently active section to <section>. 80 Generator& SetOutput(Section* section) { 81 cur_ = section; 82 return *this; 83 } 84 85 // Increases indent level. Returns <*this>, to allow chaining. 86 Generator& Nest() { 87 StrAppend(&cur_->indent, " "); 88 return *this; 89 } 90 91 // Decreases indent level. Returns <*this>, to allow chaining. 92 Generator& Unnest() { 93 cur_->indent = cur_->indent.substr(0, cur_->indent.size() - 2); 94 return *this; 95 } 96 97 // Appends the concatenated args, with a trailing newline. Returns <*this>, to 98 // allow chaining. 99 template <typename... Args> 100 Generator& Print(Args... args) { 101 StrAppend(cur_->str, cur_->indent, args..., "\n"); 102 return *this; 103 } 104 105 // Appends the print code for a single field's value. 106 // If <omit_default> is true, then the emitted code will not print zero-valued 107 // values. 108 // <field_expr> is code that when emitted yields the field's value. 109 void AppendFieldValueAppend(const FieldDescriptor& field, 110 const bool omit_default, 111 const string& field_expr); 112 113 // Appends the print code for as single field. 114 void AppendFieldAppend(const FieldDescriptor& field); 115 116 // Appends the print code for a message. May change which section is currently 117 // active. 118 void AppendDebugStringFunctions(const Descriptor& md); 119 120 // Appends the print and parse functions for an enum. May change which 121 // section is currently active. 122 void AppendEnumFunctions(const EnumDescriptor& enum_d); 123 124 // Appends the parse functions for a message. May change which section is 125 // currently active. 126 void AppendParseMessageFunction(const Descriptor& md); 127 128 // Appends all functions for a message and its nested message and enum types. 129 // May change which section is currently active. 130 void AppendMessageFunctions(const Descriptor& md); 131 132 // Appends lines to open or close namespace declarations. 133 void AddNamespaceToCurrentSection(const string& package, bool open); 134 135 // Appends the given headers as sorted #include lines. 136 void AddHeadersToCurrentSection(const std::vector<string>& headers); 137 138 // When adding #includes for tensorflow headers, prefix them with this. 139 const string tf_header_prefix_; 140 ProtoTextFunctionCode code_; 141 Section* cur_ = nullptr; 142 Section header_; 143 Section header_impl_; 144 Section cc_; 145 146 std::unordered_set<string> map_append_signatures_included_; 147 148 TF_DISALLOW_COPY_AND_ASSIGN(Generator); 149 }; 150 151 // Returns the prefix needed to reference objects defined in <fd>. E.g. 152 // "::tensorflow::test". 153 string GetPackageReferencePrefix(const FileDescriptor* fd) { 154 string result = "::"; 155 const string& package = fd->package(); 156 for (size_t i = 0; i < package.size(); ++i) { 157 if (package[i] == '.') { 158 result += "::"; 159 } else { 160 result += package[i]; 161 } 162 } 163 result += "::"; 164 return result; 165 } 166 167 // Returns the name of the class generated by proto to represent <d>. 168 string GetClassName(const Descriptor& d) { 169 if (d.containing_type() == nullptr) return d.name(); 170 return StrCat(GetClassName(*d.containing_type()), "_", d.name()); 171 } 172 173 // Returns the name of the class generated by proto to represent <ed>. 174 string GetClassName(const EnumDescriptor& ed) { 175 if (ed.containing_type() == nullptr) return ed.name(); 176 return StrCat(GetClassName(*ed.containing_type()), "_", ed.name()); 177 } 178 179 // Returns the qualified name that refers to the class generated by proto to 180 // represent <d>. 181 string GetQualifiedName(const Descriptor& d) { 182 return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d)); 183 } 184 185 // Returns the qualified name that refers to the class generated by proto to 186 // represent <ed>. 187 string GetQualifiedName(const EnumDescriptor& d) { 188 return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d)); 189 } 190 191 // Returns the qualified name that refers to the generated 192 // AppendProtoDebugString function for <d>. 193 string GetQualifiedAppendFn(const Descriptor& d) { 194 return StrCat(GetPackageReferencePrefix(d.file()), 195 "internal::AppendProtoDebugString"); 196 } 197 198 // Returns the name of the generated function that returns an enum value's 199 // string value. 200 string GetEnumNameFn(const EnumDescriptor& enum_d) { 201 return StrCat("EnumName_", GetClassName(enum_d)); 202 } 203 204 // Returns the qualified name of the function returned by GetEnumNameFn(). 205 string GetQualifiedEnumNameFn(const EnumDescriptor& enum_d) { 206 return StrCat(GetPackageReferencePrefix(enum_d.file()), 207 GetEnumNameFn(enum_d)); 208 } 209 210 // Returns the name of a generated header file, either the public api (if impl 211 // is false) or the internal implementation header (if impl is true). 212 string GetProtoTextHeaderName(const FileDescriptor& fd, bool impl) { 213 const int dot_index = fd.name().find_last_of('.'); 214 return fd.name().substr(0, dot_index) + 215 (impl ? ".pb_text-impl.h" : ".pb_text.h"); 216 } 217 218 // Returns the name of the header generated by the proto library for <fd>. 219 string GetProtoHeaderName(const FileDescriptor& fd) { 220 const int dot_index = fd.name().find_last_of('.'); 221 return fd.name().substr(0, dot_index) + ".pb.h"; 222 } 223 224 // Returns the C++ class name for the given proto field. 225 string GetCppClass(const FieldDescriptor& d) { 226 string cpp_class = d.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE 227 ? GetQualifiedName(*d.message_type()) 228 : d.cpp_type_name(); 229 230 // In open-source TensorFlow, the definition of int64 varies across 231 // platforms. The following line, which is manipulated during internal- 232 // external sync'ing, takes care of the variability. 233 if (cpp_class == "int64") { 234 cpp_class = kProtobufInt64Typename; 235 } 236 237 return cpp_class; 238 } 239 240 // Returns the string that can be used for a header guard for the generated 241 // headers for <fd>, either for the public api (if impl is false) or the 242 // internal implementation header (if impl is true). 243 string GetHeaderGuard(const FileDescriptor& fd, bool impl) { 244 string s = fd.name(); 245 std::replace(s.begin(), s.end(), '/', '_'); 246 std::replace(s.begin(), s.end(), '.', '_'); 247 return s + (impl ? "_IMPL_H_" : "_H_"); 248 } 249 250 void Generator::AppendFieldValueAppend(const FieldDescriptor& field, 251 const bool omit_default, 252 const string& field_expr) { 253 SetOutput(&cc_); 254 switch (field.cpp_type()) { 255 case FieldDescriptor::CPPTYPE_INT32: 256 case FieldDescriptor::CPPTYPE_INT64: 257 case FieldDescriptor::CPPTYPE_UINT32: 258 case FieldDescriptor::CPPTYPE_UINT64: 259 case FieldDescriptor::CPPTYPE_DOUBLE: 260 case FieldDescriptor::CPPTYPE_FLOAT: 261 Print("o->", omit_default ? "AppendNumericIfNotZero" : "AppendNumeric", 262 "(\"", field.name(), "\", ", field_expr, ");"); 263 break; 264 case FieldDescriptor::CPPTYPE_BOOL: 265 Print("o->", omit_default ? "AppendBoolIfTrue" : "AppendBool", "(\"", 266 field.name(), "\", ", field_expr, ");"); 267 break; 268 case FieldDescriptor::CPPTYPE_STRING: { 269 const auto ctype = field.options().ctype(); 270 CHECK(ctype == FieldOptions::CORD || ctype == FieldOptions::STRING) 271 << "Unsupported ctype " << ctype; 272 273 Print("o->", omit_default ? "AppendStringIfNotEmpty" : "AppendString", 274 "(\"", field.name(), "\", ProtobufStringToString(", field_expr, 275 "));"); 276 break; 277 } 278 case FieldDescriptor::CPPTYPE_ENUM: 279 if (omit_default) { 280 Print("if (", field_expr, " != 0) {").Nest(); 281 } 282 Print("const char* enum_name = ", 283 GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, ");"); 284 Print("if (enum_name[0]) {").Nest(); 285 Print("o->AppendEnumName(\"", field.name(), "\", enum_name);"); 286 Unnest().Print("} else {").Nest(); 287 Print("o->AppendNumeric(\"", field.name(), "\", ", field_expr, ");"); 288 Unnest().Print("}"); 289 if (omit_default) { 290 Unnest().Print("}"); 291 } 292 break; 293 case FieldDescriptor::CPPTYPE_MESSAGE: 294 CHECK(!field.message_type()->options().map_entry()); 295 if (omit_default) { 296 Print("if (msg.has_", field.name(), "()) {").Nest(); 297 } 298 Print("o->OpenNestedMessage(\"", field.name(), "\");"); 299 Print(GetQualifiedAppendFn(*field.message_type()), "(o, ", field_expr, 300 ");"); 301 Print("o->CloseNestedMessage();"); 302 if (omit_default) { 303 Unnest().Print("}"); 304 } 305 break; 306 } 307 } 308 309 void Generator::AppendFieldAppend(const FieldDescriptor& field) { 310 const string& name = field.name(); 311 312 if (field.is_map()) { 313 Print("{").Nest(); 314 const auto& key_type = *field.message_type()->FindFieldByName("key"); 315 const auto& value_type = *field.message_type()->FindFieldByName("value"); 316 317 Print("std::vector<", key_type.cpp_type_name(), "> keys;"); 318 Print("for (const auto& e : msg.", name, "()) keys.push_back(e.first);"); 319 Print("std::stable_sort(keys.begin(), keys.end());"); 320 Print("for (const auto& key : keys) {").Nest(); 321 Print("o->OpenNestedMessage(\"", name, "\");"); 322 AppendFieldValueAppend(key_type, false /* omit_default */, "key"); 323 AppendFieldValueAppend(value_type, false /* omit_default */, 324 StrCat("msg.", name, "().at(key)")); 325 Print("o->CloseNestedMessage();"); 326 Unnest().Print("}"); 327 328 Unnest().Print("}"); 329 } else if (field.is_repeated()) { 330 Print("for (int i = 0; i < msg.", name, "_size(); ++i) {"); 331 Nest(); 332 AppendFieldValueAppend(field, false /* omit_default */, 333 "msg." + name + "(i)"); 334 Unnest().Print("}"); 335 } else { 336 const auto* oneof = field.containing_oneof(); 337 if (oneof != nullptr) { 338 string camel_name = field.camelcase_name(); 339 camel_name[0] = toupper(camel_name[0]); 340 Print("if (msg.", oneof->name(), "_case() == ", 341 GetQualifiedName(*oneof->containing_type()), "::k", camel_name, 342 ") {"); 343 Nest(); 344 AppendFieldValueAppend(field, false /* omit_default */, 345 "msg." + name + "()"); 346 Unnest(); 347 Print("}"); 348 } else { 349 AppendFieldValueAppend(field, true /* omit_default */, 350 "msg." + name + "()"); 351 } 352 } 353 } 354 355 void Generator::AppendEnumFunctions(const EnumDescriptor& enum_d) { 356 const string sig = StrCat("const char* ", GetEnumNameFn(enum_d), "(\n ", 357 GetQualifiedName(enum_d), " value)"); 358 SetOutput(&header_); 359 Print().Print("// Enum text output for ", string(enum_d.full_name())); 360 Print(sig, ";"); 361 362 SetOutput(&cc_); 363 Print().Print(sig, " {"); 364 Nest().Print("switch (value) {").Nest(); 365 for (int i = 0; i < enum_d.value_count(); ++i) { 366 const auto& value = *enum_d.value(i); 367 Print("case ", value.number(), ": return \"", value.name(), "\";"); 368 } 369 Print("default: return \"\";"); 370 Unnest().Print("}"); 371 Unnest().Print("}"); 372 } 373 374 void Generator::AppendParseMessageFunction(const Descriptor& md) { 375 const bool map_append = (md.options().map_entry()); 376 string sig; 377 if (!map_append) { 378 sig = StrCat("bool ProtoParseFromString(\n const string& s,\n ", 379 GetQualifiedName(md), "* msg)"); 380 SetOutput(&header_).Print(sig, "\n TF_MUST_USE_RESULT;"); 381 382 SetOutput(&cc_); 383 Print().Print(sig, " {").Nest(); 384 Print("msg->Clear();"); 385 Print("Scanner scanner(s);"); 386 Print("if (!internal::ProtoParseFromScanner(", 387 "&scanner, false, false, msg)) return false;"); 388 Print("scanner.Eos();"); 389 Print("return scanner.GetResult();"); 390 Unnest().Print("}"); 391 } 392 393 // Parse from scanner - the real work here. 394 sig = StrCat("bool ProtoParseFromScanner(", 395 "\n ::tensorflow::strings::Scanner* scanner, bool nested, " 396 "bool close_curly,\n "); 397 const FieldDescriptor* key_type = nullptr; 398 const FieldDescriptor* value_type = nullptr; 399 if (map_append) { 400 key_type = md.FindFieldByName("key"); 401 value_type = md.FindFieldByName("value"); 402 StrAppend(&sig, "::tensorflow::protobuf::Map<", GetCppClass(*key_type), 403 ", ", GetCppClass(*value_type), ">* map)"); 404 } else { 405 StrAppend(&sig, GetQualifiedName(md), "* msg)"); 406 } 407 408 if (!map_append_signatures_included_.insert(sig).second) { 409 // signature for function to append to a map of this type has 410 // already been defined in this .cc file. Don't define it again. 411 return; 412 } 413 414 if (!map_append) { 415 SetOutput(&header_impl_).Print(sig, ";"); 416 } 417 418 SetOutput(&cc_); 419 Print().Print("namespace internal {"); 420 if (map_append) { 421 Print("namespace {"); 422 } 423 Print().Print(sig, " {").Nest(); 424 if (map_append) { 425 Print(GetCppClass(*key_type), " map_key;"); 426 Print("bool set_map_key = false;"); 427 Print(GetCppClass(*value_type), " map_value;"); 428 Print("bool set_map_value = false;"); 429 } 430 Print("std::vector<bool> has_seen(", md.field_count(), ", false);"); 431 Print("while(true) {").Nest(); 432 Print("ProtoSpaceAndComments(scanner);"); 433 434 // Emit success case 435 Print("if (nested && (scanner->Peek() == (close_curly ? '}' : '>'))) {") 436 .Nest(); 437 Print("scanner->One(Scanner::ALL);"); 438 Print("ProtoSpaceAndComments(scanner);"); 439 if (map_append) { 440 Print("if (!set_map_key || !set_map_value) return false;"); 441 Print("(*map)[map_key] = map_value;"); 442 } 443 Print("return true;"); 444 Unnest().Print("}"); 445 446 Print("if (!nested && scanner->empty()) { return true; }"); 447 Print("scanner->RestartCapture()"); 448 Print(" .Many(Scanner::LETTER_DIGIT_UNDERSCORE)"); 449 Print(" .StopCapture();"); 450 Print("StringPiece identifier;"); 451 Print("if (!scanner->GetResult(nullptr, &identifier)) return false;"); 452 Print("bool parsed_colon = false;"); 453 Print("(void)parsed_colon;"); // Avoid "set but not used" compiler warning 454 Print("ProtoSpaceAndComments(scanner);"); 455 Print("if (scanner->Peek() == ':') {"); 456 Nest().Print("parsed_colon = true;"); 457 Print("scanner->One(Scanner::ALL);"); 458 Print("ProtoSpaceAndComments(scanner);"); 459 Unnest().Print("}"); 460 for (int i = 0; i < md.field_count(); ++i) { 461 const FieldDescriptor* field = md.field(i); 462 const string& field_name = field->name(); 463 string mutable_value_expr; 464 string set_value_prefix; 465 if (map_append) { 466 mutable_value_expr = StrCat("&map_", field_name); 467 set_value_prefix = StrCat("map_", field_name, " = "); 468 } else if (field->is_repeated()) { 469 if (field->is_map()) { 470 mutable_value_expr = StrCat("msg->mutable_", field_name, "()"); 471 set_value_prefix = 472 "UNREACHABLE"; // generator will never use this value. 473 } else { 474 mutable_value_expr = StrCat("msg->add_", field_name, "()"); 475 set_value_prefix = StrCat("msg->add_", field_name); 476 } 477 } else { 478 mutable_value_expr = StrCat("msg->mutable_", field_name, "()"); 479 set_value_prefix = StrCat("msg->set_", field_name); 480 } 481 482 Print(i == 0 ? "" : "else ", "if (identifier == \"", field_name, "\") {"); 483 Nest(); 484 485 if (field->is_repeated()) { 486 CHECK(!map_append); 487 488 // Check to see if this is an array assignment, like a: [1, 2, 3] 489 Print("const bool is_list = (scanner->Peek() == '[');"); 490 Print("do {"); 491 // [ or , // skip 492 Nest().Print("if (is_list) {"); 493 Nest().Print("scanner->One(Scanner::ALL);"); 494 Print("ProtoSpaceAndComments(scanner);"); 495 Unnest().Print("}"); 496 } else if (field->containing_oneof() != nullptr) { 497 CHECK(!map_append); 498 499 // Detect duplicate oneof value. 500 const string oneof_name = field->containing_oneof()->name(); 501 Print("if (msg->", oneof_name, "_case() != 0) return false;"); 502 } 503 504 if (!field->is_repeated() && !map_append) { 505 // Detect duplicate nested repeated message. 506 Print("if (has_seen[", i, "]) return false;"); 507 Print("has_seen[", i, "] = true;"); 508 } 509 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { 510 Print("const char open_char = scanner->Peek();"); 511 Print("if (open_char != '{' && open_char != '<') return false;"); 512 Print("scanner->One(Scanner::ALL);"); 513 Print("ProtoSpaceAndComments(scanner);"); 514 if (field->is_map()) { 515 Print("if (!ProtoParseFromScanner("); 516 } else { 517 Print("if (!", GetPackageReferencePrefix(field->message_type()->file()), 518 "internal::ProtoParseFromScanner("); 519 } 520 Print(" scanner, true, open_char == '{', ", mutable_value_expr, 521 ")) return false;"); 522 } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) { 523 Print("string str_value;"); 524 Print( 525 "if (!parsed_colon || " 526 "!::tensorflow::strings::ProtoParseStringLiteralFromScanner("); 527 Print(" scanner, &str_value)) return false;"); 528 Print("SetProtobufStringSwapAllowed(&str_value, ", mutable_value_expr, 529 ");"); 530 } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { 531 Print("StringPiece value;"); 532 Print( 533 "if (!parsed_colon || " 534 "!scanner->RestartCapture().Many(" 535 "Scanner::LETTER_DIGIT_DASH_UNDERSCORE)." 536 "GetResult(nullptr, &value)) return false;"); 537 const auto* enum_d = field->enum_type(); 538 string value_prefix; 539 if (enum_d->containing_type() == nullptr) { 540 value_prefix = GetPackageReferencePrefix(enum_d->file()); 541 } else { 542 value_prefix = StrCat(GetQualifiedName(*enum_d), "_"); 543 } 544 545 for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) { 546 const auto* value_d = enum_d->value(enum_i); 547 const string& value_name = value_d->name(); 548 string condition = StrCat("value == \"", value_name, "\""); 549 550 Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {"); 551 Nest(); 552 Print(set_value_prefix, "(", value_prefix, value_name, ");"); 553 Unnest(); 554 } 555 Print("} else {"); 556 Nest(); 557 // Proto3 allows all numeric values. 558 Print("int32 int_value;"); 559 Print("if (strings::SafeStringToNumeric(value, &int_value)) {"); 560 Nest(); 561 Print(set_value_prefix, "(static_cast<", GetQualifiedName(*enum_d), 562 ">(int_value));"); 563 Unnest(); 564 Print("} else {").Nest().Print("return false;").Unnest().Print("}"); 565 Unnest().Print("}"); 566 } else { 567 Print(field->cpp_type_name(), " value;"); 568 switch (field->cpp_type()) { 569 case FieldDescriptor::CPPTYPE_INT32: 570 case FieldDescriptor::CPPTYPE_INT64: 571 case FieldDescriptor::CPPTYPE_UINT32: 572 case FieldDescriptor::CPPTYPE_UINT64: 573 case FieldDescriptor::CPPTYPE_DOUBLE: 574 case FieldDescriptor::CPPTYPE_FLOAT: 575 Print( 576 "if (!parsed_colon || " 577 "!::tensorflow::strings::ProtoParseNumericFromScanner(", 578 "scanner, &value)) return false;"); 579 break; 580 case FieldDescriptor::CPPTYPE_BOOL: 581 Print( 582 "if (!parsed_colon || " 583 "!::tensorflow::strings::ProtoParseBoolFromScanner(", 584 "scanner, &value)) return false;"); 585 break; 586 default: 587 LOG(FATAL) << "handled earlier"; 588 } 589 Print(set_value_prefix, "(value);"); 590 } 591 592 if (field->is_repeated()) { 593 Unnest().Print("} while (is_list && scanner->Peek() == ',');"); 594 Print( 595 "if (is_list && " 596 "!scanner->OneLiteral(\"]\").GetResult()) return false;"); 597 } 598 if (map_append) { 599 Print("set_map_", field_name, " = true;"); 600 } 601 Unnest().Print("}"); 602 } 603 Unnest().Print("}"); 604 Unnest().Print("}"); 605 Unnest().Print(); 606 if (map_append) { 607 Print("} // namespace"); 608 } 609 Print("} // namespace internal"); 610 } 611 612 void Generator::AppendDebugStringFunctions(const Descriptor& md) { 613 SetOutput(&header_impl_).Print(); 614 SetOutput(&header_).Print().Print("// Message-text conversion for ", 615 string(md.full_name())); 616 617 // Append the two debug string functions for <md>. 618 for (int short_pass = 0; short_pass < 2; ++short_pass) { 619 const bool short_debug = (short_pass == 1); 620 621 // Make the Get functions. 622 const string sig = StrCat( 623 "string ", short_debug ? "ProtoShortDebugString" : "ProtoDebugString", 624 "(\n const ", GetQualifiedName(md), "& msg)"); 625 SetOutput(&header_).Print(sig, ";"); 626 627 SetOutput(&cc_); 628 Print().Print(sig, " {").Nest(); 629 Print("string s;"); 630 Print("::tensorflow::strings::ProtoTextOutput o(&s, ", 631 short_debug ? "true" : "false", ");"); 632 Print("internal::AppendProtoDebugString(&o, msg);"); 633 Print("o.CloseTopMessage();"); 634 Print("return s;"); 635 Unnest().Print("}"); 636 } 637 638 // Make the Append function. 639 const string sig = 640 StrCat("void AppendProtoDebugString(\n", 641 " ::tensorflow::strings::ProtoTextOutput* o,\n const ", 642 GetQualifiedName(md), "& msg)"); 643 SetOutput(&header_impl_).Print(sig, ";"); 644 SetOutput(&cc_); 645 Print().Print("namespace internal {").Print(); 646 Print(sig, " {").Nest(); 647 std::vector<const FieldDescriptor*> fields; 648 fields.reserve(md.field_count()); 649 for (int i = 0; i < md.field_count(); ++i) { 650 fields.push_back(md.field(i)); 651 } 652 std::sort(fields.begin(), fields.end(), 653 [](const FieldDescriptor* left, const FieldDescriptor* right) { 654 return left->number() < right->number(); 655 }); 656 657 for (const FieldDescriptor* field : fields) { 658 SetOutput(&cc_); 659 AppendFieldAppend(*field); 660 } 661 Unnest().Print("}").Print().Print("} // namespace internal"); 662 } 663 664 void Generator::AppendMessageFunctions(const Descriptor& md) { 665 if (md.options().map_entry()) { 666 // The 'map entry' Message is not a user-visible message type. Only its 667 // parse function is created (and that actually parsed the whole Map, not 668 // just the map entry). Printing of a map is done in the code generated for 669 // the containing message. 670 AppendParseMessageFunction(md); 671 return; 672 } 673 674 // Recurse before adding the main message function, so that internal 675 // map_append functions are available before they are needed. 676 for (int i = 0; i < md.enum_type_count(); ++i) { 677 AppendEnumFunctions(*md.enum_type(i)); 678 } 679 for (int i = 0; i < md.nested_type_count(); ++i) { 680 AppendMessageFunctions(*md.nested_type(i)); 681 } 682 683 AppendDebugStringFunctions(md); 684 AppendParseMessageFunction(md); 685 } 686 687 void Generator::AddNamespaceToCurrentSection(const string& package, bool open) { 688 Print(); 689 std::vector<string> parts = {""}; 690 for (size_t i = 0; i < package.size(); ++i) { 691 if (package[i] == '.') { 692 parts.resize(parts.size() + 1); 693 } else { 694 parts.back() += package[i]; 695 } 696 } 697 if (open) { 698 for (const auto& p : parts) { 699 Print("namespace ", p, " {"); 700 } 701 } else { 702 for (auto it = parts.rbegin(); it != parts.rend(); ++it) { 703 Print("} // namespace ", *it); 704 } 705 } 706 } 707 708 void Generator::AddHeadersToCurrentSection(const std::vector<string>& headers) { 709 std::vector<string> sorted = headers; 710 std::sort(sorted.begin(), sorted.end()); 711 for (const auto& h : sorted) { 712 Print("#include \"", h, "\""); 713 } 714 } 715 716 // Adds to <all_fd> and <all_d> with all descriptors recursively 717 // reachable from the given descriptor. 718 void GetAllFileDescriptorsFromFile(const FileDescriptor* fd, 719 std::set<const FileDescriptor*>* all_fd, 720 std::set<const Descriptor*>* all_d); 721 722 // Adds to <all_fd> and <all_d> with all descriptors recursively 723 // reachable from the given descriptor. 724 void GetAllFileDescriptorsFromMessage(const Descriptor* d, 725 std::set<const FileDescriptor*>* all_fd, 726 std::set<const Descriptor*>* all_d) { 727 if (!all_d->insert(d).second) return; 728 GetAllFileDescriptorsFromFile(d->file(), all_fd, all_d); 729 for (int i = 0; i < d->field_count(); ++i) { 730 auto* f = d->field(i); 731 switch (f->cpp_type()) { 732 case FieldDescriptor::CPPTYPE_INT32: 733 case FieldDescriptor::CPPTYPE_INT64: 734 case FieldDescriptor::CPPTYPE_UINT32: 735 case FieldDescriptor::CPPTYPE_UINT64: 736 case FieldDescriptor::CPPTYPE_DOUBLE: 737 case FieldDescriptor::CPPTYPE_FLOAT: 738 case FieldDescriptor::CPPTYPE_BOOL: 739 case FieldDescriptor::CPPTYPE_STRING: 740 break; 741 case FieldDescriptor::CPPTYPE_MESSAGE: 742 GetAllFileDescriptorsFromMessage(f->message_type(), all_fd, all_d); 743 break; 744 case FieldDescriptor::CPPTYPE_ENUM: 745 GetAllFileDescriptorsFromFile(f->enum_type()->file(), all_fd, all_d); 746 break; 747 } 748 } 749 for (int i = 0; i < d->nested_type_count(); ++i) { 750 GetAllFileDescriptorsFromMessage(d->nested_type(i), all_fd, all_d); 751 } 752 } 753 754 void GetAllFileDescriptorsFromFile(const FileDescriptor* fd, 755 std::set<const FileDescriptor*>* all_fd, 756 std::set<const Descriptor*>* all_d) { 757 if (!all_fd->insert(fd).second) return; 758 for (int i = 0; i < fd->message_type_count(); ++i) { 759 GetAllFileDescriptorsFromMessage(fd->message_type(i), all_fd, all_d); 760 } 761 } 762 763 void Generator::Generate(const FileDescriptor& fd) { 764 // This does not emit code with proper proto2 semantics (e.g. it doesn't check 765 // 'has' fields on non-messages), so check that only proto3 is passed. 766 CHECK_EQ(fd.syntax(), FileDescriptor::SYNTAX_PROTO3) << fd.name(); 767 768 const string package = fd.package(); 769 std::set<const FileDescriptor*> all_fd; 770 std::set<const Descriptor*> all_d; 771 GetAllFileDescriptorsFromFile(&fd, &all_fd, &all_d); 772 773 std::vector<string> headers; 774 775 // Add header to header file. 776 SetOutput(&header_); 777 Print("// GENERATED FILE - DO NOT MODIFY"); 778 Print("#ifndef ", GetHeaderGuard(fd, false /* impl */)); 779 Print("#define ", GetHeaderGuard(fd, false /* impl */)); 780 Print(); 781 headers = { 782 GetProtoHeaderName(fd), 783 StrCat(tf_header_prefix_, "tensorflow/core/platform/macros.h"), 784 StrCat(tf_header_prefix_, "tensorflow/core/platform/protobuf.h"), 785 StrCat(tf_header_prefix_, "tensorflow/core/platform/types.h"), 786 }; 787 for (const auto& h : headers) { 788 Print("#include \"", h, "\""); 789 } 790 AddNamespaceToCurrentSection(package, true /* is_open */); 791 792 // Add header to impl file. 793 SetOutput(&header_impl_); 794 Print("// GENERATED FILE - DO NOT MODIFY"); 795 Print("#ifndef ", GetHeaderGuard(fd, true /* impl */)); 796 Print("#define ", GetHeaderGuard(fd, true /* impl */)); 797 Print(); 798 headers = { 799 GetProtoTextHeaderName(fd, false /* impl */), 800 StrCat(tf_header_prefix_, 801 "tensorflow/core/lib/strings/proto_text_util.h"), 802 StrCat(tf_header_prefix_, "tensorflow/core/lib/strings/scanner.h"), 803 }; 804 for (const FileDescriptor* d : all_fd) { 805 if (d != &fd) { 806 headers.push_back(GetProtoTextHeaderName(*d, true /* impl */)); 807 } 808 headers.push_back(GetProtoHeaderName(*d)); 809 } 810 AddHeadersToCurrentSection(headers); 811 AddNamespaceToCurrentSection(package, true /* is_open */); 812 SetOutput(&header_impl_).Print().Print("namespace internal {"); 813 814 // Add header to cc file. 815 SetOutput(&cc_); 816 Print("// GENERATED FILE - DO NOT MODIFY"); 817 Print(); 818 Print("#include <algorithm>"); // for `std::stable_sort()` 819 Print(); 820 headers = {GetProtoTextHeaderName(fd, true /* impl */)}; 821 AddHeadersToCurrentSection(headers); 822 Print(); 823 Print("using ::tensorflow::strings::Scanner;"); 824 Print("using ::tensorflow::strings::StrCat;"); 825 AddNamespaceToCurrentSection(package, true /* is_open */); 826 827 // Add declarations and definitions. 828 for (int i = 0; i < fd.enum_type_count(); ++i) { 829 AppendEnumFunctions(*fd.enum_type(i)); 830 } 831 for (int i = 0; i < fd.message_type_count(); ++i) { 832 AppendMessageFunctions(*fd.message_type(i)); 833 } 834 835 // Add footer to header file. 836 SetOutput(&header_); 837 AddNamespaceToCurrentSection(package, false /* is_open */); 838 Print().Print("#endif // ", GetHeaderGuard(fd, false /* impl */)); 839 840 // Add footer to header impl file. 841 SetOutput(&header_impl_).Print().Print("} // namespace internal"); 842 AddNamespaceToCurrentSection(package, false /* is_open */); 843 Print().Print("#endif // ", GetHeaderGuard(fd, true /* impl */)); 844 845 // Add footer to cc file. 846 SetOutput(&cc_); 847 AddNamespaceToCurrentSection(package, false /* is_open */); 848 } 849 850 } // namespace 851 852 ProtoTextFunctionCode GetProtoTextFunctionCode(const FileDescriptor& fd, 853 const string& tf_header_prefix) { 854 Generator gen(tf_header_prefix); 855 gen.Generate(fd); 856 return gen.code(); 857 } 858 859 } // namespace tensorflow 860