Home | History | Annotate | Download | only in training
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/cc/training/coordinator.h"
     17 
     18 namespace tensorflow {
     19 
     20 Coordinator::Coordinator() : Coordinator(std::vector<error::Code>()) {}
     21 
     22 Coordinator::Coordinator(const std::vector<error::Code>& clean_stop_errors)
     23     : should_stop_(false) {
     24   if (clean_stop_errors.empty()) {
     25     clean_stop_errors_.insert(error::OUT_OF_RANGE);
     26   } else {
     27     for (const auto& code : clean_stop_errors) {
     28       clean_stop_errors_.insert(static_cast<int>(code));
     29     }
     30   }
     31 }
     32 
     33 Coordinator::~Coordinator() {
     34   RequestStop().IgnoreError();
     35   Join().IgnoreError();
     36 }
     37 
     38 Status Coordinator::RegisterRunner(std::unique_ptr<RunnerInterface> runner) {
     39   {
     40     mutex_lock l(mu_);
     41     if (should_stop_) {
     42       return Status(error::FAILED_PRECONDITION,
     43                     "The coordinator has been stopped.");
     44     }
     45   }
     46   mutex_lock l(runners_lock_);
     47   runners_.push_back(std::move(runner));
     48   return Status::OK();
     49 }
     50 
     51 bool Coordinator::AllRunnersStopped() {
     52   mutex_lock l(runners_lock_);
     53   for (const auto& runner : runners_) {
     54     if (runner->IsRunning()) {
     55       return false;
     56     }
     57   }
     58   return true;
     59 }
     60 
     61 Status Coordinator::RequestStop() {
     62   mutex_lock l(mu_);
     63   if (should_stop_) {
     64     return Status(error::FAILED_PRECONDITION,
     65                   "The Coordinator is not running.");
     66   }
     67   should_stop_ = true;
     68   wait_for_stop_.notify_all();
     69   return Status::OK();
     70 }
     71 
     72 bool Coordinator::ShouldStop() {
     73   mutex_lock l(mu_);
     74   return should_stop_;
     75 }
     76 
     77 Status Coordinator::Join() {
     78   // TODO(yuefengz): deal with stragglers.
     79   {
     80     mutex_lock l(mu_);
     81     if (!should_stop_) {
     82       return Status(error::FAILED_PRECONDITION,
     83                     "Joining coordinator without requesting to stop.");
     84     }
     85   }
     86 
     87   {
     88     mutex_lock l(runners_lock_);
     89     for (const auto& t : runners_) {
     90       ReportStatus(t->Join());
     91     }
     92     runners_.clear();
     93   }
     94   return GetStatus();
     95 }
     96 
     97 void Coordinator::ReportStatus(const Status& status) {
     98   mutex_lock l(status_lock_);
     99   if (status.ok() || !status_.ok() ||
    100       clean_stop_errors_.count(static_cast<int>(status.code())) > 0) {
    101     return;
    102   }
    103   status_ = status;
    104 }
    105 
    106 Status Coordinator::GetStatus() {
    107   mutex_lock l(status_lock_);
    108   return status_;
    109 }
    110 
    111 void Coordinator::WaitForStop() {
    112   mutex_lock l(mu_);
    113   while (!should_stop_) {
    114     wait_for_stop_.wait(l);
    115   }
    116 }
    117 
    118 Status Coordinator::ExportCostGraph(CostGraphDef* cost_graph) const {
    119   mutex_lock l(runners_lock_);
    120   for (auto& t : runners_) {
    121     Status s = t->ExportCostGraph(cost_graph);
    122     if (!s.ok()) {
    123       return s;
    124     }
    125   }
    126   return Status::OK();
    127 }
    128 
    129 }  // namespace tensorflow
    130