1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/cc/framework/ops.h" 17 #include "tensorflow/cc/framework/scope.h" 18 #include "tensorflow/cc/ops/const_op.h" 19 #include "tensorflow/core/framework/fake_input.h" 20 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/graph/graph.h" 23 #include "tensorflow/core/graph/node_builder.h" 24 #include "tensorflow/core/graph/testlib.h" 25 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" 26 #include "tensorflow/core/kernels/ops_testutil.h" 27 #include "tensorflow/core/kernels/ops_util.h" 28 #include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h" 29 #include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" 30 #include "tensorflow/core/lib/core/status_test_util.h" 31 #include "tensorflow/core/platform/test.h" 32 #include "tensorflow/core/platform/test_benchmark.h" 33 #include "tensorflow/core/public/session.h" 34 #include "tensorflow/core/public/session_options.h" 35 36 namespace tensorflow { 37 38 class RemoteFusedGraphExecuteTest : public OpsTestBase {}; 39 40 TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithOneDataType) { 41 DataTypeVector input_types({DT_FLOAT, DT_FLOAT}); 42 DataTypeVector output_types({DT_FLOAT}); 43 TF_ASSERT_OK( 44 NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute") 45 .Input(FakeInput(2, DT_FLOAT)) 46 .Attr("Tinputs", input_types) 47 .Attr("Toutputs", output_types) 48 .Attr("serialized_remote_fused_graph_execute_info", "") 49 .Finalize(node_def())); 50 TF_ASSERT_OK(InitOp()); 51 // TODO(satok): Add benchmark 52 } 53 54 TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithWrongDataType) { 55 DataTypeVector input_types({DT_INT32, DT_INT32}); 56 DataTypeVector output_types({DT_FLOAT}); 57 ASSERT_FALSE( 58 NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute") 59 .Input(FakeInput(2, DT_FLOAT)) 60 .Attr("Tinputs", input_types) 61 .Attr("Toutputs", output_types) 62 .Attr("serialized_remote_fused_graph_execute_info", "") 63 .Finalize(node_def()) 64 .ok()); 65 // TODO(satok): Add benchmark 66 } 67 68 //////////////////////////// 69 // End-to-end test: Begin // 70 //////////////////////////// 71 // This test does a end-to-end test for a simple usage of 72 // RemoteFusedGraphExecuteOp. 73 74 constexpr const char* const NAME_A = "a"; 75 constexpr const char* const NAME_B = "b"; 76 constexpr const char* const NAME_A_PLUS_B = "a_plus_b"; 77 constexpr const char* const REMOTE_FUSED_EXECUTE_OP_NODE_NAME = 78 "remote_fused_execute_op"; 79 constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME = 80 "build_test_remote_fused_graph_executor"; 81 82 constexpr float NODE_A_VAL = 2.0f; 83 constexpr float NODE_A_VAL2 = 10.0f; 84 constexpr float NODE_B_VAL = 3.0f; 85 constexpr float FLOAT_VALUE_TOLERANCE = 1e-8f; 86 87 // Utility functions // 88 static Output BuildPlaceHolderOp(const string& name, const DataType dt, 89 const TensorShape& tensor_shape, Scope* root) { 90 const Scope& scope = root->WithOpName(name); 91 Node* ret; 92 const string unique_name = scope.GetUniqueNameForOp("Placeholder"); 93 NodeBuilder builder = NodeBuilder(unique_name, "Placeholder") 94 .Attr("dtype", dt) 95 .Attr("shape", tensor_shape); 96 scope.UpdateBuilder(&builder); 97 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); 98 CHECK(scope.ok()); 99 return Output(ret, 0); 100 } 101 102 static Output BuildRemoteFusedGraphExecuteOp( 103 const string& name, const std::vector<Output>& output_list, 104 const int output_node_count, 105 const RemoteFusedGraphExecuteInfo& execute_info, Scope* root) { 106 const Scope& scope = root->WithOpName(name); 107 Node* ret; 108 CHECK(scope.ok()); 109 auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list)); 110 const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute"); 111 112 DataTypeVector input_types{DT_FLOAT}; 113 DataTypeVector output_types{DT_FLOAT}; 114 115 auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute") 116 .Input(node_out_list) 117 .Attr("Tinputs", input_types) 118 .Attr("Toutputs", output_types) 119 .Attr("serialized_remote_fused_graph_execute_info", 120 StringPiece(execute_info.SerializeAsString())); 121 CHECK(scope.ok()); 122 scope.UpdateBuilder(&builder); 123 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); 124 CHECK(scope.ok()); 125 return Output(ret, 0); 126 } 127 128 static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo( 129 const GraphDef& original_graph) { 130 RemoteFusedGraphExecuteInfo execute_info; 131 execute_info.set_executor_name(REMOTE_FUSED_EXECUTOR_NAME); 132 133 // In this example, simply copy all nodes. Basically, you don't need to add 134 // unused node for inference. 135 for (const NodeDef& node : original_graph.node()) { 136 NodeDef& copied_node = *execute_info.mutable_remote_graph()->add_node(); 137 copied_node = node; 138 // Adding tensor shape type to the node 139 // TODO(satok): Use TensorShapeMap to detime tensor shape type 140 RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType( 141 std::vector<DataType>({DT_FLOAT}), 142 std::vector<TensorShape>({TensorShape()}), &copied_node); 143 } 144 145 // Add node A as input 146 execute_info.add_graph_input_node_name(NAME_A); 147 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_a = 148 *execute_info.add_default_graph_input_tensor_shape(); 149 shape_a.set_dtype(DT_FLOAT); 150 // (skip setting shape to shape_a as it's shape is rank = 0.) 151 152 // Add node A + B as output 153 execute_info.add_graph_output_node_name(NAME_A_PLUS_B); 154 RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_a_plus_b = 155 *execute_info.add_default_graph_output_tensor_shape(); 156 shape_a_plus_b.set_dtype(DT_FLOAT); 157 // (skip setting shape to shape_a_plus_b as it's shape is rank = 0.) 158 159 return execute_info; 160 } 161 162 // 1. Create SampleRemoteFusedGraphExecutor to execute your fused graph 163 class SampleRemoteFusedGraphExecutor final : public IRemoteFusedGraphExecutor { 164 public: 165 int GetVersion() final { return 1; } 166 bool Init(const RemoteFusedGraphExecuteInfo& info) final { 167 info_ = &info; 168 for (const NodeDef& node_def : info.remote_graph().node()) { 169 node_def_map_.emplace(node_def.name(), &node_def); 170 } 171 return true; 172 } 173 bool Finalize() final { return true; } 174 bool SetupGraph() final { return true; } 175 bool ExecuteGraph() final { 176 CHECK(info_ != nullptr); 177 // TODO(satok): Add utilities to implement this function more easily. 178 // CAVEAT: This test only handles add op. You can implement here as you 179 // like. 180 CHECK_EQ(1, info_->graph_input_node_name_size()); 181 const string& input_node_name = info_->graph_input_node_name(0); 182 const Tensor& input_tensor = input_tensor_cache_[input_node_name]; 183 const float input_val = *input_tensor.scalar<float>().data(); 184 // TODO(satok): Read NAME_B from node_a_plus_b 185 const NodeDef& node_b = *node_def_map_.at(NAME_B); 186 const TensorProto* proto = nullptr; 187 TF_CHECK_OK(GetNodeAttr(node_b, "value", &proto)); 188 Tensor const_tensor; 189 TF_CHECK_OK(RemoteFusedGraphExecuteUtils::MakeTensorFromProto( 190 *proto, &const_tensor)); 191 const float b_val = *const_tensor.scalar<float>().data(); 192 Tensor output_a_plus_b(DT_FLOAT, {}); 193 output_a_plus_b.flat<float>().data()[0] = input_val + b_val; 194 output_tensor_buf_.emplace(info_->graph_output_node_name(0), 195 output_a_plus_b); 196 return true; 197 } 198 199 bool TeardownGraph() final { return true; } 200 201 bool FillInputNode(const string& node_name, const Tensor& tensor) final { 202 input_tensor_cache_[node_name] = tensor; 203 return true; 204 } 205 206 bool ReadOutputNode(const string& node_name, 207 TensorAllocatorFunc tensor_allocator) final { 208 // TODO(satok): Specify tensor shape by using default_graph_tensor_shape. 209 const Tensor& buffered_output_tensor = output_tensor_buf_.at(node_name); 210 const TensorShape& output_shape = buffered_output_tensor.shape(); 211 Tensor* output_tensor = tensor_allocator(output_shape); 212 CHECK_EQ(buffered_output_tensor.dtype(), output_tensor->dtype()); 213 CHECK(output_tensor->CopyFrom(buffered_output_tensor, output_shape)); 214 return true; 215 } 216 217 Status FuseRemoteGraph(const GraphDef& original_graph_def, 218 const std::vector<string>& /*inputs*/, 219 const std::vector<string>& /*outputs*/, 220 GraphDef* fused_graph_def) final { 221 *fused_graph_def = original_graph_def; 222 return Status::OK(); 223 } 224 225 bool IsEnabled() const final { return true; } 226 227 private: 228 const RemoteFusedGraphExecuteInfo* info_; 229 std::unordered_map<string, Tensor> input_tensor_cache_; 230 std::unordered_map<string, const NodeDef*> node_def_map_; 231 std::unordered_map<string, Tensor> output_tensor_buf_; 232 }; 233 234 // 2. Register a builder of your custom executor 235 namespace remote_fused_graph_execute_op { 236 Status BuildRemoteFusedGraphExecutor( 237 std::unique_ptr<IRemoteFusedGraphExecutor>* executor) { 238 executor->reset(new SampleRemoteFusedGraphExecutor()); 239 return Status::OK(); 240 } 241 242 // This class instantiation registers executor to the 243 // RemoteFusedGraphExecuteOp. This architecture makes executors to be 244 // pluggable in order not to link unnecessary libraries. 245 static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar 246 k_test_remote_fused_graph_executor_build(REMOTE_FUSED_EXECUTOR_NAME, 247 BuildRemoteFusedGraphExecutor); 248 } // namespace remote_fused_graph_execute_op 249 250 // 3. Create Graph transform function to fuse your graph 251 static Status RewriteGraphToFusedGraph(const GraphDef& original_graph, 252 GraphDef* fused_graph) { 253 Scope root = Scope::NewRootScope(); 254 std::vector<Output> output_list; 255 const Output op_a = BuildPlaceHolderOp(NAME_A, DT_FLOAT, {}, &root); 256 output_list.emplace_back(op_a); 257 const RemoteFusedGraphExecuteInfo execute_info = 258 BuildRemoteFusedGraphExecuteInfo(original_graph); 259 BuildRemoteFusedGraphExecuteOp(REMOTE_FUSED_EXECUTE_OP_NODE_NAME, output_list, 260 1, execute_info, &root); 261 GraphDef fused_graph_def; 262 TF_CHECK_OK(root.ToGraphDef(&fused_graph_def)); 263 *fused_graph = fused_graph_def; 264 return Status::OK(); 265 } 266 267 // 4. Register transform function 268 // You can register transform function by REGISTER_GRAPH_TRANSFORM. 269 // In this test, we don't use graph transform tool to avoid linking to 270 // the graph transform library. 271 // To register transform function, you need to change the interface of 272 // BuildFusedGraphDefOfAddGraph to 273 // Status BuildFusedGraphDefOfAddGraph( 274 // const GraphDef& original_graph, const TransformFuncContext& context, 275 // GraphDef* output_graph_def); 276 // Then register the function like: 277 // REGISTER_GRAPH_TRANSFORM("rewrite_graph", RewriteGraph); 278 279 // 5. Fuse the original graph and run the inference the new fused graph 280 TEST(RemoteFusedExecuteGraphOp, EndToEndTest) { 281 // 5.1 Load original graph 282 GraphDef original_graph; 283 TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildAddGraph( 284 NAME_A, NODE_A_VAL, NAME_B, NODE_B_VAL, NAME_A_PLUS_B, &original_graph)); 285 286 // 5.2 Fuse graph 287 GraphDef fused_graph; 288 TF_ASSERT_OK(RewriteGraphToFusedGraph(original_graph, &fused_graph)); 289 290 // 5.3 Setup session 291 std::vector<Tensor> output_tensors; 292 SessionOptions session_options; 293 session_options.env = Env::Default(); 294 std::unique_ptr<Session> session = 295 std::unique_ptr<Session>(NewSession(session_options)); 296 Status status = session->Create(fused_graph); 297 ASSERT_TRUE(status.ok()); 298 RunOptions run_options; 299 run_options.set_trace_level(RunOptions::FULL_TRACE); 300 RunMetadata run_metadata; 301 302 // 5.4 Setup input 303 Tensor input_a(DT_FLOAT, {}); 304 input_a.flat<float>().data()[0] = NODE_A_VAL2; 305 std::vector<std::pair<string, Tensor>> inputs; 306 inputs.emplace_back(NAME_A, input_a); 307 308 // 5.5 Setup output 309 const std::vector<string> outputs{REMOTE_FUSED_EXECUTE_OP_NODE_NAME}; 310 311 // 5.6 Run inference with all node as output 312 status = session->Run(run_options, inputs, outputs, {}, &output_tensors, 313 &run_metadata); 314 ASSERT_TRUE(status.ok()); 315 316 // 5.7 Check output tensor value 317 ASSERT_EQ(1, output_tensors.size()); 318 EXPECT_NEAR(NODE_A_VAL2 + NODE_B_VAL, 319 output_tensors.at(0).flat<float>().data()[0], 320 FLOAT_VALUE_TOLERANCE); 321 } 322 323 //////////////////////////// 324 // End-to-end test: End // 325 //////////////////////////// 326 327 } // namespace tensorflow 328