Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/while_util.h"
     17 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     18 #include "tensorflow/compiler/xla/service/tuple_util.h"
     19 
     20 namespace xla {
     21 
     22 static StatusOr<HloComputation*> WidenWhileCondition(
     23     HloComputation* narrow_condition, const Shape& wide_shape) {
     24   const Shape& narrow_shape =
     25       narrow_condition->parameter_instruction(0)->shape();
     26 
     27   HloComputation* wide_while_cond = [&]() {
     28     HloComputation::Builder builder(
     29         tensorflow::strings::StrCat("wide.", narrow_condition->name()));
     30     builder.AddInstruction(
     31         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
     32 
     33     // This is needed so that the root instruction is shaped as a PRED[] -- we
     34     // need to get this right to begin with since we can't mutate the type of
     35     // the root instruction later.  We later change the root instruction to
     36     // something more appropriate.
     37     builder.AddInstruction(
     38         HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
     39     return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
     40   }();
     41 
     42   HloInstruction* truncated_parameter =
     43       TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0),
     44                                narrow_shape.tuple_shapes_size());
     45   HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction(
     46       HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}),
     47                                  {truncated_parameter}, narrow_condition));
     48 
     49   wide_while_cond->set_root_instruction(call_narrow_cond);
     50 
     51   TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status());
     52   return wide_while_cond;
     53 }
     54 
     55 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>>
     56 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) {
     57   const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape();
     58 
     59   HloComputation* wide_while_body = [&]() {
     60     HloComputation::Builder builder(
     61         tensorflow::strings::StrCat("wide.", narrow_body->name()));
     62     builder.AddInstruction(
     63         HloInstruction::CreateParameter(0, wide_shape, "wide_param"));
     64     return narrow_body->parent()->AddEmbeddedComputation(builder.Build());
     65   }();
     66 
     67   HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0);
     68   HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix(
     69       wide_parameter, narrow_shape.tuple_shapes_size());
     70   HloInstruction* call_narrow_body =
     71       wide_while_body->AddInstruction(HloInstruction::CreateCall(
     72           narrow_shape, {truncated_parameter}, narrow_body));
     73 
     74   std::vector<HloInstruction*> live_through_values;
     75   for (int i = narrow_shape.tuple_shapes_size();
     76        i < wide_shape.tuple_shapes_size(); i++) {
     77     live_through_values.push_back(
     78         wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
     79             wide_shape.tuple_shapes(i), wide_parameter, i)));
     80   }
     81 
     82   wide_while_body->set_root_instruction(
     83       TupleUtil::AppendSuffix(call_narrow_body, live_through_values));
     84 
     85   TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
     86                       CallInliner::Inline(call_narrow_body));
     87   return {{wide_while_body, std::move(inlined_instructions_map)}};
     88 }
     89 
     90 /*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult>
     91 WhileUtil::MakeInstructionsLiveIn(
     92     HloInstruction* while_instr,
     93     tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
     94   CHECK(ShapeUtil::IsTuple(while_instr->shape()));
     95 
     96   int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size();
     97   Shape new_while_shape = while_instr->shape();
     98   for (auto* instruction : instructions) {
     99     *new_while_shape.add_tuple_shapes() = instruction->shape();
    100   }
    101 
    102   TF_ASSIGN_OR_RETURN(
    103       HloComputation * new_while_condition,
    104       WidenWhileCondition(while_instr->while_condition(), new_while_shape));
    105 
    106   HloComputation* new_while_body;
    107   CallInliner::InlinedInstructionMap inlined_instructions_map;
    108   TF_ASSIGN_OR_RETURN(
    109       std::tie(new_while_body, inlined_instructions_map),
    110       WidenWhileBody(while_instr->while_body(), new_while_shape));
    111 
    112   HloInstruction* new_while_init =
    113       TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions);
    114   HloComputation* containing_computation = while_instr->parent();
    115   HloInstruction* new_while = containing_computation->AddInstruction(
    116       HloInstruction::CreateWhile(new_while_shape, new_while_condition,
    117                                   new_while_body, new_while_init));
    118   TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction(
    119       while_instr, TupleUtil::ExtractPrefix(
    120                        new_while, while_instr->shape().tuple_shapes_size())));
    121 
    122   HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
    123   std::vector<HloInstruction*> live_in_instructions;
    124   for (int64 i = elements_in_old_while_shape;
    125        i < new_while_shape.tuple_shapes_size(); i++) {
    126     live_in_instructions.push_back(
    127         new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement(
    128             instructions[i - elements_in_old_while_shape]->shape(),
    129             while_body_param, i)));
    130   }
    131 
    132   WhileUtil::MakeInstructionsLiveInResult result;
    133 
    134   result.new_while_instr = new_while;
    135   result.while_body_live_in_values = std::move(live_in_instructions);
    136   result.while_body_instruction_map = std::move(inlined_instructions_map);
    137 
    138   return std::move(result);
    139 }
    140 }  // namespace xla
    141