1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" 17 18 #include <unordered_map> 19 20 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 21 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" 22 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h" 23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 24 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h" 25 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h" 26 #include "tensorflow/core/distributed_runtime/worker_interface.h" 27 #include "tensorflow/core/platform/env.h" 28 #include "tensorflow/core/platform/mutex.h" 29 30 namespace tensorflow { 31 32 namespace { 33 34 class GrpcWorkerCache : public WorkerCachePartial { 35 public: 36 // TODO(ncteisen): consider adding a config var or flag for this 37 static constexpr const size_t kGrpcWorkerCacheThreadCount = 2; 38 39 explicit GrpcWorkerCache(GrpcChannelCache* channel_cache, 40 WorkerInterface* local_worker, 41 const string& local_target) 42 : local_target_(local_target), 43 local_worker_(local_worker), 44 channel_cache_(channel_cache), 45 threads_(kGrpcWorkerCacheThreadCount), 46 next_round_robin_assignment_(0) {} 47 48 // Explicit destructor to control destruction order. 49 ~GrpcWorkerCache() override { 50 threads_.clear(); // Blocks until threads exit. 51 delete channel_cache_; 52 } 53 54 void ListWorkers(std::vector<string>* workers) const override { 55 channel_cache_->ListWorkers(workers); 56 } 57 58 WorkerInterface* CreateWorker(const string& target) override { 59 if (target == local_target_) { 60 return local_worker_; 61 } else { 62 SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target); 63 if (!channel) return nullptr; 64 return NewGrpcRemoteWorker( 65 channel, threads_[AssignWorkerToThread(target)].completion_queue(), 66 &logger_); 67 } 68 } 69 70 void ReleaseWorker(const string& target, WorkerInterface* worker) override { 71 if (target == local_target_) { 72 CHECK_EQ(worker, local_worker_) 73 << "Releasing a worker that was not returned by this WorkerCache"; 74 } else { 75 WorkerCacheInterface::ReleaseWorker(target, worker); 76 } 77 } 78 79 void SetLogging(bool v) override { logger_.SetLogging(v); } 80 81 void ClearLogs() override { logger_.ClearLogs(); } 82 83 bool RetrieveLogs(int64 step_id, StepStats* ss) override { 84 return logger_.RetrieveLogs(step_id, ss); 85 } 86 87 private: 88 // Thread wrapping class that drives work over a single gRPC 89 // CompletionQueue. 90 class GrpcWorkerCacheThread { 91 public: 92 GrpcWorkerCacheThread() { 93 thread_.reset(Env::Default()->StartThread( 94 ThreadOptions(), "grpc_worker_cache", [this]() { 95 void* tag; 96 bool ok; 97 while (completion_queue_.Next(&tag, &ok)) { 98 GrpcClientCQTag* callback_tag = 99 static_cast<GrpcClientCQTag*>(tag); 100 callback_tag->OnCompleted(ok); 101 } 102 })); 103 } 104 105 ~GrpcWorkerCacheThread() { 106 completion_queue_.Shutdown(); 107 thread_.reset(); 108 } 109 110 ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; } 111 112 private: 113 ::grpc::CompletionQueue completion_queue_; 114 std::unique_ptr<Thread> thread_; 115 }; // GrpcWorkerCacheThread 116 117 size_t AssignWorkerToThread(const string& target) { 118 // Round-robin target assignment, but keeps the same target on the same 119 // polling thread always, as this is important for gRPC performace 120 mutex_lock lock(assignment_mu_); 121 auto it = target_assignments_.find(target); 122 if (it == target_assignments_.end()) { 123 it = target_assignments_ 124 .insert(std::make_pair( 125 target, (next_round_robin_assignment_++) % threads_.size())) 126 .first; 127 } 128 return it->second; 129 } 130 131 const string local_target_; 132 WorkerInterface* const local_worker_; // Not owned. 133 GrpcChannelCache* channel_cache_; // Owned. 134 WorkerCacheLogger logger_; 135 std::vector<GrpcWorkerCacheThread> threads_; 136 137 mutex assignment_mu_; 138 std::unordered_map<std::string, size_t> target_assignments_ 139 GUARDED_BY(assignment_mu_); 140 size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_); 141 }; 142 143 } // namespace 144 145 WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc) { 146 return new GrpcWorkerCache(cc, nullptr, ""); 147 } 148 149 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker( 150 GrpcChannelCache* cc, WorkerInterface* local_worker, 151 const string& local_target) { 152 return new GrpcWorkerCache(cc, local_worker, local_target); 153 } 154 155 } // namespace tensorflow 156