1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/reshape_mover.h" 17 18 #include "tensorflow/compiler/xla/layout_util.h" 19 #include "tensorflow/compiler/xla/literal_util.h" 20 #include "tensorflow/compiler/xla/ptr_util.h" 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 24 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 25 #include "tensorflow/compiler/xla/shape_util.h" 26 #include "tensorflow/compiler/xla/test.h" 27 #include "tensorflow/compiler/xla/test_helpers.h" 28 #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/lib/strings/str_util.h" 32 33 namespace op = xla::testing::opcode_matchers; 34 35 namespace xla { 36 namespace { 37 using ReshapeMoverTest = HloVerifiedTestBase; 38 39 TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { 40 HloComputation::Builder builder(TestName()); 41 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); 42 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 43 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); 44 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 45 1, ShapeUtil::MakeShape(F32, {1, 8, 7, 1}), "param1")); 46 auto reshape0 = 47 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); 48 auto reshape1 = 49 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); 50 builder.AddInstruction(HloInstruction::CreateBinary( 51 root_shape, HloOpcode::kAdd, reshape0, reshape1)); 52 53 auto computation = module().AddEntryComputation(builder.Build()); 54 55 EXPECT_THAT(computation->root_instruction(), 56 op::Add(op::Reshape(param0), op::Reshape(param1))); 57 58 EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); 59 60 EXPECT_THAT(computation->root_instruction(), 61 op::Add(op::Reshape(param0), op::Reshape(param1))); 62 } 63 64 // For a graph that looks like: 65 // 66 // +- reshape0 - rng0 67 // | 68 // +- const1 69 // | 70 // add 71 // 72 // where rng0 has a different shape than reshape0. 73 // 74 // Verifies that the reshape is not moved, since rng0 is trivially reshapable 75 // and therefore there is no nontrivial reshapes to move. 76 TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { 77 HloComputation::Builder builder(TestName()); 78 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); 79 auto rng0 = builder.AddInstruction( 80 HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {1, 8, 1, 7, 1}), 81 RandomDistribution::RNG_UNIFORM, {})); 82 auto reshape0 = 83 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); 84 85 auto const1 = builder.AddInstruction( 86 HloInstruction::CreateConstant(Literal::CreateFromShape(root_shape))); 87 88 builder.AddInstruction(HloInstruction::CreateBinary( 89 root_shape, HloOpcode::kAdd, reshape0, const1)); 90 91 auto computation = module().AddEntryComputation(builder.Build()); 92 93 EXPECT_THAT(computation->root_instruction(), 94 op::Add(op::Reshape(rng0), const1)); 95 96 EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); 97 98 EXPECT_THAT(computation->root_instruction(), 99 op::Add(op::Reshape(rng0), const1)); 100 } 101 102 TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { 103 HloComputation::Builder builder(TestName()); 104 auto root_shape = ShapeUtil::MakeShape(F32, {}); 105 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 106 0, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param0")); 107 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 108 1, ShapeUtil::MakeShape(F32, {1, 1, 1}), "param1")); 109 auto reshape0 = 110 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); 111 auto reshape1 = 112 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); 113 builder.AddInstruction(HloInstruction::CreateBinary( 114 root_shape, HloOpcode::kAdd, reshape0, reshape1)); 115 116 auto computation = module().AddEntryComputation(builder.Build()); 117 118 EXPECT_THAT(computation->root_instruction(), 119 op::Add(op::Reshape(param0), op::Reshape(param1))); 120 121 EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); 122 123 EXPECT_THAT( 124 computation->root_instruction(), 125 op::Add(op::Reshape(op::Parameter()), op::Reshape(op::Parameter()))); 126 } 127 128 TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { 129 HloComputation::Builder builder(TestName()); 130 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); 131 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 132 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); 133 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 134 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); 135 auto reshape0 = 136 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); 137 auto reshape1 = 138 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); 139 builder.AddInstruction(HloInstruction::CreateBinary( 140 root_shape, HloOpcode::kAdd, reshape0, reshape1)); 141 142 auto computation = module().AddEntryComputation(builder.Build()); 143 144 EXPECT_THAT(computation->root_instruction(), 145 op::Add(op::Reshape(param0), op::Reshape(param1))); 146 EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); 147 148 EXPECT_THAT(computation->root_instruction(), 149 op::Reshape(op::Add(param0, param1))); 150 EXPECT_EQ(root_shape.DebugString(), 151 computation->root_instruction()->shape().DebugString()); 152 } 153 154 // For a graph that looks like: 155 // 156 // +- reshape2 - param2 157 // | 158 // +- reshape1 - param1 159 // | 160 // +- constant0 161 // | 162 // select 163 // 164 // Verifies that the reshape1 and reshape2 sink past select: 165 // 166 // +- param2 167 // | 168 // +- param1 169 // | 170 // +- reshape3(constant0) 171 // | 172 // select 173 // | 174 // reshape4 175 TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { 176 HloComputation::Builder builder(TestName()); 177 auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); 178 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 179 Literal::CreateR2<bool>({{true, true, false}, {false, false, true}}))); 180 181 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 182 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1")); 183 auto reshape1 = 184 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); 185 186 auto param2 = builder.AddInstruction(HloInstruction::CreateParameter( 187 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param2")); 188 auto reshape2 = 189 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param2)); 190 191 builder.AddInstruction(HloInstruction::CreateTernary( 192 root_shape, HloOpcode::kSelect, const0, reshape1, reshape2)); 193 194 auto computation = module().AddEntryComputation(builder.Build()); 195 196 EXPECT_THAT(computation->root_instruction(), 197 op::Select(const0, reshape1, reshape2)); 198 199 EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); 200 201 EXPECT_THAT(computation->root_instruction(), 202 op::Reshape(op::Select(op::Reshape(const0), param1, param2))); 203 204 EXPECT_EQ(root_shape.DebugString(), 205 computation->root_instruction()->shape().DebugString()); 206 } 207 208 // For a graph that looks like: 209 // 210 // +- reshape0 - param0 211 // | 212 // +- param1 213 // | 214 // add 215 // 216 // Verifies that the reshape0 does not sink below add, because param1 is not 217 // trivially reshapable nor is a Reshape/Transpose. 218 TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { 219 HloComputation::Builder builder(TestName()); 220 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); 221 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 222 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); 223 auto reshape0 = 224 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); 225 auto param1 = builder.AddInstruction( 226 HloInstruction::CreateParameter(1, root_shape, "param1")); 227 builder.AddInstruction(HloInstruction::CreateBinary( 228 root_shape, HloOpcode::kAdd, reshape0, param1)); 229 230 auto computation = module().AddEntryComputation(builder.Build()); 231 232 EXPECT_THAT(computation->root_instruction(), 233 op::Add(op::Reshape(param0), param1)); 234 EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); 235 236 EXPECT_THAT(computation->root_instruction(), 237 op::Add(op::Reshape(param0), param1)); 238 EXPECT_EQ(root_shape.DebugString(), 239 computation->root_instruction()->shape().DebugString()); 240 } 241 242 // For a graph that looks like: 243 // 244 // +- pred 245 // | 246 // +- reshape0 - const0 247 // | 248 // +- reshape1 - const1 249 // | 250 // select 251 // 252 // Verifies that we don't unnecessarily sink reshapes, which are in fact 253 // trivial reshapes. 254 TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { 255 HloComputation::Builder builder(TestName()); 256 auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); 257 auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( 258 Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}}))); 259 auto reshape0 = 260 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0)); 261 262 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( 263 Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}}))); 264 auto reshape1 = 265 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); 266 267 auto pred = builder.AddInstruction(HloInstruction::CreateParameter( 268 0, ShapeUtil::MakeShape(PRED, {3, 2}), "pred")); 269 270 builder.AddInstruction(HloInstruction::CreateTernary( 271 root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); 272 273 auto computation = module().AddEntryComputation(builder.Build()); 274 275 EXPECT_THAT(computation->root_instruction(), 276 op::Select(pred, op::Reshape(const0), op::Reshape(const1))); 277 278 EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); 279 280 EXPECT_THAT(computation->root_instruction(), 281 op::Select(pred, op::Reshape(const0), op::Reshape(const1))); 282 EXPECT_EQ(root_shape.DebugString(), 283 computation->root_instruction()->shape().DebugString()); 284 } 285 286 // For a graph that looks like: 287 // 288 // +- reshape0 - param0 289 // | 290 // +- const1 291 // | 292 // add 293 // 294 // where there is only 1 non-trivial reshape (reshape0), we sink the reshape 295 // here for canonicalization benefit: 296 // 297 // +- param0 298 // | 299 // +- reshape1 - const1 300 // | 301 // add 302 // | 303 // reshape2 304 // 305 // (note that reshape1 here is trivial). 306 TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { 307 HloComputation::Builder builder(TestName()); 308 auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); 309 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 310 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0")); 311 auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( 312 Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}}))); 313 auto reshape0 = 314 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); 315 builder.AddInstruction(HloInstruction::CreateBinary( 316 root_shape, HloOpcode::kAdd, reshape0, const1)); 317 318 auto computation = module().AddEntryComputation(builder.Build()); 319 320 EXPECT_THAT(computation->root_instruction(), 321 op::Add(op::Reshape(param0), const1)); 322 323 EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); 324 325 EXPECT_THAT(computation->root_instruction(), 326 op::Reshape(op::Add(param0, op::Reshape(const1)))); 327 EXPECT_EQ(root_shape.DebugString(), 328 computation->root_instruction()->shape().DebugString()); 329 } 330 331 // For a graph that looks like: 332 // 333 // +- reshape0 - param0 (shape A) 334 // | 335 // +- reshape1 - const1 (shape B) 336 // | 337 // add 338 // 339 // There is 1 non-trivial reshape (reshape0). It's not clear whether reshape1 340 // should be trivial or not; conceptually it's trivial, but handling it would 341 // complicate the rest of our logic. 342 // 343 // For now we treat it as non-trivial, so we verify that we don't sink the 344 // reshapes in this case. 345 TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { 346 HloComputation::Builder builder(TestName()); 347 auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); 348 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 349 0, ShapeUtil::MakeShape(F32, {1, 3}), "param0")); 350 auto const1 = builder.AddInstruction( 351 HloInstruction::CreateConstant(Literal::CreateR1<float>({9, 8, 7}))); 352 auto reshape0 = 353 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); 354 auto reshape1 = 355 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); 356 357 builder.AddInstruction(HloInstruction::CreateBinary( 358 root_shape, HloOpcode::kAdd, reshape0, reshape1)); 359 360 auto computation = module().AddEntryComputation(builder.Build()); 361 362 EXPECT_THAT(computation->root_instruction(), 363 op::Add(op::Reshape(param0), op::Reshape(const1))); 364 365 EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); 366 367 EXPECT_THAT(computation->root_instruction(), 368 op::Add(op::Reshape(param0), op::Reshape(const1))); 369 EXPECT_EQ(root_shape.DebugString(), 370 computation->root_instruction()->shape().DebugString()); 371 } 372 373 TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { 374 HloComputation::Builder builder(TestName()); 375 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); 376 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 377 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); 378 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 379 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); 380 auto reshape0 = 381 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); 382 auto reshape1 = 383 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); 384 auto add = builder.AddInstruction(HloInstruction::CreateBinary( 385 root_shape, HloOpcode::kAdd, reshape0, reshape1)); 386 387 auto computation = module().AddEntryComputation(builder.Build()); 388 computation->CreateFusionInstruction({add}, 389 HloInstruction::FusionKind::kLoop); 390 391 EXPECT_THAT(computation->root_instruction(), 392 op::Fusion(op::Reshape(param0), op::Reshape(param1))); 393 394 EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); 395 396 EXPECT_THAT(computation->root_instruction(), 397 op::Reshape(op::Fusion(param0, param1))); 398 EXPECT_EQ(root_shape.DebugString(), 399 computation->root_instruction()->shape().DebugString()); 400 } 401 402 TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { 403 HloComputation::Builder builder(TestName()); 404 auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); 405 auto pred_shape = ShapeUtil::MakeShape(PRED, {8, 7}); 406 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 407 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); 408 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 409 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); 410 auto pred = builder.AddInstruction(HloInstruction::CreateParameter( 411 2, ShapeUtil::MakeShape(PRED, {1, 8, 1, 7}), "pred")); 412 auto reshape0 = 413 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); 414 auto reshape1 = 415 builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1)); 416 auto reshape_pred = 417 builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred)); 418 builder.AddInstruction(HloInstruction::CreateTernary( 419 root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); 420 421 auto computation = module().AddEntryComputation(builder.Build()); 422 423 EXPECT_THAT( 424 computation->root_instruction(), 425 op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); 426 427 EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); 428 429 EXPECT_THAT(computation->root_instruction(), 430 op::Reshape(op::Select(pred, param0, param1))); 431 EXPECT_EQ(root_shape.DebugString(), 432 computation->root_instruction()->shape().DebugString()); 433 } 434 435 TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { 436 HloComputation::Builder builder(TestName()); 437 auto root_shape = ShapeUtil::MakeShape(F32, {}); 438 auto pred_shape = ShapeUtil::MakeShape(PRED, {}); 439 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 440 0, ShapeUtil::MakeShape(F32, {}), "param0")); 441 auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 442 1, ShapeUtil::MakeShape(F32, {}), "param1")); 443 auto pred = builder.AddInstruction(HloInstruction::CreateParameter( 444 2, ShapeUtil::MakeShape(PRED, {1, 1, 1}), "pred")); 445 auto reshape_pred = 446 builder.AddInstruction(HloInstruction::CreateReshape(pred_shape, pred)); 447 auto select = builder.AddInstruction(HloInstruction::CreateTernary( 448 root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); 449 450 auto computation = module().AddEntryComputation(builder.Build()); 451 EXPECT_THAT(computation->root_instruction(), 452 op::Select(op::Reshape(pred), param0, param1)); 453 454 EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); 455 456 EXPECT_THAT(computation->root_instruction(), 457 op::Select(op::Reshape(pred), param0, param1)); 458 EXPECT_EQ(select, computation->root_instruction()); 459 } 460 461 // Tree looks like: 462 // 463 // param0 [1,128,1] 464 // | 465 // reshape [128,1] constant [128,1024] 466 // \ / 467 // multiply w/implicit broadcast [128,1024] 468 // 469 // The reshape mover would like to sink the reshape below the multiply. 470 // 471 // Previously we would attempt to insert a reshape of the constant to [1,128,1] 472 // (which is unsound, because it has a different number of elements) as 473 // preparation for sinking the reshape. 474 // 475 // To eliminate the unsoundness, we outlaw reshape sinking when one of the 476 // operands is implicitly broadcast in the elementwise consumer. 477 // 478 // TODO(b/37799338) However, it would be possible in this case to do a more 479 // in-depth analysis to get reshape movement to occur: 480 // 481 // 1. Note that the broadcast dimension (logical dimension 1) in the operands 482 // would map back to logical dimension 2 in the param0 node. 483 // 2. Match rank of the constant to the param0 node (by prepending a trivial 1 484 // dimension). 485 // 3. Reshape to [128,1024] at the root. 486 // 487 // But this is not currently done. 488 TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { 489 HloComputation::Builder builder(TestName()); 490 auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 491 0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0")); 492 auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( 493 ShapeUtil::MakeShape(F32, {128, 1}), param0)); 494 Array2D<float> a(128, 1024); 495 auto literal = Literal::CreateR2FromArray2D<float>(a); 496 auto constant = builder.AddInstruction( 497 HloInstruction::CreateConstant(std::move(literal))); 498 auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( 499 constant->shape(), HloOpcode::kMultiply, constant, reshape)); 500 501 auto computation = module().AddEntryComputation(builder.Build()); 502 EXPECT_THAT(computation->root_instruction(), 503 op::Multiply(op::Constant(), op::Reshape(param0))); 504 505 EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); 506 507 EXPECT_THAT(computation->root_instruction(), 508 op::Multiply(op::Constant(), op::Reshape(param0))); 509 EXPECT_EQ(multiply, computation->root_instruction()); 510 } 511 512 // Tree looks like this: 513 // 514 // add1 515 // | 516 // +- reshape2 - param2 517 // | 518 // +- reshape3 - add0 519 // | 520 // + reshape0 - param0 521 // | 522 // + reshape1 - param1 523 // 524 // We expect reshape{0,1} AND reshape{2,3} to be lifted. 525 TEST_F(ReshapeMoverTest, MultiplePasses) { 526 auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7}); 527 auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1}); 528 auto shape3 = ShapeUtil::MakeShape(F32, {8, 7}); 529 HloComputation::Builder builder(TestName()); 530 auto param0 = builder.AddInstruction( 531 HloInstruction::CreateParameter(0, shape1, "param0")); 532 auto param1 = builder.AddInstruction( 533 HloInstruction::CreateParameter(1, shape1, "param1")); 534 auto param2 = builder.AddInstruction( 535 HloInstruction::CreateParameter(2, shape2, "param2")); 536 auto reshape0 = 537 builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0)); 538 auto reshape1 = 539 builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1)); 540 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( 541 shape2, HloOpcode::kAdd, reshape0, reshape1)); 542 auto reshape2 = 543 builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2)); 544 auto reshape3 = 545 builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0)); 546 builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, 547 reshape2, reshape3)); 548 549 auto computation = module().AddEntryComputation(builder.Build()); 550 551 EXPECT_THAT( 552 computation->root_instruction(), 553 op::Add(op::Reshape(param2), 554 op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); 555 556 EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); 557 558 EXPECT_THAT( 559 computation->root_instruction(), 560 op::Reshape(op::Add(param2, op::Reshape(op::Add(param0, param1))))); 561 } 562 563 } // namespace 564 } // namespace xla 565