Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
     17 
     18 #include "tensorflow/compiler/xla/service/call_graph.h"
     19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     21 #include "tensorflow/compiler/xla/service/hlo_module.h"
     22 #include "tensorflow/compiler/xla/util.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 
     25 namespace xla {
     26 
     27 namespace {
     28 
     29 // Helper to replace the called computation at a while-, call-, or
     30 // conditional-instruction. This function replaces exactly one instance of
     31 // 'computation' with 'new_computation' even if 'instruction' calls
     32 // 'computation' more than once.
     33 void ReplaceCalledComputation(HloInstruction* instruction,
     34                               HloComputation* computation,
     35                               HloComputation* new_computation) {
     36   switch (instruction->opcode()) {
     37     case HloOpcode::kWhile: {
     38       if (computation == instruction->while_condition()) {
     39         instruction->set_while_condition(new_computation);
     40       } else {
     41         CHECK_EQ(computation, instruction->while_body());
     42         instruction->set_while_body(new_computation);
     43       }
     44       break;
     45     }
     46     case HloOpcode::kCall: {
     47       CHECK_EQ(instruction->to_apply(), computation);
     48       instruction->set_to_apply(new_computation);
     49       break;
     50     }
     51     case HloOpcode::kConditional: {
     52       if (computation == instruction->true_computation()) {
     53         instruction->set_true_computation(new_computation);
     54       } else {
     55         CHECK_EQ(computation, instruction->false_computation());
     56         instruction->set_false_computation(new_computation);
     57       }
     58       break;
     59     }
     60     default:
     61       LOG(FATAL) << "unexpected opcode: "
     62                  << HloOpcodeString(instruction->opcode());
     63   }
     64 }
     65 
     66 // Flatten a single call graph node. Expects to visit nodes in postorder.
     67 Status FlattenNode(const CallGraphNode& node) {
     68   HloComputation* computation = node.computation();
     69   HloModule* module = computation->parent();
     70   // Clone callee for all call-sites except the first one.
     71   for (int i = 0; i < node.caller_callsites().size(); ++i) {
     72     CallSite call_site = node.caller_callsites()[i];
     73     // Only consider sequential call contexts.
     74     if (call_site.context() == CallContext::kParallel) {
     75       continue;
     76     }
     77     CHECK_EQ(call_site.context(), CallContext::kSequential);
     78 
     79     // Skip first element if this computation is only called from a sequential
     80     // context.
     81     if (node.context() != CallContext::kBoth && i == 0) {
     82       continue;
     83     }
     84 
     85     // Clone computation for the remaining sequential context call sites.
     86     HloComputation* clone =
     87         module->AddEmbeddedComputation(computation->Clone());
     88     ReplaceCalledComputation(call_site.instruction(), computation, clone);
     89     // Clone the sub-tree of all computations called from this node.
     90     std::vector<HloComputation*> worklist;
     91     worklist.push_back(clone);
     92     while (!worklist.empty()) {
     93       auto current = worklist.back();
     94       worklist.pop_back();
     95       for (auto* instruction : current->instructions()) {
     96         if (GetInstructionCallContext(instruction) !=
     97             CallContext::kSequential) {
     98           continue;
     99         }
    100         for (auto callee : instruction->called_computations()) {
    101           HloComputation* callee_clone =
    102               module->AddEmbeddedComputation(callee->Clone());
    103           ReplaceCalledComputation(instruction, callee, callee_clone);
    104           worklist.push_back(callee_clone);
    105         }
    106       }
    107     }
    108   }
    109   return Status::OK();
    110 }
    111 
    112 }  // namespace
    113 
    114 StatusOr<bool> FlattenCallGraph::Run(HloModule* module) {
    115   XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString());
    116 
    117   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
    118   TF_RETURN_IF_ERROR(call_graph->VisitNodes(FlattenNode));
    119 
    120   XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString());
    121   return true;
    122 }
    123 
    124 }  // namespace xla
    125