Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
     17 #include "tensorflow/cc/framework/scope.h"
     18 #include "tensorflow/core/common_runtime/shape_refiner.h"
     19 #include "tensorflow/core/framework/node_def.pb.h"
     20 #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
     21 #include "tensorflow/core/lib/core/status.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/platform/test.h"
     24 
     25 namespace tensorflow {
     26 namespace {
     27 
     28 using ClusterInfo = RemoteFusedGraphExecuteUtils::ClusterInfo;
     29 
     30 constexpr const char* const NAME_A = "A";
     31 constexpr const char* const NAME_B = "B";
     32 constexpr const char* const NAME_A_PLUS_B = "A_PLUS_B";
     33 constexpr float NODE_A_VAL = 2.0f;
     34 constexpr float NODE_B_VAL = 3.0f;
     35 constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
     36 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
     37     "fuse_test_remote_fused_graph_executor0";
     38 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
     39     "fuse_test_remote_fused_graph_executor1";
     40 
     41 static NodeDef* GetNodeDef(const string& name, GraphDef* def) {
     42   CHECK_NE(def, nullptr);
     43   for (NodeDef& node_def : *def->mutable_node()) {
     44     if (node_def.name() == name) {
     45       return &node_def;
     46     }
     47   }
     48   return nullptr;
     49 }
     50 
     51 Status BuildRemoteFusedGraphExecutor0(
     52     std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
     53   executor->reset(
     54       new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
     55   return Status::OK();
     56 }
     57 
     58 Status BuildRemoteFusedGraphExecutor1(
     59     std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
     60   executor->reset(new TestRemoteFusedGraphExecutor(
     61       {"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
     62   return Status::OK();
     63 }
     64 
     65 class FuseRemoteGraphMultipleAddOpsTest : public ::testing::Test {
     66  protected:
     67   void SetUp() final {
     68     TF_ASSERT_OK(
     69         RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def_));
     70     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
     71         hexagon_remote_fused_graph_executor_build(
     72             "remote_graph_executor_name",
     73             [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
     74               return Status::OK();
     75             });
     76     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
     77         test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
     78                                                 BuildRemoteFusedGraphExecutor0);
     79 
     80     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
     81         test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
     82                                                 BuildRemoteFusedGraphExecutor1);
     83   }
     84 
     85   void TearDown() final {}
     86 
     87   Status FuseByInOut() {
     88     // Feed output shapes and types
     89     RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
     90     GraphDef graph_def_with_shapetype = graph_def_;
     91     TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
     92         input_tensors_, /*dry_run_inference*/ true, &graph_def_with_shapetype));
     93 
     94     return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
     95         graph_def_with_shapetype, inputs_, outputs_,
     96         "remote_fused_graph_node_names", subgraph_input_names_,
     97         subgraph_output_names_, "remote_graph_executor_name",
     98         /*require_shape_type=*/true, &result_graph_def_);
     99   }
    100 
    101   Status FuseByNodes() {
    102     return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
    103         graph_def_, inputs_, outputs_, "remote_fused_graph_node_names",
    104         subgraph_node_names_, "remote_graph_executor_name",
    105         /*require_shape_type=*/false, &result_graph_def_);
    106   }
    107 
    108   Status FuseByOpTypes() {
    109     return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
    110         graph_def_, inputs_, outputs_, "remote_fused_graph_node_names",
    111         subgraph_op_types_, "remote_graph_executor_name",
    112         /*require_shape_type=*/false, &result_graph_def_);
    113   }
    114 
    115   Status FuseByExecutor0() {
    116     return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
    117         graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME0,
    118         &result_graph_def_);
    119   }
    120 
    121   Status FuseByExecutor1() {
    122     return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
    123         graph_def_, inputs_, outputs_, REMOTE_FUSED_EXECUTOR_NAME1,
    124         &result_graph_def_);
    125   }
    126 
    127   Status BuildAndAddTensorShape() {
    128     return RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
    129         input_tensors_, /*dry_run_inference=*/true, &graph_def_);
    130   }
    131 
    132   Status PlaceRemoteGraphArguments() {
    133     return RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
    134         inputs_, outputs_, subgraph_node_names_, subgraph_input_names_,
    135         subgraph_output_names_, subgraph_op_types_,
    136         "remote_fused_graph_node_names", "remote_graph_executor_name",
    137         &graph_def_);
    138   }
    139 
    140   Status FuseByPlacedArguments() {
    141     const Status status =
    142         RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
    143             graph_def_, input_tensors_, &graph_def_);
    144     result_graph_def_ = graph_def_;
    145     return status;
    146   }
    147 
    148   bool IsFuseReady() {
    149     return RemoteFusedGraphExecuteUtils::IsFuseReady(graph_def_,
    150                                                      input_tensors_);
    151   }
    152 
    153   void ReplaceOpType(const std::unordered_set<string>& op_name,
    154                      const string& new_op_type) {
    155     for (NodeDef& node_def : *graph_def_.mutable_node()) {
    156       if (op_name.count(node_def.name()) > 0) {
    157         node_def.set_op(new_op_type);
    158       }
    159     }
    160   }
    161 
    162  public:
    163   const std::vector<std::pair<string, Tensor>> input_tensors_{
    164       {"A", {DT_FLOAT, {1, 1, 1, 1}}}};
    165   const std::vector<string> inputs_{"A"};
    166   const std::vector<string> outputs_{"K"};
    167   GraphDef graph_def_;
    168   GraphDef result_graph_def_;
    169   std::vector<string> subgraph_input_names_;
    170   std::vector<string> subgraph_output_names_;
    171   std::unordered_set<string> subgraph_node_names_;
    172   std::unordered_set<string> subgraph_op_types_;
    173 };
    174 
    175 void SetSubgraphArguments(const std::vector<string>& input_names,
    176                           const std::vector<string>& output_names,
    177                           FuseRemoteGraphMultipleAddOpsTest* fixture) {
    178   for (const string& input_name : input_names) {
    179     fixture->subgraph_input_names_.emplace_back(input_name);
    180   }
    181 
    182   fixture->subgraph_output_names_ = output_names;
    183 }
    184 
    185 template <typename T>
    186 static string IterToString(const T& set) {
    187   string out;
    188   for (const string& val : set) {
    189     if (!out.empty()) {
    190       out += ", ";
    191     }
    192     out += val;
    193   }
    194   return out;
    195 }
    196 
    197 static string SummarizeGraphDef(const GraphDef& graph_def) {
    198   string out;
    199   for (const NodeDef& node : graph_def.node()) {
    200     out += strings::StrCat("node: ", node.name(), "\n    input: ");
    201     for (const string& input : node.input()) {
    202       out += strings::StrCat(input, ", ");
    203     }
    204     out += "\n";
    205   }
    206   return out;
    207 }
    208 
    209 static string DumpInOutNames(const std::vector<ClusterInfo>& ci_vec) {
    210   for (int i = 0; i < ci_vec.size(); ++i) {
    211     LOG(INFO) << "Cluster(" << i << ")";
    212     LOG(INFO) << "input: " << IterToString(std::get<1>(ci_vec.at(i)));
    213     LOG(INFO) << "output: " << IterToString(std::get<2>(ci_vec.at(i)));
    214   }
    215   return "";
    216 }
    217 
    218 static void ClearCluster(ClusterInfo* cluster) {
    219   std::get<0>(*cluster).clear();
    220   std::get<1>(*cluster).clear();
    221   std::get<2>(*cluster).clear();
    222 }
    223 
    224 TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphA) {
    225   GraphDef def;
    226   TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
    227       NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
    228   std::pair<string, Tensor> input_node_info;
    229   input_node_info.first = NAME_A;
    230   input_node_info.second = Tensor(DT_FLOAT, {});
    231   input_node_info.second.scalar<float>()() = 1.0f;
    232   const std::vector<std::pair<string, Tensor>> inputs{input_node_info};
    233   std::vector<string> outputs = {NAME_B, NAME_A_PLUS_B};
    234   std::vector<tensorflow::Tensor> output_tensors;
    235   Status status = RemoteFusedGraphExecuteUtils::DryRunInference(
    236       def, inputs, outputs, false /* initialize_by_zero */, &output_tensors);
    237   ASSERT_TRUE(status.ok()) << status;
    238   EXPECT_EQ(outputs.size(), output_tensors.size());
    239   EXPECT_NEAR(NODE_B_VAL, output_tensors.at(0).scalar<float>()(),
    240               VALUE_TOLERANCE_FLOAT);
    241   EXPECT_NEAR(1.0f + NODE_B_VAL, output_tensors.at(1).scalar<float>()(),
    242               VALUE_TOLERANCE_FLOAT);
    243 }
    244 
    245 TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAUninitialized) {
    246   GraphDef def;
    247   TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
    248       NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
    249   std::pair<string, Tensor> input_node_info;
    250   input_node_info.first = NAME_A;
    251   input_node_info.second = Tensor(DT_FLOAT, {});
    252   const std::vector<std::pair<string, Tensor>> inputs{input_node_info};
    253   std::vector<string> outputs = {NAME_B, NAME_A_PLUS_B};
    254   std::vector<tensorflow::Tensor> output_tensors;
    255   Status status = RemoteFusedGraphExecuteUtils::DryRunInference(
    256       def, inputs, outputs, true /* initialize_by_zero */, &output_tensors);
    257   ASSERT_TRUE(status.ok()) << status;
    258   EXPECT_EQ(outputs.size(), output_tensors.size());
    259   EXPECT_NEAR(NODE_B_VAL, output_tensors.at(0).scalar<float>()(),
    260               VALUE_TOLERANCE_FLOAT);
    261   EXPECT_NEAR(NODE_B_VAL, output_tensors.at(1).scalar<float>()(),
    262               VALUE_TOLERANCE_FLOAT);
    263 }
    264 
    265 TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphAB) {
    266   GraphDef def;
    267   TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
    268       NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
    269   std::pair<string, Tensor> input_node_info_a;
    270   input_node_info_a.first = NAME_A;
    271   input_node_info_a.second = Tensor(DT_FLOAT, {});
    272   input_node_info_a.second.scalar<float>()() = NODE_A_VAL;
    273   std::pair<string, Tensor> input_node_info_b;
    274   input_node_info_b.first = NAME_B;
    275   input_node_info_b.second = Tensor(DT_FLOAT, {});
    276   input_node_info_b.second.scalar<float>()() = NODE_B_VAL;
    277   const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a,
    278                                                       input_node_info_b};
    279   std::vector<string> outputs = {NAME_A_PLUS_B};
    280   std::vector<tensorflow::Tensor> output_tensors;
    281   Status status = RemoteFusedGraphExecuteUtils::DryRunInference(
    282       def, inputs, outputs, false /* initialize_by_zero */, &output_tensors);
    283   ASSERT_TRUE(status.ok()) << status;
    284   EXPECT_EQ(outputs.size(), output_tensors.size());
    285   EXPECT_NEAR(NODE_A_VAL + NODE_B_VAL, output_tensors.at(0).scalar<float>()(),
    286               VALUE_TOLERANCE_FLOAT);
    287 }
    288 
    289 TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphForAllNodes) {
    290   // Set Node "A" as an input with value (= 1.0f)
    291   std::pair<string, Tensor> input_node_info_a;
    292   input_node_info_a.first = NAME_A;
    293   input_node_info_a.second = Tensor(DT_FLOAT, {});
    294   input_node_info_a.second.scalar<float>()() = 1.0f;
    295 
    296   // Setup dryrun arguments
    297   const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a};
    298   RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
    299 
    300   GraphDef def;
    301   TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
    302       NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
    303 
    304   // dryrun
    305   const Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
    306       def, inputs, false /* initialize_by_zero */, &tensor_shape_map);
    307 
    308   ASSERT_TRUE(status.ok()) << status;
    309 
    310   // Assert output node count
    311   ASSERT_EQ(3, tensor_shape_map.size());
    312   ASSERT_EQ(1, tensor_shape_map.count(NAME_A));
    313   ASSERT_EQ(1, tensor_shape_map.count(NAME_B));
    314   ASSERT_EQ(1, tensor_shape_map.count(NAME_A_PLUS_B));
    315 
    316   const RemoteFusedGraphExecuteUtils::TensorShapeType* tst =
    317       RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
    318                                                        NAME_B);
    319   EXPECT_NE(tst, nullptr);
    320   EXPECT_EQ(DT_FLOAT, tst->first);
    321   EXPECT_EQ(0, tst->second.dims());
    322 
    323   tst = RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
    324                                                          NAME_A_PLUS_B);
    325   EXPECT_NE(tst, nullptr);
    326   EXPECT_EQ(DT_FLOAT, tst->first);
    327   EXPECT_EQ(0, tst->second.dims());
    328 }
    329 
    330 TEST(RemoteFusedGraphExecuteUtils, PropagateAndBuildTensorShapeMap) {
    331   std::pair<string, Tensor> input_node_info_a;
    332   input_node_info_a.first = NAME_A;
    333   input_node_info_a.second = Tensor(DT_FLOAT, {});
    334   input_node_info_a.second.scalar<float>()() = NODE_A_VAL;
    335   std::pair<string, Tensor> input_node_info_b;
    336   input_node_info_b.first = NAME_B;
    337   input_node_info_b.second = Tensor(DT_FLOAT, {});
    338   input_node_info_b.second.scalar<float>()() = NODE_B_VAL;
    339   const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a,
    340                                                       input_node_info_b};
    341 
    342   RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
    343   GraphDef def;
    344   TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
    345       NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
    346   ImportGraphDefOptions opts;
    347   Graph graph(OpRegistry::Global());
    348   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
    349   Status status = ImportGraphDef(opts, def, &graph, &shape_refiner);
    350   ASSERT_TRUE(RemoteFusedGraphExecuteUtils::PropagateShapeInference(
    351                   def, inputs, &graph, &shape_refiner)
    352                   .ok());
    353   ASSERT_TRUE(RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
    354                   graph, shape_refiner, &tensor_shape_map)
    355                   .ok());
    356 
    357   ASSERT_EQ(3, tensor_shape_map.size());
    358   ASSERT_EQ(1, tensor_shape_map.count(NAME_A));
    359   ASSERT_EQ(1, tensor_shape_map.count(NAME_B));
    360   ASSERT_EQ(1, tensor_shape_map.count(NAME_A_PLUS_B));
    361 
    362   const RemoteFusedGraphExecuteUtils::TensorShapeType* tst =
    363       RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
    364                                                        NAME_B);
    365   EXPECT_NE(tst, nullptr);
    366   EXPECT_EQ(DT_FLOAT, tst->first);
    367   EXPECT_EQ(0, tst->second.dims());
    368 
    369   tst = RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
    370                                                          NAME_A_PLUS_B);
    371   EXPECT_NE(tst, nullptr);
    372   EXPECT_EQ(DT_FLOAT, tst->first);
    373   EXPECT_EQ(0, tst->second.dims());
    374 
    375   {
    376     NodeDef* node_def = GetNodeDef(NAME_B, &def);
    377     TF_ASSERT_OK(
    378         RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
    379             tensor_shape_map, node_def));
    380     std::vector<DataType> data_types;
    381     TF_ASSERT_OK(GetNodeAttr(
    382         *node_def, RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES,
    383         &data_types));
    384     ASSERT_EQ(1, data_types.size());
    385     EXPECT_EQ(DT_FLOAT, data_types.at(0));
    386 
    387     std::vector<TensorShape> shapes;
    388     TF_ASSERT_OK(GetNodeAttr(
    389         *node_def, RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, &shapes));
    390     ASSERT_EQ(1, shapes.size());
    391     EXPECT_EQ(0, shapes.at(0).dims());
    392   }
    393 
    394   {
    395     NodeDef* node_def = GetNodeDef(NAME_A_PLUS_B, &def);
    396     TF_ASSERT_OK(
    397         RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
    398             tensor_shape_map, node_def));
    399     std::vector<DataType> data_types;
    400     TF_ASSERT_OK(GetNodeAttr(
    401         *node_def, RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES,
    402         &data_types));
    403     ASSERT_EQ(1, data_types.size());
    404     EXPECT_EQ(DT_FLOAT, data_types.at(0));
    405 
    406     std::vector<TensorShape> shapes;
    407     TF_ASSERT_OK(GetNodeAttr(
    408         *node_def, RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, &shapes));
    409     ASSERT_EQ(1, shapes.size());
    410     EXPECT_EQ(0, shapes.at(0).dims());
    411   }
    412 }
    413 
    414 TEST(RemoteFusedGraphExecuteUtils,
    415      BuildRemoteFusedGraphExecuteInfoWithShapeInference) {
    416   // Build inputs
    417   std::pair<string, Tensor> input_node_info_a;
    418   input_node_info_a.first = NAME_A;
    419   input_node_info_a.second = Tensor(DT_FLOAT, {});
    420   input_node_info_a.second.scalar<float>()() = NODE_A_VAL;
    421   std::pair<string, Tensor> input_node_info_b;
    422   input_node_info_b.first = NAME_B;
    423   input_node_info_b.second = Tensor(DT_FLOAT, {});
    424   input_node_info_b.second.scalar<float>()() = NODE_B_VAL;
    425   const std::vector<std::pair<string, Tensor>> input_tensors{input_node_info_a,
    426                                                              input_node_info_b};
    427   const std::vector<string> inputs{NAME_A, NAME_B};
    428 
    429   // Build outputs
    430   const std::vector<string> outputs = {NAME_A_PLUS_B};
    431 
    432   GraphDef def;
    433   TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
    434       NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
    435   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
    436       input_tensors, /*dry_run_inference*/ true, &def));
    437 
    438   RemoteFusedGraphExecuteInfo execute_info0;
    439   DataTypeVector input_types0;
    440   DataTypeVector output_types0;
    441 
    442   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
    443       "executor", def, inputs, outputs, /*require_shape_type=*/true,
    444       &execute_info0, &input_types0, &output_types0));
    445 
    446   EXPECT_EQ(inputs.size(),
    447             execute_info0.default_graph_input_tensor_shape_size());
    448   EXPECT_EQ(outputs.size(),
    449             execute_info0.default_graph_output_tensor_shape_size());
    450   EXPECT_EQ(inputs.size(), input_types0.size());
    451   EXPECT_EQ(outputs.size(), output_types0.size());
    452 
    453   EXPECT_EQ(def.node_size(), execute_info0.remote_graph().node_size());
    454 }
    455 
    456 TEST(RemoteFusedGraphExecuteUtils, BuildRemoteFusedGraphExecuteOpNode) {
    457   const std::vector<string> inputs{NAME_A, NAME_B};
    458 
    459   // Build outputs
    460   const std::vector<string> outputs = {NAME_A_PLUS_B};
    461 
    462   GraphDef def;
    463   TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph(
    464       NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &def));
    465 
    466   Graph graph(OpRegistry::Global());
    467   ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
    468   TF_ASSERT_OK(ImportGraphDef({}, def, &graph, &shape_refiner));
    469 
    470   Node* node;
    471   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
    472       "fused_name", "executor", def, inputs, outputs,
    473       /*require_shape_type=*/false, &graph, &node));
    474 }
    475 
    476 TEST(RemoteFusedGraphExecuteUtils, ExtractSubgraphNodes) {
    477   GraphDef graph_def;
    478   TF_ASSERT_OK(
    479       RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def));
    480   ClusterInfo cluster;
    481   const std::unordered_set<string>& node_names = std::get<0>(cluster);
    482   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    483       {"H", "I"}, {"J"}, graph_def, &cluster));
    484   EXPECT_EQ(1, node_names.size()) << IterToString(node_names);
    485 
    486   ClearCluster(&cluster);
    487   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    488       {"F", "C", "G"}, {"J"}, graph_def, &cluster));
    489   EXPECT_EQ(3, node_names.size()) << IterToString(node_names);
    490 
    491   ClearCluster(&cluster);
    492   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    493       {"A", "B", "C", "D", "E"}, {"J"}, graph_def, &cluster));
    494   EXPECT_EQ(5, node_names.size()) << IterToString(node_names);
    495 
    496   ClearCluster(&cluster);
    497   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    498       {"A", "B", "C", "D", "E"}, {"K"}, graph_def, &cluster));
    499   EXPECT_EQ(6, node_names.size()) << IterToString(node_names);
    500 
    501   ClearCluster(&cluster);
    502   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    503       {"F"}, {"H"}, graph_def, &cluster));
    504   EXPECT_EQ(2, node_names.size()) << IterToString(node_names);
    505 }
    506 
    507 TEST(RemoteFusedGraphExecuteUtils, ClusterizeNodes) {
    508   GraphDef graph_def;
    509   TF_ASSERT_OK(
    510       RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def));
    511 
    512   std::vector<ClusterInfo> ci_vec;
    513   TF_ASSERT_OK(
    514       RemoteFusedGraphExecuteUtils::ClusterizeNodes({"J"}, graph_def, &ci_vec));
    515   ASSERT_EQ(1, ci_vec.size());
    516   EXPECT_EQ(2, std::get<1>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec);
    517   EXPECT_EQ(1, std::get<2>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec);
    518 
    519   ci_vec.clear();
    520   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
    521       {"H", "I", "J"}, graph_def, &ci_vec));
    522   ASSERT_EQ(1, ci_vec.size());
    523   EXPECT_EQ(3, std::get<1>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec);
    524   EXPECT_EQ(1, std::get<2>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec);
    525 
    526   ci_vec.clear();
    527   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
    528       {"F", "C", "G", "H", "I", "J"}, graph_def, &ci_vec));
    529   ASSERT_EQ(1, ci_vec.size());
    530   EXPECT_EQ(4, std::get<1>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec);
    531   EXPECT_EQ(2, std::get<2>(ci_vec.at(0)).size()) << DumpInOutNames(ci_vec);
    532 
    533   ci_vec.clear();
    534   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
    535       {"A", "B", "C", "D", "E"}, graph_def, &ci_vec));
    536   ASSERT_EQ(5, ci_vec.size());
    537 
    538   ci_vec.clear();
    539   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
    540       {"A", "B", "D", "E", "F", "G"}, graph_def, &ci_vec));
    541   ASSERT_EQ(2, ci_vec.size());
    542 }
    543 
    544 TEST(RemoteFusedGraphExecuteUtils, BuildSubgraphDefByInOut) {
    545   GraphDef graph_def;
    546   TF_ASSERT_OK(
    547       RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(&graph_def));
    548 
    549   ClusterInfo cluster;
    550   GraphDef subgraph_def;
    551   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    552       std::vector<string>{"H", "I"}, std::vector<string>{"J"}, graph_def,
    553       &cluster));
    554   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
    555       cluster, graph_def, &subgraph_def));
    556   EXPECT_EQ(3, subgraph_def.node_size());
    557 
    558   ClearCluster(&cluster);
    559   subgraph_def.Clear();
    560   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    561       std::vector<string>{"F", "C", "G"}, std::vector<string>{"J"}, graph_def,
    562       &cluster));
    563   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
    564       cluster, graph_def, &subgraph_def));
    565   EXPECT_EQ(6, subgraph_def.node_size());
    566 
    567   ClearCluster(&cluster);
    568   subgraph_def.Clear();
    569   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    570       std::vector<string>{"A", "B", "C", "D", "E"}, std::vector<string>{"J"},
    571       graph_def, &cluster));
    572   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
    573       cluster, graph_def, &subgraph_def));
    574   EXPECT_EQ(10, subgraph_def.node_size());
    575 
    576   ClearCluster(&cluster);
    577   subgraph_def.Clear();
    578 
    579   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    580       std::vector<string>{"A", "B", "C", "D", "E"}, std::vector<string>{"K"},
    581       graph_def, &cluster));
    582   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
    583       cluster, graph_def, &subgraph_def));
    584   EXPECT_EQ(11, subgraph_def.node_size());
    585 
    586   ClearCluster(&cluster);
    587   subgraph_def.Clear();
    588   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
    589       std::vector<string>{"F"}, std::vector<string>{"H"}, graph_def, &cluster));
    590   TF_ASSERT_OK(RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
    591       cluster, graph_def, &subgraph_def));
    592   EXPECT_EQ(3, subgraph_def.node_size());
    593 }
    594 
    595 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_HI_J) {
    596   SetSubgraphArguments(std::vector<string>{"H", "I"}, std::vector<string>{"J"},
    597                        this);
    598 
    599   TF_ASSERT_OK(FuseByInOut());
    600 
    601   EXPECT_EQ(11, graph_def_.node_size());
    602   EXPECT_EQ(11, result_graph_def_.node_size())
    603       << "=== Before: \n"
    604       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    605       << SummarizeGraphDef(result_graph_def_);
    606 }
    607 
    608 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_FCG_J) {
    609   SetSubgraphArguments(std::vector<string>{"F", "C", "G"},
    610                        std::vector<string>{"J"}, this);
    611 
    612   TF_ASSERT_OK(FuseByInOut());
    613 
    614   EXPECT_EQ(11, graph_def_.node_size());
    615   EXPECT_EQ(9, result_graph_def_.node_size())
    616       << "=== Before: \n"
    617       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    618       << SummarizeGraphDef(result_graph_def_);
    619 }
    620 
    621 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_ABCDE_J) {
    622   SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"},
    623                        std::vector<string>{"J"}, this);
    624 
    625   TF_ASSERT_OK(FuseByInOut());
    626 
    627   EXPECT_EQ(11, graph_def_.node_size());
    628   EXPECT_EQ(8, result_graph_def_.node_size())
    629       << "=== Before: \n"
    630       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    631       << SummarizeGraphDef(result_graph_def_);
    632 }
    633 
    634 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByInOut_ABCDE_K) {
    635   SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"},
    636                        std::vector<string>{"K"}, this);
    637 
    638   TF_ASSERT_OK(FuseByInOut());
    639 
    640   EXPECT_EQ(11, graph_def_.node_size());
    641   EXPECT_EQ(7, result_graph_def_.node_size())
    642       << "=== Before: \n"
    643       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    644       << SummarizeGraphDef(result_graph_def_);
    645 }
    646 
    647 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_H) {
    648   subgraph_node_names_ = {"H"};
    649 
    650   TF_ASSERT_OK(FuseByNodes());
    651 
    652   EXPECT_EQ(11, graph_def_.node_size());
    653   EXPECT_EQ(11, result_graph_def_.node_size())
    654       << "=== Before: \n"
    655       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    656       << SummarizeGraphDef(result_graph_def_);
    657 }
    658 
    659 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_HIJ) {
    660   subgraph_node_names_ = {"H", "I", "J"};
    661 
    662   TF_ASSERT_OK(FuseByNodes());
    663 
    664   EXPECT_EQ(11, graph_def_.node_size());
    665   EXPECT_EQ(9, result_graph_def_.node_size())
    666       << "=== Before: \n"
    667       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    668       << SummarizeGraphDef(result_graph_def_);
    669 }
    670 
    671 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_CFGHIJ) {
    672   subgraph_node_names_ = {"C", "F", "G", "H", "I", "J"};
    673 
    674   TF_ASSERT_OK(FuseByNodes());
    675 
    676   EXPECT_EQ(11, graph_def_.node_size());
    677   EXPECT_EQ(6, result_graph_def_.node_size())
    678       << "=== Before: \n"
    679       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    680       << SummarizeGraphDef(result_graph_def_);
    681 }
    682 
    683 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_ABCDEFGHIJ) {
    684   subgraph_node_names_ = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J"};
    685 
    686   TF_ASSERT_OK(FuseByNodes());
    687 
    688   EXPECT_EQ(11, graph_def_.node_size());
    689   EXPECT_EQ(3, result_graph_def_.node_size())  // "A", "RFG", "K"
    690       << "=== Before: \n"
    691       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    692       << SummarizeGraphDef(result_graph_def_);
    693 }
    694 
    695 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByNodes_ABCDEFGHIJK) {
    696   subgraph_node_names_ = {"A", "B", "C", "D", "E", "F",
    697                           "G", "H", "I", "J", "K"};
    698 
    699   TF_ASSERT_OK(FuseByNodes());
    700 
    701   EXPECT_EQ(11, graph_def_.node_size());
    702   EXPECT_EQ(3, result_graph_def_.node_size())  // "A", "RFG", "K"
    703       << "=== Before: \n"
    704       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    705       << SummarizeGraphDef(result_graph_def_);
    706 }
    707 
    708 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByOpTypes_HIJ) {
    709   subgraph_op_types_ = {"Mul"};
    710   ReplaceOpType({"H", "I", "J"}, "Mul");
    711 
    712   TF_ASSERT_OK(FuseByOpTypes());
    713 
    714   EXPECT_EQ(11, graph_def_.node_size());
    715   EXPECT_EQ(9, result_graph_def_.node_size())
    716       << "=== Before: \n"
    717       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    718       << SummarizeGraphDef(result_graph_def_);
    719 }
    720 
    721 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByOpTypes_FGHIJ) {
    722   subgraph_op_types_ = {"Const", "Mul"};
    723   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
    724 
    725   TF_ASSERT_OK(FuseByOpTypes());
    726 
    727   EXPECT_EQ(11, graph_def_.node_size());
    728   EXPECT_EQ(3, result_graph_def_.node_size())
    729       << "=== Before: \n"
    730       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    731       << SummarizeGraphDef(result_graph_def_);
    732 }
    733 
    734 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_HIJ) {
    735   ReplaceOpType({"H", "I", "J"}, "Mul");
    736 
    737   TF_ASSERT_OK(FuseByExecutor0());
    738 
    739   EXPECT_EQ(11, graph_def_.node_size());
    740   EXPECT_EQ(9, result_graph_def_.node_size())
    741       << "=== Before: \n"
    742       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    743       << SummarizeGraphDef(result_graph_def_);
    744 }
    745 
    746 TEST_F(FuseRemoteGraphMultipleAddOpsTest, FuseSubgraphByExecutor_FGHIJ) {
    747   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
    748 
    749   TF_ASSERT_OK(FuseByExecutor1());
    750 
    751   EXPECT_EQ(11, graph_def_.node_size());
    752   EXPECT_EQ(3, result_graph_def_.node_size())
    753       << "=== Before: \n"
    754       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    755       << SummarizeGraphDef(result_graph_def_);
    756 }
    757 
    758 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_H) {
    759   subgraph_node_names_ = {"H"};
    760 
    761   TF_ASSERT_OK(PlaceRemoteGraphArguments());
    762   ASSERT_TRUE(IsFuseReady());
    763   TF_ASSERT_OK(BuildAndAddTensorShape());
    764 
    765   EXPECT_EQ(11, graph_def_.node_size());
    766 
    767   TF_ASSERT_OK(FuseByPlacedArguments());
    768 
    769   EXPECT_EQ(11, result_graph_def_.node_size())
    770       << "=== Before: \n"
    771       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    772       << SummarizeGraphDef(result_graph_def_);
    773 }
    774 
    775 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_CFGHIJ) {
    776   subgraph_node_names_ = {"C", "F", "G", "H", "I", "J"};
    777 
    778   TF_ASSERT_OK(PlaceRemoteGraphArguments());
    779   ASSERT_TRUE(IsFuseReady());
    780   TF_ASSERT_OK(BuildAndAddTensorShape());
    781 
    782   EXPECT_EQ(11, graph_def_.node_size());
    783 
    784   TF_ASSERT_OK(FuseByPlacedArguments());
    785 
    786   EXPECT_EQ(6, result_graph_def_.node_size())
    787       << "=== Before: \n"
    788       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    789       << SummarizeGraphDef(result_graph_def_);
    790 }
    791 
    792 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_ABCDEFGHIJK) {
    793   subgraph_node_names_ = {"A", "B", "C", "D", "E", "F",
    794                           "G", "H", "I", "J", "K"};
    795 
    796   TF_ASSERT_OK(PlaceRemoteGraphArguments());
    797   ASSERT_TRUE(IsFuseReady());
    798   TF_ASSERT_OK(BuildAndAddTensorShape());
    799 
    800   EXPECT_EQ(11, graph_def_.node_size());
    801 
    802   TF_ASSERT_OK(FuseByPlacedArguments());
    803 
    804   EXPECT_EQ(3, result_graph_def_.node_size())  // "A", "RFG", "K"
    805       << "=== Before: \n"
    806       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    807       << SummarizeGraphDef(result_graph_def_);
    808 }
    809 
    810 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_HI_J) {
    811   SetSubgraphArguments(std::vector<string>{"H", "I"}, std::vector<string>{"J"},
    812                        this);
    813 
    814   TF_ASSERT_OK(PlaceRemoteGraphArguments());
    815   ASSERT_TRUE(IsFuseReady());
    816   TF_ASSERT_OK(BuildAndAddTensorShape());
    817 
    818   EXPECT_EQ(11, graph_def_.node_size());
    819 
    820   TF_ASSERT_OK(FuseByPlacedArguments());
    821 
    822   EXPECT_EQ(11, result_graph_def_.node_size())
    823       << "=== Before: \n"
    824       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    825       << SummarizeGraphDef(result_graph_def_);
    826 }
    827 
    828 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_FCG_J) {
    829   SetSubgraphArguments(std::vector<string>{"F", "C", "G"},
    830                        std::vector<string>{"J"}, this);
    831 
    832   TF_ASSERT_OK(PlaceRemoteGraphArguments());
    833   ASSERT_TRUE(IsFuseReady());
    834   TF_ASSERT_OK(BuildAndAddTensorShape());
    835 
    836   EXPECT_EQ(11, graph_def_.node_size());
    837 
    838   TF_ASSERT_OK(FuseByPlacedArguments());
    839 
    840   EXPECT_EQ(9, result_graph_def_.node_size())
    841       << "=== Before: \n"
    842       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    843       << SummarizeGraphDef(result_graph_def_);
    844 }
    845 
    846 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_ABCDE_K) {
    847   SetSubgraphArguments(std::vector<string>{"A", "B", "C", "D", "E"},
    848                        std::vector<string>{"K"}, this);
    849 
    850   TF_ASSERT_OK(PlaceRemoteGraphArguments());
    851   ASSERT_TRUE(IsFuseReady());
    852   TF_ASSERT_OK(BuildAndAddTensorShape());
    853 
    854   EXPECT_EQ(11, graph_def_.node_size());
    855 
    856   TF_ASSERT_OK(FuseByPlacedArguments());
    857 
    858   EXPECT_EQ(7, result_graph_def_.node_size())
    859       << "=== Before: \n"
    860       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    861       << SummarizeGraphDef(result_graph_def_);
    862 }
    863 
    864 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_MUL_HIJ) {
    865   ReplaceOpType({"H", "I", "J"}, "Mul");
    866   subgraph_op_types_ = {"Mul"};
    867 
    868   TF_ASSERT_OK(PlaceRemoteGraphArguments());
    869   ASSERT_TRUE(IsFuseReady());
    870   TF_ASSERT_OK(BuildAndAddTensorShape());
    871 
    872   EXPECT_EQ(11, graph_def_.node_size());
    873 
    874   TF_ASSERT_OK(FuseByPlacedArguments());
    875 
    876   EXPECT_EQ(9, result_graph_def_.node_size())
    877       << "=== Before: \n"
    878       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    879       << SummarizeGraphDef(result_graph_def_);
    880 }
    881 
    882 TEST_F(FuseRemoteGraphMultipleAddOpsTest, PlaceAndFuse_CONST_MUL_FGHIJ) {
    883   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
    884   subgraph_op_types_ = {"Const", "Mul"};
    885 
    886   TF_ASSERT_OK(PlaceRemoteGraphArguments());
    887   ASSERT_TRUE(IsFuseReady());
    888   TF_ASSERT_OK(BuildAndAddTensorShape());
    889 
    890   EXPECT_EQ(11, graph_def_.node_size());
    891 
    892   TF_ASSERT_OK(FuseByPlacedArguments());
    893 
    894   EXPECT_EQ(3, result_graph_def_.node_size())
    895       << "=== Before: \n"
    896       << SummarizeGraphDef(graph_def_) << "\n\n\n=== After: \n"
    897       << SummarizeGraphDef(result_graph_def_);
    898 }
    899 
    900 }  // namespace
    901 }  // namespace tensorflow
    902