Home | History | Annotate | Download | only in service
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     17 
     18 #include <utility>
     19 #include <vector>
     20 
     21 #include "absl/strings/str_cat.h"
     22 #include "absl/strings/str_format.h"
     23 #include "absl/strings/str_join.h"
     24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     25 #include "tensorflow/compiler/xla/shape_util.h"
     26 #include "tensorflow/compiler/xla/status_macros.h"
     27 #include "tensorflow/compiler/xla/statusor.h"
     28 #include "tensorflow/compiler/xla/types.h"
     29 #include "tensorflow/compiler/xla/util.h"
     30 #include "tensorflow/core/lib/core/errors.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 
     33 namespace xla {
     34 
     35 bool HloOrdering::ExecutesBefore(const HloInstruction* a,
     36                                  const HloInstruction* b) const {
     37   // 'a' and 'b' may be in different computations. In this case, find the
     38   // callgraph ancestor instructions which call (potentially transitively) the
     39   // computations containing 'a' and 'b' and use these ancestor instructions to
     40   // compare order.
     41   const HloInstruction* a_ancestor;
     42   const HloInstruction* b_ancestor;
     43   std::tie(a_ancestor, b_ancestor) =
     44       call_graph_->NearestAncestorsInSameComputation(
     45           const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
     46 
     47   if (a_ancestor == nullptr) {
     48     // Ancestors in a common computation could not be found so consider the
     49     // instructions 'a' and 'b' to be unordered.
     50     return false;
     51   }
     52   // a_ancestor and b_ancestor must be either both null or both non-null.
     53   CHECK_NE(b_ancestor, nullptr);
     54   CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
     55 
     56   // If the common ancestor is a while instruction there is an additional
     57   // ordering criteria which may apply. The condition computation is considered
     58   // to execute before the body computation so if 'a' is in the condition and
     59   // 'b' is in the body, then 'a' executes before 'b'.
     60   if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
     61     const HloComputation* body = a_ancestor->while_body();
     62     const HloComputation* condition = a_ancestor->while_condition();
     63     if (call_graph_->InstructionIsNestedIn(a, condition) &&
     64         call_graph_->InstructionIsNestedIn(b, body)) {
     65       return true;
     66     }
     67   }
     68 
     69   // If the common ancestor is a conditional instruction, even though the branch
     70   // computations are not really ordered per-se, we define the 0th branch
     71   // computation to be ordered before the 1st one, before the 2nd and so forth.
     72   // This ensures that buffers can still be shared among branch computations
     73   // as they will forcibly have disjoint liveness.
     74   if (a_ancestor == b_ancestor &&
     75       (a_ancestor->opcode() == HloOpcode::kConditional)) {
     76     int a_branch = -1;
     77     int b_branch = -1;
     78     for (int j = 0; j < a_ancestor->branch_count(); ++j) {
     79       if (call_graph_->InstructionIsNestedIn(
     80               a, a_ancestor->branch_computation(j))) {
     81         a_branch = j;
     82       }
     83       if (call_graph_->InstructionIsNestedIn(
     84               b, a_ancestor->branch_computation(j))) {
     85         b_branch = j;
     86       }
     87     }
     88     if (a_branch != -1 && a_branch < b_branch) {
     89       return true;
     90     }
     91     // If 'b' is the conditional ancestor, and 'a' is within a branch
     92     // computation, 'a' executes before 'b'.
     93     if (b == a_ancestor && a_branch != -1) {
     94       return true;
     95     }
     96   }
     97 
     98   return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
     99 }
    100 
    101 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
    102   // Entry parameter should always be defined before other instructions.
    103   const HloModule* module = b.defining_instruction()->parent()->parent();
    104   if (b.defining_instruction()->parent() == module->entry_computation() &&
    105       b.defining_instruction()->opcode() == HloOpcode::kParameter) {
    106     return false;
    107   }
    108 
    109   if (a.defining_instruction()->parent() == module->entry_computation() &&
    110       a.defining_instruction()->opcode() == HloOpcode::kParameter) {
    111     return true;
    112   }
    113 
    114   // Phi values require special handling. Because XLA does not have a phi
    115   // instruction, the definition instruction of the phis values are
    116   // placeholders: either the subcomputation parameter (body or condition) or
    117   // the while instruction. However, the program point where these values are
    118   // logically defined does not necessarily coincide exactly with program point
    119   // of these place-holder instructions. So we explicitly define the following
    120   // order for phi values:
    121   //
    122   //   body/condition parameter phi:
    123   //     Defined before all values defined in its computation excepting other
    124   //     phis.
    125   //
    126   //   while phi:
    127   //     defined after all values defined in the condition or body.
    128   //
    129   auto is_body_or_condition_phi = [](const HloValue& v) {
    130     return v.is_phi() &&
    131            v.defining_instruction()->opcode() == HloOpcode::kParameter;
    132   };
    133   if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
    134       call_graph_->InstructionIsNestedIn(b.defining_instruction(),
    135                                          a.defining_instruction()->parent())) {
    136     return true;
    137   }
    138   if (is_body_or_condition_phi(b) &&
    139       call_graph_->InstructionIsNestedIn(a.defining_instruction(),
    140                                          b.defining_instruction()->parent())) {
    141     return false;
    142   }
    143 
    144   // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
    145   // executes before 'b'.
    146   if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
    147       (call_graph_->InstructionIsNestedIn(
    148            a.defining_instruction(), b.defining_instruction()->while_body()) ||
    149        call_graph_->InstructionIsNestedIn(
    150            a.defining_instruction(),
    151            b.defining_instruction()->while_condition()))) {
    152     return true;
    153   }
    154   // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
    155   // executes before 'b'.
    156   if (b.is_phi() &&
    157       b.defining_instruction()->opcode() == HloOpcode::kConditional) {
    158     for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
    159       if (call_graph_->InstructionIsNestedIn(
    160               a.defining_instruction(),
    161               b.defining_instruction()->branch_computation(j))) {
    162         return true;
    163       }
    164     }
    165   }
    166   return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
    167 }
    168 
    169 /* static */
    170 bool HloOrdering::UseIsBeforeValueDefinition(
    171     const HloUse& use, const HloValue& value,
    172     const HloDataflowAnalysis& dataflow) const {
    173   VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
    174           << ", value=" << value.ToShortString() << ")";
    175   if (ExecutesBefore(use.instruction, value.defining_instruction())) {
    176     VLOG(4) << "  use instruction executes before value-defining instruction";
    177     return true;
    178   }
    179 
    180   // If the use is at the instruction where the value is defined, then the use
    181   // is before the def if the instruction allows buffer sharing (in place
    182   // computation).
    183   if (use.instruction == value.defining_instruction() &&
    184       dataflow.CanShareOperandBufferWithUser(
    185           use.instruction->mutable_operand(use.operand_number),
    186           use.operand_index, value.defining_instruction(),
    187           value.defining_index())) {
    188     VLOG(4) << "  use is value def, and instruction can share use buffer";
    189     return true;
    190   }
    191 
    192   // The use at a while is an input to a phi, and logically occurs before values
    193   // are defined in the body or condition computations.
    194   if (use.instruction->opcode() == HloOpcode::kWhile) {
    195     const HloInstruction* xla_while = use.instruction;
    196     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
    197                                            xla_while->while_body()) ||
    198         call_graph_->InstructionIsNestedIn(value.defining_instruction(),
    199                                            xla_while->while_condition())) {
    200       VLOG(4) << "  use is while " << use.instruction->name()
    201               << " and def is in condition or body";
    202       return true;
    203     }
    204   }
    205 
    206   // Similarly if the value is defined at a while, it logically occurs after any
    207   // uses in the body or condition computations.
    208   if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
    209     CHECK(value.is_phi());
    210     const HloInstruction* xla_while = value.defining_instruction();
    211     if (call_graph_->InstructionIsNestedIn(use.instruction,
    212                                            xla_while->while_body()) ||
    213         call_graph_->InstructionIsNestedIn(use.instruction,
    214                                            xla_while->while_condition())) {
    215       VLOG(4) << "  value is while " << value.defining_instruction()->name()
    216               << " and use is in condition or body";
    217       return true;
    218     }
    219   }
    220 
    221   // The use at a call occurs before values that are defined in the called
    222   // computation.
    223   if (use.instruction->opcode() == HloOpcode::kCall) {
    224     const HloInstruction* call = use.instruction;
    225     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
    226                                            call->to_apply())) {
    227       VLOG(4) << "  use is call " << use.instruction->name()
    228               << " and def is in called computation";
    229       return true;
    230     }
    231   }
    232 
    233   if (use.instruction->opcode() == HloOpcode::kConditional) {
    234     const HloInstruction* conditional = use.instruction;
    235     for (int j = 0; j < conditional->branch_count(); ++j) {
    236       if (call_graph_->InstructionIsNestedIn(
    237               value.defining_instruction(),
    238               conditional->branch_computation(j))) {
    239         VLOG(4) << "  use is conditional " << use.instruction->name()
    240                 << " and def is in " << j << "th branch computation";
    241         return true;
    242       }
    243     }
    244     if (value.defining_instruction() == use.instruction) {
    245       VLOG(4) << "  use is conditional " << use << " and def is "
    246               << value.ToShortString();
    247       return true;
    248     }
    249   }
    250 
    251   VLOG(4) << "  use is not before value";
    252   return false;
    253 }
    254 
    255 bool HloOrdering::LiveRangeStrictlyBefore(
    256     const HloValue& a, const HloValue& b,
    257     const HloDataflowAnalysis& dataflow) const {
    258   VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
    259           << ", b = " << b.ToShortString() << ")";
    260   if (!IsDefinedBefore(a, b)) {
    261     VLOG(4) << a << " not defined before " << b;
    262     return false;
    263   }
    264 
    265   if (a.live_out_of_module()) {
    266     VLOG(4) << a << " is live out of module and defined before " << b;
    267     return false;
    268   }
    269 
    270   // All uses of 'a' must be before 'b' is defined.
    271   for (const HloUse& use : a.uses()) {
    272     if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
    273                                          use.instruction)) {
    274       continue;
    275     }
    276     if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
    277       VLOG(4) << "use of " << a << " (" << use << ") not before " << b
    278               << " is defined";
    279       return false;
    280     }
    281   }
    282 
    283   if (a.instruction()->parent() == b.instruction()->parent()) {
    284     for (const HloPosition& position : a.positions()) {
    285       if (position.instruction ==
    286           a.instruction()->parent()->root_instruction()) {
    287         VLOG(4) << a << " is live out of computation and defined before " << b
    288                 << " which is in same computation";
    289         return false;
    290       }
    291     }
    292   }
    293 
    294   return true;
    295 }
    296 
    297 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
    298                                const HloDataflowAnalysis& dataflow) const {
    299   // Buffers without disjoint liveness may interfere.
    300   return !LiveRangeStrictlyBefore(a, b, dataflow) &&
    301          !LiveRangeStrictlyBefore(b, a, dataflow);
    302 }
    303 
    304 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
    305     : HloOrdering(module) {}
    306 
    307 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
    308     const HloInstruction* a, const HloInstruction* b) const {
    309   CHECK_EQ(a->parent(), b->parent());
    310 
    311   // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
    312   return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
    313 }
    314 
    315 string PredecessorHloOrdering::ToStringHelper(const string& name) const {
    316   std::vector<string> pieces;
    317   pieces.push_back(name);
    318   for (auto* computation : module_->MakeNonfusionComputations()) {
    319     pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
    320     const auto all = computation->MakeInstructionPostOrder();
    321     for (auto instruction : all) {
    322       pieces.push_back(
    323           absl::StrFormat("  %s predecessors:", instruction->name()));
    324       for (auto predecessor : all) {
    325         if (predecessors_.at(computation)
    326                 ->IsReachable(predecessor, instruction)) {
    327           pieces.push_back(absl::StrFormat("    %s", predecessor->name()));
    328         }
    329       }
    330     }
    331   }
    332   return absl::StrJoin(pieces, "\n");
    333 }
    334 
    335 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
    336     : PredecessorHloOrdering(module) {
    337   // Compute predecessor relationships between all instructions to determine
    338   // ordering based on dependencies. ExecutesBefore will return true iff there
    339   // exists a path in the HLO computation graph from 'a' to 'b'.
    340   for (auto* computation : module->MakeNonfusionComputations()) {
    341     predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
    342   }
    343 }
    344 
    345 string DependencyHloOrdering::ToString() const {
    346   return ToStringHelper("DependencyHloOrdering");
    347 }
    348 
    349 SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
    350     : HloOrdering(schedule.module()), schedule_(schedule) {
    351   Initialize();
    352 }
    353 
    354 SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
    355     : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
    356   Initialize();
    357 }
    358 
    359 void SequentialHloOrdering::Initialize() {
    360   // Create a map from instruction to its order position.
    361   TF_DCHECK_OK(schedule_.Verify());
    362   for (const auto& computation_sequence : schedule_.sequences()) {
    363     const auto& order = computation_sequence.second.instructions();
    364     for (int i = 0; i < order.size(); ++i) {
    365       InsertOrDie(&order_position_, order[i], i);
    366     }
    367   }
    368 }
    369 
    370 bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
    371     const HloInstruction* a, const HloInstruction* b) const {
    372   CHECK_EQ(a->parent(), b->parent());
    373   // If either instruction is not in the order, then 'a' and 'b' are unordered.
    374   if (!order_position_.contains(a) || !order_position_.contains(b)) {
    375     return false;
    376   }
    377   return order_position_.at(a) < order_position_.at(b);
    378 }
    379 
    380 const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
    381     const HloComputation& computation) const {
    382   return schedule_.is_computation_scheduled(&computation)
    383              ? &schedule_.sequence(&computation)
    384              : nullptr;
    385 }
    386 
    387 string SequentialHloOrdering::ToString() const {
    388   return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
    389 }
    390 
    391 }  // namespace xla
    392