1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef TENSORFLOW_USE_GDR 17 18 #include "tensorflow/contrib/gdr/gdr_memory_manager.h" 19 20 #include <atomic> 21 #include <cerrno> 22 #include <fstream> 23 #include <list> 24 #include <map> 25 #include <set> 26 27 #include <fcntl.h> 28 #include <rdma/rdma_cma.h> 29 #include <rdma/rdma_verbs.h> 30 #include <sys/epoll.h> 31 32 #include "tensorflow/contrib/gdr/gdr.pb.h" 33 #include "tensorflow/core/common_runtime/bfc_allocator.h" 34 #include "tensorflow/core/common_runtime/device.h" 35 #include "tensorflow/core/common_runtime/dma_helper.h" 36 #if GOOGLE_CUDA 37 #include "tensorflow/core/common_runtime/gpu/gpu_util.h" 38 #include "tensorflow/core/common_runtime/gpu/process_state.h" 39 #endif // GOOGLE_CUDA 40 #include "tensorflow/core/framework/allocator_registry.h" 41 #include "tensorflow/core/lib/core/status.h" 42 #include "tensorflow/core/platform/macros.h" 43 #include "tensorflow/core/platform/mutex.h" 44 45 namespace tensorflow { 46 47 namespace { 48 49 bool IsGDRAvailable() { 50 #if defined(__APPLE__) 51 return false; 52 #elif defined(PLATFORM_WINDOWS) 53 return false; 54 #else 55 std::ifstream ifs("/proc/modules"); 56 string line; 57 while (std::getline(ifs, line)) { 58 auto sep = line.find(' '); 59 CHECK_NE(sep, std::string::npos); 60 if (line.substr(0, sep) == "nv_peer_mem") { 61 return true; 62 } 63 } 64 return false; 65 #endif 66 } 67 68 int TryToReadNumaNode(ibv_device* device) { 69 #if defined(__APPLE__) 70 LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0"; 71 return 0; 72 #elif defined(PLATFORM_WINDOWS) 73 // Windows support for NUMA is not currently implemented. Return node 0. 74 return 0; 75 #else 76 VLOG(2) << "Trying to read NUMA node for device: " << device->name; 77 static const int kUnknownNumaNode = -1; 78 79 auto filename = string(device->ibdev_path) + "/device/numa_node"; 80 81 std::ifstream ifs(filename.c_str()); 82 string content; 83 CHECK(std::getline(ifs, content)); 84 85 int32 value; 86 if (strings::safe_strto32(content, &value)) { 87 if (value < 0) { 88 LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" 89 << value 90 << "), but there must be at least one NUMA node" 91 ", so returning NUMA node zero"; 92 return 0; 93 } 94 LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; 95 return value; 96 } 97 return kUnknownNumaNode; 98 #endif 99 } 100 101 void EndpointDeleter(rdma_cm_id* id) { 102 if (id) { 103 rdma_destroy_ep(id); 104 } 105 } 106 107 void MRDeleter(ibv_mr* mr) { 108 if (mr) { 109 rdma_dereg_mr(mr); 110 } 111 } 112 113 using RdmaEndpointPtr = std::unique_ptr<rdma_cm_id, decltype(&EndpointDeleter)>; 114 115 using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>; 116 117 class GdrMemoryManager : public RemoteMemoryManager { 118 public: 119 GdrMemoryManager(const string& host, const string& port); 120 121 virtual ~GdrMemoryManager(); 122 123 virtual Status Init() override; 124 125 virtual void Run() override; 126 127 virtual void Stop() override; 128 129 virtual void TransportOptionsFromTensor( 130 ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor, 131 Device* device, DeviceContext* device_context, bool on_host, 132 StatusCallback done) override; 133 134 virtual void TensorFromTransportOptions( 135 Tensor* tensor, const ::google::protobuf::Any& transport_options, 136 Device* device, DeviceContext* device_context, bool on_host, 137 StatusCallback done) override; 138 139 protected: 140 Status CreateEndpoint(const string& host, const string& port, 141 RdmaEndpointPtr& endpoint); 142 143 static bool Comparator(const void* ptr, const MemoryRegionPtr& other) { 144 return ptr < reinterpret_cast<char*>(other->addr) + other->length; 145 } 146 147 ibv_mr* FindMemoryRegion(void* addr, size_t length); 148 149 void InsertMemoryRegion(void* addr, size_t length); 150 151 void EvictMemoryRegion(void* addr, size_t length); 152 153 private: 154 const string host_; 155 const string port_; 156 RdmaEndpointPtr listening_; 157 std::atomic<bool> stopped_; 158 int epfd_; 159 160 // Server side endpoints 161 // Accessed sequentially in Run() so not protected by lock 162 std::list<RdmaEndpointPtr> server_clients_; 163 164 using TensorKey = uint32_t; 165 std::atomic<TensorKey> next_key_; 166 167 // Server side on-the-fly tensor buffers 168 mutex server_mu_; 169 std::map<TensorKey, const TensorBuffer*> tensor_buffers_ 170 GUARDED_BY(server_mu_); 171 172 // Client side endpoints 173 mutex client_mu_; 174 std::map<std::pair<string, string>, RdmaEndpointPtr> clients_ 175 GUARDED_BY(cient_mu_); 176 177 // Managed memory regions 178 mutex alloc_mu_; 179 std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(alloc_mu_); 180 181 TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager); 182 }; 183 184 // TODO(byronyi): remove this class duplicated from the one in 185 // common/runtime/gpu/pool_allocator.h when it is available in common_runtime 186 class BasicCPUAllocator : public SubAllocator { 187 public: 188 ~BasicCPUAllocator() override {} 189 190 void* Alloc(size_t alignment, size_t num_bytes) override { 191 return port::AlignedMalloc(num_bytes, alignment); 192 } 193 void Free(void* ptr, size_t) override { port::AlignedFree(ptr); } 194 }; 195 196 // TODO(byronyi): remove this class and its registration when the default 197 // cpu_allocator() returns visitable allocator 198 class BFCRdmaAllocator : public BFCAllocator { 199 public: 200 BFCRdmaAllocator() 201 : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") { 202 } 203 }; 204 205 REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator); 206 207 GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) 208 : host_(host), 209 port_(port), 210 listening_(nullptr, EndpointDeleter), 211 stopped_(true), 212 next_key_(0) {} 213 214 GdrMemoryManager::~GdrMemoryManager() { close(epfd_); } 215 216 Status GdrMemoryManager::Init() { 217 epfd_ = epoll_create1(0); 218 if (epfd_ == -1) { 219 return errors::Unavailable(strerror(errno), ": ", "epoll_create"); 220 } 221 222 rdma_addrinfo* addrinfo; 223 rdma_addrinfo hints = {}; 224 hints.ai_port_space = RDMA_PS_TCP; 225 hints.ai_flags = RAI_PASSIVE; 226 if (rdma_getaddrinfo(const_cast<char*>(host_.c_str()), 227 const_cast<char*>(port_.c_str()), &hints, &addrinfo)) { 228 return errors::Unavailable(strerror(errno), ": ", "cannot resolve rdma://", 229 host_, ":", port_); 230 } 231 232 ibv_qp_init_attr init_attr = {}; 233 init_attr.qp_type = IBV_QPT_RC; 234 init_attr.cap.max_recv_wr = 32; 235 init_attr.cap.max_send_wr = 1; 236 init_attr.cap.max_recv_sge = 1; 237 init_attr.cap.max_send_sge = 1; 238 239 // Create listening endpoint 240 rdma_cm_id* id; 241 if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) { 242 return errors::Unavailable(strerror(errno), ": ", "cannot bind to rdma://", 243 host_, ":", port_); 244 } 245 listening_.reset(id); 246 rdma_freeaddrinfo(addrinfo); 247 248 // Listen without backlog 249 if (rdma_listen(listening_.get(), 0)) { 250 return errors::Unavailable(strerror(errno), ": ", 251 "cannot listen on rdma://", host_, ":", port_); 252 } 253 LOG(INFO) << "RDMA server is listening on " << host_ << ":" << port_; 254 255 if (listening_->verbs == nullptr) { 256 return errors::Unimplemented( 257 "Unsupported address ", host_, ":", port_, 258 " as it does not bind to a particular RDMA device"); 259 } 260 261 int flags = fcntl(listening_->channel->fd, F_GETFL, 0); 262 if (fcntl(listening_->channel->fd, F_SETFL, flags | O_NONBLOCK)) { 263 return errors::Unavailable(strerror(errno), ": ", 264 "cannot set server to non-blocking mode"); 265 } 266 267 epoll_event event = {}; 268 event.events = EPOLLIN | EPOLLPRI; 269 event.data.ptr = listening_.get(); 270 if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) { 271 return errors::Unavailable(strerror(errno), ": ", 272 "cannot add server to epoll"); 273 } 274 275 Allocator* allocators[] = { 276 #if GOOGLE_CUDA 277 ProcessState::singleton()->GetCUDAHostAllocator(0), 278 ProcessState::singleton()->GetCPUAllocator(0), 279 #endif // GOOGLE_CUDA 280 cpu_allocator(), 281 }; 282 283 using namespace std::placeholders; 284 VisitableAllocator::Visitor alloc_visitor = 285 std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2); 286 VisitableAllocator::Visitor free_visitor = 287 std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2); 288 289 std::set<Allocator*> instrumented_; 290 291 // Host memory allocators 292 for (Allocator* allocator : allocators) { 293 auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator); 294 CHECK(visitable_allocator) 295 << "is not visitable for instrumentation" << allocator->Name(); 296 // Make sure we don't instrument the same allocator twice 297 if (instrumented_.find(allocator) == std::end(instrumented_)) { 298 visitable_allocator->AddAllocVisitor(alloc_visitor); 299 visitable_allocator->AddFreeVisitor(free_visitor); 300 instrumented_.insert(allocator); 301 LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name(); 302 } 303 } 304 305 #if GOOGLE_CUDA 306 VisitableAllocator::Visitor cuda_alloc_visitor = 307 std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2); 308 if (IsGDRAvailable()) { 309 // Note we don't free allocated GPU memory so there is no free visitor 310 int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1; 311 ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); 312 LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; 313 } 314 #endif // GOOGLE_CUDA 315 316 return Status::OK(); 317 } 318 319 void GdrMemoryManager::Run() { 320 stopped_ = false; 321 while (!stopped_) { 322 epoll_event events[32]; 323 int ret = epoll_wait(epfd_, events, 32, 1); 324 if (ret == -1) { 325 LOG(ERROR) << "epoll_wait: " << strerror(errno); 326 return; 327 } 328 for (int i = 0; i < ret; i++) { 329 rdma_cm_id* id = static_cast<rdma_cm_id*>(events[i].data.ptr); 330 if (id == listening_.get()) { 331 // Accept incoming connections 332 if (!rdma_get_request(listening_.get(), &id)) { 333 if (!rdma_accept(id, nullptr)) { 334 LOG(INFO) << "Accepted new RDMA connection"; 335 if (ibv_req_notify_cq(id->recv_cq, 0)) { 336 LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; 337 EndpointDeleter(id); 338 continue; 339 } 340 for (int i = 0; i < 32; i++) { 341 if (rdma_post_recvv(id, nullptr, nullptr, 0)) { 342 LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed"; 343 EndpointDeleter(id); 344 continue; 345 } 346 } 347 int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0); 348 if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) { 349 LOG(ERROR) << strerror(errno) 350 << ": cannot set server_client to non-blocking mode"; 351 EndpointDeleter(id); 352 continue; 353 } 354 epoll_event event = {}; 355 event.events = EPOLLIN | EPOLLPRI; 356 event.data.ptr = id; 357 if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd, 358 &event)) { 359 LOG(ERROR) << strerror(errno) 360 << ": cannot add server client to epoll"; 361 EndpointDeleter(id); 362 continue; 363 } 364 server_clients_.push_back({id, EndpointDeleter}); 365 } 366 } 367 } else { 368 // Polling work completions 369 ibv_cq* cq; 370 void* context; 371 if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) { 372 ibv_ack_cq_events(id->recv_cq, 1); 373 if (ibv_req_notify_cq(id->recv_cq, 0)) { 374 LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed"; 375 continue; 376 } 377 ibv_wc wc[32]; 378 int ret = ibv_poll_cq(id->recv_cq, 32, wc); 379 if (ret < 0) { 380 LOG(ERROR) << "ibv_poll_cq failed"; 381 continue; 382 } 383 for (int i = 0; i < ret; i++) { 384 if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) { 385 LOG(ERROR) << "Received unknown operation " << wc[i].opcode; 386 } 387 if (wc[i].status != 0) { 388 LOG(ERROR) << ibv_wc_status_str(wc[i].status); 389 } 390 TensorKey tensor_key = ntohl(wc[i].imm_data); 391 { 392 mutex_lock l(server_mu_); 393 auto iter = tensor_buffers_.find(tensor_key); 394 if (iter == std::end(tensor_buffers_)) { 395 LOG(ERROR) << "Cannot find tensor buffer for tensor key " 396 << tensor_key; 397 } else { 398 const TensorBuffer* buffer = iter->second; 399 buffer->Unref(); 400 tensor_buffers_.erase(iter); 401 } 402 } 403 if (rdma_post_recvv(id, nullptr, nullptr, 0)) { 404 perror("rdma_post_recvv"); 405 LOG(ERROR) << "rdma_post_recvv failed"; 406 continue; 407 } 408 } 409 } 410 } 411 } 412 } 413 } 414 415 void GdrMemoryManager::Stop() { stopped_ = true; } 416 417 void GdrMemoryManager::TransportOptionsFromTensor( 418 ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor, 419 Device* device, DeviceContext* device_context, bool on_host, 420 StatusCallback done) { 421 auto buffer = DMAHelper::buffer(&tensor); 422 void* addr = buffer->data(); 423 size_t length = buffer->size(); 424 if (length == 0) { 425 done(errors::Unavailable("Cannot register tensor buffer of size 0")); 426 return; 427 } 428 429 ibv_mr* mr = FindMemoryRegion(addr, length); 430 431 #if GOOGLE_CUDA 432 if (!on_host) { 433 Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); 434 Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape()); 435 GPUUtil::CopyGPUTensorToCPU( 436 device, device_context, &tensor, host_copy, 437 [done, host_copy, mutable_transport_options, this](const Status& s) { 438 if (!s.ok()) { 439 done(s); 440 delete host_copy; 441 return; 442 } 443 auto buffer = DMAHelper::buffer(host_copy); 444 void* addr = buffer->data(); 445 size_t length = buffer->size(); 446 ibv_mr* mr = FindMemoryRegion(addr, length); 447 448 if (mr == nullptr) { 449 done(errors::Unavailable("Cannot find pinned memory region")); 450 delete host_copy; 451 return; 452 } 453 454 buffer->Ref(); 455 TensorKey tensor_key = next_key_++; 456 { 457 mutex_lock l(server_mu_); 458 tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); 459 } 460 461 uint64_t checksum = 0; 462 if (VLOG_IS_ON(2)) { 463 checksum = GPUUtil::Checksum(*host_copy); 464 } 465 466 RemoteMemoryRegion remote_mr; 467 remote_mr.set_host(host_); 468 remote_mr.set_port(port_); 469 remote_mr.set_addr(reinterpret_cast<uint64_t>(addr)); 470 remote_mr.set_rkey(mr->rkey); 471 remote_mr.set_tensor_key(tensor_key); 472 remote_mr.set_checksum(checksum); 473 mutable_transport_options->PackFrom(remote_mr); 474 475 done(Status::OK()); 476 delete host_copy; 477 }); 478 return; 479 } 480 #endif 481 482 if (mr == nullptr) { 483 done(errors::Unavailable("Cannot find pinned memory region")); 484 return; 485 } 486 487 buffer->Ref(); 488 TensorKey tensor_key = next_key_++; 489 { 490 mutex_lock l(server_mu_); 491 tensor_buffers_.insert(std::make_pair(tensor_key, buffer)); 492 } 493 494 uint64_t checksum = 0; 495 if (VLOG_IS_ON(2)) { 496 #ifdef GOOGLE_CUDA 497 if (!on_host) { 498 checksum = GPUUtil::Checksum(device, device_context, tensor); 499 } else { 500 checksum = GPUUtil::Checksum(tensor); 501 } 502 #endif 503 } 504 505 RemoteMemoryRegion remote_mr; 506 remote_mr.set_host(host_); 507 remote_mr.set_port(port_); 508 remote_mr.set_addr(reinterpret_cast<uint64_t>(addr)); 509 remote_mr.set_rkey(mr->rkey); 510 remote_mr.set_tensor_key(tensor_key); 511 remote_mr.set_checksum(checksum); 512 mutable_transport_options->PackFrom(remote_mr); 513 514 done(Status::OK()); 515 } 516 517 void GdrMemoryManager::TensorFromTransportOptions( 518 Tensor* tensor, const ::google::protobuf::Any& transport_options, 519 Device* device, DeviceContext* device_context, bool on_host, 520 StatusCallback done) { 521 RemoteMemoryRegion remote_mr; 522 if (!transport_options.UnpackTo(&remote_mr)) { 523 done(errors::NotFound("No RDMA transport options found")); 524 return; 525 } 526 527 auto buffer = DMAHelper::buffer(tensor); 528 void* addr = buffer->data(); 529 size_t length = buffer->size(); 530 ibv_mr* mr = FindMemoryRegion(addr, length); 531 532 Tensor host_copy; 533 #if GOOGLE_CUDA 534 if (mr == nullptr && !on_host) { 535 Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); 536 host_copy = Tensor(alloc, tensor->dtype(), tensor->shape()); 537 buffer = DMAHelper::buffer(&host_copy); 538 addr = buffer->data(); 539 length = buffer->size(); 540 mr = FindMemoryRegion(addr, length); 541 } 542 #endif // GOOGLE_CUDA 543 544 if (mr == nullptr) { 545 done(errors::Unavailable("Cannot find pinned memory region")); 546 return; 547 } 548 549 decltype(clients_)::iterator iter; 550 bool success; 551 { 552 mutex_lock l(client_mu_); 553 std::tie(iter, success) = clients_.insert( 554 std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()), 555 RdmaEndpointPtr(nullptr, EndpointDeleter))); 556 if (success || iter->second.get() == nullptr) { 557 Status s = 558 CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second); 559 if (!s.ok()) { 560 done(s); 561 return; 562 } 563 } 564 } 565 rdma_cm_id* id = iter->second.get(); 566 567 uint64_t start = Env::Default()->NowMicros(); 568 569 if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0, 570 remote_mr.addr(), remote_mr.rkey())) { 571 done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed")); 572 return; 573 } 574 575 ibv_send_wr wr = {}; 576 wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; 577 wr.imm_data = htonl(remote_mr.tensor_key()); 578 wr.send_flags = IBV_SEND_SIGNALED; 579 ibv_send_wr* bad_wr; 580 if (ibv_post_send(id->qp, &wr, &bad_wr)) { 581 done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed")); 582 return; 583 } 584 585 ibv_wc wc = {}; 586 int ret; 587 while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0) 588 ; 589 if (ret < 0 || wc.status) { 590 done(errors::Unavailable(ibv_wc_status_str(wc.status))); 591 return; 592 } 593 594 #if GOOGLE_CUDA 595 if (host_copy.NumElements() > 0) { 596 uint64_t checksum = 0; 597 if (VLOG_IS_ON(2)) { 598 checksum = GPUUtil::Checksum(host_copy); 599 CHECK(checksum == remote_mr.checksum()) 600 << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); 601 } 602 Tensor* ref = new Tensor; 603 std::swap(host_copy, *ref); 604 GPUUtil::CopyCPUTensorToGPU( 605 ref, device_context, device, tensor, 606 [ref, done, buffer, remote_mr, start](const Status& s) { 607 if (!s.ok()) { 608 done(s); 609 delete ref; 610 return; 611 } 612 uint64_t end = Env::Default()->NowMicros(); 613 614 VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() 615 << " of size " << buffer->size() << " with tensor key " 616 << remote_mr.tensor_key() << " took " << (end - start) 617 << " micros"; 618 done(Status::OK()); 619 delete ref; 620 }); 621 return; 622 } 623 #endif // GOOGLE_CUDA 624 625 uint64_t end = Env::Default()->NowMicros(); 626 627 VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey() 628 << " of size " << buffer->size() << " with tensor key " 629 << remote_mr.tensor_key() << " took " << (end - start) << " micros"; 630 631 uint64_t checksum = 0; 632 if (VLOG_IS_ON(2)) { 633 #ifdef GOOGLE_CUDA 634 if (device->tensorflow_gpu_device_info() && (!on_host)) { 635 checksum = GPUUtil::Checksum(device, device_context, *tensor); 636 } else { 637 checksum = GPUUtil::Checksum(*tensor); 638 } 639 CHECK(checksum == remote_mr.checksum()) 640 << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); 641 #endif 642 } 643 done(Status::OK()); 644 } 645 646 Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port, 647 RdmaEndpointPtr& endpoint) { 648 rdma_addrinfo* addrinfo; 649 rdma_addrinfo hints = {}; 650 hints.ai_port_space = RDMA_PS_TCP; 651 if (rdma_getaddrinfo(const_cast<char*>(host.c_str()), 652 const_cast<char*>(port.c_str()), &hints, &addrinfo)) { 653 return errors::InvalidArgument( 654 strerror(errno), ": ", "cannot connect to rdma://", host, ":", port); 655 } 656 657 ibv_qp_init_attr init_attr = {}; 658 init_attr.qp_type = IBV_QPT_RC; 659 init_attr.cap.max_recv_wr = 1; 660 init_attr.cap.max_send_wr = 32; 661 init_attr.cap.max_recv_sge = 1; 662 init_attr.cap.max_send_sge = 1; 663 664 rdma_cm_id* id; 665 if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) { 666 rdma_freeaddrinfo(addrinfo); 667 return errors::Unavailable(strerror(errno), ": ", 668 "cannot create endpoint to rdma://", host, ":", 669 port); 670 } 671 rdma_freeaddrinfo(addrinfo); 672 673 if (rdma_connect(id, nullptr)) { 674 rdma_destroy_ep(id); 675 return errors::Unavailable(strerror(errno), ": ", 676 "cannot connect to rdma://", host, ":", port); 677 } 678 679 LOG(INFO) << "RDMA endpoint connected to rdma://" << host << ":" << port; 680 endpoint = RdmaEndpointPtr(id, EndpointDeleter); 681 return Status::OK(); 682 } 683 684 ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) { 685 if (length == 0) return nullptr; 686 mutex_lock l(alloc_mu_); 687 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); 688 if (iter == std::end(mrs_) || iter->get()->addr > addr) { 689 return nullptr; 690 } else { 691 return iter->get(); 692 } 693 } 694 695 void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) { 696 if (length == 0) return; 697 ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length); 698 if (mr != nullptr) { 699 mutex_lock l(alloc_mu_); 700 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); 701 mrs_.insert(iter, {mr, &MRDeleter}); 702 } else { 703 LOG(WARNING) << "Cannot register memory region"; 704 } 705 } 706 707 void GdrMemoryManager::EvictMemoryRegion(void* addr, size_t length) { 708 if (length == 0) return; 709 mutex_lock l(alloc_mu_); 710 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); 711 if (iter != std::end(mrs_) && iter->get()->addr == addr) { 712 mrs_.erase(iter); 713 } else { 714 LOG(WARNING) << "Failed to de-register memory region"; 715 } 716 } 717 718 } // namespace 719 720 RemoteMemoryManager* CreateRemoteMemoryManager(const string& host, 721 const string& port) { 722 return new GdrMemoryManager(host, port); 723 } 724 725 } // namespace tensorflow 726 727 #endif // TENSORFLOW_USE_GDR 728