1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 17 18 #include "tensorflow/core/framework/cost_graph.pb.h" 19 #include "tensorflow/core/framework/step_stats.pb.h" 20 #include "tensorflow/core/framework/tensor_testutil.h" 21 #include "tensorflow/core/lib/core/status_test_util.h" 22 #include "tensorflow/core/platform/test.h" 23 #include "tensorflow/core/protobuf/config.pb.h" 24 25 namespace tensorflow { 26 namespace { 27 28 Tensor TensorA() { 29 Tensor a_tensor(DT_INT32, TensorShape({2, 2})); 30 test::FillValues<int32>(&a_tensor, {3, 2, -1, 0}); 31 return a_tensor; 32 } 33 34 Tensor TensorB() { 35 Tensor b_tensor(DT_INT32, TensorShape({1, 2})); 36 test::FillValues<int32>(&b_tensor, {1, 2}); 37 return b_tensor; 38 } 39 40 void BuildRunStepRequest(MutableRunStepRequestWrapper* request) { 41 request->set_session_handle("handle"); 42 request->set_partial_run_handle("partial_handle"); 43 request->add_feed("feed_a:0", TensorA()); 44 request->add_feed("feed_b:0", TensorB()); 45 request->add_fetch("fetch_x:0"); 46 request->add_fetch("fetch_y:0"); 47 request->add_target("target_i"); 48 request->add_target("target_j"); 49 request->mutable_options()->set_timeout_in_ms(37); 50 } 51 52 void CheckRunStepRequest(const RunStepRequestWrapper& request) { 53 EXPECT_EQ("handle", request.session_handle()); 54 EXPECT_EQ("partial_handle", request.partial_run_handle()); 55 EXPECT_EQ(2, request.num_feeds()); 56 EXPECT_EQ("feed_a:0", request.feed_name(0)); 57 EXPECT_EQ("feed_b:0", request.feed_name(1)); 58 Tensor val; 59 TF_EXPECT_OK(request.FeedValue(0, &val)); 60 test::ExpectTensorEqual<int32>(TensorA(), val); 61 TF_EXPECT_OK(request.FeedValue(1, &val)); 62 test::ExpectTensorEqual<int32>(TensorB(), val); 63 64 EXPECT_EQ(2, request.num_fetches()); 65 EXPECT_EQ("fetch_x:0", request.fetch_name(0)); 66 EXPECT_EQ("fetch_y:0", request.fetch_name(1)); 67 EXPECT_EQ("target_i", request.target_name(0)); 68 EXPECT_EQ("target_j", request.target_name(1)); 69 EXPECT_EQ(37, request.options().timeout_in_ms()); 70 } 71 72 void BuildRunGraphRequest(const RunStepRequestWrapper& run_step_request, 73 MutableRunGraphRequestWrapper* run_graph_request) { 74 run_graph_request->set_graph_handle("graph_handle"); 75 run_graph_request->set_step_id(13); 76 run_graph_request->mutable_exec_opts()->set_record_timeline(true); 77 TF_EXPECT_OK(run_graph_request->AddSendFromRunStepRequest(run_step_request, 0, 78 "send_0")); 79 TF_EXPECT_OK(run_graph_request->AddSendFromRunStepRequest(run_step_request, 1, 80 "send_1")); 81 run_graph_request->add_recv_key("recv_2"); 82 run_graph_request->add_recv_key("recv_3"); 83 run_graph_request->set_is_partial(true); 84 } 85 86 void CheckRunGraphRequest(const RunGraphRequestWrapper& request) { 87 EXPECT_EQ("graph_handle", request.graph_handle()); 88 EXPECT_EQ(13, request.step_id()); 89 EXPECT_FALSE(request.exec_opts().record_costs()); 90 EXPECT_TRUE(request.exec_opts().record_timeline()); 91 EXPECT_FALSE(request.exec_opts().record_partition_graphs()); 92 EXPECT_EQ(2, request.num_sends()); 93 Tensor val; 94 TF_EXPECT_OK(request.SendValue(0, &val)); 95 test::ExpectTensorEqual<int32>(TensorA(), val); 96 TF_EXPECT_OK(request.SendValue(1, &val)); 97 test::ExpectTensorEqual<int32>(TensorB(), val); 98 EXPECT_TRUE(request.is_partial()); 99 EXPECT_FALSE(request.is_last_partial_run()); 100 } 101 102 void BuildRunGraphResponse(MutableRunGraphResponseWrapper* run_graph_response) { 103 run_graph_response->AddRecv("recv_2", TensorA()); 104 run_graph_response->AddRecv("recv_3", TensorB()); 105 run_graph_response->mutable_step_stats()->add_dev_stats()->set_device( 106 "/cpu:0"); 107 run_graph_response->mutable_cost_graph()->add_node()->set_name("cost_node"); 108 GraphDef graph_def; 109 graph_def.mutable_versions()->set_producer(1234); 110 graph_def.mutable_versions()->set_min_consumer(1234); 111 run_graph_response->AddPartitionGraph(graph_def); 112 } 113 114 void CheckRunGraphResponse(MutableRunGraphResponseWrapper* response) { 115 ASSERT_EQ(2, response->num_recvs()); 116 EXPECT_EQ("recv_2", response->recv_key(0)); 117 EXPECT_EQ("recv_3", response->recv_key(1)); 118 Tensor val; 119 TF_EXPECT_OK(response->RecvValue(0, &val)); 120 test::ExpectTensorEqual<int32>(TensorA(), val); 121 TF_EXPECT_OK(response->RecvValue(1, &val)); 122 test::ExpectTensorEqual<int32>(TensorB(), val); 123 ASSERT_EQ(1, response->mutable_step_stats()->dev_stats_size()); 124 EXPECT_EQ("/cpu:0", response->mutable_step_stats()->dev_stats(0).device()); 125 ASSERT_EQ(1, response->mutable_cost_graph()->node_size()); 126 EXPECT_EQ("cost_node", response->mutable_cost_graph()->node(0).name()); 127 ASSERT_EQ(1, response->num_partition_graphs()); 128 EXPECT_EQ(1234, response->mutable_partition_graph(0)->versions().producer()); 129 EXPECT_EQ(1234, 130 response->mutable_partition_graph(0)->versions().min_consumer()); 131 } 132 133 void BuildRunStepResponse(MutableRunGraphResponseWrapper* run_graph_response, 134 MutableRunStepResponseWrapper* run_step_response) { 135 TF_EXPECT_OK(run_step_response->AddTensorFromRunGraphResponse( 136 "fetch_x:0", run_graph_response, 0)); 137 TF_EXPECT_OK(run_step_response->AddTensorFromRunGraphResponse( 138 "fetch_y:0", run_graph_response, 1)); 139 *run_step_response->mutable_metadata()->mutable_step_stats() = 140 *run_graph_response->mutable_step_stats(); 141 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = 142 run_step_response->mutable_metadata()->mutable_partition_graphs(); 143 for (size_t i = 0; i < run_graph_response->num_partition_graphs(); i++) { 144 partition_graph_defs->Add()->Swap( 145 run_graph_response->mutable_partition_graph(i)); 146 } 147 } 148 149 void CheckRunStepResponse(const MutableRunStepResponseWrapper& response) { 150 ASSERT_EQ(2, response.num_tensors()); 151 EXPECT_EQ("fetch_x:0", response.tensor_name(0)); 152 EXPECT_EQ("fetch_y:0", response.tensor_name(1)); 153 Tensor val; 154 TF_EXPECT_OK(response.TensorValue(0, &val)); 155 test::ExpectTensorEqual<int32>(TensorA(), val); 156 TF_EXPECT_OK(response.TensorValue(1, &val)); 157 test::ExpectTensorEqual<int32>(TensorB(), val); 158 ASSERT_EQ(1, response.metadata().step_stats().dev_stats_size()); 159 EXPECT_EQ("/cpu:0", response.metadata().step_stats().dev_stats(0).device()); 160 ASSERT_EQ(1, response.metadata().partition_graphs_size()); 161 EXPECT_EQ(1234, 162 response.metadata().partition_graphs(0).versions().producer()); 163 EXPECT_EQ(1234, 164 response.metadata().partition_graphs(0).versions().min_consumer()); 165 } 166 167 TEST(MessageWrappers, RunStepRequest_Basic) { 168 InMemoryRunStepRequest in_memory_request; 169 BuildRunStepRequest(&in_memory_request); 170 CheckRunStepRequest(in_memory_request); 171 172 MutableProtoRunStepRequest proto_request; 173 BuildRunStepRequest(&proto_request); 174 CheckRunStepRequest(proto_request); 175 176 CheckRunStepRequest(ProtoRunStepRequest(&in_memory_request.ToProto())); 177 CheckRunStepRequest(ProtoRunStepRequest(&proto_request.ToProto())); 178 } 179 180 TEST(MessageWrappers, RunGraphRequest_Basic) { 181 InMemoryRunStepRequest in_memory_run_step_request; 182 BuildRunStepRequest(&in_memory_run_step_request); 183 184 MutableProtoRunStepRequest mutable_proto_run_step_request; 185 BuildRunStepRequest(&mutable_proto_run_step_request); 186 187 ProtoRunStepRequest proto_run_step_request( 188 &mutable_proto_run_step_request.ToProto()); 189 190 // Client -(in memory)-> Master -(in memory)-> Worker. 191 { 192 InMemoryRunGraphRequest request; 193 BuildRunGraphRequest(in_memory_run_step_request, &request); 194 CheckRunGraphRequest(request); 195 CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto())); 196 } 197 198 // Client -(mutable proto)-> Master -(in memory)-> Worker. 199 { 200 InMemoryRunGraphRequest request; 201 BuildRunGraphRequest(mutable_proto_run_step_request, &request); 202 CheckRunGraphRequest(request); 203 CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto())); 204 } 205 206 // Client -(proto)-> Master -(in memory)-> Worker. 207 { 208 InMemoryRunGraphRequest request; 209 BuildRunGraphRequest(proto_run_step_request, &request); 210 CheckRunGraphRequest(request); 211 CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto())); 212 } 213 214 // Client -(in memory)-> Master -(mutable proto)-> Worker. 215 { 216 MutableProtoRunGraphRequest request; 217 BuildRunGraphRequest(in_memory_run_step_request, &request); 218 CheckRunGraphRequest(request); 219 CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto())); 220 } 221 222 // Client -(mutable proto)-> Master -(mutable proto)-> Worker. 223 { 224 MutableProtoRunGraphRequest request; 225 BuildRunGraphRequest(mutable_proto_run_step_request, &request); 226 CheckRunGraphRequest(request); 227 CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto())); 228 } 229 230 // Client -(proto)-> Master -(mutable proto)-> Worker. 231 { 232 MutableProtoRunGraphRequest request; 233 BuildRunGraphRequest(proto_run_step_request, &request); 234 CheckRunGraphRequest(request); 235 CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto())); 236 } 237 } 238 239 TEST(MessageWrappers, RunGraphResponse_Basic) { 240 InMemoryRunGraphResponse in_memory_response; 241 BuildRunGraphResponse(&in_memory_response); 242 CheckRunGraphResponse(&in_memory_response); 243 244 OwnedProtoRunGraphResponse owned_proto_response; 245 BuildRunGraphResponse(&owned_proto_response); 246 CheckRunGraphResponse(&owned_proto_response); 247 248 RunGraphResponse response_proto; 249 NonOwnedProtoRunGraphResponse non_owned_proto_response(&response_proto); 250 BuildRunGraphResponse(&non_owned_proto_response); 251 CheckRunGraphResponse(&non_owned_proto_response); 252 } 253 254 TEST(MessageWrappers, RunStepResponse_Basic) { 255 { 256 // Worker -(in memory)-> Master -(in memory)-> Client. 257 InMemoryRunGraphResponse run_graph_response; 258 BuildRunGraphResponse(&run_graph_response); 259 InMemoryRunStepResponse response; 260 BuildRunStepResponse(&run_graph_response, &response); 261 CheckRunStepResponse(response); 262 } 263 264 { 265 // Worker -(in memory)-> Master -(owned proto)-> Client. 266 InMemoryRunGraphResponse run_graph_response; 267 BuildRunGraphResponse(&run_graph_response); 268 OwnedProtoRunStepResponse response; 269 BuildRunStepResponse(&run_graph_response, &response); 270 CheckRunStepResponse(response); 271 } 272 273 { 274 // Worker -(in memory)-> Master -(non-owned proto)-> Client. 275 InMemoryRunGraphResponse run_graph_response; 276 BuildRunGraphResponse(&run_graph_response); 277 RunStepResponse response_proto; 278 NonOwnedProtoRunStepResponse response(&response_proto); 279 BuildRunStepResponse(&run_graph_response, &response); 280 CheckRunStepResponse(response); 281 } 282 283 { 284 // Worker -(owned proto)-> Master -(in memory)-> Client. 285 OwnedProtoRunGraphResponse run_graph_response; 286 BuildRunGraphResponse(&run_graph_response); 287 InMemoryRunStepResponse response; 288 BuildRunStepResponse(&run_graph_response, &response); 289 CheckRunStepResponse(response); 290 } 291 292 { 293 // Worker -(owned proto)-> Master -(owned proto)-> Client. 294 OwnedProtoRunGraphResponse run_graph_response; 295 BuildRunGraphResponse(&run_graph_response); 296 OwnedProtoRunStepResponse response; 297 BuildRunStepResponse(&run_graph_response, &response); 298 CheckRunStepResponse(response); 299 } 300 301 { 302 // Worker -(owned proto)-> Master -(non-owned proto)-> Client. 303 OwnedProtoRunGraphResponse run_graph_response; 304 BuildRunGraphResponse(&run_graph_response); 305 RunStepResponse response_proto; 306 NonOwnedProtoRunStepResponse response(&response_proto); 307 BuildRunStepResponse(&run_graph_response, &response); 308 CheckRunStepResponse(response); 309 } 310 311 { 312 // Worker -(non-owned proto)-> Master -(in memory)-> Client. 313 RunGraphResponse run_graph_response_proto; 314 NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto); 315 BuildRunGraphResponse(&run_graph_response); 316 InMemoryRunStepResponse response; 317 BuildRunStepResponse(&run_graph_response, &response); 318 CheckRunStepResponse(response); 319 } 320 321 { 322 // Worker -(non-owned proto)-> Master -(owned proto)-> Client. 323 RunGraphResponse run_graph_response_proto; 324 NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto); 325 BuildRunGraphResponse(&run_graph_response); 326 OwnedProtoRunStepResponse response; 327 BuildRunStepResponse(&run_graph_response, &response); 328 CheckRunStepResponse(response); 329 } 330 331 { 332 // Worker -(non-owned proto)-> Master -(non-owned proto)-> Client. 333 RunGraphResponse run_graph_response_proto; 334 NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto); 335 BuildRunGraphResponse(&run_graph_response); 336 RunStepResponse response_proto; 337 NonOwnedProtoRunStepResponse response(&response_proto); 338 BuildRunStepResponse(&run_graph_response, &response); 339 CheckRunStepResponse(response); 340 } 341 } 342 343 } // namespace 344 } // namespace tensorflow 345