1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/liveness_util.h" 17 18 #include <memory> 19 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 23 24 namespace xla { 25 namespace { 26 27 class PointsToAnalysisTestBase : public HloTestBase { 28 protected: 29 void BuildModule(std::unique_ptr<HloComputation> computation) { 30 module_ = CreateNewModule(); 31 computation_ = module_->AddEntryComputation(std::move(computation)); 32 } 33 34 void RunAnalysis() { 35 CHECK_NOTNULL(module_.get()); 36 points_to_analysis_ = 37 TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); 38 dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); 39 } 40 41 void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) { 42 BuildModule(std::move(computation)); 43 RunAnalysis(); 44 } 45 46 std::unique_ptr<HloModule> module_; 47 HloComputation* computation_ = nullptr; 48 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 49 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_; 50 }; 51 52 class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; 53 54 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { 55 auto builder = HloComputation::Builder(TestName()); 56 57 Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); 58 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( 59 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); 60 auto gte0 = builder.AddInstruction( 61 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); 62 auto gte1 = builder.AddInstruction( 63 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); 64 builder.AddInstruction( 65 HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); 66 67 BuildModuleAndRunAnalysis(builder.Build()); 68 69 // GetTupleElement instructions only access the top-level buffer of their 70 // operand. 71 EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_)); 72 EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); 73 EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); 74 EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); 75 76 EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_)); 77 EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_)); 78 EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_)); 79 EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_)); 80 } 81 82 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { 83 auto builder = HloComputation::Builder(TestName()); 84 85 Shape data_shape = ShapeUtil::MakeShape(F32, {8}); 86 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( 87 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); 88 auto gte0 = builder.AddInstruction( 89 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); 90 auto gte1 = builder.AddInstruction( 91 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); 92 93 // Create a DynamicUpdateSlice instruction of tuple element 1. 94 auto starts = builder.AddInstruction( 95 HloInstruction::CreateConstant(Literal::CreateR1<int32>({2}))); 96 auto update = builder.AddInstruction(HloInstruction::CreateConstant( 97 Literal::CreateR1<float>({2.f, 2.f, 2.f}))); 98 auto dynamic_update_slice = 99 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 100 data_shape, gte1, update, starts)); 101 builder.AddInstruction( 102 HloInstruction::CreateTuple({gte0, dynamic_update_slice})); 103 104 BuildModule(builder.Build()); 105 auto fusion = computation_->CreateFusionInstruction( 106 {dynamic_update_slice, starts, update, gte1}, 107 HloInstruction::FusionKind::kLoop); 108 RunAnalysis(); 109 110 // The fusion instruction never uses tuple element 0, but does use element 1. 111 EXPECT_TRUE( 112 DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); 113 EXPECT_FALSE( 114 DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); 115 116 EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_)); 117 EXPECT_FALSE( 118 DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_)); 119 } 120 121 class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; 122 123 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { 124 auto builder = HloComputation::Builder(TestName()); 125 126 Shape shape = ShapeUtil::MakeShape(F32, {8}); 127 auto param = builder.AddInstruction( 128 HloInstruction::CreateParameter(0, shape, "param")); 129 auto exp = builder.AddInstruction( 130 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); 131 auto log = builder.AddInstruction( 132 HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); 133 134 BuildModuleAndRunAnalysis(builder.Build()); 135 136 EXPECT_TRUE( 137 CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); 138 EXPECT_TRUE( 139 CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); 140 141 EXPECT_TRUE( 142 CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); 143 EXPECT_TRUE( 144 CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_)); 145 } 146 147 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { 148 auto builder = HloComputation::Builder(TestName()); 149 150 Shape in_shape = ShapeUtil::MakeShape(F32, {8}); 151 Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); 152 auto param0 = builder.AddInstruction( 153 HloInstruction::CreateParameter(0, in_shape, "param0")); 154 auto param1 = builder.AddInstruction( 155 HloInstruction::CreateParameter(1, in_shape, "param1")); 156 auto result = builder.AddInstruction( 157 HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); 158 159 BuildModuleAndRunAnalysis(builder.Build()); 160 161 EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, 162 *points_to_analysis_)); 163 EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, 164 *points_to_analysis_)); 165 166 EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, 167 *dataflow_analysis_)); 168 EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, 169 *dataflow_analysis_)); 170 } 171 172 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { 173 auto builder = HloComputation::Builder(TestName()); 174 175 Shape shape = ShapeUtil::MakeShape(F32, {8}); 176 auto param = builder.AddInstruction( 177 HloInstruction::CreateParameter(0, shape, "param")); 178 auto exp = builder.AddInstruction( 179 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); 180 auto copy = builder.AddInstruction( 181 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); 182 183 BuildModuleAndRunAnalysis(builder.Build()); 184 185 EXPECT_TRUE( 186 CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); 187 EXPECT_TRUE( 188 CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); 189 190 EXPECT_TRUE( 191 CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); 192 EXPECT_TRUE( 193 CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_)); 194 } 195 196 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { 197 auto builder = HloComputation::Builder(TestName()); 198 199 Shape data_shape = ShapeUtil::MakeShape(F32, {8}); 200 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( 201 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); 202 auto gte0 = builder.AddInstruction( 203 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); 204 auto gte1 = builder.AddInstruction( 205 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); 206 207 // Create a DynamicUpdateSlice instruction of tuple element 1. 208 auto starts = builder.AddInstruction( 209 HloInstruction::CreateConstant(Literal::CreateR1<int32>({2}))); 210 auto update = builder.AddInstruction(HloInstruction::CreateConstant( 211 Literal::CreateR1<float>({2.f, 2.f, 2.f}))); 212 auto dynamic_update_slice = 213 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 214 data_shape, gte1, update, starts)); 215 builder.AddInstruction( 216 HloInstruction::CreateTuple({gte0, dynamic_update_slice})); 217 218 BuildModule(builder.Build()); 219 auto fusion = computation_->CreateFusionInstruction( 220 {dynamic_update_slice, starts, update, gte1}, 221 HloInstruction::FusionKind::kLoop); 222 RunAnalysis(); 223 224 // The fusion instruction can share with tuple element 1. 225 EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, 226 *points_to_analysis_)); 227 EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, 228 *points_to_analysis_)); 229 230 EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, 231 *dataflow_analysis_)); 232 EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, 233 *dataflow_analysis_)); 234 } 235 236 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { 237 auto builder = HloComputation::Builder(TestName()); 238 239 Shape data_shape = ShapeUtil::MakeShape(F32, {8}); 240 Shape update_shape = ShapeUtil::MakeShape(F32, {4}); 241 Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); 242 auto data = builder.AddInstruction( 243 HloInstruction::CreateParameter(0, data_shape, "data")); 244 auto update = builder.AddInstruction( 245 HloInstruction::CreateParameter(1, update_shape, "update")); 246 auto starts = builder.AddInstruction( 247 HloInstruction::CreateParameter(2, starts_shape, "starts")); 248 auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 249 data_shape, data, update, starts)); 250 251 BuildModuleAndRunAnalysis(builder.Build()); 252 253 // The DynamicUpdateSlice instruction can share with the data operand, but not 254 // with update or starts. 255 EXPECT_TRUE( 256 CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); 257 EXPECT_FALSE( 258 CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); 259 EXPECT_FALSE( 260 CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); 261 262 EXPECT_TRUE( 263 CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_)); 264 EXPECT_FALSE( 265 CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_)); 266 EXPECT_FALSE( 267 CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_)); 268 } 269 270 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { 271 auto builder = HloComputation::Builder(TestName()); 272 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); 273 274 auto a = builder.AddInstruction(HloInstruction::CreateConstant( 275 Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}}))); 276 auto b = builder.AddInstruction(HloInstruction::CreateConstant( 277 Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}}))); 278 279 DotDimensionNumbers dot_dnums; 280 dot_dnums.add_lhs_contracting_dimensions(1); 281 dot_dnums.add_rhs_contracting_dimensions(0); 282 auto dot = builder.AddInstruction( 283 HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); 284 285 auto one = builder.AddInstruction( 286 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 287 auto add_operand = builder.AddInstruction( 288 HloInstruction::CreateBroadcast(data_shape, one, {1})); 289 290 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 291 data_shape, HloOpcode::kAdd, dot, add_operand)); 292 293 BuildModule(builder.Build()); 294 auto fusion = computation_->CreateFusionInstruction( 295 {add, dot}, HloInstruction::FusionKind::kOutput); 296 RunAnalysis(); 297 298 // Output fused dot add should be able to share buffer with 'add_operand'. 299 EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, 300 *points_to_analysis_)); 301 302 EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, 303 *dataflow_analysis_)); 304 } 305 306 TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { 307 auto builder = HloComputation::Builder(TestName()); 308 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); 309 310 auto a = builder.AddInstruction(HloInstruction::CreateConstant( 311 Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}}))); 312 auto b = builder.AddInstruction(HloInstruction::CreateConstant( 313 Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}}))); 314 auto b_t = builder.AddInstruction( 315 HloInstruction::CreateTranspose(data_shape, b, {1, 0})); 316 317 DotDimensionNumbers dot_dnums; 318 dot_dnums.add_lhs_contracting_dimensions(1); 319 dot_dnums.add_rhs_contracting_dimensions(0); 320 auto dot = builder.AddInstruction( 321 HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); 322 323 auto one = builder.AddInstruction( 324 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 325 auto add_operand = builder.AddInstruction( 326 HloInstruction::CreateBroadcast(data_shape, one, {1})); 327 328 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 329 data_shape, HloOpcode::kAdd, dot, add_operand)); 330 331 BuildModule(builder.Build()); 332 333 auto nested_fusion = computation_->CreateFusionInstruction( 334 {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); 335 336 auto fusion = computation_->CreateFusionInstruction( 337 {add, nested_fusion}, HloInstruction::FusionKind::kOutput); 338 RunAnalysis(); 339 340 // Output fused transpose-dot-add should be share buffer with 'add_operand'. 341 EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, 342 *points_to_analysis_)); 343 344 EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, 345 *dataflow_analysis_)); 346 } 347 348 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { 349 auto builder = HloComputation::Builder(TestName()); 350 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); 351 352 auto one = builder.AddInstruction( 353 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 354 auto operand = builder.AddInstruction( 355 HloInstruction::CreateBroadcast(data_shape, one, {1})); 356 357 auto reverse = builder.AddInstruction( 358 HloInstruction::CreateReverse(data_shape, operand, {0, 1})); 359 360 auto two = builder.AddInstruction(HloInstruction::CreateConstant( 361 Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}}))); 362 363 auto add = builder.AddInstruction( 364 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); 365 366 BuildModule(builder.Build()); 367 auto fusion = computation_->CreateFusionInstruction( 368 {add, two, reverse}, HloInstruction::FusionKind::kOutput); 369 RunAnalysis(); 370 371 // Output fused operand->reverse->add cannot alias operand buffer 'operand'. 372 EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, 373 *points_to_analysis_)); 374 375 EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, 376 *dataflow_analysis_)); 377 } 378 379 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { 380 Shape data_shape = ShapeUtil::MakeShape(F32, {8}); 381 382 auto make_cond = [this, &data_shape]() { 383 auto builder = HloComputation::Builder(TestName() + ".Cond"); 384 auto data = builder.AddInstruction( 385 HloInstruction::CreateParameter(0, data_shape, "data")); 386 builder.AddInstruction(HloInstruction::CreateBinary( 387 ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); 388 return builder.Build(); 389 }; 390 391 auto make_body = [this, &data_shape]() { 392 auto builder = HloComputation::Builder(TestName() + ".Body"); 393 auto data = builder.AddInstruction( 394 HloInstruction::CreateParameter(0, data_shape, "data")); 395 builder.AddInstruction( 396 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); 397 return builder.Build(); 398 }; 399 400 module_ = CreateNewModule(); 401 HloComputation* cond_computation = 402 module_->AddEmbeddedComputation(make_cond()); 403 HloComputation* body_computation = 404 module_->AddEmbeddedComputation(make_body()); 405 406 auto builder = HloComputation::Builder(TestName()); 407 auto data = builder.AddInstruction( 408 HloInstruction::CreateParameter(0, data_shape, "data")); 409 auto whil = builder.AddInstruction(HloInstruction::CreateWhile( 410 data_shape, cond_computation, body_computation, data)); 411 computation_ = module_->AddEntryComputation(builder.Build()); 412 413 RunAnalysis(); 414 415 // The While instruction can share with the data operand. 416 EXPECT_TRUE( 417 CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); 418 419 EXPECT_TRUE( 420 CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); 421 } 422 423 // Tests that Call can alias operand buffer if the only use of the operand 424 // in the called computation is an elementwise instruction. 425 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { 426 Shape shape = ShapeUtil::MakeShape(F32, {8}); 427 // Build sub-computation with fusion root. 428 auto sub_builder = HloComputation::Builder(TestName() + "_sub"); 429 auto sub_param = sub_builder.AddInstruction( 430 HloInstruction::CreateParameter(0, shape, "sub_param")); 431 auto one = sub_builder.AddInstruction( 432 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); 433 auto ones = sub_builder.AddInstruction( 434 HloInstruction::CreateBroadcast(shape, one, {1})); 435 auto add = sub_builder.AddInstruction( 436 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); 437 438 module_ = CreateNewModule(); 439 auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); 440 sub_computation->CreateFusionInstruction({add, ones}, 441 HloInstruction::FusionKind::kLoop); 442 443 // Build entry-computation with kCall which calls 'sub_computation'. 444 auto builder = HloComputation::Builder(TestName()); 445 446 auto param = builder.AddInstruction( 447 HloInstruction::CreateParameter(0, shape, "param")); 448 auto reverse = 449 builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); 450 auto call = builder.AddInstruction( 451 HloInstruction::CreateCall(shape, {reverse}, sub_computation)); 452 computation_ = module_->AddEntryComputation(builder.Build()); 453 454 RunAnalysis(); 455 456 EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, 457 *points_to_analysis_)); 458 EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, 459 *dataflow_analysis_)); 460 } 461 462 } // namespace 463 } // namespace xla 464