Home | History | Annotate | Download | only in service
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
     17 
     18 #include <utility>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     22 #include "tensorflow/compiler/xla/service/liveness_util.h"
     23 #include "tensorflow/compiler/xla/shape_util.h"
     24 #include "tensorflow/compiler/xla/status_macros.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/compiler/xla/types.h"
     27 #include "tensorflow/compiler/xla/util.h"
     28 #include "tensorflow/core/lib/core/errors.h"
     29 #include "tensorflow/core/lib/strings/str_util.h"
     30 #include "tensorflow/core/lib/strings/stringprintf.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 
     33 namespace xla {
     34 
     35 bool HloOrdering::ExecutesBefore(const HloInstruction* a,
     36                                  const HloInstruction* b) const {
     37   // 'a' and 'b' may be in different computations. In this case, find the
     38   // callgraph ancestor instructions which call (potentially transitively) the
     39   // computations containing 'a' and 'b' and use these ancestor instructions to
     40   // compare order.
     41   const HloInstruction* a_ancestor;
     42   const HloInstruction* b_ancestor;
     43   std::tie(a_ancestor, b_ancestor) =
     44       call_graph_->NearestAncestorsInSameComputation(
     45           const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
     46 
     47   if (a_ancestor == nullptr) {
     48     // Ancestors in a common computation could not be found so consider the
     49     // instructions 'a' and 'b' to be unordered.
     50     return false;
     51   }
     52   // a_ancestor and b_ancestor must be either both null or both non-null.
     53   CHECK_NE(b_ancestor, nullptr);
     54   CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
     55 
     56   // If the common ancestor is a while instruction there is an additional
     57   // ordering criteria which may apply. The condition computation is considered
     58   // to execute before the body computation so if 'a' is in the condition and
     59   // 'b' is in the body, then 'a' executes before 'b'.
     60   if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
     61     const HloComputation* body = a_ancestor->while_body();
     62     const HloComputation* condition = a_ancestor->while_condition();
     63     if (call_graph_->InstructionIsNestedIn(a, condition) &&
     64         call_graph_->InstructionIsNestedIn(b, body)) {
     65       return true;
     66     }
     67   }
     68 
     69   return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
     70 }
     71 
     72 bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
     73   // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
     74   // is live into the module.
     75   const HloModule* module = b.defining_instruction()->parent()->parent();
     76   if (b.defining_instruction()->parent() == module->entry_computation() &&
     77       b.defining_instruction()->opcode() == HloOpcode::kParameter) {
     78     return false;
     79   }
     80 
     81   // Phi values require special handling. Because XLA does not have a phi
     82   // instruction, the definition instruction of the phis values are
     83   // placeholders: either the subcomputation parameter (body or condition) or
     84   // the while instruction. However, the program point where these values are
     85   // logically defined does not necessarily coincide exactly with program point
     86   // of these place-holder instructions. So we explicitly define the following
     87   // order for phi values:
     88   //
     89   //   body/condition parameter phi:
     90   //     Defined before all values defined in its computation excepting other
     91   //     phis.
     92   //
     93   //   while phi:
     94   //     defined after all values defined in the condition or body.
     95   //
     96   auto is_body_or_condition_phi = [](const HloValue& v) {
     97     return v.is_phi() &&
     98            v.defining_instruction()->opcode() == HloOpcode::kParameter;
     99   };
    100   if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
    101       call_graph_->InstructionIsNestedIn(b.defining_instruction(),
    102                                          a.defining_instruction()->parent())) {
    103     return true;
    104   }
    105   if (is_body_or_condition_phi(b) &&
    106       call_graph_->InstructionIsNestedIn(a.defining_instruction(),
    107                                          b.defining_instruction()->parent())) {
    108     return false;
    109   }
    110 
    111   // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
    112   // executes before 'b'.
    113   if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
    114       (call_graph_->InstructionIsNestedIn(
    115            a.defining_instruction(), b.defining_instruction()->while_body()) ||
    116        call_graph_->InstructionIsNestedIn(
    117            a.defining_instruction(),
    118            b.defining_instruction()->while_condition()))) {
    119     return true;
    120   }
    121 
    122   return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
    123 }
    124 
    125 /* static */
    126 bool HloOrdering::UseIsBeforeValueDefinition(
    127     const HloUse& use, const HloValue& value,
    128     const HloDataflowAnalysis& dataflow) const {
    129   VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
    130           << ", value=" << value.ToShortString() << ")";
    131   if (ExecutesBefore(use.instruction, value.defining_instruction())) {
    132     VLOG(4) << "  use instruction executes before value-defining instruction";
    133     return true;
    134   }
    135 
    136   // If the use is at the instruction where the value is defined, then the use
    137   // is before the def if the instruction allows buffer sharing (in place
    138   // computation).
    139   if (use.instruction == value.defining_instruction() &&
    140       CanShareOperandBufferWithUser(
    141           use.instruction->mutable_operand(use.operand_number),
    142           use.operand_index, value.defining_instruction(),
    143           value.defining_index(), dataflow)) {
    144     VLOG(4) << "  use is value def, and instruction can share use buffer";
    145     return true;
    146   }
    147 
    148   // The use at a while is an input to a phi, and logically occurs before values
    149   // are defined in the body or condition computations.
    150   if (use.instruction->opcode() == HloOpcode::kWhile) {
    151     const HloInstruction* xla_while = use.instruction;
    152     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
    153                                            xla_while->while_body()) ||
    154         call_graph_->InstructionIsNestedIn(value.defining_instruction(),
    155                                            xla_while->while_condition())) {
    156       VLOG(4) << "  use is while " << use.instruction->name()
    157               << " and def is in condition or body";
    158       return true;
    159     }
    160   }
    161 
    162   // Similarly if the value is defined at a while, it logically occurs after any
    163   // uses in the body or condition computations.
    164   if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
    165     CHECK(value.is_phi());
    166     const HloInstruction* xla_while = value.defining_instruction();
    167     if (call_graph_->InstructionIsNestedIn(use.instruction,
    168                                            xla_while->while_body()) ||
    169         call_graph_->InstructionIsNestedIn(use.instruction,
    170                                            xla_while->while_condition())) {
    171       VLOG(4) << "  value is while " << value.defining_instruction()->name()
    172               << " and use is in condition or body";
    173       return true;
    174     }
    175   }
    176 
    177   // The use at a call occurs before values that are defined in the called
    178   // computation.
    179   if (use.instruction->opcode() == HloOpcode::kCall) {
    180     const HloInstruction* call = use.instruction;
    181     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
    182                                            call->to_apply())) {
    183       VLOG(4) << "  use is call " << use.instruction->name()
    184               << " and def is in called computation";
    185       return true;
    186     }
    187   }
    188 
    189   if (use.instruction->opcode() == HloOpcode::kConditional) {
    190     const HloInstruction* conditional = use.instruction;
    191     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
    192                                            conditional->true_computation())) {
    193       VLOG(4) << "  use is conditional " << use.instruction->name()
    194               << " and def is in TRUE computation";
    195       return true;
    196     }
    197     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
    198                                            conditional->false_computation())) {
    199       VLOG(4) << "  use is conditional " << use.instruction->name()
    200               << " and def is in FALSE computation";
    201       return true;
    202     }
    203   }
    204 
    205   VLOG(4) << "  use is not before value";
    206   return false;
    207 }
    208 
    209 bool HloOrdering::LiveRangeStrictlyBefore(
    210     const HloValue& a, const HloValue& b,
    211     const HloDataflowAnalysis& dataflow) const {
    212   VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
    213           << ", b = " << b.ToShortString() << ")";
    214   if (!IsDefinedBefore(a, b)) {
    215     VLOG(4) << "a not defined before b";
    216     return false;
    217   }
    218 
    219   // All uses of 'a' must be before 'b' is defined.
    220   for (const HloUse& use : a.uses()) {
    221     if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
    222       VLOG(4) << "use of a (" << use << ") not before b is defined";
    223       return false;
    224     }
    225   }
    226 
    227   return true;
    228 }
    229 
    230 bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
    231                                const HloDataflowAnalysis& dataflow) const {
    232   // Buffers without disjoint liveness may interfere.
    233   return !LiveRangeStrictlyBefore(a, b, dataflow) &&
    234          !LiveRangeStrictlyBefore(b, a, dataflow);
    235 }
    236 
    237 HloOrderingProto HloOrdering::ToProto() const {
    238   HloOrderingProto proto;
    239   for (const auto& computation : module_->computations()) {
    240     const std::vector<const HloInstruction*>* sequence =
    241         SequentialOrder(*computation);
    242     if (sequence != nullptr) {
    243       HloOrderingProto::SequentialComputation* proto_computation =
    244           proto.add_sequential_computations();
    245       proto_computation->set_computation_name(computation->name());
    246       for (const HloInstruction* instruction : *sequence) {
    247         *proto_computation->add_instruction_names() = instruction->name();
    248       }
    249     }
    250   }
    251   return proto;
    252 }
    253 
    254 PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
    255     : HloOrdering(module) {}
    256 
    257 bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
    258     const HloInstruction* a, const HloInstruction* b) const {
    259   CHECK_EQ(a->parent(), b->parent());
    260 
    261   // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
    262   return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
    263 }
    264 
    265 string PredecessorHloOrdering::ToStringHelper(const string& name) const {
    266   std::vector<string> pieces;
    267   pieces.push_back(name);
    268   for (auto* computation : module_->MakeNonfusionComputations()) {
    269     pieces.push_back(tensorflow::strings::Printf("computation %s:",
    270                                                  computation->name().c_str()));
    271     const auto all = computation->MakeInstructionPostOrder();
    272     for (auto instruction : all) {
    273       pieces.push_back(tensorflow::strings::Printf(
    274           "  %s predecessors:", instruction->name().c_str()));
    275       for (auto predecessor : all) {
    276         if (predecessors_.at(computation)
    277                 ->IsReachable(predecessor, instruction)) {
    278           pieces.push_back(
    279               tensorflow::strings::Printf("  %s", predecessor->name().c_str()));
    280         }
    281       }
    282     }
    283   }
    284   return tensorflow::str_util::Join(pieces, "\n");
    285 }
    286 
    287 DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
    288     : PredecessorHloOrdering(module) {
    289   // Compute predecessor relationships between all instructions to determine
    290   // ordering based on dependencies. ExecutesBefore will return true iff there
    291   // exists a path in the HLO computation graph from 'a' to 'b'.
    292   for (auto* computation : module->MakeNonfusionComputations()) {
    293     predecessors_.emplace(computation, computation->ComputeReachability());
    294   }
    295 }
    296 
    297 string DependencyHloOrdering::ToString() const {
    298   return ToStringHelper("DependencyHloOrdering");
    299 }
    300 
    301 SequentialHloOrdering::SequentialHloOrdering(
    302     const HloModule* module, const HloModuleSequence& module_sequence)
    303     : HloOrdering(module), module_sequence_(module_sequence) {
    304   // Create a map from instruction to its order position.
    305   for (auto computation_order : module_sequence_) {
    306     const std::vector<const HloInstruction*>& order = computation_order.second;
    307     for (int i = 0; i < order.size(); ++i) {
    308       DCHECK_EQ(0, order_position_.count(order[i]));
    309       order_position_.emplace(order[i], i);
    310     }
    311   }
    312 }
    313 
    314 bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
    315     const HloInstruction* a, const HloInstruction* b) const {
    316   CHECK_EQ(a->parent(), b->parent());
    317   // If either instruction is not in the order, then 'a' and 'b' are unordered.
    318   if (order_position_.count(a) == 0 || order_position_.count(b) == 0) {
    319     return false;
    320   }
    321   return order_position_.at(a) < order_position_.at(b);
    322 }
    323 
    324 const std::vector<const HloInstruction*>*
    325 SequentialHloOrdering::SequentialOrder(
    326     const HloComputation& computation) const {
    327   auto find_it = module_sequence_.find(&computation);
    328   return find_it == module_sequence_.end() ? nullptr : &find_it->second;
    329 }
    330 
    331 string SequentialHloOrdering::ToString() const {
    332   std::vector<string> pieces;
    333   pieces.push_back("SequentialHloOrdering");
    334   for (auto* computation : module_->computations()) {
    335     pieces.push_back(tensorflow::strings::Printf("computation %s order:",
    336                                                  computation->name().c_str()));
    337     // Gather all instructions in the module sequence for this computation and
    338     // sort them by their position.
    339     std::vector<const HloInstruction*> instructions;
    340     for (auto& instruction_position : order_position_) {
    341       const HloInstruction* instruction = instruction_position.first;
    342       if (instruction->parent() == computation) {
    343         instructions.push_back(instruction);
    344       }
    345     }
    346     std::sort(instructions.begin(), instructions.end(),
    347               [this](const HloInstruction* a, const HloInstruction* b) {
    348                 return order_position_.at(a) < order_position_.at(b);
    349               });
    350     for (auto instruction : instructions) {
    351       pieces.push_back(
    352           tensorflow::strings::Printf("  %s", instruction->name().c_str()));
    353     }
    354   }
    355   return tensorflow::str_util::Join(pieces, "\n");
    356 }
    357 
    358 std::ostream& operator<<(
    359     std::ostream& out,
    360     const SequentialHloOrdering::HloModuleSequence& module_sequence) {
    361   for (auto computation_pair : module_sequence) {
    362     const HloComputation* computation = computation_pair.first;
    363     const std::vector<const HloInstruction*>& computation_sequence =
    364         computation_pair.second;
    365     out << "Computation " << computation->name() << ":\n";
    366     for (auto* instruction : computation_sequence) {
    367       out << "  " << instruction->name() << "\n";
    368     }
    369   }
    370   return out;
    371 }
    372 
    373 }  // namespace xla
    374