Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
     16 
     17 #include <utility>
     18 
     19 #include "absl/strings/str_join.h"
     20 #include "tensorflow/core/common_runtime/device_set.h"
     21 #include "tensorflow/core/common_runtime/function.h"
     22 #include "tensorflow/core/common_runtime/optimization_registry.h"
     23 #include "tensorflow/core/common_runtime/partitioning_utils.h"
     24 #include "tensorflow/core/common_runtime/placer.h"
     25 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
     26 #include "tensorflow/core/common_runtime/rendezvous_util.h"
     27 #include "tensorflow/core/framework/function.h"
     28 #include "tensorflow/core/framework/graph_to_functiondef.h"
     29 #include "tensorflow/core/framework/op_kernel.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/graph/graph.h"
     33 #include "tensorflow/core/graph/graph_constructor.h"
     34 #include "tensorflow/core/graph/graph_partition.h"
     35 #include "tensorflow/core/lib/core/errors.h"
     36 #include "tensorflow/core/lib/gtl/map_util.h"
     37 #include "tensorflow/core/util/device_name_utils.h"
     38 #include "tensorflow/core/util/ptr_util.h"
     39 #include "tensorflow/core/util/reffed_status_callback.h"
     40 
     41 namespace tensorflow {
     42 
     43 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
     44 
     45 Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
     46     DistributedFunctionLibraryRuntime* parent, const string& function_name,
     47     const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
     48     const FunctionLibraryRuntime::InstantiateOptions& options) {
     49   mutex_lock l(mu_);
     50   if (!init_started_) {
     51     init_started_ = true;
     52     init_result_ = parent->Instantiate(function_name, lib_def, attrs, options,
     53                                        &local_handle_);
     54   }
     55   return init_result_;
     56 }
     57 
     58 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
     59     const DeviceMgr* device_mgr, Env* env, int graph_def_version,
     60     const FunctionLibraryDefinition* lib_def,
     61     const OptimizerOptions& optimizer_options,
     62     thread::ThreadPool* default_thread_pool,
     63     DistributedFunctionLibraryRuntime* parent)
     64     : env_(env),
     65       device_mgr_(device_mgr),
     66       lib_def_(lib_def),
     67       default_thread_pool_(default_thread_pool),
     68       next_handle_(0),
     69       parent_(parent) {
     70   if (device_mgr == nullptr) {
     71     flr_map_[nullptr] = NewFunctionLibraryRuntime(
     72         nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool,
     73         optimizer_options, this);
     74     return;
     75   }
     76   for (Device* d : device_mgr->ListDevices()) {
     77     flr_map_[d] = NewFunctionLibraryRuntime(
     78         device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool,
     79         optimizer_options, this);
     80   }
     81 }
     82 
     83 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
     84     const DeviceMgr* device_mgr, Env* env, int graph_def_version,
     85     const FunctionLibraryDefinition* lib_def,
     86     const OptimizerOptions& optimizer_options,
     87     CustomKernelCreator custom_kernel_creator,
     88     thread::ThreadPool* default_thread_pool,
     89     DistributedFunctionLibraryRuntime* parent)
     90     : env_(env),
     91       device_mgr_(device_mgr),
     92       lib_def_(lib_def),
     93       default_thread_pool_(default_thread_pool),
     94       next_handle_(0),
     95       parent_(parent) {
     96   if (device_mgr == nullptr) {
     97     flr_map_[nullptr] = NewFunctionLibraryRuntime(
     98         nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool,
     99         optimizer_options, std::move(custom_kernel_creator), this);
    100     return;
    101   }
    102   for (Device* d : device_mgr->ListDevices()) {
    103     flr_map_[d] = NewFunctionLibraryRuntime(
    104         device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool,
    105         optimizer_options, custom_kernel_creator, this);
    106   }
    107 }
    108 
    109 /* static */
    110 Status ProcessFunctionLibraryRuntime::SendTensors(
    111     const string& source_device, const string& target_device,
    112     const string& key_prefix, int64 src_incarnation,
    113     gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
    114     const std::vector<AllocatorAttributes>& alloc_attrs,
    115     Rendezvous* rendezvous) {
    116   std::vector<string> keys;
    117   for (int i = 0; i < tensors_to_send.size(); ++i) {
    118     string name = strings::StrCat(key_prefix, i);
    119     string key = Rendezvous::CreateKey(source_device, src_incarnation,
    120                                        target_device, name, FrameAndIter(0, 0));
    121     keys.push_back(key);
    122   }
    123   TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
    124       rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
    125   return Status::OK();
    126 }
    127 
    128 /* static */
    129 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
    130     const string& source_device, const string& target_device,
    131     const string& key_prefix, int64 src_incarnation, int64 num_tensors,
    132     DeviceContext* device_context,
    133     const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
    134     std::vector<Tensor>* received_tensors, StatusCallback done) {
    135   std::vector<string> keys;
    136   for (int64 i = 0; i < num_tensors; ++i) {
    137     string name = strings::StrCat(key_prefix, i);
    138     string key = Rendezvous::CreateKey(source_device, src_incarnation,
    139                                        target_device, name, FrameAndIter(0, 0));
    140     keys.push_back(key);
    141   }
    142   RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
    143                                  received_tensors, std::move(done));
    144 }
    145 
    146 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
    147     const string& device_name, int64* incarnation) const {
    148   FunctionLibraryRuntime* flr = GetFLR(device_name);
    149   if (flr == nullptr) {
    150     return errors::InvalidArgument("Device name: ", device_name, " not found");
    151   }
    152   *incarnation = flr->device()->attributes().incarnation();
    153   return Status::OK();
    154 }
    155 
    156 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
    157     const string& device_name, DeviceContext** device_context) const {
    158   *device_context = nullptr;
    159   FunctionLibraryRuntime* flr = GetFLR(device_name);
    160   if (flr == nullptr) {
    161     return errors::InvalidArgument("Device name: ", device_name, " not found.");
    162   }
    163   Device* device = flr->device();
    164   string device_type = device->parsed_name().type;
    165   if (device_type == "CPU" || device_type == "TPU_SYSTEM") {
    166     // "TPU_SYSTEM" indicates that `device` is a CPU.
    167     return Status::OK();
    168   }
    169   if (device_type == "GPU" || device_type == "TPU") {
    170     auto* dev_info = flr->device()->tensorflow_gpu_device_info();
    171     if (dev_info) {
    172       *device_context = dev_info->default_context;
    173       return Status::OK();
    174     }
    175   }
    176   return errors::Internal("Device type: ", device_type,
    177                           " is currently unsupported for remote ",
    178                           "function executions");
    179 }
    180 
    181 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
    182     const string& device_name) const {
    183   Device* device = nullptr;
    184   if (device_name != kDefaultFLRDevice) {
    185     if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
    186       VLOG(1) << "Could not find device: " << device_name;
    187       return nullptr;
    188     }
    189   }
    190   const auto& iter = flr_map_.find(device);
    191   if (iter == flr_map_.end()) {
    192     LOG(ERROR) << "Could not find device: " << device_name;
    193     return nullptr;
    194   }
    195   return iter->second.get();
    196 }
    197 
    198 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
    199     const string& function_key, const string& device_name,
    200     FunctionLibraryRuntime::LocalHandle local_handle) {
    201   mutex_lock l(mu_);
    202   return AddHandleLocked(function_key, device_name, local_handle);
    203 }
    204 
    205 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandleLocked(
    206     const string& function_key, const string& device_name,
    207     FunctionLibraryRuntime::LocalHandle local_handle) {
    208   auto h = next_handle_;
    209   function_data_[h] =
    210       MakeUnique<FunctionData>(device_name, local_handle, function_key);
    211   table_[function_key] = h;
    212   next_handle_++;
    213   return h;
    214 }
    215 
    216 FunctionLibraryRuntime::Handle
    217 ProcessFunctionLibraryRuntime::AddMultiDeviceHandle(
    218     std::unique_ptr<MultiDeviceFunctionData> data, const string& function_key) {
    219   mutex_lock l(mu_);
    220   auto h = next_handle_;
    221   mdevice_data_[h] = std::move(data);
    222   table_[function_key] = h;
    223   next_handle_++;
    224   return h;
    225 }
    226 
    227 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
    228     const string& function_key) const {
    229   tf_shared_lock l(mu_);
    230   return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
    231 }
    232 
    233 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
    234     const string& device_name, FunctionLibraryRuntime::Handle handle) const {
    235   return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
    236 }
    237 
    238 FunctionLibraryRuntime::LocalHandle
    239 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
    240     const string& device_name, FunctionLibraryRuntime::Handle handle) const {
    241   tf_shared_lock l(mu_);
    242 
    243   auto miter = mdevice_data_.find(handle);
    244   if (miter != mdevice_data_.end()) {
    245     return kInvalidLocalHandle;
    246   }
    247 
    248   auto iter = function_data_.find(handle);
    249   if (iter == function_data_.end()) {
    250     return kInvalidLocalHandle;
    251   }
    252   FunctionData* function_data = iter->second.get();
    253   if (function_data->target_device() != device_name) {
    254     return kInvalidLocalHandle;
    255   }
    256   return function_data->local_handle();
    257 }
    258 
    259 string ProcessFunctionLibraryRuntime::GetDeviceName(
    260     FunctionLibraryRuntime::Handle handle) const {
    261   tf_shared_lock l(mu_);
    262   auto iter = function_data_.find(handle);
    263   CHECK(iter != function_data_.end());
    264   FunctionData* function_data = iter->second.get();
    265   return function_data->target_device();
    266 }
    267 
    268 ProcessFunctionLibraryRuntime::MultiDeviceFunctionData*
    269 ProcessFunctionLibraryRuntime::IsMultiDevice(
    270     FunctionLibraryRuntime::Handle handle) const {
    271   tf_shared_lock l(mu_);
    272   const auto& it = mdevice_data_.find(handle);
    273   if (it != mdevice_data_.end()) {
    274     return it->second.get();
    275   }
    276   return nullptr;
    277 }
    278 
    279 namespace {
    280 // Sets `group` to the first colocation group specified in `node`. If no
    281 // group is specified, does not touch `group`.
    282 void GetColocationGroup(const Node* node, string* group) {
    283   // We hoist the conversion from C-style string literal to string here,
    284   // so that we can avoid the many repeated calls to strlen().
    285   static const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
    286   const AttrValue* attr_value =
    287       node->attrs().Find(kColocationAttrNameStringPiece);
    288   if (attr_value != nullptr && attr_value->has_list() &&
    289       attr_value->list().s_size() > 0) {
    290     *group = attr_value->list().s(0);
    291   }
    292 }
    293 
    294 const string* AssignedOrRequestedDeviceName(const Node& node) {
    295   if (node.has_assigned_device_name()) {
    296     return &node.assigned_device_name();
    297   }
    298   return &node.requested_device();
    299 }
    300 
    301 }  // anonymous namespace
    302 
    303 Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
    304     const std::vector<string>& input_devices,
    305     const std::vector<string>& output_devices, const DeviceSet& device_set,
    306     Graph* graph) const {
    307   // If output_devices are not specified, we want to set the output device
    308   // based on the device of the output producing node. The output producing
    309   // node can be an arg node because functions can simply return their
    310   // arguments. To make sure that the output producing nodes have assigned
    311   // devices, we assign them to arguments first.
    312   for (Node* node : graph->op_nodes()) {
    313     if (node->IsArg()) {
    314       const AttrValue* attr_value;
    315       TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
    316       int64 index = attr_value->i();
    317       node->set_assigned_device_name(input_devices[index]);
    318     }
    319   }
    320 
    321   for (Node* node : graph->op_nodes()) {
    322     if (node->IsRetval()) {
    323       if (output_devices.empty()) {
    324         VLOG(3) << "Trying to determine device for node " << node->name();
    325         // If output_devices are empty, the node producing retval
    326         // must have explicitly assigned device or a colocation constraint
    327         // to a node with explicitly assigned device.
    328         for (const auto& it : node->in_edges()) {
    329           if (!it->IsControlEdge()) {
    330             Node* src_node = it->src();
    331             const string* src_device = AssignedOrRequestedDeviceName(*src_node);
    332             string colocation_group = "";
    333             GetColocationGroup(src_node, &colocation_group);
    334             VLOG(3) << "Considering src: " << src_node->name()
    335                     << " src_device: " << *src_device
    336                     << " colo group: " << colocation_group;
    337             while (src_device->empty() && colocation_group.empty() &&
    338                    src_node->IsIdentity()) {
    339               src_node = *src_node->in_nodes().begin();
    340               src_device = AssignedOrRequestedDeviceName(*src_node);
    341               GetColocationGroup(src_node, &colocation_group);
    342               VLOG(3) << "Considering src: " << src_node->name()
    343                       << " src_device: " << *src_device
    344                       << " colo group: " << colocation_group;
    345             }
    346 
    347             if (!colocation_group.empty()) {
    348               AttrValue::ListValue colo_attr;
    349               colo_attr.add_s(colocation_group);
    350               std::vector<string> colo_slice = {colocation_group};
    351               node->AddAttr(kColocationAttrName, colo_slice);
    352             } else if (!src_device->empty()) {
    353               // src_device can be a partially specified device. Find the
    354               // matching device in the device_set.
    355               DeviceNameUtils::ParsedName parsed;
    356               if (!DeviceNameUtils::ParseFullName(*src_device, &parsed)) {
    357                 return errors::InvalidArgument(
    358                     "Failed to parse explicit device specification ",
    359                     *src_device);
    360               }
    361               std::vector<Device*> matching_devices;
    362               device_set.FindMatchingDevices(parsed, &matching_devices);
    363               if (matching_devices.empty()) {
    364                 return errors::InvalidArgument(
    365                     "Unable to find any devices for spec ", *src_device);
    366               } else if (matching_devices.size() != 1) {
    367                 // Convert a vector of devices to a string.
    368                 // Using absl::StrJoin did not work in Android builds.
    369                 string devices = "[";
    370                 for (Device* device : matching_devices) {
    371                   devices.append(device->name());
    372                   devices.append(", ");
    373                 }
    374                 if (devices.size() > 2) {
    375                   devices.resize(devices.size() - 2);
    376                 }
    377                 devices.append("]");
    378 
    379                 return errors::InvalidArgument(
    380                     "When FunctionLibraryRuntime::Options.output_devices are "
    381                     "not specified for a multi-device function, the device "
    382                     "specification on the output node must match exactly one "
    383                     "device. Matched devices are ",
    384                     devices);
    385               }
    386               VLOG(3) << "Setting output device to "
    387                       << matching_devices[0]->name() << " for node "
    388                       << node->DebugString();
    389               node->set_assigned_device_name(matching_devices[0]->name());
    390             }
    391           }
    392         }
    393       } else {
    394         const AttrValue* attr_value;
    395         TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
    396         int64 index = attr_value->i();
    397         // output_devices size is checked in InstantiateMultiDevice
    398         DCHECK_GT(output_devices.size(), index);
    399         VLOG(3) << "Setting output device to " << output_devices[index]
    400                 << " for return at index " << index;
    401         node->set_assigned_device_name(output_devices[index]);
    402       }
    403     }
    404   }
    405   return Status::OK();
    406 }
    407 
    408 namespace {
    409 
    410 Status ValidateNoListArguments(
    411     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, const char* arg_type,
    412     const string& function_name) {
    413   for (const OpDef::ArgDef& arg : args) {
    414     if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) {
    415       return errors::InvalidArgument(
    416           "Function ", function_name, " has an ", arg_type, " named \"",
    417           arg.name(),
    418           "\" that is a list of tensors."
    419           " Multi-device functions support only single-tensor inputs "
    420           " and outputs");
    421     }
    422   }
    423   return Status::OK();
    424 }
    425 
    426 Status ValidateMultiDeviceOptions(
    427     const FunctionDef& fdef,
    428     const FunctionLibraryRuntime::InstantiateOptions& options) {
    429   const OpDef& signature = fdef.signature();
    430   // Multi-device functions don't currently support list inputs or outputs
    431   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.input_arg(), "input",
    432                                              signature.name()));
    433   TF_RETURN_IF_ERROR(ValidateNoListArguments(signature.output_arg(), "output",
    434                                              signature.name()));
    435 
    436   if (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
    437       fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) {
    438     return errors::Unimplemented(
    439         "Function '", signature.name(), "' has `",
    440         FunctionLibraryDefinition::kIntsOnDeviceAttr,
    441         "` attribute set. This attribute is not currently supported by "
    442         "multi-device functions.");
    443   }
    444 
    445   if (options.input_devices.size() != signature.input_arg_size()) {
    446     return errors::InvalidArgument(
    447         "InstantiateOptions.input_devices must have the same length "
    448         "as the number of arguments: input_devices length = ",
    449         options.input_devices.size(),
    450         " number of arguments = ", signature.input_arg_size());
    451   }
    452   if (!options.output_devices.empty() &&
    453       options.output_devices.size() != signature.output_arg_size()) {
    454     return errors::InvalidArgument(
    455         "InstantiateOptions.output_devices must either be empty or have "
    456         "the same length as the number of arguments: output_devices length "
    457         "= ",
    458         options.output_devices.size(),
    459         " number of arguments = ", signature.output_arg_size());
    460   }
    461 
    462   if (!options.state_handle.empty()) {
    463     return errors::Unimplemented(
    464         "InstantiateOptions.state_handle is not supported for multi-device "
    465         "functions. Function: ",
    466         signature.name());
    467   }
    468   if (options.create_kernels_eagerly) {
    469     return errors::Unimplemented(
    470         "InstantiateOptions.create_kernels_eagerly is not supported for "
    471         "multi-device functions. Function: ",
    472         signature.name());
    473   }
    474 
    475   return Status::OK();
    476 }
    477 
    478 Status GetGraphAndRets(const string& function_name, AttrSlice attrs,
    479                        const FunctionDef* fdef,
    480                        const FunctionLibraryDefinition* lib_def,
    481                        std::unique_ptr<Graph>* graph,
    482                        std::vector<string>* ret_node_names,
    483                        std::vector<string>* control_ret_node_names) {
    484   auto get_func_sig = [lib_def](const string& op, const OpDef** sig) {
    485     return lib_def->LookUpOpDef(op, sig);
    486   };
    487   FunctionBody* tmp_fbody;
    488   // TODO(iga): FunctionDefToBodyHelper copies fdef. Avoid this copy.
    489   TF_RETURN_IF_ERROR(
    490       FunctionDefToBodyHelper(*fdef, attrs, lib_def, get_func_sig, &tmp_fbody));
    491   if (tmp_fbody == nullptr) {
    492     LOG(ERROR) << "Failed to get FunctionBody for \"" << function_name << "\"";
    493     return errors::Internal("Failed to construct FunctionBody for ",
    494                             function_name);
    495   }
    496   std::unique_ptr<FunctionBody> fbody(tmp_fbody);
    497   *graph = std::unique_ptr<Graph>(fbody->graph);
    498   fbody->graph = nullptr;
    499   ret_node_names->reserve(fbody->ret_nodes.size());
    500   for (const Node* node : fbody->ret_nodes) {
    501     ret_node_names->push_back(node->name());
    502   }
    503   control_ret_node_names->reserve(fbody->control_ret_nodes.size());
    504   for (const Node* node : fbody->control_ret_nodes) {
    505     control_ret_node_names->push_back(node->name());
    506   }
    507   return Status::OK();
    508 }
    509 
    510 }  // anonymous namespace
    511 
    512 Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
    513     const string& function_name, AttrSlice attrs,
    514     const FunctionLibraryRuntime::InstantiateOptions& options,
    515     FunctionLibraryRuntime::Handle* handle) {
    516   // Check if this function has already been instantiated.
    517   const string& function_key = Canonicalize(function_name, attrs, options);
    518 
    519   {
    520     mutex_lock l(mu_);
    521     const auto& it = table_.find(function_key);
    522     if (it != table_.end()) {
    523       *handle = it->second;
    524       ++mdevice_data_[*handle]->instantiation_counter_;
    525       return Status::OK();
    526     }
    527   }
    528 
    529   VLOG(1) << "Instantiating MultiDevice function \"" << function_name
    530           << "\" on default device \"" << options.target << "\"";
    531   if (VLOG_IS_ON(3)) {
    532     VLOG(3) << "Requested input devices:";
    533     for (const string& device : options.input_devices) {
    534       VLOG(3) << "    " << device;
    535     }
    536     VLOG(3) << "Requested output devices:";
    537     for (const string& device : options.output_devices) {
    538       VLOG(3) << "    " << device;
    539     }
    540   }
    541 
    542   const FunctionLibraryDefinition* lib_def =
    543       options.overlay_lib == nullptr ? lib_def_ : options.overlay_lib;
    544 
    545   const FunctionDef* fdef = lib_def->Find(function_name);
    546   if (fdef == nullptr) {
    547     return errors::InvalidArgument("Failed to find function \"", function_name,
    548                                    "\" in function library: ", lib_def);
    549   }
    550 
    551   TF_RETURN_IF_ERROR(ValidateMultiDeviceOptions(*fdef, options));
    552 
    553   std::unique_ptr<Graph> graph;
    554   std::vector<string> ret_node_names;
    555   std::vector<string> control_ret_node_names;
    556 
    557   TF_RETURN_IF_ERROR(GetGraphAndRets(function_name, attrs, fdef, lib_def,
    558                                      &graph, &ret_node_names,
    559                                      &control_ret_node_names));
    560 
    561   if (options.graph_collector != nullptr) {
    562     GraphDef def;
    563     graph->ToGraphDef(&def);
    564     *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
    565     options.graph_collector->CollectRawGraph(def);
    566   }
    567 
    568   DeviceSet device_set;
    569   for (auto d : device_mgr_->ListDevices()) {
    570     device_set.AddDevice(d);
    571   }
    572 
    573   TF_RETURN_IF_ERROR(PinArgsAndRets(
    574       options.input_devices, options.output_devices, device_set, graph.get()));
    575 
    576   std::unique_ptr<MultiDeviceFunctionData> data =
    577       MakeUnique<MultiDeviceFunctionData>(function_name, function_key,
    578                                           ret_node_names.size(),
    579                                           lib_def->ReachableDefinitions(*fdef));
    580 
    581   GraphOptimizationPassOptions optimization_options;
    582   // TODO(iga): Thread other relevant options from SessionOptions.
    583   SessionOptions session_options;
    584   session_options.env = env_;
    585   session_options.config = options.config_proto;
    586   optimization_options.session_options = &session_options;
    587   optimization_options.graph = &graph;
    588   optimization_options.flib_def = &data->overlay_lib_;
    589   optimization_options.device_set = &device_set;
    590 
    591   DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
    592   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
    593       OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
    594 
    595   DumpGraph("Before calling Placer", graph.get());
    596   // Make the FunctionLibraryRuntime's device the default device if
    597   // nothing else is hard coded. This allows the same function definition
    598   // to be specialized to different devices depending on the
    599   // PartitionedCallOp's device.
    600   Device* default_device = nullptr;
    601   if (!options.target.empty()) {
    602     FunctionLibraryRuntime* flr = GetFLR(options.target);
    603     if (flr == nullptr) {
    604       return errors::InvalidArgument(
    605           "Cannot instantiate multi-device function with target device ",
    606           options.target);
    607     }
    608     default_device = flr->device();
    609   }
    610 
    611   // TODO(b/124993244): Smartly merge options in nested defuns, and raise
    612   // exceptions/warnings in case where nested function call options are ignored.
    613   Placer placer(graph.get(), &device_set, default_device,
    614                 options.config_proto.allow_soft_placement(),
    615                 options.config_proto.log_device_placement());
    616   TF_RETURN_IF_ERROR(placer.Run());
    617 
    618   DumpGraph("Before running POST_PLACEMENT passes", graph.get());
    619   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
    620       OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
    621 
    622   Device* cpu_device;
    623   TF_RETURN_IF_ERROR(device_mgr_->LookupDevice("CPU:0", &cpu_device));
    624 
    625   if (options.optimize_graph_fn) {
    626     DumpGraph("Before running graph optimization fn", graph.get());
    627     Status status = options.optimize_graph_fn(
    628         std::move(ret_node_names), std::move(control_ret_node_names),
    629         &data->overlay_lib_, device_set, cpu_device, &graph);
    630     if (!status.ok()) {
    631       LOG(WARNING) << "Ignoring multi-device function optimization failure: "
    632                    << status.ToString();
    633     }
    634     DumpGraph("After optimization", graph.get());
    635   }
    636 
    637   DumpGraph("Before running POST_REWRITE_FOR_EXEC passes", graph.get());
    638   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
    639       OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
    640   DumpGraph("After all optimization passes", graph.get());
    641 
    642   if (options.graph_collector != nullptr) {
    643     GraphDef def;
    644     graph->ToGraphDef(&def);
    645     *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
    646     options.graph_collector->CollectOptimizedGraph(def);
    647   }
    648 
    649   std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
    650   TF_RETURN_IF_ERROR(
    651       PartitionFunctionGraph(device_set, std::move(graph), &subgraphs));
    652 
    653   if (options.graph_collector != nullptr) {
    654     for (const auto& pair : subgraphs) {
    655       GraphDef def;
    656       pair.second->ToGraphDef(&def);
    657       *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
    658       options.graph_collector->CollectPartitionedGraph(def);
    659     }
    660   }
    661 
    662   int i = 0;
    663   FunctionNameGenerator name_generator(&data->overlay_lib_, function_name);
    664   for (const auto& pair : subgraphs) {
    665     i += 1;
    666     // TODO(iga): Fail gracefully if the set of devices corresponds
    667     // to more than one address space.
    668     const string& target = pair.first;
    669     Graph* subgraph = pair.second.get();
    670 
    671     ComponentFunctionData* comp_data = &data->glue_[target];
    672     TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
    673         subgraph, &comp_data->arg_indices_, &comp_data->ret_indices_,
    674         &comp_data->arg_alloc_attrs_, &comp_data->ret_alloc_attrs_));
    675     FunctionDef shard;
    676     string unique_name = name_generator.GetName();
    677     TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph, unique_name, &shard));
    678     FunctionLibraryRuntime* target_flr = GetFLR(target);
    679     TF_RETURN_IF_ERROR(data->overlay_lib_.AddFunctionDef(shard));
    680     FunctionLibraryRuntime::InstantiateOptions opts;
    681     opts.executor_type = options.executor_type;
    682     opts.target = target;
    683     opts.overlay_lib = &data->overlay_lib_;
    684     FunctionLibraryRuntime::Handle component_handle;
    685 
    686     TF_RETURN_IF_ERROR(target_flr->Instantiate(
    687         unique_name, AttrSlice(&shard.attr()), opts, &component_handle));
    688     VLOG(1) << "Instantiated component function " << unique_name
    689             << " on device " << target << " with component handle "
    690             << component_handle;
    691     VLOG(2) << DebugString(shard);
    692     comp_data->handle_ = component_handle;
    693   }
    694 
    695   *handle = AddMultiDeviceHandle(std::move(data), function_key);
    696   VLOG(2) << "Instantiated MultiDevice function \"" << function_name
    697           << "\" with handle " << *handle;
    698   return Status::OK();
    699 }
    700 
    701 Status ProcessFunctionLibraryRuntime::GetOutputDevices(
    702     FunctionLibraryRuntime::Handle handle,
    703     std::vector<Device*>* output_devices) const {
    704   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
    705   if (data == nullptr) {
    706     return errors::InvalidArgument(
    707         "Failed for find multi-device function handle ", handle);
    708   }
    709 
    710   for (const auto& pair : data->glue_) {
    711     const ComponentFunctionData& comp_data = pair.second;
    712     DCHECK(comp_data.ret_alloc_attrs_.size() == comp_data.ret_indices_.size());
    713 
    714     const string& target = pair.first;
    715     FunctionLibraryRuntime* target_flr = GetFLR(target);
    716     Device* target_device = target_flr->device();
    717     const FunctionBody* fbody = target_flr->GetFunctionBody(comp_data.handle_);
    718     DCHECK(fbody != nullptr);
    719 
    720     output_devices->resize(data->num_outputs_);
    721     for (int j = 0; j < comp_data.ret_indices_.size(); ++j) {
    722       int ret_index = comp_data.ret_indices_[j];
    723       if (fbody->ret_types[j] == DT_RESOURCE) {
    724         (*output_devices)[ret_index] = target_device;
    725       } else {
    726         (*output_devices)[ret_index] =
    727             comp_data.ret_alloc_attrs_[j].on_host() ? nullptr : target_device;
    728       }
    729     }
    730   }
    731 
    732   return Status::OK();
    733 }
    734 
    735 void ProcessFunctionLibraryRuntime::RunMultiDevice(
    736     const FunctionLibraryRuntime::Options& opts,
    737     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
    738     std::vector<Tensor>* rets,
    739     FunctionLibraryRuntime::DoneCallback done) const {
    740   if (opts.create_rendezvous) {
    741     // FLR->Run() is the default entry point. It checks for cancellation,
    742     // creates rendezvous, etc.
    743     // Letting create_rendezvous through will do the wrong thing - each
    744     // component function will get a separate rendezvous created by its FLR.
    745     done(
    746         errors::Internal("Cannot call ProcessFunctionLibraryRuntime::Run with "
    747                          "create_rendezvous=true. Please run the function "
    748                          "using FunctionLibraryRuntime::Run"));
    749     return;
    750   }
    751 
    752   const MultiDeviceFunctionData* data = IsMultiDevice(handle);
    753   if (data == nullptr) {
    754     done(
    755         errors::InvalidArgument("Failed for find multi-device function handle ",
    756                                 handle, ". Was the function instantiated?"));
    757     return;
    758   }
    759 
    760   if (data->glue_.empty()) {
    761     // Trivial case where the function body is empty.
    762     done(Status::OK());
    763     return;
    764   }
    765 
    766   auto* refcounted_done = new ReffedStatusCallback(std::move(done));
    767   for (int i = 0; i < data->glue_.size(); ++i) {
    768     refcounted_done->Ref();
    769   }
    770 
    771   FunctionLibraryRuntime::Options opts_copy = opts;
    772   for (const auto& pair : data->glue_) {
    773     const string& target = pair.first;
    774     const ComponentFunctionData& comp_data = pair.second;
    775     FunctionLibraryRuntime::Handle handle = pair.second.handle_;
    776     VLOG(1) << "Running function shard on device " << target << " with handle "
    777             << handle;
    778 
    779     opts_copy.args_alloc_attrs = comp_data.arg_alloc_attrs_;
    780     opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs_;
    781     opts_copy.remote_execution = false;
    782     std::vector<Tensor> comp_args =
    783         GetArgsForIndices(comp_data.arg_indices_, args);
    784     std::vector<Tensor>* comp_rets = new std::vector<Tensor>;
    785     rets->resize(data->num_outputs_);
    786     GetFLR(target)->Run(
    787         opts_copy, handle, comp_args, comp_rets,
    788         [comp_rets, rets, comp_data, refcounted_done](const Status& status) {
    789           if (!status.ok()) {
    790             LOG(ERROR) << "Component function execution failed: " << status;
    791             refcounted_done->UpdateStatus(status);
    792           } else {
    793             for (int i = 0; i < comp_rets->size(); ++i) {
    794               (*rets)[comp_data.ret_indices_[i]] = (*comp_rets)[i];
    795             }
    796           }
    797           delete comp_rets;
    798           // refcounted_done is thread-safe
    799           refcounted_done->Unref();
    800         });
    801   }
    802   refcounted_done->Unref();
    803 }
    804 
    805 Status ProcessFunctionLibraryRuntime::Instantiate(
    806     const string& function_name, AttrSlice attrs,
    807     const FunctionLibraryRuntime::InstantiateOptions& options,
    808     FunctionLibraryRuntime::Handle* handle) {
    809   if (options.is_multi_device_function) {
    810     return InstantiateMultiDevice(function_name, attrs, options, handle);
    811   }
    812 
    813   *handle = kInvalidHandle;
    814   FunctionLibraryRuntime* flr = GetFLR(options.target);
    815   if (flr != nullptr) {
    816     return flr->Instantiate(function_name, attrs, options, handle);
    817   }
    818   if (parent_ == nullptr) {
    819     return errors::Internal(
    820         "Currently don't support instantiating functions on device: ",
    821         options.target);
    822   }
    823   VLOG(1) << "ProcessFLR Instantiate: " << function_name
    824           << " on: " << options.target;
    825   string function_key = Canonicalize(function_name, attrs, options);
    826   FunctionData* f;
    827   {
    828     mutex_lock l(mu_);
    829     FunctionLibraryRuntime::Handle h =
    830         gtl::FindWithDefault(table_, function_key, kInvalidHandle);
    831     if (h == kInvalidHandle || function_data_.count(h) == 0) {
    832       h = AddHandleLocked(function_key, options.target, kInvalidHandle);
    833     }
    834     f = function_data_[h].get();
    835     *handle = h;
    836   }
    837   TF_RETURN_IF_ERROR(
    838       f->DistributedInit(parent_, function_name, *lib_def_, attrs, options));
    839   VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name
    840           << " on: " << options.target << " with handle: " << *handle
    841           << " (this: " << this << ")";
    842   return Status::OK();
    843 }
    844 
    845 Status ProcessFunctionLibraryRuntime::RemoveHandle(
    846     FunctionLibraryRuntime::Handle handle) {
    847   mutex_lock l(mu_);
    848   table_.erase(function_data_[handle]->function_key());
    849   function_data_.erase(handle);
    850   return Status::OK();
    851 }
    852 
    853 Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle(
    854     FunctionLibraryRuntime::Handle handle) {
    855   std::unique_ptr<MultiDeviceFunctionData> mdata;
    856   {
    857     mutex_lock l(mu_);
    858     auto it = mdevice_data_.find(handle);
    859     --it->second->instantiation_counter_;
    860     if (it->second->instantiation_counter_ != 0) {
    861       return Status::OK();
    862     }
    863     mdata = std::move(it->second);
    864     table_.erase(mdata->function_key_);
    865     mdevice_data_.erase(it);
    866   }
    867 
    868   // If we are here we are releasing the last instantiation of `handle`.
    869   // Release all component function handles.
    870   Status overall_status;
    871   for (const auto& it : mdata->glue_) {
    872     const string& device = it.first;
    873     FunctionLibraryRuntime::Handle flr_handle = it.second.handle_;
    874     FunctionLibraryRuntime* flr = GetFLR(device);
    875     if (flr == nullptr) {
    876       return errors::InvalidArgument(
    877           "Failed to find FunctionLibraryRuntime for device ", device,
    878           " when releasing multi-device function handle ", handle);
    879     }
    880     Status status = flr->ReleaseHandle(flr_handle);
    881     if (!status.ok()) {
    882       overall_status = status;
    883     }
    884   }
    885 
    886   return overall_status;
    887 }
    888 
    889 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
    890     FunctionLibraryRuntime::Handle handle) {
    891   if (IsMultiDevice(handle)) {
    892     return ReleaseMultiDeviceHandle(handle);
    893   }
    894 
    895   FunctionLibraryRuntime* flr = nullptr;
    896   string target_device;
    897   {
    898     mutex_lock l(mu_);
    899     CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
    900     target_device = function_data_[handle]->target_device();
    901   }
    902   flr = GetFLR(target_device);
    903   if (flr != nullptr) {
    904     return flr->ReleaseHandle(handle);
    905   }
    906   return errors::InvalidArgument("Handle not found: ", handle);
    907 }
    908 
    909 void ProcessFunctionLibraryRuntime::Run(
    910     const FunctionLibraryRuntime::Options& opts,
    911     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
    912     std::vector<Tensor>* rets,
    913     FunctionLibraryRuntime::DoneCallback done) const {
    914   bool multi_device;
    915   {
    916     tf_shared_lock l(mu_);
    917     multi_device = mdevice_data_.find(handle) != mdevice_data_.end();
    918   }
    919   if (multi_device) {
    920     return RunMultiDevice(opts, handle, args, rets, done);
    921   }
    922 
    923   FunctionLibraryRuntime* flr = nullptr;
    924   string target_device;
    925   FunctionLibraryRuntime::LocalHandle local_handle;
    926   {
    927     tf_shared_lock l(mu_);
    928     auto iter = function_data_.find(handle);
    929     if (iter == function_data_.end()) {
    930       done(errors::NotFound("Handle: ", handle, " not found."));
    931       return;
    932     }
    933     FunctionData* function_data = iter->second.get();
    934     target_device = function_data->target_device();
    935     local_handle = function_data->local_handle();
    936   }
    937 
    938   if (!opts.remote_execution) {
    939     done(
    940         errors::InvalidArgument("ProcessFunctionLibraryRuntime::Run should "
    941                                 "only be called for multi-device functions or "
    942                                 "for remote execution."));
    943     return;
    944   }
    945 
    946   flr = GetFLR(target_device);
    947   if (flr != nullptr) {
    948     auto rendezvous = opts.rendezvous;
    949     string source_device = opts.source_device;
    950     DeviceContext* device_context;
    951     Status s = GetDeviceContext(source_device, &device_context);
    952     if (!s.ok()) {
    953       done(s);
    954       return;
    955     }
    956     int64 src_incarnation, target_incarnation;
    957     s = GetDeviceIncarnation(source_device, &src_incarnation);
    958     s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
    959     if (!s.ok()) {
    960       done(s);
    961       return;
    962     }
    963 
    964     // Send the args over to the target device.
    965     s = SendTensors(source_device, target_device, "arg_", src_incarnation, args,
    966                     device_context, opts.args_alloc_attrs, rendezvous);
    967     if (!s.ok()) {
    968       done(s);
    969       return;
    970     }
    971     const std::vector<AllocatorAttributes>& rets_alloc_attrs =
    972         opts.rets_alloc_attrs;
    973     std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
    974     flr->Run(opts, handle, args, remote_rets,
    975              std::bind(
    976                  [source_device, target_device, target_incarnation, rendezvous,
    977                   device_context, rets_alloc_attrs, remote_rets,
    978                   rets](const Status& status,
    979                         FunctionLibraryRuntime::DoneCallback& done) {
    980                    if (!status.ok()) {
    981                      delete remote_rets;
    982                      done(status);
    983                      return;
    984                    }
    985                    int64 num_returns = remote_rets->size();
    986                    delete remote_rets;
    987                    // Now receive the return values from the target.
    988                    ReceiveTensorsAsync(target_device, source_device, "ret_",
    989                                        target_incarnation, num_returns,
    990                                        device_context, rets_alloc_attrs,
    991                                        rendezvous, rets, std::move(done));
    992                  },
    993                  std::placeholders::_1, std::move(done)));
    994     return;
    995   }
    996   if (parent_ != nullptr) {
    997     parent_->Run(opts, local_handle, args, rets, std::move(done));
    998     return;
    999   }
   1000   done(errors::Internal("Could not find device"));
   1001 }
   1002 
   1003 Status ProcessFunctionLibraryRuntime::Clone(
   1004     Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
   1005     CustomKernelCreator custom_kernel_creator,
   1006     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
   1007     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) const {
   1008   out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
   1009   out_pflr->reset(new ProcessFunctionLibraryRuntime(
   1010       device_mgr_, env, graph_def_version, out_lib_def->get(),
   1011       optimizer_options, std::move(custom_kernel_creator), default_thread_pool_,
   1012       parent_));
   1013   return Status::OK();
   1014 }
   1015 
   1016 }  // namespace tensorflow
   1017