1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/gpu/while_transformer.h" 17 18 #include "tensorflow/compiler/xla/service/copy_insertion.h" 19 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" 20 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 21 #include "tensorflow/compiler/xla/shape_util.h" 22 #include "tensorflow/compiler/xla/test.h" 23 #include "tensorflow/compiler/xla/test_helpers.h" 24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 25 #include "tensorflow/core/lib/core/status_test_util.h" 26 27 namespace xla { 28 namespace { 29 30 using ::testing::Eq; 31 using ::testing::HasSubstr; 32 33 class WhileTransformerTest : public HloTestBase { 34 protected: 35 WhileTransformerTest() 36 : module_(CreateNewModule()), 37 induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), 38 data_shape_(ShapeUtil::MakeShape(F32, {8})), 39 condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} 40 41 std::unique_ptr<HloComputation> BuildConditionComputation( 42 const int64 tuple_index, const int64 limit) { 43 auto builder = HloComputation::Builder(TestName() + ".Condition"); 44 auto limit_const = builder.AddInstruction( 45 HloInstruction::CreateConstant(Literal::CreateR0<int32>(limit))); 46 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( 47 0, GetLoopStateShape(tuple_index), "loop_state")); 48 auto induction_variable = 49 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 50 limit_const->shape(), loop_state, tuple_index)); 51 builder.AddInstruction( 52 HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt, 53 induction_variable, limit_const)); 54 return builder.Build(); 55 } 56 57 std::unique_ptr<HloComputation> BuildBodyComputation( 58 const int64 ind_var_tuple_index, const int64 data_tuple_index, 59 const int64 increment) { 60 auto builder = HloComputation::Builder(TestName() + ".Body"); 61 // Create param instruction to access loop state. 62 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( 63 0, GetLoopStateShape(ind_var_tuple_index), "loop_state")); 64 // Update the induction variable GTE(ind_var_tuple_index). 65 auto induction_variable = 66 builder.AddInstruction(HloInstruction::CreateGetTupleElement( 67 induction_variable_shape_, loop_state, ind_var_tuple_index)); 68 auto inc = builder.AddInstruction( 69 HloInstruction::CreateConstant(Literal::CreateR0<int32>(increment))); 70 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 71 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); 72 // Update data GTE(data_tuple_index). 73 auto data = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 74 data_shape_, loop_state, data_tuple_index)); 75 // Use 'induction_variable' in computation with no path to output tuple. 76 auto update = builder.AddInstruction( 77 HloInstruction::CreateBroadcast(data_shape_, induction_variable, {})); 78 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 79 data_shape_, HloOpcode::kAdd, data, update)); 80 // Create output Tuple. 81 ind_var_tuple_index == 0 82 ? builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})) 83 : builder.AddInstruction(HloInstruction::CreateTuple({add1, add0})); 84 return builder.Build(); 85 } 86 87 HloInstruction* BuildWhileInstruction(HloComputation* condition, 88 HloComputation* body, 89 const int64 ind_var_tuple_index, 90 const int64 ind_var_init) { 91 auto builder = HloComputation::Builder(TestName() + ".While"); 92 auto induction_var_init = builder.AddInstruction( 93 HloInstruction::CreateConstant(Literal::CreateR0<int32>(ind_var_init))); 94 auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( 95 Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); 96 auto loop_state_init = 97 ind_var_tuple_index == 0 98 ? builder.AddInstruction( 99 HloInstruction::CreateTuple({induction_var_init, data_init})) 100 : builder.AddInstruction( 101 HloInstruction::CreateTuple({data_init, induction_var_init})); 102 auto while_hlo = builder.AddInstruction( 103 HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index), 104 condition, body, loop_state_init)); 105 module_->AddEntryComputation(builder.Build()); 106 return while_hlo; 107 } 108 109 void RunFusionPasses() { 110 // Run standard fusion passes. 111 EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/false) 112 .Run(module_.get()) 113 .ValueOrDie()); 114 EXPECT_TRUE(gpu::GpuInstructionFusion(/*may_duplicate=*/true) 115 .Run(module_.get()) 116 .ValueOrDie()); 117 } 118 119 void RunCopyInsertionPass() { 120 HloVerifier verifier; 121 TF_ASSERT_OK(verifier.Run(module_.get()).status()); 122 CopyInsertion copy_insertion; 123 TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); 124 } 125 126 Shape GetLoopStateShape(const int64 ind_var_tuple_index) { 127 if (ind_var_tuple_index == 0) { 128 return ShapeUtil::MakeTupleShape( 129 {induction_variable_shape_, data_shape_}); 130 } else { 131 return ShapeUtil::MakeTupleShape( 132 {data_shape_, induction_variable_shape_}); 133 } 134 } 135 136 std::unique_ptr<HloModule> module_; 137 Shape induction_variable_shape_; 138 Shape data_shape_; 139 Shape condition_result_shape_; 140 }; 141 142 // TODO(b/68830972): The while transformer is far too fragile. It patterns 143 // matches the exact expressions of opcodes. Re-enable when transformation is 144 // more general 145 TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { 146 // Build computation with induction variable at tuple element 0. 147 auto condition = 148 module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); 149 auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); 150 auto while_hlo = BuildWhileInstruction(condition, body, 0, 0); 151 // Run HLO Optimization passes. 152 RunFusionPasses(); 153 RunCopyInsertionPass(); 154 // Run WhileTransformer. 155 auto result = gpu::CanTransformWhileToFor(while_hlo); 156 TF_ASSERT_OK(result.status()); 157 // Check results. 158 EXPECT_THAT(result.ConsumeValueOrDie(), 159 Eq(std::tuple<int64, int64, int64>(0, 10, 1))); 160 } 161 162 // TODO(b/68830972): The while transformer is far too fragile. It patterns 163 // matches the exact expressions of opcodes. Re-enable when transformation is 164 // more general 165 TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { 166 // Build computation with induction variable at tuple element 1. 167 auto condition = 168 module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); 169 auto body = module_->AddEmbeddedComputation(BuildBodyComputation(1, 0, 1)); 170 auto while_hlo = BuildWhileInstruction(condition, body, 1, 0); 171 // Run HLO Optimization passes. 172 RunFusionPasses(); 173 RunCopyInsertionPass(); 174 // Run WhileTransformer. 175 auto result = gpu::CanTransformWhileToFor(while_hlo); 176 TF_ASSERT_OK(result.status()); 177 // Check results. 178 EXPECT_THAT(result.ConsumeValueOrDie(), 179 Eq(std::tuple<int64, int64, int64>(0, 10, 1))); 180 } 181 182 // TODO(b/68830972): The while transformer is far too fragile. It patterns 183 // matches the exact expressions of opcodes. Re-enable when transformation is 184 // more general 185 TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { 186 // Build computation with invalid loop limit. 187 auto condition = 188 module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); 189 auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1)); 190 auto while_hlo = BuildWhileInstruction(condition, body, 0, 10); 191 // Run HLO Optimization passes. 192 RunFusionPasses(); 193 RunCopyInsertionPass(); 194 // Run WhileTransformer. 195 auto result = gpu::CanTransformWhileToFor(while_hlo); 196 ASSERT_FALSE(result.ok()); 197 EXPECT_THAT(result.status().error_message(), 198 HasSubstr("Loop start must be less than loop limit.")); 199 } 200 201 // TODO(b/68830972): The while transformer is far too fragile. It patterns 202 // matches the exact expressions of opcodes. Re-enable when transformation is 203 // more general 204 TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { 205 // Build computation with invalid loop increment. 206 auto condition = 207 module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); 208 auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, -1)); 209 auto while_hlo = BuildWhileInstruction(condition, body, 0, 0); 210 // Run HLO Optimization passes. 211 RunFusionPasses(); 212 RunCopyInsertionPass(); 213 // Run WhileTransformer. 214 auto result = gpu::CanTransformWhileToFor(while_hlo); 215 ASSERT_FALSE(result.ok()); 216 EXPECT_THAT(result.status().error_message(), 217 HasSubstr("Loop increment must greater than zero.")); 218 } 219 220 } // namespace 221 } // namespace xla 222