Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
     17 
     18 #include "absl/container/flat_hash_set.h"
     19 #include "absl/memory/memory.h"
     20 #include "tensorflow/compiler/xla/map_util.h"
     21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     23 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
     24 
     25 namespace xla {
     26 namespace gpu {
     27 
     28 bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const {
     29   return hlo_to_stream_number_.contains(&hlo);
     30 }
     31 
     32 int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const {
     33   return FindOrDie(hlo_to_stream_number_, &hlo);
     34 }
     35 
     36 void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo,
     37                                          int stream_num) {
     38   CHECK_GE(stream_num, 0);
     39   if (stream_num >= stream_count_) {
     40     stream_count_ = stream_num + 1;
     41   }
     42   InsertOrDie(&hlo_to_stream_number_, hlo, stream_num);
     43   VLOG(2) << "Assign stream #" << stream_num << " to " << hlo->ToString();
     44 }
     45 
     46 namespace {
     47 
     48 // Returns whether the two HLOs can run concurrently, i.e., neither is a
     49 // transitive consumer of the other.
     50 bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b,
     51                         const HloReachabilityMap& reachability) {
     52   return !reachability.IsConnected(&a, &b);
     53 }
     54 
     55 constexpr int kInvalidStreamNum = -1;
     56 //  Returns true iff `stream_num` is an invalid stream number.
     57 inline bool IsStreamNumValid(int stream_num) {
     58   return stream_num != kInvalidStreamNum;
     59 }
     60 
     61 // Returns which existing stream to assign to `hlo`, or -1 if a stream is not
     62 // needed. `stream_assignment` is the existing stream assignment for all
     63 // instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that
     64 // are topologically before `hlo`.
     65 int ComputeStreamToAssign(
     66     const HloInstruction& hlo, const StreamAssignment& stream_assignment,
     67     const HloReachabilityMap& reachability,
     68     const std::vector<const HloInstruction*>& seen_gemms) {
     69   if (hlo.opcode() == HloOpcode::kParameter ||
     70       hlo.opcode() == HloOpcode::kConstant) {
     71     // kParameter and kConstant do not need a thunk.
     72     return kInvalidStreamNum;
     73   }
     74 
     75   if (hlo.GetModule()
     76           ->config()
     77           .debug_options()
     78           .xla_gpu_disable_multi_streaming()) {
     79     return 0;
     80   }
     81 
     82   if (!ImplementedAsGemm(hlo)) {
     83     // If `hlo` is not implemented as a GEMM, keep it close to its operands to
     84     // avoid excessive synchronization.
     85     int stream_num = -1;
     86     for (const auto* operand : hlo.operands()) {
     87       if (stream_assignment.HasStreamAssigned(*operand)) {
     88         stream_num = std::max(stream_num,
     89                               stream_assignment.StreamNumberForHlo(*operand));
     90       }
     91     }
     92     if (!IsStreamNumValid(stream_num)) {
     93       stream_num = 0;
     94     }
     95     return stream_num;
     96   }
     97 
     98   // Assign different streams to concurrent GEMMs. The code below uses a
     99   // greedy approach. First, we compute as forbidden_stream_numbers the
    100   // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign
    101   // `hlo` a different stream.
    102   absl::flat_hash_set<int> forbidden_stream_numbers;
    103   for (const auto* seen_gemm : seen_gemms) {
    104     int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm);
    105     if (!forbidden_stream_numbers.contains(stream_num) &&
    106         CanRunConcurrently(*seen_gemm, hlo, reachability)) {
    107       forbidden_stream_numbers.insert(stream_num);
    108     }
    109   }
    110 
    111   for (int stream_num = 0; stream_num < stream_assignment.StreamCount();
    112        ++stream_num) {
    113     if (!forbidden_stream_numbers.contains(stream_num)) {
    114       return stream_num;
    115     }
    116   }
    117   return stream_assignment.StreamCount();
    118 }
    119 
    120 }  // namespace
    121 
    122 std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
    123   auto stream_assignment = absl::make_unique<StreamAssignment>();
    124   const HloComputation& computation = *module.entry_computation();
    125   std::unique_ptr<HloReachabilityMap> reachability =
    126       HloReachabilityMap::Build(&computation);
    127   std::vector<const HloInstruction*> seen_gemms;
    128   // The execution of different RNG Hlo instructions in the same module updates
    129   // a common global variable. To avoid a race condition, we simply assign all
    130   // RNG kernels to the same stream to make them run sequentially.
    131   //
    132   // TODO(b/111791052): If we remove such a common variable, we will need to
    133   // clean up the code here.
    134   int stream_num_for_rng = kInvalidStreamNum;
    135   for (const auto* hlo : computation.MakeInstructionPostOrder()) {
    136     // If we ever enable fusion of RNG instructions, we will need to extend this
    137     // code to look inside a fused instruction.
    138     int stream_num = (hlo->opcode() == HloOpcode::kRng &&
    139                       IsStreamNumValid(stream_num_for_rng))
    140                          ? stream_num_for_rng
    141                          : ComputeStreamToAssign(*hlo, *stream_assignment,
    142                                                  *reachability, seen_gemms);
    143     if (IsStreamNumValid(stream_num)) {
    144       stream_assignment->AssignStreamToHlo(hlo, stream_num);
    145       if (hlo->opcode() == HloOpcode::kRng &&
    146           !IsStreamNumValid(stream_num_for_rng)) {
    147         stream_num_for_rng = stream_num;
    148       }
    149     }
    150     if (ImplementedAsGemm(*hlo)) {
    151       seen_gemms.push_back(hlo);
    152     }
    153   }
    154   return stream_assignment;
    155 }
    156 
    157 }  // namespace gpu
    158 }  // namespace xla
    159