1 /* 2 * Copyright 2017, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "pass_queue.h" 18 19 #include "file_utils.h" 20 #include "spirit.h" 21 #include "test_utils.h" 22 #include "transformer.h" 23 #include "gtest/gtest.h" 24 25 #include <stdint.h> 26 27 namespace android { 28 namespace spirit { 29 30 namespace { 31 32 class MulToAddTransformer : public Transformer { 33 public: 34 Instruction *transform(IMulInst *mul) override { 35 auto ret = new IAddInst(mul->mResultType, mul->mOperand1, mul->mOperand2); 36 ret->setId(mul->getId()); 37 return ret; 38 } 39 }; 40 41 class AddToDivTransformer : public Transformer { 42 public: 43 Instruction *transform(IAddInst *add) override { 44 auto ret = new SDivInst(add->mResultType, add->mOperand1, add->mOperand2); 45 ret->setId(add->getId()); 46 return ret; 47 } 48 }; 49 50 class AddMulAfterAddTransformer : public Transformer { 51 public: 52 Instruction *transform(IAddInst *add) override { 53 insert(add); 54 auto ret = new IMulInst(add->mResultType, add, add); 55 ret->setId(add->getId()); 56 return ret; 57 } 58 }; 59 60 class Deleter : public Transformer { 61 public: 62 Instruction *transform(IMulInst *) override { return nullptr; } 63 }; 64 65 class InPlaceModifyingPass : public Pass { 66 public: 67 Module *run(Module *m, int *error) override { 68 m->getFloatType(64); 69 if (error) { 70 *error = 0; 71 } 72 return m; 73 } 74 }; 75 76 } // annonymous namespace 77 78 class PassQueueTest : public ::testing::Test { 79 protected: 80 virtual void SetUp() { mWordsGreyscale = readWords("greyscale.spv"); } 81 82 std::vector<uint32_t> mWordsGreyscale; 83 84 private: 85 std::vector<uint32_t> readWords(const char *testFile) { 86 static const std::string testDataPath( 87 "frameworks/rs/rsov/compiler/spirit/test_data/"); 88 const std::string &fullPath = getAbsolutePath(testDataPath + testFile); 89 return readFile<uint32_t>(fullPath); 90 } 91 }; 92 93 TEST_F(PassQueueTest, testMulToAdd) { 94 std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale)); 95 96 ASSERT_NE(nullptr, m); 97 98 EXPECT_EQ(1, countEntity<IAddInst>(m.get())); 99 EXPECT_EQ(1, countEntity<IMulInst>(m.get())); 100 101 PassQueue passes; 102 passes.append(new MulToAddTransformer()); 103 auto m1 = passes.run(m.get()); 104 105 ASSERT_NE(nullptr, m1); 106 107 ASSERT_TRUE(m1->resolveIds()); 108 109 EXPECT_EQ(2, countEntity<IAddInst>(m1)); 110 EXPECT_EQ(0, countEntity<IMulInst>(m1)); 111 } 112 113 TEST_F(PassQueueTest, testInPlaceModifying) { 114 std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale)); 115 116 ASSERT_NE(nullptr, m); 117 118 EXPECT_EQ(1, countEntity<IAddInst>(m.get())); 119 EXPECT_EQ(1, countEntity<IMulInst>(m.get())); 120 EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get())); 121 122 PassQueue passes; 123 passes.append(new InPlaceModifyingPass()); 124 auto m1 = passes.run(m.get()); 125 126 ASSERT_NE(nullptr, m1); 127 128 ASSERT_TRUE(m1->resolveIds()); 129 130 EXPECT_EQ(1, countEntity<IAddInst>(m1)); 131 EXPECT_EQ(1, countEntity<IMulInst>(m1)); 132 EXPECT_EQ(2, countEntity<TypeFloatInst>(m1)); 133 } 134 135 TEST_F(PassQueueTest, testDeletion) { 136 std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale)); 137 138 ASSERT_NE(nullptr, m.get()); 139 140 EXPECT_EQ(1, countEntity<IMulInst>(m.get())); 141 142 PassQueue passes; 143 passes.append(new Deleter()); 144 auto m1 = passes.run(m.get()); 145 146 // One of the ids from the input module is missing now. 147 ASSERT_EQ(nullptr, m1); 148 } 149 150 TEST_F(PassQueueTest, testMulToAddToDiv) { 151 std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale)); 152 153 ASSERT_NE(nullptr, m); 154 155 EXPECT_EQ(1, countEntity<IAddInst>(m.get())); 156 EXPECT_EQ(1, countEntity<IMulInst>(m.get())); 157 158 PassQueue passes; 159 passes.append(new MulToAddTransformer()); 160 passes.append(new AddToDivTransformer()); 161 auto m1 = passes.run(m.get()); 162 163 ASSERT_NE(nullptr, m1); 164 165 ASSERT_TRUE(m1->resolveIds()); 166 167 EXPECT_EQ(0, countEntity<IAddInst>(m1)); 168 EXPECT_EQ(0, countEntity<IMulInst>(m1)); 169 EXPECT_EQ(2, countEntity<SDivInst>(m1)); 170 } 171 172 TEST_F(PassQueueTest, testAMix) { 173 std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale)); 174 175 ASSERT_NE(nullptr, m); 176 177 EXPECT_EQ(1, countEntity<IAddInst>(m.get())); 178 EXPECT_EQ(1, countEntity<IMulInst>(m.get())); 179 EXPECT_EQ(0, countEntity<SDivInst>(m.get())); 180 EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get())); 181 182 PassQueue passes; 183 passes.append(new MulToAddTransformer()); 184 passes.append(new AddToDivTransformer()); 185 passes.append(new InPlaceModifyingPass()); 186 187 std::unique_ptr<Module> m1(passes.run(m.get())); 188 189 ASSERT_NE(nullptr, m1); 190 191 ASSERT_TRUE(m1->resolveIds()); 192 193 EXPECT_EQ(0, countEntity<IAddInst>(m1.get())); 194 EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); 195 EXPECT_EQ(2, countEntity<SDivInst>(m1.get())); 196 EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get())); 197 } 198 199 TEST_F(PassQueueTest, testAnotherMix) { 200 std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale)); 201 202 ASSERT_NE(nullptr, m); 203 204 EXPECT_EQ(1, countEntity<IAddInst>(m.get())); 205 EXPECT_EQ(1, countEntity<IMulInst>(m.get())); 206 EXPECT_EQ(0, countEntity<SDivInst>(m.get())); 207 EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get())); 208 209 PassQueue passes; 210 passes.append(new InPlaceModifyingPass()); 211 passes.append(new MulToAddTransformer()); 212 passes.append(new AddToDivTransformer()); 213 auto outputWords = passes.runAndSerialize(m.get()); 214 215 std::unique_ptr<Module> m1(Deserialize<Module>(outputWords)); 216 217 ASSERT_NE(nullptr, m1); 218 219 ASSERT_TRUE(m1->resolveIds()); 220 221 EXPECT_EQ(0, countEntity<IAddInst>(m1.get())); 222 EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); 223 EXPECT_EQ(2, countEntity<SDivInst>(m1.get())); 224 EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get())); 225 } 226 227 TEST_F(PassQueueTest, testMulToAddToDivFromWords) { 228 PassQueue passes; 229 passes.append(new MulToAddTransformer()); 230 passes.append(new AddToDivTransformer()); 231 auto outputWords = passes.run(std::move(mWordsGreyscale)); 232 233 std::unique_ptr<Module> m1(Deserialize<Module>(outputWords)); 234 235 ASSERT_NE(nullptr, m1); 236 237 ASSERT_TRUE(m1->resolveIds()); 238 239 EXPECT_EQ(0, countEntity<IAddInst>(m1.get())); 240 EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); 241 EXPECT_EQ(2, countEntity<SDivInst>(m1.get())); 242 } 243 244 TEST_F(PassQueueTest, testMulToAddToDivToWords) { 245 std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale)); 246 247 ASSERT_NE(nullptr, m); 248 249 EXPECT_EQ(1, countEntity<IAddInst>(m.get())); 250 EXPECT_EQ(1, countEntity<IMulInst>(m.get())); 251 252 PassQueue passes; 253 passes.append(new MulToAddTransformer()); 254 passes.append(new AddToDivTransformer()); 255 auto outputWords = passes.runAndSerialize(m.get()); 256 257 std::unique_ptr<Module> m1(Deserialize<Module>(outputWords)); 258 259 ASSERT_NE(nullptr, m1); 260 261 ASSERT_TRUE(m1->resolveIds()); 262 263 EXPECT_EQ(0, countEntity<IAddInst>(m1.get())); 264 EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); 265 EXPECT_EQ(2, countEntity<SDivInst>(m1.get())); 266 } 267 268 TEST_F(PassQueueTest, testAddMulAfterAdd) { 269 std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale)); 270 271 ASSERT_NE(nullptr, m); 272 273 EXPECT_EQ(1, countEntity<IAddInst>(m.get())); 274 EXPECT_EQ(1, countEntity<IMulInst>(m.get())); 275 276 constexpr int kNumMulToAdd = 100; 277 278 PassQueue passes; 279 for (int i = 0; i < kNumMulToAdd; i++) { 280 passes.append(new AddMulAfterAddTransformer()); 281 } 282 auto outputWords = passes.runAndSerialize(m.get()); 283 284 std::unique_ptr<Module> m1(Deserialize<Module>(outputWords)); 285 286 ASSERT_NE(nullptr, m1); 287 288 ASSERT_TRUE(m1->resolveIds()); 289 290 EXPECT_EQ(1, countEntity<IAddInst>(m1.get())); 291 EXPECT_EQ(1 + kNumMulToAdd, countEntity<IMulInst>(m1.get())); 292 } 293 294 } // namespace spirit 295 } // namespace android 296