1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/partial_run_mgr.h" 17 18 #include "tensorflow/core/lib/core/notification.h" 19 #include "tensorflow/core/platform/test.h" 20 21 namespace tensorflow { 22 namespace { 23 24 TEST(PartialRunMgrFindOrCreate, Create) { 25 // Basic test of PartialRunMgr CancellationManager creation. 26 PartialRunMgr partial_run_mgr; 27 int step_id = 1; 28 CancellationManager* cancellation_manager; 29 partial_run_mgr.FindOrCreate(step_id, &cancellation_manager); 30 EXPECT_TRUE(cancellation_manager != nullptr); 31 } 32 33 TEST(PartialRunMgrFindOrCreate, Find) { 34 // Basic test of PartialRunMgr CancellationManager find. 35 PartialRunMgr partial_run_mgr; 36 int step_id = 1; 37 CancellationManager* cancellation_manager; 38 partial_run_mgr.FindOrCreate(step_id, &cancellation_manager); 39 // Looking for the same step should return the same cancellation_manager. 40 CancellationManager* found_cancellation_manager; 41 partial_run_mgr.FindOrCreate(step_id, &found_cancellation_manager); 42 EXPECT_EQ(cancellation_manager, found_cancellation_manager); 43 } 44 45 TEST(PartialRunMgrFindOrCreate, NewCreate) { 46 // Test that PartialRunMgr creates a new CancellationManager for new steps. 47 PartialRunMgr partial_run_mgr; 48 int step_id = 1; 49 CancellationManager* cancellation_manager; 50 partial_run_mgr.FindOrCreate(step_id, &cancellation_manager); 51 // FindOrCreate on a new step should return a new cancellation_manager. 52 int new_step_id = 2; 53 CancellationManager* new_cancellation_manager; 54 partial_run_mgr.FindOrCreate(new_step_id, &new_cancellation_manager); 55 EXPECT_NE(cancellation_manager, new_cancellation_manager); 56 } 57 58 TEST(PartialRunMgr, PartialRunRemoved) { 59 // Test that PartialRunMgr ensures that the PartialRun is deleted after 60 // ExecutorDone and PartialRunDone are called. 61 PartialRunMgr partial_run_mgr; 62 int step_id = 1; 63 CancellationManager* cancellation_manager; 64 partial_run_mgr.FindOrCreate(step_id, &cancellation_manager); 65 66 int called = 0; 67 partial_run_mgr.PartialRunDone( 68 step_id, [&called](Status status) { called++; }, Status::OK()); 69 partial_run_mgr.ExecutorDone(step_id, Status::OK()); 70 71 // Calling ExecutorDone and PartialRunDone on the step_id should still only 72 // result in the callback being called once. 73 // This proves that the original PartialRun has been removed. 74 partial_run_mgr.PartialRunDone( 75 step_id, [&called](Status status) { called++; }, Status::OK()); 76 partial_run_mgr.ExecutorDone(step_id, Status::OK()); 77 EXPECT_EQ(1, called); 78 } 79 80 struct StatusTestParam { 81 Status executor_status; 82 Status partial_run_status; 83 Status expected_status; 84 }; 85 86 class StatusPropagationTest : public ::testing::TestWithParam<StatusTestParam> { 87 protected: 88 PartialRunMgr partial_run_mgr_; 89 90 // State to help keep track of when the callback is called. 91 Notification invoked_; 92 Status status_; 93 94 void set_status(const Status& status) { 95 status_ = status; 96 invoked_.Notify(); 97 } 98 99 // Blocks until status is set. 100 Status status() { 101 invoked_.WaitForNotification(); 102 return status_; 103 } 104 }; 105 106 TEST_P(StatusPropagationTest, ExecutorDoneFirst) { 107 // Tests error propagation when ExecutorDone is called first. 108 StatusTestParam param = GetParam(); 109 int step_id = 1; 110 111 CancellationManager* cancellation_manager; 112 partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager); 113 114 partial_run_mgr_.ExecutorDone(step_id, param.executor_status); 115 partial_run_mgr_.PartialRunDone(step_id, 116 [this](Status status) { set_status(status); }, 117 param.partial_run_status); 118 119 EXPECT_EQ(status(), param.expected_status); 120 } 121 122 TEST_P(StatusPropagationTest, PartialRunDoneFirst) { 123 // Tests error propagation when PartialRunDone is called first. 124 StatusTestParam param = GetParam(); 125 int step_id = 1; 126 127 CancellationManager* cancellation_manager; 128 partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager); 129 130 partial_run_mgr_.PartialRunDone(step_id, 131 [this](Status status) { set_status(status); }, 132 param.partial_run_status); 133 partial_run_mgr_.ExecutorDone(step_id, param.executor_status); 134 135 EXPECT_EQ(status(), param.expected_status); 136 } 137 138 // Instantiate tests for all error orderings, for both call orders of 139 // ExecutorDone and PartialRunDone. 140 Status ExecutorError() { return errors::Internal("executor error"); } 141 Status PartialRunError() { return errors::Internal("partial run error"); } 142 INSTANTIATE_TEST_CASE_P( 143 PartialRunMgr, StatusPropagationTest, 144 ::testing::Values( 145 StatusTestParam{Status::OK(), Status::OK(), Status::OK()}, 146 StatusTestParam{ExecutorError(), Status::OK(), ExecutorError()}, 147 StatusTestParam{Status::OK(), PartialRunError(), PartialRunError()}, 148 StatusTestParam{ExecutorError(), PartialRunError(), ExecutorError()})); 149 150 } // namespace 151 } // namespace tensorflow 152