1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" 17 #include "tensorflow/compiler/xla/array2d.h" 18 #include "tensorflow/compiler/xla/map_util.h" 19 #include "tensorflow/compiler/xla/types.h" 20 21 namespace xla { 22 namespace gpu { 23 24 void ThunkSchedule::AddDependenciesOnTransitiveOperands( 25 const Thunk& thunk, const HloInstruction& operand, 26 const std::unordered_map<const HloInstruction*, Thunk*>& hlo_to_thunk) { 27 if (hlo_to_thunk.count(&operand)) { 28 // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency 29 // list if `operand` is assigned to a different stream. As an optimization, 30 // we skip `operand`'s operands because `operand` depends on them already. 31 if (stream_assignment_->StreamNumberForHlo(operand) != 32 stream_assignment_->StreamNumberForHlo(*thunk.hlo_instruction())) { 33 depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand)); 34 } 35 } else { 36 // If `operand` doesn't need a thunk (e.g. bitcast), continue with its 37 // operands. 38 for (const auto* operand_of_operand : operand.operands()) { 39 AddDependenciesOnTransitiveOperands(thunk, *operand_of_operand, 40 hlo_to_thunk); 41 } 42 } 43 } 44 45 ThunkSchedule::ThunkSchedule( 46 std::unique_ptr<ThunkSequence> thunks, 47 std::unique_ptr<StreamAssignment> stream_assignment, 48 const std::vector<const HloInstruction*>& hlo_total_order) 49 : thunks_(std::move(thunks)), 50 stream_assignment_(std::move(stream_assignment)) { 51 std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk; 52 for (const auto& thunk : *thunks_) { 53 InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); 54 } 55 56 for (const HloInstruction* hlo : hlo_total_order) { 57 if (hlo_to_thunk.count(hlo)) { 58 thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo)); 59 } 60 } 61 62 for (const Thunk* thunk : thunk_total_order_) { 63 const auto* dst = thunk->hlo_instruction(); 64 CHECK(stream_assignment_->HasStreamAssigned(*dst)); 65 for (const auto* src : dst->operands()) { 66 AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk); 67 } 68 } 69 70 RemoveRedundantDependencyEdges(); 71 72 // Compute `depended_by_`, the inverse of `depends_on_`. 73 for (const auto& dependency : depends_on_) { 74 for (const auto* depended : dependency.second) { 75 depended_by_.insert(depended); 76 } 77 } 78 } 79 80 void ThunkSchedule::RemoveRedundantDependencyEdges() { 81 std::unordered_map<const Thunk*, int> thunk_to_total_order; 82 for (int i = 0; i < thunk_total_order_.size(); ++i) { 83 InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i); 84 } 85 86 int stream_count = stream_assignment_->StreamCount(); 87 // S1 S2 88 // 89 // T1<----+ 90 // | 91 // T3<--+ | 92 // | | depends on 93 // T4 | 94 // | 95 // T2-+ 96 // 97 // Suppose thunk T1 and T3 are scheduled on stream S1, and T2 and T4 are on 98 // stream S2. If T2 depends on T1 and T4 depends on T3, and 99 // order(T1)<order(T3)<order(T4)<order(T2), the dependency of T2 on T1 is 100 // redundant. 101 // 102 // To efficiently detect such redundancy, we leverage array `last_dependency`. 103 // last_dependency[S1][S2] indicates the last thunk (with the maximum order 104 // number) on stream S2 that thunks on S1 depends on. Therefore, if a future 105 // S1 thunk depends on a S2 thunk ordered <=last_dependency[S1][S2], that is a 106 // redundant dependency edge. 107 Array2D<int> last_dependency(stream_count, stream_count, -1); 108 for (const Thunk* dst : thunk_total_order_) { 109 if (!depends_on_.count(dst)) { 110 continue; 111 } 112 113 int dst_stream = 114 stream_assignment_->StreamNumberForHlo(*dst->hlo_instruction()); 115 std::list<const Thunk*>& sources = FindOrDie(depends_on_, dst); 116 for (auto iter = sources.begin(); iter != sources.end();) { 117 const Thunk* src = *iter; 118 // `dst` depends on `src`. 119 int src_stream = 120 stream_assignment_->StreamNumberForHlo(*src->hlo_instruction()); 121 int src_order = FindOrDie(thunk_to_total_order, src); 122 if (src_order <= last_dependency(dst_stream, src_stream)) { 123 iter = sources.erase(iter); 124 } else { 125 last_dependency(dst_stream, src_stream) = src_order; 126 ++iter; 127 } 128 } 129 if (sources.empty()) { 130 depends_on_.erase(dst); 131 } 132 } 133 } 134 135 const std::list<const Thunk*>& ThunkSchedule::DependsOn( 136 const Thunk* thunk) const { 137 if (depends_on_.count(thunk)) { 138 return FindOrDie(depends_on_, thunk); 139 } else { 140 return empty_thunk_list_; 141 } 142 } 143 144 string ThunkSchedule::ToString() const { 145 string result = "Total order:\n"; 146 for (Thunk* thunk : thunk_total_order_) { 147 tensorflow::strings::StrAppend(&result, "\t", 148 thunk->hlo_instruction()->ToString(), "\n"); 149 } 150 tensorflow::strings::StrAppend(&result, "Dependencies:\n"); 151 for (const auto& entry : depends_on_) { 152 const Thunk* dependent = entry.first; 153 for (const Thunk* dependency : entry.second) { 154 tensorflow::strings::StrAppend( 155 &result, "\t", dependent->hlo_instruction()->name(), " depends on ", 156 dependency->hlo_instruction()->name(), "\n"); 157 } 158 } 159 return result; 160 } 161 162 } // namespace gpu 163 } // namespace xla 164