1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/while_util.h" 17 #include "tensorflow/compiler/xla/service/hlo_computation.h" 18 #include "tensorflow/compiler/xla/service/tuple_util.h" 19 20 namespace xla { 21 22 static StatusOr<HloComputation*> WidenWhileCondition( 23 HloComputation* narrow_condition, const Shape& wide_shape) { 24 const Shape& narrow_shape = 25 narrow_condition->parameter_instruction(0)->shape(); 26 27 HloComputation* wide_while_cond = [&]() { 28 HloComputation::Builder builder( 29 tensorflow::strings::StrCat("wide.", narrow_condition->name())); 30 builder.AddInstruction( 31 HloInstruction::CreateParameter(0, wide_shape, "wide_param")); 32 33 // This is needed so that the root instruction is shaped as a PRED[] -- we 34 // need to get this right to begin with since we can't mutate the type of 35 // the root instruction later. We later change the root instruction to 36 // something more appropriate. 37 builder.AddInstruction( 38 HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); 39 return narrow_condition->parent()->AddEmbeddedComputation(builder.Build()); 40 }(); 41 42 HloInstruction* truncated_parameter = 43 TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0), 44 narrow_shape.tuple_shapes_size()); 45 HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction( 46 HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}), 47 {truncated_parameter}, narrow_condition)); 48 49 wide_while_cond->set_root_instruction(call_narrow_cond); 50 51 TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status()); 52 return wide_while_cond; 53 } 54 55 static StatusOr<std::pair<HloComputation*, CallInliner::InlinedInstructionMap>> 56 WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { 57 const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); 58 59 HloComputation* wide_while_body = [&]() { 60 HloComputation::Builder builder( 61 tensorflow::strings::StrCat("wide.", narrow_body->name())); 62 builder.AddInstruction( 63 HloInstruction::CreateParameter(0, wide_shape, "wide_param")); 64 return narrow_body->parent()->AddEmbeddedComputation(builder.Build()); 65 }(); 66 67 HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0); 68 HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix( 69 wide_parameter, narrow_shape.tuple_shapes_size()); 70 HloInstruction* call_narrow_body = 71 wide_while_body->AddInstruction(HloInstruction::CreateCall( 72 narrow_shape, {truncated_parameter}, narrow_body)); 73 74 std::vector<HloInstruction*> live_through_values; 75 for (int i = narrow_shape.tuple_shapes_size(); 76 i < wide_shape.tuple_shapes_size(); i++) { 77 live_through_values.push_back( 78 wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( 79 wide_shape.tuple_shapes(i), wide_parameter, i))); 80 } 81 82 wide_while_body->set_root_instruction( 83 TupleUtil::AppendSuffix(call_narrow_body, live_through_values)); 84 85 TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, 86 CallInliner::Inline(call_narrow_body)); 87 return {{wide_while_body, std::move(inlined_instructions_map)}}; 88 } 89 90 /*static*/ StatusOr<WhileUtil::MakeInstructionsLiveInResult> 91 WhileUtil::MakeInstructionsLiveIn( 92 HloInstruction* while_instr, 93 tensorflow::gtl::ArraySlice<HloInstruction*> instructions) { 94 CHECK(ShapeUtil::IsTuple(while_instr->shape())); 95 96 int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); 97 Shape new_while_shape = while_instr->shape(); 98 for (auto* instruction : instructions) { 99 *new_while_shape.add_tuple_shapes() = instruction->shape(); 100 } 101 102 TF_ASSIGN_OR_RETURN( 103 HloComputation * new_while_condition, 104 WidenWhileCondition(while_instr->while_condition(), new_while_shape)); 105 106 HloComputation* new_while_body; 107 CallInliner::InlinedInstructionMap inlined_instructions_map; 108 TF_ASSIGN_OR_RETURN( 109 std::tie(new_while_body, inlined_instructions_map), 110 WidenWhileBody(while_instr->while_body(), new_while_shape)); 111 112 HloInstruction* new_while_init = 113 TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions); 114 HloComputation* containing_computation = while_instr->parent(); 115 HloInstruction* new_while = containing_computation->AddInstruction( 116 HloInstruction::CreateWhile(new_while_shape, new_while_condition, 117 new_while_body, new_while_init)); 118 TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( 119 while_instr, TupleUtil::ExtractPrefix( 120 new_while, while_instr->shape().tuple_shapes_size()))); 121 122 HloInstruction* while_body_param = new_while_body->parameter_instruction(0); 123 std::vector<HloInstruction*> live_in_instructions; 124 for (int64 i = elements_in_old_while_shape; 125 i < new_while_shape.tuple_shapes_size(); i++) { 126 live_in_instructions.push_back( 127 new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( 128 instructions[i - elements_in_old_while_shape]->shape(), 129 while_body_param, i))); 130 } 131 132 WhileUtil::MakeInstructionsLiveInResult result; 133 134 result.new_while_instr = new_while; 135 result.while_body_live_in_values = std::move(live_in_instructions); 136 result.while_body_instruction_map = std::move(inlined_instructions_map); 137 138 return std::move(result); 139 } 140 } // namespace xla 141