1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" 17 18 #include "absl/memory/memory.h" 19 #include "tensorflow/c/c_api_internal.h" 20 #include "tensorflow/c/tf_status_helper.h" 21 #include "tensorflow/core/common_runtime/device_mgr.h" 22 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 23 #include "tensorflow/core/common_runtime/eager/execute.h" 24 #include "tensorflow/core/common_runtime/process_util.h" 25 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" 26 #include "tensorflow/core/distributed_runtime/server_lib.h" 27 #include "tensorflow/core/distributed_runtime/session_mgr.h" 28 #include "tensorflow/core/distributed_runtime/worker_cache.h" 29 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h" 30 #include "tensorflow/core/distributed_runtime/worker_env.h" 31 #include "tensorflow/core/framework/rendezvous.h" 32 #include "tensorflow/core/lib/core/error_codes.pb.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/lib/gtl/cleanup.h" 35 #include "tensorflow/core/lib/random/random.h" 36 #include "tensorflow/core/lib/strings/strcat.h" 37 #include "tensorflow/core/lib/strings/stringprintf.h" 38 #include "tensorflow/core/platform/cpu_info.h" 39 #include "tensorflow/core/platform/env.h" 40 #include "tensorflow/core/platform/host_info.h" 41 42 namespace tensorflow { 43 namespace eager { 44 45 namespace { 46 Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name, 47 const google::protobuf::Map<string, tensorflow::AttrValue>& attrs, 48 int* num_retvals) { 49 const tensorflow::OpRegistrationData* op_reg_data = nullptr; 50 auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); 51 if (errors::IsNotFound(status)) { 52 status = context->FindFunctionOpData(op_name, &op_reg_data); 53 } 54 TF_RETURN_IF_ERROR(status); 55 56 const tensorflow::OpDef& op_def = op_reg_data->op_def; 57 58 for (const auto& output_arg : op_def.output_arg()) { 59 if (!output_arg.number_attr().empty()) { 60 auto iter = attrs.find(output_arg.number_attr()); 61 if (iter == attrs.end()) { 62 return errors::InvalidArgument("Unable to find number_attr ", 63 output_arg.number_attr(), 64 " for Op: ", op_name); 65 } 66 *num_retvals += iter->second.i(); 67 } else if (!output_arg.type_list_attr().empty()) { 68 auto iter = attrs.find(output_arg.type_list_attr()); 69 if (iter == attrs.end()) { 70 return errors::InvalidArgument("Unable to find type_list_attr ", 71 output_arg.type_list_attr(), 72 " for Op: ", op_name); 73 } 74 *num_retvals += iter->second.list().type_size(); 75 } else { 76 *num_retvals += 1; 77 } 78 } 79 80 return Status::OK(); 81 } 82 } // namespace 83 84 Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, 85 CreateContextResponse* response) { 86 // make sure env_ , env_->rendezvous_mgr available 87 if (env_ == nullptr || env_->rendezvous_mgr == nullptr) { 88 return tensorflow::errors::Internal( 89 "invalid eager env_ or env_->rendezvous_mgr."); 90 } 91 std::vector<std::unique_ptr<tensorflow::Device>> devices; 92 93 TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( 94 // TODO(nareshmodi): Correctly set the SessionOptions. 95 SessionOptions(), 96 strings::Printf("/job:%s/replica:0/task:%d", 97 request->server_def().job_name().data(), 98 request->server_def().task_index()), 99 &devices)); 100 response->mutable_device_attributes()->Reserve(devices.size()); 101 for (const auto& d : devices) { 102 *response->add_device_attributes() = d->attributes(); 103 } 104 105 std::unique_ptr<tensorflow::DeviceMgr> device_mgr = 106 absl::make_unique<DeviceMgr>(std::move(devices)); 107 108 auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id()); 109 auto session_name = strings::StrCat("eager_", request->rendezvous_id()); 110 TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession( 111 session_name, request->server_def(), true)); 112 113 std::shared_ptr<WorkerSession> worker_session; 114 TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( 115 session_name, &worker_session)); 116 117 // Initialize remote tensor communication based on worker session. 118 TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); 119 120 std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext( 121 SessionOptions(), 122 tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, 123 request->async(), std::move(device_mgr), r)); 124 125 uint64 context_id; 126 { 127 mutex_lock l(contexts_mu_); 128 do { 129 context_id = random::New64(); 130 } while (contexts_.find(context_id) != contexts_.end()); 131 contexts_.emplace( 132 context_id, 133 new ServerContext(std::move(ctx), request->keep_alive_secs(), env_)); 134 } 135 response->set_context_id(context_id); 136 137 return Status::OK(); 138 } 139 140 Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { 141 const tensorflow::Tensor* t = nullptr; 142 143 // TODO(nareshmodi): This call makes async calls sync calls. Fix this. 144 TF_RETURN_IF_ERROR(handle->Tensor(&t)); 145 146 t->shape().AsProto(proto); 147 148 return Status::OK(); 149 } 150 151 Status EagerServiceImpl::ExecuteOp(const Operation& operation, 152 ServerContext* server_context, 153 QueueResponse* queue_response) { 154 std::unique_ptr<tensorflow::EagerOperation> op; 155 const char* name = operation.name().c_str(); // Shorthand 156 const tensorflow::AttrTypeMap* types; 157 bool is_function = false; 158 TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(name, &types, &is_function)); 159 if (is_function && !server_context->Context()->FindFunctionByName(name)) { 160 return errors::NotFound( 161 "'", name, 162 "' is neither a type of a primitive operation nor a name " 163 "of a function registered in binary running on ", 164 port::Hostname(), 165 ". Make sure the operation or function is " 166 "registered in the binary running in this process."); 167 } 168 op.reset(new tensorflow::EagerOperation(server_context->Context(), name, 169 is_function, types)); 170 171 TF_RETURN_IF_ERROR(op->SetDevice(operation.device().c_str())); 172 173 for (const auto& remote_handle : operation.inputs()) { 174 tensorflow::TensorHandle* handle; 175 TF_RETURN_IF_ERROR(server_context->GetTensorHandle( 176 RemoteTensorHandleInternal(remote_handle), &handle)); 177 178 op->AddInput(handle); 179 } 180 181 for (const auto& attr : operation.attrs()) { 182 op->MutableAttrs()->Set(attr.first, attr.second); 183 } 184 185 int num_retvals = 0; 186 // TODO(nareshmodi): Consider caching this. 187 TF_RETURN_IF_ERROR(GetNumRetvals(server_context->Context(), operation.name(), 188 operation.attrs(), &num_retvals)); 189 190 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals; 191 TF_RETURN_IF_ERROR(EagerExecute(op.get(), &retvals, &num_retvals)); 192 193 server_context->AddOperationOutputs(retvals, operation.id()); 194 195 for (auto* handle : retvals) { 196 TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape())); 197 } 198 199 return Status::OK(); 200 } 201 202 Status EagerServiceImpl::Enqueue(const EnqueueRequest* request, 203 EnqueueResponse* response) { 204 ServerContext* context = nullptr; 205 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); 206 core::ScopedUnref context_unref(context); 207 208 for (const auto& item : request->queue()) { 209 auto* queue_response = response->add_queue_response(); 210 if (item.has_operation()) { 211 TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response)); 212 } else { 213 TF_RETURN_IF_ERROR(context->DeleteTensorHandle( 214 RemoteTensorHandleInternal(item.handle_to_decref()))); 215 } 216 } 217 218 return Status::OK(); 219 } 220 221 Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request, 222 WaitQueueDoneResponse* response) { 223 ServerContext* context = nullptr; 224 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); 225 core::ScopedUnref context_unref(context); 226 227 if (request->op_id_size() > 0) { 228 return errors::Unimplemented( 229 "EagerServiceImpl::WaitQueueDone is not " 230 "implemented for particular op IDs."); 231 } 232 return context->Context()->AsyncWait(); 233 } 234 235 Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, 236 KeepAliveResponse* response) { 237 ServerContext* context = nullptr; 238 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); 239 core::ScopedUnref context_unref(context); 240 241 return Status::OK(); 242 } 243 244 Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, 245 CloseContextResponse* response) { 246 ServerContext* context = nullptr; 247 if (!GetServerContext(request->context_id(), &context).ok()) { 248 // Swallow the error here. 249 return Status::OK(); 250 } 251 252 core::ScopedUnref context_unref(context); 253 254 mutex_lock l(contexts_mu_); 255 contexts_.erase(request->context_id()); 256 257 // GetServerContext returns a newly Reffed copy of ServerContext, which is 258 // unreffed by context_unref. Additionally, we need to unref it one time since 259 // we are releasing it from the map. 260 context->Unref(); 261 262 return Status::OK(); 263 } 264 265 Status EagerServiceImpl::RegisterFunction( 266 const RegisterFunctionRequest* request, 267 RegisterFunctionResponse* response) { 268 ServerContext* context = nullptr; 269 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); 270 core::ScopedUnref context_unref(context); 271 272 return context->Context()->AddFunctionDef(request->function_def()); 273 } 274 275 Status EagerServiceImpl::SendTensor(const SendTensorRequest* request, 276 SendTensorResponse* response) { 277 ServerContext* context = nullptr; 278 TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); 279 core::ScopedUnref context_unref(context); 280 281 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> tensors; 282 for (const auto& tensor_proto : request->tensors()) { 283 Tensor tensor; 284 if (!tensor.FromProto(tensor_proto)) { 285 return errors::InvalidArgument("Unable to parse tensor proto"); 286 } 287 288 TensorHandle* tensor_handle = 289 new TensorHandle(tensor, nullptr, nullptr, nullptr); 290 291 TensorHandle* copied_handle = nullptr; 292 TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, context->Context(), 293 request->device_name().c_str(), 294 &copied_handle)); 295 tensors.push_back(copied_handle); 296 tensor_handle->Unref(); 297 } 298 299 context->AddOperationOutputs(tensors, request->op_id()); 300 301 return Status::OK(); 302 } 303 304 tensorflow::Status EagerServiceImpl::GetServerContext( 305 uint64 context_id, ServerContext** server_context) { 306 mutex_lock l(contexts_mu_); 307 auto iter = contexts_.find(context_id); 308 if (iter == contexts_.end()) { 309 *server_context = nullptr; 310 return errors::InvalidArgument(strings::Printf( 311 "Unable to find a context_id matching the specified one " 312 "(%lld). Perhaps the worker was restarted, or the context was GC'd?", 313 context_id)); 314 } 315 316 *server_context = iter->second; 317 (*server_context)->Ref(); 318 319 (*server_context)->RecordAccess(); 320 321 return Status::OK(); 322 } 323 324 } // namespace eager 325 } // namespace tensorflow 326