Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/core/common_runtime/constant_folding.h"
     21 #include "tensorflow/core/common_runtime/debugger_state_interface.h"
     22 #include "tensorflow/core/common_runtime/device.h"
     23 #include "tensorflow/core/common_runtime/device_mgr.h"
     24 #include "tensorflow/core/common_runtime/function.h"
     25 #include "tensorflow/core/common_runtime/graph_optimizer.h"
     26 #include "tensorflow/core/common_runtime/memory_types.h"
     27 #include "tensorflow/core/common_runtime/optimization_registry.h"
     28 #include "tensorflow/core/common_runtime/process_util.h"
     29 #include "tensorflow/core/common_runtime/rendezvous_util.h"
     30 #include "tensorflow/core/common_runtime/step_stats_collector.h"
     31 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
     32 #include "tensorflow/core/framework/cancellation.h"
     33 #include "tensorflow/core/framework/log_memory.h"
     34 #include "tensorflow/core/framework/node_def.pb.h"
     35 #include "tensorflow/core/framework/node_def_util.h"
     36 #include "tensorflow/core/framework/versions.pb.h"
     37 #include "tensorflow/core/graph/graph.h"
     38 #include "tensorflow/core/graph/graph_constructor.h"
     39 #include "tensorflow/core/graph/graph_partition.h"
     40 #include "tensorflow/core/graph/validate.h"
     41 #include "tensorflow/core/lib/core/errors.h"
     42 #include "tensorflow/core/lib/strings/stringprintf.h"
     43 #include "tensorflow/core/platform/env.h"
     44 #include "tensorflow/core/platform/logging.h"
     45 #include "tensorflow/core/platform/mutex.h"
     46 #include "tensorflow/core/platform/types.h"
     47 #include "tensorflow/core/protobuf/worker.pb.h"
     48 #include "tensorflow/core/util/env_var.h"
     49 
     50 namespace tensorflow {
     51 
     52 GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
     53     : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
     54   // The default value of sync_on_finish will be flipped soon and this
     55   // environment variable will be removed as well.
     56   Status status =
     57       ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
     58   if (!status.ok()) {
     59     LOG(ERROR) << status.error_message();
     60   }
     61 }
     62 
     63 GraphMgr::~GraphMgr() {
     64   for (auto p : table_) p.second->Unref();
     65 }
     66 
     67 GraphMgr::Item::~Item() {
     68   for (const auto& unit : this->units) {
     69     CHECK_NOTNULL(unit.device);
     70     if (!graph_mgr->skip_cost_models_) {
     71       graph_mgr->cost_model_manager_.RemoveCostModelForGraph(unit.graph);
     72     }
     73     delete unit.root;
     74     unit.device->op_segment()->RemoveHold(this->session);
     75   }
     76 }
     77 
     78 // NOTE: node->device_name() is not set by GraphConstructor.  We
     79 // expects that NodeDef in GraphDef given to workers fully specifies
     80 // device names.
     81 static string SplitByDevice(const Node* node) {
     82   return node->assigned_device_name();
     83 }
     84 
     85 // Validates "gdef" device specifications.
     86 static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
     87   DeviceNameUtils::ParsedName parsed;
     88   for (const auto& ndef : gdef.node()) {
     89     if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
     90       return errors::InvalidArgument("Missing device name in: ",
     91                                      SummarizeNodeDef(ndef));
     92     }
     93   }
     94   return Status::OK();
     95 }
     96 
     97 Status GraphMgr::DecorateAndPublishGraphForDebug(
     98     const DebugOptions& debug_options, Graph* graph, Device* device) {
     99   std::unique_ptr<DebugGraphDecoratorInterface> decorator;
    100   TF_RETURN_IF_ERROR(
    101       DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
    102   TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
    103   TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
    104   return Status::OK();
    105 }
    106 
    107 // Creates executors given a graph definition "gdef" of a "session".
    108 // If a node in "gdef" is shared by other graphs in "session", the
    109 // same op kernel is reused. E.g., typically a params node is shared
    110 // by multiple graphs in a session.
    111 //
    112 // If "gdef" is assigned to multiple devices, extra nodes (e.g.,
    113 // send/recv nodes) maybe added. The extra nodes' name are generated
    114 // by calling "new_name(old_name)".
    115 //
    116 // "executors" are filled with one executor per device if success and
    117 // the caller takes the ownership of returned executors.
    118 Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
    119                           const GraphOptions& graph_options,
    120                           const DebugOptions& debug_options,
    121                           DistributedFunctionLibraryRuntime* cluster_flr,
    122                           Item* item) {
    123   item->session = session;
    124   item->lib_def.reset(
    125       new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
    126 
    127   TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
    128 
    129   if (gdef.versions().producer() >= 5) {
    130     // Validate the graph: we assume that merging two valid graphs
    131     // should maintain graph validity.
    132     TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *item->lib_def));
    133   }
    134 
    135   item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
    136       device_mgr_, worker_env_->env, gdef.versions().producer(),
    137       item->lib_def.get(), graph_options.optimizer_options(), cluster_flr));
    138 
    139   // Constructs the graph out of "gdef".
    140   Graph graph(OpRegistry::Global());
    141   GraphConstructorOptions opts;
    142   opts.allow_internal_ops = true;
    143   opts.expect_device_spec = true;
    144   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
    145 
    146   // Splits "graph" into multiple subgraphs by device names.
    147   std::unordered_map<string, GraphDef> partitions;
    148   PartitionOptions popts;
    149   popts.node_to_loc = SplitByDevice;
    150   popts.new_name = [this](const string& prefix) {
    151     mutex_lock l(mu_);
    152     return strings::StrCat(prefix, "_G", next_id_++);
    153   };
    154   popts.get_incarnation = [this](const string& name) -> int64 {
    155     Device* device = nullptr;
    156     Status s = device_mgr_->LookupDevice(name, &device);
    157     if (s.ok()) {
    158       return device->attributes().incarnation();
    159     } else {
    160       return PartitionOptions::kIllegalIncarnation;
    161     }
    162   };
    163   popts.flib_def = &graph.flib_def();
    164   popts.control_flow_added = true;
    165   popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
    166   TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
    167   if (popts.scheduling_for_recvs) {
    168     TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
    169   }
    170 
    171   std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
    172   for (const auto& partition : partitions) {
    173     std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
    174     GraphConstructorOptions device_opts;
    175     // There are internal operations (e.g., send/recv) that we now allow.
    176     device_opts.allow_internal_ops = true;
    177     device_opts.expect_device_spec = true;
    178     TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second,
    179                                               device_graph.get()));
    180     partition_graphs.emplace(partition.first, std::move(device_graph));
    181   }
    182 
    183   GraphOptimizationPassOptions optimization_options;
    184   optimization_options.flib_def = item->lib_def.get();
    185   optimization_options.partition_graphs = &partition_graphs;
    186   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
    187       OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
    188 
    189   LocalExecutorParams params;
    190 
    191   item->units.reserve(partitions.size());
    192   item->graph_mgr = this;
    193   const auto& optimizer_opts = graph_options.optimizer_options();
    194   GraphOptimizer optimizer(optimizer_opts);
    195   for (auto& p : partition_graphs) {
    196     const string& device_name = p.first;
    197     std::unique_ptr<Graph>& subgraph = p.second;
    198     item->units.resize(item->units.size() + 1);
    199     ExecutionUnit* unit = &(item->units.back());
    200 
    201     // Find the device.
    202     Status s = device_mgr_->LookupDevice(device_name, &unit->device);
    203     if (!s.ok()) {
    204       // Remove the empty unit from the item as the item destructor wants all
    205       // units to have valid devices.
    206       item->units.pop_back();
    207       return s;
    208     }
    209 
    210     // Give the device an opportunity to rewrite its subgraph.
    211     TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
    212 
    213     // Top-level nodes in the graph uses the op segment to cache
    214     // kernels. Therefore, as long as the executor is alive, we need
    215     // to ensure the kernels cached for the session are alive.
    216     auto opseg = unit->device->op_segment();
    217     opseg->AddHold(session);
    218 
    219     // Function library runtime.
    220     FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
    221     if (lib == nullptr) {
    222       return errors::InvalidArgument("Cannot find FLR for device: ",
    223                                      unit->device->name());
    224     }
    225 
    226     // Construct the root executor for the subgraph.
    227     params.device = unit->device;
    228     params.function_library = lib;
    229     params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
    230                                                  OpKernel** kernel) {
    231       // We do not share the kernel via the OpSegment if the node is
    232       // stateless, or a function.
    233       // NOTE(mrry): We must not share function kernels (implemented
    234       // using `CallOp`) between subgraphs, because `CallOp::handle_`
    235       // is tied to a particular subgraph. Even if the function itself
    236       // is stateful, the `CallOp` that invokes it is not.
    237       if (!lib->IsStateful(ndef.op()) ||
    238           lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
    239         return lib->CreateKernel(ndef, kernel);
    240       }
    241       auto create_fn = [lib, &ndef](OpKernel** kernel) {
    242         return lib->CreateKernel(ndef, kernel);
    243       };
    244       // Kernels created for subgraph nodes need to be cached.  On
    245       // cache miss, create_fn() is invoked to create a kernel based
    246       // on the function library here + global op registry.
    247       return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
    248     };
    249     params.delete_kernel = [lib](OpKernel* kernel) {
    250       // If the node is stateful, opseg owns it. Otherwise, delete it.
    251       if (kernel && !lib->IsStateful(kernel->type_string())) {
    252         delete kernel;
    253       }
    254     };
    255 
    256     optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
    257                        /*shape_map=*/nullptr);
    258 
    259     // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph.
    260     if (!debug_options.debug_tensor_watch_opts().empty()) {
    261       TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
    262           debug_options, subgraph.get(), params.device));
    263     }
    264 
    265     TF_RETURN_IF_ERROR(
    266         EnsureMemoryTypes(DeviceType(unit->device->device_type()),
    267                           unit->device->name(), subgraph.get()));
    268     unit->graph = subgraph.get();
    269     unit->build_cost_model = graph_options.build_cost_model();
    270     if (unit->build_cost_model > 0) {
    271       skip_cost_models_ = false;
    272     }
    273     TF_RETURN_IF_ERROR(
    274         NewLocalExecutor(params, std::move(subgraph), &unit->root));
    275   }
    276   return Status::OK();
    277 }
    278 
    279 Status GraphMgr::Register(const string& session, const GraphDef& gdef,
    280                           const GraphOptions& graph_options,
    281                           const DebugOptions& debug_options,
    282                           DistributedFunctionLibraryRuntime* cluster_flr,
    283                           string* handle) {
    284   Item* item = new Item;
    285   Status s =
    286       InitItem(session, gdef, graph_options, debug_options, cluster_flr, item);
    287   if (!s.ok()) {
    288     item->Unref();
    289     return s;
    290   }
    291 
    292   // Inserts one item into table_.
    293   {
    294     mutex_lock l(mu_);
    295     *handle = strings::Printf("%016llx", ++next_id_);
    296     item->handle = *handle;
    297     CHECK(table_.insert({*handle, item}).second);
    298   }
    299   return Status::OK();
    300 }
    301 
    302 Status GraphMgr::Deregister(const string& handle) {
    303   Item* item = nullptr;
    304   // Removes one item from table_.
    305   {
    306     mutex_lock l(mu_);
    307     auto iter = table_.find(handle);
    308     if (iter == table_.end()) {
    309       return errors::Aborted("Graph handle is not found: ", handle,
    310                              ". Possibly, this worker just restarted.");
    311     }
    312     item = iter->second;
    313     table_.erase(iter);
    314   }
    315   item->Unref();
    316   return Status::OK();
    317 }
    318 
    319 Status GraphMgr::DeregisterAll() {
    320   std::vector<Item*> items;
    321   // Removes all items from table_.
    322   {
    323     mutex_lock l(mu_);
    324     for (const auto& entry : table_) {
    325       items.push_back(entry.second);
    326     }
    327     table_.clear();
    328   }
    329   for (auto item : items) {
    330     item->Unref();
    331   }
    332   return Status::OK();
    333 }
    334 
    335 Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
    336   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
    337   std::vector<string> keys;
    338   std::vector<Tensor> tensors_to_send;
    339   keys.reserve(in.size());
    340   tensors_to_send.reserve(in.size());
    341   for (const auto& p : in) {
    342     keys.push_back(p.first);
    343     tensors_to_send.push_back(p.second);
    344   }
    345   Status s =
    346       SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
    347   rendezvous->Unref();
    348   return s;
    349 }
    350 
    351 Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
    352   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
    353   Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
    354   rendezvous->Unref();
    355   return s;
    356 }
    357 
    358 void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
    359                                 StatusCallback done) {
    360   Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
    361   std::vector<string> keys;
    362   std::vector<Tensor>* received_keys = new std::vector<Tensor>;
    363   keys.reserve(out->size());
    364   received_keys->reserve(out->size());
    365   for (const auto& p : *out) {
    366     keys.push_back(p.first);
    367     received_keys->push_back(p.second);
    368   }
    369   RecvOutputsFromRendezvousAsync(
    370       rendezvous, nullptr, {}, keys, received_keys,
    371       [done, rendezvous, received_keys, out, keys](const Status s) {
    372         rendezvous->Unref();
    373         for (int i = 0; i < keys.size(); ++i) {
    374           (*out)[keys[i]] = (*received_keys)[i];
    375         }
    376         delete received_keys;
    377         done(s);
    378       });
    379 }
    380 
    381 void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
    382                             WorkerSession* session, const ExecutorOpts& opts,
    383                             StepStatsCollector* collector,
    384                             MutableRunGraphResponseWrapper* response,
    385                             CancellationManager* cancellation_manager,
    386                             const NamedTensors& in, StatusCallback done) {
    387   // Lookup an item. Holds one ref while executing.
    388   Item* item = nullptr;
    389   {
    390     mutex_lock l(mu_);
    391     auto iter = table_.find(handle);
    392     if (iter != table_.end()) {
    393       item = iter->second;
    394       item->Ref();
    395     }
    396   }
    397 
    398   if (item == nullptr) {
    399     done(errors::Aborted("Graph handle is not found: ", handle));
    400     return;
    401   }
    402 
    403   CostGraphDef* cost_graph = nullptr;
    404   if (response != nullptr) {
    405     cost_graph = response->mutable_cost_graph();
    406     if (opts.record_partition_graphs()) {
    407       for (const ExecutionUnit& unit : item->units) {
    408         GraphDef graph_def;
    409         unit.graph->ToGraphDef(&graph_def);
    410         response->AddPartitionGraph(graph_def);
    411       }
    412     }
    413   }
    414 
    415   RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
    416   Status s = rendezvous->Initialize(session);
    417 
    418   // Sends values specified by the caller.
    419   if (s.ok()) {
    420     std::vector<string> keys;
    421     std::vector<Tensor> tensors_to_send;
    422     keys.reserve(in.size());
    423     tensors_to_send.reserve(in.size());
    424     for (auto& p : in) {
    425       keys.push_back(p.first);
    426       tensors_to_send.push_back(p.second);
    427     }
    428     s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
    429   }
    430 
    431   if (!s.ok()) {
    432     done(s);
    433     item->Unref();
    434     rendezvous->Unref();
    435     return;
    436   }
    437 
    438   StartParallelExecutors(handle, step_id, item, rendezvous, collector,
    439                          cost_graph, cancellation_manager,
    440                          [this, item, rendezvous, done](const Status& s) {
    441                            done(s);
    442                            rendezvous->Unref();
    443                            item->Unref();
    444                          });
    445 }
    446 
    447 void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
    448                                       Item* item, Rendezvous* rendezvous,
    449                                       StepStatsCollector* collector,
    450                                       CostGraphDef* cost_graph,
    451                                       CancellationManager* cancellation_manager,
    452                                       StatusCallback done) {
    453   const int num_units = item->units.size();
    454   CHECK_GE(num_units, 1);
    455   ScopedStepContainer* step_container = new ScopedStepContainer(
    456       step_id,
    457       [this](const string& name) { device_mgr_->ClearContainers({name}); });
    458   // NOTE: Transfer one ref of rendezvous and item.
    459   ExecutorBarrier* barrier =
    460       new ExecutorBarrier(num_units, rendezvous,
    461                           [this, item, collector, cost_graph, step_container,
    462                            done](const Status& s) {
    463                             BuildCostModel(item, collector, cost_graph);
    464                             done(s);
    465                             delete step_container;
    466                           });
    467   Executor::Args args;
    468   {
    469     mutex_lock l(mu_);
    470     args.step_id = ++next_id_;
    471   }
    472   args.rendezvous = rendezvous;
    473   args.cancellation_manager = cancellation_manager;
    474   args.stats_collector = collector;
    475   args.step_container = step_container;
    476   args.sync_on_finish = sync_on_finish_;
    477   if (LogMemory::IsEnabled()) {
    478     LogMemory::RecordStep(args.step_id, handle);
    479   }
    480   thread::ThreadPool* pool = worker_env_->compute_pool;
    481   using std::placeholders::_1;
    482   // Line below is equivalent to this code, but does one less indirect call:
    483   //  args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
    484   auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
    485   for (const auto& unit : item->units) {
    486     // TODO(zhengxq): if the device picks its own threadpool, we need to assign
    487     //     less threads to the main compute pool by default.
    488     thread::ThreadPool* device_thread_pool =
    489         unit.device->tensorflow_device_thread_pool();
    490     if (!device_thread_pool) {
    491       args.runner = default_runner;
    492     } else {
    493       args.runner =
    494           std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
    495     }
    496     unit.root->RunAsync(args, barrier->Get());
    497   }
    498 }
    499 
    500 void GraphMgr::BuildCostModel(Item* item, StepStatsCollector* collector,
    501                               CostGraphDef* cost_graph) {
    502   if (collector && !skip_cost_models_) {
    503     // Build the cost model
    504     std::unordered_map<string, const Graph*> device_to_graph;
    505     for (const auto& unit : item->units) {
    506       if (unit.build_cost_model > 0) {
    507         device_to_graph[unit.device->name()] = unit.graph;
    508       }
    509     }
    510     collector->BuildCostModel(&cost_model_manager_, device_to_graph);
    511 
    512     if (cost_graph != nullptr) {
    513       for (const auto& unit : item->units) {
    514         cost_model_manager_.AddToCostGraphDef(unit.graph, cost_graph)
    515             .IgnoreError();
    516       }
    517     }
    518   }
    519 }
    520 
    521 }  // end namespace tensorflow
    522