Home | History | Annotate | Download | only in proto_text
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h"
     17 
     18 #include <algorithm>
     19 #include <set>
     20 #include <unordered_set>
     21 
     22 #include "tensorflow/core/platform/logging.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 using ::tensorflow::protobuf::Descriptor;
     27 using ::tensorflow::protobuf::EnumDescriptor;
     28 using ::tensorflow::protobuf::FieldDescriptor;
     29 using ::tensorflow::protobuf::FieldOptions;
     30 using ::tensorflow::protobuf::FileDescriptor;
     31 
     32 namespace tensorflow {
     33 
     34 namespace {
     35 
     36 template <typename... Args>
     37 string StrCat(const Args&... args) {
     38   std::ostringstream s;
     39   std::vector<int>{((s << args), 0)...};
     40   return s.str();
     41 }
     42 
     43 template <typename... Args>
     44 string StrAppend(string* to_append, const Args&... args) {
     45   *to_append += StrCat(args...);
     46   return *to_append;
     47 }
     48 
     49 // Class used to generate the code for proto text functions. One of these should
     50 // be created for each FileDescriptor whose code should be generated.
     51 //
     52 // This class has a notion of the current output Section.  The Print, Nested,
     53 // and Unnest functions apply their operations to the current output section,
     54 // which can be toggled with SetOutput.
     55 //
     56 // Note that on the generated code, various pieces are not optimized - for
     57 // example: map input and output, Cord input and output, comparisons against
     58 // the field names (it's a loop over all names), and tracking of has_seen.
     59 class Generator {
     60  public:
     61   explicit Generator(const string& tf_header_prefix)
     62       : tf_header_prefix_(tf_header_prefix),
     63         header_(&code_.header),
     64         header_impl_(&code_.header_impl),
     65         cc_(&code_.cc) {}
     66 
     67   void Generate(const FileDescriptor& fd);
     68 
     69   // The generated code; valid after Generate has been called.
     70   ProtoTextFunctionCode code() const { return code_; }
     71 
     72  private:
     73   struct Section {
     74     explicit Section(string* str) : str(str) {}
     75     string* str;
     76     string indent;
     77   };
     78 
     79   // Switches the currently active section to <section>.
     80   Generator& SetOutput(Section* section) {
     81     cur_ = section;
     82     return *this;
     83   }
     84 
     85   // Increases indent level.  Returns <*this>, to allow chaining.
     86   Generator& Nest() {
     87     StrAppend(&cur_->indent, "  ");
     88     return *this;
     89   }
     90 
     91   // Decreases indent level.  Returns <*this>, to allow chaining.
     92   Generator& Unnest() {
     93     cur_->indent = cur_->indent.substr(0, cur_->indent.size() - 2);
     94     return *this;
     95   }
     96 
     97   // Appends the concatenated args, with a trailing newline. Returns <*this>, to
     98   // allow chaining.
     99   template <typename... Args>
    100   Generator& Print(Args... args) {
    101     StrAppend(cur_->str, cur_->indent, args..., "\n");
    102     return *this;
    103   }
    104 
    105   // Appends the print code for a single field's value.
    106   // If <omit_default> is true, then the emitted code will not print zero-valued
    107   // values.
    108   // <field_expr> is code that when emitted yields the field's value.
    109   void AppendFieldValueAppend(const FieldDescriptor& field,
    110                               const bool omit_default,
    111                               const string& field_expr);
    112 
    113   // Appends the print code for as single field.
    114   void AppendFieldAppend(const FieldDescriptor& field);
    115 
    116   // Appends the print code for a message. May change which section is currently
    117   // active.
    118   void AppendDebugStringFunctions(const Descriptor& md);
    119 
    120   // Appends the print and parse functions for an enum. May change which
    121   // section is currently active.
    122   void AppendEnumFunctions(const EnumDescriptor& enum_d);
    123 
    124   // Appends the parse functions for a message. May change which section is
    125   // currently active.
    126   void AppendParseMessageFunction(const Descriptor& md);
    127 
    128   // Appends all functions for a message and its nested message and enum types.
    129   // May change which section is currently active.
    130   void AppendMessageFunctions(const Descriptor& md);
    131 
    132   // Appends lines to open or close namespace declarations.
    133   void AddNamespaceToCurrentSection(const string& package, bool open);
    134 
    135   // Appends the given headers as sorted #include lines.
    136   void AddHeadersToCurrentSection(const std::vector<string>& headers);
    137 
    138   // When adding #includes for tensorflow headers, prefix them with this.
    139   const string tf_header_prefix_;
    140   ProtoTextFunctionCode code_;
    141   Section* cur_ = nullptr;
    142   Section header_;
    143   Section header_impl_;
    144   Section cc_;
    145 
    146   std::unordered_set<string> map_append_signatures_included_;
    147 
    148   TF_DISALLOW_COPY_AND_ASSIGN(Generator);
    149 };
    150 
    151 // Returns the prefix needed to reference objects defined in <fd>. E.g.
    152 // "::tensorflow::test".
    153 string GetPackageReferencePrefix(const FileDescriptor* fd) {
    154   string result = "::";
    155   const string& package = fd->package();
    156   for (size_t i = 0; i < package.size(); ++i) {
    157     if (package[i] == '.') {
    158       result += "::";
    159     } else {
    160       result += package[i];
    161     }
    162   }
    163   result += "::";
    164   return result;
    165 }
    166 
    167 // Returns the name of the class generated by proto to represent <d>.
    168 string GetClassName(const Descriptor& d) {
    169   if (d.containing_type() == nullptr) return d.name();
    170   return StrCat(GetClassName(*d.containing_type()), "_", d.name());
    171 }
    172 
    173 // Returns the name of the class generated by proto to represent <ed>.
    174 string GetClassName(const EnumDescriptor& ed) {
    175   if (ed.containing_type() == nullptr) return ed.name();
    176   return StrCat(GetClassName(*ed.containing_type()), "_", ed.name());
    177 }
    178 
    179 // Returns the qualified name that refers to the class generated by proto to
    180 // represent <d>.
    181 string GetQualifiedName(const Descriptor& d) {
    182   return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d));
    183 }
    184 
    185 // Returns the qualified name that refers to the class generated by proto to
    186 // represent <ed>.
    187 string GetQualifiedName(const EnumDescriptor& d) {
    188   return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d));
    189 }
    190 
    191 // Returns the qualified name that refers to the generated
    192 // AppendProtoDebugString function for <d>.
    193 string GetQualifiedAppendFn(const Descriptor& d) {
    194   return StrCat(GetPackageReferencePrefix(d.file()),
    195                 "internal::AppendProtoDebugString");
    196 }
    197 
    198 // Returns the name of the generated function that returns an enum value's
    199 // string value.
    200 string GetEnumNameFn(const EnumDescriptor& enum_d) {
    201   return StrCat("EnumName_", GetClassName(enum_d));
    202 }
    203 
    204 // Returns the qualified name of the function returned by GetEnumNameFn().
    205 string GetQualifiedEnumNameFn(const EnumDescriptor& enum_d) {
    206   return StrCat(GetPackageReferencePrefix(enum_d.file()),
    207                 GetEnumNameFn(enum_d));
    208 }
    209 
    210 // Returns the name of a generated header file, either the public api (if impl
    211 // is false) or the internal implementation header (if impl is true).
    212 string GetProtoTextHeaderName(const FileDescriptor& fd, bool impl) {
    213   const int dot_index = fd.name().find_last_of('.');
    214   return fd.name().substr(0, dot_index) +
    215          (impl ? ".pb_text-impl.h" : ".pb_text.h");
    216 }
    217 
    218 // Returns the name of the header generated by the proto library for <fd>.
    219 string GetProtoHeaderName(const FileDescriptor& fd) {
    220   const int dot_index = fd.name().find_last_of('.');
    221   return fd.name().substr(0, dot_index) + ".pb.h";
    222 }
    223 
    224 // Returns the C++ class name for the given proto field.
    225 string GetCppClass(const FieldDescriptor& d) {
    226   string cpp_class = d.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
    227                          ? GetQualifiedName(*d.message_type())
    228                          : d.cpp_type_name();
    229 
    230   // In open-source TensorFlow, the definition of int64 varies across
    231   // platforms. The following line, which is manipulated during internal-
    232   // external sync'ing, takes care of the variability.
    233   if (cpp_class == "int64") {
    234     cpp_class = kProtobufInt64Typename;
    235   }
    236 
    237   return cpp_class;
    238 }
    239 
    240 // Returns the string that can be used for a header guard for the generated
    241 // headers for <fd>, either for the public api (if impl is false) or the
    242 // internal implementation header (if impl is true).
    243 string GetHeaderGuard(const FileDescriptor& fd, bool impl) {
    244   string s = fd.name();
    245   std::replace(s.begin(), s.end(), '/', '_');
    246   std::replace(s.begin(), s.end(), '.', '_');
    247   return s + (impl ? "_IMPL_H_" : "_H_");
    248 }
    249 
    250 void Generator::AppendFieldValueAppend(const FieldDescriptor& field,
    251                                        const bool omit_default,
    252                                        const string& field_expr) {
    253   SetOutput(&cc_);
    254   switch (field.cpp_type()) {
    255     case FieldDescriptor::CPPTYPE_INT32:
    256     case FieldDescriptor::CPPTYPE_INT64:
    257     case FieldDescriptor::CPPTYPE_UINT32:
    258     case FieldDescriptor::CPPTYPE_UINT64:
    259     case FieldDescriptor::CPPTYPE_DOUBLE:
    260     case FieldDescriptor::CPPTYPE_FLOAT:
    261       Print("o->", omit_default ? "AppendNumericIfNotZero" : "AppendNumeric",
    262             "(\"", field.name(), "\", ", field_expr, ");");
    263       break;
    264     case FieldDescriptor::CPPTYPE_BOOL:
    265       Print("o->", omit_default ? "AppendBoolIfTrue" : "AppendBool", "(\"",
    266             field.name(), "\", ", field_expr, ");");
    267       break;
    268     case FieldDescriptor::CPPTYPE_STRING: {
    269       const auto ctype = field.options().ctype();
    270       CHECK(ctype == FieldOptions::CORD || ctype == FieldOptions::STRING)
    271           << "Unsupported ctype " << ctype;
    272 
    273       Print("o->", omit_default ? "AppendStringIfNotEmpty" : "AppendString",
    274             "(\"", field.name(), "\", ProtobufStringToString(", field_expr,
    275             "));");
    276       break;
    277     }
    278     case FieldDescriptor::CPPTYPE_ENUM:
    279       if (omit_default) {
    280         Print("if (", field_expr, " != 0) {").Nest();
    281       }
    282       Print("const char* enum_name = ",
    283             GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, ");");
    284       Print("if (enum_name[0]) {").Nest();
    285       Print("o->AppendEnumName(\"", field.name(), "\", enum_name);");
    286       Unnest().Print("} else {").Nest();
    287       Print("o->AppendNumeric(\"", field.name(), "\", ", field_expr, ");");
    288       Unnest().Print("}");
    289       if (omit_default) {
    290         Unnest().Print("}");
    291       }
    292       break;
    293     case FieldDescriptor::CPPTYPE_MESSAGE:
    294       CHECK(!field.message_type()->options().map_entry());
    295       if (omit_default) {
    296         Print("if (msg.has_", field.name(), "()) {").Nest();
    297       }
    298       Print("o->OpenNestedMessage(\"", field.name(), "\");");
    299       Print(GetQualifiedAppendFn(*field.message_type()), "(o, ", field_expr,
    300             ");");
    301       Print("o->CloseNestedMessage();");
    302       if (omit_default) {
    303         Unnest().Print("}");
    304       }
    305       break;
    306   }
    307 }
    308 
    309 void Generator::AppendFieldAppend(const FieldDescriptor& field) {
    310   const string& name = field.name();
    311 
    312   if (field.is_map()) {
    313     Print("{").Nest();
    314     const auto& key_type = *field.message_type()->FindFieldByName("key");
    315     const auto& value_type = *field.message_type()->FindFieldByName("value");
    316 
    317     Print("std::vector<", key_type.cpp_type_name(), "> keys;");
    318     Print("for (const auto& e : msg.", name, "()) keys.push_back(e.first);");
    319     Print("std::stable_sort(keys.begin(), keys.end());");
    320     Print("for (const auto& key : keys) {").Nest();
    321     Print("o->OpenNestedMessage(\"", name, "\");");
    322     AppendFieldValueAppend(key_type, false /* omit_default */, "key");
    323     AppendFieldValueAppend(value_type, false /* omit_default */,
    324                            StrCat("msg.", name, "().at(key)"));
    325     Print("o->CloseNestedMessage();");
    326     Unnest().Print("}");
    327 
    328     Unnest().Print("}");
    329   } else if (field.is_repeated()) {
    330     Print("for (int i = 0; i < msg.", name, "_size(); ++i) {");
    331     Nest();
    332     AppendFieldValueAppend(field, false /* omit_default */,
    333                            "msg." + name + "(i)");
    334     Unnest().Print("}");
    335   } else {
    336     const auto* oneof = field.containing_oneof();
    337     if (oneof != nullptr) {
    338       string camel_name = field.camelcase_name();
    339       camel_name[0] = toupper(camel_name[0]);
    340       Print("if (msg.", oneof->name(), "_case() == ",
    341             GetQualifiedName(*oneof->containing_type()), "::k", camel_name,
    342             ") {");
    343       Nest();
    344       AppendFieldValueAppend(field, false /* omit_default */,
    345                              "msg." + name + "()");
    346       Unnest();
    347       Print("}");
    348     } else {
    349       AppendFieldValueAppend(field, true /* omit_default */,
    350                              "msg." + name + "()");
    351     }
    352   }
    353 }
    354 
    355 void Generator::AppendEnumFunctions(const EnumDescriptor& enum_d) {
    356   const string sig = StrCat("const char* ", GetEnumNameFn(enum_d), "(\n    ",
    357                             GetQualifiedName(enum_d), " value)");
    358   SetOutput(&header_);
    359   Print().Print("// Enum text output for ", string(enum_d.full_name()));
    360   Print(sig, ";");
    361 
    362   SetOutput(&cc_);
    363   Print().Print(sig, " {");
    364   Nest().Print("switch (value) {").Nest();
    365   for (int i = 0; i < enum_d.value_count(); ++i) {
    366     const auto& value = *enum_d.value(i);
    367     Print("case ", value.number(), ": return \"", value.name(), "\";");
    368   }
    369   Print("default: return \"\";");
    370   Unnest().Print("}");
    371   Unnest().Print("}");
    372 }
    373 
    374 void Generator::AppendParseMessageFunction(const Descriptor& md) {
    375   const bool map_append = (md.options().map_entry());
    376   string sig;
    377   if (!map_append) {
    378     sig = StrCat("bool ProtoParseFromString(\n    const string& s,\n    ",
    379                  GetQualifiedName(md), "* msg)");
    380     SetOutput(&header_).Print(sig, "\n        TF_MUST_USE_RESULT;");
    381 
    382     SetOutput(&cc_);
    383     Print().Print(sig, " {").Nest();
    384     Print("msg->Clear();");
    385     Print("Scanner scanner(s);");
    386     Print("if (!internal::ProtoParseFromScanner(",
    387           "&scanner, false, false, msg)) return false;");
    388     Print("scanner.Eos();");
    389     Print("return scanner.GetResult();");
    390     Unnest().Print("}");
    391   }
    392 
    393   // Parse from scanner - the real work here.
    394   sig = StrCat("bool ProtoParseFromScanner(",
    395                "\n    ::tensorflow::strings::Scanner* scanner, bool nested, "
    396                "bool close_curly,\n    ");
    397   const FieldDescriptor* key_type = nullptr;
    398   const FieldDescriptor* value_type = nullptr;
    399   if (map_append) {
    400     key_type = md.FindFieldByName("key");
    401     value_type = md.FindFieldByName("value");
    402     StrAppend(&sig, "::tensorflow::protobuf::Map<", GetCppClass(*key_type),
    403               ", ", GetCppClass(*value_type), ">* map)");
    404   } else {
    405     StrAppend(&sig, GetQualifiedName(md), "* msg)");
    406   }
    407 
    408   if (!map_append_signatures_included_.insert(sig).second) {
    409     // signature for function to append to a map of this type has
    410     // already been defined in this .cc file. Don't define it again.
    411     return;
    412   }
    413 
    414   if (!map_append) {
    415     SetOutput(&header_impl_).Print(sig, ";");
    416   }
    417 
    418   SetOutput(&cc_);
    419   Print().Print("namespace internal {");
    420   if (map_append) {
    421     Print("namespace {");
    422   }
    423   Print().Print(sig, " {").Nest();
    424   if (map_append) {
    425     Print(GetCppClass(*key_type), " map_key;");
    426     Print("bool set_map_key = false;");
    427     Print(GetCppClass(*value_type), " map_value;");
    428     Print("bool set_map_value = false;");
    429   }
    430   Print("std::vector<bool> has_seen(", md.field_count(), ", false);");
    431   Print("while(true) {").Nest();
    432   Print("ProtoSpaceAndComments(scanner);");
    433 
    434   // Emit success case
    435   Print("if (nested && (scanner->Peek() == (close_curly ? '}' : '>'))) {")
    436       .Nest();
    437   Print("scanner->One(Scanner::ALL);");
    438   Print("ProtoSpaceAndComments(scanner);");
    439   if (map_append) {
    440     Print("if (!set_map_key || !set_map_value) return false;");
    441     Print("(*map)[map_key] = map_value;");
    442   }
    443   Print("return true;");
    444   Unnest().Print("}");
    445 
    446   Print("if (!nested && scanner->empty()) { return true; }");
    447   Print("scanner->RestartCapture()");
    448   Print("    .Many(Scanner::LETTER_DIGIT_UNDERSCORE)");
    449   Print("    .StopCapture();");
    450   Print("StringPiece identifier;");
    451   Print("if (!scanner->GetResult(nullptr, &identifier)) return false;");
    452   Print("bool parsed_colon = false;");
    453   Print("(void)parsed_colon;"); // Avoid "set but not used" compiler warning
    454   Print("ProtoSpaceAndComments(scanner);");
    455   Print("if (scanner->Peek() == ':') {");
    456   Nest().Print("parsed_colon = true;");
    457   Print("scanner->One(Scanner::ALL);");
    458   Print("ProtoSpaceAndComments(scanner);");
    459   Unnest().Print("}");
    460   for (int i = 0; i < md.field_count(); ++i) {
    461     const FieldDescriptor* field = md.field(i);
    462     const string& field_name = field->name();
    463     string mutable_value_expr;
    464     string set_value_prefix;
    465     if (map_append) {
    466       mutable_value_expr = StrCat("&map_", field_name);
    467       set_value_prefix = StrCat("map_", field_name, " = ");
    468     } else if (field->is_repeated()) {
    469       if (field->is_map()) {
    470         mutable_value_expr = StrCat("msg->mutable_", field_name, "()");
    471         set_value_prefix =
    472             "UNREACHABLE";  // generator will never use this value.
    473       } else {
    474         mutable_value_expr = StrCat("msg->add_", field_name, "()");
    475         set_value_prefix = StrCat("msg->add_", field_name);
    476       }
    477     } else {
    478       mutable_value_expr = StrCat("msg->mutable_", field_name, "()");
    479       set_value_prefix = StrCat("msg->set_", field_name);
    480     }
    481 
    482     Print(i == 0 ? "" : "else ", "if (identifier == \"", field_name, "\") {");
    483     Nest();
    484 
    485     if (field->is_repeated()) {
    486       CHECK(!map_append);
    487 
    488       // Check to see if this is an array assignment, like a: [1, 2, 3]
    489       Print("const bool is_list = (scanner->Peek() == '[');");
    490       Print("do {");
    491       // [ or , // skip
    492       Nest().Print("if (is_list) {");
    493       Nest().Print("scanner->One(Scanner::ALL);");
    494       Print("ProtoSpaceAndComments(scanner);");
    495       Unnest().Print("}");
    496     } else if (field->containing_oneof() != nullptr) {
    497       CHECK(!map_append);
    498 
    499       // Detect duplicate oneof value.
    500       const string oneof_name = field->containing_oneof()->name();
    501       Print("if (msg->", oneof_name, "_case() != 0) return false;");
    502     }
    503 
    504     if (!field->is_repeated() && !map_append) {
    505       // Detect duplicate nested repeated message.
    506       Print("if (has_seen[", i, "]) return false;");
    507       Print("has_seen[", i, "] = true;");
    508     }
    509     if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
    510       Print("const char open_char = scanner->Peek();");
    511       Print("if (open_char != '{' && open_char != '<') return false;");
    512       Print("scanner->One(Scanner::ALL);");
    513       Print("ProtoSpaceAndComments(scanner);");
    514       if (field->is_map()) {
    515         Print("if (!ProtoParseFromScanner(");
    516       } else {
    517         Print("if (!", GetPackageReferencePrefix(field->message_type()->file()),
    518               "internal::ProtoParseFromScanner(");
    519       }
    520       Print("    scanner, true, open_char == '{', ", mutable_value_expr,
    521             ")) return false;");
    522     } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) {
    523       Print("string str_value;");
    524       Print(
    525           "if (!parsed_colon || "
    526           "!::tensorflow::strings::ProtoParseStringLiteralFromScanner(");
    527       Print("    scanner, &str_value)) return false;");
    528       Print("SetProtobufStringSwapAllowed(&str_value, ", mutable_value_expr,
    529             ");");
    530     } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
    531       Print("StringPiece value;");
    532       Print(
    533           "if (!parsed_colon || "
    534           "!scanner->RestartCapture().Many("
    535           "Scanner::LETTER_DIGIT_DASH_UNDERSCORE)."
    536           "GetResult(nullptr, &value)) return false;");
    537       const auto* enum_d = field->enum_type();
    538       string value_prefix;
    539       if (enum_d->containing_type() == nullptr) {
    540         value_prefix = GetPackageReferencePrefix(enum_d->file());
    541       } else {
    542         value_prefix = StrCat(GetQualifiedName(*enum_d), "_");
    543       }
    544 
    545       for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) {
    546         const auto* value_d = enum_d->value(enum_i);
    547         const string& value_name = value_d->name();
    548         string condition = StrCat("value == \"", value_name, "\"");
    549 
    550         Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {");
    551         Nest();
    552         Print(set_value_prefix, "(", value_prefix, value_name, ");");
    553         Unnest();
    554       }
    555       Print("} else {");
    556       Nest();
    557       // Proto3 allows all numeric values.
    558       Print("int32 int_value;");
    559       Print("if (strings::SafeStringToNumeric(value, &int_value)) {");
    560       Nest();
    561       Print(set_value_prefix, "(static_cast<", GetQualifiedName(*enum_d),
    562             ">(int_value));");
    563       Unnest();
    564       Print("} else {").Nest().Print("return false;").Unnest().Print("}");
    565       Unnest().Print("}");
    566     } else {
    567       Print(field->cpp_type_name(), " value;");
    568       switch (field->cpp_type()) {
    569         case FieldDescriptor::CPPTYPE_INT32:
    570         case FieldDescriptor::CPPTYPE_INT64:
    571         case FieldDescriptor::CPPTYPE_UINT32:
    572         case FieldDescriptor::CPPTYPE_UINT64:
    573         case FieldDescriptor::CPPTYPE_DOUBLE:
    574         case FieldDescriptor::CPPTYPE_FLOAT:
    575           Print(
    576               "if (!parsed_colon || "
    577               "!::tensorflow::strings::ProtoParseNumericFromScanner(",
    578               "scanner, &value)) return false;");
    579           break;
    580         case FieldDescriptor::CPPTYPE_BOOL:
    581           Print(
    582               "if (!parsed_colon || "
    583               "!::tensorflow::strings::ProtoParseBoolFromScanner(",
    584               "scanner, &value)) return false;");
    585           break;
    586         default:
    587           LOG(FATAL) << "handled earlier";
    588       }
    589       Print(set_value_prefix, "(value);");
    590     }
    591 
    592     if (field->is_repeated()) {
    593       Unnest().Print("} while (is_list && scanner->Peek() == ',');");
    594       Print(
    595           "if (is_list && "
    596           "!scanner->OneLiteral(\"]\").GetResult()) return false;");
    597     }
    598     if (map_append) {
    599       Print("set_map_", field_name, " = true;");
    600     }
    601     Unnest().Print("}");
    602   }
    603   Unnest().Print("}");
    604   Unnest().Print("}");
    605   Unnest().Print();
    606   if (map_append) {
    607     Print("}  // namespace");
    608   }
    609   Print("}  // namespace internal");
    610 }
    611 
    612 void Generator::AppendDebugStringFunctions(const Descriptor& md) {
    613   SetOutput(&header_impl_).Print();
    614   SetOutput(&header_).Print().Print("// Message-text conversion for ",
    615                                     string(md.full_name()));
    616 
    617   // Append the two debug string functions for <md>.
    618   for (int short_pass = 0; short_pass < 2; ++short_pass) {
    619     const bool short_debug = (short_pass == 1);
    620 
    621     // Make the Get functions.
    622     const string sig = StrCat(
    623         "string ", short_debug ? "ProtoShortDebugString" : "ProtoDebugString",
    624         "(\n    const ", GetQualifiedName(md), "& msg)");
    625     SetOutput(&header_).Print(sig, ";");
    626 
    627     SetOutput(&cc_);
    628     Print().Print(sig, " {").Nest();
    629     Print("string s;");
    630     Print("::tensorflow::strings::ProtoTextOutput o(&s, ",
    631           short_debug ? "true" : "false", ");");
    632     Print("internal::AppendProtoDebugString(&o, msg);");
    633     Print("o.CloseTopMessage();");
    634     Print("return s;");
    635     Unnest().Print("}");
    636   }
    637 
    638   // Make the Append function.
    639   const string sig =
    640       StrCat("void AppendProtoDebugString(\n",
    641              "    ::tensorflow::strings::ProtoTextOutput* o,\n    const ",
    642              GetQualifiedName(md), "& msg)");
    643   SetOutput(&header_impl_).Print(sig, ";");
    644   SetOutput(&cc_);
    645   Print().Print("namespace internal {").Print();
    646   Print(sig, " {").Nest();
    647   std::vector<const FieldDescriptor*> fields;
    648   fields.reserve(md.field_count());
    649   for (int i = 0; i < md.field_count(); ++i) {
    650     fields.push_back(md.field(i));
    651   }
    652   std::sort(fields.begin(), fields.end(),
    653             [](const FieldDescriptor* left, const FieldDescriptor* right) {
    654               return left->number() < right->number();
    655             });
    656 
    657   for (const FieldDescriptor* field : fields) {
    658     SetOutput(&cc_);
    659     AppendFieldAppend(*field);
    660   }
    661   Unnest().Print("}").Print().Print("}  // namespace internal");
    662 }
    663 
    664 void Generator::AppendMessageFunctions(const Descriptor& md) {
    665   if (md.options().map_entry()) {
    666     // The 'map entry' Message is not a user-visible message type.  Only its
    667     // parse function is created (and that actually parsed the whole Map, not
    668     // just the map entry). Printing of a map is done in the code generated for
    669     // the containing message.
    670     AppendParseMessageFunction(md);
    671     return;
    672   }
    673 
    674   // Recurse before adding the main message function, so that internal
    675   // map_append functions are available before they are needed.
    676   for (int i = 0; i < md.enum_type_count(); ++i) {
    677     AppendEnumFunctions(*md.enum_type(i));
    678   }
    679   for (int i = 0; i < md.nested_type_count(); ++i) {
    680     AppendMessageFunctions(*md.nested_type(i));
    681   }
    682 
    683   AppendDebugStringFunctions(md);
    684   AppendParseMessageFunction(md);
    685 }
    686 
    687 void Generator::AddNamespaceToCurrentSection(const string& package, bool open) {
    688   Print();
    689   std::vector<string> parts = {""};
    690   for (size_t i = 0; i < package.size(); ++i) {
    691     if (package[i] == '.') {
    692       parts.resize(parts.size() + 1);
    693     } else {
    694       parts.back() += package[i];
    695     }
    696   }
    697   if (open) {
    698     for (const auto& p : parts) {
    699       Print("namespace ", p, " {");
    700     }
    701   } else {
    702     for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
    703       Print("}  // namespace ", *it);
    704     }
    705   }
    706 }
    707 
    708 void Generator::AddHeadersToCurrentSection(const std::vector<string>& headers) {
    709   std::vector<string> sorted = headers;
    710   std::sort(sorted.begin(), sorted.end());
    711   for (const auto& h : sorted) {
    712     Print("#include \"", h, "\"");
    713   }
    714 }
    715 
    716 // Adds to <all_fd> and <all_d> with all descriptors recursively
    717 // reachable from the given descriptor.
    718 void GetAllFileDescriptorsFromFile(const FileDescriptor* fd,
    719                                    std::set<const FileDescriptor*>* all_fd,
    720                                    std::set<const Descriptor*>* all_d);
    721 
    722 // Adds to <all_fd> and <all_d> with all descriptors recursively
    723 // reachable from the given descriptor.
    724 void GetAllFileDescriptorsFromMessage(const Descriptor* d,
    725                                       std::set<const FileDescriptor*>* all_fd,
    726                                       std::set<const Descriptor*>* all_d) {
    727   if (!all_d->insert(d).second) return;
    728   GetAllFileDescriptorsFromFile(d->file(), all_fd, all_d);
    729   for (int i = 0; i < d->field_count(); ++i) {
    730     auto* f = d->field(i);
    731     switch (f->cpp_type()) {
    732       case FieldDescriptor::CPPTYPE_INT32:
    733       case FieldDescriptor::CPPTYPE_INT64:
    734       case FieldDescriptor::CPPTYPE_UINT32:
    735       case FieldDescriptor::CPPTYPE_UINT64:
    736       case FieldDescriptor::CPPTYPE_DOUBLE:
    737       case FieldDescriptor::CPPTYPE_FLOAT:
    738       case FieldDescriptor::CPPTYPE_BOOL:
    739       case FieldDescriptor::CPPTYPE_STRING:
    740         break;
    741       case FieldDescriptor::CPPTYPE_MESSAGE:
    742         GetAllFileDescriptorsFromMessage(f->message_type(), all_fd, all_d);
    743         break;
    744       case FieldDescriptor::CPPTYPE_ENUM:
    745         GetAllFileDescriptorsFromFile(f->enum_type()->file(), all_fd, all_d);
    746         break;
    747     }
    748   }
    749   for (int i = 0; i < d->nested_type_count(); ++i) {
    750     GetAllFileDescriptorsFromMessage(d->nested_type(i), all_fd, all_d);
    751   }
    752 }
    753 
    754 void GetAllFileDescriptorsFromFile(const FileDescriptor* fd,
    755                                    std::set<const FileDescriptor*>* all_fd,
    756                                    std::set<const Descriptor*>* all_d) {
    757   if (!all_fd->insert(fd).second) return;
    758   for (int i = 0; i < fd->message_type_count(); ++i) {
    759     GetAllFileDescriptorsFromMessage(fd->message_type(i), all_fd, all_d);
    760   }
    761 }
    762 
    763 void Generator::Generate(const FileDescriptor& fd) {
    764   // This does not emit code with proper proto2 semantics (e.g. it doesn't check
    765   // 'has' fields on non-messages), so check that only proto3 is passed.
    766   CHECK_EQ(fd.syntax(), FileDescriptor::SYNTAX_PROTO3) << fd.name();
    767 
    768   const string package = fd.package();
    769   std::set<const FileDescriptor*> all_fd;
    770   std::set<const Descriptor*> all_d;
    771   GetAllFileDescriptorsFromFile(&fd, &all_fd, &all_d);
    772 
    773   std::vector<string> headers;
    774 
    775   // Add header to header file.
    776   SetOutput(&header_);
    777   Print("// GENERATED FILE - DO NOT MODIFY");
    778   Print("#ifndef ", GetHeaderGuard(fd, false /* impl */));
    779   Print("#define ", GetHeaderGuard(fd, false /* impl */));
    780   Print();
    781   headers = {
    782       GetProtoHeaderName(fd),
    783       StrCat(tf_header_prefix_, "tensorflow/core/platform/macros.h"),
    784       StrCat(tf_header_prefix_, "tensorflow/core/platform/protobuf.h"),
    785       StrCat(tf_header_prefix_, "tensorflow/core/platform/types.h"),
    786   };
    787   for (const auto& h : headers) {
    788     Print("#include \"", h, "\"");
    789   }
    790   AddNamespaceToCurrentSection(package, true /* is_open */);
    791 
    792   // Add header to impl file.
    793   SetOutput(&header_impl_);
    794   Print("// GENERATED FILE - DO NOT MODIFY");
    795   Print("#ifndef ", GetHeaderGuard(fd, true /* impl */));
    796   Print("#define ", GetHeaderGuard(fd, true /* impl */));
    797   Print();
    798   headers = {
    799       GetProtoTextHeaderName(fd, false /* impl */),
    800       StrCat(tf_header_prefix_,
    801              "tensorflow/core/lib/strings/proto_text_util.h"),
    802       StrCat(tf_header_prefix_, "tensorflow/core/lib/strings/scanner.h"),
    803   };
    804   for (const FileDescriptor* d : all_fd) {
    805     if (d != &fd) {
    806       headers.push_back(GetProtoTextHeaderName(*d, true /* impl */));
    807     }
    808     headers.push_back(GetProtoHeaderName(*d));
    809   }
    810   AddHeadersToCurrentSection(headers);
    811   AddNamespaceToCurrentSection(package, true /* is_open */);
    812   SetOutput(&header_impl_).Print().Print("namespace internal {");
    813 
    814   // Add header to cc file.
    815   SetOutput(&cc_);
    816   Print("// GENERATED FILE - DO NOT MODIFY");
    817   Print();
    818   Print("#include <algorithm>");  // for `std::stable_sort()`
    819   Print();
    820   headers = {GetProtoTextHeaderName(fd, true /* impl */)};
    821   AddHeadersToCurrentSection(headers);
    822   Print();
    823   Print("using ::tensorflow::strings::Scanner;");
    824   Print("using ::tensorflow::strings::StrCat;");
    825   AddNamespaceToCurrentSection(package, true /* is_open */);
    826 
    827   // Add declarations and definitions.
    828   for (int i = 0; i < fd.enum_type_count(); ++i) {
    829     AppendEnumFunctions(*fd.enum_type(i));
    830   }
    831   for (int i = 0; i < fd.message_type_count(); ++i) {
    832     AppendMessageFunctions(*fd.message_type(i));
    833   }
    834 
    835   // Add footer to header file.
    836   SetOutput(&header_);
    837   AddNamespaceToCurrentSection(package, false /* is_open */);
    838   Print().Print("#endif  // ", GetHeaderGuard(fd, false /* impl */));
    839 
    840   // Add footer to header impl file.
    841   SetOutput(&header_impl_).Print().Print("}  // namespace internal");
    842   AddNamespaceToCurrentSection(package, false /* is_open */);
    843   Print().Print("#endif  // ", GetHeaderGuard(fd, true /* impl */));
    844 
    845   // Add footer to cc file.
    846   SetOutput(&cc_);
    847   AddNamespaceToCurrentSection(package, false /* is_open */);
    848 }
    849 
    850 }  // namespace
    851 
    852 ProtoTextFunctionCode GetProtoTextFunctionCode(const FileDescriptor& fd,
    853                                                const string& tf_header_prefix) {
    854   Generator gen(tf_header_prefix);
    855   gen.Generate(fd);
    856   return gen.code();
    857 }
    858 
    859 }  // namespace tensorflow
    860