1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_ordering.h" 17 18 #include <memory> 19 #include <string> 20 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 25 #include "tensorflow/compiler/xla/service/hlo_scheduling.h" 26 #include "tensorflow/compiler/xla/shape_util.h" 27 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 28 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 32 namespace xla { 33 namespace { 34 35 class HloOrderingTest : public HloTestBase {}; 36 37 TEST_F(HloOrderingTest, LastUseScheduledFirst) { 38 // Tests scheduling of the following HLO code: 39 // 40 // %ab = abs(%param) 41 // %exp = exp(%param) 42 // %add = add(%ab, %exp) 43 // %negate = negate(%exp) 44 // %sub = subtract(%add, %negate) 45 // 46 // %add should be scheduled before %negate because %add is the last (and only) 47 // use of %ab. Scheduling %add first then frees up %ab's buffer. 48 const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); 49 auto builder = HloComputation::Builder(TestName()); 50 auto param = 51 builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); 52 auto ab = builder.AddInstruction( 53 HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); 54 auto exp = builder.AddInstruction( 55 HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); 56 57 auto add = builder.AddInstruction( 58 HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); 59 auto negate = builder.AddInstruction( 60 HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); 61 auto sub = builder.AddInstruction( 62 HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); 63 64 auto module = CreateNewModule(); 65 module->AddEntryComputation(builder.Build()); 66 67 TF_ASSERT_OK_AND_ASSIGN( 68 SequentialHloOrdering::HloModuleSequence sequence, 69 CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { 70 return ShapeUtil::ByteSizeOf(buffer.shape()); 71 })); 72 // Verify that all instructions are in the sequence. 73 EXPECT_EQ(module->entry_computation()->instruction_count(), 74 sequence.at(module->entry_computation()).size()); 75 76 // The first instruction should be the parameter and the last the root "sub". 77 EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); 78 EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); 79 80 SequentialHloOrdering ordering(module.get(), sequence); 81 EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); 82 } 83 84 TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { 85 // Tests the ordering of instructions in different computations using the 86 // following HLO code: 87 // 88 // Entry computation: 89 // %x = Call(A, {}) 90 // %y = Call(B, {%x}) 91 // 92 // Computation A: 93 // %a = Call(C, {}) 94 // 95 // Computation B: 96 // %b = Call(C, {}) 97 // 98 // Computation C: 99 // %c = Constant(42.0f) 100 // 101 // This results in a diamond-shaped callgraph. 102 auto module = CreateNewModule(); 103 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); 104 105 auto builder_c = HloComputation::Builder("C"); 106 HloInstruction* c = builder_c.AddInstruction( 107 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 108 HloComputation* computation_c = 109 module->AddEmbeddedComputation(builder_c.Build()); 110 111 auto builder_b = HloComputation::Builder("B"); 112 builder_b.AddInstruction( 113 HloInstruction::CreateParameter(0, scalar_shape, "param")); 114 HloInstruction* b = builder_b.AddInstruction( 115 HloInstruction::CreateCall(scalar_shape, {}, computation_c)); 116 HloComputation* computation_b = 117 module->AddEmbeddedComputation(builder_b.Build()); 118 119 auto builder_a = HloComputation::Builder("A"); 120 HloInstruction* a = builder_a.AddInstruction( 121 HloInstruction::CreateCall(scalar_shape, {}, computation_c)); 122 HloComputation* computation_a = 123 module->AddEmbeddedComputation(builder_a.Build()); 124 125 auto builder = HloComputation::Builder(TestName()); 126 HloInstruction* x = builder.AddInstruction( 127 HloInstruction::CreateCall(scalar_shape, {}, computation_a)); 128 HloInstruction* y = builder.AddInstruction( 129 HloInstruction::CreateCall(scalar_shape, {x}, computation_b)); 130 module->AddEntryComputation(builder.Build()); 131 132 DependencyHloOrdering ordering(module.get()); 133 EXPECT_TRUE(ordering.ExecutesBefore(x, y)); 134 EXPECT_FALSE(ordering.ExecutesBefore(y, x)); 135 136 EXPECT_TRUE(ordering.ExecutesBefore(a, b)); 137 EXPECT_FALSE(ordering.ExecutesBefore(b, a)); 138 139 EXPECT_FALSE(ordering.ExecutesBefore(a, x)); 140 EXPECT_TRUE(ordering.ExecutesBefore(a, y)); 141 EXPECT_FALSE(ordering.ExecutesBefore(x, a)); 142 EXPECT_FALSE(ordering.ExecutesBefore(y, a)); 143 144 EXPECT_FALSE(ordering.ExecutesBefore(b, x)); 145 EXPECT_FALSE(ordering.ExecutesBefore(b, y)); 146 EXPECT_TRUE(ordering.ExecutesBefore(x, b)); 147 EXPECT_FALSE(ordering.ExecutesBefore(y, b)); 148 149 // Instruction 'c' is called from multiple callsites and should be unordered 150 // relative to all other instructions in the module. 151 EXPECT_FALSE(ordering.ExecutesBefore(c, a)); 152 EXPECT_FALSE(ordering.ExecutesBefore(c, b)); 153 EXPECT_FALSE(ordering.ExecutesBefore(c, x)); 154 EXPECT_FALSE(ordering.ExecutesBefore(c, y)); 155 EXPECT_FALSE(ordering.ExecutesBefore(a, c)); 156 EXPECT_FALSE(ordering.ExecutesBefore(b, c)); 157 EXPECT_FALSE(ordering.ExecutesBefore(x, c)); 158 EXPECT_FALSE(ordering.ExecutesBefore(y, c)); 159 } 160 161 TEST_F(HloOrderingTest, InstructionsInWhileComputations) { 162 // Tests the ordering of instructions in the body and condition of a while 163 // instruction. HLO code: 164 // 165 // body(F32[]) %param): 166 // %negate = Negate(%param) 167 // 168 // condition(F32[] %param): 169 // %convert = Convert<PRED>(%param) 170 // 171 // entry: 172 // %constant = Constant(1.0) 173 // return While(%constant, body, condition) 174 // 175 auto module = CreateNewModule(); 176 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); 177 178 auto body_builder = HloComputation::Builder("body"); 179 auto body_param = body_builder.AddInstruction( 180 HloInstruction::CreateParameter(0, scalar_shape, "body_param")); 181 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( 182 scalar_shape, HloOpcode::kNegate, body_param)); 183 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 184 185 auto cond_builder = HloComputation::Builder("condition"); 186 auto cond_param = cond_builder.AddInstruction( 187 HloInstruction::CreateParameter(0, scalar_shape, "cond_param")); 188 auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert( 189 ShapeUtil::MakeShape(xla::PRED, {}), cond_param)); 190 HloComputation* condition = 191 module->AddEmbeddedComputation(cond_builder.Build()); 192 193 auto builder = HloComputation::Builder(TestName()); 194 auto constant = builder.AddInstruction( 195 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 196 auto xla_while = builder.AddInstruction( 197 HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); 198 module->AddEntryComputation(builder.Build()); 199 200 DependencyHloOrdering ordering(module.get()); 201 EXPECT_TRUE(ordering.ExecutesBefore(constant, xla_while)); 202 EXPECT_TRUE(ordering.ExecutesBefore(constant, cond_param)); 203 EXPECT_TRUE(ordering.ExecutesBefore(constant, convert)); 204 EXPECT_TRUE(ordering.ExecutesBefore(constant, body_param)); 205 EXPECT_TRUE(ordering.ExecutesBefore(constant, negate)); 206 207 // The while should be unordered relative to the body and condition 208 // instructions. 209 EXPECT_FALSE(ordering.ExecutesBefore(xla_while, body_param)); 210 EXPECT_FALSE(ordering.ExecutesBefore(xla_while, cond_param)); 211 EXPECT_FALSE(ordering.ExecutesBefore(body_param, xla_while)); 212 EXPECT_FALSE(ordering.ExecutesBefore(cond_param, xla_while)); 213 214 // Condition instructions should be ordered before body instructions. 215 EXPECT_TRUE(ordering.ExecutesBefore(cond_param, body_param)); 216 EXPECT_TRUE(ordering.ExecutesBefore(convert, body_param)); 217 EXPECT_TRUE(ordering.ExecutesBefore(cond_param, negate)); 218 EXPECT_TRUE(ordering.ExecutesBefore(convert, negate)); 219 220 EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); 221 } 222 223 TEST_F(HloOrderingTest, ValuesInWhileComputations) { 224 // Tests the ordering of values (defined by dataflow analysis) in the body and 225 // condition of a while instruction. HLO code: 226 // 227 // body(F32[]) %param): 228 // %negate = Negate(%param) 229 // 230 // condition(F32[] %param): 231 // %convert = Convert<PRED>(%param) 232 // 233 // entry: 234 // %constant = Constant(1.0) 235 // %while = While(%constant, body, condition) 236 // %add = Add(%constant, %while) 237 // 238 auto module = CreateNewModule(); 239 const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); 240 241 auto body_builder = HloComputation::Builder("body"); 242 auto body_param = body_builder.AddInstruction( 243 HloInstruction::CreateParameter(0, scalar_shape, "body_param")); 244 auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( 245 scalar_shape, HloOpcode::kNegate, body_param)); 246 HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); 247 248 auto cond_builder = HloComputation::Builder("condition"); 249 auto cond_param = cond_builder.AddInstruction( 250 HloInstruction::CreateParameter(0, scalar_shape, "cond_param")); 251 auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert( 252 ShapeUtil::MakeShape(xla::PRED, {}), cond_param)); 253 HloComputation* condition = 254 module->AddEmbeddedComputation(cond_builder.Build()); 255 256 auto builder = HloComputation::Builder(TestName()); 257 auto constant = builder.AddInstruction( 258 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 259 auto xla_while = builder.AddInstruction( 260 HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); 261 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 262 scalar_shape, HloOpcode::kAdd, constant, xla_while)); 263 module->AddEntryComputation(builder.Build()); 264 265 TF_ASSERT_OK_AND_ASSIGN(auto dataflow, 266 HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); 267 DependencyHloOrdering ordering(module.get()); 268 269 // Init value is defined before the while, but live range is not before the 270 // while because of the use of the init value in the add. 271 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant), 272 dataflow->GetValueDefinedAt(xla_while))); 273 EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( 274 dataflow->GetValueDefinedAt(constant), 275 dataflow->GetValueDefinedAt(xla_while), *dataflow)); 276 EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant), 277 dataflow->GetValueDefinedAt(xla_while), 278 *dataflow)); 279 280 // Any value defined in the body or condition is defined before the while, and 281 // has a live range strictly before the while. 282 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate), 283 dataflow->GetValueDefinedAt(xla_while))); 284 EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( 285 dataflow->GetValueDefinedAt(negate), 286 dataflow->GetValueDefinedAt(xla_while), *dataflow)); 287 EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate), 288 dataflow->GetValueDefinedAt(xla_while), 289 *dataflow)); 290 291 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert), 292 dataflow->GetValueDefinedAt(xla_while))); 293 EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( 294 dataflow->GetValueDefinedAt(convert), 295 dataflow->GetValueDefinedAt(xla_while), *dataflow)); 296 EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert), 297 dataflow->GetValueDefinedAt(xla_while), 298 *dataflow)); 299 300 // The live range of the while should be before the add. 301 EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while), 302 dataflow->GetValueDefinedAt(add))); 303 ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1); 304 305 const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0]; 306 EXPECT_EQ(while_use.instruction, add); 307 EXPECT_TRUE(ordering.UseIsBeforeValueDefinition( 308 while_use, dataflow->GetValueDefinedAt(add), *dataflow)); 309 EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( 310 dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add), 311 *dataflow)); 312 } 313 314 // Regression test for HloOrdering::ToString() crashing when fed a computation 315 // containing a fusion node. 316 TEST_F(HloOrderingTest, ToStringDoesNotCrash) { 317 const char* module_str = R"( 318 HloModule test_module 319 320 body.v8 { 321 prev.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) 322 get-tuple-element.4 = s32[] get-tuple-element(prev.1), index=0 323 constant.1 = s32[] constant(1) 324 add = s32[] add(get-tuple-element.4, constant.1) 325 get-tuple-element.5 = f32[3]{0} get-tuple-element(prev.1), index=3 326 get-tuple-element.6 = f32[3]{0} get-tuple-element(prev.1), index=1 327 get-tuple-element.7 = f32[3]{0} get-tuple-element(prev.1), index=2 328 ROOT tuple = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(add, get-tuple-element.5, get-tuple-element.6, get-tuple-element.7) 329 } 330 331 condition.v4 { 332 constant.2 = s32[] constant(2) 333 prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) 334 get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0 335 ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8) 336 } 337 338 fused_computation { 339 get-tuple-element.5.param_1 = f32[3]{0} parameter(1) 340 get-tuple-element.6.param_2 = f32[3]{0} parameter(2) 341 add.4 = f32[3]{0} add(get-tuple-element.5.param_1, get-tuple-element.6.param_2) 342 get-tuple-element.7.param_1.1 = f32[3]{0} parameter(0) 343 ROOT add.5 = f32[3]{0} add(add.4, get-tuple-element.7.param_1.1) 344 } 345 346 ENTRY while.v11 { 347 constant.5 = s32[] constant(0) 348 constant.6 = f32[3]{0} constant({1, 1, 1}) 349 constant.7 = f32[3]{0} constant({2, 2, 2}) 350 constant.8 = f32[3]{0} constant({3, 3, 3}) 351 tuple.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(constant.5, constant.6, constant.7, constant.8) 352 while = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) while(tuple.1), condition=condition.v4, body=body.v8 353 get-tuple-element.9 = f32[3]{0} get-tuple-element(while), index=3 354 get-tuple-element.10 = f32[3]{0} get-tuple-element(while), index=1 355 get-tuple-element.11 = f32[3]{0} get-tuple-element(while), index=2 356 ROOT fusion = f32[3]{0} fusion(get-tuple-element.9, get-tuple-element.10, get-tuple-element.11), kind=kLoop, calls=fused_computation 357 })"; 358 359 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, 360 tools::Parse(module_str)); 361 DependencyHloOrdering ordering(module.get()); 362 ordering.ToString(); // Shouldn't crash. 363 } 364 365 } // namespace 366 } // namespace xla 367