Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/call_inliner.h"
     17 
     18 #include <deque>
     19 
     20 #include "tensorflow/compiler/xla/service/call_graph.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 
     23 namespace xla {
     24 namespace {
     25 
     26 // Traverses the callee computation, inlining cloned nodes into the caller
     27 // computation and connecting them to producers/consumers appropriately.
     28 // When the traversal has completed, the provided call instruction is entriely
     29 // replaced in the caller's graph.
     30 class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
     31  public:
     32   // call is the call operation -- it will be replaced with the body of the
     33   // called computation.
     34   explicit SubcomputationInsertionVisitor(HloInstruction* call)
     35       : call_(call), outer_(call->parent()) {
     36     CHECK_EQ(HloOpcode::kCall, call_->opcode());
     37   }
     38 
     39   // Resolves the operands to the HLO instruction in the inlined (caller) graph,
     40   // and clones the HLO instruction into that graph with the new operands.
     41   // If the instruction is a call, it is added to the work queue.
     42   Status DefaultAction(HloInstruction* hlo) override {
     43     TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
     44     std::vector<HloInstruction*> new_operands;
     45     for (HloInstruction* operand : hlo->operands()) {
     46       TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
     47       new_operands.push_back(new_operand);
     48     }
     49     VLOG(1) << "Cloning HLO and adding to caller: " << hlo->ToString();
     50     auto new_hlo = hlo->CloneWithNewOperands(hlo->shape(), new_operands);
     51     HloInstruction* new_hlo_pointer =
     52         outer_->AddInstruction(std::move(new_hlo));
     53     TF_RETURN_IF_ERROR(NoteMapping(hlo, new_hlo_pointer));
     54 
     55     // Account for control edges.
     56     for (HloInstruction* control_predecessor : hlo->control_predecessors()) {
     57       TF_ASSIGN_OR_RETURN(HloInstruction * new_control_predecessor,
     58                           Resolve(control_predecessor));
     59       TF_RETURN_IF_ERROR(
     60           new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
     61     }
     62 
     63     return Status::OK();
     64   }
     65 
     66   // Does not create new nodes for the parameter; rather, notes the mapping from
     67   // the subcomputation parameter node to the call operands in the caller
     68   // computation.
     69   Status HandleParameter(HloInstruction* parameter) override {
     70     TF_RETURN_IF_ERROR(NoteMapping(
     71         parameter, call_->mutable_operand(parameter->parameter_number())));
     72     return Status::OK();
     73   }
     74 
     75   // Wires the consumers of the call to instead point at the newly created root,
     76   // replacing the call operation in the caller computation.
     77   Status FinishVisit(HloInstruction* root) override {
     78     TF_ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root));
     79     VLOG(1) << "Replacing all uses of " << call_->ToString()
     80             << " with new root " << new_root->ToString();
     81     call_->ClearCalledComputations();
     82     return outer_->ReplaceInstruction(call_, new_root);
     83   }
     84 
     85   CallInliner::InlinedInstructionMap ConsumeInstructionMap() {
     86     return std::move(subcomputation_hlo_to_new_hlo_);
     87   }
     88 
     89  private:
     90   // Resolves the callee subcomputation_hlo to the new (inline) HLO in the
     91   // caller computation, or returns a NotFound error if that subcomputation HLO
     92   // has not been mapped.
     93   StatusOr<HloInstruction*> Resolve(HloInstruction* subcomputation_hlo) {
     94     auto it = subcomputation_hlo_to_new_hlo_.find(subcomputation_hlo);
     95     if (it == subcomputation_hlo_to_new_hlo_.end()) {
     96       return NotFound(
     97           "Could not find mapping from subcomputation HLO %s to a cloned HLO.",
     98           subcomputation_hlo->ToString().c_str());
     99     }
    100     return it->second;
    101   }
    102 
    103   // Notes that the given subcomputation_hlo in the callee has been mapped to
    104   // the (inline) new_hlo in the caller computation.
    105   //
    106   // Returns an error status if the subcomputation_hlo is mapped more than
    107   // once.
    108   Status NoteMapping(HloInstruction* subcomputation_hlo,
    109                      HloInstruction* new_hlo) {
    110     auto result = subcomputation_hlo_to_new_hlo_.insert(
    111         std::make_pair(subcomputation_hlo, new_hlo));
    112     TF_RET_CHECK(result.second)
    113         << "A mapping for the subcomputation HLO is already present.";
    114     return Status::OK();
    115   }
    116 
    117   HloInstruction* call_;
    118   HloComputation* outer_;
    119   CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_;
    120 };
    121 
    122 }  // namespace
    123 
    124 /* static */ StatusOr<CallInliner::InlinedInstructionMap> CallInliner::Inline(
    125     HloInstruction* call) {
    126   TF_RET_CHECK(call->opcode() == HloOpcode::kCall)
    127       << "Instruction was not a call op: " << call->opcode();
    128   const auto& callees = call->called_computations();
    129   TF_RET_CHECK(callees.size() == 1);
    130   HloComputation* callee = callees[0];
    131   // We visit the callee, cloning its body into its caller.
    132   SubcomputationInsertionVisitor visitor(call);
    133   TF_RETURN_IF_ERROR(callee->Accept(&visitor));
    134   return visitor.ConsumeInstructionMap();
    135 }
    136 
    137 StatusOr<bool> CallInliner::Run(HloModule* module) {
    138   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
    139   // Because call graph nodes are visited in post-order (callees before callers)
    140   // we'll always inline kCalls into their callers in the appropriate order.
    141   bool did_mutate = false;
    142   TF_RETURN_IF_ERROR(
    143       call_graph->VisitNodes([&](const CallGraphNode& node) -> Status {
    144         for (const CallSite& callsite : node.caller_callsites()) {
    145           VLOG(1) << "Visiting callsite: " << callsite.ToString();
    146           if (callsite.instruction()->opcode() == HloOpcode::kCall) {
    147             HloInstruction* call = callsite.instruction();
    148             TF_RETURN_IF_ERROR(Inline(call).status());
    149             did_mutate = true;
    150           }
    151         }
    152         return Status::OK();
    153       }));
    154   return did_mutate;
    155 }
    156 
    157 }  // namespace xla
    158