Home | History | Annotate | Download | only in loop_optimizations
      1 // Copyright (c) 2017 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include <memory>
     16 #include <string>
     17 #include <unordered_set>
     18 #include <vector>
     19 
     20 #include "gmock/gmock.h"
     21 #include "source/opt/iterator.h"
     22 #include "source/opt/loop_descriptor.h"
     23 #include "source/opt/pass.h"
     24 #include "source/opt/tree_iterator.h"
     25 #include "test/opt/assembly_builder.h"
     26 #include "test/opt/function_utils.h"
     27 #include "test/opt/pass_fixture.h"
     28 #include "test/opt/pass_utils.h"
     29 
     30 namespace spvtools {
     31 namespace opt {
     32 namespace {
     33 
     34 using ::testing::UnorderedElementsAre;
     35 
     36 bool Validate(const std::vector<uint32_t>& bin) {
     37   spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
     38   spv_context spvContext = spvContextCreate(target_env);
     39   spv_diagnostic diagnostic = nullptr;
     40   spv_const_binary_t binary = {bin.data(), bin.size()};
     41   spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
     42   if (error != 0) spvDiagnosticPrint(diagnostic);
     43   spvDiagnosticDestroy(diagnostic);
     44   spvContextDestroy(spvContext);
     45   return error == 0;
     46 }
     47 
     48 using PassClassTest = PassTest<::testing::Test>;
     49 
     50 /*
     51 Generated from the following GLSL
     52 #version 330 core
     53 layout(location = 0) out vec4 c;
     54 void main() {
     55   int i = 0;
     56   for (; i < 10; ++i) {
     57     int j = 0;
     58     int k = 0;
     59     for (; j < 11; ++j) {}
     60     for (; k < 12; ++k) {}
     61   }
     62 }
     63 */
     64 TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
     65   const std::string text = R"(
     66                OpCapability Shader
     67           %1 = OpExtInstImport "GLSL.std.450"
     68                OpMemoryModel Logical GLSL450
     69                OpEntryPoint Fragment %2 "main" %3
     70                OpExecutionMode %2 OriginUpperLeft
     71                OpSource GLSL 330
     72                OpName %2 "main"
     73                OpName %4 "i"
     74                OpName %5 "j"
     75                OpName %6 "k"
     76                OpName %3 "c"
     77                OpDecorate %3 Location 0
     78           %7 = OpTypeVoid
     79           %8 = OpTypeFunction %7
     80           %9 = OpTypeInt 32 1
     81          %10 = OpTypePointer Function %9
     82          %11 = OpConstant %9 0
     83          %12 = OpConstant %9 10
     84          %13 = OpTypeBool
     85          %14 = OpConstant %9 11
     86          %15 = OpConstant %9 1
     87          %16 = OpConstant %9 12
     88          %17 = OpTypeFloat 32
     89          %18 = OpTypeVector %17 4
     90          %19 = OpTypePointer Output %18
     91           %3 = OpVariable %19 Output
     92           %2 = OpFunction %7 None %8
     93          %20 = OpLabel
     94           %4 = OpVariable %10 Function
     95           %5 = OpVariable %10 Function
     96           %6 = OpVariable %10 Function
     97                OpStore %4 %11
     98                OpBranch %21
     99          %21 = OpLabel
    100                OpLoopMerge %22 %23 None
    101                OpBranch %24
    102          %24 = OpLabel
    103          %25 = OpLoad %9 %4
    104          %26 = OpSLessThan %13 %25 %12
    105                OpBranchConditional %26 %27 %22
    106          %27 = OpLabel
    107                OpStore %5 %11
    108                OpStore %6 %11
    109                OpBranch %28
    110          %28 = OpLabel
    111                OpLoopMerge %29 %30 None
    112                OpBranch %31
    113          %31 = OpLabel
    114          %32 = OpLoad %9 %5
    115          %33 = OpSLessThan %13 %32 %14
    116                OpBranchConditional %33 %34 %29
    117          %34 = OpLabel
    118                OpBranch %30
    119          %30 = OpLabel
    120          %35 = OpLoad %9 %5
    121          %36 = OpIAdd %9 %35 %15
    122                OpStore %5 %36
    123                OpBranch %28
    124          %29 = OpLabel
    125                OpBranch %37
    126          %37 = OpLabel
    127                OpLoopMerge %38 %39 None
    128                OpBranch %40
    129          %40 = OpLabel
    130          %41 = OpLoad %9 %6
    131          %42 = OpSLessThan %13 %41 %16
    132                OpBranchConditional %42 %43 %38
    133          %43 = OpLabel
    134                OpBranch %39
    135          %39 = OpLabel
    136          %44 = OpLoad %9 %6
    137          %45 = OpIAdd %9 %44 %15
    138                OpStore %6 %45
    139                OpBranch %37
    140          %38 = OpLabel
    141                OpBranch %23
    142          %23 = OpLabel
    143          %46 = OpLoad %9 %4
    144          %47 = OpIAdd %9 %46 %15
    145                OpStore %4 %47
    146                OpBranch %21
    147          %22 = OpLabel
    148                OpReturn
    149                OpFunctionEnd
    150   )";
    151   // clang-format on
    152   std::unique_ptr<IRContext> context =
    153       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
    154                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
    155   Module* module = context->module();
    156   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
    157                              << text << std::endl;
    158   const Function* f = spvtest::GetFunction(module, 2);
    159   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
    160 
    161   EXPECT_EQ(ld.NumLoops(), 3u);
    162 
    163   // Invalid basic block id.
    164   EXPECT_EQ(ld[0u], nullptr);
    165   // Not a loop header.
    166   EXPECT_EQ(ld[20], nullptr);
    167 
    168   Loop& parent_loop = *ld[21];
    169   EXPECT_TRUE(parent_loop.HasNestedLoops());
    170   EXPECT_FALSE(parent_loop.IsNested());
    171   EXPECT_EQ(parent_loop.GetDepth(), 1u);
    172   EXPECT_EQ(std::distance(parent_loop.begin(), parent_loop.end()), 2u);
    173   EXPECT_EQ(parent_loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 21));
    174   EXPECT_EQ(parent_loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 23));
    175   EXPECT_EQ(parent_loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 22));
    176 
    177   Loop& child_loop_1 = *ld[28];
    178   EXPECT_FALSE(child_loop_1.HasNestedLoops());
    179   EXPECT_TRUE(child_loop_1.IsNested());
    180   EXPECT_EQ(child_loop_1.GetDepth(), 2u);
    181   EXPECT_EQ(std::distance(child_loop_1.begin(), child_loop_1.end()), 0u);
    182   EXPECT_EQ(child_loop_1.GetHeaderBlock(), spvtest::GetBasicBlock(f, 28));
    183   EXPECT_EQ(child_loop_1.GetLatchBlock(), spvtest::GetBasicBlock(f, 30));
    184   EXPECT_EQ(child_loop_1.GetMergeBlock(), spvtest::GetBasicBlock(f, 29));
    185 
    186   Loop& child_loop_2 = *ld[37];
    187   EXPECT_FALSE(child_loop_2.HasNestedLoops());
    188   EXPECT_TRUE(child_loop_2.IsNested());
    189   EXPECT_EQ(child_loop_2.GetDepth(), 2u);
    190   EXPECT_EQ(std::distance(child_loop_2.begin(), child_loop_2.end()), 0u);
    191   EXPECT_EQ(child_loop_2.GetHeaderBlock(), spvtest::GetBasicBlock(f, 37));
    192   EXPECT_EQ(child_loop_2.GetLatchBlock(), spvtest::GetBasicBlock(f, 39));
    193   EXPECT_EQ(child_loop_2.GetMergeBlock(), spvtest::GetBasicBlock(f, 38));
    194 }
    195 
    196 static void CheckLoopBlocks(Loop* loop,
    197                             std::unordered_set<uint32_t>* expected_ids) {
    198   SCOPED_TRACE("Check loop " + std::to_string(loop->GetHeaderBlock()->id()));
    199   for (uint32_t bb_id : loop->GetBlocks()) {
    200     EXPECT_EQ(expected_ids->count(bb_id), 1u);
    201     expected_ids->erase(bb_id);
    202   }
    203   EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
    204   EXPECT_EQ(expected_ids->size(), 0u);
    205 }
    206 
    207 /*
    208 Generated from the following GLSL
    209 #version 330 core
    210 layout(location = 0) out vec4 c;
    211 void main() {
    212   int i = 0;
    213   for (; i < 10; ++i) {
    214     for (int j = 0; j < 11; ++j) {
    215       if (j < 5) {
    216         for (int k = 0; k < 12; ++k) {}
    217       }
    218       else {}
    219       for (int k = 0; k < 12; ++k) {}
    220     }
    221   }
    222 }*/
    223 TEST_F(PassClassTest, TripleNestedLoop) {
    224   const std::string text = R"(
    225                OpCapability Shader
    226           %1 = OpExtInstImport "GLSL.std.450"
    227                OpMemoryModel Logical GLSL450
    228                OpEntryPoint Fragment %2 "main" %3
    229                OpExecutionMode %2 OriginUpperLeft
    230                OpSource GLSL 330
    231                OpName %2 "main"
    232                OpName %4 "i"
    233                OpName %5 "j"
    234                OpName %6 "k"
    235                OpName %7 "k"
    236                OpName %3 "c"
    237                OpDecorate %3 Location 0
    238           %8 = OpTypeVoid
    239           %9 = OpTypeFunction %8
    240          %10 = OpTypeInt 32 1
    241          %11 = OpTypePointer Function %10
    242          %12 = OpConstant %10 0
    243          %13 = OpConstant %10 10
    244          %14 = OpTypeBool
    245          %15 = OpConstant %10 11
    246          %16 = OpConstant %10 5
    247          %17 = OpConstant %10 12
    248          %18 = OpConstant %10 1
    249          %19 = OpTypeFloat 32
    250          %20 = OpTypeVector %19 4
    251          %21 = OpTypePointer Output %20
    252           %3 = OpVariable %21 Output
    253           %2 = OpFunction %8 None %9
    254          %22 = OpLabel
    255           %4 = OpVariable %11 Function
    256           %5 = OpVariable %11 Function
    257           %6 = OpVariable %11 Function
    258           %7 = OpVariable %11 Function
    259                OpStore %4 %12
    260                OpBranch %23
    261          %23 = OpLabel
    262                OpLoopMerge %24 %25 None
    263                OpBranch %26
    264          %26 = OpLabel
    265          %27 = OpLoad %10 %4
    266          %28 = OpSLessThan %14 %27 %13
    267                OpBranchConditional %28 %29 %24
    268          %29 = OpLabel
    269                OpStore %5 %12
    270                OpBranch %30
    271          %30 = OpLabel
    272                OpLoopMerge %31 %32 None
    273                OpBranch %33
    274          %33 = OpLabel
    275          %34 = OpLoad %10 %5
    276          %35 = OpSLessThan %14 %34 %15
    277                OpBranchConditional %35 %36 %31
    278          %36 = OpLabel
    279          %37 = OpLoad %10 %5
    280          %38 = OpSLessThan %14 %37 %16
    281                OpSelectionMerge %39 None
    282                OpBranchConditional %38 %40 %39
    283          %40 = OpLabel
    284                OpStore %6 %12
    285                OpBranch %41
    286          %41 = OpLabel
    287                OpLoopMerge %42 %43 None
    288                OpBranch %44
    289          %44 = OpLabel
    290          %45 = OpLoad %10 %6
    291          %46 = OpSLessThan %14 %45 %17
    292                OpBranchConditional %46 %47 %42
    293          %47 = OpLabel
    294                OpBranch %43
    295          %43 = OpLabel
    296          %48 = OpLoad %10 %6
    297          %49 = OpIAdd %10 %48 %18
    298                OpStore %6 %49
    299                OpBranch %41
    300          %42 = OpLabel
    301                OpBranch %39
    302          %39 = OpLabel
    303                OpStore %7 %12
    304                OpBranch %50
    305          %50 = OpLabel
    306                OpLoopMerge %51 %52 None
    307                OpBranch %53
    308          %53 = OpLabel
    309          %54 = OpLoad %10 %7
    310          %55 = OpSLessThan %14 %54 %17
    311                OpBranchConditional %55 %56 %51
    312          %56 = OpLabel
    313                OpBranch %52
    314          %52 = OpLabel
    315          %57 = OpLoad %10 %7
    316          %58 = OpIAdd %10 %57 %18
    317                OpStore %7 %58
    318                OpBranch %50
    319          %51 = OpLabel
    320                OpBranch %32
    321          %32 = OpLabel
    322          %59 = OpLoad %10 %5
    323          %60 = OpIAdd %10 %59 %18
    324                OpStore %5 %60
    325                OpBranch %30
    326          %31 = OpLabel
    327                OpBranch %25
    328          %25 = OpLabel
    329          %61 = OpLoad %10 %4
    330          %62 = OpIAdd %10 %61 %18
    331                OpStore %4 %62
    332                OpBranch %23
    333          %24 = OpLabel
    334                OpReturn
    335                OpFunctionEnd
    336   )";
    337   // clang-format on
    338   std::unique_ptr<IRContext> context =
    339       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
    340                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
    341   Module* module = context->module();
    342   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
    343                              << text << std::endl;
    344   const Function* f = spvtest::GetFunction(module, 2);
    345   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
    346 
    347   EXPECT_EQ(ld.NumLoops(), 4u);
    348 
    349   // Invalid basic block id.
    350   EXPECT_EQ(ld[0u], nullptr);
    351   // Not in a loop.
    352   EXPECT_EQ(ld[22], nullptr);
    353 
    354   // Check that we can map basic block to the correct loop.
    355   // The following block ids do not belong to a loop.
    356   for (uint32_t bb_id : {22, 24}) EXPECT_EQ(ld[bb_id], nullptr);
    357 
    358   {
    359     std::unordered_set<uint32_t> basic_block_in_loop = {
    360         {23, 26, 29, 30, 33, 36, 40, 41, 44, 47, 43,
    361          42, 39, 50, 53, 56, 52, 51, 32, 31, 25}};
    362     Loop* loop = ld[23];
    363     CheckLoopBlocks(loop, &basic_block_in_loop);
    364 
    365     EXPECT_TRUE(loop->HasNestedLoops());
    366     EXPECT_FALSE(loop->IsNested());
    367     EXPECT_EQ(loop->GetDepth(), 1u);
    368     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 1u);
    369     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 22));
    370     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 23));
    371     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 25));
    372     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 24));
    373     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
    374     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
    375   }
    376 
    377   {
    378     std::unordered_set<uint32_t> basic_block_in_loop = {
    379         {30, 33, 36, 40, 41, 44, 47, 43, 42, 39, 50, 53, 56, 52, 51, 32}};
    380     Loop* loop = ld[30];
    381     CheckLoopBlocks(loop, &basic_block_in_loop);
    382 
    383     EXPECT_TRUE(loop->HasNestedLoops());
    384     EXPECT_TRUE(loop->IsNested());
    385     EXPECT_EQ(loop->GetDepth(), 2u);
    386     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 2u);
    387     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 29));
    388     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 30));
    389     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 32));
    390     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 31));
    391     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
    392     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
    393   }
    394 
    395   {
    396     std::unordered_set<uint32_t> basic_block_in_loop = {{41, 44, 47, 43}};
    397     Loop* loop = ld[41];
    398     CheckLoopBlocks(loop, &basic_block_in_loop);
    399 
    400     EXPECT_FALSE(loop->HasNestedLoops());
    401     EXPECT_TRUE(loop->IsNested());
    402     EXPECT_EQ(loop->GetDepth(), 3u);
    403     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
    404     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 40));
    405     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 41));
    406     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 43));
    407     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 42));
    408     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
    409     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
    410   }
    411 
    412   {
    413     std::unordered_set<uint32_t> basic_block_in_loop = {{50, 53, 56, 52}};
    414     Loop* loop = ld[50];
    415     CheckLoopBlocks(loop, &basic_block_in_loop);
    416 
    417     EXPECT_FALSE(loop->HasNestedLoops());
    418     EXPECT_TRUE(loop->IsNested());
    419     EXPECT_EQ(loop->GetDepth(), 3u);
    420     EXPECT_EQ(std::distance(loop->begin(), loop->end()), 0u);
    421     EXPECT_EQ(loop->GetPreHeaderBlock(), spvtest::GetBasicBlock(f, 39));
    422     EXPECT_EQ(loop->GetHeaderBlock(), spvtest::GetBasicBlock(f, 50));
    423     EXPECT_EQ(loop->GetLatchBlock(), spvtest::GetBasicBlock(f, 52));
    424     EXPECT_EQ(loop->GetMergeBlock(), spvtest::GetBasicBlock(f, 51));
    425     EXPECT_FALSE(loop->IsInsideLoop(loop->GetMergeBlock()));
    426     EXPECT_FALSE(loop->IsInsideLoop(loop->GetPreHeaderBlock()));
    427   }
    428 
    429   // Make sure LoopDescriptor gives us the inner most loop when we query for
    430   // loops.
    431   for (const BasicBlock& bb : *f) {
    432     if (Loop* loop = ld[&bb]) {
    433       for (Loop& sub_loop :
    434            make_range(++TreeDFIterator<Loop>(loop), TreeDFIterator<Loop>())) {
    435         EXPECT_FALSE(sub_loop.IsInsideLoop(bb.id()));
    436       }
    437     }
    438   }
    439 }
    440 
    441 /*
    442 Generated from the following GLSL
    443 #version 330 core
    444 layout(location = 0) out vec4 c;
    445 void main() {
    446   for (int i = 0; i < 10; ++i) {
    447     for (int j = 0; j < 11; ++j) {
    448       for (int k = 0; k < 11; ++k) {}
    449     }
    450     for (int k = 0; k < 12; ++k) {}
    451   }
    452 }
    453 */
    454 TEST_F(PassClassTest, LoopParentTest) {
    455   const std::string text = R"(
    456                OpCapability Shader
    457           %1 = OpExtInstImport "GLSL.std.450"
    458                OpMemoryModel Logical GLSL450
    459                OpEntryPoint Fragment %2 "main" %3
    460                OpExecutionMode %2 OriginUpperLeft
    461                OpSource GLSL 330
    462                OpName %2 "main"
    463                OpName %4 "i"
    464                OpName %5 "j"
    465                OpName %6 "k"
    466                OpName %7 "k"
    467                OpName %3 "c"
    468                OpDecorate %3 Location 0
    469           %8 = OpTypeVoid
    470           %9 = OpTypeFunction %8
    471          %10 = OpTypeInt 32 1
    472          %11 = OpTypePointer Function %10
    473          %12 = OpConstant %10 0
    474          %13 = OpConstant %10 10
    475          %14 = OpTypeBool
    476          %15 = OpConstant %10 11
    477          %16 = OpConstant %10 1
    478          %17 = OpConstant %10 12
    479          %18 = OpTypeFloat 32
    480          %19 = OpTypeVector %18 4
    481          %20 = OpTypePointer Output %19
    482           %3 = OpVariable %20 Output
    483           %2 = OpFunction %8 None %9
    484          %21 = OpLabel
    485           %4 = OpVariable %11 Function
    486           %5 = OpVariable %11 Function
    487           %6 = OpVariable %11 Function
    488           %7 = OpVariable %11 Function
    489                OpStore %4 %12
    490                OpBranch %22
    491          %22 = OpLabel
    492                OpLoopMerge %23 %24 None
    493                OpBranch %25
    494          %25 = OpLabel
    495          %26 = OpLoad %10 %4
    496          %27 = OpSLessThan %14 %26 %13
    497                OpBranchConditional %27 %28 %23
    498          %28 = OpLabel
    499                OpStore %5 %12
    500                OpBranch %29
    501          %29 = OpLabel
    502                OpLoopMerge %30 %31 None
    503                OpBranch %32
    504          %32 = OpLabel
    505          %33 = OpLoad %10 %5
    506          %34 = OpSLessThan %14 %33 %15
    507                OpBranchConditional %34 %35 %30
    508          %35 = OpLabel
    509                OpStore %6 %12
    510                OpBranch %36
    511          %36 = OpLabel
    512                OpLoopMerge %37 %38 None
    513                OpBranch %39
    514          %39 = OpLabel
    515          %40 = OpLoad %10 %6
    516          %41 = OpSLessThan %14 %40 %15
    517                OpBranchConditional %41 %42 %37
    518          %42 = OpLabel
    519                OpBranch %38
    520          %38 = OpLabel
    521          %43 = OpLoad %10 %6
    522          %44 = OpIAdd %10 %43 %16
    523                OpStore %6 %44
    524                OpBranch %36
    525          %37 = OpLabel
    526                OpBranch %31
    527          %31 = OpLabel
    528          %45 = OpLoad %10 %5
    529          %46 = OpIAdd %10 %45 %16
    530                OpStore %5 %46
    531                OpBranch %29
    532          %30 = OpLabel
    533                OpStore %7 %12
    534                OpBranch %47
    535          %47 = OpLabel
    536                OpLoopMerge %48 %49 None
    537                OpBranch %50
    538          %50 = OpLabel
    539          %51 = OpLoad %10 %7
    540          %52 = OpSLessThan %14 %51 %17
    541                OpBranchConditional %52 %53 %48
    542          %53 = OpLabel
    543                OpBranch %49
    544          %49 = OpLabel
    545          %54 = OpLoad %10 %7
    546          %55 = OpIAdd %10 %54 %16
    547                OpStore %7 %55
    548                OpBranch %47
    549          %48 = OpLabel
    550                OpBranch %24
    551          %24 = OpLabel
    552          %56 = OpLoad %10 %4
    553          %57 = OpIAdd %10 %56 %16
    554                OpStore %4 %57
    555                OpBranch %22
    556          %23 = OpLabel
    557                OpReturn
    558                OpFunctionEnd
    559   )";
    560   // clang-format on
    561   std::unique_ptr<IRContext> context =
    562       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
    563                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
    564   Module* module = context->module();
    565   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
    566                              << text << std::endl;
    567   const Function* f = spvtest::GetFunction(module, 2);
    568   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
    569 
    570   EXPECT_EQ(ld.NumLoops(), 4u);
    571 
    572   {
    573     Loop& loop = *ld[22];
    574     EXPECT_TRUE(loop.HasNestedLoops());
    575     EXPECT_FALSE(loop.IsNested());
    576     EXPECT_EQ(loop.GetDepth(), 1u);
    577     EXPECT_EQ(loop.GetParent(), nullptr);
    578   }
    579 
    580   {
    581     Loop& loop = *ld[29];
    582     EXPECT_TRUE(loop.HasNestedLoops());
    583     EXPECT_TRUE(loop.IsNested());
    584     EXPECT_EQ(loop.GetDepth(), 2u);
    585     EXPECT_EQ(loop.GetParent(), ld[22]);
    586   }
    587 
    588   {
    589     Loop& loop = *ld[36];
    590     EXPECT_FALSE(loop.HasNestedLoops());
    591     EXPECT_TRUE(loop.IsNested());
    592     EXPECT_EQ(loop.GetDepth(), 3u);
    593     EXPECT_EQ(loop.GetParent(), ld[29]);
    594   }
    595 
    596   {
    597     Loop& loop = *ld[47];
    598     EXPECT_FALSE(loop.HasNestedLoops());
    599     EXPECT_TRUE(loop.IsNested());
    600     EXPECT_EQ(loop.GetDepth(), 2u);
    601     EXPECT_EQ(loop.GetParent(), ld[22]);
    602   }
    603 }
    604 
    605 /*
    606 Generated from the following GLSL + --eliminate-local-multi-store
    607 The preheader of loop %33 and %41 were removed as well.
    608 
    609 #version 330 core
    610 void main() {
    611   int a = 0;
    612   for (int i = 0; i < 10; ++i) {
    613     if (i == 0) {
    614       a = 1;
    615     } else {
    616       a = 2;
    617     }
    618     for (int j = 0; j < 11; ++j) {
    619       a++;
    620     }
    621   }
    622   for (int k = 0; k < 12; ++k) {}
    623 }
    624 */
    625 TEST_F(PassClassTest, CreatePreheaderTest) {
    626   const std::string text = R"(
    627                OpCapability Shader
    628           %1 = OpExtInstImport "GLSL.std.450"
    629                OpMemoryModel Logical GLSL450
    630                OpEntryPoint Fragment %2 "main"
    631                OpExecutionMode %2 OriginUpperLeft
    632                OpSource GLSL 330
    633                OpName %2 "main"
    634           %3 = OpTypeVoid
    635           %4 = OpTypeFunction %3
    636           %5 = OpTypeInt 32 1
    637           %6 = OpTypePointer Function %5
    638           %7 = OpConstant %5 0
    639           %8 = OpConstant %5 10
    640           %9 = OpTypeBool
    641          %10 = OpConstant %5 1
    642          %11 = OpConstant %5 2
    643          %12 = OpConstant %5 11
    644          %13 = OpConstant %5 12
    645          %14 = OpUndef %5
    646           %2 = OpFunction %3 None %4
    647          %15 = OpLabel
    648                OpBranch %16
    649          %16 = OpLabel
    650          %17 = OpPhi %5 %7 %15 %18 %19
    651          %20 = OpPhi %5 %7 %15 %21 %19
    652          %22 = OpPhi %5 %14 %15 %23 %19
    653                OpLoopMerge %41 %19 None
    654                OpBranch %25
    655          %25 = OpLabel
    656          %26 = OpSLessThan %9 %20 %8
    657                OpBranchConditional %26 %27 %41
    658          %27 = OpLabel
    659          %28 = OpIEqual %9 %20 %7
    660                OpSelectionMerge %33 None
    661                OpBranchConditional %28 %30 %31
    662          %30 = OpLabel
    663                OpBranch %33
    664          %31 = OpLabel
    665                OpBranch %33
    666          %33 = OpLabel
    667          %18 = OpPhi %5 %10 %30 %11 %31 %34 %35
    668          %23 = OpPhi %5 %7 %30 %7 %31 %36 %35
    669                OpLoopMerge %37 %35 None
    670                OpBranch %38
    671          %38 = OpLabel
    672          %39 = OpSLessThan %9 %23 %12
    673                OpBranchConditional %39 %40 %37
    674          %40 = OpLabel
    675          %34 = OpIAdd %5 %18 %10
    676                OpBranch %35
    677          %35 = OpLabel
    678          %36 = OpIAdd %5 %23 %10
    679                OpBranch %33
    680          %37 = OpLabel
    681                OpBranch %19
    682          %19 = OpLabel
    683          %21 = OpIAdd %5 %20 %10
    684                OpBranch %16
    685          %41 = OpLabel
    686          %42 = OpPhi %5 %7 %25 %43 %44
    687                OpLoopMerge %45 %44 None
    688                OpBranch %46
    689          %46 = OpLabel
    690          %47 = OpSLessThan %9 %42 %13
    691                OpBranchConditional %47 %48 %45
    692          %48 = OpLabel
    693                OpBranch %44
    694          %44 = OpLabel
    695          %43 = OpIAdd %5 %42 %10
    696                OpBranch %41
    697          %45 = OpLabel
    698                OpReturn
    699                OpFunctionEnd
    700   )";
    701   // clang-format on
    702   std::unique_ptr<IRContext> context =
    703       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
    704                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
    705   Module* module = context->module();
    706   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
    707                              << text << std::endl;
    708   const Function* f = spvtest::GetFunction(module, 2);
    709   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
    710   // No invalidation of the cfg should occur during this test.
    711   CFG* cfg = context->cfg();
    712 
    713   EXPECT_EQ(ld.NumLoops(), 3u);
    714 
    715   {
    716     Loop& loop = *ld[16];
    717     EXPECT_TRUE(loop.HasNestedLoops());
    718     EXPECT_FALSE(loop.IsNested());
    719     EXPECT_EQ(loop.GetDepth(), 1u);
    720     EXPECT_EQ(loop.GetParent(), nullptr);
    721   }
    722 
    723   {
    724     Loop& loop = *ld[33];
    725     EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
    726     EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr);
    727     // Make sure the loop descriptor was properly updated.
    728     EXPECT_EQ(ld[loop.GetPreHeaderBlock()], ld[16]);
    729     {
    730       const std::vector<uint32_t>& preds =
    731           cfg->preds(loop.GetPreHeaderBlock()->id());
    732       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
    733       EXPECT_EQ(pred_set.size(), 2u);
    734       EXPECT_TRUE(pred_set.count(30));
    735       EXPECT_TRUE(pred_set.count(31));
    736       // Check the phi instructions.
    737       loop.GetPreHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
    738         for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
    739           EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
    740         }
    741       });
    742     }
    743     {
    744       const std::vector<uint32_t>& preds =
    745           cfg->preds(loop.GetHeaderBlock()->id());
    746       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
    747       EXPECT_EQ(pred_set.size(), 2u);
    748       EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
    749       EXPECT_TRUE(pred_set.count(35));
    750       // Check the phi instructions.
    751       loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
    752         for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
    753           EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
    754         }
    755       });
    756     }
    757   }
    758 
    759   {
    760     Loop& loop = *ld[41];
    761     EXPECT_EQ(loop.GetPreHeaderBlock(), nullptr);
    762     EXPECT_NE(loop.GetOrCreatePreHeaderBlock(), nullptr);
    763     EXPECT_EQ(ld[loop.GetPreHeaderBlock()], nullptr);
    764     EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id()).size(), 1u);
    765     EXPECT_EQ(cfg->preds(loop.GetPreHeaderBlock()->id())[0], 25u);
    766     // Check the phi instructions.
    767     loop.GetPreHeaderBlock()->ForEachPhiInst([](Instruction* phi) {
    768       EXPECT_EQ(phi->NumInOperands(), 2u);
    769       EXPECT_EQ(phi->GetSingleWordInOperand(1), 25u);
    770     });
    771     {
    772       const std::vector<uint32_t>& preds =
    773           cfg->preds(loop.GetHeaderBlock()->id());
    774       std::unordered_set<uint32_t> pred_set(preds.begin(), preds.end());
    775       EXPECT_EQ(pred_set.size(), 2u);
    776       EXPECT_TRUE(pred_set.count(loop.GetPreHeaderBlock()->id()));
    777       EXPECT_TRUE(pred_set.count(44));
    778       // Check the phi instructions.
    779       loop.GetHeaderBlock()->ForEachPhiInst([&pred_set](Instruction* phi) {
    780         for (uint32_t i = 1; i < phi->NumInOperands(); i += 2) {
    781           EXPECT_TRUE(pred_set.count(phi->GetSingleWordInOperand(i)));
    782         }
    783       });
    784     }
    785   }
    786 
    787   // Make sure pre-header insertion leaves the module valid.
    788   std::vector<uint32_t> bin;
    789   context->module()->ToBinary(&bin, true);
    790   EXPECT_TRUE(Validate(bin));
    791 }
    792 
    793 }  // namespace
    794 }  // namespace opt
    795 }  // namespace spvtools
    796