Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/function.h"
     17 
     18 #include <map>
     19 #include <unordered_map>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/framework/common_shape_fns.h"
     24 #include "tensorflow/core/framework/function.pb_text.h"
     25 #include "tensorflow/core/framework/graph.pb.h"
     26 #include "tensorflow/core/framework/node_def.pb.h"
     27 #include "tensorflow/core/framework/node_def_util.h"
     28 #include "tensorflow/core/framework/op.h"
     29 #include "tensorflow/core/graph/graph.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     32 #include "tensorflow/core/lib/gtl/map_util.h"
     33 #include "tensorflow/core/util/equal_graph_def.h"
     34 
     35 namespace tensorflow {
     36 
     37 // Extracts the actual type from "attr_values" based on its definition
     38 // "arg_def".
     39 //
     40 // If "arg_def" is a N*T type, *is_type_list is set to false, and
     41 // *dtypes is set to be a vector of size N and each element is T.
     42 //
     43 // If "arg_def" is a list(type), *is_type_list is set to true, and
     44 // *dtypes is set to be a vector of types specified in attrs for
     45 // arg_def.
     46 //
     47 // Otherwise (arg_def is a simple type T), *is_type_list is set to
     48 // false, and *dtypes is set to a single element vector, whose only
     49 // element is T.
     50 Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
     51                   bool* is_type_list, DataTypeVector* dtypes) {
     52   dtypes->clear();
     53   if (!arg_def.type_list_attr().empty()) {
     54     const AttrValue* v = attrs.Find(arg_def.type_list_attr());
     55     if (v == nullptr) {
     56       return errors::NotFound("type attr not found: ",
     57                               arg_def.type_list_attr());
     58     }
     59     *is_type_list = true;
     60     for (int i = 0; i < v->list().type_size(); ++i) {
     61       dtypes->push_back(v->list().type(i));
     62     }
     63     return Status::OK();
     64   }
     65 
     66   *is_type_list = false;
     67   int num = 1;
     68   if (!arg_def.number_attr().empty()) {
     69     const AttrValue* v = attrs.Find(arg_def.number_attr());
     70     if (v == nullptr) {
     71       return errors::NotFound("type attr not found: ", arg_def.type_attr());
     72     }
     73     num = v->i();
     74   }
     75 
     76   DataType dtype;
     77   if (arg_def.type() != DT_INVALID) {
     78     dtype = arg_def.type();
     79   } else if (arg_def.type_attr().empty()) {
     80     dtype = DT_INVALID;
     81   } else {
     82     const AttrValue* v = attrs.Find(arg_def.type_attr());
     83     if (v == nullptr) {
     84       return errors::NotFound("type attr not found: ", arg_def.type_attr());
     85     }
     86     dtype = v->type();
     87   }
     88   dtypes->resize(num, dtype);
     89   return Status::OK();
     90 }
     91 
     92 namespace {
     93 
     94 template <typename T>
     95 void AddAttr(const string& name, const T& val, NodeDef* ndef) {
     96   SetAttrValue(val, &((*ndef->mutable_attr())[name]));
     97 }
     98 
     99 Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
    100   // attr_values should specify all attrs defined in fdef.
    101   for (const auto& a : sig.attr()) {
    102     const AttrValue* v = attr_values.Find(a.name());
    103     if (!v) {
    104       return errors::NotFound("Attr ", a.name(), " is not found from ",
    105                               SummarizeOpDef(sig));
    106     }
    107     Status status = AttrValueHasType(*v, a.type());
    108     if (!status.ok()) {
    109       errors::AppendToMessage(&status, "for attr '", a.name(), "'");
    110       return status;
    111     }
    112   }
    113 
    114 // TODO(josh11b): Enable this code once it works with function gradients.
    115 // Right now the C++ function gradient code assumes it can pass
    116 // all the attrs of the function to the gradient, and any attrs that
    117 // the gradient doesn't care about will be ignored.
    118 #if 0
    119   if (attr_values.size() != sig.attr_size()) {
    120     for (const auto& a : attr_values) {
    121       // TODO(josh11b): Possibly should ignore attrs that start with "_" here?
    122       bool found = false;
    123       for (const auto& s : sig.attr()) {
    124         if (a.first == s.name()) {
    125           found = true;
    126           break;
    127         }
    128       }
    129       if (!found) {
    130         return errors::NotFound("Attr ", a.first, " is not found in ",
    131                                 SummarizeOpDef(sig));
    132       }
    133     }
    134   }
    135 #endif
    136 
    137   return Status::OK();
    138 }
    139 
    140 // A helper class for instantiating functions. This contains shared information
    141 // like the resulting graph and node name index.
    142 class FunctionInstantiationHelper {
    143  public:
    144   FunctionInstantiationHelper(GetFunctionSignature get_function,
    145                               InstantiationResult* result)
    146       : get_function_(std ::move(get_function)), result_(*result) {
    147     result_.nodes.clear();
    148   }
    149 
    150   // Builds index for nodes that can be used as node's input arguments.
    151   Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
    152                             AttrSlice attr_values) {
    153     bool is_type_list;
    154     DataTypeVector dtypes;
    155     TF_RETURN_IF_ERROR(
    156         ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
    157     CHECK_GE(dtypes.size(), size_t{1});
    158     int arg_index = result_.nodes.size();
    159     TF_RETURN_IF_ERROR(
    160         AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
    161     // Creates dtypes.size() nodes in the graph.
    162     for (size_t i = 0; i < dtypes.size(); ++i) {
    163       TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
    164                                  {true, arg_index, 0, false, {dtypes[i]}}));
    165       DCHECK_EQ(arg_index, result_.nodes.size());
    166       string name = arg_def.name();
    167       if (dtypes.size() > 1) {
    168         strings::StrAppend(&name, "_", i);
    169       }
    170       NodeDef* gnode = AddNode(name);
    171       gnode->set_op("_Arg");
    172       AddAttr("T", dtypes[i], gnode);
    173       AddAttr("index", arg_index, gnode);
    174       result_.arg_types.push_back(dtypes[i]);
    175       ++arg_index;
    176     }
    177     return Status::OK();
    178   }
    179 
    180   Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
    181                               const int arg_index) {
    182     const OpDef* node_sig = nullptr;
    183     TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
    184     if (node_sig->output_arg_size() == 0) {
    185       return AddItem(node.name(), {false, arg_index, 0, false, {}});
    186     }
    187     const int num_retval = node_sig->output_arg_size();
    188     int start = 0;
    189     bool is_type_list;
    190     DataTypeVector dtypes;
    191     for (int i = 0; i < num_retval; ++i) {
    192       TF_RETURN_IF_ERROR(
    193           ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
    194       // Note that we rely on the backwards-compatibility test enforcing
    195       // that output_arg(*).name() doesn't change here.
    196       const string base_name =
    197           strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
    198       TF_RETURN_IF_ERROR(
    199           AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
    200       for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
    201         TF_RETURN_IF_ERROR(
    202             AddItem(strings::StrCat(base_name, ":", j),
    203                     {false, arg_index, start + j, false, {dtypes[j]}}));
    204       }
    205       start += dtypes.size();
    206     }
    207     return Status::OK();
    208   }
    209 
    210   Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
    211     const OpDef* fnode_sig = nullptr;
    212     TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
    213     NodeDef* gnode = AddNode(fnode.name());
    214     gnode->set_op(fnode.op());
    215     gnode->set_device(fnode.device());
    216     int gnode_idx = nodes_.size() - 1;
    217 
    218     // Input
    219     const int num_args = fnode_sig->input_arg_size();
    220     bool is_type_list;  // ignored
    221     DataTypeVector dtypes;
    222     int fnode_arg_index = 0;
    223     for (int i = 0; i < num_args; ++i) {
    224       TF_RETURN_IF_ERROR(
    225           ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
    226       // Consume inputs (indexed by fnode_arg_index) until we have
    227       // matched each element of dtypes (indexed by j).
    228       for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
    229         if (fnode_arg_index >= fnode.input_size()) {
    230           // Should never happen if we computed dtypes correctly.
    231           return errors::InvalidArgument(
    232               "Attempt to access beyond input size: ", fnode_arg_index,
    233               " >= ", fnode.input_size());
    234         }
    235         // Look up the next input.
    236         const string& input_name = fnode.input(fnode_arg_index);
    237         const auto* item = GetItemOrNull(input_name);
    238         if (item == nullptr) {
    239           return errors::InvalidArgument(
    240               "input ", input_name, " is not found: ", SummarizeNodeDef(fnode));
    241         }
    242         if (item->dtypes.size() > dtypes.size() - j) {
    243           return errors::InvalidArgument("Input ", input_name, " too long for ",
    244                                          fnode_sig->input_arg(i).name());
    245         }
    246         // Match up all the elements of this input (indexed by k) with
    247         // elements of dtypes (advancing j).
    248         for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
    249           if (item->dtypes[k] != dtypes[j]) {
    250             return errors::InvalidArgument(
    251                 "input ", fnode_sig->input_arg(i).name(), "[", j,
    252                 "] expected type ", DataTypeString(dtypes[j]),
    253                 " != ", DataTypeString(item->dtypes[k]), ", the type of ",
    254                 input_name, "[", k, "]");
    255           }
    256           if (item->is_func_arg) {
    257             AddInput(gnode_idx, item->nid + k, 0);
    258           } else {
    259             AddInput(gnode_idx, item->nid, item->idx + k);
    260           }
    261         }
    262       }
    263     }
    264 
    265     // Control deps.
    266     for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
    267       const string& input = fnode.input(i);
    268       if (input.empty() || input[0] != '^') {
    269         return errors::InvalidArgument("Expected input[", i, "] == '", input,
    270                                        "' to be a control input.");
    271       }
    272       int nid = -1;
    273       const string node_name = input.substr(1);
    274       const string node_colon = node_name + ":";
    275       const string node_colon_bound = node_name + ";";
    276       // index_ is a map sorted lexicographically, so the key we are looking for
    277       // must lie in the range [node_name, node_colon_bound).
    278       auto it = index_.lower_bound(node_name);
    279       while (it != index_.end() && it->first <= node_colon_bound) {
    280         if (it->first == node_name ||
    281             tensorflow::StringPiece(it->first).starts_with(node_colon)) {
    282           nid = it->second.nid;
    283           break;
    284         }
    285         ++it;
    286       }
    287       if (nid == -1) {
    288         return errors::InvalidArgument("input[", i, "] == '", input,
    289                                        "', is not found.");
    290       }
    291       AddDep(gnode_idx, nid);
    292     }
    293 
    294     // Attrs.
    295     for (const auto& p : attrs) {
    296       (*gnode->mutable_attr())[p.first] = p.second;
    297     }
    298 
    299     return Status::OK();
    300   }
    301 
    302   Status AddReturnNode(
    303       const OpDef::ArgDef& ret_def, AttrSlice attrs,
    304       const ::tensorflow::protobuf::Map<string, string>& ret_map,
    305       int* ret_index) {
    306     auto ret_iter = ret_map.find(ret_def.name());
    307     if (ret_iter == ret_map.end()) {
    308       return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
    309     }
    310     bool is_type_list;
    311     DataTypeVector dtypes;
    312     TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
    313     CHECK_GE(dtypes.size(), size_t{1});
    314     const auto* item = GetItemOrNull(ret_iter->second);
    315     if (item == nullptr) {
    316       return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
    317                                      ret_iter->second, " is not found.");
    318     }
    319     if (dtypes != item->dtypes) {
    320       return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
    321                                      " : ", DataTypeVectorString(dtypes),
    322                                      " vs. ",
    323                                      DataTypeVectorString(item->dtypes));
    324     }
    325     for (size_t i = 0; i < dtypes.size(); ++i) {
    326       string name = strings::StrCat(ret_def.name(), "_RetVal");
    327       if (dtypes.size() > 1) {
    328         strings::StrAppend(&name, "_", i);
    329       }
    330       NodeDef* gnode = AddNode(name);
    331       gnode->set_op("_Retval");
    332       AddInput(nodes_.size() - 1, item->nid, item->idx + i);
    333       AddAttr("T", dtypes[i], gnode);
    334       AddAttr("index", (*ret_index)++, gnode);
    335       result_.ret_types.push_back(dtypes[i]);
    336     }
    337     return Status::OK();
    338   }
    339 
    340   // Adds the actual node inputs to the result graph by converting indexes to
    341   // the node names.
    342   void AddNodeInputs() {
    343     for (int i = 0; i < result_.nodes.size(); i++) {
    344       NodeInfo& node_info = nodes_[i];
    345       for (const auto& p : node_info.data_inputs) {
    346         result_.nodes[i].add_input(Name(p.first, p.second));
    347       }
    348       for (int index : node_info.control_inputs) {
    349         result_.nodes[i].add_input(Dep(index));
    350       }
    351     }
    352   }
    353 
    354  private:
    355   // This is used to build a small index for all names that can be used as a
    356   // node's input arguments.
    357   //
    358   // If is_func_arg is true, the name is a function's argument.  In
    359   // this case, the produced graph def has node[nid:nid + dtype.size()].
    360   //
    361   // Otherwise, the name is a function body's node return value.  In
    362   // this case, the produced graph def has one node node[nid] and
    363   // the node's output index [idx ... idx + num) corresponds to the
    364   // named outputs.
    365   //
    366   // In all cases, "dtype" specifies the data type.
    367   struct NameInfoItem {
    368     bool is_func_arg;
    369     int nid;
    370     int idx;
    371     bool is_type_list;
    372     DataTypeVector dtypes;
    373   };
    374 
    375   // Adds an item into the input name index.
    376   Status AddItem(const string& name, const NameInfoItem& item) {
    377     if (!index_.insert({name, item}).second) {
    378       return errors::InvalidArgument(
    379           strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
    380                           " name: "),
    381           name);
    382     }
    383     return Status::OK();
    384   }
    385 
    386   const NameInfoItem* GetItemOrNull(const string& name) const {
    387     return gtl::FindOrNull(index_, name);
    388   }
    389 
    390   string Dep(int node_index) const {
    391     return strings::StrCat("^", Name(node_index));
    392   }
    393 
    394   string Name(int node_index) const {
    395     CHECK_LT(node_index, nodes_.size());
    396     return nodes_[node_index].name;
    397   }
    398 
    399   string Name(int node_index, int output_index) const {
    400     if (output_index == 0) {
    401       return Name(node_index);
    402     } else {
    403       return strings::StrCat(Name(node_index), ":", output_index);
    404     }
    405   }
    406 
    407   NodeDef* AddNode(const string& name) {
    408     result_.nodes.emplace_back();
    409     NodeDef* gnode = &result_.nodes.back();
    410     gnode->set_name(name);
    411     nodes_.push_back({name, {}, {}});
    412     CHECK_EQ(result_.nodes.size(), nodes_.size());
    413     return gnode;
    414   }
    415 
    416   void AddInput(int node_index, int output_node, int output_index) {
    417     CHECK_LT(node_index, nodes_.size());
    418     nodes_[node_index].data_inputs.push_back(
    419         std::make_pair(output_node, output_index));
    420   }
    421 
    422   void AddDep(int node_index, int dep_index) {
    423     CHECK_LT(node_index, nodes_.size());
    424     nodes_[node_index].control_inputs.push_back(dep_index);
    425   }
    426 
    427   GetFunctionSignature get_function_;
    428   InstantiationResult& result_;
    429   // A small index for all names that can be used as a node's input arguments.
    430   std::map<string, NameInfoItem> index_;
    431   // This contains information about a node in the new graph including the node
    432   // names and input nodes' indexes.
    433   struct NodeInfo {
    434     string name;
    435     // Data inputs where <n, k> means arg k of node n.
    436     std::vector<std::pair<int, int>> data_inputs;
    437     // Control inputs (dependencies).
    438     std::vector<int> control_inputs;
    439   };
    440   // nodes_[i] is the information about result_.nodes[i].
    441   std::vector<NodeInfo> nodes_;
    442 };
    443 
    444 // Various helpers Print(proto) to print relevant protos to ascii.
    445 string Print(const OpDef::ArgDef& arg) {
    446   string out;
    447   strings::StrAppend(&out, arg.name(), ":");
    448   if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
    449   if (!arg.number_attr().empty()) {
    450     strings::StrAppend(&out, arg.number_attr(), "*");
    451   }
    452   if (arg.type() != DT_INVALID) {
    453     strings::StrAppend(&out, DataTypeString(arg.type()));
    454   } else {
    455     strings::StrAppend(&out, arg.type_attr());
    456   }
    457   if (arg.is_ref()) strings::StrAppend(&out, ")");
    458   return out;
    459 }
    460 
    461 // TODO(josh11b): Merge this with SummarizeAttrValue().
    462 string Print(const AttrValue& attr_value) {
    463   if (attr_value.value_case() == AttrValue::kType) {
    464     return DataTypeString(attr_value.type());
    465   } else if ((attr_value.value_case() == AttrValue::kList) &&
    466              (attr_value.list().type_size() > 0)) {
    467     string ret = "{";
    468     for (int i = 0; i < attr_value.list().type_size(); ++i) {
    469       if (i > 0) strings::StrAppend(&ret, ", ");
    470       strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
    471     }
    472     strings::StrAppend(&ret, "}");
    473     return ret;
    474   } else if (attr_value.value_case() == AttrValue::kFunc) {
    475     if (attr_value.func().attr_size() == 0) {
    476       return attr_value.func().name();
    477     }
    478     std::vector<string> entries;
    479     for (auto p : attr_value.func().attr()) {
    480       entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
    481     }
    482     std::sort(entries.begin(), entries.end());
    483     return strings::StrCat(attr_value.func().name(), "[",
    484                            str_util::Join(entries, ", "), "]");
    485   }
    486   return SummarizeAttrValue(attr_value);
    487 }
    488 
    489 // TODO(josh11b): Merge this with SummarizeNodeDef().
    490 string Print(const NodeDef& n) {
    491   string out;
    492   strings::StrAppend(&out, n.name(), " = ", n.op());
    493   if (n.attr_size() > 0) {
    494     std::vector<string> entries;
    495     for (auto& a : n.attr()) {
    496       entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
    497     }
    498     std::sort(entries.begin(), entries.end());
    499     strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]");
    500   }
    501   strings::StrAppend(&out, "(");
    502   std::vector<StringPiece> dat;
    503   std::vector<string> dep;
    504   for (StringPiece s : n.input()) {
    505     if (s.Consume("^")) {
    506       dep.push_back(s.ToString());
    507     } else {
    508       dat.push_back(s);
    509     }
    510   }
    511   strings::StrAppend(&out, str_util::Join(dat, ", "), ")");
    512   if (!dep.empty()) {
    513     strings::StrAppend(&out, " @ ", str_util::Join(dep, ", "));
    514   }
    515   return out;
    516 }
    517 
    518 string Print(const FunctionDef& fdef) {
    519   string out;
    520   const OpDef& sig = fdef.signature();
    521   strings::StrAppend(&out, "\n", sig.name());
    522   if (sig.attr_size() > 0) {
    523     strings::StrAppend(&out, "[");
    524     for (int i = 0; i < sig.attr_size(); ++i) {
    525       const auto& a = sig.attr(i);
    526       if (i > 0) strings::StrAppend(&out, ", ");
    527       if (a.type() == "type") {
    528         strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
    529       } else {
    530         strings::StrAppend(&out, a.name(), ":", a.type());
    531       }
    532     }
    533     strings::StrAppend(&out, "]");
    534   }
    535   strings::StrAppend(&out, "(");
    536   for (int i = 0; i < sig.input_arg_size(); ++i) {
    537     if (i > 0) strings::StrAppend(&out, ", ");
    538     strings::StrAppend(&out, Print(sig.input_arg(i)));
    539   }
    540   strings::StrAppend(&out, ") -> (");
    541   for (int i = 0; i < sig.output_arg_size(); ++i) {
    542     if (i > 0) strings::StrAppend(&out, ", ");
    543     strings::StrAppend(&out, Print(sig.output_arg(i)));
    544   }
    545   strings::StrAppend(&out, ") {\n");
    546   for (const auto& n : fdef.node_def()) {
    547     strings::StrAppend(&out, "  ", Print(n), "\n");
    548   }
    549   for (const auto& r : fdef.ret()) {
    550     strings::StrAppend(&out, "  return ", r.first, " = ", r.second, "\n");
    551   }
    552   strings::StrAppend(&out, "}\n");
    553   return out;
    554 }
    555 
    556 string Print(gtl::ArraySlice<const NodeDef*> nodes) {
    557   std::vector<const NodeDef*> arg;
    558   std::vector<const NodeDef*> ret;
    559   std::vector<const NodeDef*> body;
    560   for (const NodeDef* n : nodes) {
    561     if (n->op() == "_Arg") {
    562       arg.push_back(n);
    563     } else if (n->op() == "_Retval") {
    564       ret.push_back(n);
    565     } else {
    566       body.push_back(n);
    567     }
    568   }
    569   auto comp = [](const NodeDef* x, const NodeDef* y) {
    570     int xi;
    571     TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
    572     int yi;
    573     TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
    574     return xi < yi;
    575   };
    576   std::sort(arg.begin(), arg.end(), comp);
    577   std::sort(ret.begin(), ret.end(), comp);
    578   string out;
    579   strings::StrAppend(&out, "\n(");
    580   auto get_type = [](const NodeDef& n) {
    581     DataType dt;
    582     if (!GetNodeAttr(n, "T", &dt).ok()) {
    583       dt = DT_INVALID;
    584     }
    585     return DataTypeString(dt);
    586   };
    587   for (size_t i = 0; i < arg.size(); ++i) {
    588     const NodeDef* n = arg[i];
    589     if (i > 0) strings::StrAppend(&out, ", ");
    590     CHECK_GE(n->attr_size(), 2);
    591     strings::StrAppend(&out, n->name(), ":", get_type(*n));
    592   }
    593   strings::StrAppend(&out, ") -> (");
    594   for (size_t i = 0; i < ret.size(); ++i) {
    595     const NodeDef* n = ret[i];
    596     if (i > 0) strings::StrAppend(&out, ", ");
    597     CHECK_LE(2, n->attr_size());
    598     CHECK_EQ(1, n->input_size());
    599     strings::StrAppend(&out, n->input(0), ":", get_type(*n));
    600   }
    601   strings::StrAppend(&out, ") {\n");
    602   for (size_t i = 0; i < body.size(); ++i) {
    603     strings::StrAppend(&out, "  ", Print(*body[i]), "\n");
    604   }
    605   strings::StrAppend(&out, "}\n");
    606   return out;
    607 }
    608 
    609 Status AddDefaultAttrs(const string& op,
    610                        const GetFunctionSignature& get_function,
    611                        AttrValueMap* attrs) {
    612   const OpDef* op_def = nullptr;
    613   TF_RETURN_IF_ERROR(get_function(op, &op_def));
    614   AttrSlice attr_slice(attrs);
    615   for (const auto& attr_def : op_def->attr()) {
    616     if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
    617       if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
    618         return errors::Internal("Somehow duplicated: ", attr_def.name());
    619       }
    620     }
    621   }
    622   return Status::OK();
    623 }
    624 
    625 }  // end namespace
    626 
    627 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
    628                            GetFunctionSignature get_function,
    629                            InstantiationResult* result) {
    630   VLOG(3) << "Instantiation Function: " << Print(fdef);
    631 
    632   const OpDef& sig = fdef.signature();
    633   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
    634 
    635   FunctionInstantiationHelper helper(get_function, result);
    636   Status s;
    637   for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
    638     s = helper.BuildInputArgIndex(arg_def, attr_values);
    639     if (!s.ok()) {
    640       errors::AppendToMessage(&s, "In ", Print(arg_def));
    641       return s;
    642     }
    643   }
    644 
    645   auto substitute = [attr_values](StringPiece name, AttrValue* val) {
    646     if (const AttrValue* v = attr_values.Find(name)) {
    647       *val = *v;
    648       return true;
    649     }
    650     return false;
    651   };
    652 
    653   // Makes a copy of all attrs in fdef and substitutes placeholders.
    654   // After this step, every attr is bound to a concrete value.
    655   std::vector<AttrValueMap> node_attrs;
    656   node_attrs.resize(fdef.node_def_size());
    657   for (int i = 0; i < fdef.node_def_size(); ++i) {
    658     for (auto attr : fdef.node_def(i).attr()) {
    659       if (!SubstitutePlaceholders(substitute, &attr.second)) {
    660         return errors::InvalidArgument("Failed to bind all placeholders in ",
    661                                        SummarizeAttrValue(attr.second));
    662       }
    663       if (!node_attrs[i].insert(attr).second) {
    664         return errors::Internal("Somehow duplicated: ", attr.first);
    665       }
    666     }
    667     TF_RETURN_IF_ERROR(
    668         AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
    669   }
    670 
    671   for (int i = 0; i < fdef.node_def_size(); ++i) {
    672     s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
    673                                     result->nodes.size() + i);
    674     if (!s.ok()) {
    675       errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
    676       return s;
    677     }
    678   }
    679   // Emits one node for each fdef.node_def.
    680   for (int i = 0; i < fdef.node_def_size(); ++i) {
    681     s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
    682     if (!s.ok()) {
    683       errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i)));
    684       return s;
    685     }
    686   }
    687 
    688   // Emits nodes for the function's return values.
    689   int ret_index = 0;
    690   for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
    691     s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), &ret_index);
    692     if (!s.ok()) {
    693       errors::AppendToMessage(&s, "In function output ", Print(ret_def));
    694       return s;
    695     }
    696   }
    697 
    698   // Adds the actual node inputs using the input indexes.
    699   helper.AddNodeInputs();
    700 
    701   return Status::OK();
    702 }
    703 
    704 string DebugString(const FunctionDef& func_def) { return Print(func_def); }
    705 
    706 string DebugString(const GraphDef& instantiated_func_def) {
    707   std::vector<const NodeDef*> ptrs;
    708   for (const NodeDef& n : instantiated_func_def.node()) {
    709     ptrs.push_back(&n);
    710   }
    711   return Print(ptrs);
    712 }
    713 
    714 string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
    715   std::vector<const NodeDef*> ptrs;
    716   for (const NodeDef& n : instantiated_func_nodes) {
    717     ptrs.push_back(&n);
    718   }
    719   return Print(ptrs);
    720 }
    721 
    722 string DebugStringWhole(const GraphDef& gdef) {
    723   string ret;
    724   for (const auto& fdef : gdef.library().function()) {
    725     strings::StrAppend(&ret, Print(fdef));
    726   }
    727   strings::StrAppend(&ret, "\n");
    728   for (const auto& ndef : gdef.node()) {
    729     strings::StrAppend(&ret, Print(ndef), "\n");
    730   }
    731   return ret;
    732 }
    733 
    734 namespace {
    735 
    736 // Returns the name -> attr mapping of fdef's attrs that have a value set. In
    737 // Python, it's possible to access unset attrs, which returns a default value
    738 // and adds an unset attr to the map.
    739 std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
    740   std::map<string, AttrValue> set_attrs;
    741   for (auto pair : fdef.attr()) {
    742     if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
    743       set_attrs[pair.first] = pair.second;
    744     }
    745   }
    746   return set_attrs;
    747 }
    748 
    749 }  // end namespace
    750 
    751 bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
    752   if (!OpDefEqual(f1.signature(), f2.signature())) return false;
    753 
    754   std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
    755   std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
    756   if (f1_attrs.size() != f2_attrs.size()) return false;
    757   for (auto iter1 : f1_attrs) {
    758     auto iter2 = f2_attrs.find(iter1.first);
    759     if (iter2 == f2_attrs.end()) return false;
    760     if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
    761   }
    762 
    763   if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
    764     return false;
    765   }
    766 
    767   std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
    768   std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
    769   if (ret1 != ret2) return false;
    770 
    771   return true;
    772 }
    773 
    774 uint64 FunctionDefHash(const FunctionDef& fdef) {
    775   // signature
    776   uint64 h = OpDefHash(fdef.signature());
    777 
    778   // attrs
    779   std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
    780   for (const auto& p : attrs) {
    781     h = Hash64(p.first.data(), p.first.size(), h);
    782     h = Hash64Combine(AttrValueHash(p.second), h);
    783   }
    784 
    785   // node defs
    786   h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
    787 
    788   // output names
    789   std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
    790   for (const auto& p : ret) {
    791     h = Hash64(p.first.data(), p.first.size(), h);
    792     h = Hash64(p.second.data(), p.second.size(), h);
    793   }
    794 
    795   return h;
    796 }
    797 
    798 string Canonicalize(const string& funcname, AttrSlice attrs,
    799                     const FunctionLibraryRuntime::InstantiateOptions& options) {
    800   std::vector<string> entries;
    801   entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
    802   for (auto p : attrs) {
    803     entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
    804   }
    805   if (!options.target.empty()) {
    806     entries.push_back(
    807         strings::StrCat("_target", "=", str_util::CEscape(options.target)));
    808   }
    809   if (options.overlay_lib) {
    810     entries.push_back(strings::StrCat(
    811         "_overlay_lib", "=", reinterpret_cast<uintptr_t>(options.overlay_lib)));
    812   }
    813   if (!options.state_handle.empty()) {
    814     entries.push_back(
    815         strings::StrCat("_state_handle", "=", options.state_handle));
    816   }
    817   std::sort(entries.begin(), entries.end());
    818   return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
    819 }
    820 
    821 FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
    822                                      DataTypeSlice ret_types)
    823     : arg_types_(arg_types.begin(), arg_types.end()),
    824       ret_types_(ret_types.begin(), ret_types.end()) {
    825   args_.resize(arg_types_.size());
    826   rets_.resize(ret_types_.size());
    827 }
    828 
    829 FunctionCallFrame::~FunctionCallFrame() {}
    830 
    831 Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
    832   // Input type checks.
    833   if (args.size() != arg_types_.size()) {
    834     return errors::InvalidArgument("Expects ", arg_types_.size(),
    835                                    " arguments, but ", args.size(),
    836                                    " is provided");
    837   }
    838   for (size_t i = 0; i < args.size(); ++i) {
    839     if (arg_types_[i] != args[i].dtype()) {
    840       return errors::InvalidArgument(
    841           "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
    842           DataTypeString(args[i].dtype()), " is provided");
    843     }
    844     args_[i] = args[i];
    845   }
    846   return Status::OK();
    847 }
    848 
    849 Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
    850   rets->clear();
    851   rets->reserve(rets_.size());
    852   for (size_t i = 0; i < rets_.size(); ++i) {
    853     const auto& item = rets_[i];
    854     if (item.has_val) {
    855       rets->push_back(item.val);
    856     } else {
    857       return errors::Internal("Retval[", i, "] does not have value");
    858     }
    859   }
    860   return Status::OK();
    861 }
    862 
    863 Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
    864   rets->clear();
    865   rets->reserve(rets_.size());
    866   for (size_t i = 0; i < rets_.size(); ++i) {
    867     if (rets_[i].has_val) {
    868       rets->emplace_back(std::move(rets_[i].val));
    869     } else {
    870       return errors::Internal("Retval[", i, "] does not have value");
    871     }
    872   }
    873   return Status::OK();
    874 }
    875 
    876 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
    877   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
    878     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
    879                                    args_.size(), ")");
    880   }
    881   *val = args_[index];
    882   return Status::OK();
    883 }
    884 
    885 Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
    886   if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
    887     return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
    888                                    rets_.size(), ")");
    889   }
    890   if (val.dtype() != ret_types_[index]) {
    891     return errors::InvalidArgument(
    892         "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
    893         ", but ", DataTypeString(val.dtype()), " is provided.");
    894   }
    895   Retval* item = &rets_[index];
    896   if (!item->has_val) {
    897     item->has_val = true;
    898     item->val = val;
    899   } else {
    900     return errors::Internal("Retval[", index, "] has already been set.");
    901   }
    902   return Status::OK();
    903 }
    904 
    905 FunctionLibraryDefinition::FunctionDefAndOpRegistration::
    906     FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
    907     : fdef(fdef_in),
    908       // Exact shape inference for functions is handled by ShapeRefiner.
    909       // Here we pass a dummy shape inference function for legacy code paths.
    910       op_registration_data(fdef.signature(), shape_inference::UnknownShape,
    911                            true /* is_function */) {}
    912 
    913 FunctionLibraryDefinition::FunctionLibraryDefinition(
    914     const FunctionLibraryDefinition& other)
    915     : default_registry_(other.default_registry_), func_grad_(other.func_grad_) {
    916   for (const auto& it : other.function_defs_) {
    917     TF_CHECK_OK(AddFunctionDef(it.second->fdef));
    918   }
    919 }
    920 
    921 FunctionLibraryDefinition::FunctionLibraryDefinition(
    922     const OpRegistryInterface* default_registry,
    923     const FunctionDefLibrary& def_lib)
    924     : default_registry_(default_registry),
    925       function_defs_(def_lib.function_size()) {
    926   for (const auto& fdef : def_lib.function()) {
    927     // The latter function definition wins.
    928     auto& ptr = function_defs_[fdef.signature().name()];
    929     ptr.reset(new FunctionDefAndOpRegistration(fdef));
    930   }
    931   for (const auto& grad : def_lib.gradient()) {
    932     func_grad_[grad.function_name()] = grad.gradient_func();
    933   }
    934 }
    935 
    936 FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
    937 
    938 const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
    939   auto iter = function_defs_.find(name);
    940   if (iter == function_defs_.end()) {
    941     return nullptr;
    942   } else {
    943     return &iter->second->fdef;
    944   }
    945 }
    946 
    947 Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
    948   bool added;
    949   return AddFunctionDefHelper(fdef, &added);
    950 }
    951 
    952 Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
    953                                                        bool* added) {
    954   *added = false;
    955   std::unique_ptr<FunctionDefAndOpRegistration>* entry =
    956       &function_defs_[fdef.signature().name()];
    957   if (*entry != nullptr) {
    958     if (!FunctionDefsEqual((*entry)->fdef, fdef)) {
    959       return errors::InvalidArgument(
    960           "Cannot add function '", fdef.signature().name(),
    961           "' because a different function with the same name already "
    962           "exists.");
    963     }
    964     // Ignore duplicate FunctionDefs
    965     return Status::OK();
    966   }
    967   const OpDef* op_def;
    968   if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
    969     return errors::InvalidArgument(
    970         "Cannot add function '", fdef.signature().name(),
    971         "' because an op with the same name already exists.");
    972   }
    973   entry->reset(new FunctionDefAndOpRegistration(fdef));
    974   *added = true;
    975   return Status::OK();
    976 }
    977 
    978 Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
    979   bool added;
    980   return AddGradientDefHelper(grad, &added);
    981 }
    982 
    983 Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
    984                                                        bool* added) {
    985   *added = false;
    986   string* entry = &func_grad_[grad.function_name()];
    987   if (!entry->empty()) {
    988     if (*entry != grad.gradient_func()) {
    989       return errors::InvalidArgument(
    990           "Cannot assign gradient function '", grad.gradient_func(), "' to '",
    991           grad.function_name(), "' because it already has gradient function ",
    992           "'", *entry, "'");
    993     }
    994     // Ignore duplicate GradientDefs
    995     return Status::OK();
    996   }
    997   *entry = grad.gradient_func();
    998   *added = true;
    999   return Status::OK();
   1000 }
   1001 
   1002 Status FunctionLibraryDefinition::AddLibrary(
   1003     const FunctionLibraryDefinition& other) {
   1004   // Remember the funcs and grads that we added successfully so that
   1005   // we can roll them back on error.
   1006   std::vector<string> funcs;
   1007   std::vector<string> funcs_with_grads;
   1008   Status s;
   1009   bool added;
   1010   for (auto iter : other.function_defs_) {
   1011     s = AddFunctionDefHelper(iter.second->fdef, &added);
   1012     if (!s.ok()) {
   1013       Remove(funcs, funcs_with_grads);
   1014       return s;
   1015     }
   1016     if (added) {
   1017       funcs.push_back(iter.second->fdef.signature().name());
   1018     }
   1019   }
   1020   for (auto iter : other.func_grad_) {
   1021     GradientDef grad;
   1022     grad.set_function_name(iter.first);
   1023     grad.set_gradient_func(iter.second);
   1024     s = AddGradientDefHelper(grad, &added);
   1025     if (!s.ok()) {
   1026       Remove(funcs, funcs_with_grads);
   1027       return s;
   1028     }
   1029     if (added) {
   1030       funcs_with_grads.push_back(grad.function_name());
   1031     }
   1032   }
   1033   return Status::OK();
   1034 }
   1035 
   1036 Status FunctionLibraryDefinition::AddLibrary(
   1037     const FunctionDefLibrary& lib_def) {
   1038   // Remember the funcs and grads that we added successfully so that
   1039   // we can roll them back on error.
   1040   std::vector<string> funcs;
   1041   std::vector<string> funcs_with_grads;
   1042   Status s;
   1043   bool added;
   1044   for (const FunctionDef& fdef : lib_def.function()) {
   1045     s = AddFunctionDefHelper(fdef, &added);
   1046     if (!s.ok()) {
   1047       Remove(funcs, funcs_with_grads);
   1048       return s;
   1049     }
   1050     if (added) {
   1051       funcs.push_back(fdef.signature().name());
   1052     }
   1053   }
   1054   for (const GradientDef& grad : lib_def.gradient()) {
   1055     s = AddGradientDefHelper(grad, &added);
   1056     if (!s.ok()) {
   1057       Remove(funcs, funcs_with_grads);
   1058       return s;
   1059     }
   1060     if (added) {
   1061       funcs_with_grads.push_back(grad.function_name());
   1062     }
   1063   }
   1064   return Status::OK();
   1065 }
   1066 
   1067 Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
   1068   const auto& i = function_defs_.find(func);
   1069   if (i == function_defs_.end()) {
   1070     return errors::InvalidArgument("Tried to remove non-existent function ",
   1071                                    func);
   1072   }
   1073   function_defs_.erase(i);
   1074   return Status::OK();
   1075 }
   1076 
   1077 Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
   1078   const auto& i = func_grad_.find(func);
   1079   if (i == func_grad_.end()) {
   1080     return errors::InvalidArgument("Tried to remove non-existent gradient ",
   1081                                    func);
   1082   }
   1083   func_grad_.erase(i);
   1084   return Status::OK();
   1085 }
   1086 
   1087 void FunctionLibraryDefinition::Remove(
   1088     const std::vector<string>& funcs,
   1089     const std::vector<string>& funcs_with_grads) {
   1090   for (const string& f : funcs) {
   1091     Status s = RemoveFunction(f);
   1092     DCHECK(s.ok());
   1093   }
   1094   for (const string& f : funcs_with_grads) {
   1095     Status s = RemoveGradient(f);
   1096     DCHECK(s.ok());
   1097   }
   1098 }
   1099 
   1100 string FunctionLibraryDefinition::FindGradient(const string& func) const {
   1101   return gtl::FindWithDefault(func_grad_, func, "");
   1102 }
   1103 
   1104 Status FunctionLibraryDefinition::LookUp(
   1105     const string& op, const OpRegistrationData** op_reg_data) const {
   1106   auto iter = function_defs_.find(op);
   1107   if (iter != function_defs_.end()) {
   1108     *op_reg_data = &iter->second->op_registration_data;
   1109     return Status::OK();
   1110   }
   1111   return default_registry_->LookUp(op, op_reg_data);
   1112 }
   1113 
   1114 const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
   1115     const NodeDef& ndef) const {
   1116   if (ndef.op() != kGradientOp) {
   1117     // If 'ndef' calls a function and the function's def has the attr,
   1118     // returns it.
   1119     return Find(ndef.op());
   1120   }
   1121 
   1122   // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
   1123   // Foo's attributes.
   1124   const NameAttrList* forward_func_attrs;
   1125   if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) {
   1126     return nullptr;
   1127   }
   1128   const string& func_name = forward_func_attrs->name();
   1129   const string& grad_name = FindGradient(func_name);
   1130   // If 'func' has a user-defined gradient function, uses the grad
   1131   // function's attrs to see if noinline is specified. Otherwise,
   1132   // uses func's attrs.
   1133   if (!grad_name.empty()) {
   1134     return Find(grad_name);
   1135   }
   1136   return Find(func_name);
   1137 }
   1138 
   1139 FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
   1140   FunctionDefLibrary lib;
   1141   for (const auto& f : function_defs_) {
   1142     *lib.add_function() = f.second->fdef;
   1143   }
   1144   for (const auto& g : func_grad_) {
   1145     GradientDef* gd = lib.add_gradient();
   1146     gd->set_function_name(g.first);
   1147     gd->set_gradient_func(g.second);
   1148   }
   1149   return lib;
   1150 }
   1151 
   1152 template <typename T>
   1153 Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
   1154                                           const string& attr, T* value) const {
   1155   const FunctionDef* fdef = GetAttrImpl(ndef);
   1156   if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
   1157     return Status::OK();
   1158   }
   1159   return errors::InvalidArgument("Attr ", attr, " is not defined.");
   1160 }
   1161 
   1162 template <typename T>
   1163 Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
   1164                                           T* value) const {
   1165   return GetAttr(node.def(), attr, value);
   1166 }
   1167 
   1168 #define GET_ATTR(T)                                                            \
   1169   template Status FunctionLibraryDefinition::GetAttr(const Node&,              \
   1170                                                      const string&, T*) const; \
   1171   template Status FunctionLibraryDefinition::GetAttr(const NodeDef&,           \
   1172                                                      const string&, T*) const;
   1173 GET_ATTR(string)
   1174 GET_ATTR(bool)
   1175 #undef GET_ATTR
   1176 
   1177 void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
   1178   if (val.size() >= 2 && val[0] == '$') {
   1179     proto.set_placeholder(val.data() + 1, val.size() - 1);
   1180   } else {
   1181     SetAttrValue(val, &proto);
   1182   }
   1183 }
   1184 
   1185 FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
   1186     const string& name,
   1187     gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
   1188   AttrValueWrapper ret;
   1189   ret.proto.mutable_func()->set_name(name);
   1190   for (const auto& a : attrs) {
   1191     ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
   1192   }
   1193   return ret;
   1194 }
   1195 
   1196 NodeDef FunctionDefHelper::Node::ToNodeDef() const {
   1197   NodeDef n;
   1198   n.set_op(this->op);
   1199   n.set_name(this->ret[0]);
   1200   for (const auto& a : this->attr) {
   1201     n.mutable_attr()->insert({a.first, a.second.proto});
   1202   }
   1203   for (const string& a : this->arg) {
   1204     n.add_input(a);
   1205   }
   1206   for (const string& d : this->dep) {
   1207     n.add_input(strings::StrCat("^", d));
   1208   }
   1209   return n;
   1210 }
   1211 
   1212 /* static */
   1213 FunctionDef FunctionDefHelper::Create(
   1214     const string& function_name, gtl::ArraySlice<string> in_def,
   1215     gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
   1216     gtl::ArraySlice<Node> node_def,
   1217     gtl::ArraySlice<std::pair<string, string>> ret_def) {
   1218   FunctionDef fdef;
   1219 
   1220   // Signature
   1221   OpDefBuilder b(function_name);
   1222   for (const auto& i : in_def) b.Input(i);
   1223   for (const auto& o : out_def) b.Output(o);
   1224   for (const auto& a : attr_def) b.Attr(a);
   1225 
   1226   OpRegistrationData op_reg_data;
   1227   TF_CHECK_OK(b.Finalize(&op_reg_data));
   1228   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
   1229 
   1230   // Function body
   1231   for (const auto& n : node_def) {
   1232     *(fdef.add_node_def()) = n.ToNodeDef();
   1233   }
   1234 
   1235   // Returns
   1236   for (const auto& r : ret_def) {
   1237     fdef.mutable_ret()->insert({r.first, r.second});
   1238   }
   1239   return fdef;
   1240 }
   1241 
   1242 /* static */
   1243 FunctionDef FunctionDefHelper::Define(const string& name,
   1244                                       gtl::ArraySlice<string> arg_def,
   1245                                       gtl::ArraySlice<string> ret_def,
   1246                                       gtl::ArraySlice<string> attr_def,
   1247                                       gtl::ArraySlice<Node> node_def) {
   1248   FunctionDef fdef;
   1249   OpDefBuilder b(name);
   1250   for (const auto& a : arg_def) b.Input(a);
   1251   for (const auto& r : ret_def) b.Output(r);
   1252   for (const auto& a : attr_def) b.Attr(a);
   1253 
   1254   OpRegistrationData op_reg_data;
   1255   TF_CHECK_OK(b.Finalize(&op_reg_data));
   1256   fdef.mutable_signature()->Swap(&op_reg_data.op_def);
   1257 
   1258   // Mapping from legacy output names to NodeDef outputs.
   1259   std::unordered_map<string, string> ret_index;
   1260   for (const auto& a : fdef.signature().input_arg()) {
   1261     ret_index[a.name()] = a.name();
   1262   }
   1263 
   1264   // For looking up OpDefs
   1265   auto* op_def_registry = OpRegistry::Global();
   1266 
   1267   // Function body
   1268   for (const auto& src : node_def) {
   1269     NodeDef* n = fdef.add_node_def();
   1270     n->set_op(src.op);
   1271     n->set_name(src.ret[0]);
   1272     for (const auto& a : src.attr) {
   1273       n->mutable_attr()->insert({a.first, a.second.proto});
   1274     }
   1275     for (const string& a : src.arg) {
   1276       const auto iter = ret_index.find(a);
   1277       CHECK(iter != ret_index.end())
   1278           << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
   1279       n->add_input(iter->second);
   1280     }
   1281     for (const string& d : src.dep) {
   1282       n->add_input(strings::StrCat("^", d));
   1283     }
   1284 
   1285     // Add the outputs of this node to ret_index.
   1286     const OpDef* op_def = nullptr;
   1287     TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
   1288     CHECK(op_def != nullptr) << n->op();
   1289     NameRangeMap output_names;
   1290     TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
   1291     for (const auto& o : output_names) {
   1292       CHECK_LE(o.second.second, src.ret.size())
   1293           << "Missing ret for output '" << o.first << "' in '" << src.ret[0]
   1294           << "' of " << name;
   1295       for (int i = o.second.first; i < o.second.second; ++i) {
   1296         ret_index[src.ret[i]] =
   1297             strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
   1298       }
   1299     }
   1300   }
   1301 
   1302   // Returns
   1303   for (const auto& r : fdef.signature().output_arg()) {
   1304     const auto iter = ret_index.find(r.name());
   1305     CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
   1306     fdef.mutable_ret()->insert({r.name(), iter->second});
   1307   }
   1308   return fdef;
   1309 }
   1310 
   1311 FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
   1312                                       gtl::ArraySlice<string> ret_def,
   1313                                       gtl::ArraySlice<string> attr_def,
   1314                                       gtl::ArraySlice<Node> node_def) {
   1315   return Define("_", arg_def, ret_def, attr_def, node_def);
   1316 }
   1317 
   1318 namespace gradient {
   1319 
   1320 typedef std::unordered_map<string, Creator> OpGradFactory;
   1321 
   1322 OpGradFactory* GetOpGradFactory() {
   1323   static OpGradFactory* factory = new OpGradFactory;
   1324   return factory;
   1325 }
   1326 
   1327 bool RegisterOp(const string& op, Creator func) {
   1328   CHECK(GetOpGradFactory()->insert({op, func}).second)
   1329       << "Duplicated gradient for " << op;
   1330   return true;
   1331 }
   1332 
   1333 Status GetOpGradientCreator(const string& op, Creator* creator) {
   1334   auto fac = GetOpGradFactory();
   1335   auto iter = fac->find(op);
   1336   if (iter == fac->end()) {
   1337     return errors::NotFound("No gradient defined for op: ", op);
   1338   }
   1339   *creator = iter->second;
   1340   return Status::OK();
   1341 }
   1342 
   1343 }  // end namespace gradient
   1344 
   1345 }  // end namespace tensorflow
   1346