1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 17 18 #include <limits> 19 #include <map> 20 #include <unordered_map> 21 22 #include "grpc++/create_channel.h" 23 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/lib/gtl/map_util.h" 27 #include "tensorflow/core/lib/strings/numbers.h" 28 #include "tensorflow/core/lib/strings/str_util.h" 29 #include "tensorflow/core/lib/strings/strcat.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/macros.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/platform/types.h" 35 #include "tensorflow/core/util/device_name_utils.h" 36 37 namespace tensorflow { 38 39 namespace { 40 41 string MakeAddress(const string& job, int task) { 42 return strings::StrCat("/job:", job, "/replica:0/task:", task); 43 } 44 45 Status ValidateHostPortPair(const string& host_port) { 46 uint32 port; 47 std::vector<string> parts = str_util::Split(host_port, ':'); 48 // Must be host:port, port must be a number, host must not contain a '/'. 49 if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || 50 parts[0].find("/") != string::npos) { 51 return errors::InvalidArgument("Could not interpret \"", host_port, 52 "\" as a host-port pair."); 53 } 54 return Status::OK(); 55 } 56 } // namespace 57 58 Status NewHostPortGrpcChannel(const string& target, 59 SharedGrpcChannelPtr* channel_pointer) { 60 // Minimally ensure that the target is valid 61 TF_RETURN_IF_ERROR(ValidateHostPortPair(target)); 62 63 // TODO(mrry): Implement secure channels. 64 ::grpc::ChannelArguments args; 65 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max()); 66 // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff 67 // on connection failure, which makes our tests time out. 68 args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000); 69 *channel_pointer = ::grpc::CreateCustomChannel( 70 "dns:///" + target, ::grpc::InsecureChannelCredentials(), args); 71 return Status::OK(); 72 } 73 74 ChannelCreationFunction ConvertToChannelCreationFunction( 75 const std::function<Status(string, SharedGrpcChannelPtr*)>& 76 new_channel_func_ptr) { 77 return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr { 78 SharedGrpcChannelPtr channel_ptr; 79 if (new_channel_func_ptr(target, &channel_ptr).ok()) { 80 return channel_ptr; 81 } else { 82 return nullptr; 83 } 84 }; 85 } 86 87 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id, 88 const std::vector<string>& host_ports) { 89 std::map<int, string> host_ports_map; 90 for (size_t i = 0; i < host_ports.size(); ++i) { 91 host_ports_map[i] = host_ports[i]; 92 } 93 return AddHostPortsJob(job_id, host_ports_map); 94 } 95 96 Status GrpcChannelSpec::AddHostPortsJob( 97 const string& job_id, const std::map<int, string>& host_ports) { 98 if (!job_ids_.insert(job_id).second) { 99 return errors::InvalidArgument( 100 "Duplicate job ID in cluster specification: ", job_id); 101 } 102 for (const auto& id_host_port : host_ports) { 103 TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second)); 104 } 105 host_ports_jobs_.emplace_back(job_id, host_ports); 106 return Status::OK(); 107 } 108 109 namespace { 110 111 // GrpcChannelCache that caches results to FindWorkerChannel() calls. 112 class CachingGrpcChannelCache : public GrpcChannelCache { 113 public: 114 CachingGrpcChannelCache() {} 115 116 ~CachingGrpcChannelCache() override {} 117 118 SharedGrpcChannelPtr FindWorkerChannel(const string& target) override { 119 SharedGrpcChannelPtr ch = nullptr; 120 { 121 mutex_lock l(mu_); // could use reader lock 122 ch = gtl::FindPtrOrNull(channels_, target); 123 if (ch) { 124 return ch; 125 } 126 } 127 ch = FindChannelOnce(target); 128 if (ch) { 129 mutex_lock l(mu_); 130 channels_.insert({target, ch}); 131 } 132 return ch; 133 } 134 135 protected: 136 // Find the ClientChannel for "target". Only called when no channel was 137 // found in the channels_ cache for "target". A non nullptr result will be 138 // cached in channels_. 139 virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0; 140 141 private: 142 // TODO(zhifengc): Eviction when the map becomes too big. 143 mutex mu_; 144 std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_); 145 }; 146 147 // A ChannelCache that is the union of multiple ChannelCaches. 148 // Takes ownership of the caches passed to the constructor. 149 class MultiGrpcChannelCache : public CachingGrpcChannelCache { 150 public: 151 explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches) 152 : CachingGrpcChannelCache(), caches_(caches) {} 153 154 ~MultiGrpcChannelCache() override { 155 for (GrpcChannelCache* cache : caches_) { 156 delete cache; 157 } 158 } 159 160 void ListWorkers(std::vector<string>* workers) override { 161 for (GrpcChannelCache* cache : caches_) { 162 cache->ListWorkers(workers); 163 } 164 } 165 166 string TranslateTask(const string& target) override { 167 mutex_lock l(mu_); // could use reader lock 168 GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target); 169 if (cache == nullptr) { 170 for (GrpcChannelCache* c : caches_) { 171 string r = c->TranslateTask(target); 172 if (!r.empty()) { 173 target_caches_.insert({target, c}); 174 cache = c; 175 break; 176 } 177 } 178 } 179 CHECK(cache) << "Could not find GrpcChannelCache holding channel for " 180 << target; 181 return cache->TranslateTask(target); 182 } 183 184 protected: 185 SharedGrpcChannelPtr FindChannelOnce(const string& target) override { 186 for (GrpcChannelCache* cache : caches_) { 187 SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target)); 188 if (ch) { 189 mutex_lock l(mu_); 190 target_caches_.insert({target, cache}); 191 return ch; 192 } 193 } 194 return nullptr; 195 } 196 197 private: 198 // List of channels used by this MultiGrpcChannelCache. 199 const std::vector<GrpcChannelCache*> caches_; 200 201 mutex mu_; 202 // Cache of channels keyed by the target they are handling. 203 // The same GrpcChannelCache can appear multiple times in the cache. 204 std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_); 205 }; 206 207 class SparseGrpcChannelCache : public CachingGrpcChannelCache { 208 public: 209 SparseGrpcChannelCache(const string& job_id, 210 const std::map<int, string>& host_ports, 211 ChannelCreationFunction channel_func) 212 : job_id_(job_id), 213 host_ports_(host_ports), 214 channel_func_(std::move(channel_func)) { 215 LOG(INFO) << "Initialize GrpcChannelCache for job " << ToString(); 216 } 217 ~SparseGrpcChannelCache() override {} 218 219 void ListWorkers(std::vector<string>* workers) override { 220 workers->reserve(workers->size() + host_ports_.size()); 221 for (const auto& id_host_port : host_ports_) { 222 workers->emplace_back(MakeAddress(job_id_, id_host_port.first)); 223 } 224 } 225 226 string TranslateTask(const string& target) override { 227 DeviceNameUtils::ParsedName parsed; 228 if (!DeviceNameUtils::ParseFullName(target, &parsed)) { 229 LOG(WARNING) << "Invalid target: " << target; 230 return ""; 231 } 232 233 if (!parsed.has_job || parsed.job != job_id_) { 234 return ""; 235 } 236 if (!parsed.has_replica || parsed.replica != 0) { 237 LOG(WARNING) << "Replica ID must be 0 in target: " << target; 238 return ""; 239 } 240 int32 task = parsed.has_task ? parsed.task : -1; 241 auto iter = host_ports_.find(task); 242 if (iter == host_ports_.end()) { 243 LOG(WARNING) << "Task " << task << " was not defined in sparse job " 244 << job_id_ << ": " << target; 245 return ""; 246 } 247 return iter->second; 248 } 249 250 protected: 251 SharedGrpcChannelPtr FindChannelOnce(const string& target) override { 252 const string host_port = TranslateTask(target); 253 if (host_port.empty()) { 254 return nullptr; 255 } 256 return channel_func_(host_port); 257 } 258 259 private: 260 string ToString() { 261 std::vector<string> task_strings; 262 task_strings.reserve(host_ports_.size()); 263 for (const auto& id_host_port : host_ports_) { 264 task_strings.emplace_back( 265 strings::StrCat(id_host_port.first, " -> ", id_host_port.second)); 266 } 267 return strings::StrCat(job_id_, " -> {", str_util::Join(task_strings, ", "), 268 "}"); 269 } 270 271 const string job_id_; 272 const std::map<int, string> host_ports_; 273 const ChannelCreationFunction channel_func_; 274 TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache); 275 }; 276 277 } // namespace 278 279 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec, 280 ChannelCreationFunction channel_func) { 281 const int num_jobs = spec.host_ports_jobs().size(); 282 if (!num_jobs) { 283 LOG(ERROR) << "Empty channel spec."; 284 return nullptr; 285 } 286 std::vector<GrpcChannelCache*> caches; 287 caches.reserve(num_jobs); 288 for (auto& job : spec.host_ports_jobs()) { 289 caches.push_back( 290 new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func)); 291 } 292 return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache(caches); 293 } 294 295 } // end namespace tensorflow 296