1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/python/framework/python_op_gen.h" 17 18 #include <stdio.h> 19 #include <sstream> 20 #include <unordered_map> 21 #include "tensorflow/core/framework/api_def.pb.h" 22 #include "tensorflow/core/framework/attr_value.pb.h" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_def.pb_text.h" 25 #include "tensorflow/core/framework/op_def.pb.h" 26 #include "tensorflow/core/framework/op_def_util.h" 27 #include "tensorflow/core/framework/op_gen_lib.h" 28 #include "tensorflow/core/framework/tensor.pb_text.h" 29 #include "tensorflow/core/framework/tensor.pb.h" 30 #include "tensorflow/core/framework/tensor_shape.pb.h" 31 #include "tensorflow/core/framework/types.h" 32 #include "tensorflow/core/framework/types.pb.h" 33 #include "tensorflow/core/lib/gtl/map_util.h" 34 #include "tensorflow/core/lib/gtl/stl_util.h" 35 #include "tensorflow/core/lib/strings/str_util.h" 36 #include "tensorflow/core/lib/strings/strcat.h" 37 #include "tensorflow/core/lib/strings/stringprintf.h" 38 #include "tensorflow/core/platform/logging.h" 39 #include "tensorflow/core/platform/macros.h" 40 #include "tensorflow/core/platform/types.h" 41 #include "tensorflow/python/framework/python_op_gen_internal.h" 42 43 namespace tensorflow { 44 namespace python_op_gen_internal { 45 46 const int kRightMargin = 78; 47 48 bool IsPythonReserved(const string& s) { 49 static const std::set<string>* const kPythonReserved = new std::set<string>( 50 {// Keywords in Python, from: 51 // import keyword 52 // print keyword.kwlist 53 "and", "as", "assert", "break", "class", "continue", "def", "del", 54 "elif", "else", "except", "exec", "finally", "for", "from", "global", 55 "if", "import", "in", "is", "lambda", "not", "or", "pass", "print", 56 "raise", "return", "try", "while", "with", "yield", 57 // Built-in functions and types in Python, from: 58 // [x for x in dir(__builtins__) if not x[0].islower()] 59 "ArithmeticError", "AssertionError", "AttributeError", "BaseException", 60 "BufferError", "BytesWarning", "DeprecationWarning", "EOFError", 61 "Ellipsis", "EnvironmentError", "Exception", "False", 62 "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError", 63 "ImportError", "ImportWarning", "IndentationError", "IndexError", 64 "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError", 65 "NameError", "None", "NotImplemented", "NotImplementedError", "OSError", 66 "OverflowError", "PendingDeprecationWarning", "ReferenceError", 67 "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration", 68 "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError", 69 "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError", 70 "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError", 71 "UnicodeWarning", "UserWarning", "ValueError", "Warning", 72 "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__", 73 "__package__"}); 74 75 return kPythonReserved->count(s) > 0; 76 } 77 78 string AvoidPythonReserved(const string& s) { 79 if (IsPythonReserved(s)) return strings::StrCat(s, "_"); 80 return s; 81 } 82 83 // Indent the first line by "initial" spaces and all following lines 84 // by "rest" spaces. 85 string Indent(int initial, int rest, StringPiece in) { 86 // TODO(josh11b): Also word-wrapping? 87 string copy(in.data(), in.size()); 88 str_util::StripTrailingWhitespace(©); 89 std::vector<string> v = str_util::Split(copy, '\n'); 90 91 string result; 92 bool first = true; 93 for (const string& line : v) { 94 if (first) { 95 result = strings::StrCat(Spaces(initial), line, "\n"); 96 first = false; 97 } else { 98 if (line.empty()) { 99 strings::StrAppend(&result, "\n"); 100 } else { 101 strings::StrAppend(&result, Spaces(rest), line, "\n"); 102 } 103 } 104 } 105 return result; 106 } 107 108 // Adds append to *dest, with a space if the first line will be <= width, 109 // or a newline otherwise. 110 void AppendWithinWidth(string* dest, StringPiece append, int width) { 111 auto first_line = append.find('\n'); 112 if (first_line == string::npos) first_line = append.size(); 113 if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) { 114 strings::StrAppend(dest, "\n", append); 115 } else { 116 strings::StrAppend(dest, " ", append); 117 } 118 } 119 120 // Like DataTypeString() but uses the Python names for the 121 // float types. 122 string PythonDataTypeString(DataType dtype) { 123 switch (dtype) { 124 case DT_FLOAT: 125 return "float32"; 126 case DT_DOUBLE: 127 return "float64"; 128 default: 129 return DataTypeString(dtype); 130 } 131 } 132 133 string TypeString(DataType dtype, bool ref) { 134 if (ref) { 135 return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`"); 136 } else { 137 return strings::StrCat("`", PythonDataTypeString(dtype), "`"); 138 } 139 } 140 141 string TypeListString(const AttrValue& value) { 142 string ret; 143 for (int t : value.list().type()) { 144 if (!ret.empty()) strings::StrAppend(&ret, ", "); 145 DataType dtype = static_cast<DataType>(t); 146 if (IsRefType(dtype)) { 147 strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)), 148 " mutable"); 149 } else { 150 strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`"); 151 } 152 } 153 return ret; 154 } 155 156 string SingleTensorName(DataType dtype, bool is_ref) { 157 const string type_str = TypeString(dtype, is_ref); 158 return strings::StrCat("A `Tensor` of type ", type_str, "."); 159 } 160 161 const char kUnknownTensorType[] = {"A `Tensor`."}; 162 163 string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg, 164 const std::unordered_map<string, string>& inferred_attrs, 165 bool is_output) { 166 if (!arg.number_attr().empty()) { 167 // N Tensors with the same type 168 const string* original_arg = 169 gtl::FindOrNull(inferred_attrs, arg.number_attr()); 170 string prefix; 171 if (original_arg == nullptr) { 172 prefix = strings::StrCat("A list of `", arg.number_attr(), "`"); 173 } else if (*original_arg == arg.name()) { 174 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); 175 if (attr->has_minimum() && attr->minimum() > 0) { 176 prefix = strings::StrCat("A list of at least ", attr->minimum()); 177 } else { 178 prefix = "A list of"; 179 } 180 } else { 181 prefix = strings::StrCat("A list with the same length as `", 182 AvoidPythonReserved(*original_arg), "` of"); 183 } 184 185 if (arg.type() != DT_INVALID) { 186 return strings::StrCat(prefix, " `Tensor` objects with type ", 187 TypeString(arg.type(), arg.is_ref()), "."); 188 } else { 189 original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr()); 190 if (arg.is_ref()) { 191 strings::StrAppend(&prefix, " mutable"); 192 } 193 if (original_arg == nullptr) { 194 return strings::StrCat(prefix, " `Tensor` objects with type `", 195 arg.type_attr(), "`."); 196 } else if (*original_arg == arg.name()) { 197 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); 198 if (attr->has_allowed_values()) { 199 return strings::StrCat(prefix, 200 " `Tensor` objects with the same type in: ", 201 TypeListString(attr->allowed_values()), "."); 202 } else { 203 return strings::StrCat(prefix, 204 " `Tensor` objects with the same type."); 205 } 206 } else { 207 return strings::StrCat(prefix, 208 " `Tensor` objects with the same type as `", 209 AvoidPythonReserved(*original_arg), "`."); 210 } 211 } 212 } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) { 213 const bool is_list = !arg.type_list_attr().empty(); 214 const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr(); 215 const OpDef::AttrDef* attr = FindAttr(attr_name, op_def); 216 const string mutable_str = arg.is_ref() ? "mutable " : ""; 217 const string prefix = 218 is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects") 219 : strings::StrCat("A ", mutable_str, "`Tensor`"); 220 const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name); 221 if (original_arg == nullptr) { 222 return strings::StrCat(prefix, " of type `", attr_name, "`."); 223 } else if (*original_arg == arg.name()) { 224 if (attr->has_allowed_values()) { 225 if (is_list) { 226 return strings::StrCat(prefix, " with types from: ", 227 TypeListString(attr->allowed_values()), "."); 228 } else { 229 return strings::StrCat( 230 prefix, is_output ? ". Has one of the following types: " 231 : ". Must be one of the following types: ", 232 TypeListString(attr->allowed_values()), "."); 233 } 234 } else { 235 return strings::StrCat(prefix, "."); 236 } 237 } else { 238 return strings::StrCat(prefix, 239 is_output ? ". Has the same type as `" 240 : ". Must have the same type as `", 241 AvoidPythonReserved(*original_arg), "`."); 242 } 243 } else { 244 return SingleTensorName(arg.type(), arg.is_ref()); 245 } 246 } 247 248 string GetReturns(const OpDef& op_def, 249 const std::vector<string>& output_type_string) { 250 string result; 251 DCHECK_EQ(op_def.output_arg_size(), output_type_string.size()); 252 const int num_outs = op_def.output_arg_size(); 253 strings::StrAppend(&result, "\n Returns:\n"); 254 if (num_outs == 0) { 255 strings::StrAppend(&result, " The created Operation.\n"); 256 } else { 257 if (num_outs == 1) { 258 StringPiece description = op_def.output_arg(0).description(); 259 if (ConsumeEquals(&description)) { // Skip the generated type info. 260 strings::StrAppend(&result, Indent(4, 4, description)); 261 } else { 262 // Special case of one output, don't use the name of the output unless 263 // there is no description. 264 string desc = output_type_string.empty() ? kUnknownTensorType 265 : output_type_string[0]; 266 if (desc == kUnknownTensorType) { 267 // Special case where we don't understand how the output tensor type 268 // depends on the input tensor types, just use the output arg 269 // description if we can. 270 if (!description.empty()) { 271 desc = op_def.output_arg(0).description(); 272 } else if (!op_def.output_arg(0).name().empty()) { 273 desc = strings::StrCat(" The ", op_def.output_arg(0).name(), 274 " `Tensor`."); 275 } 276 } else if (!description.empty()) { 277 AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); 278 } 279 strings::StrAppend(&result, Indent(4, 4, desc)); 280 } 281 } else { 282 std::vector<string> out_names(num_outs); 283 for (int i = 0; i < num_outs; ++i) { 284 if (!op_def.output_arg(i).name().empty()) { 285 out_names[i] = op_def.output_arg(i).name(); 286 } else { 287 out_names[i] = strings::StrCat("output", i); 288 } 289 } 290 strings::StrAppend(&result, " A tuple of `Tensor` objects (", 291 str_util::Join(out_names, ", "), ").\n\n"); 292 for (int i = 0; i < num_outs; ++i) { 293 string desc = strings::StrCat(out_names[i], ": "); 294 StringPiece description = op_def.output_arg(i).description(); 295 if (ConsumeEquals(&description)) { // Skip the generated type info. 296 strings::StrAppend(&desc, description); 297 } else { 298 const string type = static_cast<size_t>(i) < output_type_string.size() 299 ? output_type_string[i] 300 : kUnknownTensorType; 301 if (!description.empty()) { 302 if (type == kUnknownTensorType) { 303 // Special case where we don't understand how the output tensor 304 // type depends on the input tensor types, so we just use the 305 // output arg description. 306 strings::StrAppend(&desc, description); 307 } else { 308 strings::StrAppend(&desc, type, " ", description); 309 } 310 } else { 311 strings::StrAppend(&desc, type); 312 } 313 } 314 strings::StrAppend(&result, Indent(4, 6, desc)); 315 } 316 } 317 } 318 return result; 319 } 320 321 string StringToPython(const string& str) { 322 return strings::StrCat("\"", str_util::CEscape(str), "\""); 323 } 324 325 string DataTypeToPython(DataType dtype, const string& dtype_module) { 326 return strings::StrCat(dtype_module, PythonDataTypeString(dtype)); 327 } 328 329 string ShapeToPython(const TensorShapeProto& shape) { 330 if (shape.unknown_rank()) { 331 return "None"; 332 } 333 string python = "["; 334 for (const auto& dim : shape.dim()) { 335 if (python.size() > 1) strings::StrAppend(&python, ", "); 336 if (!dim.name().empty()) { 337 strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ", 338 dim.size(), ")"); 339 } else { 340 strings::StrAppend(&python, dim.size()); 341 } 342 } 343 strings::StrAppend(&python, "]"); 344 return python; 345 } 346 347 string TensorToPython(const TensorProto& proto) { 348 return ProtoShortDebugString(proto); 349 } 350 351 string AttrListToPython(const AttrValue& value, 352 const string& dtype_module = "tf.") { 353 string ret; 354 if (value.list().s_size() > 0) { 355 for (int i = 0; i < value.list().s_size(); ++i) { 356 if (i > 0) strings::StrAppend(&ret, ", "); 357 strings::StrAppend(&ret, StringToPython(value.list().s(i))); 358 } 359 } else if (value.list().i_size() > 0) { 360 for (int i = 0; i < value.list().i_size(); ++i) { 361 if (i > 0) strings::StrAppend(&ret, ", "); 362 strings::StrAppend(&ret, value.list().i(i)); 363 } 364 } else if (value.list().f_size() > 0) { 365 for (int i = 0; i < value.list().f_size(); ++i) { 366 if (i > 0) strings::StrAppend(&ret, ", "); 367 strings::StrAppend(&ret, value.list().f(i)); 368 } 369 } else if (value.list().b_size() > 0) { 370 for (int i = 0; i < value.list().b_size(); ++i) { 371 if (i > 0) strings::StrAppend(&ret, ", "); 372 strings::StrAppend(&ret, value.list().b(i) ? "True" : "False"); 373 } 374 } else if (value.list().type_size() > 0) { 375 for (int i = 0; i < value.list().type_size(); ++i) { 376 if (i > 0) strings::StrAppend(&ret, ", "); 377 strings::StrAppend(&ret, 378 DataTypeToPython(value.list().type(i), dtype_module)); 379 } 380 } else if (value.list().shape_size() > 0) { 381 for (int i = 0; i < value.list().shape_size(); ++i) { 382 if (i > 0) strings::StrAppend(&ret, ", "); 383 strings::StrAppend(&ret, ShapeToPython(value.list().shape(i))); 384 } 385 } else if (value.list().tensor_size() > 0) { 386 for (int i = 0; i < value.list().tensor_size(); ++i) { 387 if (i > 0) strings::StrAppend(&ret, ", "); 388 strings::StrAppend(&ret, TensorToPython(value.list().tensor(i))); 389 } 390 } else if (value.list().func_size() > 0) { 391 for (int i = 0; i < value.list().func_size(); ++i) { 392 if (i > 0) strings::StrAppend(&ret, ", "); 393 strings::StrAppend(&ret, StringToPython(value.list().func(i).name())); 394 } 395 } 396 return ret; 397 } 398 399 // NOTE: The return value may contain spaces (for example, it could be 400 // a string "foo bar" with an embedded space) and is not safe to pass 401 // to WordWrap(). 402 string AttrValueToPython(const string& type, const AttrValue& value, 403 const string& dtype_module) { 404 if (type == "string") { 405 return StringToPython(value.s()); 406 } else if (type == "int") { 407 return strings::StrCat(value.i()); 408 } else if (type == "float") { 409 if (std::isnan(value.f()) || std::isinf(value.f())) { 410 return strings::StrCat("float('", value.f(), "')"); 411 } else { 412 return strings::StrCat(value.f()); 413 } 414 } else if (type == "bool") { 415 return value.b() ? "True" : "False"; 416 } else if (type == "type") { 417 return DataTypeToPython(value.type(), dtype_module); 418 } else if (type == "shape") { 419 return ShapeToPython(value.shape()); 420 } else if (type == "tensor") { 421 return TensorToPython(value.tensor()); 422 } else if (type == "func") { 423 return StringToPython(value.func().name()); 424 } else if (StringPiece(type).starts_with("list(")) { 425 return strings::StrCat("[", AttrListToPython(value, dtype_module), "]"); 426 } else { 427 return "?"; 428 } 429 } 430 431 void GenerateLowerCaseOpName(const string& str, string* result) { 432 const char joiner = '_'; 433 const int last_index = str.size() - 1; 434 for (int i = 0; i <= last_index; ++i) { 435 const char c = str[i]; 436 // Emit a joiner only if a previous-lower-to-now-upper or a 437 // now-upper-to-next-lower transition happens. 438 if (isupper(c) && (i > 0)) { 439 if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) { 440 result->push_back(joiner); 441 } 442 } 443 result->push_back(tolower(c)); 444 } 445 } 446 447 static void AddDelimiter(string* append_to, const string& delim) { 448 if (!append_to->empty()) strings::StrAppend(append_to, delim); 449 } 450 451 const ApiDef::Attr* FindAttr(StringPiece name, const ApiDef& api_def) { 452 for (int i = 0; i < api_def.attr_size(); ++i) { 453 if (api_def.attr(i).name() == name) { 454 return &api_def.attr(i); 455 } 456 } 457 return nullptr; 458 } 459 460 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { 461 for (int i = 0; i < api_def.in_arg_size(); ++i) { 462 if (api_def.in_arg(i).name() == name) { 463 return &api_def.in_arg(i); 464 } 465 } 466 return nullptr; 467 } 468 469 GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def, 470 const string& function_name) 471 : op_def_(op_def), 472 api_def_(api_def), 473 function_name_(function_name), 474 num_outs_(op_def.output_arg_size()) {} 475 476 GenPythonOp::~GenPythonOp() {} 477 478 string GenPythonOp::Code() { 479 // This has all the input args followed by those attrs that don't have 480 // defaults. 481 std::vector<ParamNames> params_no_default; 482 // The parameters with defaults (these have to be listed after those without). 483 // No input args are included, just attrs. 484 std::vector<ParamNames> params_with_default; 485 486 for (int i = 0; i < api_def_.arg_order_size(); ++i) { 487 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); 488 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); 489 params_no_default.emplace_back(api_def_arg.name(), api_def_arg.rename_to()); 490 if (!arg.type_attr().empty()) { 491 gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_attr(), arg.name()); 492 } else if (!arg.type_list_attr().empty()) { 493 gtl::InsertIfNotPresent(&inferred_attrs_, arg.type_list_attr(), 494 arg.name()); 495 } 496 if (!arg.number_attr().empty()) { 497 gtl::InsertIfNotPresent(&inferred_attrs_, arg.number_attr(), arg.name()); 498 } 499 } 500 for (int i = 0; i < api_def_.attr_size(); ++i) { 501 const auto& attr(api_def_.attr(i)); 502 // Do not add inferred attrs to the Python function signature. 503 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { 504 if (attr.has_default_value()) { 505 params_with_default.emplace_back(attr.name(), attr.rename_to()); 506 } else { 507 params_no_default.emplace_back(attr.name(), attr.rename_to()); 508 } 509 } 510 } 511 512 // Save the list of attr parameters (attrs that won't be inferred), 513 // those with defaults go at the end. 514 // Get the attrs in the order we want by taking the attrs without defaults 515 // from the end of args_no_default, and adding args_no_default. 516 attrs_.reserve(params_no_default.size() - op_def_.input_arg_size() + 517 params_with_default.size()); 518 for (int i = op_def_.input_arg_size(); i < params_no_default.size(); ++i) { 519 attrs_.push_back(params_no_default[i].GetName()); 520 } 521 for (int i = 0; i < params_with_default.size(); ++i) { 522 attrs_.push_back(params_with_default[i].GetName()); 523 } 524 525 param_names_.reserve(params_no_default.size() + params_with_default.size()); 526 param_names_.insert(param_names_.begin(), params_no_default.begin(), 527 params_no_default.end()); 528 for (const auto& param : params_with_default) { 529 param_names_.push_back(param); 530 } 531 532 string parameters; 533 for (const auto& param : params_no_default) { 534 AddDelimiter(¶meters, ", "); 535 strings::StrAppend(¶meters, param.GetRenameTo()); 536 } 537 for (const auto& param_and_default : params_with_default) { 538 AddDelimiter(¶meters, ", "); 539 strings::StrAppend(¶meters, param_and_default.GetRenameTo(), "=None"); 540 } 541 AddDelimiter(¶meters, ", "); 542 strings::StrAppend(¶meters, "name=None"); 543 544 AddExport(); 545 AddDefLine(parameters); 546 AddDocStringDescription(); 547 AddDocStringArgs(); 548 AddDocStringInputs(); 549 AddDocStringAttrs(); 550 AddDocStringNameArg(); 551 AddOutputGlobals(); 552 AddDocStringOutputs(); 553 strings::StrAppend(&result_, " \"\"\"\n"); 554 AddBody(" "); 555 strings::StrAppend(&result_, "\n\n"); 556 557 return prelude_ + result_; 558 } 559 560 void GenPythonOp::AddExport() { 561 if (api_def_.visibility() != ApiDef::VISIBLE) { 562 return; 563 } 564 565 strings::StrAppend(&result_, "@tf_export("); 566 567 // Add all endpoint names to tf_export. 568 bool first_endpoint = true; 569 for (const auto& endpoint : api_def_.endpoint()) { 570 if (!first_endpoint) { 571 strings::StrAppend(&result_, ", "); 572 } else { 573 first_endpoint = false; 574 } 575 string endpoint_name; 576 python_op_gen_internal::GenerateLowerCaseOpName(endpoint.name(), 577 &endpoint_name); 578 strings::StrAppend(&result_, "'", endpoint_name, "'"); 579 } 580 strings::StrAppend(&result_, ")\n"); 581 } 582 583 void GenPythonOp::AddDefLine(const string& function_name, 584 const string& parameters) { 585 strings::StrAppend(&result_, "def ", function_name, "(", parameters, "):\n"); 586 } 587 588 void GenPythonOp::AddDefLine(const string& parameters) { 589 AddDefLine(function_name_, parameters); 590 } 591 592 void GenPythonOp::AddDocStringDescription() { 593 string comment; 594 if (api_def_.summary().empty()) { 595 comment = "TODO: add doc.\n"; 596 } else { 597 comment = strings::StrCat(api_def_.summary(), "\n"); 598 if (!api_def_.description().empty()) { 599 strings::StrAppend(&comment, "\n", Indent(2, 2, api_def_.description())); 600 } 601 } 602 strings::StrAppend(&result_, " r\"\"\"", comment, "\n"); 603 } 604 605 void GenPythonOp::AddDocStringArgs() { 606 strings::StrAppend(&result_, " Args:\n"); 607 } 608 609 void GenPythonOp::AddDocStringInputs() { 610 for (int i = 0; i < api_def_.arg_order_size(); ++i) { 611 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); 612 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); 613 StringPiece description = api_def_arg.description(); 614 string desc; 615 if (ConsumeEquals(&description)) { // Skip the generated type info. 616 desc = strings::StrCat(param_names_[i].GetRenameTo(), ": "); 617 } else { 618 desc = strings::StrCat(param_names_[i].GetRenameTo(), ": ", 619 ArgTypeName(op_def_, arg, inferred_attrs_, false)); 620 } 621 if (!description.empty()) { 622 AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */); 623 } 624 strings::StrAppend(&result_, Indent(4, 6, desc)); 625 } 626 } 627 628 void GenPythonOp::AddDocStringAttrs() { 629 for (const string& name : attrs_) { 630 const auto& attr = *FindAttr(name, op_def_); 631 const auto& api_def_attr = *FindAttr(name, api_def_); 632 string desc = 633 strings::StrCat(AvoidPythonReserved(api_def_attr.rename_to()), ": "); 634 635 static const char* const kAttrTypeName[][2] = { 636 {"string", "`string`"}, 637 {"list(string)", "list of `strings`"}, 638 {"int", "`int`"}, 639 {"list(int)", "list of `ints`"}, 640 {"float", "`float`"}, 641 {"list(float)", "list of `floats`"}, 642 {"bool", "`bool`"}, 643 {"list(bool)", "list of `bools`"}, 644 {"type", "`tf.DType`"}, 645 {"list(type)", "list of `tf.DTypes`"}, 646 {"shape", "`tf.TensorShape` or list of `ints`"}, 647 {"list(shape)", 648 "list of shapes (each a `tf.TensorShape` or list of `ints`)"}, 649 {"tensor", "`tf.TensorProto`"}, 650 {"list(tensor)", "list of `tf.TensorProto` objects"}, 651 {"func", "function decorated with @Defun"}, 652 {"list(func)", "list of functions decorated with @Defun"}, 653 }; 654 for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) { 655 if (attr.type() == kAttrTypeName[i][0]) { 656 string s; 657 if (api_def_attr.has_default_value()) { 658 s = strings::StrCat("optional ", kAttrTypeName[i][1]); 659 } else { 660 s = kAttrTypeName[i][1]; 661 } 662 if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) { 663 strings::StrAppend(&desc, "An ", s); 664 } else { 665 strings::StrAppend(&desc, "A ", s); 666 } 667 break; 668 } 669 } 670 671 if (attr.has_allowed_values()) { 672 strings::StrAppend(&desc, " from: `", 673 AttrListToPython(attr.allowed_values()), "`"); 674 } 675 676 if (attr.has_minimum()) { 677 if (attr.type() == "int") { 678 strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`"); 679 } else if (attr.minimum() > 0) { 680 strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`"); 681 } 682 } 683 684 strings::StrAppend(&desc, "."); 685 686 if (api_def_attr.has_default_value()) { 687 strings::StrAppend( 688 &desc, " Defaults to `", 689 AttrValueToPython(attr.type(), api_def_attr.default_value()), "`."); 690 } 691 if (!api_def_attr.description().empty()) { 692 AppendWithinWidth(&desc, api_def_attr.description(), 693 kRightMargin - 4 /* indent */); 694 } 695 strings::StrAppend(&result_, Indent(4, 6, desc)); 696 } 697 } 698 699 void GenPythonOp::AddDocStringNameArg() { 700 strings::StrAppend(&result_, 701 " name: A name for the operation (optional).\n"); 702 } 703 704 void GenPythonOp::AddOutputGlobals() { 705 // Prepare a NamedTuple type to hold the outputs, if there are multiple 706 if (num_outs_ > 1) { 707 // Prepare the list of output names 708 std::vector<string> out_names(num_outs_); 709 for (int i = 0; i < num_outs_; ++i) { 710 if (!api_def_.out_arg(i).rename_to().empty()) { 711 out_names[i] = api_def_.out_arg(i).rename_to(); 712 } else { 713 out_names[i] = strings::StrCat("output", i); 714 } 715 } 716 string out_names_list = 717 strings::StrCat("[\"", str_util::Join(out_names, "\", \""), "\"]"); 718 719 // Provide the output names as a Python list 720 string lower_op_name_outputs = 721 strings::StrCat("_", function_name_, "_outputs"); 722 const string outputs_prefix = strings::StrCat(lower_op_name_outputs, " = "); 723 strings::StrAppend(&prelude_, "\n", 724 WordWrap(outputs_prefix, out_names_list, kRightMargin), 725 "\n"); 726 727 strings::StrAppend(&prelude_, "_", op_def_.name(), 728 "Output = _collections.namedtuple(\n"); 729 const string tuple_type_prefix = " "; 730 const string tuple_type_suffix = strings::StrCat( 731 "\"", op_def_.name(), "\", ", lower_op_name_outputs, ")"); 732 strings::StrAppend( 733 &prelude_, WordWrap(tuple_type_prefix, tuple_type_suffix, kRightMargin), 734 "\n\n"); 735 } 736 strings::StrAppend(&prelude_, "\n"); 737 } 738 739 void GenPythonOp::AddDocStringOutputs() { 740 std::vector<string> output_type_string; 741 output_type_string.reserve(num_outs_); 742 for (int i = 0; i < num_outs_; ++i) { 743 output_type_string.push_back( 744 ArgTypeName(op_def_, op_def_.output_arg(i), inferred_attrs_, true)); 745 } 746 strings::StrAppend(&result_, GetReturns(op_def_, output_type_string)); 747 } 748 749 void GenPythonOp::AddBody(const string& prefix) { 750 const string apply_prefix = 751 strings::StrCat(prefix, "_result = _op_def_lib.apply_op("); 752 AddBodyNoReturn(apply_prefix); 753 if (num_outs_ > 1) { 754 strings::StrAppend(&result_, prefix, "_result = _", op_def_.name(), 755 "Output._make(_result)\n"); 756 } 757 strings::StrAppend(&result_, prefix, "return _result\n"); 758 } 759 760 void GenPythonOp::AddBodyNoReturn(const string& apply_prefix) { 761 string args = strings::StrCat("\"", op_def_.name(), "\", "); 762 for (size_t i = 0; i < param_names_.size(); ++i) { 763 strings::StrAppend(&args, AvoidPythonReserved(param_names_[i].GetName()), 764 "=", param_names_[i].GetRenameTo(), ", "); 765 } 766 strings::StrAppend(&args, "name=name)"); 767 768 strings::StrAppend(&result_, 769 // Wrap the arguments, and indent to the (. 770 WordWrap(apply_prefix, args, kRightMargin), "\n"); 771 } 772 773 } // namespace python_op_gen_internal 774 775 string GetPythonOp(const OpDef& op_def, const ApiDef& api_def, 776 const string& function_name) { 777 return python_op_gen_internal::GenPythonOp(op_def, api_def, function_name) 778 .Code(); 779 } 780 781 string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, 782 const std::vector<string>& hidden_ops, 783 bool require_shapes) { 784 string result; 785 // Header 786 // TODO(josh11b): Mention the library for which wrappers are being generated. 787 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops. 788 789 This file is MACHINE GENERATED! Do not edit. 790 """ 791 792 import collections as _collections 793 794 from tensorflow.core.framework import op_def_pb2 as _op_def_pb2 795 796 # Needed to trigger the call to _set_call_cpp_shape_fn. 797 from tensorflow.python.framework import common_shapes as _common_shapes 798 799 from tensorflow.python.framework import op_def_registry as _op_def_registry 800 from tensorflow.python.framework import ops as _ops 801 from tensorflow.python.framework import op_def_library as _op_def_library 802 from tensorflow.python.util.tf_export import tf_export 803 )"); 804 805 // We'll make a copy of ops that filters out descriptions. 806 OpList cleaned_ops; 807 auto out = cleaned_ops.mutable_op(); 808 out->Reserve(ops.op_size()); 809 for (const auto& op_def : ops.op()) { 810 const auto* api_def = api_defs.GetApiDef(op_def.name()); 811 812 if (api_def->visibility() == ApiDef::SKIP) { 813 continue; 814 } 815 816 // An op is hidden if either its ApiDef visibility is HIDDEN 817 // or it is in the hidden_ops list. 818 bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; 819 if (!is_hidden) { 820 for (const string& hidden : hidden_ops) { 821 if (op_def.name() == hidden) { 822 is_hidden = true; 823 break; 824 } 825 } 826 } 827 828 string function_name; 829 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), 830 &function_name); 831 if (is_hidden) function_name = strings::StrCat("_", function_name); 832 833 // When users create custom python wrappers, they may link in the 834 // default op registry by accident, and because they can't 835 // enumerate all 'hidden' symbols, this guard is to prevent 836 // instantiating a python reserved word in their wrapper. 837 if (python_op_gen_internal::IsPythonReserved(function_name)) { 838 continue; 839 } 840 841 strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name)); 842 843 if (!require_shapes) { 844 strings::StrAppend(&result, "_ops.RegisterShape(\"", op_def.name(), 845 "\")(None)\n"); 846 } 847 848 auto added = out->Add(); 849 *added = op_def; 850 RemoveNonDeprecationDescriptionsFromOpDef(added); 851 } 852 853 result.append(R"(def _InitOpDefLibrary(op_list_proto_bytes): 854 op_list = _op_def_pb2.OpList() 855 op_list.ParseFromString(op_list_proto_bytes) 856 _op_def_registry.register_op_list(op_list) 857 op_def_lib = _op_def_library.OpDefLibrary() 858 op_def_lib.add_op_list(op_list) 859 return op_def_lib 860 861 862 )"); 863 864 result.append("# "); 865 auto ops_text = ProtoDebugString(cleaned_ops); 866 str_util::StripTrailingWhitespace(&ops_text); 867 result.append(str_util::StringReplace(ops_text, "\n", "\n# ", true)); 868 result.append("\n"); 869 strings::Appendf(&result, "_op_def_lib = _InitOpDefLibrary(b\"%s\")\n", 870 str_util::CEscape(cleaned_ops.SerializeAsString()).c_str()); 871 return result; 872 } 873 874 void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, 875 const std::vector<string>& hidden_ops, 876 bool require_shapes) { 877 printf("%s", GetPythonOps(ops, api_defs, hidden_ops, require_shapes).c_str()); 878 } 879 880 string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) { 881 string op_list_str(op_list_buf, op_list_len); 882 OpList ops; 883 ops.ParseFromString(op_list_str); 884 ApiDefMap api_def_map(ops); 885 return GetPythonOps(ops, api_def_map, {}, false); 886 } 887 888 } // namespace tensorflow 889