1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/sendrecv_ops.h" 17 18 #include "tensorflow/core/framework/op.h" 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/lib/strings/numbers.h" 21 #include "tensorflow/core/lib/strings/strcat.h" 22 #include "tensorflow/core/platform/logging.h" 23 24 namespace tensorflow { 25 26 static string GetRendezvousKeyPrefix(const string& send_device, 27 const string& recv_device, 28 const uint64 send_device_incarnation, 29 const string& tensor_name) { 30 return strings::StrCat(send_device, ";", 31 strings::FpToString(send_device_incarnation), ";", 32 recv_device, ";", tensor_name); 33 } 34 35 static void GetRendezvousKey(const string& key_prefix, 36 const FrameAndIter& frame_iter, string* key) { 37 key->clear(); 38 strings::StrAppend(key, key_prefix, ";", frame_iter.frame_id, ":", 39 frame_iter.iter_id); 40 } 41 42 static FrameAndIter GetFrameAndIter(OpKernelContext* ctx, 43 bool hostmem_sendrecv) { 44 if (hostmem_sendrecv && ctx->call_frame() != nullptr) { 45 // Host memory send/recv pairs are added by 46 // common_runtime/memory_types.cc. When the pair of nodes are 47 // added inside a function, we need to use the function call frame 48 // to formulate the unique rendezvous key. 49 return FrameAndIter(reinterpret_cast<uint64>(ctx->call_frame()), 0); 50 } else { 51 return ctx->frame_iter(); 52 } 53 } 54 55 SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 56 string send_device; 57 OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); 58 string recv_device; 59 OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); 60 uint64 send_device_incarnation; 61 OP_REQUIRES_OK( 62 ctx, ctx->GetAttr("send_device_incarnation", 63 reinterpret_cast<int64*>(&send_device_incarnation))); 64 string tensor_name; 65 OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); 66 key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device, 67 send_device_incarnation, tensor_name); 68 // The vast majority of Send nodes are outside any loop context, so 69 // proactively cache the rendezvous key for the top-level. 70 GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_); 71 OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_)); 72 if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { 73 hostmem_sendrecv_ = false; 74 } 75 } 76 77 void SendOp::Compute(OpKernelContext* ctx) { 78 OP_REQUIRES( 79 ctx, ctx->rendezvous() != nullptr, 80 errors::Internal("Op kernel context needs to provide a rendezvous.")); 81 82 // The device context may be passed between the Send/Recv 83 // boundary, so that the device context used to produce the Tensor 84 // is used when performing the copy on the recv side (which may be 85 // a different device). 86 Rendezvous::Args args; 87 args.device_context = ctx->op_device_context(); 88 args.alloc_attrs = ctx->input_alloc_attr(0); 89 90 FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_); 91 if (frame_iter == FrameAndIter(0, 0)) { 92 // Use the cached rendezvous key. 93 VLOG(2) << "Send " << parsed_key_.buf_; 94 ctx->SetStatus(ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0), 95 ctx->is_input_dead())); 96 return; 97 } else { 98 Rendezvous::ParsedKey in_loop_parsed; 99 GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_); 100 VLOG(2) << "Send " << in_loop_parsed.buf_; 101 OP_REQUIRES_OK(ctx, 102 Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed)); 103 104 ctx->SetStatus(ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0), 105 ctx->is_input_dead())); 106 return; 107 } 108 } 109 110 REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp); 111 REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp); 112 113 #ifdef TENSORFLOW_USE_SYCL 114 REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_SYCL), SendOp); 115 REGISTER_KERNEL_BUILDER( 116 Name("_HostSend").Device(DEVICE_SYCL).HostMemory("tensor"), SendOp); 117 #endif // TENSORFLOW_USE_SYCL 118 119 REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp); 120 REGISTER_KERNEL_BUILDER( 121 Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp); 122 123 RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { 124 string send_device; 125 OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device)); 126 string recv_device; 127 OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device)); 128 uint64 send_device_incarnation; 129 OP_REQUIRES_OK( 130 ctx, ctx->GetAttr("send_device_incarnation", 131 reinterpret_cast<int64*>(&send_device_incarnation))); 132 string tensor_name; 133 OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name)); 134 key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device, 135 send_device_incarnation, tensor_name); 136 // The vast majority of Recv nodes are outside any loop context, so 137 // proactively cache the rendezvous key for the top-level. 138 GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_); 139 OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_)); 140 if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) { 141 hostmem_sendrecv_ = false; 142 } 143 } 144 145 namespace { 146 Rendezvous::DoneCallback make_recv_callback(OpKernelContext* ctx, 147 AsyncOpKernel::DoneCallback done) { 148 using namespace std::placeholders; 149 return std::bind( 150 [ctx](AsyncOpKernel::DoneCallback done, 151 // Begin unbound arguments. 152 const Status& s, const Rendezvous::Args& send_args, 153 const Rendezvous::Args& recv_args, const Tensor& val, 154 bool is_dead) { 155 ctx->SetStatus(s); 156 if (s.ok()) { 157 // 'ctx' allocates the output tensor of the expected type. 158 // The runtime checks whether the tensor received here is 159 // the same type. 160 if (!is_dead) { 161 ctx->set_output(0, val); 162 } 163 *ctx->is_output_dead() = is_dead; 164 } 165 done(); 166 }, 167 std::move(done), _1, _2, _3, _4, _5); 168 } 169 } // namespace 170 171 void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { 172 OP_REQUIRES( 173 ctx, ctx->rendezvous() != nullptr, 174 errors::Internal("Op kernel context needs to provide a rendezvous.")); 175 176 Rendezvous::Args args; 177 args.device_context = ctx->op_device_context(); 178 args.alloc_attrs = ctx->output_alloc_attr(0); 179 180 FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_); 181 if (frame_iter == FrameAndIter(0, 0)) { 182 VLOG(2) << "Recv " << parsed_key_.buf_; 183 ctx->rendezvous()->RecvAsync(parsed_key_, args, 184 make_recv_callback(ctx, std::move(done))); 185 } else { 186 Rendezvous::ParsedKey in_loop_parsed; 187 GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_); 188 VLOG(2) << "Recv " << in_loop_parsed.buf_; 189 OP_REQUIRES_OK_ASYNC( 190 ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done); 191 ctx->rendezvous()->RecvAsync(in_loop_parsed, args, 192 make_recv_callback(ctx, std::move(done))); 193 } 194 } 195 196 REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp); 197 REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp); 198 199 #ifdef TENSORFLOW_USE_SYCL 200 REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_SYCL), RecvOp); 201 #endif // TENSORFLOW_USE_SYCL 202 203 REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp); 204 REGISTER_KERNEL_BUILDER( 205 Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp); 206 207 #ifdef TENSORFLOW_USE_SYCL 208 REGISTER_KERNEL_BUILDER( 209 Name("_HostRecv").Device(DEVICE_SYCL).HostMemory("tensor"), RecvOp); 210 #endif // TENSORFLOW_USE_SYCL 211 212 } // end namespace tensorflow 213