Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/python/framework/python_op_gen.h"
     17 
     18 #include <stdio.h>
     19 #include <sstream>
     20 #include <unordered_map>
     21 #include "tensorflow/core/framework/api_def.pb.h"
     22 #include "tensorflow/core/framework/attr_value.pb.h"
     23 #include "tensorflow/core/framework/op.h"
     24 #include "tensorflow/core/framework/op_def.pb_text.h"
     25 #include "tensorflow/core/framework/op_def.pb.h"
     26 #include "tensorflow/core/framework/op_def_util.h"
     27 #include "tensorflow/core/framework/op_gen_lib.h"
     28 #include "tensorflow/core/framework/tensor.pb_text.h"
     29 #include "tensorflow/core/framework/tensor.pb.h"
     30 #include "tensorflow/core/framework/tensor_shape.pb.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/framework/types.pb.h"
     33 #include "tensorflow/core/lib/gtl/map_util.h"
     34 #include "tensorflow/core/lib/gtl/stl_util.h"
     35 #include "tensorflow/core/lib/strings/str_util.h"
     36 #include "tensorflow/core/lib/strings/strcat.h"
     37 #include "tensorflow/core/lib/strings/stringprintf.h"
     38 #include "tensorflow/core/platform/logging.h"
     39 #include "tensorflow/core/platform/macros.h"
     40 #include "tensorflow/core/platform/types.h"
     41 #include "tensorflow/python/framework/python_op_gen_internal.h"
     42 
     43 namespace tensorflow {
     44 namespace python_op_gen_internal {
     45 
     46 const int kRightMargin = 78;
     47 
     48 bool IsPythonReserved(const string& s) {
     49   static const std::set<string>* const kPythonReserved = new std::set<string>(
     50       {// Keywords in Python, from:
     51        //   import keyword
     52        //   print keyword.kwlist
     53        "and", "as", "assert", "break", "class", "continue", "def", "del",
     54        "elif", "else", "except", "exec", "finally", "for", "from", "global",
     55        "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
     56        "raise", "return", "try", "while", "with", "yield",
     57        // Built-in functions and types in Python, from:
     58        //   [x for x in dir(__builtins__) if not x[0].islower()]
     59        "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
     60        "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
     61        "Ellipsis", "EnvironmentError", "Exception", "False",
     62        "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
     63        "ImportError", "ImportWarning", "IndentationError", "IndexError",
     64        "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
     65        "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
     66        "OverflowError", "PendingDeprecationWarning", "ReferenceError",
     67        "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
     68        "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
     69        "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
     70        "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
     71        "UnicodeWarning", "UserWarning", "ValueError", "Warning",
     72        "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
     73        "__package__"});
     74 
     75   return kPythonReserved->count(s) > 0;
     76 }
     77 
     78 string AvoidPythonReserved(const string& s) {
     79   if (IsPythonReserved(s)) return strings::StrCat(s, "_");
     80   return s;
     81 }
     82 
     83 // Indent the first line by "initial" spaces and all following lines
     84 // by "rest" spaces.
     85 string Indent(int initial, int rest, StringPiece in) {
     86   // TODO(josh11b): Also word-wrapping?
     87   string copy(in.data(), in.size());
     88   str_util::StripTrailingWhitespace(&copy);
     89   std::vector<string> v = str_util::Split(copy, '\n');
     90 
     91   string result;
     92   bool first = true;
     93   for (const string& line : v) {
     94     if (first) {
     95       result = strings::StrCat(Spaces(initial), line, "\n");
     96       first = false;
     97     } else {
     98       if (line.empty()) {
     99         strings::StrAppend(&result, "\n");
    100       } else {
    101         strings::StrAppend(&result, Spaces(rest), line, "\n");
    102       }
    103     }
    104   }
    105   return result;
    106 }
    107 
    108 // Adds append to *dest, with a space if the first line will be <= width,
    109 // or a newline otherwise.
    110 void AppendWithinWidth(string* dest, StringPiece append, int width) {
    111   auto first_line = append.find('\n');
    112   if (first_line == string::npos) first_line = append.size();
    113   if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
    114     strings::StrAppend(dest, "\n", append);
    115   } else {
    116     strings::StrAppend(dest, " ", append);
    117   }
    118 }
    119 
    120 // Like DataTypeString() but uses the Python names for the
    121 // float types.
    122 string PythonDataTypeString(DataType dtype) {
    123   switch (dtype) {
    124     case DT_FLOAT:
    125       return "float32";
    126     case DT_DOUBLE:
    127       return "float64";
    128     default:
    129       return DataTypeString(dtype);
    130   }
    131 }
    132 
    133 string TypeString(DataType dtype, bool ref) {
    134   if (ref) {
    135     return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
    136   } else {
    137     return strings::StrCat("`", PythonDataTypeString(dtype), "`");
    138   }
    139 }
    140 
    141 string TypeListString(const AttrValue& value) {
    142   string ret;
    143   for (int t : value.list().type()) {
    144     if (!ret.empty()) strings::StrAppend(&ret, ", ");
    145     DataType dtype = static_cast<DataType>(t);
    146     if (IsRefType(dtype)) {
    147       strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
    148                          " mutable");
    149     } else {
    150       strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
    151     }
    152   }
    153   return ret;
    154 }
    155 
    156 string SingleTensorName(DataType dtype, bool is_ref) {
    157   const string type_str = TypeString(dtype, is_ref);
    158   return strings::StrCat("A `Tensor` of type ", type_str, ".");
    159 }
    160 
    161 const char kUnknownTensorType[] = {"A `Tensor`."};
    162 
    163 string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
    164                    const std::unordered_map<string, string>& inferred_attrs,
    165                    bool is_output) {
    166   if (!arg.number_attr().empty()) {
    167     // N Tensors with the same type
    168     const string* original_arg =
    169         gtl::FindOrNull(inferred_attrs, arg.number_attr());
    170     string prefix;
    171     if (original_arg == nullptr) {
    172       prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
    173     } else if (*original_arg == arg.name()) {
    174       const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
    175       if (attr->has_minimum() && attr->minimum() > 0) {
    176         prefix = strings::StrCat("A list of at least ", attr->minimum());
    177       } else {
    178         prefix = "A list of";
    179       }
    180     } else {
    181       prefix = strings::StrCat("A list with the same length as `",
    182                                AvoidPythonReserved(*original_arg), "` of");
    183     }
    184 
    185     if (arg.type() != DT_INVALID) {
    186       return strings::StrCat(prefix, " `Tensor` objects with type ",
    187                              TypeString(arg.type(), arg.is_ref()), ".");
    188     } else {
    189       original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
    190       if (arg.is_ref()) {
    191         strings::StrAppend(&prefix, " mutable");
    192       }
    193       if (original_arg == nullptr) {
    194         return strings::StrCat(prefix, " `Tensor` objects with type `",
    195                                arg.type_attr(), "`.");
    196       } else if (*original_arg == arg.name()) {
    197         const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
    198         if (attr->has_allowed_values()) {
    199           return strings::StrCat(prefix,
    200                                  " `Tensor` objects with the same type in: ",
    201                                  TypeListString(attr->allowed_values()), ".");
    202         } else {
    203           return strings::StrCat(prefix,
    204                                  " `Tensor` objects with the same type.");
    205         }
    206       } else {
    207         return strings::StrCat(prefix,
    208                                " `Tensor` objects with the same type as `",
    209                                AvoidPythonReserved(*original_arg), "`.");
    210       }
    211     }
    212   } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
    213     const bool is_list = !arg.type_list_attr().empty();
    214     const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
    215     const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
    216     const string mutable_str = arg.is_ref() ? "mutable " : "";
    217     const string prefix =
    218         is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
    219                 : strings::StrCat("A ", mutable_str, "`Tensor`");
    220     const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
    221     if (original_arg == nullptr) {
    222       return strings::StrCat(prefix, " of type `", attr_name, "`.");
    223     } else if (*original_arg == arg.name()) {
    224       if (attr->has_allowed_values()) {
    225         if (is_list) {
    226           return strings::StrCat(prefix, " with types from: ",
    227                                  TypeListString(attr->allowed_values()), ".");
    228         } else {
    229           return strings::StrCat(
    230               prefix, is_output ? ". Has one of the following types: "
    231                                 : ". Must be one of the following types: ",
    232               TypeListString(attr->allowed_values()), ".");
    233         }
    234       } else {
    235         return strings::StrCat(prefix, ".");
    236       }
    237     } else {
    238       return strings::StrCat(prefix,
    239                              is_output ? ". Has the same type as `"
    240                                        : ". Must have the same type as `",
    241                              AvoidPythonReserved(*original_arg), "`.");
    242     }
    243   } else {
    244     return SingleTensorName(arg.type(), arg.is_ref());
    245   }
    246 }
    247 
    248 string GetReturns(const OpDef& op_def,
    249                   const std::vector<string>& output_type_string) {
    250   string result;
    251   DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
    252   const int num_outs = op_def.output_arg_size();
    253   strings::StrAppend(&result, "\n  Returns:\n");
    254   if (num_outs == 0) {
    255     strings::StrAppend(&result, "    The created Operation.\n");
    256   } else {
    257     if (num_outs == 1) {
    258       StringPiece description = op_def.output_arg(0).description();
    259       if (ConsumeEquals(&description)) {  // Skip the generated type info.
    260         strings::StrAppend(&result, Indent(4, 4, description));
    261       } else {
    262         // Special case of one output, don't use the name of the output unless
    263         // there is no description.
    264         string desc = output_type_string.empty() ? kUnknownTensorType
    265                                                  : output_type_string[0];
    266         if (desc == kUnknownTensorType) {
    267           // Special case where we don't understand how the output tensor type
    268           // depends on the input tensor types, just use the output arg
    269           // description if we can.
    270           if (!description.empty()) {
    271             desc = op_def.output_arg(0).description();
    272           } else if (!op_def.output_arg(0).name().empty()) {
    273             desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
    274                                    " `Tensor`.");
    275           }
    276         } else if (!description.empty()) {
    277           AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
    278         }
    279         strings::StrAppend(&result, Indent(4, 4, desc));
    280       }
    281     } else {
    282       std::vector<string> out_names(num_outs);
    283       for (int i = 0; i < num_outs; ++i) {
    284         if (!op_def.output_arg(i).name().empty()) {
    285           out_names[i] = op_def.output_arg(i).name();
    286         } else {
    287           out_names[i] = strings::StrCat("output", i);
    288         }
    289       }
    290       strings::StrAppend(&result, "    A tuple of `Tensor` objects (",
    291                          str_util::Join(out_names, ", "), ").\n\n");
    292       for (int i = 0; i < num_outs; ++i) {
    293         string desc = strings::StrCat(out_names[i], ": ");
    294         StringPiece description = op_def.output_arg(i).description();
    295         if (ConsumeEquals(&description)) {  // Skip the generated type info.
    296           strings::StrAppend(&desc, description);
    297         } else {
    298           const string type = static_cast<size_t>(i) < output_type_string.size()
    299                                   ? output_type_string[i]
    300                                   : kUnknownTensorType;
    301           if (!description.empty()) {
    302             if (type == kUnknownTensorType) {
    303               // Special case where we don't understand how the output tensor
    304               // type depends on the input tensor types, so we just use the
    305               // output arg description.
    306               strings::StrAppend(&desc, description);
    307             } else {
    308               strings::StrAppend(&desc, type, " ", description);
    309             }
    310           } else {
    311             strings::StrAppend(&desc, type);
    312           }
    313         }
    314         strings::StrAppend(&result, Indent(4, 6, desc));
    315       }
    316     }
    317   }
    318   return result;
    319 }
    320 
    321 string StringToPython(const string& str) {
    322   return strings::StrCat("\"", str_util::CEscape(str), "\"");
    323 }
    324 
    325 string DataTypeToPython(DataType dtype, const string& dtype_module) {
    326   return strings::StrCat(dtype_module, PythonDataTypeString(dtype));
    327 }
    328 
    329 string ShapeToPython(const TensorShapeProto& shape) {
    330   if (shape.unknown_rank()) {
    331     return "None";
    332   }
    333   string python = "[";
    334   for (const auto& dim : shape.dim()) {
    335     if (python.size() > 1) strings::StrAppend(&python, ", ");
    336     if (!dim.name().empty()) {
    337       strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
    338                          dim.size(), ")");
    339     } else {
    340       strings::StrAppend(&python, dim.size());
    341     }
    342   }
    343   strings::StrAppend(&python, "]");
    344   return python;
    345 }
    346 
    347 string TensorToPython(const TensorProto& proto) {
    348   return ProtoShortDebugString(proto);
    349 }
    350 
    351 string AttrListToPython(const AttrValue& value,
    352                         const string& dtype_module = "tf.") {
    353   string ret;
    354   if (value.list().s_size() > 0) {
    355     for (int i = 0; i < value.list().s_size(); ++i) {
    356       if (i > 0) strings::StrAppend(&ret, ", ");
    357       strings::StrAppend(&ret, StringToPython(value.list().s(i)));
    358     }
    359   } else if (value.list().i_size() > 0) {
    360     for (int i = 0; i < value.list().i_size(); ++i) {
    361       if (i > 0) strings::StrAppend(&ret, ", ");
    362       strings::StrAppend(&ret, value.list().i(i));
    363     }
    364   } else if (value.list().f_size() > 0) {
    365     for (int i = 0; i < value.list().f_size(); ++i) {
    366       if (i > 0) strings::StrAppend(&ret, ", ");
    367       strings::StrAppend(&ret, value.list().f(i));
    368     }
    369   } else if (value.list().b_size() > 0) {
    370     for (int i = 0; i < value.list().b_size(); ++i) {
    371       if (i > 0) strings::StrAppend(&ret, ", ");
    372       strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
    373     }
    374   } else if (value.list().type_size() > 0) {
    375     for (int i = 0; i < value.list().type_size(); ++i) {
    376       if (i > 0) strings::StrAppend(&ret, ", ");
    377       strings::StrAppend(&ret,
    378                          DataTypeToPython(value.list().type(i), dtype_module));
    379     }
    380   } else if (value.list().shape_size() > 0) {
    381     for (int i = 0; i < value.list().shape_size(); ++i) {
    382       if (i > 0) strings::StrAppend(&ret, ", ");
    383       strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
    384     }
    385   } else if (value.list().tensor_size() > 0) {
    386     for (int i = 0; i < value.list().tensor_size(); ++i) {
    387       if (i > 0) strings::StrAppend(&ret, ", ");
    388       strings::StrAppend(&ret, TensorToPython(value.list().tensor(i)));
    389     }
    390   } else if (value.list().func_size() > 0) {
    391     for (int i = 0; i < value.list().func_size(); ++i) {
    392       if (i > 0) strings::StrAppend(&ret, ", ");
    393       strings::StrAppend(&ret, StringToPython(value.list().func(i).name()));
    394     }
    395   }
    396   return ret;
    397 }
    398 
    399 // NOTE: The return value may contain spaces (for example, it could be
    400 // a string "foo bar" with an embedded space) and is not safe to pass
    401 // to WordWrap().
    402 string AttrValueToPython(const string& type, const AttrValue& value,
    403                          const string& dtype_module) {
    404   if (type == "string") {
    405     return StringToPython(value.s());
    406   } else if (type == "int") {
    407     return strings::StrCat(value.i());
    408   } else if (type == "float") {
    409     if (std::isnan(value.f()) || std::isinf(value.f())) {
    410       return strings::StrCat("float('", value.f(), "')");
    411     } else {
    412       return strings::StrCat(value.f());
    413     }
    414   } else if (type == "bool") {
    415     return value.b() ? "True" : "False";
    416   } else if (type == "type") {
    417     return DataTypeToPython(value.type(), dtype_module);
    418   } else if (type == "shape") {
    419     return ShapeToPython(value.shape());
    420   } else if (type == "tensor") {
    421     return TensorToPython(value.tensor());
    422   } else if (type == "func") {
    423     return StringToPython(value.func().name());
    424   } else if (StringPiece(type).starts_with("list(")) {
    425     return strings::StrCat("[", AttrListToPython(value, dtype_module), "]");
    426   } else {
    427     return "?";
    428   }
    429 }
    430 
    431 void GenerateLowerCaseOpName(const string& str, string* result) {
    432   const char joiner = '_';
    433   const int last_index = str.size() - 1;
    434   for (int i = 0; i <= last_index; ++i) {
    435     const char c = str[i];
    436     // Emit a joiner only if a previous-lower-to-now-upper or a
    437     // now-upper-to-next-lower transition happens.
    438     if (isupper(c) && (i > 0)) {
    439       if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
    440         result->push_back(joiner);
    441       }
    442     }
    443     result->push_back(tolower(c));
    444   }
    445 }
    446 
    447 static void AddDelimiter(string* append_to, const string& delim) {
    448   if (!append_to->empty()) strings::StrAppend(append_to, delim);
    449 }
    450 
    451 const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) {
    452   for (int i = 0; i < api_def.attr_size(); ++i) {
    453     if (api_def.attr(i).name() == name) {
    454       return &api_def.attr(i);
    455     }
    456   }
    457   return nullptr;
    458 }
    459 
    460 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
    461   for (int i = 0; i < api_def.in_arg_size(); ++i) {
    462     if (api_def.in_arg(i).name() == name) {
    463       return &api_def.in_arg(i);
    464     }
    465   }
    466   return nullptr;
    467 }
    468 
    469 GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def,
    470                          const string& function_name)
    471     : op_def_(op_def),
    472       api_def_(api_def),
    473       function_name_(function_name),
    474       num_outs_(op_def.output_arg_size()) {}
    475 
    476 GenPythonOp::~GenPythonOp() {}
    477 
    478 string GenPythonOp::Code() {
    479   // This has all the input args followed by those attrs that don't have
    480   // defaults.
    481   std::vector<ParamNames> params_no_default;
    482   // The parameters with defaults (these have to be listed after those without).
    483   // No input args are included, just attrs.
    484   std::vector<ParamNames> params_with_default;
    485 
    486   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
    487     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
    488     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
    489     params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to());
    490     if (!arg.type_attr().empty()) {
    491       gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name());
    492     } else if (!arg.type_list_attr().empty()) {
    493       gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(),
    494                               arg.name());
    495     }
    496     if (!arg.number_attr().empty()) {
    497       gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name());
    498     }
    499   }
    500   for (int i = 0; i < api_def_.attr_size(); ++i) {
    501     const auto& attr(api_def_.attr(i));
    502     // Do not add inferred attrs to the Python function signature.
    503     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
    504       if (attr.has_default_value()) {
    505         params_with_default.emplace_back(attr.name(), attr.rename_to());
    506       } else {
    507         params_no_default.emplace_back(attr.name(), attr.rename_to());
    508       }
    509     }
    510   }
    511 
    512   // Save the list of attr parameters (attrs that won't be inferred),
    513   // those with defaults go at the end.
    514   // Get the attrs in the order we want by taking the attrs without defaults
    515   // from the end of args_no_default, and adding args_no_default.
    516   attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() +
    517                  params_with_default.size());
    518   for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) {
    519     attrs_.push_back(params_no_default[i].GetName());
    520   }
    521   for (int i = 0; i < params_with_default.size(); ++i) {
    522     attrs_.push_back(params_with_default[i].GetName());
    523   }
    524 
    525   param_names_.reserve(params_no_default.size() + params_with_default.size());
    526   param_names_.insert(param_names_.begin(), params_no_default.begin(),
    527                       params_no_default.end());
    528   for (const auto& param : params_with_default) {
    529     param_names_.push_back(param);
    530   }
    531 
    532   string parameters;
    533   for (const auto& param : params_no_default) {
    534     AddDelimiter(&parameters, ", ");
    535     strings::StrAppend(&parameters, param.GetRenameTo());
    536   }
    537   for (const auto& param_and_default : params_with_default) {
    538     AddDelimiter(&parameters, ", ");
    539     strings::StrAppend(&parameters, param_and_default.GetRenameTo(), "=None");
    540   }
    541   AddDelimiter(&parameters, ", ");
    542   strings::StrAppend(&parameters, "name=None");
    543 
    544   AddExport();
    545   AddDefLine(parameters);
    546   AddDocStringDescription();
    547   AddDocStringArgs();
    548   AddDocStringInputs();
    549   AddDocStringAttrs();
    550   AddDocStringNameArg();
    551   AddOutputGlobals();
    552   AddDocStringOutputs();
    553   strings::StrAppend(&result_, "  \"\"\"\n");
    554   AddBody("  ");
    555   strings::StrAppend(&result_, "\n\n");
    556 
    557   return prelude_ + result_;
    558 }
    559 
    560 void GenPythonOp::AddExport() {
    561   if (api_def_.visibility() != ApiDef::VISIBLE) {
    562     return;
    563   }
    564 
    565   strings::StrAppend(&result_, "@tf_export(");
    566 
    567   // Add all endpoint names to tf_export.
    568   bool first_endpoint = true;
    569   for (const auto& endpoint : api_def_.endpoint()) {
    570     if (!first_endpoint) {
    571       strings::StrAppend(&result_, ", ");
    572     } else {
    573       first_endpoint = false;
    574     }
    575     string endpoint_name;
    576     python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(),
    577                                                     &endpoint_name);
    578     strings::StrAppend(&result_, "'", endpoint_name, "'");
    579   }
    580   strings::StrAppend(&result_, ")\n");
    581 }
    582 
    583 void GenPythonOp::AddDefLine(const string& function_name,
    584                              const string& parameters) {
    585   strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n");
    586 }
    587 
    588 void GenPythonOp::AddDefLine(const string& parameters) {
    589   AddDefLine(function_name_, parameters);
    590 }
    591 
    592 void GenPythonOp::AddDocStringDescription() {
    593   string comment;
    594   if (api_def_.summary().empty()) {
    595     comment = "TODO: add doc.\n";
    596   } else {
    597     comment = strings::StrCat(api_def_.summary(), "\n");
    598     if (!api_def_.description().empty()) {
    599       strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description()));
    600     }
    601   }
    602   strings::StrAppend(&result_, "  r\"\"\"", comment, "\n");
    603 }
    604 
    605 void GenPythonOp::AddDocStringArgs() {
    606   strings::StrAppend(&result_, "  Args:\n");
    607 }
    608 
    609 void GenPythonOp::AddDocStringInputs() {
    610   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
    611     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
    612     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
    613     StringPiece description = api_def_arg.description();
    614     string desc;
    615     if (ConsumeEquals(&description)) {  // Skip the generated type info.
    616       desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ");
    617     } else {
    618       desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ",
    619                              ArgTypeName(op_def_, arg, inferred_attrs_, false));
    620     }
    621     if (!description.empty()) {
    622       AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
    623     }
    624     strings::StrAppend(&result_, Indent(4, 6, desc));
    625   }
    626 }
    627 
    628 void GenPythonOp::AddDocStringAttrs() {
    629   for (const string& name : attrs_) {
    630     const auto& attr = *FindAttr(name, op_def_);
    631     const auto& api_def_attr = *FindAttr(name, api_def_);
    632     string desc =
    633         strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": ");
    634 
    635     static const char* const kAttrTypeName[][2] = {
    636         {"string", "`string`"},
    637         {"list(string)", "list of `strings`"},
    638         {"int", "`int`"},
    639         {"list(int)", "list of `ints`"},
    640         {"float", "`float`"},
    641         {"list(float)", "list of `floats`"},
    642         {"bool", "`bool`"},
    643         {"list(bool)", "list of `bools`"},
    644         {"type", "`tf.DType`"},
    645         {"list(type)", "list of `tf.DTypes`"},
    646         {"shape", "`tf.TensorShape` or list of `ints`"},
    647         {"list(shape)",
    648          "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
    649         {"tensor", "`tf.TensorProto`"},
    650         {"list(tensor)", "list of `tf.TensorProto` objects"},
    651         {"func", "function decorated with @Defun"},
    652         {"list(func)", "list of functions decorated with @Defun"},
    653     };
    654     for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
    655       if (attr.type() == kAttrTypeName[i][0]) {
    656         string s;
    657         if (api_def_attr.has_default_value()) {
    658           s = strings::StrCat("optional ", kAttrTypeName[i][1]);
    659         } else {
    660           s = kAttrTypeName[i][1];
    661         }
    662         if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
    663           strings::StrAppend(&desc, "An ", s);
    664         } else {
    665           strings::StrAppend(&desc, "A ", s);
    666         }
    667         break;
    668       }
    669     }
    670 
    671     if (attr.has_allowed_values()) {
    672       strings::StrAppend(&desc, " from: `",
    673                          AttrListToPython(attr.allowed_values()), "`");
    674     }
    675 
    676     if (attr.has_minimum()) {
    677       if (attr.type() == "int") {
    678         strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
    679       } else if (attr.minimum() > 0) {
    680         strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
    681       }
    682     }
    683 
    684     strings::StrAppend(&desc, ".");
    685 
    686     if (api_def_attr.has_default_value()) {
    687       strings::StrAppend(
    688           &desc, " Defaults to `",
    689           AttrValueToPython(attr.type(), api_def_attr.default_value()), "`.");
    690     }
    691     if (!api_def_attr.description().empty()) {
    692       AppendWithinWidth(&desc, api_def_attr.description(),
    693                         kRightMargin - 4 /* indent */);
    694     }
    695     strings::StrAppend(&result_, Indent(4, 6, desc));
    696   }
    697 }
    698 
    699 void GenPythonOp::AddDocStringNameArg() {
    700   strings::StrAppend(&result_,
    701                      "    name: A name for the operation (optional).\n");
    702 }
    703 
    704 void GenPythonOp::AddOutputGlobals() {
    705   // Prepare a NamedTuple type to hold the outputs, if there are multiple
    706   if (num_outs_ > 1) {
    707     // Prepare the list of output names
    708     std::vector<string> out_names(num_outs_);
    709     for (int i = 0; i < num_outs_; ++i) {
    710       if (!api_def_.out_arg(i).rename_to().empty()) {
    711         out_names[i] = api_def_.out_arg(i).rename_to();
    712       } else {
    713         out_names[i] = strings::StrCat("output", i);
    714       }
    715     }
    716     string out_names_list =
    717         strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]");
    718 
    719     // Provide the output names as a Python list
    720     string lower_op_name_outputs =
    721         strings::StrCat("_", function_name_, "_outputs");
    722     const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = ");
    723     strings::StrAppend(&prelude_, "\n",
    724                        WordWrap(outputs_prefix, out_names_list, kRightMargin),
    725                        "\n");
    726 
    727     strings::StrAppend(&prelude_, "_", op_def_.name(),
    728                        "Output = _collections.namedtuple(\n");
    729     const string tuple_type_prefix = "    ";
    730     const string tuple_type_suffix = strings::StrCat(
    731         "\"", op_def_.name(), "\", ", lower_op_name_outputs, ")");
    732     strings::StrAppend(
    733         &prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin),
    734         "\n\n");
    735   }
    736   strings::StrAppend(&prelude_, "\n");
    737 }
    738 
    739 void GenPythonOp::AddDocStringOutputs() {
    740   std::vector<string> output_type_string;
    741   output_type_string.reserve(num_outs_);
    742   for (int i = 0; i < num_outs_; ++i) {
    743     output_type_string.push_back(
    744         ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true));
    745   }
    746   strings::StrAppend(&result_, GetReturns(op_def_, output_type_string));
    747 }
    748 
    749 void GenPythonOp::AddBody(const string& prefix) {
    750   const string apply_prefix =
    751       strings::StrCat(prefix, "_result = _op_def_lib.apply_op(");
    752   AddBodyNoReturn(apply_prefix);
    753   if (num_outs_ > 1) {
    754     strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(),
    755                        "Output._make(_result)\n");
    756   }
    757   strings::StrAppend(&result_, prefix, "return _result\n");
    758 }
    759 
    760 void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) {
    761   string args = strings::StrCat("\"", op_def_.name(), "\", ");
    762   for (size_t i = 0; i < param_names_.size(); ++i) {
    763     strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()),
    764                        "=", param_names_[i].GetRenameTo(), ", ");
    765   }
    766   strings::StrAppend(&args, "name=name)");
    767 
    768   strings::StrAppend(&result_,
    769                      // Wrap the arguments, and indent to the (.
    770                      WordWrap(apply_prefix, args, kRightMargin), "\n");
    771 }
    772 
    773 }  // namespace python_op_gen_internal
    774 
    775 string GetPythonOp(const OpDef& op_def, const ApiDef& api_def,
    776                    const string& function_name) {
    777   return python_op_gen_internal::GenPythonOp(op_def, api_def, function_name)
    778       .Code();
    779 }
    780 
    781 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
    782                     const std::vector<string>& hidden_ops,
    783                     bool require_shapes) {
    784   string result;
    785   // Header
    786   // TODO(josh11b): Mention the library for which wrappers are being generated.
    787   strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
    788 
    789 This file is MACHINE GENERATED! Do not edit.
    790 """
    791 
    792 import collections as _collections
    793 
    794 from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
    795 
    796 # Needed to trigger the call to _set_call_cpp_shape_fn.
    797 from tensorflow.python.framework import common_shapes as _common_shapes
    798 
    799 from tensorflow.python.framework import op_def_registry as _op_def_registry
    800 from tensorflow.python.framework import ops as _ops
    801 from tensorflow.python.framework import op_def_library as _op_def_library
    802 from tensorflow.python.util.tf_export import tf_export
    803 )");
    804 
    805   // We'll make a copy of ops that filters out descriptions.
    806   OpList cleaned_ops;
    807   auto out = cleaned_ops.mutable_op();
    808   out->Reserve(ops.op_size());
    809   for (const auto& op_def : ops.op()) {
    810     const auto* api_def = api_defs.GetApiDef(op_def.name());
    811 
    812     if (api_def->visibility() == ApiDef::SKIP) {
    813       continue;
    814     }
    815 
    816     // An op is hidden if either its ApiDef visibility is HIDDEN
    817     // or it is in the hidden_ops list.
    818     bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
    819     if (!is_hidden) {
    820       for (const string& hidden : hidden_ops) {
    821         if (op_def.name() == hidden) {
    822           is_hidden = true;
    823           break;
    824         }
    825       }
    826     }
    827 
    828     string function_name;
    829     python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
    830                                                     &function_name);
    831     if (is_hidden) function_name = strings::StrCat("_", function_name);
    832 
    833     // When users create custom python wrappers, they may link in the
    834     // default op registry by accident, and because they can't
    835     // enumerate all 'hidden' symbols, this guard is to prevent
    836     // instantiating a python reserved word in their wrapper.
    837     if (python_op_gen_internal::IsPythonReserved(function_name)) {
    838       continue;
    839     }
    840 
    841     strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name));
    842 
    843     if (!require_shapes) {
    844       strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
    845                          "\")(None)\n");
    846     }
    847 
    848     auto added = out->Add();
    849     *added = op_def;
    850     RemoveNonDeprecationDescriptionsFromOpDef(added);
    851   }
    852 
    853   result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
    854   op_list = _op_def_pb2.OpList()
    855   op_list.ParseFromString(op_list_proto_bytes)
    856   _op_def_registry.register_op_list(op_list)
    857   op_def_lib = _op_def_library.OpDefLibrary()
    858   op_def_lib.add_op_list(op_list)
    859   return op_def_lib
    860 
    861 
    862 )");
    863 
    864   result.append("# ");
    865   auto ops_text = ProtoDebugString(cleaned_ops);
    866   str_util::StripTrailingWhitespace(&ops_text);
    867   result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
    868   result.append("\n");
    869   strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
    870                    str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
    871   return result;
    872 }
    873 
    874 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
    875                     const std::vector<string>& hidden_ops,
    876                     bool require_shapes) {
    877   printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes).c_str());
    878 }
    879 
    880 string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
    881   string op_list_str(op_list_buf, op_list_len);
    882   OpList ops;
    883   ops.ParseFromString(op_list_str);
    884   ApiDefMap api_def_map(ops);
    885   return GetPythonOps(ops, api_def_map, {}, false);
    886 }
    887 
    888 }  // namespace tensorflow
    889