Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/function.h"
     17 
     18 #include <map>
     19 #include <unordered_map>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "absl/container/flat_hash_set.h"
     24 #include "absl/strings/str_join.h"
     25 #include "tensorflow/core/framework/common_shape_fns.h"
     26 #include "tensorflow/core/framework/function.pb_text.h"
     27 #include "tensorflow/core/framework/graph.pb.h"
     28 #include "tensorflow/core/framework/node_def.pb.h"
     29 #include "tensorflow/core/framework/node_def_util.h"
     30 #include "tensorflow/core/framework/op.h"
     31 #include "tensorflow/core/graph/graph.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     34 #include "tensorflow/core/lib/gtl/map_util.h"
     35 #include "tensorflow/core/lib/strings/str_util.h"
     36 #include "tensorflow/core/util/device_name_utils.h"
     37 #include "tensorflow/core/util/equal_graph_def.h"
     38 
     39 namespace tensorflow {
     40 
     41 // Extracts the actual type from "attr_values" based on its definition
     42 // "arg_def".
     43 //
     44 // If "arg_def" is a N*T type, *is_type_list is set to false, and
     45 // *dtypes is set to be a vector of size N and each element is T.
     46 //
     47 // If "arg_def" is a list(type), *is_type_list is set to true, and
     48 // *dtypes is set to be a vector of types specified in attrs for
     49 // arg_def.
     50 //
     51 // Otherwise (arg_def is a simple type T), *is_type_list is set to
     52 // false, and *dtypes is set to a single element vector, whose only
     53 // element is T.
     54 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
     55                   bool* is_type_list, DataTypeVector* dtypes) {
     56   dtypes->clear();
     57   if (!arg_def.type_list_attr().empty()) {
     58     const AttrValue* v = attrs.Find(arg_def.type_list_attr());
     59     if (v == nullptr) {
     60       return errors::NotFound("type attr not found: ",
     61                               arg_def.type_list_attr());
     62     }
     63     *is_type_list = true;
     64     for (int i = 0; i < v->list().type_size(); ++i) {
     65       dtypes->push_back(v->list().type(i));
     66     }
     67     return Status::OK();
     68   }
     69 
     70   *is_type_list = false;
     71   int num = 1;
     72   if (!arg_def.number_attr().empty()) {
     73     const AttrValue* v = attrs.Find(arg_def.number_attr());
     74     if (v == nullptr) {
     75       return errors::NotFound("type attr not found: ", arg_def.type_attr());
     76     }
     77     num = v->i();
     78   }
     79 
     80   DataType dtype;
     81   if (arg_def.type() != DT_INVALID) {
     82     dtype = arg_def.type();
     83   } else if (arg_def.type_attr().empty()) {
     84     dtype = DT_INVALID;
     85   } else {
     86     const AttrValue* v = attrs.Find(arg_def.type_attr());
     87     if (v == nullptr) {
     88       return errors::NotFound("type attr not found: ", arg_def.type_attr());
     89     }
     90     dtype = v->type();
     91   }
     92   dtypes->resize(num, dtype);
     93   return Status::OK();
     94 }
     95 
     96 namespace {
     97 
     98 template <typename T>
     99 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
    100   SetAttrValue(val, &((*ndef->mutable_attr())[name]));
    101 }
    102 
    103 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
    104   // attr_values should specify all attrs defined in fdef.
    105   for (const auto& a : sig.attr()) {
    106     const AttrValue* v = attr_values.Find(a.name());
    107     if (!v) {
    108       return errors::NotFound("Attr ", a.name(), " is not found from ",
    109                               SummarizeOpDef(sig));
    110     }
    111     Status status = AttrValueHasType(*v, a.type());
    112     if (!status.ok()) {
    113       errors::AppendToMessage(&status, "for attr '", a.name(), "'");
    114       return status;
    115     }
    116   }
    117 
    118 // TODO(josh11b): Enable this code once it works with function gradients.
    119 // Right now the C++ function gradient code assumes it can pass
    120 // all the attrs of the function to the gradient, and any attrs that
    121 // the gradient doesn't care about will be ignored.
    122 #if 0
    123   if (attr_values.size() != sig.attr_size()) {
    124     for (const auto& a : attr_values) {
    125       // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
    126       bool found = false;
    127       for (const auto& s : sig.attr()) {
    128         if (a.first == s.name()) {
    129           found = true;
    130           break;
    131         }
    132       }
    133       if (!found) {
    134         return errors::NotFound("Attr ", a.first, " is not found in ",
    135                                 SummarizeOpDef(sig));
    136       }
    137     }
    138   }
    139 #endif
    140 
    141   return Status::OK();
    142 }
    143 
    144 // A helper class for instantiating functions. This contains shared information
    145 // like the resulting graph and node name index.
    146 class FunctionInstantiationHelper {
    147  public:
    148   FunctionInstantiationHelper(GetFunctionSignature get_function,
    149                               InstantiationResult* result)
    150       : get_function_(std ::move(get_function)), result_(*result) {
    151     result_.nodes.clear();
    152   }
    153 
    154   // Builds index for nodes that can be used as node's input arguments.
    155   Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
    156                             bool ints_on_device) {
    157     bool is_type_list;
    158     DataTypeVector dtypes;
    159     TF_RETURN_IF_ERROR(
    160         ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
    161     CHECK_GE(dtypes.size(), size_t{1});
    162     int arg_index = result_.nodes.size();
    163     TF_RETURN_IF_ERROR(
    164         AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
    165     // Creates dtypes.size() nodes in the graph.
    166     for (size_t i = 0; i < dtypes.size(); ++i) {
    167       TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
    168                                  {true, arg_index, 0, false, {dtypes[i]}}));
    169       DCHECK_EQ(arg_index, result_.nodes.size());
    170       string name = arg_def.name();
    171       if (dtypes.size() > 1) {
    172         strings::StrAppend(&name, "_", i);
    173       }
    174       NodeDef* gnode = AddNode(name);
    175       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
    176         gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
    177       } else {
    178         gnode->set_op(FunctionLibraryDefinition::kArgOp);
    179       }
    180       AddAttr("T", dtypes[i], gnode);
    181       AddAttr("index", arg_index, gnode);
    182       result_.arg_types.push_back(dtypes[i]);
    183       ++arg_index;
    184     }
    185     return Status::OK();
    186   }
    187 
    188   Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
    189                               const int arg_index) {
    190     const OpDef* node_sig = nullptr;
    191     TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
    192     if (node_sig->output_arg_size() == 0) {
    193       return AddItem(node.name(), {false, arg_index, 0, false, {}});
    194     }
    195     const int num_retval = node_sig->output_arg_size();
    196     int start = 0;
    197     bool is_type_list;
    198     DataTypeVector dtypes;
    199     for (int i = 0; i < num_retval; ++i) {
    200       TF_RETURN_IF_ERROR(
    201           ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
    202       // Note that we rely on the backwards-compatibility test enforcing
    203       // that output_arg(*).name() doesn't change here.
    204       const string base_name =
    205           strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
    206       TF_RETURN_IF_ERROR(
    207           AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
    208       for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
    209         TF_RETURN_IF_ERROR(
    210             AddItem(strings::StrCat(base_name, ":", j),
    211                     {false, arg_index, start + j, false, {dtypes[j]}}));
    212       }
    213       start += dtypes.size();
    214     }
    215     return Status::OK();
    216   }
    217 
    218   Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
    219     const OpDef* fnode_sig = nullptr;
    220     TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
    221     NodeDef* gnode = AddNode(fnode.name());
    222     gnode->set_op(fnode.op());
    223     gnode->set_device(fnode.device());
    224     int gnode_idx = nodes_.size() - 1;
    225 
    226     // Input
    227     const int num_args = fnode_sig->input_arg_size();
    228     bool is_type_list;  // ignored
    229     DataTypeVector dtypes;
    230     int fnode_arg_index = 0;
    231     for (int i = 0; i < num_args; ++i) {
    232       TF_RETURN_IF_ERROR(
    233           ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
    234       // Consume inputs (indexed by fnode_arg_index) until we have
    235       // matched each element of dtypes (indexed by j).
    236       for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
    237         if (fnode_arg_index >= fnode.input_size()) {
    238           // Should never happen if we computed dtypes correctly.
    239           return errors::InvalidArgument(
    240               "Attempt to access beyond input size: ", fnode_arg_index,
    241               " >= ", fnode.input_size());
    242         }
    243         // Look up the next input.
    244         const string& input_name = fnode.input(fnode_arg_index);
    245         const auto* item = GetItemOrNull(input_name);
    246         if (item == nullptr) {
    247           return errors::InvalidArgument(
    248               "input ", input_name,
    249               " is not found: ", FormatNodeDefForError(fnode));
    250         }
    251         if (item->dtypes.size() > dtypes.size() - j) {
    252           return errors::InvalidArgument("Input ", input_name, " too long for ",
    253                                          fnode_sig->input_arg(i).name());
    254         }
    255         // Match up all the elements of this input (indexed by k) with
    256         // elements of dtypes (advancing j).
    257         for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
    258           if (item->dtypes[k] != dtypes[j]) {
    259             return errors::InvalidArgument(
    260                 "input ", fnode_sig->input_arg(i).name(), "[", j,
    261                 "] expected type ", DataTypeString(dtypes[j]),
    262                 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
    263                 input_name, "[", k, "]");
    264           }
    265           if (item->is_func_arg) {
    266             AddInput(gnode_idx, item->nid + k, 0);
    267           } else {
    268             AddInput(gnode_idx, item->nid, item->idx + k);
    269           }
    270         }
    271       }
    272     }
    273 
    274     // Control deps.
    275     for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
    276       const string& input = fnode.input(i);
    277       if (input.empty() || input[0] != '^') {
    278         return errors::InvalidArgument("Expected input[", i, "] == '", input,
    279                                        "' to be a control input.");
    280       }
    281       int nid = -1;
    282       const string node_name = input.substr(1);
    283       const string node_colon = node_name + ":";
    284       const string node_colon_bound = node_name + ";";
    285       // index_ is a map sorted lexicographically, so the key we are looking for
    286       // must lie in the range [node_name, node_colon_bound).
    287       auto it = index_.lower_bound(node_name);
    288       while (it != index_.end() && it->first <= node_colon_bound) {
    289         if (it->first == node_name ||
    290             tensorflow::str_util::StartsWith(it->first, node_colon)) {
    291           nid = it->second.nid;
    292           break;
    293         }
    294         ++it;
    295       }
    296       if (nid == -1) {
    297         return errors::InvalidArgument("input[", i, "] == '", input,
    298                                        "', is not found.");
    299       }
    300       AddDep(gnode_idx, nid);
    301     }
    302 
    303     // Attrs.
    304     for (const auto& p : attrs) {
    305       (*gnode->mutable_attr())[p.first] = p.second;
    306     }
    307 
    308     return Status::OK();
    309   }
    310 
    311   Status AddReturnNode(
    312       const OpDef::ArgDef& ret_def, AttrSlice attrs,
    313       const ::tensorflow::protobuf::Map<string, string>& ret_map,
    314       bool ints_on_device, int* ret_index) {
    315     auto ret_iter = ret_map.find(ret_def.name());
    316     if (ret_iter == ret_map.end()) {
    317       return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
    318     }
    319     bool is_type_list;
    320     DataTypeVector dtypes;
    321     TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
    322     CHECK_GE(dtypes.size(), size_t{1});
    323     const auto* item = GetItemOrNull(ret_iter->second);
    324     if (item == nullptr) {
    325       return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
    326                                      ret_iter->second, " is not found.");
    327     }
    328     if (dtypes != item->dtypes) {
    329       return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
    330                                      " : ", DataTypeVectorString(dtypes),
    331                                      " vs. ",
    332                                      DataTypeVectorString(item->dtypes));
    333     }
    334     for (size_t i = 0; i < dtypes.size(); ++i) {
    335       string name = strings::StrCat(ret_def.name(), "_RetVal");
    336       if (dtypes.size() > 1) {
    337         strings::StrAppend(&name, "_", i);
    338       }
    339       NodeDef* gnode = AddNode(name);
    340       if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
    341         gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
    342       } else {
    343         gnode->set_op(FunctionLibraryDefinition::kRetOp);
    344       }
    345       AddInput(nodes_.size() - 1, item->nid, item->idx + i);
    346       AddAttr("T", dtypes[i], gnode);
    347       AddAttr("index", (*ret_index)++, gnode);
    348       result_.ret_types.push_back(dtypes[i]);
    349     }
    350     return Status::OK();
    351   }
    352 
    353   // Adds the actual node inputs to the result graph by converting indexes to
    354   // the node names.
    355   void AddNodeInputs() {
    356     for (int i = 0; i < result_.nodes.size(); i++) {
    357       NodeInfo& node_info = nodes_[i];
    358       for (const auto& p : node_info.data_inputs) {
    359         result_.nodes[i].add_input(Name(p.first, p.second));
    360       }
    361       for (int index : node_info.control_inputs) {
    362         result_.nodes[i].add_input(Dep(index));
    363       }
    364     }
    365   }
    366 
    367  private:
    368   // This is used to build a small index for all names that can be used as a
    369   // node's input arguments.
    370   //
    371   // If is_func_arg is true, the name is a function's argument.  In
    372   // this case, the produced graph def has node[nid:nid + dtype.size()].
    373   //
    374   // Otherwise, the name is a function body's node return value.  In
    375   // this case, the produced graph def has one node node[nid] and
    376   // the node's output index [idx ... idx + num) corresponds to the
    377   // named outputs.
    378   //
    379   // In all cases, "dtype" specifies the data type.
    380   struct NameInfoItem {
    381     bool is_func_arg;
    382     int nid;
    383     int idx;
    384     bool is_type_list;
    385     DataTypeVector dtypes;
    386   };
    387 
    388   // Adds an item into the input name index.
    389   Status AddItem(const string& name, const NameInfoItem& item) {
    390     if (!index_.insert({name, item}).second) {
    391       return errors::InvalidArgument(
    392           strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
    393                           " name: "),
    394           name);
    395     }
    396     return Status::OK();
    397   }
    398 
    399   const NameInfoItem* GetItemOrNull(const string& name) const {
    400     return gtl::FindOrNull(index_, name);
    401   }
    402 
    403   string Dep(int node_index) const {
    404     return strings::StrCat("^", Name(node_index));
    405   }
    406 
    407   string Name(int node_index) const {
    408     CHECK_LT(node_index, nodes_.size());
    409     return nodes_[node_index].name;
    410   }
    411 
    412   string Name(int node_index, int output_index) const {
    413     if (output_index == 0) {
    414       return Name(node_index);
    415     } else {
    416       return strings::StrCat(Name(node_index), ":", output_index);
    417     }
    418   }
    419 
    420   NodeDef* AddNode(const string& name) {
    421     result_.nodes.emplace_back();
    422     NodeDef* gnode = &result_.nodes.back();
    423     gnode->set_name(name);
    424     nodes_.push_back({name, {}, {}});
    425     CHECK_EQ(result_.nodes.size(), nodes_.size());
    426     return gnode;
    427   }
    428 
    429   void AddInput(int node_index, int output_node, int output_index) {
    430     CHECK_LT(node_index, nodes_.size());
    431     nodes_[node_index].data_inputs.push_back(
    432         std::make_pair(output_node, output_index));
    433   }
    434 
    435   void AddDep(int node_index, int dep_index) {
    436     CHECK_LT(node_index, nodes_.size());
    437     nodes_[node_index].control_inputs.push_back(dep_index);
    438   }
    439 
    440   GetFunctionSignature get_function_;
    441   InstantiationResult& result_;
    442   // A small index for all names that can be used as a node's input arguments.
    443   std::map<string, NameInfoItem> index_;
    444   // This contains information about a node in the new graph including the node
    445   // names and input nodes' indexes.
    446   struct NodeInfo {
    447     string name;
    448     // Data inputs where <n, k> means arg k of node n.
    449     std::vector<std::pair<int, int>> data_inputs;
    450     // Control inputs (dependencies).
    451     std::vector<int> control_inputs;
    452   };
    453   // nodes_[i] is the information about result_.nodes[i].
    454   std::vector<NodeInfo> nodes_;
    455 };
    456 
    457 // Various helpers Print(proto) to print relevant protos to ascii.
    458 string Print(const OpDef::ArgDef& arg) {
    459   string out;
    460   strings::StrAppend(&out, arg.name(), ":");
    461   if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
    462   if (!arg.number_attr().empty()) {
    463     strings::StrAppend(&out, arg.number_attr(), "*");
    464   }
    465   if (arg.type() != DT_INVALID) {
    466     strings::StrAppend(&out, DataTypeString(arg.type()));
    467   } else {
    468     strings::StrAppend(&out, arg.type_attr());
    469   }
    470   if (arg.is_ref()) strings::StrAppend(&out, ")");
    471   return out;
    472 }
    473 
    474 // TODO(josh11b): Merge this with SummarizeAttrValue().
    475 string Print(const AttrValue& attr_value) {
    476   if (attr_value.value_case() == AttrValue::kType) {
    477     return DataTypeString(attr_value.type());
    478   } else if ((attr_value.value_case() == AttrValue::kList) &&
    479              (attr_value.list().type_size() > 0)) {
    480     string ret = "{";
    481     for (int i = 0; i < attr_value.list().type_size(); ++i) {
    482       if (i > 0) strings::StrAppend(&ret, ", ");
    483       strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
    484     }
    485     strings::StrAppend(&ret, "}");
    486     return ret;
    487   } else if (attr_value.value_case() == AttrValue::kFunc) {
    488     if (attr_value.func().attr_size() == 0) {
    489       return attr_value.func().name();
    490     }
    491     std::vector<string> entries;
    492     for (auto p : attr_value.func().attr()) {
    493       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
    494     }
    495     std::sort(entries.begin(), entries.end());
    496     return strings::StrCat(attr_value.func().name(), "[",
    497                            str_util::Join(entries, ", "), "]");
    498   }
    499   return SummarizeAttrValue(attr_value);
    500 }
    501 
    502 // TODO(josh11b): Merge this with SummarizeNodeDef().
    503 string Print(const NodeDef& n) {
    504   string out;
    505   strings::StrAppend(&out, n.name(), " = ", n.op());
    506   if (n.attr_size() > 0) {
    507     std::vector<string> entries;
    508     for (auto& a : n.attr()) {
    509       entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
    510     }
    511     std::sort(entries.begin(), entries.end());
    512     // Add a short device string at the end of all attributes.
    513     if (!n.device().empty()) {
    514       DeviceNameUtils::ParsedName parsed;
    515       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
    516         entries.push_back(
    517             strings::StrCat("device=", parsed.type, ":", parsed.id));
    518       } else {
    519         entries.push_back("device=<FAILED_TO_PARSE>");
    520       }
    521     }
    522     strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
    523   }
    524   strings::StrAppend(&out, "(");
    525   std::vector<StringPiece> dat;
    526   std::vector<string> dep;
    527   for (StringPiece s : n.input()) {
    528     if (str_util::ConsumePrefix(&s, "^")) {
    529       dep.emplace_back(s);
    530     } else {
    531       dat.push_back(s);
    532     }
    533   }
    534   strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
    535   if (!dep.empty()) {
    536     strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
    537   }
    538   return out;
    539 }
    540 
    541 string Print(const FunctionDef& fdef) {
    542   string out;
    543   const OpDef& sig = fdef.signature();
    544   strings::StrAppend(&out, "\n", sig.name());
    545   if (sig.attr_size() > 0) {
    546     strings::StrAppend(&out, "[");
    547     for (int i = 0; i < sig.attr_size(); ++i) {
    548       const auto& a = sig.attr(i);
    549       if (i > 0) strings::StrAppend(&out, ", ");
    550       if (a.type() == "type") {
    551         strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
    552       } else {
    553         strings::StrAppend(&out, a.name(), ":", a.type());
    554       }
    555     }
    556     strings::StrAppend(&out, "]");
    557   }
    558   strings::StrAppend(&out, "(");
    559   for (int i = 0; i < sig.input_arg_size(); ++i) {
    560     if (i > 0) strings::StrAppend(&out, ", ");
    561     strings::StrAppend(&out, Print(sig.input_arg(i)));
    562   }
    563   strings::StrAppend(&out, ") -> (");
    564   for (int i = 0; i < sig.output_arg_size(); ++i) {
    565     if (i > 0) strings::StrAppend(&out, ", ");
    566     strings::StrAppend(&out, Print(sig.output_arg(i)));
    567   }
    568   strings::StrAppend(&out, ") {\n");
    569   for (const auto& n : fdef.node_def()) {
    570     strings::StrAppend(&out, "  ", Print(n), "\n");
    571   }
    572   for (const auto& cr : fdef.control_ret()) {
    573     strings::StrAppend(&out, "  @return ", cr.first, " = ", cr.second, "\n");
    574   }
    575   for (const auto& r : fdef.ret()) {
    576     strings::StrAppend(&out, "  return ", r.first, " = ", r.second, "\n");
    577   }
    578   strings::StrAppend(&out, "}\n");
    579   return out;
    580 }
    581 
    582 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
    583   std::vector<const NodeDef*> arg;
    584   std::vector<const NodeDef*> ret;
    585   std::vector<const NodeDef*> body;
    586   for (const NodeDef* n : nodes) {
    587     if (n->op() == FunctionLibraryDefinition::kArgOp ||
    588         n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
    589       arg.push_back(n);
    590     } else if (n->op() == FunctionLibraryDefinition::kRetOp ||
    591                n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
    592       ret.push_back(n);
    593     } else {
    594       body.push_back(n);
    595     }
    596   }
    597   auto comp = [](const NodeDef* x, const NodeDef* y) {
    598     int xi;
    599     TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
    600     int yi;
    601     TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
    602     return xi < yi;
    603   };
    604   std::sort(arg.begin(), arg.end(), comp);
    605   std::sort(ret.begin(), ret.end(), comp);
    606   string out;
    607   strings::StrAppend(&out, "\n(");
    608   auto get_type_and_device = [](const NodeDef& n) {
    609     DataType dt;
    610     if (!GetNodeAttr(n, "T", &dt).ok()) {
    611       dt = DT_INVALID;
    612     }
    613     if (!n.device().empty()) {
    614       DeviceNameUtils::ParsedName parsed;
    615       if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
    616         return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
    617                                parsed.id);
    618       } else {
    619         LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
    620                      << n.op() << ":" << n.name();
    621         return strings::StrCat(DataTypeString(dt), "@",
    622                                "<FAILED_TO_PARSE_DEVICE>");
    623       }
    624     }
    625     return DataTypeString(dt);
    626   };
    627   for (size_t i = 0; i < arg.size(); ++i) {
    628     const NodeDef* n = arg[i];
    629     if (i > 0) strings::StrAppend(&out, ", ");
    630     CHECK_GE(n->attr_size(), 2);
    631     strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
    632   }
    633   strings::StrAppend(&out, ") -> (");
    634   for (size_t i = 0; i < ret.size(); ++i) {
    635     const NodeDef* n = ret[i];
    636     if (i > 0) strings::StrAppend(&out, ", ");
    637     CHECK_LE(2, n->attr_size());
    638 
    639     // The _RetVal op should have a unique non-control input. We assert that
    640     // here and add it to the output.
    641     bool found_non_control_input = false;
    642     for (const string& input : n->input()) {
    643       if (!input.empty() && input[0] != '^') {
    644         DCHECK_EQ(found_non_control_input, false)
    645             << "RetVal node has more than one non-control input: "
    646             << absl::StrJoin(n->input(), ", ");
    647         strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
    648         found_non_control_input = true;
    649       }
    650     }
    651     DCHECK_EQ(found_non_control_input, true)
    652         << "RetVal did not have any non-control inputs: "
    653         << absl::StrJoin(n->input(), ", ");
    654   }
    655   strings::StrAppend(&out, ") {\n");
    656   for (size_t i = 0; i < body.size(); ++i) {
    657     strings::StrAppend(&out, "  ", Print(*body[i]), "\n");
    658   }
    659   strings::StrAppend(&out, "}\n");
    660   return out;
    661 }
    662 
    663 Status AddDefaultAttrs(const string& op,
    664                        const GetFunctionSignature& get_function,
    665                        AttrValueMap* attrs) {
    666   const OpDef* op_def = nullptr;
    667   TF_RETURN_IF_ERROR(get_function(op, &op_def));
    668   AttrSlice attr_slice(attrs);
    669   for (const auto& attr_def : op_def->attr()) {
    670     if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
    671       if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
    672         return errors::Internal("Somehow duplicated: ", attr_def.name());
    673       }
    674     }
    675   }
    676   return Status::OK();
    677 }
    678 
    679 }  // end namespace
    680 
    681 // TODO(shikharagarwal): Transmit original node names correctly in file.
    682 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
    683                            GetFunctionSignature get_function,
    684                            InstantiationResult* result) {
    685   VLOG(4) << "Instantiation Function: " << Print(fdef);
    686 
    687   const OpDef& sig = fdef.signature();
    688   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
    689 
    690   bool ints_on_device =
    691       fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
    692       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
    693 
    694   FunctionInstantiationHelper helper(get_function, result);
    695   Status s;
    696   for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
    697     s = helper.BuildInputArgIndex(arg_def, attr_values, ints_on_device);
    698     if (!s.ok()) {
    699       errors::AppendToMessage(&s, "In ", Print(arg_def));
    700       return s;
    701     }
    702   }
    703 
    704   auto substitute = [attr_values](StringPiece name, AttrValue* val) {
    705     if (const AttrValue* v = attr_values.Find(name)) {
    706       *val = *v;
    707       return true;
    708     }
    709     return false;
    710   };
    711 
    712   // Makes a copy of all attrs in fdef and substitutes placeholders.
    713   // After this step, every attr is bound to a concrete value.
    714   std::vector<AttrValueMap> node_attrs;
    715   node_attrs.resize(fdef.node_def_size());
    716   for (int i = 0; i < fdef.node_def_size(); ++i) {
    717     for (auto attr : fdef.node_def(i).attr()) {
    718       if (!SubstitutePlaceholders(substitute, &attr.second)) {
    719         return errors::InvalidArgument("Failed to bind all placeholders in ",
    720                                        SummarizeAttrValue(attr.second));
    721       }
    722       if (!node_attrs[i].insert(attr).second) {
    723         return errors::Internal("Somehow duplicated: ", attr.first);
    724       }
    725     }
    726     TF_RETURN_IF_ERROR(
    727         AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
    728   }
    729 
    730   for (int i = 0; i < fdef.node_def_size(); ++i) {
    731     s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
    732                                     result->nodes.size() + i);
    733     if (!s.ok()) {
    734       errors::AppendToMessage(&s, "In ",
    735                               FormatNodeDefForError(fdef.node_def(i)));
    736       return s;
    737     }
    738   }
    739   // Emits one node for each fdef.node_def.
    740   for (int i = 0; i < fdef.node_def_size(); ++i) {
    741     s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
    742     if (!s.ok()) {
    743       errors::AppendToMessage(&s, "In ",
    744                               FormatNodeDefForError(fdef.node_def(i)));
    745       return s;
    746     }
    747   }
    748 
    749   // Emits nodes for the function's return values.
    750   int ret_index = 0;
    751   for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
    752     s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
    753                              &ret_index);
    754     if (!s.ok()) {
    755       errors::AppendToMessage(&s, "In function output ", Print(ret_def));
    756       return s;
    757     }
    758   }
    759 
    760   // Adds the actual node inputs using the input indexes.
    761   helper.AddNodeInputs();
    762 
    763   return Status::OK();
    764 }
    765 
    766 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
    767 
    768 string DebugString(const GraphDef& instantiated_func_def) {
    769   std::vector<const NodeDef*> ptrs;
    770   for (const NodeDef& n : instantiated_func_def.node()) {
    771     ptrs.push_back(&n);
    772   }
    773   return Print(ptrs);
    774 }
    775 
    776 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
    777   std::vector<const NodeDef*> ptrs;
    778   for (const NodeDef& n : instantiated_func_nodes) {
    779     ptrs.push_back(&n);
    780   }
    781   return Print(ptrs);
    782 }
    783 
    784 string DebugStringWhole(const GraphDef& gdef) {
    785   string ret;
    786   for (const auto& fdef : gdef.library().function()) {
    787     strings::StrAppend(&ret, Print(fdef));
    788   }
    789   strings::StrAppend(&ret, "\n");
    790   for (const auto& ndef : gdef.node()) {
    791     strings::StrAppend(&ret, Print(ndef), "\n");
    792   }
    793   return ret;
    794 }
    795 
    796 namespace {
    797 
    798 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
    799 // Python, it's possible to access unset attrs, which returns a default value
    800 // and adds an unset attr to the map.
    801 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
    802   std::map<string, AttrValue> set_attrs;
    803   for (auto pair : fdef.attr()) {
    804     if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
    805       set_attrs[pair.first] = pair.second;
    806     }
    807   }
    808   return set_attrs;
    809 }
    810 
    811 }  // end namespace
    812 
    813 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
    814   if (!OpDefEqual(f1.signature(), f2.signature())) return false;
    815 
    816   std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
    817   std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
    818   if (f1_attrs.size() != f2_attrs.size()) return false;
    819   for (auto iter1 : f1_attrs) {
    820     auto iter2 = f2_attrs.find(iter1.first);
    821     if (iter2 == f2_attrs.end()) return false;
    822     if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
    823   }
    824 
    825   if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
    826     return false;
    827   }
    828 
    829   std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
    830   std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
    831   if (ret1 != ret2) return false;
    832 
    833   std::map<string, string> control_ret1(f1.control_ret().begin(),
    834                                         f1.control_ret().end());
    835   std::map<string, string> control_ret2(f2.control_ret().begin(),
    836                                         f2.control_ret().end());
    837   if (control_ret1 != control_ret2) return false;
    838 
    839   return true;
    840 }
    841 
    842 uint64 FunctionDefHash(const FunctionDef& fdef) {
    843   // signature
    844   uint64 h = OpDefHash(fdef.signature());
    845 
    846   // attrs
    847   std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
    848   for (const auto& p : attrs) {
    849     h = Hash64(p.first.data(), p.first.size(), h);
    850     h = Hash64Combine(AttrValueHash(p.second), h);
    851   }
    852 
    853   // node defs
    854   h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
    855 
    856   // output names
    857   std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
    858   for (const auto& p : ret) {
    859     h = Hash64(p.first.data(), p.first.size(), h);
    860     h = Hash64(p.second.data(), p.second.size(), h);
    861   }
    862 
    863   // control output names
    864   std::map<string, string> control_ret(fdef.control_ret().begin(),
    865                                        fdef.control_ret().end());
    866   for (const auto& p : control_ret) {
    867     h = Hash64(p.first.data(), p.first.size(), h);
    868     h = Hash64(p.second.data(), p.second.size(), h);
    869   }
    870 
    871   return h;
    872 }
    873 
    874 static constexpr const char* const kExecutorAttr = "_executor";
    875 
    876 /* static */
    877 string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
    878                                             AttrSlice attrs) {
    879   if (!options.executor_type.empty()) {
    880     return options.executor_type;
    881   } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
    882     return executor_attr->s();
    883   } else {
    884     return string();
    885   }
    886 }
    887 
    888 string Canonicalize(const string& funcname, AttrSlice attrs,
    889                     const FunctionLibraryRuntime::InstantiateOptions& options) {
    890   std::vector<string> entries;
    891   entries.reserve(attrs.size() + static_cast<int>(options.target.empty()) +
    892                   options.input_devices.size());
    893   for (auto p : attrs) {
    894     if (p.first != kExecutorAttr) {
    895       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
    896     }
    897   }
    898   if (!options.target.empty()) {
    899     entries.push_back(
    900         strings::StrCat("_target", "=", str_util::CEscape(options.target)));
    901   }
    902   for (int i = 0; i < options.input_devices.size(); ++i) {
    903     entries.push_back(strings::StrCat(
    904         "_input_dev", i, "=", str_util::CEscape(options.input_devices[i])));
    905   }
    906   for (int i = 0; i < options.output_devices.size(); ++i) {
    907     entries.push_back(strings::StrCat(
    908         "_output_dev", i, "=", str_util::CEscape(options.output_devices[i])));
    909   }
    910   if (options.overlay_lib) {
    911     entries.push_back(strings::StrCat(
    912         "_overlay_lib", "=", reinterpret_cast<uintptr_t>(options.overlay_lib)));
    913   }
    914   if (!options.state_handle.empty()) {
    915     entries.push_back(
    916         strings::StrCat("_state_handle", "=", options.state_handle));
    917   }
    918   string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
    919   if (!executor_type.empty()) {
    920     entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type));
    921   }
    922   string config_proto_serialized;
    923   options.config_proto.SerializeToString(&config_proto_serialized);
    924   if (!config_proto_serialized.empty()) {
    925     entries.push_back(strings::StrCat(
    926         "_config_proto", "=", str_util::CEscape(config_proto_serialized)));
    927   }
    928   std::sort(entries.begin(), entries.end());
    929   return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
    930 }
    931 
    932 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
    933                                      DataTypeSlice ret_types)
    934     : arg_types_(arg_types.begin(), arg_types.end()),
    935       ret_types_(ret_types.begin(), ret_types.end()) {
    936   args_.resize(arg_types_.size());
    937   rets_.resize(ret_types_.size());
    938 }
    939 
    940 FunctionCallFrame::~FunctionCallFrame() {}
    941 
    942 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
    943   // Input type checks.
    944   if (args.size() != arg_types_.size()) {
    945     return errors::InvalidArgument("Expects ", arg_types_.size(),
    946                                    " arguments, but ", args.size(),
    947                                    " is provided");
    948   }
    949   for (size_t i = 0; i < args.size(); ++i) {
    950     if (arg_types_[i] != args[i].dtype()) {
    951       return errors::InvalidArgument(
    952           "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
    953           DataTypeString(args[i].dtype()), " is provided");
    954     }
    955     args_[i] = args[i];
    956   }
    957   return Status::OK();
    958 }
    959 
    960 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
    961   rets->clear();
    962   rets->reserve(rets_.size());
    963   for (size_t i = 0; i < rets_.size(); ++i) {
    964     const auto& item = rets_[i];
    965     if (item.has_val) {
    966       rets->push_back(item.val);
    967     } else {
    968       return errors::Internal("Retval[", i, "] does not have value");
    969     }
    970   }
    971   return Status::OK();
    972 }
    973 
    974 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
    975                                          bool allow_dead_tensors) {
    976   rets->clear();
    977   rets->reserve(rets_.size());
    978   for (size_t i = 0; i < rets_.size(); ++i) {
    979     if (rets_[i].has_val) {
    980       rets->emplace_back(std::move(rets_[i].val));
    981     } else if (allow_dead_tensors) {
    982       rets->emplace_back();
    983     } else {
    984       return errors::Internal("Retval[", i, "] does not have value");
    985     }
    986   }
    987   return Status::OK();
    988 }
    989 
    990 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
    991   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
    992     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
    993                                    args_.size(), ")");
    994   }
    995   *val = args_[index];
    996   return Status::OK();
    997 }
    998 
    999 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
   1000   if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
   1001     return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
   1002                                    rets_.size(), ")");
   1003   }
   1004   if (val.dtype() != ret_types_[index]) {
   1005     return errors::InvalidArgument(
   1006         "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
   1007         ", but ", DataTypeString(val.dtype()), " is provided.");
   1008   }
   1009   Retval* item = &rets_[index];
   1010   if (!item->has_val) {
   1011     item->has_val = true;
   1012     item->val = val;
   1013   } else {
   1014     return errors::Internal("Retval[", index, "] has already been set.");
   1015   }
   1016   return Status::OK();
   1017 }
   1018 
   1019 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
   1020     FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
   1021     : fdef(fdef_in),
   1022       // Exact shape inference for functions is handled by ShapeRefiner.
   1023       // Here we pass a dummy shape inference function for legacy code paths.
   1024       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
   1025                            true /* is_function */) {}
   1026 
   1027 FunctionLibraryDefinition::FunctionLibraryDefinition(
   1028     const FunctionLibraryDefinition& other)
   1029     : default_registry_(other.default_registry_) {
   1030   tf_shared_lock l(other.mu_);
   1031   for (const auto& it : other.function_defs_) {
   1032     TF_CHECK_OK(AddFunctionDef(it.second->fdef));
   1033   }
   1034   func_grad_ = other.func_grad_;
   1035 }
   1036 
   1037 FunctionLibraryDefinition::FunctionLibraryDefinition(
   1038     const OpRegistryInterface* default_registry,
   1039     const FunctionDefLibrary& def_lib)
   1040     : default_registry_(default_registry),
   1041       function_defs_(def_lib.function_size()) {
   1042   for (const auto& fdef : def_lib.function()) {
   1043     // The latter function definition wins.
   1044     auto& ptr = function_defs_[fdef.signature().name()];
   1045     ptr.reset(new FunctionDefAndOpRegistration(fdef));
   1046   }
   1047   for (const auto& grad : def_lib.gradient()) {
   1048     func_grad_[grad.function_name()] = grad.gradient_func();
   1049   }
   1050 }
   1051 
   1052 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
   1053 
   1054 bool FunctionLibraryDefinition::Contains(const string& func) const {
   1055   tf_shared_lock l(mu_);
   1056   return function_defs_.find(func) != function_defs_.end();
   1057 }
   1058 
   1059 const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
   1060   tf_shared_lock l(mu_);
   1061   return FindHelper(func);
   1062 }
   1063 
   1064 const FunctionDef* FunctionLibraryDefinition::FindHelper(
   1065     const string& func) const {
   1066   auto iter = function_defs_.find(func);
   1067   if (iter == function_defs_.end()) {
   1068     return nullptr;
   1069   } else {
   1070     return &iter->second->fdef;
   1071   }
   1072 }
   1073 
   1074 Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
   1075   mutex_lock l(mu_);
   1076   bool added;
   1077   return AddFunctionDefHelper(fdef, &added);
   1078 }
   1079 
   1080 Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
   1081                                                        bool* added) {
   1082   *added = false;
   1083   std::unique_ptr<FunctionDefAndOpRegistration>* entry =
   1084       &function_defs_[fdef.signature().name()];
   1085   if (*entry != nullptr) {
   1086     if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
   1087       return errors::InvalidArgument(
   1088           "Cannot add function '", fdef.signature().name(),
   1089           "' because a different function with the same name already "
   1090           "exists.");
   1091     }
   1092     // Ignore duplicate FunctionDefs
   1093     return Status::OK();
   1094   }
   1095   const OpDef* op_def;
   1096   if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
   1097     return errors::InvalidArgument(
   1098         "Cannot add function '", fdef.signature().name(),
   1099         "' because an op with the same name already exists.");
   1100   }
   1101   entry->reset(new FunctionDefAndOpRegistration(fdef));
   1102   *added = true;
   1103   return Status::OK();
   1104 }
   1105 
   1106 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
   1107   mutex_lock l(mu_);
   1108   bool added;
   1109   return AddGradientDefHelper(grad, &added);
   1110 }
   1111 
   1112 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
   1113                                                        bool* added) {
   1114   *added = false;
   1115   string* entry = &func_grad_[grad.function_name()];
   1116   if (!entry->empty()) {
   1117     if (*entry != grad.gradient_func()) {
   1118       return errors::InvalidArgument(
   1119           "Cannot assign gradient function '", grad.gradient_func(), "' to '",
   1120           grad.function_name(), "' because it already has gradient function ",
   1121           "'", *entry, "'");
   1122     }
   1123     // Ignore duplicate GradientDefs
   1124     return Status::OK();
   1125   }
   1126   *entry = grad.gradient_func();
   1127   *added = true;
   1128   return Status::OK();
   1129 }
   1130 
   1131 Status FunctionLibraryDefinition::AddLibrary(
   1132     const FunctionLibraryDefinition& other) {
   1133   // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
   1134   // the duration of the function could lead to deadlock).
   1135   FunctionLibraryDefinition clone(other);
   1136   mutex_lock l(mu_);
   1137   // Remember the funcs and grads that we added successfully so that
   1138   // we can roll them back on error.
   1139   std::vector<string> funcs;
   1140   std::vector<string> funcs_with_grads;
   1141   Status s;
   1142   bool added;
   1143   for (auto iter : clone.function_defs_) {
   1144     s = AddFunctionDefHelper(iter.second->fdef, &added);
   1145     if (!s.ok()) {
   1146       Remove(funcs, funcs_with_grads);
   1147       return s;
   1148     }
   1149     if (added) {
   1150       funcs.push_back(iter.second->fdef.signature().name());
   1151     }
   1152   }
   1153   for (auto iter : clone.func_grad_) {
   1154     GradientDef grad;
   1155     grad.set_function_name(iter.first);
   1156     grad.set_gradient_func(iter.second);
   1157     s = AddGradientDefHelper(grad, &added);
   1158     if (!s.ok()) {
   1159       Remove(funcs, funcs_with_grads);
   1160       return s;
   1161     }
   1162     if (added) {
   1163       funcs_with_grads.push_back(grad.function_name());
   1164     }
   1165   }
   1166   return Status::OK();
   1167 }
   1168 
   1169 Status FunctionLibraryDefinition::AddLibrary(
   1170     const FunctionDefLibrary& lib_def) {
   1171   // Remember the funcs and grads that we added successfully so that
   1172   // we can roll them back on error.
   1173   mutex_lock l(mu_);
   1174   std::vector<string> funcs;
   1175   std::vector<string> funcs_with_grads;
   1176   Status s;
   1177   bool added;
   1178   for (const FunctionDef& fdef : lib_def.function()) {
   1179     s = AddFunctionDefHelper(fdef, &added);
   1180     if (!s.ok()) {
   1181       Remove(funcs, funcs_with_grads);
   1182       return s;
   1183     }
   1184     if (added) {
   1185       funcs.push_back(fdef.signature().name());
   1186     }
   1187   }
   1188   for (const GradientDef& grad : lib_def.gradient()) {
   1189     s = AddGradientDefHelper(grad, &added);
   1190     if (!s.ok()) {
   1191       Remove(funcs, funcs_with_grads);
   1192       return s;
   1193     }
   1194     if (added) {
   1195       funcs_with_grads.push_back(grad.function_name());
   1196     }
   1197   }
   1198   return Status::OK();
   1199 }
   1200 
   1201 Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
   1202                                                   const FunctionDef& fdef) {
   1203   mutex_lock l(mu_);
   1204   bool added;
   1205   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
   1206   TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added));
   1207   return Status::OK();
   1208 }
   1209 
   1210 Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
   1211   mutex_lock l(mu_);
   1212   bool added;
   1213   TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
   1214   TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
   1215   return Status::OK();
   1216 }
   1217 
   1218 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
   1219   mutex_lock l(mu_);
   1220   TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
   1221   return Status::OK();
   1222 }
   1223 
   1224 Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
   1225   const auto& i = function_defs_.find(func);
   1226   if (i == function_defs_.end()) {
   1227     return errors::InvalidArgument("Tried to remove non-existent function ",
   1228                                    func);
   1229   }
   1230   function_defs_.erase(i);
   1231   return Status::OK();
   1232 }
   1233 
   1234 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
   1235   const auto& i = func_grad_.find(func);
   1236   if (i == func_grad_.end()) {
   1237     return errors::InvalidArgument("Tried to remove non-existent gradient ",
   1238                                    func);
   1239   }
   1240   func_grad_.erase(i);
   1241   return Status::OK();
   1242 }
   1243 
   1244 void FunctionLibraryDefinition::Remove(
   1245     const std::vector<string>& funcs,
   1246     const std::vector<string>& funcs_with_grads) {
   1247   for (const string& f : funcs) {
   1248     Status s = RemoveFunctionHelper(f);
   1249     DCHECK(s.ok());
   1250   }
   1251   for (const string& f : funcs_with_grads) {
   1252     Status s = RemoveGradient(f);
   1253     DCHECK(s.ok());
   1254   }
   1255 }
   1256 
   1257 string FunctionLibraryDefinition::FindGradient(const string& func) const {
   1258   tf_shared_lock l(mu_);
   1259   return gtl::FindWithDefault(func_grad_, func, "");
   1260 }
   1261 
   1262 string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
   1263   return gtl::FindWithDefault(func_grad_, func, "");
   1264 }
   1265 
   1266 Status FunctionLibraryDefinition::LookUp(
   1267     const string& op, const OpRegistrationData** op_reg_data) const {
   1268   tf_shared_lock l(mu_);
   1269   auto iter = function_defs_.find(op);
   1270   if (iter != function_defs_.end()) {
   1271     *op_reg_data = &iter->second->op_registration_data;
   1272     return Status::OK();
   1273   }
   1274   return default_registry_->LookUp(op, op_reg_data);
   1275 }
   1276 
   1277 string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
   1278   tf_shared_lock l(mu_);
   1279   int index = 0;
   1280   string name = strings::StrCat(prefix, index);
   1281   while (function_defs_.find(name) != function_defs_.end()) {
   1282     ++index;
   1283     name = strings::StrCat(prefix, index);
   1284   }
   1285   return name;
   1286 }
   1287 
   1288 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
   1289     const NodeDef& ndef) const {
   1290   if (ndef.op() != kGradientOp) {
   1291     // If 'ndef' calls a function and the function's def has the attr,
   1292     // returns it.
   1293     return Find(ndef.op());
   1294   }
   1295 
   1296   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
   1297   // Foo's attributes.
   1298   const NameAttrList* forward_func_attrs;
   1299   if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
   1300     return nullptr;
   1301   }
   1302   const string& func_name = forward_func_attrs->name();
   1303   {
   1304     tf_shared_lock l(mu_);
   1305     const string& grad_name = FindGradientHelper(func_name);
   1306     // If 'func' has a user-defined gradient function, uses the grad
   1307     // function's attrs to see if noinline is specified. Otherwise,
   1308     // uses func's attrs.
   1309     if (!grad_name.empty()) {
   1310       return FindHelper(grad_name);
   1311     }
   1312     return FindHelper(func_name);
   1313   }
   1314 }
   1315 
   1316 std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
   1317   std::vector<string> function_names;
   1318   tf_shared_lock l(mu_);
   1319   function_names.reserve(function_defs_.size());
   1320   for (const auto& it : function_defs_) {
   1321     function_names.emplace_back(it.first);
   1322   }
   1323   return function_names;
   1324 }
   1325 
   1326 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
   1327   FunctionDefLibrary lib;
   1328   tf_shared_lock l(mu_);
   1329   for (const auto& f : function_defs_) {
   1330     *lib.add_function() = f.second->fdef;
   1331   }
   1332   for (const auto& g : func_grad_) {
   1333     GradientDef* gd = lib.add_gradient();
   1334     gd->set_function_name(g.first);
   1335     gd->set_gradient_func(g.second);
   1336   }
   1337   return lib;
   1338 }
   1339 
   1340 template <typename T>
   1341 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
   1342                                           const string& attr, T* value) const {
   1343   const FunctionDef* fdef = GetAttrImpl(ndef);
   1344   if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
   1345     return Status::OK();
   1346   }
   1347   return errors::InvalidArgument("Attr ", attr, " is not defined.");
   1348 }
   1349 
   1350 template <typename T>
   1351 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
   1352                                           T* value) const {
   1353   return GetAttr(node.def(), attr, value);
   1354 }
   1355 
   1356 #define GET_ATTR(T)                                                            \
   1357   template Status FunctionLibraryDefinition::GetAttr(const Node&,              \
   1358                                                      const string&, T*) const; \
   1359   template Status FunctionLibraryDefinition::GetAttr(const NodeDef&,           \
   1360                                                      const string&, T*) const;
   1361 GET_ATTR(string)
   1362 GET_ATTR(bool)
   1363 #undef GET_ATTR
   1364 
   1365 namespace {
   1366 
   1367 constexpr char kApiImplements[] = "api_implements";
   1368 
   1369 absl::flat_hash_set<string> ReachableFunctions(
   1370     const FunctionLibraryDefinition& flib,
   1371     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
   1372   // Functions that are reachable from the graph.
   1373   absl::flat_hash_set<string> reachable_funcs;
   1374 
   1375   // For any functions, if it has attribute "api_implements" =
   1376   // "some_interface" and it is reachable, then it means any other
   1377   // function with same attribute name and value could also be potentially
   1378   // reachable, eg via implementation_selector swapping the
   1379   // nodedef.
   1380   absl::flat_hash_set<string> reachable_api_interface;
   1381 
   1382   // Functions might be reachable from the nested function calls, so we keep a
   1383   // queue of functions that we have to check.
   1384   gtl::InlinedVector<const FunctionDef*, 4> func_queue;
   1385 
   1386   // Add reachable and not already processed functions to the functions queue.
   1387   const auto add_to_func_queue = [&](const string& func_name) {
   1388     const FunctionDef* func = flib.Find(func_name);
   1389     if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
   1390       func_queue.push_back(func);
   1391     }
   1392   };
   1393 
   1394   // Add all the functions that are reachable from the given node to the queue.
   1395   const auto process_node = [&](const NodeDef& node) {
   1396     // Node itself can be a call to the function.
   1397     add_to_func_queue(node.op());
   1398 
   1399     // Or node can have an attribute referencing a function.
   1400     for (const auto& attr : node.attr()) {
   1401       const auto& attr_value = attr.second;
   1402 
   1403       // 1. AttrValue.func
   1404       if (attr_value.has_func()) {
   1405         add_to_func_queue(attr_value.func().name());
   1406       }
   1407 
   1408       // 2. AttrValue.ListValue.func
   1409       if (attr_value.has_list()) {
   1410         for (const auto& func : attr_value.list().func()) {
   1411           add_to_func_queue(func.name());
   1412         }
   1413       }
   1414     }
   1415   };
   1416 
   1417   // Add all functions that are directly called from the optimized graph.
   1418   std::for_each(nodes.begin(), nodes.end(), process_node);
   1419 
   1420   // Process all reachable functions.
   1421   while (!func_queue.empty()) {
   1422     const FunctionDef* func = func_queue.back();
   1423     func_queue.pop_back();
   1424 
   1425     const string& func_name = func->signature().name();
   1426     reachable_funcs.insert(func_name);
   1427 
   1428     const auto attr_it = func->attr().find(kApiImplements);
   1429     if (attr_it != func->attr().end()) {
   1430       reachable_api_interface.insert(attr_it->second.s());
   1431     }
   1432 
   1433     // Find all the functions called from the function body.
   1434     const auto& func_body = func->node_def();
   1435     std::for_each(func_body.begin(), func_body.end(), process_node);
   1436 
   1437     // Check if the function has a registered gradient.
   1438     const string grad_func_name = flib.FindGradient(func_name);
   1439     if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
   1440   }
   1441 
   1442   for (const auto& func_name : flib.ListFunctionNames()) {
   1443     const auto& func_def = flib.Find(func_name);
   1444     const auto attr_it = func_def->attr().find(kApiImplements);
   1445     if (attr_it != func_def->attr().end()) {
   1446       if (reachable_api_interface.contains(attr_it->second.s())) {
   1447         reachable_funcs.insert(func_name);
   1448       }
   1449     }
   1450   }
   1451 
   1452   return reachable_funcs;
   1453 }
   1454 
   1455 FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
   1456     const FunctionLibraryDefinition& flib,
   1457     const protobuf::RepeatedPtrField<NodeDef>& nodes) {
   1458   absl::flat_hash_set<string> reachable_funcs = ReachableFunctions(flib, nodes);
   1459 
   1460   FunctionLibraryDefinition reachable_flib(flib.default_registry(),
   1461                                            FunctionDefLibrary());
   1462 
   1463   for (const string& func_name : reachable_funcs) {
   1464     const FunctionDef* func = flib.Find(func_name);
   1465     DCHECK_NE(func, nullptr);
   1466     // That should never fail, because we copy functions from valid flib and use
   1467     // the same default registry.
   1468     const Status added = reachable_flib.AddFunctionDef(*func);
   1469     DCHECK(added.ok());
   1470 
   1471     const string grad_func_name = flib.FindGradient(func_name);
   1472     if (!grad_func_name.empty()) {
   1473       GradientDef grad;
   1474       grad.set_function_name(func_name);
   1475       grad.set_gradient_func(grad_func_name);
   1476       // It can only fail if function already has a gradient function.
   1477       const Status added_grad = reachable_flib.AddGradientDef(grad);
   1478       DCHECK(added_grad.ok());
   1479     }
   1480   }
   1481 
   1482   return reachable_flib;
   1483 }
   1484 
   1485 }  // namespace
   1486 
   1487 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
   1488     const GraphDef& graph) const {
   1489   return ReachableFunctionLibraryDefinition(*this, graph.node());
   1490 }
   1491 
   1492 FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
   1493     const FunctionDef& func) const {
   1494   return ReachableFunctionLibraryDefinition(*this, func.node_def());
   1495 }
   1496 
   1497 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
   1498   if (val.size() >= 2 && val[0] == '$') {
   1499     proto.set_placeholder(val.data() + 1, val.size() - 1);
   1500   } else {
   1501     SetAttrValue(val, &proto);
   1502   }
   1503 }
   1504 
   1505 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
   1506     const string& name,
   1507     gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
   1508   AttrValueWrapper ret;
   1509   ret.proto.mutable_func()->set_name(name);
   1510   for (const auto& a : attrs) {
   1511     ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
   1512   }
   1513   return ret;
   1514 }
   1515 
   1516 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
   1517   NodeDef n;
   1518   n.set_op(this->op);
   1519   n.set_name(this->ret[0]);
   1520   for (const auto& a : this->attr) {
   1521     n.mutable_attr()->insert({a.first, a.second.proto});
   1522   }
   1523   for (const string& a : this->arg) {
   1524     n.add_input(a);
   1525   }
   1526   for (const string& d : this->dep) {
   1527     n.add_input(strings::StrCat("^", d));
   1528   }
   1529   if (!this->device.empty()) {
   1530     n.set_device(this->device);
   1531   }
   1532   return n;
   1533 }
   1534 
   1535 /* static */
   1536 FunctionDef FunctionDefHelper::Create(
   1537     const string& function_name, gtl::ArraySlice<string> in_def,
   1538     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
   1539     gtl::ArraySlice<Node> node_def,
   1540     gtl::ArraySlice<std::pair<string, string>> ret_def,
   1541     gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
   1542   FunctionDef fdef;
   1543 
   1544   // Signature
   1545   OpDefBuilder b(function_name);
   1546   for (const auto& i : in_def) b.Input(i);
   1547   for (const auto& o : out_def) b.Output(o);
   1548   for (const auto& a : attr_def) b.Attr(a);
   1549   for (const auto& c : control_ret_def) b.ControlOutput(c.first);
   1550 
   1551   OpRegistrationData op_reg_data;
   1552   TF_CHECK_OK(b.Finalize(&op_reg_data));
   1553   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
   1554 
   1555   // Function body
   1556   for (const auto& n : node_def) {
   1557     *(fdef.add_node_def()) = n.ToNodeDef();
   1558   }
   1559 
   1560   // Returns
   1561   for (const auto& r : ret_def) {
   1562     fdef.mutable_ret()->insert({r.first, r.second});
   1563   }
   1564 
   1565   // Control returns
   1566   for (const auto& cr : control_ret_def) {
   1567     fdef.mutable_control_ret()->insert({cr.first, cr.second});
   1568   }
   1569 
   1570   auto* op_def_registry = OpRegistry::Global();
   1571   // Check if any op is stateful.
   1572   for (const auto& n : node_def) {
   1573     const OpDef* op_def = nullptr;
   1574     auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
   1575     // Lookup can fail if e.g. we are calling a function that was not yet
   1576     // defined.  If it happens, conservatively assume the op is stateful.
   1577     if (!status.ok() || op_def->is_stateful()) {
   1578       fdef.mutable_signature()->set_is_stateful(true);
   1579     }
   1580   }
   1581 
   1582   return fdef;
   1583 }
   1584 
   1585 /* static */
   1586 FunctionDef FunctionDefHelper::Create(
   1587     const string& function_name, gtl::ArraySlice<string> in_def,
   1588     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
   1589     gtl::ArraySlice<Node> node_def,
   1590     gtl::ArraySlice<std::pair<string, string>> ret_def) {
   1591   return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
   1592                 /*control_ret_def=*/{});
   1593 }
   1594 
   1595 /* static */
   1596 FunctionDef FunctionDefHelper::Define(const string& name,
   1597                                       gtl::ArraySlice<string> arg_def,
   1598                                       gtl::ArraySlice<string> ret_def,
   1599                                       gtl::ArraySlice<string> attr_def,
   1600                                       gtl::ArraySlice<Node> node_def) {
   1601   FunctionDef fdef;
   1602   OpDefBuilder b(name);
   1603   for (const auto& a : arg_def) b.Input(a);
   1604   for (const auto& r : ret_def) b.Output(r);
   1605   for (const auto& a : attr_def) b.Attr(a);
   1606 
   1607   OpRegistrationData op_reg_data;
   1608   TF_CHECK_OK(b.Finalize(&op_reg_data));
   1609   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
   1610 
   1611   // Mapping from legacy output names to NodeDef outputs.
   1612   std::unordered_map<string, string> ret_index;
   1613   for (const auto& a : fdef.signature().input_arg()) {
   1614     ret_index[a.name()] = a.name();
   1615   }
   1616 
   1617   // For looking up OpDefs
   1618   auto* op_def_registry = OpRegistry::Global();
   1619 
   1620   // Function body
   1621   for (const auto& src : node_def) {
   1622     NodeDef* n = fdef.add_node_def();
   1623     n->set_op(src.op);
   1624     n->set_name(src.ret[0]);
   1625     for (const auto& a : src.attr) {
   1626       n->mutable_attr()->insert({a.first, a.second.proto});
   1627     }
   1628     for (const string& a : src.arg) {
   1629       const auto iter = ret_index.find(a);
   1630       CHECK(iter != ret_index.end())
   1631           << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
   1632       n->add_input(iter->second);
   1633     }
   1634     for (const string& d : src.dep) {
   1635       n->add_input(strings::StrCat("^", d));
   1636     }
   1637 
   1638     // Add the outputs of this node to ret_index.
   1639     const OpDef* op_def = nullptr;
   1640     TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
   1641     CHECK(op_def != nullptr) << n->op();
   1642     NameRangeMap output_names;
   1643     TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
   1644     for (const auto& o : output_names) {
   1645       CHECK_LE(o.second.second, src.ret.size())
   1646           << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
   1647           << "' of " << name;
   1648       for (int i = o.second.first; i < o.second.second; ++i) {
   1649         ret_index[src.ret[i]] =
   1650             strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
   1651       }
   1652     }
   1653     if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
   1654   }
   1655 
   1656   // Returns
   1657   for (const auto& r : fdef.signature().output_arg()) {
   1658     const auto iter = ret_index.find(r.name());
   1659     CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
   1660     fdef.mutable_ret()->insert({r.name(), iter->second});
   1661   }
   1662   return fdef;
   1663 }
   1664 
   1665 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
   1666                                       gtl::ArraySlice<string> ret_def,
   1667                                       gtl::ArraySlice<string> attr_def,
   1668                                       gtl::ArraySlice<Node> node_def) {
   1669   return Define("_", arg_def, ret_def, attr_def, node_def);
   1670 }
   1671 
   1672 namespace gradient {
   1673 
   1674 typedef std::unordered_map<string, Creator> OpGradFactory;
   1675 
   1676 OpGradFactory* GetOpGradFactory() {
   1677   static OpGradFactory* factory = new OpGradFactory;
   1678   return factory;
   1679 }
   1680 
   1681 bool RegisterOp(const string& op, Creator func) {
   1682   CHECK(GetOpGradFactory()->insert({op, func}).second)
   1683       << "Duplicated gradient for " << op;
   1684   return true;
   1685 }
   1686 
   1687 Status GetOpGradientCreator(const string& op, Creator* creator) {
   1688   auto fac = GetOpGradFactory();
   1689   auto iter = fac->find(op);
   1690   if (iter == fac->end()) {
   1691     return errors::NotFound("No gradient defined for op: ", op);
   1692   }
   1693   *creator = iter->second;
   1694   return Status::OK();
   1695 }
   1696 
   1697 }  // end namespace gradient
   1698 
   1699 }  // namespace tensorflow
   1700