1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/ops/const_op.h" 17 #include "tensorflow/cc/ops/image_ops.h" 18 #include "tensorflow/cc/ops/nn_ops.h" 19 #include "tensorflow/cc/ops/standard_ops.h" 20 #include "tensorflow/core/common_runtime/function.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/framework/tensor_testutil.h" 23 #include "tensorflow/core/graph/default_device.h" 24 #include "tensorflow/core/graph/node_builder.h" 25 #include "tensorflow/core/graph/testlib.h" 26 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" 27 #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" 28 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" 29 #include "tensorflow/core/lib/core/status_test_util.h" 30 #include "tensorflow/core/platform/test.h" 31 #include "tensorflow/core/public/session.h" 32 #include "tensorflow/tools/graph_transforms/transform_utils.h" 33 34 namespace tensorflow { 35 namespace graph_transforms { 36 37 // Declared here so we don't have to put it in a public header. 38 Status FuseRemoteGraph(const GraphDef& input_graph_def, 39 const TransformFuncContext& context, 40 GraphDef* output_graph_def); 41 42 Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def, 43 const TransformFuncContext& context, 44 GraphDef* output_graph_def); 45 46 namespace { 47 constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME = 48 "remote_fused_graph_executor_name"; 49 constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME = 50 "remote_fused_graph_node_name"; 51 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 = 52 "fuse_test_remote_fused_graph_executor0"; 53 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 = 54 "fuse_test_remote_fused_graph_executor1"; 55 56 Status BuildRemoteFusedGraphExecutor0( 57 std::unique_ptr<IRemoteFusedGraphExecutor>* executor) { 58 executor->reset( 59 new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0)); 60 return Status::OK(); 61 } 62 63 Status BuildRemoteFusedGraphExecutor1( 64 std::unique_ptr<IRemoteFusedGraphExecutor>* executor) { 65 executor->reset(new TestRemoteFusedGraphExecutor( 66 {"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1)); 67 return Status::OK(); 68 } 69 70 class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test { 71 protected: 72 void SetUp() final { 73 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph( 74 &input_graph_def_)); 75 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar 76 hexagon_remote_fused_graph_executor_build( 77 REMOTE_FUSED_GRAPH_EXECUTOR_NAME, 78 [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status { 79 return Status::OK(); 80 }); 81 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar 82 test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0, 83 BuildRemoteFusedGraphExecutor0); 84 85 RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar 86 test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1, 87 BuildRemoteFusedGraphExecutor1); 88 } 89 90 void TearDown() final {} 91 92 Status Fuse() { return FuseInternal(/*only_place_args=*/false); } 93 94 Status PlaceFuseArgs() { return FuseInternal(/*only_place_args*/ true); } 95 96 Status FuseWithPlacedArgs() { 97 const std::vector<std::pair<string, Tensor>> input_tensors{ 98 {"A", {DT_FLOAT, {1, 1, 1, 1}}}}; 99 return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments( 100 input_graph_def_with_fuse_args_, input_tensors, &output_graph_def_); 101 } 102 103 Status FuseInternal(bool only_place_args) { 104 TransformFuncContext context; 105 context.input_names = inputs_; 106 context.output_names = outputs_; 107 108 if (!input_types_.empty()) { 109 context.params.insert(std::pair<string, std::vector<string>>( 110 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES, 111 {input_types_}})); 112 } 113 if (!input_shapes_.empty()) { 114 context.params.insert(std::pair<string, std::vector<string>>( 115 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES, 116 {input_shapes_}})); 117 } 118 if (!fused_node_names_str_.empty()) { 119 context.params.insert(std::pair<string, std::vector<string>>( 120 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES, 121 {fused_node_names_str_}})); 122 } 123 124 if (!border_inputs_str_.empty()) { 125 context.params.insert(std::pair<string, std::vector<string>>( 126 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS, 127 {border_inputs_str_}})); 128 } 129 if (!border_outputs_str_.empty()) { 130 context.params.insert(std::pair<string, std::vector<string>>( 131 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS, 132 {border_outputs_str_}})); 133 } 134 135 if (!fused_op_types_str_.empty()) { 136 context.params.insert(std::pair<string, std::vector<string>>( 137 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES, 138 {fused_op_types_str_}})); 139 } 140 141 if (fuse_by_executor_) { 142 context.params.insert(std::pair<string, std::vector<string>>( 143 {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR, 144 {"true"}})); 145 } 146 147 context.params.insert(std::pair<string, std::vector<string>>( 148 {RemoteFusedGraphExecuteUtils:: 149 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME, 150 {remote_fused_graph_executor_name_}})); 151 context.params.insert(std::pair<string, std::vector<string>>( 152 {RemoteFusedGraphExecuteUtils:: 153 TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME, 154 {REMOTE_FUSED_GRAPH_NODE_NAME}})); 155 156 if (only_place_args) { 157 return PlaceRemoteGraphArguments(input_graph_def_, context, 158 &input_graph_def_with_fuse_args_); 159 } else { 160 return FuseRemoteGraph(input_graph_def_, context, &output_graph_def_); 161 } 162 } 163 164 void SetInputShapeType() { 165 input_types_ = "float"; 166 input_shapes_ = "1,1,1,1"; 167 } 168 169 void ReplaceOpType(const std::unordered_set<string>& op_name, 170 const string& new_op_type) { 171 for (NodeDef& node_def : *input_graph_def_.mutable_node()) { 172 if (op_name.count(node_def.name()) > 0) { 173 node_def.set_op(new_op_type); 174 } 175 } 176 } 177 178 void CheckGraph(int expected_node_count, int expected_cluster_count) { 179 EXPECT_EQ(expected_node_count, output_graph_def_.node_size()); 180 181 int cluster_count = 0; 182 for (const NodeDef& node_def : output_graph_def_.node()) { 183 const string& name = node_def.name(); 184 if (StringPiece(name).starts_with(REMOTE_FUSED_GRAPH_NODE_NAME)) { 185 ++cluster_count; 186 RemoteFusedGraphExecuteInfo info; 187 string serialized_proto; 188 TF_ASSERT_OK( 189 GetNodeAttr(node_def, 190 RemoteFusedGraphExecuteUtils:: 191 ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO, 192 &serialized_proto)); 193 info.ParseFromString(serialized_proto); 194 CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name()); 195 } 196 } 197 EXPECT_EQ(expected_cluster_count, cluster_count); 198 } 199 200 public: 201 const std::vector<string> inputs_{"A"}; 202 const std::vector<string> outputs_{"K"}; 203 GraphDef input_graph_def_; 204 string input_types_; 205 string input_shapes_; 206 GraphDef input_graph_def_with_fuse_args_; 207 GraphDef output_graph_def_; 208 string fused_node_names_str_; 209 string border_inputs_str_; 210 string border_outputs_str_; 211 string fused_op_types_str_; 212 string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME}; 213 bool fuse_by_executor_{false}; 214 }; 215 216 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 217 FuseRemoteGraphByNodesWithShapeType_HIJ) { 218 SetInputShapeType(); 219 fused_node_names_str_ = "H,I,J"; 220 TF_ASSERT_OK(Fuse()); 221 CheckGraph(9, 1); 222 } 223 224 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 225 FuseRemoteGraphByNodesWithoutShapeType_HIJ) { 226 fused_node_names_str_ = "H,I,J"; 227 TF_ASSERT_OK(Fuse()); 228 CheckGraph(9, 1); 229 } 230 231 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 232 FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK) { 233 SetInputShapeType(); 234 fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K"; 235 TF_ASSERT_OK(Fuse()); 236 CheckGraph(3, 1); 237 } 238 239 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 240 FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK) { 241 fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K"; 242 TF_ASSERT_OK(Fuse()); 243 CheckGraph(3, 1); 244 } 245 246 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 247 FuseRemoteGraphByBorderWithShapeType_FCG_J) { 248 SetInputShapeType(); 249 border_inputs_str_ = "F:0,C:0,G"; 250 border_outputs_str_ = "J:0"; 251 TF_ASSERT_OK(Fuse()); 252 CheckGraph(9, 1); 253 } 254 255 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 256 FuseRemoteGraphByBorderWithoutShapeType_FCG_J) { 257 border_inputs_str_ = "F:0,C:0,G"; 258 border_outputs_str_ = "J:0"; 259 TF_ASSERT_OK(Fuse()); 260 CheckGraph(9, 1); 261 } 262 263 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 264 FuseRemoteGraphByBorderWithShapeType_ABCDE_K) { 265 SetInputShapeType(); 266 border_inputs_str_ = "A,B,C,D,E"; 267 border_outputs_str_ = "K"; 268 TF_ASSERT_OK(Fuse()); 269 CheckGraph(7, 1); 270 } 271 272 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 273 FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K) { 274 border_inputs_str_ = "A,B,C,D,E"; 275 border_outputs_str_ = "K"; 276 TF_ASSERT_OK(Fuse()); 277 CheckGraph(7, 1); 278 } 279 280 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 281 FuseRemoteGraphByOpTypes_HIJ) { 282 ReplaceOpType({"H", "I", "J"}, "Mul"); 283 fused_op_types_str_ = "Mul"; 284 TF_ASSERT_OK(Fuse()); 285 CheckGraph(9, 1); 286 } 287 288 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 289 FuseRemoteGraphByOpTypes_FGHIJ) { 290 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul"); 291 fused_op_types_str_ = "Const,Mul"; 292 TF_ASSERT_OK(Fuse()); 293 CheckGraph(3, 1); 294 } 295 296 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 297 FuseRemoteGraphByExecutor_HIJ) { 298 ReplaceOpType({"H", "I", "J"}, "Mul"); 299 remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0; 300 fuse_by_executor_ = true; 301 TF_ASSERT_OK(Fuse()); 302 CheckGraph(9, 1); 303 } 304 305 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 306 FuseRemoteGraphByExecutor_FGHIJ) { 307 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul"); 308 remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1; 309 fuse_by_executor_ = true; 310 TF_ASSERT_OK(Fuse()); 311 CheckGraph(3, 1); 312 } 313 314 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) { 315 fused_node_names_str_ = "H,I,J"; 316 TF_ASSERT_OK(PlaceFuseArgs()); 317 TF_ASSERT_OK(FuseWithPlacedArgs()); 318 CheckGraph(9, 1); 319 } 320 321 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDEFGHIJK) { 322 fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K"; 323 TF_ASSERT_OK(PlaceFuseArgs()); 324 TF_ASSERT_OK(FuseWithPlacedArgs()); 325 CheckGraph(3, 1); 326 } 327 328 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_FCG_J) { 329 border_inputs_str_ = "F:0,C:0,G"; 330 border_outputs_str_ = "J:0"; 331 TF_ASSERT_OK(PlaceFuseArgs()); 332 TF_ASSERT_OK(FuseWithPlacedArgs()); 333 CheckGraph(9, 1); 334 } 335 336 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDE_K) { 337 SetInputShapeType(); 338 border_inputs_str_ = "A,B,C,D,E"; 339 border_outputs_str_ = "K"; 340 TF_ASSERT_OK(PlaceFuseArgs()); 341 TF_ASSERT_OK(FuseWithPlacedArgs()); 342 CheckGraph(7, 1); 343 } 344 345 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_MUL_HIJ) { 346 SetInputShapeType(); 347 ReplaceOpType({"H", "I", "J"}, "Mul"); 348 fused_op_types_str_ = "Mul"; 349 350 TF_ASSERT_OK(PlaceFuseArgs()); 351 TF_ASSERT_OK(FuseWithPlacedArgs()); 352 CheckGraph(9, 1); 353 } 354 355 TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, 356 PlaceAndFuse_CONST_MUL_FGHIJ) { 357 SetInputShapeType(); 358 ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul"); 359 fused_op_types_str_ = "Const,Mul"; 360 361 TF_ASSERT_OK(PlaceFuseArgs()); 362 TF_ASSERT_OK(FuseWithPlacedArgs()); 363 CheckGraph(3, 1); 364 } 365 366 } // namespace 367 } // namespace graph_transforms 368 } // namespace tensorflow 369