Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <utility>
     17 
     18 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
     19 
     20 #include "tensorflow/cc/framework/ops.h"
     21 #include "tensorflow/cc/ops/standard_ops.h"
     22 #include "tensorflow/core/framework/function_testlib.h"
     23 #include "tensorflow/core/graph/graph_constructor.h"
     24 #include "tensorflow/core/graph/graph_def_builder.h"
     25 #include "tensorflow/core/lib/core/status_test_util.h"
     26 #include "tensorflow/core/platform/test.h"
     27 #include "tensorflow/core/util/equal_graph_def.h"
     28 
     29 namespace tensorflow {
     30 namespace {
     31 
     32 template <class Tkey, class Tvalue>
     33 bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
     34                    const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
     35                    const std::function<string(const Tkey&)>& key_to_string,
     36                    const std::function<string(const Tvalue&)>& value_to_string,
     37                    const std::function<bool(const Tkey&, const Tvalue&,
     38                                             const Tvalue&)>& compare,
     39                    const string& map_name, string* diff) {
     40   for (const auto& elt_a : a) {
     41     const auto iter = b.find(elt_a.first);
     42     if (iter == b.end()) {
     43       if (diff) {
     44         *diff = strings::StrCat(
     45             map_name, " expected: contains element with key '",
     46             key_to_string(elt_a.first), "' got: map has no such element");
     47       }
     48       return false;
     49     }
     50     if (!compare(elt_a.first, elt_a.second, iter->second)) {
     51       if (diff) {
     52         *diff = strings::StrCat(map_name, " expected: element with key '",
     53                                 key_to_string(elt_a.first), " has value '",
     54                                 value_to_string(elt_a.second), "' got: '",
     55                                 value_to_string(iter->second), "'");
     56       }
     57       return false;
     58     }
     59   }
     60   for (const auto& elt_b : b) {
     61     const auto iter = a.find(elt_b.first);
     62     if (iter == a.end()) {
     63       if (diff) {
     64         *diff = strings::StrCat(map_name, " got: contains element with key '",
     65                                 key_to_string(elt_b.first),
     66                                 "' expected: map has no such element");
     67       }
     68       return false;
     69     }
     70   }
     71   return true;
     72 }
     73 
     74 bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
     75                           const string& diff_preamble, string* diff) {
     76   if (a.op() != b.op()) {
     77     if (diff) {
     78       *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
     79                               ", expected op '", a.op(), "' got '", b.op());
     80     }
     81     return false;
     82   }
     83   if (a.device() != b.device()) {
     84     if (diff) {
     85       *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
     86                               ", expected device '", a.device(), "' got '",
     87                               b.device());
     88     }
     89     return false;
     90   }
     91   if (a.input_size() != b.input_size()) {
     92     if (diff) {
     93       *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
     94                               ", expected ", a.input_size(), " inputs got ",
     95                               b.input_size(), " expected:\n", a.DebugString(),
     96                               "\ngot:\n", b.DebugString());
     97     }
     98     return false;
     99   }
    100   for (int i = 0; i < a.input_size(); ++i) {
    101     if (a.input(i) != b.input(i)) {
    102       if (diff) {
    103         *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
    104                                 " input ", i, ", expected ", a.input(i),
    105                                 " got ", b.input(i), " expected:\n",
    106                                 a.DebugString(), "\ngot:\n", b.DebugString());
    107       }
    108       return false;
    109     }
    110   }
    111   return EqualProtoMap<string, AttrValue>(
    112       a.attr(), b.attr(), [](const string& s) { return s; },
    113       [](const AttrValue& v) { return v.DebugString(); },
    114       [](const string& key, const AttrValue& av, const AttrValue& bv) {
    115         if (key == "shape_inference_graph") {
    116           // Default serialization of GraphDef is unstable because maps don't
    117           // serialize deterministically. Rather than go through the hoops to
    118           // turn on deterministic serialization of this attr just for this
    119           // test, add logic here to compare determinstically.
    120           GraphDef ga;
    121           if (!ga.ParseFromString(av.s())) {
    122             return false;
    123           }
    124           GraphDef gb;
    125           if (!gb.ParseFromString(bv.s())) {
    126             return false;
    127           }
    128           return EqualGraphDef(ga, gb, nullptr);
    129         } else {
    130           return av.DebugString() == bv.DebugString();
    131         }
    132       },
    133       strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
    134       diff);
    135 }
    136 
    137 bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
    138                       string* diff) {
    139   if (a.signature().DebugString() != b.signature().DebugString()) {
    140     if (diff) {
    141       *diff = strings::StrCat("Signature mismatch for function ",
    142                               a.signature().name(), ", expected:\n",
    143                               a.signature().DebugString(), "\ngot:\n",
    144                               b.signature().DebugString());
    145     }
    146     return false;
    147   }
    148   if (!EqualProtoMap<string, AttrValue>(
    149           a.attr(), b.attr(), [](const string& s) { return s; },
    150           [](const AttrValue& v) { return v.DebugString(); },
    151           [](const string& key, const AttrValue& av, const AttrValue& bv) {
    152             return av.DebugString() == bv.DebugString();
    153           },
    154           strings::StrCat("attr mismatch for function ", a.signature().name()),
    155           diff)) {
    156     return false;
    157   }
    158   if (!EqualProtoMap<string, string>(
    159           a.ret(), b.ret(), [](const string& s) { return s; },
    160           [](const string& s) { return s; },
    161           [](const string& key, const string& av, const string& bv) {
    162             return av == bv;
    163           },
    164           strings::StrCat("ret mismatch for function ", a.signature().name()),
    165           diff)) {
    166     return false;
    167   }
    168   for (int i = 0; i < a.node_def_size(); ++i) {
    169     bool found = false;
    170     for (int j = 0; j < b.node_def_size(); ++j) {
    171       if (a.node_def(i).name() == b.node_def(j).name()) {
    172         if (!EqualFunctionNodeDef(
    173                 a.node_def(i), b.node_def(j),
    174                 strings::StrCat("Function ", a.signature().name()), diff)) {
    175           return false;
    176         }
    177         found = true;
    178         break;
    179       }
    180     }
    181     if (!found) {
    182       if (diff) {
    183         *diff = strings::StrCat("Function ", a.signature().name(),
    184                                 ", expected: has node '", a.node_def(i).name(),
    185                                 "' got: no node of that name");
    186       }
    187       return false;
    188     }
    189   }
    190   for (int i = 0; i < b.node_def_size(); ++i) {
    191     bool found = false;
    192     for (int j = 0; j < a.node_def_size(); ++j) {
    193       if (b.node_def(i).name() == a.node_def(j).name()) {
    194         found = true;
    195         break;
    196       }
    197     }
    198     if (!found) {
    199       if (diff) {
    200         *diff = strings::StrCat("Function ", a.signature().name(),
    201                                 ", got: has node '", b.node_def(i).name(),
    202                                 "' expected: no node of that name");
    203       }
    204       return false;
    205     }
    206   }
    207   return true;
    208 }
    209 
    210 bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
    211                              const FunctionDefLibrary& actual, string* diff) {
    212   std::unordered_map<string, const FunctionDef*> actual_index;
    213   for (const FunctionDef& function : actual.function()) {
    214     actual_index[function.signature().name()] = &function;
    215   }
    216 
    217   for (const FunctionDef& expected_function : expected.function()) {
    218     auto it = actual_index.find(expected_function.signature().name());
    219     if (it == actual_index.end()) {
    220       if (diff) {
    221         *diff = strings::StrCat("Did not find expected function '",
    222                                 expected_function.signature().name(), "'");
    223       }
    224       return false;
    225     }
    226     if (!EqualFunctionDef(expected_function, *it->second, diff)) return false;
    227     actual_index.erase(it);
    228   }
    229 
    230   if (!actual_index.empty()) {
    231     if (diff != nullptr) {
    232       *diff = strings::StrCat("Found unexpected function '",
    233                               actual_index.begin()->second->signature().name(),
    234                               "'");
    235     }
    236     return false;
    237   }
    238 
    239   return true;
    240 }
    241 
    242 #define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual)         \
    243   do {                                                            \
    244     string diff;                                                  \
    245     EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \
    246         << diff << "\nActual: " << actual.DebugString();          \
    247   } while (false)
    248 
    249 // TODO(misard): remove these fake registrations once there are real Ops to be
    250 // compiled.
    251 REGISTER_OP("_XlaHostCompute")
    252     .Input("inputs: Tinputs")
    253     .Output("outputs: Toutputs")
    254     .Attr("Tinputs: list(type) >= 0")
    255     .Attr("Toutputs: list(type) >= 0")
    256     .Attr("key: string")
    257     .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
    258 
    259 REGISTER_OP("_XlaSendFromHost")
    260     .Input("input: Tinputs")
    261     .Attr("Tinputs: list(type) >= 0")
    262     .Attr("key: string")
    263     .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
    264 
    265 REGISTER_OP("_XlaRecvAtHost")
    266     .Output("output: Toutputs")
    267     .Attr("Toutputs: list(type) >= 0")
    268     .Attr("key: string")
    269     .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
    270 
    271 REGISTER_OP("InputTest")
    272     .Output("o: float")
    273     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    274       c->set_output(0, c->UnknownShape());
    275       return Status::OK();
    276     });
    277 
    278 REGISTER_OP("InputTestShaped")
    279     .Output("o: float")
    280     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    281       c->set_output(0, c->Vector(2));
    282       return Status::OK();
    283     });
    284 
    285 REGISTER_OP("UnaryTest")
    286     .Input("a: float")
    287     .Output("o: float")
    288     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    289       ::tensorflow::shape_inference::ShapeHandle o;
    290       TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
    291       c->set_output(0, o);
    292       return Status::OK();
    293     });
    294 REGISTER_OP("BinaryTest")
    295     .Input("a: float")
    296     .Input("b: float")
    297     .Output("o: float")
    298     .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    299       ::tensorflow::shape_inference::ShapeHandle o;
    300       TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
    301       c->set_output(0, o);
    302       return Status::OK();
    303     });
    304 REGISTER_OP("BinaryTest2")
    305     .Input("a: float")
    306     .Input("b: float")
    307     .Output("o: float")
    308     .SetShapeFn(::tensorflow::shape_inference::UnknownShape);
    309 
    310 REGISTER_OP("AddNLikeTest")
    311     .Input("inputs: N * T")
    312     .Output("sum: T")
    313     .Attr("N: int >= 1")
    314     .Attr("T: numbertype")
    315     .SetIsCommutative()
    316     .SetIsAggregate();
    317 
    318 Node* NoOp(const GraphDefBuilder::Options& opts) {
    319   return ops::SourceOp("NoOp", opts);
    320 }
    321 
    322 Node* Input(const GraphDefBuilder::Options& opts) {
    323   return ops::SourceOp("InputTest", opts);
    324 }
    325 
    326 Node* InputShaped(const GraphDefBuilder::Options& opts) {
    327   return ops::SourceOp("InputTestShaped", opts);
    328 }
    329 
    330 Node* KnownShape(const gtl::ArraySlice<int>& shape,
    331                  const GraphDefBuilder::Options& opts) {
    332   if (opts.HaveError()) return nullptr;
    333   NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
    334                            opts.op_registry());
    335   TensorProto value;
    336   value.set_dtype(DT_FLOAT);
    337   for (int dim : shape) {
    338     value.mutable_tensor_shape()->add_dim()->set_size(dim);
    339   }
    340   return opts.WithAttr("value", value)
    341       .WithAttr("dtype", DT_FLOAT)
    342       .FinalizeBuilder(&node_builder);
    343 }
    344 
    345 Node* RecvAtHost(const string& key, const gtl::ArraySlice<DataType>& dtypes,
    346                  const GraphDefBuilder::Options& opts) {
    347   if (opts.HaveError()) return nullptr;
    348   NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
    349                            "_XlaRecvAtHost", opts.op_registry());
    350   return opts.WithAttr("Toutputs", dtypes)
    351       .WithAttr("key", key)
    352       .FinalizeBuilder(&node_builder);
    353 }
    354 
    355 Node* SendFromHost(const string& key, const std::vector<ops::NodeOut>& inputs,
    356                    const GraphDefBuilder::Options& opts) {
    357   if (opts.HaveError()) return nullptr;
    358   NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
    359                            "_XlaSendFromHost", opts.op_registry());
    360   node_builder.Input(inputs);
    361   std::vector<DataType> dtypes;
    362   for (const auto& node : inputs) {
    363     dtypes.push_back(node.dt);
    364   }
    365   return opts.WithAttr("key", key)
    366       .WithAttr("Tinputs", dtypes)
    367       .FinalizeBuilder(&node_builder);
    368 }
    369 
    370 Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
    371   return ops::UnaryOp("UnaryTest", std::move(a), opts);
    372 }
    373 
    374 Node* Binary(ops::NodeOut a, ops::NodeOut b,
    375              const GraphDefBuilder::Options& opts) {
    376   return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
    377 }
    378 
    379 Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
    380                          const GraphDefBuilder::Options& opts) {
    381   return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
    382 }
    383 
    384 Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
    385                const GraphDefBuilder::Options& opts) {
    386   if (opts.HaveError()) return nullptr;
    387   NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest",
    388                            opts.op_registry());
    389   node_builder.Input(inputs);
    390   return opts.FinalizeBuilder(&node_builder);
    391 }
    392 
    393 Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) {
    394   return ops::SourceOp("_Arg",
    395                        opts.WithAttr("T", type).WithAttr("index", index));
    396 }
    397 
    398 Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
    399   if (opts.HaveError()) return nullptr;
    400   NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
    401                            opts.op_registry());
    402   node_builder.Input(std::move(a)).Attr("index", index);
    403   return opts.FinalizeBuilder(&node_builder);
    404 }
    405 
    406 Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
    407   Status s;
    408   // Convert the GraphDef to a Graph
    409   std::unique_ptr<FunctionLibraryDefinition> lib_def(
    410       new FunctionLibraryDefinition(OpRegistry::Global(), *library));
    411   GraphConstructorOptions options;
    412   options.allow_internal_ops = true;
    413   std::unique_ptr<Graph> graph(new Graph(lib_def.get()));
    414   s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
    415   if (!s.ok()) return s;
    416 
    417   std::unique_ptr<Graph> graph_out;
    418   s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph,
    419                                       /*rewrite_subgraph_fn=*/{},
    420                                       /*parallel_checking=*/false,
    421                                       /*reuse_existing_functions=*/false,
    422                                       &graph_out, lib_def.get());
    423   if (!s.ok()) return s;
    424 
    425   GraphDef graphdef_out;
    426   graph_out->ToGraphDef(&graphdef_out);
    427   graphdef->Swap(&graphdef_out);
    428 
    429   *library = lib_def->ToProto();
    430   return s;
    431 }
    432 
    433 // If there are no marked nodes, funcification should be a no-op.
    434 TEST(EncapsulateSubgraphsTest, NoFunctions) {
    435   GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
    436 
    437   Node* a = Input(builder.opts().WithName("A"));
    438   Node* b = Input(builder.opts().WithName("B"));
    439   Node* c = Unary(a, builder.opts().WithName("C"));
    440   Binary(b, c, builder.opts().WithName("D"));
    441 
    442   GraphDef graphdef_in;
    443   FunctionDefLibrary library_in;
    444   TF_EXPECT_OK(builder.ToGraphDef(&graphdef_in));
    445   *library_in.add_function() = test::function::XTimesTwo();
    446 
    447   GraphDef graphdef_out = graphdef_in;
    448   FunctionDefLibrary library_out = library_in;
    449   TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
    450 
    451   // If there are no marked nodes, funcification should be a no-op.
    452   TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
    453   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
    454 }
    455 
    456 // Test with one function to transform.
    457 TEST(EncapsulateSubgraphsTest, OneFunction) {
    458   FunctionDefLibrary library;
    459   GraphDef graphdef;
    460 
    461   {
    462     *library.add_function() = test::function::XTimesTwo();
    463 
    464     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
    465     Node* a = Input(b1.opts().WithName("A"));
    466     Node* b = Input(b1.opts().WithName("B"));
    467     // Give nodes 'c' and 'd' names that collide after lowercasing.
    468     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
    469     Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr(
    470                                "_encapsulate", "F1"));
    471     Binary(a, d, b1.opts().WithName("E"));
    472     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
    473   }
    474 
    475   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
    476 
    477   FunctionDefLibrary library_expected;
    478   GraphDef graphdef_expected;
    479 
    480   *library_expected.add_function() = test::function::XTimesTwo();
    481   *library_expected.add_function() = FunctionDefHelper::Create(
    482       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {},
    483       {
    484           {{"C"}, "UnaryTest", {"a_0_arg"}},
    485           {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
    486       },
    487       {{"c_0_retval", "c:o:0"}});
    488 
    489   {
    490     std::unique_ptr<FunctionLibraryDefinition> lib_def(
    491         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
    492     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
    493     Node* a = Input(b2.opts().WithName("A"));
    494     Node* b = Input(b2.opts().WithName("B"));
    495 
    496     NodeBuilder node_builder("F1", "F1", lib_def.get());
    497     node_builder.Input(a).Input(b);
    498     Node* call = b2.opts().FinalizeBuilder(&node_builder);
    499 
    500     Binary(a, call, b2.opts().WithName("E"));
    501     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
    502   }
    503 
    504   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
    505   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
    506 }
    507 
    508 // Test with two functions to transform.
    509 TEST(EncapsulateSubgraphsTest, TwoFunctions) {
    510   FunctionDefLibrary library;
    511   GraphDef graphdef;
    512 
    513   {
    514     *library.add_function() = test::function::XTimesTwo();
    515 
    516     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
    517     Node* a = Input(b1.opts().WithName("A"));
    518     Node* b = Input(b1.opts().WithName("B"));
    519     Node* control = Input(b1.opts().WithName("Control"));
    520     Node* c =
    521         Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr(
    522                      "_encapsulate", "F1"));
    523     Node* d =
    524         Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr(
    525                          "_encapsulate", "F2"));
    526     Binary(a, d, b1.opts().WithName("E"));
    527     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
    528   }
    529 
    530   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
    531 
    532   FunctionDefLibrary library_expected;
    533   GraphDef graphdef_expected;
    534 
    535   *library_expected.add_function() = test::function::XTimesTwo();
    536   *library_expected.add_function() = FunctionDefHelper::Create(
    537       "F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {},
    538       {
    539           {{"C"}, "UnaryTest", {"a_0_arg"}},
    540       },
    541       {{"c_0_retval", "C:o:0"}});
    542   *library_expected.add_function() = FunctionDefHelper::Create(
    543       "F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {},
    544       {
    545           {{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}},
    546       },
    547       {{"d_0_retval", "D:o:0"}});
    548 
    549   {
    550     std::unique_ptr<FunctionLibraryDefinition> lib_def(
    551         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
    552     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
    553     Node* a = Input(b2.opts().WithName("A"));
    554     Node* b = Input(b2.opts().WithName("B"));
    555     Node* control = Input(b2.opts().WithName("Control"));
    556 
    557     NodeBuilder nb("F1", "F1", lib_def.get());
    558     nb.Input(a).ControlInput(control);
    559     Node* call1 = b2.opts().FinalizeBuilder(&nb);
    560 
    561     NodeBuilder nb2("F2", "F2", lib_def.get());
    562     nb2.Input(b).Input(call1).ControlInput(control);
    563     Node* call2 = b2.opts().FinalizeBuilder(&nb2);
    564 
    565     Binary(a, call2, b2.opts().WithName("E"));
    566     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
    567   }
    568 
    569   // If there are no marked nodes, funcification should be a no-op.
    570   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
    571   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
    572 }
    573 
    574 // Returns a vector of node names in 'graph', sorted by name.
    575 std::vector<string> GraphNodes(const Graph& graph) {
    576   std::vector<string> nodes;
    577   for (const auto& node : graph.nodes()) {
    578     if (!node->IsSource() && !node->IsSink()) {
    579       nodes.push_back(node->name());
    580     }
    581   }
    582   std::sort(nodes.begin(), nodes.end());
    583   return nodes;
    584 }
    585 
    586 // Returns a sorted vector of (src, dst) edges in 'graph'.
    587 std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
    588   std::vector<std::pair<string, string>> edges;
    589   for (const Edge* edge : graph.edges()) {
    590     if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
    591     edges.emplace_back(
    592         strings::StrCat(edge->src()->name(), ":", edge->src_output()),
    593         strings::StrCat(edge->dst()->name(), ":", edge->dst_input()));
    594   }
    595   std::sort(edges.begin(), edges.end());
    596   return edges;
    597 }
    598 
    599 TEST(EncapsulateSubgraphsTest, InputDeduplication) {
    600   Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
    601       "/job:localhost/replica:0/task:0/cpu:0");
    602   auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT);
    603   auto add1 = ops::Add(root.WithOpName("add1"), x, x);
    604   add1.node()->AddAttr("_cluster", "cluster1");
    605   auto add2 = ops::Add(root.WithOpName("add2"), add1, add1);
    606   add2.node()->AddAttr("_cluster", "cluster2");
    607   auto out = ops::Mul(root.WithOpName("mul"), add1, add2);
    608 
    609   Graph graph_before_encapsulation(OpRegistry::Global());
    610   TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
    611 
    612   FunctionLibraryDefinition library(OpRegistry::Global(), {});
    613   std::unique_ptr<Graph> graph;
    614   TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
    615       "_cluster", "_outside", graph_before_encapsulation,
    616       /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false,
    617       /*reuse_existing_functions=*/false, &graph, &library));
    618 
    619   std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
    620   EXPECT_EQ(expected_nodes, GraphNodes(*graph));
    621 
    622   std::vector<std::pair<string, string>> expected_edges = {
    623       {"cluster1:0", "cluster2:0"},
    624       {"cluster1:0", "mul:0"},
    625       {"cluster2:0", "mul:1"},
    626       {"x:0", "cluster1:0"}};
    627   EXPECT_EQ(expected_edges, GraphEdges(*graph));
    628 }
    629 
    630 TEST(EncapsulateSubgraphsTest, ParallelChecking) {
    631   Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
    632       "/job:localhost/replica:0/task:0/cpu:0");
    633   auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
    634   auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
    635   auto add1 = ops::Add(root.WithOpName("add1"), x1, x2);
    636   add1.node()->AddAttr("_cluster", "cluster1");
    637   auto add2 = ops::Add(root.WithOpName("add2"), add1, x2);
    638   add2.node()->AddAttr("_cluster", "cluster1");
    639   auto out = ops::Mul(root.WithOpName("mul"), x1, add2);
    640 
    641   Graph graph_before_encapsulation(OpRegistry::Global());
    642   TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
    643 
    644   FunctionLibraryDefinition library(OpRegistry::Global(), {});
    645   std::unique_ptr<Graph> graph;
    646   TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
    647       "_cluster", "_outside", graph_before_encapsulation,
    648       /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true,
    649       /*reuse_existing_functions=*/false, &graph, &library));
    650 
    651   std::vector<string> expected_nodes = {
    652       "add1", "add2", "cluster1", "cluster1_parallel_check/_0",
    653       "mul",  "x1",   "x2"};
    654   EXPECT_EQ(expected_nodes, GraphNodes(*graph));
    655 
    656   std::vector<std::pair<string, string>> expected_edges = {
    657       {"add1:0", "add2:0"},
    658       {"add2:0", "cluster1_parallel_check/_0:0"},
    659       {"cluster1:0", "cluster1_parallel_check/_0:1"},
    660       {"cluster1_parallel_check/_0:0", "mul:1"},
    661       {"x1:0", "add1:0"},
    662       {"x1:0", "cluster1:0"},
    663       {"x1:0", "mul:0"},
    664       {"x2:0", "add1:1"},
    665       {"x2:0", "add2:1"},
    666       {"x2:0", "cluster1:1"},
    667   };
    668   EXPECT_EQ(expected_edges, GraphEdges(*graph));
    669 }
    670 
    671 const Node* FindNodeByName(const Graph& graph, const string& name) {
    672   for (const Node* node : graph.nodes()) {
    673     if (node->name() == name) return node;
    674   }
    675   return nullptr;
    676 }
    677 
    678 bool HasGuaranteeConstAttr(const Node& n) {
    679   bool is_guaranteed_constant = false;
    680   if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant",
    681                    &is_guaranteed_constant)
    682            .ok()) {
    683     return false;
    684   }
    685   return is_guaranteed_constant;
    686 }
    687 
    688 TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
    689   Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
    690       "/job:localhost/replica:0/task:0/cpu:0");
    691   auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
    692   auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f);
    693   auto const_guarantee_x1 =
    694       ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
    695   auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2);
    696   add1.node()->AddAttr("_encapsulate", "encapsulate1");
    697 
    698   Graph graph_before(OpRegistry::Global());
    699   TF_ASSERT_OK(root.ToGraph(&graph_before));
    700 
    701   std::unique_ptr<Graph> graph_after;
    702   FunctionLibraryDefinition library(OpRegistry::Global(), {});
    703   int guaranteed_consts = 0;
    704   TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
    705       "_encapsulate", "_outside", graph_before,
    706       /*rewrite_subgraph_fn=*/
    707       [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
    708                            std::vector<int>* input_permutation,
    709                            std::vector<int>* output_permutation,
    710                            NodeDef* call_def) {
    711         Graph* graph = graph_ptr->get();
    712         for (const Node* n : graph->nodes()) {
    713           if (n->type_string() == "_Arg" &&
    714               StringPiece(n->name()).starts_with("const")) {
    715             ++guaranteed_consts;
    716             EXPECT_TRUE(HasGuaranteeConstAttr(*n));
    717           } else {
    718             EXPECT_FALSE(HasGuaranteeConstAttr(*n));
    719           }
    720         }
    721         return Status::OK();
    722       },
    723       /*parallel_checking=*/false,
    724       /*reuse_existing_functions=*/false, &graph_after, &library));
    725   EXPECT_EQ(2, guaranteed_consts);
    726 }
    727 
    728 TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
    729   Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
    730       "/job:localhost/replica:0/task:0/cpu:0");
    731   auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
    732   auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
    733   auto const_guarantee_x1 =
    734       ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
    735   auto const_guarantee_x2 =
    736       ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
    737   auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"),
    738                                        const_guarantee_x1, const_guarantee_x2);
    739   auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2);
    740   auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2);
    741   mul1.node()->AddAttr("_encapsulate", "encapsulate1");
    742 
    743   Graph graph_before(OpRegistry::Global());
    744   TF_ASSERT_OK(root.ToGraph(&graph_before));
    745 
    746   std::unique_ptr<Graph> graph_after;
    747   FunctionLibraryDefinition library(OpRegistry::Global(), {});
    748   int guaranteed_consts = 0;
    749   TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
    750       "_encapsulate", "_outside", graph_before,
    751       /*rewrite_subgraph_fn=*/
    752       [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
    753                            std::vector<int>* input_permutation,
    754                            std::vector<int>* output_permutation,
    755                            NodeDef* call_def) {
    756         Graph* graph = graph_ptr->get();
    757         for (const Node* n : graph->nodes()) {
    758           if (n->type_string() == "_Arg" &&
    759               StringPiece(n->name()).starts_with("const")) {
    760             ++guaranteed_consts;
    761             EXPECT_TRUE(HasGuaranteeConstAttr(*n));
    762           } else {
    763             EXPECT_FALSE(HasGuaranteeConstAttr(*n));
    764           }
    765         }
    766         return Status::OK();
    767       },
    768       /*parallel_checking=*/false,
    769       /*reuse_existing_functions=*/false, &graph_after, &library));
    770   // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const
    771   // and another non-const, so overall non-const.
    772   EXPECT_EQ(1, guaranteed_consts);
    773 }
    774 
    775 // Test with one function to transform and one outside_compilation cluster.
    776 TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
    777   FunctionDefLibrary library;
    778   GraphDef graphdef;
    779 
    780   {
    781     *library.add_function() = test::function::XTimesTwo();
    782 
    783     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
    784     Node* a = Input(b1.opts().WithName("A"));
    785     Node* b = Input(b1.opts().WithName("B"));
    786     // Give nodes 'c' and 'd' names that collide after lowercasing.
    787     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
    788     Node* d = Binary(b, c,
    789                      b1.opts().WithName("c").WithControlInput(c).WithAttr(
    790                          "_encapsulate", "F1"));
    791     Node* e = Binary(c, d,
    792                      b1.opts()
    793                          .WithName("E")
    794                          .WithControlInputs({b, d})
    795                          .WithAttr("_encapsulate", "F1")
    796                          .WithAttr("_outside", "O1"));
    797     Node* f = Binary(c, e,
    798                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
    799                          "_encapsulate", "F1"));
    800     Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
    801     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
    802   }
    803 
    804   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
    805 
    806   FunctionDefLibrary library_expected;
    807   GraphDef graphdef_expected;
    808 
    809   string shape_string_expected;
    810   {
    811     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
    812     Node* recv =
    813         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
    814                    shape.opts().WithName("outside_compilation_F1_O1_recv"));
    815     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
    816                      shape.opts().WithName("E"));
    817     SendFromHost("host_compute_channel_F1_O1", {e},
    818                  shape.opts().WithName("outside_compilation_F1_O1_send"));
    819     GraphDef shape_graph;
    820     TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
    821     EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
    822   }
    823 
    824   *library_expected.add_function() = test::function::XTimesTwo();
    825   *library_expected.add_function() = FunctionDefHelper::Create(
    826       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
    827       {
    828           {{"C"}, "UnaryTest", {"a_0_arg"}},
    829           {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
    830           {{"F"},
    831            "BinaryTest",
    832            {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
    833            {},
    834            {"outside_compilation_O1_host_compute"}},
    835           {{"outside_compilation_O1_host_compute"},
    836            "_XlaHostCompute",
    837            {"C:o:0", "c:o:0"},
    838            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
    839             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
    840             {"key", "host_compute_channel_F1_O1"},
    841             {"shape_inference_graph", shape_string_expected},
    842             {"shapes", gtl::ArraySlice<DataType>({})}},
    843            {"c"}},
    844       },
    845       {{"f_0_retval", "F:o:0"}});
    846 
    847   {
    848     std::unique_ptr<FunctionLibraryDefinition> lib_def(
    849         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
    850     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
    851     Node* a = Input(b2.opts().WithName("A"));
    852     Node* b = Input(b2.opts().WithName("B"));
    853 
    854     NodeBuilder node_builder("F1", "F1", lib_def.get());
    855     node_builder.Input(a).Input(b);
    856     Node* call = b2.opts().FinalizeBuilder(&node_builder);
    857 
    858     Node* recv =
    859         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
    860                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
    861     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
    862                      b2.opts().WithName("E").WithControlInputs({recv, b}));
    863     Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
    864                               b2.opts()
    865                                   .WithName("outside_compilation_F1_O1_send")
    866                                   .WithControlInput(e));
    867 
    868     Node* s = NoOp(
    869         b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
    870 
    871     Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
    872     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
    873   }
    874 
    875   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
    876   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
    877 }
    878 
    879 // Test with one function to transform and two outside_compilation clusters.
    880 TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
    881   FunctionDefLibrary library;
    882   GraphDef graphdef;
    883 
    884   {
    885     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
    886     Node* a = Input(b1.opts().WithName("A"));
    887     Node* b = Input(b1.opts().WithName("B"));
    888     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
    889     Node* d =
    890         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
    891     Node* e = Binary(c, d,
    892                      b1.opts()
    893                          .WithName("E")
    894                          .WithControlInputs({b, d})
    895                          .WithAttr("_encapsulate", "F1")
    896                          .WithAttr("_outside", "O1"));
    897     Node* f = Binary(c, e,
    898                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
    899                          "_encapsulate", "F1"));
    900     Node* g = Binary(e, f,
    901                      b1.opts()
    902                          .WithName("G")
    903                          .WithControlInputs({e, f})
    904                          .WithAttr("_encapsulate", "F1")
    905                          .WithAttr("_outside", "O2"));
    906     Node* h = Binary(d, e,
    907                      b1.opts()
    908                          .WithName("H")
    909                          .WithAttr("_encapsulate", "F1")
    910                          .WithAttr("_outside", "O2"));
    911     Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1"));
    912     Binary(g, i, b1.opts().WithName("J"));
    913     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
    914   }
    915 
    916   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
    917 
    918   FunctionDefLibrary library_expected;
    919   GraphDef graphdef_expected;
    920 
    921   string shape_string_expected_1;
    922   {
    923     GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
    924     Node* recv =
    925         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
    926                    shape1.opts().WithName("outside_compilation_F1_O1_recv"));
    927     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
    928                      shape1.opts().WithName("E"));
    929     SendFromHost("host_compute_channel_F1_O1", {e},
    930                  shape1.opts().WithName("outside_compilation_F1_O1_send"));
    931     GraphDef shape1_graph;
    932     TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph));
    933     EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1));
    934   }
    935 
    936   string shape_string_expected_2;
    937   {
    938     GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
    939     Node* recv1 =
    940         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
    941                    shape2.opts().WithName("outside_compilation_F1_O1_recv"));
    942     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
    943                      shape2.opts().WithName("E"));
    944     Node* recv2 =
    945         RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
    946                    shape2.opts().WithName("outside_compilation_F1_O2_recv"));
    947     Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
    948     SendFromHost("host_compute_channel_F1_O2", {h},
    949                  shape2.opts().WithName("outside_compilation_F1_O2_send"));
    950     GraphDef shape2_graph;
    951     TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph));
    952     EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2));
    953   }
    954 
    955   *library_expected.add_function() = FunctionDefHelper::Create(
    956       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
    957       {
    958           {{"C"}, "UnaryTest", {"a_0_arg"}},
    959           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
    960           {{"I"},
    961            "UnaryTest",
    962            {"outside_compilation_O2_host_compute:outputs:0"}},
    963           {{"F"},
    964            "BinaryTest",
    965            {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
    966            {},
    967            {"outside_compilation_O1_host_compute"}},
    968           {{"outside_compilation_O2_host_compute"},
    969            "_XlaHostCompute",
    970            {"D:o:0", "F:o:0"},
    971            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
    972             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
    973             {"key", "host_compute_channel_F1_O2"},
    974             {"shape_inference_graph", shape_string_expected_2},
    975             {"shapes", gtl::ArraySlice<DataType>({})}},
    976            {"F"}},
    977           {{"outside_compilation_O1_host_compute"},
    978            "_XlaHostCompute",
    979            {"C:o:0", "D:o:0"},
    980            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
    981             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
    982             {"key", "host_compute_channel_F1_O1"},
    983             {"shape_inference_graph", shape_string_expected_1},
    984             {"shapes", gtl::ArraySlice<DataType>({})}},
    985            {"D"}},
    986       },
    987       {{"i_0_retval", "I:o:0"}});
    988 
    989   {
    990     std::unique_ptr<FunctionLibraryDefinition> lib_def(
    991         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
    992     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
    993     Node* a = Input(b2.opts().WithName("A"));
    994     Node* b = Input(b2.opts().WithName("B"));
    995 
    996     NodeBuilder node_builder("F1", "F1", lib_def.get());
    997     node_builder.Input(a).Input(b);
    998     Node* call = b2.opts().FinalizeBuilder(&node_builder);
    999 
   1000     Node* recv1 =
   1001         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
   1002                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
   1003     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
   1004                      b2.opts().WithName("E").WithControlInputs({recv1, b}));
   1005     Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
   1006                                b2.opts()
   1007                                    .WithName("outside_compilation_F1_O1_send")
   1008                                    .WithControlInput(e));
   1009 
   1010     Node* recv2 =
   1011         RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
   1012                    b2.opts().WithName("outside_compilation_F1_O2_recv"));
   1013     Node* g = Binary(e, ops::NodeOut(recv2, 1),
   1014                      b2.opts().WithName("G").WithControlInputs({recv2, e}));
   1015     Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
   1016     Node* send2 =
   1017         SendFromHost("host_compute_channel_F1_O2", {h},
   1018                      b2.opts().WithName("outside_compilation_F1_O2_send"));
   1019 
   1020     Node* s = NoOp(b2.opts()
   1021                        .WithName("F1_sequencer")
   1022                        .WithControlInputs({recv1, send1, recv2, send2}));
   1023 
   1024     Binary(g, call, b2.opts().WithName("J").WithControlInput(s));
   1025     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
   1026   }
   1027 
   1028   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
   1029   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
   1030 }
   1031 
   1032 // Test with two functions to transform, each with one outside_compilation
   1033 // cluster.
   1034 TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
   1035   FunctionDefLibrary library;
   1036   GraphDef graphdef;
   1037 
   1038   {
   1039     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
   1040     Node* a = InputShaped(b1.opts().WithName("A"));
   1041     Node* b = InputShaped(b1.opts().WithName("B"));
   1042     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
   1043     Node* d =
   1044         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
   1045     Node* e = Binary(c, d,
   1046                      b1.opts()
   1047                          .WithName("E")
   1048                          .WithControlInputs({b, d})
   1049                          .WithAttr("_encapsulate", "F1")
   1050                          .WithAttr("_outside", "O1"));
   1051     Node* f = Binary(c, e,
   1052                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
   1053                          "_encapsulate", "F1"));
   1054     Node* g = Binary(e, f,
   1055                      b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr(
   1056                          "_encapsulate", "F2"));
   1057     Node* h = Binary(d, g,
   1058                      b1.opts()
   1059                          .WithName("H")
   1060                          .WithAttr("_encapsulate", "F2")
   1061                          .WithAttr("_outside", "O1"));
   1062     Node* i =
   1063         Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
   1064     Binary(g, i, b1.opts().WithName("J"));
   1065     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
   1066   }
   1067 
   1068   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
   1069 
   1070   FunctionDefLibrary library_expected;
   1071   GraphDef graphdef_expected;
   1072 
   1073   string shape_string_expected;
   1074   {
   1075     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
   1076     Node* recv =
   1077         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
   1078                    shape.opts().WithName("outside_compilation_F1_O1_recv"));
   1079     Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
   1080                      shape.opts().WithName("E"));
   1081     SendFromHost("host_compute_channel_F1_O1", {e},
   1082                  shape.opts().WithName("outside_compilation_F1_O1_send"));
   1083     GraphDef shape_graph;
   1084     TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
   1085     EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
   1086   }
   1087 
   1088   TensorShapeProto shape_proto_expected;
   1089   shape_proto_expected.add_dim()->set_size(2);
   1090 
   1091   *library_expected.add_function() = FunctionDefHelper::Create(
   1092       "F1", {"a_0_arg:float", "b_0_arg:float"},
   1093       {"f_0_retval:float", "d_0_retval:float"}, {},
   1094       {
   1095           {{"C"}, "UnaryTest", {"a_0_arg"}},
   1096           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
   1097           {{"F"},
   1098            "BinaryTest",
   1099            {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
   1100            {},
   1101            {"outside_compilation_O1_host_compute"}},
   1102           {{"outside_compilation_O1_host_compute"},
   1103            "_XlaHostCompute",
   1104            {"C:o:0", "D:o:0"},
   1105            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
   1106             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1107             {"key", "host_compute_channel_F1_O1"},
   1108             {"shape_inference_graph", shape_string_expected},
   1109             {"shapes", gtl::ArraySlice<DataType>({})}},
   1110            {"D"}},
   1111       },
   1112       {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
   1113 
   1114   *library_expected.add_function() = FunctionDefHelper::Create(
   1115       "F2", {"e_0_arg:float", "f_0_arg:float"},
   1116       {"g_0_retval:float", "i_0_retval:float"}, {},
   1117       {
   1118           {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
   1119           {{"I"},
   1120            "BinaryTest",
   1121            {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
   1122           {{"outside_compilation_O1_host_compute"},
   1123            "_XlaHostCompute",
   1124            {"G:o:0"},
   1125            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1126             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1127             {"key", "host_compute_channel_F2_O1"},
   1128             {"shape_inference_graph", ""},
   1129             {"shapes",
   1130              gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
   1131       },
   1132       {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
   1133 
   1134   {
   1135     std::unique_ptr<FunctionLibraryDefinition> lib_def(
   1136         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
   1137     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
   1138     Node* a = InputShaped(b2.opts().WithName("A"));
   1139     Node* b = InputShaped(b2.opts().WithName("B"));
   1140 
   1141     Node* recv1 =
   1142         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
   1143                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
   1144     Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
   1145                      b2.opts().WithName("E").WithControlInputs({recv1, b}));
   1146     Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
   1147                                b2.opts()
   1148                                    .WithName("outside_compilation_F1_O1_send")
   1149                                    .WithControlInput(e));
   1150     NodeBuilder node_builder1("F1", "F1", lib_def.get());
   1151     node_builder1.Input(a).Input(b);
   1152     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
   1153     Node* s1 = NoOp(
   1154         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
   1155 
   1156     Node* recv2 =
   1157         RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT},
   1158                    b2.opts().WithName("outside_compilation_F2_O1_recv"));
   1159     Node* h = Binary(ops::NodeOut(call1, 1), recv2,
   1160                      b2.opts().WithName("H").WithControlInput(s1));
   1161     Node* send2 =
   1162         SendFromHost("host_compute_channel_F2_O1", {h},
   1163                      b2.opts().WithName("outside_compilation_F2_O1_send"));
   1164 
   1165     NodeBuilder node_builder2("F2", "F2", lib_def.get());
   1166     node_builder2.Input(e).Input(call1);
   1167     Node* call2 = b2.opts()
   1168                       .WithControlInputs({s1, e, call1})
   1169                       .FinalizeBuilder(&node_builder2);
   1170     Node* s2 = NoOp(
   1171         b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}));
   1172     Binary(call2, ops::NodeOut(call2, 1),
   1173            b2.opts().WithName("J").WithControlInput(s2));
   1174     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
   1175   }
   1176 
   1177   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
   1178   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
   1179 }
   1180 
   1181 // Test with one outside_compilation cluster that has no inputs from the
   1182 // compiled subgraph.
   1183 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
   1184   FunctionDefLibrary library;
   1185   GraphDef graphdef;
   1186 
   1187   {
   1188     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
   1189     Node* a = InputShaped(b1.opts().WithName("A"));
   1190     Node* b = Input(b1.opts().WithName("B"));
   1191     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
   1192     Node* d =
   1193         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
   1194     Node* e = Unary(a, b1.opts()
   1195                            .WithName("E")
   1196                            .WithAttr("_encapsulate", "F1")
   1197                            .WithAttr("_outside", "O1"));
   1198     Node* f =
   1199         Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
   1200     Unary(f, b1.opts().WithName("G"));
   1201     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
   1202   }
   1203 
   1204   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
   1205 
   1206   FunctionDefLibrary library_expected;
   1207   GraphDef graphdef_expected;
   1208 
   1209   TensorShapeProto shape_proto_expected;
   1210   shape_proto_expected.add_dim()->set_size(2);
   1211 
   1212   *library_expected.add_function() = FunctionDefHelper::Create(
   1213       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
   1214       {
   1215           {{"C"}, "UnaryTest", {"a_0_arg"}},
   1216           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
   1217           {{"F"},
   1218            "BinaryTest",
   1219            {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
   1220           {{"outside_compilation_O1_host_compute"},
   1221            "_XlaHostCompute",
   1222            {},
   1223            {{"Tinputs", gtl::ArraySlice<DataType>({})},
   1224             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1225             {"key", "host_compute_channel_F1_O1"},
   1226             {"shape_inference_graph", ""},
   1227             {"shapes",
   1228              gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
   1229       },
   1230       {{"f_0_retval", "F:o:0"}});
   1231 
   1232   {
   1233     std::unique_ptr<FunctionLibraryDefinition> lib_def(
   1234         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
   1235     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
   1236     Node* a = InputShaped(b2.opts().WithName("A"));
   1237     Node* b = Input(b2.opts().WithName("B"));
   1238 
   1239     Node* e = Unary(a, b2.opts().WithName("E"));
   1240     Node* send1 =
   1241         SendFromHost("host_compute_channel_F1_O1", {e},
   1242                      b2.opts().WithName("outside_compilation_F1_O1_send"));
   1243     NodeBuilder node_builder1("F1", "F1", lib_def.get());
   1244     node_builder1.Input(a).Input(b);
   1245     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
   1246     Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1));
   1247 
   1248     Unary(call1, b2.opts().WithName("G").WithControlInput(s1));
   1249     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
   1250   }
   1251 
   1252   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
   1253   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
   1254 }
   1255 
   1256 // Test with one outside_compilation cluster that has no data inputs but has a
   1257 // control input from the compiled subgraph.
   1258 TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
   1259   FunctionDefLibrary library;
   1260   GraphDef graphdef;
   1261 
   1262   {
   1263     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
   1264     Node* a = InputShaped(b1.opts().WithName("A"));
   1265     Node* b = Input(b1.opts().WithName("B"));
   1266     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
   1267     Node* d =
   1268         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
   1269     Node* e = Unary(a, b1.opts()
   1270                            .WithName("E")
   1271                            .WithControlInput(d)
   1272                            .WithAttr("_encapsulate", "F1")
   1273                            .WithAttr("_outside", "O1"));
   1274     Node* f =
   1275         Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
   1276     Unary(f, b1.opts().WithName("G"));
   1277     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
   1278   }
   1279 
   1280   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
   1281 
   1282   FunctionDefLibrary library_expected;
   1283   GraphDef graphdef_expected;
   1284 
   1285   TensorShapeProto shape_proto_expected;
   1286   shape_proto_expected.add_dim()->set_size(2);
   1287 
   1288   *library_expected.add_function() = FunctionDefHelper::Create(
   1289       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
   1290       {
   1291           {{"C"}, "UnaryTest", {"a_0_arg"}},
   1292           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
   1293           {{"F"},
   1294            "BinaryTest",
   1295            {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
   1296           {{"outside_compilation_O1_host_compute"},
   1297            "_XlaHostCompute",
   1298            {},
   1299            {{"Tinputs", gtl::ArraySlice<DataType>({})},
   1300             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1301             {"key", "host_compute_channel_F1_O1"},
   1302             {"shape_inference_graph", ""},
   1303             {"shapes",
   1304              gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}},
   1305            {"D"}},
   1306       },
   1307       {{"f_0_retval", "F:o:0"}});
   1308 
   1309   {
   1310     std::unique_ptr<FunctionLibraryDefinition> lib_def(
   1311         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
   1312     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
   1313     Node* a = InputShaped(b2.opts().WithName("A"));
   1314     Node* b = Input(b2.opts().WithName("B"));
   1315 
   1316     Node* recv1 =
   1317         RecvAtHost("host_compute_channel_F1_O1", {},
   1318                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
   1319     Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
   1320     Node* send1 =
   1321         SendFromHost("host_compute_channel_F1_O1", {e},
   1322                      b2.opts().WithName("outside_compilation_F1_O1_send"));
   1323     NodeBuilder node_builder1("F1", "F1", lib_def.get());
   1324     node_builder1.Input(a).Input(b);
   1325     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
   1326     Node* s1 = NoOp(
   1327         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
   1328 
   1329     Unary(call1, b2.opts().WithName("G").WithControlInput(s1));
   1330     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
   1331   }
   1332 
   1333   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
   1334   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
   1335 }
   1336 
   1337 // Test with one outside_compilation cluster that has no outputs from the
   1338 // compiled subgraph.
   1339 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
   1340   FunctionDefLibrary library;
   1341   GraphDef graphdef;
   1342 
   1343   {
   1344     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
   1345     Node* a = Input(b1.opts().WithName("A"));
   1346     Node* b = Input(b1.opts().WithName("B"));
   1347     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
   1348     Node* d =
   1349         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
   1350     Node* e = Unary(d, b1.opts()
   1351                            .WithName("E")
   1352                            .WithAttr("_encapsulate", "F1")
   1353                            .WithAttr("_outside", "O1"));
   1354     Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
   1355     Binary(e, f, b1.opts().WithName("G"));
   1356     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
   1357   }
   1358 
   1359   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
   1360 
   1361   FunctionDefLibrary library_expected;
   1362   GraphDef graphdef_expected;
   1363 
   1364   *library_expected.add_function() = FunctionDefHelper::Create(
   1365       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
   1366       {
   1367           {{"C"}, "UnaryTest", {"a_0_arg"}},
   1368           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
   1369           {{"F"}, "UnaryTest", {"D:o:0"}},
   1370           {{"outside_compilation_O1_host_compute"},
   1371            "_XlaHostCompute",
   1372            {"D:o:0"},
   1373            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1374             {"Toutputs", gtl::ArraySlice<DataType>({})},
   1375             {"key", "host_compute_channel_F1_O1"},
   1376             {"shape_inference_graph", ""},
   1377             {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
   1378       },
   1379       {{"f_0_retval", "F:o:0"}});
   1380 
   1381   {
   1382     std::unique_ptr<FunctionLibraryDefinition> lib_def(
   1383         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
   1384     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
   1385     Node* a = Input(b2.opts().WithName("A"));
   1386     Node* b = Input(b2.opts().WithName("B"));
   1387 
   1388     Node* recv1 =
   1389         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
   1390                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
   1391     Node* e = Unary(recv1, b2.opts().WithName("E"));
   1392     NodeBuilder node_builder1("F1", "F1", lib_def.get());
   1393     node_builder1.Input(a).Input(b);
   1394     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
   1395     Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1));
   1396 
   1397     Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1));
   1398     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
   1399   }
   1400 
   1401   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
   1402   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
   1403 }
   1404 
   1405 // Test with one outside_compilation cluster that has no data outputs but has a
   1406 // control output to the compiled subgraph.
   1407 TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
   1408   FunctionDefLibrary library;
   1409   GraphDef graphdef;
   1410 
   1411   {
   1412     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
   1413     Node* a = Input(b1.opts().WithName("A"));
   1414     Node* b = Input(b1.opts().WithName("B"));
   1415     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
   1416     Node* d =
   1417         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
   1418     Node* e = Unary(d, b1.opts()
   1419                            .WithName("E")
   1420                            .WithAttr("_encapsulate", "F1")
   1421                            .WithAttr("_outside", "O1"));
   1422     Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr(
   1423                            "_encapsulate", "F1"));
   1424     Binary(e, f, b1.opts().WithName("G"));
   1425     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
   1426   }
   1427 
   1428   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
   1429 
   1430   FunctionDefLibrary library_expected;
   1431   GraphDef graphdef_expected;
   1432 
   1433   *library_expected.add_function() = FunctionDefHelper::Create(
   1434       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
   1435       {
   1436           {{"C"}, "UnaryTest", {"a_0_arg"}},
   1437           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
   1438           {{"F"},
   1439            "UnaryTest",
   1440            {"D:o:0"},
   1441            {},
   1442            {"outside_compilation_O1_host_compute"}},
   1443           {{"outside_compilation_O1_host_compute"},
   1444            "_XlaHostCompute",
   1445            {"D:o:0"},
   1446            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1447             {"Toutputs", gtl::ArraySlice<DataType>({})},
   1448             {"key", "host_compute_channel_F1_O1"},
   1449             {"shape_inference_graph", ""},
   1450             {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
   1451       },
   1452       {{"f_0_retval", "F:o:0"}});
   1453 
   1454   {
   1455     std::unique_ptr<FunctionLibraryDefinition> lib_def(
   1456         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
   1457     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
   1458     Node* a = Input(b2.opts().WithName("A"));
   1459     Node* b = Input(b2.opts().WithName("B"));
   1460 
   1461     Node* recv1 =
   1462         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
   1463                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
   1464     Node* e = Unary(recv1, b2.opts().WithName("E"));
   1465     Node* send1 = SendFromHost("host_compute_channel_F1_O1", {},
   1466                                b2.opts()
   1467                                    .WithName("outside_compilation_F1_O1_send")
   1468                                    .WithControlInput(e));
   1469     NodeBuilder node_builder1("F1", "F1", lib_def.get());
   1470     node_builder1.Input(a).Input(b);
   1471     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
   1472     Node* s1 = NoOp(
   1473         b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
   1474 
   1475     Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1));
   1476     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
   1477   }
   1478 
   1479   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
   1480   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
   1481 }
   1482 
   1483 // Test with one outside_compilation cluster that has no outputs from the
   1484 // compiled subgraph.
   1485 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
   1486   FunctionDefLibrary library;
   1487   GraphDef graphdef;
   1488 
   1489   {
   1490     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
   1491     Node* a = Input(b1.opts().WithName("A"));
   1492     Node* b = Input(b1.opts().WithName("B"));
   1493     Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
   1494     Node* d =
   1495         Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
   1496     Node* e = Unary(a, b1.opts()
   1497                            .WithName("E")
   1498                            .WithAttr("_encapsulate", "F1")
   1499                            .WithAttr("_outside", "O1"));
   1500     Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
   1501     Binary(e, f, b1.opts().WithName("G"));
   1502     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
   1503   }
   1504 
   1505   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
   1506 
   1507   FunctionDefLibrary library_expected;
   1508   GraphDef graphdef_expected;
   1509 
   1510   *library_expected.add_function() = FunctionDefHelper::Create(
   1511       "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
   1512       {
   1513           {{"C"}, "UnaryTest", {"a_0_arg"}},
   1514           {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
   1515           {{"F"}, "UnaryTest", {"D:o:0"}},
   1516       },
   1517       {{"f_0_retval", "F:o:0"}});
   1518 
   1519   {
   1520     std::unique_ptr<FunctionLibraryDefinition> lib_def(
   1521         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
   1522     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
   1523     Node* a = Input(b2.opts().WithName("A"));
   1524     Node* b = Input(b2.opts().WithName("B"));
   1525 
   1526     Node* e = Unary(a, b2.opts().WithName("E"));
   1527     NodeBuilder node_builder1("F1", "F1", lib_def.get());
   1528     node_builder1.Input(a).Input(b);
   1529     Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
   1530 
   1531     Binary(e, call1, b2.opts().WithName("G"));
   1532     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
   1533   }
   1534 
   1535   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
   1536   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
   1537 }
   1538 
   1539 // Test for shape inference of outside compilation.
   1540 TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
   1541   FunctionDefLibrary library;
   1542   GraphDef graphdef;
   1543 
   1544   {
   1545     *library.add_function() = test::function::XTimesTwo();
   1546 
   1547     GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
   1548     Node* a = InputShaped(b1.opts().WithName("A"));
   1549     Node* b = Input(b1.opts().WithName("B"));
   1550     // Give nodes 'c' and 'd' names that collide after lowercasing.
   1551     Node* c = Unary(a, b1.opts().WithName("C"));
   1552     Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
   1553                            "_encapsulate", "F1"));
   1554     Node* e = BinaryUnknownShape(c, d,
   1555                                  b1.opts()
   1556                                      .WithName("E")
   1557                                      .WithControlInputs({b, d})
   1558                                      .WithAttr("_encapsulate", "F1")
   1559                                      .WithAttr("_outside", "O1"));
   1560     Node* f = Binary(c, e,
   1561                      b1.opts().WithName("F").WithControlInput(e).WithAttr(
   1562                          "_encapsulate", "F1"));
   1563     Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
   1564     TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
   1565   }
   1566 
   1567   TF_EXPECT_OK(Encapsulate(&graphdef, &library));
   1568 
   1569   FunctionDefLibrary library_expected;
   1570   GraphDef graphdef_expected;
   1571 
   1572   string shape_string_expected;
   1573   {
   1574     GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
   1575     Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0"));
   1576     Node* recv =
   1577         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
   1578                    shape.opts().WithName("outside_compilation_F1_O1_recv"));
   1579     Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
   1580     SendFromHost("host_compute_channel_F1_O1", {e},
   1581                  shape.opts().WithName("outside_compilation_F1_O1_send"));
   1582     GraphDef shape_graph;
   1583     TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
   1584     EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
   1585   }
   1586 
   1587   *library_expected.add_function() = test::function::XTimesTwo();
   1588   *library_expected.add_function() = FunctionDefHelper::Create(
   1589       "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {},
   1590       {
   1591           {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
   1592           {{"F"},
   1593            "BinaryTest",
   1594            {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
   1595            {},
   1596            {"outside_compilation_O1_host_compute"}},
   1597           {{"outside_compilation_O1_host_compute"},
   1598            "_XlaHostCompute",
   1599            {"c:o:0"},
   1600            {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1601             {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
   1602             {"key", "host_compute_channel_F1_O1"},
   1603             {"shape_inference_graph", shape_string_expected},
   1604             {"shapes", gtl::ArraySlice<DataType>({})}},
   1605            {"c"}},
   1606       },
   1607       {{"f_0_retval", "F:o:0"}});
   1608 
   1609   {
   1610     std::unique_ptr<FunctionLibraryDefinition> lib_def(
   1611         new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
   1612     GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
   1613     Node* a = InputShaped(b2.opts().WithName("A"));
   1614     Node* b = Input(b2.opts().WithName("B"));
   1615     Node* c = Unary(a, b2.opts().WithName("C"));
   1616 
   1617     NodeBuilder node_builder("F1", "F1", lib_def.get());
   1618     node_builder.Input(b).Input(c);
   1619     Node* call =
   1620         b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder);
   1621 
   1622     Node* recv =
   1623         RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
   1624                    b2.opts().WithName("outside_compilation_F1_O1_recv"));
   1625     Node* e = BinaryUnknownShape(
   1626         c, ops::NodeOut(recv, 0),
   1627         b2.opts().WithName("E").WithControlInputs({recv, b}));
   1628     Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
   1629                               b2.opts()
   1630                                   .WithName("outside_compilation_F1_O1_send")
   1631                                   .WithControlInput(e));
   1632 
   1633     Node* s = NoOp(
   1634         b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
   1635 
   1636     Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
   1637     TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
   1638   }
   1639 
   1640   TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
   1641   TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
   1642 }
   1643 
   1644 }  // namespace
   1645 }  // namespace tensorflow
   1646