1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" 17 #include "tensorflow/cc/framework/scope.h" 18 #include "tensorflow/core/common_runtime/shape_refiner.h" 19 #include "tensorflow/core/framework/node_def.pb.h" 20 #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" 21 #include "tensorflow/core/lib/core/status.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/platform/test.h" 24 25 namespace tensorflow { 26 namespace { 27 28 using ClusterInfo = RemoteFusedGraphExecuteUtils::ClusterInfo; 29 30 constexpr const char* const NAME_A = "A"; 31 constexpr const char* const NAME_B = "B"; 32 constexpr const char* const NAME_A_PLUS_B = "A_PLUS_B"; 33 constexpr float NODE_A_VAL = 2.0f; 34 constexpr float NODE_B_VAL = 3.0f; 35 constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f; 36 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 = 37 "fuse_test_remote_fused_graph_executor0"; 38 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 = 39 "fuse_test_remote_fused_graph_executor1"; 40 41 static NodeDef* GetNodeDef(const string& name, GraphDef* def) { 42 CHECK_NE(def, nullptr); 43 for (NodeDef& node_def : *def->mutable_node()) { 44 if (node_def.name() == name) { 45 return &node_def; 46 } 47 } 48 return nullptr; 49 } 50 51 Status BuildRemoteFusedGraphExecutor0( 52 std::unique_ptr<IRemoteFusedGraphExecutor>* executor) { 53 executor->reset( 54 new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0)); 55 return Status::OK(); 56 } 57 58 Status BuildRemoteFusedGraphExecutor1( 59 std::unique_ptr<IRemoteFusedGraphExecutor>* executor) { 60 executor->reset(new TestRemoteFusedGraphExecutor( 61 {"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1)); 62 return Status::OK(); 63 } 64 65 class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test { 66 protected: 67 void SetUp() final { 68 TF_ASSERT_OK( 69 RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_)); 70 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar 71 hexagon_remote_fused_graph_executor_build( 72 "remote_graph_executor_name", 73 [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status { 74 return Status::OK(); 75 }); 76 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar 77 test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0, 78 BuildRemoteFusedGraphExecutor0); 79 80 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar 81 test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1, 82 BuildRemoteFusedGraphExecutor1); 83 } 84 85 void TearDown() final {} 86 87 Status FuseByInOut() { 88 // Feed output shapes and types 89 RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; 90 GraphDef graph_def_with_shapetype = graph_def_; 91 TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( 92 input_tensors_, /*dry_run_inference*/ true, &graph_def_with_shapetype)); 93 94 return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder( 95 graph_def_with_shapetype, inputs_, outputs_, 96 "remote_fused_graph_node_names", subgraph_input_names_, 97 subgraph_output_names_, "remote_graph_executor_name", 98 /*require_shape_type=*/true, &result_graph_def_); 99 } 100 101 Status FuseByNodes() { 102 return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames( 103 graph_def_, inputs_, outputs_, "remote_fused_graph_node_names", 104 subgraph_node_names_, "remote_graph_executor_name", 105 /*require_shape_type=*/false, &result_graph_def_); 106 } 107 108 Status FuseByOpTypes() { 109 return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes( 110 graph_def_, inputs_, outputs_, "remote_fused_graph_node_names", 111 subgraph_op_types_, "remote_graph_executor_name", 112 /*require_shape_type=*/false, &result_graph_def_); 113 } 114 115 Status FuseByExecutor0() { 116 return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor( 117 graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME0, 118 &result_graph_def_); 119 } 120 121 Status FuseByExecutor1() { 122 return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor( 123 graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME1, 124 &result_graph_def_); 125 } 126 127 Status BuildAndAddTensorShape() { 128 return RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( 129 input_tensors_, /*dry_run_inference=*/true, &graph_def_); 130 } 131 132 Status PlaceRemoteGraphArguments() { 133 return RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments( 134 inputs_, outputs_, subgraph_node_names_, subgraph_input_names_, 135 subgraph_output_names_, subgraph_op_types_, 136 "remote_fused_graph_node_names", "remote_graph_executor_name", 137 &graph_def_); 138 } 139 140 Status FuseByPlacedArguments() { 141 const Status status = 142 RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments( 143 graph_def_, input_tensors_, &graph_def_); 144 result_graph_def_ = graph_def_; 145 return status; 146 } 147 148 bool IsFuseReady() { 149 return RemoteFusedGraphExecuteUtils::IsFuseReady(graph_def_, 150 input_tensors_); 151 } 152 153 void ReplaceOpType(const std::unordered_set<string>& op_name, 154 const string& new_op_type) { 155 for (NodeDef& node_def : *graph_def_.mutable_node()) { 156 if (op_name.count(node_def.name()) > 0) { 157 node_def.set_op(new_op_type); 158 } 159 } 160 } 161 162 public: 163 const std::vector<std::pair<string, Tensor>> input_tensors_{ 164 {"A", {DT_FLOAT, {1, 1, 1, 1}}}}; 165 const std::vector<string> inputs_{"A"}; 166 const std::vector<string> outputs_{"K"}; 167 GraphDef graph_def_; 168 GraphDef result_graph_def_; 169 std::vector<string> subgraph_input_names_; 170 std::vector<string> subgraph_output_names_; 171 std::unordered_set<string> subgraph_node_names_; 172 std::unordered_set<string> subgraph_op_types_; 173 }; 174 175 void SetSubgraphArguments(const std::vector<string>& input_names, 176 const std::vector<string>& output_names, 177 FuseRemoteGraphMultipleAddOpsTest* fixture) { 178 for (const string& input_name : input_names) { 179 fixture->subgraph_input_names_.emplace_back(input_name); 180 } 181 182 fixture->subgraph_output_names_ = output_names; 183 } 184 185 template <typename T> 186 static string IterToString(const T& set) { 187 string out; 188 for (const string& val : set) { 189 if (!out.empty()) { 190 out += ", "; 191 } 192 out += val; 193 } 194 return out; 195 } 196 197 static string SummarizeGraphDef(const GraphDef& graph_def) { 198 string out; 199 for (const NodeDef& node : graph_def.node()) { 200 out += strings::StrCat("node: ", node.name(), "\n input: "); 201 for (const string& input : node.input()) { 202 out += strings::StrCat(input, ", "); 203 } 204 out += "\n"; 205 } 206 return out; 207 } 208 209 static string DumpInOutNames(const std::vector<ClusterInfo>& ci_vec) { 210 for (int i = 0; i < ci_vec.size(); ++i) { 211 LOG(INFO) << "Cluster(" << i << ")"; 212 LOG(INFO) << "input: " << IterToString(std::get<1>(ci_vec.at(i))); 213 LOG(INFO) << "output: " << IterToString(std::get<2>(ci_vec.at(i))); 214 } 215 return ""; 216 } 217 218 static void ClearCluster(ClusterInfo* cluster) { 219 std::get<0>(*cluster).clear(); 220 std::get<1>(*cluster).clear(); 221 std::get<2>(*cluster).clear(); 222 } 223 224 TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphA) { 225 GraphDef def; 226 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( 227 NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); 228 std::pair<string, Tensor> input_node_info; 229 input_node_info.first = NAME_A; 230 input_node_info.second = Tensor(DT_FLOAT, {}); 231 input_node_info.second.scalar<float>()() = 1.0f; 232 const std::vector<std::pair<string, Tensor>> inputs{input_node_info}; 233 std::vector<string> outputs = {NAME_B, NAME_A_PLUS_B}; 234 std::vector<tensorflow::Tensor> output_tensors; 235 Status status = RemoteFusedGraphExecuteUtils::DryRunInference( 236 def, inputs, outputs, false /* initialize_by_zero */, &output_tensors); 237 ASSERT_TRUE(status.ok()) << status; 238 EXPECT_EQ(outputs.size(), output_tensors.size()); 239 EXPECT_NEAR(NODE_B_VAL, output_tensors.at(0).scalar<float>()(), 240 VALUE_TOLERANCE_FLOAT); 241 EXPECT_NEAR(1.0f + NODE_B_VAL, output_tensors.at(1).scalar<float>()(), 242 VALUE_TOLERANCE_FLOAT); 243 } 244 245 TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAUninitialized) { 246 GraphDef def; 247 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( 248 NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); 249 std::pair<string, Tensor> input_node_info; 250 input_node_info.first = NAME_A; 251 input_node_info.second = Tensor(DT_FLOAT, {}); 252 const std::vector<std::pair<string, Tensor>> inputs{input_node_info}; 253 std::vector<string> outputs = {NAME_B, NAME_A_PLUS_B}; 254 std::vector<tensorflow::Tensor> output_tensors; 255 Status status = RemoteFusedGraphExecuteUtils::DryRunInference( 256 def, inputs, outputs, true /* initialize_by_zero */, &output_tensors); 257 ASSERT_TRUE(status.ok()) << status; 258 EXPECT_EQ(outputs.size(), output_tensors.size()); 259 EXPECT_NEAR(NODE_B_VAL, output_tensors.at(0).scalar<float>()(), 260 VALUE_TOLERANCE_FLOAT); 261 EXPECT_NEAR(NODE_B_VAL, output_tensors.at(1).scalar<float>()(), 262 VALUE_TOLERANCE_FLOAT); 263 } 264 265 TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAB) { 266 GraphDef def; 267 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( 268 NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); 269 std::pair<string, Tensor> input_node_info_a; 270 input_node_info_a.first = NAME_A; 271 input_node_info_a.second = Tensor(DT_FLOAT, {}); 272 input_node_info_a.second.scalar<float>()() = NODE_A_VAL; 273 std::pair<string, Tensor> input_node_info_b; 274 input_node_info_b.first = NAME_B; 275 input_node_info_b.second = Tensor(DT_FLOAT, {}); 276 input_node_info_b.second.scalar<float>()() = NODE_B_VAL; 277 const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a, 278 input_node_info_b}; 279 std::vector<string> outputs = {NAME_A_PLUS_B}; 280 std::vector<tensorflow::Tensor> output_tensors; 281 Status status = RemoteFusedGraphExecuteUtils::DryRunInference( 282 def, inputs, outputs, false /* initialize_by_zero */, &output_tensors); 283 ASSERT_TRUE(status.ok()) << status; 284 EXPECT_EQ(outputs.size(), output_tensors.size()); 285 EXPECT_NEAR(NODE_A_VAL + NODE_B_VAL, output_tensors.at(0).scalar<float>()(), 286 VALUE_TOLERANCE_FLOAT); 287 } 288 289 TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphForAllNodes) { 290 // Set Node "A" as an input with value (= 1.0f) 291 std::pair<string, Tensor> input_node_info_a; 292 input_node_info_a.first = NAME_A; 293 input_node_info_a.second = Tensor(DT_FLOAT, {}); 294 input_node_info_a.second.scalar<float>()() = 1.0f; 295 296 // Setup dryrun arguments 297 const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a}; 298 RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; 299 300 GraphDef def; 301 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( 302 NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); 303 304 // dryrun 305 const Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode( 306 def, inputs, false /* initialize_by_zero */, &tensor_shape_map); 307 308 ASSERT_TRUE(status.ok()) << status; 309 310 // Assert output node count 311 ASSERT_EQ(3, tensor_shape_map.size()); 312 ASSERT_EQ(1, tensor_shape_map.count(NAME_A)); 313 ASSERT_EQ(1, tensor_shape_map.count(NAME_B)); 314 ASSERT_EQ(1, tensor_shape_map.count(NAME_A_PLUS_B)); 315 316 const RemoteFusedGraphExecuteUtils::TensorShapeType* tst = 317 RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, 318 NAME_B); 319 EXPECT_NE(tst, nullptr); 320 EXPECT_EQ(DT_FLOAT, tst->first); 321 EXPECT_EQ(0, tst->second.dims()); 322 323 tst = RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, 324 NAME_A_PLUS_B); 325 EXPECT_NE(tst, nullptr); 326 EXPECT_EQ(DT_FLOAT, tst->first); 327 EXPECT_EQ(0, tst->second.dims()); 328 } 329 330 TEST(RemoteFusedGraphExecuteUtils, PropagateAndBuildTensorShapeMap) { 331 std::pair<string, Tensor> input_node_info_a; 332 input_node_info_a.first = NAME_A; 333 input_node_info_a.second = Tensor(DT_FLOAT, {}); 334 input_node_info_a.second.scalar<float>()() = NODE_A_VAL; 335 std::pair<string, Tensor> input_node_info_b; 336 input_node_info_b.first = NAME_B; 337 input_node_info_b.second = Tensor(DT_FLOAT, {}); 338 input_node_info_b.second.scalar<float>()() = NODE_B_VAL; 339 const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a, 340 input_node_info_b}; 341 342 RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map; 343 GraphDef def; 344 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( 345 NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); 346 ImportGraphDefOptions opts; 347 Graph graph(OpRegistry::Global()); 348 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 349 Status status = ImportGraphDef(opts, def, &graph, &shape_refiner); 350 ASSERT_TRUE(RemoteFusedGraphExecuteUtils::PropagateShapeInference( 351 def, inputs, &graph, &shape_refiner) 352 .ok()); 353 ASSERT_TRUE(RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph( 354 graph, shape_refiner, &tensor_shape_map) 355 .ok()); 356 357 ASSERT_EQ(3, tensor_shape_map.size()); 358 ASSERT_EQ(1, tensor_shape_map.count(NAME_A)); 359 ASSERT_EQ(1, tensor_shape_map.count(NAME_B)); 360 ASSERT_EQ(1, tensor_shape_map.count(NAME_A_PLUS_B)); 361 362 const RemoteFusedGraphExecuteUtils::TensorShapeType* tst = 363 RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, 364 NAME_B); 365 EXPECT_NE(tst, nullptr); 366 EXPECT_EQ(DT_FLOAT, tst->first); 367 EXPECT_EQ(0, tst->second.dims()); 368 369 tst = RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, 370 NAME_A_PLUS_B); 371 EXPECT_NE(tst, nullptr); 372 EXPECT_EQ(DT_FLOAT, tst->first); 373 EXPECT_EQ(0, tst->second.dims()); 374 375 { 376 NodeDef* node_def = GetNodeDef(NAME_B, &def); 377 TF_ASSERT_OK( 378 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( 379 tensor_shape_map, node_def)); 380 std::vector<DataType> data_types; 381 TF_ASSERT_OK(GetNodeAttr( 382 *node_def, RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, 383 &data_types)); 384 ASSERT_EQ(1, data_types.size()); 385 EXPECT_EQ(DT_FLOAT, data_types.at(0)); 386 387 std::vector<TensorShape> shapes; 388 TF_ASSERT_OK(GetNodeAttr( 389 *node_def, RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, &shapes)); 390 ASSERT_EQ(1, shapes.size()); 391 EXPECT_EQ(0, shapes.at(0).dims()); 392 } 393 394 { 395 NodeDef* node_def = GetNodeDef(NAME_A_PLUS_B, &def); 396 TF_ASSERT_OK( 397 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( 398 tensor_shape_map, node_def)); 399 std::vector<DataType> data_types; 400 TF_ASSERT_OK(GetNodeAttr( 401 *node_def, RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, 402 &data_types)); 403 ASSERT_EQ(1, data_types.size()); 404 EXPECT_EQ(DT_FLOAT, data_types.at(0)); 405 406 std::vector<TensorShape> shapes; 407 TF_ASSERT_OK(GetNodeAttr( 408 *node_def, RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, &shapes)); 409 ASSERT_EQ(1, shapes.size()); 410 EXPECT_EQ(0, shapes.at(0).dims()); 411 } 412 } 413 414 TEST(RemoteFusedGraphExecuteUtils, 415 BuildRemoteFusedGraphExecuteInfoWithShapeInference) { 416 // Build inputs 417 std::pair<string, Tensor> input_node_info_a; 418 input_node_info_a.first = NAME_A; 419 input_node_info_a.second = Tensor(DT_FLOAT, {}); 420 input_node_info_a.second.scalar<float>()() = NODE_A_VAL; 421 std::pair<string, Tensor> input_node_info_b; 422 input_node_info_b.first = NAME_B; 423 input_node_info_b.second = Tensor(DT_FLOAT, {}); 424 input_node_info_b.second.scalar<float>()() = NODE_B_VAL; 425 const std::vector<std::pair<string, Tensor>> input_tensors{input_node_info_a, 426 input_node_info_b}; 427 const std::vector<string> inputs{NAME_A, NAME_B}; 428 429 // Build outputs 430 const std::vector<string> outputs = {NAME_A_PLUS_B}; 431 432 GraphDef def; 433 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( 434 NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); 435 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( 436 input_tensors, /*dry_run_inference*/ true, &def)); 437 438 RemoteFusedGraphExecuteInfo execute_info0; 439 DataTypeVector input_types0; 440 DataTypeVector output_types0; 441 442 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo( 443 "executor", def, inputs, outputs, /*require_shape_type=*/true, 444 &execute_info0, &input_types0, &output_types0)); 445 446 EXPECT_EQ(inputs.size(), 447 execute_info0.default_graph_input_tensor_shape_size()); 448 EXPECT_EQ(outputs.size(), 449 execute_info0.default_graph_output_tensor_shape_size()); 450 EXPECT_EQ(inputs.size(), input_types0.size()); 451 EXPECT_EQ(outputs.size(), output_types0.size()); 452 453 EXPECT_EQ(def.node_size(), execute_info0.remote_graph().node_size()); 454 } 455 456 TEST(RemoteFusedGraphExecuteUtils, BuildRemoteFusedGraphExecuteOpNode) { 457 const std::vector<string> inputs{NAME_A, NAME_B}; 458 459 // Build outputs 460 const std::vector<string> outputs = {NAME_A_PLUS_B}; 461 462 GraphDef def; 463 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( 464 NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def)); 465 466 Graph graph(OpRegistry::Global()); 467 ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 468 TF_ASSERT_OK(ImportGraphDef({}, def, &graph, &shape_refiner)); 469 470 Node* node; 471 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( 472 "fused_name", "executor", def, inputs, outputs, 473 /*require_shape_type=*/false, &graph, &node)); 474 } 475 476 TEST(RemoteFusedGraphExecuteUtils, ExtractSubgraphNodes) { 477 GraphDef graph_def; 478 TF_ASSERT_OK( 479 RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def)); 480 ClusterInfo cluster; 481 const std::unordered_set<string>& node_names = std::get<0>(cluster); 482 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 483 {"H", "I"}, {"J"}, graph_def, &cluster)); 484 EXPECT_EQ(1, node_names.size()) << IterToString(node_names); 485 486 ClearCluster(&cluster); 487 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 488 {"F", "C", "G"}, {"J"}, graph_def, &cluster)); 489 EXPECT_EQ(3, node_names.size()) << IterToString(node_names); 490 491 ClearCluster(&cluster); 492 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 493 {"A", "B", "C", "D", "E"}, {"J"}, graph_def, &cluster)); 494 EXPECT_EQ(5, node_names.size()) << IterToString(node_names); 495 496 ClearCluster(&cluster); 497 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 498 {"A", "B", "C", "D", "E"}, {"K"}, graph_def, &cluster)); 499 EXPECT_EQ(6, node_names.size()) << IterToString(node_names); 500 501 ClearCluster(&cluster); 502 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 503 {"F"}, {"H"}, graph_def, &cluster)); 504 EXPECT_EQ(2, node_names.size()) << IterToString(node_names); 505 } 506 507 TEST(RemoteFusedGraphExecuteUtils, ClusterizeNodes) { 508 GraphDef graph_def; 509 TF_ASSERT_OK( 510 RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def)); 511 512 std::vector<ClusterInfo> ci_vec; 513 TF_ASSERT_OK( 514 RemoteFusedGraphExecuteUtils::ClusterizeNodes({"J"}, graph_def, &ci_vec)); 515 ASSERT_EQ(1, ci_vec.size()); 516 EXPECT_EQ(2, std::get<1>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec); 517 EXPECT_EQ(1, std::get<2>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec); 518 519 ci_vec.clear(); 520 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::ClusterizeNodes( 521 {"H", "I", "J"}, graph_def, &ci_vec)); 522 ASSERT_EQ(1, ci_vec.size()); 523 EXPECT_EQ(3, std::get<1>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec); 524 EXPECT_EQ(1, std::get<2>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec); 525 526 ci_vec.clear(); 527 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::ClusterizeNodes( 528 {"F", "C", "G", "H", "I", "J"}, graph_def, &ci_vec)); 529 ASSERT_EQ(1, ci_vec.size()); 530 EXPECT_EQ(4, std::get<1>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec); 531 EXPECT_EQ(2, std::get<2>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec); 532 533 ci_vec.clear(); 534 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::ClusterizeNodes( 535 {"A", "B", "C", "D", "E"}, graph_def, &ci_vec)); 536 ASSERT_EQ(5, ci_vec.size()); 537 538 ci_vec.clear(); 539 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::ClusterizeNodes( 540 {"A", "B", "D", "E", "F", "G"}, graph_def, &ci_vec)); 541 ASSERT_EQ(2, ci_vec.size()); 542 } 543 544 TEST(RemoteFusedGraphExecuteUtils, BuildSubgraphDefByInOut) { 545 GraphDef graph_def; 546 TF_ASSERT_OK( 547 RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def)); 548 549 ClusterInfo cluster; 550 GraphDef subgraph_def; 551 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 552 std::vector<string>{"H", "I"}, std::vector<string>{"J"}, graph_def, 553 &cluster)); 554 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef( 555 cluster, graph_def, &subgraph_def)); 556 EXPECT_EQ(3, subgraph_def.node_size()); 557 558 ClearCluster(&cluster); 559 subgraph_def.Clear(); 560 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 561 std::vector<string>{"F", "C", "G"}, std::vector<string>{"J"}, graph_def, 562 &cluster)); 563 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef( 564 cluster, graph_def, &subgraph_def)); 565 EXPECT_EQ(6, subgraph_def.node_size()); 566 567 ClearCluster(&cluster); 568 subgraph_def.Clear(); 569 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 570 std::vector<string>{"A", "B", "C", "D", "E"}, std::vector<string>{"J"}, 571 graph_def, &cluster)); 572 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef( 573 cluster, graph_def, &subgraph_def)); 574 EXPECT_EQ(10, subgraph_def.node_size()); 575 576 ClearCluster(&cluster); 577 subgraph_def.Clear(); 578 579 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 580 std::vector<string>{"A", "B", "C", "D", "E"}, std::vector<string>{"K"}, 581 graph_def, &cluster)); 582 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef( 583 cluster, graph_def, &subgraph_def)); 584 EXPECT_EQ(11, subgraph_def.node_size()); 585 586 ClearCluster(&cluster); 587 subgraph_def.Clear(); 588 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 589 std::vector<string>{"F"}, std::vector<string>{"H"}, graph_def, &cluster)); 590 TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef( 591 cluster, graph_def, &subgraph_def)); 592 EXPECT_EQ(3, subgraph_def.node_size()); 593 } 594 595 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_HI_J) { 596 SetSubgraphArguments(std::vector<string>{"H", "I"}, std::vector<string>{"J"}, 597 this); 598 599 TF_ASSERT_OK(FuseByInOut()); 600 601 EXPECT_EQ(11, graph_def_.node_size()); 602 EXPECT_EQ(11, result_graph_def_.node_size()) 603 << "=== Before: \n" 604 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 605 << SummarizeGraphDef(result_graph_def_); 606 } 607 608 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_FCG_J) { 609 SetSubgraphArguments(std::vector<string>{"F", "C", "G"}, 610 std::vector<string>{"J"}, this); 611 612 TF_ASSERT_OK(FuseByInOut()); 613 614 EXPECT_EQ(11, graph_def_.node_size()); 615 EXPECT_EQ(9, result_graph_def_.node_size()) 616 << "=== Before: \n" 617 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 618 << SummarizeGraphDef(result_graph_def_); 619 } 620 621 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_ABCDE_J) { 622 SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"}, 623 std::vector<string>{"J"}, this); 624 625 TF_ASSERT_OK(FuseByInOut()); 626 627 EXPECT_EQ(11, graph_def_.node_size()); 628 EXPECT_EQ(8, result_graph_def_.node_size()) 629 << "=== Before: \n" 630 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 631 << SummarizeGraphDef(result_graph_def_); 632 } 633 634 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_ABCDE_K) { 635 SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"}, 636 std::vector<string>{"K"}, this); 637 638 TF_ASSERT_OK(FuseByInOut()); 639 640 EXPECT_EQ(11, graph_def_.node_size()); 641 EXPECT_EQ(7, result_graph_def_.node_size()) 642 << "=== Before: \n" 643 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 644 << SummarizeGraphDef(result_graph_def_); 645 } 646 647 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_H) { 648 subgraph_node_names_ = {"H"}; 649 650 TF_ASSERT_OK(FuseByNodes()); 651 652 EXPECT_EQ(11, graph_def_.node_size()); 653 EXPECT_EQ(11, result_graph_def_.node_size()) 654 << "=== Before: \n" 655 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 656 << SummarizeGraphDef(result_graph_def_); 657 } 658 659 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_HIJ) { 660 subgraph_node_names_ = {"H", "I", "J"}; 661 662 TF_ASSERT_OK(FuseByNodes()); 663 664 EXPECT_EQ(11, graph_def_.node_size()); 665 EXPECT_EQ(9, result_graph_def_.node_size()) 666 << "=== Before: \n" 667 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 668 << SummarizeGraphDef(result_graph_def_); 669 } 670 671 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_CFGHIJ) { 672 subgraph_node_names_ = {"C", "F", "G", "H", "I", "J"}; 673 674 TF_ASSERT_OK(FuseByNodes()); 675 676 EXPECT_EQ(11, graph_def_.node_size()); 677 EXPECT_EQ(6, result_graph_def_.node_size()) 678 << "=== Before: \n" 679 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 680 << SummarizeGraphDef(result_graph_def_); 681 } 682 683 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_ABCDEFGHIJ) { 684 subgraph_node_names_ = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J"}; 685 686 TF_ASSERT_OK(FuseByNodes()); 687 688 EXPECT_EQ(11, graph_def_.node_size()); 689 EXPECT_EQ(3, result_graph_def_.node_size()) // "A", "RFG", "K" 690 << "=== Before: \n" 691 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 692 << SummarizeGraphDef(result_graph_def_); 693 } 694 695 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_ABCDEFGHIJK) { 696 subgraph_node_names_ = {"A", "B", "C", "D", "E", "F", 697 "G", "H", "I", "J", "K"}; 698 699 TF_ASSERT_OK(FuseByNodes()); 700 701 EXPECT_EQ(11, graph_def_.node_size()); 702 EXPECT_EQ(3, result_graph_def_.node_size()) // "A", "RFG", "K" 703 << "=== Before: \n" 704 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 705 << SummarizeGraphDef(result_graph_def_); 706 } 707 708 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByOpTypes_HIJ) { 709 subgraph_op_types_ = {"Mul"}; 710 ReplaceOpType({"H", "I", "J"}, "Mul"); 711 712 TF_ASSERT_OK(FuseByOpTypes()); 713 714 EXPECT_EQ(11, graph_def_.node_size()); 715 EXPECT_EQ(9, result_graph_def_.node_size()) 716 << "=== Before: \n" 717 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 718 << SummarizeGraphDef(result_graph_def_); 719 } 720 721 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByOpTypes_FGHIJ) { 722 subgraph_op_types_ = {"Const", "Mul"}; 723 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul"); 724 725 TF_ASSERT_OK(FuseByOpTypes()); 726 727 EXPECT_EQ(11, graph_def_.node_size()); 728 EXPECT_EQ(3, result_graph_def_.node_size()) 729 << "=== Before: \n" 730 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 731 << SummarizeGraphDef(result_graph_def_); 732 } 733 734 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_HIJ) { 735 ReplaceOpType({"H", "I", "J"}, "Mul"); 736 737 TF_ASSERT_OK(FuseByExecutor0()); 738 739 EXPECT_EQ(11, graph_def_.node_size()); 740 EXPECT_EQ(9, result_graph_def_.node_size()) 741 << "=== Before: \n" 742 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 743 << SummarizeGraphDef(result_graph_def_); 744 } 745 746 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_FGHIJ) { 747 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul"); 748 749 TF_ASSERT_OK(FuseByExecutor1()); 750 751 EXPECT_EQ(11, graph_def_.node_size()); 752 EXPECT_EQ(3, result_graph_def_.node_size()) 753 << "=== Before: \n" 754 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 755 << SummarizeGraphDef(result_graph_def_); 756 } 757 758 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_H) { 759 subgraph_node_names_ = {"H"}; 760 761 TF_ASSERT_OK(PlaceRemoteGraphArguments()); 762 ASSERT_TRUE(IsFuseReady()); 763 TF_ASSERT_OK(BuildAndAddTensorShape()); 764 765 EXPECT_EQ(11, graph_def_.node_size()); 766 767 TF_ASSERT_OK(FuseByPlacedArguments()); 768 769 EXPECT_EQ(11, result_graph_def_.node_size()) 770 << "=== Before: \n" 771 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 772 << SummarizeGraphDef(result_graph_def_); 773 } 774 775 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_CFGHIJ) { 776 subgraph_node_names_ = {"C", "F", "G", "H", "I", "J"}; 777 778 TF_ASSERT_OK(PlaceRemoteGraphArguments()); 779 ASSERT_TRUE(IsFuseReady()); 780 TF_ASSERT_OK(BuildAndAddTensorShape()); 781 782 EXPECT_EQ(11, graph_def_.node_size()); 783 784 TF_ASSERT_OK(FuseByPlacedArguments()); 785 786 EXPECT_EQ(6, result_graph_def_.node_size()) 787 << "=== Before: \n" 788 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 789 << SummarizeGraphDef(result_graph_def_); 790 } 791 792 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_ABCDEFGHIJK) { 793 subgraph_node_names_ = {"A", "B", "C", "D", "E", "F", 794 "G", "H", "I", "J", "K"}; 795 796 TF_ASSERT_OK(PlaceRemoteGraphArguments()); 797 ASSERT_TRUE(IsFuseReady()); 798 TF_ASSERT_OK(BuildAndAddTensorShape()); 799 800 EXPECT_EQ(11, graph_def_.node_size()); 801 802 TF_ASSERT_OK(FuseByPlacedArguments()); 803 804 EXPECT_EQ(3, result_graph_def_.node_size()) // "A", "RFG", "K" 805 << "=== Before: \n" 806 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 807 << SummarizeGraphDef(result_graph_def_); 808 } 809 810 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_HI_J) { 811 SetSubgraphArguments(std::vector<string>{"H", "I"}, std::vector<string>{"J"}, 812 this); 813 814 TF_ASSERT_OK(PlaceRemoteGraphArguments()); 815 ASSERT_TRUE(IsFuseReady()); 816 TF_ASSERT_OK(BuildAndAddTensorShape()); 817 818 EXPECT_EQ(11, graph_def_.node_size()); 819 820 TF_ASSERT_OK(FuseByPlacedArguments()); 821 822 EXPECT_EQ(11, result_graph_def_.node_size()) 823 << "=== Before: \n" 824 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 825 << SummarizeGraphDef(result_graph_def_); 826 } 827 828 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_FCG_J) { 829 SetSubgraphArguments(std::vector<string>{"F", "C", "G"}, 830 std::vector<string>{"J"}, this); 831 832 TF_ASSERT_OK(PlaceRemoteGraphArguments()); 833 ASSERT_TRUE(IsFuseReady()); 834 TF_ASSERT_OK(BuildAndAddTensorShape()); 835 836 EXPECT_EQ(11, graph_def_.node_size()); 837 838 TF_ASSERT_OK(FuseByPlacedArguments()); 839 840 EXPECT_EQ(9, result_graph_def_.node_size()) 841 << "=== Before: \n" 842 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 843 << SummarizeGraphDef(result_graph_def_); 844 } 845 846 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_ABCDE_K) { 847 SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"}, 848 std::vector<string>{"K"}, this); 849 850 TF_ASSERT_OK(PlaceRemoteGraphArguments()); 851 ASSERT_TRUE(IsFuseReady()); 852 TF_ASSERT_OK(BuildAndAddTensorShape()); 853 854 EXPECT_EQ(11, graph_def_.node_size()); 855 856 TF_ASSERT_OK(FuseByPlacedArguments()); 857 858 EXPECT_EQ(7, result_graph_def_.node_size()) 859 << "=== Before: \n" 860 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 861 << SummarizeGraphDef(result_graph_def_); 862 } 863 864 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_MUL_HIJ) { 865 ReplaceOpType({"H", "I", "J"}, "Mul"); 866 subgraph_op_types_ = {"Mul"}; 867 868 TF_ASSERT_OK(PlaceRemoteGraphArguments()); 869 ASSERT_TRUE(IsFuseReady()); 870 TF_ASSERT_OK(BuildAndAddTensorShape()); 871 872 EXPECT_EQ(11, graph_def_.node_size()); 873 874 TF_ASSERT_OK(FuseByPlacedArguments()); 875 876 EXPECT_EQ(9, result_graph_def_.node_size()) 877 << "=== Before: \n" 878 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 879 << SummarizeGraphDef(result_graph_def_); 880 } 881 882 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_CONST_MUL_FGHIJ) { 883 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul"); 884 subgraph_op_types_ = {"Const", "Mul"}; 885 886 TF_ASSERT_OK(PlaceRemoteGraphArguments()); 887 ASSERT_TRUE(IsFuseReady()); 888 TF_ASSERT_OK(BuildAndAddTensorShape()); 889 890 EXPECT_EQ(11, graph_def_.node_size()); 891 892 TF_ASSERT_OK(FuseByPlacedArguments()); 893 894 EXPECT_EQ(3, result_graph_def_.node_size()) 895 << "=== Before: \n" 896 << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n" 897 << SummarizeGraphDef(result_graph_def_); 898 } 899 900 } // namespace 901 } // namespace tensorflow 902