1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/c/eager/c_api.h" 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/c/c_api.h" 25 #include "tensorflow/c/c_api_internal.h" 26 #include "tensorflow/c/eager/c_api_internal.h" 27 #include "tensorflow/c/eager/runtime.h" 28 #ifdef TENSORFLOW_EAGER_USE_XLA 29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 30 #endif // TENSORFLOW_EAGER_USE_XLA 31 #include "tensorflow/core/common_runtime/copy_tensor.h" 32 #include "tensorflow/core/common_runtime/device_factory.h" 33 #include "tensorflow/core/common_runtime/device_mgr.h" 34 #include "tensorflow/core/common_runtime/function.h" 35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 36 #include "tensorflow/core/framework/rendezvous.h" 37 #include "tensorflow/core/framework/tensor_shape.pb.h" 38 #include "tensorflow/core/framework/types.h" 39 #include "tensorflow/core/lib/core/refcount.h" 40 #include "tensorflow/core/lib/gtl/flatmap.h" 41 #include "tensorflow/core/lib/gtl/map_util.h" 42 #include "tensorflow/core/lib/gtl/stl_util.h" 43 #include "tensorflow/core/platform/mutex.h" 44 #include "tensorflow/core/platform/thread_annotations.h" 45 #include "tensorflow/core/public/version.h" 46 47 using tensorflow::int64; 48 using tensorflow::string; 49 50 namespace { 51 bool IsCPU(const tensorflow::Device* d) { 52 return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; 53 } 54 55 bool IsXLA(const tensorflow::Device* d) { 56 if (d == nullptr) return false; 57 const auto& device_type = d->attributes().device_type(); 58 return device_type.find("XLA") != std::string::npos; 59 } 60 61 string DeviceName(const tensorflow::Device* d) { 62 return (d == nullptr) ? "cpu:0" : d->name(); 63 } 64 65 #ifdef TENSORFLOW_EAGER_USE_XLA 66 std::atomic_int_fast64_t func_id_generator(0); 67 #endif // TENSORFLOW_EAGER_USE_XLA 68 } // namespace 69 70 extern "C" { 71 72 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } 73 74 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, 75 size_t proto_len, TF_Status* status) { 76 TF_SetConfig(&options->session_options, proto, proto_len, status); 77 } 78 79 void TFE_ContextOptionsSetDevicePlacementPolicy( 80 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { 81 options->policy = policy; 82 } 83 84 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } 85 86 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { 87 TF_Graph* graph = TF_NewGraph(); 88 TF_Session* session = TF_NewSession(graph, &opts->session_options, status); 89 if (status->status.ok()) { 90 if (session->device_mgr == nullptr || session->devices.empty()) { 91 status->status = tensorflow::errors::InvalidArgument( 92 "Provided TF_SessionOptions are not compatible with eager execution " 93 "(perhaps the TF_SessionOptions alluded to session execution in a " 94 "remote address space?)"); 95 } 96 } 97 if (!status->status.ok()) { 98 TF_DeleteGraph(graph); 99 return nullptr; 100 } 101 102 return new TFE_Context(*opts, session); 103 } 104 105 void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { 106 status->status = tensorflow::Status::OK(); 107 { 108 tensorflow::mutex_lock ml(ctx->cache_mu); 109 tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); 110 } 111 TF_Graph* graph = ctx->session->graph; 112 TF_DeleteSession(ctx->session, status); 113 TF_DeleteGraph(graph); 114 ctx->rendezvous->Unref(); 115 delete ctx; 116 } 117 118 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { 119 return TF_SessionListDevices(ctx->session, status); 120 } 121 122 void TFE_ContextClearCaches(TFE_Context* ctx) { 123 tensorflow::mutex_lock ml(ctx->cache_mu); 124 tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); 125 } 126 127 void TFE_ContextSetThreadLocalDevicePlacementPolicy( 128 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { 129 tensorflow::mutex_lock ml(ctx->policy_map_mu); 130 ctx->thread_local_policies[std::this_thread::get_id()] = policy; 131 } 132 133 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( 134 TFE_Context* ctx) { 135 tensorflow::mutex_lock ml(ctx->policy_map_mu); 136 auto policy_map_it = 137 ctx->thread_local_policies.find(std::this_thread::get_id()); 138 if (policy_map_it != ctx->thread_local_policies.end()) { 139 return policy_map_it->second; 140 } 141 return ctx->policy; 142 } 143 144 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { 145 tensorflow::Tensor tensor; 146 status->status = tensorflow::TF_TensorToTensor(t, &tensor); 147 if (!status->status.ok()) return nullptr; 148 return new TFE_TensorHandle(tensor, nullptr); 149 } 150 151 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; } 152 153 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { 154 return static_cast<TF_DataType>(h->t.dtype()); 155 } 156 157 int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); } 158 159 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) { 160 return h->t.dim_size(dim_index); 161 } 162 163 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) { 164 // TODO(apassos) this will be potentially incorrect in the distributed case as 165 // our local device will have a name which depends on the ClusterSpec and 166 // hence will require the context to resolve. 167 return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" 168 : h->d->name().c_str(); 169 } 170 171 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { 172 if (!IsCPU(h->d)) { 173 TF_SetStatus(status, TF_UNIMPLEMENTED, 174 tensorflow::strings::StrCat( 175 "TFE_TensorHandle can be resolved iff it is on CPU (this " 176 "handle is on ", 177 h->d->name(), 178 "). Consider using TFE_TensorHandleCopyToDevice to get a " 179 "copy of the tensor on CPU") 180 .c_str()); 181 return nullptr; 182 } 183 return tensorflow::TF_TensorFromTensor(h->t, status); 184 } 185 186 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, 187 TFE_Context* ctx, 188 const char* device_name, 189 TF_Status* status) { 190 tensorflow::Device* dstd = ctx->devices()[0]; 191 if (device_name != nullptr && strlen(device_name) > 0) { 192 status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd); 193 if (!status->status.ok()) return nullptr; 194 } 195 196 tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d; 197 bool is_same_device = 198 (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); 199 const bool dst_cpu = IsCPU(dstd); 200 const bool src_cpu = IsCPU(srcd); 201 // both_on_cpu can be true and yet is_same_device is false, if one of src/dst 202 // has device type XLA_CPU, and the other CPU. 203 const bool both_on_cpu = src_cpu && dst_cpu; 204 if (is_same_device || both_on_cpu) { 205 return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); 206 } 207 tensorflow::Tensor* src = &(h->t); 208 if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && 209 !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { 210 TF_SetStatus( 211 status, TF_INVALID_ARGUMENT, 212 tensorflow::strings::StrCat("Can't copy Tensor with type ", 213 tensorflow::DataTypeString(src->dtype()), 214 " to device ", DeviceName(dstd), ".") 215 .c_str()); 216 return nullptr; 217 } 218 tensorflow::AllocatorAttributes attr; 219 if (src->dtype() == tensorflow::DT_VARIANT) { 220 attr.set_on_host(true); 221 } 222 tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); 223 if (src->shape().num_elements() == 0) { 224 return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd); 225 } 226 tensorflow::DeviceContext* src_device_context = nullptr; 227 if (!src_cpu) { 228 src_device_context = srcd->tensorflow_gpu_device_info()->default_context; 229 } 230 tensorflow::DeviceContext* dst_device_context = nullptr; 231 if (!dst_cpu) { 232 dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; 233 } 234 // TODO(ashankar): The Sync() call below may be more aggressive than 235 // necessary. It is based on knowledge of implementation details - that 236 // GPU devices are implemented using 3 streams - one for host->device copies, 237 // one for device->host copies and one for sending operations to the GPU. 238 // With that setup, Sync()ing across all 3 streams should be sufficient 239 // but more than necessary (since it waits for operations that might have 240 // nothing to do with this tensor to complete). 241 status->status = srcd->Sync(); 242 tensorflow::Notification n; 243 tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, 244 srcd, dstd, tensorflow::AllocatorAttributes(), 245 tensorflow::AllocatorAttributes(), src, &dst, 246 [status, &n](const tensorflow::Status& s) { 247 status->status = s; 248 n.Notify(); 249 }); 250 n.WaitForNotification(); 251 return (TF_GetCode(status) == TF_OK) 252 ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd) 253 : nullptr; 254 } 255 256 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, 257 TF_Status* status) { 258 const char* name = op_or_function_name; // Shorthand 259 const tensorflow::AttrTypeMap* types; 260 status->status = tensorflow::AttrTypeMapForOp(name, &types); 261 if (status->status.ok()) return new TFE_Op(ctx, name, types); 262 if (TF_GetCode(status) == TF_NOT_FOUND) { 263 tensorflow::mutex_lock l(ctx->functions_mu); 264 if (ctx->func_lib_def.Find(name) != nullptr) { 265 status->status = tensorflow::Status::OK(); 266 return new TFE_Op(ctx, name, nullptr); 267 } 268 } 269 return nullptr; 270 } 271 272 void TFE_DeleteOp(TFE_Op* op) { delete op; } 273 274 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { 275 tensorflow::Device* d = nullptr; 276 if (device_name != nullptr && strlen(device_name) > 0) { 277 status->status = 278 op->ctx->session->device_mgr->LookupDevice(device_name, &d); 279 if (!status->status.ok()) return; 280 } 281 op->device = d; 282 } 283 284 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { 285 tensorflow::Device* device = 286 (op->device == nullptr) ? op->ctx->devices()[0] : op->device; 287 return device->name().c_str(); 288 } 289 290 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { 291 op->use_xla = enable; 292 #ifndef TENSORFLOW_EAGER_USE_XLA 293 LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " 294 "built with XLA support."; 295 #endif // TENSORFLOW_EAGER_USE_XLA 296 } 297 298 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { 299 // Questionable heuristic ... 300 // 301 // Motivation: After an 'op' is placed on GPU because some of its earlier 302 // inputs are on GPU, we want to keep the 'op' there, even if some later 303 // inputs of it are not on GPU. 304 if (IsCPU(op->device) && !IsCPU(h->d)) { 305 op->device = h->d; 306 } 307 if (!status->status.ok()) return; 308 op->inputs.push_back(h->t); 309 op->input_devices.push_back(h->d); 310 op->attrs.NumInputs(op->inputs.size()); 311 } 312 313 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, 314 unsigned char* is_list, TF_Status* status) { 315 TF_AttrType ret; 316 if (op->is_function()) { 317 status->status = tensorflow::errors::Unimplemented( 318 "TODO(apassos): Support for attributes for TensorFlow functions is not " 319 "ready yet."); 320 return TF_ATTR_INT; // The compiler requires that we return something. 321 } 322 status->status = 323 tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list); 324 return ret; 325 } 326 327 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, 328 const char* op_or_function_name, 329 const char* attr_name, unsigned char* is_list, 330 TF_Status* status) { 331 TF_AttrType ret; 332 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); 333 if (!status->status.ok()) { 334 return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. 335 } 336 ret = TFE_OpGetAttrType(op, attr_name, is_list, status); 337 TFE_DeleteOp(op); 338 return ret; 339 } 340 341 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { 342 op->attrs.Set(attr_name, value); 343 } 344 345 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { 346 op->attrs.Set(attr_name, static_cast<int64>(value)); 347 } 348 349 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { 350 op->attrs.Set(attr_name, value); 351 } 352 353 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { 354 op->attrs.Set(attr_name, (value == 0) ? false : true); 355 } 356 357 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { 358 op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value)); 359 } 360 361 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, 362 const int num_dims, TF_Status* out_status) { 363 if (num_dims > tensorflow::TensorShape::MaxDimensions()) { 364 TF_SetStatus(out_status, TF_INVALID_ARGUMENT, 365 tensorflow::strings::StrCat( 366 "Value specified for `", attr_name, "` has ", num_dims, 367 " dimensions which is over the limit of ", 368 tensorflow::TensorShape::MaxDimensions(), ".") 369 .c_str()); 370 return; 371 } 372 tensorflow::TensorShapeProto proto; 373 if (num_dims < 0) { 374 proto.set_unknown_rank(true); 375 } else { 376 for (int d = 0; d < num_dims; ++d) { 377 proto.add_dim()->set_size(dims[d]); 378 } 379 } 380 op->attrs.Set(attr_name, proto); 381 } 382 383 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, 384 const TFE_Op* value) { 385 tensorflow::AttrValue attr_value; 386 tensorflow::NameAttrList* func = attr_value.mutable_func(); 387 func->set_name(value->name); 388 value->attrs.FillAttrValueMap(func->mutable_attr()); 389 op->attrs.Set(attr_name, attr_value); 390 } 391 392 #define TFE_OP_SET_ATTR_LIST(fn, type) \ 393 void fn(TFE_Op* op, const char* attr_name, const type* values, \ 394 int num_values) { \ 395 op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \ 396 values, num_values)); \ 397 } 398 TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) 399 TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) 400 #undef TFE_OP_SET_ATTR_LIST 401 402 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, 403 const int64_t* values, int num_values) { 404 op->attrs.Set(attr_name, 405 tensorflow::gtl::ArraySlice<const int64>( 406 reinterpret_cast<const int64*>(values), num_values)); 407 } 408 409 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, 410 const TF_DataType* values, int num_values) { 411 op->attrs.Set( 412 attr_name, 413 tensorflow::gtl::ArraySlice<const tensorflow::DataType>( 414 reinterpret_cast<const tensorflow::DataType*>(values), num_values)); 415 } 416 417 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, 418 const unsigned char* values, int num_values) { 419 std::unique_ptr<bool[]> b(new bool[num_values]); 420 for (int i = 0; i < num_values; ++i) { 421 b[i] = values[i]; 422 } 423 op->attrs.Set(attr_name, 424 tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values)); 425 } 426 427 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, 428 const int64_t** dims, const int* num_dims, 429 int num_values, TF_Status* out_status) { 430 std::unique_ptr<tensorflow::TensorShapeProto[]> proto( 431 new tensorflow::TensorShapeProto[num_values]); 432 for (int i = 0; i < num_values; ++i) { 433 const auto num_dims_i = num_dims[i]; 434 435 if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) { 436 TF_SetStatus(out_status, TF_INVALID_ARGUMENT, 437 tensorflow::strings::StrCat( 438 "Value specified for `", attr_name, "` has ", num_dims_i, 439 " dimensions which is over the limit of ", 440 tensorflow::TensorShape::MaxDimensions(), ".") 441 .c_str()); 442 return; 443 } 444 if (num_dims_i < 0) { 445 proto[i].set_unknown_rank(true); 446 } else { 447 const int64_t* dims_i = dims[i]; 448 auto proto_i = &proto[i]; 449 for (int d = 0; d < num_dims_i; ++d) { 450 proto_i->add_dim()->set_size(dims_i[d]); 451 } 452 } 453 } 454 op->attrs.Set(attr_name, 455 tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>( 456 proto.get(), num_values)); 457 } 458 459 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, 460 const TFE_Op** value, int num_values) { 461 std::unique_ptr<tensorflow::NameAttrList[]> funcs( 462 new tensorflow::NameAttrList[num_values]); 463 for (int i = 0; i < num_values; i++) { 464 funcs[i].set_name(value[i]->name); 465 value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr()); 466 } 467 op->attrs.Set(attr_name, 468 tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>( 469 funcs.get(), num_values)); 470 } 471 472 namespace { 473 474 tensorflow::Status ValidateInputTypeAndPlacement( 475 TFE_Context* ctx, tensorflow::Device* host_device, 476 tensorflow::Device* op_device, TFE_Op* op, 477 const tensorflow::OpKernel* kernel, 478 std::vector<TFE_TensorHandle*>* copied_tensors) { 479 const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); 480 if (memtypes.size() != op->inputs.size()) { 481 return tensorflow::errors::InvalidArgument( 482 "expected ", memtypes.size(), " inputs, got ", op->inputs.size()); 483 } 484 for (int i = 0; i < op->inputs.size(); ++i) { 485 const tensorflow::Device* expected_device = 486 memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; 487 const tensorflow::Device* actual_device = 488 op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; 489 if (expected_device != actual_device) { 490 switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { 491 case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: 492 // TODO(xpan): See if we could bubble python related error up 493 // to python level. 494 if (op->inputs[i].dtype() == tensorflow::DT_INT32) { 495 // Note: enabling silent copies of int32 tensors to match behavior 496 // of graph mode. 497 break; 498 } 499 TF_FALLTHROUGH_INTENDED; 500 case TFE_DEVICE_PLACEMENT_EXPLICIT: 501 return tensorflow::errors::InvalidArgument( 502 "Tensors on conflicting devices:" 503 " cannot compute ", 504 op->name, " as input #", i, " was expected to be on ", 505 expected_device->name(), " but is actually on ", 506 actual_device->name(), " (operation running on ", 507 op_device->name(), ")", 508 " Tensors can be copied explicitly using .gpu() or .cpu()," 509 " or transparently copied by using tfe.enable_eager_execution(" 510 "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices" 511 " may slow down your model"); 512 case TFE_DEVICE_PLACEMENT_WARN: 513 LOG(WARNING) << "before computing " << op->name << " input #" << i 514 << " was expected to be on " << expected_device->name() 515 << " but is actually on " << actual_device->name() 516 << " (operation running on " << op_device->name() 517 << "). This triggers a copy which can be a performance " 518 "bottleneck."; 519 break; 520 case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. 521 break; 522 } 523 // We are only here if the policy is warn or silent copies, so we should 524 // trigger a copy. 525 TFE_TensorHandle original{op->inputs[i], op->input_devices[i]}; 526 TF_Status* s = TF_NewStatus(); 527 TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( 528 &original, ctx, expected_device->name().c_str(), s); 529 if (!s->status.ok()) { 530 tensorflow::Status status = s->status; 531 delete s; 532 return tensorflow::errors::Internal( 533 "Failed copying input tensor from ", actual_device->name(), " to ", 534 expected_device->name(), " in order to run ", op->name, ": ", 535 status.error_message()); 536 } 537 op->inputs[i] = copied_tensor->t; 538 copied_tensors->push_back(copied_tensor); 539 op->input_devices[i] = copied_tensor->d; 540 delete s; 541 } 542 if (op->inputs[i].dtype() != kernel->input_type(i)) { 543 return tensorflow::errors::InvalidArgument( 544 "cannot compute ", op->name, " as input #", i, 545 " was expected to be a ", 546 tensorflow::DataTypeString(kernel->input_type(i)), 547 " tensor but is a ", 548 tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor"); 549 } 550 } 551 return tensorflow::Status::OK(); 552 } 553 554 #ifdef TENSORFLOW_EAGER_USE_XLA 555 // Synthesizes and returns a wrapper function over `op`, which must be a 556 // primitive op (e.g. matmul). 557 // 558 // The wrapper function conforms to the function signature expected by 559 // _XlaLaunchOp, with input params ordered by <constants, (variable) args and 560 // resources>. For example, if the op has input params <Const1, Arg2, Const3, 561 // Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5, 562 // Resource4> as the input params to the synthesized function. 563 // 564 // It populates `const_input_types`, `arg_input_types` and 565 // `op_input_to_func_input` based on the reordering results, that the caller can 566 // use them to build an _XlaLaunchOp. On error, it returns NULL, and sets 567 // `status` accordingly. 568 const tensorflow::FunctionDef* OpToFunction( 569 TFE_Op* op, std::vector<TF_DataType>* const_input_types, 570 std::vector<TF_DataType>* arg_input_types, 571 tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input, 572 TF_Status* status) { 573 DCHECK(!op->is_function()); 574 575 tensorflow::FunctionDef fdef; 576 577 // Get the OpDef of the op we are trying to encapsulate. 578 TFE_Context* ctx = op->ctx; 579 const tensorflow::OpRegistrationData* op_data; 580 { 581 tensorflow::tf_shared_lock l(ctx->functions_mu); 582 status->status = ctx->func_lib_def.LookUp(op->name, &op_data); 583 if (!status->status.ok()) { 584 return nullptr; 585 } 586 } 587 const tensorflow::OpDef& op_def = op_data->op_def; 588 589 tensorflow::OpDef* signature = fdef.mutable_signature(); 590 591 // Handle constant inputs. 592 const std::unordered_set<string> const_inputs( 593 *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name)); 594 595 // First add place holders for the input args, so that we can refer to them by 596 // position in the next loop. Also tally up the resource inputs. 597 int num_resource_inputs = 0; 598 for (int i = 0; i < op_def.input_arg_size(); ++i) { 599 if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) { 600 ++num_resource_inputs; 601 } 602 signature->add_input_arg(); 603 } 604 605 // Now we map the input params from `op_def` to `signature`, where the param 606 // ordering for `signature` is: <constants, args, resources>. 607 int const_index = 0; 608 int arg_index = const_inputs.size(); 609 int resource_index = op_def.input_arg_size() - num_resource_inputs; 610 for (int i = 0; i < op_def.input_arg_size(); ++i) { 611 const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i); 612 tensorflow::OpDef::ArgDef* func_input_arg = nullptr; 613 if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) { 614 VLOG(1) << "For const input, mapping op input " << i << " to func input " 615 << const_index; 616 (*op_input_to_func_input)[i] = const_index; 617 func_input_arg = signature->mutable_input_arg(const_index++); 618 const_input_types->push_back( 619 static_cast<TF_DataType>(op->inputs[i].dtype())); 620 } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { 621 VLOG(1) << "For resource input, mapping op input " << i 622 << " to func input " << resource_index; 623 (*op_input_to_func_input)[i] = resource_index; 624 func_input_arg = signature->mutable_input_arg(resource_index++); 625 } else { 626 VLOG(1) << "For arg input, mapping op input " << i << " to func input " 627 << arg_index; 628 (*op_input_to_func_input)[i] = arg_index; 629 func_input_arg = signature->mutable_input_arg(arg_index++); 630 arg_input_types->push_back( 631 static_cast<TF_DataType>(op->inputs[i].dtype())); 632 } 633 634 func_input_arg->set_name(op_input_arg.name()); 635 func_input_arg->set_type(op->inputs[i].dtype()); 636 } 637 VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); 638 639 // Resources args are at the end of the function input params, and we should 640 // have iterated over all of them. 641 DCHECK_EQ(signature->input_arg_size(), resource_index); 642 643 // Make the synthesized function's name unique. 644 signature->set_name(tensorflow::strings::StrCat( 645 op_def.name(), func_id_generator.fetch_add(1))); 646 647 // Add the node def and set its input names to match op_def's names. 648 const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); 649 DCHECK_EQ(signature->input_arg_size(), ndef.input_size()); 650 *fdef.add_node_def() = ndef; 651 for (int i = 0; i < op_def.input_arg_size(); ++i) { 652 fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name()); 653 } 654 VLOG(1) << "Added NodeDef: " << fdef.DebugString(); 655 656 // Fix the output names and set output types. 657 for (int i = 0; i < op_def.output_arg_size(); ++i) { 658 tensorflow::OpDef::ArgDef* arg = signature->add_output_arg(); 659 const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i); 660 const string& out_tensor_name = tensorflow::strings::StrCat( 661 ndef.name(), ":", op_def_arg.name(), ":", 0); 662 arg->set_name(op_def_arg.name()); 663 (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name; 664 const string& type_attr = op_def_arg.type_attr(); 665 if (!type_attr.empty()) { 666 auto i = ndef.attr().find(type_attr); 667 if (i == ndef.attr().end()) { 668 status->status = tensorflow::errors::InvalidArgument( 669 tensorflow::strings::StrCat("Could not find attr ", type_attr, 670 " in NodeDef ", ndef.DebugString())); 671 return nullptr; 672 } 673 arg->set_type(i->second.type()); 674 } 675 } 676 VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); 677 678 tensorflow::mutex_lock l(ctx->functions_mu); 679 status->status = ctx->func_lib_def.AddFunctionDef(fdef); 680 if (!status->status.ok()) return nullptr; 681 const auto ret = ctx->func_lib_def.Find(signature->name()); 682 DCHECK(ret != nullptr); 683 return ret; 684 } 685 686 // Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed 687 // via XLA. 688 std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) { 689 VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name; 690 auto launch_op = 691 std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status)); 692 if (TF_GetCode(status) != TF_OK) return nullptr; 693 if (op->device) { 694 TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status); 695 if (TF_GetCode(status) != TF_OK) return nullptr; 696 } 697 698 const tensorflow::FunctionDef* fdef; 699 { 700 tensorflow::tf_shared_lock l(op->ctx->functions_mu); 701 fdef = op->ctx->func_lib_def.Find(op->name); 702 } 703 std::vector<TF_DataType> const_input_types; 704 std::vector<TF_DataType> arg_input_types; 705 tensorflow::gtl::FlatMap<int, int> op_input_to_func_input; 706 if (fdef == nullptr) { 707 // See if this is a primitive op, and if so create a function for it, so 708 // that _XlaLaunchOp can access it. 709 fdef = OpToFunction(op, &const_input_types, &arg_input_types, 710 &op_input_to_func_input, status); 711 if (!status->status.ok()) return nullptr; 712 } else { 713 // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for 714 // functions, so we need to find another way to handle constant inputs. 715 for (int i = const_input_types.size(); 716 i < fdef->signature().input_arg_size(); ++i) { 717 VLOG(1) << "Adding Targs from input arg " << i; 718 const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i); 719 arg_input_types.push_back(static_cast<TF_DataType>(arg.type())); 720 } 721 } 722 DCHECK(fdef != nullptr); 723 724 // Copy inputs and their devices. 725 // Since input param reordering may have occurred between `op` and `launch_op` 726 // via `op_input_to_func_input`, adjust the actual inputs accordingly. 727 launch_op->inputs = op->inputs; 728 launch_op->input_devices = op->input_devices; 729 if (!op_input_to_func_input.empty()) { 730 DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); 731 if (!op->input_devices.empty()) { 732 DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size()); 733 } 734 for (int i = 0; i < op_input_to_func_input.size(); ++i) { 735 VLOG(1) << "mapping op input " << i << " to func input " 736 << op_input_to_func_input[i]; 737 738 launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; 739 if (!op->input_devices.empty()) { 740 launch_op->input_devices[op_input_to_func_input[i]] = 741 op->input_devices[i]; 742 } 743 } 744 } 745 launch_op->attrs.NumInputs(op->inputs.size()); 746 747 TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(), 748 const_input_types.size()); 749 750 // Set Targs and Nresources attrs. 751 TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(), 752 arg_input_types.size()); 753 const int num_resource_inputs = fdef->signature().input_arg_size() - 754 const_input_types.size() - 755 arg_input_types.size(); 756 TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs); 757 758 // Set Tresults attr. 759 std::vector<TF_DataType> tresults; 760 for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) { 761 tresults.push_back(static_cast<TF_DataType>(arg.type())); 762 } 763 TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(), 764 tresults.size()); 765 766 // Set function attr. 767 tensorflow::AttrValue attr_value; 768 tensorflow::NameAttrList* func = attr_value.mutable_func(); 769 func->set_name(fdef->signature().name()); 770 launch_op->attrs.Set("function", attr_value); 771 772 return launch_op; 773 } 774 #endif // TENSORFLOW_EAGER_USE_XLA 775 } // namespace 776 777 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, 778 TF_Status* status) { 779 TFE_Context* ctx = op->ctx; 780 // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU 781 tensorflow::Device* device = 782 (op->device == nullptr) ? ctx->devices()[0] : op->device; 783 784 #ifdef TENSORFLOW_EAGER_USE_XLA 785 std::unique_ptr<TFE_Op> xla_launch_op; 786 if (op->use_xla && op->name != "_XlaLaunch") { 787 xla_launch_op = BuildXlaLaunch(op, status); 788 if (!status->status.ok()) { 789 return; 790 } 791 op = xla_launch_op.get(); 792 } 793 #endif // TENSORFLOW_EAGER_USE_XLA 794 795 std::vector<tensorflow::Tensor> outputs(1); 796 const tensorflow::MemoryTypeVector* output_memory_types = nullptr; 797 tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); 798 tensorflow::KernelAndDevice* kernel; 799 { 800 tensorflow::tf_shared_lock l(ctx->cache_mu); 801 kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); 802 } 803 if (kernel == nullptr) { 804 const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); 805 kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); 806 // Knowledge of the implementation of Init (and in-turn 807 // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def 808 // will be accessed, so grab on to the lock. 809 // See WARNING comment below - would be nice to rework to avoid this 810 // subtlety. 811 tensorflow::tf_shared_lock l(ctx->functions_mu); 812 status->status = 813 tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); 814 if (!status->status.ok()) { 815 delete kernel; 816 return; 817 } 818 tensorflow::mutex_lock ml(ctx->cache_mu); 819 tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); 820 } 821 std::vector<TFE_TensorHandle*> copied_tensors; 822 status->status = ValidateInputTypeAndPlacement( 823 ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors); 824 output_memory_types = &kernel->kernel()->output_memory_types(); 825 if (!status->status.ok()) { 826 for (auto* t : copied_tensors) { 827 TFE_DeleteTensorHandle(t); 828 } 829 return; 830 } 831 std::unique_ptr<tensorflow::NodeExecStats> maybe_stats; 832 if (ctx->should_store_metadata.load()) { 833 maybe_stats.reset(new tensorflow::NodeExecStats); 834 maybe_stats->set_node_name(op->name); 835 maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); 836 maybe_stats->set_op_start_rel_micros(0); 837 maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); 838 // TODO(apassos) track referenced tensors 839 } 840 // WARNING: kernel->Run utilizes the FunctionLibraryRuntime 841 // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, 842 // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation 843 // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by 844 // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. 845 // This is quite subtle. Re-work things to make this better? (Would it make 846 // sense for FunctionLibraryRuntime to ensure thread-safe access to 847 // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats 848 // for ops which are a part of functions. 849 status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get()); 850 for (auto* t : copied_tensors) { 851 TFE_DeleteTensorHandle(t); 852 } 853 if (!status->status.ok()) return; 854 if (maybe_stats != nullptr) { 855 maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - 856 maybe_stats->all_start_micros()); 857 tensorflow::mutex_lock ml(ctx->metadata_mu); 858 if (ctx->should_store_metadata.load()) { 859 auto* step_stats = ctx->run_metadata.mutable_step_stats(); 860 // Lazily initialize the RunMetadata with information about all devices if 861 // this is the first call. 862 while (step_stats->dev_stats_size() < ctx->devices().size()) { 863 step_stats->add_dev_stats(); 864 } 865 // Find the current device's index. 866 int device_idx = 0; 867 for (int i = 0; i < ctx->devices().size(); ++i) { 868 if (ctx->devices()[i] == device) { 869 device_idx = i; 870 break; 871 } 872 } 873 // Populate the device stats for this device. 874 auto* dev_stats = step_stats->mutable_dev_stats(device_idx); 875 dev_stats->set_device(device->name()); 876 *dev_stats->add_node_stats() = *maybe_stats; 877 } 878 } 879 *num_retvals = std::min<int>(*num_retvals, outputs.size()); 880 for (int i = 0; i < *num_retvals; ++i) { 881 tensorflow::Device* d = IsCPU(device) ? nullptr : device; 882 if (d != nullptr && output_memory_types != nullptr && 883 (*output_memory_types)[i] == tensorflow::HOST_MEMORY) { 884 d = nullptr; 885 } 886 retvals[i] = new TFE_TensorHandle(outputs[i], d); 887 } 888 } 889 890 void TFE_ContextAddFunctionDef(TFE_Context* ctx, 891 const char* serialized_function_def, size_t size, 892 TF_Status* status) { 893 tensorflow::FunctionDef function_def; 894 if (!function_def.ParseFromArray(serialized_function_def, size)) { 895 status->status = 896 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); 897 return; 898 } 899 tensorflow::mutex_lock l(ctx->functions_mu); 900 status->status = ctx->func_lib_def.AddFunctionDef(function_def); 901 } 902 903 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, 904 TF_Status* status) { 905 tensorflow::mutex_lock l(ctx->functions_mu); 906 status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); 907 } 908 909 } // extern "C" 910 911 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { 912 return new TFE_TensorHandle(t, nullptr); 913 } 914 915 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( 916 TFE_TensorHandle* h, TF_Status* status) { 917 if (h->d != nullptr) { 918 status->status = tensorflow::errors::FailedPrecondition( 919 "TFE_TensorHandle is placed in device (not host) memory. Cannot return " 920 "a tensorflow::Tensor"); 921 return nullptr; 922 } 923 return &h->t; 924 } 925 926 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { 927 ctx->should_store_metadata.store(true); 928 } 929 930 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { 931 tensorflow::mutex_lock ml(ctx->metadata_mu); 932 ctx->should_store_metadata.store(false); 933 ctx->run_metadata.Clear(); 934 } 935 936 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, 937 TF_Status* status) { 938 tensorflow::mutex_lock ml(ctx->metadata_mu); 939 status->status = MessageToBuffer(ctx->run_metadata, buf); 940 ctx->run_metadata.Clear(); 941 } 942