Home | History | Annotate | Download | only in utils
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/grappler/utils/functions.h"
     16 
     17 #include "absl/container/flat_hash_map.h"
     18 #include "absl/container/flat_hash_set.h"
     19 #include "absl/strings/str_cat.h"
     20 #include "absl/strings/substitute.h"
     21 #include "tensorflow/core/framework/attr_value.pb.h"
     22 #include "tensorflow/core/framework/function.h"
     23 #include "tensorflow/core/framework/function.pb.h"
     24 #include "tensorflow/core/framework/graph_def_util.h"
     25 #include "tensorflow/core/framework/node_def.pb.h"
     26 #include "tensorflow/core/framework/op.h"
     27 #include "tensorflow/core/framework/tensor_shape.pb.h"
     28 #include "tensorflow/core/framework/types.pb.h"
     29 #include "tensorflow/core/framework/versions.pb.h"
     30 #include "tensorflow/core/grappler/op_types.h"
     31 #include "tensorflow/core/grappler/utils.h"
     32 #include "tensorflow/core/lib/strings/scanner.h"
     33 
     34 namespace tensorflow {
     35 namespace grappler {
     36 
     37 namespace {
     38 
     39 Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
     40                                    const NodeDef& node,
     41                                    GrapplerFunctionConnectivity* connectivity) {
     42   tensorflow::NameRangeMap outputs_range_map;
     43   TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
     44       node, registration.op_def, nullptr, &outputs_range_map));
     45   connectivity->RegisterFunctionBodyOutputs(node.name(),
     46                                             std::move(outputs_range_map));
     47   return Status::OK();
     48 }
     49 
     50 Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib,
     51                                    const NodeDef& node,
     52                                    GrapplerFunctionConnectivity* connectivity) {
     53   const OpRegistrationData* registration;
     54   TF_RETURN_IF_ERROR(flib.LookUp(node.op(), &registration));
     55   return RegisterFunctionBodyOutputs(*registration, node, connectivity);
     56 }
     57 
     58 // Replace the placeholder attribute values with the values specified in
     59 // instantiation attributes.
     60 Status ResolveFunctionBodyNodeAttrPlaceholders(
     61     const AttrSlice& func_instantiation_attr, NodeDef* node) {
     62   for (auto& attr : *node->mutable_attr()) {
     63     const string& placeholder = attr.second.placeholder();
     64     if (placeholder.empty()) continue;
     65 
     66     const AttrValue* attr_value = func_instantiation_attr.Find(placeholder);
     67     if (attr_value) {
     68       attr.second = *attr_value;
     69     } else {
     70       return errors::InvalidArgument("Can't resolve placeholder: ",
     71                                      placeholder);
     72     }
     73   }
     74   return Status::OK();
     75 }
     76 
     77 }  // namespace
     78 
     79 void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
     80     InputArgExpansion input_arg_expansion) {
     81   string input_name = input_arg_expansion.input_name;
     82   const auto& placeholders = input_arg_expansion.placeholders;
     83 
     84   for (int i = 0; i < placeholders.size(); ++i) {
     85     const string& placeholder = input_arg_expansion.placeholders[i];
     86     input_arg_placeholders_.insert(
     87         {placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}});
     88   }
     89   input_arg_expansions_.insert(
     90       {std::move(input_name), std::move(input_arg_expansion)});
     91 }
     92 
     93 void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
     94     const string& node_name, tensorflow::NameRangeMap&& outputs) {
     95   function_body_outputs_[node_name] = std::move(outputs);
     96 }
     97 
     98 Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
     99     const string& func_def_input, std::vector<string>* graph_def_inputs) const {
    100   using ::tensorflow::strings::Scanner;
    101 
    102   if (IsControlInput(func_def_input)) {
    103     graph_def_inputs->push_back(func_def_input);
    104     return Status::OK();
    105   }
    106 
    107   // Parse input format: "node_name[:node_output][:position]"
    108   string node_name;
    109   string node_output;
    110   int position = -1;
    111 
    112   StringPiece capture;
    113   StringPiece remaining;
    114 
    115   // Parse "node_name"
    116   if (Scanner(func_def_input)
    117           .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
    118           .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
    119           .GetResult(&remaining, &capture)) {
    120     node_name = string(capture.data(), capture.size());
    121   }
    122 
    123   // Parse "node_output" if it exists
    124   if (Scanner(remaining)
    125           .OneLiteral(":")
    126           .RestartCapture()
    127           .One(strings::Scanner::LETTER)
    128           .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
    129           .GetResult(&remaining, &capture)) {
    130     node_output = string(capture.data(), capture.size());
    131   }
    132 
    133   // Parse "position" if it exists
    134   if (Scanner(remaining)
    135           .OneLiteral(":")
    136           .RestartCapture()
    137           .Many(strings::Scanner::DIGIT)
    138           .GetResult(nullptr, &capture)) {
    139     CHECK(strings::safe_strto32(capture, &position));
    140   }
    141 
    142   // If "node_output" is not empty, it must be an output of a function body node
    143   bool is_function_body_output = !node_output.empty();
    144 
    145   // Function input argument: "node_name[:position]"
    146   if (!is_function_body_output) {
    147     auto input_arg = input_arg_expansions_.find(node_name);
    148     if (input_arg != input_arg_expansions_.end()) {
    149       const InputArgExpansion& input_arg_expansion = input_arg->second;
    150       const auto& placeholders = input_arg_expansion.placeholders;
    151 
    152       if (position == -1) {
    153         // If position is not defined use all placeholders
    154         graph_def_inputs->reserve(placeholders.size());
    155         for (const string& placeholder : placeholders) {
    156           graph_def_inputs->push_back(placeholder);
    157         }
    158       } else {
    159         if (position > input_arg_expansion.placeholders.size() - 1) {
    160           return errors::InvalidArgument("Invalid input ", node_name,
    161                                          "position: ", position,
    162                                          " (out of range)");
    163         }
    164         graph_def_inputs->push_back(input_arg_expansion.placeholders[position]);
    165       }
    166 
    167       return Status::OK();
    168     }
    169   }
    170 
    171   // Function body output: "node_name:node_output[:position]"
    172   if (is_function_body_output) {
    173     auto function_body_outputs = function_body_outputs_.find(node_name);
    174     if (function_body_outputs != function_body_outputs_.end()) {
    175       const tensorflow::NameRangeMap& outputs = function_body_outputs->second;
    176       auto output = outputs.find(node_output);
    177       if (output != outputs.end()) {
    178         const auto& output_range = output->second;
    179 
    180         if (position == -1) {
    181           graph_def_inputs->reserve(graph_def_inputs->size() +
    182                                     output_range.second - output_range.first);
    183           // If position is not defined expand node output range
    184           for (int i = output_range.first; i < output_range.second; ++i) {
    185             graph_def_inputs->push_back(
    186                 i == 0 ? node_name : absl::StrCat(node_name, ":", i));
    187           }
    188         } else {
    189           if (position > (output_range.second - output_range.first)) {
    190             return errors::InvalidArgument(
    191                 "Invalid node ", node_name, " output ", node_output,
    192                 " position: ", position, " (out of range)");
    193           }
    194           int pos = output_range.first + position;
    195           graph_def_inputs->push_back(
    196               pos == 0 ? node_name : absl::StrCat(node_name, ":", pos));
    197         }
    198 
    199         return Status::OK();
    200       }
    201     }
    202   }
    203 
    204   return errors::InvalidArgument("Failed to expand a function def input: ",
    205                                  func_def_input);
    206 }
    207 
    208 Status GrapplerFunctionConnectivity::ExpandNodeInputs(
    209     NodeDef* function_body_node) const {
    210   std::vector<string> expanded_inputs;
    211 
    212   for (const string& function_def_input : function_body_node->input()) {
    213     TF_RETURN_IF_ERROR(
    214         ExpandFunctionDefInput(function_def_input, &expanded_inputs));
    215   }
    216 
    217   function_body_node->clear_input();
    218   for (string& expanded_input : expanded_inputs)
    219     function_body_node->add_input(std::move(expanded_input));
    220   return Status::OK();
    221 }
    222 
    223 Status GrapplerFunctionConnectivity::AsFunctionDefInput(
    224     const string& graph_def_input, string* func_def_input) const {
    225   if (IsControlInput(graph_def_input)) {
    226     *func_def_input = graph_def_input;
    227     return Status::OK();
    228   }
    229 
    230   const TensorId tensor = ParseTensorName(graph_def_input);
    231   DCHECK_GE(tensor.index(), 0);
    232 
    233   const absl::string_view node_name = tensor.node();
    234   const int index = tensor.index();
    235 
    236   // Check if it's an input arg placeholder
    237   if (tensor.index() == 0) {
    238     const auto is_input_placeholder = input_arg_placeholders_.find(node_name);
    239     if (is_input_placeholder != input_arg_placeholders_.end()) {
    240       const InputArgPlaceholder& placeholder = is_input_placeholder->second;
    241       *func_def_input =
    242           absl::StrCat(placeholder.input_name, ":", placeholder.input_index);
    243       return Status::OK();
    244     }
    245   }
    246 
    247   // It must be output from one of the function body nodes
    248   const auto is_body_output = function_body_outputs_.find(tensor.node());
    249   if (is_body_output != function_body_outputs_.end()) {
    250     const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
    251 
    252     for (const auto& el : outputs_range_map) {
    253       const auto& output_name = el.first;
    254       const auto& output_range = el.second;
    255       if (index >= output_range.first && index < output_range.second) {
    256         int pos = index - output_range.first;
    257         *func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos);
    258         return Status::OK();
    259       }
    260     }
    261   }
    262 
    263   return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
    264 }
    265 
    266 Status GrapplerFunctionConnectivity::AsFunctionDefNode(
    267     NodeDef* function_body_node) const {
    268   string func_def_input;
    269 
    270   for (int i = 0; i < function_body_node->input_size(); ++i) {
    271     TF_RETURN_IF_ERROR(
    272         AsFunctionDefInput(function_body_node->input(i), &func_def_input));
    273     function_body_node->set_input(i, func_def_input);
    274   }
    275 
    276   return Status::OK();
    277 }
    278 
    279 Status GrapplerFunctionItemInstantiation::GetTypeAttr(
    280     const string& type_attr_name, DataType* data_type) const {
    281   const AttrValue* type_attr = func_instantiation_attr_.Find(type_attr_name);
    282   if (type_attr == nullptr) {
    283     return errors::InvalidArgument("Type attribute ", type_attr_name,
    284                                    " is not defined");
    285   } else if (type_attr->type() == DT_INVALID) {
    286     return errors::InvalidArgument("Type attribute ", type_attr_name,
    287                                    " is not defined with a valid type");
    288   } else {
    289     *data_type = type_attr->type();
    290   }
    291   return Status::OK();
    292 }
    293 
    294 Status GrapplerFunctionItemInstantiation::GetArgType(
    295     const OpDef::ArgDef& arg, DataType* data_type) const {
    296   if (arg.type() != DT_INVALID) {
    297     *data_type = arg.type();
    298   } else {
    299     if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) {
    300       return errors::InvalidArgument(
    301           "Arguments with sequence of tensors are not supported. Unsupported "
    302           "argument name: ",
    303           arg.name());
    304     }
    305     TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type));
    306   }
    307   return Status::OK();
    308 }
    309 
    310 GrapplerFunctionItem::GrapplerFunctionItem(
    311     string func_name, string description, AttrSlice func_attr,
    312     std::vector<InputArgExpansion> input_arg_expansions,
    313     std::vector<OutputArgExpansion> output_arg_expansions,
    314     std::vector<ControlOutput> control_outputs, const int graph_def_version,
    315     const bool is_stateful, GraphDef&& function_body)
    316     : description_(std::move(description)),
    317       func_attr_(func_attr),
    318       input_arg_expansions_(std::move(input_arg_expansions)),
    319       output_arg_expansions_(std::move(output_arg_expansions)),
    320       control_outputs_(std::move(control_outputs)),
    321       is_stateful_(is_stateful) {
    322   id = std::move(func_name);
    323   graph = std::move(function_body);
    324 
    325   graph.mutable_versions()->set_producer(graph_def_version);
    326   // Fill the feed nodes with input placeholders.
    327   for (const InputArgExpansion& input_arg : input_arg_expansions_) {
    328     for (const string& placeholder : input_arg.placeholders) {
    329       feed.push_back({placeholder, Tensor()});
    330     }
    331   }
    332   // Fill the fetch nodes with outputs.
    333   for (const OutputArgExpansion& output_arg : output_arg_expansions_) {
    334     for (const string& output_node : output_arg.output_nodes) {
    335       fetch.push_back(output_node);
    336     }
    337   }
    338   // We must keep all control output nodes.
    339   for (const ControlOutput& control_output : control_outputs_) {
    340     keep_ops.push_back(control_output.node_name);
    341   }
    342 
    343   // Tensorflow functions execution semantics is different from the main graph,
    344   // and we need to preserve it when we do graph optimizations.
    345   optimization_options().allow_pruning_stateful_and_dataset_ops = false;
    346 }
    347 
    348 const string& GrapplerFunctionItem::description() const { return description_; }
    349 
    350 const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const {
    351   return input_arg_expansions_;
    352 }
    353 
    354 const InputArgExpansion& GrapplerFunctionItem::input(int i) const {
    355   return input_arg_expansions_[i];
    356 }
    357 
    358 const std::size_t GrapplerFunctionItem::input_size() const {
    359   return input_arg_expansions_.size();
    360 }
    361 
    362 const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const {
    363   return output_arg_expansions_;
    364 }
    365 
    366 const OutputArgExpansion& GrapplerFunctionItem::output(int i) const {
    367   return output_arg_expansions_[i];
    368 }
    369 
    370 const std::size_t GrapplerFunctionItem::output_size() const {
    371   return output_arg_expansions_.size();
    372 }
    373 
    374 const std::vector<ControlOutput>& GrapplerFunctionItem::control_outputs()
    375     const {
    376   return control_outputs_;
    377 }
    378 
    379 const std::size_t GrapplerFunctionItem::control_output_size() const {
    380   return control_outputs_.size();
    381 }
    382 
    383 const AttrSlice& GrapplerFunctionItem::func_attr() const { return func_attr_; }
    384 
    385 const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
    386 
    387 GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
    388 
    389 bool GrapplerFunctionItem::is_stateful() const { return is_stateful_; }
    390 
    391 GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
    392   graph.Swap(&other);
    393   return *this;
    394 }
    395 
    396 bool HasParametrizedType(const FunctionDef& func) {
    397   const auto is_type_parametrized = [](const OpDef::ArgDef& arg) {
    398     return !arg.type_attr().empty() || !arg.number_attr().empty() ||
    399            !arg.type_list_attr().empty();
    400   };
    401 
    402   const auto& input = func.signature().input_arg();
    403   const auto& output = func.signature().output_arg();
    404   return std::any_of(input.begin(), input.end(), is_type_parametrized) ||
    405          std::any_of(output.begin(), output.end(), is_type_parametrized);
    406 }
    407 
    408 bool HasParametrizedBody(const FunctionDef& func) {
    409   const auto is_parametrized = [&](const NodeDef& node) {
    410     for (const auto& attr : node.attr()) {
    411       if (!attr.second.placeholder().empty()) return true;
    412     }
    413     return false;
    414   };
    415   return std::any_of(func.node_def().begin(), func.node_def().end(),
    416                      is_parametrized);
    417 }
    418 
    419 bool IsParametrized(const FunctionDef& func) {
    420   return HasParametrizedType(func) || HasParametrizedBody(func);
    421 }
    422 
    423 Status InstantiationTypeParameters(
    424     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
    425     absl::flat_hash_map<string, DataType>* type_parameters) {
    426   if (!type_parameters->empty()) {
    427     return errors::InvalidArgument("Type parameters output map must be empty");
    428   }
    429 
    430   GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
    431 
    432   const auto resolve_type_attr = [&](const OpDef::ArgDef& arg) {
    433     // Check if it's unknown and unresolved type.
    434     if (arg.type() == DT_INVALID &&
    435         type_parameters->find(arg.type_attr()) == type_parameters->end()) {
    436       DataType data_type;
    437       TF_RETURN_IF_ERROR(instantiation.GetArgType(arg, &data_type));
    438       type_parameters->insert({arg.type_attr(), data_type});
    439     }
    440     return Status::OK();
    441   };
    442 
    443   for (const auto& input : func.signature().input_arg())
    444     TF_RETURN_IF_ERROR(resolve_type_attr(input));
    445   for (const auto& output : func.signature().output_arg())
    446     TF_RETURN_IF_ERROR(resolve_type_attr(output));
    447 
    448   return Status::OK();
    449 }
    450 
    451 Status InstantiationBodyParameters(
    452     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
    453     absl::flat_hash_map<string, AttrValue>* body_parameters) {
    454   if (!body_parameters->empty()) {
    455     return errors::InvalidArgument("Body parameters output map must be empty");
    456   }
    457 
    458   for (const NodeDef& func_body_node : func.node_def()) {
    459     for (auto& attr : func_body_node.attr()) {
    460       const string& placeholder = attr.second.placeholder();
    461 
    462       if (placeholder.empty() ||
    463           body_parameters->find(placeholder) != body_parameters->end()) {
    464         continue;
    465       }
    466 
    467       const AttrValue* placeholder_value =
    468           func_instantiation_attr.Find(placeholder);
    469       if (placeholder_value) {
    470         body_parameters->insert({placeholder, *placeholder_value});
    471       } else {
    472         return errors::InvalidArgument("Can't resolve placeholder: ",
    473                                        placeholder);
    474       }
    475     }
    476   }
    477 
    478   return Status::OK();
    479 }
    480 
    481 Status MakeGrapplerFunctionItem(const FunctionDef& func,
    482                                 const AttrSlice& func_instantiation_attr,
    483                                 const FunctionLibraryDefinition& flib,
    484                                 const int graph_def_version,
    485                                 GrapplerFunctionItem* item) {
    486   const OpDef& signature = func.signature();
    487 
    488   if (signature.name().empty()) {
    489     return errors::InvalidArgument("Function name must be specified");
    490   }
    491 
    492   // Function types will be resolved from function instantiation attributes. All
    493   // other attributes will be lost during conversion to FunctionDef.
    494   for (const OpDef::AttrDef& attr : signature.attr()) {
    495     if (attr.type() != "type") {
    496       return errors::InvalidArgument(
    497           "Function signature must have only type attributes");
    498     }
    499   }
    500 
    501   // Helper methods to lookup function instantiation attributes
    502   GrapplerFunctionItemInstantiation instantiation(func_instantiation_attr);
    503 
    504   // Mapping from FunctionDef input format (name[:output][:position]) to
    505   // GraphDef input format (name[:position])
    506   GrapplerFunctionConnectivity connectivity;
    507 
    508   // Instantiate function body into a statically defined graph def.
    509   GraphDef function_body;
    510 
    511   // Function body shares the library with the graph that instantiated it. We do
    512   // not need a full copy of the function library, just the reachable subset.
    513   *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
    514 
    515   VLOG(3) << absl::Substitute(
    516       "Deleted $0 unreachable functions from the Grappler function item "
    517       "instantiation of $1 (library size = $2)",
    518       flib.num_functions() - function_body.library().function_size(),
    519       signature.name(), function_body.library().function_size());
    520 
    521   // TODO(ezhulenev): support functions with tensor sequence inputs/outputs
    522 
    523   // Make sure that there are no tensor lists in inputs or outputs.
    524   for (const OpDef::ArgDef& input : signature.input_arg()) {
    525     if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
    526       return errors::InvalidArgument(
    527           "Inputs with lists of tensors are not supported. Input: ",
    528           input.name());
    529     }
    530   }
    531   for (const OpDef::ArgDef& output : signature.output_arg()) {
    532     if (!output.type_list_attr().empty() || !output.number_attr().empty()) {
    533       return errors::InvalidArgument(
    534           "Outputs with lists of tensors are not supported. Output: ",
    535           output.name());
    536     }
    537   }
    538 
    539   std::vector<InputArgExpansion> inputs;
    540   inputs.reserve(signature.input_arg_size());
    541 
    542   // For each input argument create a placeholder in function body.
    543   for (const OpDef::ArgDef& input : signature.input_arg()) {
    544     DataType input_data_type;
    545     TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type));
    546 
    547     NodeDef* placeholder = function_body.add_node();
    548     placeholder->set_name(input.name());
    549     placeholder->set_op("Placeholder");
    550     (*placeholder->mutable_attr())["dtype"].set_type(input_data_type);
    551     (*placeholder->mutable_attr())["shape"].mutable_shape()->set_unknown_rank(
    552         true);
    553 
    554     InputArgExpansion input_expansion{/*input_name=*/input.name(),
    555                                       /*data_type=*/input_data_type,
    556                                       /*is_ref=*/input.is_ref(),
    557                                       /*placeholders=*/{input.name()}};
    558     connectivity.RegisterInputArgExpansion(input_expansion);
    559     inputs.push_back(std::move(input_expansion));
    560   }
    561 
    562   // Keep names of all nodes in the function body to guarantee that we do not
    563   // add an identity with a duplicate name.
    564   absl::flat_hash_set<absl::string_view> func_body_nodes;
    565 
    566   // Generate unique output node name: "${out_arg_name}_output_node_${index}".
    567   const auto output_node_name = [&func_body_nodes](const OpDef::ArgDef& out,
    568                                                    int index) -> string {
    569     string name = absl::StrCat(out.name(), "_output_node_", index);
    570     int i = 1;
    571     while (func_body_nodes.find(name) != func_body_nodes.end()) {
    572       name = absl::StrCat(out.name(), "_output_node_", index, "_", i++);
    573     }
    574     return name;
    575   };
    576 
    577   // Add all function nodes to the function body.
    578   for (const NodeDef& func_def_node : func.node_def()) {
    579     func_body_nodes.insert(func_def_node.name());
    580 
    581     NodeDef* new_node = function_body.add_node();
    582     *new_node = func_def_node;
    583 
    584     const OpRegistrationData* registration;
    585     TF_RETURN_IF_ERROR(flib.LookUp(func_def_node.op(), &registration));
    586 
    587     // Resolve all placeholder values using function instantiation attributes.
    588     TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders(
    589         func_instantiation_attr, new_node));
    590 
    591     // Register node output range in a function connectivity.
    592     TF_RETURN_IF_ERROR(RegisterFunctionBodyOutputs(*registration, func_def_node,
    593                                                    &connectivity));
    594   }
    595 
    596   // Rewrite inputs to use GraphDef format
    597   for (NodeDef& node : *function_body.mutable_node()) {
    598     TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
    599   }
    600 
    601   std::vector<OutputArgExpansion> outputs;
    602   outputs.reserve(signature.output_arg_size());
    603 
    604   // For each function output argument we create an Identity node in the
    605   // function body, that reads output tensor from the function body node.
    606   for (const OpDef::ArgDef& out : signature.output_arg()) {
    607     DataType output_data_type;
    608     TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type));
    609 
    610     std::vector<string> output_tensors;
    611     auto ret = func.ret().find(out.name());
    612     TF_RETURN_IF_ERROR(
    613         ret != func.ret().end()
    614             // Expand outputs using provided output mapping
    615             ? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)
    616             // Otherwise output must be one of the function inputs
    617             : connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
    618 
    619     absl::InlinedVector<string, 1> output_nodes;
    620     for (int i = 0; i < output_tensors.size(); ++i) {
    621       const string& output_tensor = output_tensors[i];
    622 
    623       NodeDef* identity = function_body.add_node();
    624       identity->set_name(output_node_name(out, i));
    625       identity->set_op("Identity");
    626       (*identity->mutable_attr())["T"].set_type(output_data_type);
    627       identity->add_input(output_tensor);
    628 
    629       output_nodes.push_back(identity->name());
    630     }
    631 
    632     OutputArgExpansion output{/*output_name=*/out.name(),
    633                               /*data_type=*/output_data_type,
    634                               /*is_ref=*/out.is_ref(),
    635                               /*output_nodes=*/std::move(output_nodes)};
    636     outputs.push_back(std::move(output));
    637   }
    638 
    639   // Control outputs ensure that all side-effectful nodes in the function body
    640   // will execute, even if they are not required to compute regular output args.
    641   std::vector<ControlOutput> control_outputs;
    642   control_outputs.reserve(func.control_ret_size());
    643   for (const auto& control_ret : func.control_ret()) {
    644     control_outputs.push_back({control_ret.first, control_ret.second});
    645   }
    646 
    647   *item = GrapplerFunctionItem(
    648       /*func_name=*/signature.name(),
    649       /*description=*/signature.description(),
    650       /*func_attr=*/AttrSlice(&func.attr()), std::move(inputs),
    651       std::move(outputs), std::move(control_outputs), graph_def_version,
    652       signature.is_stateful(), std::move(function_body));
    653   return Status::OK();
    654 }
    655 
    656 Status MakeGrapplerFunctionItem(const FunctionDef& func,
    657                                 const FunctionLibraryDefinition& flib,
    658                                 const int graph_def_version,
    659                                 GrapplerFunctionItem* item) {
    660   return MakeGrapplerFunctionItem(func, AttrSlice(), flib, graph_def_version,
    661                                   item);
    662 }
    663 
    664 // Register GrapplerFunctionItem input arg expansion and function body outputs
    665 // in the GrapplerFunctionConnectivity.
    666 Status RegisterGrapplerFunctionConnectivity(
    667     const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
    668     GrapplerFunctionConnectivity* connectivity) {
    669   for (const InputArgExpansion& input : item.inputs()) {
    670     connectivity->RegisterInputArgExpansion(input);
    671   }
    672   for (const NodeDef& func_body_node : item.function_body().node()) {
    673     TF_RETURN_IF_ERROR(
    674         RegisterFunctionBodyOutputs(flib, func_body_node, connectivity));
    675   }
    676   return Status::OK();
    677 }
    678 
    679 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
    680                              GrapplerFunctionItem* item) {
    681   if (!IsConstant(input_const)) {
    682     return errors::InvalidArgument("Input node ", input_const.name(),
    683                                    " is not a constant");
    684   }
    685 
    686   auto& inputs = item->input_arg_expansions_;
    687 
    688   // Find input arg expansion and input placeholder position in it for the
    689   // given function input position.
    690   InputArgExpansion* input_arg_expansion = nullptr;
    691   int placeholder_idx = input_index;
    692 
    693   for (InputArgExpansion& input : inputs) {
    694     if (placeholder_idx < input.placeholders.size()) {
    695       input_arg_expansion = &input;
    696       break;
    697     }
    698     placeholder_idx -= input.placeholders.size();
    699   }
    700 
    701   if (input_arg_expansion == nullptr) {
    702     return errors::InvalidArgument("Input placeholder not found: input_index=",
    703                                    input_index, " function=", item->id);
    704   }
    705 
    706   // Delete placeholder from input expansion.
    707   string placeholder_name = input_arg_expansion->placeholders[placeholder_idx];
    708   input_arg_expansion->placeholders.erase(
    709       input_arg_expansion->placeholders.begin() + placeholder_idx);
    710 
    711   // Delete empty input expansions.
    712   inputs.erase(std::remove_if(inputs.begin(), inputs.end(),
    713                               [](const InputArgExpansion& input) {
    714                                 return input.placeholders.empty();
    715                               }),
    716                inputs.end());
    717 
    718   // Replace placeholder node in the function body with a const node.
    719   for (NodeDef& node : *item->graph.mutable_node()) {
    720     if (node.name() == placeholder_name) {
    721       node = input_const;
    722       node.set_name(placeholder_name);
    723       node.clear_input();   // remove potential control inputs
    724       node.clear_device();  // device placement is defined by instantiating node
    725     }
    726   }
    727 
    728   return Status::OK();
    729 }
    730 
    731 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
    732                              GrapplerFunctionItem* item,
    733                              std::vector<std::pair<int, int>>* output_mapping) {
    734   DCHECK(output_mapping->empty());
    735 
    736   // Code below assumes that we do not support tensor list outputs and there is
    737   // a 1-to-1 mapping between output tensor and output argument expansion.
    738   for (const OutputArgExpansion& out_arg : item->outputs()) {
    739     DCHECK(out_arg.output_nodes.size() == 1)
    740         << "Output arg expansion must have single output";
    741   }
    742 
    743   // Do some sanity checking of the removed outputs positions.
    744   for (int remove_output : remove_outputs) {
    745     if (remove_output < 0 || remove_output >= item->output_size()) {
    746       return errors::InvalidArgument(
    747           "Function output index is out of bound: index=", remove_output,
    748           " max_output_index=", item->output_size());
    749     }
    750   }
    751 
    752   absl::flat_hash_set<const OutputArgExpansion*> remove_output_args;
    753   const auto is_remove_output_arg = [&](const OutputArgExpansion& output) {
    754     return remove_output_args.find(&output) != remove_output_args.end();
    755   };
    756 
    757   for (int i = 0; i < item->output_size(); ++i) {
    758     const OutputArgExpansion& output = item->output(i);
    759     if (remove_outputs.find(i) != remove_outputs.end()) {
    760       VLOG(3) << "Remove functions output: output_name=" << output.output_name
    761               << "(index = " << i << ")";
    762       remove_output_args.insert(&output);
    763     } else if (!remove_output_args.empty()) {
    764       // Add output mapping only if output position changed.
    765       output_mapping->push_back({i, i - remove_output_args.size()});
    766     }
    767   }
    768 
    769   auto& o = item->output_arg_expansions_;
    770   o.erase(std::remove_if(o.begin(), o.end(), is_remove_output_arg), o.end());
    771 
    772   return Status::OK();
    773 }
    774 
    775 Status MakeFunctionDef(const GrapplerFunctionItem& item,
    776                        const FunctionLibraryDefinition& flib,
    777                        FunctionDef* func) {
    778   func->mutable_signature()->set_name(item.id);
    779   func->mutable_signature()->set_description(item.description());
    780   func->mutable_signature()->set_is_stateful(item.is_stateful());
    781 
    782   // Keep track of placeholders that were added to the graph in place of
    783   // expanded function input arguments.
    784   absl::flat_hash_set<absl::string_view> input_placeholders;
    785   for (const InputArgExpansion& input_arg : item.inputs()) {
    786     for (const string& placeholder : input_arg.placeholders) {
    787       input_placeholders.insert(placeholder);
    788     }
    789   }
    790 
    791   // Keep track of identity nodes that were added to the graph in place of
    792   // expanded function output arguments.
    793   absl::flat_hash_set<absl::string_view> output_nodes;
    794   for (const OutputArgExpansion& output_arg : item.outputs()) {
    795     for (const string& output_node : output_arg.output_nodes) {
    796       output_nodes.insert(output_node);
    797     }
    798   }
    799 
    800   // If the output identity node was not modified by any optimizer, we can
    801   // bypass it and returns the function value from its input.
    802   absl::flat_hash_map<absl::string_view, string> output_tensors;
    803   for (const NodeDef& func_body_node : item.function_body().node()) {
    804     if (!IsIdentity(func_body_node)) continue;
    805 
    806     const string& node_name = func_body_node.name();
    807     if (output_nodes.find(node_name) != output_nodes.end()) {
    808       // Grappler optimizers might optimize nodes in the fanin of the output
    809       // node, and forward their control dependencies. We can't express control
    810       // dependencies in a function signature, so we have to keep the node.
    811       if (func_body_node.input_size() == 1) {
    812         VLOG(3) << "Bypass function output node: " << node_name << " -> "
    813                 << func_body_node.input(0);
    814         output_tensors.emplace(node_name, func_body_node.input(0));
    815       } else {
    816         VLOG(3) << "Keep function output node: " << node_name;
    817       }
    818     }
    819   }
    820 
    821   // Return output tensor name (input of the output node) if it's safe to bypass
    822   // output node, otherwise returns the output node name.
    823   const auto output_tensor =
    824       [&output_tensors](const OutputArgExpansion& output_arg) -> const string& {
    825     const string& output_node = output_arg.output_nodes[0];
    826     const auto is_output_tensor = output_tensors.find(output_node);
    827     return is_output_tensor == output_tensors.end() ? output_node
    828                                                     : is_output_tensor->second;
    829   };
    830 
    831   // Build a GrapplerFunctionConnectivity from inputs and new function body.
    832   GrapplerFunctionConnectivity connectivity;
    833   TF_RETURN_IF_ERROR(
    834       RegisterGrapplerFunctionConnectivity(item, flib, &connectivity));
    835 
    836   // Add function input arguments.
    837   for (const InputArgExpansion& input_arg : item.inputs()) {
    838     DCHECK(input_arg.placeholders.size() == 1)  // do some sanity checking
    839         << "Inputs of tensor lists are not supported";
    840 
    841     OpDef::ArgDef arg_def;
    842     arg_def.set_name(input_arg.input_name);
    843     arg_def.set_type(input_arg.data_type);
    844     arg_def.set_is_ref(input_arg.is_ref);
    845     *func->mutable_signature()->add_input_arg() = arg_def;
    846   }
    847 
    848   // Add function output arguments.
    849   for (const OutputArgExpansion& output_arg : item.outputs()) {
    850     DCHECK(output_arg.output_nodes.size() == 1)  // do some sanity checking
    851         << "Outputs of tensor lists are not supported";
    852 
    853     OpDef::ArgDef arg_def;
    854     arg_def.set_name(output_arg.output_name);
    855     arg_def.set_type(output_arg.data_type);
    856     arg_def.set_is_ref(output_arg.is_ref);
    857     *func->mutable_signature()->add_output_arg() = arg_def;
    858 
    859     TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(
    860         output_tensor(output_arg),
    861         &(*func->mutable_ret())[output_arg.output_name]));
    862   }
    863 
    864   // Add function control outputs.
    865   for (const ControlOutput& control_out : item.control_outputs()) {
    866     func->mutable_control_ret()->insert(
    867         {control_out.output_name, control_out.node_name});
    868     *func->mutable_signature()->add_control_output() = control_out.output_name;
    869   }
    870 
    871   // Copy function definition specific attributes.
    872   for (const auto& attr : item.func_attr()) {
    873     const auto& attr_name = attr.first;
    874     const auto& attr_value = attr.second;
    875     (*func->mutable_attr())[attr_name] = attr_value;
    876   }
    877 
    878   // Copy function body nodes to the FunctionDef and update input format
    879   for (const NodeDef& func_node : item.function_body().node()) {
    880     const string& name = func_node.name();
    881 
    882     // Do not copy input placeholders.
    883     if (IsPlaceholder(func_node) && input_placeholders.count(name)) continue;
    884     // Do not copy output nodes that we bypassed.
    885     if (IsIdentity(func_node) && output_tensors.count(name)) continue;
    886 
    887     NodeDef* func_def_node = func->add_node_def();
    888     *func_def_node = func_node;
    889     TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node));
    890   }
    891 
    892   return Status::OK();
    893 }
    894 
    895 }  // end namespace grappler
    896 }  // end namespace tensorflow
    897