Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
     16 
     17 #include <stddef.h>
     18 #include <algorithm>
     19 #include <unordered_map>
     20 #include <utility>
     21 
     22 #include "tensorflow/core/common_runtime/device_mgr.h"
     23 #include "tensorflow/core/framework/cancellation.h"
     24 #include "tensorflow/core/framework/device_attributes.pb.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/gtl/flatmap.h"
     29 #include "tensorflow/core/lib/strings/str_util.h"
     30 #include "tensorflow/core/lib/strings/strcat.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/util/device_name_utils.h"
     33 
     34 namespace tensorflow {
     35 
     36 void CollectiveParamResolverLocal::InstanceRec::WaitForOutMu(mutex_lock& lock) {
     37   while (!out_mu_available) out_cv.wait(lock);
     38 }
     39 
     40 CollectiveParamResolverLocal::CollectiveParamResolverLocal(
     41     const ConfigProto& config, const DeviceMgr* dev_mgr,
     42     DeviceResolverInterface* dev_resolver, const string& task_name)
     43     : nccl_(config.experimental().collective_nccl()),
     44       dev_mgr_(dev_mgr),
     45       dev_resolver_(dev_resolver),
     46       task_name_(task_name) {}
     47 
     48 void CollectiveParamResolverLocal::CompleteGroupAsync(
     49     const CompleteGroupRequest* request, CompleteGroupResponse* response,
     50     CancellationManager* cancel_mgr, const StatusCallback& done) {
     51   done(
     52       errors::Internal("CompleteGroup is not implemented by "
     53                        "CollectiveParamResolverLocal which is "
     54                        "intended only for non-distributed deployment."));
     55 }
     56 
     57 void CollectiveParamResolverLocal::CompleteGroupLocal(
     58     const string& device, CollectiveParams* cp, const GroupRecCallback& done) {
     59   VLOG(1) << "CompleteGroupLocal device=" << device << " cp: " << cp << ": "
     60           << cp->ToString();
     61   std::vector<StatusCallback> to_be_called;
     62   GroupRec* gr = nullptr;
     63   {
     64     mutex_lock l(group_mu_);
     65     auto it = group_table_.find(cp->group.group_key);
     66     if (it == group_table_.end()) {
     67       gr = new GroupRec;
     68       gr->group.group_key = cp->group.group_key;
     69       gr->group.group_size = cp->group.group_size;
     70       gr->group.device_type = cp->group.device_type;
     71       group_table_[gr->group.group_key].reset(gr);
     72       VLOG(2) << "New group_key=" << gr->group.group_key
     73               << " group_size=" << gr->group.group_size;
     74     } else {
     75       gr = it->second.get();
     76     }
     77   }
     78   Status status;
     79   {
     80     mutex_lock gr_lock(gr->mu);
     81     if (!gr->device_set.empty()) {
     82       // Check for consistency with existing GroupRec.
     83       if (cp->group.device_type != gr->group.device_type) {
     84         status = errors::Internal(
     85             "Collective Op ", cp->name, " is assigned to device ", device,
     86             " with type ", cp->group.device_type.type_string(),
     87             " and group_key ", cp->group.group_key, " but that group has type ",
     88             gr->group.device_type.type_string());
     89       } else if (cp->group.group_size != gr->group.group_size) {
     90         status = errors::Internal(
     91             "Collective Op ", cp->name, " has group_size ",
     92             cp->group.group_size, " and group_key", cp->group.group_key,
     93             " but that group has size ", gr->group.group_size);
     94       }
     95     }
     96     if (status.ok()) {
     97       // Insert device if not already present.
     98       auto it = gr->device_set.find(device);
     99       if (it == gr->device_set.end()) {
    100         if (gr->device_set.size() == gr->group.group_size) {
    101           // The group is already full.
    102           status = errors::Internal(
    103               "Collective Op ", cp->name, " is assigned to device ", device,
    104               " and group_key ", cp->group.group_key,
    105               " but that group doesn't contain that device.");
    106         } else {
    107           // This is a new device that has not yet joined the group.
    108           gr->device_set.insert(device);
    109           gr->device_list.push_back(device);
    110           DeviceNameUtils::ParsedName parsed_device;
    111           DeviceNameUtils::ParseFullName(device, &parsed_device);
    112           string task_name = strings::StrCat("/job:", parsed_device.job,
    113                                              "/replica:", parsed_device.replica,
    114                                              "/task:", parsed_device.task);
    115           gr->task_set.insert(task_name);
    116           gr->task_list.push_back(task_name);
    117           gr->group.num_tasks = static_cast<int32>(gr->task_set.size());
    118           VLOG(1) << "group_key=" << gr->group.group_key
    119                   << " group_size=" << gr->group.group_size
    120                   << " dev_set=" << gr->device_set.size();
    121         }
    122       }
    123     }
    124 
    125     if (status.ok()) {
    126       // If the group is not yet complete, queue to wait for it.
    127       VLOG(2) << "group_size " << gr->group.group_size << " set size "
    128               << gr->device_set.size() << " gr " << gr;
    129 
    130       if (gr->device_set.size() < gr->group.group_size) {
    131         gr->waiting.push_back(std::bind(done, std::placeholders::_1, gr));
    132         return;
    133       }
    134       CHECK_EQ(gr->device_set.size(), gr->group.group_size);
    135       if (!gr->waiting.empty()) {
    136         std::swap(to_be_called, gr->waiting);
    137       }
    138     }
    139   }
    140   done(status, gr);
    141   for (int i = 0; i < to_be_called.size(); ++i) {
    142     to_be_called[i](Status::OK());
    143   }
    144 }
    145 
    146 namespace {
    147 struct DevRec {
    148   string task;
    149   string device;
    150   int original_rank;
    151   int local_rank;
    152   int global_rank;
    153   const DeviceLocality* locality;
    154 };
    155 typedef std::unordered_map<string, DevRec> TaskDeviceMap;
    156 typedef std::unordered_map<string, TaskDeviceMap> GlobalDeviceMap;
    157 
    158 // Create a populated GlobalDeviceMap from CollInstanceParams and localities.
    159 GlobalDeviceMap BuildDevRecs(const CollInstanceParams& ip,
    160                              const std::vector<DeviceLocality>& localities) {
    161   GlobalDeviceMap gdm;
    162   CHECK_EQ(ip.device_names.size(), ip.task_names.size());
    163   CHECK_EQ(ip.device_names.size(), localities.size());
    164   for (int i = 0; i < ip.device_names.size(); ++i) {
    165     TaskDeviceMap& tdm = gdm[ip.task_names[i]];
    166     DevRec* dr = &tdm[ip.device_names[i]];
    167     dr->task = ip.task_names[i];
    168     dr->device = ip.device_names[i];
    169     dr->original_rank = i;
    170     dr->local_rank = 0;   // Will be populated later by OrderTaskDeviceMap.
    171     dr->global_rank = 0;  // Will be populated later by EstablishGlobalRank.
    172     dr->locality = &localities[i];
    173   }
    174   return gdm;
    175 }
    176 
    177 bool ParseRingOrder(const string& gpu_ring_order_str, TaskDeviceMap* tdm) {
    178   std::vector<int32> gpu_ring_order_vec;
    179   if (!str_util::SplitAndParseAsInts(gpu_ring_order_str, ',',
    180                                      &gpu_ring_order_vec)) {
    181     return false;
    182   }
    183   if (gpu_ring_order_vec.size() != tdm->size()) return false;
    184   // gpu id -> local rank
    185   gtl::FlatMap<int32, int32> gpu_ranks;
    186   for (int32 rank = 0; rank < static_cast<int32>(gpu_ring_order_vec.size());
    187        ++rank) {
    188     gpu_ranks[gpu_ring_order_vec[rank]] = rank;
    189   }
    190 
    191   for (auto& tdm_it : *tdm) {
    192     DeviceNameUtils::ParsedName parsed_name;
    193     DevRec* dr = &tdm_it.second;
    194     if (!DeviceNameUtils::ParseFullName(dr->device, &parsed_name)) {
    195       return false;
    196     }
    197     auto rank_it = gpu_ranks.find(parsed_name.id);
    198     if (rank_it == gpu_ranks.end()) return false;
    199     dr->local_rank = rank_it->second;
    200   }
    201   VLOG(2) << "Assigned local ranks based on ring order " << gpu_ring_order_str;
    202   return true;
    203 }
    204 
    205 void OrderTaskDeviceMap(const string& gpu_ring_order, TaskDeviceMap* tdm) {
    206   CHECK_GT(tdm->size(), 0);  // Should never be called with 0 devices
    207 
    208   // If a valid ring order has been passed in via ConfigProto, use that.
    209   if (ParseRingOrder(gpu_ring_order, tdm)) return;
    210 
    211   // Either no ring order was passed in, or the format was unexpected.
    212   // We now assign a ring order based on link strengths.  Note that this
    213   // algorithm is not optimal and may not always find the best ring order.
    214   int least_rank = -1;
    215   string next_device;
    216   std::set<string> selected;
    217   // Starting device is one with the least initial rank.
    218   for (const auto& it : *tdm) {
    219     if (least_rank < 0 || it.second.original_rank < least_rank) {
    220       least_rank = it.second.original_rank;
    221       next_device = it.second.device;
    222     }
    223   }
    224   CHECK_GE(least_rank, 0);
    225   DeviceNameUtils::ParsedName parsed_name;
    226   CHECK(DeviceNameUtils::ParseFullName(next_device, &parsed_name));
    227   // NOTE: InterconnectLink has only a device_id, nothing more, so for
    228   // the time being if there's more than one device at a task we
    229   // assume they're all GPUs.
    230 
    231   int next_rank = 0;
    232   while (true) {
    233     selected.insert(next_device);
    234     auto next_dev_it = tdm->find(next_device);
    235     CHECK(next_dev_it != tdm->end());
    236     DevRec* dr = &next_dev_it->second;
    237     dr->local_rank = next_rank;
    238     ++next_rank;
    239     if (selected.size() == tdm->size()) {
    240       break;
    241     }
    242     // For the present time we assume Locality links only cover GPUs.
    243     // For multiple CPUs, just take them in order.
    244     const InterconnectLink* best_link = nullptr;
    245     if (parsed_name.type == "GPU") {
    246       for (const InterconnectLink& il : dr->locality->links().link()) {
    247         parsed_name.id = il.device_id();
    248         string endpoint_device =
    249             DeviceNameUtils::ParsedNameToString(parsed_name);
    250         // Skip the device if we've already seen it.
    251         if (selected.find(endpoint_device) != selected.end()) {
    252           continue;
    253         }
    254         // Skip the device if it is not participating in this collective
    255         // instance.
    256         if (tdm->find(endpoint_device) == tdm->end()) {
    257           continue;
    258         }
    259         if (best_link == nullptr || il.strength() > best_link->strength()) {
    260           best_link = &il;
    261         }
    262       }
    263     }
    264     if (best_link != nullptr) {
    265       // Follow the best edge
    266       parsed_name.id = best_link->device_id();
    267       next_device = DeviceNameUtils::ParsedNameToString(parsed_name);
    268     } else {
    269       // No good edges, alas. Pick the lowest initial rank among remaining
    270       // devices.
    271       least_rank = -1;
    272       for (const auto& it : *tdm) {
    273         if (selected.find(it.second.device) != selected.end()) {
    274           continue;
    275         }
    276         if (least_rank < 0 || it.second.original_rank < least_rank) {
    277           least_rank = it.second.original_rank;
    278           next_device = it.second.device;
    279         }
    280       }
    281       CHECK_GE(least_rank, 0);
    282     }
    283   }
    284 }
    285 
    286 // The first time a shared CollectiveParams is established for a
    287 // shared set of instances we compute a good rank order for all the
    288 // devices in the group, that is appropriate for a ring algorithm.
    289 // This order need not be the same across different instance groups
    290 // sharing the same device group where there is more than one good
    291 // order.
    292 GlobalDeviceMap EstablishGlobalRank(
    293     CollectiveParams* cp, const std::vector<DeviceLocality>& localities) {
    294   VLOG(1) << "EstablishGlobalRank";
    295   GlobalDeviceMap gdm = BuildDevRecs(cp->instance, localities);
    296   for (auto& iter : gdm) {
    297     TaskDeviceMap& tdm = iter.second;
    298     OrderTaskDeviceMap(cp->instance.gpu_ring_order, &tdm);
    299   }
    300   // Connect the global rank order by the order in which tasks first appear.
    301   std::set<string> ordered_tasks;
    302   int next_rank = 0;
    303   for (int i = 0; i < cp->instance.task_names.size(); ++i) {
    304     const string& task_name = cp->instance.task_names[i];
    305     if (ordered_tasks.find(task_name) != ordered_tasks.end()) {
    306       continue;
    307     }
    308     ordered_tasks.insert(task_name);
    309     TaskDeviceMap* tdm = &gdm[task_name];
    310     for (auto& it : *tdm) {
    311       it.second.global_rank = it.second.local_rank + next_rank;
    312     }
    313     next_rank += tdm->size();
    314   }
    315   return gdm;
    316 }
    317 
    318 // Count the devices associated with each task and set
    319 // cp->same_num_devices_per_task.  Requires cp->instance.task_names
    320 // be sorted.
    321 void SetDevPerTask(CollectiveParams* cp) {
    322   cp->instance.num_devices_per_task.clear();
    323   const string* last_task_name = &cp->instance.task_names[0];
    324   int count = 0;
    325   for (const string& task_name : cp->instance.task_names) {
    326     if (task_name == *last_task_name) {
    327       ++count;
    328     } else {
    329       cp->instance.num_devices_per_task[*last_task_name] = count;
    330       count = 1;
    331       last_task_name = &task_name;
    332     }
    333   }
    334   cp->instance.num_devices_per_task[*last_task_name] = count;
    335 
    336   cp->instance.same_num_devices_per_task = false;
    337   int dev_per_task = -1;
    338   for (const auto& task_dev : cp->instance.num_devices_per_task) {
    339     if (dev_per_task == -1) {
    340       dev_per_task = task_dev.second;
    341     } else if (dev_per_task != task_dev.second) {
    342       return;
    343     }
    344   }
    345   cp->instance.same_num_devices_per_task = true;
    346   CHECK_EQ((cp->group.group_size % cp->group.num_tasks), 0);
    347 }
    348 
    349 // Sort cp->instance.device_names lexicographically, but do by first
    350 // computing a reordering permutation so we can keep cp->instance.task_names
    351 // in corresponding order.
    352 void SortDevicesAndTasks(CollectiveParams* cp) {
    353   VLOG(1) << "SortDevicesAndTasks " << cp << " instance " << &cp->instance;
    354   CHECK(cp);
    355   CHECK_EQ(cp->group.group_size, cp->instance.device_names.size());
    356   CHECK_EQ(cp->group.group_size, cp->instance.task_names.size());
    357   std::vector<int> perm(cp->group.group_size);
    358   // TODO(tucker): substitute std::iota when the windows build supports it.
    359   // std::iota(perm.begin(), perm.end(), 0);
    360   for (int i = 0; i < perm.size(); ++i) {
    361     perm[i] = i;
    362   }
    363   std::sort(perm.begin(), perm.end(), [cp](int a, int b) {
    364     return cp->instance.device_names[a] < cp->instance.device_names[b];
    365   });
    366   std::vector<string> new_devs;
    367   std::vector<string> new_tasks;
    368   new_devs.reserve(cp->group.group_size);
    369   new_tasks.reserve(cp->group.group_size);
    370   for (int pi : perm) {
    371     new_devs.push_back(cp->instance.device_names[pi]);
    372     new_tasks.push_back(cp->instance.task_names[pi]);
    373   }
    374   cp->instance.device_names = std::move(new_devs);
    375   cp->instance.task_names = std::move(new_tasks);
    376   VLOG(1) << "Modified device_names on " << cp;
    377   SetDevPerTask(cp);
    378 }
    379 }  // namespace
    380 
    381 void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
    382                                                        CollectiveParams* cp) {
    383   cp->task.is_local.resize(cp->group.group_size, false);
    384   for (int i = 0; i < cp->group.group_size; ++i) {
    385     cp->task.is_local[i] = (cp->instance.task_names[i] == task_name);
    386   }
    387 }
    388 
    389 void CollectiveParamResolverLocal::SetDefaultRank(const string& device,
    390                                                   CollectiveParams* cp) {
    391   CHECK_EQ(cp->group.group_size, cp->instance.device_names.size()) << cp;
    392   for (int i = 0; i < cp->group.group_size; ++i) {
    393     if (cp->instance.device_names[i] == device) {
    394       cp->default_rank = i;
    395       break;
    396     }
    397   }
    398 }
    399 
    400 void CollectiveParamResolverLocal::InitInstanceSharedParams(
    401     const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
    402     const StatusCallback& done) {
    403   ir->shared.instance = cp->instance;
    404   {
    405     mutex_lock gl(gr->mu);
    406     ir->shared.group = gr->group;
    407     ir->shared.instance.device_names.assign(gr->device_list.begin(),
    408                                             gr->device_list.end());
    409     ir->shared.instance.task_names.assign(gr->task_list.begin(),
    410                                           gr->task_list.end());
    411     VLOG(2) << "Initialized names for instance: "
    412             << ir->shared.instance.ToString();
    413   }
    414   ir->shared.default_rank = -1;
    415 
    416   // Sort device_names lexicographically, keeping task_names in corresponding
    417   // order.  Also set number of devices per task.
    418   SortDevicesAndTasks(&ir->shared);
    419 
    420   // Get Locality data for all devices.
    421 
    422   // Set is_local and task_names in *shared prior to invoking
    423   // GetDeviceLocalitiesAsync.  In a distributed context this function can be
    424   // called by a derived class, some of the devices may be non-local and
    425   // GetDeviceLocalitiesAsync will use those fields to launch RPCs.
    426   CompleteTaskIsLocal(task_name_, &ir->shared);
    427 
    428   // Because the callback may execute in a different thread, we release
    429   // ir->out_mu here.  Before releasing, we mark it as unavailable for other
    430   // threads.
    431   ir->out_mu_available = false;
    432   ir->out_mu.unlock();
    433   std::vector<DeviceLocality>* localities = new std::vector<DeviceLocality>;
    434   dev_resolver_->GetDeviceLocalitiesAsync(
    435       ir->shared.instance, localities,
    436       [this, gr, cp, ir, localities, done](const Status& s)
    437           EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
    438             // Then we recover the lock in the callback thread that will hold it
    439             // through the rest of the call chain.  Signal the cv now, any
    440             // waiting threads will wake only when out_mu is released later.
    441             ir->out_mu.lock();
    442             DCHECK(!ir->out_mu_available);
    443             ir->out_mu_available = true;
    444             ir->out_cv.notify_all();
    445             if (s.ok()) {
    446               CompleteDefaultRanking(gr, cp, ir, *localities);
    447               done(Status::OK());
    448             } else {
    449               done(s);
    450             }
    451             delete localities;
    452           });
    453 }
    454 
    455 // NOTE(ayushd): The DeviceLocality objects in localities will have LocalLinks
    456 // to all devices that they are physically connected to and visible to the
    457 // TensorFlow runtime.  This set of devices may be a superset of the devices
    458 // participating in this instance of collectives.
    459 void CollectiveParamResolverLocal::CompleteDefaultRanking(
    460     const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
    461     const std::vector<DeviceLocality>& localities) {
    462   // Establish an instance-specific default rank order for devices
    463   // based on localities.  This rank order should be a good ring
    464   // order, if possible.
    465   GlobalDeviceMap gdm = EstablishGlobalRank(&ir->shared, localities);
    466   // Reflect the new global ranking on shared
    467   size_t num_devices = ir->shared.group.group_size;
    468   std::vector<string> new_device_names(num_devices, "");
    469   std::vector<string> new_task_names(num_devices, "");
    470   for (const auto& git : gdm) {
    471     const TaskDeviceMap& tdm = git.second;
    472     for (const auto& tit : tdm) {
    473       const DevRec& dr = tit.second;
    474       new_device_names[dr.global_rank] =
    475           ir->shared.instance.device_names[dr.original_rank];
    476       new_task_names[dr.global_rank] =
    477           ir->shared.instance.task_names[dr.original_rank];
    478     }
    479   }
    480 
    481   ir->shared.instance.device_names = new_device_names;
    482   ir->shared.instance.task_names = new_task_names;
    483   if (VLOG_IS_ON(2)) {
    484     string buf;
    485     for (const auto& d : new_device_names) strings::StrAppend(&buf, "\n", d);
    486     VLOG(2) << "Optimized device order for " << ir->shared.name << ": " << buf;
    487   }
    488 }
    489 
    490 void CollectiveParamResolverLocal::CallbackWithStatus(
    491     const InstanceRecCallback& done, InstanceRec* irec) {
    492   Status s;
    493   {
    494     mutex_lock l(irec->out_mu);
    495     irec->WaitForOutMu(l);
    496     s = irec->status;
    497   }
    498   done(s, irec);
    499 }
    500 
    501 void CollectiveParamResolverLocal::FindInstanceRec(
    502     const GroupRec* gr, CollectiveParams* cp, const InstanceRecCallback& done) {
    503   InstanceRec* irec = nullptr;
    504   bool exit_outside_locks = false;
    505   {
    506     mutex_lock l(instance_mu_);
    507     auto it = instance_table_.find(cp->instance.instance_key);
    508     if (it != instance_table_.end()) {
    509       irec = it->second.get();
    510       {
    511         mutex_lock l(irec->in_mu);
    512         if (irec->is_init) {
    513           exit_outside_locks = true;
    514         } else {
    515           irec->init_waiters.push_back([this, done](InstanceRec* irec) {
    516             CallbackWithStatus(done, irec);
    517           });
    518           return;
    519         }
    520       }
    521     } else {
    522       // Create new InstanceRec.
    523       irec = new InstanceRec;
    524       instance_table_[cp->instance.instance_key].reset(irec);
    525     }
    526   }
    527   if (exit_outside_locks) {
    528     CallbackWithStatus(done, irec);
    529     return;
    530   }
    531 
    532   CallInitInstanceSharedParams(gr, cp, irec, done);
    533 }
    534 
    535 void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
    536     const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
    537     const InstanceRecCallback& done) NO_THREAD_SAFETY_ANALYSIS {
    538   // This function serves merely to make a function call that should
    539   // be thread/mutex safe but violates the simple model applied by
    540   // static analysis, so we turn off analysis only within this
    541   // function body.
    542   //
    543   // A lock on ir->out_mu must be held* throughout the _bodies_ of the
    544   // chain of function calls initiated here, each of which calls
    545   // another as its last action, but it will be dropped within the
    546   // callback defined below, which means that the lock can be dropped
    547   // before all the function stack frames pop. The static analysis will
    548   // not allow that.
    549   //
    550   // *the lock is dropped just before calling GetDeviceLocalitiesAsync, because
    551   // there is no guarantee that the thread that executes the callback is the
    552   // same as the one that locked ir->out_mu.  To prevent other threads from
    553   // grabbing ir->out_mu, we mark ir->out_mu_available as false.  Hence, in
    554   // principle, the lock is held throughout.
    555   ir->out_mu.lock();
    556   DCHECK(ir->out_mu_available);
    557   ir->known.resize(cp->group.group_size, false);
    558   InitInstanceSharedParams(
    559       gr, cp, ir,
    560       [this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) {
    561         DCHECK(ir->out_mu_available);
    562         ir->status.Update(s);
    563         ir->out_mu.unlock();
    564         // Prepare to invoke any waiters that accumulated during
    565         // initialization.
    566         std::vector<IRConsumer> init_waiters;
    567         {
    568           mutex_lock tl(instance_mu_);
    569           {
    570             mutex_lock l(ir->in_mu);
    571             ir->is_init = true;
    572             if (!ir->init_waiters.empty()) {
    573               std::swap(init_waiters, ir->init_waiters);
    574             }
    575           }
    576         }
    577         CallbackWithStatus(done, ir);
    578         for (auto& f : init_waiters) {
    579           f(ir);
    580         }
    581       });
    582 }
    583 
    584 void CollectiveParamResolverLocal::CompleteParamsAsync(
    585     const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
    586     const StatusCallback& done) {
    587   VLOG(1) << "CompleteParams local " << device << " for " << cp << ": "
    588           << cp->ToString();
    589   CompleteGroupLocal(
    590       device, cp,
    591       [this, device, cp, done](const Status& s, const GroupRec* gr) {
    592         if (s.ok()) {
    593           CompleteInstanceLocal(device, gr, cp, cp->is_source, done);
    594         } else {
    595           done(s);
    596         }
    597       });
    598 }
    599 
    600 void CollectiveParamResolverLocal::CompleteInstanceAsync(
    601     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
    602     CancellationManager* cancel_mgr, const StatusCallback& done) {
    603   done(
    604       errors::Internal("CompleteInstance is not implemented by "
    605                        "CollectiveParamResolverLocal which is "
    606                        "intended only for non-distributed deployment."));
    607 }
    608 
    609 // TODO(b/111897089): we need a better way to pick the collective
    610 // implementation.  The ideal way would depend upon the topology and link
    611 // strength before picking a particular implementation.
    612 void CollectiveParamResolverLocal::AssignCollectiveType(CollectiveParams* cp) {
    613   if (cp->instance.type == BROADCAST_COLLECTIVE) {
    614     cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
    615   } else if (cp->instance.type == REDUCTION_COLLECTIVE) {
    616     if (nccl_) {
    617       cp->instance.impl_details.collective_name = "NcclReduce";
    618     } else {
    619       cp->instance.impl_details.collective_name = "RingReduce";
    620     }
    621   } else if (cp->instance.type == GATHER_COLLECTIVE) {
    622     cp->instance.impl_details.collective_name = "RingGather";
    623   } else {
    624     cp->instance.impl_details.collective_name = "undef";
    625   }
    626   VLOG(1) << "AssignCollectiveType "
    627           << cp->instance.impl_details.collective_name;
    628 }
    629 
    630 void CollectiveParamResolverLocal::CompleteInstanceLocal(
    631     const string& device, const GroupRec* gr, CollectiveParams* cp,
    632     bool is_source, const StatusCallback& done) {
    633   VLOG(1) << "CompleteInstanceLocal " << device
    634           << " instance_key: " << cp->instance.instance_key << " gr " << gr;
    635 
    636   // Populate the group portion of *cp from *gr.  Most of it should already
    637   // match.
    638   DCHECK_EQ(cp->group.group_key, gr->group.group_key);
    639   DCHECK_EQ(cp->group.group_size, gr->group.group_size);
    640   DCHECK_EQ(cp->group.device_type, gr->group.device_type);
    641   cp->group = gr->group;
    642 
    643   // Get the shared InstanceRec for this instance.
    644   FindInstanceRec(gr, cp,
    645                   [this, device, gr, cp, is_source, done](const Status& s,
    646                                                           InstanceRec* ir) {
    647                     if (s.ok()) {
    648                       CompleteInstanceFromInitializedIRec(device, gr, cp, ir,
    649                                                           is_source, done);
    650                     } else {
    651                       done(s);
    652                     }
    653                   });
    654 }
    655 
    656 void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
    657     const string& device, const GroupRec* gr, CollectiveParams* cp,
    658     InstanceRec* ir, bool is_source, const StatusCallback& done) {
    659   // Populate the fields common across instance.
    660   {
    661     mutex_lock l(ir->out_mu);
    662     ir->WaitForOutMu(l);
    663     // custom operator= does a deep copy.
    664     cp->instance = ir->shared.instance;
    665   }
    666   // Populate the fields common across task.
    667   AssignCollectiveType(cp);
    668   SetDefaultRank(device, cp);
    669   CompleteTaskIsLocal(task_name_, cp);
    670 
    671   CollectiveImplementationInterface* col_impl;
    672   Status status = CollectiveRegistry::LookupParamResolverInstance(
    673       cp->instance.impl_details.collective_name, &col_impl);
    674   if (status.ok()) {
    675     status = col_impl->InitializeInstanceBeforeGroupDiscovery(cp);
    676   }
    677   if (!status.ok()) {
    678     done(status);
    679     return;
    680   }
    681 
    682   //  We may need to wait for the group if:
    683   //  * this is a broadcast, for source discovery;
    684   //  * we are using NCCL with more than 1 worker, for the communicator key from
    685   //    rank 0.
    686   bool broadcast = cp->instance.type == BROADCAST_COLLECTIVE;
    687   bool nccl = cp->instance.type == REDUCTION_COLLECTIVE &&
    688               cp->instance.impl_details.collective_name == "NcclReduce" &&
    689               cp->group.num_tasks > 1;
    690   if (broadcast || nccl) {
    691     WaitForGroup(ir, cp, is_source, broadcast, nccl,
    692                  [col_impl, ir, device, cp, done](InstanceRec* irec) {
    693                    Status s;
    694                    if (ir != irec) {
    695                      s = errors::Internal("Expected ir ", ir, " and irec ",
    696                                           irec, " to be equal");
    697                    } else {
    698                      mutex_lock l(irec->out_mu);
    699                      irec->WaitForOutMu(l);
    700                      s = irec->status;
    701                      cp->source_rank = irec->source_rank;
    702                      cp->instance.communicator_key = irec->communicator_key;
    703                    }
    704                    if (s.ok()) {
    705                      s = col_impl->InitializeCollectiveParams(cp);
    706                    }
    707                    done(s);
    708                  });
    709   } else {
    710     done(col_impl->InitializeCollectiveParams(cp));
    711   }
    712 }
    713 
    714 void CollectiveParamResolverLocal::WaitForGroup(
    715     InstanceRec* ir, CollectiveParams* cp, bool is_source, bool init_source,
    716     bool init_nccl, const IRConsumer& f) {
    717   std::vector<IRConsumer> ready_waiters;
    718   {
    719     mutex_lock l(ir->out_mu);
    720     ir->WaitForOutMu(l);
    721     CHECK_EQ(cp->group.group_size, ir->known.size());
    722     CHECK_GE(cp->default_rank, 0);
    723     if (!ir->known[cp->default_rank]) {
    724       ir->known[cp->default_rank] = true;
    725       ++ir->known_count;
    726       if (init_source && is_source) {
    727         // Initialize source rank.
    728         if (ir->source_rank >= 0) {
    729           ir->status = errors::Internal("Instance ", cp->instance.instance_key,
    730                                         " already has source ", ir->source_rank,
    731                                         ", received second claim from ",
    732                                         cp->default_rank);
    733         } else {
    734           ir->source_rank = cp->default_rank;
    735         }
    736       }
    737       if (init_nccl && cp->default_rank == 0) {
    738         // Initialize communicator key.
    739         if (!ir->communicator_key.empty()) {
    740           ir->status =
    741               errors::Internal("Instance ", cp->instance.instance_key,
    742                                " already has communicator_key ",
    743                                str_util::CEscape(ir->communicator_key),
    744                                ", received second claim from device ",
    745                                cp->instance.device_names[cp->default_rank]);
    746         } else {
    747           ir->communicator_key = cp->instance.communicator_key;
    748         }
    749       }
    750     }
    751     if (ir->known_count < ir->shared.group.group_size) {
    752       ir->known_waiters.push_back(f);
    753       return;
    754     }
    755     CHECK_EQ(ir->known_count, ir->shared.group.group_size);
    756     if (init_source && ir->source_rank < 0) {
    757       // NOTE(ayushd): changing the error message below would also require
    758       // updating CompleteParamsBroadcastForgotSend test in
    759       // CollectiveParamResolverLocalTest.
    760       ir->status =
    761           errors::Internal("Instance ", cp->instance.instance_key,
    762                            " found no source for broadcast.  This "
    763                            "could mean that there were group_size=",
    764                            ir->known_count, " BcastRecvs but no BcastSend.");
    765     }
    766     if (init_nccl && ir->communicator_key.empty()) {
    767       ir->status = errors::Internal(
    768           "Instance ", cp->instance.instance_key, " device ",
    769           cp->instance.device_names[cp->default_rank],
    770           " did not find rank 0 for setting communicator key.  This is an "
    771           "internal error in collective param resolution");
    772     }
    773     if (!ir->known_waiters.empty()) {
    774       ready_waiters = std::move(ir->known_waiters);
    775     }
    776   }
    777   f(ir);
    778   for (auto& f : ready_waiters) {
    779     f(ir);
    780   }
    781 }
    782 
    783 }  // namespace tensorflow
    784