Home | History | Annotate | Download | only in eager
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
     17 
     18 #include "absl/memory/memory.h"
     19 #include "tensorflow/c/c_api_internal.h"
     20 #include "tensorflow/c/tf_status_helper.h"
     21 #include "tensorflow/core/common_runtime/device_mgr.h"
     22 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
     23 #include "tensorflow/core/common_runtime/eager/execute.h"
     24 #include "tensorflow/core/common_runtime/process_util.h"
     25 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
     26 #include "tensorflow/core/distributed_runtime/server_lib.h"
     27 #include "tensorflow/core/distributed_runtime/session_mgr.h"
     28 #include "tensorflow/core/distributed_runtime/worker_cache.h"
     29 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
     30 #include "tensorflow/core/distributed_runtime/worker_env.h"
     31 #include "tensorflow/core/framework/rendezvous.h"
     32 #include "tensorflow/core/lib/core/error_codes.pb.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/gtl/cleanup.h"
     35 #include "tensorflow/core/lib/random/random.h"
     36 #include "tensorflow/core/lib/strings/strcat.h"
     37 #include "tensorflow/core/lib/strings/stringprintf.h"
     38 #include "tensorflow/core/platform/cpu_info.h"
     39 #include "tensorflow/core/platform/env.h"
     40 #include "tensorflow/core/platform/host_info.h"
     41 
     42 namespace tensorflow {
     43 namespace eager {
     44 
     45 namespace {
     46 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
     47                      const google::protobuf::Map<string, tensorflow::AttrValue>& attrs,
     48                      int* num_retvals) {
     49   const tensorflow::OpRegistrationData* op_reg_data = nullptr;
     50   auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data);
     51   if (errors::IsNotFound(status)) {
     52     status = context->FindFunctionOpData(op_name, &op_reg_data);
     53   }
     54   TF_RETURN_IF_ERROR(status);
     55 
     56   const tensorflow::OpDef& op_def = op_reg_data->op_def;
     57 
     58   for (const auto& output_arg : op_def.output_arg()) {
     59     if (!output_arg.number_attr().empty()) {
     60       auto iter = attrs.find(output_arg.number_attr());
     61       if (iter == attrs.end()) {
     62         return errors::InvalidArgument("Unable to find number_attr ",
     63                                        output_arg.number_attr(),
     64                                        " for Op: ", op_name);
     65       }
     66       *num_retvals += iter->second.i();
     67     } else if (!output_arg.type_list_attr().empty()) {
     68       auto iter = attrs.find(output_arg.type_list_attr());
     69       if (iter == attrs.end()) {
     70         return errors::InvalidArgument("Unable to find type_list_attr ",
     71                                        output_arg.type_list_attr(),
     72                                        " for Op: ", op_name);
     73       }
     74       *num_retvals += iter->second.list().type_size();
     75     } else {
     76       *num_retvals += 1;
     77     }
     78   }
     79 
     80   return Status::OK();
     81 }
     82 }  // namespace
     83 
     84 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
     85                                        CreateContextResponse* response) {
     86   // make sure env_ , env_->rendezvous_mgr available
     87   if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
     88     return tensorflow::errors::Internal(
     89         "invalid eager env_ or env_->rendezvous_mgr.");
     90   }
     91   std::vector<std::unique_ptr<tensorflow::Device>> devices;
     92 
     93   TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
     94       // TODO(nareshmodi): Correctly set the SessionOptions.
     95       SessionOptions(),
     96       strings::Printf("/job:%s/replica:0/task:%d",
     97                       request->server_def().job_name().data(),
     98                       request->server_def().task_index()),
     99       &devices));
    100   response->mutable_device_attributes()->Reserve(devices.size());
    101   for (const auto& d : devices) {
    102     *response->add_device_attributes() = d->attributes();
    103   }
    104 
    105   std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
    106       absl::make_unique<DeviceMgr>(std::move(devices));
    107 
    108   auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
    109   auto session_name = strings::StrCat("eager_", request->rendezvous_id());
    110   TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
    111       session_name, request->server_def(), true));
    112 
    113   std::shared_ptr<WorkerSession> worker_session;
    114   TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
    115       session_name, &worker_session));
    116 
    117   // Initialize remote tensor communication based on worker session.
    118   TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
    119 
    120   std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
    121       SessionOptions(),
    122       tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
    123       request->async(), std::move(device_mgr), r));
    124 
    125   uint64 context_id;
    126   {
    127     mutex_lock l(contexts_mu_);
    128     do {
    129       context_id = random::New64();
    130     } while (contexts_.find(context_id) != contexts_.end());
    131     contexts_.emplace(
    132         context_id,
    133         new ServerContext(std::move(ctx), request->keep_alive_secs(), env_));
    134   }
    135   response->set_context_id(context_id);
    136 
    137   return Status::OK();
    138 }
    139 
    140 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
    141   const tensorflow::Tensor* t = nullptr;
    142 
    143   // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
    144   TF_RETURN_IF_ERROR(handle->Tensor(&t));
    145 
    146   t->shape().AsProto(proto);
    147 
    148   return Status::OK();
    149 }
    150 
    151 Status EagerServiceImpl::ExecuteOp(const Operation& operation,
    152                                    ServerContext* server_context,
    153                                    QueueResponse* queue_response) {
    154   std::unique_ptr<tensorflow::EagerOperation> op;
    155   const char* name = operation.name().c_str();  // Shorthand
    156   const tensorflow::AttrTypeMap* types;
    157   bool is_function = false;
    158   TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(name, &types, &is_function));
    159   if (is_function && !server_context->Context()->FindFunctionByName(name)) {
    160     return errors::NotFound(
    161         "'", name,
    162         "' is neither a type of a primitive operation nor a name "
    163         "of a function registered in binary running on ",
    164         port::Hostname(),
    165         ". Make sure the operation or function is "
    166         "registered in the binary running in this process.");
    167   }
    168   op.reset(new tensorflow::EagerOperation(server_context->Context(), name,
    169                                           is_function, types));
    170 
    171   TF_RETURN_IF_ERROR(op->SetDevice(operation.device().c_str()));
    172 
    173   for (const auto& remote_handle : operation.inputs()) {
    174     tensorflow::TensorHandle* handle;
    175     TF_RETURN_IF_ERROR(server_context->GetTensorHandle(
    176         RemoteTensorHandleInternal(remote_handle), &handle));
    177 
    178     op->AddInput(handle);
    179   }
    180 
    181   for (const auto& attr : operation.attrs()) {
    182     op->MutableAttrs()->Set(attr.first, attr.second);
    183   }
    184 
    185   int num_retvals = 0;
    186   // TODO(nareshmodi): Consider caching this.
    187   TF_RETURN_IF_ERROR(GetNumRetvals(server_context->Context(), operation.name(),
    188                                    operation.attrs(), &num_retvals));
    189 
    190   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals;
    191   TF_RETURN_IF_ERROR(EagerExecute(op.get(), &retvals, &num_retvals));
    192 
    193   server_context->AddOperationOutputs(retvals, operation.id());
    194 
    195   for (auto* handle : retvals) {
    196     TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape()));
    197   }
    198 
    199   return Status::OK();
    200 }
    201 
    202 Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
    203                                  EnqueueResponse* response) {
    204   ServerContext* context = nullptr;
    205   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
    206   core::ScopedUnref context_unref(context);
    207 
    208   for (const auto& item : request->queue()) {
    209     auto* queue_response = response->add_queue_response();
    210     if (item.has_operation()) {
    211       TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response));
    212     } else {
    213       TF_RETURN_IF_ERROR(context->DeleteTensorHandle(
    214           RemoteTensorHandleInternal(item.handle_to_decref())));
    215     }
    216   }
    217 
    218   return Status::OK();
    219 }
    220 
    221 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request,
    222                                        WaitQueueDoneResponse* response) {
    223   ServerContext* context = nullptr;
    224   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
    225   core::ScopedUnref context_unref(context);
    226 
    227   if (request->op_id_size() > 0) {
    228     return errors::Unimplemented(
    229         "EagerServiceImpl::WaitQueueDone is not "
    230         "implemented for particular op IDs.");
    231   }
    232   return context->Context()->AsyncWait();
    233 }
    234 
    235 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request,
    236                                    KeepAliveResponse* response) {
    237   ServerContext* context = nullptr;
    238   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
    239   core::ScopedUnref context_unref(context);
    240 
    241   return Status::OK();
    242 }
    243 
    244 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request,
    245                                       CloseContextResponse* response) {
    246   ServerContext* context = nullptr;
    247   if (!GetServerContext(request->context_id(), &context).ok()) {
    248     // Swallow the error here.
    249     return Status::OK();
    250   }
    251 
    252   core::ScopedUnref context_unref(context);
    253 
    254   mutex_lock l(contexts_mu_);
    255   contexts_.erase(request->context_id());
    256 
    257   // GetServerContext returns a newly Reffed copy of ServerContext, which is
    258   // unreffed by context_unref. Additionally, we need to unref it one time since
    259   // we are releasing it from the map.
    260   context->Unref();
    261 
    262   return Status::OK();
    263 }
    264 
    265 Status EagerServiceImpl::RegisterFunction(
    266     const RegisterFunctionRequest* request,
    267     RegisterFunctionResponse* response) {
    268   ServerContext* context = nullptr;
    269   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
    270   core::ScopedUnref context_unref(context);
    271 
    272   return context->Context()->AddFunctionDef(request->function_def());
    273 }
    274 
    275 Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
    276                                     SendTensorResponse* response) {
    277   ServerContext* context = nullptr;
    278   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
    279   core::ScopedUnref context_unref(context);
    280 
    281   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors;
    282   for (const auto& tensor_proto : request->tensors()) {
    283     Tensor tensor;
    284     if (!tensor.FromProto(tensor_proto)) {
    285       return errors::InvalidArgument("Unable to parse tensor proto");
    286     }
    287 
    288     TensorHandle* tensor_handle =
    289         new TensorHandle(tensor, nullptr, nullptr, nullptr);
    290 
    291     TensorHandle* copied_handle = nullptr;
    292     TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(),
    293                                          request->device_name().c_str(),
    294                                          &copied_handle));
    295     tensors.push_back(copied_handle);
    296     tensor_handle->Unref();
    297   }
    298 
    299   context->AddOperationOutputs(tensors, request->op_id());
    300 
    301   return Status::OK();
    302 }
    303 
    304 tensorflow::Status EagerServiceImpl::GetServerContext(
    305     uint64 context_id, ServerContext** server_context) {
    306   mutex_lock l(contexts_mu_);
    307   auto iter = contexts_.find(context_id);
    308   if (iter == contexts_.end()) {
    309     *server_context = nullptr;
    310     return errors::InvalidArgument(strings::Printf(
    311         "Unable to find a context_id matching the specified one "
    312         "(%lld). Perhaps the worker was restarted, or the context was GC'd?",
    313         context_id));
    314   }
    315 
    316   *server_context = iter->second;
    317   (*server_context)->Ref();
    318 
    319   (*server_context)->RecordAccess();
    320 
    321   return Status::OK();
    322 }
    323 
    324 }  // namespace eager
    325 }  // namespace tensorflow
    326