1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/while_util.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 19 #include "tensorflow/compiler/xla/test.h" 20 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" 21 22 namespace xla { 23 namespace { 24 25 namespace op = ::xla::testing::opcode_matchers; 26 27 StatusOr<std::unique_ptr<HloModule>> GetParsedModule( 28 HloComputation** entry_computation, HloInstruction** param0, 29 HloInstruction** param1, HloInstruction** param2) { 30 const char* const hlo_string = R"( 31 HloModule ModuleWithWhile 32 33 while_body { 34 ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0) 35 } 36 37 while_condition { 38 p_cond = f32[32,32]{1,0} parameter(0) 39 ROOT result = pred[] constant(true) 40 } 41 42 ENTRY entry { 43 p_entry_0 = f32[32,32]{1,0} parameter(0) 44 p_entry_1 = s32[32,32]{1,0} parameter(1) 45 p_entry_2 = s64[32,32]{1,0} parameter(2) 46 while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0) 47 ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body 48 } 49 )"; 50 51 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, 52 tools::Parse(hlo_string)); 53 54 *entry_computation = module->entry_computation(); 55 *param0 = (*entry_computation)->parameter_instruction(0); 56 *param1 = (*entry_computation)->parameter_instruction(1); 57 *param2 = (*entry_computation)->parameter_instruction(2); 58 59 return std::move(module); 60 } 61 62 TEST(WhileUtil, MakeZeroInstructionsLiveOp) { 63 HloInstruction *param0, *param1, *param2; 64 HloComputation* entry_computation; 65 66 TF_ASSERT_OK_AND_ASSIGN( 67 std::unique_ptr<HloModule> module, 68 GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); 69 70 HloInstruction* while_instr = entry_computation->root_instruction(); 71 ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); 72 73 TF_ASSERT_OK_AND_ASSIGN( 74 WhileUtil::MakeInstructionsLiveInResult make_live_in_result, 75 WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{})); 76 77 HloInstruction* new_while_instr = make_live_in_result.new_while_instr; 78 79 EXPECT_THAT( 80 entry_computation->root_instruction(), 81 op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), 82 op::GetTupleElement(::testing::Eq(new_while_instr), 1))); 83 84 auto param_reconstructed = 85 op::Tuple(op::GetTupleElement(op::Parameter(0), 0), 86 op::GetTupleElement(op::Parameter(0), 1)); 87 88 EXPECT_THAT(new_while_instr->while_body()->root_instruction(), 89 op::Tuple(op::GetTupleElement(param_reconstructed, 0), 90 op::GetTupleElement(param_reconstructed, 1))); 91 } 92 93 TEST(WhileUtilTest, MakeTwoInstructionsLive) { 94 HloInstruction *param0, *param1, *param2; 95 HloComputation* entry_computation; 96 97 TF_ASSERT_OK_AND_ASSIGN( 98 std::unique_ptr<HloModule> module, 99 GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); 100 101 HloInstruction* while_instr = entry_computation->root_instruction(); 102 ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); 103 104 TF_ASSERT_OK_AND_ASSIGN( 105 WhileUtil::MakeInstructionsLiveInResult make_live_in_result, 106 WhileUtil::MakeInstructionsLiveIn(while_instr, 107 /*instructions=*/{param0, param1})); 108 109 HloInstruction* new_while_instr = make_live_in_result.new_while_instr; 110 111 XLA_VLOG_LINES(3, module->ToString()); 112 113 EXPECT_THAT( 114 entry_computation->root_instruction(), 115 op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), 116 op::GetTupleElement(::testing::Eq(new_while_instr), 1))); 117 118 auto first_half_param_reconstructed = 119 op::Tuple(op::GetTupleElement(op::Parameter(0), 0), 120 op::GetTupleElement(op::Parameter(0), 1)); 121 122 EXPECT_THAT(new_while_instr->while_body()->root_instruction(), 123 op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0), 124 op::GetTupleElement(first_half_param_reconstructed, 1), 125 op::GetTupleElement(op::Parameter(0), 2), 126 op::GetTupleElement(op::Parameter(0), 3))); 127 } 128 129 } // namespace 130 } // namespace xla 131