Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/sendrecv_ops.h"
     17 
     18 #include "tensorflow/core/framework/op.h"
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/lib/strings/numbers.h"
     21 #include "tensorflow/core/lib/strings/strcat.h"
     22 #include "tensorflow/core/platform/logging.h"
     23 
     24 namespace tensorflow {
     25 
     26 static string GetRendezvousKeyPrefix(const string& send_device,
     27                                      const string& recv_device,
     28                                      const uint64 send_device_incarnation,
     29                                      const string& tensor_name) {
     30   return strings::StrCat(send_device, ";",
     31                          strings::FpToString(send_device_incarnation), ";",
     32                          recv_device, ";", tensor_name);
     33 }
     34 
     35 static void GetRendezvousKey(const string& key_prefix,
     36                              const FrameAndIter& frame_iter, string* key) {
     37   key->clear();
     38   strings::StrAppend(key, key_prefix, ";", frame_iter.frame_id, ":",
     39                      frame_iter.iter_id);
     40 }
     41 
     42 static FrameAndIter GetFrameAndIter(OpKernelContext* ctx,
     43                                     bool hostmem_sendrecv) {
     44   if (hostmem_sendrecv && ctx->call_frame() != nullptr) {
     45     // Host memory send/recv pairs are added by
     46     // common_runtime/memory_types.cc.  When the pair of nodes are
     47     // added inside a function, we need to use the function call frame
     48     // to formulate the unique rendezvous key.
     49     return FrameAndIter(reinterpret_cast<uint64>(ctx->call_frame()), 0);
     50   } else {
     51     return ctx->frame_iter();
     52   }
     53 }
     54 
     55 SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     56   string send_device;
     57   OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
     58   string recv_device;
     59   OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
     60   uint64 send_device_incarnation;
     61   OP_REQUIRES_OK(
     62       ctx, ctx->GetAttr("send_device_incarnation",
     63                         reinterpret_cast<int64*>(&send_device_incarnation)));
     64   string tensor_name;
     65   OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
     66   key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
     67                                        send_device_incarnation, tensor_name);
     68   // The vast majority of Send nodes are outside any loop context, so
     69   // proactively cache the rendezvous key for the top-level.
     70   GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
     71   OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
     72   if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) {
     73     hostmem_sendrecv_ = false;
     74   }
     75 }
     76 
     77 void SendOp::Compute(OpKernelContext* ctx) {
     78   OP_REQUIRES(
     79       ctx, ctx->rendezvous() != nullptr,
     80       errors::Internal("Op kernel context needs to provide a rendezvous."));
     81 
     82   // The device context may be passed between the Send/Recv
     83   // boundary, so that the device context used to produce the Tensor
     84   // is used when performing the copy on the recv side (which may be
     85   // a different device).
     86   Rendezvous::Args args;
     87   args.device_context = ctx->op_device_context();
     88   args.alloc_attrs = ctx->input_alloc_attr(0);
     89 
     90   FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
     91   if (frame_iter == FrameAndIter(0, 0)) {
     92     // Use the cached rendezvous key.
     93     VLOG(2) << "Send " << parsed_key_.buf_;
     94     ctx->SetStatus(ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0),
     95                                            ctx->is_input_dead()));
     96     return;
     97   } else {
     98     Rendezvous::ParsedKey in_loop_parsed;
     99     GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_);
    100     VLOG(2) << "Send " << in_loop_parsed.buf_;
    101     OP_REQUIRES_OK(ctx,
    102                    Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed));
    103 
    104     ctx->SetStatus(ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0),
    105                                            ctx->is_input_dead()));
    106     return;
    107   }
    108 }
    109 
    110 REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
    111 REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_GPU), SendOp);
    112 
    113 #ifdef TENSORFLOW_USE_SYCL
    114 REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_SYCL), SendOp);
    115 REGISTER_KERNEL_BUILDER(
    116     Name("_HostSend").Device(DEVICE_SYCL).HostMemory("tensor"), SendOp);
    117 #endif  // TENSORFLOW_USE_SYCL
    118 
    119 REGISTER_KERNEL_BUILDER(Name("_HostSend").Device(DEVICE_CPU), SendOp);
    120 REGISTER_KERNEL_BUILDER(
    121     Name("_HostSend").Device(DEVICE_GPU).HostMemory("tensor"), SendOp);
    122 
    123 RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
    124   string send_device;
    125   OP_REQUIRES_OK(ctx, ctx->GetAttr("send_device", &send_device));
    126   string recv_device;
    127   OP_REQUIRES_OK(ctx, ctx->GetAttr("recv_device", &recv_device));
    128   uint64 send_device_incarnation;
    129   OP_REQUIRES_OK(
    130       ctx, ctx->GetAttr("send_device_incarnation",
    131                         reinterpret_cast<int64*>(&send_device_incarnation)));
    132   string tensor_name;
    133   OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
    134   key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
    135                                        send_device_incarnation, tensor_name);
    136   // The vast majority of Recv nodes are outside any loop context, so
    137   // proactively cache the rendezvous key for the top-level.
    138   GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
    139   OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
    140   if (!ctx->GetAttr("_hostmem_sendrecv", &hostmem_sendrecv_).ok()) {
    141     hostmem_sendrecv_ = false;
    142   }
    143 }
    144 
    145 namespace {
    146 Rendezvous::DoneCallback make_recv_callback(OpKernelContext* ctx,
    147                                             AsyncOpKernel::DoneCallback done) {
    148   using namespace std::placeholders;
    149   return std::bind(
    150       [ctx](AsyncOpKernel::DoneCallback done,
    151             // Begin unbound arguments.
    152             const Status& s, const Rendezvous::Args& send_args,
    153             const Rendezvous::Args& recv_args, const Tensor& val,
    154             bool is_dead) {
    155         ctx->SetStatus(s);
    156         if (s.ok()) {
    157           // 'ctx' allocates the output tensor of the expected type.
    158           // The runtime checks whether the tensor received here is
    159           // the same type.
    160           if (!is_dead) {
    161             ctx->set_output(0, val);
    162           }
    163           *ctx->is_output_dead() = is_dead;
    164         }
    165         done();
    166       },
    167       std::move(done), _1, _2, _3, _4, _5);
    168 }
    169 }  // namespace
    170 
    171 void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
    172   OP_REQUIRES(
    173       ctx, ctx->rendezvous() != nullptr,
    174       errors::Internal("Op kernel context needs to provide a rendezvous."));
    175 
    176   Rendezvous::Args args;
    177   args.device_context = ctx->op_device_context();
    178   args.alloc_attrs = ctx->output_alloc_attr(0);
    179 
    180   FrameAndIter frame_iter = GetFrameAndIter(ctx, hostmem_sendrecv_);
    181   if (frame_iter == FrameAndIter(0, 0)) {
    182     VLOG(2) << "Recv " << parsed_key_.buf_;
    183     ctx->rendezvous()->RecvAsync(parsed_key_, args,
    184                                  make_recv_callback(ctx, std::move(done)));
    185   } else {
    186     Rendezvous::ParsedKey in_loop_parsed;
    187     GetRendezvousKey(key_prefix_, frame_iter, &in_loop_parsed.buf_);
    188     VLOG(2) << "Recv " << in_loop_parsed.buf_;
    189     OP_REQUIRES_OK_ASYNC(
    190         ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done);
    191     ctx->rendezvous()->RecvAsync(in_loop_parsed, args,
    192                                  make_recv_callback(ctx, std::move(done)));
    193   }
    194 }
    195 
    196 REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
    197 REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_GPU), RecvOp);
    198 
    199 #ifdef TENSORFLOW_USE_SYCL
    200 REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_SYCL), RecvOp);
    201 #endif  // TENSORFLOW_USE_SYCL
    202 
    203 REGISTER_KERNEL_BUILDER(Name("_HostRecv").Device(DEVICE_CPU), RecvOp);
    204 REGISTER_KERNEL_BUILDER(
    205     Name("_HostRecv").Device(DEVICE_GPU).HostMemory("tensor"), RecvOp);
    206 
    207 #ifdef TENSORFLOW_USE_SYCL
    208 REGISTER_KERNEL_BUILDER(
    209     Name("_HostRecv").Device(DEVICE_SYCL).HostMemory("tensor"), RecvOp);
    210 #endif  // TENSORFLOW_USE_SYCL
    211 
    212 }  // end namespace tensorflow
    213