Home | History | Annotate | Download | only in spirit
      1 /*
      2  * Copyright 2017, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "pass_queue.h"
     18 
     19 #include "file_utils.h"
     20 #include "spirit.h"
     21 #include "test_utils.h"
     22 #include "transformer.h"
     23 #include "gtest/gtest.h"
     24 
     25 #include <stdint.h>
     26 
     27 namespace android {
     28 namespace spirit {
     29 
     30 namespace {
     31 
     32 class MulToAddTransformer : public Transformer {
     33 public:
     34   Instruction *transform(IMulInst *mul) override {
     35     auto ret = new IAddInst(mul->mResultType, mul->mOperand1, mul->mOperand2);
     36     ret->setId(mul->getId());
     37     return ret;
     38   }
     39 };
     40 
     41 class AddToDivTransformer : public Transformer {
     42 public:
     43   Instruction *transform(IAddInst *add) override {
     44     auto ret = new SDivInst(add->mResultType, add->mOperand1, add->mOperand2);
     45     ret->setId(add->getId());
     46     return ret;
     47   }
     48 };
     49 
     50 class AddMulAfterAddTransformer : public Transformer {
     51 public:
     52   Instruction *transform(IAddInst *add) override {
     53     insert(add);
     54     auto ret = new IMulInst(add->mResultType, add, add);
     55     ret->setId(add->getId());
     56     return ret;
     57   }
     58 };
     59 
     60 class Deleter : public Transformer {
     61 public:
     62   Instruction *transform(IMulInst *) override { return nullptr; }
     63 };
     64 
     65 class InPlaceModifyingPass : public Pass {
     66 public:
     67   Module *run(Module *m, int *error) override {
     68     m->getFloatType(64);
     69     if (error) {
     70       *error = 0;
     71     }
     72     return m;
     73   }
     74 };
     75 
     76 } // annonymous namespace
     77 
     78 class PassQueueTest : public ::testing::Test {
     79 protected:
     80   virtual void SetUp() { mWordsGreyscale = readWords("greyscale.spv"); }
     81 
     82   std::vector<uint32_t> mWordsGreyscale;
     83 
     84 private:
     85   std::vector<uint32_t> readWords(const char *testFile) {
     86     static const std::string testDataPath(
     87         "frameworks/rs/rsov/compiler/spirit/test_data/");
     88     const std::string &fullPath = getAbsolutePath(testDataPath + testFile);
     89     return readFile<uint32_t>(fullPath);
     90   }
     91 };
     92 
     93 TEST_F(PassQueueTest, testMulToAdd) {
     94   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
     95 
     96   ASSERT_NE(nullptr, m);
     97 
     98   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
     99   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
    100 
    101   PassQueue passes;
    102   passes.append(new MulToAddTransformer());
    103   auto m1 = passes.run(m.get());
    104 
    105   ASSERT_NE(nullptr, m1);
    106 
    107   ASSERT_TRUE(m1->resolveIds());
    108 
    109   EXPECT_EQ(2, countEntity<IAddInst>(m1));
    110   EXPECT_EQ(0, countEntity<IMulInst>(m1));
    111 }
    112 
    113 TEST_F(PassQueueTest, testInPlaceModifying) {
    114   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
    115 
    116   ASSERT_NE(nullptr, m);
    117 
    118   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
    119   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
    120   EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
    121 
    122   PassQueue passes;
    123   passes.append(new InPlaceModifyingPass());
    124   auto m1 = passes.run(m.get());
    125 
    126   ASSERT_NE(nullptr, m1);
    127 
    128   ASSERT_TRUE(m1->resolveIds());
    129 
    130   EXPECT_EQ(1, countEntity<IAddInst>(m1));
    131   EXPECT_EQ(1, countEntity<IMulInst>(m1));
    132   EXPECT_EQ(2, countEntity<TypeFloatInst>(m1));
    133 }
    134 
    135 TEST_F(PassQueueTest, testDeletion) {
    136   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
    137 
    138   ASSERT_NE(nullptr, m.get());
    139 
    140   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
    141 
    142   PassQueue passes;
    143   passes.append(new Deleter());
    144   auto m1 = passes.run(m.get());
    145 
    146   // One of the ids from the input module is missing now.
    147   ASSERT_EQ(nullptr, m1);
    148 }
    149 
    150 TEST_F(PassQueueTest, testMulToAddToDiv) {
    151   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
    152 
    153   ASSERT_NE(nullptr, m);
    154 
    155   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
    156   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
    157 
    158   PassQueue passes;
    159   passes.append(new MulToAddTransformer());
    160   passes.append(new AddToDivTransformer());
    161   auto m1 = passes.run(m.get());
    162 
    163   ASSERT_NE(nullptr, m1);
    164 
    165   ASSERT_TRUE(m1->resolveIds());
    166 
    167   EXPECT_EQ(0, countEntity<IAddInst>(m1));
    168   EXPECT_EQ(0, countEntity<IMulInst>(m1));
    169   EXPECT_EQ(2, countEntity<SDivInst>(m1));
    170 }
    171 
    172 TEST_F(PassQueueTest, testAMix) {
    173   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
    174 
    175   ASSERT_NE(nullptr, m);
    176 
    177   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
    178   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
    179   EXPECT_EQ(0, countEntity<SDivInst>(m.get()));
    180   EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
    181 
    182   PassQueue passes;
    183   passes.append(new MulToAddTransformer());
    184   passes.append(new AddToDivTransformer());
    185   passes.append(new InPlaceModifyingPass());
    186 
    187   std::unique_ptr<Module> m1(passes.run(m.get()));
    188 
    189   ASSERT_NE(nullptr, m1);
    190 
    191   ASSERT_TRUE(m1->resolveIds());
    192 
    193   EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
    194   EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
    195   EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
    196   EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get()));
    197 }
    198 
    199 TEST_F(PassQueueTest, testAnotherMix) {
    200   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
    201 
    202   ASSERT_NE(nullptr, m);
    203 
    204   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
    205   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
    206   EXPECT_EQ(0, countEntity<SDivInst>(m.get()));
    207   EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
    208 
    209   PassQueue passes;
    210   passes.append(new InPlaceModifyingPass());
    211   passes.append(new MulToAddTransformer());
    212   passes.append(new AddToDivTransformer());
    213   auto outputWords = passes.runAndSerialize(m.get());
    214 
    215   std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
    216 
    217   ASSERT_NE(nullptr, m1);
    218 
    219   ASSERT_TRUE(m1->resolveIds());
    220 
    221   EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
    222   EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
    223   EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
    224   EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get()));
    225 }
    226 
    227 TEST_F(PassQueueTest, testMulToAddToDivFromWords) {
    228   PassQueue passes;
    229   passes.append(new MulToAddTransformer());
    230   passes.append(new AddToDivTransformer());
    231   auto outputWords = passes.run(std::move(mWordsGreyscale));
    232 
    233   std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
    234 
    235   ASSERT_NE(nullptr, m1);
    236 
    237   ASSERT_TRUE(m1->resolveIds());
    238 
    239   EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
    240   EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
    241   EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
    242 }
    243 
    244 TEST_F(PassQueueTest, testMulToAddToDivToWords) {
    245   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
    246 
    247   ASSERT_NE(nullptr, m);
    248 
    249   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
    250   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
    251 
    252   PassQueue passes;
    253   passes.append(new MulToAddTransformer());
    254   passes.append(new AddToDivTransformer());
    255   auto outputWords = passes.runAndSerialize(m.get());
    256 
    257   std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
    258 
    259   ASSERT_NE(nullptr, m1);
    260 
    261   ASSERT_TRUE(m1->resolveIds());
    262 
    263   EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
    264   EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
    265   EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
    266 }
    267 
    268 TEST_F(PassQueueTest, testAddMulAfterAdd) {
    269   std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
    270 
    271   ASSERT_NE(nullptr, m);
    272 
    273   EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
    274   EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
    275 
    276   constexpr int kNumMulToAdd = 100;
    277 
    278   PassQueue passes;
    279   for (int i = 0; i < kNumMulToAdd; i++) {
    280     passes.append(new AddMulAfterAddTransformer());
    281   }
    282   auto outputWords = passes.runAndSerialize(m.get());
    283 
    284   std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
    285 
    286   ASSERT_NE(nullptr, m1);
    287 
    288   ASSERT_TRUE(m1->resolveIds());
    289 
    290   EXPECT_EQ(1, countEntity<IAddInst>(m1.get()));
    291   EXPECT_EQ(1 + kNumMulToAdd, countEntity<IMulInst>(m1.get()));
    292 }
    293 
    294 } // namespace spirit
    295 } // namespace android
    296