Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/ops/const_op.h"
     17 #include "tensorflow/cc/ops/image_ops.h"
     18 #include "tensorflow/cc/ops/nn_ops.h"
     19 #include "tensorflow/cc/ops/standard_ops.h"
     20 #include "tensorflow/core/common_runtime/function.h"
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/framework/tensor_testutil.h"
     23 #include "tensorflow/core/graph/default_device.h"
     24 #include "tensorflow/core/graph/node_builder.h"
     25 #include "tensorflow/core/graph/testlib.h"
     26 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
     27 #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
     28 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
     29 #include "tensorflow/core/lib/core/status_test_util.h"
     30 #include "tensorflow/core/platform/test.h"
     31 #include "tensorflow/core/public/session.h"
     32 #include "tensorflow/tools/graph_transforms/transform_utils.h"
     33 
     34 namespace tensorflow {
     35 namespace graph_transforms {
     36 
     37 // Declared here so we don't have to put it in a public header.
     38 Status FuseRemoteGraph(const GraphDef& input_graph_def,
     39                        const TransformFuncContext& context,
     40                        GraphDef* output_graph_def);
     41 
     42 Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
     43                                  const TransformFuncContext& context,
     44                                  GraphDef* output_graph_def);
     45 
     46 namespace {
     47 constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
     48     "remote_fused_graph_executor_name";
     49 constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
     50     "remote_fused_graph_node_name";
     51 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
     52     "fuse_test_remote_fused_graph_executor0";
     53 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
     54     "fuse_test_remote_fused_graph_executor1";
     55 
     56 Status BuildRemoteFusedGraphExecutor0(
     57     std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
     58   executor->reset(
     59       new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
     60   return Status::OK();
     61 }
     62 
     63 Status BuildRemoteFusedGraphExecutor1(
     64     std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
     65   executor->reset(new TestRemoteFusedGraphExecutor(
     66       {"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
     67   return Status::OK();
     68 }
     69 
     70 class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
     71  protected:
     72   void SetUp() final {
     73     TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
     74         &input_graph_def_));
     75     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
     76         hexagon_remote_fused_graph_executor_build(
     77             REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
     78             [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
     79               return Status::OK();
     80             });
     81     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
     82         test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
     83                                                 BuildRemoteFusedGraphExecutor0);
     84 
     85     RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
     86         test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
     87                                                 BuildRemoteFusedGraphExecutor1);
     88   }
     89 
     90   void TearDown() final {}
     91 
     92   Status Fuse() { return FuseInternal(/*only_place_args=*/false); }
     93 
     94   Status PlaceFuseArgs() { return FuseInternal(/*only_place_args*/ true); }
     95 
     96   Status FuseWithPlacedArgs() {
     97     const std::vector<std::pair<string, Tensor>> input_tensors{
     98         {"A", {DT_FLOAT, {1, 1, 1, 1}}}};
     99     return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
    100         input_graph_def_with_fuse_args_, input_tensors, &output_graph_def_);
    101   }
    102 
    103   Status FuseInternal(bool only_place_args) {
    104     TransformFuncContext context;
    105     context.input_names = inputs_;
    106     context.output_names = outputs_;
    107 
    108     if (!input_types_.empty()) {
    109       context.params.insert(std::pair<string, std::vector<string>>(
    110           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES,
    111            {input_types_}}));
    112     }
    113     if (!input_shapes_.empty()) {
    114       context.params.insert(std::pair<string, std::vector<string>>(
    115           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES,
    116            {input_shapes_}}));
    117     }
    118     if (!fused_node_names_str_.empty()) {
    119       context.params.insert(std::pair<string, std::vector<string>>(
    120           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES,
    121            {fused_node_names_str_}}));
    122     }
    123 
    124     if (!border_inputs_str_.empty()) {
    125       context.params.insert(std::pair<string, std::vector<string>>(
    126           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS,
    127            {border_inputs_str_}}));
    128     }
    129     if (!border_outputs_str_.empty()) {
    130       context.params.insert(std::pair<string, std::vector<string>>(
    131           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS,
    132            {border_outputs_str_}}));
    133     }
    134 
    135     if (!fused_op_types_str_.empty()) {
    136       context.params.insert(std::pair<string, std::vector<string>>(
    137           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES,
    138            {fused_op_types_str_}}));
    139     }
    140 
    141     if (fuse_by_executor_) {
    142       context.params.insert(std::pair<string, std::vector<string>>(
    143           {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR,
    144            {"true"}}));
    145     }
    146 
    147     context.params.insert(std::pair<string, std::vector<string>>(
    148         {RemoteFusedGraphExecuteUtils::
    149              TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
    150          {remote_fused_graph_executor_name_}}));
    151     context.params.insert(std::pair<string, std::vector<string>>(
    152         {RemoteFusedGraphExecuteUtils::
    153              TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
    154          {REMOTE_FUSED_GRAPH_NODE_NAME}}));
    155 
    156     if (only_place_args) {
    157       return PlaceRemoteGraphArguments(input_graph_def_, context,
    158                                        &input_graph_def_with_fuse_args_);
    159     } else {
    160       return FuseRemoteGraph(input_graph_def_, context, &output_graph_def_);
    161     }
    162   }
    163 
    164   void SetInputShapeType() {
    165     input_types_ = "float";
    166     input_shapes_ = "1,1,1,1";
    167   }
    168 
    169   void ReplaceOpType(const std::unordered_set<string>& op_name,
    170                      const string& new_op_type) {
    171     for (NodeDef& node_def : *input_graph_def_.mutable_node()) {
    172       if (op_name.count(node_def.name()) > 0) {
    173         node_def.set_op(new_op_type);
    174       }
    175     }
    176   }
    177 
    178   void CheckGraph(int expected_node_count, int expected_cluster_count) {
    179     EXPECT_EQ(expected_node_count, output_graph_def_.node_size());
    180 
    181     int cluster_count = 0;
    182     for (const NodeDef& node_def : output_graph_def_.node()) {
    183       const string& name = node_def.name();
    184       if (StringPiece(name).starts_with(REMOTE_FUSED_GRAPH_NODE_NAME)) {
    185         ++cluster_count;
    186         RemoteFusedGraphExecuteInfo info;
    187         string serialized_proto;
    188         TF_ASSERT_OK(
    189             GetNodeAttr(node_def,
    190                         RemoteFusedGraphExecuteUtils::
    191                             ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
    192                         &serialized_proto));
    193         info.ParseFromString(serialized_proto);
    194         CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name());
    195       }
    196     }
    197     EXPECT_EQ(expected_cluster_count, cluster_count);
    198   }
    199 
    200  public:
    201   const std::vector<string> inputs_{"A"};
    202   const std::vector<string> outputs_{"K"};
    203   GraphDef input_graph_def_;
    204   string input_types_;
    205   string input_shapes_;
    206   GraphDef input_graph_def_with_fuse_args_;
    207   GraphDef output_graph_def_;
    208   string fused_node_names_str_;
    209   string border_inputs_str_;
    210   string border_outputs_str_;
    211   string fused_op_types_str_;
    212   string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME};
    213   bool fuse_by_executor_{false};
    214 };
    215 
    216 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    217        FuseRemoteGraphByNodesWithShapeType_HIJ) {
    218   SetInputShapeType();
    219   fused_node_names_str_ = "H,I,J";
    220   TF_ASSERT_OK(Fuse());
    221   CheckGraph(9, 1);
    222 }
    223 
    224 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    225        FuseRemoteGraphByNodesWithoutShapeType_HIJ) {
    226   fused_node_names_str_ = "H,I,J";
    227   TF_ASSERT_OK(Fuse());
    228   CheckGraph(9, 1);
    229 }
    230 
    231 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    232        FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK) {
    233   SetInputShapeType();
    234   fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
    235   TF_ASSERT_OK(Fuse());
    236   CheckGraph(3, 1);
    237 }
    238 
    239 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    240        FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK) {
    241   fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
    242   TF_ASSERT_OK(Fuse());
    243   CheckGraph(3, 1);
    244 }
    245 
    246 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    247        FuseRemoteGraphByBorderWithShapeType_FCG_J) {
    248   SetInputShapeType();
    249   border_inputs_str_ = "F:0,C:0,G";
    250   border_outputs_str_ = "J:0";
    251   TF_ASSERT_OK(Fuse());
    252   CheckGraph(9, 1);
    253 }
    254 
    255 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    256        FuseRemoteGraphByBorderWithoutShapeType_FCG_J) {
    257   border_inputs_str_ = "F:0,C:0,G";
    258   border_outputs_str_ = "J:0";
    259   TF_ASSERT_OK(Fuse());
    260   CheckGraph(9, 1);
    261 }
    262 
    263 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    264        FuseRemoteGraphByBorderWithShapeType_ABCDE_K) {
    265   SetInputShapeType();
    266   border_inputs_str_ = "A,B,C,D,E";
    267   border_outputs_str_ = "K";
    268   TF_ASSERT_OK(Fuse());
    269   CheckGraph(7, 1);
    270 }
    271 
    272 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    273        FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K) {
    274   border_inputs_str_ = "A,B,C,D,E";
    275   border_outputs_str_ = "K";
    276   TF_ASSERT_OK(Fuse());
    277   CheckGraph(7, 1);
    278 }
    279 
    280 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    281        FuseRemoteGraphByOpTypes_HIJ) {
    282   ReplaceOpType({"H", "I", "J"}, "Mul");
    283   fused_op_types_str_ = "Mul";
    284   TF_ASSERT_OK(Fuse());
    285   CheckGraph(9, 1);
    286 }
    287 
    288 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    289        FuseRemoteGraphByOpTypes_FGHIJ) {
    290   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
    291   fused_op_types_str_ = "Const,Mul";
    292   TF_ASSERT_OK(Fuse());
    293   CheckGraph(3, 1);
    294 }
    295 
    296 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    297        FuseRemoteGraphByExecutor_HIJ) {
    298   ReplaceOpType({"H", "I", "J"}, "Mul");
    299   remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0;
    300   fuse_by_executor_ = true;
    301   TF_ASSERT_OK(Fuse());
    302   CheckGraph(9, 1);
    303 }
    304 
    305 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    306        FuseRemoteGraphByExecutor_FGHIJ) {
    307   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
    308   remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1;
    309   fuse_by_executor_ = true;
    310   TF_ASSERT_OK(Fuse());
    311   CheckGraph(3, 1);
    312 }
    313 
    314 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
    315   fused_node_names_str_ = "H,I,J";
    316   TF_ASSERT_OK(PlaceFuseArgs());
    317   TF_ASSERT_OK(FuseWithPlacedArgs());
    318   CheckGraph(9, 1);
    319 }
    320 
    321 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDEFGHIJK) {
    322   fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
    323   TF_ASSERT_OK(PlaceFuseArgs());
    324   TF_ASSERT_OK(FuseWithPlacedArgs());
    325   CheckGraph(3, 1);
    326 }
    327 
    328 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_FCG_J) {
    329   border_inputs_str_ = "F:0,C:0,G";
    330   border_outputs_str_ = "J:0";
    331   TF_ASSERT_OK(PlaceFuseArgs());
    332   TF_ASSERT_OK(FuseWithPlacedArgs());
    333   CheckGraph(9, 1);
    334 }
    335 
    336 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDE_K) {
    337   SetInputShapeType();
    338   border_inputs_str_ = "A,B,C,D,E";
    339   border_outputs_str_ = "K";
    340   TF_ASSERT_OK(PlaceFuseArgs());
    341   TF_ASSERT_OK(FuseWithPlacedArgs());
    342   CheckGraph(7, 1);
    343 }
    344 
    345 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_MUL_HIJ) {
    346   SetInputShapeType();
    347   ReplaceOpType({"H", "I", "J"}, "Mul");
    348   fused_op_types_str_ = "Mul";
    349 
    350   TF_ASSERT_OK(PlaceFuseArgs());
    351   TF_ASSERT_OK(FuseWithPlacedArgs());
    352   CheckGraph(9, 1);
    353 }
    354 
    355 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
    356        PlaceAndFuse_CONST_MUL_FGHIJ) {
    357   SetInputShapeType();
    358   ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
    359   fused_op_types_str_ = "Const,Mul";
    360 
    361   TF_ASSERT_OK(PlaceFuseArgs());
    362   TF_ASSERT_OK(FuseWithPlacedArgs());
    363   CheckGraph(3, 1);
    364 }
    365 
    366 }  // namespace
    367 }  // namespace graph_transforms
    368 }  // namespace tensorflow
    369