Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/c/eager/c_api.h"
     17 
     18 #include <algorithm>
     19 #include <cstddef>
     20 #include <memory>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "absl/memory/memory.h"
     25 #include "tensorflow/c/c_api.h"
     26 #include "tensorflow/c/c_api_internal.h"
     27 #include "tensorflow/c/eager/c_api_internal.h"
     28 #include "tensorflow/core/platform/host_info.h"
     29 #ifdef TENSORFLOW_EAGER_USE_XLA
     30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     31 #endif  // TENSORFLOW_EAGER_USE_XLA
     32 #include "tensorflow/core/common_runtime/copy_tensor.h"
     33 #include "tensorflow/core/common_runtime/device_factory.h"
     34 #include "tensorflow/core/common_runtime/device_mgr.h"
     35 #include "tensorflow/core/common_runtime/device_set.h"
     36 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
     37 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
     38 #include "tensorflow/core/common_runtime/eager/execute.h"
     39 #include "tensorflow/core/common_runtime/function.h"
     40 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
     41 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
     42 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
     43 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
     44 #include "tensorflow/core/distributed_runtime/server_lib.h"
     45 #include "tensorflow/core/distributed_runtime/worker_env.h"
     46 #include "tensorflow/core/framework/node_def_util.h"
     47 #include "tensorflow/core/framework/rendezvous.h"
     48 #include "tensorflow/core/framework/tensor_shape.pb.h"
     49 #include "tensorflow/core/framework/types.h"
     50 #include "tensorflow/core/lib/core/refcount.h"
     51 #include "tensorflow/core/lib/core/stringpiece.h"
     52 #include "tensorflow/core/lib/gtl/cleanup.h"
     53 #include "tensorflow/core/lib/gtl/flatmap.h"
     54 #include "tensorflow/core/lib/gtl/map_util.h"
     55 #include "tensorflow/core/lib/gtl/stl_util.h"
     56 #include "tensorflow/core/lib/random/random.h"
     57 #include "tensorflow/core/platform/env.h"
     58 #include "tensorflow/core/platform/mutex.h"
     59 #include "tensorflow/core/platform/thread_annotations.h"
     60 #include "tensorflow/core/public/version.h"
     61 
     62 using tensorflow::int64;
     63 using tensorflow::string;
     64 
     65 namespace {
     66 bool IsCPU(const tensorflow::Device* d) {
     67   return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
     68 }
     69 
     70 bool IsXLA(const tensorflow::Device* d) {
     71   if (d == nullptr) return false;
     72   const auto& device_type = d->attributes().device_type();
     73   return device_type.find("XLA") != std::string::npos;
     74 }
     75 
     76 string DeviceName(const tensorflow::Device* d) {
     77   return (d == nullptr) ? "cpu:0" : d->name();
     78 }
     79 
     80 tensorflow::Status GetAllRemoteDevices(
     81     const std::vector<string>& remote_workers,
     82     tensorflow::WorkerCacheInterface* worker_cache,
     83     std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
     84   std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
     85   tensorflow::Status status;
     86   // TODO(nareshmodi) do this in parallel instead of serially.
     87   for (const string& remote_worker : remote_workers) {
     88     tensorflow::Notification n;
     89     tensorflow::NewRemoteDevices(
     90         tensorflow::Env::Default(), worker_cache, remote_worker,
     91         [&status, &n, &remote_devices](
     92             const tensorflow::Status& s,
     93             std::vector<tensorflow::Device*>* devices) {
     94           status = s;
     95           if (s.ok()) {
     96             for (tensorflow::Device* d : *devices) {
     97               remote_devices.emplace_back(d);
     98             }
     99           }
    100           n.Notify();
    101         });
    102     n.WaitForNotification();
    103   }
    104   std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
    105       new tensorflow::DeviceMgr(std::move(remote_devices)));
    106 
    107   TF_RETURN_IF_ERROR(status);
    108 
    109   *device_mgr = std::move(remote_device_mgr);
    110   return tensorflow::Status::OK();
    111 }
    112 
    113 tensorflow::Status CreateRemoteContexts(
    114     const std::vector<string>& remote_workers, int64 rendezvous_id,
    115     int keep_alive_secs, const tensorflow::ServerDef& server_def,
    116     tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
    117     tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
    118   for (int i = 0; i < remote_workers.size(); i++) {
    119     const string& remote_worker = remote_workers[i];
    120 
    121     tensorflow::eager::CreateContextRequest request;
    122     tensorflow::eager::CreateContextResponse response;
    123     request.set_rendezvous_id(rendezvous_id);
    124     tensorflow::DeviceNameUtils::ParsedName parsed_name;
    125     if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
    126                                                     &parsed_name)) {
    127       return tensorflow::errors::InvalidArgument(
    128           "Unable to parse ", remote_worker, " as a device name");
    129     }
    130     *request.mutable_server_def() = server_def;
    131     request.mutable_server_def()->set_job_name(parsed_name.job);
    132     request.mutable_server_def()->set_task_index(parsed_name.task);
    133     request.set_async(async);
    134     request.set_keep_alive_secs(keep_alive_secs);
    135     auto* eager_client = remote_eager_workers->GetClient(remote_worker);
    136     if (eager_client == nullptr) {
    137       return tensorflow::errors::Internal(
    138           "Cannot find a client for the given target:", remote_worker);
    139     }
    140     tensorflow::Notification n;
    141     tensorflow::Status status;
    142     // TODO(nareshmodi) do this in parallel instead of serially.
    143     eager_client->CreateContextAsync(
    144         &request, &response, [&status, &n](const tensorflow::Status& s) {
    145           status = s;
    146           n.Notify();
    147         });
    148     n.WaitForNotification();
    149     TF_RETURN_IF_ERROR(status);
    150 
    151     remote_contexts->emplace(remote_worker, response.context_id());
    152   }
    153   return tensorflow::Status::OK();
    154 }
    155 
    156 tensorflow::Status UpdateTFE_ContextWithServerDef(
    157     int keep_alive_secs, const tensorflow::ServerDef& server_def,
    158     TFE_Context* ctx) {
    159   // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
    160   // server object (which currently CHECK-fails) and we miss the error, instead,
    161   // we log the error, and then return to allow the user to see the error
    162   // message.
    163 #define LOG_AND_RETURN_IF_ERROR(...)                    \
    164   do {                                                  \
    165     const ::tensorflow::Status _status = (__VA_ARGS__); \
    166     if (TF_PREDICT_FALSE(!_status.ok())) {              \
    167       LOG(ERROR) << _status.error_message();            \
    168       return _status;                                   \
    169     }                                                   \
    170   } while (0);
    171 
    172   string worker_name =
    173       tensorflow::strings::StrCat("/job:", server_def.job_name(),
    174                                   "/replica:0/task:", server_def.task_index());
    175 
    176   std::unique_ptr<tensorflow::ServerInterface> server;
    177   LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
    178 
    179   tensorflow::GrpcServer* grpc_server =
    180       dynamic_cast<tensorflow::GrpcServer*>(server.get());
    181   if (grpc_server == nullptr) {
    182     LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
    183         "Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
    184   }
    185 
    186   LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
    187 
    188   int64 rendezvous_id = tensorflow::random::New64();
    189 
    190   std::vector<string> remote_workers;
    191   grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
    192   remote_workers.erase(
    193       std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
    194       remote_workers.end());
    195 
    196   std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
    197   LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
    198       remote_workers, grpc_server->master_env()->worker_cache,
    199       &remote_device_mgr));
    200 
    201   std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
    202       grpc_server->channel_cache();
    203   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
    204       tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
    205 
    206   // Initialize remote eager workers.
    207   tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
    208   LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
    209       remote_workers, rendezvous_id, keep_alive_secs, server_def,
    210       remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
    211 
    212   tensorflow::RemoteRendezvous* r =
    213       grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
    214 
    215   auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
    216   TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
    217       session_name, server_def, true));
    218 
    219   std::shared_ptr<tensorflow::WorkerSession> worker_session;
    220   TF_RETURN_IF_ERROR(
    221       grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
    222           session_name, &worker_session));
    223 
    224   // Initialize remote tensor communication based on worker session.
    225   TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
    226 
    227   auto* device_mgr = grpc_server->worker_env()->device_mgr;
    228 
    229   return ctx->context.InitializeRemote(
    230       std::move(server), std::move(remote_eager_workers),
    231       std::move(remote_device_mgr), remote_contexts, r, device_mgr,
    232       keep_alive_secs);
    233 #undef LOG_AND_RETURN_IF_ERROR
    234 }
    235 
    236 tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
    237                                            TFE_TensorHandle* input) {
    238   TFE_OpInferenceContext* ictx = op->inference_ctx.get();
    239   const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
    240   if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
    241     // Some clients that are still setting their input attributes manually are
    242     // adding input list to their op by calling `TFE_OpAddInput` for each of
    243     // its elements instead of calling `TFE_OpAddInputList`. When this happens,
    244     // we cannot detect the end of such list, thus lose track of the input
    245     // arguments in the op definition. To guarantee backward compatibility with
    246     // those clients, disable automatic inference in this case.
    247     op->inference_ctx.reset(nullptr);
    248     return tensorflow::Status::OK();
    249   }
    250   const std::string& type_attr = input_def.type_attr();
    251   if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
    252     op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
    253     ictx->attrs.insert(type_attr);
    254   }
    255   return tensorflow::Status::OK();
    256 }
    257 
    258 void OpInferSingleTypeInputListAttrs(TFE_Op* op,
    259                                      const tensorflow::OpDef::ArgDef& input_def,
    260                                      TFE_TensorHandle** inputs,
    261                                      int num_inputs) {
    262   TFE_OpInferenceContext* ictx = op->inference_ctx.get();
    263   if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
    264     op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
    265     ictx->attrs.insert(input_def.number_attr());
    266   }
    267   if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
    268     op->operation.MutableAttrs()->Set(input_def.type_attr(),
    269                                       inputs[0]->handle->dtype);
    270     ictx->attrs.insert(input_def.type_attr());
    271   }
    272 }
    273 
    274 void OpInferMixedTypeInputListAttrs(TFE_Op* op,
    275                                     const tensorflow::OpDef::ArgDef& input_def,
    276                                     TFE_TensorHandle** inputs, int num_inputs) {
    277   TFE_OpInferenceContext* ictx = op->inference_ctx.get();
    278   if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
    279     std::unique_ptr<tensorflow::DataType[]> dtypes(
    280         new tensorflow::DataType[num_inputs]);
    281     for (int i = 0; i < num_inputs; ++i) {
    282       dtypes[i] = inputs[i]->handle->dtype;
    283     }
    284     op->operation.MutableAttrs()->Set(
    285         input_def.type_list_attr(),
    286         tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
    287                                                                 num_inputs));
    288     ictx->attrs.insert(input_def.type_list_attr());
    289   }
    290 }
    291 
    292 tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
    293                                          int num_inputs) {
    294   TFE_OpInferenceContext* ictx = op->inference_ctx.get();
    295   const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
    296   if (!input_def.type_list_attr().empty()) {
    297     OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
    298   } else if (!input_def.type_attr().empty() &&
    299              !input_def.number_attr().empty()) {
    300     OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
    301   } else {
    302     return tensorflow::errors::InvalidArgument("Invalid input list definition");
    303   }
    304   return tensorflow::Status::OK();
    305 }
    306 
    307 }  // namespace
    308 
    309 extern "C" {
    310 
    311 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
    312 
    313 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
    314                                  size_t proto_len, TF_Status* status) {
    315   TF_SetConfig(&options->session_options, proto, proto_len, status);
    316 }
    317 
    318 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
    319                                 unsigned char enable) {
    320   options->async = enable;
    321 }
    322 
    323 void TFE_ContextOptionsSetDevicePlacementPolicy(
    324     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
    325   options->policy = policy;
    326 }
    327 
    328 TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
    329                                                         unsigned char enable,
    330                                                         TF_Status* status) {
    331   status->status = ctx->context.SetAsyncForThread(enable);
    332 }
    333 
    334 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
    335 
    336 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
    337   std::vector<std::unique_ptr<tensorflow::Device>> devices;
    338   status->status = tensorflow::DeviceFactory::AddDevices(
    339       opts->session_options.options, "/job:localhost/replica:0/task:0",
    340       &devices);
    341   if (!status->status.ok()) return nullptr;
    342   std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
    343       new tensorflow::DeviceMgr(std::move(devices)));
    344 
    345   tensorflow::Rendezvous* r =
    346       new tensorflow::IntraProcessRendezvous(device_mgr.get());
    347 
    348   return new TFE_Context(opts->session_options.options, opts->policy,
    349                          opts->async, device_mgr.release(),
    350                          /*device_mgr_owned*/ true, r);
    351 }
    352 
    353 TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
    354                                        TF_Session* sess, TF_Status* status) {
    355   const tensorflow::DeviceMgr* device_mgr = nullptr;
    356   status->status = sess->session->LocalDeviceManager(&device_mgr);
    357   if (!status->status.ok()) return nullptr;
    358   tensorflow::Rendezvous* r =
    359       new tensorflow::IntraProcessRendezvous(device_mgr);
    360   return new TFE_Context(opts->session_options.options, opts->policy,
    361                          opts->async, device_mgr, /*device_mgr_owned*/ false,
    362                          r);
    363 }
    364 
    365 void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
    366 
    367 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
    368   TF_DeviceList* list = new TF_DeviceList;
    369   ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
    370   if (ctx->context.remote_device_mgr()) {
    371     ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
    372   }
    373   return list;
    374 }
    375 
    376 void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) {
    377   status->status = ctx->context.ClearCaches();
    378 }
    379 
    380 // Set server_def on the context, possibly updating it.
    381 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
    382                                                    int keep_alive_secs,
    383                                                    const void* proto,
    384                                                    size_t proto_len,
    385                                                    TF_Status* status) {
    386   tensorflow::ServerDef server_def;
    387   if (!server_def.ParseFromArray(proto, proto_len)) {
    388     status->status = tensorflow::errors::InvalidArgument(
    389         "Invalid tensorflow.ServerDef protocol buffer");
    390     return;
    391   }
    392   status->status =
    393       UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
    394 }
    395 
    396 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
    397     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
    398   ctx->context.SetThreadLocalDevicePlacementPolicy(
    399       static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
    400 }
    401 
    402 // Note: this function looks up a thread local policy. So it should be called in
    403 // the appropriate client thread. In particular, in async mode, it may not be
    404 // safe to call this function from the async EagerExecutor threads.
    405 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
    406     TFE_Context* ctx) {
    407   return static_cast<TFE_ContextDevicePlacementPolicy>(
    408       ctx->context.GetDevicePlacementPolicy());
    409 }
    410 
    411 void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
    412   status->status = ctx->context.AsyncWait();
    413 }
    414 
    415 void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
    416   status->status = ctx->context.GetStatus();
    417 }
    418 
    419 void TFE_ContextAsyncClearError(TFE_Context* ctx) {
    420   ctx->context.ClearAsyncError();
    421 }
    422 
    423 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
    424   tensorflow::Tensor tensor;
    425   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
    426   if (!status->status.ok()) return nullptr;
    427   return new TFE_TensorHandle(tensor, nullptr, nullptr);
    428 }
    429 
    430 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
    431   if (h == nullptr) return;
    432   VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
    433           << h->handle;
    434   if (h->handle) {
    435     h->handle->Unref();
    436   }
    437   delete h;
    438 }
    439 
    440 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
    441   return static_cast<TF_DataType>(h->handle->dtype);
    442 }
    443 
    444 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
    445   if (h == nullptr || h->handle == nullptr) {
    446     status->status = tensorflow::errors::InvalidArgument(
    447         "The passed in handle is a nullptr");
    448     return -1;
    449   }
    450   int result;
    451   status->status = h->handle->NumDims(&result);
    452   return result;
    453 }
    454 
    455 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
    456   if (h == nullptr || h->handle == nullptr) {
    457     status->status = tensorflow::errors::InvalidArgument(
    458         "The passed in handle is a nullptr");
    459     return -1;
    460   }
    461   tensorflow::int64 result;
    462   status->status = h->handle->NumElements(&result);
    463   return result;
    464 }
    465 
    466 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
    467                             TF_Status* status) {
    468   if (h == nullptr || h->handle == nullptr) {
    469     status->status = tensorflow::errors::InvalidArgument(
    470         "The passed in handle is a nullptr");
    471     return -1;
    472   }
    473   tensorflow::int64 result;
    474   status->status = h->handle->Dim(dim_index, &result);
    475   return result;
    476 }
    477 
    478 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
    479   if (h == nullptr || h->handle == nullptr) {
    480     status->status = tensorflow::errors::InvalidArgument(
    481         "The passed in handle is a nullptr");
    482     return nullptr;
    483   }
    484   tensorflow::Device* d = h->handle->op_device();
    485   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
    486                         : d->name().c_str();
    487 }
    488 
    489 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
    490                                               TF_Status* status) {
    491   if (h == nullptr || h->handle == nullptr) {
    492     status->status = tensorflow::errors::InvalidArgument(
    493         "The passed in handle is a nullptr");
    494     return nullptr;
    495   }
    496   tensorflow::Device* d = h->handle->device();
    497   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
    498                         : d->name().c_str();
    499 }
    500 
    501 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
    502     TFE_TensorHandle* h, TF_Status* status) {
    503   if (h == nullptr || h->handle == nullptr) {
    504     status->status = tensorflow::errors::InvalidArgument(
    505         "The passed in handle is a nullptr");
    506     return nullptr;
    507   }
    508 
    509   h->handle->Ref();
    510 
    511   return new TFE_TensorHandle(h->handle);
    512 }
    513 
    514 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
    515   if (h == nullptr || h->handle == nullptr) {
    516     status->status = tensorflow::errors::InvalidArgument(
    517         "The passed in handle is a nullptr");
    518     return nullptr;
    519   }
    520   // TODO(agarwal): move this implementation inside TFE_TensorHandle.
    521   const tensorflow::Tensor* t = nullptr;
    522   tensorflow::TensorHandle* h_cpu = nullptr;
    523   tensorflow::Device* d = nullptr;
    524   tensorflow::Device* op_device = nullptr;
    525 
    526   if (h->handle->IsRemote()) {
    527     status->status = EagerCopyToDevice(
    528         h->handle, h->handle->Context(),
    529         h->handle->Context()->HostCPU()->name().c_str(), &h_cpu);
    530     if (!status->status.ok()) {
    531       return nullptr;
    532     }
    533     status->status = h_cpu->TensorAndDevice(&t, &d, &op_device);
    534     if (!status->status.ok()) {
    535       h_cpu->Unref();
    536       return nullptr;
    537     }
    538   } else {
    539     status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
    540     if (!status->status.ok()) return nullptr;
    541 
    542     if (!IsCPU(d)) {
    543       status->status = h->handle->CopyToDevice(
    544           h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu);
    545       if (!status->status.ok()) {
    546         return nullptr;
    547       }
    548       status->status = h_cpu->TensorAndDevice(&t, &d, &op_device);
    549       if (!status->status.ok()) {
    550         h_cpu->Unref();
    551         return nullptr;
    552       }
    553     }
    554   }
    555   TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
    556   if (h_cpu != nullptr) {
    557     h_cpu->Unref();
    558   }
    559   return retval;
    560 }
    561 
    562 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
    563                   TF_Status* status) {
    564   const char* name = op_or_function_name;  // Shorthand
    565   const tensorflow::AttrTypeMap* types;
    566   bool is_function = false;
    567   status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
    568   if (!status->status.ok()) {
    569     return nullptr;
    570   }
    571   if (!is_function) {
    572     const tensorflow::OpDef* op_def;
    573     status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
    574     if (!status->status.ok()) {
    575       return nullptr;
    576     }
    577     return new TFE_Op(ctx, name, false, types,
    578                       new TFE_OpInferenceContext(op_def));
    579   }
    580   if (!ctx->context.FindFunctionByName(name)) {
    581     status->status = tensorflow::errors::NotFound(
    582         "'", name,
    583         "' is neither a type of a primitive operation nor a name "
    584         "of a function registered in binary running on ",
    585         tensorflow::port::Hostname(),
    586         ". Make sure the operation or function is "
    587         "registered in the binary running in this process.");
    588     return nullptr;
    589   }
    590   return new TFE_Op(ctx, name, true, types, nullptr);
    591 }
    592 
    593 void TFE_DeleteOp(TFE_Op* op) { delete op; }
    594 
    595 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
    596   status->status = op->operation.SetDevice(device_name);
    597 }
    598 
    599 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
    600   tensorflow::Device* device = (op->operation.Device() == nullptr)
    601                                    ? op->operation.EagerContext()->HostCPU()
    602                                    : op->operation.Device();
    603   return device->name().c_str();
    604 }
    605 
    606 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
    607   op->operation.SetUseXla(enable);
    608 #ifndef TENSORFLOW_EAGER_USE_XLA
    609   LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
    610                   "built with XLA support.";
    611 #endif  // TENSORFLOW_EAGER_USE_XLA
    612 }
    613 
    614 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
    615   op->operation.AddInput(input->handle);
    616   if (op->inference_ctx) {
    617     status->status = OpInferSingleInputAttrs(op, input);
    618   }
    619 }
    620 
    621 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
    622                         TF_Status* status) {
    623   for (int i = 0; i < num_inputs; ++i) {
    624     op->operation.AddInput(inputs[i]->handle);
    625   }
    626   if (op->inference_ctx) {
    627     status->status = OpInferInputListAttrs(op, inputs, num_inputs);
    628   }
    629 }
    630 
    631 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
    632                               unsigned char* is_list, TF_Status* status) {
    633   TF_AttrType ret;
    634   status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
    635                                               attr_name, &ret, is_list);
    636   return ret;
    637 }
    638 
    639 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
    640                                   const char* op_or_function_name,
    641                                   const char* attr_name, unsigned char* is_list,
    642                                   TF_Status* status) {
    643   TF_AttrType ret;
    644   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
    645   if (!status->status.ok()) {
    646     return TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
    647   }
    648   ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
    649   TFE_DeleteOp(op);
    650   return ret;
    651 }
    652 
    653 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
    654                          size_t length) {
    655   op->operation.MutableAttrs()->Set(
    656       attr_name,
    657       tensorflow::StringPiece(static_cast<const char*>(value), length));
    658 }
    659 
    660 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
    661   op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
    662 }
    663 
    664 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
    665   op->operation.MutableAttrs()->Set(attr_name, value);
    666 }
    667 
    668 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
    669   op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
    670 }
    671 
    672 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
    673   op->operation.MutableAttrs()->Set(attr_name,
    674                                     static_cast<tensorflow::DataType>(value));
    675 }
    676 
    677 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
    678                         const int num_dims, TF_Status* out_status) {
    679   if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
    680     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
    681                  tensorflow::strings::StrCat(
    682                      "Value specified for `", attr_name, "` has ", num_dims,
    683                      " dimensions which is over the limit of ",
    684                      tensorflow::TensorShape::MaxDimensions(), ".")
    685                      .c_str());
    686     return;
    687   }
    688   tensorflow::TensorShapeProto proto;
    689   if (num_dims < 0) {
    690     proto.set_unknown_rank(true);
    691   } else {
    692     for (int d = 0; d < num_dims; ++d) {
    693       proto.add_dim()->set_size(dims[d]);
    694     }
    695   }
    696   op->operation.MutableAttrs()->Set(attr_name, proto);
    697 }
    698 
    699 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
    700                            const TFE_Op* value) {
    701   tensorflow::AttrValue attr_value;
    702   tensorflow::NameAttrList* func = attr_value.mutable_func();
    703   func->set_name(value->operation.Name());
    704   value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
    705   op->operation.MutableAttrs()->Set(attr_name, attr_value);
    706 }
    707 
    708 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
    709                                const char* data, size_t length) {
    710   tensorflow::AttrValue attr_value;
    711   tensorflow::NameAttrList* func = attr_value.mutable_func();
    712   func->set_name(data, length);
    713   op->operation.MutableAttrs()->Set(attr_name, attr_value);
    714 }
    715 
    716 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
    717                          TF_Status* status) {
    718   tensorflow::Tensor t;
    719   status->status = TF_TensorToTensor(tensor, &t);
    720   if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
    721 }
    722 
    723 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
    724                              const void* const* values, const size_t* lengths,
    725                              int num_values) {
    726   std::vector<tensorflow::StringPiece> v(num_values);
    727   for (int i = 0; i < num_values; ++i) {
    728     v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
    729                                    lengths[i]);
    730   }
    731   op->operation.MutableAttrs()->Set(attr_name, v);
    732 }
    733 
    734 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
    735                             const float* values, int num_values) {
    736   op->operation.MutableAttrs()->Set(
    737       attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
    738 }
    739 
    740 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
    741                           const int64_t* values, int num_values) {
    742   op->operation.MutableAttrs()->Set(
    743       attr_name, tensorflow::gtl::ArraySlice<const int64>(
    744                      reinterpret_cast<const int64*>(values), num_values));
    745 }
    746 
    747 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
    748                            const TF_DataType* values, int num_values) {
    749   op->operation.MutableAttrs()->Set(
    750       attr_name,
    751       tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
    752           reinterpret_cast<const tensorflow::DataType*>(values), num_values));
    753 }
    754 
    755 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
    756                            const unsigned char* values, int num_values) {
    757   std::unique_ptr<bool[]> b(new bool[num_values]);
    758   for (int i = 0; i < num_values; ++i) {
    759     b[i] = values[i];
    760   }
    761   op->operation.MutableAttrs()->Set(
    762       attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
    763 }
    764 
    765 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
    766                             const int64_t** dims, const int* num_dims,
    767                             int num_values, TF_Status* out_status) {
    768   std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
    769       new tensorflow::TensorShapeProto[num_values]);
    770   for (int i = 0; i < num_values; ++i) {
    771     const auto num_dims_i = num_dims[i];
    772 
    773     if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
    774       TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
    775                    tensorflow::strings::StrCat(
    776                        "Value specified for `", attr_name, "` has ", num_dims_i,
    777                        " dimensions which is over the limit of ",
    778                        tensorflow::TensorShape::MaxDimensions(), ".")
    779                        .c_str());
    780       return;
    781     }
    782     if (num_dims_i < 0) {
    783       proto[i].set_unknown_rank(true);
    784     } else {
    785       const int64_t* dims_i = dims[i];
    786       auto proto_i = &proto[i];
    787       for (int d = 0; d < num_dims_i; ++d) {
    788         proto_i->add_dim()->set_size(dims_i[d]);
    789       }
    790     }
    791   }
    792   op->operation.MutableAttrs()->Set(
    793       attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
    794                      proto.get(), num_values));
    795 }
    796 
    797 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
    798                                const TFE_Op** value, int num_values) {
    799   std::unique_ptr<tensorflow::NameAttrList[]> funcs(
    800       new tensorflow::NameAttrList[num_values]);
    801   for (int i = 0; i < num_values; i++) {
    802     funcs[i].set_name(value[i]->operation.Name());
    803     value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
    804   }
    805   op->operation.MutableAttrs()->Set(
    806       attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
    807                      funcs.get(), num_values));
    808 }
    809 
    810 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
    811                  TF_Status* status) {
    812   VLOG(1) << "Calling TFE_Execute() on op " << op;
    813   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
    814       *num_retvals);
    815   status->status =
    816       tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals);
    817   if (!status->status.ok()) {
    818     return;
    819   }
    820   for (int i = 0; i < *num_retvals; ++i) {
    821     retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
    822   }
    823 }
    824 
    825 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
    826                                                TFE_Context* ctx,
    827                                                const char* device_name,
    828                                                TF_Status* status) {
    829   tensorflow::TensorHandle* handle;
    830   status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
    831                                                  device_name, &handle);
    832   if (status->status.ok()) {
    833     return new TFE_TensorHandle(handle);
    834   }
    835   return nullptr;
    836 }
    837 
    838 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
    839                                const char* serialized_function_def, size_t size,
    840                                TF_Status* status) {
    841   tensorflow::FunctionDef function_def;
    842   if (!function_def.ParseFromArray(serialized_function_def, size)) {
    843     status->status =
    844         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
    845     return;
    846   }
    847   status->status = ctx->context.AddFunctionDef(function_def);
    848 }
    849 
    850 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
    851                             TF_Status* status) {
    852   status->status = ctx->context.AddFunctionDef(function->fdef);
    853 }
    854 
    855 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
    856   return ctx->context.FindFunctionDef(name) != nullptr;
    857 }
    858 
    859 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
    860   ctx->context.SetShouldStoreGraphs(true);
    861   ctx->context.SetShouldStoreStepStats(true);
    862 }
    863 
    864 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
    865   ctx->context.SetShouldStoreGraphs(false);
    866   ctx->context.SetShouldStoreStepStats(false);
    867 }
    868 
    869 }  // extern "C"
    870 
    871 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
    872   return new TFE_TensorHandle(t, nullptr, nullptr);
    873 }
    874 
    875 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
    876     TFE_TensorHandle* h, TF_Status* status) {
    877   if (!h->handle->OnHostCPU()) {
    878     status->status = tensorflow::errors::FailedPrecondition(
    879         "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
    880         "a tensorflow::Tensor");
    881     return nullptr;
    882   }
    883   tensorflow::Device* d = nullptr;
    884   tensorflow::Device* op_device = nullptr;
    885   const tensorflow::Tensor* t = nullptr;
    886   status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
    887   if (!status->status.ok()) return nullptr;
    888   return t;
    889 }
    890 
    891 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
    892                                   TF_Status* status) {
    893   TFE_ContextAsyncWait(ctx, status);
    894   if (!status->status.ok()) return;
    895   tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
    896   status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
    897   ctx->context.ClearRunMetadata();
    898 }
    899 
    900 namespace {
    901 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
    902                 TF_Status* status) {
    903   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
    904   for (const auto& attr : func.attr()) {
    905     if (TF_GetCode(status) != TF_OK) return nullptr;
    906     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
    907     if (TF_GetCode(status) != TF_OK) return nullptr;
    908   }
    909   return func_op;
    910 }
    911 }  // namespace
    912 
    913 void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
    914 
    915 void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
    916 
    917 namespace tensorflow {
    918 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
    919                           const tensorflow::AttrValue& default_value,
    920                           const char* attr_name, TF_Status* status) {
    921   switch (default_value.value_case()) {
    922     case tensorflow::AttrValue::kS: {
    923       const string& v = default_value.s();
    924       TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
    925       break;
    926     }
    927     case tensorflow::AttrValue::kI:
    928       TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
    929       break;
    930     case tensorflow::AttrValue::kF:
    931       TFE_OpSetAttrFloat(op, attr_name, default_value.f());
    932       break;
    933     case tensorflow::AttrValue::kB:
    934       TFE_OpSetAttrBool(op, attr_name, default_value.b());
    935       break;
    936     case tensorflow::AttrValue::kType:
    937       TFE_OpSetAttrType(op, attr_name,
    938                         static_cast<TF_DataType>(default_value.type()));
    939       break;
    940     case tensorflow::AttrValue::kShape: {
    941       const auto& tensor_shape = default_value.shape();
    942       if (tensor_shape.unknown_rank()) {
    943         TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
    944       } else {
    945         const auto num_dims = tensor_shape.dim_size();
    946         std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
    947         for (int i = 0; i < num_dims; ++i) {
    948           dims[i] = tensor_shape.dim(i).size();
    949         }
    950         TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
    951       }
    952     } break;
    953     case tensorflow::AttrValue::kFunc: {
    954       const auto func_op = GetFunc(ctx, default_value.func(), status);
    955       if (TF_GetCode(status) != TF_OK) return;
    956       // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
    957       // require TFE_Op* and just convert it internally a NameAttrValue, so
    958       // consider adding an overload to the C API to make this case easier.
    959       TFE_OpSetAttrFunction(op, attr_name, func_op);
    960     } break;
    961     case tensorflow::AttrValue::kList:
    962       TF_FALLTHROUGH_INTENDED;
    963     case tensorflow::AttrValue::kTensor:
    964       TF_FALLTHROUGH_INTENDED;
    965     case tensorflow::AttrValue::kPlaceholder:
    966       TF_FALLTHROUGH_INTENDED;
    967     case tensorflow::AttrValue::VALUE_NOT_SET:
    968       TF_SetStatus(
    969           status, TF_UNIMPLEMENTED,
    970           tensorflow::strings::StrCat("Unable to get setfor default value: ",
    971                                       default_value.DebugString())
    972               .data());
    973   }
    974 }
    975 }  // namespace tensorflow
    976