Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
     17 
     18 #include "tensorflow/core/framework/cost_graph.pb.h"
     19 #include "tensorflow/core/framework/step_stats.pb.h"
     20 #include "tensorflow/core/framework/tensor_testutil.h"
     21 #include "tensorflow/core/lib/core/status_test_util.h"
     22 #include "tensorflow/core/platform/test.h"
     23 #include "tensorflow/core/protobuf/config.pb.h"
     24 
     25 namespace tensorflow {
     26 namespace {
     27 
     28 Tensor TensorA() {
     29   Tensor a_tensor(DT_INT32, TensorShape({2, 2}));
     30   test::FillValues<int32>(&a_tensor, {3, 2, -1, 0});
     31   return a_tensor;
     32 }
     33 
     34 Tensor TensorB() {
     35   Tensor b_tensor(DT_INT32, TensorShape({1, 2}));
     36   test::FillValues<int32>(&b_tensor, {1, 2});
     37   return b_tensor;
     38 }
     39 
     40 void BuildRunStepRequest(MutableRunStepRequestWrapper* request) {
     41   request->set_session_handle("handle");
     42   request->set_partial_run_handle("partial_handle");
     43   request->add_feed("feed_a:0", TensorA());
     44   request->add_feed("feed_b:0", TensorB());
     45   request->add_fetch("fetch_x:0");
     46   request->add_fetch("fetch_y:0");
     47   request->add_target("target_i");
     48   request->add_target("target_j");
     49   request->mutable_options()->set_timeout_in_ms(37);
     50 }
     51 
     52 void CheckRunStepRequest(const RunStepRequestWrapper& request) {
     53   EXPECT_EQ("handle", request.session_handle());
     54   EXPECT_EQ("partial_handle", request.partial_run_handle());
     55   EXPECT_EQ(2, request.num_feeds());
     56   EXPECT_EQ("feed_a:0", request.feed_name(0));
     57   EXPECT_EQ("feed_b:0", request.feed_name(1));
     58   Tensor val;
     59   TF_EXPECT_OK(request.FeedValue(0, &val));
     60   test::ExpectTensorEqual<int32>(TensorA(), val);
     61   TF_EXPECT_OK(request.FeedValue(1, &val));
     62   test::ExpectTensorEqual<int32>(TensorB(), val);
     63 
     64   EXPECT_EQ(2, request.num_fetches());
     65   EXPECT_EQ("fetch_x:0", request.fetch_name(0));
     66   EXPECT_EQ("fetch_y:0", request.fetch_name(1));
     67   EXPECT_EQ("target_i", request.target_name(0));
     68   EXPECT_EQ("target_j", request.target_name(1));
     69   EXPECT_EQ(37, request.options().timeout_in_ms());
     70 }
     71 
     72 void BuildRunGraphRequest(const RunStepRequestWrapper& run_step_request,
     73                           MutableRunGraphRequestWrapper* run_graph_request) {
     74   run_graph_request->set_graph_handle("graph_handle");
     75   run_graph_request->set_step_id(13);
     76   run_graph_request->mutable_exec_opts()->set_record_timeline(true);
     77   TF_EXPECT_OK(run_graph_request->AddSendFromRunStepRequest(run_step_request, 0,
     78                                                             "send_0"));
     79   TF_EXPECT_OK(run_graph_request->AddSendFromRunStepRequest(run_step_request, 1,
     80                                                             "send_1"));
     81   run_graph_request->add_recv_key("recv_2");
     82   run_graph_request->add_recv_key("recv_3");
     83   run_graph_request->set_is_partial(true);
     84 }
     85 
     86 void CheckRunGraphRequest(const RunGraphRequestWrapper& request) {
     87   EXPECT_EQ("graph_handle", request.graph_handle());
     88   EXPECT_EQ(13, request.step_id());
     89   EXPECT_FALSE(request.exec_opts().record_costs());
     90   EXPECT_TRUE(request.exec_opts().record_timeline());
     91   EXPECT_FALSE(request.exec_opts().record_partition_graphs());
     92   EXPECT_EQ(2, request.num_sends());
     93   Tensor val;
     94   TF_EXPECT_OK(request.SendValue(0, &val));
     95   test::ExpectTensorEqual<int32>(TensorA(), val);
     96   TF_EXPECT_OK(request.SendValue(1, &val));
     97   test::ExpectTensorEqual<int32>(TensorB(), val);
     98   EXPECT_TRUE(request.is_partial());
     99   EXPECT_FALSE(request.is_last_partial_run());
    100 }
    101 
    102 void BuildRunGraphResponse(MutableRunGraphResponseWrapper* run_graph_response) {
    103   run_graph_response->AddRecv("recv_2", TensorA());
    104   run_graph_response->AddRecv("recv_3", TensorB());
    105   run_graph_response->mutable_step_stats()->add_dev_stats()->set_device(
    106       "/cpu:0");
    107   run_graph_response->mutable_cost_graph()->add_node()->set_name("cost_node");
    108   GraphDef graph_def;
    109   graph_def.mutable_versions()->set_producer(1234);
    110   graph_def.mutable_versions()->set_min_consumer(1234);
    111   run_graph_response->AddPartitionGraph(graph_def);
    112 }
    113 
    114 void CheckRunGraphResponse(MutableRunGraphResponseWrapper* response) {
    115   ASSERT_EQ(2, response->num_recvs());
    116   EXPECT_EQ("recv_2", response->recv_key(0));
    117   EXPECT_EQ("recv_3", response->recv_key(1));
    118   Tensor val;
    119   TF_EXPECT_OK(response->RecvValue(0, &val));
    120   test::ExpectTensorEqual<int32>(TensorA(), val);
    121   TF_EXPECT_OK(response->RecvValue(1, &val));
    122   test::ExpectTensorEqual<int32>(TensorB(), val);
    123   ASSERT_EQ(1, response->mutable_step_stats()->dev_stats_size());
    124   EXPECT_EQ("/cpu:0", response->mutable_step_stats()->dev_stats(0).device());
    125   ASSERT_EQ(1, response->mutable_cost_graph()->node_size());
    126   EXPECT_EQ("cost_node", response->mutable_cost_graph()->node(0).name());
    127   ASSERT_EQ(1, response->num_partition_graphs());
    128   EXPECT_EQ(1234, response->mutable_partition_graph(0)->versions().producer());
    129   EXPECT_EQ(1234,
    130             response->mutable_partition_graph(0)->versions().min_consumer());
    131 }
    132 
    133 void BuildRunStepResponse(MutableRunGraphResponseWrapper* run_graph_response,
    134                           MutableRunStepResponseWrapper* run_step_response) {
    135   TF_EXPECT_OK(run_step_response->AddTensorFromRunGraphResponse(
    136       "fetch_x:0", run_graph_response, 0));
    137   TF_EXPECT_OK(run_step_response->AddTensorFromRunGraphResponse(
    138       "fetch_y:0", run_graph_response, 1));
    139   *run_step_response->mutable_metadata()->mutable_step_stats() =
    140       *run_graph_response->mutable_step_stats();
    141   protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
    142       run_step_response->mutable_metadata()->mutable_partition_graphs();
    143   for (size_t i = 0; i < run_graph_response->num_partition_graphs(); i++) {
    144     partition_graph_defs->Add()->Swap(
    145         run_graph_response->mutable_partition_graph(i));
    146   }
    147 }
    148 
    149 void CheckRunStepResponse(const MutableRunStepResponseWrapper& response) {
    150   ASSERT_EQ(2, response.num_tensors());
    151   EXPECT_EQ("fetch_x:0", response.tensor_name(0));
    152   EXPECT_EQ("fetch_y:0", response.tensor_name(1));
    153   Tensor val;
    154   TF_EXPECT_OK(response.TensorValue(0, &val));
    155   test::ExpectTensorEqual<int32>(TensorA(), val);
    156   TF_EXPECT_OK(response.TensorValue(1, &val));
    157   test::ExpectTensorEqual<int32>(TensorB(), val);
    158   ASSERT_EQ(1, response.metadata().step_stats().dev_stats_size());
    159   EXPECT_EQ("/cpu:0", response.metadata().step_stats().dev_stats(0).device());
    160   ASSERT_EQ(1, response.metadata().partition_graphs_size());
    161   EXPECT_EQ(1234,
    162             response.metadata().partition_graphs(0).versions().producer());
    163   EXPECT_EQ(1234,
    164             response.metadata().partition_graphs(0).versions().min_consumer());
    165 }
    166 
    167 TEST(MessageWrappers, RunStepRequest_Basic) {
    168   InMemoryRunStepRequest in_memory_request;
    169   BuildRunStepRequest(&in_memory_request);
    170   CheckRunStepRequest(in_memory_request);
    171 
    172   MutableProtoRunStepRequest proto_request;
    173   BuildRunStepRequest(&proto_request);
    174   CheckRunStepRequest(proto_request);
    175 
    176   CheckRunStepRequest(ProtoRunStepRequest(&in_memory_request.ToProto()));
    177   CheckRunStepRequest(ProtoRunStepRequest(&proto_request.ToProto()));
    178 }
    179 
    180 TEST(MessageWrappers, RunGraphRequest_Basic) {
    181   InMemoryRunStepRequest in_memory_run_step_request;
    182   BuildRunStepRequest(&in_memory_run_step_request);
    183 
    184   MutableProtoRunStepRequest mutable_proto_run_step_request;
    185   BuildRunStepRequest(&mutable_proto_run_step_request);
    186 
    187   ProtoRunStepRequest proto_run_step_request(
    188       &mutable_proto_run_step_request.ToProto());
    189 
    190   // Client -(in memory)-> Master -(in memory)-> Worker.
    191   {
    192     InMemoryRunGraphRequest request;
    193     BuildRunGraphRequest(in_memory_run_step_request, &request);
    194     CheckRunGraphRequest(request);
    195     CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto()));
    196   }
    197 
    198   // Client -(mutable proto)-> Master -(in memory)-> Worker.
    199   {
    200     InMemoryRunGraphRequest request;
    201     BuildRunGraphRequest(mutable_proto_run_step_request, &request);
    202     CheckRunGraphRequest(request);
    203     CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto()));
    204   }
    205 
    206   // Client -(proto)-> Master -(in memory)-> Worker.
    207   {
    208     InMemoryRunGraphRequest request;
    209     BuildRunGraphRequest(proto_run_step_request, &request);
    210     CheckRunGraphRequest(request);
    211     CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto()));
    212   }
    213 
    214   // Client -(in memory)-> Master -(mutable proto)-> Worker.
    215   {
    216     MutableProtoRunGraphRequest request;
    217     BuildRunGraphRequest(in_memory_run_step_request, &request);
    218     CheckRunGraphRequest(request);
    219     CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto()));
    220   }
    221 
    222   // Client -(mutable proto)-> Master -(mutable proto)-> Worker.
    223   {
    224     MutableProtoRunGraphRequest request;
    225     BuildRunGraphRequest(mutable_proto_run_step_request, &request);
    226     CheckRunGraphRequest(request);
    227     CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto()));
    228   }
    229 
    230   // Client -(proto)-> Master -(mutable proto)-> Worker.
    231   {
    232     MutableProtoRunGraphRequest request;
    233     BuildRunGraphRequest(proto_run_step_request, &request);
    234     CheckRunGraphRequest(request);
    235     CheckRunGraphRequest(ProtoRunGraphRequest(&request.ToProto()));
    236   }
    237 }
    238 
    239 TEST(MessageWrappers, RunGraphResponse_Basic) {
    240   InMemoryRunGraphResponse in_memory_response;
    241   BuildRunGraphResponse(&in_memory_response);
    242   CheckRunGraphResponse(&in_memory_response);
    243 
    244   OwnedProtoRunGraphResponse owned_proto_response;
    245   BuildRunGraphResponse(&owned_proto_response);
    246   CheckRunGraphResponse(&owned_proto_response);
    247 
    248   RunGraphResponse response_proto;
    249   NonOwnedProtoRunGraphResponse non_owned_proto_response(&response_proto);
    250   BuildRunGraphResponse(&non_owned_proto_response);
    251   CheckRunGraphResponse(&non_owned_proto_response);
    252 }
    253 
    254 TEST(MessageWrappers, RunStepResponse_Basic) {
    255   {
    256     // Worker -(in memory)-> Master -(in memory)-> Client.
    257     InMemoryRunGraphResponse run_graph_response;
    258     BuildRunGraphResponse(&run_graph_response);
    259     InMemoryRunStepResponse response;
    260     BuildRunStepResponse(&run_graph_response, &response);
    261     CheckRunStepResponse(response);
    262   }
    263 
    264   {
    265     // Worker -(in memory)-> Master -(owned proto)-> Client.
    266     InMemoryRunGraphResponse run_graph_response;
    267     BuildRunGraphResponse(&run_graph_response);
    268     OwnedProtoRunStepResponse response;
    269     BuildRunStepResponse(&run_graph_response, &response);
    270     CheckRunStepResponse(response);
    271   }
    272 
    273   {
    274     // Worker -(in memory)-> Master -(non-owned proto)-> Client.
    275     InMemoryRunGraphResponse run_graph_response;
    276     BuildRunGraphResponse(&run_graph_response);
    277     RunStepResponse response_proto;
    278     NonOwnedProtoRunStepResponse response(&response_proto);
    279     BuildRunStepResponse(&run_graph_response, &response);
    280     CheckRunStepResponse(response);
    281   }
    282 
    283   {
    284     // Worker -(owned proto)-> Master -(in memory)-> Client.
    285     OwnedProtoRunGraphResponse run_graph_response;
    286     BuildRunGraphResponse(&run_graph_response);
    287     InMemoryRunStepResponse response;
    288     BuildRunStepResponse(&run_graph_response, &response);
    289     CheckRunStepResponse(response);
    290   }
    291 
    292   {
    293     // Worker -(owned proto)-> Master -(owned proto)-> Client.
    294     OwnedProtoRunGraphResponse run_graph_response;
    295     BuildRunGraphResponse(&run_graph_response);
    296     OwnedProtoRunStepResponse response;
    297     BuildRunStepResponse(&run_graph_response, &response);
    298     CheckRunStepResponse(response);
    299   }
    300 
    301   {
    302     // Worker -(owned proto)-> Master -(non-owned proto)-> Client.
    303     OwnedProtoRunGraphResponse run_graph_response;
    304     BuildRunGraphResponse(&run_graph_response);
    305     RunStepResponse response_proto;
    306     NonOwnedProtoRunStepResponse response(&response_proto);
    307     BuildRunStepResponse(&run_graph_response, &response);
    308     CheckRunStepResponse(response);
    309   }
    310 
    311   {
    312     // Worker -(non-owned proto)-> Master -(in memory)-> Client.
    313     RunGraphResponse run_graph_response_proto;
    314     NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto);
    315     BuildRunGraphResponse(&run_graph_response);
    316     InMemoryRunStepResponse response;
    317     BuildRunStepResponse(&run_graph_response, &response);
    318     CheckRunStepResponse(response);
    319   }
    320 
    321   {
    322     // Worker -(non-owned proto)-> Master -(owned proto)-> Client.
    323     RunGraphResponse run_graph_response_proto;
    324     NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto);
    325     BuildRunGraphResponse(&run_graph_response);
    326     OwnedProtoRunStepResponse response;
    327     BuildRunStepResponse(&run_graph_response, &response);
    328     CheckRunStepResponse(response);
    329   }
    330 
    331   {
    332     // Worker -(non-owned proto)-> Master -(non-owned proto)-> Client.
    333     RunGraphResponse run_graph_response_proto;
    334     NonOwnedProtoRunGraphResponse run_graph_response(&run_graph_response_proto);
    335     BuildRunGraphResponse(&run_graph_response);
    336     RunStepResponse response_proto;
    337     NonOwnedProtoRunStepResponse response(&response_proto);
    338     BuildRunStepResponse(&run_graph_response, &response);
    339     CheckRunStepResponse(response);
    340   }
    341 }
    342 
    343 }  // namespace
    344 }  // namespace tensorflow
    345