Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
     17 
     18 #include "tensorflow/compiler/xla/map_util.h"
     19 #include "tensorflow/compiler/xla/ptr_util.h"
     20 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
     21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     22 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
     23 
     24 namespace xla {
     25 namespace gpu {
     26 
     27 bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const {
     28   return hlo_to_stream_number_.count(&hlo);
     29 }
     30 
     31 int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const {
     32   return FindOrDie(hlo_to_stream_number_, &hlo);
     33 }
     34 
     35 void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo,
     36                                          int stream_no) {
     37   CHECK_GE(stream_no, 0);
     38   if (stream_no >= stream_count_) {
     39     stream_count_ = stream_no + 1;
     40   }
     41   InsertOrDie(&hlo_to_stream_number_, hlo, stream_no);
     42   VLOG(2) << "Assign stream #" << stream_no << " to " << hlo->ToString();
     43 }
     44 
     45 namespace {
     46 
     47 // Returns whether the two HLOs can run concurrently, i.e., neither is a
     48 // transitive consumer of the other.
     49 bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b,
     50                         const HloReachabilityMap& reachability) {
     51   return !reachability.IsConnected(&a, &b);
     52 }
     53 
     54 // Returns which existing stream to assign to `hlo`, or -1 if a stream is not
     55 // needed. `stream_assignment` is the existing stream assignment for all
     56 // instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that
     57 // are topologically before `hlo`.
     58 int ComputeStreamToAssign(
     59     const HloInstruction& hlo, const StreamAssignment& stream_assignment,
     60     const HloReachabilityMap& reachability,
     61     const std::vector<const HloInstruction*>& seen_gemms) {
     62   if (hlo.opcode() == HloOpcode::kParameter ||
     63       hlo.opcode() == HloOpcode::kConstant) {
     64     // kParameter and kConstant do not need a thunk.
     65     return -1;
     66   }
     67 
     68   if (hlo.GetModule()
     69           ->config()
     70           .debug_options()
     71           .xla_gpu_disable_multi_streaming()) {
     72     return 0;
     73   }
     74 
     75   if (!ImplementedAsGemm(hlo)) {
     76     // If `hlo` is not implemented as a GEMM, keep it close to its operands to
     77     // avoid excessive synchronization.
     78     int stream_no = -1;
     79     for (const auto* operand : hlo.operands()) {
     80       if (stream_assignment.HasStreamAssigned(*operand)) {
     81         stream_no =
     82             std::max(stream_no, stream_assignment.StreamNumberForHlo(*operand));
     83       }
     84     }
     85     if (stream_no == -1) {
     86       stream_no = 0;
     87     }
     88     return stream_no;
     89   }
     90 
     91   // Assign different streams to concurrent GEMMs. The code below uses a
     92   // greedy approach. First, we compute as forbidden_stream_numbers the
     93   // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign
     94   // `hlo` a different stream.
     95   std::set<int> forbidden_stream_numbers;
     96   for (const auto* seen_gemm : seen_gemms) {
     97     int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm);
     98     if (!forbidden_stream_numbers.count(stream_no) &&
     99         CanRunConcurrently(*seen_gemm, hlo, reachability)) {
    100       forbidden_stream_numbers.insert(stream_no);
    101     }
    102   }
    103 
    104   for (int stream_no = 0; stream_no < stream_assignment.StreamCount();
    105        ++stream_no) {
    106     if (!forbidden_stream_numbers.count(stream_no)) {
    107       return stream_no;
    108     }
    109   }
    110   return stream_assignment.StreamCount();
    111 }
    112 
    113 }  // namespace
    114 
    115 std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
    116   auto stream_assignment = MakeUnique<StreamAssignment>();
    117   const HloComputation& computation = *module.entry_computation();
    118   std::unique_ptr<HloReachabilityMap> reachability =
    119       computation.ComputeReachability();
    120   std::vector<const HloInstruction*> seen_gemms;
    121   for (const auto* hlo : computation.MakeInstructionPostOrder()) {
    122     int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment,
    123                                           *reachability, seen_gemms);
    124     if (stream_no != -1) {
    125       stream_assignment->AssignStreamToHlo(hlo, stream_no);
    126     }
    127     if (ImplementedAsGemm(*hlo)) {
    128       seen_gemms.push_back(hlo);
    129     }
    130   }
    131   return stream_assignment;
    132 }
    133 
    134 }  // namespace gpu
    135 }  // namespace xla
    136