Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     17 
     18 #include <stddef.h>
     19 #include <algorithm>
     20 #include <functional>
     21 #include <list>
     22 #include <queue>
     23 #include <set>
     24 #include <sstream>
     25 
     26 #include "tensorflow/compiler/xla/layout_util.h"
     27 #include "tensorflow/compiler/xla/map_util.h"
     28 #include "tensorflow/compiler/xla/ptr_util.h"
     29 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     30 #include "tensorflow/compiler/xla/service/hlo_module.h"
     31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     32 #include "tensorflow/compiler/xla/shape_util.h"
     33 #include "tensorflow/compiler/xla/status_macros.h"
     34 #include "tensorflow/compiler/xla/types.h"
     35 #include "tensorflow/compiler/xla/util.h"
     36 #include "tensorflow/core/lib/core/errors.h"
     37 #include "tensorflow/core/lib/core/status.h"
     38 #include "tensorflow/core/lib/gtl/flatset.h"
     39 #include "tensorflow/core/lib/strings/str_util.h"
     40 #include "tensorflow/core/lib/strings/strcat.h"
     41 #include "tensorflow/core/platform/logging.h"
     42 
     43 namespace xla {
     44 
     45 using ::tensorflow::strings::StrCat;
     46 
     47 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
     48     HloInstruction* root_instruction) {
     49   int parameter_count = 0;
     50   for (auto& instruction : instructions_) {
     51     if (instruction->opcode() == HloOpcode::kParameter) {
     52       parameter_count++;
     53     }
     54   }
     55   // If root_instruction is not specified use the last added instruction.
     56   HloInstruction* root =
     57       root_instruction ? root_instruction : last_added_instruction_;
     58   CHECK_NE(nullptr, root);
     59   return WrapUnique(new HloComputation(name_, parameter_count, &instructions_,
     60                                        root, fusion_instruction_));
     61 }
     62 
     63 HloComputation::HloComputation(
     64     const string& name, int parameter_count,
     65     std::vector<std::unique_ptr<HloInstruction>>* instructions,
     66     HloInstruction* root_instruction, HloInstruction* fusion_instruction)
     67     : name_(name),
     68       root_instruction_(root_instruction),
     69       fusion_instruction_(fusion_instruction) {
     70   param_instructions_.resize(parameter_count, nullptr);
     71   bool root_found = false;
     72   for (auto& instruction : *instructions) {
     73     if (instruction->opcode() == HloOpcode::kParameter) {
     74       int64 param_no = instruction->parameter_number();
     75       CHECK(param_no >= 0 && param_no < parameter_count)
     76           << "\nERROR: invalid parameter number.  Expected [0, "
     77           << parameter_count << "), got " << param_no;
     78       CHECK(param_instructions_[param_no] == nullptr)
     79           << "\nERROR: parameter number " << param_no
     80           << " already allocated in this computation";
     81       param_instructions_[param_no] = instruction.get();
     82     }
     83     root_found |= instruction.get() == root_instruction_;
     84     AddInstructionInternal(std::move(instruction));
     85   }
     86   CHECK(root_found)
     87       << "\nERROR: root instruction is not present in computation.";
     88 }
     89 
     90 HloInstruction* HloComputation::AddInstruction(
     91     std::unique_ptr<HloInstruction> instruction) {
     92   CHECK(instruction->opcode() != HloOpcode::kParameter)
     93       << "Parameter instructions cannot be added to a computation after "
     94       << "it has been built";
     95   return AddInstructionInternal(std::move(instruction));
     96 }
     97 
     98 HloInstruction* HloComputation::AddInstructionInternal(
     99     std::unique_ptr<HloInstruction> instruction) {
    100   if (parent() != nullptr) {
    101     instruction->UniquifyName(&parent()->instruction_name_uniquer());
    102     instruction->SetUniqueId(parent()->NewUniqueInstructionId());
    103   }
    104   Reparent(instruction.get());
    105   HloInstruction* pinst = instruction.get();
    106   instruction_iterators_[pinst] =
    107       instructions_.insert(instructions_.end(), std::move(instruction));
    108   return pinst;
    109 }
    110 
    111 HloInstruction* HloComputation::AddParameter(
    112     std::unique_ptr<HloInstruction> instruction) {
    113   CHECK(instruction->opcode() == HloOpcode::kParameter);
    114   CHECK(IsFusionComputation());
    115   CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
    116   instruction->set_parent(this);
    117   param_instructions_.push_back(instruction.get());
    118   AddInstructionInternal(std::move(instruction));
    119   return instructions_.back().get();
    120 }
    121 
    122 Status HloComputation::RemoveParameter(int64 param_no) {
    123   CHECK_GE(param_no, 0);
    124   CHECK_LT(param_no, param_instructions_.size());
    125   CHECK(IsFusionComputation());
    126   HloInstruction* param_instruction = param_instructions_[param_no];
    127   auto param_instruction_iterator = param_instructions_.begin() + param_no;
    128   param_instructions_.erase(param_instruction_iterator);
    129   // Throw removed fused parameter instruction away.
    130   TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
    131 
    132   while (param_no < param_instructions_.size()) {
    133     param_instruction = param_instructions_[param_no];
    134     string param_name = param_instruction->name();
    135     // Fusion parameters are named foo.param_1, bar.param_2, etc. We are
    136     // renumbering the parameters, so replace the final number in the name with
    137     // the updated value.
    138     const string param_underscore = ".param_";
    139     size_t index = param_name.rfind(param_underscore);
    140     if (index == string::npos) {
    141       string after_param = name().substr(index + param_underscore.size());
    142       int64 numeric_suffix;
    143       if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
    144         param_name =
    145             StrCat(param_name.substr(0, index), param_underscore, param_no);
    146       }
    147     }
    148 
    149     HloInstruction* new_instr =
    150         AddInstructionInternal(HloInstruction::CreateParameter(
    151             param_no, param_instruction->shape(), param_name));
    152     TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
    153     param_instructions_[param_no] = new_instr;
    154     TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
    155     param_no++;
    156   }
    157 
    158   return Status::OK();
    159 }
    160 
    161 void HloComputation::Reparent(HloInstruction* instruction) {
    162   instruction->set_parent(this);
    163 }
    164 
    165 bool HloComputation::IsRemovable(const HloInstruction* instruction) {
    166   // If the instruction has control predecessors or successors then we cannot
    167   // remove the instruction without violating ordering constraints (added, for
    168   // example, to avert interference due to buffer aliasing).
    169   if (!instruction->control_predecessors().empty() ||
    170       !instruction->control_successors().empty()) {
    171     return false;
    172   }
    173 
    174   if (instruction->opcode() == HloOpcode::kParameter &&
    175       !IsFusionComputation()) {
    176     return false;
    177   }
    178 
    179   return true;
    180 }
    181 
    182 bool HloComputation::HasSideEffect() const {
    183   for (auto* instruction : instructions()) {
    184     if (instruction->HasSideEffect()) {
    185       return true;
    186     }
    187   }
    188   return false;
    189 }
    190 
    191 Status HloComputation::RemoveInstructionAndUnusedOperands(
    192     HloInstruction* instruction) {
    193   TF_RET_CHECK(root_instruction() != instruction);
    194 
    195   TF_RET_CHECK(instruction->user_count() == 0);
    196   TF_RET_CHECK(IsRemovable(instruction))
    197       << "Cannot remove instruction: " << instruction->ToString();
    198   std::unordered_set<HloInstruction*> removed;
    199   std::queue<HloInstruction*> worklist;
    200   worklist.push(instruction);
    201   while (!worklist.empty()) {
    202     HloInstruction* item = worklist.front();
    203     worklist.pop();
    204 
    205     if (removed.count(item) != 0 || item->user_count() != 0 ||
    206         item == root_instruction() || !IsRemovable(item) ||
    207         item->HasSideEffect()) {
    208       continue;
    209     }
    210     for (int i = 0; i < item->operand_count(); ++i) {
    211       worklist.push(item->mutable_operand(i));
    212     }
    213 
    214     TF_RETURN_IF_ERROR(RemoveInstruction(item));
    215     removed.insert(item);
    216   }
    217   return Status::OK();
    218 }
    219 
    220 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
    221   VLOG(2) << "Removing instruction " << instruction->name()
    222           << " from computation " << name();
    223   TF_RET_CHECK(IsRemovable(instruction))
    224       << "cannot remove instruction: " << instruction->ToString();
    225   TF_RET_CHECK(root_instruction() != instruction)
    226       << "cannot remove root instruction " << instruction->name();
    227   TF_RET_CHECK(instruction->user_count() == 0)
    228       << "instruction " << instruction->name()
    229       << " has users and cannot be removed";
    230   TF_RET_CHECK(instruction->control_predecessors().empty())
    231       << "instruction " << instruction->name()
    232       << " has control predecessors and cannot be removed";
    233   TF_RET_CHECK(instruction->control_successors().empty())
    234       << "instruction " << instruction->name()
    235       << " has control successors and cannot be removed";
    236 
    237   TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
    238   auto inst_it = instruction_iterators_.at(instruction);
    239   (*inst_it)->set_parent(nullptr);
    240   instruction->DetachFromOperands();
    241   instructions_.erase(inst_it);
    242   return Status::OK();
    243 }
    244 
    245 void HloComputation::set_root_instruction(
    246     HloInstruction* new_root_instruction) {
    247   // The shape of the root (ignoring layout) is an invariant of the computation
    248   // for non-fusion cases.
    249   if (!IsFusionComputation()) {
    250     CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
    251                                 root_instruction_->shape()))
    252         << new_root_instruction->shape().ShortDebugString()
    253         << " is incompatible with "
    254         << root_instruction_->shape().ShortDebugString();
    255   }
    256   bool root_found = false;
    257   for (auto& instruction : instructions_) {
    258     if (new_root_instruction == instruction.get()) {
    259       root_found = true;
    260       break;
    261     }
    262   }
    263   DCHECK(root_found);
    264 
    265   root_instruction_ = new_root_instruction;
    266 }
    267 
    268 namespace {
    269 
    270 // Helper class which computes the post order of an expression rooted at a
    271 // particular instruction.
    272 class InstructionPostOrderer : public DfsHloVisitorWithDefault {
    273  public:
    274   // added_instructions is the set of instructions which have already been
    275   // accounted for in the post order in previous invocations of
    276   // GetOrder. Without this mechanism, instructions which are predecessors of
    277   // multiple root instructions of the computation can be added to the post
    278   // order more than once.
    279   static std::list<HloInstruction*> GetOrder(
    280       HloInstruction* root,
    281       tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) {
    282     InstructionPostOrderer orderer(added_instructions);
    283     TF_CHECK_OK(root->Accept(&orderer));
    284     return std::move(orderer.post_order_);
    285   }
    286 
    287  private:
    288   explicit InstructionPostOrderer(
    289       tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions)
    290       : added_instructions_(added_instructions) {}
    291   ~InstructionPostOrderer() override {}
    292 
    293   Status DefaultAction(HloInstruction* hlo_instruction) override {
    294     if (added_instructions_->count(hlo_instruction) == 0) {
    295       post_order_.push_back(hlo_instruction);
    296       added_instructions_->insert(hlo_instruction);
    297     }
    298     return Status::OK();
    299   }
    300 
    301   std::list<HloInstruction*> post_order_;
    302   tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_;
    303 };
    304 
    305 // Helper which builds a post order of the HLO call graph.
    306 void ComputeComputationPostOrder(
    307     HloComputation* computation,
    308     tensorflow::gtl::FlatSet<HloComputation*>* visited,
    309     std::list<HloComputation*>* post_order) {
    310   if (visited->count(computation) > 0) {
    311     return;
    312   }
    313 
    314   for (auto* instruction : computation->instructions()) {
    315     for (HloComputation* called_computation :
    316          instruction->called_computations()) {
    317       ComputeComputationPostOrder(called_computation, visited, post_order);
    318     }
    319   }
    320 
    321   visited->insert(computation);
    322   post_order->push_back(computation);
    323 }
    324 
    325 }  // namespace
    326 
    327 std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
    328   std::list<HloInstruction*> post_order;
    329   std::list<HloInstruction*> trace_instructions;
    330   tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
    331   for (auto& instruction : instructions_) {
    332     if (instruction->opcode() == HloOpcode::kTrace) {
    333       // Trace instructions aren't handled by the DFS visitor. Add trace
    334       // instructions to the post order at the end (necessarily they have no
    335       // users).
    336       trace_instructions.push_back(instruction.get());
    337     } else if (instruction->users().empty()) {
    338       post_order.splice(post_order.end(),
    339                         InstructionPostOrderer::GetOrder(instruction.get(),
    340                                                          &added_instructions));
    341     }
    342   }
    343   post_order.splice(post_order.end(), trace_instructions);
    344   CHECK_EQ(instructions_.size(), post_order.size())
    345       << "number of instructions does not match post order size";
    346   return post_order;
    347 }
    348 
    349 std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
    350     const {
    351   tensorflow::gtl::FlatSet<HloComputation*> visited;
    352   std::list<HloComputation*> post_order;
    353 
    354   // To avoid special handling of this computation, cast away const of
    355   // 'this'. 'this' is immediately removed from the post order after
    356   // construction.
    357   ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
    358                               &post_order);
    359 
    360   // We don't want to include this computation in the post order.
    361   CHECK_EQ(this, post_order.back());
    362   post_order.pop_back();
    363 
    364   return post_order;
    365 }
    366 
    367 string HloComputation::ToString(const HloPrintOptions& options) const {
    368   std::ostringstream s;
    369   for (int i = 0; i < options.indent_amount(); i++) {
    370     s << "    ";
    371   }
    372   if (options.print_percent()) {
    373     s << "%";
    374   }
    375   s << name();
    376   if (options.print_program_shape()) {
    377     s << " " << ShapeUtil::HumanString(ComputeProgramShape());
    378   }
    379   s << " {\n";
    380   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
    381     for (int i = 0; i < options.indent_amount(); i++) {
    382       s << "    ";
    383     }
    384     s << "  " << (instruction == root_instruction_ ? "ROOT " : "")
    385       << instruction->ToString(options) << "\n";
    386   }
    387   for (int i = 0; i < options.indent_amount(); i++) {
    388     s << "    ";
    389   }
    390   s << "}";
    391   return s.str();
    392 }
    393 
    394 HloComputationProto HloComputation::ToProto() const {
    395   HloComputationProto proto;
    396   proto.set_name(name_);
    397   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
    398     HloInstructionProto instruction_proto = instruction->ToProto();
    399     proto.add_instructions()->Swap(&instruction_proto);
    400   }
    401   proto.set_root_name(root_instruction()->name());
    402   return proto;
    403 }
    404 
    405 /* static */ StatusOr<std::unique_ptr<HloComputation>>
    406 HloComputation::CreateFromProto(
    407     HloModule* module, const HloComputationProto& proto,
    408     const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
    409     const std::function<void(std::unique_ptr<HloComputation>)>&
    410         add_fused_computation,
    411     HloInstruction* fusion_instruction) {
    412   std::vector<std::unique_ptr<HloInstruction>> instructions;
    413   tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
    414   int64 parameter_count = 0;
    415   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
    416     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
    417                         HloInstruction::CreateFromProto(
    418                             module, instruction_proto, instruction_map,
    419                             computation_map, add_fused_computation));
    420     if (instruction->opcode() == HloOpcode::kParameter) {
    421       parameter_count++;
    422     }
    423     TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name()));
    424     instruction_map[instruction->name()] = instruction.get();
    425     instructions.push_back(std::move(instruction));
    426   }
    427 
    428   TF_RET_CHECK(!proto.root_name().empty());
    429   TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name()));
    430   HloInstruction* root = instruction_map.at(proto.root_name());
    431   return WrapUnique(new HloComputation(
    432       proto.name(), parameter_count, &instructions, root, fusion_instruction));
    433 }
    434 
    435 void HloComputation::FuseInstructionsInto(
    436     tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
    437     HloInstruction* fusion_instruction) {
    438   CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
    439   HloInstruction* root = instructions_to_fuse.front();
    440   TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
    441   if (root == root_instruction()) {
    442     set_root_instruction(fusion_instruction);
    443   }
    444   TF_CHECK_OK(RemoveInstruction(root));
    445   for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
    446     HloInstruction* instruction = instructions_to_fuse[i];
    447     fusion_instruction->FuseInstruction(instruction);
    448     if (instruction->user_count() == 0) {
    449       TF_CHECK_OK(RemoveInstruction(instruction));
    450     }
    451   }
    452 }
    453 
    454 HloInstruction* HloComputation::CreateFusionInstruction(
    455     tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
    456     HloInstruction::FusionKind fusion_kind) {
    457   HloInstruction* root = instructions_to_fuse.front();
    458   HloInstruction* fusion_instruction = AddInstruction(
    459       HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
    460   FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
    461   return fusion_instruction;
    462 }
    463 
    464 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
    465     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
    466     ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) {
    467   if (ShapeUtil::IsArray(instruction->shape())) {
    468     if (indices_to_copy == nullptr || indices_to_copy->element(*index)) {
    469       // Use kCopy to copy array elements
    470       HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary(
    471           instruction->shape(), HloOpcode::kCopy, instruction));
    472       if (copies_added != nullptr) {
    473         *copies_added->mutable_element(*index) = copy;
    474       }
    475       return copy;
    476     } else {
    477       // Array elements which are not to be copied are passed through
    478       // transparently.
    479       return instruction;
    480     }
    481   } else if (ShapeUtil::IsTuple(instruction->shape())) {
    482     std::vector<HloInstruction*> elements;
    483     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
    484          i++) {
    485       HloInstruction* gte =
    486           AddInstruction(HloInstruction::CreateGetTupleElement(
    487               ShapeUtil::GetTupleElementShape(instruction->shape(), i),
    488               instruction, i));
    489 
    490       index->push_back(i);
    491       TF_ASSIGN_OR_RETURN(
    492           HloInstruction * element,
    493           DeepCopyHelper(gte, indices_to_copy, copies_added, index));
    494       elements.push_back(element);
    495       index->pop_back();
    496     }
    497     return AddInstruction(HloInstruction::CreateTuple(elements));
    498   } else {
    499     return FailedPrecondition(
    500         "Can only copy array and tuple shaped instructions");
    501   }
    502 }
    503 
    504 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
    505     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
    506     ShapeTree<HloInstruction*>* copies_added) {
    507   if (instruction->parent() != this) {
    508     return FailedPrecondition(
    509         "Can't deep copy instruction %s: instruction is not in computation %s",
    510         instruction->name().c_str(), name().c_str());
    511   }
    512   if (indices_to_copy != nullptr &&
    513       !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
    514     return FailedPrecondition(
    515         "Can't deep copy instruction %s: given shape tree of indices to copy "
    516         "has incompatible shapes: %s vs. %s",
    517         instruction->name().c_str(),
    518         ShapeUtil::HumanString(instruction->shape()).c_str(),
    519         ShapeUtil::HumanString(indices_to_copy->shape()).c_str());
    520   }
    521 
    522   ShapeIndex index;
    523   return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index);
    524 }
    525 
    526 ProgramShape HloComputation::ComputeProgramShape() const {
    527   ProgramShape program_shape;
    528 
    529   for (auto* param_instruction : param_instructions_) {
    530     *program_shape.add_parameters() = param_instruction->shape();
    531     *program_shape.add_parameter_names() = param_instruction->name();
    532   }
    533   *program_shape.mutable_result() = root_instruction_->shape();
    534 
    535   LayoutUtil::ClearLayout(&program_shape);
    536   return program_shape;
    537 }
    538 
    539 bool HloComputation::operator==(const HloComputation& other) const {
    540   std::set<std::pair<const HloInstruction*, const HloInstruction*>> visited;
    541   std::function<bool(const HloInstruction*, const HloInstruction*)> eq =
    542       [&visited, &eq](const HloInstruction* a, const HloInstruction* b) {
    543         // If <a,b> are visited but not identical, the recursion should have
    544         // been aborted. So, if <a,b> are visited at this point, they must be
    545         // identical.
    546         if (visited.count(std::make_pair(a, b)) > 0) {
    547           return true;
    548         }
    549         visited.emplace(a, b);
    550         return a->Identical(
    551             *b, eq, [](const HloComputation* a, const HloComputation* b) {
    552               return *a == *b;
    553             });
    554       };
    555   return eq(root_instruction(), other.root_instruction());
    556 }
    557 
    558 Status HloComputation::ReplaceWithNewInstruction(
    559     HloInstruction* old_instruction,
    560     std::unique_ptr<HloInstruction> new_instruction) {
    561   return ReplaceInstruction(old_instruction,
    562                             AddInstruction(std::move(new_instruction)));
    563 }
    564 
    565 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
    566                                           HloInstruction* new_instruction) {
    567   TF_RET_CHECK(
    568       ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
    569       << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
    570       << ShapeUtil::HumanString(new_instruction->shape());
    571 
    572   VLOG(10) << "transformed " << old_instruction->ToString() << " to "
    573            << new_instruction->ToString();
    574   // Try to add metadata for HLO instructions that are created to replace
    575   // existing HLO instructions (e.g. during optimizations). The assumption is
    576   // that the old instruction and the new instruction would perform the same
    577   // function, and that they would be correlated to the same TF op. This might
    578   // not always be correct since HLO optimizations can cross TF op boundaries.
    579   // But still this seems to be better than nothing.
    580   if (new_instruction->metadata().op_name().empty()) {
    581     new_instruction->set_metadata(old_instruction->metadata());
    582   }
    583   TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
    584   return RemoveInstructionAndUnusedOperands(old_instruction);
    585 }
    586 
    587 std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
    588     const {
    589   const std::list<HloInstruction*> all = MakeInstructionPostOrder();
    590   auto result = MakeUnique<HloReachabilityMap>(all);
    591 
    592   std::vector<HloInstruction*> inputs;
    593   for (const HloInstruction* hlo : all) {
    594     inputs.assign(hlo->operands().begin(), hlo->operands().end());
    595     inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
    596                   hlo->control_predecessors().end());
    597     result->SetReachabilityToUnion(inputs, hlo);
    598   }
    599   return result;
    600 }
    601 
    602 void HloComputation::UpdateReachabilityThroughInstruction(
    603     const HloInstruction* instruction, HloReachabilityMap* reachability_map) {
    604   std::queue<const HloInstruction*> worklist;
    605   worklist.push(instruction);
    606 
    607   std::vector<HloInstruction*> inputs;
    608 
    609   while (!worklist.empty()) {
    610     const HloInstruction* item = worklist.front();
    611     worklist.pop();
    612 
    613     inputs.assign(item->operands().begin(), item->operands().end());
    614     inputs.insert(inputs.end(), item->control_predecessors().begin(),
    615                   item->control_predecessors().end());
    616 
    617     if (reachability_map->SetReachabilityToUnion(inputs, item)) {
    618       // Add immediate successors to worklist.
    619       for (const HloInstruction* user : item->users()) {
    620         worklist.push(user);
    621       }
    622       for (const HloInstruction* succ : item->control_successors()) {
    623         worklist.push(succ);
    624       }
    625     }
    626   }
    627 }
    628 
    629 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
    630   std::vector<HloInstruction*> unreachable_roots;
    631   for (auto* instruction : instructions()) {
    632     if (instruction->user_count() == 0 &&
    633         instruction->control_successors().empty() &&
    634         instruction != root_instruction()) {
    635       unreachable_roots.push_back(instruction);
    636     }
    637   }
    638   VLOG(3) << "Unreachable roots:"
    639           << tensorflow::str_util::Join(
    640                  unreachable_roots, "\n\t",
    641                  [](string* out, const HloInstruction* hlo) {
    642                    tensorflow::strings::StrAppend(out, hlo->ToString());
    643                  });
    644   return unreachable_roots;
    645 }
    646 
    647 template <typename HloInstructionPtr>
    648 Status HloComputation::Accept(
    649     DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
    650   // Visit unreachable roots. Beware that the visitor might delete the currently
    651   // visited root, which would invalidate iterators if the unreachable roots
    652   // weren't computed ahead of time.
    653   for (HloInstruction* root : CollectUnreachableRoots()) {
    654     VLOG(3) << "Traversing unreachable root: " << root->ToString();
    655     // Call FinishVisit only at the end.
    656     TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
    657   }
    658   // Visit the computation root instruction last.
    659   return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
    660 }
    661 
    662 // Explicit instantiations.
    663 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
    664 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
    665 
    666 Status HloComputation::AcceptWithOperandOrder(
    667     DfsHloVisitor* visitor,
    668     const HloInstruction::CompareFunction& operand_order) const {
    669   // Visit unreachable roots. Beware that the visitor might delete the currently
    670   // visited root, which would invalidate iterators if the unreachable roots
    671   // weren't computed ahead of time.
    672   for (HloInstruction* root : CollectUnreachableRoots()) {
    673     TF_RETURN_IF_ERROR(
    674         root->AcceptWithOperandOrder(visitor, operand_order,
    675                                      /*call_finish_visit=*/false));
    676   }
    677   // Visit the computation root instruction last.
    678   return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
    679                                                     /*call_finish_visit=*/true);
    680 }
    681 
    682 template <typename HloInstructionPtr>
    683 Status HloComputation::AcceptOrdered(
    684     DfsHloVisitorBase<HloInstructionPtr>* visitor,
    685     const std::vector<const HloInstruction*>& order) const {
    686   VLOG(3) << "Accepting visitor with order.";
    687   for (HloInstruction* root : CollectUnreachableRoots()) {
    688     TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
    689         << root->ToString();
    690   }
    691   TF_RET_CHECK(order.size() == instruction_count());
    692   std::unordered_set<const HloInstruction*> visited;
    693   for (const HloInstruction* instruction : order) {
    694     VLOG(3) << "Visiting ordered: " << instruction->ToString();
    695     TF_RET_CHECK(instruction_iterators_.count(instruction) == 1)
    696         << "Instruction " << instruction->name() << " is not in computation "
    697         << name();
    698     TF_RET_CHECK(visited.count(instruction) == 0)
    699         << "Instruction " << instruction->name()
    700         << " appears more than once in order";
    701     HloInstruction* mutable_instruction =
    702         const_cast<HloInstruction*>(instruction);
    703     TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
    704     TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
    705     visitor->SetVisited(*mutable_instruction);
    706     TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
    707     visited.insert(instruction);
    708   }
    709   TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
    710   return Status::OK();
    711 }
    712 
    713 // Explicit instantiations.
    714 template Status HloComputation::AcceptOrdered(
    715     DfsHloVisitor*, const std::vector<const HloInstruction*>&) const;
    716 template Status HloComputation::AcceptOrdered(
    717     ConstDfsHloVisitor*, const std::vector<const HloInstruction*>&) const;
    718 
    719 Status HloComputation::Accept(
    720     const std::function<Status(HloInstruction*)>& visitor_func) {
    721   FunctionVisitor visitor(visitor_func);
    722   return this->Accept(&visitor);
    723 }
    724 
    725 Status HloComputation::Accept(
    726     const std::function<Status(const HloInstruction*)>& visitor_func) const {
    727   ConstFunctionVisitor visitor(visitor_func);
    728   return this->Accept(&visitor);
    729 }
    730 
    731 std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix,
    732                                                       HloModule* module) {
    733   return CloneWithReplacements(
    734       /*replacements=*/std::unordered_map<const HloInstruction*,
    735                                           std::unique_ptr<HloInstruction>>(),
    736       module, suffix);
    737 }
    738 
    739 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
    740     std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
    741         replacements,
    742     HloModule* module, const string& suffix) {
    743   // Look up instr in the replacements map, and return either the replacement,
    744   // or instr, if the replacement isn't present.
    745   //
    746   // Note: This can return null, indicating that instr should not be present in
    747   // the new computation.
    748   auto replace = [&](HloInstruction* instr) {
    749     auto it = replacements.find(instr);
    750     if (it == replacements.end()) {
    751       return instr;
    752     }
    753     return it->second.get();
    754   };
    755 
    756   VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
    757   std::vector<HloInstruction*> postorder;
    758   for (HloInstruction* instr : MakeInstructionPostOrder()) {
    759     if (HloInstruction* replacement = replace(instr)) {
    760       postorder.push_back(replacement);
    761     }
    762   }
    763 
    764   std::unordered_map<HloInstruction*, HloInstruction*> clone_map;
    765   std::vector<std::unique_ptr<HloInstruction>> instructions;
    766   std::unique_ptr<HloInstruction> new_instr = nullptr;
    767   for (auto instr : postorder) {
    768     std::vector<HloInstruction*> new_operands;
    769     for (auto operand : instr->operands()) {
    770       auto replaced_operand = replace(operand);
    771       // If replaced_operand is null, that means 'replacements' asked us not to
    772       // include operand in the new computation.  But we can't do that, because
    773       // operand is used by instr.
    774       CHECK_NE(replaced_operand, nullptr)
    775           << "replacements map tried to eliminate a used instruction "
    776           << operand->ToString() << ", used by " << instr->ToString();
    777       new_operands.push_back(FindOrDie(clone_map, replaced_operand));
    778     }
    779     new_instr =
    780         instr->CloneWithNewOperands(instr->shape(), new_operands, module);
    781     InsertOrDie(&clone_map, instr, new_instr.get());
    782     instructions.push_back(std::move(new_instr));
    783   }
    784   Builder builder(name() + "." + suffix);
    785   for (auto& instr : instructions) {
    786     builder.AddInstruction(std::move(instr));
    787   }
    788   auto result = builder.Build(
    789       /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction())));
    790 
    791   // Clone control dependencies.
    792   for (auto instr : postorder) {
    793     HloInstruction* new_instr = FindOrDie(clone_map, instr);
    794     for (auto successor : instr->control_successors()) {
    795       auto replaced_successor = replace(successor);
    796 
    797       // successor may not be in clone_map, because it might have been
    798       // removed by the replacements map.
    799       if (replaced_successor == nullptr) {
    800         continue;
    801       }
    802 
    803       TF_CHECK_OK(new_instr->AddControlDependencyTo(
    804           FindOrDie(clone_map, replaced_successor)));
    805     }
    806   }
    807 
    808   // We cloned the elements of 'replacements', so they're all going to be
    809   // destroyed.  HloInstructions need to be detached from their operands before
    810   // they're destroyed, otherwise they stick around in the operands' users lists
    811   // and cause use-after-frees.
    812   for (auto& kv : replacements) {
    813     if (std::unique_ptr<HloInstruction>& new_instr = kv.second) {
    814       new_instr->DetachFromOperands();
    815     }
    816   }
    817 
    818   return result;
    819 }
    820 
    821 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
    822   name_ = name_uniquer->GetUniqueName(name_);
    823 }
    824 
    825 }  // namespace xla
    826