Home | History | Annotate | Download | only in eager
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/c/eager/c_api.h"
     17 
     18 #include <algorithm>
     19 #include <cstddef>
     20 #include <memory>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "tensorflow/c/c_api.h"
     25 #include "tensorflow/c/c_api_internal.h"
     26 #include "tensorflow/c/eager/c_api_internal.h"
     27 #include "tensorflow/c/eager/runtime.h"
     28 #ifdef TENSORFLOW_EAGER_USE_XLA
     29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     30 #endif  // TENSORFLOW_EAGER_USE_XLA
     31 #include "tensorflow/core/common_runtime/copy_tensor.h"
     32 #include "tensorflow/core/common_runtime/device_factory.h"
     33 #include "tensorflow/core/common_runtime/device_mgr.h"
     34 #include "tensorflow/core/common_runtime/function.h"
     35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
     36 #include "tensorflow/core/framework/rendezvous.h"
     37 #include "tensorflow/core/framework/tensor_shape.pb.h"
     38 #include "tensorflow/core/framework/types.h"
     39 #include "tensorflow/core/lib/core/refcount.h"
     40 #include "tensorflow/core/lib/gtl/flatmap.h"
     41 #include "tensorflow/core/lib/gtl/map_util.h"
     42 #include "tensorflow/core/lib/gtl/stl_util.h"
     43 #include "tensorflow/core/platform/mutex.h"
     44 #include "tensorflow/core/platform/thread_annotations.h"
     45 #include "tensorflow/core/public/version.h"
     46 
     47 using tensorflow::int64;
     48 using tensorflow::string;
     49 
     50 namespace {
     51 bool IsCPU(const tensorflow::Device* d) {
     52   return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
     53 }
     54 
     55 bool IsXLA(const tensorflow::Device* d) {
     56   if (d == nullptr) return false;
     57   const auto& device_type = d->attributes().device_type();
     58   return device_type.find("XLA") != std::string::npos;
     59 }
     60 
     61 string DeviceName(const tensorflow::Device* d) {
     62   return (d == nullptr) ? "cpu:0" : d->name();
     63 }
     64 
     65 #ifdef TENSORFLOW_EAGER_USE_XLA
     66 std::atomic_int_fast64_t func_id_generator(0);
     67 #endif  // TENSORFLOW_EAGER_USE_XLA
     68 }  // namespace
     69 
     70 extern "C" {
     71 
     72 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
     73 
     74 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
     75                                  size_t proto_len, TF_Status* status) {
     76   TF_SetConfig(&options->session_options, proto, proto_len, status);
     77 }
     78 
     79 void TFE_ContextOptionsSetDevicePlacementPolicy(
     80     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
     81   options->policy = policy;
     82 }
     83 
     84 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
     85 
     86 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
     87   TF_Graph* graph = TF_NewGraph();
     88   TF_Session* session = TF_NewSession(graph, &opts->session_options, status);
     89   if (status->status.ok()) {
     90     if (session->device_mgr == nullptr || session->devices.empty()) {
     91       status->status = tensorflow::errors::InvalidArgument(
     92           "Provided TF_SessionOptions are not compatible with eager execution "
     93           "(perhaps the TF_SessionOptions alluded to session execution in a "
     94           "remote address space?)");
     95     }
     96   }
     97   if (!status->status.ok()) {
     98     TF_DeleteGraph(graph);
     99     return nullptr;
    100   }
    101 
    102   return new TFE_Context(*opts, session);
    103 }
    104 
    105 void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
    106   status->status = tensorflow::Status::OK();
    107   {
    108     tensorflow::mutex_lock ml(ctx->cache_mu);
    109     tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
    110   }
    111   TF_Graph* graph = ctx->session->graph;
    112   TF_DeleteSession(ctx->session, status);
    113   TF_DeleteGraph(graph);
    114   ctx->rendezvous->Unref();
    115   delete ctx;
    116 }
    117 
    118 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
    119   return TF_SessionListDevices(ctx->session, status);
    120 }
    121 
    122 void TFE_ContextClearCaches(TFE_Context* ctx) {
    123   tensorflow::mutex_lock ml(ctx->cache_mu);
    124   tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
    125 }
    126 
    127 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
    128     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
    129   tensorflow::mutex_lock ml(ctx->policy_map_mu);
    130   ctx->thread_local_policies[std::this_thread::get_id()] = policy;
    131 }
    132 
    133 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
    134     TFE_Context* ctx) {
    135   tensorflow::mutex_lock ml(ctx->policy_map_mu);
    136   auto policy_map_it =
    137       ctx->thread_local_policies.find(std::this_thread::get_id());
    138   if (policy_map_it != ctx->thread_local_policies.end()) {
    139     return policy_map_it->second;
    140   }
    141   return ctx->policy;
    142 }
    143 
    144 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
    145   tensorflow::Tensor tensor;
    146   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
    147   if (!status->status.ok()) return nullptr;
    148   return new TFE_TensorHandle(tensor, nullptr);
    149 }
    150 
    151 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
    152 
    153 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
    154   return static_cast<TF_DataType>(h->t.dtype());
    155 }
    156 
    157 int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
    158 
    159 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
    160   return h->t.dim_size(dim_index);
    161 }
    162 
    163 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
    164   // TODO(apassos) this will be potentially incorrect in the distributed case as
    165   // our local device will have a name which depends on the ClusterSpec and
    166   // hence will require the context to resolve.
    167   return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
    168                            : h->d->name().c_str();
    169 }
    170 
    171 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
    172   if (!IsCPU(h->d)) {
    173     TF_SetStatus(status, TF_UNIMPLEMENTED,
    174                  tensorflow::strings::StrCat(
    175                      "TFE_TensorHandle can be resolved iff it is on CPU (this "
    176                      "handle is on ",
    177                      h->d->name(),
    178                      "). Consider using TFE_TensorHandleCopyToDevice to get a "
    179                      "copy of the tensor on CPU")
    180                      .c_str());
    181     return nullptr;
    182   }
    183   return tensorflow::TF_TensorFromTensor(h->t, status);
    184 }
    185 
    186 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
    187                                                TFE_Context* ctx,
    188                                                const char* device_name,
    189                                                TF_Status* status) {
    190   tensorflow::Device* dstd = ctx->devices()[0];
    191   if (device_name != nullptr && strlen(device_name) > 0) {
    192     status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd);
    193     if (!status->status.ok()) return nullptr;
    194   }
    195 
    196   tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
    197   bool is_same_device =
    198       (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
    199   const bool dst_cpu = IsCPU(dstd);
    200   const bool src_cpu = IsCPU(srcd);
    201   // both_on_cpu can be true and yet is_same_device is false, if one of src/dst
    202   // has device type XLA_CPU, and the other CPU.
    203   const bool both_on_cpu = src_cpu && dst_cpu;
    204   if (is_same_device || both_on_cpu) {
    205     return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
    206   }
    207   tensorflow::Tensor* src = &(h->t);
    208   if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
    209                    !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
    210     TF_SetStatus(
    211         status, TF_INVALID_ARGUMENT,
    212         tensorflow::strings::StrCat("Can't copy Tensor with type ",
    213                                     tensorflow::DataTypeString(src->dtype()),
    214                                     " to device ", DeviceName(dstd), ".")
    215             .c_str());
    216     return nullptr;
    217   }
    218   tensorflow::AllocatorAttributes attr;
    219   if (src->dtype() == tensorflow::DT_VARIANT) {
    220     attr.set_on_host(true);
    221   }
    222   tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
    223   if (src->shape().num_elements() == 0) {
    224     return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd);
    225   }
    226   tensorflow::DeviceContext* src_device_context = nullptr;
    227   if (!src_cpu) {
    228     src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
    229   }
    230   tensorflow::DeviceContext* dst_device_context = nullptr;
    231   if (!dst_cpu) {
    232     dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
    233   }
    234   // TODO(ashankar): The Sync() call below may be more aggressive than
    235   // necessary. It is based on knowledge of implementation details - that
    236   // GPU devices are implemented using 3 streams - one for host->device copies,
    237   // one for device->host copies and one for sending operations to the GPU.
    238   // With that setup, Sync()ing across all 3 streams should be sufficient
    239   // but more than necessary (since it waits for operations that might have
    240   // nothing to do with this tensor to complete).
    241   status->status = srcd->Sync();
    242   tensorflow::Notification n;
    243   tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
    244                                  srcd, dstd, tensorflow::AllocatorAttributes(),
    245                                  tensorflow::AllocatorAttributes(), src, &dst,
    246                                  [status, &n](const tensorflow::Status& s) {
    247                                    status->status = s;
    248                                    n.Notify();
    249                                  });
    250   n.WaitForNotification();
    251   return (TF_GetCode(status) == TF_OK)
    252              ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd)
    253              : nullptr;
    254 }
    255 
    256 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
    257                   TF_Status* status) {
    258   const char* name = op_or_function_name;  // Shorthand
    259   const tensorflow::AttrTypeMap* types;
    260   status->status = tensorflow::AttrTypeMapForOp(name, &types);
    261   if (status->status.ok()) return new TFE_Op(ctx, name, types);
    262   if (TF_GetCode(status) == TF_NOT_FOUND) {
    263     tensorflow::mutex_lock l(ctx->functions_mu);
    264     if (ctx->func_lib_def.Find(name) != nullptr) {
    265       status->status = tensorflow::Status::OK();
    266       return new TFE_Op(ctx, name, nullptr);
    267     }
    268   }
    269   return nullptr;
    270 }
    271 
    272 void TFE_DeleteOp(TFE_Op* op) { delete op; }
    273 
    274 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
    275   tensorflow::Device* d = nullptr;
    276   if (device_name != nullptr && strlen(device_name) > 0) {
    277     status->status =
    278         op->ctx->session->device_mgr->LookupDevice(device_name, &d);
    279     if (!status->status.ok()) return;
    280   }
    281   op->device = d;
    282 }
    283 
    284 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
    285   tensorflow::Device* device =
    286       (op->device == nullptr) ? op->ctx->devices()[0] : op->device;
    287   return device->name().c_str();
    288 }
    289 
    290 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
    291   op->use_xla = enable;
    292 #ifndef TENSORFLOW_EAGER_USE_XLA
    293   LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
    294                   "built with XLA support.";
    295 #endif  // TENSORFLOW_EAGER_USE_XLA
    296 }
    297 
    298 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
    299   // Questionable heuristic ...
    300   //
    301   // Motivation: After an 'op' is placed on GPU because some of its earlier
    302   // inputs are on GPU, we want to keep the 'op' there, even if some later
    303   // inputs of it are not on GPU.
    304   if (IsCPU(op->device) && !IsCPU(h->d)) {
    305     op->device = h->d;
    306   }
    307   if (!status->status.ok()) return;
    308   op->inputs.push_back(h->t);
    309   op->input_devices.push_back(h->d);
    310   op->attrs.NumInputs(op->inputs.size());
    311 }
    312 
    313 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
    314                               unsigned char* is_list, TF_Status* status) {
    315   TF_AttrType ret;
    316   if (op->is_function()) {
    317     status->status = tensorflow::errors::Unimplemented(
    318         "TODO(apassos): Support for attributes for TensorFlow functions is not "
    319         "ready yet.");
    320     return TF_ATTR_INT;  // The compiler requires that we return something.
    321   }
    322   status->status =
    323       tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
    324   return ret;
    325 }
    326 
    327 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
    328                                   const char* op_or_function_name,
    329                                   const char* attr_name, unsigned char* is_list,
    330                                   TF_Status* status) {
    331   TF_AttrType ret;
    332   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
    333   if (!status->status.ok()) {
    334     return TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
    335   }
    336   ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
    337   TFE_DeleteOp(op);
    338   return ret;
    339 }
    340 
    341 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
    342   op->attrs.Set(attr_name, value);
    343 }
    344 
    345 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
    346   op->attrs.Set(attr_name, static_cast<int64>(value));
    347 }
    348 
    349 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
    350   op->attrs.Set(attr_name, value);
    351 }
    352 
    353 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
    354   op->attrs.Set(attr_name, (value == 0) ? false : true);
    355 }
    356 
    357 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
    358   op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
    359 }
    360 
    361 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
    362                         const int num_dims, TF_Status* out_status) {
    363   if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
    364     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
    365                  tensorflow::strings::StrCat(
    366                      "Value specified for `", attr_name, "` has ", num_dims,
    367                      " dimensions which is over the limit of ",
    368                      tensorflow::TensorShape::MaxDimensions(), ".")
    369                      .c_str());
    370     return;
    371   }
    372   tensorflow::TensorShapeProto proto;
    373   if (num_dims < 0) {
    374     proto.set_unknown_rank(true);
    375   } else {
    376     for (int d = 0; d < num_dims; ++d) {
    377       proto.add_dim()->set_size(dims[d]);
    378     }
    379   }
    380   op->attrs.Set(attr_name, proto);
    381 }
    382 
    383 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
    384                            const TFE_Op* value) {
    385   tensorflow::AttrValue attr_value;
    386   tensorflow::NameAttrList* func = attr_value.mutable_func();
    387   func->set_name(value->name);
    388   value->attrs.FillAttrValueMap(func->mutable_attr());
    389   op->attrs.Set(attr_name, attr_value);
    390 }
    391 
    392 #define TFE_OP_SET_ATTR_LIST(fn, type)                                \
    393   void fn(TFE_Op* op, const char* attr_name, const type* values,      \
    394           int num_values) {                                           \
    395     op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
    396                                  values, num_values));                \
    397   }
    398 TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
    399 TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
    400 #undef TFE_OP_SET_ATTR_LIST
    401 
    402 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
    403                           const int64_t* values, int num_values) {
    404   op->attrs.Set(attr_name,
    405                 tensorflow::gtl::ArraySlice<const int64>(
    406                     reinterpret_cast<const int64*>(values), num_values));
    407 }
    408 
    409 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
    410                            const TF_DataType* values, int num_values) {
    411   op->attrs.Set(
    412       attr_name,
    413       tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
    414           reinterpret_cast<const tensorflow::DataType*>(values), num_values));
    415 }
    416 
    417 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
    418                            const unsigned char* values, int num_values) {
    419   std::unique_ptr<bool[]> b(new bool[num_values]);
    420   for (int i = 0; i < num_values; ++i) {
    421     b[i] = values[i];
    422   }
    423   op->attrs.Set(attr_name,
    424                 tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
    425 }
    426 
    427 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
    428                             const int64_t** dims, const int* num_dims,
    429                             int num_values, TF_Status* out_status) {
    430   std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
    431       new tensorflow::TensorShapeProto[num_values]);
    432   for (int i = 0; i < num_values; ++i) {
    433     const auto num_dims_i = num_dims[i];
    434 
    435     if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
    436       TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
    437                    tensorflow::strings::StrCat(
    438                        "Value specified for `", attr_name, "` has ", num_dims_i,
    439                        " dimensions which is over the limit of ",
    440                        tensorflow::TensorShape::MaxDimensions(), ".")
    441                        .c_str());
    442       return;
    443     }
    444     if (num_dims_i < 0) {
    445       proto[i].set_unknown_rank(true);
    446     } else {
    447       const int64_t* dims_i = dims[i];
    448       auto proto_i = &proto[i];
    449       for (int d = 0; d < num_dims_i; ++d) {
    450         proto_i->add_dim()->set_size(dims_i[d]);
    451       }
    452     }
    453   }
    454   op->attrs.Set(attr_name,
    455                 tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
    456                     proto.get(), num_values));
    457 }
    458 
    459 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
    460                                const TFE_Op** value, int num_values) {
    461   std::unique_ptr<tensorflow::NameAttrList[]> funcs(
    462       new tensorflow::NameAttrList[num_values]);
    463   for (int i = 0; i < num_values; i++) {
    464     funcs[i].set_name(value[i]->name);
    465     value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr());
    466   }
    467   op->attrs.Set(attr_name,
    468                 tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
    469                     funcs.get(), num_values));
    470 }
    471 
    472 namespace {
    473 
    474 tensorflow::Status ValidateInputTypeAndPlacement(
    475     TFE_Context* ctx, tensorflow::Device* host_device,
    476     tensorflow::Device* op_device, TFE_Op* op,
    477     const tensorflow::OpKernel* kernel,
    478     std::vector<TFE_TensorHandle*>* copied_tensors) {
    479   const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
    480   if (memtypes.size() != op->inputs.size()) {
    481     return tensorflow::errors::InvalidArgument(
    482         "expected ", memtypes.size(), " inputs, got ", op->inputs.size());
    483   }
    484   for (int i = 0; i < op->inputs.size(); ++i) {
    485     const tensorflow::Device* expected_device =
    486         memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
    487     const tensorflow::Device* actual_device =
    488         op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
    489     if (expected_device != actual_device) {
    490       switch (TFE_ContextGetDevicePlacementPolicy(ctx)) {
    491         case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32:
    492           // TODO(xpan): See if we could bubble python related error up
    493           // to python level.
    494           if (op->inputs[i].dtype() == tensorflow::DT_INT32) {
    495             // Note: enabling silent copies of int32 tensors to match behavior
    496             // of graph mode.
    497             break;
    498           }
    499           TF_FALLTHROUGH_INTENDED;
    500         case TFE_DEVICE_PLACEMENT_EXPLICIT:
    501           return tensorflow::errors::InvalidArgument(
    502               "Tensors on conflicting devices:"
    503               " cannot compute ",
    504               op->name, " as input #", i, " was expected to be on ",
    505               expected_device->name(), " but is actually on ",
    506               actual_device->name(), " (operation running on ",
    507               op_device->name(), ")",
    508               " Tensors can be copied explicitly using .gpu() or .cpu(),"
    509               " or transparently copied by using tfe.enable_eager_execution("
    510               "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices"
    511               " may slow down your model");
    512         case TFE_DEVICE_PLACEMENT_WARN:
    513           LOG(WARNING) << "before computing " << op->name << " input #" << i
    514                        << " was expected to be on " << expected_device->name()
    515                        << " but is actually on " << actual_device->name()
    516                        << " (operation running on " << op_device->name()
    517                        << "). This triggers a copy which can be a performance "
    518                           "bottleneck.";
    519           break;
    520         case TFE_DEVICE_PLACEMENT_SILENT:  // Do nothing.
    521           break;
    522       }
    523       // We are only here if the policy is warn or silent copies, so we should
    524       // trigger a copy.
    525       TFE_TensorHandle original{op->inputs[i], op->input_devices[i]};
    526       TF_Status* s = TF_NewStatus();
    527       TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
    528           &original, ctx, expected_device->name().c_str(), s);
    529       if (!s->status.ok()) {
    530         tensorflow::Status status = s->status;
    531         delete s;
    532         return tensorflow::errors::Internal(
    533             "Failed copying input tensor from ", actual_device->name(), " to ",
    534             expected_device->name(), " in order to run ", op->name, ": ",
    535             status.error_message());
    536       }
    537       op->inputs[i] = copied_tensor->t;
    538       copied_tensors->push_back(copied_tensor);
    539       op->input_devices[i] = copied_tensor->d;
    540       delete s;
    541     }
    542     if (op->inputs[i].dtype() != kernel->input_type(i)) {
    543       return tensorflow::errors::InvalidArgument(
    544           "cannot compute ", op->name, " as input #", i,
    545           " was expected to be a ",
    546           tensorflow::DataTypeString(kernel->input_type(i)),
    547           " tensor but is a ",
    548           tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor");
    549     }
    550   }
    551   return tensorflow::Status::OK();
    552 }
    553 
    554 #ifdef TENSORFLOW_EAGER_USE_XLA
    555 // Synthesizes and returns a wrapper function over `op`, which must be a
    556 // primitive op (e.g. matmul).
    557 //
    558 // The wrapper function conforms to the function signature expected by
    559 // _XlaLaunchOp, with input params ordered by <constants, (variable) args and
    560 // resources>. For example, if the op has input params <Const1, Arg2, Const3,
    561 // Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
    562 // Resource4> as the input params to the synthesized function.
    563 //
    564 // It populates `const_input_types`, `arg_input_types` and
    565 // `op_input_to_func_input` based on the reordering results, that the caller can
    566 // use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
    567 // `status` accordingly.
    568 const tensorflow::FunctionDef* OpToFunction(
    569     TFE_Op* op, std::vector<TF_DataType>* const_input_types,
    570     std::vector<TF_DataType>* arg_input_types,
    571     tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input,
    572     TF_Status* status) {
    573   DCHECK(!op->is_function());
    574 
    575   tensorflow::FunctionDef fdef;
    576 
    577   // Get the OpDef of the op we are trying to encapsulate.
    578   TFE_Context* ctx = op->ctx;
    579   const tensorflow::OpRegistrationData* op_data;
    580   {
    581     tensorflow::tf_shared_lock l(ctx->functions_mu);
    582     status->status = ctx->func_lib_def.LookUp(op->name, &op_data);
    583     if (!status->status.ok()) {
    584       return nullptr;
    585     }
    586   }
    587   const tensorflow::OpDef& op_def = op_data->op_def;
    588 
    589   tensorflow::OpDef* signature = fdef.mutable_signature();
    590 
    591   // Handle constant inputs.
    592   const std::unordered_set<string> const_inputs(
    593       *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
    594 
    595   // First add place holders for the input args, so that we can refer to them by
    596   // position in the next loop. Also tally up the resource inputs.
    597   int num_resource_inputs = 0;
    598   for (int i = 0; i < op_def.input_arg_size(); ++i) {
    599     if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) {
    600       ++num_resource_inputs;
    601     }
    602     signature->add_input_arg();
    603   }
    604 
    605   // Now we map the input params from `op_def` to `signature`, where the param
    606   // ordering for `signature` is: <constants, args, resources>.
    607   int const_index = 0;
    608   int arg_index = const_inputs.size();
    609   int resource_index = op_def.input_arg_size() - num_resource_inputs;
    610   for (int i = 0; i < op_def.input_arg_size(); ++i) {
    611     const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
    612     tensorflow::OpDef::ArgDef* func_input_arg = nullptr;
    613     if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
    614       VLOG(1) << "For const input, mapping op input " << i << " to func input "
    615               << const_index;
    616       (*op_input_to_func_input)[i] = const_index;
    617       func_input_arg = signature->mutable_input_arg(const_index++);
    618       const_input_types->push_back(
    619           static_cast<TF_DataType>(op->inputs[i].dtype()));
    620     } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
    621       VLOG(1) << "For resource input, mapping op input " << i
    622               << " to func input " << resource_index;
    623       (*op_input_to_func_input)[i] = resource_index;
    624       func_input_arg = signature->mutable_input_arg(resource_index++);
    625     } else {
    626       VLOG(1) << "For arg input, mapping op input " << i << " to func input "
    627               << arg_index;
    628       (*op_input_to_func_input)[i] = arg_index;
    629       func_input_arg = signature->mutable_input_arg(arg_index++);
    630       arg_input_types->push_back(
    631           static_cast<TF_DataType>(op->inputs[i].dtype()));
    632     }
    633 
    634     func_input_arg->set_name(op_input_arg.name());
    635     func_input_arg->set_type(op->inputs[i].dtype());
    636   }
    637   VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
    638 
    639   // Resources args are at the end of the function input params, and we should
    640   // have iterated over all of them.
    641   DCHECK_EQ(signature->input_arg_size(), resource_index);
    642 
    643   // Make the synthesized function's name unique.
    644   signature->set_name(tensorflow::strings::StrCat(
    645       op_def.name(), func_id_generator.fetch_add(1)));
    646 
    647   // Add the node def and set its input names to match op_def's names.
    648   const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
    649   DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
    650   *fdef.add_node_def() = ndef;
    651   for (int i = 0; i < op_def.input_arg_size(); ++i) {
    652     fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
    653   }
    654   VLOG(1) << "Added NodeDef: " << fdef.DebugString();
    655 
    656   // Fix the output names and set output types.
    657   for (int i = 0; i < op_def.output_arg_size(); ++i) {
    658     tensorflow::OpDef::ArgDef* arg = signature->add_output_arg();
    659     const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
    660     const string& out_tensor_name = tensorflow::strings::StrCat(
    661         ndef.name(), ":", op_def_arg.name(), ":", 0);
    662     arg->set_name(op_def_arg.name());
    663     (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
    664     const string& type_attr = op_def_arg.type_attr();
    665     if (!type_attr.empty()) {
    666       auto i = ndef.attr().find(type_attr);
    667       if (i == ndef.attr().end()) {
    668         status->status = tensorflow::errors::InvalidArgument(
    669             tensorflow::strings::StrCat("Could not find attr ", type_attr,
    670                                         " in NodeDef ", ndef.DebugString()));
    671         return nullptr;
    672       }
    673       arg->set_type(i->second.type());
    674     }
    675   }
    676   VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
    677 
    678   tensorflow::mutex_lock l(ctx->functions_mu);
    679   status->status = ctx->func_lib_def.AddFunctionDef(fdef);
    680   if (!status->status.ok()) return nullptr;
    681   const auto ret = ctx->func_lib_def.Find(signature->name());
    682   DCHECK(ret != nullptr);
    683   return ret;
    684 }
    685 
    686 // Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
    687 // via XLA.
    688 std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
    689   VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
    690   auto launch_op =
    691       std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status));
    692   if (TF_GetCode(status) != TF_OK) return nullptr;
    693   if (op->device) {
    694     TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
    695     if (TF_GetCode(status) != TF_OK) return nullptr;
    696   }
    697 
    698   const tensorflow::FunctionDef* fdef;
    699   {
    700     tensorflow::tf_shared_lock l(op->ctx->functions_mu);
    701     fdef = op->ctx->func_lib_def.Find(op->name);
    702   }
    703   std::vector<TF_DataType> const_input_types;
    704   std::vector<TF_DataType> arg_input_types;
    705   tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
    706   if (fdef == nullptr) {
    707     // See if this is a primitive op, and if so create a function for it, so
    708     // that _XlaLaunchOp can access it.
    709     fdef = OpToFunction(op, &const_input_types, &arg_input_types,
    710                         &op_input_to_func_input, status);
    711     if (!status->status.ok()) return nullptr;
    712   } else {
    713     // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
    714     // functions, so we need to find another way to handle constant inputs.
    715     for (int i = const_input_types.size();
    716          i < fdef->signature().input_arg_size(); ++i) {
    717       VLOG(1) << "Adding Targs from input arg " << i;
    718       const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i);
    719       arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
    720     }
    721   }
    722   DCHECK(fdef != nullptr);
    723 
    724   // Copy inputs and their devices.
    725   // Since input param reordering may have occurred between `op` and `launch_op`
    726   // via `op_input_to_func_input`, adjust the actual inputs accordingly.
    727   launch_op->inputs = op->inputs;
    728   launch_op->input_devices = op->input_devices;
    729   if (!op_input_to_func_input.empty()) {
    730     DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
    731     if (!op->input_devices.empty()) {
    732       DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size());
    733     }
    734     for (int i = 0; i < op_input_to_func_input.size(); ++i) {
    735       VLOG(1) << "mapping op input " << i << " to func input "
    736               << op_input_to_func_input[i];
    737 
    738       launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
    739       if (!op->input_devices.empty()) {
    740         launch_op->input_devices[op_input_to_func_input[i]] =
    741             op->input_devices[i];
    742       }
    743     }
    744   }
    745   launch_op->attrs.NumInputs(op->inputs.size());
    746 
    747   TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
    748                         const_input_types.size());
    749 
    750   // Set Targs and Nresources attrs.
    751   TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
    752                         arg_input_types.size());
    753   const int num_resource_inputs = fdef->signature().input_arg_size() -
    754                                   const_input_types.size() -
    755                                   arg_input_types.size();
    756   TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
    757 
    758   // Set Tresults attr.
    759   std::vector<TF_DataType> tresults;
    760   for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) {
    761     tresults.push_back(static_cast<TF_DataType>(arg.type()));
    762   }
    763   TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
    764                         tresults.size());
    765 
    766   // Set function attr.
    767   tensorflow::AttrValue attr_value;
    768   tensorflow::NameAttrList* func = attr_value.mutable_func();
    769   func->set_name(fdef->signature().name());
    770   launch_op->attrs.Set("function", attr_value);
    771 
    772   return launch_op;
    773 }
    774 #endif  // TENSORFLOW_EAGER_USE_XLA
    775 }  // namespace
    776 
    777 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
    778                  TF_Status* status) {
    779   TFE_Context* ctx = op->ctx;
    780   // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
    781   tensorflow::Device* device =
    782       (op->device == nullptr) ? ctx->devices()[0] : op->device;
    783 
    784 #ifdef TENSORFLOW_EAGER_USE_XLA
    785   std::unique_ptr<TFE_Op> xla_launch_op;
    786   if (op->use_xla && op->name != "_XlaLaunch") {
    787     xla_launch_op = BuildXlaLaunch(op, status);
    788     if (!status->status.ok()) {
    789       return;
    790     }
    791     op = xla_launch_op.get();
    792   }
    793 #endif  // TENSORFLOW_EAGER_USE_XLA
    794 
    795   std::vector<tensorflow::Tensor> outputs(1);
    796   const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
    797   tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
    798   tensorflow::KernelAndDevice* kernel;
    799   {
    800     tensorflow::tf_shared_lock l(ctx->cache_mu);
    801     kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
    802   }
    803   if (kernel == nullptr) {
    804     const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
    805     kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
    806     // Knowledge of the implementation of Init (and in-turn
    807     // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
    808     // will be accessed, so grab on to the lock.
    809     // See WARNING comment below - would be nice to rework to avoid this
    810     // subtlety.
    811     tensorflow::tf_shared_lock l(ctx->functions_mu);
    812     status->status =
    813         tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
    814     if (!status->status.ok()) {
    815       delete kernel;
    816       return;
    817     }
    818     tensorflow::mutex_lock ml(ctx->cache_mu);
    819     tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
    820   }
    821   std::vector<TFE_TensorHandle*> copied_tensors;
    822   status->status = ValidateInputTypeAndPlacement(
    823       ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors);
    824   output_memory_types = &kernel->kernel()->output_memory_types();
    825   if (!status->status.ok()) {
    826     for (auto* t : copied_tensors) {
    827       TFE_DeleteTensorHandle(t);
    828     }
    829     return;
    830   }
    831   std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
    832   if (ctx->should_store_metadata.load()) {
    833     maybe_stats.reset(new tensorflow::NodeExecStats);
    834     maybe_stats->set_node_name(op->name);
    835     maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
    836     maybe_stats->set_op_start_rel_micros(0);
    837     maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
    838     // TODO(apassos) track referenced tensors
    839   }
    840   // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
    841   // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
    842   // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
    843   // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
    844   // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
    845   // This is quite subtle. Re-work things to make this better?  (Would it make
    846   // sense for FunctionLibraryRuntime to ensure thread-safe access to
    847   // FunctionLibraryDefinition?).  TODO(apassos) figure out how to record stats
    848   // for ops which are a part of functions.
    849   status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get());
    850   for (auto* t : copied_tensors) {
    851     TFE_DeleteTensorHandle(t);
    852   }
    853   if (!status->status.ok()) return;
    854   if (maybe_stats != nullptr) {
    855     maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
    856                                        maybe_stats->all_start_micros());
    857     tensorflow::mutex_lock ml(ctx->metadata_mu);
    858     if (ctx->should_store_metadata.load()) {
    859       auto* step_stats = ctx->run_metadata.mutable_step_stats();
    860       // Lazily initialize the RunMetadata with information about all devices if
    861       // this is the first call.
    862       while (step_stats->dev_stats_size() < ctx->devices().size()) {
    863         step_stats->add_dev_stats();
    864       }
    865       // Find the current device's index.
    866       int device_idx = 0;
    867       for (int i = 0; i < ctx->devices().size(); ++i) {
    868         if (ctx->devices()[i] == device) {
    869           device_idx = i;
    870           break;
    871         }
    872       }
    873       // Populate the device stats for this device.
    874       auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
    875       dev_stats->set_device(device->name());
    876       *dev_stats->add_node_stats() = *maybe_stats;
    877     }
    878   }
    879   *num_retvals = std::min<int>(*num_retvals, outputs.size());
    880   for (int i = 0; i < *num_retvals; ++i) {
    881     tensorflow::Device* d = IsCPU(device) ? nullptr : device;
    882     if (d != nullptr && output_memory_types != nullptr &&
    883         (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
    884       d = nullptr;
    885     }
    886     retvals[i] = new TFE_TensorHandle(outputs[i], d);
    887   }
    888 }
    889 
    890 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
    891                                const char* serialized_function_def, size_t size,
    892                                TF_Status* status) {
    893   tensorflow::FunctionDef function_def;
    894   if (!function_def.ParseFromArray(serialized_function_def, size)) {
    895     status->status =
    896         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
    897     return;
    898   }
    899   tensorflow::mutex_lock l(ctx->functions_mu);
    900   status->status = ctx->func_lib_def.AddFunctionDef(function_def);
    901 }
    902 
    903 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
    904                             TF_Status* status) {
    905   tensorflow::mutex_lock l(ctx->functions_mu);
    906   status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
    907 }
    908 
    909 }  // extern "C"
    910 
    911 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
    912   return new TFE_TensorHandle(t, nullptr);
    913 }
    914 
    915 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
    916     TFE_TensorHandle* h, TF_Status* status) {
    917   if (h->d != nullptr) {
    918     status->status = tensorflow::errors::FailedPrecondition(
    919         "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
    920         "a tensorflow::Tensor");
    921     return nullptr;
    922   }
    923   return &h->t;
    924 }
    925 
    926 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
    927   ctx->should_store_metadata.store(true);
    928 }
    929 
    930 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
    931   tensorflow::mutex_lock ml(ctx->metadata_mu);
    932   ctx->should_store_metadata.store(false);
    933   ctx->run_metadata.Clear();
    934 }
    935 
    936 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
    937                                   TF_Status* status) {
    938   tensorflow::mutex_lock ml(ctx->metadata_mu);
    939   status->status = MessageToBuffer(ctx->run_metadata, buf);
    940   ctx->run_metadata.Clear();
    941 }
    942