1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" 17 18 #include "tensorflow/compiler/xla/map_util.h" 19 #include "tensorflow/compiler/xla/ptr_util.h" 20 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_reachability.h" 23 24 namespace xla { 25 namespace gpu { 26 27 bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const { 28 return hlo_to_stream_number_.count(&hlo); 29 } 30 31 int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const { 32 return FindOrDie(hlo_to_stream_number_, &hlo); 33 } 34 35 void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo, 36 int stream_no) { 37 CHECK_GE(stream_no, 0); 38 if (stream_no >= stream_count_) { 39 stream_count_ = stream_no + 1; 40 } 41 InsertOrDie(&hlo_to_stream_number_, hlo, stream_no); 42 VLOG(2) << "Assign stream #" << stream_no << " to " << hlo->ToString(); 43 } 44 45 namespace { 46 47 // Returns whether the two HLOs can run concurrently, i.e., neither is a 48 // transitive consumer of the other. 49 bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b, 50 const HloReachabilityMap& reachability) { 51 return !reachability.IsConnected(&a, &b); 52 } 53 54 // Returns which existing stream to assign to `hlo`, or -1 if a stream is not 55 // needed. `stream_assignment` is the existing stream assignment for all 56 // instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that 57 // are topologically before `hlo`. 58 int ComputeStreamToAssign( 59 const HloInstruction& hlo, const StreamAssignment& stream_assignment, 60 const HloReachabilityMap& reachability, 61 const std::vector<const HloInstruction*>& seen_gemms) { 62 if (hlo.opcode() == HloOpcode::kParameter || 63 hlo.opcode() == HloOpcode::kConstant) { 64 // kParameter and kConstant do not need a thunk. 65 return -1; 66 } 67 68 if (hlo.GetModule() 69 ->config() 70 .debug_options() 71 .xla_gpu_disable_multi_streaming()) { 72 return 0; 73 } 74 75 if (!ImplementedAsGemm(hlo)) { 76 // If `hlo` is not implemented as a GEMM, keep it close to its operands to 77 // avoid excessive synchronization. 78 int stream_no = -1; 79 for (const auto* operand : hlo.operands()) { 80 if (stream_assignment.HasStreamAssigned(*operand)) { 81 stream_no = 82 std::max(stream_no, stream_assignment.StreamNumberForHlo(*operand)); 83 } 84 } 85 if (stream_no == -1) { 86 stream_no = 0; 87 } 88 return stream_no; 89 } 90 91 // Assign different streams to concurrent GEMMs. The code below uses a 92 // greedy approach. First, we compute as forbidden_stream_numbers the 93 // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign 94 // `hlo` a different stream. 95 std::set<int> forbidden_stream_numbers; 96 for (const auto* seen_gemm : seen_gemms) { 97 int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm); 98 if (!forbidden_stream_numbers.count(stream_no) && 99 CanRunConcurrently(*seen_gemm, hlo, reachability)) { 100 forbidden_stream_numbers.insert(stream_no); 101 } 102 } 103 104 for (int stream_no = 0; stream_no < stream_assignment.StreamCount(); 105 ++stream_no) { 106 if (!forbidden_stream_numbers.count(stream_no)) { 107 return stream_no; 108 } 109 } 110 return stream_assignment.StreamCount(); 111 } 112 113 } // namespace 114 115 std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) { 116 auto stream_assignment = MakeUnique<StreamAssignment>(); 117 const HloComputation& computation = *module.entry_computation(); 118 std::unique_ptr<HloReachabilityMap> reachability = 119 computation.ComputeReachability(); 120 std::vector<const HloInstruction*> seen_gemms; 121 for (const auto* hlo : computation.MakeInstructionPostOrder()) { 122 int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment, 123 *reachability, seen_gemms); 124 if (stream_no != -1) { 125 stream_assignment->AssignStreamToHlo(hlo, stream_no); 126 } 127 if (ImplementedAsGemm(*hlo)) { 128 seen_gemms.push_back(hlo); 129 } 130 } 131 return stream_assignment; 132 } 133 134 } // namespace gpu 135 } // namespace xla 136