Home | History | Annotate | Download | only in gdr
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef TENSORFLOW_USE_GDR
     17 
     18 #include "tensorflow/contrib/gdr/gdr_memory_manager.h"
     19 
     20 #include <atomic>
     21 #include <cerrno>
     22 #include <fstream>
     23 #include <list>
     24 #include <map>
     25 #include <set>
     26 
     27 #include <fcntl.h>
     28 #include <rdma/rdma_cma.h>
     29 #include <rdma/rdma_verbs.h>
     30 #include <sys/epoll.h>
     31 
     32 #include "tensorflow/contrib/gdr/gdr.pb.h"
     33 #include "tensorflow/core/common_runtime/bfc_allocator.h"
     34 #include "tensorflow/core/common_runtime/device.h"
     35 #include "tensorflow/core/common_runtime/dma_helper.h"
     36 #if GOOGLE_CUDA
     37 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
     38 #include "tensorflow/core/common_runtime/gpu/process_state.h"
     39 #endif  // GOOGLE_CUDA
     40 #include "tensorflow/core/framework/allocator_registry.h"
     41 #include "tensorflow/core/lib/core/status.h"
     42 #include "tensorflow/core/platform/macros.h"
     43 #include "tensorflow/core/platform/mutex.h"
     44 
     45 namespace tensorflow {
     46 
     47 namespace {
     48 
     49 bool IsGDRAvailable() {
     50 #if defined(__APPLE__)
     51   return false;
     52 #elif defined(PLATFORM_WINDOWS)
     53   return false;
     54 #else
     55   std::ifstream ifs("/proc/modules");
     56   string line;
     57   while (std::getline(ifs, line)) {
     58     auto sep = line.find(' ');
     59     CHECK_NE(sep, std::string::npos);
     60     if (line.substr(0, sep) == "nv_peer_mem") {
     61       return true;
     62     }
     63   }
     64   return false;
     65 #endif
     66 }
     67 
     68 int TryToReadNumaNode(ibv_device* device) {
     69 #if defined(__APPLE__)
     70   LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
     71   return 0;
     72 #elif defined(PLATFORM_WINDOWS)
     73   // Windows support for NUMA is not currently implemented. Return node 0.
     74   return 0;
     75 #else
     76   VLOG(2) << "Trying to read NUMA node for device: " << device->name;
     77   static const int kUnknownNumaNode = -1;
     78 
     79   auto filename = string(device->ibdev_path) + "/device/numa_node";
     80 
     81   std::ifstream ifs(filename.c_str());
     82   string content;
     83   CHECK(std::getline(ifs, content));
     84 
     85   int32 value;
     86   if (strings::safe_strto32(content, &value)) {
     87     if (value < 0) {
     88       LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
     89                 << value
     90                 << "), but there must be at least one NUMA node"
     91                    ", so returning NUMA node zero";
     92       return 0;
     93     }
     94     LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
     95     return value;
     96   }
     97   return kUnknownNumaNode;
     98 #endif
     99 }
    100 
    101 void EndpointDeleter(rdma_cm_id* id) {
    102   if (id) {
    103     rdma_destroy_ep(id);
    104   }
    105 }
    106 
    107 void MRDeleter(ibv_mr* mr) {
    108   if (mr) {
    109     rdma_dereg_mr(mr);
    110   }
    111 }
    112 
    113 using RdmaEndpointPtr = std::unique_ptr<rdma_cm_id, decltype(&EndpointDeleter)>;
    114 
    115 using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
    116 
    117 class GdrMemoryManager : public RemoteMemoryManager {
    118  public:
    119   GdrMemoryManager(const string& host, const string& port);
    120 
    121   virtual ~GdrMemoryManager();
    122 
    123   virtual Status Init() override;
    124 
    125   virtual void Run() override;
    126 
    127   virtual void Stop() override;
    128 
    129   virtual void TransportOptionsFromTensor(
    130       ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
    131       Device* device, DeviceContext* device_context, bool on_host,
    132       StatusCallback done) override;
    133 
    134   virtual void TensorFromTransportOptions(
    135       Tensor* tensor, const ::google::protobuf::Any& transport_options,
    136       Device* device, DeviceContext* device_context, bool on_host,
    137       StatusCallback done) override;
    138 
    139  protected:
    140   Status CreateEndpoint(const string& host, const string& port,
    141                         RdmaEndpointPtr& endpoint);
    142 
    143   static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
    144     return ptr < reinterpret_cast<char*>(other->addr) + other->length;
    145   }
    146 
    147   ibv_mr* FindMemoryRegion(void* addr, size_t length);
    148 
    149   void InsertMemoryRegion(void* addr, size_t length);
    150 
    151   void EvictMemoryRegion(void* addr, size_t length);
    152 
    153  private:
    154   const string host_;
    155   const string port_;
    156   RdmaEndpointPtr listening_;
    157   std::atomic<bool> stopped_;
    158   int epfd_;
    159 
    160   // Server side endpoints
    161   // Accessed sequentially in Run() so not protected by lock
    162   std::list<RdmaEndpointPtr> server_clients_;
    163 
    164   using TensorKey = uint32_t;
    165   std::atomic<TensorKey> next_key_;
    166 
    167   // Server side on-the-fly tensor buffers
    168   mutex server_mu_;
    169   std::map<TensorKey, const TensorBuffer*> tensor_buffers_
    170       GUARDED_BY(server_mu_);
    171 
    172   // Client side endpoints
    173   mutex client_mu_;
    174   std::map<std::pair<string, string>, RdmaEndpointPtr> clients_
    175       GUARDED_BY(cient_mu_);
    176 
    177   // Managed memory regions
    178   mutex alloc_mu_;
    179   std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(alloc_mu_);
    180 
    181   TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
    182 };
    183 
    184 // TODO(byronyi): remove this class duplicated from the one in
    185 // common/runtime/gpu/pool_allocator.h when it is available in common_runtime
    186 class BasicCPUAllocator : public SubAllocator {
    187  public:
    188   ~BasicCPUAllocator() override {}
    189 
    190   void* Alloc(size_t alignment, size_t num_bytes) override {
    191     return port::AlignedMalloc(num_bytes, alignment);
    192   }
    193   void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
    194 };
    195 
    196 // TODO(byronyi): remove this class and its registration when the default
    197 // cpu_allocator() returns visitable allocator
    198 class BFCRdmaAllocator : public BFCAllocator {
    199  public:
    200   BFCRdmaAllocator()
    201       : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
    202   }
    203 };
    204 
    205 REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
    206 
    207 GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
    208     : host_(host),
    209       port_(port),
    210       listening_(nullptr, EndpointDeleter),
    211       stopped_(true),
    212       next_key_(0) {}
    213 
    214 GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
    215 
    216 Status GdrMemoryManager::Init() {
    217   epfd_ = epoll_create1(0);
    218   if (epfd_ == -1) {
    219     return errors::Unavailable(strerror(errno), ": ", "epoll_create");
    220   }
    221 
    222   rdma_addrinfo* addrinfo;
    223   rdma_addrinfo hints = {};
    224   hints.ai_port_space = RDMA_PS_TCP;
    225   hints.ai_flags = RAI_PASSIVE;
    226   if (rdma_getaddrinfo(const_cast<char*>(host_.c_str()),
    227                        const_cast<char*>(port_.c_str()), &hints, &addrinfo)) {
    228     return errors::Unavailable(strerror(errno), ": ", "cannot resolve rdma://",
    229                                host_, ":", port_);
    230   }
    231 
    232   ibv_qp_init_attr init_attr = {};
    233   init_attr.qp_type = IBV_QPT_RC;
    234   init_attr.cap.max_recv_wr = 32;
    235   init_attr.cap.max_send_wr = 1;
    236   init_attr.cap.max_recv_sge = 1;
    237   init_attr.cap.max_send_sge = 1;
    238 
    239   // Create listening endpoint
    240   rdma_cm_id* id;
    241   if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
    242     return errors::Unavailable(strerror(errno), ": ", "cannot bind to rdma://",
    243                                host_, ":", port_);
    244   }
    245   listening_.reset(id);
    246   rdma_freeaddrinfo(addrinfo);
    247 
    248   // Listen without backlog
    249   if (rdma_listen(listening_.get(), 0)) {
    250     return errors::Unavailable(strerror(errno), ": ",
    251                                "cannot listen on rdma://", host_, ":", port_);
    252   }
    253   LOG(INFO) << "RDMA server is listening on " << host_ << ":" << port_;
    254 
    255   if (listening_->verbs == nullptr) {
    256     return errors::Unimplemented(
    257         "Unsupported address ", host_, ":", port_,
    258         " as it does not bind to a particular RDMA device");
    259   }
    260 
    261   int flags = fcntl(listening_->channel->fd, F_GETFL, 0);
    262   if (fcntl(listening_->channel->fd, F_SETFL, flags | O_NONBLOCK)) {
    263     return errors::Unavailable(strerror(errno), ": ",
    264                                "cannot set server to non-blocking mode");
    265   }
    266 
    267   epoll_event event = {};
    268   event.events = EPOLLIN | EPOLLPRI;
    269   event.data.ptr = listening_.get();
    270   if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) {
    271     return errors::Unavailable(strerror(errno), ": ",
    272                                "cannot add server to epoll");
    273   }
    274 
    275   Allocator* allocators[] = {
    276 #if GOOGLE_CUDA
    277     ProcessState::singleton()->GetCUDAHostAllocator(0),
    278     ProcessState::singleton()->GetCPUAllocator(0),
    279 #endif  // GOOGLE_CUDA
    280     cpu_allocator(),
    281   };
    282 
    283   using namespace std::placeholders;
    284   VisitableAllocator::Visitor alloc_visitor =
    285       std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
    286   VisitableAllocator::Visitor free_visitor =
    287       std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
    288 
    289   std::set<Allocator*> instrumented_;
    290 
    291   // Host memory allocators
    292   for (Allocator* allocator : allocators) {
    293     auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
    294     CHECK(visitable_allocator)
    295         << "is not visitable for instrumentation" << allocator->Name();
    296     // Make sure we don't instrument the same allocator twice
    297     if (instrumented_.find(allocator) == std::end(instrumented_)) {
    298       visitable_allocator->AddAllocVisitor(alloc_visitor);
    299       visitable_allocator->AddFreeVisitor(free_visitor);
    300       instrumented_.insert(allocator);
    301       LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
    302     }
    303   }
    304 
    305 #if GOOGLE_CUDA
    306   VisitableAllocator::Visitor cuda_alloc_visitor =
    307       std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
    308   if (IsGDRAvailable()) {
    309     // Note we don't free allocated GPU memory so there is no free visitor
    310     int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
    311     ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
    312     LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
    313   }
    314 #endif  // GOOGLE_CUDA
    315 
    316   return Status::OK();
    317 }
    318 
    319 void GdrMemoryManager::Run() {
    320   stopped_ = false;
    321   while (!stopped_) {
    322     epoll_event events[32];
    323     int ret = epoll_wait(epfd_, events, 32, 1);
    324     if (ret == -1) {
    325       LOG(ERROR) << "epoll_wait: " << strerror(errno);
    326       return;
    327     }
    328     for (int i = 0; i < ret; i++) {
    329       rdma_cm_id* id = static_cast<rdma_cm_id*>(events[i].data.ptr);
    330       if (id == listening_.get()) {
    331         // Accept incoming connections
    332         if (!rdma_get_request(listening_.get(), &id)) {
    333           if (!rdma_accept(id, nullptr)) {
    334             LOG(INFO) << "Accepted new RDMA connection";
    335             if (ibv_req_notify_cq(id->recv_cq, 0)) {
    336               LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
    337               EndpointDeleter(id);
    338               continue;
    339             }
    340             for (int i = 0; i < 32; i++) {
    341               if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
    342                 LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed";
    343                 EndpointDeleter(id);
    344                 continue;
    345               }
    346             }
    347             int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0);
    348             if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) {
    349               LOG(ERROR) << strerror(errno)
    350                          << ": cannot set server_client to non-blocking mode";
    351               EndpointDeleter(id);
    352               continue;
    353             }
    354             epoll_event event = {};
    355             event.events = EPOLLIN | EPOLLPRI;
    356             event.data.ptr = id;
    357             if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd,
    358                           &event)) {
    359               LOG(ERROR) << strerror(errno)
    360                          << ": cannot add server client to epoll";
    361               EndpointDeleter(id);
    362               continue;
    363             }
    364             server_clients_.push_back({id, EndpointDeleter});
    365           }
    366         }
    367       } else {
    368         // Polling work completions
    369         ibv_cq* cq;
    370         void* context;
    371         if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) {
    372           ibv_ack_cq_events(id->recv_cq, 1);
    373           if (ibv_req_notify_cq(id->recv_cq, 0)) {
    374             LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
    375             continue;
    376           }
    377           ibv_wc wc[32];
    378           int ret = ibv_poll_cq(id->recv_cq, 32, wc);
    379           if (ret < 0) {
    380             LOG(ERROR) << "ibv_poll_cq failed";
    381             continue;
    382           }
    383           for (int i = 0; i < ret; i++) {
    384             if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) {
    385               LOG(ERROR) << "Received unknown operation " << wc[i].opcode;
    386             }
    387             if (wc[i].status != 0) {
    388               LOG(ERROR) << ibv_wc_status_str(wc[i].status);
    389             }
    390             TensorKey tensor_key = ntohl(wc[i].imm_data);
    391             {
    392               mutex_lock l(server_mu_);
    393               auto iter = tensor_buffers_.find(tensor_key);
    394               if (iter == std::end(tensor_buffers_)) {
    395                 LOG(ERROR) << "Cannot find tensor buffer for tensor key "
    396                            << tensor_key;
    397               } else {
    398                 const TensorBuffer* buffer = iter->second;
    399                 buffer->Unref();
    400                 tensor_buffers_.erase(iter);
    401               }
    402             }
    403             if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
    404               perror("rdma_post_recvv");
    405               LOG(ERROR) << "rdma_post_recvv failed";
    406               continue;
    407             }
    408           }
    409         }
    410       }
    411     }
    412   }
    413 }
    414 
    415 void GdrMemoryManager::Stop() { stopped_ = true; }
    416 
    417 void GdrMemoryManager::TransportOptionsFromTensor(
    418     ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
    419     Device* device, DeviceContext* device_context, bool on_host,
    420     StatusCallback done) {
    421   auto buffer = DMAHelper::buffer(&tensor);
    422   void* addr = buffer->data();
    423   size_t length = buffer->size();
    424   if (length == 0) {
    425     done(errors::Unavailable("Cannot register tensor buffer of size 0"));
    426     return;
    427   }
    428 
    429   ibv_mr* mr = FindMemoryRegion(addr, length);
    430 
    431 #if GOOGLE_CUDA
    432   if (!on_host) {
    433     Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
    434     Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape());
    435     GPUUtil::CopyGPUTensorToCPU(
    436         device, device_context, &tensor, host_copy,
    437         [done, host_copy, mutable_transport_options, this](const Status& s) {
    438           if (!s.ok()) {
    439             done(s);
    440             delete host_copy;
    441             return;
    442           }
    443           auto buffer = DMAHelper::buffer(host_copy);
    444           void* addr = buffer->data();
    445           size_t length = buffer->size();
    446           ibv_mr* mr = FindMemoryRegion(addr, length);
    447 
    448           if (mr == nullptr) {
    449             done(errors::Unavailable("Cannot find pinned memory region"));
    450             delete host_copy;
    451             return;
    452           }
    453 
    454           buffer->Ref();
    455           TensorKey tensor_key = next_key_++;
    456           {
    457             mutex_lock l(server_mu_);
    458             tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
    459           }
    460 
    461           uint64_t checksum = 0;
    462           if (VLOG_IS_ON(2)) {
    463             checksum = GPUUtil::Checksum(*host_copy);
    464           }
    465 
    466           RemoteMemoryRegion remote_mr;
    467           remote_mr.set_host(host_);
    468           remote_mr.set_port(port_);
    469           remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
    470           remote_mr.set_rkey(mr->rkey);
    471           remote_mr.set_tensor_key(tensor_key);
    472           remote_mr.set_checksum(checksum);
    473           mutable_transport_options->PackFrom(remote_mr);
    474 
    475           done(Status::OK());
    476           delete host_copy;
    477         });
    478     return;
    479   }
    480 #endif
    481 
    482   if (mr == nullptr) {
    483     done(errors::Unavailable("Cannot find pinned memory region"));
    484     return;
    485   }
    486 
    487   buffer->Ref();
    488   TensorKey tensor_key = next_key_++;
    489   {
    490     mutex_lock l(server_mu_);
    491     tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
    492   }
    493 
    494   uint64_t checksum = 0;
    495   if (VLOG_IS_ON(2)) {
    496 #ifdef GOOGLE_CUDA
    497     if (!on_host) {
    498       checksum = GPUUtil::Checksum(device, device_context, tensor);
    499     } else {
    500       checksum = GPUUtil::Checksum(tensor);
    501     }
    502 #endif
    503   }
    504 
    505   RemoteMemoryRegion remote_mr;
    506   remote_mr.set_host(host_);
    507   remote_mr.set_port(port_);
    508   remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
    509   remote_mr.set_rkey(mr->rkey);
    510   remote_mr.set_tensor_key(tensor_key);
    511   remote_mr.set_checksum(checksum);
    512   mutable_transport_options->PackFrom(remote_mr);
    513 
    514   done(Status::OK());
    515 }
    516 
    517 void GdrMemoryManager::TensorFromTransportOptions(
    518     Tensor* tensor, const ::google::protobuf::Any& transport_options,
    519     Device* device, DeviceContext* device_context, bool on_host,
    520     StatusCallback done) {
    521   RemoteMemoryRegion remote_mr;
    522   if (!transport_options.UnpackTo(&remote_mr)) {
    523     done(errors::NotFound("No RDMA transport options found"));
    524     return;
    525   }
    526 
    527   auto buffer = DMAHelper::buffer(tensor);
    528   void* addr = buffer->data();
    529   size_t length = buffer->size();
    530   ibv_mr* mr = FindMemoryRegion(addr, length);
    531 
    532   Tensor host_copy;
    533 #if GOOGLE_CUDA
    534   if (mr == nullptr && !on_host) {
    535     Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
    536     host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
    537     buffer = DMAHelper::buffer(&host_copy);
    538     addr = buffer->data();
    539     length = buffer->size();
    540     mr = FindMemoryRegion(addr, length);
    541   }
    542 #endif  // GOOGLE_CUDA
    543 
    544   if (mr == nullptr) {
    545     done(errors::Unavailable("Cannot find pinned memory region"));
    546     return;
    547   }
    548 
    549   decltype(clients_)::iterator iter;
    550   bool success;
    551   {
    552     mutex_lock l(client_mu_);
    553     std::tie(iter, success) = clients_.insert(
    554         std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
    555                        RdmaEndpointPtr(nullptr, EndpointDeleter)));
    556     if (success || iter->second.get() == nullptr) {
    557       Status s =
    558           CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second);
    559       if (!s.ok()) {
    560         done(s);
    561         return;
    562       }
    563     }
    564   }
    565   rdma_cm_id* id = iter->second.get();
    566 
    567   uint64_t start = Env::Default()->NowMicros();
    568 
    569   if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0,
    570                      remote_mr.addr(), remote_mr.rkey())) {
    571     done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed"));
    572     return;
    573   }
    574 
    575   ibv_send_wr wr = {};
    576   wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
    577   wr.imm_data = htonl(remote_mr.tensor_key());
    578   wr.send_flags = IBV_SEND_SIGNALED;
    579   ibv_send_wr* bad_wr;
    580   if (ibv_post_send(id->qp, &wr, &bad_wr)) {
    581     done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed"));
    582     return;
    583   }
    584 
    585   ibv_wc wc = {};
    586   int ret;
    587   while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0)
    588     ;
    589   if (ret < 0 || wc.status) {
    590     done(errors::Unavailable(ibv_wc_status_str(wc.status)));
    591     return;
    592   }
    593 
    594 #if GOOGLE_CUDA
    595   if (host_copy.NumElements() > 0) {
    596     uint64_t checksum = 0;
    597     if (VLOG_IS_ON(2)) {
    598       checksum = GPUUtil::Checksum(host_copy);
    599       CHECK(checksum == remote_mr.checksum())
    600           << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum();
    601     }
    602     Tensor* ref = new Tensor;
    603     std::swap(host_copy, *ref);
    604     GPUUtil::CopyCPUTensorToGPU(
    605         ref, device_context, device, tensor,
    606         [ref, done, buffer, remote_mr, start](const Status& s) {
    607           if (!s.ok()) {
    608             done(s);
    609             delete ref;
    610             return;
    611           }
    612           uint64_t end = Env::Default()->NowMicros();
    613 
    614           VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
    615                   << " of size " << buffer->size() << " with tensor key "
    616                   << remote_mr.tensor_key() << " took " << (end - start)
    617                   << " micros";
    618           done(Status::OK());
    619           delete ref;
    620         });
    621     return;
    622   }
    623 #endif  // GOOGLE_CUDA
    624 
    625   uint64_t end = Env::Default()->NowMicros();
    626 
    627   VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
    628           << " of size " << buffer->size() << " with tensor key "
    629           << remote_mr.tensor_key() << " took " << (end - start) << " micros";
    630 
    631   uint64_t checksum = 0;
    632   if (VLOG_IS_ON(2)) {
    633 #ifdef GOOGLE_CUDA
    634     if (device->tensorflow_gpu_device_info() && (!on_host)) {
    635       checksum = GPUUtil::Checksum(device, device_context, *tensor);
    636     } else {
    637       checksum = GPUUtil::Checksum(*tensor);
    638     }
    639     CHECK(checksum == remote_mr.checksum())
    640         << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum();
    641 #endif
    642   }
    643   done(Status::OK());
    644 }
    645 
    646 Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port,
    647                                         RdmaEndpointPtr& endpoint) {
    648   rdma_addrinfo* addrinfo;
    649   rdma_addrinfo hints = {};
    650   hints.ai_port_space = RDMA_PS_TCP;
    651   if (rdma_getaddrinfo(const_cast<char*>(host.c_str()),
    652                        const_cast<char*>(port.c_str()), &hints, &addrinfo)) {
    653     return errors::InvalidArgument(
    654         strerror(errno), ": ", "cannot connect to rdma://", host, ":", port);
    655   }
    656 
    657   ibv_qp_init_attr init_attr = {};
    658   init_attr.qp_type = IBV_QPT_RC;
    659   init_attr.cap.max_recv_wr = 1;
    660   init_attr.cap.max_send_wr = 32;
    661   init_attr.cap.max_recv_sge = 1;
    662   init_attr.cap.max_send_sge = 1;
    663 
    664   rdma_cm_id* id;
    665   if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
    666     rdma_freeaddrinfo(addrinfo);
    667     return errors::Unavailable(strerror(errno), ": ",
    668                                "cannot create endpoint to rdma://", host, ":",
    669                                port);
    670   }
    671   rdma_freeaddrinfo(addrinfo);
    672 
    673   if (rdma_connect(id, nullptr)) {
    674     rdma_destroy_ep(id);
    675     return errors::Unavailable(strerror(errno), ": ",
    676                                "cannot connect to rdma://", host, ":", port);
    677   }
    678 
    679   LOG(INFO) << "RDMA endpoint connected to rdma://" << host << ":" << port;
    680   endpoint = RdmaEndpointPtr(id, EndpointDeleter);
    681   return Status::OK();
    682 }
    683 
    684 ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) {
    685   if (length == 0) return nullptr;
    686   mutex_lock l(alloc_mu_);
    687   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
    688   if (iter == std::end(mrs_) || iter->get()->addr > addr) {
    689     return nullptr;
    690   } else {
    691     return iter->get();
    692   }
    693 }
    694 
    695 void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
    696   if (length == 0) return;
    697   ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
    698   if (mr != nullptr) {
    699     mutex_lock l(alloc_mu_);
    700     auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
    701     mrs_.insert(iter, {mr, &MRDeleter});
    702   } else {
    703     LOG(WARNING) << "Cannot register memory region";
    704   }
    705 }
    706 
    707 void GdrMemoryManager::EvictMemoryRegion(void* addr, size_t length) {
    708   if (length == 0) return;
    709   mutex_lock l(alloc_mu_);
    710   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
    711   if (iter != std::end(mrs_) && iter->get()->addr == addr) {
    712     mrs_.erase(iter);
    713   } else {
    714     LOG(WARNING) << "Failed to de-register memory region";
    715   }
    716 }
    717 
    718 }  // namespace
    719 
    720 RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
    721                                                const string& port) {
    722   return new GdrMemoryManager(host, port);
    723 }
    724 
    725 }  // namespace tensorflow
    726 
    727 #endif  // TENSORFLOW_USE_GDR
    728