Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
     16 
     17 #include "tensorflow/core/common_runtime/base_collective_executor.h"
     18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
     19 #include "tensorflow/core/common_runtime/collective_rma_local.h"
     20 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
     21 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
     22 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
     23 #include "tensorflow/core/distributed_runtime/worker_cache.h"
     24 #include "tensorflow/core/lib/random/random.h"
     25 
     26 namespace tensorflow {
     27 
     28 RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
     29     const ConfigProto& config, const DeviceMgr* dev_mgr,
     30     std::unique_ptr<DeviceResolverDistributed> dev_resolver,
     31     std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
     32     WorkerCacheInterface* worker_cache, const string& task_name)
     33     : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
     34                             std::move(param_resolver)),
     35       worker_cache_(worker_cache),
     36       task_name_(task_name) {
     37   group_leader_ = (task_name == config.experimental().collective_group_leader())
     38                       ? ""
     39                       : config.experimental().collective_group_leader();
     40 }
     41 
     42 RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
     43   for (auto it : sequence_table_) {
     44     delete it.second;
     45   }
     46 }
     47 
     48 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
     49   CollectiveRemoteAccessDistributed* rma =
     50       new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
     51                                             worker_cache_, step_id);
     52   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
     53                                     &gpu_ring_order_);
     54 }
     55 
     56 namespace {
     57 // StepId must leave the most-significant 7 bits empty for future use.
     58 static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
     59 
     60 int64 NewRandomStepId() {
     61   int64 step_id = random::New64();
     62   // Leave MS 8 bits clear for future use.
     63   step_id &= kStepIdMask;
     64   return step_id;
     65 }
     66 }  // namespace
     67 
     68 void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
     69     int64 graph_key, const StatusCallback& done) {
     70   if (group_leader_.empty()) {
     71     mutex_lock l(sequence_mu_);
     72     GraphKeySequence* gks = nullptr;
     73     auto it = sequence_table_.find(graph_key);
     74     if (it == sequence_table_.end()) {
     75       gks = new GraphKeySequence(graph_key);
     76       sequence_table_[graph_key] = gks;
     77     } else {
     78       gks = it->second;
     79     }
     80     gks->next_step_id_ = NewRandomStepId();
     81     done(Status::OK());
     82   } else {
     83     WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_);
     84     GetStepSequenceRequest* req = new GetStepSequenceRequest;
     85     GetStepSequenceResponse* resp = new GetStepSequenceResponse;
     86     req->add_graph_key(graph_key);
     87     wi->GetStepSequenceAsync(
     88         req, resp, [this, req, resp, done](const Status& s) {
     89           if (!s.ok()) {
     90             LOG(ERROR) << "Bad response [" << s
     91                        << "] from GetStepSequenceAsync call to "
     92                        << group_leader_;
     93             done(s);
     94           } else {
     95             done(UpdateStepSequences(*resp));
     96           }
     97           delete req;
     98           delete resp;
     99         });
    100   }
    101 }
    102 
    103 void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
    104     const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
    105     const StatusCallback& done) {
    106   if (!group_leader_.empty()) {
    107     LOG(ERROR) << "GetStepSequence called at non-group-leader";
    108     done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
    109   } else {
    110     mutex_lock l(sequence_mu_);
    111     for (int64 graph_key : request->graph_key()) {
    112       auto it = sequence_table_.find(graph_key);
    113       GraphKeySequence* gks = nullptr;
    114       if (it == sequence_table_.end()) {
    115         gks = new GraphKeySequence(graph_key);
    116         gks->next_step_id_ = NewRandomStepId();
    117         sequence_table_[graph_key] = gks;
    118       } else {
    119         gks = it->second;
    120       }
    121       StepSequence* ss = response->add_step_sequence();
    122       ss->set_graph_key(graph_key);
    123       ss->set_next_step_id(gks->next_step_id_);
    124     }
    125     done(Status::OK());
    126   }
    127 }
    128 
    129 Status RpcCollectiveExecutorMgr::UpdateStepSequences(
    130     const GetStepSequenceResponse& resp) {
    131   mutex_lock l(sequence_mu_);
    132   for (const StepSequence& ss : resp.step_sequence()) {
    133     GraphKeySequence* gks = nullptr;
    134     auto it = sequence_table_.find(ss.graph_key());
    135     if (it == sequence_table_.end()) {
    136       gks = new GraphKeySequence(ss.graph_key());
    137       sequence_table_[ss.graph_key()] = gks;
    138     } else {
    139       gks = it->second;
    140     }
    141     gks->next_step_id_ = ss.next_step_id();
    142   }
    143   return Status::OK();
    144 }
    145 
    146 int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) {
    147   mutex_lock l(sequence_mu_);
    148   auto it = sequence_table_.find(graph_key);
    149   if (it != sequence_table_.end()) {
    150     return it->second->next_step_id_;
    151   }
    152   return CollectiveExecutor::kInvalidId;
    153 }
    154 
    155 void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) {
    156   mutex_lock l(sequence_mu_);
    157   auto it = sequence_table_.find(graph_key);
    158   if (it != sequence_table_.end()) {
    159     if (step_id == it->second->next_step_id_) {
    160       it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
    161     } else {
    162       it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
    163     }
    164   } else {
    165     LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
    166   }
    167 }
    168 
    169 }  // namespace tensorflow
    170