1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op_gen_lib.h" 17 18 #include <algorithm> 19 #include <vector> 20 #include "tensorflow/core/framework/attr_value.pb.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/lib/gtl/map_util.h" 23 #include "tensorflow/core/lib/strings/str_util.h" 24 #include "tensorflow/core/lib/strings/strcat.h" 25 #include "tensorflow/core/platform/protobuf.h" 26 #include "tensorflow/core/util/proto/proto_utils.h" 27 28 namespace tensorflow { 29 30 string WordWrap(StringPiece prefix, StringPiece str, int width) { 31 const string indent_next_line = "\n" + Spaces(prefix.size()); 32 width -= prefix.size(); 33 string result; 34 strings::StrAppend(&result, prefix); 35 36 while (!str.empty()) { 37 if (static_cast<int>(str.size()) <= width) { 38 // Remaining text fits on one line. 39 strings::StrAppend(&result, str); 40 break; 41 } 42 auto space = str.rfind(' ', width); 43 if (space == StringPiece::npos) { 44 // Rather make a too-long line and break at a space. 45 space = str.find(' '); 46 if (space == StringPiece::npos) { 47 strings::StrAppend(&result, str); 48 break; 49 } 50 } 51 // Breaking at character at position <space>. 52 StringPiece to_append = str.substr(0, space); 53 str.remove_prefix(space + 1); 54 // Remove spaces at break. 55 while (str_util::EndsWith(to_append, " ")) { 56 to_append.remove_suffix(1); 57 } 58 while (str_util::ConsumePrefix(&str, " ")) { 59 } 60 61 // Go on to the next line. 62 strings::StrAppend(&result, to_append); 63 if (!str.empty()) strings::StrAppend(&result, indent_next_line); 64 } 65 66 return result; 67 } 68 69 bool ConsumeEquals(StringPiece* description) { 70 if (str_util::ConsumePrefix(description, "=")) { 71 while (str_util::ConsumePrefix(description, 72 " ")) { // Also remove spaces after "=". 73 } 74 return true; 75 } 76 return false; 77 } 78 79 // Split `*orig` into two pieces at the first occurrence of `split_ch`. 80 // Returns whether `split_ch` was found. Afterwards, `*before_split` 81 // contains the maximum prefix of the input `*orig` that doesn't 82 // contain `split_ch`, and `*orig` contains everything after the 83 // first `split_ch`. 84 static bool SplitAt(char split_ch, StringPiece* orig, 85 StringPiece* before_split) { 86 auto pos = orig->find(split_ch); 87 if (pos == StringPiece::npos) { 88 *before_split = *orig; 89 *orig = StringPiece(); 90 return false; 91 } else { 92 *before_split = orig->substr(0, pos); 93 orig->remove_prefix(pos + 1); 94 return true; 95 } 96 } 97 98 // Does this line start with "<spaces><field>:" where "<field>" is 99 // in multi_line_fields? Sets *colon_pos to the position of the colon. 100 static bool StartsWithFieldName(StringPiece line, 101 const std::vector<string>& multi_line_fields) { 102 StringPiece up_to_colon; 103 if (!SplitAt(':', &line, &up_to_colon)) return false; 104 while (str_util::ConsumePrefix(&up_to_colon, " ")) 105 ; // Remove leading spaces. 106 for (const auto& field : multi_line_fields) { 107 if (up_to_colon == field) { 108 return true; 109 } 110 } 111 return false; 112 } 113 114 static bool ConvertLine(StringPiece line, 115 const std::vector<string>& multi_line_fields, 116 string* ml) { 117 // Is this a field we should convert? 118 if (!StartsWithFieldName(line, multi_line_fields)) { 119 return false; 120 } 121 // Has a matching field name, so look for "..." after the colon. 122 StringPiece up_to_colon; 123 StringPiece after_colon = line; 124 SplitAt(':', &after_colon, &up_to_colon); 125 while (str_util::ConsumePrefix(&after_colon, " ")) 126 ; // Remove leading spaces. 127 if (!str_util::ConsumePrefix(&after_colon, "\"")) { 128 // We only convert string fields, so don't convert this line. 129 return false; 130 } 131 auto last_quote = after_colon.rfind('\"'); 132 if (last_quote == StringPiece::npos) { 133 // Error: we don't see the expected matching quote, abort the conversion. 134 return false; 135 } 136 StringPiece escaped = after_colon.substr(0, last_quote); 137 StringPiece suffix = after_colon.substr(last_quote + 1); 138 // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>' 139 140 string unescaped; 141 if (!str_util::CUnescape(escaped, &unescaped, nullptr)) { 142 // Error unescaping, abort the conversion. 143 return false; 144 } 145 // No more errors possible at this point. 146 147 // Find a string to mark the end that isn't in unescaped. 148 string end = "END"; 149 for (int s = 0; unescaped.find(end) != string::npos; ++s) { 150 end = strings::StrCat("END", s); 151 } 152 153 // Actually start writing the converted output. 154 strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end); 155 if (!suffix.empty()) { 156 // Output suffix, in case there was a trailing comment in the source. 157 strings::StrAppend(ml, suffix); 158 } 159 strings::StrAppend(ml, "\n"); 160 return true; 161 } 162 163 string PBTxtToMultiline(StringPiece pbtxt, 164 const std::vector<string>& multi_line_fields) { 165 string ml; 166 // Probably big enough, since the input and output are about the 167 // same size, but just a guess. 168 ml.reserve(pbtxt.size() * (17. / 16)); 169 StringPiece line; 170 while (!pbtxt.empty()) { 171 // Split pbtxt into its first line and everything after. 172 SplitAt('\n', &pbtxt, &line); 173 // Convert line or output it unchanged 174 if (!ConvertLine(line, multi_line_fields, &ml)) { 175 strings::StrAppend(&ml, line, "\n"); 176 } 177 } 178 return ml; 179 } 180 181 // Given a single line of text `line` with first : at `colon`, determine if 182 // there is an "<<END" expression after the colon and if so return true and set 183 // `*end` to everything after the "<<". 184 static bool FindMultiline(StringPiece line, size_t colon, string* end) { 185 if (colon == StringPiece::npos) return false; 186 line.remove_prefix(colon + 1); 187 while (str_util::ConsumePrefix(&line, " ")) { 188 } 189 if (str_util::ConsumePrefix(&line, "<<")) { 190 *end = string(line); 191 return true; 192 } 193 return false; 194 } 195 196 string PBTxtFromMultiline(StringPiece multiline_pbtxt) { 197 string pbtxt; 198 // Probably big enough, since the input and output are about the 199 // same size, but just a guess. 200 pbtxt.reserve(multiline_pbtxt.size() * (33. / 32)); 201 StringPiece line; 202 while (!multiline_pbtxt.empty()) { 203 // Split multiline_pbtxt into its first line and everything after. 204 if (!SplitAt('\n', &multiline_pbtxt, &line)) { 205 strings::StrAppend(&pbtxt, line); 206 break; 207 } 208 209 string end; 210 auto colon = line.find(':'); 211 if (!FindMultiline(line, colon, &end)) { 212 // Normal case: not a multi-line string, just output the line as-is. 213 strings::StrAppend(&pbtxt, line, "\n"); 214 continue; 215 } 216 217 // Multi-line case: 218 // something: <<END 219 // xx 220 // yy 221 // END 222 // Should be converted to: 223 // something: "xx\nyy" 224 225 // Output everything up to the colon (" something:"). 226 strings::StrAppend(&pbtxt, line.substr(0, colon + 1)); 227 228 // Add every line to unescaped until we see the "END" string. 229 string unescaped; 230 bool first = true; 231 while (!multiline_pbtxt.empty()) { 232 SplitAt('\n', &multiline_pbtxt, &line); 233 if (str_util::ConsumePrefix(&line, end)) break; 234 if (first) { 235 first = false; 236 } else { 237 unescaped.push_back('\n'); 238 } 239 strings::StrAppend(&unescaped, line); 240 line = StringPiece(); 241 } 242 243 // Escape what we extracted and then output it in quotes. 244 strings::StrAppend(&pbtxt, " \"", str_util::CEscape(unescaped), "\"", line, 245 "\n"); 246 } 247 return pbtxt; 248 } 249 250 static void StringReplace(const string& from, const string& to, string* s) { 251 // Split *s into pieces delimited by `from`. 252 std::vector<string> split; 253 string::size_type pos = 0; 254 while (pos < s->size()) { 255 auto found = s->find(from, pos); 256 if (found == string::npos) { 257 split.push_back(s->substr(pos)); 258 break; 259 } else { 260 split.push_back(s->substr(pos, found - pos)); 261 pos = found + from.size(); 262 if (pos == s->size()) { // handle case where `from` is at the very end. 263 split.push_back(""); 264 } 265 } 266 } 267 // Join the pieces back together with a new delimiter. 268 *s = str_util::Join(split, to.c_str()); 269 } 270 271 static void RenameInDocs(const string& from, const string& to, 272 ApiDef* api_def) { 273 const string from_quoted = strings::StrCat("`", from, "`"); 274 const string to_quoted = strings::StrCat("`", to, "`"); 275 for (int i = 0; i < api_def->in_arg_size(); ++i) { 276 if (!api_def->in_arg(i).description().empty()) { 277 StringReplace(from_quoted, to_quoted, 278 api_def->mutable_in_arg(i)->mutable_description()); 279 } 280 } 281 for (int i = 0; i < api_def->out_arg_size(); ++i) { 282 if (!api_def->out_arg(i).description().empty()) { 283 StringReplace(from_quoted, to_quoted, 284 api_def->mutable_out_arg(i)->mutable_description()); 285 } 286 } 287 for (int i = 0; i < api_def->attr_size(); ++i) { 288 if (!api_def->attr(i).description().empty()) { 289 StringReplace(from_quoted, to_quoted, 290 api_def->mutable_attr(i)->mutable_description()); 291 } 292 } 293 if (!api_def->summary().empty()) { 294 StringReplace(from_quoted, to_quoted, api_def->mutable_summary()); 295 } 296 if (!api_def->description().empty()) { 297 StringReplace(from_quoted, to_quoted, api_def->mutable_description()); 298 } 299 } 300 301 namespace { 302 303 // Initializes given ApiDef with data in OpDef. 304 void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) { 305 api_def->set_graph_op_name(op_def.name()); 306 api_def->set_visibility(ApiDef::VISIBLE); 307 308 auto* endpoint = api_def->add_endpoint(); 309 endpoint->set_name(op_def.name()); 310 311 for (const auto& op_in_arg : op_def.input_arg()) { 312 auto* api_in_arg = api_def->add_in_arg(); 313 api_in_arg->set_name(op_in_arg.name()); 314 api_in_arg->set_rename_to(op_in_arg.name()); 315 api_in_arg->set_description(op_in_arg.description()); 316 317 *api_def->add_arg_order() = op_in_arg.name(); 318 } 319 for (const auto& op_out_arg : op_def.output_arg()) { 320 auto* api_out_arg = api_def->add_out_arg(); 321 api_out_arg->set_name(op_out_arg.name()); 322 api_out_arg->set_rename_to(op_out_arg.name()); 323 api_out_arg->set_description(op_out_arg.description()); 324 } 325 for (const auto& op_attr : op_def.attr()) { 326 auto* api_attr = api_def->add_attr(); 327 api_attr->set_name(op_attr.name()); 328 api_attr->set_rename_to(op_attr.name()); 329 if (op_attr.has_default_value()) { 330 *api_attr->mutable_default_value() = op_attr.default_value(); 331 } 332 api_attr->set_description(op_attr.description()); 333 } 334 api_def->set_summary(op_def.summary()); 335 api_def->set_description(op_def.description()); 336 } 337 338 // Updates base_arg based on overrides in new_arg. 339 void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) { 340 if (!new_arg.rename_to().empty()) { 341 base_arg->set_rename_to(new_arg.rename_to()); 342 } 343 if (!new_arg.description().empty()) { 344 base_arg->set_description(new_arg.description()); 345 } 346 } 347 348 // Updates base_attr based on overrides in new_attr. 349 void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) { 350 if (!new_attr.rename_to().empty()) { 351 base_attr->set_rename_to(new_attr.rename_to()); 352 } 353 if (new_attr.has_default_value()) { 354 *base_attr->mutable_default_value() = new_attr.default_value(); 355 } 356 if (!new_attr.description().empty()) { 357 base_attr->set_description(new_attr.description()); 358 } 359 } 360 361 // Updates base_api_def based on overrides in new_api_def. 362 Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) { 363 // Merge visibility 364 if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) { 365 base_api_def->set_visibility(new_api_def.visibility()); 366 } 367 // Merge endpoints 368 if (new_api_def.endpoint_size() > 0) { 369 base_api_def->clear_endpoint(); 370 std::copy( 371 new_api_def.endpoint().begin(), new_api_def.endpoint().end(), 372 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint())); 373 } 374 // Merge args 375 for (const auto& new_arg : new_api_def.in_arg()) { 376 bool found_base_arg = false; 377 for (int i = 0; i < base_api_def->in_arg_size(); ++i) { 378 auto* base_arg = base_api_def->mutable_in_arg(i); 379 if (base_arg->name() == new_arg.name()) { 380 MergeArg(base_arg, new_arg); 381 found_base_arg = true; 382 break; 383 } 384 } 385 if (!found_base_arg) { 386 return errors::FailedPrecondition("Argument ", new_arg.name(), 387 " not defined in base api for ", 388 base_api_def->graph_op_name()); 389 } 390 } 391 for (const auto& new_arg : new_api_def.out_arg()) { 392 bool found_base_arg = false; 393 for (int i = 0; i < base_api_def->out_arg_size(); ++i) { 394 auto* base_arg = base_api_def->mutable_out_arg(i); 395 if (base_arg->name() == new_arg.name()) { 396 MergeArg(base_arg, new_arg); 397 found_base_arg = true; 398 break; 399 } 400 } 401 if (!found_base_arg) { 402 return errors::FailedPrecondition("Argument ", new_arg.name(), 403 " not defined in base api for ", 404 base_api_def->graph_op_name()); 405 } 406 } 407 // Merge arg order 408 if (new_api_def.arg_order_size() > 0) { 409 // Validate that new arg_order is correct. 410 if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) { 411 return errors::FailedPrecondition( 412 "Invalid number of arguments ", new_api_def.arg_order_size(), " for ", 413 base_api_def->graph_op_name(), 414 ". Expected: ", base_api_def->arg_order_size()); 415 } 416 if (!std::is_permutation(new_api_def.arg_order().begin(), 417 new_api_def.arg_order().end(), 418 base_api_def->arg_order().begin())) { 419 return errors::FailedPrecondition( 420 "Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "), 421 " for ", base_api_def->graph_op_name(), 422 ". All elements in arg_order override must match base arg_order: ", 423 str_util::Join(base_api_def->arg_order(), ", ")); 424 } 425 426 base_api_def->clear_arg_order(); 427 std::copy( 428 new_api_def.arg_order().begin(), new_api_def.arg_order().end(), 429 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order())); 430 } 431 // Merge attributes 432 for (const auto& new_attr : new_api_def.attr()) { 433 bool found_base_attr = false; 434 for (int i = 0; i < base_api_def->attr_size(); ++i) { 435 auto* base_attr = base_api_def->mutable_attr(i); 436 if (base_attr->name() == new_attr.name()) { 437 MergeAttr(base_attr, new_attr); 438 found_base_attr = true; 439 break; 440 } 441 } 442 if (!found_base_attr) { 443 return errors::FailedPrecondition("Attribute ", new_attr.name(), 444 " not defined in base api for ", 445 base_api_def->graph_op_name()); 446 } 447 } 448 // Merge summary 449 if (!new_api_def.summary().empty()) { 450 base_api_def->set_summary(new_api_def.summary()); 451 } 452 // Merge description 453 auto description = new_api_def.description().empty() 454 ? base_api_def->description() 455 : new_api_def.description(); 456 457 if (!new_api_def.description_prefix().empty()) { 458 description = 459 strings::StrCat(new_api_def.description_prefix(), "\n", description); 460 } 461 if (!new_api_def.description_suffix().empty()) { 462 description = 463 strings::StrCat(description, "\n", new_api_def.description_suffix()); 464 } 465 base_api_def->set_description(description); 466 return Status::OK(); 467 } 468 } // namespace 469 470 ApiDefMap::ApiDefMap(const OpList& op_list) { 471 for (const auto& op : op_list.op()) { 472 ApiDef api_def; 473 InitApiDefFromOpDef(op, &api_def); 474 map_[op.name()] = api_def; 475 } 476 } 477 478 ApiDefMap::~ApiDefMap() {} 479 480 Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) { 481 for (const auto& filename : filenames) { 482 TF_RETURN_IF_ERROR(LoadFile(env, filename)); 483 } 484 return Status::OK(); 485 } 486 487 Status ApiDefMap::LoadFile(Env* env, const string& filename) { 488 if (filename.empty()) return Status::OK(); 489 string contents; 490 TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); 491 Status status = LoadApiDef(contents); 492 if (!status.ok()) { 493 // Return failed status annotated with filename to aid in debugging. 494 return Status(status.code(), 495 strings::StrCat("Error parsing ApiDef file ", filename, ": ", 496 status.error_message())); 497 } 498 return Status::OK(); 499 } 500 501 Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { 502 const string contents = PBTxtFromMultiline(api_def_file_contents); 503 ApiDefs api_defs; 504 TF_RETURN_IF_ERROR( 505 proto_utils::ParseTextFormatFromString(contents, &api_defs)); 506 for (const auto& api_def : api_defs.op()) { 507 // Check if the op definition is loaded. If op definition is not 508 // loaded, then we just skip this ApiDef. 509 if (map_.find(api_def.graph_op_name()) != map_.end()) { 510 // Overwrite current api def with data in api_def. 511 TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def)); 512 } 513 } 514 return Status::OK(); 515 } 516 517 void ApiDefMap::UpdateDocs() { 518 for (auto& name_and_api_def : map_) { 519 auto& api_def = name_and_api_def.second; 520 CHECK_GT(api_def.endpoint_size(), 0); 521 const string canonical_name = api_def.endpoint(0).name(); 522 if (api_def.graph_op_name() != canonical_name) { 523 RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def); 524 } 525 for (const auto& in_arg : api_def.in_arg()) { 526 if (in_arg.name() != in_arg.rename_to()) { 527 RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def); 528 } 529 } 530 for (const auto& out_arg : api_def.out_arg()) { 531 if (out_arg.name() != out_arg.rename_to()) { 532 RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def); 533 } 534 } 535 for (const auto& attr : api_def.attr()) { 536 if (attr.name() != attr.rename_to()) { 537 RenameInDocs(attr.name(), attr.rename_to(), &api_def); 538 } 539 } 540 } 541 } 542 543 const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const { 544 return gtl::FindOrNull(map_, name); 545 } 546 } // namespace tensorflow 547