Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
     17 #include "tensorflow/compiler/xla/array2d.h"
     18 #include "tensorflow/compiler/xla/map_util.h"
     19 #include "tensorflow/compiler/xla/types.h"
     20 
     21 namespace xla {
     22 namespace gpu {
     23 
     24 void ThunkSchedule::AddDependenciesOnTransitiveOperands(
     25     const Thunk& thunk, const HloInstruction& operand,
     26     const std::unordered_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
     27   if (hlo_to_thunk.count(&operand)) {
     28     // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency
     29     // list if `operand` is assigned to a different stream. As an optimization,
     30     // we skip `operand`'s operands because `operand` depends on them already.
     31     if (stream_assignment_->StreamNumberForHlo(operand) !=
     32         stream_assignment_->StreamNumberForHlo(*thunk.hlo_instruction())) {
     33       depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand));
     34     }
     35   } else {
     36     // If `operand` doesn't need a thunk (e.g. bitcast), continue with its
     37     // operands.
     38     for (const auto* operand_of_operand : operand.operands()) {
     39       AddDependenciesOnTransitiveOperands(thunk, *operand_of_operand,
     40                                           hlo_to_thunk);
     41     }
     42   }
     43 }
     44 
     45 ThunkSchedule::ThunkSchedule(
     46     std::unique_ptr<ThunkSequence> thunks,
     47     std::unique_ptr<StreamAssignment> stream_assignment,
     48     const std::vector<const HloInstruction*>& hlo_total_order)
     49     : thunks_(std::move(thunks)),
     50       stream_assignment_(std::move(stream_assignment)) {
     51   std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
     52   for (const auto& thunk : *thunks_) {
     53     InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
     54   }
     55 
     56   for (const HloInstruction* hlo : hlo_total_order) {
     57     if (hlo_to_thunk.count(hlo)) {
     58       thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
     59     }
     60   }
     61 
     62   for (const Thunk* thunk : thunk_total_order_) {
     63     const auto* dst = thunk->hlo_instruction();
     64     CHECK(stream_assignment_->HasStreamAssigned(*dst));
     65     for (const auto* src : dst->operands()) {
     66       AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk);
     67     }
     68   }
     69 
     70   RemoveRedundantDependencyEdges();
     71 
     72   // Compute `depended_by_`, the inverse of `depends_on_`.
     73   for (const auto& dependency : depends_on_) {
     74     for (const auto* depended : dependency.second) {
     75       depended_by_.insert(depended);
     76     }
     77   }
     78 }
     79 
     80 void ThunkSchedule::RemoveRedundantDependencyEdges() {
     81   std::unordered_map<const Thunk*, int> thunk_to_total_order;
     82   for (int i = 0; i < thunk_total_order_.size(); ++i) {
     83     InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i);
     84   }
     85 
     86   int stream_count = stream_assignment_->StreamCount();
     87   // S1  S2
     88   //
     89   // T1<----+
     90   //        |
     91   // T3<--+ |
     92   //      | | depends on
     93   //     T4 |
     94   //        |
     95   //     T2-+
     96   //
     97   // Suppose thunk T1 and T3 are scheduled on stream S1, and T2 and T4 are on
     98   // stream S2. If T2 depends on T1 and T4 depends on T3, and
     99   // order(T1)<order(T3)<order(T4)<order(T2), the dependency of T2 on T1 is
    100   // redundant.
    101   //
    102   // To efficiently detect such redundancy, we leverage array `last_dependency`.
    103   // last_dependency[S1][S2] indicates the last thunk (with the maximum order
    104   // number) on stream S2 that thunks on S1 depends on. Therefore, if a future
    105   // S1 thunk depends on a S2 thunk ordered <=last_dependency[S1][S2], that is a
    106   // redundant dependency edge.
    107   Array2D<int> last_dependency(stream_count, stream_count, -1);
    108   for (const Thunk* dst : thunk_total_order_) {
    109     if (!depends_on_.count(dst)) {
    110       continue;
    111     }
    112 
    113     int dst_stream =
    114         stream_assignment_->StreamNumberForHlo(*dst->hlo_instruction());
    115     std::list<const Thunk*>& sources = FindOrDie(depends_on_, dst);
    116     for (auto iter = sources.begin(); iter != sources.end();) {
    117       const Thunk* src = *iter;
    118       // `dst` depends on `src`.
    119       int src_stream =
    120           stream_assignment_->StreamNumberForHlo(*src->hlo_instruction());
    121       int src_order = FindOrDie(thunk_to_total_order, src);
    122       if (src_order <= last_dependency(dst_stream, src_stream)) {
    123         iter = sources.erase(iter);
    124       } else {
    125         last_dependency(dst_stream, src_stream) = src_order;
    126         ++iter;
    127       }
    128     }
    129     if (sources.empty()) {
    130       depends_on_.erase(dst);
    131     }
    132   }
    133 }
    134 
    135 const std::list<const Thunk*>& ThunkSchedule::DependsOn(
    136     const Thunk* thunk) const {
    137   if (depends_on_.count(thunk)) {
    138     return FindOrDie(depends_on_, thunk);
    139   } else {
    140     return empty_thunk_list_;
    141   }
    142 }
    143 
    144 string ThunkSchedule::ToString() const {
    145   string result = "Total order:\n";
    146   for (Thunk* thunk : thunk_total_order_) {
    147     tensorflow::strings::StrAppend(&result, "\t",
    148                                    thunk->hlo_instruction()->ToString(), "\n");
    149   }
    150   tensorflow::strings::StrAppend(&result, "Dependencies:\n");
    151   for (const auto& entry : depends_on_) {
    152     const Thunk* dependent = entry.first;
    153     for (const Thunk* dependency : entry.second) {
    154       tensorflow::strings::StrAppend(
    155           &result, "\t", dependent->hlo_instruction()->name(), " depends on ",
    156           dependency->hlo_instruction()->name(), "\n");
    157     }
    158   }
    159   return result;
    160 }
    161 
    162 }  // namespace gpu
    163 }  // namespace xla
    164