Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
     17 
     18 #include <limits>
     19 #include <map>
     20 #include <unordered_map>
     21 
     22 #include "grpc++/create_channel.h"
     23 
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/gtl/map_util.h"
     27 #include "tensorflow/core/lib/strings/numbers.h"
     28 #include "tensorflow/core/lib/strings/str_util.h"
     29 #include "tensorflow/core/lib/strings/strcat.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/macros.h"
     32 #include "tensorflow/core/platform/mutex.h"
     33 #include "tensorflow/core/platform/thread_annotations.h"
     34 #include "tensorflow/core/platform/types.h"
     35 #include "tensorflow/core/util/device_name_utils.h"
     36 
     37 namespace tensorflow {
     38 
     39 namespace {
     40 
     41 string MakeAddress(const string& job, int task) {
     42   return strings::StrCat("/job:", job, "/replica:0/task:", task);
     43 }
     44 
     45 Status ValidateHostPortPair(const string& host_port) {
     46   uint32 port;
     47   std::vector<string> parts = str_util::Split(host_port, ':');
     48   // Must be host:port, port must be a number, host must not contain a '/'.
     49   if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) ||
     50       parts[0].find("/") != string::npos) {
     51     return errors::InvalidArgument("Could not interpret \"", host_port,
     52                                    "\" as a host-port pair.");
     53   }
     54   return Status::OK();
     55 }
     56 }  // namespace
     57 
     58 Status NewHostPortGrpcChannel(const string& target,
     59                               SharedGrpcChannelPtr* channel_pointer) {
     60   // Minimally ensure that the target is valid
     61   TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
     62 
     63   // TODO(mrry): Implement secure channels.
     64   ::grpc::ChannelArguments args;
     65   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
     66   // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
     67   // on connection failure, which makes our tests time out.
     68   args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
     69   *channel_pointer = ::grpc::CreateCustomChannel(
     70       "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
     71   return Status::OK();
     72 }
     73 
     74 ChannelCreationFunction ConvertToChannelCreationFunction(
     75     const std::function<Status(string, SharedGrpcChannelPtr*)>&
     76         new_channel_func_ptr) {
     77   return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
     78     SharedGrpcChannelPtr channel_ptr;
     79     if (new_channel_func_ptr(target, &channel_ptr).ok()) {
     80       return channel_ptr;
     81     } else {
     82       return nullptr;
     83     }
     84   };
     85 }
     86 
     87 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
     88                                         const std::vector<string>& host_ports) {
     89   std::map<int, string> host_ports_map;
     90   for (size_t i = 0; i < host_ports.size(); ++i) {
     91     host_ports_map[i] = host_ports[i];
     92   }
     93   return AddHostPortsJob(job_id, host_ports_map);
     94 }
     95 
     96 Status GrpcChannelSpec::AddHostPortsJob(
     97     const string& job_id, const std::map<int, string>& host_ports) {
     98   if (!job_ids_.insert(job_id).second) {
     99     return errors::InvalidArgument(
    100         "Duplicate job ID in cluster specification: ", job_id);
    101   }
    102   for (const auto& id_host_port : host_ports) {
    103     TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
    104   }
    105   host_ports_jobs_.emplace_back(job_id, host_ports);
    106   return Status::OK();
    107 }
    108 
    109 namespace {
    110 
    111 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
    112 class CachingGrpcChannelCache : public GrpcChannelCache {
    113  public:
    114   CachingGrpcChannelCache() {}
    115 
    116   ~CachingGrpcChannelCache() override {}
    117 
    118   SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
    119     SharedGrpcChannelPtr ch = nullptr;
    120     {
    121       mutex_lock l(mu_);  // could use reader lock
    122       ch = gtl::FindPtrOrNull(channels_, target);
    123       if (ch) {
    124         return ch;
    125       }
    126     }
    127     ch = FindChannelOnce(target);
    128     if (ch) {
    129       mutex_lock l(mu_);
    130       channels_.insert({target, ch});
    131     }
    132     return ch;
    133   }
    134 
    135  protected:
    136   // Find the ClientChannel for "target".  Only called when no channel was
    137   // found in the channels_ cache for "target".  A non nullptr result will be
    138   // cached in channels_.
    139   virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
    140 
    141  private:
    142   // TODO(zhifengc): Eviction when the map becomes too big.
    143   mutex mu_;
    144   std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_);
    145 };
    146 
    147 // A ChannelCache that is the union of multiple ChannelCaches.
    148 // Takes ownership of the caches passed to the constructor.
    149 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
    150  public:
    151   explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
    152       : CachingGrpcChannelCache(), caches_(caches) {}
    153 
    154   ~MultiGrpcChannelCache() override {
    155     for (GrpcChannelCache* cache : caches_) {
    156       delete cache;
    157     }
    158   }
    159 
    160   void ListWorkers(std::vector<string>* workers) override {
    161     for (GrpcChannelCache* cache : caches_) {
    162       cache->ListWorkers(workers);
    163     }
    164   }
    165 
    166   string TranslateTask(const string& target) override {
    167     mutex_lock l(mu_);  // could use reader lock
    168     GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
    169     if (cache == nullptr) {
    170       for (GrpcChannelCache* c : caches_) {
    171         string r = c->TranslateTask(target);
    172         if (!r.empty()) {
    173           target_caches_.insert({target, c});
    174           cache = c;
    175           break;
    176         }
    177       }
    178     }
    179     CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
    180                  << target;
    181     return cache->TranslateTask(target);
    182   }
    183 
    184  protected:
    185   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
    186     for (GrpcChannelCache* cache : caches_) {
    187       SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
    188       if (ch) {
    189         mutex_lock l(mu_);
    190         target_caches_.insert({target, cache});
    191         return ch;
    192       }
    193     }
    194     return nullptr;
    195   }
    196 
    197  private:
    198   // List of channels used by this MultiGrpcChannelCache.
    199   const std::vector<GrpcChannelCache*> caches_;
    200 
    201   mutex mu_;
    202   // Cache of channels keyed by the target they are handling.
    203   // The same GrpcChannelCache can appear multiple times in the cache.
    204   std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_);
    205 };
    206 
    207 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
    208  public:
    209   SparseGrpcChannelCache(const string& job_id,
    210                          const std::map<int, string>& host_ports,
    211                          ChannelCreationFunction channel_func)
    212       : job_id_(job_id),
    213         host_ports_(host_ports),
    214         channel_func_(std::move(channel_func)) {
    215     LOG(INFO) << "Initialize GrpcChannelCache for job " << ToString();
    216   }
    217   ~SparseGrpcChannelCache() override {}
    218 
    219   void ListWorkers(std::vector<string>* workers) override {
    220     workers->reserve(workers->size() + host_ports_.size());
    221     for (const auto& id_host_port : host_ports_) {
    222       workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
    223     }
    224   }
    225 
    226   string TranslateTask(const string& target) override {
    227     DeviceNameUtils::ParsedName parsed;
    228     if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
    229       LOG(WARNING) << "Invalid target: " << target;
    230       return "";
    231     }
    232 
    233     if (!parsed.has_job || parsed.job != job_id_) {
    234       return "";
    235     }
    236     if (!parsed.has_replica || parsed.replica != 0) {
    237       LOG(WARNING) << "Replica ID must be 0 in target: " << target;
    238       return "";
    239     }
    240     int32 task = parsed.has_task ? parsed.task : -1;
    241     auto iter = host_ports_.find(task);
    242     if (iter == host_ports_.end()) {
    243       LOG(WARNING) << "Task " << task << " was not defined in sparse job "
    244                    << job_id_ << ": " << target;
    245       return "";
    246     }
    247     return iter->second;
    248   }
    249 
    250  protected:
    251   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
    252     const string host_port = TranslateTask(target);
    253     if (host_port.empty()) {
    254       return nullptr;
    255     }
    256     return channel_func_(host_port);
    257   }
    258 
    259  private:
    260   string ToString() {
    261     std::vector<string> task_strings;
    262     task_strings.reserve(host_ports_.size());
    263     for (const auto& id_host_port : host_ports_) {
    264       task_strings.emplace_back(
    265           strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
    266     }
    267     return strings::StrCat(job_id_, " -> {", str_util::Join(task_strings, ", "),
    268                            "}");
    269   }
    270 
    271   const string job_id_;
    272   const std::map<int, string> host_ports_;
    273   const ChannelCreationFunction channel_func_;
    274   TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
    275 };
    276 
    277 }  // namespace
    278 
    279 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
    280                                       ChannelCreationFunction channel_func) {
    281   const int num_jobs = spec.host_ports_jobs().size();
    282   if (!num_jobs) {
    283     LOG(ERROR) << "Empty channel spec.";
    284     return nullptr;
    285   }
    286   std::vector<GrpcChannelCache*> caches;
    287   caches.reserve(num_jobs);
    288   for (auto& job : spec.host_ports_jobs()) {
    289     caches.push_back(
    290         new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func));
    291   }
    292   return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache(caches);
    293 }
    294 
    295 }  // end namespace tensorflow
    296