Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
     16 
     17 #include <utility>
     18 
     19 #include "tensorflow/core/common_runtime/function.h"
     20 #include "tensorflow/core/common_runtime/rendezvous_util.h"
     21 #include "tensorflow/core/lib/gtl/map_util.h"
     22 #include "tensorflow/core/util/device_name_utils.h"
     23 
     24 namespace tensorflow {
     25 
     26 const char ProcessFunctionLibraryRuntime::kDefaultFLRDevice[] = "null";
     27 
     28 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
     29     const DeviceMgr* device_mgr, Env* env, int graph_def_version,
     30     const FunctionLibraryDefinition* lib_def,
     31     const OptimizerOptions& optimizer_options,
     32     DistributedFunctionLibraryRuntime* parent)
     33     : device_mgr_(device_mgr),
     34       lib_def_(lib_def),
     35       next_handle_(0),
     36       parent_(parent) {
     37   if (device_mgr == nullptr) {
     38     flr_map_[nullptr] =
     39         NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
     40                                   lib_def, optimizer_options, this);
     41     return;
     42   }
     43   for (Device* d : device_mgr->ListDevices()) {
     44     flr_map_[d] =
     45         NewFunctionLibraryRuntime(device_mgr, env, d, graph_def_version,
     46                                   lib_def, optimizer_options, this);
     47   }
     48 }
     49 
     50 ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
     51     const DeviceMgr* device_mgr, Env* env, int graph_def_version,
     52     const FunctionLibraryDefinition* lib_def,
     53     const OptimizerOptions& optimizer_options,
     54     CustomKernelCreator custom_kernel_creator,
     55     DistributedFunctionLibraryRuntime* parent)
     56     : device_mgr_(device_mgr),
     57       lib_def_(lib_def),
     58       next_handle_(0),
     59       parent_(parent) {
     60   if (device_mgr == nullptr) {
     61     flr_map_[nullptr] = NewFunctionLibraryRuntime(
     62         nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
     63         std::move(custom_kernel_creator), this);
     64     return;
     65   }
     66   for (Device* d : device_mgr->ListDevices()) {
     67     flr_map_[d] = NewFunctionLibraryRuntime(
     68         device_mgr, env, d, graph_def_version, lib_def, optimizer_options,
     69         custom_kernel_creator, this);
     70   }
     71 }
     72 
     73 /* static */
     74 Status ProcessFunctionLibraryRuntime::SendTensors(
     75     const string& source_device, const string& target_device,
     76     const string& key_prefix, int64 src_incarnation,
     77     gtl::ArraySlice<Tensor> tensors_to_send, DeviceContext* device_context,
     78     const std::vector<AllocatorAttributes>& alloc_attrs,
     79     Rendezvous* rendezvous) {
     80   std::vector<string> keys;
     81   for (int i = 0; i < tensors_to_send.size(); ++i) {
     82     string name = strings::StrCat(key_prefix, i);
     83     string key = Rendezvous::CreateKey(source_device, src_incarnation,
     84                                        target_device, name, FrameAndIter(0, 0));
     85     keys.push_back(key);
     86   }
     87   TF_RETURN_IF_ERROR(SendTensorsToRendezvous(
     88       rendezvous, device_context, alloc_attrs, keys, tensors_to_send));
     89   return Status::OK();
     90 }
     91 
     92 /* static */
     93 void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
     94     const string& source_device, const string& target_device,
     95     const string& key_prefix, int64 src_incarnation, int64 num_tensors,
     96     DeviceContext* device_context,
     97     const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
     98     std::vector<Tensor>* received_tensors, const StatusCallback& done) {
     99   std::vector<string> keys;
    100   for (int64 i = 0; i < num_tensors; ++i) {
    101     string name = strings::StrCat(key_prefix, i);
    102     string key = Rendezvous::CreateKey(source_device, src_incarnation,
    103                                        target_device, name, FrameAndIter(0, 0));
    104     keys.push_back(key);
    105   }
    106   RecvOutputsFromRendezvousAsync(
    107       rendezvous, device_context, alloc_attrs, keys, received_tensors,
    108       [done](const Status& status) { done(status); });
    109 }
    110 
    111 Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
    112     const string& device_name, int64* incarnation) {
    113   FunctionLibraryRuntime* flr = GetFLR(device_name);
    114   if (flr == nullptr) {
    115     return errors::InvalidArgument("Device name: ", device_name, " not found");
    116   }
    117   *incarnation = flr->device()->attributes().incarnation();
    118   return Status::OK();
    119 }
    120 
    121 Status ProcessFunctionLibraryRuntime::GetDeviceContext(
    122     const string& device_name, DeviceContext** device_context) {
    123   *device_context = nullptr;
    124   FunctionLibraryRuntime* flr = GetFLR(device_name);
    125   if (flr == nullptr) {
    126     return errors::InvalidArgument("Device name: ", device_name, " not found.");
    127   }
    128   Device* device = flr->device();
    129   string device_type = device->parsed_name().type;
    130   if (device_type == "CPU") return Status::OK();
    131   if (device_type == "GPU") {
    132     auto* dev_info = flr->device()->tensorflow_gpu_device_info();
    133     if (dev_info) {
    134       *device_context = dev_info->default_context;
    135       return Status::OK();
    136     }
    137   }
    138   return errors::Internal("Device type: ", device_type,
    139                           " is currently unsupported for remote ",
    140                           "function executions");
    141 }
    142 
    143 FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
    144     const string& device_name) const {
    145   Device* device = nullptr;
    146   if (device_name != kDefaultFLRDevice) {
    147     if (!device_mgr_->LookupDevice(device_name, &device).ok()) {
    148       LOG(ERROR) << "Could not find device: " << device_name;
    149       return nullptr;
    150     }
    151   }
    152   const auto& iter = flr_map_.find(device);
    153   if (iter == flr_map_.end()) {
    154     LOG(ERROR) << "Could not find device: " << device_name;
    155     return nullptr;
    156   }
    157   return iter->second.get();
    158 }
    159 
    160 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
    161     const string& function_key, const string& device_name,
    162     FunctionLibraryRuntime::LocalHandle local_handle) {
    163   mutex_lock l(mu_);
    164   FunctionLibraryRuntime::Handle h =
    165       gtl::FindWithDefault(table_, function_key, kInvalidHandle);
    166   if (h != kInvalidHandle) {
    167     if (function_data_.count(h) != 0) return h;
    168   }
    169   h = next_handle_;
    170   function_data_.insert({h, FunctionData(device_name, local_handle)});
    171   table_[function_key] = h;
    172   next_handle_++;
    173   return h;
    174 }
    175 
    176 FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
    177     const string& function_key) const {
    178   mutex_lock l(mu_);
    179   FunctionLibraryRuntime::Handle h =
    180       gtl::FindWithDefault(table_, function_key, kInvalidHandle);
    181   if (h != kInvalidHandle) {
    182     if (function_data_.count(h) == 0) return kInvalidHandle;
    183   }
    184   return h;
    185 }
    186 
    187 bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
    188     const string& device_name, FunctionLibraryRuntime::Handle handle) {
    189   return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
    190 }
    191 
    192 FunctionLibraryRuntime::LocalHandle
    193 ProcessFunctionLibraryRuntime::GetHandleOnDevice(
    194     const string& device_name, FunctionLibraryRuntime::Handle handle) {
    195   mutex_lock l(mu_);
    196   if (function_data_.count(handle) == 0) {
    197     return kInvalidLocalHandle;
    198   }
    199   const FunctionData& function_data = function_data_[handle];
    200   if (function_data.target_device != device_name) {
    201     return kInvalidLocalHandle;
    202   }
    203   return function_data.local_handle;
    204 }
    205 
    206 string ProcessFunctionLibraryRuntime::GetDeviceName(
    207     FunctionLibraryRuntime::Handle handle) {
    208   mutex_lock l(mu_);
    209   CHECK_EQ(1, function_data_.count(handle));
    210   const FunctionData& function_data = function_data_[handle];
    211   return function_data.target_device;
    212 }
    213 
    214 Status ProcessFunctionLibraryRuntime::Instantiate(
    215     const string& function_name, AttrSlice attrs,
    216     const FunctionLibraryRuntime::InstantiateOptions& options,
    217     FunctionLibraryRuntime::Handle* handle) {
    218   *handle = kInvalidHandle;
    219   FunctionLibraryRuntime* flr = GetFLR(options.target);
    220   if (flr != nullptr) {
    221     return flr->Instantiate(function_name, attrs, options, handle);
    222   }
    223   if (parent_ == nullptr) {
    224     return errors::Internal(
    225         "Currently don't support instantiating functions on device: ",
    226         options.target);
    227   }
    228   FunctionLibraryRuntime::Handle cluster_handle;
    229   TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs,
    230                                           options, &cluster_handle));
    231   string function_key = Canonicalize(function_name, attrs);
    232   *handle = AddHandle(function_key, options.target, cluster_handle);
    233   return Status::OK();
    234 }
    235 
    236 Status ProcessFunctionLibraryRuntime::RemoveHandle(
    237     FunctionLibraryRuntime::Handle handle) {
    238   mutex_lock l(mu_);
    239   function_data_.erase(handle);
    240   return Status::OK();
    241 }
    242 
    243 Status ProcessFunctionLibraryRuntime::ReleaseHandle(
    244     FunctionLibraryRuntime::Handle handle) {
    245   FunctionLibraryRuntime* flr = nullptr;
    246   string target_device;
    247   {
    248     mutex_lock l(mu_);
    249     CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle;
    250     target_device = function_data_[handle].target_device;
    251   }
    252   flr = GetFLR(target_device);
    253   if (flr != nullptr) {
    254     return flr->ReleaseHandle(handle);
    255   }
    256   return errors::InvalidArgument("Handle not found: ", handle);
    257 }
    258 
    259 void ProcessFunctionLibraryRuntime::Run(
    260     const FunctionLibraryRuntime::Options& opts,
    261     FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
    262     std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
    263   if (!opts.remote_execution) {
    264     done(errors::InvalidArgument(
    265         "ProcessFunctionLibraryRuntime::Run should only be called when there ",
    266         "is a remote execution."));
    267     return;
    268   }
    269 
    270   FunctionLibraryRuntime* flr = nullptr;
    271   string target_device;
    272   FunctionLibraryRuntime::LocalHandle local_handle;
    273   {
    274     mutex_lock l(mu_);
    275     if (function_data_.count(handle) == 0) {
    276       done(errors::NotFound("Handle: ", handle, " not found."));
    277       return;
    278     }
    279     target_device = function_data_[handle].target_device;
    280     local_handle = function_data_[handle].local_handle;
    281   }
    282   flr = GetFLR(target_device);
    283   if (flr != nullptr) {
    284     auto rendezvous = opts.rendezvous;
    285     string source_device = opts.source_device;
    286     DeviceContext* device_context;
    287     Status s = GetDeviceContext(source_device, &device_context);
    288     if (!s.ok()) {
    289       done(s);
    290       return;
    291     }
    292     int64 src_incarnation, target_incarnation;
    293     s = GetDeviceIncarnation(source_device, &src_incarnation);
    294     s.Update(GetDeviceIncarnation(target_device, &target_incarnation));
    295     if (!s.ok()) {
    296       done(s);
    297       return;
    298     }
    299 
    300     // Send the args over to the target device.
    301     s = SendTensors(source_device, target_device, "arg_", src_incarnation, args,
    302                     device_context, opts.args_alloc_attrs, rendezvous);
    303     if (!s.ok()) {
    304       done(s);
    305       return;
    306     }
    307     const std::vector<AllocatorAttributes>& rets_alloc_attrs =
    308         opts.rets_alloc_attrs;
    309     std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
    310     flr->Run(opts, handle, args, remote_rets,
    311              [source_device, target_device, target_incarnation, rendezvous,
    312               device_context, rets_alloc_attrs, remote_rets, rets,
    313               done](const Status& status) {
    314                if (!status.ok()) {
    315                  delete remote_rets;
    316                  done(status);
    317                  return;
    318                }
    319                int64 num_returns = remote_rets->size();
    320                delete remote_rets;
    321                // Now receive the return values from the target.
    322                ReceiveTensorsAsync(target_device, source_device, "ret_",
    323                                    target_incarnation, num_returns,
    324                                    device_context, rets_alloc_attrs, rendezvous,
    325                                    rets, done);
    326              });
    327     return;
    328   }
    329   if (parent_ != nullptr) {
    330     parent_->Run(opts, local_handle, args, rets, done);
    331     return;
    332   }
    333   done(errors::Internal("Could not find device"));
    334 }
    335 
    336 Status ProcessFunctionLibraryRuntime::Clone(
    337     Env* env, int graph_def_version, const OptimizerOptions& optimizer_options,
    338     CustomKernelCreator custom_kernel_creator,
    339     std::unique_ptr<FunctionLibraryDefinition>* out_lib_def,
    340     std::unique_ptr<ProcessFunctionLibraryRuntime>* out_pflr) {
    341   out_lib_def->reset(new FunctionLibraryDefinition(*lib_def_));
    342   out_pflr->reset(new ProcessFunctionLibraryRuntime(
    343       device_mgr_, env, graph_def_version, out_lib_def->get(),
    344       optimizer_options, std::move(custom_kernel_creator), parent_));
    345   return Status::OK();
    346 }
    347 
    348 }  // namespace tensorflow
    349