1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <utility> 17 18 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" 19 20 #include "tensorflow/cc/framework/ops.h" 21 #include "tensorflow/cc/ops/standard_ops.h" 22 #include "tensorflow/core/framework/function_testlib.h" 23 #include "tensorflow/core/graph/graph_constructor.h" 24 #include "tensorflow/core/graph/graph_def_builder.h" 25 #include "tensorflow/core/lib/core/status_test_util.h" 26 #include "tensorflow/core/platform/test.h" 27 #include "tensorflow/core/util/equal_graph_def.h" 28 29 namespace tensorflow { 30 namespace { 31 32 template <class Tkey, class Tvalue> 33 bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a, 34 const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b, 35 const std::function<string(const Tkey&)>& key_to_string, 36 const std::function<string(const Tvalue&)>& value_to_string, 37 const std::function<bool(const Tkey&, const Tvalue&, 38 const Tvalue&)>& compare, 39 const string& map_name, string* diff) { 40 for (const auto& elt_a : a) { 41 const auto iter = b.find(elt_a.first); 42 if (iter == b.end()) { 43 if (diff) { 44 *diff = strings::StrCat( 45 map_name, " expected: contains element with key '", 46 key_to_string(elt_a.first), "' got: map has no such element"); 47 } 48 return false; 49 } 50 if (!compare(elt_a.first, elt_a.second, iter->second)) { 51 if (diff) { 52 *diff = strings::StrCat(map_name, " expected: element with key '", 53 key_to_string(elt_a.first), " has value '", 54 value_to_string(elt_a.second), "' got: '", 55 value_to_string(iter->second), "'"); 56 } 57 return false; 58 } 59 } 60 for (const auto& elt_b : b) { 61 const auto iter = a.find(elt_b.first); 62 if (iter == a.end()) { 63 if (diff) { 64 *diff = strings::StrCat(map_name, " got: contains element with key '", 65 key_to_string(elt_b.first), 66 "' expected: map has no such element"); 67 } 68 return false; 69 } 70 } 71 return true; 72 } 73 74 bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, 75 const string& diff_preamble, string* diff) { 76 if (a.op() != b.op()) { 77 if (diff) { 78 *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), 79 ", expected op '", a.op(), "' got '", b.op()); 80 } 81 return false; 82 } 83 if (a.device() != b.device()) { 84 if (diff) { 85 *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), 86 ", expected device '", a.device(), "' got '", 87 b.device()); 88 } 89 return false; 90 } 91 if (a.input_size() != b.input_size()) { 92 if (diff) { 93 *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), 94 ", expected ", a.input_size(), " inputs got ", 95 b.input_size(), " expected:\n", a.DebugString(), 96 "\ngot:\n", b.DebugString()); 97 } 98 return false; 99 } 100 for (int i = 0; i < a.input_size(); ++i) { 101 if (a.input(i) != b.input(i)) { 102 if (diff) { 103 *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), 104 " input ", i, ", expected ", a.input(i), 105 " got ", b.input(i), " expected:\n", 106 a.DebugString(), "\ngot:\n", b.DebugString()); 107 } 108 return false; 109 } 110 } 111 return EqualProtoMap<string, AttrValue>( 112 a.attr(), b.attr(), [](const string& s) { return s; }, 113 [](const AttrValue& v) { return v.DebugString(); }, 114 [](const string& key, const AttrValue& av, const AttrValue& bv) { 115 if (key == "shape_inference_graph") { 116 // Default serialization of GraphDef is unstable because maps don't 117 // serialize deterministically. Rather than go through the hoops to 118 // turn on deterministic serialization of this attr just for this 119 // test, add logic here to compare determinstically. 120 GraphDef ga; 121 if (!ga.ParseFromString(av.s())) { 122 return false; 123 } 124 GraphDef gb; 125 if (!gb.ParseFromString(bv.s())) { 126 return false; 127 } 128 return EqualGraphDef(ga, gb, nullptr); 129 } else { 130 return av.DebugString() == bv.DebugString(); 131 } 132 }, 133 strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), 134 diff); 135 } 136 137 bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, 138 string* diff) { 139 if (a.signature().DebugString() != b.signature().DebugString()) { 140 if (diff) { 141 *diff = strings::StrCat("Signature mismatch for function ", 142 a.signature().name(), ", expected:\n", 143 a.signature().DebugString(), "\ngot:\n", 144 b.signature().DebugString()); 145 } 146 return false; 147 } 148 if (!EqualProtoMap<string, AttrValue>( 149 a.attr(), b.attr(), [](const string& s) { return s; }, 150 [](const AttrValue& v) { return v.DebugString(); }, 151 [](const string& key, const AttrValue& av, const AttrValue& bv) { 152 return av.DebugString() == bv.DebugString(); 153 }, 154 strings::StrCat("attr mismatch for function ", a.signature().name()), 155 diff)) { 156 return false; 157 } 158 if (!EqualProtoMap<string, string>( 159 a.ret(), b.ret(), [](const string& s) { return s; }, 160 [](const string& s) { return s; }, 161 [](const string& key, const string& av, const string& bv) { 162 return av == bv; 163 }, 164 strings::StrCat("ret mismatch for function ", a.signature().name()), 165 diff)) { 166 return false; 167 } 168 for (int i = 0; i < a.node_def_size(); ++i) { 169 bool found = false; 170 for (int j = 0; j < b.node_def_size(); ++j) { 171 if (a.node_def(i).name() == b.node_def(j).name()) { 172 if (!EqualFunctionNodeDef( 173 a.node_def(i), b.node_def(j), 174 strings::StrCat("Function ", a.signature().name()), diff)) { 175 return false; 176 } 177 found = true; 178 break; 179 } 180 } 181 if (!found) { 182 if (diff) { 183 *diff = strings::StrCat("Function ", a.signature().name(), 184 ", expected: has node '", a.node_def(i).name(), 185 "' got: no node of that name"); 186 } 187 return false; 188 } 189 } 190 for (int i = 0; i < b.node_def_size(); ++i) { 191 bool found = false; 192 for (int j = 0; j < a.node_def_size(); ++j) { 193 if (b.node_def(i).name() == a.node_def(j).name()) { 194 found = true; 195 break; 196 } 197 } 198 if (!found) { 199 if (diff) { 200 *diff = strings::StrCat("Function ", a.signature().name(), 201 ", got: has node '", b.node_def(i).name(), 202 "' expected: no node of that name"); 203 } 204 return false; 205 } 206 } 207 return true; 208 } 209 210 bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, 211 const FunctionDefLibrary& actual, string* diff) { 212 std::unordered_map<string, const FunctionDef*> actual_index; 213 for (const FunctionDef& function : actual.function()) { 214 actual_index[function.signature().name()] = &function; 215 } 216 217 for (const FunctionDef& expected_function : expected.function()) { 218 auto it = actual_index.find(expected_function.signature().name()); 219 if (it == actual_index.end()) { 220 if (diff) { 221 *diff = strings::StrCat("Did not find expected function '", 222 expected_function.signature().name(), "'"); 223 } 224 return false; 225 } 226 if (!EqualFunctionDef(expected_function, *it->second, diff)) return false; 227 actual_index.erase(it); 228 } 229 230 if (!actual_index.empty()) { 231 if (diff != nullptr) { 232 *diff = strings::StrCat("Found unexpected function '", 233 actual_index.begin()->second->signature().name(), 234 "'"); 235 } 236 return false; 237 } 238 239 return true; 240 } 241 242 #define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \ 243 do { \ 244 string diff; \ 245 EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \ 246 << diff << "\nActual: " << actual.DebugString(); \ 247 } while (false) 248 249 // TODO(misard): remove these fake registrations once there are real Ops to be 250 // compiled. 251 REGISTER_OP("_XlaHostCompute") 252 .Input("inputs: Tinputs") 253 .Output("outputs: Toutputs") 254 .Attr("Tinputs: list(type) >= 0") 255 .Attr("Toutputs: list(type) >= 0") 256 .Attr("key: string") 257 .SetShapeFn(::tensorflow::shape_inference::UnknownShape); 258 259 REGISTER_OP("_XlaSendFromHost") 260 .Input("input: Tinputs") 261 .Attr("Tinputs: list(type) >= 0") 262 .Attr("key: string") 263 .SetShapeFn(::tensorflow::shape_inference::UnknownShape); 264 265 REGISTER_OP("_XlaRecvAtHost") 266 .Output("output: Toutputs") 267 .Attr("Toutputs: list(type) >= 0") 268 .Attr("key: string") 269 .SetShapeFn(::tensorflow::shape_inference::UnknownShape); 270 271 REGISTER_OP("InputTest") 272 .Output("o: float") 273 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 274 c->set_output(0, c->UnknownShape()); 275 return Status::OK(); 276 }); 277 278 REGISTER_OP("InputTestShaped") 279 .Output("o: float") 280 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 281 c->set_output(0, c->Vector(2)); 282 return Status::OK(); 283 }); 284 285 REGISTER_OP("UnaryTest") 286 .Input("a: float") 287 .Output("o: float") 288 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 289 ::tensorflow::shape_inference::ShapeHandle o; 290 TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); 291 c->set_output(0, o); 292 return Status::OK(); 293 }); 294 REGISTER_OP("BinaryTest") 295 .Input("a: float") 296 .Input("b: float") 297 .Output("o: float") 298 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 299 ::tensorflow::shape_inference::ShapeHandle o; 300 TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); 301 c->set_output(0, o); 302 return Status::OK(); 303 }); 304 REGISTER_OP("BinaryTest2") 305 .Input("a: float") 306 .Input("b: float") 307 .Output("o: float") 308 .SetShapeFn(::tensorflow::shape_inference::UnknownShape); 309 310 REGISTER_OP("AddNLikeTest") 311 .Input("inputs: N * T") 312 .Output("sum: T") 313 .Attr("N: int >= 1") 314 .Attr("T: numbertype") 315 .SetIsCommutative() 316 .SetIsAggregate(); 317 318 Node* NoOp(const GraphDefBuilder::Options& opts) { 319 return ops::SourceOp("NoOp", opts); 320 } 321 322 Node* Input(const GraphDefBuilder::Options& opts) { 323 return ops::SourceOp("InputTest", opts); 324 } 325 326 Node* InputShaped(const GraphDefBuilder::Options& opts) { 327 return ops::SourceOp("InputTestShaped", opts); 328 } 329 330 Node* KnownShape(const gtl::ArraySlice<int>& shape, 331 const GraphDefBuilder::Options& opts) { 332 if (opts.HaveError()) return nullptr; 333 NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", 334 opts.op_registry()); 335 TensorProto value; 336 value.set_dtype(DT_FLOAT); 337 for (int dim : shape) { 338 value.mutable_tensor_shape()->add_dim()->set_size(dim); 339 } 340 return opts.WithAttr("value", value) 341 .WithAttr("dtype", DT_FLOAT) 342 .FinalizeBuilder(&node_builder); 343 } 344 345 Node* RecvAtHost(const string& key, const gtl::ArraySlice<DataType>& dtypes, 346 const GraphDefBuilder::Options& opts) { 347 if (opts.HaveError()) return nullptr; 348 NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), 349 "_XlaRecvAtHost", opts.op_registry()); 350 return opts.WithAttr("Toutputs", dtypes) 351 .WithAttr("key", key) 352 .FinalizeBuilder(&node_builder); 353 } 354 355 Node* SendFromHost(const string& key, const std::vector<ops::NodeOut>& inputs, 356 const GraphDefBuilder::Options& opts) { 357 if (opts.HaveError()) return nullptr; 358 NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), 359 "_XlaSendFromHost", opts.op_registry()); 360 node_builder.Input(inputs); 361 std::vector<DataType> dtypes; 362 for (const auto& node : inputs) { 363 dtypes.push_back(node.dt); 364 } 365 return opts.WithAttr("key", key) 366 .WithAttr("Tinputs", dtypes) 367 .FinalizeBuilder(&node_builder); 368 } 369 370 Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { 371 return ops::UnaryOp("UnaryTest", std::move(a), opts); 372 } 373 374 Node* Binary(ops::NodeOut a, ops::NodeOut b, 375 const GraphDefBuilder::Options& opts) { 376 return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts); 377 } 378 379 Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b, 380 const GraphDefBuilder::Options& opts) { 381 return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts); 382 } 383 384 Node* AddNLike(const std::vector<ops::NodeOut>& inputs, 385 const GraphDefBuilder::Options& opts) { 386 if (opts.HaveError()) return nullptr; 387 NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest", 388 opts.op_registry()); 389 node_builder.Input(inputs); 390 return opts.FinalizeBuilder(&node_builder); 391 } 392 393 Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) { 394 return ops::SourceOp("_Arg", 395 opts.WithAttr("T", type).WithAttr("index", index)); 396 } 397 398 Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { 399 if (opts.HaveError()) return nullptr; 400 NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval", 401 opts.op_registry()); 402 node_builder.Input(std::move(a)).Attr("index", index); 403 return opts.FinalizeBuilder(&node_builder); 404 } 405 406 Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { 407 Status s; 408 // Convert the GraphDef to a Graph 409 std::unique_ptr<FunctionLibraryDefinition> lib_def( 410 new FunctionLibraryDefinition(OpRegistry::Global(), *library)); 411 GraphConstructorOptions options; 412 options.allow_internal_ops = true; 413 std::unique_ptr<Graph> graph(new Graph(lib_def.get())); 414 s = ConvertGraphDefToGraph(options, *graphdef, graph.get()); 415 if (!s.ok()) return s; 416 417 std::unique_ptr<Graph> graph_out; 418 s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, 419 /*rewrite_subgraph_fn=*/{}, 420 /*parallel_checking=*/false, 421 /*reuse_existing_functions=*/false, 422 &graph_out, lib_def.get()); 423 if (!s.ok()) return s; 424 425 GraphDef graphdef_out; 426 graph_out->ToGraphDef(&graphdef_out); 427 graphdef->Swap(&graphdef_out); 428 429 *library = lib_def->ToProto(); 430 return s; 431 } 432 433 // If there are no marked nodes, funcification should be a no-op. 434 TEST(EncapsulateSubgraphsTest, NoFunctions) { 435 GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); 436 437 Node* a = Input(builder.opts().WithName("A")); 438 Node* b = Input(builder.opts().WithName("B")); 439 Node* c = Unary(a, builder.opts().WithName("C")); 440 Binary(b, c, builder.opts().WithName("D")); 441 442 GraphDef graphdef_in; 443 FunctionDefLibrary library_in; 444 TF_EXPECT_OK(builder.ToGraphDef(&graphdef_in)); 445 *library_in.add_function() = test::function::XTimesTwo(); 446 447 GraphDef graphdef_out = graphdef_in; 448 FunctionDefLibrary library_out = library_in; 449 TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out)); 450 451 // If there are no marked nodes, funcification should be a no-op. 452 TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out); 453 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out); 454 } 455 456 // Test with one function to transform. 457 TEST(EncapsulateSubgraphsTest, OneFunction) { 458 FunctionDefLibrary library; 459 GraphDef graphdef; 460 461 { 462 *library.add_function() = test::function::XTimesTwo(); 463 464 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 465 Node* a = Input(b1.opts().WithName("A")); 466 Node* b = Input(b1.opts().WithName("B")); 467 // Give nodes 'c' and 'd' names that collide after lowercasing. 468 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 469 Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr( 470 "_encapsulate", "F1")); 471 Binary(a, d, b1.opts().WithName("E")); 472 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 473 } 474 475 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 476 477 FunctionDefLibrary library_expected; 478 GraphDef graphdef_expected; 479 480 *library_expected.add_function() = test::function::XTimesTwo(); 481 *library_expected.add_function() = FunctionDefHelper::Create( 482 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {}, 483 { 484 {{"C"}, "UnaryTest", {"a_0_arg"}}, 485 {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, 486 }, 487 {{"c_0_retval", "c:o:0"}}); 488 489 { 490 std::unique_ptr<FunctionLibraryDefinition> lib_def( 491 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 492 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 493 Node* a = Input(b2.opts().WithName("A")); 494 Node* b = Input(b2.opts().WithName("B")); 495 496 NodeBuilder node_builder("F1", "F1", lib_def.get()); 497 node_builder.Input(a).Input(b); 498 Node* call = b2.opts().FinalizeBuilder(&node_builder); 499 500 Binary(a, call, b2.opts().WithName("E")); 501 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 502 } 503 504 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 505 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 506 } 507 508 // Test with two functions to transform. 509 TEST(EncapsulateSubgraphsTest, TwoFunctions) { 510 FunctionDefLibrary library; 511 GraphDef graphdef; 512 513 { 514 *library.add_function() = test::function::XTimesTwo(); 515 516 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 517 Node* a = Input(b1.opts().WithName("A")); 518 Node* b = Input(b1.opts().WithName("B")); 519 Node* control = Input(b1.opts().WithName("Control")); 520 Node* c = 521 Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr( 522 "_encapsulate", "F1")); 523 Node* d = 524 Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr( 525 "_encapsulate", "F2")); 526 Binary(a, d, b1.opts().WithName("E")); 527 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 528 } 529 530 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 531 532 FunctionDefLibrary library_expected; 533 GraphDef graphdef_expected; 534 535 *library_expected.add_function() = test::function::XTimesTwo(); 536 *library_expected.add_function() = FunctionDefHelper::Create( 537 "F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {}, 538 { 539 {{"C"}, "UnaryTest", {"a_0_arg"}}, 540 }, 541 {{"c_0_retval", "C:o:0"}}); 542 *library_expected.add_function() = FunctionDefHelper::Create( 543 "F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {}, 544 { 545 {{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}}, 546 }, 547 {{"d_0_retval", "D:o:0"}}); 548 549 { 550 std::unique_ptr<FunctionLibraryDefinition> lib_def( 551 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 552 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 553 Node* a = Input(b2.opts().WithName("A")); 554 Node* b = Input(b2.opts().WithName("B")); 555 Node* control = Input(b2.opts().WithName("Control")); 556 557 NodeBuilder nb("F1", "F1", lib_def.get()); 558 nb.Input(a).ControlInput(control); 559 Node* call1 = b2.opts().FinalizeBuilder(&nb); 560 561 NodeBuilder nb2("F2", "F2", lib_def.get()); 562 nb2.Input(b).Input(call1).ControlInput(control); 563 Node* call2 = b2.opts().FinalizeBuilder(&nb2); 564 565 Binary(a, call2, b2.opts().WithName("E")); 566 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 567 } 568 569 // If there are no marked nodes, funcification should be a no-op. 570 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 571 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 572 } 573 574 // Returns a vector of node names in 'graph', sorted by name. 575 std::vector<string> GraphNodes(const Graph& graph) { 576 std::vector<string> nodes; 577 for (const auto& node : graph.nodes()) { 578 if (!node->IsSource() && !node->IsSink()) { 579 nodes.push_back(node->name()); 580 } 581 } 582 std::sort(nodes.begin(), nodes.end()); 583 return nodes; 584 } 585 586 // Returns a sorted vector of (src, dst) edges in 'graph'. 587 std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) { 588 std::vector<std::pair<string, string>> edges; 589 for (const Edge* edge : graph.edges()) { 590 if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; 591 edges.emplace_back( 592 strings::StrCat(edge->src()->name(), ":", edge->src_output()), 593 strings::StrCat(edge->dst()->name(), ":", edge->dst_input())); 594 } 595 std::sort(edges.begin(), edges.end()); 596 return edges; 597 } 598 599 TEST(EncapsulateSubgraphsTest, InputDeduplication) { 600 Scope root = Scope::NewRootScope().ExitOnError().WithDevice( 601 "/job:localhost/replica:0/task:0/cpu:0"); 602 auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT); 603 auto add1 = ops::Add(root.WithOpName("add1"), x, x); 604 add1.node()->AddAttr("_cluster", "cluster1"); 605 auto add2 = ops::Add(root.WithOpName("add2"), add1, add1); 606 add2.node()->AddAttr("_cluster", "cluster2"); 607 auto out = ops::Mul(root.WithOpName("mul"), add1, add2); 608 609 Graph graph_before_encapsulation(OpRegistry::Global()); 610 TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); 611 612 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 613 std::unique_ptr<Graph> graph; 614 TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( 615 "_cluster", "_outside", graph_before_encapsulation, 616 /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, 617 /*reuse_existing_functions=*/false, &graph, &library)); 618 619 std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"}; 620 EXPECT_EQ(expected_nodes, GraphNodes(*graph)); 621 622 std::vector<std::pair<string, string>> expected_edges = { 623 {"cluster1:0", "cluster2:0"}, 624 {"cluster1:0", "mul:0"}, 625 {"cluster2:0", "mul:1"}, 626 {"x:0", "cluster1:0"}}; 627 EXPECT_EQ(expected_edges, GraphEdges(*graph)); 628 } 629 630 TEST(EncapsulateSubgraphsTest, ParallelChecking) { 631 Scope root = Scope::NewRootScope().ExitOnError().WithDevice( 632 "/job:localhost/replica:0/task:0/cpu:0"); 633 auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); 634 auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); 635 auto add1 = ops::Add(root.WithOpName("add1"), x1, x2); 636 add1.node()->AddAttr("_cluster", "cluster1"); 637 auto add2 = ops::Add(root.WithOpName("add2"), add1, x2); 638 add2.node()->AddAttr("_cluster", "cluster1"); 639 auto out = ops::Mul(root.WithOpName("mul"), x1, add2); 640 641 Graph graph_before_encapsulation(OpRegistry::Global()); 642 TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); 643 644 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 645 std::unique_ptr<Graph> graph; 646 TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( 647 "_cluster", "_outside", graph_before_encapsulation, 648 /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, 649 /*reuse_existing_functions=*/false, &graph, &library)); 650 651 std::vector<string> expected_nodes = { 652 "add1", "add2", "cluster1", "cluster1_parallel_check/_0", 653 "mul", "x1", "x2"}; 654 EXPECT_EQ(expected_nodes, GraphNodes(*graph)); 655 656 std::vector<std::pair<string, string>> expected_edges = { 657 {"add1:0", "add2:0"}, 658 {"add2:0", "cluster1_parallel_check/_0:0"}, 659 {"cluster1:0", "cluster1_parallel_check/_0:1"}, 660 {"cluster1_parallel_check/_0:0", "mul:1"}, 661 {"x1:0", "add1:0"}, 662 {"x1:0", "cluster1:0"}, 663 {"x1:0", "mul:0"}, 664 {"x2:0", "add1:1"}, 665 {"x2:0", "add2:1"}, 666 {"x2:0", "cluster1:1"}, 667 }; 668 EXPECT_EQ(expected_edges, GraphEdges(*graph)); 669 } 670 671 const Node* FindNodeByName(const Graph& graph, const string& name) { 672 for (const Node* node : graph.nodes()) { 673 if (node->name() == name) return node; 674 } 675 return nullptr; 676 } 677 678 bool HasGuaranteeConstAttr(const Node& n) { 679 bool is_guaranteed_constant = false; 680 if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", 681 &is_guaranteed_constant) 682 .ok()) { 683 return false; 684 } 685 return is_guaranteed_constant; 686 } 687 688 TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { 689 Scope root = Scope::NewRootScope().ExitOnError().WithDevice( 690 "/job:localhost/replica:0/task:0/cpu:0"); 691 auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); 692 auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f); 693 auto const_guarantee_x1 = 694 ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); 695 auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2); 696 add1.node()->AddAttr("_encapsulate", "encapsulate1"); 697 698 Graph graph_before(OpRegistry::Global()); 699 TF_ASSERT_OK(root.ToGraph(&graph_before)); 700 701 std::unique_ptr<Graph> graph_after; 702 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 703 int guaranteed_consts = 0; 704 TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( 705 "_encapsulate", "_outside", graph_before, 706 /*rewrite_subgraph_fn=*/ 707 [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr, 708 std::vector<int>* input_permutation, 709 std::vector<int>* output_permutation, 710 NodeDef* call_def) { 711 Graph* graph = graph_ptr->get(); 712 for (const Node* n : graph->nodes()) { 713 if (n->type_string() == "_Arg" && 714 StringPiece(n->name()).starts_with("const")) { 715 ++guaranteed_consts; 716 EXPECT_TRUE(HasGuaranteeConstAttr(*n)); 717 } else { 718 EXPECT_FALSE(HasGuaranteeConstAttr(*n)); 719 } 720 } 721 return Status::OK(); 722 }, 723 /*parallel_checking=*/false, 724 /*reuse_existing_functions=*/false, &graph_after, &library)); 725 EXPECT_EQ(2, guaranteed_consts); 726 } 727 728 TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { 729 Scope root = Scope::NewRootScope().ExitOnError().WithDevice( 730 "/job:localhost/replica:0/task:0/cpu:0"); 731 auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); 732 auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); 733 auto const_guarantee_x1 = 734 ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); 735 auto const_guarantee_x2 = 736 ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2); 737 auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"), 738 const_guarantee_x1, const_guarantee_x2); 739 auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2); 740 auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2); 741 mul1.node()->AddAttr("_encapsulate", "encapsulate1"); 742 743 Graph graph_before(OpRegistry::Global()); 744 TF_ASSERT_OK(root.ToGraph(&graph_before)); 745 746 std::unique_ptr<Graph> graph_after; 747 FunctionLibraryDefinition library(OpRegistry::Global(), {}); 748 int guaranteed_consts = 0; 749 TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( 750 "_encapsulate", "_outside", graph_before, 751 /*rewrite_subgraph_fn=*/ 752 [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr, 753 std::vector<int>* input_permutation, 754 std::vector<int>* output_permutation, 755 NodeDef* call_def) { 756 Graph* graph = graph_ptr->get(); 757 for (const Node* n : graph->nodes()) { 758 if (n->type_string() == "_Arg" && 759 StringPiece(n->name()).starts_with("const")) { 760 ++guaranteed_consts; 761 EXPECT_TRUE(HasGuaranteeConstAttr(*n)); 762 } else { 763 EXPECT_FALSE(HasGuaranteeConstAttr(*n)); 764 } 765 } 766 return Status::OK(); 767 }, 768 /*parallel_checking=*/false, 769 /*reuse_existing_functions=*/false, &graph_after, &library)); 770 // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const 771 // and another non-const, so overall non-const. 772 EXPECT_EQ(1, guaranteed_consts); 773 } 774 775 // Test with one function to transform and one outside_compilation cluster. 776 TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { 777 FunctionDefLibrary library; 778 GraphDef graphdef; 779 780 { 781 *library.add_function() = test::function::XTimesTwo(); 782 783 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 784 Node* a = Input(b1.opts().WithName("A")); 785 Node* b = Input(b1.opts().WithName("B")); 786 // Give nodes 'c' and 'd' names that collide after lowercasing. 787 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 788 Node* d = Binary(b, c, 789 b1.opts().WithName("c").WithControlInput(c).WithAttr( 790 "_encapsulate", "F1")); 791 Node* e = Binary(c, d, 792 b1.opts() 793 .WithName("E") 794 .WithControlInputs({b, d}) 795 .WithAttr("_encapsulate", "F1") 796 .WithAttr("_outside", "O1")); 797 Node* f = Binary(c, e, 798 b1.opts().WithName("F").WithControlInput(e).WithAttr( 799 "_encapsulate", "F1")); 800 Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); 801 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 802 } 803 804 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 805 806 FunctionDefLibrary library_expected; 807 GraphDef graphdef_expected; 808 809 string shape_string_expected; 810 { 811 GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); 812 Node* recv = 813 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, 814 shape.opts().WithName("outside_compilation_F1_O1_recv")); 815 Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), 816 shape.opts().WithName("E")); 817 SendFromHost("host_compute_channel_F1_O1", {e}, 818 shape.opts().WithName("outside_compilation_F1_O1_send")); 819 GraphDef shape_graph; 820 TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); 821 EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); 822 } 823 824 *library_expected.add_function() = test::function::XTimesTwo(); 825 *library_expected.add_function() = FunctionDefHelper::Create( 826 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, 827 { 828 {{"C"}, "UnaryTest", {"a_0_arg"}}, 829 {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, 830 {{"F"}, 831 "BinaryTest", 832 {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, 833 {}, 834 {"outside_compilation_O1_host_compute"}}, 835 {{"outside_compilation_O1_host_compute"}, 836 "_XlaHostCompute", 837 {"C:o:0", "c:o:0"}, 838 {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}, 839 {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 840 {"key", "host_compute_channel_F1_O1"}, 841 {"shape_inference_graph", shape_string_expected}, 842 {"shapes", gtl::ArraySlice<DataType>({})}}, 843 {"c"}}, 844 }, 845 {{"f_0_retval", "F:o:0"}}); 846 847 { 848 std::unique_ptr<FunctionLibraryDefinition> lib_def( 849 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 850 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 851 Node* a = Input(b2.opts().WithName("A")); 852 Node* b = Input(b2.opts().WithName("B")); 853 854 NodeBuilder node_builder("F1", "F1", lib_def.get()); 855 node_builder.Input(a).Input(b); 856 Node* call = b2.opts().FinalizeBuilder(&node_builder); 857 858 Node* recv = 859 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, 860 b2.opts().WithName("outside_compilation_F1_O1_recv")); 861 Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), 862 b2.opts().WithName("E").WithControlInputs({recv, b})); 863 Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, 864 b2.opts() 865 .WithName("outside_compilation_F1_O1_send") 866 .WithControlInput(e)); 867 868 Node* s = NoOp( 869 b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); 870 871 Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); 872 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 873 } 874 875 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 876 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 877 } 878 879 // Test with one function to transform and two outside_compilation clusters. 880 TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { 881 FunctionDefLibrary library; 882 GraphDef graphdef; 883 884 { 885 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 886 Node* a = Input(b1.opts().WithName("A")); 887 Node* b = Input(b1.opts().WithName("B")); 888 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 889 Node* d = 890 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); 891 Node* e = Binary(c, d, 892 b1.opts() 893 .WithName("E") 894 .WithControlInputs({b, d}) 895 .WithAttr("_encapsulate", "F1") 896 .WithAttr("_outside", "O1")); 897 Node* f = Binary(c, e, 898 b1.opts().WithName("F").WithControlInput(e).WithAttr( 899 "_encapsulate", "F1")); 900 Node* g = Binary(e, f, 901 b1.opts() 902 .WithName("G") 903 .WithControlInputs({e, f}) 904 .WithAttr("_encapsulate", "F1") 905 .WithAttr("_outside", "O2")); 906 Node* h = Binary(d, e, 907 b1.opts() 908 .WithName("H") 909 .WithAttr("_encapsulate", "F1") 910 .WithAttr("_outside", "O2")); 911 Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1")); 912 Binary(g, i, b1.opts().WithName("J")); 913 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 914 } 915 916 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 917 918 FunctionDefLibrary library_expected; 919 GraphDef graphdef_expected; 920 921 string shape_string_expected_1; 922 { 923 GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); 924 Node* recv = 925 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, 926 shape1.opts().WithName("outside_compilation_F1_O1_recv")); 927 Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), 928 shape1.opts().WithName("E")); 929 SendFromHost("host_compute_channel_F1_O1", {e}, 930 shape1.opts().WithName("outside_compilation_F1_O1_send")); 931 GraphDef shape1_graph; 932 TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); 933 EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); 934 } 935 936 string shape_string_expected_2; 937 { 938 GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); 939 Node* recv1 = 940 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, 941 shape2.opts().WithName("outside_compilation_F1_O1_recv")); 942 Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), 943 shape2.opts().WithName("E")); 944 Node* recv2 = 945 RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, 946 shape2.opts().WithName("outside_compilation_F1_O2_recv")); 947 Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); 948 SendFromHost("host_compute_channel_F1_O2", {h}, 949 shape2.opts().WithName("outside_compilation_F1_O2_send")); 950 GraphDef shape2_graph; 951 TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); 952 EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); 953 } 954 955 *library_expected.add_function() = FunctionDefHelper::Create( 956 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, 957 { 958 {{"C"}, "UnaryTest", {"a_0_arg"}}, 959 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, 960 {{"I"}, 961 "UnaryTest", 962 {"outside_compilation_O2_host_compute:outputs:0"}}, 963 {{"F"}, 964 "BinaryTest", 965 {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, 966 {}, 967 {"outside_compilation_O1_host_compute"}}, 968 {{"outside_compilation_O2_host_compute"}, 969 "_XlaHostCompute", 970 {"D:o:0", "F:o:0"}, 971 {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}, 972 {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 973 {"key", "host_compute_channel_F1_O2"}, 974 {"shape_inference_graph", shape_string_expected_2}, 975 {"shapes", gtl::ArraySlice<DataType>({})}}, 976 {"F"}}, 977 {{"outside_compilation_O1_host_compute"}, 978 "_XlaHostCompute", 979 {"C:o:0", "D:o:0"}, 980 {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}, 981 {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 982 {"key", "host_compute_channel_F1_O1"}, 983 {"shape_inference_graph", shape_string_expected_1}, 984 {"shapes", gtl::ArraySlice<DataType>({})}}, 985 {"D"}}, 986 }, 987 {{"i_0_retval", "I:o:0"}}); 988 989 { 990 std::unique_ptr<FunctionLibraryDefinition> lib_def( 991 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 992 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 993 Node* a = Input(b2.opts().WithName("A")); 994 Node* b = Input(b2.opts().WithName("B")); 995 996 NodeBuilder node_builder("F1", "F1", lib_def.get()); 997 node_builder.Input(a).Input(b); 998 Node* call = b2.opts().FinalizeBuilder(&node_builder); 999 1000 Node* recv1 = 1001 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, 1002 b2.opts().WithName("outside_compilation_F1_O1_recv")); 1003 Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), 1004 b2.opts().WithName("E").WithControlInputs({recv1, b})); 1005 Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, 1006 b2.opts() 1007 .WithName("outside_compilation_F1_O1_send") 1008 .WithControlInput(e)); 1009 1010 Node* recv2 = 1011 RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, 1012 b2.opts().WithName("outside_compilation_F1_O2_recv")); 1013 Node* g = Binary(e, ops::NodeOut(recv2, 1), 1014 b2.opts().WithName("G").WithControlInputs({recv2, e})); 1015 Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); 1016 Node* send2 = 1017 SendFromHost("host_compute_channel_F1_O2", {h}, 1018 b2.opts().WithName("outside_compilation_F1_O2_send")); 1019 1020 Node* s = NoOp(b2.opts() 1021 .WithName("F1_sequencer") 1022 .WithControlInputs({recv1, send1, recv2, send2})); 1023 1024 Binary(g, call, b2.opts().WithName("J").WithControlInput(s)); 1025 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 1026 } 1027 1028 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 1029 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 1030 } 1031 1032 // Test with two functions to transform, each with one outside_compilation 1033 // cluster. 1034 TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { 1035 FunctionDefLibrary library; 1036 GraphDef graphdef; 1037 1038 { 1039 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 1040 Node* a = InputShaped(b1.opts().WithName("A")); 1041 Node* b = InputShaped(b1.opts().WithName("B")); 1042 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 1043 Node* d = 1044 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); 1045 Node* e = Binary(c, d, 1046 b1.opts() 1047 .WithName("E") 1048 .WithControlInputs({b, d}) 1049 .WithAttr("_encapsulate", "F1") 1050 .WithAttr("_outside", "O1")); 1051 Node* f = Binary(c, e, 1052 b1.opts().WithName("F").WithControlInput(e).WithAttr( 1053 "_encapsulate", "F1")); 1054 Node* g = Binary(e, f, 1055 b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr( 1056 "_encapsulate", "F2")); 1057 Node* h = Binary(d, g, 1058 b1.opts() 1059 .WithName("H") 1060 .WithAttr("_encapsulate", "F2") 1061 .WithAttr("_outside", "O1")); 1062 Node* i = 1063 Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); 1064 Binary(g, i, b1.opts().WithName("J")); 1065 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 1066 } 1067 1068 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 1069 1070 FunctionDefLibrary library_expected; 1071 GraphDef graphdef_expected; 1072 1073 string shape_string_expected; 1074 { 1075 GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); 1076 Node* recv = 1077 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, 1078 shape.opts().WithName("outside_compilation_F1_O1_recv")); 1079 Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), 1080 shape.opts().WithName("E")); 1081 SendFromHost("host_compute_channel_F1_O1", {e}, 1082 shape.opts().WithName("outside_compilation_F1_O1_send")); 1083 GraphDef shape_graph; 1084 TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); 1085 EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); 1086 } 1087 1088 TensorShapeProto shape_proto_expected; 1089 shape_proto_expected.add_dim()->set_size(2); 1090 1091 *library_expected.add_function() = FunctionDefHelper::Create( 1092 "F1", {"a_0_arg:float", "b_0_arg:float"}, 1093 {"f_0_retval:float", "d_0_retval:float"}, {}, 1094 { 1095 {{"C"}, "UnaryTest", {"a_0_arg"}}, 1096 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, 1097 {{"F"}, 1098 "BinaryTest", 1099 {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, 1100 {}, 1101 {"outside_compilation_O1_host_compute"}}, 1102 {{"outside_compilation_O1_host_compute"}, 1103 "_XlaHostCompute", 1104 {"C:o:0", "D:o:0"}, 1105 {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}, 1106 {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1107 {"key", "host_compute_channel_F1_O1"}, 1108 {"shape_inference_graph", shape_string_expected}, 1109 {"shapes", gtl::ArraySlice<DataType>({})}}, 1110 {"D"}}, 1111 }, 1112 {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); 1113 1114 *library_expected.add_function() = FunctionDefHelper::Create( 1115 "F2", {"e_0_arg:float", "f_0_arg:float"}, 1116 {"g_0_retval:float", "i_0_retval:float"}, {}, 1117 { 1118 {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, 1119 {{"I"}, 1120 "BinaryTest", 1121 {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, 1122 {{"outside_compilation_O1_host_compute"}, 1123 "_XlaHostCompute", 1124 {"G:o:0"}, 1125 {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1126 {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1127 {"key", "host_compute_channel_F2_O1"}, 1128 {"shape_inference_graph", ""}, 1129 {"shapes", 1130 gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}}, 1131 }, 1132 {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); 1133 1134 { 1135 std::unique_ptr<FunctionLibraryDefinition> lib_def( 1136 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 1137 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 1138 Node* a = InputShaped(b2.opts().WithName("A")); 1139 Node* b = InputShaped(b2.opts().WithName("B")); 1140 1141 Node* recv1 = 1142 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, 1143 b2.opts().WithName("outside_compilation_F1_O1_recv")); 1144 Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), 1145 b2.opts().WithName("E").WithControlInputs({recv1, b})); 1146 Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, 1147 b2.opts() 1148 .WithName("outside_compilation_F1_O1_send") 1149 .WithControlInput(e)); 1150 NodeBuilder node_builder1("F1", "F1", lib_def.get()); 1151 node_builder1.Input(a).Input(b); 1152 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); 1153 Node* s1 = NoOp( 1154 b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); 1155 1156 Node* recv2 = 1157 RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT}, 1158 b2.opts().WithName("outside_compilation_F2_O1_recv")); 1159 Node* h = Binary(ops::NodeOut(call1, 1), recv2, 1160 b2.opts().WithName("H").WithControlInput(s1)); 1161 Node* send2 = 1162 SendFromHost("host_compute_channel_F2_O1", {h}, 1163 b2.opts().WithName("outside_compilation_F2_O1_send")); 1164 1165 NodeBuilder node_builder2("F2", "F2", lib_def.get()); 1166 node_builder2.Input(e).Input(call1); 1167 Node* call2 = b2.opts() 1168 .WithControlInputs({s1, e, call1}) 1169 .FinalizeBuilder(&node_builder2); 1170 Node* s2 = NoOp( 1171 b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2})); 1172 Binary(call2, ops::NodeOut(call2, 1), 1173 b2.opts().WithName("J").WithControlInput(s2)); 1174 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 1175 } 1176 1177 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 1178 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 1179 } 1180 1181 // Test with one outside_compilation cluster that has no inputs from the 1182 // compiled subgraph. 1183 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { 1184 FunctionDefLibrary library; 1185 GraphDef graphdef; 1186 1187 { 1188 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 1189 Node* a = InputShaped(b1.opts().WithName("A")); 1190 Node* b = Input(b1.opts().WithName("B")); 1191 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 1192 Node* d = 1193 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); 1194 Node* e = Unary(a, b1.opts() 1195 .WithName("E") 1196 .WithAttr("_encapsulate", "F1") 1197 .WithAttr("_outside", "O1")); 1198 Node* f = 1199 Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); 1200 Unary(f, b1.opts().WithName("G")); 1201 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 1202 } 1203 1204 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 1205 1206 FunctionDefLibrary library_expected; 1207 GraphDef graphdef_expected; 1208 1209 TensorShapeProto shape_proto_expected; 1210 shape_proto_expected.add_dim()->set_size(2); 1211 1212 *library_expected.add_function() = FunctionDefHelper::Create( 1213 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, 1214 { 1215 {{"C"}, "UnaryTest", {"a_0_arg"}}, 1216 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, 1217 {{"F"}, 1218 "BinaryTest", 1219 {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, 1220 {{"outside_compilation_O1_host_compute"}, 1221 "_XlaHostCompute", 1222 {}, 1223 {{"Tinputs", gtl::ArraySlice<DataType>({})}, 1224 {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1225 {"key", "host_compute_channel_F1_O1"}, 1226 {"shape_inference_graph", ""}, 1227 {"shapes", 1228 gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}}, 1229 }, 1230 {{"f_0_retval", "F:o:0"}}); 1231 1232 { 1233 std::unique_ptr<FunctionLibraryDefinition> lib_def( 1234 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 1235 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 1236 Node* a = InputShaped(b2.opts().WithName("A")); 1237 Node* b = Input(b2.opts().WithName("B")); 1238 1239 Node* e = Unary(a, b2.opts().WithName("E")); 1240 Node* send1 = 1241 SendFromHost("host_compute_channel_F1_O1", {e}, 1242 b2.opts().WithName("outside_compilation_F1_O1_send")); 1243 NodeBuilder node_builder1("F1", "F1", lib_def.get()); 1244 node_builder1.Input(a).Input(b); 1245 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); 1246 Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1)); 1247 1248 Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); 1249 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 1250 } 1251 1252 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 1253 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 1254 } 1255 1256 // Test with one outside_compilation cluster that has no data inputs but has a 1257 // control input from the compiled subgraph. 1258 TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { 1259 FunctionDefLibrary library; 1260 GraphDef graphdef; 1261 1262 { 1263 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 1264 Node* a = InputShaped(b1.opts().WithName("A")); 1265 Node* b = Input(b1.opts().WithName("B")); 1266 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 1267 Node* d = 1268 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); 1269 Node* e = Unary(a, b1.opts() 1270 .WithName("E") 1271 .WithControlInput(d) 1272 .WithAttr("_encapsulate", "F1") 1273 .WithAttr("_outside", "O1")); 1274 Node* f = 1275 Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); 1276 Unary(f, b1.opts().WithName("G")); 1277 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 1278 } 1279 1280 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 1281 1282 FunctionDefLibrary library_expected; 1283 GraphDef graphdef_expected; 1284 1285 TensorShapeProto shape_proto_expected; 1286 shape_proto_expected.add_dim()->set_size(2); 1287 1288 *library_expected.add_function() = FunctionDefHelper::Create( 1289 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, 1290 { 1291 {{"C"}, "UnaryTest", {"a_0_arg"}}, 1292 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, 1293 {{"F"}, 1294 "BinaryTest", 1295 {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, 1296 {{"outside_compilation_O1_host_compute"}, 1297 "_XlaHostCompute", 1298 {}, 1299 {{"Tinputs", gtl::ArraySlice<DataType>({})}, 1300 {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1301 {"key", "host_compute_channel_F1_O1"}, 1302 {"shape_inference_graph", ""}, 1303 {"shapes", 1304 gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}, 1305 {"D"}}, 1306 }, 1307 {{"f_0_retval", "F:o:0"}}); 1308 1309 { 1310 std::unique_ptr<FunctionLibraryDefinition> lib_def( 1311 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 1312 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 1313 Node* a = InputShaped(b2.opts().WithName("A")); 1314 Node* b = Input(b2.opts().WithName("B")); 1315 1316 Node* recv1 = 1317 RecvAtHost("host_compute_channel_F1_O1", {}, 1318 b2.opts().WithName("outside_compilation_F1_O1_recv")); 1319 Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); 1320 Node* send1 = 1321 SendFromHost("host_compute_channel_F1_O1", {e}, 1322 b2.opts().WithName("outside_compilation_F1_O1_send")); 1323 NodeBuilder node_builder1("F1", "F1", lib_def.get()); 1324 node_builder1.Input(a).Input(b); 1325 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); 1326 Node* s1 = NoOp( 1327 b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); 1328 1329 Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); 1330 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 1331 } 1332 1333 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 1334 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 1335 } 1336 1337 // Test with one outside_compilation cluster that has no outputs from the 1338 // compiled subgraph. 1339 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { 1340 FunctionDefLibrary library; 1341 GraphDef graphdef; 1342 1343 { 1344 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 1345 Node* a = Input(b1.opts().WithName("A")); 1346 Node* b = Input(b1.opts().WithName("B")); 1347 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 1348 Node* d = 1349 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); 1350 Node* e = Unary(d, b1.opts() 1351 .WithName("E") 1352 .WithAttr("_encapsulate", "F1") 1353 .WithAttr("_outside", "O1")); 1354 Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); 1355 Binary(e, f, b1.opts().WithName("G")); 1356 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 1357 } 1358 1359 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 1360 1361 FunctionDefLibrary library_expected; 1362 GraphDef graphdef_expected; 1363 1364 *library_expected.add_function() = FunctionDefHelper::Create( 1365 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, 1366 { 1367 {{"C"}, "UnaryTest", {"a_0_arg"}}, 1368 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, 1369 {{"F"}, "UnaryTest", {"D:o:0"}}, 1370 {{"outside_compilation_O1_host_compute"}, 1371 "_XlaHostCompute", 1372 {"D:o:0"}, 1373 {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1374 {"Toutputs", gtl::ArraySlice<DataType>({})}, 1375 {"key", "host_compute_channel_F1_O1"}, 1376 {"shape_inference_graph", ""}, 1377 {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}}, 1378 }, 1379 {{"f_0_retval", "F:o:0"}}); 1380 1381 { 1382 std::unique_ptr<FunctionLibraryDefinition> lib_def( 1383 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 1384 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 1385 Node* a = Input(b2.opts().WithName("A")); 1386 Node* b = Input(b2.opts().WithName("B")); 1387 1388 Node* recv1 = 1389 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, 1390 b2.opts().WithName("outside_compilation_F1_O1_recv")); 1391 Node* e = Unary(recv1, b2.opts().WithName("E")); 1392 NodeBuilder node_builder1("F1", "F1", lib_def.get()); 1393 node_builder1.Input(a).Input(b); 1394 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); 1395 Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1)); 1396 1397 Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); 1398 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 1399 } 1400 1401 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 1402 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 1403 } 1404 1405 // Test with one outside_compilation cluster that has no data outputs but has a 1406 // control output to the compiled subgraph. 1407 TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { 1408 FunctionDefLibrary library; 1409 GraphDef graphdef; 1410 1411 { 1412 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 1413 Node* a = Input(b1.opts().WithName("A")); 1414 Node* b = Input(b1.opts().WithName("B")); 1415 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 1416 Node* d = 1417 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); 1418 Node* e = Unary(d, b1.opts() 1419 .WithName("E") 1420 .WithAttr("_encapsulate", "F1") 1421 .WithAttr("_outside", "O1")); 1422 Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr( 1423 "_encapsulate", "F1")); 1424 Binary(e, f, b1.opts().WithName("G")); 1425 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 1426 } 1427 1428 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 1429 1430 FunctionDefLibrary library_expected; 1431 GraphDef graphdef_expected; 1432 1433 *library_expected.add_function() = FunctionDefHelper::Create( 1434 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, 1435 { 1436 {{"C"}, "UnaryTest", {"a_0_arg"}}, 1437 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, 1438 {{"F"}, 1439 "UnaryTest", 1440 {"D:o:0"}, 1441 {}, 1442 {"outside_compilation_O1_host_compute"}}, 1443 {{"outside_compilation_O1_host_compute"}, 1444 "_XlaHostCompute", 1445 {"D:o:0"}, 1446 {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1447 {"Toutputs", gtl::ArraySlice<DataType>({})}, 1448 {"key", "host_compute_channel_F1_O1"}, 1449 {"shape_inference_graph", ""}, 1450 {"shapes", gtl::ArraySlice<TensorShapeProto>({})}}}, 1451 }, 1452 {{"f_0_retval", "F:o:0"}}); 1453 1454 { 1455 std::unique_ptr<FunctionLibraryDefinition> lib_def( 1456 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 1457 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 1458 Node* a = Input(b2.opts().WithName("A")); 1459 Node* b = Input(b2.opts().WithName("B")); 1460 1461 Node* recv1 = 1462 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, 1463 b2.opts().WithName("outside_compilation_F1_O1_recv")); 1464 Node* e = Unary(recv1, b2.opts().WithName("E")); 1465 Node* send1 = SendFromHost("host_compute_channel_F1_O1", {}, 1466 b2.opts() 1467 .WithName("outside_compilation_F1_O1_send") 1468 .WithControlInput(e)); 1469 NodeBuilder node_builder1("F1", "F1", lib_def.get()); 1470 node_builder1.Input(a).Input(b); 1471 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); 1472 Node* s1 = NoOp( 1473 b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); 1474 1475 Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); 1476 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 1477 } 1478 1479 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 1480 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 1481 } 1482 1483 // Test with one outside_compilation cluster that has no outputs from the 1484 // compiled subgraph. 1485 TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { 1486 FunctionDefLibrary library; 1487 GraphDef graphdef; 1488 1489 { 1490 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 1491 Node* a = Input(b1.opts().WithName("A")); 1492 Node* b = Input(b1.opts().WithName("B")); 1493 Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); 1494 Node* d = 1495 Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); 1496 Node* e = Unary(a, b1.opts() 1497 .WithName("E") 1498 .WithAttr("_encapsulate", "F1") 1499 .WithAttr("_outside", "O1")); 1500 Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); 1501 Binary(e, f, b1.opts().WithName("G")); 1502 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 1503 } 1504 1505 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 1506 1507 FunctionDefLibrary library_expected; 1508 GraphDef graphdef_expected; 1509 1510 *library_expected.add_function() = FunctionDefHelper::Create( 1511 "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, 1512 { 1513 {{"C"}, "UnaryTest", {"a_0_arg"}}, 1514 {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, 1515 {{"F"}, "UnaryTest", {"D:o:0"}}, 1516 }, 1517 {{"f_0_retval", "F:o:0"}}); 1518 1519 { 1520 std::unique_ptr<FunctionLibraryDefinition> lib_def( 1521 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 1522 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 1523 Node* a = Input(b2.opts().WithName("A")); 1524 Node* b = Input(b2.opts().WithName("B")); 1525 1526 Node* e = Unary(a, b2.opts().WithName("E")); 1527 NodeBuilder node_builder1("F1", "F1", lib_def.get()); 1528 node_builder1.Input(a).Input(b); 1529 Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); 1530 1531 Binary(e, call1, b2.opts().WithName("G")); 1532 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 1533 } 1534 1535 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 1536 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 1537 } 1538 1539 // Test for shape inference of outside compilation. 1540 TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { 1541 FunctionDefLibrary library; 1542 GraphDef graphdef; 1543 1544 { 1545 *library.add_function() = test::function::XTimesTwo(); 1546 1547 GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); 1548 Node* a = InputShaped(b1.opts().WithName("A")); 1549 Node* b = Input(b1.opts().WithName("B")); 1550 // Give nodes 'c' and 'd' names that collide after lowercasing. 1551 Node* c = Unary(a, b1.opts().WithName("C")); 1552 Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr( 1553 "_encapsulate", "F1")); 1554 Node* e = BinaryUnknownShape(c, d, 1555 b1.opts() 1556 .WithName("E") 1557 .WithControlInputs({b, d}) 1558 .WithAttr("_encapsulate", "F1") 1559 .WithAttr("_outside", "O1")); 1560 Node* f = Binary(c, e, 1561 b1.opts().WithName("F").WithControlInput(e).WithAttr( 1562 "_encapsulate", "F1")); 1563 Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); 1564 TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); 1565 } 1566 1567 TF_EXPECT_OK(Encapsulate(&graphdef, &library)); 1568 1569 FunctionDefLibrary library_expected; 1570 GraphDef graphdef_expected; 1571 1572 string shape_string_expected; 1573 { 1574 GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); 1575 Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0")); 1576 Node* recv = 1577 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, 1578 shape.opts().WithName("outside_compilation_F1_O1_recv")); 1579 Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); 1580 SendFromHost("host_compute_channel_F1_O1", {e}, 1581 shape.opts().WithName("outside_compilation_F1_O1_send")); 1582 GraphDef shape_graph; 1583 TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); 1584 EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); 1585 } 1586 1587 *library_expected.add_function() = test::function::XTimesTwo(); 1588 *library_expected.add_function() = FunctionDefHelper::Create( 1589 "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {}, 1590 { 1591 {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}}, 1592 {{"F"}, 1593 "BinaryTest", 1594 {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"}, 1595 {}, 1596 {"outside_compilation_O1_host_compute"}}, 1597 {{"outside_compilation_O1_host_compute"}, 1598 "_XlaHostCompute", 1599 {"c:o:0"}, 1600 {{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1601 {"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})}, 1602 {"key", "host_compute_channel_F1_O1"}, 1603 {"shape_inference_graph", shape_string_expected}, 1604 {"shapes", gtl::ArraySlice<DataType>({})}}, 1605 {"c"}}, 1606 }, 1607 {{"f_0_retval", "F:o:0"}}); 1608 1609 { 1610 std::unique_ptr<FunctionLibraryDefinition> lib_def( 1611 new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); 1612 GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); 1613 Node* a = InputShaped(b2.opts().WithName("A")); 1614 Node* b = Input(b2.opts().WithName("B")); 1615 Node* c = Unary(a, b2.opts().WithName("C")); 1616 1617 NodeBuilder node_builder("F1", "F1", lib_def.get()); 1618 node_builder.Input(b).Input(c); 1619 Node* call = 1620 b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); 1621 1622 Node* recv = 1623 RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, 1624 b2.opts().WithName("outside_compilation_F1_O1_recv")); 1625 Node* e = BinaryUnknownShape( 1626 c, ops::NodeOut(recv, 0), 1627 b2.opts().WithName("E").WithControlInputs({recv, b})); 1628 Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, 1629 b2.opts() 1630 .WithName("outside_compilation_F1_O1_send") 1631 .WithControlInput(e)); 1632 1633 Node* s = NoOp( 1634 b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); 1635 1636 Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); 1637 TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); 1638 } 1639 1640 TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); 1641 TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); 1642 } 1643 1644 } // namespace 1645 } // namespace tensorflow 1646