Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/computation_tracker.h"
     17 
     18 #include <list>
     19 #include <string>
     20 #include <utility>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/ptr_util.h"
     24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     25 #include "tensorflow/compiler/xla/status_macros.h"
     26 #include "tensorflow/compiler/xla/types.h"
     27 #include "tensorflow/compiler/xla/util.h"
     28 #include "tensorflow/core/lib/strings/strcat.h"
     29 #include "tensorflow/core/lib/strings/stringprintf.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 
     32 using ::tensorflow::strings::Appendf;
     33 
     34 namespace xla {
     35 
     36 ComputationTracker::ComputationTracker() : next_computation_(1) {}
     37 
     38 ComputationHandle ComputationTracker::NewComputation(
     39     const string& computation_name) {
     40   tensorflow::mutex_lock lock(computation_mutex_);
     41   ComputationHandle computation_handle;
     42   int64 handle_value = next_computation_++;
     43   computation_handle.set_handle(handle_value);
     44   opaque_to_computation_[handle_value] =
     45       MakeUnique<UserComputation>(computation_name, computation_handle);
     46   return computation_handle;
     47 }
     48 
     49 StatusOr<ComputationHandle> ComputationTracker::LoadSessionModule(
     50     const SessionModule& session_module) {
     51   tensorflow::mutex_lock lock(computation_mutex_);
     52 
     53   // For each embedded computation, create a new computation based on its
     54   // serialized data, and place the mapping from the old computation handle to
     55   // the new computation handle.
     56 
     57   // Build a mapping from old embedded computation handles to new computation
     58   // handles. We build the ID mapping first since the embedded computations are
     59   // in no particular order and may refer to each other.
     60   std::map<int64, ComputationHandle> old_to_new;
     61   for (const SessionComputation& computation :
     62        session_module.embedded_computations()) {
     63     const int64 old_handle = computation.computation_handle().handle();
     64     if (!old_to_new.emplace(old_handle, AllocateHandle()).second) {
     65       return InvalidArgument("Duplicate embedded computation handle %lld",
     66                              old_handle);
     67     }
     68   }
     69 
     70   // Create a new computation from each serialized embedded computation.
     71   for (const SessionComputation& computation :
     72        session_module.embedded_computations()) {
     73     const int64 old_handle = computation.computation_handle().handle();
     74     const ComputationHandle& new_handle = old_to_new[old_handle];
     75     TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()],
     76                         UserComputation::MakeWithRemapping(
     77                             computation, new_handle, old_to_new));
     78   }
     79 
     80   // Finally, place the entry computation in the tracker with all of the
     81   // remappings populated from the above.
     82   const int64 old_handle = session_module.entry().computation_handle().handle();
     83   TF_ASSIGN_OR_RETURN(
     84       old_to_new[old_handle],
     85       LoadSessionComputation(session_module.entry(), &old_to_new));
     86   return old_to_new[old_handle];
     87 }
     88 
     89 StatusOr<std::unique_ptr<SessionModule>>
     90 ComputationTracker::SnapshotComputation(const ComputationHandle& computation) {
     91   TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation));
     92   const VersionedComputationHandle entry_versioned_handle =
     93       user_computation->GetVersionedHandle();
     94   std::set<VersionedComputationHandle> visited;
     95   std::list<VersionedComputationHandle> post_order;
     96   {
     97     tensorflow::mutex_lock lock(computation_mutex_);
     98     ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order);
     99   }
    100   auto session_module = MakeUnique<SessionModule>();
    101   *session_module->mutable_entry() =
    102       Resolve(entry_versioned_handle.handle)
    103           .ValueOrDie()
    104           ->CloneSessionComputation(entry_versioned_handle.version);
    105   for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) {
    106     *session_module->add_embedded_computations() =
    107         Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version);
    108   }
    109   return std::move(session_module);
    110 }
    111 
    112 StatusOr<UserComputation*> ComputationTracker::Resolve(
    113     const ComputationHandle& computation) const {
    114   tensorflow::mutex_lock lock(computation_mutex_);
    115   return ResolveInternal(computation);
    116 }
    117 
    118 ComputationHandle ComputationTracker::AllocateHandle() {
    119   int64 handle_value = next_computation_++;
    120   ComputationHandle result;
    121   result.set_handle(handle_value);
    122   return result;
    123 }
    124 
    125 StatusOr<ComputationHandle> ComputationTracker::LoadSessionComputation(
    126     const SessionComputation& session_computation,
    127     std::map<int64, ComputationHandle>* old_to_new) {
    128   TF_RET_CHECK(old_to_new != nullptr);
    129   const ComputationHandle new_handle = AllocateHandle();
    130   (*old_to_new)[session_computation.computation_handle().handle()] = new_handle;
    131   TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()],
    132                       UserComputation::MakeWithRemapping(
    133                           session_computation, new_handle, *old_to_new));
    134   return new_handle;
    135 }
    136 
    137 StatusOr<UserComputation*> ComputationTracker::ResolveInternal(
    138     const ComputationHandle& computation) const {
    139   auto it = opaque_to_computation_.find(computation.handle());
    140   if (it == opaque_to_computation_.end()) {
    141     return NotFound("computation handle not found: %lld", computation.handle());
    142   }
    143   UserComputation* user_computation = it->second.get();
    144   return user_computation;
    145 }
    146 
    147 void ComputationTracker::ComputeComputationPostOrder(
    148     const VersionedComputationHandle& versioned_handle,
    149     std::set<VersionedComputationHandle>* visited,
    150     std::list<VersionedComputationHandle>* post_order) const {
    151   if (visited->count(versioned_handle) > 0) {
    152     CHECK_EQ(1, visited->count(versioned_handle));
    153     return;
    154   }
    155 
    156   UserComputation* computation =
    157       ResolveInternal(versioned_handle.handle).ValueOrDie();
    158   std::vector<VersionedComputationHandle> embedded_handles =
    159       computation->GetEmbeddedComputations(versioned_handle.version);
    160 
    161   for (const auto& embedded_handle : embedded_handles) {
    162     ComputeComputationPostOrder(embedded_handle, visited, post_order);
    163   }
    164 
    165   visited->insert(versioned_handle);
    166   post_order->push_back(versioned_handle);
    167 }
    168 
    169 StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
    170     const VersionedComputationHandle& entry_handle,
    171     const HloModuleConfig& config,
    172     bool include_unreachable_instructions) const {
    173   tensorflow::mutex_lock lock(computation_mutex_);
    174 
    175   VLOG(1) << "BuildHloModule(" << entry_handle
    176           << ", include_unreachable_instructions="
    177           << include_unreachable_instructions << ")";
    178   XLA_VLOG_LINES(1, ToStringInternal());
    179 
    180   TF_ASSIGN_OR_RETURN(UserComputation * entry_computation,
    181                       ResolveInternal(entry_handle.handle));
    182 
    183   // Build a topological sort of the entry and any embedded computations as a
    184   // list. The root of the computation will be the last element in the list.
    185   std::set<VersionedComputationHandle> visited;
    186   std::list<VersionedComputationHandle> post_order;
    187   ComputeComputationPostOrder(entry_handle, &visited, &post_order);
    188 
    189   // Map from ComputationHandle value and computation version to HloComputation.
    190   std::map<VersionedComputationHandle, HloComputation*> hlo_computations;
    191 
    192   // The resolver lambda resolves VersionedHandles to embedded
    193   // HloComputation*. This is required by UserComputation::BuildHloComputation
    194   // when lowering calling operations (map, reduce etc).
    195   auto resolver = [&hlo_computations](
    196       const VersionedComputationHandle& versioned_handle) -> HloComputation* {
    197     CHECK_GT(hlo_computations.count(versioned_handle), 0);
    198     return hlo_computations.at(versioned_handle);
    199   };
    200 
    201   // Print the post-order list for this entry computation.
    202   if (VLOG_IS_ON(2)) {
    203     VLOG(2) << "Visiting UserComputations in post order:";
    204     for (const VersionedComputationHandle& versioned_handle : post_order) {
    205       VLOG(2) << "  " << versioned_handle;
    206     }
    207   }
    208 
    209   string module_name =
    210       tensorflow::strings::StrCat(entry_computation->name(), "_module");
    211   auto module = MakeUnique<HloModule>(module_name, entry_handle, config);
    212   for (auto versioned_handle : post_order) {
    213     UserComputation* computation =
    214         ResolveInternal(versioned_handle.handle).ValueOrDie();
    215 
    216     TF_ASSIGN_OR_RETURN(
    217         std::unique_ptr<HloComputation> hlo_computation,
    218         computation->BuildHloComputation(versioned_handle.version, resolver,
    219                                          config.debug_options(),
    220                                          include_unreachable_instructions));
    221 
    222     // Add the newly created computation to VersionedHandle-to-HloComputation
    223     // map.
    224     DCHECK_EQ(0, hlo_computations.count(versioned_handle));
    225     hlo_computations[versioned_handle] = hlo_computation.get();
    226 
    227     if (computation == entry_computation) {
    228       module->AddEntryComputation(std::move(hlo_computation));
    229     } else {
    230       module->AddEmbeddedComputation(std::move(hlo_computation));
    231     }
    232   }
    233 
    234   return std::move(module);
    235 }
    236 
    237 string ComputationTracker::ToString() const {
    238   tensorflow::mutex_lock lock(computation_mutex_);
    239   return ToStringInternal();
    240 }
    241 
    242 string ComputationTracker::ToStringInternal() const {
    243   string out;
    244   Appendf(&out, "ComputationTracker(%p):\n", this);
    245   for (const auto& handle_computation : opaque_to_computation_) {
    246     int64 handle = handle_computation.first;
    247     const std::unique_ptr<UserComputation>& computation =
    248         handle_computation.second;
    249     Appendf(&out, "  %4lld : %s \"%s\"\n", handle,
    250             computation->GetVersionedHandle().ToString().c_str(),
    251             computation->name().c_str());
    252   }
    253   return out;
    254 }
    255 
    256 }  // namespace xla
    257