1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_cse.h" 17 18 #include <memory> 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/layout_util.h" 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/ptr_util.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 28 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 33 #include "tensorflow/compiler/xla/tests/literal_test_util.h" 34 #include "tensorflow/compiler/xla/tests/test_utils.h" 35 #include "tensorflow/compiler/xla/util.h" 36 #include "tensorflow/compiler/xla/xla_data.pb.h" 37 38 #include "tensorflow/compiler/xla/types.h" 39 #include "tensorflow/core/platform/types.h" 40 41 namespace op = xla::testing::opcode_matchers; 42 43 namespace xla { 44 namespace { 45 46 class HloCseTest : public HloTestBase { 47 protected: 48 HloCseTest() {} 49 }; 50 51 TEST_F(HloCseTest, CombineTwoConstants) { 52 // Test that two identical constants are commoned. 53 auto builder = HloComputation::Builder(TestName()); 54 auto constant1 = builder.AddInstruction( 55 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 56 auto constant2 = builder.AddInstruction( 57 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 58 builder.AddInstruction(HloInstruction::CreateBinary( 59 constant1->shape(), HloOpcode::kAdd, constant1, constant2)); 60 61 auto module = CreateNewModule(); 62 auto computation = module->AddEntryComputation(builder.Build()); 63 64 EXPECT_EQ(3, computation->instruction_count()); 65 66 HloCSE cse(/*is_layout_sensitive=*/false); 67 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 68 69 EXPECT_EQ(2, computation->instruction_count()); 70 HloInstruction* constant = *computation->instructions().begin(); 71 EXPECT_EQ(42.0f, constant->literal().Get<float>({})); 72 73 auto result = ExecuteAndTransfer(std::move(module), {}); 74 auto expected = Literal::CreateR0<float>(84.0); 75 LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); 76 } 77 78 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { 79 // Test that two identical constants with different layouts are commoned if 80 // the pass is not layout sensitive. 81 auto builder = HloComputation::Builder(TestName()); 82 auto constant1 = builder.AddInstruction( 83 HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( 84 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); 85 auto constant2 = builder.AddInstruction( 86 HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( 87 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); 88 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 89 constant1->shape(), HloOpcode::kAdd, constant1, constant2)); 90 91 auto module = CreateNewModule(); 92 auto computation = module->AddEntryComputation(builder.Build()); 93 94 EXPECT_EQ(3, computation->instruction_count()); 95 EXPECT_THAT(add, op::Add(constant1, constant2)); 96 97 HloCSE cse(/*is_layout_sensitive=*/false); 98 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 99 100 EXPECT_EQ(2, computation->instruction_count()); 101 auto first_operand = add->operand(0); 102 EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2)); 103 EXPECT_THAT(add, op::Add(first_operand, first_operand)); 104 105 auto result = ExecuteAndTransfer(std::move(module), {}); 106 auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); 107 LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); 108 } 109 110 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { 111 // Test that two identical constants with different layouts are *not* commoned 112 // if the pass is layout sensitive. 113 auto builder = HloComputation::Builder(TestName()); 114 auto constant1 = builder.AddInstruction( 115 HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( 116 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); 117 auto constant2 = builder.AddInstruction( 118 HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>( 119 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); 120 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 121 constant1->shape(), HloOpcode::kAdd, constant1, constant2)); 122 123 auto module = CreateNewModule(); 124 auto computation = module->AddEntryComputation(builder.Build()); 125 126 EXPECT_EQ(3, computation->instruction_count()); 127 EXPECT_THAT(add, op::Add(constant1, constant2)); 128 129 HloCSE cse(/*is_layout_sensitive=*/true); 130 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); 131 132 EXPECT_EQ(3, computation->instruction_count()); 133 EXPECT_THAT(add, op::Add(constant1, constant2)); 134 135 auto result = ExecuteAndTransfer(std::move(module), {}); 136 auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}); 137 LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); 138 } 139 140 TEST_F(HloCseTest, ConstantsSameValueDifferentType) { 141 // Test that constants with the same value but different type are *not* 142 // commoned. 143 auto builder = HloComputation::Builder(TestName()); 144 builder.AddInstruction( 145 HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42))); 146 builder.AddInstruction( 147 HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))); 148 builder.AddInstruction( 149 HloInstruction::CreateConstant(Literal::CreateR0<uint64>(42.0))); 150 builder.AddInstruction( 151 HloInstruction::CreateConstant(Literal::CreateR0<int64>(42.0))); 152 builder.AddInstruction( 153 HloInstruction::CreateConstant(Literal::CreateR0<double>(42.0))); 154 builder.AddInstruction( 155 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 156 // Duplicate the float constant to verify something happens. 157 builder.AddInstruction( 158 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 159 160 auto module = CreateNewModule(); 161 auto computation = module->AddEntryComputation(builder.Build()); 162 163 EXPECT_EQ(7, computation->instruction_count()); 164 165 HloCSE cse(/*is_layout_sensitive=*/false); 166 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 167 168 EXPECT_EQ(6, computation->instruction_count()); 169 } 170 171 TEST_F(HloCseTest, NonscalarConstants) { 172 // Test that identical nonscalar constants are merged. 173 auto builder = HloComputation::Builder(TestName()); 174 auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( 175 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 176 auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( 177 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 178 // Create a constant which has the same shape but a different value. 179 auto uncommon_constant = 180 builder.AddInstruction(HloInstruction::CreateConstant( 181 Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}}))); 182 183 // Tie the constants together with a tuple. This makes it easier to refer to 184 // the constant instructions via their use. 185 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( 186 {common_constant1, common_constant2, uncommon_constant})); 187 188 auto module = CreateNewModule(); 189 auto computation = module->AddEntryComputation(builder.Build()); 190 191 EXPECT_EQ(4, computation->instruction_count()); 192 EXPECT_THAT(tuple, 193 op::Tuple(common_constant1, common_constant2, uncommon_constant)); 194 195 HloCSE cse(/*is_layout_sensitive=*/false); 196 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 197 198 EXPECT_EQ(3, computation->instruction_count()); 199 auto first_operand = tuple->operand(0); 200 EXPECT_THAT(first_operand, 201 ::testing::AnyOf(common_constant1, common_constant2)); 202 EXPECT_THAT(tuple, 203 op::Tuple(first_operand, first_operand, uncommon_constant)); 204 } 205 206 TEST_F(HloCseTest, IdenticalInstructions) { 207 // Test that three identical instructions are commoned. 208 auto builder = HloComputation::Builder(TestName()); 209 auto constant = builder.AddInstruction( 210 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); 211 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 212 constant->shape(), HloOpcode::kExp, constant)); 213 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( 214 constant->shape(), HloOpcode::kExp, constant)); 215 auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary( 216 constant->shape(), HloOpcode::kExp, constant)); 217 auto tuple = 218 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3})); 219 220 auto module = CreateNewModule(); 221 auto computation = module->AddEntryComputation(builder.Build()); 222 223 EXPECT_EQ(5, computation->instruction_count()); 224 EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3)); 225 226 HloCSE cse(/*is_layout_sensitive=*/false); 227 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 228 229 EXPECT_EQ(3, computation->instruction_count()); 230 auto first_operand = tuple->operand(0); 231 EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2, exp3)); 232 EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand, first_operand)); 233 } 234 235 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { 236 // Test that two identical instructions with different layouts are *not* 237 // commoned if the pass is layout sensitive. 238 auto builder = HloComputation::Builder(TestName()); 239 auto constant = builder.AddInstruction(HloInstruction::CreateConstant( 240 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 241 242 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 243 constant->shape(), HloOpcode::kExp, constant)); 244 *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 245 246 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( 247 constant->shape(), HloOpcode::kExp, constant)); 248 *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); 249 250 auto tuple = 251 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); 252 253 auto module = CreateNewModule(); 254 auto computation = module->AddEntryComputation(builder.Build()); 255 256 EXPECT_EQ(4, computation->instruction_count()); 257 EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); 258 259 HloCSE cse(/*is_layout_sensitive=*/true); 260 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); 261 262 EXPECT_EQ(4, computation->instruction_count()); 263 EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); 264 } 265 266 TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { 267 // Test that two identical instructions with different layouts are commoned if 268 // the pass is layout insensitive. 269 auto builder = HloComputation::Builder(TestName()); 270 auto constant = builder.AddInstruction(HloInstruction::CreateConstant( 271 Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}))); 272 273 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 274 constant->shape(), HloOpcode::kExp, constant)); 275 *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 276 277 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( 278 constant->shape(), HloOpcode::kExp, constant)); 279 *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); 280 281 auto tuple = 282 builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2})); 283 284 auto module = CreateNewModule(); 285 auto computation = module->AddEntryComputation(builder.Build()); 286 287 EXPECT_EQ(4, computation->instruction_count()); 288 EXPECT_THAT(tuple, op::Tuple(exp1, exp2)); 289 290 HloCSE cse(/*is_layout_sensitive=*/false); 291 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 292 293 EXPECT_EQ(3, computation->instruction_count()); 294 auto first_operand = tuple->operand(0); 295 EXPECT_THAT(first_operand, ::testing::AnyOf(exp1, exp2)); 296 EXPECT_THAT(tuple, op::Tuple(first_operand, first_operand)); 297 } 298 299 TEST_F(HloCseTest, FusionInternalCSE) { 300 // Test that we can CSE expressions that live within a fusion node 301 // computation. 302 auto module = CreateNewModule(); 303 auto builder = HloComputation::Builder(TestName()); 304 305 const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); 306 auto param0 = builder.AddInstruction( 307 HloInstruction::CreateParameter(0, shape_r0, "p0")); 308 auto param1 = builder.AddInstruction( 309 HloInstruction::CreateParameter(1, shape_r0, "p1")); 310 auto add1 = builder.AddInstruction( 311 HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1)); 312 auto add2 = builder.AddInstruction( 313 HloInstruction::CreateBinary(shape_r0, HloOpcode::kAdd, param0, param1)); 314 auto mul = builder.AddInstruction( 315 HloInstruction::CreateBinary(shape_r0, HloOpcode::kMultiply, add1, add2)); 316 317 auto computation = module->AddEntryComputation(builder.Build()); 318 auto fused_computation = 319 computation 320 ->CreateFusionInstruction({mul, add1, add2}, 321 HloInstruction::FusionKind::kLoop) 322 ->fused_instructions_computation(); 323 324 EXPECT_EQ(5, fused_computation->instruction_count()); 325 HloCSE cse(/*is_layout_sensitive=*/false); 326 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 327 EXPECT_EQ(4, fused_computation->instruction_count()); 328 329 auto root = fused_computation->root_instruction(); 330 EXPECT_THAT(root, op::Multiply(root->operand(0), root->operand(0))); 331 } 332 333 TEST_F(HloCseTest, IdenticalExpressions) { 334 // Test that two identical expressions are commoned. Build the following 335 // computation: 336 // 337 // constant = 42.0 338 // negate1 = neg(constant) 339 // exp1 = exp(constant) 340 // add1 = add(negate1, exp1) 341 // negate2 = neg(constant) 342 // exp2 = exp(constant) 343 // add2 = add(negate2, exp2) 344 // tuple = tuple(add1, add2) 345 // 346 // The *1 instructions should be merged with the *2 instructions. 347 auto builder = HloComputation::Builder(TestName()); 348 auto constant = builder.AddInstruction( 349 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0))); 350 351 auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( 352 constant->shape(), HloOpcode::kNegate, constant)); 353 auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( 354 constant->shape(), HloOpcode::kExp, constant)); 355 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( 356 constant->shape(), HloOpcode::kAdd, negate1, exp1)); 357 358 auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary( 359 constant->shape(), HloOpcode::kNegate, constant)); 360 auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( 361 constant->shape(), HloOpcode::kExp, constant)); 362 auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( 363 constant->shape(), HloOpcode::kAdd, negate2, exp2)); 364 365 auto tuple = 366 builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); 367 368 auto module = CreateNewModule(); 369 auto computation = module->AddEntryComputation(builder.Build()); 370 371 EXPECT_EQ(8, computation->instruction_count()); 372 EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2))); 373 374 HloCSE cse(/*is_layout_sensitive=*/false); 375 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 376 377 EXPECT_EQ(5, computation->instruction_count()); 378 auto operand = tuple->operand(0); 379 EXPECT_THAT(tuple, op::Tuple(operand, operand)); 380 EXPECT_THAT(operand, op::Add(op::Negate(), op::Exp())); 381 } 382 383 TEST_F(HloCseTest, DoNotCombineRng) { 384 // Test that two RNG ops are not commoned. 385 auto builder = HloComputation::Builder(TestName()); 386 auto constant1 = builder.AddInstruction( 387 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 388 auto constant2 = builder.AddInstruction( 389 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 390 auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( 391 ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, 392 {constant1, constant2})); 393 auto rng2 = builder.AddInstruction(HloInstruction::CreateRng( 394 ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, 395 {constant1, constant2})); 396 397 builder.AddInstruction(HloInstruction::CreateBinary( 398 constant1->shape(), HloOpcode::kAdd, rng1, rng2)); 399 400 auto module = CreateNewModule(); 401 auto computation = module->AddEntryComputation(builder.Build()); 402 403 HloInstruction* root = computation->root_instruction(); 404 EXPECT_THAT(root, op::Add(rng1, rng2)); 405 406 uint32 count_before = computation->instruction_count(); 407 408 HloCSE cse(/*is_layout_sensitive=*/false); 409 EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); 410 411 uint32 count_after = computation->instruction_count(); 412 EXPECT_EQ(count_before, count_after); 413 root = computation->root_instruction(); 414 EXPECT_THAT(root, op::Add(rng1, rng2)); 415 } 416 417 // TODO(b/28245743): Handle impure functions correctly in CSE. 418 TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { 419 // Test that two calls to an impure function are not commoned. RNG 420 // is the source of the impurity. 421 422 auto module = CreateNewModule(); 423 424 // rng_function is an impure function because it does RNG. 425 HloComputation* rng_function = nullptr; 426 { 427 Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 428 auto builder = HloComputation::Builder(TestName() + "_rng_fun"); 429 auto constant1 = builder.AddInstruction( 430 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 431 auto constant2 = builder.AddInstruction( 432 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 433 auto rng = builder.AddInstruction(HloInstruction::CreateRng( 434 scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); 435 auto param = builder.AddInstruction(HloInstruction::CreateParameter( 436 0, ShapeUtil::MakeShape(F32, {}), "param")); 437 builder.AddInstruction(HloInstruction::CreateBinary( 438 scalar_shape, HloOpcode::kAdd, rng, param)); 439 rng_function = module->AddEmbeddedComputation(builder.Build()); 440 } 441 442 // Computation calls rng_function twice with the same parameter. 443 HloComputation* computation = nullptr; 444 { 445 auto builder = HloComputation::Builder(TestName()); 446 auto constant = builder.AddInstruction( 447 HloInstruction::CreateConstant(Literal::CreateR1<float>({5.0f}))); 448 auto rng1 = builder.AddInstruction( 449 HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); 450 auto rng2 = builder.AddInstruction( 451 HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); 452 builder.AddInstruction(HloInstruction::CreateBinary( 453 constant->shape(), HloOpcode::kAdd, rng1, rng2)); 454 computation = module->AddEntryComputation(builder.Build()); 455 } 456 457 EXPECT_EQ(4, computation->instruction_count()); 458 HloInstruction* root = computation->root_instruction(); 459 EXPECT_THAT(root, op::Add(op::Map(), op::Map())); 460 461 HloCSE cse(/*is_layout_sensitive=*/false); 462 EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); 463 464 EXPECT_EQ(4, computation->instruction_count()); 465 root = computation->root_instruction(); 466 auto operand = root->operand(0)->operand(0); 467 EXPECT_THAT(operand, op::Map()); 468 EXPECT_THAT(root, op::Add(operand, operand)); 469 } 470 471 } // namespace 472 } // namespace xla 473