Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/python/eager/python_eager_op_gen.h"
     16 
     17 #include <stdio.h>
     18 #include <sstream>
     19 #include <unordered_map>
     20 #include "tensorflow/core/framework/api_def.pb.h"
     21 #include "tensorflow/core/framework/attr_value.pb.h"
     22 #include "tensorflow/core/framework/op.h"
     23 #include "tensorflow/core/framework/op_def.pb_text.h"
     24 #include "tensorflow/core/framework/op_def.pb.h"
     25 #include "tensorflow/core/framework/op_def_util.h"
     26 #include "tensorflow/core/framework/op_gen_lib.h"
     27 #include "tensorflow/core/framework/tensor.pb_text.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/framework/types.pb.h"
     30 #include "tensorflow/core/lib/gtl/map_util.h"
     31 #include "tensorflow/core/lib/gtl/stl_util.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/lib/strings/strcat.h"
     34 #include "tensorflow/core/lib/strings/stringprintf.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/macros.h"
     37 #include "tensorflow/core/platform/types.h"
     38 #include "tensorflow/python/framework/python_op_gen_internal.h"
     39 
     40 namespace tensorflow {
     41 namespace {
     42 
     43 const int kRightMargin = 78;
     44 
     45 constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
     46 
     47 string AttrVarName(const string& attr_name,
     48                    std::unordered_map<string, string>* attr_expressions) {
     49   const string var = strings::StrCat("_attr_", attr_name);
     50   if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
     51   return var;
     52 }
     53 
     54 void AddInferredAttr(const string& indentation, const string& attr_name,
     55                      const string& value_expression, string* result,
     56                      std::unordered_map<string, string>* attr_expressions) {
     57   strings::StrAppend(result, indentation,
     58                      AttrVarName(attr_name, attr_expressions), " = ",
     59                      value_expression, "\n");
     60 }
     61 
     62 string VectorToTuple(const std::vector<string>& l) {
     63   if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
     64   string ret = "(";
     65   for (int i = 0; i < l.size(); ++i) {
     66     if (i > 0) {
     67       strings::StrAppend(&ret, ", ");
     68     }
     69     strings::StrAppend(&ret, l[i]);
     70   }
     71   strings::StrAppend(&ret, ")");
     72   return ret;
     73 }
     74 
     75 void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
     76                const string& var, string* result) {
     77   for (int i = 0; i < output_sizes.size(); ++i) {
     78     if (!output_sizes[i].empty()) {
     79       strings::StrAppend(result, prefix, var, " = ");
     80       if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
     81       if (i + 1 < output_sizes.size()) {
     82         // Special case i == 0 to avoid "0 +" in the generated code.
     83         if (i == 0) {
     84           strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
     85                              var, "[", output_sizes[i], ":]");
     86         } else {
     87           strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
     88                              output_sizes[i], "]] + ", var, "[", i, " + ",
     89                              output_sizes[i], ":]");
     90         }
     91       } else {
     92         strings::StrAppend(result, "[", var, "[", i, ":]]");
     93       }
     94       strings::StrAppend(result, "\n");
     95     }
     96   }
     97 }
     98 
     99 string TensorPBString(const TensorProto& pb) {
    100   // Note: This gets used in the argument list, and so must survive naive
    101   // word wrapping.
    102   return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\"");
    103 }
    104 
    105 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
    106   for (int i = 0; i < api_def.in_arg_size(); ++i) {
    107     if (api_def.in_arg(i).name() == name) {
    108       return &api_def.in_arg(i);
    109     }
    110   }
    111   return nullptr;
    112 }
    113 
    114 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
    115  public:
    116   GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
    117                    const string& function_name)
    118       : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) {
    119     op_name_ = function_name_;
    120     op_name_.Consume("_");
    121   }
    122   ~GenEagerPythonOp() override {}
    123 
    124   string Code() override;
    125 
    126  protected:
    127   void HandleGraphMode(const string& function_setup);
    128 
    129   string GetEagerNotAllowedError();
    130   void ExpectListArg(const string& indentation, const string& arg_name,
    131                      string* output);
    132   bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
    133   void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
    134                                        string* num_outputs_expr);
    135 
    136   void AddEagerFunctionTeardown(const string& indentation,
    137                                 const std::vector<string>& output_sizes,
    138                                 bool execute_record_gradient);
    139 
    140   bool AddEagerFastPathAndGraphCode(const string& parameters,
    141                                     const std::vector<string>& output_sizes,
    142                                     const string& eager_not_allowed_error);
    143   bool AddEagerFallbackCode(const string& parameters,
    144                             const std::vector<string>& output_sizes,
    145                             const string& num_outputs_expr,
    146                             const string& eager_not_allowed_error);
    147   void AddEagerFastPathExecute();
    148 
    149   void AddEagerInferredAttrs(const string& indentation);
    150   void AddEagerInputCasts(const string& indentation);
    151   void AddEagerAttrs(const string& indentation);
    152   void AddEagerExecute(const string& indentation,
    153                        const string& num_outputs_expr);
    154 
    155   void AddAttrForArg(const string& attr, int arg_index) {
    156     gtl::InsertIfNotPresent(&inferred_attrs_, attr,
    157                             op_def_.input_arg(arg_index).name());
    158     auto iter = attr_to_args_.find(attr);
    159     if (iter == attr_to_args_.end()) {
    160       attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
    161     } else {
    162       iter->second.push_back(arg_index);
    163     }
    164   }
    165 
    166   // Returns a string expression representing a flattened list of all
    167   // the inputs given by `*input_indices` (or all inputs if
    168   // `input_indices` is nullptr).  `*output_sizes` can be used to unflatten.
    169   string FlattenInputs(const std::vector<int>* input_indices,
    170                        std::vector<string>* output_sizes) const;
    171 
    172   StringPiece op_name_;
    173   typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
    174   AttrToArgMap attr_to_args_;
    175   std::unordered_map<string, string> attr_expressions_;
    176   // This has all the input args followed by those attrs that don't have
    177   // defaults.
    178   std::vector<python_op_gen_internal::ParamNames> params_no_default_;
    179   // The parameters with defaults (these have to be listed after those without).
    180   // No input args are included, just attrs.
    181   std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
    182       params_with_default_;
    183 };
    184 
    185 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
    186                         const string& function_name) {
    187   return GenEagerPythonOp(op_def, api_def, function_name).Code();
    188 }
    189 
    190 string GenEagerPythonOp::FlattenInputs(
    191     const std::vector<int>* input_indices,
    192     std::vector<string>* output_sizes) const {
    193   string inputs;
    194   enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
    195   const int n = input_indices != nullptr ? input_indices->size()
    196                                          : op_def_.input_arg_size();
    197   for (int j = 0; j < n; ++j) {
    198     const int i = input_indices ? (*input_indices)[j] : j;
    199     const auto& arg(op_def_.input_arg(i));
    200     const bool is_list =
    201         !arg.type_list_attr().empty() || !arg.number_attr().empty();
    202     if (is_list) {
    203       if (inputs_state == WAS_SOLO_INPUT) {
    204         strings::StrAppend(&inputs, "] + ");
    205       } else if (inputs_state == WAS_LIST_INPUT) {
    206         strings::StrAppend(&inputs, " + ");
    207       }
    208       strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
    209       inputs_state = WAS_LIST_INPUT;
    210       if (output_sizes != nullptr) {
    211         if (!arg.number_attr().empty()) {
    212           output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
    213         } else {
    214           output_sizes->emplace_back(
    215               strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
    216         }
    217       }
    218     } else {
    219       if (inputs_state == WAS_SOLO_INPUT) {
    220         strings::StrAppend(&inputs, ", ");
    221       } else if (inputs_state == WAS_LIST_INPUT) {
    222         strings::StrAppend(&inputs, " + [");
    223       } else {
    224         strings::StrAppend(&inputs, "[");
    225       }
    226       strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
    227       inputs_state = WAS_SOLO_INPUT;
    228       if (output_sizes != nullptr) output_sizes->emplace_back();
    229     }
    230   }
    231   if (inputs_state == STARTING) return "[]";
    232   if (inputs_state == WAS_SOLO_INPUT) {
    233     strings::StrAppend(&inputs, "]");
    234   }
    235   return inputs;
    236 }
    237 
    238 string GenEagerPythonOp::Code() {
    239   if (api_def_.visibility() == ApiDef::SKIP) {
    240     return "";
    241   }
    242 
    243   for (int i = 0; i < api_def_.arg_order_size(); ++i) {
    244     const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
    245     const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
    246     params_no_default_.emplace_back(api_def_arg.name(),
    247                                     api_def_arg.rename_to());
    248     if (!arg.type_attr().empty()) {
    249       AddAttrForArg(arg.type_attr(), i);
    250     } else if (!arg.type_list_attr().empty()) {
    251       AddAttrForArg(arg.type_list_attr(), i);
    252     }
    253     if (!arg.number_attr().empty()) {
    254       AddAttrForArg(arg.number_attr(), i);
    255     }
    256   }
    257   for (int i = 0; i < op_def_.attr_size(); ++i) {
    258     const auto& attr(op_def_.attr(i));
    259     const auto& api_def_attr(api_def_.attr(i));
    260     // Do not add inferred attrs to the Python function signature.
    261     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
    262       if (api_def_attr.has_default_value()) {
    263         if (attr.type() == "tensor") {
    264           params_with_default_.emplace_back(
    265               python_op_gen_internal::ParamNames(api_def_attr.name(),
    266                                                  api_def_attr.rename_to()),
    267               strings::StrCat(
    268                   "_execute.make_tensor(",
    269                   TensorPBString(api_def_attr.default_value().tensor()), ", \"",
    270                   api_def_attr.rename_to(), "\")"));
    271         } else if (attr.type() == "list(tensor)") {
    272           std::vector<string> pbtxt;
    273           for (const auto& pb : api_def_attr.default_value().list().tensor()) {
    274             pbtxt.emplace_back(TensorPBString(pb));
    275           }
    276           params_with_default_.emplace_back(
    277               python_op_gen_internal::ParamNames(api_def_attr.name(),
    278                                                  api_def_attr.rename_to()),
    279               strings::StrCat("[_execute.make_tensor(_pb, \"",
    280                               api_def_attr.rename_to(), "\") for _pb in ",
    281                               VectorToTuple(pbtxt), "]"));
    282         } else {
    283           params_with_default_.emplace_back(
    284               python_op_gen_internal::ParamNames(api_def_attr.name(),
    285                                                  api_def_attr.rename_to()),
    286               python_op_gen_internal::AttrValueToPython(
    287                   attr.type(), api_def_attr.default_value(), "_dtypes."));
    288         }
    289       } else {
    290         params_no_default_.emplace_back(api_def_attr.name(),
    291                                         api_def_attr.rename_to());
    292       }
    293     }
    294   }
    295 
    296   // Save the list of attr parameters (attrs that won't be inferred),
    297   // those with defaults go at the end.
    298   // Get the attrs in the order we want by taking the attrs without defaults
    299   // from the end of params_no_default_, and adding params_no_default_.
    300   attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
    301                  params_with_default_.size());
    302   for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) {
    303     attrs_.push_back(params_no_default_[i].GetName());
    304   }
    305   for (const auto& p : params_with_default_) {
    306     attrs_.push_back(p.first.GetName());
    307   }
    308 
    309   param_names_.reserve(params_no_default_.size() + params_with_default_.size());
    310   param_names_.insert(param_names_.begin(), params_no_default_.begin(),
    311                       params_no_default_.end());
    312   for (const auto& param_and_default : params_with_default_) {
    313     param_names_.push_back(param_and_default.first);
    314   }
    315 
    316   string parameters;
    317   for (const auto& param : params_no_default_) {
    318     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
    319     strings::StrAppend(&parameters, param.GetRenameTo());
    320   }
    321   for (const auto& param_and_default : params_with_default_) {
    322     if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
    323     strings::StrAppend(&parameters, param_and_default.first.GetRenameTo(), "=",
    324                        param_and_default.second);
    325   }
    326   if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
    327   strings::StrAppend(&parameters, "name=None");
    328 
    329   // Add attr_expressions_ for attrs that are params.
    330   for (int i = 0; i < attrs_.size(); ++i) {
    331     const string& attr_name = attrs_[i];
    332     const string& attr_api_name =
    333         param_names_[i + op_def_.input_arg_size()].GetRenameTo();
    334     attr_expressions_[attr_name] = attr_api_name;
    335   }
    336   // Add attr_expressions_ for attrs that are inferred.
    337   for (int i = 0; i < op_def_.attr_size(); ++i) {
    338     const auto& attr(op_def_.attr(i));
    339     if (attr.type() == "int") {
    340       auto arg_list = attr_to_args_.find(attr.name());
    341       if (arg_list != attr_to_args_.end()) {
    342         AttrVarName(attr.name(), &attr_expressions_);
    343       }
    344     }
    345   }
    346 
    347   string num_outputs_expr;
    348   std::vector<string> output_sizes(num_outs_);
    349   GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
    350 
    351   string eager_not_allowed_error = GetEagerNotAllowedError();
    352 
    353   if (!AddEagerFastPathAndGraphCode(parameters, output_sizes,
    354                                     eager_not_allowed_error)) {
    355     return result_;
    356   }
    357 
    358   if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
    359                             eager_not_allowed_error)) {
    360     return result_;
    361   }
    362 
    363   return prelude_ + result_;
    364 }
    365 
    366 void GenEagerPythonOp::HandleGraphMode(const string& function_setup) {
    367   // Handle graph-mode case
    368   strings::StrAppend(&result_,
    369                      "  _ctx = _context.context()\n"
    370                      "  if _ctx.in_graph_mode():\n",
    371                      function_setup,
    372                      "    _, _, _op = _op_def_lib._apply_op_helper(\n");
    373   AddBodyNoReturn("        ");
    374   if (num_outs_ > 0) {
    375     strings::StrAppend(&result_, "    _result = _op.outputs[:]\n");
    376     // Special case handling for stateful op with single list output
    377     // that might be empty.
    378     if (num_outs_ == 1 && op_def_.is_stateful() &&
    379         (!op_def_.output_arg(0).number_attr().empty() ||
    380          !op_def_.output_arg(0).type_list_attr().empty())) {
    381       // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
    382       // a constraint indicating that this can never be empty.
    383       strings::StrAppend(&result_,
    384                          "    if not _result:\n"
    385                          "      return _op\n");
    386     }
    387     strings::StrAppend(&result_, "    _inputs_flat = _op.inputs\n");
    388 
    389     // Compute graph-mode attrs.
    390     if (op_def_.attr_size() > 0) {
    391       string attr_values;
    392       for (int i = 0; i < op_def_.attr_size(); ++i) {
    393         if (i > 0) strings::StrAppend(&attr_values, ", ");
    394         const auto& attr_name(op_def_.attr(i).name());
    395         strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"",
    396                            attr_name, "\")");
    397       }
    398       strings::StrAppend(&attr_values, ")");
    399       strings::StrAppend(&result_,
    400                          WordWrap("    _attrs = (", attr_values, kRightMargin),
    401                          "\n");
    402     } else {
    403       strings::StrAppend(&result_, "    _attrs = None\n");
    404     }
    405   } else {
    406     strings::StrAppend(&result_, "    return _op\n");
    407   }
    408 }
    409 
    410 string GenEagerPythonOp::GetEagerNotAllowedError() {
    411   bool eager_allowed = true;
    412   string ref_arg;
    413   for (int i = 0; i < op_def_.input_arg_size(); ++i) {
    414     const auto& arg = op_def_.input_arg(i);
    415     if (arg.is_ref()) {
    416       eager_allowed = false;
    417       DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
    418       ref_arg = api_def_.in_arg(i).rename_to();
    419     }
    420   }
    421   for (int i = 0; i < op_def_.output_arg_size(); ++i) {
    422     const auto& arg = op_def_.output_arg(i);
    423     if (arg.is_ref()) {
    424       eager_allowed = false;
    425       DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
    426       ref_arg = api_def_.out_arg(i).rename_to();
    427     }
    428   }
    429 
    430   if (eager_allowed) return "";
    431 
    432   return strings::StrCat("raise RuntimeError(\"", op_name_,
    433                          " op does not support eager execution. ", "Arg '",
    434                          ref_arg, "' is a ref.\")\n");
    435 }
    436 
    437 void GenEagerPythonOp::ExpectListArg(const string& indentation,
    438                                      const string& arg_name, string* output) {
    439   strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
    440                      ", (list, tuple)):\n", indentation, "  raise TypeError(\n",
    441                      indentation, "      \"Expected list for '", arg_name,
    442                      "' argument to \"\n", indentation, "      \"'", op_name_,
    443                      "' Op, not %r.\" % ", arg_name, ")\n");
    444 }
    445 
    446 bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
    447                                              string* function_setup) {
    448   // Validate list inputs, infer length attrs.
    449   for (int i = 0; i < op_def_.attr_size(); ++i) {
    450     const auto& attr(op_def_.attr(i));
    451     if (attr.type() == "int") {
    452       auto arg_list = attr_to_args_.find(attr.name());
    453       if (arg_list != attr_to_args_.end()) {
    454         // Inferred int attrs are the lengths of inputs. Validate those
    455         // inputs are lists and have the same length.
    456         for (auto iter = arg_list->second.begin();
    457              iter != arg_list->second.end(); ++iter) {
    458           const string& arg_api_name = param_names_[*iter].GetRenameTo();
    459           ExpectListArg(indentation, arg_api_name, function_setup);
    460           if (iter == arg_list->second.begin()) {
    461             AddInferredAttr(indentation, attr.name(),
    462                             strings::StrCat("len(", arg_api_name, ")"),
    463                             function_setup, &attr_expressions_);
    464           } else {
    465             const auto& attr_var = attr_expressions_[attr.name()];
    466             strings::StrAppend(
    467                 function_setup, indentation, "if len(", arg_api_name,
    468                 ") != ", attr_var, ":\n", indentation, "  raise ValueError(\n",
    469                 indentation, "      \"List argument '", arg_api_name, "' to '",
    470                 op_name_, "' Op with length %d \"\n", indentation,
    471                 "      \"must match length %d of argument '",
    472                 inferred_attrs_[attr.name()], "'.\" %\n", indentation,
    473                 "      (len(", arg_api_name, "), ", attr_var, "))\n");
    474           }
    475         }
    476       }
    477     }
    478   }
    479 
    480   for (int i = 0; i < attrs_.size(); ++i) {
    481     const string& attr_name = attrs_[i];
    482     const auto& param = param_names_[i + op_def_.input_arg_size()];
    483     const auto& attr = *FindAttr(attr_name, op_def_);
    484     const string& attr_api_name = param.GetRenameTo();
    485     StringPiece attr_type = attr.type();
    486     attr_expressions_[attr_name] = attr_api_name;
    487     const int default_index = i - (attrs_.size() - params_with_default_.size());
    488     if (default_index >= 0) {
    489       const string& default_value = params_with_default_[default_index].second;
    490       strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
    491                          " is None:\n");
    492       strings::StrAppend(function_setup, indentation, "  ", attr_api_name,
    493                          " = ", default_value, "\n");
    494     }
    495     if (attr_type.starts_with("list(")) {
    496       ExpectListArg(indentation, attr_api_name, function_setup);
    497     }
    498 
    499     if (attr_type == "string") {
    500       strings::StrAppend(function_setup, indentation, attr_api_name,
    501                          " = _execute.make_str(", attr_api_name, ", \"",
    502                          attr_api_name, "\")\n");
    503     } else if (attr_type == "list(string)") {
    504       strings::StrAppend(function_setup, indentation, attr_api_name,
    505                          " = [_execute.make_str(_s, \"", attr_api_name,
    506                          "\") for _s in ", attr_api_name, "]\n");
    507     } else if (attr_type == "int") {
    508       strings::StrAppend(function_setup, indentation, attr_api_name,
    509                          " = _execute.make_int(", attr_api_name, ", \"",
    510                          attr_api_name, "\")\n");
    511     } else if (attr_type == "list(int)") {
    512       strings::StrAppend(function_setup, indentation, attr_api_name,
    513                          " = [_execute.make_int(_i, \"", attr_api_name,
    514                          "\") for _i in ", attr_api_name, "]\n");
    515     } else if (attr_type == "float") {
    516       strings::StrAppend(function_setup, indentation, attr_api_name,
    517                          " = _execute.make_float(", attr_api_name, ", \"",
    518                          attr_api_name, "\")\n");
    519     } else if (attr_type == "list(float)") {
    520       strings::StrAppend(function_setup, indentation, attr_api_name,
    521                          " = [_execute.make_float(_f, \"", attr_api_name,
    522                          "\") for _f in ", attr_api_name, "]\n");
    523     } else if (attr_type == "bool") {
    524       strings::StrAppend(function_setup, indentation, attr_api_name,
    525                          " = _execute.make_bool(", attr_api_name, ", \"",
    526                          attr_api_name, "\")\n");
    527     } else if (attr_type == "list(bool)") {
    528       strings::StrAppend(function_setup, indentation, attr_api_name,
    529                          " = [_execute.make_bool(_b, \"", attr_api_name,
    530                          "\") for _b in ", attr_api_name, "]\n");
    531     } else if (attr_type == "type") {
    532       strings::StrAppend(function_setup, indentation, attr_api_name,
    533                          " = _execute.make_type(", attr_api_name, ", \"",
    534                          attr_api_name, "\")\n");
    535     } else if (attr_type == "list(type)") {
    536       strings::StrAppend(function_setup, indentation, attr_api_name,
    537                          " = [_execute.make_type(_t, \"", attr_api_name,
    538                          "\") for _t in ", attr_api_name, "]\n");
    539     } else if (attr_type == "shape") {
    540       strings::StrAppend(function_setup, indentation, attr_api_name,
    541                          " = _execute.make_shape(", attr_api_name, ", \"",
    542                          attr_api_name, "\")\n");
    543     } else if (attr_type == "list(shape)") {
    544       strings::StrAppend(function_setup, indentation, attr_api_name,
    545                          " = [_execute.make_shape(_s, \"", attr_api_name,
    546                          "\") for _s in ", attr_api_name, "]\n");
    547     } else if (attr_type == "tensor") {
    548       strings::StrAppend(function_setup, indentation, attr_api_name,
    549                          " = _execute.make_tensor(", attr_api_name, ", \"",
    550                          attr_api_name, "\")\n");
    551     } else if (attr_type == "list(tensor)") {
    552       strings::StrAppend(function_setup, indentation, attr_api_name,
    553                          " = [_execute.make_tensor(_t, \"", attr_api_name,
    554                          "\") for _t in ", attr_api_name, "]\n");
    555     } else if (attr_type != "func") {
    556       *function_setup =
    557           strings::StrCat("# No definition for ", function_name_,
    558                           " since we don't support attrs with type\n"
    559                           "# '",
    560                           attr_type, "' right now.\n\n");
    561       return false;
    562     }
    563   }
    564   return true;
    565 }
    566 
    567 // If output i is list output, output_sizes[i] will be set to a
    568 // string with the python expression that will evaluate to its
    569 // length. output_sizes[i] is empty for non-list outputs.
    570 void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
    571     std::vector<string>* output_sizes, string* num_outputs_expr) {
    572   // Expression representing the number of outputs.
    573   int num_fixed_outputs = 0;
    574   for (int i = 0; i < num_outs_; ++i) {
    575     const auto& arg(op_def_.output_arg(i));
    576     if (!arg.number_attr().empty()) {
    577       if (!num_outputs_expr->empty()) {
    578         strings::StrAppend(num_outputs_expr, " + ");
    579       }
    580       (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
    581       strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
    582     } else if (!arg.type_list_attr().empty()) {
    583       if (!num_outputs_expr->empty()) {
    584         strings::StrAppend(num_outputs_expr, " + ");
    585       }
    586       // Have to be careful to use an expression that works in both
    587       // graph and eager paths here.
    588       const auto iter = inferred_attrs_.find(arg.type_list_attr());
    589       if (iter == inferred_attrs_.end()) {
    590         (*output_sizes)[i] = strings::StrCat(
    591             "len(", attr_expressions_[arg.type_list_attr()], ")");
    592       } else {
    593         (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
    594       }
    595       strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
    596     } else {
    597       ++num_fixed_outputs;
    598     }
    599   }
    600   if (num_fixed_outputs > 0) {
    601     if (!num_outputs_expr->empty()) {
    602       strings::StrAppend(num_outputs_expr, " + ");
    603     }
    604     strings::StrAppend(num_outputs_expr, num_fixed_outputs);
    605   } else if (num_outputs_expr->empty()) {
    606     *num_outputs_expr = "0";
    607   }
    608 }
    609 
    610 void GenEagerPythonOp::AddEagerFunctionTeardown(
    611     const string& indentation, const std::vector<string>& output_sizes,
    612     bool execute_record_gradient) {
    613   if (num_outs_ > 0) {
    614     if (execute_record_gradient) {
    615       strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n",
    616                          "      \"", op_def_.name(),
    617                          "\", _inputs_flat, _attrs, _result, name)\n");
    618     }
    619     if (num_outs_ == 1 && !output_sizes[0].empty()) {
    620       // Single list result.
    621     } else if (num_outs_ == 1) {
    622       // Execute returns a single-element list which we need to destructure.
    623       strings::StrAppend(&result_, indentation, "_result, = _result\n");
    624     } else {
    625       // Have multiple outputs, so we will need to reformat the return
    626       // value of execute() to be a list with one entry per op output
    627       // (that entry will be a list of tensors if that output is of list
    628       // type).
    629       // For list outputs, convert the right subrange of _result into a list.
    630       Unflatten(indentation, output_sizes, "_result", &result_);
    631       // Convert to a named tuple.
    632       strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(),
    633                          "Output._make(_result)\n");
    634     }
    635   } else {
    636     strings::StrAppend(&result_, indentation, "_result = None\n");
    637   }
    638   strings::StrAppend(&result_, indentation, "return _result\n\n");
    639 }
    640 
    641 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
    642     const string& parameters, const std::vector<string>& output_sizes,
    643     const string& eager_not_allowed_error) {
    644   AddExport();
    645   AddDefLine(function_name_, parameters);
    646   AddDocStringDescription();
    647   AddDocStringArgs();
    648   AddDocStringInputs();
    649   AddDocStringAttrs();
    650   AddDocStringNameArg();
    651   AddOutputGlobals();  // Added to prelude_
    652   AddDocStringOutputs();
    653   strings::StrAppend(&result_, "  \"\"\"\n");
    654 
    655   // Handle graph-mode case
    656   string function_setup;
    657   if (!GetEagerFunctionSetup("    ", &function_setup)) {
    658     result_ = function_setup;
    659     return false;
    660   }
    661   HandleGraphMode(function_setup);
    662   AddEagerFunctionTeardown("    ", output_sizes,
    663                            true /* execute_record_gradient */);
    664 
    665   // Handle eager-mode case
    666   strings::StrAppend(&result_, "  else:\n");
    667 
    668   if (eager_not_allowed_error.empty()) {
    669     AddEagerFastPathExecute();
    670   } else {
    671     strings::StrAppend(&result_, "    ", eager_not_allowed_error);
    672   }
    673 
    674   strings::StrAppend(&result_, "\n\n");
    675   return true;
    676 }
    677 
    678 bool GenEagerPythonOp::AddEagerFallbackCode(
    679     const string& parameters, const std::vector<string>& output_sizes,
    680     const string& num_outputs_expr, const string& eager_not_allowed_error) {
    681   if (!eager_not_allowed_error.empty()) {
    682     strings::StrAppend(&result_, "  ", eager_not_allowed_error);
    683     return true;
    684   }
    685 
    686   AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix), parameters);
    687   strings::StrAppend(
    688       &result_, "  r\"\"\"This is the slowpath function for Eager mode.\n");
    689   strings::StrAppend(&result_, "  This is for function ", function_name_,
    690                      "\n  \"\"\"\n");
    691 
    692   strings::StrAppend(&result_, "  _ctx = _context.context()\n");
    693 
    694   string function_setup;
    695   if (!GetEagerFunctionSetup("  ", &function_setup)) {
    696     result_ = function_setup;
    697     return false;
    698   }
    699   strings::StrAppend(&result_, function_setup);
    700 
    701   AddEagerInferredAttrs("  ");
    702   AddEagerInputCasts("  ");
    703   strings::StrAppend(
    704       &result_, "  _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
    705   AddEagerAttrs("  ");
    706   AddEagerExecute("  ", num_outputs_expr);
    707 
    708   AddEagerFunctionTeardown("  ", output_sizes,
    709                            true /* execute_record_gradient */);
    710 
    711   return true;
    712 }
    713 
    714 void GenEagerPythonOp::AddEagerFastPathExecute() {
    715   string fastpath_execute_params = strings::StrCat(
    716       "_ctx._handle, _ctx.device_name, \"", op_def_.name(), "\", ",
    717       "_execute.record_gradient, name, _ctx._post_execution_callbacks");
    718   string fallback_params;
    719 
    720   for (int i = 0; i < api_def_.in_arg_size(); i++) {
    721     const string param_name = param_names_[i].GetRenameTo();
    722     strings::StrAppend(&fastpath_execute_params, ", ", param_name);
    723     if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
    724     strings::StrAppend(&fallback_params, param_name);
    725   }
    726 
    727   for (const auto& attr : api_def_.attr()) {
    728     if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
    729       strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
    730                          attr.rename_to());
    731 
    732       if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
    733       strings::StrAppend(&fallback_params, attr.rename_to(), "=",
    734                          attr.rename_to());
    735     }
    736   }
    737 
    738   if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
    739   strings::StrAppend(&fallback_params, "name=name");
    740 
    741   strings::StrAppend(&result_, "    try:\n");
    742   strings::StrAppend(
    743       &result_, "      ",
    744       "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n",
    745       WordWrap(strings::StrCat("        "),
    746                strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
    747       "\n");
    748 
    749   if (op_def_.output_arg_size() > 1) {
    750     const string output_tuple_name =
    751         strings::StrCat("_", op_def_.name(), "Output");
    752     strings::StrAppend(&result_, "      ", "_result = ", output_tuple_name,
    753                        "._make(_result)\n");
    754   }
    755   strings::StrAppend(&result_, "      ", "return _result\n");
    756 
    757   // Handle fallback.
    758   strings::StrAppend(&result_, "    ", "except _core._FallbackException:\n");
    759   strings::StrAppend(
    760       &result_, "      ", "return ", function_name_, kEagerFallbackSuffix,
    761       "(\n",
    762       WordWrap(strings::StrCat("          "),
    763                strings::StrCat(fallback_params, ")"), kRightMargin),
    764       "\n");
    765 
    766   // Any errors thrown from execute need to be unwrapped from
    767   // _NotOkStatusException.
    768   strings::StrAppend(&result_, "    ",
    769                      "except _core._NotOkStatusException as e:\n");
    770   strings::StrAppend(&result_, "      ", "if name is not None:\n");
    771   strings::StrAppend(&result_, "        ",
    772                      "message = e.message + \" name: \" + name\n");
    773   strings::StrAppend(&result_, "      ", "else:\n");
    774   strings::StrAppend(&result_, "        ", "message = e.message\n");
    775   strings::StrAppend(
    776       &result_, "      ",
    777       "_six.raise_from(_core._status_to_exception(e.code, message), None)\n");
    778 }
    779 
    780 void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
    781   // Figure out values for inferred attrs, and cast to eager tensors.
    782   for (int i = 0; i < op_def_.attr_size(); ++i) {
    783     const auto& attr(op_def_.attr(i));
    784     const auto& api_def_attr(api_def_.attr(i));
    785     auto arg_list = attr_to_args_.find(attr.name());
    786     if (arg_list != attr_to_args_.end()) {
    787       if (attr.type() == "type") {
    788         std::vector<string> output_sizes;
    789         const string flattened =
    790             FlattenInputs(&arg_list->second, &output_sizes);
    791         string conversion = strings::StrCat("_execute.args_to_matching_eager(",
    792                                             flattened, ", _ctx");
    793         if (attr.has_default_value()) {
    794           strings::StrAppend(
    795               &conversion, ", ",
    796               python_op_gen_internal::AttrValueToPython(
    797                   attr.type(), api_def_attr.default_value(), "_dtypes."));
    798         }
    799         strings::StrAppend(&conversion, ")");
    800         const string var_name = AttrVarName(attr.name(), &attr_expressions_);
    801         if (output_sizes.size() == 1) {
    802           // Avoid creating a temporary variable in the case where
    803           // we can easily assign to the right value directly.
    804           const string inputs_var =
    805               param_names_[arg_list->second.front()].GetRenameTo();
    806           if (output_sizes.front().empty()) {
    807             strings::StrAppend(&result_, indentation, var_name, ", (",
    808                                inputs_var, ",) = ", conversion, "\n");
    809           } else {
    810             strings::StrAppend(&result_, indentation, var_name, ", ",
    811                                inputs_var, " = ", conversion, "\n");
    812           }
    813         } else {
    814           const string inputs_var = strings::StrCat("_inputs_", attr.name());
    815           strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
    816                              " = ", conversion, "\n");
    817           // Convert from a flat list of eager tensors back to the
    818           // parameter variables.
    819           Unflatten(indentation, output_sizes, inputs_var, &result_);
    820           std::vector<string> p;
    821           for (int j : arg_list->second) {
    822             p.emplace_back(param_names_[j].GetRenameTo());
    823           }
    824           strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
    825                              inputs_var, "\n");
    826         }
    827       } else if (attr.type() == "list(type)") {
    828         // NOTE: We ignore default values for these attrs, since it is
    829         // unclear how you would use it, and the one use case is
    830         // parse_single_sequence_example which only needs it for
    831         // backwards compatibility.
    832         const string var_name = AttrVarName(attr.name(), &attr_expressions_);
    833         string inputs_var;
    834         string conversion;
    835         if (arg_list->second.size() > 1) {
    836           // If you have more than one list(tensor) argument, their types
    837           // have to match.
    838           std::vector<string> lists;
    839           for (auto iter = arg_list->second.begin();
    840                iter != arg_list->second.end(); ++iter) {
    841             lists.push_back(param_names_[*iter].GetRenameTo());
    842           }
    843           inputs_var = VectorToTuple(lists);
    844           conversion = "_execute.args_to_mixed_eager_tensors";
    845         } else {
    846           // For one list(tensor) argument, we just convert every
    847           // element of the list to an eager tensor.
    848           inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
    849           conversion = "_execute.convert_to_mixed_eager_tensors";
    850         }
    851         strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
    852                            " = ", conversion, "(", inputs_var, ", _ctx)\n");
    853       }
    854     }
    855   }
    856 }
    857 
    858 void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
    859   // Cast remaining args to eager tensors
    860   for (int i = 0; i < op_def_.input_arg_size(); ++i) {
    861     const auto& arg(op_def_.input_arg(i));
    862     if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
    863     const string& param = param_names_[i].GetRenameTo();
    864     const string fn = arg.number_attr().empty() ? "" : "n_";
    865     const string dtype =
    866         python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
    867     strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
    868                        "to_tensor(", param, ", ", dtype, ")\n");
    869   }
    870 }
    871 
    872 void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
    873   // Compute eager attrs
    874   if (op_def_.attr_size() > 0) {
    875     string attr_values;
    876     for (int i = 0; i < op_def_.attr_size(); ++i) {
    877       if (i > 0) strings::StrAppend(&attr_values, ", ");
    878       const auto& attr_name(op_def_.attr(i).name());
    879       strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
    880                          attr_expressions_[attr_name]);
    881     }
    882     strings::StrAppend(&attr_values, ")");
    883     strings::StrAppend(
    884         &result_,
    885         WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
    886                  kRightMargin),
    887         "\n");
    888   } else {
    889     strings::StrAppend(&result_, indentation, "_attrs = None\n");
    890   }
    891 }
    892 
    893 void GenEagerPythonOp::AddEagerExecute(const string& indentation,
    894                                        const string& num_outputs_expr) {
    895   const string return_prefix =
    896       strings::StrCat(indentation, "_result = _execute.execute(");
    897   const string return_args = strings::StrCat(
    898       "b\"", op_def_.name(), "\", ", num_outputs_expr,
    899       ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)");
    900   strings::StrAppend(&result_,
    901                      // Wrap the arguments, and indent to the (.
    902                      WordWrap(return_prefix, return_args, kRightMargin), "\n");
    903 }
    904 
    905 string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
    906                          const std::vector<string>& hidden_ops,
    907                          bool require_shapes,
    908                          const string& source_file_name = "") {
    909   string result;
    910   // Header
    911   // TODO(josh11b): Mention the library for which wrappers are being generated.
    912   strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
    913 
    914 This file is MACHINE GENERATED! Do not edit.
    915 )");
    916 
    917   // Mention the original source file so someone tracing back through
    918   // generated Python code will know where to look next.
    919   if (!source_file_name.empty()) {
    920     strings::StrAppend(&result, "Original C++ source file: ");
    921     strings::StrAppend(&result, source_file_name);
    922     strings::StrAppend(&result, "\n");
    923   }
    924 
    925   strings::StrAppend(&result, R"("""
    926 
    927 import collections as _collections
    928 import six as _six
    929 
    930 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
    931 from tensorflow.python.eager import context as _context
    932 from tensorflow.python.eager import core as _core
    933 from tensorflow.python.eager import execute as _execute
    934 from tensorflow.python.framework import dtypes as _dtypes
    935 from tensorflow.python.framework import errors as _errors
    936 from tensorflow.python.framework import tensor_shape as _tensor_shape
    937 
    938 from tensorflow.core.framework import op_def_pb2 as _op_def_pb2
    939 # Needed to trigger the call to _set_call_cpp_shape_fn.
    940 from tensorflow.python.framework import common_shapes as _common_shapes
    941 from tensorflow.python.framework import op_def_registry as _op_def_registry
    942 from tensorflow.python.framework import ops as _ops
    943 from tensorflow.python.framework import op_def_library as _op_def_library
    944 from tensorflow.python.util.tf_export import tf_export
    945 
    946 )");
    947 
    948   // We'll make a copy of ops that filters out descriptions.
    949   OpList cleaned_ops;
    950   auto out = cleaned_ops.mutable_op();
    951   out->Reserve(ops.op_size());
    952   for (const auto& op_def : ops.op()) {
    953     const auto* api_def = api_defs.GetApiDef(op_def.name());
    954 
    955     if (api_def->visibility() == ApiDef::SKIP) {
    956       continue;
    957     }
    958 
    959     // An op is hidden if either its ApiDef visibility is HIDDEN
    960     // or it is in the hidden_ops list.
    961     bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
    962     if (!is_hidden) {
    963       for (const string& hidden : hidden_ops) {
    964         if (op_def.name() == hidden) {
    965           is_hidden = true;
    966           break;
    967         }
    968       }
    969     }
    970 
    971     string function_name;
    972     python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
    973                                                     &function_name);
    974     if (is_hidden) function_name = strings::StrCat("_", function_name);
    975 
    976     // When users create custom python wrappers, they may link in the
    977     // default op registry by accident, and because they can't
    978     // enumerate all 'hidden' symbols, this guard is to prevent
    979     // instantiating a python reserved word in their wrapper.
    980     if (python_op_gen_internal::IsPythonReserved(function_name)) {
    981       continue;
    982     }
    983 
    984     strings::StrAppend(&result,
    985                        GetEagerPythonOp(op_def, *api_def, function_name));
    986 
    987     if (!require_shapes) {
    988       strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(),
    989                          "\")(None)\n\n");
    990     }
    991 
    992     auto added = out->Add();
    993     *added = op_def;
    994     RemoveNonDeprecationDescriptionsFromOpDef(added);
    995   }
    996 
    997   result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes):
    998   op_list = _op_def_pb2.OpList()
    999   op_list.ParseFromString(op_list_proto_bytes)
   1000   _op_def_registry.register_op_list(op_list)
   1001   op_def_lib = _op_def_library.OpDefLibrary()
   1002   op_def_lib.add_op_list(op_list)
   1003   return op_def_lib
   1004 )");
   1005 
   1006   result.append("# ");
   1007   auto ops_text = ProtoDebugString(cleaned_ops);
   1008   str_util::StripTrailingWhitespace(&ops_text);
   1009   result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true));
   1010   result.append("\n");
   1011   strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n",
   1012                    str_util::CEscape(cleaned_ops.SerializeAsString()).c_str());
   1013   return result;
   1014 }
   1015 
   1016 }  // namespace
   1017 
   1018 void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs,
   1019                          const std::vector<string>& hidden_ops,
   1020                          bool require_shapes, const string& source_file_name) {
   1021   printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes,
   1022                                  source_file_name)
   1023                    .c_str());
   1024 }
   1025 
   1026 string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
   1027   string op_list_str(op_list_buf, op_list_len);
   1028   OpList ops;
   1029   ops.ParseFromString(op_list_str);
   1030 
   1031   ApiDefMap api_def_map(ops);
   1032   return GetEagerPythonOps(ops, api_def_map, {}, false);
   1033 }
   1034 
   1035 }  // namespace tensorflow
   1036