Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/rendezvous.h"
     17 
     18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/framework/tensor_shape.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 #include "tensorflow/core/framework/types.pb.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/core/notification.h"
     25 #include "tensorflow/core/lib/core/status_test_util.h"
     26 #include "tensorflow/core/lib/core/threadpool.h"
     27 #include "tensorflow/core/lib/random/simple_philox.h"
     28 #include "tensorflow/core/lib/strings/strcat.h"
     29 #include "tensorflow/core/platform/env.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/mutex.h"
     32 #include "tensorflow/core/platform/test.h"
     33 #include "tensorflow/core/platform/test_benchmark.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace tensorflow {
     37 namespace {
     38 
     39 TEST(RendezvousTest, Key) {
     40   const string key = Rendezvous::CreateKey(
     41       "/job:mnist/replica:1/task:2/CPU:0", 7890,
     42       "/job:mnist/replica:1/task:2/device:GPU:0", "var0", FrameAndIter(0, 0));
     43   EXPECT_EQ(key,
     44             "/job:mnist/replica:1/task:2/CPU:0;"
     45             "0000000000001ed2;"  // 7890 = 0x1ed2
     46             "/job:mnist/replica:1/task:2/device:GPU:0;"
     47             "var0;"
     48             "0:0");
     49   Rendezvous::ParsedKey parsed;
     50   TF_EXPECT_OK(Rendezvous::ParseKey(key, &parsed));
     51   EXPECT_EQ(parsed.src_device, "/job:mnist/replica:1/task:2/CPU:0");
     52   EXPECT_EQ(parsed.src_incarnation, 7890);
     53   EXPECT_EQ(parsed.src.type, "CPU");
     54   EXPECT_EQ(parsed.dst_device, "/job:mnist/replica:1/task:2/device:GPU:0");
     55   EXPECT_EQ(parsed.dst.type, "GPU");
     56 
     57   EXPECT_FALSE(Rendezvous::ParseKey("foo;bar;baz", &parsed).ok());
     58   EXPECT_FALSE(Rendezvous::ParseKey("/job:mnist/replica:1/task:2/CPU:0;"
     59                                     "/job:mnist/replica:1/task:2/device:GPU:0;",
     60                                     &parsed)
     61                    .ok());
     62   EXPECT_FALSE(
     63       Rendezvous::ParseKey(strings::StrCat(key, ";", key), &parsed).ok());
     64 }
     65 
     66 class LocalRendezvousTest : public ::testing::Test {
     67  public:
     68   LocalRendezvousTest() : threads_(Env::Default(), "test", 16) {
     69     rendez_ = NewLocalRendezvous();
     70   }
     71 
     72   ~LocalRendezvousTest() override { rendez_->Unref(); }
     73 
     74   void SchedClosure(std::function<void()> fn) {
     75     threads_.Schedule(std::move(fn));
     76   }
     77 
     78   Rendezvous* rendez_;
     79 
     80  private:
     81   thread::ThreadPool threads_;
     82 };
     83 
     84 // string -> Tensor<string>
     85 Tensor V(const string& content) {
     86   Tensor tensor(DT_STRING, TensorShape({}));
     87   tensor.scalar<string>()() = content;
     88   return tensor;
     89 }
     90 
     91 // Tensor<string> -> string
     92 string V(const Tensor& tensor) {
     93   CHECK_EQ(tensor.dtype(), DT_STRING);
     94   CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
     95   return tensor.scalar<string>()();
     96 }
     97 
     98 Rendezvous::ParsedKey MakeKey(const string& name) {
     99   string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
    100                                    "/job:mnist/replica:1/task:2/device:GPU:0",
    101                                    name, FrameAndIter(0, 0));
    102   Rendezvous::ParsedKey k;
    103   TF_EXPECT_OK(Rendezvous::ParseKey(s, &k));
    104   return k;
    105 }
    106 
    107 const Rendezvous::ParsedKey& KeyFoo() {
    108   static auto key = MakeKey("foo");
    109   return key;
    110 }
    111 
    112 const Rendezvous::ParsedKey& KeyBar() {
    113   static auto key = MakeKey("bar");
    114   return key;
    115 }
    116 
    117 TEST_F(LocalRendezvousTest, SendRecv) {
    118   Rendezvous::Args args;
    119   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
    120   Tensor val(DT_STRING);
    121   bool is_dead = false;
    122   TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
    123   EXPECT_EQ("hello", V(val));
    124 }
    125 
    126 TEST_F(LocalRendezvousTest, RecvSend) {
    127   SchedClosure([this]() {
    128     Env::Default()->SleepForMicroseconds(10000);
    129     Rendezvous::Args args;
    130     TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
    131   });
    132   Tensor val(DT_STRING);
    133   bool is_dead = false;
    134   Rendezvous::Args args;
    135   TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &val, &is_dead));
    136   EXPECT_EQ("hello", V(val));
    137 }
    138 
    139 TEST_F(LocalRendezvousTest, PingPong) {
    140   SchedClosure([this]() {
    141     Tensor t(DT_STRING);
    142     bool is_dead = false;
    143     Rendezvous::Args args;
    144     TF_ASSERT_OK(rendez_->Recv(KeyFoo(), args, &t, &is_dead));
    145     TF_ASSERT_OK(rendez_->Send(KeyBar(), args, t, is_dead));
    146   });
    147   Env::Default()->SleepForMicroseconds(1000000);
    148   Tensor val(DT_STRING);
    149   bool val_dead = false;
    150   Rendezvous::Args args;
    151   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("secret msg"), val_dead));
    152   TF_ASSERT_OK(rendez_->Recv(KeyBar(), args, &val, &val_dead));
    153   EXPECT_EQ("secret msg", V(val));
    154 }
    155 
    156 // A simple structure that behaves a bit like a blocking counter.  The
    157 // user that decrements counter to 0 does done.Notify(), and the main
    158 // thread waits for done to be notified.
    159 struct BlockingState {
    160   mutex lock;
    161   int counter = 0;
    162   Notification done;
    163 };
    164 
    165 TEST_F(LocalRendezvousTest, RandomSendRecv) {
    166   // We are scheduling 2*N closures in the this->threads_, which is
    167   // configured with only 16 threads. Furthermore, because the
    168   // threadpool may execute the closures in an arbitrary order, we
    169   // must use RecvAsync below. Otherwise, blocking Recv() may run
    170   // before all all the Send() and deadlock.
    171   static const int N = 100;
    172   random::PhiloxRandom philox(testing::RandomSeed(), 17);
    173   random::SimplePhilox rnd(&philox);
    174   BlockingState state;
    175   state.counter = N;
    176   for (int i = 0; i < N; ++i) {
    177     int micros = 100 + rnd.Uniform(1000);
    178     SchedClosure([this, i, micros]() {
    179       Env::Default()->SleepForMicroseconds(micros);
    180       Rendezvous::Args args;
    181       TF_ASSERT_OK(rendez_->Send(MakeKey(strings::StrCat(i)), args,
    182                                  V(strings::StrCat(i)), false));
    183     });
    184     auto recv_done = [this, &state, i](const Status& status,
    185                                        const Rendezvous::Args& sender_args,
    186                                        const Rendezvous::Args& recver_args,
    187                                        const Tensor& val, const bool val_dead) {
    188       EXPECT_EQ(strings::StrCat(i), V(val));
    189       bool done = false;
    190       {
    191         mutex_lock l(state.lock);
    192         state.counter--;
    193         if (state.counter == 0) {
    194           done = true;
    195         }
    196       }
    197       if (done) {
    198         state.done.Notify();
    199       }
    200     };
    201     micros = 100 + rnd.Uniform(1000);
    202     SchedClosure([this, i, micros, recv_done]() {
    203       Env::Default()->SleepForMicroseconds(micros);
    204       rendez_->RecvAsync(MakeKey(strings::StrCat(i)), Rendezvous::Args(),
    205                          recv_done);
    206     });
    207   }
    208 
    209   state.done.WaitForNotification();
    210 }
    211 
    212 void RandomSleep() {
    213   if (std::rand() % 10 == 0) {
    214     Env::Default()->SleepForMicroseconds(1000);
    215   }
    216 }
    217 
    218 TEST_F(LocalRendezvousTest, MultiSends) {
    219   static const int N = 100;
    220   const auto& key_foo = KeyFoo();
    221   Rendezvous::Args args;
    222   SchedClosure([=]() {
    223     for (int i = 0; i < N; ++i) {
    224       TF_ASSERT_OK(rendez_->Send(key_foo, args, V(strings::StrCat(i)), false));
    225       RandomSleep();
    226     }
    227   });
    228   Tensor val;
    229   bool val_dead;
    230   for (int i = 0; i < N; ++i) {
    231     TF_ASSERT_OK(rendez_->Recv(key_foo, args, &val, &val_dead));
    232     RandomSleep();
    233   }
    234 }
    235 
    236 TEST_F(LocalRendezvousTest, RecvAbort) {
    237   rendez_->Ref();
    238   SchedClosure([this]() {
    239     rendez_->StartAbort(errors::Aborted(""));  // abort
    240     rendez_->Unref();
    241   });
    242   Tensor val(DT_STRING);
    243   bool val_dead = false;
    244   Rendezvous::Args args;
    245   Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
    246   EXPECT_TRUE(errors::IsAborted(status));
    247 }
    248 
    249 // Similar to RecvAbort. But this test case ensures the main thread
    250 // Recv() call happens after StartAbort().
    251 TEST_F(LocalRendezvousTest, RecvSleepAbort) {
    252   rendez_->Ref();
    253   SchedClosure([this]() {
    254     Env::Default()->SleepForMicroseconds(1000000);
    255     rendez_->StartAbort(errors::Aborted(""));  // abort
    256     rendez_->Unref();
    257   });
    258   Tensor val(DT_STRING);
    259   bool val_dead = false;
    260   Rendezvous::Args args;
    261   Status status = rendez_->Recv(KeyFoo(), args, &val, &val_dead);
    262   EXPECT_TRUE(errors::IsAborted(status));
    263 }
    264 
    265 TEST_F(LocalRendezvousTest, AbortThenRecvOrSend) {
    266   rendez_->StartAbort(errors::Aborted(""));
    267   Tensor val(DT_STRING);
    268   bool val_dead = false;
    269   Rendezvous::Args args;
    270   EXPECT_TRUE(errors::IsAborted(rendez_->Send(KeyFoo(), args, val, val_dead)));
    271   EXPECT_TRUE(
    272       errors::IsAborted(rendez_->Recv(KeyFoo(), args, &val, &val_dead)));
    273 }
    274 
    275 class DummyDeviceContext : public DeviceContext {
    276  public:
    277   explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
    278   ~DummyDeviceContext() override {}
    279   int stream_id() const { return stream_id_; }
    280 
    281  private:
    282   const int stream_id_;
    283 };
    284 
    285 TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
    286   Rendezvous::Args args;
    287   args.device_context = new DummyDeviceContext(123);
    288 
    289   TF_ASSERT_OK(rendez_->Send(KeyFoo(), args, V("hello"), false));
    290 
    291   Notification n;
    292   Rendezvous::Args args1;
    293   args1.device_context = new DummyDeviceContext(1);
    294   rendez_->RecvAsync(
    295       KeyFoo(), args1,
    296       [&n](const Status& s, const Rendezvous::Args& send_args,
    297            const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) {
    298         CHECK_EQ(123, dynamic_cast<const DummyDeviceContext*>(
    299                           send_args.device_context)
    300                           ->stream_id());
    301         n.Notify();
    302       });
    303 
    304   n.WaitForNotification();
    305   args.device_context->Unref();
    306   args1.device_context->Unref();
    307 }
    308 
    309 void BM_SendRecv(int iters) {
    310   Rendezvous* rendez = NewLocalRendezvous();
    311   Tensor orig = V("val");
    312   Tensor val(DT_STRING, TensorShape({}));
    313   bool is_dead = false;
    314   Rendezvous::Args args;
    315   Status s;
    316   if (iters > 0) {
    317     while (iters--) {
    318       TF_CHECK_OK(rendez->Send(KeyFoo(), args, orig, is_dead));
    319       TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &val, &is_dead));
    320     }
    321     CHECK_EQ(V(val), V(orig));
    322   }
    323   rendez->Unref();
    324 }
    325 BENCHMARK(BM_SendRecv);
    326 
    327 void BM_PingPong(int iters) {
    328   CHECK_GT(iters, 0);
    329   thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
    330 
    331   // The main thread sends "foo" for iters times and receives "bar"
    332   // for iters times.  The other thread sends "bar" for iters times
    333   // and receives "foo" for iters times.
    334   Rendezvous* rendez = NewLocalRendezvous();
    335   pool->Schedule([rendez, iters]() {
    336     Tensor bar = V("bar");
    337     Tensor foo(DT_STRING, TensorShape({}));
    338     bool is_dead = false;
    339     Rendezvous::Args args;
    340     Status s;
    341     for (int i = 0; i < iters; ++i) {
    342       TF_CHECK_OK(rendez->Recv(KeyFoo(), args, &foo, &is_dead));
    343       TF_CHECK_OK(rendez->Send(KeyBar(), args, bar, is_dead));
    344     }
    345     CHECK_EQ("foo", V(foo));
    346   });
    347   Tensor foo = V("foo");
    348   Tensor bar(DT_STRING, TensorShape({}));
    349   bool is_dead = false;
    350   Rendezvous::Args args;
    351   Status s;
    352   for (int i = 0; i < iters; ++i) {
    353     TF_CHECK_OK(rendez->Send(KeyFoo(), args, foo, is_dead));
    354     TF_CHECK_OK(rendez->Recv(KeyBar(), args, &bar, &is_dead));
    355   }
    356   CHECK_EQ("bar", V(bar));
    357   delete pool;
    358 }
    359 BENCHMARK(BM_PingPong);
    360 
    361 }  // namespace
    362 }  // namespace tensorflow
    363