Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/partial_run_mgr.h"
     17 
     18 #include "tensorflow/core/lib/core/notification.h"
     19 #include "tensorflow/core/platform/test.h"
     20 
     21 namespace tensorflow {
     22 namespace {
     23 
     24 TEST(PartialRunMgrFindOrCreate, Create) {
     25   // Basic test of PartialRunMgr CancellationManager creation.
     26   PartialRunMgr partial_run_mgr;
     27   int step_id = 1;
     28   CancellationManager* cancellation_manager;
     29   partial_run_mgr.FindOrCreate(step_id, &cancellation_manager);
     30   EXPECT_TRUE(cancellation_manager != nullptr);
     31 }
     32 
     33 TEST(PartialRunMgrFindOrCreate, Find) {
     34   // Basic test of PartialRunMgr CancellationManager find.
     35   PartialRunMgr partial_run_mgr;
     36   int step_id = 1;
     37   CancellationManager* cancellation_manager;
     38   partial_run_mgr.FindOrCreate(step_id, &cancellation_manager);
     39   // Looking for the same step should return the same cancellation_manager.
     40   CancellationManager* found_cancellation_manager;
     41   partial_run_mgr.FindOrCreate(step_id, &found_cancellation_manager);
     42   EXPECT_EQ(cancellation_manager, found_cancellation_manager);
     43 }
     44 
     45 TEST(PartialRunMgrFindOrCreate, NewCreate) {
     46   // Test that PartialRunMgr creates a new CancellationManager for new steps.
     47   PartialRunMgr partial_run_mgr;
     48   int step_id = 1;
     49   CancellationManager* cancellation_manager;
     50   partial_run_mgr.FindOrCreate(step_id, &cancellation_manager);
     51   // FindOrCreate on a new step should return a new cancellation_manager.
     52   int new_step_id = 2;
     53   CancellationManager* new_cancellation_manager;
     54   partial_run_mgr.FindOrCreate(new_step_id, &new_cancellation_manager);
     55   EXPECT_NE(cancellation_manager, new_cancellation_manager);
     56 }
     57 
     58 TEST(PartialRunMgr, PartialRunRemoved) {
     59   // Test that PartialRunMgr ensures that the PartialRun is deleted after
     60   // ExecutorDone and PartialRunDone are called.
     61   PartialRunMgr partial_run_mgr;
     62   int step_id = 1;
     63   CancellationManager* cancellation_manager;
     64   partial_run_mgr.FindOrCreate(step_id, &cancellation_manager);
     65 
     66   int called = 0;
     67   partial_run_mgr.PartialRunDone(
     68       step_id, [&called](Status status) { called++; }, Status::OK());
     69   partial_run_mgr.ExecutorDone(step_id, Status::OK());
     70 
     71   // Calling ExecutorDone and PartialRunDone on the step_id should still only
     72   // result in the callback being called once.
     73   // This proves that the original PartialRun has been removed.
     74   partial_run_mgr.PartialRunDone(
     75       step_id, [&called](Status status) { called++; }, Status::OK());
     76   partial_run_mgr.ExecutorDone(step_id, Status::OK());
     77   EXPECT_EQ(1, called);
     78 }
     79 
     80 struct StatusTestParam {
     81   Status executor_status;
     82   Status partial_run_status;
     83   Status expected_status;
     84 };
     85 
     86 class StatusPropagationTest : public ::testing::TestWithParam<StatusTestParam> {
     87  protected:
     88   PartialRunMgr partial_run_mgr_;
     89 
     90   // State to help keep track of when the callback is called.
     91   Notification invoked_;
     92   Status status_;
     93 
     94   void set_status(const Status& status) {
     95     status_ = status;
     96     invoked_.Notify();
     97   }
     98 
     99   // Blocks until status is set.
    100   Status status() {
    101     invoked_.WaitForNotification();
    102     return status_;
    103   }
    104 };
    105 
    106 TEST_P(StatusPropagationTest, ExecutorDoneFirst) {
    107   // Tests error propagation when ExecutorDone is called first.
    108   StatusTestParam param = GetParam();
    109   int step_id = 1;
    110 
    111   CancellationManager* cancellation_manager;
    112   partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager);
    113 
    114   partial_run_mgr_.ExecutorDone(step_id, param.executor_status);
    115   partial_run_mgr_.PartialRunDone(step_id,
    116                                   [this](Status status) { set_status(status); },
    117                                   param.partial_run_status);
    118 
    119   EXPECT_EQ(status(), param.expected_status);
    120 }
    121 
    122 TEST_P(StatusPropagationTest, PartialRunDoneFirst) {
    123   // Tests error propagation when PartialRunDone is called first.
    124   StatusTestParam param = GetParam();
    125   int step_id = 1;
    126 
    127   CancellationManager* cancellation_manager;
    128   partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager);
    129 
    130   partial_run_mgr_.PartialRunDone(step_id,
    131                                   [this](Status status) { set_status(status); },
    132                                   param.partial_run_status);
    133   partial_run_mgr_.ExecutorDone(step_id, param.executor_status);
    134 
    135   EXPECT_EQ(status(), param.expected_status);
    136 }
    137 
    138 // Instantiate tests for all error orderings, for both call orders of
    139 // ExecutorDone and PartialRunDone.
    140 Status ExecutorError() { return errors::Internal("executor error"); }
    141 Status PartialRunError() { return errors::Internal("partial run error"); }
    142 INSTANTIATE_TEST_CASE_P(
    143     PartialRunMgr, StatusPropagationTest,
    144     ::testing::Values(
    145         StatusTestParam{Status::OK(), Status::OK(), Status::OK()},
    146         StatusTestParam{ExecutorError(), Status::OK(), ExecutorError()},
    147         StatusTestParam{Status::OK(), PartialRunError(), PartialRunError()},
    148         StatusTestParam{ExecutorError(), PartialRunError(), ExecutorError()}));
    149 
    150 }  // namespace
    151 }  // namespace tensorflow
    152