Home | History | Annotate | Download | only in verbs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef TENSORFLOW_USE_VERBS
     17 
     18 #include "tensorflow/contrib/verbs/rdma_mgr.h"
     19 #include <fstream>
     20 #include <vector>
     21 #include "tensorflow/contrib/verbs/grpc_verbs_client.h"
     22 #include "tensorflow/contrib/verbs/verbs_service.pb.h"
     23 #include "tensorflow/core/common_runtime/bfc_allocator.h"
     24 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
     25 #include "tensorflow/core/common_runtime/gpu/process_state.h"
     26 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
     27 #include "tensorflow/core/distributed_runtime/session_mgr.h"
     28 #include "tensorflow/core/framework/allocator_registry.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 
     31 namespace tensorflow {
     32 
     33 RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
     34                  GrpcChannelCache* const channel_cache)
     35     : worker_env_(worker_env), channel_cache_(channel_cache) {
     36   rdma_adapter_ = new RdmaAdapter(worker_env_);
     37   // hardcoded to default session (legacy_session_)
     38   // TODO: use WorkerSessionForSession
     39   // need to pass in session handle
     40   local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
     41   std::vector<string> workers;
     42   worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
     43       &workers);
     44   num_remote_workers_ = workers.size() - 1;
     45   VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
     46   for (size_t i = 0; i < workers.size(); i++) {
     47     if (local_worker_.compare(workers[i]) != 0) {
     48       channel_table_.insert(
     49           {workers[i],
     50            new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
     51     }
     52   }
     53 }
     54 
     55 // Setup Rdma channels between peers.
     56 // This is done at the beginning of the server setup.
     57 
     58 void RdmaMgr::SetupChannels() {
     59   for (const auto& p : channel_table_) {
     60     string worker_name = p.first;
     61     RDMA_LOG(2) << "Connecting to remote node " << worker_name;
     62     RdmaChannel* rc = p.second;
     63     GetRemoteAddressRequest req;
     64     GetRemoteAddressResponse resp;
     65     // get the channel cache
     66     SharedGrpcChannelPtr client_channel =
     67         channel_cache_->FindWorkerChannel(worker_name);
     68     GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
     69     CHECK(client != nullptr) << "No worker known as " << worker_name;
     70 
     71     // setting up request
     72     req.set_host_name(local_worker_);
     73     Channel* channel_info = req.mutable_channel();
     74     channel_info->set_lid(rc->self_.lid);
     75     channel_info->set_qpn(rc->self_.qpn);
     76     channel_info->set_psn(rc->self_.psn);
     77     channel_info->set_snp(rc->self_.snp);
     78     channel_info->set_iid(rc->self_.iid);
     79     for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
     80       MemoryRegion* mr = req.add_mr();
     81       mr->set_remote_addr(
     82           reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
     83       mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
     84     }
     85     // synchronous call
     86     Status s;
     87     int attempts = 0;
     88     static const int max_num_attempts = 5;
     89     do {
     90       s = client->GetRemoteAddress(&req, &resp);
     91       // save obtained remote addresses
     92       // connect to the remote channel
     93       if (s.ok()) {
     94         CHECK(worker_name.compare(resp.host_name()) == 0);
     95         RdmaAddress ra;
     96         ra.lid = resp.channel().lid();
     97         ra.qpn = resp.channel().qpn();
     98         ra.psn = resp.channel().psn();
     99         ra.snp = resp.channel().snp();
    100         ra.iid = resp.channel().iid();
    101         rc->SetRemoteAddress(ra, false);
    102         rc->Connect();
    103         int i = 0;
    104         int idx[] = {1, 0};
    105         for (const auto& mr : resp.mr()) {
    106           // the connections are crossed, i.e.
    107           // local tx_message_buffer <---> remote rx_message_buffer_
    108           // local rx_message_buffer <---> remote tx_message_buffer_
    109           // hence idx[] = {1, 0}.
    110           RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]];
    111           RemoteMR rmr;
    112           rmr.remote_addr = mr.remote_addr();
    113           rmr.rkey = mr.rkey();
    114           rb->SetRemoteMR(rmr, false);
    115           i++;
    116         }
    117         CHECK(i == RdmaChannel::kNumMessageBuffers);
    118       } else {
    119         LOG(ERROR) << "Connecting to " << worker_name << ": Got "
    120                    << s.error_message() << ". Retrying (" << (attempts + 1)
    121                    << "/" << max_num_attempts << ")...";
    122         if (++attempts == max_num_attempts) {
    123           break;
    124         }
    125         worker_env_->env->SleepForMicroseconds(2000000);
    126       }
    127     } while (!s.ok());
    128     RDMA_LOG(0) << "Connected to remote node " << worker_name;
    129     delete client;
    130   }
    131 }
    132 
    133 // Check connectivity by pinging every channel
    134 bool RdmaMgr::ConnectivityCheck() {
    135   int i, rcnt = 0, scnt = 0;
    136 
    137   for (const auto& p : channel_table_) {
    138     string worker_name = p.first;
    139     RdmaChannel* rc = p.second;
    140 
    141     VLOG(2) << "Ping to " << worker_name;
    142     CHECK(rc->PingPostSend() == 0) << "Couldn't post send  to " << worker_name
    143                                    << " with error: " << std::strerror(errno);
    144     for (i = 0; i < rc->adapter_->params_.queue_depth - 1; i++) {
    145       rc->Recv();
    146     }
    147   }
    148 
    149   while (rcnt < num_remote_workers_ || scnt < num_remote_workers_) {
    150     int ne;
    151     do {
    152       ne = ibv_poll_cq(rdma_adapter_->cq_, 2 * num_remote_workers_,
    153                        rdma_adapter_->wc_);
    154       CHECK(ne >= 0) << "poll CQ failed " << ne << "with error"
    155                      << std::strerror(errno);
    156     } while (ne < 1);
    157 
    158     for (i = 0; i < ne; ++i) {
    159       ibv_wc_status s = rdma_adapter_->wc_[i].status;
    160       // recv complete
    161       if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) {
    162         CHECK(s == IBV_WC_SUCCESS)
    163             << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
    164             << rdma_adapter_->wc_[i].status << ") for PING_RECV_WRID";
    165         ++rcnt;
    166         // send complete
    167       } else {
    168         RdmaChannel* rc =
    169             reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id);
    170         CHECK(s == IBV_WC_SUCCESS)
    171             << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
    172             << rdma_adapter_->wc_[i].status << ") to " << rc->remote_name_;
    173         ++scnt;
    174       }
    175     }  // for
    176   }    // while
    177   CHECK(rcnt == scnt) << "Connectivity check failed!";
    178   rdma_adapter_->StartPolling();
    179   return (num_remote_workers_ == rcnt) && (num_remote_workers_ == scnt);
    180 }
    181 
    182 RdmaMgr::~RdmaMgr() {
    183   for (const auto& p : channel_table_) delete p.second;
    184   channel_table_.clear();
    185   delete rdma_adapter_;
    186 }
    187 
    188 // Find a channel via the given name.
    189 // Args:
    190 //   name: peer name, e.g. worker1
    191 // Returns
    192 //   channel object that is connected to the named peer.
    193 RdmaChannel* RdmaMgr::FindChannel(const string& name) {
    194   ChannelTable::iterator iter = channel_table_.find(name);
    195   CHECK(iter != channel_table_.end());
    196   return iter->second;
    197 }
    198 
    199 bool IsGDRAvailable() {
    200 #if defined(__APPLE__)
    201   return false;
    202 #elif defined(PLATFORM_WINDOWS)
    203   return false;
    204 #else
    205   std::ifstream ifs("/proc/modules");
    206   string line;
    207   while (std::getline(ifs, line)) {
    208     auto sep = line.find(' ');
    209     CHECK_NE(sep, std::string::npos);
    210     if (line.substr(0, sep) == "nv_peer_mem") {
    211       return true;
    212     }
    213   }
    214   return false;
    215 #endif
    216 }
    217 
    218 int TryToReadNumaNode(ibv_device* device) {
    219 #if defined(__APPLE__)
    220   LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
    221   return 0;
    222 #elif defined(PLATFORM_WINDOWS)
    223   // Windows support for NUMA is not currently implemented. Return node 0.
    224   return 0;
    225 #else
    226   VLOG(2) << "Trying to read NUMA node for device: " << device->name;
    227   static const int kUnknownNumaNode = -1;
    228 
    229   auto filename = string(device->ibdev_path) + "/device/numa_node";
    230 
    231   std::ifstream ifs(filename.c_str());
    232   string content;
    233   CHECK(std::getline(ifs, content));
    234 
    235   int32 value;
    236   if (strings::safe_strto32(content, &value)) {
    237     if (value < 0) {
    238       LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
    239                 << value
    240                 << "), but there must be at least one NUMA node"
    241                    ", so returning NUMA node zero";
    242       return 0;
    243     }
    244     LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
    245     return value;
    246   }
    247   return kUnknownNumaNode;
    248 #endif
    249 }
    250 
    251 void MRDeleter(ibv_mr* mr) {
    252   if (mr) {
    253     ibv_dereg_mr(mr);
    254   }
    255 }
    256 
    257 // TODO(byronyi): remove this class duplicated from the one in
    258 // common/runtime/gpu/pool_allocator.h when it is available in common_runtime
    259 class BasicCPUAllocator : public SubAllocator {
    260  public:
    261   ~BasicCPUAllocator() override {}
    262 
    263   void* Alloc(size_t alignment, size_t num_bytes) override {
    264     return port::AlignedMalloc(num_bytes, alignment);
    265   }
    266   void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
    267 };
    268 
    269 // TODO(byronyi): remove this class and its registration when the default
    270 // cpu_allocator() returns visitable allocator
    271 class BFCRdmaAllocator : public BFCAllocator {
    272  public:
    273   BFCRdmaAllocator()
    274       : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
    275   }
    276 };
    277 
    278 REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
    279 
    280 void RdmaMgr::InitAllocators() {
    281   RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_;
    282 
    283   Allocator* allocators[] = {
    284 #if GOOGLE_CUDA
    285     ProcessState::singleton()->GetCUDAHostAllocator(0),
    286     ProcessState::singleton()->GetCPUAllocator(0),
    287 #endif  // GOOGLE_CUDA
    288     cpu_allocator(),
    289   };
    290 
    291   using namespace std::placeholders;
    292 
    293   std::set<Allocator*> instrumented_;
    294 
    295   // Host memory allocators
    296   for (Allocator* allocator : allocators) {
    297     VisitableAllocator::Visitor alloc_visitor =
    298         std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
    299                   &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name());
    300     VisitableAllocator::Visitor free_visitor = std::bind(
    301         &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
    302 
    303     auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
    304     CHECK(visitable_allocator)
    305         << "is not visitable for instrumentation" << allocator->Name();
    306     // Make sure we don't instrument the same allocator twice
    307     if (instrumented_.find(allocator) == std::end(instrumented_)) {
    308       visitable_allocator->AddAllocVisitor(alloc_visitor);
    309       visitable_allocator->AddFreeVisitor(free_visitor);
    310       instrumented_.insert(allocator);
    311       LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
    312     }
    313   }
    314 
    315 #if GOOGLE_CUDA
    316   if (IsGDRAvailable()) {
    317     // Note we don't free allocated GPU memory so there is no free visitor
    318     int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
    319 
    320     char buf[8];
    321     sprintf(buf, "gpu");
    322     VisitableAllocator::Visitor cuda_alloc_visitor =
    323         std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
    324                   &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf));
    325 
    326     ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
    327     LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
    328   }
    329 #endif  // GOOGLE_CUDA
    330 }
    331 
    332 }  // end namespace tensorflow
    333 
    334 #endif
    335