1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" 17 18 #include <memory> 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/layout_util.h" 22 #include "tensorflow/compiler/xla/literal_util.h" 23 #include "tensorflow/compiler/xla/ptr_util.h" 24 #include "tensorflow/compiler/xla/service/hlo_computation.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 27 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 28 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/compiler/xla/test.h" 31 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/window_util.h" 34 #include "tensorflow/compiler/xla/xla_data.pb.h" 35 #include "tensorflow/core/lib/core/status_test_util.h" 36 #include "tensorflow/core/lib/strings/str_util.h" 37 38 namespace xla { 39 namespace { 40 41 namespace op = xla::testing::opcode_matchers; 42 43 AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { 44 return [](const Shape&, const Shape&) { return true; }; 45 } 46 47 AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { 48 return [](const Shape&, const Shape&) { return false; }; 49 } 50 51 class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; 52 53 // Test that A + 0 is simplified to A 54 TEST_F(AlgebraicSimplifierTest, AddZero) { 55 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 56 HloComputation::Builder builder(TestName()); 57 HloInstruction* param0 = builder.AddInstruction( 58 HloInstruction::CreateParameter(0, r0f32, "param0")); 59 HloInstruction* zero = builder.AddInstruction( 60 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 61 builder.AddInstruction( 62 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); 63 64 auto computation = module().AddEntryComputation(builder.Build()); 65 HloInstruction* root = computation->root_instruction(); 66 EXPECT_EQ(root->opcode(), HloOpcode::kAdd); 67 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 68 non_bitcasting_callback()); 69 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 70 root = computation->root_instruction(); 71 EXPECT_EQ(root, param0); 72 } 73 74 // Test that Const + A is canonicalized to A + Const. 75 TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { 76 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 77 HloComputation::Builder builder(TestName()); 78 HloInstruction* param0 = builder.AddInstruction( 79 HloInstruction::CreateParameter(0, r0f32, "param0")); 80 HloInstruction* constant = builder.AddInstruction( 81 HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); 82 builder.AddInstruction( 83 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); 84 85 auto computation = module().AddEntryComputation(builder.Build()); 86 HloInstruction* root = computation->root_instruction(); 87 EXPECT_EQ(root->opcode(), HloOpcode::kAdd); 88 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 89 non_bitcasting_callback()); 90 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 91 root = computation->root_instruction(); 92 EXPECT_THAT(root, op::Add(param0, op::Constant())); 93 } 94 95 // Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. 96 TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { 97 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 98 HloComputation::Builder builder(TestName()); 99 HloInstruction* param0 = builder.AddInstruction( 100 HloInstruction::CreateParameter(0, r0f32, "param0")); 101 HloInstruction* constant1 = builder.AddInstruction( 102 HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); 103 HloInstruction* constant2 = builder.AddInstruction( 104 HloInstruction::CreateConstant(Literal::CreateR0(3.14159f))); 105 106 HloInstruction* add1 = builder.AddInstruction( 107 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1)); 108 builder.AddInstruction( 109 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); 110 111 auto computation = module().AddEntryComputation(builder.Build()); 112 HloInstruction* root = computation->root_instruction(); 113 EXPECT_EQ(root->opcode(), HloOpcode::kAdd); 114 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 115 non_bitcasting_callback()); 116 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 117 root = computation->root_instruction(); 118 EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); 119 } 120 121 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { 122 Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); 123 HloComputation::Builder builder(TestName()); 124 HloInstruction* param0 = builder.AddInstruction( 125 HloInstruction::CreateParameter(0, r2f32, "param0")); 126 HloInstruction* zero = builder.AddInstruction( 127 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 128 HloInstruction* bcast = builder.AddInstruction( 129 HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); 130 builder.AddInstruction( 131 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); 132 133 auto computation = module().AddEntryComputation(builder.Build()); 134 HloInstruction* root = computation->root_instruction(); 135 EXPECT_EQ(root->opcode(), HloOpcode::kAdd); 136 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 137 non_bitcasting_callback()); 138 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 139 root = computation->root_instruction(); 140 EXPECT_EQ(root, param0); 141 } 142 143 TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { 144 Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); 145 HloComputation::Builder builder(TestName()); 146 HloInstruction* param0 = builder.AddInstruction( 147 HloInstruction::CreateParameter(0, r2f32, "param0")); 148 HloInstruction* zero = builder.AddInstruction( 149 HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0}))); 150 HloInstruction* bcast = 151 builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); 152 builder.AddInstruction( 153 HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); 154 155 auto computation = module().AddEntryComputation(builder.Build()); 156 HloInstruction* root = computation->root_instruction(); 157 EXPECT_EQ(root->opcode(), HloOpcode::kAdd); 158 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 159 non_bitcasting_callback()); 160 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 161 root = computation->root_instruction(); 162 EXPECT_EQ(root, param0); 163 } 164 165 // Test that A - 0 is simplified to A 166 TEST_F(AlgebraicSimplifierTest, SubZero) { 167 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 168 HloComputation::Builder builder(TestName()); 169 HloInstruction* param0 = builder.AddInstruction( 170 HloInstruction::CreateParameter(0, r0f32, "param0")); 171 HloInstruction* zero = builder.AddInstruction( 172 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 173 builder.AddInstruction( 174 HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); 175 176 auto computation = module().AddEntryComputation(builder.Build()); 177 HloInstruction* root = computation->root_instruction(); 178 EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); 179 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 180 non_bitcasting_callback()); 181 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 182 root = computation->root_instruction(); 183 EXPECT_EQ(root, param0); 184 } 185 186 // Test that A - Const is canonicalized to A + (-Const). 187 TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { 188 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 189 HloComputation::Builder builder(TestName()); 190 HloInstruction* param0 = builder.AddInstruction( 191 HloInstruction::CreateParameter(0, r0f32, "param0")); 192 HloInstruction* constant = builder.AddInstruction( 193 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 194 builder.AddInstruction(HloInstruction::CreateBinary( 195 r0f32, HloOpcode::kSubtract, param0, constant)); 196 197 auto computation = module().AddEntryComputation(builder.Build()); 198 HloInstruction* root = computation->root_instruction(); 199 EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); 200 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 201 non_bitcasting_callback()); 202 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 203 root = computation->root_instruction(); 204 EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); 205 } 206 207 // Test that (A/B)/C is simplified to A/(B*C). 208 TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { 209 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 210 HloComputation::Builder builder(TestName()); 211 HloInstruction* param0 = builder.AddInstruction( 212 HloInstruction::CreateParameter(0, r0f32, "param0")); 213 HloInstruction* param1 = builder.AddInstruction( 214 HloInstruction::CreateParameter(1, r0f32, "param1")); 215 HloInstruction* param2 = builder.AddInstruction( 216 HloInstruction::CreateParameter(2, r0f32, "param2")); 217 HloInstruction* div = builder.AddInstruction( 218 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1)); 219 builder.AddInstruction( 220 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); 221 222 auto computation = module().AddEntryComputation(builder.Build()); 223 224 EXPECT_THAT(computation->root_instruction(), 225 op::Divide(op::Divide(param0, param1), param2)); 226 227 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 228 non_bitcasting_callback()); 229 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 230 231 EXPECT_THAT(computation->root_instruction(), 232 op::Divide(param0, op::Multiply(param1, param2))); 233 } 234 235 // Test that A/(B/C) is simplified to (A*C)/B. 236 TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { 237 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 238 HloComputation::Builder builder(TestName()); 239 HloInstruction* param0 = builder.AddInstruction( 240 HloInstruction::CreateParameter(0, r0f32, "param0")); 241 HloInstruction* param1 = builder.AddInstruction( 242 HloInstruction::CreateParameter(1, r0f32, "param1")); 243 HloInstruction* param2 = builder.AddInstruction( 244 HloInstruction::CreateParameter(2, r0f32, "param2")); 245 HloInstruction* div = builder.AddInstruction( 246 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param1, param2)); 247 builder.AddInstruction( 248 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); 249 250 auto computation = module().AddEntryComputation(builder.Build()); 251 252 EXPECT_THAT(computation->root_instruction(), 253 op::Divide(param0, op::Divide(param1, param2))); 254 255 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 256 non_bitcasting_callback()); 257 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 258 259 EXPECT_THAT(computation->root_instruction(), 260 op::Divide(op::Multiply(param0, param2), param1)); 261 } 262 263 // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). 264 TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { 265 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 266 Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); 267 HloComputation::Builder builder(TestName()); 268 HloInstruction* param0 = builder.AddInstruction( 269 HloInstruction::CreateParameter(0, r0f32, "param0")); 270 HloInstruction* param1 = builder.AddInstruction( 271 HloInstruction::CreateParameter(1, r2f32, "param1")); 272 HloInstruction* param2 = builder.AddInstruction( 273 HloInstruction::CreateParameter(2, r2f32, "param2")); 274 HloInstruction* param3 = builder.AddInstruction( 275 HloInstruction::CreateParameter(3, r0f32, "param3")); 276 HloInstruction* div0 = builder.AddInstruction( 277 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1)); 278 HloInstruction* div1 = builder.AddInstruction( 279 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3)); 280 builder.AddInstruction( 281 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); 282 283 auto computation = module().AddEntryComputation(builder.Build()); 284 285 EXPECT_THAT( 286 computation->root_instruction(), 287 op::Divide(op::Divide(param0, param1), op::Divide(param2, param3))); 288 289 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 290 non_bitcasting_callback()); 291 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 292 293 EXPECT_THAT( 294 computation->root_instruction(), 295 op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); 296 EXPECT_TRUE( 297 ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32)); 298 } 299 300 // Test that A/exp(B) is simplified to A*exp(-B). 301 TEST_F(AlgebraicSimplifierTest, DivOfExp) { 302 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 303 HloComputation::Builder builder(TestName()); 304 HloInstruction* param0 = builder.AddInstruction( 305 HloInstruction::CreateParameter(0, r0f32, "param0")); 306 HloInstruction* param1 = builder.AddInstruction( 307 HloInstruction::CreateParameter(1, r0f32, "param1")); 308 HloInstruction* exp = builder.AddInstruction( 309 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); 310 builder.AddInstruction( 311 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); 312 313 auto computation = module().AddEntryComputation(builder.Build()); 314 315 EXPECT_THAT(computation->root_instruction(), 316 op::Divide(param0, op::Exp(param1))); 317 318 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 319 non_bitcasting_callback()); 320 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 321 322 EXPECT_THAT(computation->root_instruction(), 323 op::Multiply(param0, op::Exp(op::Negate(param1)))); 324 } 325 326 // Test that A/pow(B,C) is simplified to A*pow(B,-C). 327 TEST_F(AlgebraicSimplifierTest, DivOfPower) { 328 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 329 HloComputation::Builder builder(TestName()); 330 HloInstruction* param0 = builder.AddInstruction( 331 HloInstruction::CreateParameter(0, r0f32, "param0")); 332 HloInstruction* param1 = builder.AddInstruction( 333 HloInstruction::CreateParameter(1, r0f32, "param1")); 334 HloInstruction* param2 = builder.AddInstruction( 335 HloInstruction::CreateParameter(2, r0f32, "param2")); 336 HloInstruction* power = builder.AddInstruction( 337 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param1, param2)); 338 builder.AddInstruction( 339 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); 340 341 auto computation = module().AddEntryComputation(builder.Build()); 342 343 EXPECT_THAT(computation->root_instruction(), 344 op::Divide(param0, op::Power(param1, param2))); 345 346 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 347 non_bitcasting_callback()); 348 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 349 350 EXPECT_THAT(computation->root_instruction(), 351 op::Multiply(param0, op::Power(param1, op::Negate(param2)))); 352 } 353 354 // Test that broadcasting is done on the right step when simplifying A/pow(B,C) 355 // to A*pow(B,-C). 356 TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { 357 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 358 Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); 359 HloComputation::Builder builder(TestName()); 360 HloInstruction* param0 = builder.AddInstruction( 361 HloInstruction::CreateParameter(0, r1f32, "param0")); 362 HloInstruction* param1 = builder.AddInstruction( 363 HloInstruction::CreateParameter(1, r1f32, "param1")); 364 HloInstruction* param2 = builder.AddInstruction( 365 HloInstruction::CreateParameter(2, r0f32, "param2")); 366 HloInstruction* power = builder.AddInstruction( 367 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2)); 368 builder.AddInstruction( 369 HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); 370 371 auto computation = module().AddEntryComputation(builder.Build()); 372 373 EXPECT_THAT(computation->root_instruction(), 374 op::Divide(param0, op::Power(param1, param2))); 375 376 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 377 non_bitcasting_callback()); 378 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 379 380 ASSERT_THAT(computation->root_instruction(), 381 op::Multiply(param0, op::Power(param1, op::Negate(param2)))); 382 383 const HloInstruction* negate = 384 computation->root_instruction()->operand(1)->operand(1); 385 const Shape& negate_shape = negate->shape(); 386 EXPECT_EQ(0, negate_shape.dimensions_size()); 387 } 388 389 // A / Const => A * (1 / Const) 390 TEST_F(AlgebraicSimplifierTest, DivideByConstant) { 391 Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); 392 HloComputation::Builder builder(TestName()); 393 HloInstruction* param0 = builder.AddInstruction( 394 HloInstruction::CreateParameter(0, r1f32, "param0")); 395 HloInstruction* constant = 396 builder.AddInstruction(HloInstruction::CreateConstant( 397 Literal::CreateR1<float>({0.f, 1.f, 2.f}))); 398 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, 399 param0, constant)); 400 401 auto computation = module().AddEntryComputation(builder.Build()); 402 403 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 404 non_bitcasting_callback()); 405 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 406 407 EXPECT_THAT(computation->root_instruction(), 408 op::Multiply(param0, op::Divide(op::Constant(), constant))); 409 } 410 411 // pow(pow(A, X), Y) => pow(A, X*Y) 412 TEST_F(AlgebraicSimplifierTest, PowerOfPower) { 413 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 414 Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); 415 HloComputation::Builder builder(TestName()); 416 HloInstruction* base = builder.AddInstruction( 417 HloInstruction::CreateParameter(0, r1f32, "param0")); 418 HloInstruction* exp1 = builder.AddInstruction( 419 HloInstruction::CreateParameter(1, r0f32, "param1")); 420 HloInstruction* exp2 = builder.AddInstruction( 421 HloInstruction::CreateParameter(2, r0f32, "param2")); 422 HloInstruction* inner_power = builder.AddInstruction( 423 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); 424 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, 425 inner_power, exp2)); 426 427 auto computation = module().AddEntryComputation(builder.Build()); 428 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 429 non_bitcasting_callback()); 430 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 431 EXPECT_THAT(computation->root_instruction(), 432 op::Power(base, op::Multiply(exp1, exp2))); 433 } 434 435 // Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex 436 // numbers. 437 TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { 438 Shape r0c64 = ShapeUtil::MakeShape(C64, {}); 439 Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); 440 HloComputation::Builder builder(TestName()); 441 HloInstruction* base = builder.AddInstruction( 442 HloInstruction::CreateParameter(0, r1c64, "param0")); 443 HloInstruction* exp1 = builder.AddInstruction( 444 HloInstruction::CreateParameter(1, r0c64, "param1")); 445 HloInstruction* exp2 = builder.AddInstruction( 446 HloInstruction::CreateParameter(2, r0c64, "param2")); 447 HloInstruction* inner_power = builder.AddInstruction( 448 HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1)); 449 builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, 450 inner_power, exp2)); 451 452 module().AddEntryComputation(builder.Build()); 453 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 454 non_bitcasting_callback()); 455 ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); 456 } 457 458 // Test that A/1 is simplified to A for a scalar. 459 TEST_F(AlgebraicSimplifierTest, DivOneScalar) { 460 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 461 HloComputation::Builder builder(TestName()); 462 HloInstruction* param0 = builder.AddInstruction( 463 HloInstruction::CreateParameter(0, r0f32, "param0")); 464 HloInstruction* one = builder.AddInstruction( 465 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 466 HloInstruction* div = builder.AddInstruction( 467 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); 468 469 auto computation = module().AddEntryComputation(builder.Build()); 470 HloInstruction* root = computation->root_instruction(); 471 EXPECT_EQ(root, div); 472 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 473 non_bitcasting_callback()); 474 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 475 root = computation->root_instruction(); 476 EXPECT_EQ(root, param0); 477 } 478 479 // Test that A/1 is simplified to A for an array. 480 TEST_F(AlgebraicSimplifierTest, DivOneArray) { 481 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); 482 HloComputation::Builder builder(TestName()); 483 HloInstruction* param0 = builder.AddInstruction( 484 HloInstruction::CreateParameter(0, r2f32, "param0")); 485 HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant( 486 Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}}))); 487 HloInstruction* div = builder.AddInstruction( 488 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); 489 490 auto computation = module().AddEntryComputation(builder.Build()); 491 HloInstruction* root = computation->root_instruction(); 492 EXPECT_EQ(root, div); 493 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 494 non_bitcasting_callback()); 495 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 496 root = computation->root_instruction(); 497 EXPECT_EQ(root, param0); 498 } 499 500 // Test that complex(real(c), imag(c)) is simplified to c. 501 TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { 502 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); 503 Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); 504 HloComputation::Builder builder(TestName()); 505 HloInstruction* param0 = builder.AddInstruction( 506 HloInstruction::CreateParameter(0, r2c64, "param0")); 507 HloInstruction* real = builder.AddInstruction( 508 HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0)); 509 HloInstruction* imag = builder.AddInstruction( 510 HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0)); 511 HloInstruction* cplx = builder.AddInstruction( 512 HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); 513 514 auto computation = module().AddEntryComputation(builder.Build()); 515 HloInstruction* root = computation->root_instruction(); 516 EXPECT_EQ(root, cplx); 517 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 518 non_bitcasting_callback()); 519 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 520 root = computation->root_instruction(); 521 EXPECT_EQ(root, param0); 522 } 523 524 // Test that real(complex(r,i)) is simplified to r. 525 TEST_F(AlgebraicSimplifierTest, RealOfComplex) { 526 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); 527 HloComputation::Builder builder(TestName()); 528 HloInstruction* param0 = builder.AddInstruction( 529 HloInstruction::CreateParameter(0, r2f32, "param0")); 530 HloInstruction* param1 = builder.AddInstruction( 531 HloInstruction::CreateParameter(1, r2f32, "param1")); 532 HloInstruction* cplx = builder.AddInstruction( 533 HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64), 534 HloOpcode::kComplex, param0, param1)); 535 HloInstruction* real = builder.AddInstruction( 536 HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); 537 538 auto computation = module().AddEntryComputation(builder.Build()); 539 HloInstruction* root = computation->root_instruction(); 540 EXPECT_EQ(root, real); 541 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 542 non_bitcasting_callback()); 543 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 544 root = computation->root_instruction(); 545 EXPECT_EQ(root, param0); 546 } 547 548 // Test that imag(complex(r,i)) is simplified to i. 549 TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { 550 Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); 551 HloComputation::Builder builder(TestName()); 552 HloInstruction* param0 = builder.AddInstruction( 553 HloInstruction::CreateParameter(0, r2f32, "param0")); 554 HloInstruction* param1 = builder.AddInstruction( 555 HloInstruction::CreateParameter(1, r2f32, "param1")); 556 HloInstruction* cplx = builder.AddInstruction( 557 HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64), 558 HloOpcode::kComplex, param0, param1)); 559 HloInstruction* imag = builder.AddInstruction( 560 HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); 561 562 auto computation = module().AddEntryComputation(builder.Build()); 563 HloInstruction* root = computation->root_instruction(); 564 EXPECT_EQ(root, imag); 565 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 566 non_bitcasting_callback()); 567 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 568 root = computation->root_instruction(); 569 EXPECT_EQ(root, param1); 570 } 571 572 // Test that get_element(make_tuple({A,B}),1) is simplified to B 573 TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { 574 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 575 HloComputation::Builder builder(TestName()); 576 HloInstruction* param0 = builder.AddInstruction( 577 HloInstruction::CreateParameter(0, r0f32, "param0")); 578 HloInstruction* param1 = builder.AddInstruction( 579 HloInstruction::CreateParameter(1, r0f32, "param1")); 580 HloInstruction* param2 = builder.AddInstruction( 581 HloInstruction::CreateParameter(2, r0f32, "param2")); 582 HloInstruction* tuple = 583 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); 584 HloInstruction* get = builder.AddInstruction( 585 HloInstruction::CreateGetTupleElement(r0f32, tuple, 1)); 586 HloInstruction* add = builder.AddInstruction( 587 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); 588 589 auto computation = module().AddEntryComputation(builder.Build()); 590 HloInstruction* root = computation->root_instruction(); 591 EXPECT_EQ(root, add); 592 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 593 non_bitcasting_callback()); 594 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 595 root = computation->root_instruction(); 596 EXPECT_THAT(root, op::Add(param1, param2)); 597 } 598 599 // Test that exp(A)/exp(B) is simplified to exp(A-B) 600 TEST_F(AlgebraicSimplifierTest, ExpDiv) { 601 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 602 HloComputation::Builder builder(TestName()); 603 HloInstruction* param0 = builder.AddInstruction( 604 HloInstruction::CreateParameter(0, r0f32, "param0")); 605 HloInstruction* param1 = builder.AddInstruction( 606 HloInstruction::CreateParameter(1, r0f32, "param1")); 607 HloInstruction* exp0 = builder.AddInstruction( 608 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); 609 HloInstruction* exp1 = builder.AddInstruction( 610 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); 611 builder.AddInstruction( 612 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); 613 614 auto computation = module().AddEntryComputation(builder.Build()); 615 616 EXPECT_THAT(computation->root_instruction(), 617 op::Divide(op::Exp(param0), op::Exp(param1))); 618 619 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 620 non_bitcasting_callback()); 621 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 622 623 EXPECT_THAT(computation->root_instruction(), 624 op::Exp(op::Subtract(param0, param1))); 625 } 626 627 // Test that exp(A)*exp(B) is simplified to exp(A+B) 628 TEST_F(AlgebraicSimplifierTest, ExpMul) { 629 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 630 HloComputation::Builder builder(TestName()); 631 HloInstruction* param0 = builder.AddInstruction( 632 HloInstruction::CreateParameter(0, r0f32, "param0")); 633 HloInstruction* param1 = builder.AddInstruction( 634 HloInstruction::CreateParameter(1, r0f32, "param1")); 635 HloInstruction* exp0 = builder.AddInstruction( 636 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); 637 HloInstruction* exp1 = builder.AddInstruction( 638 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); 639 builder.AddInstruction( 640 HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); 641 642 auto computation = module().AddEntryComputation(builder.Build()); 643 644 EXPECT_THAT(computation->root_instruction(), 645 op::Multiply(op::Exp(param0), op::Exp(param1))); 646 647 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 648 non_bitcasting_callback()); 649 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 650 651 EXPECT_THAT(computation->root_instruction(), 652 op::Exp(op::Add(param0, param1))); 653 } 654 655 // Test that pow(exp(A), B) is simplified to exp(A*B) 656 TEST_F(AlgebraicSimplifierTest, PowExp) { 657 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 658 HloComputation::Builder builder(TestName()); 659 HloInstruction* param0 = builder.AddInstruction( 660 HloInstruction::CreateParameter(0, r0f32, "param0")); 661 HloInstruction* param1 = builder.AddInstruction( 662 HloInstruction::CreateParameter(1, r0f32, "param1")); 663 HloInstruction* exp0 = builder.AddInstruction( 664 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); 665 builder.AddInstruction( 666 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); 667 668 auto computation = module().AddEntryComputation(builder.Build()); 669 670 EXPECT_THAT(computation->root_instruction(), 671 op::Power(op::Exp(param0), param1)); 672 673 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 674 non_bitcasting_callback()); 675 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 676 677 EXPECT_THAT(computation->root_instruction(), 678 op::Exp(op::Multiply(param0, param1))); 679 } 680 681 // Test that ln(pow(A, B)) is simplified to ln(A)*B 682 TEST_F(AlgebraicSimplifierTest, LnPow) { 683 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 684 HloComputation::Builder builder(TestName()); 685 HloInstruction* param0 = builder.AddInstruction( 686 HloInstruction::CreateParameter(0, r0f32, "param0")); 687 HloInstruction* param1 = builder.AddInstruction( 688 HloInstruction::CreateParameter(1, r0f32, "param1")); 689 HloInstruction* pow = builder.AddInstruction( 690 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, param1)); 691 builder.AddInstruction( 692 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); 693 694 auto computation = module().AddEntryComputation(builder.Build()); 695 696 EXPECT_THAT(computation->root_instruction(), 697 op::Log(op::Power(param0, param1))); 698 699 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 700 non_bitcasting_callback()); 701 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 702 703 EXPECT_THAT(computation->root_instruction(), 704 op::Multiply(op::Log(param0), param1)); 705 } 706 707 // Test that ln(exp(A)) is simplified to A 708 TEST_F(AlgebraicSimplifierTest, LnExp) { 709 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 710 HloComputation::Builder builder(TestName()); 711 HloInstruction* param0 = builder.AddInstruction( 712 HloInstruction::CreateParameter(0, r0f32, "param0")); 713 HloInstruction* exp0 = builder.AddInstruction( 714 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); 715 builder.AddInstruction( 716 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); 717 718 auto computation = module().AddEntryComputation(builder.Build()); 719 720 EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); 721 722 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 723 non_bitcasting_callback()); 724 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 725 726 EXPECT_EQ(computation->root_instruction(), param0); 727 } 728 729 // Test that ln(exp(A)/exp(B)) is simplified to A-B 730 TEST_F(AlgebraicSimplifierTest, LnExpDiv) { 731 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 732 HloComputation::Builder builder(TestName()); 733 HloInstruction* param0 = builder.AddInstruction( 734 HloInstruction::CreateParameter(0, r0f32, "param0")); 735 HloInstruction* param1 = builder.AddInstruction( 736 HloInstruction::CreateParameter(1, r0f32, "param1")); 737 HloInstruction* exp0 = builder.AddInstruction( 738 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0)); 739 HloInstruction* exp1 = builder.AddInstruction( 740 HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1)); 741 HloInstruction* div = builder.AddInstruction( 742 HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); 743 builder.AddInstruction( 744 HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); 745 746 auto computation = module().AddEntryComputation(builder.Build()); 747 748 EXPECT_THAT(computation->root_instruction(), 749 op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); 750 751 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 752 non_bitcasting_callback()); 753 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 754 755 EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); 756 } 757 758 // Test that pow(A, 0) where A is a scalar is simplified to the scalar 759 // constant 1. 760 TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { 761 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 762 HloComputation::Builder builder(TestName()); 763 HloInstruction* param0 = builder.AddInstruction( 764 HloInstruction::CreateParameter(0, r0f32, "param0")); 765 HloInstruction* zero = builder.AddInstruction( 766 HloInstruction::CreateConstant(Literal::CreateR0<float>(0))); 767 builder.AddInstruction( 768 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); 769 770 auto computation = module().AddEntryComputation(builder.Build()); 771 772 EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); 773 774 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 775 non_bitcasting_callback()); 776 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 777 778 HloInstruction* root = computation->root_instruction(); 779 EXPECT_THAT(root, op::Constant()); 780 EXPECT_EQ(root->literal().GetFirstElement<float>(), 1); 781 } 782 783 // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). 784 TEST_F(AlgebraicSimplifierTest, Pow0Vector) { 785 Shape r1f32 = ShapeUtil::MakeShape(F32, {42}); 786 HloComputation::Builder builder(TestName()); 787 HloInstruction* param0 = builder.AddInstruction( 788 HloInstruction::CreateParameter(0, r1f32, "param0")); 789 HloInstruction* zero = builder.AddInstruction( 790 HloInstruction::CreateConstant(Literal::CreateR0<float>(0))); 791 builder.AddInstruction( 792 HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); 793 794 auto computation = module().AddEntryComputation(builder.Build()); 795 796 EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); 797 798 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 799 non_bitcasting_callback()); 800 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 801 802 HloInstruction* root = computation->root_instruction(); 803 EXPECT_THAT(root, op::Broadcast()); 804 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32)) 805 << ShapeUtil::HumanString(root->shape()); 806 EXPECT_EQ(root->dimensions().size(), 0); 807 EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape())); 808 EXPECT_EQ(root->operand(0)->literal().GetFirstElement<float>(), 1); 809 } 810 811 // Test that pow(A, 1) is simplified to A. 812 TEST_F(AlgebraicSimplifierTest, Pow1) { 813 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 814 HloComputation::Builder builder(TestName()); 815 HloInstruction* param0 = builder.AddInstruction( 816 HloInstruction::CreateParameter(0, r0f32, "param0")); 817 HloInstruction* one = builder.AddInstruction( 818 HloInstruction::CreateConstant(Literal::CreateR0<float>(1))); 819 builder.AddInstruction( 820 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); 821 822 auto computation = module().AddEntryComputation(builder.Build()); 823 824 EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); 825 826 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 827 non_bitcasting_callback()); 828 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 829 830 EXPECT_EQ(computation->root_instruction(), param0); 831 } 832 833 // Test that pow(A, 2) is simplified to A*A. 834 TEST_F(AlgebraicSimplifierTest, Pow2) { 835 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 836 HloComputation::Builder builder(TestName()); 837 HloInstruction* param0 = builder.AddInstruction( 838 HloInstruction::CreateParameter(0, r0f32, "param0")); 839 HloInstruction* two = builder.AddInstruction( 840 HloInstruction::CreateConstant(Literal::CreateR0<float>(2))); 841 builder.AddInstruction( 842 HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); 843 844 auto computation = module().AddEntryComputation(builder.Build()); 845 846 EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); 847 848 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 849 non_bitcasting_callback()); 850 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 851 852 EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); 853 } 854 855 // Test that pow(A, -1) is simplified to 1/A. 856 TEST_F(AlgebraicSimplifierTest, PowNegative1) { 857 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 858 HloComputation::Builder builder(TestName()); 859 HloInstruction* param0 = builder.AddInstruction( 860 HloInstruction::CreateParameter(0, r0f32, "param0")); 861 HloInstruction* negative_one = builder.AddInstruction( 862 HloInstruction::CreateConstant(Literal::CreateR0<float>(-1))); 863 builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, 864 param0, negative_one)); 865 866 auto computation = module().AddEntryComputation(builder.Build()); 867 868 EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); 869 870 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 871 non_bitcasting_callback()); 872 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 873 874 HloInstruction* root = computation->root_instruction(); 875 EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); 876 EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); 877 EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement<float>(), 878 1); 879 } 880 881 TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { 882 auto builder = HloComputation::Builder(TestName()); 883 HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( 884 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs")); 885 886 HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter( 887 1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs")); 888 889 ConvolutionDimensionNumbers dnums; 890 dnums.set_input_batch_dimension(0); 891 dnums.add_input_spatial_dimensions(1); 892 dnums.set_input_feature_dimension(2); 893 894 dnums.set_output_batch_dimension(0); 895 dnums.add_output_spatial_dimensions(1); 896 dnums.set_output_feature_dimension(2); 897 898 dnums.add_kernel_spatial_dimensions(0); 899 dnums.set_kernel_input_feature_dimension(1); 900 dnums.set_kernel_output_feature_dimension(2); 901 Window window; 902 WindowDimension* dim = window.add_dimensions(); 903 dim->set_size(3); 904 dim->set_padding_low(0); 905 dim->set_padding_high(0); 906 dim->set_stride(1); 907 dim->set_window_dilation(1); 908 dim->set_base_dilation(1); 909 dim->set_window_reversal(false); 910 // Create add computation. 911 builder.AddInstruction(HloInstruction::CreateConvolve( 912 ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); 913 module().AddEntryComputation(builder.Build()); 914 HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, 915 non_bitcasting_callback()); 916 EXPECT_THAT(module().entry_computation()->root_instruction(), 917 op::Convolution(lhs, rhs)); 918 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 919 EXPECT_THAT(module().entry_computation()->root_instruction(), 920 op::Broadcast(op::Constant())); 921 } 922 923 TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { 924 auto builder = HloComputation::Builder(TestName()); 925 HloInstruction* param = 926 builder.AddInstruction(HloInstruction::CreateParameter( 927 0, ShapeUtil::MakeShape(F32, {3, 0}), "op")); 928 Window window; 929 for (int64 i = 0; i < 2; ++i) { 930 WindowDimension* dim = window.add_dimensions(); 931 dim->set_size(1); 932 dim->set_padding_low(1); 933 dim->set_padding_high(1); 934 dim->set_window_dilation(1); 935 dim->set_base_dilation(1); 936 } 937 // Create add computation. 938 HloComputation* add_computation = nullptr; 939 { 940 HloComputation::Builder builder(TestName() + ".add"); 941 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 942 HloInstruction* p0 = builder.AddInstruction( 943 HloInstruction::CreateParameter(0, scalar_shape, "p0")); 944 HloInstruction* p1 = builder.AddInstruction( 945 HloInstruction::CreateParameter(1, scalar_shape, "p1")); 946 builder.AddInstruction( 947 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); 948 add_computation = module().AddEmbeddedComputation(builder.Build()); 949 } 950 builder.AddInstruction(HloInstruction::CreateReduceWindow( 951 ShapeUtil::MakeShape(F32, {5, 2}), param, 952 builder.AddInstruction( 953 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))), 954 window, add_computation)); 955 module().AddEntryComputation(builder.Build()); 956 HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, 957 non_bitcasting_callback()); 958 EXPECT_THAT(module().entry_computation()->root_instruction(), 959 op::ReduceWindow(param, op::Constant())); 960 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 961 EXPECT_THAT(module().entry_computation()->root_instruction(), 962 op::Broadcast(op::Constant())); 963 } 964 965 TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { 966 auto builder = HloComputation::Builder(TestName()); 967 HloInstruction* param = 968 builder.AddInstruction(HloInstruction::CreateParameter( 969 0, ShapeUtil::MakeShape(F32, {3, 0}), "op")); 970 PaddingConfig padding; 971 for (int i = 0; i < 2; ++i) { 972 PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions(); 973 dimension->set_edge_padding_low(1); 974 dimension->set_edge_padding_high(1); 975 dimension->set_interior_padding(0); 976 } 977 builder.AddInstruction(HloInstruction::CreatePad( 978 ShapeUtil::MakeShape(F32, {5, 2}), param, 979 builder.AddInstruction( 980 HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), 981 padding)); 982 module().AddEntryComputation(builder.Build()); 983 EXPECT_THAT(module().entry_computation()->root_instruction(), 984 op::Pad(param, op::Constant())); 985 HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, 986 non_bitcasting_callback()); 987 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 988 EXPECT_THAT(module().entry_computation()->root_instruction(), 989 op::Broadcast(op::Constant())); 990 } 991 992 TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { 993 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 994 995 auto builder = HloComputation::Builder(TestName()); 996 auto op = builder.AddInstruction(HloInstruction::CreateParameter( 997 0, ShapeUtil::MakeShape(F32, {3, 2}), "op")); 998 auto reshape1 = builder.AddInstruction( 999 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op)); 1000 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( 1001 ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1})); 1002 builder.AddInstruction(HloInstruction::CreateReshape( 1003 ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); 1004 1005 auto computation = builder.Build(); 1006 module().AddEntryComputation(std::move(computation)); 1007 1008 EXPECT_THAT(module().entry_computation()->root_instruction(), 1009 op::Reshape(op::Broadcast(op::Reshape(op)))); 1010 1011 HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false, 1012 non_bitcasting_callback()); 1013 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1014 1015 EXPECT_THAT(module().entry_computation()->root_instruction(), op); 1016 } 1017 1018 // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. 1019 TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { 1020 HloComputation::Builder builder(TestName()); 1021 HloInstruction* input = builder.AddInstruction( 1022 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 1023 builder.AddInstruction( 1024 HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); 1025 1026 auto computation = module().AddEntryComputation(builder.Build()); 1027 1028 EXPECT_THAT(computation->root_instruction(), op::Convert(input)); 1029 1030 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1031 non_bitcasting_callback()); 1032 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1033 1034 EXPECT_THAT(computation->root_instruction(), input); 1035 } 1036 1037 // Test that copies are removed. 1038 TEST_F(AlgebraicSimplifierTest, RemoveCopy) { 1039 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 1040 HloComputation::Builder builder(TestName()); 1041 HloInstruction* param0 = builder.AddInstruction( 1042 HloInstruction::CreateParameter(0, r0f32, "param0")); 1043 builder.AddInstruction( 1044 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); 1045 1046 auto computation = module().AddEntryComputation(builder.Build()); 1047 1048 EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); 1049 1050 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1051 non_bitcasting_callback()); 1052 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1053 1054 EXPECT_THAT(computation->root_instruction(), param0); 1055 } 1056 1057 // Test that unary concatenates are removed. 1058 TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { 1059 Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); 1060 HloComputation::Builder builder(TestName()); 1061 HloInstruction* param0 = builder.AddInstruction( 1062 HloInstruction::CreateParameter(0, r1f32, "param0")); 1063 builder.AddInstruction( 1064 HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); 1065 1066 auto computation = module().AddEntryComputation(builder.Build()); 1067 1068 EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); 1069 1070 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1071 non_bitcasting_callback()); 1072 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1073 1074 EXPECT_THAT(computation->root_instruction(), param0); 1075 } 1076 1077 // Test that empty operands of concatenates are removed. 1078 TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { 1079 const int kParamLength = 100; 1080 Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); 1081 HloComputation::Builder builder(TestName()); 1082 HloInstruction* param0 = builder.AddInstruction( 1083 HloInstruction::CreateParameter(0, r1f32, "param0")); 1084 HloInstruction* param1 = builder.AddInstruction( 1085 HloInstruction::CreateParameter(1, r1f32, "param1")); 1086 HloInstruction* empty_literal = builder.AddInstruction( 1087 HloInstruction::CreateConstant(Literal::CreateR1<float>({}))); 1088 HloInstruction* empty_slice = 1089 builder.AddInstruction(HloInstruction::CreateSlice( 1090 ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1})); 1091 Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength}); 1092 builder.AddInstruction(HloInstruction::CreateConcatenate( 1093 result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); 1094 1095 auto computation = module().AddEntryComputation(builder.Build()); 1096 1097 EXPECT_THAT( 1098 computation->root_instruction(), 1099 op::Concatenate(empty_literal, param0, param0, empty_slice, param1)); 1100 1101 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1102 non_bitcasting_callback()); 1103 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1104 1105 EXPECT_THAT(computation->root_instruction(), 1106 op::Concatenate(param0, param0, param1)); 1107 } 1108 1109 // Test a concatenate with only empty operands is removed. 1110 TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { 1111 const int kParamLength = 100; 1112 Shape r1f32 = ShapeUtil::MakeShape(F32, {kParamLength}); 1113 HloComputation::Builder builder(TestName()); 1114 HloInstruction* param0 = builder.AddInstruction( 1115 HloInstruction::CreateParameter(0, r1f32, "param0")); 1116 HloInstruction* empty_literal = builder.AddInstruction( 1117 HloInstruction::CreateConstant(Literal::CreateR1<float>({}))); 1118 HloInstruction* empty_slice = 1119 builder.AddInstruction(HloInstruction::CreateSlice( 1120 ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1})); 1121 Shape result_shape = ShapeUtil::MakeShape(F32, {0}); 1122 builder.AddInstruction(HloInstruction::CreateConcatenate( 1123 result_shape, {empty_literal, empty_slice}, 0)); 1124 1125 auto computation = module().AddEntryComputation(builder.Build()); 1126 1127 EXPECT_THAT(computation->root_instruction(), 1128 op::Concatenate(empty_literal, empty_slice)); 1129 1130 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1131 non_bitcasting_callback()); 1132 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1133 1134 EXPECT_EQ(computation->root_instruction(), empty_literal); 1135 } 1136 1137 // Test that concat with a scalar broadcast becomes a pad. 1138 TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { 1139 Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); 1140 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 1141 HloComputation::Builder builder(TestName()); 1142 HloInstruction* param0 = builder.AddInstruction( 1143 HloInstruction::CreateParameter(0, r1f32, "param0")); 1144 HloInstruction* param1 = builder.AddInstruction( 1145 HloInstruction::CreateParameter(1, r0f32, "param1")); 1146 HloInstruction* broadcast = builder.AddInstruction( 1147 HloInstruction::CreateBroadcast(r1f32, param1, {})); 1148 builder.AddInstruction(HloInstruction::CreateConcatenate( 1149 ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0)); 1150 1151 auto computation = module().AddEntryComputation(builder.Build()); 1152 1153 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1154 non_bitcasting_callback()); 1155 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1156 EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); 1157 } 1158 1159 // Test that a simplification which changes layouts is not performed if layout 1160 // sensitive is true. 1161 TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { 1162 HloComputation::Builder builder(TestName()); 1163 HloInstruction* param0 = 1164 builder.AddInstruction(HloInstruction::CreateParameter( 1165 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); 1166 HloInstruction* copy = builder.AddInstruction( 1167 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); 1168 1169 auto computation = module().AddEntryComputation(builder.Build()); 1170 1171 // Set to different layouts. 1172 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 1173 *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); 1174 1175 EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); 1176 1177 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, 1178 non_bitcasting_callback()); 1179 EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); 1180 1181 // Copy has not been removed. 1182 EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); 1183 } 1184 1185 // Test that a simplification which preserves layouts is performed if layout 1186 // sensitive is true. 1187 TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { 1188 HloComputation::Builder builder(TestName()); 1189 HloInstruction* param0 = 1190 builder.AddInstruction(HloInstruction::CreateParameter( 1191 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); 1192 HloInstruction* copy = builder.AddInstruction( 1193 HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); 1194 1195 auto computation = module().AddEntryComputation(builder.Build()); 1196 1197 // Set to same layouts. 1198 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 1199 *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 1200 1201 EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); 1202 1203 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, 1204 non_bitcasting_callback()); 1205 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1206 1207 // Copy has been removed. 1208 EXPECT_THAT(computation->root_instruction(), param0); 1209 } 1210 1211 // Test that a reshape which could be replaced with a bitcast is not if 1212 // add_bitcasts is false. 1213 TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { 1214 HloComputation::Builder builder(TestName()); 1215 HloInstruction* param0 = 1216 builder.AddInstruction(HloInstruction::CreateParameter( 1217 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); 1218 HloInstruction* reshape = 1219 builder.AddInstruction(HloInstruction::CreateReshape( 1220 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0)); 1221 1222 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 1223 *reshape->mutable_shape()->mutable_layout() = 1224 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); 1225 1226 auto computation = module().AddEntryComputation(builder.Build()); 1227 1228 EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); 1229 1230 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, 1231 non_bitcasting_callback()); 1232 EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); 1233 1234 // Reshape is not replaced with a bitcast. 1235 EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); 1236 } 1237 1238 // Test transforming reshapes to bitcasts under various conditions. 1239 TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { 1240 HloComputation::Builder builder(TestName()); 1241 HloInstruction* param0 = 1242 builder.AddInstruction(HloInstruction::CreateParameter( 1243 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); 1244 *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); 1245 1246 // Reshape which can be transformed into a bitcast. 1247 HloInstruction* transformable_reshape = 1248 builder.AddInstruction(HloInstruction::CreateReshape( 1249 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0)); 1250 *transformable_reshape->mutable_shape()->mutable_layout() = 1251 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); 1252 1253 // Reshape does not just add degenerate dimensions. 1254 HloInstruction* dimensions_wrong_reshape = 1255 builder.AddInstruction(HloInstruction::CreateReshape( 1256 ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0)); 1257 *dimensions_wrong_reshape->mutable_shape()->mutable_layout() = 1258 LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); 1259 1260 // Reshape has wrong layout. 1261 HloInstruction* layout_wrong_reshape = 1262 builder.AddInstruction(HloInstruction::CreateReshape( 1263 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0)); 1264 *layout_wrong_reshape->mutable_shape()->mutable_layout() = 1265 LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0}); 1266 1267 // Collect all the reshapes into a tuple so they are not dead. 1268 builder.AddInstruction(HloInstruction::CreateTuple( 1269 {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); 1270 1271 auto computation = module().AddEntryComputation(builder.Build()); 1272 1273 EXPECT_THAT(computation->root_instruction(), 1274 op::Tuple(transformable_reshape, dimensions_wrong_reshape, 1275 layout_wrong_reshape)); 1276 1277 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, 1278 bitcasting_callback()); 1279 simplifier.Run(&module()).ValueOrDie(); 1280 1281 // Verify that only the first reshape is replaced. 1282 EXPECT_THAT( 1283 computation->root_instruction(), 1284 op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); 1285 } 1286 1287 TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { 1288 HloComputation::Builder builder(TestName()); 1289 HloInstruction* param = 1290 builder.AddInstruction(HloInstruction::CreateParameter( 1291 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param")); 1292 HloInstruction* movable_reshape = 1293 builder.AddInstruction(HloInstruction::CreateReshape( 1294 ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); 1295 HloInstruction* zero = builder.AddInstruction( 1296 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 1297 builder.AddInstruction( 1298 HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), 1299 HloOpcode::kMaximum, movable_reshape, zero)); 1300 auto computation = module().AddEntryComputation(builder.Build()); 1301 1302 EXPECT_THAT(computation->root_instruction(), 1303 op::Maximum(op::Reshape(param), zero)); 1304 1305 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1306 bitcasting_callback()); 1307 1308 simplifier.Run(&module()).ValueOrDie(); 1309 EXPECT_THAT(computation->root_instruction(), 1310 op::Reshape(op::Maximum(param, zero))); 1311 } 1312 1313 // Regression test for a bug in the reshape sinking transformation, where 1314 // moving a reshape to a scalar led to a crash. 1315 TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { 1316 HloComputation::Builder builder(TestName()); 1317 HloInstruction* param = 1318 builder.AddInstruction(HloInstruction::CreateParameter( 1319 0, ShapeUtil::MakeShape(F32, {1, 1}), "param")); 1320 HloInstruction* reshape = builder.AddInstruction( 1321 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param)); 1322 HloInstruction* zero = builder.AddInstruction( 1323 HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.}))); 1324 builder.AddInstruction(HloInstruction::CreateBinary( 1325 ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); 1326 auto computation = module().AddEntryComputation(builder.Build()); 1327 1328 EXPECT_THAT(computation->root_instruction(), 1329 op::Maximum(op::Reshape(param), zero)); 1330 1331 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1332 bitcasting_callback()); 1333 1334 simplifier.Run(&module()).ValueOrDie(); 1335 1336 EXPECT_THAT(computation->root_instruction(), 1337 op::Maximum(op::Reshape(param), zero)); 1338 } 1339 1340 // Regression test for a bug where if we failed to sink a reshape, we'd set the 1341 // 'changed' bit in AlgebraicSimplifier to false. 1342 TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { 1343 HloComputation::Builder builder(TestName()); 1344 1345 // This add (param0 + 0) can be simplified. 1346 Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); 1347 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( 1348 shape, HloOpcode::kAdd, 1349 builder.AddInstruction( 1350 HloInstruction::CreateParameter(0, shape, "param0")), 1351 builder.AddInstruction(HloInstruction::CreateConstant( 1352 Literal::CreateR2<float>({{0, 0}, {0, 0}}))))); 1353 1354 builder.AddInstruction( 1355 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); 1356 1357 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1358 bitcasting_callback()); 1359 module().AddEntryComputation(builder.Build()); 1360 EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1361 } 1362 1363 // Regression test for a bug where if we failed to sink a reshape, we'd set the 1364 // 'changed' bit in AlgebraicSimplifier to false. 1365 TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { 1366 HloComputation::Builder builder(TestName()); 1367 1368 // This add (param0 + 0) can be simplified. 1369 Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); 1370 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( 1371 shape, HloOpcode::kAdd, 1372 builder.AddInstruction( 1373 HloInstruction::CreateParameter(0, shape, "param0")), 1374 builder.AddInstruction(HloInstruction::CreateConstant( 1375 Literal::CreateR2<float>({{0, 0}, {0, 0}}))))); 1376 1377 builder.AddInstruction( 1378 HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, 1379 /*broadcast_dimensions=*/{0, 1})); 1380 1381 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1382 bitcasting_callback()); 1383 module().AddEntryComputation(builder.Build()); 1384 EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1385 } 1386 1387 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { 1388 HloComputation::Builder builder(TestName()); 1389 HloInstruction* param = 1390 builder.AddInstruction(HloInstruction::CreateParameter( 1391 0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param")); 1392 *param->mutable_shape()->mutable_layout() = 1393 LayoutUtil::MakeLayout({1, 2, 0, 3}); 1394 1395 HloInstruction* transpose = 1396 builder.AddInstruction(HloInstruction::CreateTranspose( 1397 ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3})); 1398 *transpose->mutable_shape()->mutable_layout() = 1399 LayoutUtil::MakeLayout({0, 1, 2, 3}); 1400 1401 auto computation = module().AddEntryComputation(builder.Build()); 1402 1403 EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); 1404 1405 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, 1406 bitcasting_callback()); 1407 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1408 1409 // Verify that the reshape is replaced. 1410 EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); 1411 } 1412 1413 TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { 1414 HloComputation::Builder builder(TestName()); 1415 HloInstruction* param = 1416 builder.AddInstruction(HloInstruction::CreateParameter( 1417 0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param")); 1418 *param->mutable_shape()->mutable_layout() = 1419 LayoutUtil::MakeLayout({1, 2, 3, 0}); 1420 1421 HloInstruction* transpose = 1422 builder.AddInstruction(HloInstruction::CreateTranspose( 1423 ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1})); 1424 *transpose->mutable_shape()->mutable_layout() = 1425 LayoutUtil::MakeLayout({3, 1, 2, 0}); 1426 1427 auto computation = module().AddEntryComputation(builder.Build()); 1428 1429 EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); 1430 1431 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, 1432 bitcasting_callback()); 1433 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1434 1435 // Verify that the reshape is replaced. 1436 EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); 1437 } 1438 1439 TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { 1440 HloComputation::Builder builder(TestName()); 1441 HloInstruction* param0 = 1442 builder.AddInstruction(HloInstruction::CreateParameter( 1443 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0")); 1444 1445 HloInstruction* reshape1 = 1446 builder.AddInstruction(HloInstruction::CreateReshape( 1447 ShapeUtil::MakeShape(F32, {2, 1, 2}), param0)); 1448 1449 builder.AddInstruction(HloInstruction::CreateReshape( 1450 ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); 1451 1452 auto computation = module().AddEntryComputation(builder.Build()); 1453 1454 EXPECT_THAT(computation->root_instruction(), 1455 op::Reshape(op::Reshape(param0))); 1456 1457 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1458 non_bitcasting_callback()); 1459 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1460 1461 EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); 1462 } 1463 1464 TEST_F(AlgebraicSimplifierTest, CopiesMerged) { 1465 HloComputation::Builder builder(TestName()); 1466 HloInstruction* param0 = 1467 builder.AddInstruction(HloInstruction::CreateParameter( 1468 0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}), 1469 "param0")); 1470 1471 HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary( 1472 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), 1473 HloOpcode::kCopy, param0)); 1474 1475 builder.AddInstruction(HloInstruction::CreateUnary( 1476 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), 1477 HloOpcode::kCopy, copy1)); 1478 1479 auto computation = module().AddEntryComputation(builder.Build()); 1480 1481 EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); 1482 1483 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, 1484 non_bitcasting_callback()); 1485 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1486 1487 EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); 1488 } 1489 1490 TEST_F(AlgebraicSimplifierTest, TransposesMerged) { 1491 HloComputation::Builder builder(TestName()); 1492 HloInstruction* param0 = 1493 builder.AddInstruction(HloInstruction::CreateParameter( 1494 0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0")); 1495 1496 HloInstruction* transpose1 = 1497 builder.AddInstruction(HloInstruction::CreateTranspose( 1498 ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0})); 1499 1500 builder.AddInstruction(HloInstruction::CreateTranspose( 1501 ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); 1502 1503 auto computation = module().AddEntryComputation(builder.Build()); 1504 1505 EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); 1506 1507 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1508 non_bitcasting_callback()); 1509 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1510 1511 EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); 1512 EXPECT_EQ(std::vector<int64>({2, 1, 0}), 1513 computation->root_instruction()->dimensions()); 1514 } 1515 1516 // Test merging reshape and broadcast. 1517 TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { 1518 HloComputation::Builder builder(TestName()); 1519 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 1520 0, ShapeUtil::MakeShape(F32, {5}), "param0")); 1521 auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( 1522 ShapeUtil::MakeShape(F32, {1, 5, 1}), param0)); 1523 builder.AddInstruction(HloInstruction::CreateBroadcast( 1524 ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2})); 1525 1526 auto computation = module().AddEntryComputation(builder.Build()); 1527 1528 EXPECT_THAT(computation->root_instruction(), 1529 op::Broadcast(op::Reshape(param0))); 1530 1531 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1532 non_bitcasting_callback()); 1533 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1534 1535 EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); 1536 } 1537 1538 // Test merging broadcast and reshape. 1539 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { 1540 HloComputation::Builder builder(TestName()); 1541 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 1542 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0")); 1543 auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast( 1544 ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2})); 1545 builder.AddInstruction(HloInstruction::CreateReshape( 1546 ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); 1547 1548 auto computation = module().AddEntryComputation(builder.Build()); 1549 1550 EXPECT_THAT(computation->root_instruction(), 1551 op::Reshape(op::Broadcast(param0))); 1552 1553 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1554 non_bitcasting_callback()); 1555 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1556 1557 EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); 1558 } 1559 1560 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { 1561 HloComputation::Builder builder(TestName()); 1562 auto param = builder.AddInstruction(HloInstruction::CreateParameter( 1563 0, ShapeUtil::MakeShape(F32, {1}), "param")); 1564 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( 1565 ShapeUtil::MakeShape(F32, {3, 1}), param, {1})); 1566 builder.AddInstruction( 1567 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); 1568 1569 auto computation = module().AddEntryComputation(builder.Build()); 1570 1571 EXPECT_THAT(computation->root_instruction(), 1572 op::Reshape(op::Broadcast(param))); 1573 1574 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1575 non_bitcasting_callback()); 1576 EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); 1577 1578 EXPECT_THAT(computation->root_instruction(), 1579 op::Reshape(op::Broadcast(param))); 1580 } 1581 1582 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { 1583 HloComputation::Builder builder(TestName()); 1584 auto param = builder.AddInstruction(HloInstruction::CreateParameter( 1585 0, ShapeUtil::MakeShape(F32, {4}), "param")); 1586 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( 1587 ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2})); 1588 builder.AddInstruction(HloInstruction::CreateReshape( 1589 ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); 1590 1591 HloComputation* computation = module().AddEntryComputation(builder.Build()); 1592 1593 EXPECT_THAT(computation->root_instruction(), 1594 op::Reshape(op::Broadcast(param))); 1595 1596 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1597 non_bitcasting_callback()); 1598 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1599 1600 EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); 1601 EXPECT_THAT(computation->root_instruction()->dimensions(), 1602 ::testing::ElementsAre(3)); 1603 } 1604 1605 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { 1606 HloComputation::Builder builder(TestName()); 1607 auto param = builder.AddInstruction(HloInstruction::CreateParameter( 1608 0, ShapeUtil::MakeShape(F32, {1}), "param")); 1609 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( 1610 ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2})); 1611 builder.AddInstruction(HloInstruction::CreateReshape( 1612 ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); 1613 1614 HloComputation* computation = module().AddEntryComputation(builder.Build()); 1615 1616 EXPECT_THAT(computation->root_instruction(), 1617 op::Reshape(op::Broadcast(param))); 1618 1619 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1620 non_bitcasting_callback()); 1621 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 1622 1623 EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); 1624 const std::vector<int64> broadcast_dims = 1625 computation->root_instruction()->dimensions(); 1626 EXPECT_EQ(1, broadcast_dims.size()); 1627 EXPECT_THAT(broadcast_dims[0], ::testing::AnyOf(1, 2, 3)); 1628 } 1629 1630 TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { 1631 HloComputation::Builder builder(TestName()); 1632 auto param = builder.AddInstruction(HloInstruction::CreateParameter( 1633 0, ShapeUtil::MakeShape(F32, {4}), "param")); 1634 auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast( 1635 ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2})); 1636 builder.AddInstruction(HloInstruction::CreateReshape( 1637 ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); 1638 1639 HloComputation* computation = module().AddEntryComputation(builder.Build()); 1640 1641 EXPECT_THAT(computation->root_instruction(), 1642 op::Reshape(op::Broadcast(param))); 1643 1644 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1645 non_bitcasting_callback()); 1646 EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); 1647 1648 EXPECT_THAT(computation->root_instruction(), 1649 op::Reshape(op::Broadcast(param))); 1650 } 1651 1652 TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { 1653 HloComputation::Builder builder(TestName()); 1654 HloInstruction* param = 1655 builder.AddInstruction(HloInstruction::CreateParameter( 1656 0, ShapeUtil::MakeShape(F32, {2, 2}), "param")); 1657 HloInstruction* zero = builder.AddInstruction( 1658 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 1659 PaddingConfig no_padding; 1660 for (int i = 0; i < 2; ++i) { 1661 auto dimension = no_padding.add_dimensions(); 1662 dimension->set_edge_padding_low(0); 1663 dimension->set_edge_padding_high(0); 1664 dimension->set_interior_padding(0); 1665 } 1666 builder.AddInstruction(HloInstruction::CreatePad( 1667 ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); 1668 1669 HloModule module(TestName()); 1670 HloComputation* computation = module.AddEntryComputation(builder.Build()); 1671 1672 EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); 1673 1674 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1675 non_bitcasting_callback()); 1676 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 1677 1678 EXPECT_THAT(computation->root_instruction(), param); 1679 } 1680 1681 TEST_F(AlgebraicSimplifierTest, NegativePadding) { 1682 // Verify that a pad instruction with negative padding is replaced with a 1683 // pad with non-negative padding followed by a slice. 1684 HloComputation::Builder builder(TestName()); 1685 HloInstruction* param = 1686 builder.AddInstruction(HloInstruction::CreateParameter( 1687 0, ShapeUtil::MakeShape(F32, {10, 10}), "param")); 1688 HloInstruction* zero = builder.AddInstruction( 1689 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 1690 PaddingConfig padding; 1691 int64 low_padding[2] = {-1, -2}; 1692 int64 high_padding[2] = {2, -3}; 1693 for (int i = 0; i < 2; ++i) { 1694 auto dimension = padding.add_dimensions(); 1695 dimension->set_edge_padding_low(low_padding[i]); 1696 dimension->set_edge_padding_high(high_padding[i]); 1697 dimension->set_interior_padding(0); 1698 } 1699 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( 1700 ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); 1701 1702 HloModule module(TestName()); 1703 HloComputation* computation = module.AddEntryComputation(builder.Build()); 1704 1705 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1706 non_bitcasting_callback()); 1707 1708 auto has_negative_padding = [](const HloInstruction* pad) { 1709 for (auto& padding_dimension : pad->padding_config().dimensions()) { 1710 if (padding_dimension.edge_padding_low() < 0 || 1711 padding_dimension.edge_padding_high() < 0) { 1712 return true; 1713 } 1714 } 1715 return false; 1716 }; 1717 1718 EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); 1719 EXPECT_TRUE(has_negative_padding(pad)); 1720 1721 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 1722 1723 EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); 1724 EXPECT_FALSE( 1725 has_negative_padding(computation->root_instruction()->operand(0))); 1726 } 1727 1728 TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { 1729 HloComputation::Builder builder(TestName()); 1730 HloInstruction* param = 1731 builder.AddInstruction(HloInstruction::CreateParameter( 1732 0, ShapeUtil::MakeShape(F32, {2, 3}), "param")); 1733 builder.AddInstruction( 1734 HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); 1735 1736 HloModule module(TestName()); 1737 HloComputation* computation = module.AddEntryComputation(builder.Build()); 1738 1739 EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); 1740 1741 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1742 non_bitcasting_callback()); 1743 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 1744 1745 EXPECT_THAT(computation->root_instruction(), param); 1746 } 1747 1748 TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { 1749 HloComputation::Builder builder(TestName()); 1750 const int64 dim0 = 2; 1751 const int64 dim1 = 3; 1752 HloInstruction* param = 1753 builder.AddInstruction(HloInstruction::CreateParameter( 1754 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param")); 1755 builder.AddInstruction(HloInstruction::CreateSlice( 1756 ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, 1757 /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); 1758 1759 HloModule module(TestName()); 1760 HloComputation* computation = module.AddEntryComputation(builder.Build()); 1761 1762 EXPECT_THAT(computation->root_instruction(), op::Slice(param)); 1763 1764 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 1765 non_bitcasting_callback()); 1766 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 1767 1768 EXPECT_THAT(computation->root_instruction(), param); 1769 } 1770 1771 TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { 1772 struct ConvTestOptions { 1773 int in_batch = 10; 1774 int in_height = 2; 1775 int in_width = 2; 1776 int in_channels = 3; 1777 int f_width = 1; 1778 int f_height = 1; 1779 int f_output_channels = 10; 1780 int row_stride = 1; 1781 int row_padding = 0; 1782 int col_stride = 1; 1783 int col_padding = 0; 1784 bool input_minor_to_major_layout = false; 1785 bool filter_minor_to_major_layout = false; 1786 bool output_minor_to_major_layout = false; 1787 1788 const char* dim_order = "NHWC"; // can use chars NHWC in any order. 1789 const char* kernel_dim_order = "HWIO"; // can use chars HWIO in any order. 1790 1791 ConvTestOptions& Reset() { 1792 *this = ConvTestOptions(); 1793 return *this; 1794 } 1795 }; 1796 1797 ConvTestOptions options; 1798 1799 // Builds a convolution from <options> and runs algebraic simplification on 1800 // the computation. Returns a string description of the result of 1801 // simplification. 1802 auto build_and_simplify = [&options, this]() -> string { 1803 HloComputation::Builder b(TestName()); 1804 1805 Window window; 1806 auto* f_dim_1 = window.add_dimensions(); 1807 f_dim_1->set_size(options.f_height); 1808 f_dim_1->set_stride(options.row_stride); 1809 f_dim_1->set_padding_low(options.row_padding); 1810 f_dim_1->set_padding_high(options.row_padding); 1811 f_dim_1->set_window_dilation(1); 1812 f_dim_1->set_base_dilation(1); 1813 auto* f_dim_2 = window.add_dimensions(); 1814 f_dim_2->set_size(options.f_width); 1815 f_dim_2->set_stride(options.col_stride); 1816 f_dim_2->set_padding_low(options.col_padding); 1817 f_dim_2->set_padding_high(options.col_padding); 1818 f_dim_2->set_window_dilation(1); 1819 f_dim_2->set_base_dilation(1); 1820 1821 ConvolutionDimensionNumbers dnums; 1822 std::vector<int64> in_dims; 1823 int in_channel_idx = -1; 1824 // filled in later 1825 dnums.add_input_spatial_dimensions(-1); 1826 dnums.add_output_spatial_dimensions(-1); 1827 dnums.add_input_spatial_dimensions(-1); 1828 dnums.add_output_spatial_dimensions(-1); 1829 for (int i = 0; i < strlen(options.dim_order); ++i) { 1830 char ch = options.dim_order[i]; 1831 if (ch == 'N') { 1832 dnums.set_input_batch_dimension(i); 1833 dnums.set_output_batch_dimension(i); 1834 in_dims.push_back(options.in_batch); 1835 } else if (ch == 'H') { 1836 dnums.set_input_spatial_dimensions(0, i); 1837 dnums.set_output_spatial_dimensions(0, i); 1838 in_dims.push_back(options.in_height); 1839 } else if (ch == 'W') { 1840 dnums.set_input_spatial_dimensions(1, i); 1841 dnums.set_output_spatial_dimensions(1, i); 1842 in_dims.push_back(options.in_width); 1843 } else if (ch == 'C') { 1844 dnums.set_input_feature_dimension(i); 1845 dnums.set_output_feature_dimension(i); 1846 in_dims.push_back(options.in_channels); 1847 in_channel_idx = i; 1848 } 1849 } 1850 1851 std::vector<int64> f_dims; 1852 dnums.add_kernel_spatial_dimensions(-1); // filled in later 1853 dnums.add_kernel_spatial_dimensions(-1); // filled in later 1854 for (int i = 0; i < strlen(options.kernel_dim_order); ++i) { 1855 char ch = options.kernel_dim_order[i]; 1856 if (ch == 'H') { 1857 dnums.set_kernel_spatial_dimensions(0, i); 1858 f_dims.push_back(options.f_height); 1859 } else if (ch == 'W') { 1860 dnums.set_kernel_spatial_dimensions(1, i); 1861 f_dims.push_back(options.f_width); 1862 } else if (ch == 'I') { 1863 dnums.set_kernel_input_feature_dimension(i); 1864 f_dims.push_back(options.in_channels); 1865 } else if (ch == 'O') { 1866 dnums.set_kernel_output_feature_dimension(i); 1867 f_dims.push_back(options.f_output_channels); 1868 } 1869 } 1870 1871 auto out_dims = in_dims; 1872 out_dims[in_channel_idx] = options.f_output_channels; 1873 1874 auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims, 1875 bool minor_to_major_layout) { 1876 if (minor_to_major_layout) { 1877 return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3}); 1878 } else { 1879 return ShapeUtil::MakeShape(F32, dims); 1880 } 1881 }; 1882 auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout); 1883 auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout); 1884 auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout); 1885 1886 HloInstruction* input = 1887 b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input")); 1888 HloInstruction* filter = 1889 b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); 1890 1891 b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, 1892 window, dnums)); 1893 1894 HloModule module(TestName()); 1895 auto* computation = module.AddEntryComputation(b.Build()); 1896 1897 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, 1898 bitcasting_callback()); 1899 if (!simplifier.Run(&module).ValueOrDie()) { 1900 return "NO_CHANGE"; 1901 } 1902 auto* root = computation->root_instruction(); 1903 if (root->opcode() == HloOpcode::kBitcast && 1904 root->operand(0)->opcode() == HloOpcode::kDot) { 1905 auto lhs_shape = root->operand(0)->operand(0)->shape(); 1906 auto rhs_shape = root->operand(0)->operand(1)->shape(); 1907 return tensorflow::strings::StrCat( 1908 tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ", 1909 tensorflow::str_util::Join(rhs_shape.dimensions(), "x")); 1910 } 1911 return "UNEXPECTED CHANGE"; 1912 }; 1913 1914 // Default options are the simplest case and succeed. 1915 options.Reset(); 1916 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); 1917 1918 // Swapping dim spatial and batch order works. 1919 options.Reset().dim_order = "NWHC"; 1920 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); 1921 options.Reset().dim_order = "WHNC"; 1922 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); 1923 // Channel dimension earlier fails. 1924 options.Reset().dim_order = "HWCN"; 1925 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1926 options.Reset().dim_order = "CHWN"; 1927 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1928 1929 // Filtering dims spatial dims can be anywhere, since they are 1x1. 1930 options.Reset().kernel_dim_order = "WHIO"; 1931 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); 1932 options.Reset().kernel_dim_order = "IWOH"; 1933 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); 1934 options.Reset().kernel_dim_order = "IWHO"; 1935 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); 1936 // But moving output channel before input channel fails. 1937 options.Reset().kernel_dim_order = "HWOI"; 1938 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1939 options.Reset().kernel_dim_order = "WHOI"; 1940 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1941 options.Reset().kernel_dim_order = "OWIH"; 1942 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1943 options.Reset().kernel_dim_order = "OWHI"; 1944 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1945 1946 // Combine different dim and kernel dim orders. 1947 options.Reset().kernel_dim_order = "IWHO"; 1948 options.dim_order = "WHNC"; 1949 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); 1950 1951 // Test invalid cases from wrong filter size, strides, or padding. 1952 options.Reset().f_width = 2; 1953 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1954 options.Reset().f_height = 2; 1955 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1956 options.Reset().row_stride = 2; 1957 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1958 options.Reset().col_stride = 2; 1959 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1960 options.Reset().col_padding = 1; 1961 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1962 options.Reset().row_padding = 1; 1963 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1964 1965 // The default dim_order is "NHWC". Col-major layout makes C the most major. 1966 options.Reset().input_minor_to_major_layout = true; 1967 options.output_minor_to_major_layout = true; 1968 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1969 1970 // The input and output have different layouts. 1971 options.Reset().input_minor_to_major_layout = true; 1972 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1973 1974 // C is most minor, and I is more major than O. 1975 options.Reset().input_minor_to_major_layout = true; 1976 options.filter_minor_to_major_layout = true; 1977 options.output_minor_to_major_layout = true; 1978 options.dim_order = "CHWN"; 1979 options.kernel_dim_order = "OIHW"; 1980 EXPECT_EQ("40x3 DOT 3x10", build_and_simplify()); 1981 1982 // C is not the most minor dimension. 1983 options.Reset().input_minor_to_major_layout = true; 1984 options.filter_minor_to_major_layout = true; 1985 options.output_minor_to_major_layout = true; 1986 options.dim_order = "HWNC"; 1987 options.kernel_dim_order = "OIHW"; 1988 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1989 1990 // I is more minor than O. 1991 options.Reset().input_minor_to_major_layout = true; 1992 options.filter_minor_to_major_layout = true; 1993 options.output_minor_to_major_layout = true; 1994 options.dim_order = "CHWN"; 1995 options.kernel_dim_order = "IOHW"; 1996 EXPECT_EQ("NO_CHANGE", build_and_simplify()); 1997 } 1998 1999 // Test that max(min(A, x), y) is transformed to clamp(y, A, x) 2000 TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { 2001 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 2002 HloComputation::Builder builder(TestName()); 2003 HloInstruction* param0 = builder.AddInstruction( 2004 HloInstruction::CreateParameter(0, r0f32, "param0")); 2005 HloInstruction* min_value = builder.AddInstruction( 2006 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 2007 HloInstruction* max_value = builder.AddInstruction( 2008 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 2009 HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( 2010 r0f32, HloOpcode::kMinimum, param0, min_value)); 2011 builder.AddInstruction( 2012 HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); 2013 2014 HloModule module(TestName()); 2015 auto computation = module.AddEntryComputation(builder.Build()); 2016 2017 EXPECT_THAT(computation->root_instruction(), 2018 op::Maximum(op::Minimum(param0, min_value), max_value)); 2019 2020 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2021 non_bitcasting_callback()); 2022 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 2023 2024 EXPECT_THAT(computation->root_instruction(), 2025 op::Clamp(max_value, param0, min_value)); 2026 } 2027 2028 // Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar 2029 // values. 2030 TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { 2031 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 2032 HloComputation::Builder builder(TestName()); 2033 HloInstruction* param0 = builder.AddInstruction( 2034 HloInstruction::CreateParameter(0, r0f32, "param0")); 2035 HloInstruction* min_value = builder.AddInstruction( 2036 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 2037 HloInstruction* max_value = builder.AddInstruction( 2038 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 2039 HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( 2040 r0f32, HloOpcode::kMaximum, param0, max_value)); 2041 builder.AddInstruction( 2042 HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); 2043 2044 HloModule module(TestName()); 2045 auto computation = module.AddEntryComputation(builder.Build()); 2046 2047 EXPECT_THAT(computation->root_instruction(), 2048 op::Minimum(op::Maximum(param0, max_value), min_value)); 2049 2050 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2051 non_bitcasting_callback()); 2052 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 2053 2054 EXPECT_THAT(computation->root_instruction(), 2055 op::Clamp(max_value, param0, min_value)); 2056 } 2057 2058 // Test that min(max(A, x), y) is transformed to clamp(x, A, y) for 2059 // broadcasted scalar values. 2060 TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { 2061 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 2062 Shape r1f32 = ShapeUtil::MakeShape(F32, {100}); 2063 HloComputation::Builder builder(TestName()); 2064 HloInstruction* param0 = builder.AddInstruction( 2065 HloInstruction::CreateParameter(0, r1f32, "param0")); 2066 HloInstruction* min_value = builder.AddInstruction( 2067 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 2068 HloInstruction* max_value = builder.AddInstruction( 2069 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 2070 HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( 2071 r1f32, HloOpcode::kMaximum, param0, max_value)); 2072 builder.AddInstruction( 2073 HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); 2074 2075 HloModule module(TestName()); 2076 auto computation = module.AddEntryComputation(builder.Build()); 2077 2078 EXPECT_THAT(computation->root_instruction(), 2079 op::Minimum(op::Maximum(param0, max_value), min_value)); 2080 2081 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2082 non_bitcasting_callback()); 2083 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 2084 2085 EXPECT_THAT(computation->root_instruction(), 2086 op::Clamp(max_value, param0, min_value)); 2087 } 2088 2089 // Test that min(max(A, non-constant1), non-constant2) is not canonicalized to 2090 // clamp(non-constant1, A, non-constant2) 2091 TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { 2092 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 2093 HloComputation::Builder builder(TestName()); 2094 HloInstruction* param0 = builder.AddInstruction( 2095 HloInstruction::CreateParameter(0, r0f32, "param0")); 2096 HloInstruction* min_value = builder.AddInstruction( 2097 HloInstruction::CreateParameter(1, r0f32, "param1")); 2098 HloInstruction* max_value = builder.AddInstruction( 2099 HloInstruction::CreateParameter(2, r0f32, "param2")); 2100 HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( 2101 r0f32, HloOpcode::kMaximum, param0, max_value)); 2102 builder.AddInstruction( 2103 HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); 2104 2105 HloModule module(TestName()); 2106 auto computation = module.AddEntryComputation(builder.Build()); 2107 2108 EXPECT_THAT(computation->root_instruction(), 2109 op::Minimum(op::Maximum(param0, max_value), min_value)); 2110 2111 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2112 non_bitcasting_callback()); 2113 EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); 2114 2115 EXPECT_THAT(computation->root_instruction(), 2116 op::Minimum(op::Maximum(param0, max_value), min_value)); 2117 } 2118 2119 // Test that min(f(max(A, constant1)), constant2) is not transformed to 2120 // clamp(constant1, A, constant2) 2121 TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { 2122 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 2123 HloComputation::Builder builder(TestName()); 2124 HloInstruction* param0 = builder.AddInstruction( 2125 HloInstruction::CreateParameter(0, r0f32, "param0")); 2126 HloInstruction* min_value = builder.AddInstruction( 2127 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 2128 HloInstruction* max_value = builder.AddInstruction( 2129 HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f))); 2130 HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( 2131 r0f32, HloOpcode::kMaximum, param0, max_value)); 2132 HloInstruction* fmax = builder.AddInstruction( 2133 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value)); 2134 builder.AddInstruction(HloInstruction::CreateBinary( 2135 r0f32, HloOpcode::kMinimum, fmax, min_value)); 2136 2137 HloModule module(TestName()); 2138 auto computation = module.AddEntryComputation(builder.Build()); 2139 2140 EXPECT_THAT(computation->root_instruction(), 2141 op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), 2142 min_value)); 2143 2144 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2145 non_bitcasting_callback()); 2146 EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); 2147 2148 EXPECT_THAT(computation->root_instruction(), 2149 op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), 2150 min_value)); 2151 } 2152 2153 // Test that slice(broadcast(/*scalar value*/)) simplifies to a single 2154 // broadcast. 2155 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { 2156 Shape r0f32 = ShapeUtil::MakeShape(F32, {}); 2157 HloComputation::Builder builder(TestName()); 2158 HloInstruction* scalar_param = builder.AddInstruction( 2159 HloInstruction::CreateParameter(0, r0f32, "scalar_param")); 2160 2161 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); 2162 HloInstruction* broadcast = 2163 builder.AddInstruction(HloInstruction::CreateBroadcast( 2164 broadcast_shape, scalar_param, 2165 AsInt64Slice(broadcast_shape.dimensions()))); 2166 2167 Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); 2168 HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( 2169 slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); 2170 2171 HloModule module(TestName()); 2172 auto computation = module.AddEntryComputation(builder.Build()); 2173 2174 HloInstruction* root = computation->root_instruction(); 2175 EXPECT_EQ(root, slice); 2176 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); 2177 2178 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2179 non_bitcasting_callback()); 2180 2181 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 2182 2183 // Running simplification again should not result in any further changes. 2184 ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); 2185 2186 root = computation->root_instruction(); 2187 EXPECT_THAT(root, op::Broadcast(scalar_param)); 2188 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape)); 2189 } 2190 2191 // Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a 2192 // single broadcast. 2193 TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { 2194 HloComputation::Builder builder(TestName()); 2195 HloInstruction* forty_two = builder.AddInstruction( 2196 HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); 2197 2198 Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); 2199 HloInstruction* broadcast = 2200 builder.AddInstruction(HloInstruction::CreateBroadcast( 2201 broadcast_shape, forty_two, 2202 AsInt64Slice(broadcast_shape.dimensions()))); 2203 2204 HloInstruction* transpose = 2205 builder.AddInstruction(HloInstruction::CreateTranspose( 2206 ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0})); 2207 2208 Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4}); 2209 HloInstruction* reshape = builder.AddInstruction( 2210 HloInstruction::CreateReshape(reshape_shape, transpose)); 2211 2212 HloModule module(TestName()); 2213 auto computation = module.AddEntryComputation(builder.Build()); 2214 2215 HloInstruction* root = computation->root_instruction(); 2216 EXPECT_EQ(root, reshape); 2217 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); 2218 2219 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2220 non_bitcasting_callback()); 2221 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 2222 2223 root = computation->root_instruction(); 2224 EXPECT_THAT(root, op::Broadcast(forty_two)); 2225 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape)); 2226 } 2227 2228 // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). 2229 TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { 2230 HloModule module(TestName()); 2231 HloComputation::Builder builder(TestName()); 2232 2233 // Create operand to the pad. 2234 HloInstruction* operand = 2235 builder.AddInstruction(HloInstruction::CreateParameter( 2236 0, ShapeUtil::MakeShape(F32, {1, 2, 3, 4}), "p0")); 2237 2238 // Create the pad. 2239 PaddingConfig padding = MakeNoPaddingConfig(4); 2240 padding.mutable_dimensions(1)->set_edge_padding_low(1); 2241 padding.mutable_dimensions(3)->set_edge_padding_high(2); 2242 2243 HloInstruction* pad_value = builder.AddInstruction( 2244 HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f))); 2245 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( 2246 ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); 2247 2248 // Create add computation. 2249 HloComputation* add_computation = nullptr; 2250 { 2251 HloComputation::Builder builder(TestName() + ".add"); 2252 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 2253 HloInstruction* p0 = builder.AddInstruction( 2254 HloInstruction::CreateParameter(0, scalar_shape, "p0")); 2255 HloInstruction* p1 = builder.AddInstruction( 2256 HloInstruction::CreateParameter(1, scalar_shape, "p1")); 2257 builder.AddInstruction( 2258 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); 2259 add_computation = module.AddEmbeddedComputation(builder.Build()); 2260 } 2261 2262 // Create the reduce-window. 2263 Window window; 2264 for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) { 2265 auto* dim = window.add_dimensions(); 2266 dim->set_size(1); 2267 dim->set_padding_low(10); 2268 dim->set_padding_high(100); 2269 dim->set_window_dilation(1); 2270 dim->set_base_dilation(1); 2271 } 2272 const Shape reduce_window_shape = 2273 ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); 2274 HloInstruction* reduce_init_value = builder.AddInstruction( 2275 HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f))); 2276 HloInstruction* reduce_window = 2277 builder.AddInstruction(HloInstruction::CreateReduceWindow( 2278 reduce_window_shape, pad, reduce_init_value, window, 2279 add_computation)); 2280 2281 // Build the computation and run the simplifier. 2282 auto computation = module.AddEntryComputation(builder.Build()); 2283 HloInstruction* root = computation->root_instruction(); 2284 EXPECT_EQ(root, reduce_window); 2285 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2286 non_bitcasting_callback()); 2287 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 2288 2289 // Running simplification again should not result in any further changes. 2290 ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); 2291 2292 // Verify the result 2293 root = computation->root_instruction(); 2294 EXPECT_THAT(root, op::ReduceWindow(operand, op::Constant())); 2295 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape)) 2296 << ShapeUtil::HumanString(root->shape()) << " vs " 2297 << ShapeUtil::HumanString(reduce_window_shape); 2298 EXPECT_EQ(root->window().dimensions(0).padding_low(), 10); 2299 EXPECT_EQ(root->window().dimensions(1).padding_low(), 11); 2300 EXPECT_EQ(root->window().dimensions(2).padding_low(), 10); 2301 EXPECT_EQ(root->window().dimensions(3).padding_low(), 10); 2302 EXPECT_EQ(root->window().dimensions(0).padding_high(), 100); 2303 EXPECT_EQ(root->window().dimensions(1).padding_high(), 100); 2304 EXPECT_EQ(root->window().dimensions(2).padding_high(), 100); 2305 EXPECT_EQ(root->window().dimensions(3).padding_high(), 102); 2306 } 2307 2308 TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { 2309 HloComputation::Builder builder(TestName()); 2310 const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1}); 2311 HloInstruction* a = 2312 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); 2313 builder.AddInstruction( 2314 HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); 2315 2316 HloModule module(TestName()); 2317 auto computation = module.AddEntryComputation(builder.Build()); 2318 2319 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2320 non_bitcasting_callback()); 2321 ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); 2322 2323 HloInstruction* root = computation->root_instruction(); 2324 EXPECT_EQ(a, root); 2325 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape)); 2326 } 2327 2328 TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { 2329 // Dots add computations to the parent module. Test that, when the HloModule's 2330 // computations are updated, then iterator invalidation doesn't occur 2331 // when running on subsequent computations. 2332 Shape r1f32 = ShapeUtil::MakeShape(F32, {1}); 2333 HloComputation::Builder builder(TestName() + ".Dot"); 2334 HloInstruction* x = 2335 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); 2336 HloInstruction* y = 2337 builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); 2338 DotDimensionNumbers dot_dnums; 2339 dot_dnums.add_lhs_contracting_dimensions(1); 2340 dot_dnums.add_rhs_contracting_dimensions(0); 2341 builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); 2342 std::unique_ptr<HloComputation> dot_computation(builder.Build()); 2343 2344 HloComputation::Builder call_builder(TestName() + ".Call"); 2345 HloInstruction* zero = call_builder.AddInstruction( 2346 HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f}))); 2347 HloInstruction* one = call_builder.AddInstruction( 2348 HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f}))); 2349 call_builder.AddInstruction( 2350 HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); 2351 2352 module().AddEmbeddedComputation(std::move(dot_computation)); 2353 module().AddEntryComputation(call_builder.Build()); 2354 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2355 non_bitcasting_callback()); 2356 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 2357 } 2358 2359 // Test that a constant with tuple shape becomes a tuple of constants. 2360 TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { 2361 HloComputation::Builder builder(TestName()); 2362 const float constant_scalar = 7.3f; 2363 std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f}; 2364 std::unique_ptr<Literal> value = 2365 Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(), 2366 Literal::CreateR1<float>(constant_vector).get()}); 2367 builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); 2368 2369 auto computation = module().AddEntryComputation(builder.Build()); 2370 2371 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2372 non_bitcasting_callback()); 2373 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 2374 EXPECT_THAT(computation->root_instruction(), 2375 op::Tuple(op::Constant(), op::Constant())); 2376 } 2377 2378 // A dynamic-slice is trivial if its start indices are all zeroes and the size 2379 // of its input equals the size of its output. In this case, the dynamic slice 2380 // is equal to its input. 2381 TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { 2382 HloComputation::Builder builder(TestName()); 2383 2384 Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); 2385 builder.AddInstruction(HloInstruction::CreateDynamicSlice( 2386 shape, 2387 builder.AddInstruction( 2388 HloInstruction::CreateParameter(0, shape, "slice_from")), 2389 builder.AddInstruction( 2390 HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))), 2391 /*slice_sizes=*/{10, 100, 1000})); 2392 2393 auto computation = module().AddEntryComputation(builder.Build()); 2394 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2395 non_bitcasting_callback()); 2396 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 2397 EXPECT_THAT(computation->root_instruction(), op::Parameter()); 2398 } 2399 2400 // A dynamic-update-slice is trivial if its start indices are all zeroes and the 2401 // size of its "update" equals the size of its output. In this case, the 2402 // dynamic-update-slice is equal to its update. 2403 TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { 2404 HloComputation::Builder builder(TestName()); 2405 2406 Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); 2407 Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); 2408 2409 HloInstruction* slice = 2410 builder.AddInstruction(HloInstruction::CreateDynamicSlice( 2411 slice_shape, 2412 builder.AddInstruction( 2413 HloInstruction::CreateParameter(0, full_shape, "slice_from")), 2414 builder.AddInstruction(HloInstruction::CreateParameter( 2415 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), 2416 /*slice_sizes=*/{10, 1, 1000})); 2417 2418 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( 2419 slice_shape, 2420 builder.AddInstruction( 2421 HloInstruction::CreateParameter(2, slice_shape, "to_update")), 2422 slice, 2423 builder.AddInstruction( 2424 HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))))); 2425 2426 auto computation = module().AddEntryComputation(builder.Build()); 2427 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2428 non_bitcasting_callback()); 2429 ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); 2430 EXPECT_THAT(computation->root_instruction(), 2431 op::DynamicSlice(op::Parameter(), op::Parameter())); 2432 } 2433 2434 struct PadReduceWindowEffectiveBroadcastCase { 2435 std::vector<int64> input_spatials; 2436 std::vector<int64> symmetric_pad_spatials; 2437 std::vector<int64> reduce_window_spatials; 2438 // Whether to use `B F S0 S1` form vs `B S0 S1 F` form. 2439 // 2440 // This doesn't test any different functionality but is useful for making sure 2441 // kBroadcast nodes are well formed. 2442 bool prepend_a; 2443 bool should_become_broadcast; 2444 2445 string ToTestCaseName() const { 2446 return tensorflow::strings::StrCat( 2447 tensorflow::str_util::Join(input_spatials, ","), ";", 2448 tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";", 2449 tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a, 2450 ";", should_become_broadcast); 2451 } 2452 }; 2453 2454 void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) { 2455 *os << c.ToTestCaseName(); 2456 } 2457 2458 class PadReduceWindowEffectiveBroadcastTest 2459 : public AlgebraicSimplifierTest, 2460 public ::testing::WithParamInterface< 2461 PadReduceWindowEffectiveBroadcastCase> {}; 2462 2463 TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { 2464 const auto& param = GetParam(); 2465 2466 // a and b are parallel bounds we can either turn into a B F S0 S1 or 2467 // `B S0 S1 F` kind of pattern. 2468 auto decorate_spatials = [¶m](tensorflow::gtl::ArraySlice<int64> spatials, 2469 int64 a, int64 b) { 2470 std::vector<int64> result; 2471 if (param.prepend_a) { 2472 result.push_back(a); 2473 } 2474 for (int64 s : spatials) { 2475 result.push_back(s); 2476 } 2477 if (!param.prepend_a) { 2478 result.push_back(a); 2479 } 2480 result.push_back(b); 2481 return result; 2482 }; 2483 2484 HloComputation::Builder builder(TestName()); 2485 const Shape input_shape = ShapeUtil::MakeShape( 2486 F32, decorate_spatials(param.input_spatials, 128, 2048)); 2487 HloInstruction* input = builder.AddInstruction( 2488 HloInstruction::CreateParameter(0, input_shape, "input")); 2489 2490 PaddingConfig padding = window_util::MakeSymmetricPadding( 2491 decorate_spatials(param.symmetric_pad_spatials, 0, 0)); 2492 TF_ASSERT_OK_AND_ASSIGN( 2493 const Shape pad_shape, 2494 ShapeInference::InferPadShape(input->shape(), 2495 ShapeUtil::MakeShape(F32, {}), padding)); 2496 HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( 2497 pad_shape, input, 2498 builder.AddInstruction( 2499 HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), 2500 padding)); 2501 2502 HloComputation* add_computation = nullptr; 2503 { 2504 HloComputation::Builder builder(TestName() + ".add"); 2505 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); 2506 HloInstruction* p0 = builder.AddInstruction( 2507 HloInstruction::CreateParameter(0, scalar_shape, "p0")); 2508 HloInstruction* p1 = builder.AddInstruction( 2509 HloInstruction::CreateParameter(1, scalar_shape, "p1")); 2510 builder.AddInstruction( 2511 HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); 2512 add_computation = module().AddEmbeddedComputation(builder.Build()); 2513 } 2514 2515 Window window = window_util::MakeWindow( 2516 decorate_spatials(param.reduce_window_spatials, 1, 1)); 2517 auto zero = builder.AddInstruction( 2518 HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))); 2519 TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, 2520 ShapeInference::InferReduceWindowShape( 2521 pad->shape(), zero->shape(), window, 2522 add_computation->ComputeProgramShape())); 2523 builder.AddInstruction(HloInstruction::CreateReduceWindow( 2524 output_shape, pad, zero, window, add_computation)); 2525 2526 auto computation = module().AddEntryComputation(builder.Build()); 2527 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2528 non_bitcasting_callback()); 2529 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); 2530 ASSERT_TRUE(run_successful); 2531 2532 EXPECT_TRUE( 2533 ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape)); 2534 2535 if (param.should_become_broadcast) { 2536 EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_)); 2537 } else { 2538 EXPECT_THAT(computation->root_instruction(), 2539 op::ReduceWindow(::testing::_, zero)); 2540 } 2541 } 2542 2543 const std::vector<PadReduceWindowEffectiveBroadcastCase>& 2544 PadReduceWindowEffectiveBroadcastCases() { 2545 static auto* cases = new std::vector<PadReduceWindowEffectiveBroadcastCase>{ 2546 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6}, 2547 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, 2548 /*should_become_broadcast=*/true}, // 2549 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6}, 2550 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false, 2551 /*should_become_broadcast=*/true}, // 2552 {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6}, 2553 /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, 2554 /*should_become_broadcast=*/false}, // 2555 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, 2556 /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true, 2557 /*should_become_broadcast=*/true}, // 2558 {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, 2559 /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true, 2560 /*should_become_broadcast=*/false}, // 2561 {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2}, 2562 /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true, 2563 /*should_become_broadcast=*/false}, // 2564 }; 2565 return *cases; 2566 } 2567 2568 INSTANTIATE_TEST_CASE_P( 2569 PadReduceWindowEffectiveBroadcastInstantiation, 2570 PadReduceWindowEffectiveBroadcastTest, 2571 ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); 2572 2573 class DotStrengthReductionTest 2574 : public AlgebraicSimplifierTest, 2575 public ::testing::WithParamInterface< 2576 ::testing::tuple<int, int, int, bool, bool>> {}; 2577 TEST_P(DotStrengthReductionTest, DotStrengthReduction) { 2578 int m, k, n; 2579 bool transpose_lhs, transpose_rhs; 2580 std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); 2581 2582 Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); 2583 Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); 2584 Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); 2585 Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); 2586 Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); 2587 HloComputation::Builder builder(TestName()); 2588 2589 auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( 2590 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs")); 2591 if (transpose_lhs) { 2592 lhs = builder.AddInstruction( 2593 HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0})); 2594 } 2595 auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( 2596 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs")); 2597 if (transpose_rhs) { 2598 rhs = builder.AddInstruction( 2599 HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0})); 2600 } 2601 DotDimensionNumbers dot_dnums; 2602 dot_dnums.add_lhs_contracting_dimensions(1); 2603 dot_dnums.add_rhs_contracting_dimensions(0); 2604 builder.AddInstruction( 2605 HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); 2606 auto computation = module().AddEntryComputation(builder.Build()); 2607 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2608 non_bitcasting_callback()); 2609 TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module())); 2610 const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; 2611 const bool computation_should_be_modified = 2612 dot_should_be_transformed || (transpose_lhs && transpose_rhs); 2613 EXPECT_EQ(changed, computation_should_be_modified); 2614 bool has_no_dot = true; 2615 for (const auto& hlo : computation->instructions()) { 2616 if (hlo->opcode() == HloOpcode::kDot) { 2617 has_no_dot = false; 2618 break; 2619 } 2620 } 2621 EXPECT_EQ(has_no_dot, dot_should_be_transformed); 2622 } 2623 2624 INSTANTIATE_TEST_CASE_P( 2625 DotStrengthReductionTestInstantiation, DotStrengthReductionTest, 2626 ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), 2627 ::testing::Values(1, 2), ::testing::Bool(), 2628 ::testing::Bool())); 2629 2630 struct DotOfConcatTestSpec { 2631 int64 m; 2632 int64 k; 2633 int64 n; 2634 }; 2635 2636 class DotOfConcatSimplificationTest 2637 : public HloVerifiedTestBase, 2638 public ::testing::WithParamInterface<DotOfConcatTestSpec> {}; 2639 2640 // Test that we transform 2641 // dot(const, concat(A, B, C)) 2642 // to 2643 // add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) 2644 TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { 2645 HloComputation::Builder builder(TestName()); 2646 2647 DotOfConcatTestSpec spec = GetParam(); 2648 2649 ASSERT_GE(spec.k, 3); 2650 2651 int64 k0 = spec.k / 3; 2652 int64 k1 = spec.k / 3; 2653 int64 k2 = spec.k - k0 - k1; 2654 2655 Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); 2656 auto* lhs = builder.AddInstruction( 2657 HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( 2658 /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k))); 2659 2660 Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n}); 2661 Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n}); 2662 Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n}); 2663 2664 HloInstruction* rhs0 = builder.AddInstruction( 2665 HloInstruction::CreateParameter(0, rhs0_shape, "rhs0")); 2666 HloInstruction* rhs1 = builder.AddInstruction( 2667 HloInstruction::CreateParameter(1, rhs1_shape, "rhs1")); 2668 HloInstruction* rhs2 = builder.AddInstruction( 2669 HloInstruction::CreateParameter(2, rhs2_shape, "rhs2")); 2670 2671 Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); 2672 HloInstruction* rhs = builder.AddInstruction( 2673 HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0)); 2674 2675 DotDimensionNumbers dot_dnums; 2676 dot_dnums.add_lhs_contracting_dimensions(1); 2677 dot_dnums.add_rhs_contracting_dimensions(0); 2678 2679 Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); 2680 builder.AddInstruction( 2681 HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); 2682 2683 auto computation = module().AddEntryComputation(builder.Build()); 2684 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2685 non_bitcasting_callback()); 2686 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); 2687 ASSERT_TRUE(run_successful); 2688 2689 EXPECT_TRUE( 2690 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); 2691 2692 auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); 2693 auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); 2694 auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); 2695 EXPECT_THAT(computation->root_instruction(), 2696 op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); 2697 } 2698 2699 // Test that we transform 2700 // dot(concat(A, B, C), const) 2701 // to 2702 // add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) 2703 TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { 2704 HloComputation::Builder builder(TestName()); 2705 2706 DotOfConcatTestSpec spec = GetParam(); 2707 2708 ASSERT_GE(spec.k, 4); 2709 2710 int64 k0 = spec.k / 4; 2711 int64 k1 = spec.k / 4; 2712 int64 k2 = spec.k / 4; 2713 int64 k3 = spec.k - k0 - k1 - k2; 2714 2715 Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0}); 2716 Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1}); 2717 Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2}); 2718 Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3}); 2719 2720 HloInstruction* lhs0 = builder.AddInstruction( 2721 HloInstruction::CreateParameter(0, lhs0_shape, "lhs0")); 2722 HloInstruction* lhs1 = builder.AddInstruction( 2723 HloInstruction::CreateParameter(1, lhs1_shape, "lhs1")); 2724 HloInstruction* lhs2 = builder.AddInstruction( 2725 HloInstruction::CreateParameter(2, lhs2_shape, "lhs2")); 2726 HloInstruction* lhs3 = builder.AddInstruction( 2727 HloInstruction::CreateParameter(3, lhs3_shape, "lhs3")); 2728 2729 Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); 2730 HloInstruction* lhs = 2731 builder.AddInstruction(HloInstruction::CreateConcatenate( 2732 lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1)); 2733 2734 Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); 2735 auto* rhs = builder.AddInstruction( 2736 HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( 2737 /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n))); 2738 2739 DotDimensionNumbers dot_dnums; 2740 dot_dnums.add_lhs_contracting_dimensions(1); 2741 dot_dnums.add_rhs_contracting_dimensions(0); 2742 2743 Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); 2744 builder.AddInstruction( 2745 HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); 2746 2747 auto computation = module().AddEntryComputation(builder.Build()); 2748 AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, 2749 non_bitcasting_callback()); 2750 TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); 2751 ASSERT_TRUE(run_successful); 2752 EXPECT_TRUE( 2753 ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); 2754 2755 auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); 2756 auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); 2757 auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); 2758 auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); 2759 EXPECT_THAT(computation->root_instruction(), 2760 op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), 2761 match_dot_3)); 2762 } 2763 2764 DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { 2765 {/*m=*/3, /*k=*/9, /*n=*/3}, // 2766 {/*m=*/3, /*k=*/20, /*n=*/3}, // 2767 {/*m=*/1, /*k=*/18, /*n=*/5}, // 2768 {/*m=*/20, /*k=*/20, /*n=*/1}, // 2769 {/*m=*/1, /*k=*/16, /*n=*/1}, // 2770 }; 2771 2772 INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, 2773 DotOfConcatSimplificationTest, 2774 ::testing::ValuesIn(kDotOfConcatTestSpecs)); 2775 } // namespace 2776 } // namespace xla 2777