1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/python/eager/python_eager_op_gen.h" 16 17 #include <stdio.h> 18 #include <sstream> 19 #include <unordered_map> 20 #include "tensorflow/core/framework/api_def.pb.h" 21 #include "tensorflow/core/framework/attr_value.pb.h" 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_def.pb_text.h" 24 #include "tensorflow/core/framework/op_def.pb.h" 25 #include "tensorflow/core/framework/op_def_util.h" 26 #include "tensorflow/core/framework/op_gen_lib.h" 27 #include "tensorflow/core/framework/tensor.pb_text.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/framework/types.pb.h" 30 #include "tensorflow/core/lib/gtl/map_util.h" 31 #include "tensorflow/core/lib/gtl/stl_util.h" 32 #include "tensorflow/core/lib/strings/str_util.h" 33 #include "tensorflow/core/lib/strings/strcat.h" 34 #include "tensorflow/core/lib/strings/stringprintf.h" 35 #include "tensorflow/core/platform/logging.h" 36 #include "tensorflow/core/platform/macros.h" 37 #include "tensorflow/core/platform/types.h" 38 #include "tensorflow/python/framework/python_op_gen_internal.h" 39 40 namespace tensorflow { 41 namespace { 42 43 const int kRightMargin = 78; 44 45 constexpr char kEagerFallbackSuffix[] = "_eager_fallback"; 46 47 string AttrVarName(const string& attr_name, 48 std::unordered_map<string, string>* attr_expressions) { 49 const string var = strings::StrCat("_attr_", attr_name); 50 if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var; 51 return var; 52 } 53 54 void AddInferredAttr(const string& indentation, const string& attr_name, 55 const string& value_expression, string* result, 56 std::unordered_map<string, string>* attr_expressions) { 57 strings::StrAppend(result, indentation, 58 AttrVarName(attr_name, attr_expressions), " = ", 59 value_expression, "\n"); 60 } 61 62 string VectorToTuple(const std::vector<string>& l) { 63 if (l.size() == 1) return strings::StrCat("(", l.front(), ",)"); 64 string ret = "("; 65 for (int i = 0; i < l.size(); ++i) { 66 if (i > 0) { 67 strings::StrAppend(&ret, ", "); 68 } 69 strings::StrAppend(&ret, l[i]); 70 } 71 strings::StrAppend(&ret, ")"); 72 return ret; 73 } 74 75 void Unflatten(const string& prefix, const std::vector<string>& output_sizes, 76 const string& var, string* result) { 77 for (int i = 0; i < output_sizes.size(); ++i) { 78 if (!output_sizes[i].empty()) { 79 strings::StrAppend(result, prefix, var, " = "); 80 if (i > 0) strings::StrAppend(result, var, "[:", i, "] + "); 81 if (i + 1 < output_sizes.size()) { 82 // Special case i == 0 to avoid "0 +" in the generated code. 83 if (i == 0) { 84 strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ", 85 var, "[", output_sizes[i], ":]"); 86 } else { 87 strings::StrAppend(result, "[", var, "[", i, ":", i, " + ", 88 output_sizes[i], "]] + ", var, "[", i, " + ", 89 output_sizes[i], ":]"); 90 } 91 } else { 92 strings::StrAppend(result, "[", var, "[", i, ":]]"); 93 } 94 strings::StrAppend(result, "\n"); 95 } 96 } 97 } 98 99 string TensorPBString(const TensorProto& pb) { 100 // Note: This gets used in the argument list, and so must survive naive 101 // word wrapping. 102 return strings::StrCat("\"\"\"", ProtoShortDebugString(pb), "\"\"\""); 103 } 104 105 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { 106 for (int i = 0; i < api_def.in_arg_size(); ++i) { 107 if (api_def.in_arg(i).name() == name) { 108 return &api_def.in_arg(i); 109 } 110 } 111 return nullptr; 112 } 113 114 class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { 115 public: 116 GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, 117 const string& function_name) 118 : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) { 119 op_name_ = function_name_; 120 op_name_.Consume("_"); 121 } 122 ~GenEagerPythonOp() override {} 123 124 string Code() override; 125 126 protected: 127 void HandleGraphMode(const string& function_setup); 128 129 string GetEagerNotAllowedError(); 130 void ExpectListArg(const string& indentation, const string& arg_name, 131 string* output); 132 bool GetEagerFunctionSetup(const string& indentation, string* function_setup); 133 void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes, 134 string* num_outputs_expr); 135 136 void AddEagerFunctionTeardown(const string& indentation, 137 const std::vector<string>& output_sizes, 138 bool execute_record_gradient); 139 140 bool AddEagerFastPathAndGraphCode(const string& parameters, 141 const std::vector<string>& output_sizes, 142 const string& eager_not_allowed_error); 143 bool AddEagerFallbackCode(const string& parameters, 144 const std::vector<string>& output_sizes, 145 const string& num_outputs_expr, 146 const string& eager_not_allowed_error); 147 void AddEagerFastPathExecute(); 148 149 void AddEagerInferredAttrs(const string& indentation); 150 void AddEagerInputCasts(const string& indentation); 151 void AddEagerAttrs(const string& indentation); 152 void AddEagerExecute(const string& indentation, 153 const string& num_outputs_expr); 154 155 void AddAttrForArg(const string& attr, int arg_index) { 156 gtl::InsertIfNotPresent(&inferred_attrs_, attr, 157 op_def_.input_arg(arg_index).name()); 158 auto iter = attr_to_args_.find(attr); 159 if (iter == attr_to_args_.end()) { 160 attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index})); 161 } else { 162 iter->second.push_back(arg_index); 163 } 164 } 165 166 // Returns a string expression representing a flattened list of all 167 // the inputs given by `*input_indices` (or all inputs if 168 // `input_indices` is nullptr). `*output_sizes` can be used to unflatten. 169 string FlattenInputs(const std::vector<int>* input_indices, 170 std::vector<string>* output_sizes) const; 171 172 StringPiece op_name_; 173 typedef std::unordered_map<string, std::vector<int>> AttrToArgMap; 174 AttrToArgMap attr_to_args_; 175 std::unordered_map<string, string> attr_expressions_; 176 // This has all the input args followed by those attrs that don't have 177 // defaults. 178 std::vector<python_op_gen_internal::ParamNames> params_no_default_; 179 // The parameters with defaults (these have to be listed after those without). 180 // No input args are included, just attrs. 181 std::vector<std::pair<python_op_gen_internal::ParamNames, string>> 182 params_with_default_; 183 }; 184 185 string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, 186 const string& function_name) { 187 return GenEagerPythonOp(op_def, api_def, function_name).Code(); 188 } 189 190 string GenEagerPythonOp::FlattenInputs( 191 const std::vector<int>* input_indices, 192 std::vector<string>* output_sizes) const { 193 string inputs; 194 enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING; 195 const int n = input_indices != nullptr ? input_indices->size() 196 : op_def_.input_arg_size(); 197 for (int j = 0; j < n; ++j) { 198 const int i = input_indices ? (*input_indices)[j] : j; 199 const auto& arg(op_def_.input_arg(i)); 200 const bool is_list = 201 !arg.type_list_attr().empty() || !arg.number_attr().empty(); 202 if (is_list) { 203 if (inputs_state == WAS_SOLO_INPUT) { 204 strings::StrAppend(&inputs, "] + "); 205 } else if (inputs_state == WAS_LIST_INPUT) { 206 strings::StrAppend(&inputs, " + "); 207 } 208 strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")"); 209 inputs_state = WAS_LIST_INPUT; 210 if (output_sizes != nullptr) { 211 if (!arg.number_attr().empty()) { 212 output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr)); 213 } else { 214 output_sizes->emplace_back( 215 strings::StrCat("len(", param_names_[i].GetRenameTo(), ")")); 216 } 217 } 218 } else { 219 if (inputs_state == WAS_SOLO_INPUT) { 220 strings::StrAppend(&inputs, ", "); 221 } else if (inputs_state == WAS_LIST_INPUT) { 222 strings::StrAppend(&inputs, " + ["); 223 } else { 224 strings::StrAppend(&inputs, "["); 225 } 226 strings::StrAppend(&inputs, param_names_[i].GetRenameTo()); 227 inputs_state = WAS_SOLO_INPUT; 228 if (output_sizes != nullptr) output_sizes->emplace_back(); 229 } 230 } 231 if (inputs_state == STARTING) return "[]"; 232 if (inputs_state == WAS_SOLO_INPUT) { 233 strings::StrAppend(&inputs, "]"); 234 } 235 return inputs; 236 } 237 238 string GenEagerPythonOp::Code() { 239 if (api_def_.visibility() == ApiDef::SKIP) { 240 return ""; 241 } 242 243 for (int i = 0; i < api_def_.arg_order_size(); ++i) { 244 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); 245 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); 246 params_no_default_.emplace_back(api_def_arg.name(), 247 api_def_arg.rename_to()); 248 if (!arg.type_attr().empty()) { 249 AddAttrForArg(arg.type_attr(), i); 250 } else if (!arg.type_list_attr().empty()) { 251 AddAttrForArg(arg.type_list_attr(), i); 252 } 253 if (!arg.number_attr().empty()) { 254 AddAttrForArg(arg.number_attr(), i); 255 } 256 } 257 for (int i = 0; i < op_def_.attr_size(); ++i) { 258 const auto& attr(op_def_.attr(i)); 259 const auto& api_def_attr(api_def_.attr(i)); 260 // Do not add inferred attrs to the Python function signature. 261 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { 262 if (api_def_attr.has_default_value()) { 263 if (attr.type() == "tensor") { 264 params_with_default_.emplace_back( 265 python_op_gen_internal::ParamNames(api_def_attr.name(), 266 api_def_attr.rename_to()), 267 strings::StrCat( 268 "_execute.make_tensor(", 269 TensorPBString(api_def_attr.default_value().tensor()), ", \"", 270 api_def_attr.rename_to(), "\")")); 271 } else if (attr.type() == "list(tensor)") { 272 std::vector<string> pbtxt; 273 for (const auto& pb : api_def_attr.default_value().list().tensor()) { 274 pbtxt.emplace_back(TensorPBString(pb)); 275 } 276 params_with_default_.emplace_back( 277 python_op_gen_internal::ParamNames(api_def_attr.name(), 278 api_def_attr.rename_to()), 279 strings::StrCat("[_execute.make_tensor(_pb, \"", 280 api_def_attr.rename_to(), "\") for _pb in ", 281 VectorToTuple(pbtxt), "]")); 282 } else { 283 params_with_default_.emplace_back( 284 python_op_gen_internal::ParamNames(api_def_attr.name(), 285 api_def_attr.rename_to()), 286 python_op_gen_internal::AttrValueToPython( 287 attr.type(), api_def_attr.default_value(), "_dtypes.")); 288 } 289 } else { 290 params_no_default_.emplace_back(api_def_attr.name(), 291 api_def_attr.rename_to()); 292 } 293 } 294 } 295 296 // Save the list of attr parameters (attrs that won't be inferred), 297 // those with defaults go at the end. 298 // Get the attrs in the order we want by taking the attrs without defaults 299 // from the end of params_no_default_, and adding params_no_default_. 300 attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() + 301 params_with_default_.size()); 302 for (int i = op_def_.input_arg_size(); i < params_no_default_.size(); ++i) { 303 attrs_.push_back(params_no_default_[i].GetName()); 304 } 305 for (const auto& p : params_with_default_) { 306 attrs_.push_back(p.first.GetName()); 307 } 308 309 param_names_.reserve(params_no_default_.size() + params_with_default_.size()); 310 param_names_.insert(param_names_.begin(), params_no_default_.begin(), 311 params_no_default_.end()); 312 for (const auto& param_and_default : params_with_default_) { 313 param_names_.push_back(param_and_default.first); 314 } 315 316 string parameters; 317 for (const auto& param : params_no_default_) { 318 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 319 strings::StrAppend(¶meters, param.GetRenameTo()); 320 } 321 for (const auto& param_and_default : params_with_default_) { 322 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 323 strings::StrAppend(¶meters, param_and_default.first.GetRenameTo(), "=", 324 param_and_default.second); 325 } 326 if (!parameters.empty()) strings::StrAppend(¶meters, ", "); 327 strings::StrAppend(¶meters, "name=None"); 328 329 // Add attr_expressions_ for attrs that are params. 330 for (int i = 0; i < attrs_.size(); ++i) { 331 const string& attr_name = attrs_[i]; 332 const string& attr_api_name = 333 param_names_[i + op_def_.input_arg_size()].GetRenameTo(); 334 attr_expressions_[attr_name] = attr_api_name; 335 } 336 // Add attr_expressions_ for attrs that are inferred. 337 for (int i = 0; i < op_def_.attr_size(); ++i) { 338 const auto& attr(op_def_.attr(i)); 339 if (attr.type() == "int") { 340 auto arg_list = attr_to_args_.find(attr.name()); 341 if (arg_list != attr_to_args_.end()) { 342 AttrVarName(attr.name(), &attr_expressions_); 343 } 344 } 345 } 346 347 string num_outputs_expr; 348 std::vector<string> output_sizes(num_outs_); 349 GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr); 350 351 string eager_not_allowed_error = GetEagerNotAllowedError(); 352 353 if (!AddEagerFastPathAndGraphCode(parameters, output_sizes, 354 eager_not_allowed_error)) { 355 return result_; 356 } 357 358 if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr, 359 eager_not_allowed_error)) { 360 return result_; 361 } 362 363 return prelude_ + result_; 364 } 365 366 void GenEagerPythonOp::HandleGraphMode(const string& function_setup) { 367 // Handle graph-mode case 368 strings::StrAppend(&result_, 369 " _ctx = _context.context()\n" 370 " if _ctx.in_graph_mode():\n", 371 function_setup, 372 " _, _, _op = _op_def_lib._apply_op_helper(\n"); 373 AddBodyNoReturn(" "); 374 if (num_outs_ > 0) { 375 strings::StrAppend(&result_, " _result = _op.outputs[:]\n"); 376 // Special case handling for stateful op with single list output 377 // that might be empty. 378 if (num_outs_ == 1 && op_def_.is_stateful() && 379 (!op_def_.output_arg(0).number_attr().empty() || 380 !op_def_.output_arg(0).type_list_attr().empty())) { 381 // TODO(josh11b): Can skip this if the number_attr/type_list_attr has 382 // a constraint indicating that this can never be empty. 383 strings::StrAppend(&result_, 384 " if not _result:\n" 385 " return _op\n"); 386 } 387 strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n"); 388 389 // Compute graph-mode attrs. 390 if (op_def_.attr_size() > 0) { 391 string attr_values; 392 for (int i = 0; i < op_def_.attr_size(); ++i) { 393 if (i > 0) strings::StrAppend(&attr_values, ", "); 394 const auto& attr_name(op_def_.attr(i).name()); 395 strings::StrAppend(&attr_values, "\"", attr_name, "\", _op.get_attr(\"", 396 attr_name, "\")"); 397 } 398 strings::StrAppend(&attr_values, ")"); 399 strings::StrAppend(&result_, 400 WordWrap(" _attrs = (", attr_values, kRightMargin), 401 "\n"); 402 } else { 403 strings::StrAppend(&result_, " _attrs = None\n"); 404 } 405 } else { 406 strings::StrAppend(&result_, " return _op\n"); 407 } 408 } 409 410 string GenEagerPythonOp::GetEagerNotAllowedError() { 411 bool eager_allowed = true; 412 string ref_arg; 413 for (int i = 0; i < op_def_.input_arg_size(); ++i) { 414 const auto& arg = op_def_.input_arg(i); 415 if (arg.is_ref()) { 416 eager_allowed = false; 417 DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name()); 418 ref_arg = api_def_.in_arg(i).rename_to(); 419 } 420 } 421 for (int i = 0; i < op_def_.output_arg_size(); ++i) { 422 const auto& arg = op_def_.output_arg(i); 423 if (arg.is_ref()) { 424 eager_allowed = false; 425 DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name()); 426 ref_arg = api_def_.out_arg(i).rename_to(); 427 } 428 } 429 430 if (eager_allowed) return ""; 431 432 return strings::StrCat("raise RuntimeError(\"", op_name_, 433 " op does not support eager execution. ", "Arg '", 434 ref_arg, "' is a ref.\")\n"); 435 } 436 437 void GenEagerPythonOp::ExpectListArg(const string& indentation, 438 const string& arg_name, string* output) { 439 strings::StrAppend(output, indentation, "if not isinstance(", arg_name, 440 ", (list, tuple)):\n", indentation, " raise TypeError(\n", 441 indentation, " \"Expected list for '", arg_name, 442 "' argument to \"\n", indentation, " \"'", op_name_, 443 "' Op, not %r.\" % ", arg_name, ")\n"); 444 } 445 446 bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation, 447 string* function_setup) { 448 // Validate list inputs, infer length attrs. 449 for (int i = 0; i < op_def_.attr_size(); ++i) { 450 const auto& attr(op_def_.attr(i)); 451 if (attr.type() == "int") { 452 auto arg_list = attr_to_args_.find(attr.name()); 453 if (arg_list != attr_to_args_.end()) { 454 // Inferred int attrs are the lengths of inputs. Validate those 455 // inputs are lists and have the same length. 456 for (auto iter = arg_list->second.begin(); 457 iter != arg_list->second.end(); ++iter) { 458 const string& arg_api_name = param_names_[*iter].GetRenameTo(); 459 ExpectListArg(indentation, arg_api_name, function_setup); 460 if (iter == arg_list->second.begin()) { 461 AddInferredAttr(indentation, attr.name(), 462 strings::StrCat("len(", arg_api_name, ")"), 463 function_setup, &attr_expressions_); 464 } else { 465 const auto& attr_var = attr_expressions_[attr.name()]; 466 strings::StrAppend( 467 function_setup, indentation, "if len(", arg_api_name, 468 ") != ", attr_var, ":\n", indentation, " raise ValueError(\n", 469 indentation, " \"List argument '", arg_api_name, "' to '", 470 op_name_, "' Op with length %d \"\n", indentation, 471 " \"must match length %d of argument '", 472 inferred_attrs_[attr.name()], "'.\" %\n", indentation, 473 " (len(", arg_api_name, "), ", attr_var, "))\n"); 474 } 475 } 476 } 477 } 478 } 479 480 for (int i = 0; i < attrs_.size(); ++i) { 481 const string& attr_name = attrs_[i]; 482 const auto& param = param_names_[i + op_def_.input_arg_size()]; 483 const auto& attr = *FindAttr(attr_name, op_def_); 484 const string& attr_api_name = param.GetRenameTo(); 485 StringPiece attr_type = attr.type(); 486 attr_expressions_[attr_name] = attr_api_name; 487 const int default_index = i - (attrs_.size() - params_with_default_.size()); 488 if (default_index >= 0) { 489 const string& default_value = params_with_default_[default_index].second; 490 strings::StrAppend(function_setup, indentation, "if ", attr_api_name, 491 " is None:\n"); 492 strings::StrAppend(function_setup, indentation, " ", attr_api_name, 493 " = ", default_value, "\n"); 494 } 495 if (attr_type.starts_with("list(")) { 496 ExpectListArg(indentation, attr_api_name, function_setup); 497 } 498 499 if (attr_type == "string") { 500 strings::StrAppend(function_setup, indentation, attr_api_name, 501 " = _execute.make_str(", attr_api_name, ", \"", 502 attr_api_name, "\")\n"); 503 } else if (attr_type == "list(string)") { 504 strings::StrAppend(function_setup, indentation, attr_api_name, 505 " = [_execute.make_str(_s, \"", attr_api_name, 506 "\") for _s in ", attr_api_name, "]\n"); 507 } else if (attr_type == "int") { 508 strings::StrAppend(function_setup, indentation, attr_api_name, 509 " = _execute.make_int(", attr_api_name, ", \"", 510 attr_api_name, "\")\n"); 511 } else if (attr_type == "list(int)") { 512 strings::StrAppend(function_setup, indentation, attr_api_name, 513 " = [_execute.make_int(_i, \"", attr_api_name, 514 "\") for _i in ", attr_api_name, "]\n"); 515 } else if (attr_type == "float") { 516 strings::StrAppend(function_setup, indentation, attr_api_name, 517 " = _execute.make_float(", attr_api_name, ", \"", 518 attr_api_name, "\")\n"); 519 } else if (attr_type == "list(float)") { 520 strings::StrAppend(function_setup, indentation, attr_api_name, 521 " = [_execute.make_float(_f, \"", attr_api_name, 522 "\") for _f in ", attr_api_name, "]\n"); 523 } else if (attr_type == "bool") { 524 strings::StrAppend(function_setup, indentation, attr_api_name, 525 " = _execute.make_bool(", attr_api_name, ", \"", 526 attr_api_name, "\")\n"); 527 } else if (attr_type == "list(bool)") { 528 strings::StrAppend(function_setup, indentation, attr_api_name, 529 " = [_execute.make_bool(_b, \"", attr_api_name, 530 "\") for _b in ", attr_api_name, "]\n"); 531 } else if (attr_type == "type") { 532 strings::StrAppend(function_setup, indentation, attr_api_name, 533 " = _execute.make_type(", attr_api_name, ", \"", 534 attr_api_name, "\")\n"); 535 } else if (attr_type == "list(type)") { 536 strings::StrAppend(function_setup, indentation, attr_api_name, 537 " = [_execute.make_type(_t, \"", attr_api_name, 538 "\") for _t in ", attr_api_name, "]\n"); 539 } else if (attr_type == "shape") { 540 strings::StrAppend(function_setup, indentation, attr_api_name, 541 " = _execute.make_shape(", attr_api_name, ", \"", 542 attr_api_name, "\")\n"); 543 } else if (attr_type == "list(shape)") { 544 strings::StrAppend(function_setup, indentation, attr_api_name, 545 " = [_execute.make_shape(_s, \"", attr_api_name, 546 "\") for _s in ", attr_api_name, "]\n"); 547 } else if (attr_type == "tensor") { 548 strings::StrAppend(function_setup, indentation, attr_api_name, 549 " = _execute.make_tensor(", attr_api_name, ", \"", 550 attr_api_name, "\")\n"); 551 } else if (attr_type == "list(tensor)") { 552 strings::StrAppend(function_setup, indentation, attr_api_name, 553 " = [_execute.make_tensor(_t, \"", attr_api_name, 554 "\") for _t in ", attr_api_name, "]\n"); 555 } else if (attr_type != "func") { 556 *function_setup = 557 strings::StrCat("# No definition for ", function_name_, 558 " since we don't support attrs with type\n" 559 "# '", 560 attr_type, "' right now.\n\n"); 561 return false; 562 } 563 } 564 return true; 565 } 566 567 // If output i is list output, output_sizes[i] will be set to a 568 // string with the python expression that will evaluate to its 569 // length. output_sizes[i] is empty for non-list outputs. 570 void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr( 571 std::vector<string>* output_sizes, string* num_outputs_expr) { 572 // Expression representing the number of outputs. 573 int num_fixed_outputs = 0; 574 for (int i = 0; i < num_outs_; ++i) { 575 const auto& arg(op_def_.output_arg(i)); 576 if (!arg.number_attr().empty()) { 577 if (!num_outputs_expr->empty()) { 578 strings::StrAppend(num_outputs_expr, " + "); 579 } 580 (*output_sizes)[i] = attr_expressions_[arg.number_attr()]; 581 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]); 582 } else if (!arg.type_list_attr().empty()) { 583 if (!num_outputs_expr->empty()) { 584 strings::StrAppend(num_outputs_expr, " + "); 585 } 586 // Have to be careful to use an expression that works in both 587 // graph and eager paths here. 588 const auto iter = inferred_attrs_.find(arg.type_list_attr()); 589 if (iter == inferred_attrs_.end()) { 590 (*output_sizes)[i] = strings::StrCat( 591 "len(", attr_expressions_[arg.type_list_attr()], ")"); 592 } else { 593 (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")"); 594 } 595 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]); 596 } else { 597 ++num_fixed_outputs; 598 } 599 } 600 if (num_fixed_outputs > 0) { 601 if (!num_outputs_expr->empty()) { 602 strings::StrAppend(num_outputs_expr, " + "); 603 } 604 strings::StrAppend(num_outputs_expr, num_fixed_outputs); 605 } else if (num_outputs_expr->empty()) { 606 *num_outputs_expr = "0"; 607 } 608 } 609 610 void GenEagerPythonOp::AddEagerFunctionTeardown( 611 const string& indentation, const std::vector<string>& output_sizes, 612 bool execute_record_gradient) { 613 if (num_outs_ > 0) { 614 if (execute_record_gradient) { 615 strings::StrAppend(&result_, indentation, "_execute.record_gradient(\n", 616 " \"", op_def_.name(), 617 "\", _inputs_flat, _attrs, _result, name)\n"); 618 } 619 if (num_outs_ == 1 && !output_sizes[0].empty()) { 620 // Single list result. 621 } else if (num_outs_ == 1) { 622 // Execute returns a single-element list which we need to destructure. 623 strings::StrAppend(&result_, indentation, "_result, = _result\n"); 624 } else { 625 // Have multiple outputs, so we will need to reformat the return 626 // value of execute() to be a list with one entry per op output 627 // (that entry will be a list of tensors if that output is of list 628 // type). 629 // For list outputs, convert the right subrange of _result into a list. 630 Unflatten(indentation, output_sizes, "_result", &result_); 631 // Convert to a named tuple. 632 strings::StrAppend(&result_, indentation, "_result = _", op_def_.name(), 633 "Output._make(_result)\n"); 634 } 635 } else { 636 strings::StrAppend(&result_, indentation, "_result = None\n"); 637 } 638 strings::StrAppend(&result_, indentation, "return _result\n\n"); 639 } 640 641 bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( 642 const string& parameters, const std::vector<string>& output_sizes, 643 const string& eager_not_allowed_error) { 644 AddExport(); 645 AddDefLine(function_name_, parameters); 646 AddDocStringDescription(); 647 AddDocStringArgs(); 648 AddDocStringInputs(); 649 AddDocStringAttrs(); 650 AddDocStringNameArg(); 651 AddOutputGlobals(); // Added to prelude_ 652 AddDocStringOutputs(); 653 strings::StrAppend(&result_, " \"\"\"\n"); 654 655 // Handle graph-mode case 656 string function_setup; 657 if (!GetEagerFunctionSetup(" ", &function_setup)) { 658 result_ = function_setup; 659 return false; 660 } 661 HandleGraphMode(function_setup); 662 AddEagerFunctionTeardown(" ", output_sizes, 663 true /* execute_record_gradient */); 664 665 // Handle eager-mode case 666 strings::StrAppend(&result_, " else:\n"); 667 668 if (eager_not_allowed_error.empty()) { 669 AddEagerFastPathExecute(); 670 } else { 671 strings::StrAppend(&result_, " ", eager_not_allowed_error); 672 } 673 674 strings::StrAppend(&result_, "\n\n"); 675 return true; 676 } 677 678 bool GenEagerPythonOp::AddEagerFallbackCode( 679 const string& parameters, const std::vector<string>& output_sizes, 680 const string& num_outputs_expr, const string& eager_not_allowed_error) { 681 if (!eager_not_allowed_error.empty()) { 682 strings::StrAppend(&result_, " ", eager_not_allowed_error); 683 return true; 684 } 685 686 AddDefLine(strings::StrCat(function_name_, kEagerFallbackSuffix), parameters); 687 strings::StrAppend( 688 &result_, " r\"\"\"This is the slowpath function for Eager mode.\n"); 689 strings::StrAppend(&result_, " This is for function ", function_name_, 690 "\n \"\"\"\n"); 691 692 strings::StrAppend(&result_, " _ctx = _context.context()\n"); 693 694 string function_setup; 695 if (!GetEagerFunctionSetup(" ", &function_setup)) { 696 result_ = function_setup; 697 return false; 698 } 699 strings::StrAppend(&result_, function_setup); 700 701 AddEagerInferredAttrs(" "); 702 AddEagerInputCasts(" "); 703 strings::StrAppend( 704 &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n"); 705 AddEagerAttrs(" "); 706 AddEagerExecute(" ", num_outputs_expr); 707 708 AddEagerFunctionTeardown(" ", output_sizes, 709 true /* execute_record_gradient */); 710 711 return true; 712 } 713 714 void GenEagerPythonOp::AddEagerFastPathExecute() { 715 string fastpath_execute_params = strings::StrCat( 716 "_ctx._handle, _ctx.device_name, \"", op_def_.name(), "\", ", 717 "_execute.record_gradient, name, _ctx._post_execution_callbacks"); 718 string fallback_params; 719 720 for (int i = 0; i < api_def_.in_arg_size(); i++) { 721 const string param_name = param_names_[i].GetRenameTo(); 722 strings::StrAppend(&fastpath_execute_params, ", ", param_name); 723 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); 724 strings::StrAppend(&fallback_params, param_name); 725 } 726 727 for (const auto& attr : api_def_.attr()) { 728 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { 729 strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ", 730 attr.rename_to()); 731 732 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); 733 strings::StrAppend(&fallback_params, attr.rename_to(), "=", 734 attr.rename_to()); 735 } 736 } 737 738 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", "); 739 strings::StrAppend(&fallback_params, "name=name"); 740 741 strings::StrAppend(&result_, " try:\n"); 742 strings::StrAppend( 743 &result_, " ", 744 "_result = _pywrap_tensorflow.TFE_Py_FastPathExecute(\n", 745 WordWrap(strings::StrCat(" "), 746 strings::StrCat(fastpath_execute_params, ")"), kRightMargin), 747 "\n"); 748 749 if (op_def_.output_arg_size() > 1) { 750 const string output_tuple_name = 751 strings::StrCat("_", op_def_.name(), "Output"); 752 strings::StrAppend(&result_, " ", "_result = ", output_tuple_name, 753 "._make(_result)\n"); 754 } 755 strings::StrAppend(&result_, " ", "return _result\n"); 756 757 // Handle fallback. 758 strings::StrAppend(&result_, " ", "except _core._FallbackException:\n"); 759 strings::StrAppend( 760 &result_, " ", "return ", function_name_, kEagerFallbackSuffix, 761 "(\n", 762 WordWrap(strings::StrCat(" "), 763 strings::StrCat(fallback_params, ")"), kRightMargin), 764 "\n"); 765 766 // Any errors thrown from execute need to be unwrapped from 767 // _NotOkStatusException. 768 strings::StrAppend(&result_, " ", 769 "except _core._NotOkStatusException as e:\n"); 770 strings::StrAppend(&result_, " ", "if name is not None:\n"); 771 strings::StrAppend(&result_, " ", 772 "message = e.message + \" name: \" + name\n"); 773 strings::StrAppend(&result_, " ", "else:\n"); 774 strings::StrAppend(&result_, " ", "message = e.message\n"); 775 strings::StrAppend( 776 &result_, " ", 777 "_six.raise_from(_core._status_to_exception(e.code, message), None)\n"); 778 } 779 780 void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) { 781 // Figure out values for inferred attrs, and cast to eager tensors. 782 for (int i = 0; i < op_def_.attr_size(); ++i) { 783 const auto& attr(op_def_.attr(i)); 784 const auto& api_def_attr(api_def_.attr(i)); 785 auto arg_list = attr_to_args_.find(attr.name()); 786 if (arg_list != attr_to_args_.end()) { 787 if (attr.type() == "type") { 788 std::vector<string> output_sizes; 789 const string flattened = 790 FlattenInputs(&arg_list->second, &output_sizes); 791 string conversion = strings::StrCat("_execute.args_to_matching_eager(", 792 flattened, ", _ctx"); 793 if (attr.has_default_value()) { 794 strings::StrAppend( 795 &conversion, ", ", 796 python_op_gen_internal::AttrValueToPython( 797 attr.type(), api_def_attr.default_value(), "_dtypes.")); 798 } 799 strings::StrAppend(&conversion, ")"); 800 const string var_name = AttrVarName(attr.name(), &attr_expressions_); 801 if (output_sizes.size() == 1) { 802 // Avoid creating a temporary variable in the case where 803 // we can easily assign to the right value directly. 804 const string inputs_var = 805 param_names_[arg_list->second.front()].GetRenameTo(); 806 if (output_sizes.front().empty()) { 807 strings::StrAppend(&result_, indentation, var_name, ", (", 808 inputs_var, ",) = ", conversion, "\n"); 809 } else { 810 strings::StrAppend(&result_, indentation, var_name, ", ", 811 inputs_var, " = ", conversion, "\n"); 812 } 813 } else { 814 const string inputs_var = strings::StrCat("_inputs_", attr.name()); 815 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var, 816 " = ", conversion, "\n"); 817 // Convert from a flat list of eager tensors back to the 818 // parameter variables. 819 Unflatten(indentation, output_sizes, inputs_var, &result_); 820 std::vector<string> p; 821 for (int j : arg_list->second) { 822 p.emplace_back(param_names_[j].GetRenameTo()); 823 } 824 strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ", 825 inputs_var, "\n"); 826 } 827 } else if (attr.type() == "list(type)") { 828 // NOTE: We ignore default values for these attrs, since it is 829 // unclear how you would use it, and the one use case is 830 // parse_single_sequence_example which only needs it for 831 // backwards compatibility. 832 const string var_name = AttrVarName(attr.name(), &attr_expressions_); 833 string inputs_var; 834 string conversion; 835 if (arg_list->second.size() > 1) { 836 // If you have more than one list(tensor) argument, their types 837 // have to match. 838 std::vector<string> lists; 839 for (auto iter = arg_list->second.begin(); 840 iter != arg_list->second.end(); ++iter) { 841 lists.push_back(param_names_[*iter].GetRenameTo()); 842 } 843 inputs_var = VectorToTuple(lists); 844 conversion = "_execute.args_to_mixed_eager_tensors"; 845 } else { 846 // For one list(tensor) argument, we just convert every 847 // element of the list to an eager tensor. 848 inputs_var = param_names_[arg_list->second.front()].GetRenameTo(); 849 conversion = "_execute.convert_to_mixed_eager_tensors"; 850 } 851 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var, 852 " = ", conversion, "(", inputs_var, ", _ctx)\n"); 853 } 854 } 855 } 856 } 857 858 void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) { 859 // Cast remaining args to eager tensors 860 for (int i = 0; i < op_def_.input_arg_size(); ++i) { 861 const auto& arg(op_def_.input_arg(i)); 862 if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue; 863 const string& param = param_names_[i].GetRenameTo(); 864 const string fn = arg.number_attr().empty() ? "" : "n_"; 865 const string dtype = 866 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes."); 867 strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn, 868 "to_tensor(", param, ", ", dtype, ")\n"); 869 } 870 } 871 872 void GenEagerPythonOp::AddEagerAttrs(const string& indentation) { 873 // Compute eager attrs 874 if (op_def_.attr_size() > 0) { 875 string attr_values; 876 for (int i = 0; i < op_def_.attr_size(); ++i) { 877 if (i > 0) strings::StrAppend(&attr_values, ", "); 878 const auto& attr_name(op_def_.attr(i).name()); 879 strings::StrAppend(&attr_values, "\"", attr_name, "\", ", 880 attr_expressions_[attr_name]); 881 } 882 strings::StrAppend(&attr_values, ")"); 883 strings::StrAppend( 884 &result_, 885 WordWrap(indentation, strings::StrCat("_attrs = (", attr_values), 886 kRightMargin), 887 "\n"); 888 } else { 889 strings::StrAppend(&result_, indentation, "_attrs = None\n"); 890 } 891 } 892 893 void GenEagerPythonOp::AddEagerExecute(const string& indentation, 894 const string& num_outputs_expr) { 895 const string return_prefix = 896 strings::StrCat(indentation, "_result = _execute.execute("); 897 const string return_args = strings::StrCat( 898 "b\"", op_def_.name(), "\", ", num_outputs_expr, 899 ", inputs=_inputs_flat, attrs=_attrs, ctx=_ctx, name=name)"); 900 strings::StrAppend(&result_, 901 // Wrap the arguments, and indent to the (. 902 WordWrap(return_prefix, return_args, kRightMargin), "\n"); 903 } 904 905 string GetEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, 906 const std::vector<string>& hidden_ops, 907 bool require_shapes, 908 const string& source_file_name = "") { 909 string result; 910 // Header 911 // TODO(josh11b): Mention the library for which wrappers are being generated. 912 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops. 913 914 This file is MACHINE GENERATED! Do not edit. 915 )"); 916 917 // Mention the original source file so someone tracing back through 918 // generated Python code will know where to look next. 919 if (!source_file_name.empty()) { 920 strings::StrAppend(&result, "Original C++ source file: "); 921 strings::StrAppend(&result, source_file_name); 922 strings::StrAppend(&result, "\n"); 923 } 924 925 strings::StrAppend(&result, R"(""" 926 927 import collections as _collections 928 import six as _six 929 930 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow 931 from tensorflow.python.eager import context as _context 932 from tensorflow.python.eager import core as _core 933 from tensorflow.python.eager import execute as _execute 934 from tensorflow.python.framework import dtypes as _dtypes 935 from tensorflow.python.framework import errors as _errors 936 from tensorflow.python.framework import tensor_shape as _tensor_shape 937 938 from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 939 # Needed to trigger the call to _set_call_cpp_shape_fn. 940 from tensorflow.python.framework import common_shapes as _common_shapes 941 from tensorflow.python.framework import op_def_registry as _op_def_registry 942 from tensorflow.python.framework import ops as _ops 943 from tensorflow.python.framework import op_def_library as _op_def_library 944 from tensorflow.python.util.tf_export import tf_export 945 946 )"); 947 948 // We'll make a copy of ops that filters out descriptions. 949 OpList cleaned_ops; 950 auto out = cleaned_ops.mutable_op(); 951 out->Reserve(ops.op_size()); 952 for (const auto& op_def : ops.op()) { 953 const auto* api_def = api_defs.GetApiDef(op_def.name()); 954 955 if (api_def->visibility() == ApiDef::SKIP) { 956 continue; 957 } 958 959 // An op is hidden if either its ApiDef visibility is HIDDEN 960 // or it is in the hidden_ops list. 961 bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; 962 if (!is_hidden) { 963 for (const string& hidden : hidden_ops) { 964 if (op_def.name() == hidden) { 965 is_hidden = true; 966 break; 967 } 968 } 969 } 970 971 string function_name; 972 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), 973 &function_name); 974 if (is_hidden) function_name = strings::StrCat("_", function_name); 975 976 // When users create custom python wrappers, they may link in the 977 // default op registry by accident, and because they can't 978 // enumerate all 'hidden' symbols, this guard is to prevent 979 // instantiating a python reserved word in their wrapper. 980 if (python_op_gen_internal::IsPythonReserved(function_name)) { 981 continue; 982 } 983 984 strings::StrAppend(&result, 985 GetEagerPythonOp(op_def, *api_def, function_name)); 986 987 if (!require_shapes) { 988 strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), 989 "\")(None)\n\n"); 990 } 991 992 auto added = out->Add(); 993 *added = op_def; 994 RemoveNonDeprecationDescriptionsFromOpDef(added); 995 } 996 997 result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes): 998 op_list = _op_def_pb2.OpList() 999 op_list.ParseFromString(op_list_proto_bytes) 1000 _op_def_registry.register_op_list(op_list) 1001 op_def_lib = _op_def_library.OpDefLibrary() 1002 op_def_lib.add_op_list(op_list) 1003 return op_def_lib 1004 )"); 1005 1006 result.append("# "); 1007 auto ops_text = ProtoDebugString(cleaned_ops); 1008 str_util::StripTrailingWhitespace(&ops_text); 1009 result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true)); 1010 result.append("\n"); 1011 strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n", 1012 str_util::CEscape(cleaned_ops.SerializeAsString()).c_str()); 1013 return result; 1014 } 1015 1016 } // namespace 1017 1018 void PrintEagerPythonOps(const OpList& ops, const ApiDefMap& api_defs, 1019 const std::vector<string>& hidden_ops, 1020 bool require_shapes, const string& source_file_name) { 1021 printf("%s", GetEagerPythonOps(ops, api_defs, hidden_ops, require_shapes, 1022 source_file_name) 1023 .c_str()); 1024 } 1025 1026 string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) { 1027 string op_list_str(op_list_buf, op_list_len); 1028 OpList ops; 1029 ops.ParseFromString(op_list_str); 1030 1031 ApiDefMap api_def_map(ops); 1032 return GetEagerPythonOps(ops, api_def_map, {}, false); 1033 } 1034 1035 } // namespace tensorflow 1036