Home | History | Annotate | Download | only in verbs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifdef TENSORFLOW_USE_VERBS
     17 
     18 #include <fcntl.h>
     19 #include <cstdlib>
     20 
     21 #include "tensorflow/contrib/verbs/rdma.h"
     22 #include "tensorflow/contrib/verbs/verbs_service.pb.h"
     23 #include "tensorflow/core/common_runtime/device_mgr.h"
     24 #include "tensorflow/core/common_runtime/dma_helper.h"
     25 #include "tensorflow/core/common_runtime/process_util.h"
     26 #if GOOGLE_CUDA
     27 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
     28 #include "tensorflow/core/common_runtime/gpu/process_state.h"
     29 #endif
     30 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
     31 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     32 #include "tensorflow/core/distributed_runtime/session_mgr.h"
     33 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     34 #include "tensorflow/core/framework/rendezvous.h"
     35 #include "tensorflow/core/framework/tensor.h"
     36 #include "tensorflow/core/lib/core/status.h"
     37 #include "tensorflow/core/lib/core/stringpiece.h"
     38 #include "tensorflow/core/lib/core/threadpool.h"
     39 #include "tensorflow/core/lib/hash/hash.h"
     40 #include "tensorflow/core/lib/random/random.h"
     41 
     42 namespace tensorflow {
     43 
     44 #define RoCE_V2 "RoCE v2"
     45 
     46 namespace {
     47 
     48 // convenience function for printing message
     49 string MessageTypeToString(RdmaMessageType rmt) {
     50   switch (rmt) {
     51     case RDMA_MESSAGE_META_DATA_UPDATE:
     52       return "RDMA_MESSAGE_META_DATA_UPDATE";
     53       break;
     54     case RDMA_MESSAGE_TENSOR_RE_REQUEST:
     55       return "RDMA_MESSAGE_TENSOR_RE_REQUEST";
     56       break;
     57     case RDMA_MESSAGE_TENSOR_REQUEST:
     58       return "RDMA_MESSAGE_TENSOR_REQUEST";
     59       break;
     60     default:
     61       return "UNKNOWN MESSAGE";
     62   }
     63 }
     64 }  // namespace
     65 
     66 // Function to get environment variable
     67 // Args:
     68 //    var_name - the name of the environmental variable
     69 // Returns:
     70 //    string with it's value or empty string if not set
     71 string get_env_var(char const* var_name) {
     72   char const* var_temp = getenv(var_name);
     73 
     74   return (var_temp == NULL) ? string() : string(var_temp);
     75 }
     76 
     77 // Function to open device
     78 // Args:
     79 //   ibv_dev device to open
     80 // Returns:
     81 //   context of the opened device
     82 ibv_context* open_device(ibv_device* ibv_dev) {
     83   ibv_context* context = ibv_open_device(ibv_dev);
     84 
     85   CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev);
     86   return context;
     87 }
     88 
     89 // Function to count the number of active ports for device
     90 // Args:
     91 //   device - to check active ports
     92 // Returns:
     93 //   number of active ports of the given device
     94 int get_dev_active_port_count(ibv_device* device) {
     95   ibv_device_attr device_att;
     96   ibv_port_attr port_attr;
     97   ibv_context* context = NULL;
     98   int rc, port_index, active_ports = 0;
     99 
    100   context = ibv_open_device(device);
    101   CHECK(context) << "Open context failed for " << ibv_get_device_name(device);
    102   rc = ibv_query_device(context, &device_att);
    103   CHECK(!rc) << "Failed to query the device";
    104 
    105   for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
    106     rc = ibv_query_port(context, port_index, &port_attr);
    107     CHECK(!rc) << "Failed to query the port" << port_index;
    108     if (port_attr.state == IBV_PORT_ACTIVE) {
    109       active_ports++;
    110     }
    111   }
    112   ibv_close_device(context);
    113   return active_ports;
    114 }
    115 
    116 // Function to set device. If RDMA_DEVICE not set, search for device with active
    117 // port.
    118 // Fails if more than one device with active port was found.
    119 // Returns:
    120 //   device to use
    121 ibv_device* set_device() {
    122   ibv_device** dev_list;
    123   int dev_num, device_index, device_to_open = 0;
    124   int num_devs_with_active_port = 0;
    125   string env_p_rdma_device, str_port_num;
    126 
    127   dev_list = ibv_get_device_list(&dev_num);
    128   CHECK(dev_list) << "No InfiniBand device found";
    129 
    130   env_p_rdma_device = get_env_var("RDMA_DEVICE");
    131   if (!env_p_rdma_device.empty()) {
    132     for (device_index = 0; device_index < dev_num; device_index++) {
    133       if (!env_p_rdma_device.compare(
    134               ibv_get_device_name(dev_list[device_index]))) {
    135         CHECK(get_dev_active_port_count(dev_list[device_index]) != 0)
    136             << "Device " << ibv_get_device_name(dev_list[device_index])
    137             << " has no active ports";
    138         return dev_list[device_index];
    139       }
    140     }
    141     // check validity of input device
    142     CHECK(false) << "The device " << env_p_rdma_device << " wasn't found";
    143   } else {
    144     // set default device
    145     str_port_num = get_env_var("RDMA_DEVICE_PORT");
    146     CHECK(str_port_num.empty())
    147         << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user";
    148     for (device_index = 0; device_index < dev_num; device_index++) {
    149       // get port_num
    150       if (get_dev_active_port_count(dev_list[device_index]) > 0) {
    151         num_devs_with_active_port++;
    152         CHECK(num_devs_with_active_port <= 1) << ". More than one device with "
    153                                                  "active port in the system. "
    154                                                  "Please enter RDMA_DEVICE";
    155         // found device with at least 1 active port
    156         device_to_open = device_index;
    157       }
    158     }
    159     CHECK(num_devs_with_active_port > 0)
    160         << "There is no active port in the system";
    161     return dev_list[device_to_open];
    162   }
    163   CHECK(false) << "No device was set!";
    164   return NULL;  // never happens
    165 }
    166 
    167 // Function to set port for device.
    168 // If RDMA_DEVICE_PORT not set, first active port of the device will be set.
    169 // Args:
    170 //   context of the device
    171 // Returns:
    172 //   port to use
    173 uint8_t set_port(ibv_context* context) {
    174   uint8_t port_num = 0;  // 0 is illegal port number
    175   string str_port_num;
    176   ibv_device_attr device_att;
    177   ibv_port_attr port_attr;
    178   int rc, port_index;
    179 
    180   rc = ibv_query_device(context, &device_att);
    181   CHECK(!rc) << "Failed to query the device\n";
    182 
    183   str_port_num = get_env_var("RDMA_DEVICE_PORT");
    184   // user defined port
    185   if (!str_port_num.empty()) {
    186     port_num = stoi(str_port_num);
    187     CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive";
    188     CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be "
    189                                                    "less or equal to amount of "
    190                                                    "available ports";
    191     rc = ibv_query_port(context, port_num, &port_attr);
    192     CHECK(!rc) << "Failed to query the port" << port_num;
    193     // check if port id active
    194     CHECK(port_attr.state == IBV_PORT_ACTIVE)
    195         << "Selected RDMA_DEVICE_PORT is not active";
    196   } else {  // set default port
    197     for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) {
    198       rc = ibv_query_port(context, port_index, &port_attr);
    199       CHECK(!rc) << "Failed to query the port" << port_index;
    200       if (port_attr.state == IBV_PORT_ACTIVE) {
    201         port_num = port_index;
    202         break;
    203       }
    204     }
    205     CHECK_GT(port_num, 0) << "No active ports";
    206   }
    207   return port_num;
    208 }
    209 
    210 // Function read from sysfs file
    211 // Args:
    212 //   dir - directory
    213 //   file - file
    214 //   buff - buffer for the result
    215 //   size - buffer size
    216 // Returns:
    217 //   number of bytes were read or -1 if failed
    218 int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) {
    219   char* path;
    220   int fd;
    221   int len;
    222 
    223   if (asprintf(&path, "%s/%s", dir, file) < 0) return -1;
    224 
    225   fd = open(path, O_RDONLY);
    226   if (fd < 0) {
    227     free(path);
    228     return -1;
    229   }
    230 
    231   len = read(fd, buf, size);
    232 
    233   close(fd);
    234   free(path);
    235 
    236   if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0';
    237 
    238   return len;
    239 }
    240 
    241 // Function to check if GID index support RoCE V2
    242 // Args:
    243 //   context - device context
    244 //   port_num - port number
    245 //   index -  GID index
    246 // Returns:
    247 //   if GID supports RoCE V2 - true, otherwise - false.
    248 bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num,
    249                          uint8_t index) {
    250   char name[32];
    251   char buff[41];
    252 
    253   snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index);
    254   if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <=
    255       0) {
    256     return false;
    257   }
    258   return !strcmp(buff, RoCE_V2);
    259 }
    260 
    261 // Function to set GID index.
    262 // If the port link is IB, no GID index should be selected.
    263 // If Ethernet but RDMA_GID_INDEX not set gid index that supports
    264 //   RoCE V2 will be chosen(fails if more than one IP is configured)
    265 // Args:
    266 //   context - device context
    267 //   port_num - port number
    268 // Returns:
    269 //   GID index to use
    270 uint8_t set_gid(uint8_t port_num, ibv_context* context) {
    271   ibv_port_attr port_attr;
    272   string gid_str;
    273   int rc, i, gids_num = 0, v2_ip_num = 0;
    274   union ibv_gid gid;
    275   uint8_t gid_index = 0;
    276 
    277   rc = ibv_query_port(context, port_num, &port_attr);
    278   CHECK(!rc) << "Failed to query the port" << port_num;
    279 
    280   for (i = 0; i < port_attr.gid_tbl_len; i++) {
    281     rc = ibv_query_gid(context, port_num, i, &gid);
    282     CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index "
    283                << i;
    284     if (gid.global.interface_id) {
    285       gids_num++;
    286       if (gid.global.subnet_prefix == 0 &&
    287           is_gid_type_roce_v2(context, port_num, i)) {
    288         if (v2_ip_num == 0) {
    289           // can be overwritten by RDMA_GID_INDEX later
    290           gid_index = i;
    291         }
    292         v2_ip_num++;
    293       }
    294     }
    295   }
    296   switch (port_attr.link_layer) {
    297     case (IBV_LINK_LAYER_ETHERNET):
    298       gid_str = get_env_var("RDMA_GID_INDEX");
    299       if (!gid_str.empty()) {
    300         gid_index = stoi(gid_str);
    301         CHECK(gid_index < gids_num)
    302             << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num;
    303       } else {
    304         CHECK(v2_ip_num <= 1)
    305             << "More than one IP is available, please specify GID_INDEX";
    306       }
    307       break;
    308     case (IBV_LINK_LAYER_INFINIBAND):  // no need in GID index
    309       break;
    310     default:
    311       LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and "
    312                    "InfiniBand only. ";
    313   }
    314   if (!is_gid_type_roce_v2(context, port_num, gid_index)) {
    315     LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index;
    316   }
    317   return gid_index;
    318 }
    319 
    320 // set the default or environment value to the configuration parameter.
    321 // Args:
    322 //   default_val- the default value for this parameter
    323 //   env_param- the environment parameter's name
    324 // Returns:
    325 //   32-bit value
    326 uint32_t set_param(uint32_t default_val, const char* env_param) {
    327   uint32_t val = default_val;
    328   string val_s;
    329 
    330   val_s = get_env_var(env_param);
    331 
    332   if (!val_s.empty()) {
    333     val = stoi(val_s);
    334   }
    335   return val;
    336 }
    337 
    338 enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) {
    339   ibv_port_attr port_attr;
    340   enum ibv_mtu mtu = IBV_MTU_512;
    341   string mtu_s;
    342   int rc, mtu_i;
    343 
    344   rc = ibv_query_port(context, port_num, &port_attr);
    345   CHECK(!rc) << "Failed to query the port" << port_num;
    346 
    347   mtu_s = get_env_var("RDMA_MTU");
    348 
    349   if (!mtu_s.empty()) {
    350     mtu_i = stoi(mtu_s);
    351     switch (mtu_i) {
    352       case 256:
    353         mtu = IBV_MTU_256;
    354         break;
    355       case 512:
    356         mtu = IBV_MTU_512;
    357         break;
    358       case 1024:
    359         mtu = IBV_MTU_1024;
    360         break;
    361       case 2048:
    362         mtu = IBV_MTU_2048;
    363         break;
    364       case 4096:
    365         mtu = IBV_MTU_4096;
    366         break;
    367       default:
    368         CHECK(0) << "Error: MTU input value must be one of the following: 256, "
    369                     "512, 1024, 2048, 4096. MTU "
    370                  << mtu << " is invalid\n";
    371         break;
    372     }
    373     CHECK(mtu < port_attr.active_mtu)
    374         << "MTU configuration for the QPs is larger than active MTU";
    375   } else {
    376     mtu = port_attr.active_mtu;
    377   }
    378   return mtu;
    379 }
    380 
    381 RdmaParams params_init(ibv_context* context) {
    382   RdmaParams params;
    383 
    384   params.port_num = set_port(context);
    385   params.sgid_index = set_gid(params.port_num, context);
    386   params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY");
    387   params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH");
    388   params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT");
    389   params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT");
    390   params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL");
    391   CHECK(params.sl <= 7) << "SL value is " << (int)params.sl
    392                         << ". Valid values are 0-7.";
    393   params.mtu = set_mtu(params.port_num, context);
    394   params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS");
    395   return params;
    396 }
    397 
    398 ibv_pd* alloc_protection_domain(ibv_context* context) {
    399   ibv_pd* pd = ibv_alloc_pd(context);
    400   CHECK(pd) << "Failed to allocate protection domain";
    401   return pd;
    402 }
    403 
    404 RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
    405     : context_(open_device(set_device())),
    406       params_(params_init(context_)),
    407       pd_(alloc_protection_domain(context_)),
    408       worker_env_(worker_env) {
    409   event_channel_ = ibv_create_comp_channel(context_);
    410   CHECK(event_channel_) << "Failed to create completion channel";
    411   cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
    412                       0);
    413   CHECK(cq_) << "Failed to create completion queue";
    414   CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
    415 }
    416 
    417 RdmaAdapter::~RdmaAdapter() {
    418   polling_thread_.reset();
    419   CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
    420   CHECK(!ibv_destroy_comp_channel(event_channel_))
    421       << "Failed to destroy channel";
    422   CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
    423   CHECK(!ibv_close_device(context_)) << "Failed to release context";
    424 }
    425 
    426 void RdmaAdapter::StartPolling() {
    427   polling_thread_.reset(Env::Default()->StartThread(
    428       ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
    429   VLOG(2) << "Start RdmaAdapter: " << name();
    430 }
    431 
    432 string RdmaAdapter::name() const { return string(context_->device->name); }
    433 
    434 // Function to process incoming messages
    435 // There are two types of messages:
    436 // 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
    437 // 2. IBV_WC_RDMA_WRITE (send))
    438 void RdmaAdapter::Process_CQ() {
    439   while (true) {
    440     ibv_cq* cq;
    441     void* cq_context;
    442     CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context));
    443     CHECK(cq == cq_);
    444     ibv_ack_cq_events(cq, 1);
    445     CHECK(!ibv_req_notify_cq(cq_, 0));
    446 
    447     int ne =
    448         ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
    449     CHECK_GE(ne, 0);
    450     for (int i = 0; i < ne; ++i) {
    451       CHECK(wc_[i].status == IBV_WC_SUCCESS)
    452           << "Failed status \n"
    453           << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
    454           << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
    455       if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
    456         RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
    457         // put back a recv wr.
    458         rc->Recv();
    459         // imm_data is the index of RX buffer in the buffer table.
    460         uint32_t imm_data = wc_[i].imm_data;
    461         RdmaMessageBuffer* rb;
    462         RdmaMessage rm;
    463 
    464         if (imm_data == RDMA_IMM_DATA_ACK) {
    465           // receive an ack to a message
    466           rb = rc->tx_message_buffer_;
    467           rb->SetBufferStatus(remote, idle);
    468           rb->SendNextItem();
    469           continue;
    470         }
    471 
    472         if (imm_data <= RDMA_IMM_MAX_REQUEST_ID) {
    473           // receive a tensor RDMA write
    474           uint32_t request_index = imm_data;
    475           RdmaTensorRequest* request = rc->GetTensorRequest(request_index);
    476           request->RecvTensorContent();
    477           continue;
    478         }
    479 
    480         // receive a control message
    481         rb = rc->rx_message_buffer_;
    482         RdmaMessage::ParseMessage(rm, rb->buffer_);
    483         RdmaMessageBuffer::SendAck(rc);
    484         RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
    485                     << ": Received " << MessageTypeToString(rm.type_) << " "
    486                     << "#" << rm.request_index_ << ": " << rm.name_;
    487 
    488         if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
    489           RdmaTensorResponse* response = rc->AddTensorResponse(rm);
    490           response->Start();
    491         } else if (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) {
    492           RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
    493           request->RecvTensorMetaData(rm.data_type_, rm.tensor_shape_,
    494                                       rm.is_dead_, rm.tensor_bytes_);
    495 #ifdef RDMA_DATA_VALIDATION
    496           request->RecvTensorChecksum(rm.checksum_);
    497 #endif
    498         } else if (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST) {
    499           RdmaTensorResponse* response = rc->UpdateTensorResponse(rm);
    500           response->Resume();
    501         } else if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
    502           RdmaTensorRequest* request = rc->GetTensorRequest(rm.request_index_);
    503           request->RecvErrorStatus(rm.status_);
    504         }
    505       } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
    506         RdmaWriteID* wr_id = reinterpret_cast<RdmaWriteID*>(wc_[i].wr_id);
    507         RDMA_LOG(2) << "Write complete of type " << wr_id->write_type;
    508         switch (wr_id->write_type) {
    509           case RDMA_WRITE_ID_ACK:
    510             break;
    511           case RDMA_WRITE_ID_MESSAGE: {
    512             RdmaMessageBuffer* rb =
    513                 reinterpret_cast<RdmaMessageBuffer*>(wr_id->write_context);
    514             rb->SetBufferStatus(local, idle);
    515             rb->SendNextItem();
    516             break;
    517           }
    518           case RDMA_WRITE_ID_TENSOR_WRITE: {
    519             RdmaTensorResponse* response =
    520                 reinterpret_cast<RdmaTensorResponse*>(wr_id->write_context);
    521             response->Destroy();
    522           }
    523         }
    524         delete wr_id;
    525       }
    526     }
    527   }
    528 }
    529 
    530 int RdmaChannel::PingPostRecv() {
    531   struct ibv_recv_wr wr, *bad_wr;
    532   memset(&wr, 0, sizeof(wr));
    533   wr.sg_list = &ping_sge_list_;
    534   wr.num_sge = 1;
    535   wr.wr_id = kPingRecvWrid;
    536 
    537   return ibv_post_recv(qp_, &wr, &bad_wr);
    538 }
    539 
    540 int RdmaChannel::PingPostSend() {
    541   struct ibv_send_wr wr, *bad_wr;
    542   memset(&wr, 0, sizeof(wr));
    543   wr.wr_id = (uint64_t)this;
    544   wr.sg_list = &ping_sge_list_;
    545   wr.num_sge = 1;
    546   wr.opcode = IBV_WR_SEND;
    547   wr.send_flags = IBV_SEND_SIGNALED;
    548 
    549   return ibv_post_send(qp_, &wr, &bad_wr);
    550 }
    551 
    552 RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
    553                          const string remote_name)
    554     : adapter_(adapter),
    555       local_name_(local_name),
    556       remote_name_(remote_name),
    557       request_serial_(0) {
    558   struct ibv_sge list;
    559 
    560   mr_ = ibv_reg_mr(adapter_->pd_, ping_buff_, kPingBuffSize,
    561                    IBV_ACCESS_LOCAL_WRITE);
    562   CHECK(mr_) << "Failed to register memory region";
    563 
    564   memset(&list, 0, sizeof(list));
    565   list.addr = (uintptr_t)ping_buff_;
    566   list.length = kPingBuffSize;
    567   list.lkey = mr_->lkey;
    568 
    569   ping_sge_list_ = list;
    570   // Create queue pair
    571   {
    572     struct ibv_qp_init_attr attr;
    573     memset(&attr, 0, sizeof(ibv_qp_init_attr));
    574     attr.send_cq = adapter_->cq_;
    575     attr.recv_cq = adapter_->cq_;
    576     attr.cap.max_send_wr = adapter_->params_.queue_depth;
    577     attr.cap.max_recv_wr = adapter_->params_.queue_depth;
    578     attr.cap.max_send_sge = 1;
    579     attr.cap.max_recv_sge = 1;
    580     attr.qp_type = IBV_QPT_RC;
    581 
    582     qp_ = ibv_create_qp(adapter_->pd_, &attr);
    583     CHECK(qp_) << "Failed to create queue pair";
    584   }
    585 
    586   // Init queue pair
    587   {
    588     struct ibv_qp_attr attr;
    589     memset(&attr, 0, sizeof(ibv_qp_attr));
    590     attr.qp_state = IBV_QPS_INIT;
    591     attr.pkey_index = adapter_->params_.pkey_index;
    592     attr.port_num = adapter_->params_.port_num;
    593     attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
    594 
    595     int mask =
    596         IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
    597     CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
    598   }
    599 
    600   // Local address
    601   {
    602     struct ibv_port_attr attr;
    603     CHECK(
    604         !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr))
    605         << "Query port";
    606     self_.lid = attr.lid;
    607     self_.qpn = qp_->qp_num;
    608     self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
    609     union ibv_gid gid;
    610     CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num,
    611                          adapter_->params_.sgid_index, &gid))
    612         << "Query gid";
    613     self_.snp = gid.global.subnet_prefix;
    614     self_.iid = gid.global.interface_id;
    615   }
    616 
    617   // create message and ack buffers, then initialize the tables.
    618   {
    619     const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer"};
    620     tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
    621     rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
    622     message_buffers_.reserve(kNumMessageBuffers);
    623     message_buffers_.push_back(tx_message_buffer_);
    624     message_buffers_.push_back(rx_message_buffer_);
    625     // create buffer on host
    626     tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
    627     rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
    628   }
    629   CHECK(PingPostRecv() == 0) << "Couldn't post receive from " << remote_name_
    630                              << " with error " << std::strerror(errno);
    631 }
    632 
    633 RdmaChannel::~RdmaChannel() {
    634   ibv_dereg_mr(mr_);
    635   CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
    636   delete tx_message_buffer_;
    637   delete rx_message_buffer_;
    638 }
    639 
    640 void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
    641   mutex_lock lock{mu_};
    642   if ((override) || (!remote_set_)) {
    643     remote_.lid = ra.lid;
    644     remote_.qpn = ra.qpn;
    645     remote_.psn = ra.psn;
    646     remote_.snp = ra.snp;
    647     remote_.iid = ra.iid;
    648     remote_set_ = true;
    649   } else {
    650     CHECK(remote_.lid == ra.lid);
    651     CHECK(remote_.qpn == ra.qpn);
    652     CHECK(remote_.psn == ra.psn);
    653     CHECK(remote_.snp == ra.snp);
    654     CHECK(remote_.iid == ra.iid);
    655   }
    656 }
    657 
    658 // Adding tokens to the completion queue
    659 // Tokens are needed to process future messages.
    660 void RdmaChannel::Recv() {
    661   struct ibv_recv_wr wr;
    662   memset(&wr, 0, sizeof(wr));
    663   wr.wr_id = (uint64_t)this;
    664   struct ibv_recv_wr* bad_wr;
    665   CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
    666 }
    667 
    668 RdmaTensorRequest* RdmaChannel::InsertTensorRequest(
    669     const string& key, int64 step_id, Device* dst_dev,
    670     const Rendezvous::Args recv_args,
    671     const RdmaTensorRequest::RecvDoneCallback& done) {
    672   mutex_lock lock{ct_mu_};
    673   uint32_t request_index = request_serial_++;
    674   if (request_serial_ > RDMA_IMM_MAX_REQUEST_ID) {
    675     request_serial_ = 0;
    676   }
    677   RdmaTensorRequest request(request_index, key, step_id, this, dst_dev,
    678                             recv_args, done);
    679   auto it = request_table_.emplace(request_index, request);
    680   return &it.first->second;
    681 }
    682 
    683 void RdmaChannel::RemoveTensorRequest(uint32_t request_index) {
    684   mutex_lock lock{ct_mu_};
    685   request_table_.erase(request_index);
    686 }
    687 
    688 RdmaTensorRequest* RdmaChannel::GetTensorRequest(uint32_t request_index) {
    689   mutex_lock lock{ct_mu_};
    690   RequestTable::iterator iter = request_table_.find(request_index);
    691   CHECK(iter != request_table_.end());
    692   return &iter->second;
    693 }
    694 
    695 void RdmaChannel::Connect() {
    696   {
    697     mutex_lock lock{mu_};
    698     CHECK(remote_set_) << "remote channel is not set";
    699   }
    700   Connect(remote_);
    701 }
    702 
    703 // Setup channel to a remote node
    704 // Args:
    705 //   remoteAddr: the rdma address of a remote channel.
    706 // Returns:
    707 //   None
    708 void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
    709   mutex_lock lock{mu_};
    710   if (!connected_) {
    711     struct ibv_qp_attr attr;
    712     memset(&attr, 0, sizeof(ibv_qp_attr));
    713     attr.qp_state = IBV_QPS_RTR;
    714 
    715     // This assumes both QP's ports are configured with the same MTU
    716     attr.path_mtu = adapter_->params_.mtu;
    717     attr.dest_qp_num = remoteAddr.qpn;
    718     attr.rq_psn = remoteAddr.psn;
    719     attr.max_dest_rd_atomic = 1;
    720     attr.min_rnr_timer = 12;
    721     attr.ah_attr.is_global = 1;
    722     attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp;
    723     attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid;
    724     attr.ah_attr.grh.flow_label = 0;
    725     attr.ah_attr.grh.hop_limit = 255;
    726     attr.ah_attr.dlid = remoteAddr.lid;
    727     attr.ah_attr.sl = adapter_->params_.sl;
    728     attr.ah_attr.src_path_bits = 0;
    729     attr.ah_attr.port_num = adapter_->params_.port_num;
    730     attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index;
    731     attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class;
    732 
    733     int r;
    734     CHECK(!(r = ibv_modify_qp(qp_, &attr,
    735                               IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
    736                                   IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
    737                                   IBV_QP_MAX_DEST_RD_ATOMIC |
    738                                   IBV_QP_MIN_RNR_TIMER)))
    739         << "QP to Ready to Receive " << r;
    740 
    741     memset(&attr, 0, sizeof(ibv_qp_attr));
    742     attr.qp_state = IBV_QPS_RTS;
    743     attr.sq_psn = self_.psn;
    744     attr.timeout = adapter_->params_.timeout;
    745     attr.retry_cnt = adapter_->params_.retry_cnt;
    746     attr.rnr_retry = 7; /* infinite */
    747     attr.max_rd_atomic = 1;
    748 
    749     CHECK(!(r = ibv_modify_qp(qp_, &attr,
    750                               IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
    751                                   IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
    752                                   IBV_QP_MAX_QP_RD_ATOMIC)))
    753         << "QP to Ready to Send " << r;
    754 
    755     connected_ = true;
    756   } else {
    757     RDMA_LOG(2) << "channel already connected";
    758   }
    759 }
    760 
    761 RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
    762     : channel_(channel), name_(name) {}
    763 
    764 RdmaMessageBuffer::~RdmaMessageBuffer() {
    765   CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
    766   FreeBuffer();
    767 }
    768 
    769 void RdmaMessageBuffer::FreeBuffer() {
    770   if ((buffer_ != nullptr) && buffer_on_host_) {
    771     free(buffer_);
    772   }
    773 }
    774 
    775 // Allocate CPU memory for the Rdma buffer
    776 // Args:
    777 //   size: to-be-allocated memory size
    778 //   lock: whether or not mutex_lock the process to protect concurrency.
    779 // Returns:
    780 //   None
    781 void RdmaMessageBuffer::CreateCPUBuffer(size_t size, bool lock) {
    782   CHECK(size > 0);
    783   if (lock) {
    784     mu_.lock();
    785   }
    786   if (local_status_ != none) {
    787     // delete existing buffer
    788     CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
    789     FreeBuffer();
    790   }
    791   size_ = size;
    792   buffer_ = malloc(size_);
    793   self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
    794                      IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
    795   CHECK(self_) << "Failed to register memory region";
    796   buffer_on_host_ = true;
    797   local_status_ = idle;
    798   if (lock) {
    799     mu_.unlock();
    800   }
    801 }
    802 
    803 // Set address of remote memory region
    804 // Args:
    805 //   rmr: address of remote memory region
    806 //   override: whether override existing information
    807 // Returns:
    808 //   None
    809 void RdmaMessageBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
    810   mutex_lock lock{mu_};
    811   if ((override) || (remote_status_ == none)) {
    812     remote_.remote_addr = rmr.remote_addr;
    813     remote_.rkey = rmr.rkey;
    814     remote_status_ = idle;
    815   } else {
    816     CHECK(remote_.remote_addr == rmr.remote_addr);
    817     CHECK(remote_.rkey == rmr.rkey);
    818   }
    819 }
    820 
    821 // Put a task in the buffer's job queue
    822 void RdmaMessageBuffer::EnqueueItem(string item) {
    823   mutex_lock lock{mu_};
    824   queue_.push(item);
    825 }
    826 
    827 // Rdma-Write the content of the buffer
    828 void RdmaMessageBuffer::Write(uint32_t imm_data, size_t buffer_size) {
    829   Write(channel_, imm_data, buffer_size, (uint64_t)buffer_, self_->lkey,
    830         remote_.remote_addr, remote_.rkey, RDMA_WRITE_ID_MESSAGE, this);
    831 }
    832 
    833 // Generalized Write method
    834 void RdmaMessageBuffer::Write(const RdmaChannel* channel, uint32_t imm_data,
    835                               size_t buffer_size, uint64_t src_addr,
    836                               uint32_t lkey, uint64_t remote_addr,
    837                               uint32_t rkey, RdmaWriteIDType write_type,
    838                               void* write_context) {
    839   struct ibv_sge list;
    840   list.addr = src_addr;
    841   list.length = buffer_size;
    842   list.lkey = lkey;
    843 
    844   struct ibv_send_wr wr;
    845   memset(&wr, 0, sizeof(wr));
    846   wr.wr_id = (uint64_t) new RdmaWriteID(write_type, write_context);
    847   wr.sg_list = &list;
    848   wr.num_sge = 1;
    849   wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
    850   wr.send_flags = IBV_SEND_SIGNALED;
    851   wr.imm_data = imm_data;
    852   wr.wr.rdma.remote_addr = remote_addr;
    853   wr.wr.rdma.rkey = rkey;
    854 
    855   struct ibv_send_wr* bad_wr;
    856   CHECK(!ibv_post_send(channel->qp_, &wr, &bad_wr)) << "Failed to post send";
    857 }
    858 
    859 // Send the next ack from the buffer's job queue.
    860 void RdmaMessageBuffer::SendAck(const RdmaChannel* channel) {
    861   Write(channel, RDMA_IMM_DATA_ACK, 0, 0, 0, 0, 0, RDMA_WRITE_ID_ACK, nullptr);
    862 }
    863 
    864 // Send the next message from the buffer's job queue.
    865 void RdmaMessageBuffer::SendNextItem() {
    866   uint32_t imm_data = RDMA_IMM_DATA_MESSAGE;
    867   mu_.lock();
    868   if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
    869     local_status_ = busy;
    870     remote_status_ = busy;
    871     string message = queue_.front();
    872     queue_.pop();
    873     // local/remote_status_ won't be set back to idle
    874     // unitl Write() is successful
    875     mu_.unlock();
    876     memcpy(buffer_, message.data(), message.size());
    877     Write(imm_data, message.size());
    878   } else {
    879     mu_.unlock();
    880   }
    881 }
    882 
    883 #if GOOGLE_CUDA
    884 static void CountCopies(const std::string& key, void* src_addr, void* dst_addr,
    885                         size_t tensor_bytes, bool is_gpu_to_cpu) {
    886 #ifdef RDMA_COUNT_COPIES
    887   static uint64_t numGPUToCPUCopies = 0;
    888   static uint64_t numGPUToCPUCopiedBytes = 0;
    889   static uint64_t numCPUToGPUCopies = 0;
    890   static uint64_t numCPUToGPUCopiedBytes = 0;
    891   static uint64_t numTotalCopies = 0;
    892 
    893   if (is_gpu_to_cpu) {
    894     ++numGPUToCPUCopies;
    895     numGPUToCPUCopiedBytes += tensor_bytes;
    896   } else {
    897     ++numCPUToGPUCopies;
    898     numCPUToGPUCopiedBytes += tensor_bytes;
    899   }
    900   if ((++numTotalCopies % 0x400) == 0) {
    901     RDMA_LOG(0) << "Tensor copies:"
    902                 << " GPU to CPU: " << numGPUToCPUCopies << " ("
    903                 << numGPUToCPUCopiedBytes << " Bytes)"
    904                 << " CPU to GPU: " << numCPUToGPUCopies << " ("
    905                 << numCPUToGPUCopiedBytes << " Bytes)";
    906   }
    907   RDMA_LOG(2) << "Copying tensor " << key << " From: " << src_addr
    908               << " To: " << dst_addr;
    909 #endif  // RDMA_COUNT_COPIES
    910 }
    911 #endif  // GOOGLE_CUDA
    912 
    913 #ifdef RDMA_DATA_VALIDATION
    914 static uint64_t Checksum(Device* device, const DeviceContext* device_context,
    915                          const Tensor& in) {
    916   uint64 checksum = 0;
    917   if (DataTypeCanUseMemcpy(in.dtype())) {
    918 #if GOOGLE_CUDA
    919     if (in.TotalBytes() == 0) {
    920       return 0;
    921     }
    922     checksum = (device_context != nullptr)
    923                    ? GPUUtil::Checksum(device, device_context, in)
    924                    : GPUUtil::Checksum(in);
    925 #endif  // GOOGLE_CUDA
    926   } else {
    927     string s = in.SummarizeValue(999999);
    928     checksum = Hash64(s.c_str(), s.size(), 0);
    929   }
    930   return checksum;
    931 }
    932 
    933 static void ValidateChecksum(uint64_t expected, uint64_t actual,
    934                              const Tensor& in, uint32_t request_index,
    935                              const std::string& key, const std::string& msg) {
    936   RDMA_LOG(2) << "Request #" << request_index << ": " << key
    937               << ": Checksum: " << std::hex << " Expected = 0x" << expected
    938               << ". Actual = 0x" << actual << ".";
    939 
    940   if (expected != actual) {
    941     // Checksum failed. There is one case where this is allowed - if the
    942     // tensor is an AssignAdd of the global step. Since the data-validation
    943     // always postpones the Tensor response in order to send a checksum message,
    944     // it is possible that the global-step was updated while the response was
    945     // still in queue.
    946     if ((in.TotalBytes() == 8) && (in.dtype() == DT_INT64)) {
    947       int64_t prev_val = *(int64_t*)DMAHelper::base(&in) - 1;
    948       actual = Hash64((const char*)&prev_val, 8, 0);
    949     }
    950     if (expected != actual) {
    951       LOG(FATAL) << "[" << msg << "]: Checksum validation failed for request #"
    952                  << request_index << ": " << key << std::hex << " "
    953                  << DataTypeString(in.dtype()) << " "
    954                  << in.shape().DebugString() << " (0x" << in.TotalBytes()
    955                  << " bytes): "
    956                  << " Expected 0x" << expected << ". Got 0x" << actual << ".";
    957     }
    958   }
    959 }
    960 #endif  // RDMA_DATA_VALIDATION
    961 
    962 #if GOOGLE_CUDA
    963 // Sync the 'done' operation on the GPU stream, but without all the data
    964 // copying.
    965 static void StreamGPUOp(Device* gpu_device, const DeviceContext* device_context,
    966                         StatusCallback done) {
    967   Tensor dummy1, dummy2;
    968   GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context, &dummy1, &dummy2,
    969                               done);
    970 }
    971 #endif  // GOOGLE_CUDA
    972 
    973 RdmaTensorResponse* RdmaChannel::AddTensorResponse(const RdmaMessage& rm) {
    974   mutex_lock lock{mu_};
    975   auto it =
    976       responses_table_.emplace(rm.request_index_, RdmaTensorResponse(this, rm));
    977   CHECK(it.second) << "Response with the ID " << rm.request_index_
    978                    << " already exists.";
    979   return &it.first->second;
    980 }
    981 
    982 RdmaTensorResponse* RdmaChannel::UpdateTensorResponse(const RdmaMessage& rm) {
    983   mutex_lock lock{mu_};
    984   auto it = responses_table_.find(rm.request_index_);
    985   CHECK(it != responses_table_.end()) << "No response found.";
    986   RdmaTensorResponse* response = &it->second;
    987   response->Update(rm);
    988   return response;
    989 }
    990 
    991 void RdmaChannel::RemoveTensorResponse(uint32_t request_index) {
    992   mutex_lock lock{mu_};
    993   responses_table_.erase(request_index);
    994 }
    995 
    996 void RdmaTensorResponse::Start() {
    997   Rendezvous::ParsedKey parsed;
    998   Status s = Rendezvous::ParseKey(rm_.name_, &parsed);
    999   if (!s.ok()) {
   1000     SendErrorStatus(s);
   1001     return;
   1002   }
   1003 
   1004   channel_->adapter_->worker_env_->rendezvous_mgr->RecvLocalAsync(
   1005       rm_.step_id_, parsed,
   1006       [this, parsed](const Status& status, const Rendezvous::Args& send_args,
   1007                      const Rendezvous::Args& recv_args, const Tensor& in,
   1008                      bool is_dead) {
   1009         CHECK(status.ok()) << "RecvLocalAsync was not ok."
   1010                            << " error message: " << status.error_message();
   1011         RecvHandler(parsed, send_args, recv_args, in, is_dead);
   1012       });
   1013 }
   1014 
   1015 void RdmaTensorResponse::Resume() { SendContent(*tensor_, *proto_, is_dead_); }
   1016 
   1017 // Helper for RecvTensor. Validates "key" and returns the source
   1018 // device in "*src_dev".
   1019 Status RdmaTensorResponse::PrepareRecvTensor(
   1020     const Rendezvous::ParsedKey& parsed, Device** src_dev) {
   1021   // Figures out which device the tensor is hosted on.
   1022   string local_name = DeviceNameUtils::LocalName(parsed.src_device);
   1023   TF_RETURN_IF_ERROR(channel_->adapter_->worker_env_->device_mgr->LookupDevice(
   1024       local_name, src_dev));
   1025 
   1026   // Does the device have the right incarnation number we expect?
   1027   if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
   1028     return errors::Aborted(
   1029         "RecvTensor expects a different device incarnation: ",
   1030         parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
   1031         ". Your worker job was probably restarted. Check your "
   1032         "worker job for the reason why it was restarted.");
   1033   }
   1034 
   1035   return Status::OK();
   1036 }
   1037 
   1038 void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
   1039                                      const Rendezvous::Args& send_args,
   1040                                      const Rendezvous::Args& recv_args,
   1041                                      const Tensor& in, bool is_dead) {
   1042   Status s = PrepareRecvTensor(parsed, &src_dev_);
   1043   if (!s.ok()) {
   1044     SendErrorStatus(s);
   1045     return;
   1046   }
   1047 
   1048   meta_data_changed_ = TensorMetaDataChanged(in, is_dead);
   1049 #ifdef RDMA_DATA_VALIDATION
   1050   // Always send a meta data message with the source checksum
   1051   meta_data_changed_ = rm_.type_ == RDMA_MESSAGE_TENSOR_REQUEST;
   1052   checksum_ = Checksum(src_dev_, send_args.device_context, in);
   1053 #endif
   1054   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
   1055   // string tensor needs to be serialized
   1056   Tensor copy;
   1057   TensorProto proto;
   1058   const bool on_host = send_args.alloc_attrs.on_host();
   1059   if (src_dev_->tensorflow_gpu_device_info() && !on_host) {
   1060 #if GOOGLE_CUDA
   1061     DeviceContext* send_dev_context = send_args.device_context;
   1062     CHECK(send_dev_context)
   1063         << "send dev name: " << src_dev_->name()
   1064         << " gpu_info: " << src_dev_->tensorflow_gpu_device_info();
   1065 
   1066     if (can_memcpy) {
   1067       // If the tensor is located on a GDR compatible GPU, there is no need to
   1068       // copy it. We can send directly from the source, just need to make sure
   1069       // we are in sync with the GPU stream.
   1070       // If the tensor's meta-data changed however, we will need to clone it,
   1071       // so anyway we'll have to copy it from GPU to CPU first. If at some
   1072       // point in time Clone() is changed to only save a shallow copy, we can
   1073       // skip the copy here as well.
   1074       if ((in.TotalBytes() > 0) && !meta_data_changed_ &&
   1075           (RdmaMemoryMgr::Singleton().FindMemoryRegion(
   1076                (void*)DMAHelper::base(&in), in.TotalBytes()) != nullptr)) {
   1077         StreamGPUOp(src_dev_, send_dev_context,
   1078                     [this, in, proto, is_dead](const Status& s) {
   1079                       Send(in, proto, is_dead, s);
   1080                     });
   1081         return;
   1082       }
   1083 
   1084       // The tensor must be copied from GPU to CPU, because either:
   1085       // 1. The tensor is located on a non GDR compatible GPU.
   1086       // 2. The tensor's meta-data has changed.
   1087       Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
   1088       copy = Tensor(alloc, in.dtype(), in.shape());
   1089       CountCopies(rm_.name_, (void*)DMAHelper::base(&in),
   1090                   (void*)DMAHelper::base(&copy), in.TotalBytes(), true);
   1091       GPUUtil::CopyGPUTensorToCPU(
   1092           src_dev_, send_dev_context, &in, &copy,
   1093           [this, copy, proto, is_dead](const Status& s) {
   1094             Send(copy, proto, is_dead, s);
   1095           });
   1096     } else {
   1097       GPUUtil::SetProtoFromGPU(
   1098           in, src_dev_, send_args.device_context, &proto, is_dead,
   1099           [this, in, proto, is_dead](const Status& s) mutable {
   1100             Send(in, proto, is_dead, s);
   1101           });
   1102     }
   1103 #else
   1104     SendErrorStatus(errors::Internal("No GPU device in process"));
   1105 #endif  // GOOGLE_CUDA
   1106   } else {
   1107     // tensor is in CPU memory.
   1108     if (!can_memcpy) {
   1109       in.AsProtoTensorContent(&proto);
   1110     }
   1111     Send(in, proto, is_dead, Status::OK());
   1112   }
   1113 }
   1114 
   1115 void RdmaTensorResponse::Send(const Tensor& in, const TensorProto& proto,
   1116                               bool is_dead, const Status& status) {
   1117   if (!status.ok()) {
   1118     SendErrorStatus(status);
   1119     return;
   1120   }
   1121   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
   1122   bool proto_size_changed =
   1123       (!can_memcpy) && (proto.ByteSize() != rm_.tensor_bytes_);
   1124   if (meta_data_changed_ || proto_size_changed) {
   1125     Clone(in, proto, is_dead);
   1126     SendMetaData(in, proto, is_dead);
   1127   } else {
   1128     SendContent(in, proto, is_dead);
   1129   }
   1130 }
   1131 
   1132 bool RdmaTensorResponse::TensorMetaDataChanged(const Tensor& in, bool is_dead) {
   1133   return (rm_.data_type_ != in.dtype()) || (rm_.tensor_shape_ != in.shape()) ||
   1134          (rm_.is_dead_ != is_dead);
   1135 }
   1136 
   1137 void RdmaTensorResponse::Clone(const Tensor& in, const TensorProto& proto,
   1138                                bool is_dead) {
   1139   // Clone the data to be sent later. For simplicity, we clone the tensor's
   1140   // data even if it is already a copy. Performance is less of a concern here
   1141   // since the meta-data hardly ever changes. The reason we create a copy, is
   1142   // that some tensors share their buffer between different step-ids, so the
   1143   // tensor content may change before re-request was completed.
   1144   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
   1145   if (can_memcpy && (in.TotalBytes() > 0)) {
   1146     AllocatorAttributes host_alloc_attrs;
   1147     host_alloc_attrs.set_nic_compatible(true);
   1148     host_alloc_attrs.set_on_host(true);
   1149     Allocator* allocator = src_dev_->GetAllocator(host_alloc_attrs);
   1150     tensor_ = new Tensor(allocator, in.dtype(), in.shape());
   1151     memcpy(DMAHelper::base(tensor_), DMAHelper::base(&in), in.TotalBytes());
   1152   } else {
   1153     tensor_ = new Tensor(in.dtype(), in.shape());
   1154   }
   1155   if (!can_memcpy) {
   1156     proto_ = new TensorProto(proto);
   1157   }
   1158   is_dead_ = is_dead;
   1159 }
   1160 
   1161 void RdmaTensorResponse::SendMetaData(const Tensor& in,
   1162                                       const TensorProto& proto, bool is_dead) {
   1163   RDMA_LOG(2) << "Request #" << rm_.request_index_
   1164               << ": Meta data changed: " << rm_.name_;
   1165   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
   1166   size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
   1167 
   1168   // Send meta-data update:
   1169   RdmaMessage rm;
   1170   rm.type_ = RDMA_MESSAGE_META_DATA_UPDATE;
   1171   rm.name_size_ = rm_.name_.size();
   1172   rm.name_ = rm_.name_;
   1173   rm.tensor_shape_ = in.shape();
   1174   rm.data_type_ = in.dtype();
   1175   rm.step_id_ = rm_.step_id_;
   1176   rm.is_dead_ = is_dead;
   1177   rm.tensor_bytes_ = tensor_bytes;
   1178   rm.request_index_ = rm_.request_index_;
   1179 #ifdef RDMA_DATA_VALIDATION
   1180   rm.checksum_ = checksum_;
   1181 #endif
   1182   RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
   1183               << ": Sending RDMA_MESSAGE_META_DATA_UPDATE #"
   1184               << rm.request_index_ << ": " << rm.name_
   1185               << " (shape = " << rm.tensor_shape_.DebugString() << "."
   1186               << " data-type = " << DataTypeString(rm.data_type_) << "."
   1187               << " is-dead = " << rm.is_dead_ << ")";
   1188 
   1189   string message = RdmaMessage::CreateMessage(rm);
   1190   channel_->tx_message_buffer_->EnqueueItem(message);
   1191   channel_->tx_message_buffer_->SendNextItem();
   1192 }
   1193 
   1194 void RdmaTensorResponse::SendContent(const Tensor& in, const TensorProto& proto,
   1195                                      bool is_dead) {
   1196   bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
   1197   size_t tensor_bytes = (can_memcpy) ? in.TotalBytes() : proto.ByteSize();
   1198   uint32_t imm_data = rm_.request_index_;
   1199   if (!is_dead) {
   1200     if (can_memcpy) {
   1201       src_buffer_ = const_cast<TensorBuffer*>(DMAHelper::buffer(&in));
   1202       if (src_buffer_ != nullptr) {
   1203         src_buffer_->Ref();  // Keep buffer alive until write is complete
   1204         src_addr_ = src_buffer_->data();
   1205         mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(src_addr_,
   1206                                                           tensor_bytes);
   1207       }
   1208     } else {
   1209       RDMA_LOG(2) << "Encoding proto: " << rm_.name_
   1210                   << " (Size: " << tensor_bytes << ") " << in.DebugString();
   1211       src_addr_ = malloc(tensor_bytes);
   1212       mr_ = ibv_reg_mr(channel_->adapter_->pd_, src_addr_, tensor_bytes,
   1213                        IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
   1214       proto.SerializeToArray(src_addr_, tensor_bytes);
   1215     }
   1216   } else {
   1217     tensor_bytes = 0;
   1218   }
   1219 
   1220   uint32_t lkey = (mr_ == nullptr) ? 0 : mr_->lkey;
   1221   RDMA_LOG(1) << "Step 0x" << std::hex << rm_.step_id_ << std::dec
   1222               << ": Sending tensor content #" << rm_.request_index_ << " from "
   1223               << std::hex << src_addr_ << " (0x" << lkey << ")"
   1224               << " to " << rm_.remote_addr_ << " (0x" << rm_.rkey_
   1225               << "): " << rm_.name_ << " (size: 0x" << std::hex << tensor_bytes
   1226               << ")";
   1227 
   1228   RdmaMessageBuffer::Write(channel_, imm_data, tensor_bytes,
   1229                            (uint64_t)src_addr_, lkey, rm_.remote_addr_,
   1230                            rm_.rkey_, RDMA_WRITE_ID_TENSOR_WRITE, this);
   1231 }
   1232 
   1233 void RdmaTensorResponse::SendErrorStatus(const Status& status) {
   1234   RdmaMessage rm;
   1235   rm.type_ = RDMA_MESSAGE_ERROR_STATUS;
   1236   rm.name_size_ = rm_.name_.size();
   1237   rm.name_ = rm_.name_;
   1238   rm.step_id_ = rm_.step_id_;
   1239   rm.request_index_ = rm_.request_index_;
   1240   rm.status_ = status;
   1241   LOG(ERROR) << "Step 0x" << std::hex << rm.step_id_ << std::dec
   1242              << ": Sending RDMA_MESSAGE_ERROR_STATUS #" << rm.request_index_
   1243              << ": " << rm.name_ << ". Status: " << status.ToString();
   1244 
   1245   string message = RdmaMessage::CreateMessage(rm);
   1246   channel_->tx_message_buffer_->EnqueueItem(message);
   1247   channel_->tx_message_buffer_->SendNextItem();
   1248 
   1249   // Destroy the response.
   1250   Destroy();
   1251 }
   1252 
   1253 void RdmaTensorResponse::Destroy() {
   1254   if (src_buffer_ != nullptr) {
   1255     src_buffer_->Unref();
   1256   }
   1257   if (tensor_ != nullptr) {
   1258     delete tensor_;
   1259   }
   1260   if (proto_ != nullptr) {
   1261     ibv_dereg_mr(mr_);
   1262     free(src_addr_);
   1263     delete proto_;
   1264   }
   1265   // Remove response from the pending list:
   1266   channel_->RemoveTensorResponse(rm_.request_index_);
   1267 }
   1268 
   1269 // Create a RdmaMessage according to the pre-defined format
   1270 // Args:
   1271 //   rm: the message structure
   1272 // Returns:
   1273 //   message in string format
   1274 string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
   1275   // Rdma Message format
   1276   // type|name_size|name|step_id|request_index|remote_addr|rkey|is_dead|...
   1277   //   1B|    2B   | 512|  8B   |     8B      |       8B  | 4B |    1B |...
   1278   // ...|data_type|tensor_shape|tensor_bytes|error_status          |
   1279   // ...|   XB    |    XB      |    8B      |size - 4B, proto - XB |
   1280   //
   1281   // ACK:             Imm-type: ACK
   1282   // TENSOR_REQUEST:  Imm-type: MESSAGE
   1283   //                  Fields: type, request_index, name, step_id, remote_addr,
   1284   //                      rkey, is_dead, data_type, tensor_shape, tensor_bytes
   1285   // META_DATA_UPDATE: Imm-type: MESSAGE
   1286   //                  Fields: type, request_index, is_dead, data_type,
   1287   //                      tensor_shape, tensor_bytes
   1288   // TENSOR_RE_REQUST: Imm-type: MESSAGE
   1289   //                  Fields: type, request_index, name, step_id, remote_addr,
   1290   //                      rkey, is_dead, data_type, tensor_shape, tensor_bytes
   1291   // ERROR_STATUS:    Imm-type: MESSAGE
   1292   //                  Fields: type, request_index, name, step_id, error_status
   1293   // Tensor content:  Imm-type: request_index
   1294   size_t message_size = kMessageTotalBytes;
   1295   char message[kMessageTotalBytes + kErrorStatusMaxSize];
   1296   // type
   1297   message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
   1298   // request index
   1299   memcpy(&message[kRequestIndexStartIndex], &rm.request_index_,
   1300          sizeof(rm.request_index_));
   1301   // name, step_id, remote_addr, rkey
   1302   if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
   1303       (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
   1304     memcpy(&message[kNameSizeStartIndex], &rm.name_size_,
   1305            sizeof(rm.name_size_));
   1306     memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
   1307     memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
   1308            sizeof(rm.remote_addr_));
   1309     memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
   1310     memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
   1311   }
   1312   // is_dead, data_type, tensor_shape, tensor_bytes
   1313   if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
   1314       (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
   1315       (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
   1316     memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
   1317 
   1318     memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
   1319            sizeof(rm.data_type_));
   1320     memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
   1321            sizeof(rm.tensor_shape_));
   1322     memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
   1323            sizeof(rm.tensor_bytes_));
   1324   }
   1325   // checksum
   1326 #ifdef RDMA_DATA_VALIDATION
   1327   memcpy(&message[kChecksumStartIndex], &rm.checksum_, sizeof(rm.checksum_));
   1328 #endif
   1329   // error status
   1330   if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
   1331     ::grpc::Status gs = ToGrpcStatus(rm.status_);
   1332     ErrorStatusProto gsProto;
   1333     gsProto.set_error_code(gs.error_code());
   1334     gsProto.set_error_message(gs.error_message());
   1335     gsProto.set_error_details(gs.error_details());
   1336     uint32_t gsProtoSize = gsProto.ByteSize();
   1337     if (gsProtoSize + 4 > kErrorStatusMaxSize) {
   1338       LOG(ERROR) << "Error status (" << gsProtoSize + 4 << " bytes) "
   1339                  << "is too big to fit in RDMA message (" << kErrorStatusMaxSize
   1340                  << " bytes). Truncated.";
   1341       gsProtoSize = kErrorStatusMaxSize - 4;
   1342     }
   1343     uint32_t* proto_size = (uint32_t*)&message[kErrorStatusStartIndex];
   1344     *proto_size = gsProtoSize;
   1345     gsProto.SerializeToArray(&message[kErrorStatusStartIndex + 4], gsProtoSize);
   1346     message_size += gsProtoSize + 4;
   1347   }
   1348   return string(message, message_size);
   1349 }
   1350 
   1351 // Parse a RdmaMessage according to the pre-defined format
   1352 // Args:
   1353 //   rm: the message structure where the parsed message will be saved
   1354 //   buffer: the place where the raw message is stored
   1355 // Returns:
   1356 //   None
   1357 void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
   1358   char* message = static_cast<char*>(buffer);
   1359   // type
   1360   rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
   1361   // request index
   1362   memcpy(&rm.request_index_, &message[kRequestIndexStartIndex],
   1363          sizeof(rm.request_index_));
   1364   // name, step_id, remote_addr, rkey
   1365   if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
   1366       (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
   1367     memcpy(&rm.name_size_, &message[kNameSizeStartIndex],
   1368            sizeof(rm.name_size_));
   1369     rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
   1370     memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
   1371            sizeof(rm.remote_addr_));
   1372     memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
   1373     memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
   1374   }
   1375   // data_type, tensor_bytes, tensor_shape, is_dead
   1376   if ((rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) ||
   1377       (rm.type_ == RDMA_MESSAGE_META_DATA_UPDATE) ||
   1378       (rm.type_ == RDMA_MESSAGE_TENSOR_RE_REQUEST)) {
   1379     memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
   1380     memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
   1381            sizeof(rm.data_type_));
   1382     memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
   1383            sizeof(rm.tensor_shape_));
   1384     memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
   1385            sizeof(rm.tensor_bytes_));
   1386   }
   1387   // checksum
   1388 #ifdef RDMA_DATA_VALIDATION
   1389   memcpy(&rm.checksum_, &message[kChecksumStartIndex], sizeof(rm.checksum_));
   1390 #endif
   1391   // error status
   1392   if (rm.type_ == RDMA_MESSAGE_ERROR_STATUS) {
   1393     ErrorStatusProto gsProto;
   1394     uint32_t gsProtoSize = *(uint32_t*)&message[kErrorStatusStartIndex];
   1395     CHECK(ParseProtoUnlimited(&gsProto, &message[kErrorStatusStartIndex + 4],
   1396                               gsProtoSize))
   1397         << "Failed to parse error status proto from message. Aborting.";
   1398     ::grpc::Status gs((::grpc::StatusCode)gsProto.error_code(),
   1399                       gsProto.error_message(), gsProto.error_details());
   1400     rm.status_ = FromGrpcStatus(gs);
   1401   }
   1402 }
   1403 
   1404 //*****************************************************************************
   1405 // RdmaMemoryMgr
   1406 //*****************************************************************************
   1407 
   1408 ibv_mr* RdmaMemoryMgr::FindMemoryRegion(void* addr, size_t length) {
   1409   mutex_lock l(mrs_mu_);
   1410   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
   1411   if (iter == std::end(mrs_) || iter->get()->addr > addr) {
   1412     return nullptr;
   1413   } else {
   1414     return iter->get();
   1415   }
   1416 }
   1417 
   1418 void RdmaMemoryMgr::InsertMemoryRegion(void* addr, size_t length,
   1419                                        const std::string& allocator_name) {
   1420   if (length == 0) return;
   1421   ibv_mr* mr = ibv_reg_mr(pd_, addr, length,
   1422                           IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
   1423   RDMA_LOG(1) << "Insert memory region 0x" << std::hex << mr->rkey << ". ["
   1424               << addr << "-" << (void*)((uint64_t)addr + length - 1) << "]"
   1425               << " SIZE: 0x" << length << " (" << allocator_name << ").";
   1426   if (mr != nullptr) {
   1427     mutex_lock l(mrs_mu_);
   1428     auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
   1429     mrs_.insert(iter, {mr, &MRDeleter});
   1430   } else {
   1431     LOG(WARNING) << "Cannot register memory region";
   1432   }
   1433 }
   1434 
   1435 void RdmaMemoryMgr::EvictMemoryRegion(void* addr, size_t length) {
   1436   if (length == 0) return;
   1437   mutex_lock l(mrs_mu_);
   1438   auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
   1439   if (iter != std::end(mrs_) && iter->get()->addr == addr) {
   1440     mrs_.erase(iter);
   1441     RDMA_LOG(1) << "Evict memory region 0x" << std::hex << iter->get()->rkey;
   1442 
   1443   } else {
   1444     LOG(WARNING) << "Failed to de-register memory region";
   1445   }
   1446 }
   1447 
   1448 const TensorMetaData* RdmaMemoryMgr::GetTensorMetaData(
   1449     const std::string& tensor_name) {
   1450   mutex_lock l(tensor_meta_data_mu_);
   1451   auto it = tensors_meta_data_.find(tensor_name);
   1452   if (it == tensors_meta_data_.end()) {
   1453     return nullptr;
   1454   }
   1455   return &it->second;
   1456 }
   1457 
   1458 const TensorMetaData* RdmaMemoryMgr::SetTensorMetaData(
   1459     const std::string& tensor_name, DataType dtype, const TensorShape& shape,
   1460     bool is_dead, size_t proto_size) {
   1461   mutex_lock l(tensor_meta_data_mu_);
   1462   TensorMetaData& meta_data = tensors_meta_data_[tensor_name];
   1463   meta_data.data_type_ = dtype;
   1464   meta_data.tensor_shape_ = shape;
   1465   meta_data.proto_size_ = proto_size;
   1466   meta_data.is_dead_ = is_dead;
   1467   return &meta_data;
   1468 }
   1469 
   1470 //*****************************************************************************
   1471 // RdmaTensorRequest
   1472 //*****************************************************************************
   1473 
   1474 RdmaTensorRequest::RdmaTensorRequest(
   1475     uint32_t index, const string& key, int64 step_id, RdmaChannel* channel,
   1476     Device* dst_dev, const Rendezvous::Args recv_args,
   1477     const RdmaTensorRequest::RecvDoneCallback& done)
   1478     : index_(index),
   1479       key_(key),
   1480       step_id_(step_id),
   1481       channel_(channel),
   1482       dst_dev_(dst_dev),
   1483       recv_args_(recv_args),
   1484       meta_data_(RdmaMemoryMgr::Singleton().GetTensorMetaData(key)),
   1485       result_tensor_(nullptr),
   1486       proxy_tensor_(nullptr),
   1487       rdma_addr_(nullptr),
   1488       mr_(nullptr),
   1489       done_(done) {}
   1490 
   1491 RdmaTensorRequest::~RdmaTensorRequest() { DeallocateTensors(); }
   1492 
   1493 void RdmaTensorRequest::Done(const Status& s) {
   1494   Tensor val = std::move(*result_tensor_);
   1495 
   1496 #ifdef RDMA_DATA_VALIDATION
   1497   // Validate checksum
   1498   // Unfortunately we can't always do a Checksum directly on the result tensor.
   1499   // If the result tensor is on GPU, then we need to copy it back to CPU. If
   1500   // we happen to be in the midst of a proxy callback, then the copying will
   1501   // get stuck.
   1502   uint64_t checksum = (proxy_tensor_ != nullptr)
   1503                           ? Checksum(nullptr, nullptr, *proxy_tensor_)
   1504                           : Checksum(dst_dev_, recv_args_.device_context, val);
   1505   ValidateChecksum(checksum_, checksum, val, index_, key_, "RDMA");
   1506 #endif
   1507 
   1508   Rendezvous::Args recv_args = std::move(recv_args_);
   1509   bool is_dead = (meta_data_ == nullptr) ? false : meta_data_->is_dead_;
   1510   RecvDoneCallback done = done_;
   1511   DeallocateTensors();
   1512   channel_->RemoveTensorRequest(index_);
   1513   done(s, Rendezvous::Args(), recv_args, val, is_dead);
   1514 }
   1515 
   1516 void RdmaTensorRequest::DeallocateTensors() {
   1517   if (result_tensor_ != nullptr) {
   1518     delete result_tensor_;
   1519     result_tensor_ = nullptr;
   1520   }
   1521   if (proxy_tensor_ != nullptr) {
   1522     delete proxy_tensor_;
   1523     proxy_tensor_ = nullptr;
   1524   }
   1525 }
   1526 
   1527 bool RdmaTensorRequest::AllocateTensors() {
   1528   result_tensor_ =
   1529       new Tensor(dst_dev_->GetAllocator(recv_args_.alloc_attrs),
   1530                  meta_data_->data_type_, meta_data_->tensor_shape_);
   1531 
   1532   size_t tensor_size = result_tensor_->TotalBytes();
   1533   bool can_memcpy = DataTypeCanUseMemcpy(result_tensor_->dtype());
   1534   if (can_memcpy) {
   1535     if (tensor_size == 0) {
   1536       return true;
   1537     }
   1538     rdma_addr_ = DMAHelper::base(result_tensor_);
   1539     mr_ = RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
   1540 #if GOOGLE_CUDA
   1541     if (mr_ == nullptr) {
   1542       // Can't RDMA directly to result. Use a proxy.
   1543       proxy_tensor_ =
   1544           new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0),
   1545                      result_tensor_->dtype(), result_tensor_->shape());
   1546       rdma_addr_ = DMAHelper::base(proxy_tensor_);
   1547       mr_ =
   1548           RdmaMemoryMgr::Singleton().FindMemoryRegion(rdma_addr_, tensor_size);
   1549     }
   1550 #endif
   1551   } else {
   1552     uint32_t proto_size = meta_data_->proto_size_;
   1553     rdma_addr_ = malloc(proto_size);
   1554     mr_ = ibv_reg_mr(RdmaMemoryMgr::Singleton().pd_, rdma_addr_, proto_size,
   1555                      IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
   1556   }
   1557   CHECK(mr_ != nullptr) << " No memory region found for address " << rdma_addr_
   1558                         << ": " << key_;
   1559   return true;
   1560 }
   1561 
   1562 void RdmaTensorRequest::AllocateTensorsAsync(StatusCallback done) {
   1563   AllocateTensors();
   1564   bool on_host = recv_args_.alloc_attrs.on_host();
   1565   if (dst_dev_->tensorflow_gpu_device_info() && !on_host &&
   1566       (proxy_tensor_ == nullptr)) {
   1567 #if GOOGLE_CUDA
   1568     // We need to sync the memory allocation on the GPU:
   1569     StreamGPUOp(dst_dev_, recv_args_.device_context, done);
   1570 #endif
   1571   } else {
   1572     done(Status::OK());
   1573   }
   1574 }
   1575 
   1576 void RdmaTensorRequest::Send(RdmaMessageType message_type) {
   1577   RdmaMessageBuffer* rb = channel_->tx_message_buffer_;
   1578   RdmaMessage rm;
   1579   rm.type_ = message_type;
   1580   rm.request_index_ = index_;
   1581   rm.name_size_ = key_.size();
   1582   rm.name_ = key_;
   1583   rm.step_id_ = step_id_;
   1584   rm.remote_addr_ = (uint64_t)rdma_addr_;
   1585   if (meta_data_ != nullptr) {
   1586     rm.data_type_ = meta_data_->data_type_;
   1587     rm.tensor_shape_ = meta_data_->tensor_shape_;
   1588     rm.is_dead_ = meta_data_->is_dead_;
   1589     rm.tensor_bytes_ = meta_data_->proto_size_;
   1590   } else {
   1591     rm.data_type_ = DT_INVALID;
   1592   }
   1593   rm.rkey_ = (mr_ == nullptr) ? 0 : mr_->rkey;
   1594 
   1595   RDMA_LOG(1) << "Step 0x" << std::hex << rm.step_id_ << std::dec
   1596               << ": Sending  " << MessageTypeToString(message_type) << " #"
   1597               << index_ << ": " << rm.name_ << " on " << rdma_addr_
   1598               << " (rkey: 0x" << std::hex << rm.rkey_ << ")";
   1599 
   1600   string message = RdmaMessage::CreateMessage(rm);
   1601   rb->EnqueueItem(message);
   1602   rb->SendNextItem();
   1603 }
   1604 
   1605 void RdmaTensorRequest::RecvTensorMetaData(DataType dtype, TensorShape shape,
   1606                                            bool is_dead, size_t proto_size) {
   1607   meta_data_ = RdmaMemoryMgr::Singleton().SetTensorMetaData(
   1608       key_, dtype, shape, is_dead, proto_size);
   1609 
   1610   DeallocateTensors();
   1611   AllocateTensorsAsync(
   1612       [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_RE_REQUEST); });
   1613 }
   1614 
   1615 void RdmaTensorRequest::RecvTensorContent() {
   1616   bool can_memcpy = DataTypeCanUseMemcpy(meta_data_->data_type_);
   1617   size_t message_size =
   1618       can_memcpy ? result_tensor_->TotalBytes() : meta_data_->proto_size_;
   1619   RDMA_LOG(1) << "Step 0x" << std::hex << step_id_ << std::dec
   1620               << ": Received tensor content #" << index_ << ": " << key_
   1621               << " (Size: 0x" << std::hex << message_size << ")";
   1622 
   1623   Tensor val;
   1624 
   1625 #if GOOGLE_CUDA
   1626   if (proxy_tensor_ != nullptr) {
   1627     CountCopies(key_, (void*)DMAHelper::base(proxy_tensor_),
   1628                 (void*)DMAHelper::base(result_tensor_),
   1629                 result_tensor_->TotalBytes(), false);
   1630     GPUUtil::CopyCPUTensorToGPU(proxy_tensor_, recv_args_.device_context,
   1631                                 dst_dev_, result_tensor_,
   1632                                 [this](const Status& s) {
   1633                                   CHECK(s.ok()) << "copy tensor to gpu sync";
   1634                                   Done(s);
   1635                                 });
   1636     return;
   1637   }
   1638 #endif
   1639 
   1640   if (can_memcpy) {
   1641     Done(Status::OK());
   1642   } else {
   1643     RDMA_LOG(2) << "Decoding proto: " << key_
   1644                 << " (Size: " << meta_data_->proto_size_ << ")";
   1645     TensorProto proto;
   1646     CHECK(ParseProtoUnlimited(&proto, rdma_addr_, meta_data_->proto_size_))
   1647         << "fail to parse proto from array";
   1648     ibv_dereg_mr(mr_);
   1649     free(rdma_addr_);
   1650     Status s = dst_dev_->MakeTensorFromProto(proto, recv_args_.alloc_attrs,
   1651                                              result_tensor_);
   1652     Done(s);
   1653   }
   1654 }
   1655 
   1656 void RdmaTensorRequest::RecvErrorStatus(const Status& status) {
   1657   if (result_tensor_ == nullptr) {
   1658     result_tensor_ = new Tensor();
   1659   }
   1660   LOG(ERROR) << "Received RDMA_MESSAGE_ERROR_STATUS: " << status.ToString();
   1661   Done(status);
   1662 }
   1663 
   1664 void RdmaTensorRequest::Start() {
   1665   meta_data_ = RdmaMemoryMgr::Singleton().GetTensorMetaData(key_);
   1666   if (meta_data_ != nullptr) {
   1667     AllocateTensorsAsync(
   1668         [this](const Status& s) { Send(RDMA_MESSAGE_TENSOR_REQUEST); });
   1669   } else {
   1670     Send(RDMA_MESSAGE_TENSOR_REQUEST);
   1671   }
   1672 }
   1673 
   1674 }  // end namespace tensorflow
   1675 
   1676 #endif
   1677