1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/c/eager/c_api.h" 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "absl/memory/memory.h" 25 #include "tensorflow/c/c_api.h" 26 #include "tensorflow/c/c_api_internal.h" 27 #include "tensorflow/c/eager/c_api_internal.h" 28 #include "tensorflow/core/platform/host_info.h" 29 #ifdef TENSORFLOW_EAGER_USE_XLA 30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 31 #endif // TENSORFLOW_EAGER_USE_XLA 32 #include "tensorflow/core/common_runtime/copy_tensor.h" 33 #include "tensorflow/core/common_runtime/device_factory.h" 34 #include "tensorflow/core/common_runtime/device_mgr.h" 35 #include "tensorflow/core/common_runtime/device_set.h" 36 #include "tensorflow/core/common_runtime/eager/attr_builder.h" 37 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" 38 #include "tensorflow/core/common_runtime/eager/execute.h" 39 #include "tensorflow/core/common_runtime/function.h" 40 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 41 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" 42 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" 43 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" 44 #include "tensorflow/core/distributed_runtime/server_lib.h" 45 #include "tensorflow/core/distributed_runtime/worker_env.h" 46 #include "tensorflow/core/framework/node_def_util.h" 47 #include "tensorflow/core/framework/rendezvous.h" 48 #include "tensorflow/core/framework/tensor_shape.pb.h" 49 #include "tensorflow/core/framework/types.h" 50 #include "tensorflow/core/lib/core/refcount.h" 51 #include "tensorflow/core/lib/core/stringpiece.h" 52 #include "tensorflow/core/lib/gtl/cleanup.h" 53 #include "tensorflow/core/lib/gtl/flatmap.h" 54 #include "tensorflow/core/lib/gtl/map_util.h" 55 #include "tensorflow/core/lib/gtl/stl_util.h" 56 #include "tensorflow/core/lib/random/random.h" 57 #include "tensorflow/core/platform/env.h" 58 #include "tensorflow/core/platform/mutex.h" 59 #include "tensorflow/core/platform/thread_annotations.h" 60 #include "tensorflow/core/public/version.h" 61 62 using tensorflow::int64; 63 using tensorflow::string; 64 65 namespace { 66 bool IsCPU(const tensorflow::Device* d) { 67 return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; 68 } 69 70 bool IsXLA(const tensorflow::Device* d) { 71 if (d == nullptr) return false; 72 const auto& device_type = d->attributes().device_type(); 73 return device_type.find("XLA") != std::string::npos; 74 } 75 76 string DeviceName(const tensorflow::Device* d) { 77 return (d == nullptr) ? "cpu:0" : d->name(); 78 } 79 80 tensorflow::Status GetAllRemoteDevices( 81 const std::vector<string>& remote_workers, 82 tensorflow::WorkerCacheInterface* worker_cache, 83 std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) { 84 std::vector<std::unique_ptr<tensorflow::Device>> remote_devices; 85 tensorflow::Status status; 86 // TODO(nareshmodi) do this in parallel instead of serially. 87 for (const string& remote_worker : remote_workers) { 88 tensorflow::Notification n; 89 tensorflow::NewRemoteDevices( 90 tensorflow::Env::Default(), worker_cache, remote_worker, 91 [&status, &n, &remote_devices]( 92 const tensorflow::Status& s, 93 std::vector<tensorflow::Device*>* devices) { 94 status = s; 95 if (s.ok()) { 96 for (tensorflow::Device* d : *devices) { 97 remote_devices.emplace_back(d); 98 } 99 } 100 n.Notify(); 101 }); 102 n.WaitForNotification(); 103 } 104 std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr( 105 new tensorflow::DeviceMgr(std::move(remote_devices))); 106 107 TF_RETURN_IF_ERROR(status); 108 109 *device_mgr = std::move(remote_device_mgr); 110 return tensorflow::Status::OK(); 111 } 112 113 tensorflow::Status CreateRemoteContexts( 114 const std::vector<string>& remote_workers, int64 rendezvous_id, 115 int keep_alive_secs, const tensorflow::ServerDef& server_def, 116 tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, 117 tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) { 118 for (int i = 0; i < remote_workers.size(); i++) { 119 const string& remote_worker = remote_workers[i]; 120 121 tensorflow::eager::CreateContextRequest request; 122 tensorflow::eager::CreateContextResponse response; 123 request.set_rendezvous_id(rendezvous_id); 124 tensorflow::DeviceNameUtils::ParsedName parsed_name; 125 if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, 126 &parsed_name)) { 127 return tensorflow::errors::InvalidArgument( 128 "Unable to parse ", remote_worker, " as a device name"); 129 } 130 *request.mutable_server_def() = server_def; 131 request.mutable_server_def()->set_job_name(parsed_name.job); 132 request.mutable_server_def()->set_task_index(parsed_name.task); 133 request.set_async(async); 134 request.set_keep_alive_secs(keep_alive_secs); 135 auto* eager_client = remote_eager_workers->GetClient(remote_worker); 136 if (eager_client == nullptr) { 137 return tensorflow::errors::Internal( 138 "Cannot find a client for the given target:", remote_worker); 139 } 140 tensorflow::Notification n; 141 tensorflow::Status status; 142 // TODO(nareshmodi) do this in parallel instead of serially. 143 eager_client->CreateContextAsync( 144 &request, &response, [&status, &n](const tensorflow::Status& s) { 145 status = s; 146 n.Notify(); 147 }); 148 n.WaitForNotification(); 149 TF_RETURN_IF_ERROR(status); 150 151 remote_contexts->emplace(remote_worker, response.context_id()); 152 } 153 return tensorflow::Status::OK(); 154 } 155 156 tensorflow::Status UpdateTFE_ContextWithServerDef( 157 int keep_alive_secs, const tensorflow::ServerDef& server_def, 158 TFE_Context* ctx) { 159 // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the 160 // server object (which currently CHECK-fails) and we miss the error, instead, 161 // we log the error, and then return to allow the user to see the error 162 // message. 163 #define LOG_AND_RETURN_IF_ERROR(...) \ 164 do { \ 165 const ::tensorflow::Status _status = (__VA_ARGS__); \ 166 if (TF_PREDICT_FALSE(!_status.ok())) { \ 167 LOG(ERROR) << _status.error_message(); \ 168 return _status; \ 169 } \ 170 } while (0); 171 172 string worker_name = 173 tensorflow::strings::StrCat("/job:", server_def.job_name(), 174 "/replica:0/task:", server_def.task_index()); 175 176 std::unique_ptr<tensorflow::ServerInterface> server; 177 LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server)); 178 179 tensorflow::GrpcServer* grpc_server = 180 dynamic_cast<tensorflow::GrpcServer*>(server.get()); 181 if (grpc_server == nullptr) { 182 LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( 183 "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); 184 } 185 186 LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); 187 188 int64 rendezvous_id = tensorflow::random::New64(); 189 190 std::vector<string> remote_workers; 191 grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); 192 remote_workers.erase( 193 std::remove(remote_workers.begin(), remote_workers.end(), worker_name), 194 remote_workers.end()); 195 196 std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr; 197 LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( 198 remote_workers, grpc_server->master_env()->worker_cache, 199 &remote_device_mgr)); 200 201 std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache = 202 grpc_server->channel_cache(); 203 std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers( 204 tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); 205 206 // Initialize remote eager workers. 207 tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts; 208 LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( 209 remote_workers, rendezvous_id, keep_alive_secs, server_def, 210 remote_eager_workers.get(), ctx->context.Async(), &remote_contexts)); 211 212 tensorflow::RemoteRendezvous* r = 213 grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); 214 215 auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); 216 TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( 217 session_name, server_def, true)); 218 219 std::shared_ptr<tensorflow::WorkerSession> worker_session; 220 TF_RETURN_IF_ERROR( 221 grpc_server->worker_env()->session_mgr->WorkerSessionForSession( 222 session_name, &worker_session)); 223 224 // Initialize remote tensor communication based on worker session. 225 TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); 226 227 auto* device_mgr = grpc_server->worker_env()->device_mgr; 228 229 return ctx->context.InitializeRemote( 230 std::move(server), std::move(remote_eager_workers), 231 std::move(remote_device_mgr), remote_contexts, r, device_mgr, 232 keep_alive_secs); 233 #undef LOG_AND_RETURN_IF_ERROR 234 } 235 236 tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op, 237 TFE_TensorHandle* input) { 238 TFE_OpInferenceContext* ictx = op->inference_ctx.get(); 239 const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++); 240 if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) { 241 // Some clients that are still setting their input attributes manually are 242 // adding input list to their op by calling `TFE_OpAddInput` for each of 243 // its elements instead of calling `TFE_OpAddInputList`. When this happens, 244 // we cannot detect the end of such list, thus lose track of the input 245 // arguments in the op definition. To guarantee backward compatibility with 246 // those clients, disable automatic inference in this case. 247 op->inference_ctx.reset(nullptr); 248 return tensorflow::Status::OK(); 249 } 250 const std::string& type_attr = input_def.type_attr(); 251 if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) { 252 op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype); 253 ictx->attrs.insert(type_attr); 254 } 255 return tensorflow::Status::OK(); 256 } 257 258 void OpInferSingleTypeInputListAttrs(TFE_Op* op, 259 const tensorflow::OpDef::ArgDef& input_def, 260 TFE_TensorHandle** inputs, 261 int num_inputs) { 262 TFE_OpInferenceContext* ictx = op->inference_ctx.get(); 263 if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) { 264 op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs); 265 ictx->attrs.insert(input_def.number_attr()); 266 } 267 if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) { 268 op->operation.MutableAttrs()->Set(input_def.type_attr(), 269 inputs[0]->handle->dtype); 270 ictx->attrs.insert(input_def.type_attr()); 271 } 272 } 273 274 void OpInferMixedTypeInputListAttrs(TFE_Op* op, 275 const tensorflow::OpDef::ArgDef& input_def, 276 TFE_TensorHandle** inputs, int num_inputs) { 277 TFE_OpInferenceContext* ictx = op->inference_ctx.get(); 278 if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) { 279 std::unique_ptr<tensorflow::DataType[]> dtypes( 280 new tensorflow::DataType[num_inputs]); 281 for (int i = 0; i < num_inputs; ++i) { 282 dtypes[i] = inputs[i]->handle->dtype; 283 } 284 op->operation.MutableAttrs()->Set( 285 input_def.type_list_attr(), 286 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(), 287 num_inputs)); 288 ictx->attrs.insert(input_def.type_list_attr()); 289 } 290 } 291 292 tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs, 293 int num_inputs) { 294 TFE_OpInferenceContext* ictx = op->inference_ctx.get(); 295 const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++); 296 if (!input_def.type_list_attr().empty()) { 297 OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs); 298 } else if (!input_def.type_attr().empty() && 299 !input_def.number_attr().empty()) { 300 OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs); 301 } else { 302 return tensorflow::errors::InvalidArgument("Invalid input list definition"); 303 } 304 return tensorflow::Status::OK(); 305 } 306 307 } // namespace 308 309 extern "C" { 310 311 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } 312 313 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, 314 size_t proto_len, TF_Status* status) { 315 TF_SetConfig(&options->session_options, proto, proto_len, status); 316 } 317 318 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, 319 unsigned char enable) { 320 options->async = enable; 321 } 322 323 void TFE_ContextOptionsSetDevicePlacementPolicy( 324 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { 325 options->policy = policy; 326 } 327 328 TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, 329 unsigned char enable, 330 TF_Status* status) { 331 status->status = ctx->context.SetAsyncForThread(enable); 332 } 333 334 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } 335 336 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { 337 std::vector<std::unique_ptr<tensorflow::Device>> devices; 338 status->status = tensorflow::DeviceFactory::AddDevices( 339 opts->session_options.options, "/job:localhost/replica:0/task:0", 340 &devices); 341 if (!status->status.ok()) return nullptr; 342 std::unique_ptr<tensorflow::DeviceMgr> device_mgr( 343 new tensorflow::DeviceMgr(std::move(devices))); 344 345 tensorflow::Rendezvous* r = 346 new tensorflow::IntraProcessRendezvous(device_mgr.get()); 347 348 return new TFE_Context(opts->session_options.options, opts->policy, 349 opts->async, device_mgr.release(), 350 /*device_mgr_owned*/ true, r); 351 } 352 353 TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts, 354 TF_Session* sess, TF_Status* status) { 355 const tensorflow::DeviceMgr* device_mgr = nullptr; 356 status->status = sess->session->LocalDeviceManager(&device_mgr); 357 if (!status->status.ok()) return nullptr; 358 tensorflow::Rendezvous* r = 359 new tensorflow::IntraProcessRendezvous(device_mgr); 360 return new TFE_Context(opts->session_options.options, opts->policy, 361 opts->async, device_mgr, /*device_mgr_owned*/ false, 362 r); 363 } 364 365 void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; } 366 367 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { 368 TF_DeviceList* list = new TF_DeviceList; 369 ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); 370 if (ctx->context.remote_device_mgr()) { 371 ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); 372 } 373 return list; 374 } 375 376 void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) { 377 status->status = ctx->context.ClearCaches(); 378 } 379 380 // Set server_def on the context, possibly updating it. 381 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, 382 int keep_alive_secs, 383 const void* proto, 384 size_t proto_len, 385 TF_Status* status) { 386 tensorflow::ServerDef server_def; 387 if (!server_def.ParseFromArray(proto, proto_len)) { 388 status->status = tensorflow::errors::InvalidArgument( 389 "Invalid tensorflow.ServerDef protocol buffer"); 390 return; 391 } 392 status->status = 393 UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx); 394 } 395 396 void TFE_ContextSetThreadLocalDevicePlacementPolicy( 397 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { 398 ctx->context.SetThreadLocalDevicePlacementPolicy( 399 static_cast<tensorflow::ContextDevicePlacementPolicy>(policy)); 400 } 401 402 // Note: this function looks up a thread local policy. So it should be called in 403 // the appropriate client thread. In particular, in async mode, it may not be 404 // safe to call this function from the async EagerExecutor threads. 405 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( 406 TFE_Context* ctx) { 407 return static_cast<TFE_ContextDevicePlacementPolicy>( 408 ctx->context.GetDevicePlacementPolicy()); 409 } 410 411 void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { 412 status->status = ctx->context.AsyncWait(); 413 } 414 415 void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { 416 status->status = ctx->context.GetStatus(); 417 } 418 419 void TFE_ContextAsyncClearError(TFE_Context* ctx) { 420 ctx->context.ClearAsyncError(); 421 } 422 423 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { 424 tensorflow::Tensor tensor; 425 status->status = tensorflow::TF_TensorToTensor(t, &tensor); 426 if (!status->status.ok()) return nullptr; 427 return new TFE_TensorHandle(tensor, nullptr, nullptr); 428 } 429 430 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { 431 if (h == nullptr) return; 432 VLOG(1) << "Deleting tensor handle " << h << " with internal handle " 433 << h->handle; 434 if (h->handle) { 435 h->handle->Unref(); 436 } 437 delete h; 438 } 439 440 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { 441 return static_cast<TF_DataType>(h->handle->dtype); 442 } 443 444 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { 445 if (h == nullptr || h->handle == nullptr) { 446 status->status = tensorflow::errors::InvalidArgument( 447 "The passed in handle is a nullptr"); 448 return -1; 449 } 450 int result; 451 status->status = h->handle->NumDims(&result); 452 return result; 453 } 454 455 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { 456 if (h == nullptr || h->handle == nullptr) { 457 status->status = tensorflow::errors::InvalidArgument( 458 "The passed in handle is a nullptr"); 459 return -1; 460 } 461 tensorflow::int64 result; 462 status->status = h->handle->NumElements(&result); 463 return result; 464 } 465 466 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, 467 TF_Status* status) { 468 if (h == nullptr || h->handle == nullptr) { 469 status->status = tensorflow::errors::InvalidArgument( 470 "The passed in handle is a nullptr"); 471 return -1; 472 } 473 tensorflow::int64 result; 474 status->status = h->handle->Dim(dim_index, &result); 475 return result; 476 } 477 478 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { 479 if (h == nullptr || h->handle == nullptr) { 480 status->status = tensorflow::errors::InvalidArgument( 481 "The passed in handle is a nullptr"); 482 return nullptr; 483 } 484 tensorflow::Device* d = h->handle->op_device(); 485 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" 486 : d->name().c_str(); 487 } 488 489 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h, 490 TF_Status* status) { 491 if (h == nullptr || h->handle == nullptr) { 492 status->status = tensorflow::errors::InvalidArgument( 493 "The passed in handle is a nullptr"); 494 return nullptr; 495 } 496 tensorflow::Device* d = h->handle->device(); 497 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" 498 : d->name().c_str(); 499 } 500 501 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( 502 TFE_TensorHandle* h, TF_Status* status) { 503 if (h == nullptr || h->handle == nullptr) { 504 status->status = tensorflow::errors::InvalidArgument( 505 "The passed in handle is a nullptr"); 506 return nullptr; 507 } 508 509 h->handle->Ref(); 510 511 return new TFE_TensorHandle(h->handle); 512 } 513 514 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { 515 if (h == nullptr || h->handle == nullptr) { 516 status->status = tensorflow::errors::InvalidArgument( 517 "The passed in handle is a nullptr"); 518 return nullptr; 519 } 520 // TODO(agarwal): move this implementation inside TFE_TensorHandle. 521 const tensorflow::Tensor* t = nullptr; 522 tensorflow::TensorHandle* h_cpu = nullptr; 523 tensorflow::Device* d = nullptr; 524 tensorflow::Device* op_device = nullptr; 525 526 if (h->handle->IsRemote()) { 527 status->status = EagerCopyToDevice( 528 h->handle, h->handle->Context(), 529 h->handle->Context()->HostCPU()->name().c_str(), &h_cpu); 530 if (!status->status.ok()) { 531 return nullptr; 532 } 533 status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); 534 if (!status->status.ok()) { 535 h_cpu->Unref(); 536 return nullptr; 537 } 538 } else { 539 status->status = h->handle->TensorAndDevice(&t, &d, &op_device); 540 if (!status->status.ok()) return nullptr; 541 542 if (!IsCPU(d)) { 543 status->status = h->handle->CopyToDevice( 544 h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu); 545 if (!status->status.ok()) { 546 return nullptr; 547 } 548 status->status = h_cpu->TensorAndDevice(&t, &d, &op_device); 549 if (!status->status.ok()) { 550 h_cpu->Unref(); 551 return nullptr; 552 } 553 } 554 } 555 TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status); 556 if (h_cpu != nullptr) { 557 h_cpu->Unref(); 558 } 559 return retval; 560 } 561 562 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, 563 TF_Status* status) { 564 const char* name = op_or_function_name; // Shorthand 565 const tensorflow::AttrTypeMap* types; 566 bool is_function = false; 567 status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function); 568 if (!status->status.ok()) { 569 return nullptr; 570 } 571 if (!is_function) { 572 const tensorflow::OpDef* op_def; 573 status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def); 574 if (!status->status.ok()) { 575 return nullptr; 576 } 577 return new TFE_Op(ctx, name, false, types, 578 new TFE_OpInferenceContext(op_def)); 579 } 580 if (!ctx->context.FindFunctionByName(name)) { 581 status->status = tensorflow::errors::NotFound( 582 "'", name, 583 "' is neither a type of a primitive operation nor a name " 584 "of a function registered in binary running on ", 585 tensorflow::port::Hostname(), 586 ". Make sure the operation or function is " 587 "registered in the binary running in this process."); 588 return nullptr; 589 } 590 return new TFE_Op(ctx, name, true, types, nullptr); 591 } 592 593 void TFE_DeleteOp(TFE_Op* op) { delete op; } 594 595 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { 596 status->status = op->operation.SetDevice(device_name); 597 } 598 599 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { 600 tensorflow::Device* device = (op->operation.Device() == nullptr) 601 ? op->operation.EagerContext()->HostCPU() 602 : op->operation.Device(); 603 return device->name().c_str(); 604 } 605 606 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { 607 op->operation.SetUseXla(enable); 608 #ifndef TENSORFLOW_EAGER_USE_XLA 609 LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " 610 "built with XLA support."; 611 #endif // TENSORFLOW_EAGER_USE_XLA 612 } 613 614 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { 615 op->operation.AddInput(input->handle); 616 if (op->inference_ctx) { 617 status->status = OpInferSingleInputAttrs(op, input); 618 } 619 } 620 621 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, 622 TF_Status* status) { 623 for (int i = 0; i < num_inputs; ++i) { 624 op->operation.AddInput(inputs[i]->handle); 625 } 626 if (op->inference_ctx) { 627 status->status = OpInferInputListAttrs(op, inputs, num_inputs); 628 } 629 } 630 631 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, 632 unsigned char* is_list, TF_Status* status) { 633 TF_AttrType ret; 634 status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), 635 attr_name, &ret, is_list); 636 return ret; 637 } 638 639 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, 640 const char* op_or_function_name, 641 const char* attr_name, unsigned char* is_list, 642 TF_Status* status) { 643 TF_AttrType ret; 644 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); 645 if (!status->status.ok()) { 646 return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. 647 } 648 ret = TFE_OpGetAttrType(op, attr_name, is_list, status); 649 TFE_DeleteOp(op); 650 return ret; 651 } 652 653 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, 654 size_t length) { 655 op->operation.MutableAttrs()->Set( 656 attr_name, 657 tensorflow::StringPiece(static_cast<const char*>(value), length)); 658 } 659 660 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { 661 op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value)); 662 } 663 664 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { 665 op->operation.MutableAttrs()->Set(attr_name, value); 666 } 667 668 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { 669 op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true); 670 } 671 672 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { 673 op->operation.MutableAttrs()->Set(attr_name, 674 static_cast<tensorflow::DataType>(value)); 675 } 676 677 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, 678 const int num_dims, TF_Status* out_status) { 679 if (num_dims > tensorflow::TensorShape::MaxDimensions()) { 680 TF_SetStatus(out_status, TF_INVALID_ARGUMENT, 681 tensorflow::strings::StrCat( 682 "Value specified for `", attr_name, "` has ", num_dims, 683 " dimensions which is over the limit of ", 684 tensorflow::TensorShape::MaxDimensions(), ".") 685 .c_str()); 686 return; 687 } 688 tensorflow::TensorShapeProto proto; 689 if (num_dims < 0) { 690 proto.set_unknown_rank(true); 691 } else { 692 for (int d = 0; d < num_dims; ++d) { 693 proto.add_dim()->set_size(dims[d]); 694 } 695 } 696 op->operation.MutableAttrs()->Set(attr_name, proto); 697 } 698 699 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, 700 const TFE_Op* value) { 701 tensorflow::AttrValue attr_value; 702 tensorflow::NameAttrList* func = attr_value.mutable_func(); 703 func->set_name(value->operation.Name()); 704 value->operation.Attrs().FillAttrValueMap(func->mutable_attr()); 705 op->operation.MutableAttrs()->Set(attr_name, attr_value); 706 } 707 708 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, 709 const char* data, size_t length) { 710 tensorflow::AttrValue attr_value; 711 tensorflow::NameAttrList* func = attr_value.mutable_func(); 712 func->set_name(data, length); 713 op->operation.MutableAttrs()->Set(attr_name, attr_value); 714 } 715 716 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, 717 TF_Status* status) { 718 tensorflow::Tensor t; 719 status->status = TF_TensorToTensor(tensor, &t); 720 if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t); 721 } 722 723 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, 724 const void* const* values, const size_t* lengths, 725 int num_values) { 726 std::vector<tensorflow::StringPiece> v(num_values); 727 for (int i = 0; i < num_values; ++i) { 728 v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]), 729 lengths[i]); 730 } 731 op->operation.MutableAttrs()->Set(attr_name, v); 732 } 733 734 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, 735 const float* values, int num_values) { 736 op->operation.MutableAttrs()->Set( 737 attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values)); 738 } 739 740 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, 741 const int64_t* values, int num_values) { 742 op->operation.MutableAttrs()->Set( 743 attr_name, tensorflow::gtl::ArraySlice<const int64>( 744 reinterpret_cast<const int64*>(values), num_values)); 745 } 746 747 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, 748 const TF_DataType* values, int num_values) { 749 op->operation.MutableAttrs()->Set( 750 attr_name, 751 tensorflow::gtl::ArraySlice<const tensorflow::DataType>( 752 reinterpret_cast<const tensorflow::DataType*>(values), num_values)); 753 } 754 755 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, 756 const unsigned char* values, int num_values) { 757 std::unique_ptr<bool[]> b(new bool[num_values]); 758 for (int i = 0; i < num_values; ++i) { 759 b[i] = values[i]; 760 } 761 op->operation.MutableAttrs()->Set( 762 attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values)); 763 } 764 765 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, 766 const int64_t** dims, const int* num_dims, 767 int num_values, TF_Status* out_status) { 768 std::unique_ptr<tensorflow::TensorShapeProto[]> proto( 769 new tensorflow::TensorShapeProto[num_values]); 770 for (int i = 0; i < num_values; ++i) { 771 const auto num_dims_i = num_dims[i]; 772 773 if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) { 774 TF_SetStatus(out_status, TF_INVALID_ARGUMENT, 775 tensorflow::strings::StrCat( 776 "Value specified for `", attr_name, "` has ", num_dims_i, 777 " dimensions which is over the limit of ", 778 tensorflow::TensorShape::MaxDimensions(), ".") 779 .c_str()); 780 return; 781 } 782 if (num_dims_i < 0) { 783 proto[i].set_unknown_rank(true); 784 } else { 785 const int64_t* dims_i = dims[i]; 786 auto proto_i = &proto[i]; 787 for (int d = 0; d < num_dims_i; ++d) { 788 proto_i->add_dim()->set_size(dims_i[d]); 789 } 790 } 791 } 792 op->operation.MutableAttrs()->Set( 793 attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>( 794 proto.get(), num_values)); 795 } 796 797 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, 798 const TFE_Op** value, int num_values) { 799 std::unique_ptr<tensorflow::NameAttrList[]> funcs( 800 new tensorflow::NameAttrList[num_values]); 801 for (int i = 0; i < num_values; i++) { 802 funcs[i].set_name(value[i]->operation.Name()); 803 value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr()); 804 } 805 op->operation.MutableAttrs()->Set( 806 attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>( 807 funcs.get(), num_values)); 808 } 809 810 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, 811 TF_Status* status) { 812 VLOG(1) << "Calling TFE_Execute() on op " << op; 813 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals( 814 *num_retvals); 815 status->status = 816 tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals); 817 if (!status->status.ok()) { 818 return; 819 } 820 for (int i = 0; i < *num_retvals; ++i) { 821 retvals[i] = new TFE_TensorHandle(handle_retvals[i]); 822 } 823 } 824 825 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, 826 TFE_Context* ctx, 827 const char* device_name, 828 TF_Status* status) { 829 tensorflow::TensorHandle* handle; 830 status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, 831 device_name, &handle); 832 if (status->status.ok()) { 833 return new TFE_TensorHandle(handle); 834 } 835 return nullptr; 836 } 837 838 void TFE_ContextAddFunctionDef(TFE_Context* ctx, 839 const char* serialized_function_def, size_t size, 840 TF_Status* status) { 841 tensorflow::FunctionDef function_def; 842 if (!function_def.ParseFromArray(serialized_function_def, size)) { 843 status->status = 844 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); 845 return; 846 } 847 status->status = ctx->context.AddFunctionDef(function_def); 848 } 849 850 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, 851 TF_Status* status) { 852 status->status = ctx->context.AddFunctionDef(function->fdef); 853 } 854 855 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { 856 return ctx->context.FindFunctionDef(name) != nullptr; 857 } 858 859 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { 860 ctx->context.SetShouldStoreGraphs(true); 861 ctx->context.SetShouldStoreStepStats(true); 862 } 863 864 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { 865 ctx->context.SetShouldStoreGraphs(false); 866 ctx->context.SetShouldStoreStepStats(false); 867 } 868 869 } // extern "C" 870 871 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { 872 return new TFE_TensorHandle(t, nullptr, nullptr); 873 } 874 875 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( 876 TFE_TensorHandle* h, TF_Status* status) { 877 if (!h->handle->OnHostCPU()) { 878 status->status = tensorflow::errors::FailedPrecondition( 879 "TFE_TensorHandle is placed in device (not host) memory. Cannot return " 880 "a tensorflow::Tensor"); 881 return nullptr; 882 } 883 tensorflow::Device* d = nullptr; 884 tensorflow::Device* op_device = nullptr; 885 const tensorflow::Tensor* t = nullptr; 886 status->status = h->handle->TensorAndDevice(&t, &d, &op_device); 887 if (!status->status.ok()) return nullptr; 888 return t; 889 } 890 891 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, 892 TF_Status* status) { 893 TFE_ContextAsyncWait(ctx, status); 894 if (!status->status.ok()) return; 895 tensorflow::mutex_lock ml(*ctx->context.MetadataMu()); 896 status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf); 897 ctx->context.ClearRunMetadata(); 898 } 899 900 namespace { 901 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, 902 TF_Status* status) { 903 TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); 904 for (const auto& attr : func.attr()) { 905 if (TF_GetCode(status) != TF_OK) return nullptr; 906 SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); 907 if (TF_GetCode(status) != TF_OK) return nullptr; 908 } 909 return func_op; 910 } 911 } // namespace 912 913 void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); } 914 915 void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); } 916 917 namespace tensorflow { 918 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, 919 const tensorflow::AttrValue& default_value, 920 const char* attr_name, TF_Status* status) { 921 switch (default_value.value_case()) { 922 case tensorflow::AttrValue::kS: { 923 const string& v = default_value.s(); 924 TFE_OpSetAttrString(op, attr_name, v.data(), v.size()); 925 break; 926 } 927 case tensorflow::AttrValue::kI: 928 TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i())); 929 break; 930 case tensorflow::AttrValue::kF: 931 TFE_OpSetAttrFloat(op, attr_name, default_value.f()); 932 break; 933 case tensorflow::AttrValue::kB: 934 TFE_OpSetAttrBool(op, attr_name, default_value.b()); 935 break; 936 case tensorflow::AttrValue::kType: 937 TFE_OpSetAttrType(op, attr_name, 938 static_cast<TF_DataType>(default_value.type())); 939 break; 940 case tensorflow::AttrValue::kShape: { 941 const auto& tensor_shape = default_value.shape(); 942 if (tensor_shape.unknown_rank()) { 943 TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status); 944 } else { 945 const auto num_dims = tensor_shape.dim_size(); 946 std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); 947 for (int i = 0; i < num_dims; ++i) { 948 dims[i] = tensor_shape.dim(i).size(); 949 } 950 TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status); 951 } 952 } break; 953 case tensorflow::AttrValue::kFunc: { 954 const auto func_op = GetFunc(ctx, default_value.func(), status); 955 if (TF_GetCode(status) != TF_OK) return; 956 // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList 957 // require TFE_Op* and just convert it internally a NameAttrValue, so 958 // consider adding an overload to the C API to make this case easier. 959 TFE_OpSetAttrFunction(op, attr_name, func_op); 960 } break; 961 case tensorflow::AttrValue::kList: 962 TF_FALLTHROUGH_INTENDED; 963 case tensorflow::AttrValue::kTensor: 964 TF_FALLTHROUGH_INTENDED; 965 case tensorflow::AttrValue::kPlaceholder: 966 TF_FALLTHROUGH_INTENDED; 967 case tensorflow::AttrValue::VALUE_NOT_SET: 968 TF_SetStatus( 969 status, TF_UNIMPLEMENTED, 970 tensorflow::strings::StrCat("Unable to get setfor default value: ", 971 default_value.DebugString()) 972 .data()); 973 } 974 } 975 } // namespace tensorflow 976