1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/grappler/utils/functions.h" 16 17 #include "absl/container/flat_hash_map.h" 18 #include "absl/container/flat_hash_set.h" 19 #include "absl/strings/str_cat.h" 20 #include "absl/strings/substitute.h" 21 #include "tensorflow/core/framework/attr_value.pb.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/framework/function.pb.h" 24 #include "tensorflow/core/framework/graph_def_util.h" 25 #include "tensorflow/core/framework/node_def.pb.h" 26 #include "tensorflow/core/framework/op.h" 27 #include "tensorflow/core/framework/tensor_shape.pb.h" 28 #include "tensorflow/core/framework/types.pb.h" 29 #include "tensorflow/core/framework/versions.pb.h" 30 #include "tensorflow/core/grappler/op_types.h" 31 #include "tensorflow/core/grappler/utils.h" 32 #include "tensorflow/core/lib/strings/scanner.h" 33 34 namespace tensorflow { 35 namespace grappler { 36 37 namespace { 38 39 Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration, 40 const NodeDef& node, 41 GrapplerFunctionConnectivity* connectivity) { 42 tensorflow::NameRangeMap outputs_range_map; 43 TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode( 44 node, registration.op_def, nullptr, &outputs_range_map)); 45 connectivity->RegisterFunctionBodyOutputs(node.name(), 46 std::move(outputs_range_map)); 47 return Status::OK(); 48 } 49 50 Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib, 51 const NodeDef& node, 52 GrapplerFunctionConnectivity* connectivity) { 53 const OpRegistrationData* registration; 54 TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration)); 55 return RegisterFunctionBodyOutputs(*registration, node, connectivity); 56 } 57 58 // Replace the placeholder attribute values with the values specified in 59 // instantiation attributes. 60 Status ResolveFunctionBodyNodeAttrPlaceholders( 61 const AttrSlice& func_instantiation_attr, NodeDef* node) { 62 for (auto& attr : *node->mutable_attr()) { 63 const string& placeholder = attr.second.placeholder(); 64 if (placeholder.empty()) continue; 65 66 const AttrValue* attr_value = func_instantiation_attr.Find(placeholder); 67 if (attr_value) { 68 attr.second = *attr_value; 69 } else { 70 return errors::InvalidArgument("Can't resolve placeholder: ", 71 placeholder); 72 } 73 } 74 return Status::OK(); 75 } 76 77 } // namespace 78 79 void GrapplerFunctionConnectivity::RegisterInputArgExpansion( 80 InputArgExpansion input_arg_expansion) { 81 string input_name = input_arg_expansion.input_name; 82 const auto& placeholders = input_arg_expansion.placeholders; 83 84 for (int i = 0; i < placeholders.size(); ++i) { 85 const string& placeholder = input_arg_expansion.placeholders[i]; 86 input_arg_placeholders_.insert( 87 {placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}}); 88 } 89 input_arg_expansions_.insert( 90 {std::move(input_name), std::move(input_arg_expansion)}); 91 } 92 93 void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs( 94 const string& node_name, tensorflow::NameRangeMap&& outputs) { 95 function_body_outputs_[node_name] = std::move(outputs); 96 } 97 98 Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( 99 const string& func_def_input, std::vector<string>* graph_def_inputs) const { 100 using ::tensorflow::strings::Scanner; 101 102 if (IsControlInput(func_def_input)) { 103 graph_def_inputs->push_back(func_def_input); 104 return Status::OK(); 105 } 106 107 // Parse input format: "node_name[:node_output][:position]" 108 string node_name; 109 string node_output; 110 int position = -1; 111 112 StringPiece capture; 113 StringPiece remaining; 114 115 // Parse "node_name" 116 if (Scanner(func_def_input) 117 .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) 118 .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) 119 .GetResult(&remaining, &capture)) { 120 node_name = string(capture.data(), capture.size()); 121 } 122 123 // Parse "node_output" if it exists 124 if (Scanner(remaining) 125 .OneLiteral(":") 126 .RestartCapture() 127 .One(strings::Scanner::LETTER) 128 .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE) 129 .GetResult(&remaining, &capture)) { 130 node_output = string(capture.data(), capture.size()); 131 } 132 133 // Parse "position" if it exists 134 if (Scanner(remaining) 135 .OneLiteral(":") 136 .RestartCapture() 137 .Many(strings::Scanner::DIGIT) 138 .GetResult(nullptr, &capture)) { 139 CHECK(strings::safe_strto32(capture, &position)); 140 } 141 142 // If "node_output" is not empty, it must be an output of a function body node 143 bool is_function_body_output = !node_output.empty(); 144 145 // Function input argument: "node_name[:position]" 146 if (!is_function_body_output) { 147 auto input_arg = input_arg_expansions_.find(node_name); 148 if (input_arg != input_arg_expansions_.end()) { 149 const InputArgExpansion& input_arg_expansion = input_arg->second; 150 const auto& placeholders = input_arg_expansion.placeholders; 151 152 if (position == -1) { 153 // If position is not defined use all placeholders 154 graph_def_inputs->reserve(placeholders.size()); 155 for (const string& placeholder : placeholders) { 156 graph_def_inputs->push_back(placeholder); 157 } 158 } else { 159 if (position > input_arg_expansion.placeholders.size() - 1) { 160 return errors::InvalidArgument("Invalid input ", node_name, 161 "position: ", position, 162 " (out of range)"); 163 } 164 graph_def_inputs->push_back(input_arg_expansion.placeholders[position]); 165 } 166 167 return Status::OK(); 168 } 169 } 170 171 // Function body output: "node_name:node_output[:position]" 172 if (is_function_body_output) { 173 auto function_body_outputs = function_body_outputs_.find(node_name); 174 if (function_body_outputs != function_body_outputs_.end()) { 175 const tensorflow::NameRangeMap& outputs = function_body_outputs->second; 176 auto output = outputs.find(node_output); 177 if (output != outputs.end()) { 178 const auto& output_range = output->second; 179 180 if (position == -1) { 181 graph_def_inputs->reserve(graph_def_inputs->size() + 182 output_range.second - output_range.first); 183 // If position is not defined expand node output range 184 for (int i = output_range.first; i < output_range.second; ++i) { 185 graph_def_inputs->push_back( 186 i == 0 ? node_name : absl::StrCat(node_name, ":", i)); 187 } 188 } else { 189 if (position > (output_range.second - output_range.first)) { 190 return errors::InvalidArgument( 191 "Invalid node ", node_name, " output ", node_output, 192 " position: ", position, " (out of range)"); 193 } 194 int pos = output_range.first + position; 195 graph_def_inputs->push_back( 196 pos == 0 ? node_name : absl::StrCat(node_name, ":", pos)); 197 } 198 199 return Status::OK(); 200 } 201 } 202 } 203 204 return errors::InvalidArgument("Failed to expand a function def input: ", 205 func_def_input); 206 } 207 208 Status GrapplerFunctionConnectivity::ExpandNodeInputs( 209 NodeDef* function_body_node) const { 210 std::vector<string> expanded_inputs; 211 212 for (const string& function_def_input : function_body_node->input()) { 213 TF_RETURN_IF_ERROR( 214 ExpandFunctionDefInput(function_def_input, &expanded_inputs)); 215 } 216 217 function_body_node->clear_input(); 218 for (string& expanded_input : expanded_inputs) 219 function_body_node->add_input(std::move(expanded_input)); 220 return Status::OK(); 221 } 222 223 Status GrapplerFunctionConnectivity::AsFunctionDefInput( 224 const string& graph_def_input, string* func_def_input) const { 225 if (IsControlInput(graph_def_input)) { 226 *func_def_input = graph_def_input; 227 return Status::OK(); 228 } 229 230 const TensorId tensor = ParseTensorName(graph_def_input); 231 DCHECK_GE(tensor.index(), 0); 232 233 const absl::string_view node_name = tensor.node(); 234 const int index = tensor.index(); 235 236 // Check if it's an input arg placeholder 237 if (tensor.index() == 0) { 238 const auto is_input_placeholder = input_arg_placeholders_.find(node_name); 239 if (is_input_placeholder != input_arg_placeholders_.end()) { 240 const InputArgPlaceholder& placeholder = is_input_placeholder->second; 241 *func_def_input = 242 absl::StrCat(placeholder.input_name, ":", placeholder.input_index); 243 return Status::OK(); 244 } 245 } 246 247 // It must be output from one of the function body nodes 248 const auto is_body_output = function_body_outputs_.find(tensor.node()); 249 if (is_body_output != function_body_outputs_.end()) { 250 const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second; 251 252 for (const auto& el : outputs_range_map) { 253 const auto& output_name = el.first; 254 const auto& output_range = el.second; 255 if (index >= output_range.first && index < output_range.second) { 256 int pos = index - output_range.first; 257 *func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos); 258 return Status::OK(); 259 } 260 } 261 } 262 263 return errors::InvalidArgument("Unknown graph def input: ", graph_def_input); 264 } 265 266 Status GrapplerFunctionConnectivity::AsFunctionDefNode( 267 NodeDef* function_body_node) const { 268 string func_def_input; 269 270 for (int i = 0; i < function_body_node->input_size(); ++i) { 271 TF_RETURN_IF_ERROR( 272 AsFunctionDefInput(function_body_node->input(i), &func_def_input)); 273 function_body_node->set_input(i, func_def_input); 274 } 275 276 return Status::OK(); 277 } 278 279 Status GrapplerFunctionItemInstantiation::GetTypeAttr( 280 const string& type_attr_name, DataType* data_type) const { 281 const AttrValue* type_attr = func_instantiation_attr_.Find(type_attr_name); 282 if (type_attr == nullptr) { 283 return errors::InvalidArgument("Type attribute ", type_attr_name, 284 " is not defined"); 285 } else if (type_attr->type() == DT_INVALID) { 286 return errors::InvalidArgument("Type attribute ", type_attr_name, 287 " is not defined with a valid type"); 288 } else { 289 *data_type = type_attr->type(); 290 } 291 return Status::OK(); 292 } 293 294 Status GrapplerFunctionItemInstantiation::GetArgType( 295 const OpDef::ArgDef& arg, DataType* data_type) const { 296 if (arg.type() != DT_INVALID) { 297 *data_type = arg.type(); 298 } else { 299 if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) { 300 return errors::InvalidArgument( 301 "Arguments with sequence of tensors are not supported. Unsupported " 302 "argument name: ", 303 arg.name()); 304 } 305 TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type)); 306 } 307 return Status::OK(); 308 } 309 310 GrapplerFunctionItem::GrapplerFunctionItem( 311 string func_name, string description, AttrSlice func_attr, 312 std::vector<InputArgExpansion> input_arg_expansions, 313 std::vector<OutputArgExpansion> output_arg_expansions, 314 std::vector<ControlOutput> control_outputs, const int graph_def_version, 315 const bool is_stateful, GraphDef&& function_body) 316 : description_(std::move(description)), 317 func_attr_(func_attr), 318 input_arg_expansions_(std::move(input_arg_expansions)), 319 output_arg_expansions_(std::move(output_arg_expansions)), 320 control_outputs_(std::move(control_outputs)), 321 is_stateful_(is_stateful) { 322 id = std::move(func_name); 323 graph = std::move(function_body); 324 325 graph.mutable_versions()->set_producer(graph_def_version); 326 // Fill the feed nodes with input placeholders. 327 for (const InputArgExpansion& input_arg : input_arg_expansions_) { 328 for (const string& placeholder : input_arg.placeholders) { 329 feed.push_back({placeholder, Tensor()}); 330 } 331 } 332 // Fill the fetch nodes with outputs. 333 for (const OutputArgExpansion& output_arg : output_arg_expansions_) { 334 for (const string& output_node : output_arg.output_nodes) { 335 fetch.push_back(output_node); 336 } 337 } 338 // We must keep all control output nodes. 339 for (const ControlOutput& control_output : control_outputs_) { 340 keep_ops.push_back(control_output.node_name); 341 } 342 343 // Tensorflow functions execution semantics is different from the main graph, 344 // and we need to preserve it when we do graph optimizations. 345 optimization_options().allow_pruning_stateful_and_dataset_ops = false; 346 } 347 348 const string& GrapplerFunctionItem::description() const { return description_; } 349 350 const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const { 351 return input_arg_expansions_; 352 } 353 354 const InputArgExpansion& GrapplerFunctionItem::input(int i) const { 355 return input_arg_expansions_[i]; 356 } 357 358 const std::size_t GrapplerFunctionItem::input_size() const { 359 return input_arg_expansions_.size(); 360 } 361 362 const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const { 363 return output_arg_expansions_; 364 } 365 366 const OutputArgExpansion& GrapplerFunctionItem::output(int i) const { 367 return output_arg_expansions_[i]; 368 } 369 370 const std::size_t GrapplerFunctionItem::output_size() const { 371 return output_arg_expansions_.size(); 372 } 373 374 const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs() 375 const { 376 return control_outputs_; 377 } 378 379 const std::size_t GrapplerFunctionItem::control_output_size() const { 380 return control_outputs_.size(); 381 } 382 383 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; } 384 385 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; } 386 387 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; } 388 389 bool GrapplerFunctionItem::is_stateful() const { return is_stateful_; } 390 391 GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) { 392 graph.Swap(&other); 393 return *this; 394 } 395 396 bool HasParametrizedType(const FunctionDef& func) { 397 const auto is_type_parametrized = [](const OpDef::ArgDef& arg) { 398 return !arg.type_attr().empty() || !arg.number_attr().empty() || 399 !arg.type_list_attr().empty(); 400 }; 401 402 const auto& input = func.signature().input_arg(); 403 const auto& output = func.signature().output_arg(); 404 return std::any_of(input.begin(), input.end(), is_type_parametrized) || 405 std::any_of(output.begin(), output.end(), is_type_parametrized); 406 } 407 408 bool HasParametrizedBody(const FunctionDef& func) { 409 const auto is_parametrized = [&](const NodeDef& node) { 410 for (const auto& attr : node.attr()) { 411 if (!attr.second.placeholder().empty()) return true; 412 } 413 return false; 414 }; 415 return std::any_of(func.node_def().begin(), func.node_def().end(), 416 is_parametrized); 417 } 418 419 bool IsParametrized(const FunctionDef& func) { 420 return HasParametrizedType(func) || HasParametrizedBody(func); 421 } 422 423 Status InstantiationTypeParameters( 424 const FunctionDef& func, const AttrSlice& func_instantiation_attr, 425 absl::flat_hash_map<string, DataType>* type_parameters) { 426 if (!type_parameters->empty()) { 427 return errors::InvalidArgument("Type parameters output map must be empty"); 428 } 429 430 GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr); 431 432 const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) { 433 // Check if it's unknown and unresolved type. 434 if (arg.type() == DT_INVALID && 435 type_parameters->find(arg.type_attr()) == type_parameters->end()) { 436 DataType data_type; 437 TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type)); 438 type_parameters->insert({arg.type_attr(), data_type}); 439 } 440 return Status::OK(); 441 }; 442 443 for (const auto& input : func.signature().input_arg()) 444 TF_RETURN_IF_ERROR(resolve_type_attr(input)); 445 for (const auto& output : func.signature().output_arg()) 446 TF_RETURN_IF_ERROR(resolve_type_attr(output)); 447 448 return Status::OK(); 449 } 450 451 Status InstantiationBodyParameters( 452 const FunctionDef& func, const AttrSlice& func_instantiation_attr, 453 absl::flat_hash_map<string, AttrValue>* body_parameters) { 454 if (!body_parameters->empty()) { 455 return errors::InvalidArgument("Body parameters output map must be empty"); 456 } 457 458 for (const NodeDef& func_body_node : func.node_def()) { 459 for (auto& attr : func_body_node.attr()) { 460 const string& placeholder = attr.second.placeholder(); 461 462 if (placeholder.empty() || 463 body_parameters->find(placeholder) != body_parameters->end()) { 464 continue; 465 } 466 467 const AttrValue* placeholder_value = 468 func_instantiation_attr.Find(placeholder); 469 if (placeholder_value) { 470 body_parameters->insert({placeholder, *placeholder_value}); 471 } else { 472 return errors::InvalidArgument("Can't resolve placeholder: ", 473 placeholder); 474 } 475 } 476 } 477 478 return Status::OK(); 479 } 480 481 Status MakeGrapplerFunctionItem(const FunctionDef& func, 482 const AttrSlice& func_instantiation_attr, 483 const FunctionLibraryDefinition& flib, 484 const int graph_def_version, 485 GrapplerFunctionItem* item) { 486 const OpDef& signature = func.signature(); 487 488 if (signature.name().empty()) { 489 return errors::InvalidArgument("Function name must be specified"); 490 } 491 492 // Function types will be resolved from function instantiation attributes. All 493 // other attributes will be lost during conversion to FunctionDef. 494 for (const OpDef::AttrDef& attr : signature.attr()) { 495 if (attr.type() != "type") { 496 return errors::InvalidArgument( 497 "Function signature must have only type attributes"); 498 } 499 } 500 501 // Helper methods to lookup function instantiation attributes 502 GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr); 503 504 // Mapping from FunctionDef input format (name[:output][:position]) to 505 // GraphDef input format (name[:position]) 506 GrapplerFunctionConnectivity connectivity; 507 508 // Instantiate function body into a statically defined graph def. 509 GraphDef function_body; 510 511 // Function body shares the library with the graph that instantiated it. We do 512 // not need a full copy of the function library, just the reachable subset. 513 *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto(); 514 515 VLOG(3) << absl::Substitute( 516 "Deleted $0 unreachable functions from the Grappler function item " 517 "instantiation of $1 (library size = $2)", 518 flib.num_functions() - function_body.library().function_size(), 519 signature.name(), function_body.library().function_size()); 520 521 // TODO(ezhulenev): support functions with tensor sequence inputs/outputs 522 523 // Make sure that there are no tensor lists in inputs or outputs. 524 for (const OpDef::ArgDef& input : signature.input_arg()) { 525 if (!input.type_list_attr().empty() || !input.number_attr().empty()) { 526 return errors::InvalidArgument( 527 "Inputs with lists of tensors are not supported. Input: ", 528 input.name()); 529 } 530 } 531 for (const OpDef::ArgDef& output : signature.output_arg()) { 532 if (!output.type_list_attr().empty() || !output.number_attr().empty()) { 533 return errors::InvalidArgument( 534 "Outputs with lists of tensors are not supported. Output: ", 535 output.name()); 536 } 537 } 538 539 std::vector<InputArgExpansion> inputs; 540 inputs.reserve(signature.input_arg_size()); 541 542 // For each input argument create a placeholder in function body. 543 for (const OpDef::ArgDef& input : signature.input_arg()) { 544 DataType input_data_type; 545 TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type)); 546 547 NodeDef* placeholder = function_body.add_node(); 548 placeholder->set_name(input.name()); 549 placeholder->set_op("Placeholder"); 550 (*placeholder->mutable_attr())["dtype"].set_type(input_data_type); 551 (*placeholder->mutable_attr())["shape"].mutable_shape()->set_unknown_rank( 552 true); 553 554 InputArgExpansion input_expansion{/*input_name=*/input.name(), 555 /*data_type=*/input_data_type, 556 /*is_ref=*/input.is_ref(), 557 /*placeholders=*/{input.name()}}; 558 connectivity.RegisterInputArgExpansion(input_expansion); 559 inputs.push_back(std::move(input_expansion)); 560 } 561 562 // Keep names of all nodes in the function body to guarantee that we do not 563 // add an identity with a duplicate name. 564 absl::flat_hash_set<absl::string_view> func_body_nodes; 565 566 // Generate unique output node name: "${out_arg_name}_output_node_${index}". 567 const auto output_node_name = [&func_body_nodes](const OpDef::ArgDef& out, 568 int index) -> string { 569 string name = absl::StrCat(out.name(), "_output_node_", index); 570 int i = 1; 571 while (func_body_nodes.find(name) != func_body_nodes.end()) { 572 name = absl::StrCat(out.name(), "_output_node_", index, "_", i++); 573 } 574 return name; 575 }; 576 577 // Add all function nodes to the function body. 578 for (const NodeDef& func_def_node : func.node_def()) { 579 func_body_nodes.insert(func_def_node.name()); 580 581 NodeDef* new_node = function_body.add_node(); 582 *new_node = func_def_node; 583 584 const OpRegistrationData* registration; 585 TF_RETURN_IF_ERROR(flib.LookUp(func_def_node.op(), ®istration)); 586 587 // Resolve all placeholder values using function instantiation attributes. 588 TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders( 589 func_instantiation_attr, new_node)); 590 591 // Register node output range in a function connectivity. 592 TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node, 593 &connectivity)); 594 } 595 596 // Rewrite inputs to use GraphDef format 597 for (NodeDef& node : *function_body.mutable_node()) { 598 TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node)); 599 } 600 601 std::vector<OutputArgExpansion> outputs; 602 outputs.reserve(signature.output_arg_size()); 603 604 // For each function output argument we create an Identity node in the 605 // function body, that reads output tensor from the function body node. 606 for (const OpDef::ArgDef& out : signature.output_arg()) { 607 DataType output_data_type; 608 TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type)); 609 610 std::vector<string> output_tensors; 611 auto ret = func.ret().find(out.name()); 612 TF_RETURN_IF_ERROR( 613 ret != func.ret().end() 614 // Expand outputs using provided output mapping 615 ? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors) 616 // Otherwise output must be one of the function inputs 617 : connectivity.ExpandFunctionDefInput(out.name(), &output_tensors)); 618 619 absl::InlinedVector<string, 1> output_nodes; 620 for (int i = 0; i < output_tensors.size(); ++i) { 621 const string& output_tensor = output_tensors[i]; 622 623 NodeDef* identity = function_body.add_node(); 624 identity->set_name(output_node_name(out, i)); 625 identity->set_op("Identity"); 626 (*identity->mutable_attr())["T"].set_type(output_data_type); 627 identity->add_input(output_tensor); 628 629 output_nodes.push_back(identity->name()); 630 } 631 632 OutputArgExpansion output{/*output_name=*/out.name(), 633 /*data_type=*/output_data_type, 634 /*is_ref=*/out.is_ref(), 635 /*output_nodes=*/std::move(output_nodes)}; 636 outputs.push_back(std::move(output)); 637 } 638 639 // Control outputs ensure that all side-effectful nodes in the function body 640 // will execute, even if they are not required to compute regular output args. 641 std::vector<ControlOutput> control_outputs; 642 control_outputs.reserve(func.control_ret_size()); 643 for (const auto& control_ret : func.control_ret()) { 644 control_outputs.push_back({control_ret.first, control_ret.second}); 645 } 646 647 *item = GrapplerFunctionItem( 648 /*func_name=*/signature.name(), 649 /*description=*/signature.description(), 650 /*func_attr=*/AttrSlice(&func.attr()), std::move(inputs), 651 std::move(outputs), std::move(control_outputs), graph_def_version, 652 signature.is_stateful(), std::move(function_body)); 653 return Status::OK(); 654 } 655 656 Status MakeGrapplerFunctionItem(const FunctionDef& func, 657 const FunctionLibraryDefinition& flib, 658 const int graph_def_version, 659 GrapplerFunctionItem* item) { 660 return MakeGrapplerFunctionItem(func, AttrSlice(), flib, graph_def_version, 661 item); 662 } 663 664 // Register GrapplerFunctionItem input arg expansion and function body outputs 665 // in the GrapplerFunctionConnectivity. 666 Status RegisterGrapplerFunctionConnectivity( 667 const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, 668 GrapplerFunctionConnectivity* connectivity) { 669 for (const InputArgExpansion& input : item.inputs()) { 670 connectivity->RegisterInputArgExpansion(input); 671 } 672 for (const NodeDef& func_body_node : item.function_body().node()) { 673 TF_RETURN_IF_ERROR( 674 RegisterFunctionBodyOutputs(flib, func_body_node, connectivity)); 675 } 676 return Status::OK(); 677 } 678 679 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, 680 GrapplerFunctionItem* item) { 681 if (!IsConstant(input_const)) { 682 return errors::InvalidArgument("Input node ", input_const.name(), 683 " is not a constant"); 684 } 685 686 auto& inputs = item->input_arg_expansions_; 687 688 // Find input arg expansion and input placeholder position in it for the 689 // given function input position. 690 InputArgExpansion* input_arg_expansion = nullptr; 691 int placeholder_idx = input_index; 692 693 for (InputArgExpansion& input : inputs) { 694 if (placeholder_idx < input.placeholders.size()) { 695 input_arg_expansion = &input; 696 break; 697 } 698 placeholder_idx -= input.placeholders.size(); 699 } 700 701 if (input_arg_expansion == nullptr) { 702 return errors::InvalidArgument("Input placeholder not found: input_index=", 703 input_index, " function=", item->id); 704 } 705 706 // Delete placeholder from input expansion. 707 string placeholder_name = input_arg_expansion->placeholders[placeholder_idx]; 708 input_arg_expansion->placeholders.erase( 709 input_arg_expansion->placeholders.begin() + placeholder_idx); 710 711 // Delete empty input expansions. 712 inputs.erase(std::remove_if(inputs.begin(), inputs.end(), 713 [](const InputArgExpansion& input) { 714 return input.placeholders.empty(); 715 }), 716 inputs.end()); 717 718 // Replace placeholder node in the function body with a const node. 719 for (NodeDef& node : *item->graph.mutable_node()) { 720 if (node.name() == placeholder_name) { 721 node = input_const; 722 node.set_name(placeholder_name); 723 node.clear_input(); // remove potential control inputs 724 node.clear_device(); // device placement is defined by instantiating node 725 } 726 } 727 728 return Status::OK(); 729 } 730 731 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs, 732 GrapplerFunctionItem* item, 733 std::vector<std::pair<int, int>>* output_mapping) { 734 DCHECK(output_mapping->empty()); 735 736 // Code below assumes that we do not support tensor list outputs and there is 737 // a 1-to-1 mapping between output tensor and output argument expansion. 738 for (const OutputArgExpansion& out_arg : item->outputs()) { 739 DCHECK(out_arg.output_nodes.size() == 1) 740 << "Output arg expansion must have single output"; 741 } 742 743 // Do some sanity checking of the removed outputs positions. 744 for (int remove_output : remove_outputs) { 745 if (remove_output < 0 || remove_output >= item->output_size()) { 746 return errors::InvalidArgument( 747 "Function output index is out of bound: index=", remove_output, 748 " max_output_index=", item->output_size()); 749 } 750 } 751 752 absl::flat_hash_set<const OutputArgExpansion*> remove_output_args; 753 const auto is_remove_output_arg = [&](const OutputArgExpansion& output) { 754 return remove_output_args.find(&output) != remove_output_args.end(); 755 }; 756 757 for (int i = 0; i < item->output_size(); ++i) { 758 const OutputArgExpansion& output = item->output(i); 759 if (remove_outputs.find(i) != remove_outputs.end()) { 760 VLOG(3) << "Remove functions output: output_name=" << output.output_name 761 << "(index = " << i << ")"; 762 remove_output_args.insert(&output); 763 } else if (!remove_output_args.empty()) { 764 // Add output mapping only if output position changed. 765 output_mapping->push_back({i, i - remove_output_args.size()}); 766 } 767 } 768 769 auto& o = item->output_arg_expansions_; 770 o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end()); 771 772 return Status::OK(); 773 } 774 775 Status MakeFunctionDef(const GrapplerFunctionItem& item, 776 const FunctionLibraryDefinition& flib, 777 FunctionDef* func) { 778 func->mutable_signature()->set_name(item.id); 779 func->mutable_signature()->set_description(item.description()); 780 func->mutable_signature()->set_is_stateful(item.is_stateful()); 781 782 // Keep track of placeholders that were added to the graph in place of 783 // expanded function input arguments. 784 absl::flat_hash_set<absl::string_view> input_placeholders; 785 for (const InputArgExpansion& input_arg : item.inputs()) { 786 for (const string& placeholder : input_arg.placeholders) { 787 input_placeholders.insert(placeholder); 788 } 789 } 790 791 // Keep track of identity nodes that were added to the graph in place of 792 // expanded function output arguments. 793 absl::flat_hash_set<absl::string_view> output_nodes; 794 for (const OutputArgExpansion& output_arg : item.outputs()) { 795 for (const string& output_node : output_arg.output_nodes) { 796 output_nodes.insert(output_node); 797 } 798 } 799 800 // If the output identity node was not modified by any optimizer, we can 801 // bypass it and returns the function value from its input. 802 absl::flat_hash_map<absl::string_view, string> output_tensors; 803 for (const NodeDef& func_body_node : item.function_body().node()) { 804 if (!IsIdentity(func_body_node)) continue; 805 806 const string& node_name = func_body_node.name(); 807 if (output_nodes.find(node_name) != output_nodes.end()) { 808 // Grappler optimizers might optimize nodes in the fanin of the output 809 // node, and forward their control dependencies. We can't express control 810 // dependencies in a function signature, so we have to keep the node. 811 if (func_body_node.input_size() == 1) { 812 VLOG(3) << "Bypass function output node: " << node_name << " -> " 813 << func_body_node.input(0); 814 output_tensors.emplace(node_name, func_body_node.input(0)); 815 } else { 816 VLOG(3) << "Keep function output node: " << node_name; 817 } 818 } 819 } 820 821 // Return output tensor name (input of the output node) if it's safe to bypass 822 // output node, otherwise returns the output node name. 823 const auto output_tensor = 824 [&output_tensors](const OutputArgExpansion& output_arg) -> const string& { 825 const string& output_node = output_arg.output_nodes[0]; 826 const auto is_output_tensor = output_tensors.find(output_node); 827 return is_output_tensor == output_tensors.end() ? output_node 828 : is_output_tensor->second; 829 }; 830 831 // Build a GrapplerFunctionConnectivity from inputs and new function body. 832 GrapplerFunctionConnectivity connectivity; 833 TF_RETURN_IF_ERROR( 834 RegisterGrapplerFunctionConnectivity(item, flib, &connectivity)); 835 836 // Add function input arguments. 837 for (const InputArgExpansion& input_arg : item.inputs()) { 838 DCHECK(input_arg.placeholders.size() == 1) // do some sanity checking 839 << "Inputs of tensor lists are not supported"; 840 841 OpDef::ArgDef arg_def; 842 arg_def.set_name(input_arg.input_name); 843 arg_def.set_type(input_arg.data_type); 844 arg_def.set_is_ref(input_arg.is_ref); 845 *func->mutable_signature()->add_input_arg() = arg_def; 846 } 847 848 // Add function output arguments. 849 for (const OutputArgExpansion& output_arg : item.outputs()) { 850 DCHECK(output_arg.output_nodes.size() == 1) // do some sanity checking 851 << "Outputs of tensor lists are not supported"; 852 853 OpDef::ArgDef arg_def; 854 arg_def.set_name(output_arg.output_name); 855 arg_def.set_type(output_arg.data_type); 856 arg_def.set_is_ref(output_arg.is_ref); 857 *func->mutable_signature()->add_output_arg() = arg_def; 858 859 TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput( 860 output_tensor(output_arg), 861 &(*func->mutable_ret())[output_arg.output_name])); 862 } 863 864 // Add function control outputs. 865 for (const ControlOutput& control_out : item.control_outputs()) { 866 func->mutable_control_ret()->insert( 867 {control_out.output_name, control_out.node_name}); 868 *func->mutable_signature()->add_control_output() = control_out.output_name; 869 } 870 871 // Copy function definition specific attributes. 872 for (const auto& attr : item.func_attr()) { 873 const auto& attr_name = attr.first; 874 const auto& attr_value = attr.second; 875 (*func->mutable_attr())[attr_name] = attr_value; 876 } 877 878 // Copy function body nodes to the FunctionDef and update input format 879 for (const NodeDef& func_node : item.function_body().node()) { 880 const string& name = func_node.name(); 881 882 // Do not copy input placeholders. 883 if (IsPlaceholder(func_node) && input_placeholders.count(name)) continue; 884 // Do not copy output nodes that we bypassed. 885 if (IsIdentity(func_node) && output_tensors.count(name)) continue; 886 887 NodeDef* func_def_node = func->add_node_def(); 888 *func_def_node = func_node; 889 TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node)); 890 } 891 892 return Status::OK(); 893 } 894 895 } // end namespace grappler 896 } // end namespace tensorflow 897