Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op_gen_lib.h"
     17 
     18 #include <algorithm>
     19 #include <vector>
     20 #include "tensorflow/core/framework/attr_value.pb.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/gtl/map_util.h"
     23 #include "tensorflow/core/lib/strings/str_util.h"
     24 #include "tensorflow/core/lib/strings/strcat.h"
     25 #include "tensorflow/core/platform/protobuf.h"
     26 #include "tensorflow/core/util/proto/proto_utils.h"
     27 
     28 namespace tensorflow {
     29 
     30 string WordWrap(StringPiece prefix, StringPiece str, int width) {
     31   const string indent_next_line = "\n" + Spaces(prefix.size());
     32   width -= prefix.size();
     33   string result;
     34   strings::StrAppend(&result, prefix);
     35 
     36   while (!str.empty()) {
     37     if (static_cast<int>(str.size()) <= width) {
     38       // Remaining text fits on one line.
     39       strings::StrAppend(&result, str);
     40       break;
     41     }
     42     auto space = str.rfind(' ', width);
     43     if (space == StringPiece::npos) {
     44       // Rather make a too-long line and break at a space.
     45       space = str.find(' ');
     46       if (space == StringPiece::npos) {
     47         strings::StrAppend(&result, str);
     48         break;
     49       }
     50     }
     51     // Breaking at character at position <space>.
     52     StringPiece to_append = str.substr(0, space);
     53     str.remove_prefix(space + 1);
     54     // Remove spaces at break.
     55     while (str_util::EndsWith(to_append, " ")) {
     56       to_append.remove_suffix(1);
     57     }
     58     while (str_util::ConsumePrefix(&str, " ")) {
     59     }
     60 
     61     // Go on to the next line.
     62     strings::StrAppend(&result, to_append);
     63     if (!str.empty()) strings::StrAppend(&result, indent_next_line);
     64   }
     65 
     66   return result;
     67 }
     68 
     69 bool ConsumeEquals(StringPiece* description) {
     70   if (str_util::ConsumePrefix(description, "=")) {
     71     while (str_util::ConsumePrefix(description,
     72                                    " ")) {  // Also remove spaces after "=".
     73     }
     74     return true;
     75   }
     76   return false;
     77 }
     78 
     79 // Split `*orig` into two pieces at the first occurrence of `split_ch`.
     80 // Returns whether `split_ch` was found. Afterwards, `*before_split`
     81 // contains the maximum prefix of the input `*orig` that doesn't
     82 // contain `split_ch`, and `*orig` contains everything after the
     83 // first `split_ch`.
     84 static bool SplitAt(char split_ch, StringPiece* orig,
     85                     StringPiece* before_split) {
     86   auto pos = orig->find(split_ch);
     87   if (pos == StringPiece::npos) {
     88     *before_split = *orig;
     89     *orig = StringPiece();
     90     return false;
     91   } else {
     92     *before_split = orig->substr(0, pos);
     93     orig->remove_prefix(pos + 1);
     94     return true;
     95   }
     96 }
     97 
     98 // Does this line start with "<spaces><field>:" where "<field>" is
     99 // in multi_line_fields? Sets *colon_pos to the position of the colon.
    100 static bool StartsWithFieldName(StringPiece line,
    101                                 const std::vector<string>& multi_line_fields) {
    102   StringPiece up_to_colon;
    103   if (!SplitAt(':', &line, &up_to_colon)) return false;
    104   while (str_util::ConsumePrefix(&up_to_colon, " "))
    105     ;  // Remove leading spaces.
    106   for (const auto& field : multi_line_fields) {
    107     if (up_to_colon == field) {
    108       return true;
    109     }
    110   }
    111   return false;
    112 }
    113 
    114 static bool ConvertLine(StringPiece line,
    115                         const std::vector<string>& multi_line_fields,
    116                         string* ml) {
    117   // Is this a field we should convert?
    118   if (!StartsWithFieldName(line, multi_line_fields)) {
    119     return false;
    120   }
    121   // Has a matching field name, so look for "..." after the colon.
    122   StringPiece up_to_colon;
    123   StringPiece after_colon = line;
    124   SplitAt(':', &after_colon, &up_to_colon);
    125   while (str_util::ConsumePrefix(&after_colon, " "))
    126     ;  // Remove leading spaces.
    127   if (!str_util::ConsumePrefix(&after_colon, "\"")) {
    128     // We only convert string fields, so don't convert this line.
    129     return false;
    130   }
    131   auto last_quote = after_colon.rfind('\"');
    132   if (last_quote == StringPiece::npos) {
    133     // Error: we don't see the expected matching quote, abort the conversion.
    134     return false;
    135   }
    136   StringPiece escaped = after_colon.substr(0, last_quote);
    137   StringPiece suffix = after_colon.substr(last_quote + 1);
    138   // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
    139 
    140   string unescaped;
    141   if (!str_util::CUnescape(escaped, &unescaped, nullptr)) {
    142     // Error unescaping, abort the conversion.
    143     return false;
    144   }
    145   // No more errors possible at this point.
    146 
    147   // Find a string to mark the end that isn't in unescaped.
    148   string end = "END";
    149   for (int s = 0; unescaped.find(end) != string::npos; ++s) {
    150     end = strings::StrCat("END", s);
    151   }
    152 
    153   // Actually start writing the converted output.
    154   strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
    155   if (!suffix.empty()) {
    156     // Output suffix, in case there was a trailing comment in the source.
    157     strings::StrAppend(ml, suffix);
    158   }
    159   strings::StrAppend(ml, "\n");
    160   return true;
    161 }
    162 
    163 string PBTxtToMultiline(StringPiece pbtxt,
    164                         const std::vector<string>& multi_line_fields) {
    165   string ml;
    166   // Probably big enough, since the input and output are about the
    167   // same size, but just a guess.
    168   ml.reserve(pbtxt.size() * (17. / 16));
    169   StringPiece line;
    170   while (!pbtxt.empty()) {
    171     // Split pbtxt into its first line and everything after.
    172     SplitAt('\n', &pbtxt, &line);
    173     // Convert line or output it unchanged
    174     if (!ConvertLine(line, multi_line_fields, &ml)) {
    175       strings::StrAppend(&ml, line, "\n");
    176     }
    177   }
    178   return ml;
    179 }
    180 
    181 // Given a single line of text `line` with first : at `colon`, determine if
    182 // there is an "<<END" expression after the colon and if so return true and set
    183 // `*end` to everything after the "<<".
    184 static bool FindMultiline(StringPiece line, size_t colon, string* end) {
    185   if (colon == StringPiece::npos) return false;
    186   line.remove_prefix(colon + 1);
    187   while (str_util::ConsumePrefix(&line, " ")) {
    188   }
    189   if (str_util::ConsumePrefix(&line, "<<")) {
    190     *end = string(line);
    191     return true;
    192   }
    193   return false;
    194 }
    195 
    196 string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
    197   string pbtxt;
    198   // Probably big enough, since the input and output are about the
    199   // same size, but just a guess.
    200   pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
    201   StringPiece line;
    202   while (!multiline_pbtxt.empty()) {
    203     // Split multiline_pbtxt into its first line and everything after.
    204     if (!SplitAt('\n', &multiline_pbtxt, &line)) {
    205       strings::StrAppend(&pbtxt, line);
    206       break;
    207     }
    208 
    209     string end;
    210     auto colon = line.find(':');
    211     if (!FindMultiline(line, colon, &end)) {
    212       // Normal case: not a multi-line string, just output the line as-is.
    213       strings::StrAppend(&pbtxt, line, "\n");
    214       continue;
    215     }
    216 
    217     // Multi-line case:
    218     //     something: <<END
    219     // xx
    220     // yy
    221     // END
    222     // Should be converted to:
    223     //     something: "xx\nyy"
    224 
    225     // Output everything up to the colon ("    something:").
    226     strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
    227 
    228     // Add every line to unescaped until we see the "END" string.
    229     string unescaped;
    230     bool first = true;
    231     while (!multiline_pbtxt.empty()) {
    232       SplitAt('\n', &multiline_pbtxt, &line);
    233       if (str_util::ConsumePrefix(&line, end)) break;
    234       if (first) {
    235         first = false;
    236       } else {
    237         unescaped.push_back('\n');
    238       }
    239       strings::StrAppend(&unescaped, line);
    240       line = StringPiece();
    241     }
    242 
    243     // Escape what we extracted and then output it in quotes.
    244     strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line,
    245                        "\n");
    246   }
    247   return pbtxt;
    248 }
    249 
    250 static void StringReplace(const string& from, const string& to, string* s) {
    251   // Split *s into pieces delimited by `from`.
    252   std::vector<string> split;
    253   string::size_type pos = 0;
    254   while (pos < s->size()) {
    255     auto found = s->find(from, pos);
    256     if (found == string::npos) {
    257       split.push_back(s->substr(pos));
    258       break;
    259     } else {
    260       split.push_back(s->substr(pos, found - pos));
    261       pos = found + from.size();
    262       if (pos == s->size()) {  // handle case where `from` is at the very end.
    263         split.push_back("");
    264       }
    265     }
    266   }
    267   // Join the pieces back together with a new delimiter.
    268   *s = str_util::Join(split, to.c_str());
    269 }
    270 
    271 static void RenameInDocs(const string& from, const string& to,
    272                          ApiDef* api_def) {
    273   const string from_quoted = strings::StrCat("`", from, "`");
    274   const string to_quoted = strings::StrCat("`", to, "`");
    275   for (int i = 0; i < api_def->in_arg_size(); ++i) {
    276     if (!api_def->in_arg(i).description().empty()) {
    277       StringReplace(from_quoted, to_quoted,
    278                     api_def->mutable_in_arg(i)->mutable_description());
    279     }
    280   }
    281   for (int i = 0; i < api_def->out_arg_size(); ++i) {
    282     if (!api_def->out_arg(i).description().empty()) {
    283       StringReplace(from_quoted, to_quoted,
    284                     api_def->mutable_out_arg(i)->mutable_description());
    285     }
    286   }
    287   for (int i = 0; i < api_def->attr_size(); ++i) {
    288     if (!api_def->attr(i).description().empty()) {
    289       StringReplace(from_quoted, to_quoted,
    290                     api_def->mutable_attr(i)->mutable_description());
    291     }
    292   }
    293   if (!api_def->summary().empty()) {
    294     StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
    295   }
    296   if (!api_def->description().empty()) {
    297     StringReplace(from_quoted, to_quoted, api_def->mutable_description());
    298   }
    299 }
    300 
    301 namespace {
    302 
    303 // Initializes given ApiDef with data in OpDef.
    304 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
    305   api_def->set_graph_op_name(op_def.name());
    306   api_def->set_visibility(ApiDef::VISIBLE);
    307 
    308   auto* endpoint = api_def->add_endpoint();
    309   endpoint->set_name(op_def.name());
    310 
    311   for (const auto& op_in_arg : op_def.input_arg()) {
    312     auto* api_in_arg = api_def->add_in_arg();
    313     api_in_arg->set_name(op_in_arg.name());
    314     api_in_arg->set_rename_to(op_in_arg.name());
    315     api_in_arg->set_description(op_in_arg.description());
    316 
    317     *api_def->add_arg_order() = op_in_arg.name();
    318   }
    319   for (const auto& op_out_arg : op_def.output_arg()) {
    320     auto* api_out_arg = api_def->add_out_arg();
    321     api_out_arg->set_name(op_out_arg.name());
    322     api_out_arg->set_rename_to(op_out_arg.name());
    323     api_out_arg->set_description(op_out_arg.description());
    324   }
    325   for (const auto& op_attr : op_def.attr()) {
    326     auto* api_attr = api_def->add_attr();
    327     api_attr->set_name(op_attr.name());
    328     api_attr->set_rename_to(op_attr.name());
    329     if (op_attr.has_default_value()) {
    330       *api_attr->mutable_default_value() = op_attr.default_value();
    331     }
    332     api_attr->set_description(op_attr.description());
    333   }
    334   api_def->set_summary(op_def.summary());
    335   api_def->set_description(op_def.description());
    336 }
    337 
    338 // Updates base_arg based on overrides in new_arg.
    339 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
    340   if (!new_arg.rename_to().empty()) {
    341     base_arg->set_rename_to(new_arg.rename_to());
    342   }
    343   if (!new_arg.description().empty()) {
    344     base_arg->set_description(new_arg.description());
    345   }
    346 }
    347 
    348 // Updates base_attr based on overrides in new_attr.
    349 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
    350   if (!new_attr.rename_to().empty()) {
    351     base_attr->set_rename_to(new_attr.rename_to());
    352   }
    353   if (new_attr.has_default_value()) {
    354     *base_attr->mutable_default_value() = new_attr.default_value();
    355   }
    356   if (!new_attr.description().empty()) {
    357     base_attr->set_description(new_attr.description());
    358   }
    359 }
    360 
    361 // Updates base_api_def based on overrides in new_api_def.
    362 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
    363   // Merge visibility
    364   if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
    365     base_api_def->set_visibility(new_api_def.visibility());
    366   }
    367   // Merge endpoints
    368   if (new_api_def.endpoint_size() > 0) {
    369     base_api_def->clear_endpoint();
    370     std::copy(
    371         new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
    372         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
    373   }
    374   // Merge args
    375   for (const auto& new_arg : new_api_def.in_arg()) {
    376     bool found_base_arg = false;
    377     for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
    378       auto* base_arg = base_api_def->mutable_in_arg(i);
    379       if (base_arg->name() == new_arg.name()) {
    380         MergeArg(base_arg, new_arg);
    381         found_base_arg = true;
    382         break;
    383       }
    384     }
    385     if (!found_base_arg) {
    386       return errors::FailedPrecondition("Argument ", new_arg.name(),
    387                                         " not defined in base api for ",
    388                                         base_api_def->graph_op_name());
    389     }
    390   }
    391   for (const auto& new_arg : new_api_def.out_arg()) {
    392     bool found_base_arg = false;
    393     for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
    394       auto* base_arg = base_api_def->mutable_out_arg(i);
    395       if (base_arg->name() == new_arg.name()) {
    396         MergeArg(base_arg, new_arg);
    397         found_base_arg = true;
    398         break;
    399       }
    400     }
    401     if (!found_base_arg) {
    402       return errors::FailedPrecondition("Argument ", new_arg.name(),
    403                                         " not defined in base api for ",
    404                                         base_api_def->graph_op_name());
    405     }
    406   }
    407   // Merge arg order
    408   if (new_api_def.arg_order_size() > 0) {
    409     // Validate that new arg_order is correct.
    410     if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
    411       return errors::FailedPrecondition(
    412           "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
    413           base_api_def->graph_op_name(),
    414           ". Expected: ", base_api_def->arg_order_size());
    415     }
    416     if (!std::is_permutation(new_api_def.arg_order().begin(),
    417                              new_api_def.arg_order().end(),
    418                              base_api_def->arg_order().begin())) {
    419       return errors::FailedPrecondition(
    420           "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
    421           " for ", base_api_def->graph_op_name(),
    422           ". All elements in arg_order override must match base arg_order: ",
    423           str_util::Join(base_api_def->arg_order(), ", "));
    424     }
    425 
    426     base_api_def->clear_arg_order();
    427     std::copy(
    428         new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
    429         protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
    430   }
    431   // Merge attributes
    432   for (const auto& new_attr : new_api_def.attr()) {
    433     bool found_base_attr = false;
    434     for (int i = 0; i < base_api_def->attr_size(); ++i) {
    435       auto* base_attr = base_api_def->mutable_attr(i);
    436       if (base_attr->name() == new_attr.name()) {
    437         MergeAttr(base_attr, new_attr);
    438         found_base_attr = true;
    439         break;
    440       }
    441     }
    442     if (!found_base_attr) {
    443       return errors::FailedPrecondition("Attribute ", new_attr.name(),
    444                                         " not defined in base api for ",
    445                                         base_api_def->graph_op_name());
    446     }
    447   }
    448   // Merge summary
    449   if (!new_api_def.summary().empty()) {
    450     base_api_def->set_summary(new_api_def.summary());
    451   }
    452   // Merge description
    453   auto description = new_api_def.description().empty()
    454                          ? base_api_def->description()
    455                          : new_api_def.description();
    456 
    457   if (!new_api_def.description_prefix().empty()) {
    458     description =
    459         strings::StrCat(new_api_def.description_prefix(), "\n", description);
    460   }
    461   if (!new_api_def.description_suffix().empty()) {
    462     description =
    463         strings::StrCat(description, "\n", new_api_def.description_suffix());
    464   }
    465   base_api_def->set_description(description);
    466   return Status::OK();
    467 }
    468 }  // namespace
    469 
    470 ApiDefMap::ApiDefMap(const OpList& op_list) {
    471   for (const auto& op : op_list.op()) {
    472     ApiDef api_def;
    473     InitApiDefFromOpDef(op, &api_def);
    474     map_[op.name()] = api_def;
    475   }
    476 }
    477 
    478 ApiDefMap::~ApiDefMap() {}
    479 
    480 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
    481   for (const auto& filename : filenames) {
    482     TF_RETURN_IF_ERROR(LoadFile(env, filename));
    483   }
    484   return Status::OK();
    485 }
    486 
    487 Status ApiDefMap::LoadFile(Env* env, const string& filename) {
    488   if (filename.empty()) return Status::OK();
    489   string contents;
    490   TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
    491   Status status = LoadApiDef(contents);
    492   if (!status.ok()) {
    493     // Return failed status annotated with filename to aid in debugging.
    494     return Status(status.code(),
    495                   strings::StrCat("Error parsing ApiDef file ", filename, ": ",
    496                                   status.error_message()));
    497   }
    498   return Status::OK();
    499 }
    500 
    501 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
    502   const string contents = PBTxtFromMultiline(api_def_file_contents);
    503   ApiDefs api_defs;
    504   TF_RETURN_IF_ERROR(
    505       proto_utils::ParseTextFormatFromString(contents, &api_defs));
    506   for (const auto& api_def : api_defs.op()) {
    507     // Check if the op definition is loaded. If op definition is not
    508     // loaded, then we just skip this ApiDef.
    509     if (map_.find(api_def.graph_op_name()) != map_.end()) {
    510       // Overwrite current api def with data in api_def.
    511       TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
    512     }
    513   }
    514   return Status::OK();
    515 }
    516 
    517 void ApiDefMap::UpdateDocs() {
    518   for (auto& name_and_api_def : map_) {
    519     auto& api_def = name_and_api_def.second;
    520     CHECK_GT(api_def.endpoint_size(), 0);
    521     const string canonical_name = api_def.endpoint(0).name();
    522     if (api_def.graph_op_name() != canonical_name) {
    523       RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
    524     }
    525     for (const auto& in_arg : api_def.in_arg()) {
    526       if (in_arg.name() != in_arg.rename_to()) {
    527         RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
    528       }
    529     }
    530     for (const auto& out_arg : api_def.out_arg()) {
    531       if (out_arg.name() != out_arg.rename_to()) {
    532         RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
    533       }
    534     }
    535     for (const auto& attr : api_def.attr()) {
    536       if (attr.name() != attr.rename_to()) {
    537         RenameInDocs(attr.name(), attr.rename_to(), &api_def);
    538       }
    539     }
    540   }
    541 }
    542 
    543 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
    544   return gtl::FindOrNull(map_, name);
    545 }
    546 }  // namespace tensorflow
    547