Home | History | Annotate | Download | only in proto_text
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h"
     17 
     18 #include <algorithm>
     19 #include <set>
     20 #include <unordered_set>
     21 
     22 #include "tensorflow/core/platform/logging.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 using ::tensorflow::protobuf::Descriptor;
     27 using ::tensorflow::protobuf::EnumDescriptor;
     28 using ::tensorflow::protobuf::FieldDescriptor;
     29 using ::tensorflow::protobuf::FieldOptions;
     30 using ::tensorflow::protobuf::FileDescriptor;
     31 
     32 namespace tensorflow {
     33 
     34 namespace {
     35 
     36 template <typename... Args>
     37 string StrCat(const Args&... args) {
     38   std::ostringstream s;
     39   std::vector<int>{((s << args), 0)...};
     40   return s.str();
     41 }
     42 
     43 template <typename... Args>
     44 string StrAppend(string* to_append, const Args&... args) {
     45   *to_append += StrCat(args...);
     46   return *to_append;
     47 }
     48 
     49 // Class used to generate the code for proto text functions. One of these should
     50 // be created for each FileDescriptor whose code should be generated.
     51 //
     52 // This class has a notion of the current output Section.  The Print, Nested,
     53 // and Unnest functions apply their operations to the current output section,
     54 // which can be toggled with SetOutput.
     55 //
     56 // Note that on the generated code, various pieces are not optimized - for
     57 // example: map input and output, Cord input and output, comparisons against
     58 // the field names (it's a loop over all names), and tracking of has_seen.
     59 class Generator {
     60  public:
     61   explicit Generator(const string& tf_header_prefix)
     62       : tf_header_prefix_(tf_header_prefix),
     63         header_(&code_.header),
     64         header_impl_(&code_.header_impl),
     65         cc_(&code_.cc) {}
     66 
     67   void Generate(const FileDescriptor& fd);
     68 
     69   // The generated code; valid after Generate has been called.
     70   ProtoTextFunctionCode code() const { return code_; }
     71 
     72  private:
     73   struct Section {
     74     explicit Section(string* str) : str(str) {}
     75     string* str;
     76     string indent;
     77   };
     78 
     79   // Switches the currently active section to <section>.
     80   Generator& SetOutput(Section* section) {
     81     cur_ = section;
     82     return *this;
     83   }
     84 
     85   // Increases indent level.  Returns <*this>, to allow chaining.
     86   Generator& Nest() {
     87     StrAppend(&cur_->indent, "  ");
     88     return *this;
     89   }
     90 
     91   // Decreases indent level.  Returns <*this>, to allow chaining.
     92   Generator& Unnest() {
     93     cur_->indent = cur_->indent.substr(0, cur_->indent.size() - 2);
     94     return *this;
     95   }
     96 
     97   // Appends the concatenated args, with a trailing newline. Returns <*this>, to
     98   // allow chaining.
     99   template <typename... Args>
    100   Generator& Print(Args... args) {
    101     StrAppend(cur_->str, cur_->indent, args..., "\n");
    102     return *this;
    103   }
    104 
    105   // Appends the print code for a single field's value.
    106   // If <omit_default> is true, then the emitted code will not print zero-valued
    107   // values.
    108   // <field_expr> is code that when emitted yields the field's value.
    109   void AppendFieldValueAppend(const FieldDescriptor& field,
    110                               const bool omit_default,
    111                               const string& field_expr);
    112 
    113   // Appends the print code for as single field.
    114   void AppendFieldAppend(const FieldDescriptor& field);
    115 
    116   // Appends the print code for a message. May change which section is currently
    117   // active.
    118   void AppendDebugStringFunctions(const Descriptor& md);
    119 
    120   // Appends the print and parse functions for an enum. May change which
    121   // section is currently active.
    122   void AppendEnumFunctions(const EnumDescriptor& enum_d);
    123 
    124   // Appends the parse functions for a message. May change which section is
    125   // currently active.
    126   void AppendParseMessageFunction(const Descriptor& md);
    127 
    128   // Appends all functions for a message and its nested message and enum types.
    129   // May change which section is currently active.
    130   void AppendMessageFunctions(const Descriptor& md);
    131 
    132   // Appends lines to open or close namespace declarations.
    133   void AddNamespaceToCurrentSection(const string& package, bool open);
    134 
    135   // Appends the given headers as sorted #include lines.
    136   void AddHeadersToCurrentSection(const std::vector<string>& headers);
    137 
    138   // When adding #includes for tensorflow headers, prefix them with this.
    139   const string tf_header_prefix_;
    140   ProtoTextFunctionCode code_;
    141   Section* cur_ = nullptr;
    142   Section header_;
    143   Section header_impl_;
    144   Section cc_;
    145 
    146   std::unordered_set<string> map_append_signatures_included_;
    147 
    148   TF_DISALLOW_COPY_AND_ASSIGN(Generator);
    149 };
    150 
    151 // Returns the prefix needed to reference objects defined in <fd>. E.g.
    152 // "::tensorflow::test".
    153 string GetPackageReferencePrefix(const FileDescriptor* fd) {
    154   string result = "::";
    155   const string& package = fd->package();
    156   for (size_t i = 0; i < package.size(); ++i) {
    157     if (package[i] == '.') {
    158       result += "::";
    159     } else {
    160       result += package[i];
    161     }
    162   }
    163   result += "::";
    164   return result;
    165 }
    166 
    167 // Returns the name of the class generated by proto to represent <d>.
    168 string GetClassName(const Descriptor& d) {
    169   if (d.containing_type() == nullptr) return d.name();
    170   return StrCat(GetClassName(*d.containing_type()), "_", d.name());
    171 }
    172 
    173 // Returns the name of the class generated by proto to represent <ed>.
    174 string GetClassName(const EnumDescriptor& ed) {
    175   if (ed.containing_type() == nullptr) return ed.name();
    176   return StrCat(GetClassName(*ed.containing_type()), "_", ed.name());
    177 }
    178 
    179 // Returns the qualified name that refers to the class generated by proto to
    180 // represent <d>.
    181 string GetQualifiedName(const Descriptor& d) {
    182   return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d));
    183 }
    184 
    185 // Returns the qualified name that refers to the class generated by proto to
    186 // represent <ed>.
    187 string GetQualifiedName(const EnumDescriptor& d) {
    188   return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d));
    189 }
    190 
    191 // Returns the qualified name that refers to the generated
    192 // AppendProtoDebugString function for <d>.
    193 string GetQualifiedAppendFn(const Descriptor& d) {
    194   return StrCat(GetPackageReferencePrefix(d.file()),
    195                 "internal::AppendProtoDebugString");
    196 }
    197 
    198 // Returns the name of the generated function that returns an enum value's
    199 // string value.
    200 string GetEnumNameFn(const EnumDescriptor& enum_d) {
    201   return StrCat("EnumName_", GetClassName(enum_d));
    202 }
    203 
    204 // Returns the qualified name of the function returned by GetEnumNameFn().
    205 string GetQualifiedEnumNameFn(const EnumDescriptor& enum_d) {
    206   return StrCat(GetPackageReferencePrefix(enum_d.file()),
    207                 GetEnumNameFn(enum_d));
    208 }
    209 
    210 // Returns the name of a generated header file, either the public api (if impl
    211 // is false) or the internal implementation header (if impl is true).
    212 string GetProtoTextHeaderName(const FileDescriptor& fd, bool impl) {
    213   const int dot_index = fd.name().find_last_of('.');
    214   return fd.name().substr(0, dot_index) +
    215          (impl ? ".pb_text-impl.h" : ".pb_text.h");
    216 }
    217 
    218 // Returns the name of the header generated by the proto library for <fd>.
    219 string GetProtoHeaderName(const FileDescriptor& fd) {
    220   const int dot_index = fd.name().find_last_of('.');
    221   return fd.name().substr(0, dot_index) + ".pb.h";
    222 }
    223 
    224 // Returns the C++ class name for the given proto field.
    225 string GetCppClass(const FieldDescriptor& d) {
    226   string cpp_class = d.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
    227                          ? GetQualifiedName(*d.message_type())
    228                          : d.cpp_type_name();
    229 
    230   // In open-source TensorFlow, the definition of int64 varies across
    231   // platforms. The following line, which is manipulated during internal-
    232   // external sync'ing, takes care of the variability.
    233   if (cpp_class == "int64") {
    234     cpp_class = kProtobufInt64Typename;
    235   }
    236 
    237   return cpp_class;
    238 }
    239 
    240 // Returns the string that can be used for a header guard for the generated
    241 // headers for <fd>, either for the public api (if impl is false) or the
    242 // internal implementation header (if impl is true).
    243 string GetHeaderGuard(const FileDescriptor& fd, bool impl) {
    244   string s = fd.name();
    245   std::replace(s.begin(), s.end(), '/', '_');
    246   std::replace(s.begin(), s.end(), '.', '_');
    247   return s + (impl ? "_IMPL_H_" : "_H_");
    248 }
    249 
    250 void Generator::AppendFieldValueAppend(const FieldDescriptor& field,
    251                                        const bool omit_default,
    252                                        const string& field_expr) {
    253   SetOutput(&cc_);
    254   switch (field.cpp_type()) {
    255     case FieldDescriptor::CPPTYPE_INT32:
    256     case FieldDescriptor::CPPTYPE_INT64:
    257     case FieldDescriptor::CPPTYPE_UINT32:
    258     case FieldDescriptor::CPPTYPE_UINT64:
    259     case FieldDescriptor::CPPTYPE_DOUBLE:
    260     case FieldDescriptor::CPPTYPE_FLOAT:
    261       Print("o->", omit_default ? "AppendNumericIfNotZero" : "AppendNumeric",
    262             "(\"", field.name(), "\", ", field_expr, ");");
    263       break;
    264     case FieldDescriptor::CPPTYPE_BOOL:
    265       Print("o->", omit_default ? "AppendBoolIfTrue" : "AppendBool", "(\"",
    266             field.name(), "\", ", field_expr, ");");
    267       break;
    268     case FieldDescriptor::CPPTYPE_STRING: {
    269       const auto ctype = field.options().ctype();
    270       CHECK(ctype == FieldOptions::CORD || ctype == FieldOptions::STRING)
    271           << "Unsupported ctype " << ctype;
    272 
    273       Print("o->", omit_default ? "AppendStringIfNotEmpty" : "AppendString",
    274             "(\"", field.name(), "\", ProtobufStringToString(", field_expr,
    275             "));");
    276       break;
    277     }
    278     case FieldDescriptor::CPPTYPE_ENUM:
    279       if (omit_default) {
    280         Print("if (", field_expr, " != 0) {").Nest();
    281       }
    282       Print("o->AppendEnumName(\"", field.name(), "\", ",
    283             GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, "));");
    284       if (omit_default) {
    285         Unnest().Print("}");
    286       }
    287       break;
    288     case FieldDescriptor::CPPTYPE_MESSAGE:
    289       CHECK(!field.message_type()->options().map_entry());
    290       if (omit_default) {
    291         Print("if (msg.has_", field.name(), "()) {").Nest();
    292       }
    293       Print("o->OpenNestedMessage(\"", field.name(), "\");");
    294       Print(GetQualifiedAppendFn(*field.message_type()), "(o, ", field_expr,
    295             ");");
    296       Print("o->CloseNestedMessage();");
    297       if (omit_default) {
    298         Unnest().Print("}");
    299       }
    300       break;
    301   }
    302 }
    303 
    304 void Generator::AppendFieldAppend(const FieldDescriptor& field) {
    305   const string& name = field.name();
    306 
    307   if (field.is_map()) {
    308     Print("{").Nest();
    309     const auto& key_type = *field.message_type()->FindFieldByName("key");
    310     const auto& value_type = *field.message_type()->FindFieldByName("value");
    311 
    312     Print("std::vector<", key_type.cpp_type_name(), "> keys;");
    313     Print("for (const auto& e : msg.", name, "()) keys.push_back(e.first);");
    314     Print("std::stable_sort(keys.begin(), keys.end());");
    315     Print("for (const auto& key : keys) {").Nest();
    316     Print("o->OpenNestedMessage(\"", name, "\");");
    317     AppendFieldValueAppend(key_type, false /* omit_default */, "key");
    318     AppendFieldValueAppend(value_type, false /* omit_default */,
    319                            StrCat("msg.", name, "().at(key)"));
    320     Print("o->CloseNestedMessage();");
    321     Unnest().Print("}");
    322 
    323     Unnest().Print("}");
    324   } else if (field.is_repeated()) {
    325     Print("for (int i = 0; i < msg.", name, "_size(); ++i) {");
    326     Nest();
    327     AppendFieldValueAppend(field, false /* omit_default */,
    328                            "msg." + name + "(i)");
    329     Unnest().Print("}");
    330   } else {
    331     const auto* oneof = field.containing_oneof();
    332     if (oneof != nullptr) {
    333       string camel_name = field.camelcase_name();
    334       camel_name[0] = toupper(camel_name[0]);
    335       Print("if (msg.", oneof->name(), "_case() == ",
    336             GetQualifiedName(*oneof->containing_type()), "::k", camel_name,
    337             ") {");
    338       Nest();
    339       AppendFieldValueAppend(field, false /* omit_default */,
    340                              "msg." + name + "()");
    341       Unnest();
    342       Print("}");
    343     } else {
    344       AppendFieldValueAppend(field, true /* omit_default */,
    345                              "msg." + name + "()");
    346     }
    347   }
    348 }
    349 
    350 void Generator::AppendEnumFunctions(const EnumDescriptor& enum_d) {
    351   const string sig = StrCat("const char* ", GetEnumNameFn(enum_d), "(\n    ",
    352                             GetQualifiedName(enum_d), " value)");
    353   SetOutput(&header_);
    354   Print().Print("// Enum text output for ", string(enum_d.full_name()));
    355   Print(sig, ";");
    356 
    357   SetOutput(&cc_);
    358   Print().Print(sig, " {");
    359   Nest().Print("switch (value) {").Nest();
    360   for (int i = 0; i < enum_d.value_count(); ++i) {
    361     const auto& value = *enum_d.value(i);
    362     Print("case ", value.number(), ": return \"", value.name(), "\";");
    363   }
    364   Print("default: return \"\";");
    365   Unnest().Print("}");
    366   Unnest().Print("}");
    367 }
    368 
    369 void Generator::AppendParseMessageFunction(const Descriptor& md) {
    370   const bool map_append = (md.options().map_entry());
    371   string sig;
    372   if (!map_append) {
    373     sig = StrCat("bool ProtoParseFromString(\n    const string& s,\n    ",
    374                  GetQualifiedName(md), "* msg)");
    375     SetOutput(&header_).Print(sig, "\n        TF_MUST_USE_RESULT;");
    376 
    377     SetOutput(&cc_);
    378     Print().Print(sig, " {").Nest();
    379     Print("msg->Clear();");
    380     Print("Scanner scanner(s);");
    381     Print("if (!internal::ProtoParseFromScanner(",
    382           "&scanner, false, false, msg)) return false;");
    383     Print("scanner.Eos();");
    384     Print("return scanner.GetResult();");
    385     Unnest().Print("}");
    386   }
    387 
    388   // Parse from scanner - the real work here.
    389   sig = StrCat("bool ProtoParseFromScanner(",
    390                "\n    ::tensorflow::strings::Scanner* scanner, bool nested, "
    391                "bool close_curly,\n    ");
    392   const FieldDescriptor* key_type = nullptr;
    393   const FieldDescriptor* value_type = nullptr;
    394   if (map_append) {
    395     key_type = md.FindFieldByName("key");
    396     value_type = md.FindFieldByName("value");
    397     StrAppend(&sig, "::tensorflow::protobuf::Map<", GetCppClass(*key_type),
    398               ", ", GetCppClass(*value_type), ">* map)");
    399   } else {
    400     StrAppend(&sig, GetQualifiedName(md), "* msg)");
    401   }
    402 
    403   if (!map_append_signatures_included_.insert(sig).second) {
    404     // signature for function to append to a map of this type has
    405     // already been defined in this .cc file. Don't define it again.
    406     return;
    407   }
    408 
    409   if (!map_append) {
    410     SetOutput(&header_impl_).Print(sig, ";");
    411   }
    412 
    413   SetOutput(&cc_);
    414   Print().Print("namespace internal {");
    415   if (map_append) {
    416     Print("namespace {");
    417   }
    418   Print().Print(sig, " {").Nest();
    419   if (map_append) {
    420     Print(GetCppClass(*key_type), " map_key;");
    421     Print("bool set_map_key = false;");
    422     Print(GetCppClass(*value_type), " map_value;");
    423     Print("bool set_map_value = false;");
    424   }
    425   Print("std::vector<bool> has_seen(", md.field_count(), ", false);");
    426   Print("while(true) {").Nest();
    427   Print("ProtoSpaceAndComments(scanner);");
    428 
    429   // Emit success case
    430   Print("if (nested && (scanner->Peek() == (close_curly ? '}' : '>'))) {")
    431       .Nest();
    432   Print("scanner->One(Scanner::ALL);");
    433   Print("ProtoSpaceAndComments(scanner);");
    434   if (map_append) {
    435     Print("if (!set_map_key || !set_map_value) return false;");
    436     Print("(*map)[map_key] = map_value;");
    437   }
    438   Print("return true;");
    439   Unnest().Print("}");
    440 
    441   Print("if (!nested && scanner->empty()) { return true; }");
    442   Print("scanner->RestartCapture()");
    443   Print("    .Many(Scanner::LETTER_DIGIT_UNDERSCORE)");
    444   Print("    .StopCapture();");
    445   Print("StringPiece identifier;");
    446   Print("if (!scanner->GetResult(nullptr, &identifier)) return false;");
    447   Print("bool parsed_colon = false;");
    448   Print("(void)parsed_colon;"); // Avoid "set but not used" compiler warning
    449   Print("ProtoSpaceAndComments(scanner);");
    450   Print("if (scanner->Peek() == ':') {");
    451   Nest().Print("parsed_colon = true;");
    452   Print("scanner->One(Scanner::ALL);");
    453   Print("ProtoSpaceAndComments(scanner);");
    454   Unnest().Print("}");
    455   for (int i = 0; i < md.field_count(); ++i) {
    456     const FieldDescriptor* field = md.field(i);
    457     const string& field_name = field->name();
    458     string mutable_value_expr;
    459     string set_value_prefix;
    460     if (map_append) {
    461       mutable_value_expr = StrCat("&map_", field_name);
    462       set_value_prefix = StrCat("map_", field_name, " = ");
    463     } else if (field->is_repeated()) {
    464       if (field->is_map()) {
    465         mutable_value_expr = StrCat("msg->mutable_", field_name, "()");
    466         set_value_prefix =
    467             "UNREACHABLE";  // generator will never use this value.
    468       } else {
    469         mutable_value_expr = StrCat("msg->add_", field_name, "()");
    470         set_value_prefix = StrCat("msg->add_", field_name);
    471       }
    472     } else {
    473       mutable_value_expr = StrCat("msg->mutable_", field_name, "()");
    474       set_value_prefix = StrCat("msg->set_", field_name);
    475     }
    476 
    477     Print(i == 0 ? "" : "else ", "if (identifier == \"", field_name, "\") {");
    478     Nest();
    479 
    480     if (field->is_repeated()) {
    481       CHECK(!map_append);
    482 
    483       // Check to see if this is an array assignment, like a: [1, 2, 3]
    484       Print("const bool is_list = (scanner->Peek() == '[');");
    485       Print("do {");
    486       // [ or , // skip
    487       Nest().Print("if (is_list) {");
    488       Nest().Print("scanner->One(Scanner::ALL);");
    489       Print("ProtoSpaceAndComments(scanner);");
    490       Unnest().Print("}");
    491     } else if (field->containing_oneof() != nullptr) {
    492       CHECK(!map_append);
    493 
    494       // Detect duplicate oneof value.
    495       const string oneof_name = field->containing_oneof()->name();
    496       Print("if (msg->", oneof_name, "_case() != 0) return false;");
    497     }
    498 
    499     if (!field->is_repeated() && !map_append) {
    500       // Detect duplicate nested repeated message.
    501       Print("if (has_seen[", i, "]) return false;");
    502       Print("has_seen[", i, "] = true;");
    503     }
    504     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
    505       Print("const char open_char = scanner->Peek();");
    506       Print("if (open_char != '{' && open_char != '<') return false;");
    507       Print("scanner->One(Scanner::ALL);");
    508       Print("ProtoSpaceAndComments(scanner);");
    509       if (field->is_map()) {
    510         Print("if (!ProtoParseFromScanner(");
    511       } else {
    512         Print("if (!", GetPackageReferencePrefix(field->message_type()->file()),
    513               "internal::ProtoParseFromScanner(");
    514       }
    515       Print("    scanner, true, open_char == '{', ", mutable_value_expr,
    516             ")) return false;");
    517     } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) {
    518       Print("string str_value;");
    519       Print(
    520           "if (!parsed_colon || "
    521           "!::tensorflow::strings::ProtoParseStringLiteralFromScanner(");
    522       Print("    scanner, &str_value)) return false;");
    523       Print("SetProtobufStringSwapAllowed(&str_value, ", mutable_value_expr,
    524             ");");
    525     } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
    526       Print("StringPiece value;");
    527       Print(
    528           "if (!parsed_colon || "
    529           "!scanner->RestartCapture().Many("
    530           "Scanner::LETTER_DIGIT_DASH_UNDERSCORE)."
    531           "GetResult(nullptr, &value)) return false;");
    532       const auto* enum_d = field->enum_type();
    533       string value_prefix;
    534       if (enum_d->containing_type() == nullptr) {
    535         value_prefix = GetPackageReferencePrefix(enum_d->file());
    536       } else {
    537         value_prefix = StrCat(GetQualifiedName(*enum_d), "_");
    538       }
    539 
    540       for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) {
    541         const auto* value_d = enum_d->value(enum_i);
    542         const string& value_name = value_d->name();
    543         string condition = StrCat("value == \"", value_name,
    544                                   "\" || value == \"", value_d->number(), "\"");
    545         if (value_d->number() == 0) {
    546           StrAppend(&condition, " || value == \"-0\"");
    547         }
    548 
    549         Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {");
    550         Nest();
    551         Print(set_value_prefix, "(", value_prefix, value_name, ");");
    552         Unnest();
    553       }
    554       Print("} else {").Nest().Print("return false;").Unnest().Print("}");
    555     } else {
    556       Print(field->cpp_type_name(), " value;");
    557       switch (field->cpp_type()) {
    558         case FieldDescriptor::CPPTYPE_INT32:
    559         case FieldDescriptor::CPPTYPE_INT64:
    560         case FieldDescriptor::CPPTYPE_UINT32:
    561         case FieldDescriptor::CPPTYPE_UINT64:
    562         case FieldDescriptor::CPPTYPE_DOUBLE:
    563         case FieldDescriptor::CPPTYPE_FLOAT:
    564           Print(
    565               "if (!parsed_colon || "
    566               "!::tensorflow::strings::ProtoParseNumericFromScanner(",
    567               "scanner, &value)) return false;");
    568           break;
    569         case FieldDescriptor::CPPTYPE_BOOL:
    570           Print(
    571               "if (!parsed_colon || "
    572               "!::tensorflow::strings::ProtoParseBoolFromScanner(",
    573               "scanner, &value)) return false;");
    574           break;
    575         default:
    576           LOG(FATAL) << "handled earlier";
    577       }
    578       Print(set_value_prefix, "(value);");
    579     }
    580 
    581     if (field->is_repeated()) {
    582       Unnest().Print("} while (is_list && scanner->Peek() == ',');");
    583       Print(
    584           "if (is_list && "
    585           "!scanner->OneLiteral(\"]\").GetResult()) return false;");
    586     }
    587     if (map_append) {
    588       Print("set_map_", field_name, " = true;");
    589     }
    590     Unnest().Print("}");
    591   }
    592   Unnest().Print("}");
    593   Unnest().Print("}");
    594   Unnest().Print();
    595   if (map_append) {
    596     Print("}  // namespace");
    597   }
    598   Print("}  // namespace internal");
    599 }
    600 
    601 void Generator::AppendDebugStringFunctions(const Descriptor& md) {
    602   SetOutput(&header_impl_).Print();
    603   SetOutput(&header_).Print().Print("// Message-text conversion for ",
    604                                     string(md.full_name()));
    605 
    606   // Append the two debug string functions for <md>.
    607   for (int short_pass = 0; short_pass < 2; ++short_pass) {
    608     const bool short_debug = (short_pass == 1);
    609 
    610     // Make the Get functions.
    611     const string sig = StrCat(
    612         "string ", short_debug ? "ProtoShortDebugString" : "ProtoDebugString",
    613         "(\n    const ", GetQualifiedName(md), "& msg)");
    614     SetOutput(&header_).Print(sig, ";");
    615 
    616     SetOutput(&cc_);
    617     Print().Print(sig, " {").Nest();
    618     Print("string s;");
    619     Print("::tensorflow::strings::ProtoTextOutput o(&s, ",
    620           short_debug ? "true" : "false", ");");
    621     Print("internal::AppendProtoDebugString(&o, msg);");
    622     Print("o.CloseTopMessage();");
    623     Print("return s;");
    624     Unnest().Print("}");
    625   }
    626 
    627   // Make the Append function.
    628   const string sig =
    629       StrCat("void AppendProtoDebugString(\n",
    630              "    ::tensorflow::strings::ProtoTextOutput* o,\n    const ",
    631              GetQualifiedName(md), "& msg)");
    632   SetOutput(&header_impl_).Print(sig, ";");
    633   SetOutput(&cc_);
    634   Print().Print("namespace internal {").Print();
    635   Print(sig, " {").Nest();
    636   std::vector<const FieldDescriptor*> fields;
    637   fields.reserve(md.field_count());
    638   for (int i = 0; i < md.field_count(); ++i) {
    639     fields.push_back(md.field(i));
    640   }
    641   std::sort(fields.begin(), fields.end(),
    642             [](const FieldDescriptor* left, const FieldDescriptor* right) {
    643               return left->number() < right->number();
    644             });
    645 
    646   for (const FieldDescriptor* field : fields) {
    647     SetOutput(&cc_);
    648     AppendFieldAppend(*field);
    649   }
    650   Unnest().Print("}").Print().Print("}  // namespace internal");
    651 }
    652 
    653 void Generator::AppendMessageFunctions(const Descriptor& md) {
    654   if (md.options().map_entry()) {
    655     // The 'map entry' Message is not a user-visible message type.  Only its
    656     // parse function is created (and that actually parsed the whole Map, not
    657     // just the map entry). Printing of a map is done in the code generated for
    658     // the containing message.
    659     AppendParseMessageFunction(md);
    660     return;
    661   }
    662 
    663   // Recurse before adding the main message function, so that internal
    664   // map_append functions are available before they are needed.
    665   for (int i = 0; i < md.enum_type_count(); ++i) {
    666     AppendEnumFunctions(*md.enum_type(i));
    667   }
    668   for (int i = 0; i < md.nested_type_count(); ++i) {
    669     AppendMessageFunctions(*md.nested_type(i));
    670   }
    671 
    672   AppendDebugStringFunctions(md);
    673   AppendParseMessageFunction(md);
    674 }
    675 
    676 void Generator::AddNamespaceToCurrentSection(const string& package, bool open) {
    677   Print();
    678   std::vector<string> parts = {""};
    679   for (size_t i = 0; i < package.size(); ++i) {
    680     if (package[i] == '.') {
    681       parts.resize(parts.size() + 1);
    682     } else {
    683       parts.back() += package[i];
    684     }
    685   }
    686   if (open) {
    687     for (const auto& p : parts) {
    688       Print("namespace ", p, " {");
    689     }
    690   } else {
    691     for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
    692       Print("}  // namespace ", *it);
    693     }
    694   }
    695 }
    696 
    697 void Generator::AddHeadersToCurrentSection(const std::vector<string>& headers) {
    698   std::vector<string> sorted = headers;
    699   std::sort(sorted.begin(), sorted.end());
    700   for (const auto& h : sorted) {
    701     Print("#include \"", h, "\"");
    702   }
    703 }
    704 
    705 // Adds to <all_fd> and <all_d> with all descriptors recursively
    706 // reachable from the given descriptor.
    707 void GetAllFileDescriptorsFromFile(const FileDescriptor* fd,
    708                                    std::set<const FileDescriptor*>* all_fd,
    709                                    std::set<const Descriptor*>* all_d);
    710 
    711 // Adds to <all_fd> and <all_d> with all descriptors recursively
    712 // reachable from the given descriptor.
    713 void GetAllFileDescriptorsFromMessage(const Descriptor* d,
    714                                       std::set<const FileDescriptor*>* all_fd,
    715                                       std::set<const Descriptor*>* all_d) {
    716   if (!all_d->insert(d).second) return;
    717   GetAllFileDescriptorsFromFile(d->file(), all_fd, all_d);
    718   for (int i = 0; i < d->field_count(); ++i) {
    719     auto* f = d->field(i);
    720     switch (f->cpp_type()) {
    721       case FieldDescriptor::CPPTYPE_INT32:
    722       case FieldDescriptor::CPPTYPE_INT64:
    723       case FieldDescriptor::CPPTYPE_UINT32:
    724       case FieldDescriptor::CPPTYPE_UINT64:
    725       case FieldDescriptor::CPPTYPE_DOUBLE:
    726       case FieldDescriptor::CPPTYPE_FLOAT:
    727       case FieldDescriptor::CPPTYPE_BOOL:
    728       case FieldDescriptor::CPPTYPE_STRING:
    729         break;
    730       case FieldDescriptor::CPPTYPE_MESSAGE:
    731         GetAllFileDescriptorsFromMessage(f->message_type(), all_fd, all_d);
    732         break;
    733       case FieldDescriptor::CPPTYPE_ENUM:
    734         GetAllFileDescriptorsFromFile(f->enum_type()->file(), all_fd, all_d);
    735         break;
    736     }
    737   }
    738   for (int i = 0; i < d->nested_type_count(); ++i) {
    739     GetAllFileDescriptorsFromMessage(d->nested_type(i), all_fd, all_d);
    740   }
    741 }
    742 
    743 void GetAllFileDescriptorsFromFile(const FileDescriptor* fd,
    744                                    std::set<const FileDescriptor*>* all_fd,
    745                                    std::set<const Descriptor*>* all_d) {
    746   if (!all_fd->insert(fd).second) return;
    747   for (int i = 0; i < fd->message_type_count(); ++i) {
    748     GetAllFileDescriptorsFromMessage(fd->message_type(i), all_fd, all_d);
    749   }
    750 }
    751 
    752 void Generator::Generate(const FileDescriptor& fd) {
    753   // This does not emit code with proper proto2 semantics (e.g. it doesn't check
    754   // 'has' fields on non-messages), so check that only proto3 is passed.
    755   CHECK_EQ(fd.syntax(), FileDescriptor::SYNTAX_PROTO3) << fd.name();
    756 
    757   const string package = fd.package();
    758   std::set<const FileDescriptor*> all_fd;
    759   std::set<const Descriptor*> all_d;
    760   GetAllFileDescriptorsFromFile(&fd, &all_fd, &all_d);
    761 
    762   std::vector<string> headers;
    763 
    764   // Add header to header file.
    765   SetOutput(&header_);
    766   Print("// GENERATED FILE - DO NOT MODIFY");
    767   Print("#ifndef ", GetHeaderGuard(fd, false /* impl */));
    768   Print("#define ", GetHeaderGuard(fd, false /* impl */));
    769   Print();
    770   headers = {
    771       GetProtoHeaderName(fd),
    772       StrCat(tf_header_prefix_, "tensorflow/core/platform/macros.h"),
    773       StrCat(tf_header_prefix_, "tensorflow/core/platform/protobuf.h"),
    774       StrCat(tf_header_prefix_, "tensorflow/core/platform/types.h"),
    775   };
    776   for (const auto& h : headers) {
    777     Print("#include \"", h, "\"");
    778   }
    779   AddNamespaceToCurrentSection(package, true /* is_open */);
    780 
    781   // Add header to impl file.
    782   SetOutput(&header_impl_);
    783   Print("// GENERATED FILE - DO NOT MODIFY");
    784   Print("#ifndef ", GetHeaderGuard(fd, true /* impl */));
    785   Print("#define ", GetHeaderGuard(fd, true /* impl */));
    786   Print();
    787   headers = {
    788       GetProtoTextHeaderName(fd, false /* impl */),
    789       StrCat(tf_header_prefix_,
    790              "tensorflow/core/lib/strings/proto_text_util.h"),
    791       StrCat(tf_header_prefix_, "tensorflow/core/lib/strings/scanner.h"),
    792   };
    793   for (const FileDescriptor* d : all_fd) {
    794     if (d != &fd) {
    795       headers.push_back(GetProtoTextHeaderName(*d, true /* impl */));
    796     }
    797     headers.push_back(GetProtoHeaderName(*d));
    798   }
    799   AddHeadersToCurrentSection(headers);
    800   AddNamespaceToCurrentSection(package, true /* is_open */);
    801   SetOutput(&header_impl_).Print().Print("namespace internal {");
    802 
    803   // Add header to cc file.
    804   SetOutput(&cc_);
    805   Print("// GENERATED FILE - DO NOT MODIFY");
    806   headers = {GetProtoTextHeaderName(fd, true /* impl */)};
    807   AddHeadersToCurrentSection(headers);
    808   Print();
    809   Print("using ::tensorflow::strings::Scanner;");
    810   Print("using ::tensorflow::strings::StrCat;");
    811   AddNamespaceToCurrentSection(package, true /* is_open */);
    812 
    813   // Add declarations and definitions.
    814   for (int i = 0; i < fd.enum_type_count(); ++i) {
    815     AppendEnumFunctions(*fd.enum_type(i));
    816   }
    817   for (int i = 0; i < fd.message_type_count(); ++i) {
    818     AppendMessageFunctions(*fd.message_type(i));
    819   }
    820 
    821   // Add footer to header file.
    822   SetOutput(&header_);
    823   AddNamespaceToCurrentSection(package, false /* is_open */);
    824   Print().Print("#endif  // ", GetHeaderGuard(fd, false /* impl */));
    825 
    826   // Add footer to header impl file.
    827   SetOutput(&header_impl_).Print().Print("}  // namespace internal");
    828   AddNamespaceToCurrentSection(package, false /* is_open */);
    829   Print().Print("#endif  // ", GetHeaderGuard(fd, true /* impl */));
    830 
    831   // Add footer to cc file.
    832   SetOutput(&cc_);
    833   AddNamespaceToCurrentSection(package, false /* is_open */);
    834 }
    835 
    836 }  // namespace
    837 
    838 ProtoTextFunctionCode GetProtoTextFunctionCode(const FileDescriptor& fd,
    839                                                const string& tf_header_prefix) {
    840   Generator gen(tf_header_prefix);
    841   gen.Generate(fd);
    842   return gen.code();
    843 }
    844 
    845 }  // namespace tensorflow
    846