Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op_gen_lib.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/framework/attr_value.pb.h"
     20 #include "tensorflow/core/lib/core/errors.h"
     21 #include "tensorflow/core/lib/gtl/map_util.h"
     22 #include "tensorflow/core/lib/strings/str_util.h"
     23 #include "tensorflow/core/lib/strings/strcat.h"
     24 #include "tensorflow/core/platform/protobuf.h"
     25 
     26 namespace tensorflow {
     27 
     28 string WordWrap(StringPiece prefix, StringPiece str, int width) {
     29   const string indent_next_line = "\n" + Spaces(prefix.size());
     30   width -= prefix.size();
     31   string result;
     32   strings::StrAppend(&result, prefix);
     33 
     34   while (!str.empty()) {
     35     if (static_cast<int>(str.size()) <= width) {
     36       // Remaining text fits on one line.
     37       strings::StrAppend(&result, str);
     38       break;
     39     }
     40     auto space = str.rfind(' ', width);
     41     if (space == StringPiece::npos) {
     42       // Rather make a too-long line and break at a space.
     43       space = str.find(' ');
     44       if (space == StringPiece::npos) {
     45         strings::StrAppend(&result, str);
     46         break;
     47       }
     48     }
     49     // Breaking at character at position <space>.
     50     StringPiece to_append = str.substr(0, space);
     51     str.remove_prefix(space + 1);
     52     // Remove spaces at break.
     53     while (to_append.ends_with(" ")) {
     54       to_append.remove_suffix(1);
     55     }
     56     while (str.Consume(" ")) {
     57     }
     58 
     59     // Go on to the next line.
     60     strings::StrAppend(&result, to_append);
     61     if (!str.empty()) strings::StrAppend(&result, indent_next_line);
     62   }
     63 
     64   return result;
     65 }
     66 
     67 bool ConsumeEquals(StringPiece* description) {
     68   if (description->Consume("=")) {
     69     while (description->Consume(" ")) {  // Also remove spaces after "=".
     70     }
     71     return true;
     72   }
     73   return false;
     74 }
     75 
     76 // Split `*orig` into two pieces at the first occurrence of `split_ch`.
     77 // Returns whether `split_ch` was found. Afterwards, `*before_split`
     78 // contains the maximum prefix of the input `*orig` that doesn't
     79 // contain `split_ch`, and `*orig` contains everything after the
     80 // first `split_ch`.
     81 static bool SplitAt(char split_ch, StringPiece* orig,
     82                     StringPiece* before_split) {
     83   auto pos = orig->find(split_ch);
     84   if (pos == StringPiece::npos) {
     85     *before_split = *orig;
     86     *orig = StringPiece();
     87     return false;
     88   } else {
     89     *before_split = orig->substr(0, pos);
     90     orig->remove_prefix(pos + 1);
     91     return true;
     92   }
     93 }
     94 
     95 // Does this line start with "<spaces><field>:" where "<field>" is
     96 // in multi_line_fields? Sets *colon_pos to the position of the colon.
     97 static bool StartsWithFieldName(StringPiece line,
     98                                 const std::vector<string>& multi_line_fields) {
     99   StringPiece up_to_colon;
    100   if (!SplitAt(':', &line, &up_to_colon)) return false;
    101   while (up_to_colon.Consume(" "))
    102     ;  // Remove leading spaces.
    103   for (const auto& field : multi_line_fields) {
    104     if (up_to_colon == field) {
    105       return true;
    106     }
    107   }
    108   return false;
    109 }
    110 
    111 static bool ConvertLine(StringPiece line,
    112                         const std::vector<string>& multi_line_fields,
    113                         string* ml) {
    114   // Is this a field we should convert?
    115   if (!StartsWithFieldName(line, multi_line_fields)) {
    116     return false;
    117   }
    118   // Has a matching field name, so look for "..." after the colon.
    119   StringPiece up_to_colon;
    120   StringPiece after_colon = line;
    121   SplitAt(':', &after_colon, &up_to_colon);
    122   while (after_colon.Consume(" "))
    123     ;  // Remove leading spaces.
    124   if (!after_colon.Consume("\"")) {
    125     // We only convert string fields, so don't convert this line.
    126     return false;
    127   }
    128   auto last_quote = after_colon.rfind('\"');
    129   if (last_quote == StringPiece::npos) {
    130     // Error: we don't see the expected matching quote, abort the conversion.
    131     return false;
    132   }
    133   StringPiece escaped = after_colon.substr(0, last_quote);
    134   StringPiece suffix = after_colon.substr(last_quote + 1);
    135   // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
    136 
    137   string unescaped;
    138   if (!str_util::CUnescape(escaped, &unescaped, nullptr)) {
    139     // Error unescaping, abort the conversion.
    140     return false;
    141   }
    142   // No more errors possible at this point.
    143 
    144   // Find a string to mark the end that isn't in unescaped.
    145   string end = "END";
    146   for (int s = 0; unescaped.find(end) != string::npos; ++s) {
    147     end = strings::StrCat("END", s);
    148   }
    149 
    150   // Actually start writing the converted output.
    151   strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
    152   if (!suffix.empty()) {
    153     // Output suffix, in case there was a trailing comment in the source.
    154     strings::StrAppend(ml, suffix);
    155   }
    156   strings::StrAppend(ml, "\n");
    157   return true;
    158 }
    159 
    160 string PBTxtToMultiline(StringPiece pbtxt,
    161                         const std::vector<string>& multi_line_fields) {
    162   string ml;
    163   // Probably big enough, since the input and output are about the
    164   // same size, but just a guess.
    165   ml.reserve(pbtxt.size() * (17. / 16));
    166   StringPiece line;
    167   while (!pbtxt.empty()) {
    168     // Split pbtxt into its first line and everything after.
    169     SplitAt('\n', &pbtxt, &line);
    170     // Convert line or output it unchanged
    171     if (!ConvertLine(line, multi_line_fields, &ml)) {
    172       strings::StrAppend(&ml, line, "\n");
    173     }
    174   }
    175   return ml;
    176 }
    177 
    178 // Given a single line of text `line` with first : at `colon`, determine if
    179 // there is an "<<END" expression after the colon and if so return true and set
    180 // `*end` to everything after the "<<".
    181 static bool FindMultiline(StringPiece line, size_t colon, string* end) {
    182   if (colon == StringPiece::npos) return false;
    183   line.remove_prefix(colon + 1);
    184   while (line.Consume(" ")) {
    185   }
    186   if (line.Consume("<<")) {
    187     *end = line.ToString();
    188     return true;
    189   }
    190   return false;
    191 }
    192 
    193 string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
    194   string pbtxt;
    195   // Probably big enough, since the input and output are about the
    196   // same size, but just a guess.
    197   pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
    198   StringPiece line;
    199   while (!multiline_pbtxt.empty()) {
    200     // Split multiline_pbtxt into its first line and everything after.
    201     if (!SplitAt('\n', &multiline_pbtxt, &line)) {
    202       strings::StrAppend(&pbtxt, line);
    203       break;
    204     }
    205 
    206     string end;
    207     auto colon = line.find(':');
    208     if (!FindMultiline(line, colon, &end)) {
    209       // Normal case: not a multi-line string, just output the line as-is.
    210       strings::StrAppend(&pbtxt, line, "\n");
    211       continue;
    212     }
    213 
    214     // Multi-line case:
    215     //     something: <<END
    216     // xx
    217     // yy
    218     // END
    219     // Should be converted to:
    220     //     something: "xx\nyy"
    221 
    222     // Output everything up to the colon ("    something:").
    223     strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
    224 
    225     // Add every line to unescaped until we see the "END" string.
    226     string unescaped;
    227     bool first = true;
    228     string suffix;
    229     while (!multiline_pbtxt.empty()) {
    230       SplitAt('\n', &multiline_pbtxt, &line);
    231       if (line.Consume(end)) break;
    232       if (first) {
    233         first = false;
    234       } else {
    235         unescaped.push_back('\n');
    236       }
    237       strings::StrAppend(&unescaped, line);
    238       line = StringPiece();
    239     }
    240 
    241     // Escape what we extracted and then output it in quotes.
    242     strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line,
    243                        "\n");
    244   }
    245   return pbtxt;
    246 }
    247 
    248 static void StringReplace(const string& from, const string& to, string* s) {
    249   // Split *s into pieces delimited by `from`.
    250   std::vector<string> split;
    251   string::size_type pos = 0;
    252   while (pos < s->size()) {
    253     auto found = s->find(from, pos);
    254     if (found == string::npos) {
    255       split.push_back(s->substr(pos));
    256       break;
    257     } else {
    258       split.push_back(s->substr(pos, found - pos));
    259       pos = found + from.size();
    260       if (pos == s->size()) {  // handle case where `from` is at the very end.
    261         split.push_back("");
    262       }
    263     }
    264   }
    265   // Join the pieces back together with a new delimiter.
    266   *s = str_util::Join(split, to.c_str());
    267 }
    268 
    269 static void RenameInDocs(const string& from, const string& to,
    270                          ApiDef* api_def) {
    271   const string from_quoted = strings::StrCat("`", from, "`");
    272   const string to_quoted = strings::StrCat("`", to, "`");
    273   for (int i = 0; i < api_def->in_arg_size(); ++i) {
    274     if (!api_def->in_arg(i).description().empty()) {
    275       StringReplace(from_quoted, to_quoted,
    276                     api_def->mutable_in_arg(i)->mutable_description());
    277     }
    278   }
    279   for (int i = 0; i < api_def->out_arg_size(); ++i) {
    280     if (!api_def->out_arg(i).description().empty()) {
    281       StringReplace(from_quoted, to_quoted,
    282                     api_def->mutable_out_arg(i)->mutable_description());
    283     }
    284   }
    285   for (int i = 0; i < api_def->attr_size(); ++i) {
    286     if (!api_def->attr(i).description().empty()) {
    287       StringReplace(from_quoted, to_quoted,
    288                     api_def->mutable_attr(i)->mutable_description());
    289     }
    290   }
    291   if (!api_def->summary().empty()) {
    292     StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
    293   }
    294   if (!api_def->description().empty()) {
    295     StringReplace(from_quoted, to_quoted, api_def->mutable_description());
    296   }
    297 }
    298 
    299 namespace {
    300 
    301 // Initializes given ApiDef with data in OpDef.
    302 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
    303   api_def->set_graph_op_name(op_def.name());
    304   api_def->set_visibility(ApiDef::VISIBLE);
    305 
    306   auto* endpoint = api_def->add_endpoint();
    307   endpoint->set_name(op_def.name());
    308   if (op_def.has_deprecation()) {
    309     endpoint->set_deprecation_version(op_def.deprecation().version());
    310   }
    311 
    312   for (const auto& op_in_arg : op_def.input_arg()) {
    313     auto* api_in_arg = api_def->add_in_arg();
    314     api_in_arg->set_name(op_in_arg.name());
    315     api_in_arg->set_rename_to(op_in_arg.name());
    316     api_in_arg->set_description(op_in_arg.description());
    317 
    318     *api_def->add_arg_order() = op_in_arg.name();
    319   }
    320   for (const auto& op_out_arg : op_def.output_arg()) {
    321     auto* api_out_arg = api_def->add_out_arg();
    322     api_out_arg->set_name(op_out_arg.name());
    323     api_out_arg->set_rename_to(op_out_arg.name());
    324     api_out_arg->set_description(op_out_arg.description());
    325   }
    326   for (const auto& op_attr : op_def.attr()) {
    327     auto* api_attr = api_def->add_attr();
    328     api_attr->set_name(op_attr.name());
    329     api_attr->set_rename_to(op_attr.name());
    330     if (op_attr.has_default_value()) {
    331       *api_attr->mutable_default_value() = op_attr.default_value();
    332     }
    333     api_attr->set_description(op_attr.description());
    334   }
    335   api_def->set_summary(op_def.summary());
    336   api_def->set_description(op_def.description());
    337 }
    338 
    339 // Updates base_arg based on overrides in new_arg.
    340 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
    341   if (!new_arg.rename_to().empty()) {
    342     base_arg->set_rename_to(new_arg.rename_to());
    343   }
    344   if (!new_arg.description().empty()) {
    345     base_arg->set_description(new_arg.description());
    346   }
    347 }
    348 
    349 // Updates base_attr based on overrides in new_attr.
    350 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
    351   if (!new_attr.rename_to().empty()) {
    352     base_attr->set_rename_to(new_attr.rename_to());
    353   }
    354   if (new_attr.has_default_value()) {
    355     *base_attr->mutable_default_value() = new_attr.default_value();
    356   }
    357   if (!new_attr.description().empty()) {
    358     base_attr->set_description(new_attr.description());
    359   }
    360 }
    361 
    362 // Updates base_api_def based on overrides in new_api_def.
    363 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
    364   // Merge visibility
    365   if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
    366     base_api_def->set_visibility(new_api_def.visibility());
    367   }
    368   // Merge endpoints
    369   if (new_api_def.endpoint_size() > 0) {
    370     base_api_def->clear_endpoint();
    371     std::copy(
    372         new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
    373         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
    374   }
    375   // Merge args
    376   for (const auto& new_arg : new_api_def.in_arg()) {
    377     bool found_base_arg = false;
    378     for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
    379       auto* base_arg = base_api_def->mutable_in_arg(i);
    380       if (base_arg->name() == new_arg.name()) {
    381         MergeArg(base_arg, new_arg);
    382         found_base_arg = true;
    383         break;
    384       }
    385     }
    386     if (!found_base_arg) {
    387       return errors::FailedPrecondition("Argument ", new_arg.name(),
    388                                         " not defined in base api for ",
    389                                         base_api_def->graph_op_name());
    390     }
    391   }
    392   for (const auto& new_arg : new_api_def.out_arg()) {
    393     bool found_base_arg = false;
    394     for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
    395       auto* base_arg = base_api_def->mutable_out_arg(i);
    396       if (base_arg->name() == new_arg.name()) {
    397         MergeArg(base_arg, new_arg);
    398         found_base_arg = true;
    399         break;
    400       }
    401     }
    402     if (!found_base_arg) {
    403       return errors::FailedPrecondition("Argument ", new_arg.name(),
    404                                         " not defined in base api for ",
    405                                         base_api_def->graph_op_name());
    406     }
    407   }
    408   // Merge arg order
    409   if (new_api_def.arg_order_size() > 0) {
    410     // Validate that new arg_order is correct.
    411     if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
    412       return errors::FailedPrecondition(
    413           "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
    414           base_api_def->graph_op_name(),
    415           ". Expected: ", base_api_def->arg_order_size());
    416     }
    417     if (!std::is_permutation(new_api_def.arg_order().begin(),
    418                              new_api_def.arg_order().end(),
    419                              base_api_def->arg_order().begin())) {
    420       return errors::FailedPrecondition(
    421           "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
    422           " for ", base_api_def->graph_op_name(),
    423           ". All elements in arg_order override must match base arg_order: ",
    424           str_util::Join(base_api_def->arg_order(), ", "));
    425     }
    426 
    427     base_api_def->clear_arg_order();
    428     std::copy(
    429         new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
    430         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
    431   }
    432   // Merge attributes
    433   for (const auto& new_attr : new_api_def.attr()) {
    434     bool found_base_attr = false;
    435     for (int i = 0; i < base_api_def->attr_size(); ++i) {
    436       auto* base_attr = base_api_def->mutable_attr(i);
    437       if (base_attr->name() == new_attr.name()) {
    438         MergeAttr(base_attr, new_attr);
    439         found_base_attr = true;
    440         break;
    441       }
    442     }
    443     if (!found_base_attr) {
    444       return errors::FailedPrecondition("Attribute ", new_attr.name(),
    445                                         " not defined in base api for ",
    446                                         base_api_def->graph_op_name());
    447     }
    448   }
    449   // Merge summary
    450   if (!new_api_def.summary().empty()) {
    451     base_api_def->set_summary(new_api_def.summary());
    452   }
    453   // Merge description
    454   auto description = new_api_def.description().empty()
    455                          ? base_api_def->description()
    456                          : new_api_def.description();
    457 
    458   if (!new_api_def.description_prefix().empty()) {
    459     description =
    460         strings::StrCat(new_api_def.description_prefix(), "\n", description);
    461   }
    462   if (!new_api_def.description_suffix().empty()) {
    463     description =
    464         strings::StrCat(description, "\n", new_api_def.description_suffix());
    465   }
    466   base_api_def->set_description(description);
    467   return Status::OK();
    468 }
    469 }  // namespace
    470 
    471 ApiDefMap::ApiDefMap(const OpList& op_list) {
    472   for (const auto& op : op_list.op()) {
    473     ApiDef api_def;
    474     InitApiDefFromOpDef(op, &api_def);
    475     map_[op.name()] = api_def;
    476   }
    477 }
    478 
    479 ApiDefMap::~ApiDefMap() {}
    480 
    481 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
    482   for (const auto& filename : filenames) {
    483     TF_RETURN_IF_ERROR(LoadFile(env, filename));
    484   }
    485   return Status::OK();
    486 }
    487 
    488 Status ApiDefMap::LoadFile(Env* env, const string& filename) {
    489   if (filename.empty()) return Status::OK();
    490   string contents;
    491   TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
    492   TF_RETURN_IF_ERROR(LoadApiDef(contents));
    493   return Status::OK();
    494 }
    495 
    496 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
    497   const string contents = PBTxtFromMultiline(api_def_file_contents);
    498   ApiDefs api_defs;
    499   protobuf::TextFormat::ParseFromString(contents, &api_defs);
    500   for (const auto& api_def : api_defs.op()) {
    501     // Check if the op definition is loaded. If op definition is not
    502     // loaded, then we just skip this ApiDef.
    503     if (map_.find(api_def.graph_op_name()) != map_.end()) {
    504       // Overwrite current api def with data in api_def.
    505       TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
    506     }
    507   }
    508   return Status::OK();
    509 }
    510 
    511 void ApiDefMap::UpdateDocs() {
    512   for (auto& name_and_api_def : map_) {
    513     auto& api_def = name_and_api_def.second;
    514     CHECK_GT(api_def.endpoint_size(), 0);
    515     const string canonical_name = api_def.endpoint(0).name();
    516     if (api_def.graph_op_name() != canonical_name) {
    517       RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
    518     }
    519     for (const auto& in_arg : api_def.in_arg()) {
    520       if (in_arg.name() != in_arg.rename_to()) {
    521         RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
    522       }
    523     }
    524     for (const auto& out_arg : api_def.out_arg()) {
    525       if (out_arg.name() != out_arg.rename_to()) {
    526         RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
    527       }
    528     }
    529     for (const auto& attr : api_def.attr()) {
    530       if (attr.name() != attr.rename_to()) {
    531         RenameInDocs(attr.name(), attr.rename_to(), &api_def);
    532       }
    533     }
    534   }
    535 }
    536 
    537 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
    538   return gtl::FindOrNull(map_, name);
    539 }
    540 }  // namespace tensorflow
    541