Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
     17 
     18 #include <unordered_map>
     19 
     20 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
     21 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
     22 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
     23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     24 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
     25 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
     26 #include "tensorflow/core/distributed_runtime/worker_interface.h"
     27 #include "tensorflow/core/platform/env.h"
     28 #include "tensorflow/core/platform/mutex.h"
     29 
     30 namespace tensorflow {
     31 
     32 namespace {
     33 
     34 class GrpcWorkerCache : public WorkerCachePartial {
     35  public:
     36   // TODO(ncteisen): consider adding a config var or flag for this
     37   static constexpr const size_t kGrpcWorkerCacheThreadCount = 2;
     38 
     39   explicit GrpcWorkerCache(GrpcChannelCache* channel_cache,
     40                            WorkerInterface* local_worker,
     41                            const string& local_target)
     42       : local_target_(local_target),
     43         local_worker_(local_worker),
     44         channel_cache_(channel_cache),
     45         threads_(kGrpcWorkerCacheThreadCount),
     46         next_round_robin_assignment_(0) {}
     47 
     48   // Explicit destructor to control destruction order.
     49   ~GrpcWorkerCache() override {
     50     threads_.clear();  // Blocks until threads exit.
     51     delete channel_cache_;
     52   }
     53 
     54   void ListWorkers(std::vector<string>* workers) const override {
     55     channel_cache_->ListWorkers(workers);
     56   }
     57 
     58   WorkerInterface* CreateWorker(const string& target) override {
     59     if (target == local_target_) {
     60       return local_worker_;
     61     } else {
     62       SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
     63       if (!channel) return nullptr;
     64       return NewGrpcRemoteWorker(
     65           channel, threads_[AssignWorkerToThread(target)].completion_queue(),
     66           &logger_);
     67     }
     68   }
     69 
     70   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
     71     if (target == local_target_) {
     72       CHECK_EQ(worker, local_worker_)
     73           << "Releasing a worker that was not returned by this WorkerCache";
     74     } else {
     75       WorkerCacheInterface::ReleaseWorker(target, worker);
     76     }
     77   }
     78 
     79   void SetLogging(bool v) override { logger_.SetLogging(v); }
     80 
     81   void ClearLogs() override { logger_.ClearLogs(); }
     82 
     83   bool RetrieveLogs(int64 step_id, StepStats* ss) override {
     84     return logger_.RetrieveLogs(step_id, ss);
     85   }
     86 
     87  private:
     88   // Thread wrapping class that drives work over a single gRPC
     89   // CompletionQueue.
     90   class GrpcWorkerCacheThread {
     91    public:
     92     GrpcWorkerCacheThread() {
     93       thread_.reset(Env::Default()->StartThread(
     94           ThreadOptions(), "grpc_worker_cache", [this]() {
     95             void* tag;
     96             bool ok;
     97             while (completion_queue_.Next(&tag, &ok)) {
     98               GrpcClientCQTag* callback_tag =
     99                   static_cast<GrpcClientCQTag*>(tag);
    100               callback_tag->OnCompleted(ok);
    101             }
    102           }));
    103     }
    104 
    105     ~GrpcWorkerCacheThread() {
    106       completion_queue_.Shutdown();
    107       thread_.reset();
    108     }
    109 
    110     ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
    111 
    112    private:
    113     ::grpc::CompletionQueue completion_queue_;
    114     std::unique_ptr<Thread> thread_;
    115   };  // GrpcWorkerCacheThread
    116 
    117   size_t AssignWorkerToThread(const string& target) {
    118     // Round-robin target assignment, but keeps the same target on the same
    119     // polling thread always, as this is important for gRPC performace
    120     mutex_lock lock(assignment_mu_);
    121     auto it = target_assignments_.find(target);
    122     if (it == target_assignments_.end()) {
    123       it = target_assignments_
    124                .insert(std::make_pair(
    125                    target, (next_round_robin_assignment_++) % threads_.size()))
    126                .first;
    127     }
    128     return it->second;
    129   }
    130 
    131   const string local_target_;
    132   WorkerInterface* const local_worker_;  // Not owned.
    133   GrpcChannelCache* channel_cache_;      // Owned.
    134   WorkerCacheLogger logger_;
    135   std::vector<GrpcWorkerCacheThread> threads_;
    136 
    137   mutex assignment_mu_;
    138   std::unordered_map<std::string, size_t> target_assignments_
    139       GUARDED_BY(assignment_mu_);
    140   size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_);
    141 };
    142 
    143 }  // namespace
    144 
    145 WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc) {
    146   return new GrpcWorkerCache(cc, nullptr, "");
    147 }
    148 
    149 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
    150     GrpcChannelCache* cc, WorkerInterface* local_worker,
    151     const string& local_target) {
    152   return new GrpcWorkerCache(cc, local_worker, local_target);
    153 }
    154 
    155 }  // namespace tensorflow
    156