1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef TENSORFLOW_USE_VERBS 17 18 #include <fcntl.h> 19 #include <cstdlib> 20 21 #include "tensorflow/contrib/verbs/rdma.h" 22 #include "tensorflow/contrib/verbs/verbs_service.pb.h" 23 #include "tensorflow/core/common_runtime/device_mgr.h" 24 #include "tensorflow/core/common_runtime/dma_helper.h" 25 #include "tensorflow/core/common_runtime/process_util.h" 26 #if GOOGLE_CUDA 27 #include "tensorflow/core/common_runtime/gpu/gpu_util.h" 28 #include "tensorflow/core/common_runtime/gpu/process_state.h" 29 #endif 30 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" 31 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 32 #include "tensorflow/core/distributed_runtime/session_mgr.h" 33 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 34 #include "tensorflow/core/framework/rendezvous.h" 35 #include "tensorflow/core/framework/tensor.h" 36 #include "tensorflow/core/lib/core/status.h" 37 #include "tensorflow/core/lib/core/stringpiece.h" 38 #include "tensorflow/core/lib/core/threadpool.h" 39 #include "tensorflow/core/lib/hash/hash.h" 40 #include "tensorflow/core/lib/random/random.h" 41 42 namespace tensorflow { 43 44 #define RoCE_V2 "RoCE v2" 45 46 namespace { 47 48 // convenience function for printing message 49 string MessageTypeToString(RdmaMessageType rmt) { 50 switch (rmt) { 51 case RDMA_MESSAGE_META_DATA_UPDATE: 52 return "RDMA_MESSAGE_META_DATA_UPDATE"; 53 break; 54 case RDMA_MESSAGE_TENSOR_RE_REQUEST: 55 return "RDMA_MESSAGE_TENSOR_RE_REQUEST"; 56 break; 57 case RDMA_MESSAGE_TENSOR_REQUEST: 58 return "RDMA_MESSAGE_TENSOR_REQUEST"; 59 break; 60 default: 61 return "UNKNOWN MESSAGE"; 62 } 63 } 64 } // namespace 65 66 // Function to get environment variable 67 // Args: 68 // var_name - the name of the environmental variable 69 // Returns: 70 // string with it's value or empty string if not set 71 string get_env_var(char const* var_name) { 72 char const* var_temp = getenv(var_name); 73 74 return (var_temp == NULL) ? string() : string(var_temp); 75 } 76 77 // Function to open device 78 // Args: 79 // ibv_dev device to open 80 // Returns: 81 // context of the opened device 82 ibv_context* open_device(ibv_device* ibv_dev) { 83 ibv_context* context = ibv_open_device(ibv_dev); 84 85 CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev); 86 return context; 87 } 88 89 // Function to count the number of active ports for device 90 // Args: 91 // device - to check active ports 92 // Returns: 93 // number of active ports of the given device 94 int get_dev_active_port_count(ibv_device* device) { 95 ibv_device_attr device_att; 96 ibv_port_attr port_attr; 97 ibv_context* context = NULL; 98 int rc, port_index, active_ports = 0; 99 100 context = ibv_open_device(device); 101 CHECK(context) << "Open context failed for " << ibv_get_device_name(device); 102 rc = ibv_query_device(context, &device_att); 103 CHECK(!rc) << "Failed to query the device"; 104 105 for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { 106 rc = ibv_query_port(context, port_index, &port_attr); 107 CHECK(!rc) << "Failed to query the port" << port_index; 108 if (port_attr.state == IBV_PORT_ACTIVE) { 109 active_ports++; 110 } 111 } 112 ibv_close_device(context); 113 return active_ports; 114 } 115 116 // Function to set device. If RDMA_DEVICE not set, search for device with active 117 // port. 118 // Fails if more than one device with active port was found. 119 // Returns: 120 // device to use 121 ibv_device* set_device() { 122 ibv_device** dev_list; 123 int dev_num, device_index, device_to_open = 0; 124 int num_devs_with_active_port = 0; 125 string env_p_rdma_device, str_port_num; 126 127 dev_list = ibv_get_device_list(&dev_num); 128 CHECK(dev_list) << "No InfiniBand device found"; 129 130 env_p_rdma_device = get_env_var("RDMA_DEVICE"); 131 if (!env_p_rdma_device.empty()) { 132 for (device_index = 0; device_index < dev_num; device_index++) { 133 if (!env_p_rdma_device.compare( 134 ibv_get_device_name(dev_list[device_index]))) { 135 CHECK(get_dev_active_port_count(dev_list[device_index]) != 0) 136 << "Device " << ibv_get_device_name(dev_list[device_index]) 137 << " has no active ports"; 138 return dev_list[device_index]; 139 } 140 } 141 // check validity of input device 142 CHECK(false) << "The device " << env_p_rdma_device << " wasn't found"; 143 } else { 144 // set default device 145 str_port_num = get_env_var("RDMA_DEVICE_PORT"); 146 CHECK(str_port_num.empty()) 147 << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user"; 148 for (device_index = 0; device_index < dev_num; device_index++) { 149 // get port_num 150 if (get_dev_active_port_count(dev_list[device_index]) > 0) { 151 num_devs_with_active_port++; 152 CHECK(num_devs_with_active_port <= 1) << ". More than one device with " 153 "active port in the system. " 154 "Please enter RDMA_DEVICE"; 155 // found device with at least 1 active port 156 device_to_open = device_index; 157 } 158 } 159 CHECK(num_devs_with_active_port > 0) 160 << "There is no active port in the system"; 161 return dev_list[device_to_open]; 162 } 163 CHECK(false) << "No device was set!"; 164 return NULL; // never happens 165 } 166 167 // Function to set port for device. 168 // If RDMA_DEVICE_PORT not set, first active port of the device will be set. 169 // Args: 170 // context of the device 171 // Returns: 172 // port to use 173 uint8_t set_port(ibv_context* context) { 174 uint8_t port_num = 0; // 0 is illegal port number 175 string str_port_num; 176 ibv_device_attr device_att; 177 ibv_port_attr port_attr; 178 int rc, port_index; 179 180 rc = ibv_query_device(context, &device_att); 181 CHECK(!rc) << "Failed to query the device\n"; 182 183 str_port_num = get_env_var("RDMA_DEVICE_PORT"); 184 // user defined port 185 if (!str_port_num.empty()) { 186 port_num = stoi(str_port_num); 187 CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive"; 188 CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be " 189 "less or equal to amount of " 190 "available ports"; 191 rc = ibv_query_port(context, port_num, &port_attr); 192 CHECK(!rc) << "Failed to query the port" << port_num; 193 // check if port id active 194 CHECK(port_attr.state == IBV_PORT_ACTIVE) 195 << "Selected RDMA_DEVICE_PORT is not active"; 196 } else { // set default port 197 for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { 198 rc = ibv_query_port(context, port_index, &port_attr); 199 CHECK(!rc) << "Failed to query the port" << port_index; 200 if (port_attr.state == IBV_PORT_ACTIVE) { 201 port_num = port_index; 202 break; 203 } 204 } 205 CHECK_GT(port_num, 0) << "No active ports"; 206 } 207 return port_num; 208 } 209 210 // Function read from sysfs file 211 // Args: 212 // dir - directory 213 // file - file 214 // buff - buffer for the result 215 // size - buffer size 216 // Returns: 217 // number of bytes were read or -1 if failed 218 int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) { 219 char* path; 220 int fd; 221 int len; 222 223 if (asprintf(&path, "%s/%s", dir, file) < 0) return -1; 224 225 fd = open(path, O_RDONLY); 226 if (fd < 0) { 227 free(path); 228 return -1; 229 } 230 231 len = read(fd, buf, size); 232 233 close(fd); 234 free(path); 235 236 if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0'; 237 238 return len; 239 } 240 241 // Function to check if GID index support RoCE V2 242 // Args: 243 // context - device context 244 // port_num - port number 245 // index - GID index 246 // Returns: 247 // if GID supports RoCE V2 - true, otherwise - false. 248 bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num, 249 uint8_t index) { 250 char name[32]; 251 char buff[41]; 252 253 snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index); 254 if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <= 255 0) { 256 return false; 257 } 258 return !strcmp(buff, RoCE_V2); 259 } 260 261 // Function to set GID index. 262 // If the port link is IB, no GID index should be selected. 263 // If Ethernet but RDMA_GID_INDEX not set gid index that supports 264 // RoCE V2 will be chosen(fails if more than one IP is configured) 265 // Args: 266 // context - device context 267 // port_num - port number 268 // Returns: 269 // GID index to use 270 uint8_t set_gid(uint8_t port_num, ibv_context* context) { 271 ibv_port_attr port_attr; 272 string gid_str; 273 int rc, i, gids_num = 0, v2_ip_num = 0; 274 union ibv_gid gid; 275 uint8_t gid_index = 0; 276 277 rc = ibv_query_port(context, port_num, &port_attr); 278 CHECK(!rc) << "Failed to query the port" << port_num; 279 280 for (i = 0; i < port_attr.gid_tbl_len; i++) { 281 rc = ibv_query_gid(context, port_num, i, &gid); 282 CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index " 283 << i; 284 if (gid.global.interface_id) { 285 gids_num++; 286 if (gid.global.subnet_prefix == 0 && 287 is_gid_type_roce_v2(context, port_num, i)) { 288 if (v2_ip_num == 0) { 289 // can be overwritten by RDMA_GID_INDEX later 290 gid_index = i; 291 } 292 v2_ip_num++; 293 } 294 } 295 } 296 switch (port_attr.link_layer) { 297 case (IBV_LINK_LAYER_ETHERNET): 298 gid_str = get_env_var("RDMA_GID_INDEX"); 299 if (!gid_str.empty()) { 300 gid_index = stoi(gid_str); 301 CHECK(gid_index < gids_num) 302 << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num; 303 } else { 304 CHECK(v2_ip_num <= 1) 305 << "More than one IP is available, please specify GID_INDEX"; 306 } 307 break; 308 case (IBV_LINK_LAYER_INFINIBAND): // no need in GID index 309 break; 310 default: 311 LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and " 312 "InfiniBand only. "; 313 } 314 if (!is_gid_type_roce_v2(context, port_num, gid_index)) { 315 LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index; 316 } 317 return gid_index; 318 } 319 320 // set the default or environment value to the configuration parameter. 321 // Args: 322 // default_val- the default value for this parameter 323 // env_param- the environment parameter's name 324 // Returns: 325 // 32-bit value 326 uint32_t set_param(uint32_t default_val, const char* env_param) { 327 uint32_t val = default_val; 328 string val_s; 329 330 val_s = get_env_var(env_param); 331 332 if (!val_s.empty()) { 333 val = stoi(val_s); 334 } 335 return val; 336 } 337 338 enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) { 339 ibv_port_attr port_attr; 340 enum ibv_mtu mtu = IBV_MTU_512; 341 string mtu_s; 342 int rc, mtu_i; 343 344 rc = ibv_query_port(context, port_num, &port_attr); 345 CHECK(!rc) << "Failed to query the port" << port_num; 346 347 mtu_s = get_env_var("RDMA_MTU"); 348 349 if (!mtu_s.empty()) { 350 mtu_i = stoi(mtu_s); 351 switch (mtu_i) { 352 case 256: 353 mtu = IBV_MTU_256; 354 break; 355 case 512: 356 mtu = IBV_MTU_512; 357 break; 358 case 1024: 359 mtu = IBV_MTU_1024; 360 break; 361 case 2048: 362 mtu = IBV_MTU_2048; 363 break; 364 case 4096: 365 mtu = IBV_MTU_4096; 366 break; 367 default: 368 CHECK(0) << "Error: MTU input value must be one of the following: 256, " 369 "512, 1024, 2048, 4096. MTU " 370 << mtu << " is invalid\n"; 371 break; 372 } 373 CHECK(mtu < port_attr.active_mtu) 374 << "MTU configuration for the QPs is larger than active MTU"; 375 } else { 376 mtu = port_attr.active_mtu; 377 } 378 return mtu; 379 } 380 381 RdmaParams params_init(ibv_context* context) { 382 RdmaParams params; 383 384 params.port_num = set_port(context); 385 params.sgid_index = set_gid(params.port_num, context); 386 params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY"); 387 params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH"); 388 params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT"); 389 params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT"); 390 params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL"); 391 CHECK(params.sl <= 7) << "SL value is " << (int)params.sl 392 << ". Valid values are 0-7."; 393 params.mtu = set_mtu(params.port_num, context); 394 params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS"); 395 return params; 396 } 397 398 ibv_pd* alloc_protection_domain(ibv_context* context) { 399 ibv_pd* pd = ibv_alloc_pd(context); 400 CHECK(pd) << "Failed to allocate protection domain"; 401 return pd; 402 } 403 404 RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env) 405 : context_(open_device(set_device())), 406 params_(params_init(context_)), 407 pd_(alloc_protection_domain(context_)), 408 worker_env_(worker_env) { 409 event_channel_ = ibv_create_comp_channel(context_); 410 CHECK(event_channel_) << "Failed to create completion channel"; 411 cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_, 412 0); 413 CHECK(cq_) << "Failed to create completion queue"; 414 CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification"; 415 } 416 417 RdmaAdapter::~RdmaAdapter() { 418 polling_thread_.reset(); 419 CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ"; 420 CHECK(!ibv_destroy_comp_channel(event_channel_)) 421 << "Failed to destroy channel"; 422 CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD"; 423 CHECK(!ibv_close_device(context_)) << "Failed to release context"; 424 } 425 426 void RdmaAdapter::StartPolling() { 427 polling_thread_.reset(Env::Default()->StartThread( 428 ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); })); 429 VLOG(2) << "Start RdmaAdapter: " << name(); 430 } 431 432 string RdmaAdapter::name() const { return string(context_->device->name); } 433 434 // Function to process incoming messages 435 // There are two types of messages: 436 // 1. IBV_WC_RECV_RDMA_WITH_IMM (receive) 437 // 2. IBV_WC_RDMA_WRITE (send)) 438 void RdmaAdapter::Process_CQ() { 439 while (true) { 440 ibv_cq* cq; 441 void* cq_context; 442 CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context)); 443 CHECK(cq == cq_); 444 ibv_ack_cq_events(cq, 1); 445 CHECK(!ibv_req_notify_cq(cq_, 0)); 446 447 int ne = 448 ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_)); 449 CHECK_GE(ne, 0); 450 for (int i = 0; i < ne; ++i) { 451 CHECK(wc_[i].status == IBV_WC_SUCCESS) 452 << "Failed status \n" 453 << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " " 454 << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err; 455 if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) { 456 RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id); 457 // put back a recv wr. 458 rc->Recv(); 459 // imm_data is the index of RX buffer in the buffer table. 460 uint32_t imm_data = wc_[i].imm_data; 461 RdmaMessageBuffer* rb; 462 RdmaMessage rm; 463 464 if (imm_data == RDMA_IMM_DATA_ACK) { 465 // receive an ack to a message 466 rb = rc->tx_message_buffer_; 467 rb->SetBufferStatus(remote, idle); 468 rb->SendNextItem(); 469 continue; 470 } 471 472 if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) { 473 // receive a tensor RDMA write 474 uint32_t request_index = imm_data; 475 RdmaTensorRequest* request = rc->GetTensorRequest(request_index); 476 request->RecvTensorContent(); 477 continue; 478 } 479 480 // receive a control message 481 rb = rc->rx_message_buffer_; 482 RdmaMessage::ParseMessage(rm, rb->buffer_); 483 RdmaMessageBuffer::SendAck(rc); 484 RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec 485 << ": Received " << MessageTypeToString(rm.type_) << " " 486 << "#" << rm.request_index_ << ": " << rm.name_; 487 488 if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) { 489 RdmaTensorResponse* response = rc->AddTensorResponse(rm); 490 response->Start(); 491 } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) { 492 RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_); 493 request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_, 494 rm.is_dead_, rm.tensor_bytes_); 495 #ifdef RDMA_DATA_VALIDATION 496 request->RecvTensorChecksum(rm.checksum_); 497 #endif 498 } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) { 499 RdmaTensorResponse* response = rc->UpdateTensorResponse(rm); 500 response->Resume(); 501 } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { 502 RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_); 503 request->RecvErrorStatus(rm.status_); 504 } 505 } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) { 506 RdmaWriteID* wr_id = reinterpret_cast<RdmaWriteID*>(wc_[i].wr_id); 507 RDMA_LOG(2) << "Write complete of type " << wr_id->write_type; 508 switch (wr_id->write_type) { 509 case RDMA_WRITE_ID_ACK: 510 break; 511 case RDMA_WRITE_ID_MESSAGE: { 512 RdmaMessageBuffer* rb = 513 reinterpret_cast<RdmaMessageBuffer*>(wr_id->write_context); 514 rb->SetBufferStatus(local, idle); 515 rb->SendNextItem(); 516 break; 517 } 518 case RDMA_WRITE_ID_TENSOR_WRITE: { 519 RdmaTensorResponse* response = 520 reinterpret_cast<RdmaTensorResponse*>(wr_id->write_context); 521 response->Destroy(); 522 } 523 } 524 delete wr_id; 525 } 526 } 527 } 528 } 529 530 int RdmaChannel::PingPostRecv() { 531 struct ibv_recv_wr wr, *bad_wr; 532 memset(&wr, 0, sizeof(wr)); 533 wr.sg_list = &ping_sge_list_; 534 wr.num_sge = 1; 535 wr.wr_id = kPingRecvWrid; 536 537 return ibv_post_recv(qp_, &wr, &bad_wr); 538 } 539 540 int RdmaChannel::PingPostSend() { 541 struct ibv_send_wr wr, *bad_wr; 542 memset(&wr, 0, sizeof(wr)); 543 wr.wr_id = (uint64_t)this; 544 wr.sg_list = &ping_sge_list_; 545 wr.num_sge = 1; 546 wr.opcode = IBV_WR_SEND; 547 wr.send_flags = IBV_SEND_SIGNALED; 548 549 return ibv_post_send(qp_, &wr, &bad_wr); 550 } 551 552 RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, 553 const string remote_name) 554 : adapter_(adapter), 555 local_name_(local_name), 556 remote_name_(remote_name), 557 request_serial_(0) { 558 struct ibv_sge list; 559 560 mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize, 561 IBV_ACCESS_LOCAL_WRITE); 562 CHECK(mr_) << "Failed to register memory region"; 563 564 memset(&list, 0, sizeof(list)); 565 list.addr = (uintptr_t)ping_buff_; 566 list.length = kPingBuffSize; 567 list.lkey = mr_->lkey; 568 569 ping_sge_list_ = list; 570 // Create queue pair 571 { 572 struct ibv_qp_init_attr attr; 573 memset(&attr, 0, sizeof(ibv_qp_init_attr)); 574 attr.send_cq = adapter_->cq_; 575 attr.recv_cq = adapter_->cq_; 576 attr.cap.max_send_wr = adapter_->params_.queue_depth; 577 attr.cap.max_recv_wr = adapter_->params_.queue_depth; 578 attr.cap.max_send_sge = 1; 579 attr.cap.max_recv_sge = 1; 580 attr.qp_type = IBV_QPT_RC; 581 582 qp_ = ibv_create_qp(adapter_->pd_, &attr); 583 CHECK(qp_) << "Failed to create queue pair"; 584 } 585 586 // Init queue pair 587 { 588 struct ibv_qp_attr attr; 589 memset(&attr, 0, sizeof(ibv_qp_attr)); 590 attr.qp_state = IBV_QPS_INIT; 591 attr.pkey_index = adapter_->params_.pkey_index; 592 attr.port_num = adapter_->params_.port_num; 593 attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE; 594 595 int mask = 596 IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; 597 CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT"; 598 } 599 600 // Local address 601 { 602 struct ibv_port_attr attr; 603 CHECK( 604 !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr)) 605 << "Query port"; 606 self_.lid = attr.lid; 607 self_.qpn = qp_->qp_num; 608 self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff; 609 union ibv_gid gid; 610 CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num, 611 adapter_->params_.sgid_index, &gid)) 612 << "Query gid"; 613 self_.snp = gid.global.subnet_prefix; 614 self_.iid = gid.global.interface_id; 615 } 616 617 // create message and ack buffers, then initialize the tables. 618 { 619 const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer"}; 620 tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]); 621 rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]); 622 message_buffers_.reserve(kNumMessageBuffers); 623 message_buffers_.push_back(tx_message_buffer_); 624 message_buffers_.push_back(rx_message_buffer_); 625 // create buffer on host 626 tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize); 627 rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize); 628 } 629 CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_ 630 << " with error " << std::strerror(errno); 631 } 632 633 RdmaChannel::~RdmaChannel() { 634 ibv_dereg_mr(mr_); 635 CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP"; 636 delete tx_message_buffer_; 637 delete rx_message_buffer_; 638 } 639 640 void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { 641 mutex_lock lock{mu_}; 642 if ((override) || (!remote_set_)) { 643 remote_.lid = ra.lid; 644 remote_.qpn = ra.qpn; 645 remote_.psn = ra.psn; 646 remote_.snp = ra.snp; 647 remote_.iid = ra.iid; 648 remote_set_ = true; 649 } else { 650 CHECK(remote_.lid == ra.lid); 651 CHECK(remote_.qpn == ra.qpn); 652 CHECK(remote_.psn == ra.psn); 653 CHECK(remote_.snp == ra.snp); 654 CHECK(remote_.iid == ra.iid); 655 } 656 } 657 658 // Adding tokens to the completion queue 659 // Tokens are needed to process future messages. 660 void RdmaChannel::Recv() { 661 struct ibv_recv_wr wr; 662 memset(&wr, 0, sizeof(wr)); 663 wr.wr_id = (uint64_t)this; 664 struct ibv_recv_wr* bad_wr; 665 CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; 666 } 667 668 RdmaTensorRequest* RdmaChannel::InsertTensorRequest( 669 const string& key, int64 step_id, Device* dst_dev, 670 const Rendezvous::Args recv_args, 671 const RdmaTensorRequest::RecvDoneCallback& done) { 672 mutex_lock lock{ct_mu_}; 673 uint32_t request_index = request_serial_++; 674 if (request_serial_ > RDMA_IMM_MAX_REQUEST_ID) { 675 request_serial_ = 0; 676 } 677 RdmaTensorRequest request(request_index, key, step_id, this, dst_dev, 678 recv_args, done); 679 auto it = request_table_.emplace(request_index, request); 680 return &it.first->second; 681 } 682 683 void RdmaChannel::RemoveTensorRequest(uint32_t request_index) { 684 mutex_lock lock{ct_mu_}; 685 request_table_.erase(request_index); 686 } 687 688 RdmaTensorRequest* RdmaChannel::GetTensorRequest(uint32_t request_index) { 689 mutex_lock lock{ct_mu_}; 690 RequestTable::iterator iter = request_table_.find(request_index); 691 CHECK(iter != request_table_.end()); 692 return &iter->second; 693 } 694 695 void RdmaChannel::Connect() { 696 { 697 mutex_lock lock{mu_}; 698 CHECK(remote_set_) << "remote channel is not set"; 699 } 700 Connect(remote_); 701 } 702 703 // Setup channel to a remote node 704 // Args: 705 // remoteAddr: the rdma address of a remote channel. 706 // Returns: 707 // None 708 void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { 709 mutex_lock lock{mu_}; 710 if (!connected_) { 711 struct ibv_qp_attr attr; 712 memset(&attr, 0, sizeof(ibv_qp_attr)); 713 attr.qp_state = IBV_QPS_RTR; 714 715 // This assumes both QP's ports are configured with the same MTU 716 attr.path_mtu = adapter_->params_.mtu; 717 attr.dest_qp_num = remoteAddr.qpn; 718 attr.rq_psn = remoteAddr.psn; 719 attr.max_dest_rd_atomic = 1; 720 attr.min_rnr_timer = 12; 721 attr.ah_attr.is_global = 1; 722 attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp; 723 attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid; 724 attr.ah_attr.grh.flow_label = 0; 725 attr.ah_attr.grh.hop_limit = 255; 726 attr.ah_attr.dlid = remoteAddr.lid; 727 attr.ah_attr.sl = adapter_->params_.sl; 728 attr.ah_attr.src_path_bits = 0; 729 attr.ah_attr.port_num = adapter_->params_.port_num; 730 attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index; 731 attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class; 732 733 int r; 734 CHECK(!(r = ibv_modify_qp(qp_, &attr, 735 IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | 736 IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | 737 IBV_QP_MAX_DEST_RD_ATOMIC | 738 IBV_QP_MIN_RNR_TIMER))) 739 << "QP to Ready to Receive " << r; 740 741 memset(&attr, 0, sizeof(ibv_qp_attr)); 742 attr.qp_state = IBV_QPS_RTS; 743 attr.sq_psn = self_.psn; 744 attr.timeout = adapter_->params_.timeout; 745 attr.retry_cnt = adapter_->params_.retry_cnt; 746 attr.rnr_retry = 7; /* infinite */ 747 attr.max_rd_atomic = 1; 748 749 CHECK(!(r = ibv_modify_qp(qp_, &attr, 750 IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | 751 IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | 752 IBV_QP_MAX_QP_RD_ATOMIC))) 753 << "QP to Ready to Send " << r; 754 755 connected_ = true; 756 } else { 757 RDMA_LOG(2) << "channel already connected"; 758 } 759 } 760 761 RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name) 762 : channel_(channel), name_(name) {} 763 764 RdmaMessageBuffer::~RdmaMessageBuffer() { 765 CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed"; 766 FreeBuffer(); 767 } 768 769 void RdmaMessageBuffer::FreeBuffer() { 770 if ((buffer_ != nullptr) && buffer_on_host_) { 771 free(buffer_); 772 } 773 } 774 775 // Allocate CPU memory for the Rdma buffer 776 // Args: 777 // size: to-be-allocated memory size 778 // lock: whether or not mutex_lock the process to protect concurrency. 779 // Returns: 780 // None 781 void RdmaMessageBuffer::CreateCPUBuffer(size_t size, bool lock) { 782 CHECK(size > 0); 783 if (lock) { 784 mu_.lock(); 785 } 786 if (local_status_ != none) { 787 // delete existing buffer 788 CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed"; 789 FreeBuffer(); 790 } 791 size_ = size; 792 buffer_ = malloc(size_); 793 self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_, 794 IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); 795 CHECK(self_) << "Failed to register memory region"; 796 buffer_on_host_ = true; 797 local_status_ = idle; 798 if (lock) { 799 mu_.unlock(); 800 } 801 } 802 803 // Set address of remote memory region 804 // Args: 805 // rmr: address of remote memory region 806 // override: whether override existing information 807 // Returns: 808 // None 809 void RdmaMessageBuffer::SetRemoteMR(RemoteMR rmr, bool override) { 810 mutex_lock lock{mu_}; 811 if ((override) || (remote_status_ == none)) { 812 remote_.remote_addr = rmr.remote_addr; 813 remote_.rkey = rmr.rkey; 814 remote_status_ = idle; 815 } else { 816 CHECK(remote_.remote_addr == rmr.remote_addr); 817 CHECK(remote_.rkey == rmr.rkey); 818 } 819 } 820 821 // Put a task in the buffer's job queue 822 void RdmaMessageBuffer::EnqueueItem(string item) { 823 mutex_lock lock{mu_}; 824 queue_.push(item); 825 } 826 827 // Rdma-Write the content of the buffer 828 void RdmaMessageBuffer::Write(uint32_t imm_data, size_t buffer_size) { 829 Write(channel_, imm_data, buffer_size, (uint64_t)buffer_, self_->lkey, 830 remote_.remote_addr, remote_.rkey, RDMA_WRITE_ID_MESSAGE, this); 831 } 832 833 // Generalized Write method 834 void RdmaMessageBuffer::Write(const RdmaChannel* channel, uint32_t imm_data, 835 size_t buffer_size, uint64_t src_addr, 836 uint32_t lkey, uint64_t remote_addr, 837 uint32_t rkey, RdmaWriteIDType write_type, 838 void* write_context) { 839 struct ibv_sge list; 840 list.addr = src_addr; 841 list.length = buffer_size; 842 list.lkey = lkey; 843 844 struct ibv_send_wr wr; 845 memset(&wr, 0, sizeof(wr)); 846 wr.wr_id = (uint64_t) new RdmaWriteID(write_type, write_context); 847 wr.sg_list = &list; 848 wr.num_sge = 1; 849 wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; 850 wr.send_flags = IBV_SEND_SIGNALED; 851 wr.imm_data = imm_data; 852 wr.wr.rdma.remote_addr = remote_addr; 853 wr.wr.rdma.rkey = rkey; 854 855 struct ibv_send_wr* bad_wr; 856 CHECK(!ibv_post_send(channel->qp_, &wr, &bad_wr)) << "Failed to post send"; 857 } 858 859 // Send the next ack from the buffer's job queue. 860 void RdmaMessageBuffer::SendAck(const RdmaChannel* channel) { 861 Write(channel, RDMA_IMM_DATA_ACK, 0, 0, 0, 0, 0, RDMA_WRITE_ID_ACK, nullptr); 862 } 863 864 // Send the next message from the buffer's job queue. 865 void RdmaMessageBuffer::SendNextItem() { 866 uint32_t imm_data = RDMA_IMM_DATA_MESSAGE; 867 mu_.lock(); 868 if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) { 869 local_status_ = busy; 870 remote_status_ = busy; 871 string message = queue_.front(); 872 queue_.pop(); 873 // local/remote_status_ won't be set back to idle 874 // unitl Write() is successful 875 mu_.unlock(); 876 memcpy(buffer_, message.data(), message.size()); 877 Write(imm_data, message.size()); 878 } else { 879 mu_.unlock(); 880 } 881 } 882 883 #if GOOGLE_CUDA 884 static void CountCopies(const std::string& key, void* src_addr, void* dst_addr, 885 size_t tensor_bytes, bool is_gpu_to_cpu) { 886 #ifdef RDMA_COUNT_COPIES 887 static uint64_t numGPUToCPUCopies = 0; 888 static uint64_t numGPUToCPUCopiedBytes = 0; 889 static uint64_t numCPUToGPUCopies = 0; 890 static uint64_t numCPUToGPUCopiedBytes = 0; 891 static uint64_t numTotalCopies = 0; 892 893 if (is_gpu_to_cpu) { 894 ++numGPUToCPUCopies; 895 numGPUToCPUCopiedBytes += tensor_bytes; 896 } else { 897 ++numCPUToGPUCopies; 898 numCPUToGPUCopiedBytes += tensor_bytes; 899 } 900 if ((++numTotalCopies % 0x400) == 0) { 901 RDMA_LOG(0) << "Tensor copies:" 902 << " GPU to CPU: " << numGPUToCPUCopies << " (" 903 << numGPUToCPUCopiedBytes << " Bytes)" 904 << " CPU to GPU: " << numCPUToGPUCopies << " (" 905 << numCPUToGPUCopiedBytes << " Bytes)"; 906 } 907 RDMA_LOG(2) << "Copying tensor " << key << " From: " << src_addr 908 << " To: " << dst_addr; 909 #endif // RDMA_COUNT_COPIES 910 } 911 #endif // GOOGLE_CUDA 912 913 #ifdef RDMA_DATA_VALIDATION 914 static uint64_t Checksum(Device* device, const DeviceContext* device_context, 915 const Tensor& in) { 916 uint64 checksum = 0; 917 if (DataTypeCanUseMemcpy(in.dtype())) { 918 #if GOOGLE_CUDA 919 if (in.TotalBytes() == 0) { 920 return 0; 921 } 922 checksum = (device_context != nullptr) 923 ? GPUUtil::Checksum(device, device_context, in) 924 : GPUUtil::Checksum(in); 925 #endif // GOOGLE_CUDA 926 } else { 927 string s = in.SummarizeValue(999999); 928 checksum = Hash64(s.c_str(), s.size(), 0); 929 } 930 return checksum; 931 } 932 933 static void ValidateChecksum(uint64_t expected, uint64_t actual, 934 const Tensor& in, uint32_t request_index, 935 const std::string& key, const std::string& msg) { 936 RDMA_LOG(2) << "Request #" << request_index << ": " << key 937 << ": Checksum: " << std::hex << " Expected = 0x" << expected 938 << ". Actual = 0x" << actual << "."; 939 940 if (expected != actual) { 941 // Checksum failed. There is one case where this is allowed - if the 942 // tensor is an AssignAdd of the global step. Since the data-validation 943 // always postpones the Tensor response in order to send a checksum message, 944 // it is possible that the global-step was updated while the response was 945 // still in queue. 946 if ((in.TotalBytes() == 8) && (in.dtype() == DT_INT64)) { 947 int64_t prev_val = *(int64_t*)DMAHelper::base(&in) - 1; 948 actual = Hash64((const char*)&prev_val, 8, 0); 949 } 950 if (expected != actual) { 951 LOG(FATAL) << "[" << msg << "]: Checksum validation failed for request #" 952 << request_index << ": " << key << std::hex << " " 953 << DataTypeString(in.dtype()) << " " 954 << in.shape().DebugString() << " (0x" << in.TotalBytes() 955 << " bytes): " 956 << " Expected 0x" << expected << ". Got 0x" << actual << "."; 957 } 958 } 959 } 960 #endif // RDMA_DATA_VALIDATION 961 962 #if GOOGLE_CUDA 963 // Sync the 'done' operation on the GPU stream, but without all the data 964 // copying. 965 static void StreamGPUOp(Device* gpu_device, const DeviceContext* device_context, 966 StatusCallback done) { 967 Tensor dummy1, dummy2; 968 GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context, &dummy1, &dummy2, 969 done); 970 } 971 #endif // GOOGLE_CUDA 972 973 RdmaTensorResponse* RdmaChannel::AddTensorResponse(const RdmaMessage& rm) { 974 mutex_lock lock{mu_}; 975 auto it = 976 responses_table_.emplace(rm.request_index_, RdmaTensorResponse(this, rm)); 977 CHECK(it.second) << "Response with the ID " << rm.request_index_ 978 << " already exists."; 979 return &it.first->second; 980 } 981 982 RdmaTensorResponse* RdmaChannel::UpdateTensorResponse(const RdmaMessage& rm) { 983 mutex_lock lock{mu_}; 984 auto it = responses_table_.find(rm.request_index_); 985 CHECK(it != responses_table_.end()) << "No response found."; 986 RdmaTensorResponse* response = &it->second; 987 response->Update(rm); 988 return response; 989 } 990 991 void RdmaChannel::RemoveTensorResponse(uint32_t request_index) { 992 mutex_lock lock{mu_}; 993 responses_table_.erase(request_index); 994 } 995 996 void RdmaTensorResponse::Start() { 997 Rendezvous::ParsedKey parsed; 998 Status s = Rendezvous::ParseKey(rm_.name_, &parsed); 999 if (!s.ok()) { 1000 SendErrorStatus(s); 1001 return; 1002 } 1003 1004 channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync( 1005 rm_.step_id_, parsed, 1006 [this, parsed](const Status& status, const Rendezvous::Args& send_args, 1007 const Rendezvous::Args& recv_args, const Tensor& in, 1008 bool is_dead) { 1009 CHECK(status.ok()) << "RecvLocalAsync was not ok." 1010 << " error message: " << status.error_message(); 1011 RecvHandler(parsed, send_args, recv_args, in, is_dead); 1012 }); 1013 } 1014 1015 void RdmaTensorResponse::Resume() { SendContent(*tensor_, *proto_, is_dead_); } 1016 1017 // Helper for RecvTensor. Validates "key" and returns the source 1018 // device in "*src_dev". 1019 Status RdmaTensorResponse::PrepareRecvTensor( 1020 const Rendezvous::ParsedKey& parsed, Device** src_dev) { 1021 // Figures out which device the tensor is hosted on. 1022 string local_name = DeviceNameUtils::LocalName(parsed.src_device); 1023 TF_RETURN_IF_ERROR(channel_->adapter_->worker_env_->device_mgr->LookupDevice( 1024 local_name, src_dev)); 1025 1026 // Does the device have the right incarnation number we expect? 1027 if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) { 1028 return errors::Aborted( 1029 "RecvTensor expects a different device incarnation: ", 1030 parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(), 1031 ". Your worker job was probably restarted. Check your " 1032 "worker job for the reason why it was restarted."); 1033 } 1034 1035 return Status::OK(); 1036 } 1037 1038 void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed, 1039 const Rendezvous::Args& send_args, 1040 const Rendezvous::Args& recv_args, 1041 const Tensor& in, bool is_dead) { 1042 Status s = PrepareRecvTensor(parsed, &src_dev_); 1043 if (!s.ok()) { 1044 SendErrorStatus(s); 1045 return; 1046 } 1047 1048 meta_data_changed_ = TensorMetaDataChanged(in, is_dead); 1049 #ifdef RDMA_DATA_VALIDATION 1050 // Always send a meta data message with the source checksum 1051 meta_data_changed_ = rm_.type_ == RDMA_MESSAGE_TENSOR_REQUEST; 1052 checksum_ = Checksum(src_dev_, send_args.device_context, in); 1053 #endif 1054 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); 1055 // string tensor needs to be serialized 1056 Tensor copy; 1057 TensorProto proto; 1058 const bool on_host = send_args.alloc_attrs.on_host(); 1059 if (src_dev_->tensorflow_gpu_device_info() && !on_host) { 1060 #if GOOGLE_CUDA 1061 DeviceContext* send_dev_context = send_args.device_context; 1062 CHECK(send_dev_context) 1063 << "send dev name: " << src_dev_->name() 1064 << " gpu_info: " << src_dev_->tensorflow_gpu_device_info(); 1065 1066 if (can_memcpy) { 1067 // If the tensor is located on a GDR compatible GPU, there is no need to 1068 // copy it. We can send directly from the source, just need to make sure 1069 // we are in sync with the GPU stream. 1070 // If the tensor's meta-data changed however, we will need to clone it, 1071 // so anyway we'll have to copy it from GPU to CPU first. If at some 1072 // point in time Clone() is changed to only save a shallow copy, we can 1073 // skip the copy here as well. 1074 if ((in.TotalBytes() > 0) && !meta_data_changed_ && 1075 (RdmaMemoryMgr::Singleton().FindMemoryRegion( 1076 (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) { 1077 StreamGPUOp(src_dev_, send_dev_context, 1078 [this, in, proto, is_dead](const Status& s) { 1079 Send(in, proto, is_dead, s); 1080 }); 1081 return; 1082 } 1083 1084 // The tensor must be copied from GPU to CPU, because either: 1085 // 1. The tensor is located on a non GDR compatible GPU. 1086 // 2. The tensor's meta-data has changed. 1087 Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0); 1088 copy = Tensor(alloc, in.dtype(), in.shape()); 1089 CountCopies(rm_.name_, (void*)DMAHelper::base(&in), 1090 (void*)DMAHelper::base(©), in.TotalBytes(), true); 1091 GPUUtil::CopyGPUTensorToCPU( 1092 src_dev_, send_dev_context, &in, ©, 1093 [this, copy, proto, is_dead](const Status& s) { 1094 Send(copy, proto, is_dead, s); 1095 }); 1096 } else { 1097 GPUUtil::SetProtoFromGPU( 1098 in, src_dev_, send_args.device_context, &proto, is_dead, 1099 [this, in, proto, is_dead](const Status& s) mutable { 1100 Send(in, proto, is_dead, s); 1101 }); 1102 } 1103 #else 1104 SendErrorStatus(errors::Internal("No GPU device in process")); 1105 #endif // GOOGLE_CUDA 1106 } else { 1107 // tensor is in CPU memory. 1108 if (!can_memcpy) { 1109 in.AsProtoTensorContent(&proto); 1110 } 1111 Send(in, proto, is_dead, Status::OK()); 1112 } 1113 } 1114 1115 void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto, 1116 bool is_dead, const Status& status) { 1117 if (!status.ok()) { 1118 SendErrorStatus(status); 1119 return; 1120 } 1121 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); 1122 bool proto_size_changed = 1123 (!can_memcpy) && (proto.ByteSize() != rm_.tensor_bytes_); 1124 if (meta_data_changed_ || proto_size_changed) { 1125 Clone(in, proto, is_dead); 1126 SendMetaData(in, proto, is_dead); 1127 } else { 1128 SendContent(in, proto, is_dead); 1129 } 1130 } 1131 1132 bool RdmaTensorResponse::TensorMetaDataChanged(const Tensor& in, bool is_dead) { 1133 return (rm_.data_type_ != in.dtype()) || (rm_.tensor_shape_ != in.shape()) || 1134 (rm_.is_dead_ != is_dead); 1135 } 1136 1137 void RdmaTensorResponse::Clone(const Tensor& in, const TensorProto& proto, 1138 bool is_dead) { 1139 // Clone the data to be sent later. For simplicity, we clone the tensor's 1140 // data even if it is already a copy. Performance is less of a concern here 1141 // since the meta-data hardly ever changes. The reason we create a copy, is 1142 // that some tensors share their buffer between different step-ids, so the 1143 // tensor content may change before re-request was completed. 1144 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); 1145 if (can_memcpy && (in.TotalBytes() > 0)) { 1146 AllocatorAttributes host_alloc_attrs; 1147 host_alloc_attrs.set_nic_compatible(true); 1148 host_alloc_attrs.set_on_host(true); 1149 Allocator* allocator = src_dev_->GetAllocator(host_alloc_attrs); 1150 tensor_ = new Tensor(allocator, in.dtype(), in.shape()); 1151 memcpy(DMAHelper::base(tensor_), DMAHelper::base(&in), in.TotalBytes()); 1152 } else { 1153 tensor_ = new Tensor(in.dtype(), in.shape()); 1154 } 1155 if (!can_memcpy) { 1156 proto_ = new TensorProto(proto); 1157 } 1158 is_dead_ = is_dead; 1159 } 1160 1161 void RdmaTensorResponse::SendMetaData(const Tensor& in, 1162 const TensorProto& proto, bool is_dead) { 1163 RDMA_LOG(2) << "Request #" << rm_.request_index_ 1164 << ": Meta data changed: " << rm_.name_; 1165 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); 1166 size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize(); 1167 1168 // Send meta-data update: 1169 RdmaMessage rm; 1170 rm.type_ = RDMA_MESSAGE_META_DATA_UPDATE; 1171 rm.name_size_ = rm_.name_.size(); 1172 rm.name_ = rm_.name_; 1173 rm.tensor_shape_ = in.shape(); 1174 rm.data_type_ = in.dtype(); 1175 rm.step_id_ = rm_.step_id_; 1176 rm.is_dead_ = is_dead; 1177 rm.tensor_bytes_ = tensor_bytes; 1178 rm.request_index_ = rm_.request_index_; 1179 #ifdef RDMA_DATA_VALIDATION 1180 rm.checksum_ = checksum_; 1181 #endif 1182 RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec 1183 << ": Sending RDMA_MESSAGE_META_DATA_UPDATE #" 1184 << rm.request_index_ << ": " << rm.name_ 1185 << " (shape = " << rm.tensor_shape_.DebugString() << "." 1186 << " data-type = " << DataTypeString(rm.data_type_) << "." 1187 << " is-dead = " << rm.is_dead_ << ")"; 1188 1189 string message = RdmaMessage::CreateMessage(rm); 1190 channel_->tx_message_buffer_->EnqueueItem(message); 1191 channel_->tx_message_buffer_->SendNextItem(); 1192 } 1193 1194 void RdmaTensorResponse::SendContent(const Tensor& in, const TensorProto& proto, 1195 bool is_dead) { 1196 bool can_memcpy = DataTypeCanUseMemcpy(in.dtype()); 1197 size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize(); 1198 uint32_t imm_data = rm_.request_index_; 1199 if (!is_dead) { 1200 if (can_memcpy) { 1201 src_buffer_ = const_cast<TensorBuffer*>(DMAHelper::buffer(&in)); 1202 if (src_buffer_ != nullptr) { 1203 src_buffer_->Ref(); // Keep buffer alive until write is complete 1204 src_addr_ = src_buffer_->data(); 1205 mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(src_addr_, 1206 tensor_bytes); 1207 } 1208 } else { 1209 RDMA_LOG(2) << "Encoding proto: " << rm_.name_ 1210 << " (Size: " << tensor_bytes << ") " << in.DebugString(); 1211 src_addr_ = malloc(tensor_bytes); 1212 mr_ = ibv_reg_mr(channel_->adapter_->pd_, src_addr_, tensor_bytes, 1213 IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); 1214 proto.SerializeToArray(src_addr_, tensor_bytes); 1215 } 1216 } else { 1217 tensor_bytes = 0; 1218 } 1219 1220 uint32_t lkey = (mr_ == nullptr) ? 0 : mr_->lkey; 1221 RDMA_LOG(1) << "Step 0x" << std::hex << rm_.step_id_ << std::dec 1222 << ": Sending tensor content #" << rm_.request_index_ << " from " 1223 << std::hex << src_addr_ << " (0x" << lkey << ")" 1224 << " to " << rm_.remote_addr_ << " (0x" << rm_.rkey_ 1225 << "): " << rm_.name_ << " (size: 0x" << std::hex << tensor_bytes 1226 << ")"; 1227 1228 RdmaMessageBuffer::Write(channel_, imm_data, tensor_bytes, 1229 (uint64_t)src_addr_, lkey, rm_.remote_addr_, 1230 rm_.rkey_, RDMA_WRITE_ID_TENSOR_WRITE, this); 1231 } 1232 1233 void RdmaTensorResponse::SendErrorStatus(const Status& status) { 1234 RdmaMessage rm; 1235 rm.type_ = RDMA_MESSAGE_ERROR_STATUS; 1236 rm.name_size_ = rm_.name_.size(); 1237 rm.name_ = rm_.name_; 1238 rm.step_id_ = rm_.step_id_; 1239 rm.request_index_ = rm_.request_index_; 1240 rm.status_ = status; 1241 LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec 1242 << ": Sending RDMA_MESSAGE_ERROR_STATUS #" << rm.request_index_ 1243 << ": " << rm.name_ << ". Status: " << status.ToString(); 1244 1245 string message = RdmaMessage::CreateMessage(rm); 1246 channel_->tx_message_buffer_->EnqueueItem(message); 1247 channel_->tx_message_buffer_->SendNextItem(); 1248 1249 // Destroy the response. 1250 Destroy(); 1251 } 1252 1253 void RdmaTensorResponse::Destroy() { 1254 if (src_buffer_ != nullptr) { 1255 src_buffer_->Unref(); 1256 } 1257 if (tensor_ != nullptr) { 1258 delete tensor_; 1259 } 1260 if (proto_ != nullptr) { 1261 ibv_dereg_mr(mr_); 1262 free(src_addr_); 1263 delete proto_; 1264 } 1265 // Remove response from the pending list: 1266 channel_->RemoveTensorResponse(rm_.request_index_); 1267 } 1268 1269 // Create a RdmaMessage according to the pre-defined format 1270 // Args: 1271 // rm: the message structure 1272 // Returns: 1273 // message in string format 1274 string RdmaMessage::CreateMessage(const RdmaMessage& rm) { 1275 // Rdma Message format 1276 // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|... 1277 // 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |... 1278 // ...|data_type|tensor_shape|tensor_bytes|error_status | 1279 // ...| XB | XB | 8B |size - 4B, proto - XB | 1280 // 1281 // ACK: Imm-type: ACK 1282 // TENSOR_REQUEST: Imm-type: MESSAGE 1283 // Fields: type, request_index, name, step_id, remote_addr, 1284 // rkey, is_dead, data_type, tensor_shape, tensor_bytes 1285 // META_DATA_UPDATE: Imm-type: MESSAGE 1286 // Fields: type, request_index, is_dead, data_type, 1287 // tensor_shape, tensor_bytes 1288 // TENSOR_RE_REQUST: Imm-type: MESSAGE 1289 // Fields: type, request_index, name, step_id, remote_addr, 1290 // rkey, is_dead, data_type, tensor_shape, tensor_bytes 1291 // ERROR_STATUS: Imm-type: MESSAGE 1292 // Fields: type, request_index, name, step_id, error_status 1293 // Tensor content: Imm-type: request_index 1294 size_t message_size = kMessageTotalBytes; 1295 char message[kMessageTotalBytes + kErrorStatusMaxSize]; 1296 // type 1297 message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff; 1298 // request index 1299 memcpy(&message[kRequestIndexStartIndex], &rm.request_index_, 1300 sizeof(rm.request_index_)); 1301 // name, step_id, remote_addr, rkey 1302 if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || 1303 (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { 1304 memcpy(&message[kNameSizeStartIndex], &rm.name_size_, 1305 sizeof(rm.name_size_)); 1306 memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size()); 1307 memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_, 1308 sizeof(rm.remote_addr_)); 1309 memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_)); 1310 memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_)); 1311 } 1312 // is_dead, data_type, tensor_shape, tensor_bytes 1313 if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || 1314 (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) || 1315 (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { 1316 memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_)); 1317 1318 memcpy(&message[kDataTypeStartIndex], &rm.data_type_, 1319 sizeof(rm.data_type_)); 1320 memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_, 1321 sizeof(rm.tensor_shape_)); 1322 memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_, 1323 sizeof(rm.tensor_bytes_)); 1324 } 1325 // checksum 1326 #ifdef RDMA_DATA_VALIDATION 1327 memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_)); 1328 #endif 1329 // error status 1330 if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { 1331 ::grpc::Status gs = ToGrpcStatus(rm.status_); 1332 ErrorStatusProto gsProto; 1333 gsProto.set_error_code(gs.error_code()); 1334 gsProto.set_error_message(gs.error_message()); 1335 gsProto.set_error_details(gs.error_details()); 1336 uint32_t gsProtoSize = gsProto.ByteSize(); 1337 if (gsProtoSize + 4 > kErrorStatusMaxSize) { 1338 LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) " 1339 << "is too big to fit in RDMA message (" << kErrorStatusMaxSize 1340 << " bytes). Truncated."; 1341 gsProtoSize = kErrorStatusMaxSize - 4; 1342 } 1343 uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex]; 1344 *proto_size = gsProtoSize; 1345 gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize); 1346 message_size += gsProtoSize + 4; 1347 } 1348 return string(message, message_size); 1349 } 1350 1351 // Parse a RdmaMessage according to the pre-defined format 1352 // Args: 1353 // rm: the message structure where the parsed message will be saved 1354 // buffer: the place where the raw message is stored 1355 // Returns: 1356 // None 1357 void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) { 1358 char* message = static_cast<char*>(buffer); 1359 // type 1360 rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]); 1361 // request index 1362 memcpy(&rm.request_index_, &message[kRequestIndexStartIndex], 1363 sizeof(rm.request_index_)); 1364 // name, step_id, remote_addr, rkey 1365 if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || 1366 (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { 1367 memcpy(&rm.name_size_, &message[kNameSizeStartIndex], 1368 sizeof(rm.name_size_)); 1369 rm.name_ = string(&message[kNameStartIndex], rm.name_size_); 1370 memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex], 1371 sizeof(rm.remote_addr_)); 1372 memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_)); 1373 memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_)); 1374 } 1375 // data_type, tensor_bytes, tensor_shape, is_dead 1376 if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) || 1377 (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) || 1378 (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) { 1379 memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_)); 1380 memcpy(&rm.data_type_, &message[kDataTypeStartIndex], 1381 sizeof(rm.data_type_)); 1382 memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex], 1383 sizeof(rm.tensor_shape_)); 1384 memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex], 1385 sizeof(rm.tensor_bytes_)); 1386 } 1387 // checksum 1388 #ifdef RDMA_DATA_VALIDATION 1389 memcpy(&rm.checksum_, &message[kChecksumStartIndex], sizeof(rm.checksum_)); 1390 #endif 1391 // error status 1392 if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) { 1393 ErrorStatusProto gsProto; 1394 uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex]; 1395 CHECK(ParseProtoUnlimited(&gsProto, &message[kErrorStatusStartIndex + 4], 1396 gsProtoSize)) 1397 << "Failed to parse error status proto from message. Aborting."; 1398 ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(), 1399 gsProto.error_message(), gsProto.error_details()); 1400 rm.status_ = FromGrpcStatus(gs); 1401 } 1402 } 1403 1404 //***************************************************************************** 1405 // RdmaMemoryMgr 1406 //***************************************************************************** 1407 1408 ibv_mr* RdmaMemoryMgr::FindMemoryRegion(void* addr, size_t length) { 1409 mutex_lock l(mrs_mu_); 1410 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); 1411 if (iter == std::end(mrs_) || iter->get()->addr > addr) { 1412 return nullptr; 1413 } else { 1414 return iter->get(); 1415 } 1416 } 1417 1418 void RdmaMemoryMgr::InsertMemoryRegion(void* addr, size_t length, 1419 const std::string& allocator_name) { 1420 if (length == 0) return; 1421 ibv_mr* mr = ibv_reg_mr(pd_, addr, length, 1422 IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); 1423 RDMA_LOG(1) << "Insert memory region 0x" << std::hex << mr->rkey << ". [" 1424 << addr << "-" << (void*)((uint64_t)addr + length - 1) << "]" 1425 << " SIZE: 0x" << length << " (" << allocator_name << ")."; 1426 if (mr != nullptr) { 1427 mutex_lock l(mrs_mu_); 1428 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); 1429 mrs_.insert(iter, {mr, &MRDeleter}); 1430 } else { 1431 LOG(WARNING) << "Cannot register memory region"; 1432 } 1433 } 1434 1435 void RdmaMemoryMgr::EvictMemoryRegion(void* addr, size_t length) { 1436 if (length == 0) return; 1437 mutex_lock l(mrs_mu_); 1438 auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator); 1439 if (iter != std::end(mrs_) && iter->get()->addr == addr) { 1440 mrs_.erase(iter); 1441 RDMA_LOG(1) << "Evict memory region 0x" << std::hex << iter->get()->rkey; 1442 1443 } else { 1444 LOG(WARNING) << "Failed to de-register memory region"; 1445 } 1446 } 1447 1448 const TensorMetaData* RdmaMemoryMgr::GetTensorMetaData( 1449 const std::string& tensor_name) { 1450 mutex_lock l(tensor_meta_data_mu_); 1451 auto it = tensors_meta_data_.find(tensor_name); 1452 if (it == tensors_meta_data_.end()) { 1453 return nullptr; 1454 } 1455 return &it->second; 1456 } 1457 1458 const TensorMetaData* RdmaMemoryMgr::SetTensorMetaData( 1459 const std::string& tensor_name, DataType dtype, const TensorShape& shape, 1460 bool is_dead, size_t proto_size) { 1461 mutex_lock l(tensor_meta_data_mu_); 1462 TensorMetaData& meta_data = tensors_meta_data_[tensor_name]; 1463 meta_data.data_type_ = dtype; 1464 meta_data.tensor_shape_ = shape; 1465 meta_data.proto_size_ = proto_size; 1466 meta_data.is_dead_ = is_dead; 1467 return &meta_data; 1468 } 1469 1470 //***************************************************************************** 1471 // RdmaTensorRequest 1472 //***************************************************************************** 1473 1474 RdmaTensorRequest::RdmaTensorRequest( 1475 uint32_t index, const string& key, int64 step_id, RdmaChannel* channel, 1476 Device* dst_dev, const Rendezvous::Args recv_args, 1477 const RdmaTensorRequest::RecvDoneCallback& done) 1478 : index_(index), 1479 key_(key), 1480 step_id_(step_id), 1481 channel_(channel), 1482 dst_dev_(dst_dev), 1483 recv_args_(recv_args), 1484 meta_data_(RdmaMemoryMgr::Singleton().GetTensorMetaData(key)), 1485 result_tensor_(nullptr), 1486 proxy_tensor_(nullptr), 1487 rdma_addr_(nullptr), 1488 mr_(nullptr), 1489 done_(done) {} 1490 1491 RdmaTensorRequest::~RdmaTensorRequest() { DeallocateTensors(); } 1492 1493 void RdmaTensorRequest::Done(const Status& s) { 1494 Tensor val = std::move(*result_tensor_); 1495 1496 #ifdef RDMA_DATA_VALIDATION 1497 // Validate checksum 1498 // Unfortunately we can't always do a Checksum directly on the result tensor. 1499 // If the result tensor is on GPU, then we need to copy it back to CPU. If 1500 // we happen to be in the midst of a proxy callback, then the copying will 1501 // get stuck. 1502 uint64_t checksum = (proxy_tensor_ != nullptr) 1503 ? Checksum(nullptr, nullptr, *proxy_tensor_) 1504 : Checksum(dst_dev_, recv_args_.device_context, val); 1505 ValidateChecksum(checksum_, checksum, val, index_, key_, "RDMA"); 1506 #endif 1507 1508 Rendezvous::Args recv_args = std::move(recv_args_); 1509 bool is_dead = (meta_data_ == nullptr) ? false : meta_data_->is_dead_; 1510 RecvDoneCallback done = done_; 1511 DeallocateTensors(); 1512 channel_->RemoveTensorRequest(index_); 1513 done(s, Rendezvous::Args(), recv_args, val, is_dead); 1514 } 1515 1516 void RdmaTensorRequest::DeallocateTensors() { 1517 if (result_tensor_ != nullptr) { 1518 delete result_tensor_; 1519 result_tensor_ = nullptr; 1520 } 1521 if (proxy_tensor_ != nullptr) { 1522 delete proxy_tensor_; 1523 proxy_tensor_ = nullptr; 1524 } 1525 } 1526 1527 bool RdmaTensorRequest::AllocateTensors() { 1528 result_tensor_ = 1529 new Tensor(dst_dev_->GetAllocator(recv_args_.alloc_attrs), 1530 meta_data_->data_type_, meta_data_->tensor_shape_); 1531 1532 size_t tensor_size = result_tensor_->TotalBytes(); 1533 bool can_memcpy = DataTypeCanUseMemcpy(result_tensor_->dtype()); 1534 if (can_memcpy) { 1535 if (tensor_size == 0) { 1536 return true; 1537 } 1538 rdma_addr_ = DMAHelper::base(result_tensor_); 1539 mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size); 1540 #if GOOGLE_CUDA 1541 if (mr_ == nullptr) { 1542 // Can't RDMA directly to result. Use a proxy. 1543 proxy_tensor_ = 1544 new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0), 1545 result_tensor_->dtype(), result_tensor_->shape()); 1546 rdma_addr_ = DMAHelper::base(proxy_tensor_); 1547 mr_ = 1548 RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size); 1549 } 1550 #endif 1551 } else { 1552 uint32_t proto_size = meta_data_->proto_size_; 1553 rdma_addr_ = malloc(proto_size); 1554 mr_ = ibv_reg_mr(RdmaMemoryMgr::Singleton().pd_, rdma_addr_, proto_size, 1555 IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); 1556 } 1557 CHECK(mr_ != nullptr) << " No memory region found for address " << rdma_addr_ 1558 << ": " << key_; 1559 return true; 1560 } 1561 1562 void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) { 1563 AllocateTensors(); 1564 bool on_host = recv_args_.alloc_attrs.on_host(); 1565 if (dst_dev_->tensorflow_gpu_device_info() && !on_host && 1566 (proxy_tensor_ == nullptr)) { 1567 #if GOOGLE_CUDA 1568 // We need to sync the memory allocation on the GPU: 1569 StreamGPUOp(dst_dev_, recv_args_.device_context, done); 1570 #endif 1571 } else { 1572 done(Status::OK()); 1573 } 1574 } 1575 1576 void RdmaTensorRequest::Send(RdmaMessageType message_type) { 1577 RdmaMessageBuffer* rb = channel_->tx_message_buffer_; 1578 RdmaMessage rm; 1579 rm.type_ = message_type; 1580 rm.request_index_ = index_; 1581 rm.name_size_ = key_.size(); 1582 rm.name_ = key_; 1583 rm.step_id_ = step_id_; 1584 rm.remote_addr_ = (uint64_t)rdma_addr_; 1585 if (meta_data_ != nullptr) { 1586 rm.data_type_ = meta_data_->data_type_; 1587 rm.tensor_shape_ = meta_data_->tensor_shape_; 1588 rm.is_dead_ = meta_data_->is_dead_; 1589 rm.tensor_bytes_ = meta_data_->proto_size_; 1590 } else { 1591 rm.data_type_ = DT_INVALID; 1592 } 1593 rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey; 1594 1595 RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec 1596 << ": Sending " << MessageTypeToString(message_type) << " #" 1597 << index_ << ": " << rm.name_ << " on " << rdma_addr_ 1598 << " (rkey: 0x" << std::hex << rm.rkey_ << ")"; 1599 1600 string message = RdmaMessage::CreateMessage(rm); 1601 rb->EnqueueItem(message); 1602 rb->SendNextItem(); 1603 } 1604 1605 void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape, 1606 bool is_dead, size_t proto_size) { 1607 meta_data_ = RdmaMemoryMgr::Singleton().SetTensorMetaData( 1608 key_, dtype, shape, is_dead, proto_size); 1609 1610 DeallocateTensors(); 1611 AllocateTensorsAsync( 1612 [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); }); 1613 } 1614 1615 void RdmaTensorRequest::RecvTensorContent() { 1616 bool can_memcpy = DataTypeCanUseMemcpy(meta_data_->data_type_); 1617 size_t message_size = 1618 can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_; 1619 RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec 1620 << ": Received tensor content #" << index_ << ": " << key_ 1621 << " (Size: 0x" << std::hex << message_size << ")"; 1622 1623 Tensor val; 1624 1625 #if GOOGLE_CUDA 1626 if (proxy_tensor_ != nullptr) { 1627 CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_), 1628 (void*)DMAHelper::base(result_tensor_), 1629 result_tensor_->TotalBytes(), false); 1630 GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context, 1631 dst_dev_, result_tensor_, 1632 [this](const Status& s) { 1633 CHECK(s.ok()) << "copy tensor to gpu sync"; 1634 Done(s); 1635 }); 1636 return; 1637 } 1638 #endif 1639 1640 if (can_memcpy) { 1641 Done(Status::OK()); 1642 } else { 1643 RDMA_LOG(2) << "Decoding proto: " << key_ 1644 << " (Size: " << meta_data_->proto_size_ << ")"; 1645 TensorProto proto; 1646 CHECK(ParseProtoUnlimited(&proto, rdma_addr_, meta_data_->proto_size_)) 1647 << "fail to parse proto from array"; 1648 ibv_dereg_mr(mr_); 1649 free(rdma_addr_); 1650 Status s = dst_dev_->MakeTensorFromProto(proto, recv_args_.alloc_attrs, 1651 result_tensor_); 1652 Done(s); 1653 } 1654 } 1655 1656 void RdmaTensorRequest::RecvErrorStatus(const Status& status) { 1657 if (result_tensor_ == nullptr) { 1658 result_tensor_ = new Tensor(); 1659 } 1660 LOG(ERROR) << "Received RDMA_MESSAGE_ERROR_STATUS: " << status.ToString(); 1661 Done(status); 1662 } 1663 1664 void RdmaTensorRequest::Start() { 1665 meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_); 1666 if (meta_data_ != nullptr) { 1667 AllocateTensorsAsync( 1668 [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_REQUEST); }); 1669 } else { 1670 Send(RDMA_MESSAGE_TENSOR_REQUEST); 1671 } 1672 } 1673 1674 } // end namespace tensorflow 1675 1676 #endif 1677