Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/master.h"
     17 
     18 #include <map>
     19 #include <memory>
     20 
     21 #include "grpc++/grpc++.h"
     22 
     23 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
     24 #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
     25 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     26 #include "tensorflow/core/framework/allocator.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_testutil.h"
     29 #include "tensorflow/core/graph/testlib.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/lib/core/notification.h"
     32 #include "tensorflow/core/lib/core/status_test_util.h"
     33 #include "tensorflow/core/lib/core/threadpool.h"
     34 #include "tensorflow/core/lib/gtl/map_util.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/mutex.h"
     37 #include "tensorflow/core/platform/test.h"
     38 #include "tensorflow/core/platform/types.h"
     39 #include "tensorflow/core/protobuf/master.pb.h"
     40 #include "tensorflow/core/protobuf/master_service.grpc.pb.h"
     41 
     42 namespace tensorflow {
     43 
     44 class MasterTest : public ::testing::Test {
     45  protected:
     46   MasterTest() {
     47     std::vector<string> targets;
     48     SessionOptions options;
     49     (*options.config.mutable_device_count())["CPU"] = 1;
     50     (*options.config.mutable_device_count())["GPU"] = 0;
     51     TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_));
     52     SharedGrpcChannelPtr channel_ptr;
     53     TF_CHECK_OK(NewHostPortGrpcChannel(cluster_->targets()[0], &channel_ptr));
     54     master_ = grpc::MasterService::NewStub(channel_ptr);
     55   }
     56 
     57   std::unique_ptr<test::TestCluster> cluster_;
     58   std::unique_ptr<grpc::MasterService::Stub> master_;
     59 
     60   // Helpers for MasterService.{CreateSession,RunStep,CloseSession}
     61   // rpc calls.
     62 
     63   Status CreateSession(const GraphDef& def, string* handle,
     64                        int64* initial_version) {
     65     ::grpc::ClientContext ctx;
     66     CreateSessionRequest req;
     67     *(req.mutable_graph_def()) = def;
     68     // Invokes placement frequently.
     69     req.mutable_config()->set_placement_period(1);
     70     CreateSessionResponse resp;
     71     const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp));
     72     if (s.ok()) {
     73       *handle = resp.session_handle();
     74       *initial_version = resp.graph_version();
     75     }
     76     return s;
     77   }
     78 
     79   Status ExtendSession(const string& handle, const GraphDef& def,
     80                        int64 current_version, int64* new_version) {
     81     ::grpc::ClientContext ctx;
     82     ExtendSessionRequest req;
     83     req.set_session_handle(handle);
     84     *(req.mutable_graph_def()) = def;
     85     req.set_current_graph_version(current_version);
     86     ExtendSessionResponse resp;
     87     const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp));
     88     if (s.ok()) {
     89       *new_version = resp.new_graph_version();
     90     }
     91     return s;
     92   }
     93 
     94   Status RunStep(const string& handle,
     95                  const std::vector<std::pair<string, const Tensor*> >& feed,
     96                  const std::map<string, Tensor*>& fetch) {
     97     ::grpc::ClientContext ctx;
     98     RunStepRequest req;
     99     req.set_session_handle(handle);
    100     for (const auto& p : feed) {
    101       const string& feed_name = p.first;
    102       const Tensor* feed_tensor = p.second;
    103       auto f = req.add_feed();
    104       f->set_name(feed_name);
    105       feed_tensor->AsProtoTensorContent(f->mutable_tensor());
    106     }
    107     for (const auto& p : fetch) {
    108       const string& fetch_name = p.first;
    109       req.add_fetch(fetch_name);
    110     }
    111     RunStepResponse resp;
    112     const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp));
    113     if (s.ok()) {
    114       for (const auto& fetch_resp : resp.tensor()) {
    115         auto it = fetch.find(fetch_resp.name());
    116         CHECK(it != fetch.end());
    117         CHECK(it->second->FromProto(fetch_resp.tensor()));
    118       }
    119     }
    120     return s;
    121   }
    122 
    123   Status CloseSession(const string& handle) {
    124     ::grpc::ClientContext ctx;
    125     CloseSessionRequest req;
    126     req.set_session_handle(handle);
    127     CloseSessionResponse resp;
    128     return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp));
    129   }
    130 
    131   Status Reset() {
    132     ::grpc::ClientContext ctx;
    133     ResetRequest req;
    134     ResetResponse resp;
    135     return FromGrpcStatus(master_->Reset(&ctx, req, &resp));
    136   }
    137 };
    138 
    139 TEST_F(MasterTest, CreateClose) {
    140   GraphDef def;  // Empty.
    141   string handle;
    142   int64 initial_version;
    143   TF_ASSERT_OK(CreateSession(def, &handle, &initial_version));
    144   EXPECT_TRUE(errors::IsAborted(CloseSession("randombits")));
    145   EXPECT_TRUE(CloseSession(handle).ok());
    146 }
    147 
    148 TEST_F(MasterTest, ListDevices) {
    149   ::grpc::ClientContext ctx;
    150   ListDevicesRequest req;
    151   ListDevicesResponse resp;
    152   const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp));
    153   TF_EXPECT_OK(s);
    154   EXPECT_EQ(1, resp.local_device_size());
    155   EXPECT_EQ("CPU", resp.local_device(0).device_type());
    156 }
    157 
    158 TEST_F(MasterTest, Reset) {
    159   GraphDef def;  // Empty.
    160   string s1, s2;
    161   int64 initial_version1, initial_version2;
    162   TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1));
    163   TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2));
    164   EXPECT_TRUE(Reset().ok());
    165   EXPECT_TRUE(errors::IsAborted(CloseSession(s1)));
    166   EXPECT_TRUE(errors::IsAborted(CloseSession(s2)));
    167 }
    168 
    169 TEST_F(MasterTest, Extend) {
    170   GraphDef def_0;  // Empty.
    171   string handle;
    172   int64 initial_version;
    173   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
    174 
    175   Tensor A_expected(DT_FLOAT, TensorShape({2, 2}));
    176   test::FillValues<float>(&A_expected, {3.0, 2.0, -1.0, 0.0});
    177 
    178   Tensor x_expected(DT_FLOAT, TensorShape({2, 1}));
    179   test::FillValues<float>(&x_expected, {2.0, 2.0});
    180 
    181   Graph graph_1(OpRegistry::Global());
    182   test::graph::Constant(&graph_1, A_expected, "A");
    183   GraphDef def_1;
    184   test::graph::ToGraphDef(&graph_1, &def_1);
    185   int64 version_1;
    186   TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
    187   EXPECT_GT(version_1, initial_version);
    188   Tensor A(DT_FLOAT, TensorShape({2, 2}));
    189   TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
    190   test::ExpectTensorEqual<float>(A, A_expected);
    191 
    192   Graph graph_2(OpRegistry::Global());
    193   test::graph::Constant(&graph_2, x_expected, "x");
    194   GraphDef def_2;
    195   test::graph::ToGraphDef(&graph_2, &def_2);
    196   int64 version_2;
    197   EXPECT_TRUE(errors::IsAborted(
    198       ExtendSession("randombits", def_2, version_1, &version_2)));
    199   TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2));
    200   EXPECT_GT(version_2, version_1);
    201 
    202   Tensor x(DT_FLOAT, TensorShape({2, 1}));
    203   TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}}));
    204   test::ExpectTensorEqual<float>(A, A_expected);
    205   test::ExpectTensorEqual<float>(x, x_expected);
    206 
    207   TF_ASSERT_OK(CloseSession(handle));
    208 }
    209 
    210 TEST_F(MasterTest, ExtendUpdateStatefulFails) {
    211   GraphDef def_0;  // Empty.
    212   string handle;
    213   int64 initial_version;
    214   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
    215 
    216   Graph graph_1(OpRegistry::Global());
    217   test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
    218   GraphDef def_1;
    219   test::graph::ToGraphDef(&graph_1, &def_1);
    220 
    221   int64 version_1, version_2;
    222   TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
    223   EXPECT_GT(version_1, initial_version);
    224   EXPECT_TRUE(errors::IsInvalidArgument(
    225       ExtendSession(handle, def_1, version_1, &version_2)));
    226   TF_ASSERT_OK(CloseSession(handle));
    227 }
    228 
    229 TEST_F(MasterTest, ExtendTwiceFails) {
    230   GraphDef def_0;  // Empty.
    231   string handle;
    232   int64 initial_version;
    233   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
    234 
    235   Graph graph_1(OpRegistry::Global());
    236   test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
    237   GraphDef def_1;
    238   test::graph::ToGraphDef(&graph_1, &def_1);
    239 
    240   int64 version_1;
    241   TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
    242   EXPECT_GT(version_1, initial_version);
    243   EXPECT_TRUE(errors::IsAborted(
    244       ExtendSession(handle, def_1, initial_version, &version_1)));
    245   TF_ASSERT_OK(CloseSession(handle));
    246 }
    247 
    248 TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) {
    249   GraphDef def_0;  // Empty.
    250   string handle;
    251   int64 initial_version;
    252   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
    253 
    254   Graph graph_1(OpRegistry::Global());
    255   test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
    256   GraphDef def_1;
    257   test::graph::ToGraphDef(&graph_1, &def_1);
    258 
    259   Notification n;
    260   mutex mu;
    261   int succeeded = 0;
    262   int failed = 0;
    263   auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded,
    264                     &failed]() {
    265     n.WaitForNotification();
    266     int64 new_version;
    267     Status s = ExtendSession(handle, def_1, initial_version, &new_version);
    268     EXPECT_TRUE(s.ok() || errors::IsAborted(s));
    269     {
    270       mutex_lock l(mu);
    271       if (s.ok()) {
    272         ++succeeded;
    273       } else {
    274         ++failed;
    275       }
    276     }
    277   };
    278 
    279   // Run 100 concurrent Extend calls and expect only one to succeed.
    280   {
    281     thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100);
    282     for (int i = 0; i < 100; ++i) {
    283       thread_pool.Schedule(extend_fn);
    284     }
    285     n.Notify();
    286   }
    287 
    288   EXPECT_EQ(failed, 99);
    289   EXPECT_EQ(succeeded, 1);
    290   TF_ASSERT_OK(CloseSession(handle));
    291 }
    292 
    293 TEST_F(MasterTest, ConcurrentExtendAndRun) {
    294   Graph graph_0(OpRegistry::Global());
    295   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
    296   test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
    297   test::graph::Constant(&graph_0, a_tensor, "A");
    298   GraphDef def_0;
    299   test::graph::ToGraphDef(&graph_0, &def_0);
    300 
    301   string handle;
    302   int64 initial_version;
    303   TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
    304 
    305   Graph graph_1(OpRegistry::Global());
    306   Tensor b_tensor(DT_FLOAT, TensorShape({2, 2}));
    307   test::FillValues<float>(&b_tensor, {1, 0, 0, 1});
    308   test::graph::Constant(&graph_1, b_tensor, "B");
    309   GraphDef def_1;
    310   test::graph::ToGraphDef(&graph_1, &def_1);
    311 
    312   Notification extend_done;
    313   Notification extend_can_start;
    314 
    315   auto get_a_fn = [this, handle, &extend_done]() {
    316     Tensor A(DT_FLOAT, TensorShape({2, 2}));
    317     while (!extend_done.HasBeenNotified()) {
    318       TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
    319     }
    320     // Run at least once after the Extend has completed.
    321     TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
    322   };
    323 
    324   auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() {
    325     Tensor A(DT_FLOAT, TensorShape({2, 2}));
    326     Tensor B(DT_FLOAT, TensorShape({2, 2}));
    327 
    328     // Run at least once before the Extend has completed.
    329     EXPECT_TRUE(
    330         errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})));
    331     extend_can_start.Notify();
    332 
    333     // Concurrent with the Extend, we will either fail (as above), or
    334     // succeed (as below).
    335     while (!extend_done.HasBeenNotified()) {
    336       Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}});
    337       EXPECT_TRUE(errors::IsNotFound(s) || s.ok());
    338     }
    339 
    340     // Run at least once after the Extend has completed.
    341     TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}));
    342   };
    343 
    344   auto extend_fn = [this, handle, def_1, initial_version, &extend_done,
    345                     &extend_can_start]() {
    346     extend_can_start.WaitForNotification();
    347     int64 version_1;
    348     TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
    349     extend_done.Notify();
    350   };
    351 
    352   {
    353     thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3);
    354     thread_pool.Schedule(get_a_fn);
    355     thread_pool.Schedule(get_a_and_b_fn);
    356     thread_pool.Schedule(extend_fn);
    357   }
    358 
    359   TF_ASSERT_OK(CloseSession(handle));
    360 }
    361 
    362 TEST_F(MasterTest, EigenProblem) {
    363   // A = [3 2; -1 0]; x = rand(2, 1);
    364   // for i=1:100; x = A * x; end
    365   // We'll try to compute the largest eigenvalue for A.
    366   Graph graph(OpRegistry::Global());
    367   Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
    368   // Store rows [3, 2] and [-1, 0] in row major format.
    369   test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
    370   Node* a_node = test::graph::Constant(&graph, a_tensor);
    371 
    372   // x is from the feed.
    373   Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
    374   test::FillValues<float>(&x_tensor, {0, 0});
    375   Node* x_node = test::graph::Constant(&graph, x_tensor);
    376 
    377   // y = A * x
    378   Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false);
    379 
    380   GraphDef def;
    381   test::graph::ToGraphDef(&graph, &def);
    382 
    383   string handle;
    384   int64 initial_version;
    385   TF_CHECK_OK(CreateSession(def, &handle, &initial_version));
    386 
    387   // Temps supporting the computation of the convergence condition.
    388   const Eigen::array<Eigen::DenseIndex, 1> sum_along_dim(0);
    389   const Eigen::array<Eigen::DenseIndex, 2> matrix_transpose({1, 0});
    390   Tensor x(DT_FLOAT, TensorShape({2, 1}));
    391   Tensor y(DT_FLOAT, TensorShape({2, 1}));
    392   Eigen::Tensor<float, 1, Eigen::RowMajor> y_square_sum;
    393   Eigen::Tensor<float, 2, Eigen::RowMajor> y_normalized(2, 1);
    394   y_normalized.setRandom();
    395   Eigen::Tensor<float, 1, Eigen::RowMajor> error_square_sum;
    396   float lambda;
    397 
    398   // The computation loop.
    399   bool converged = false;
    400   while (!converged) {
    401     // Run one step of the graph.
    402     auto x_matrix = x.matrix<float>();
    403     x_matrix = y_normalized;
    404     TF_EXPECT_OK(
    405         RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}}));
    406     auto y_matrix = y.matrix<float>();
    407 
    408     // Client code computes the convergence condition.
    409     {
    410       lambda = y_matrix(0, 0) / x_matrix(0, 0);
    411       y_square_sum = y.matrix<float>().square().sum(sum_along_dim);
    412       const float norm = static_cast<float>(sqrt(y_square_sum(0)));
    413       y_normalized = y_matrix * (1 / norm);
    414       error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim);
    415       VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = ["
    416               << y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda;
    417       converged = sqrt(error_square_sum(0)) < 1e-10;
    418     }
    419   }
    420   EXPECT_NEAR(lambda, 2.0, 0.01);
    421   TF_EXPECT_OK(CloseSession(handle));
    422 }
    423 
    424 }  // namespace tensorflow
    425