1 // Copyright (c) 2018 Google LLC. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASI, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ 16 #define SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ 17 18 #include <algorithm> 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "source/opt/tree_iterator.h" 24 25 namespace spvtools { 26 namespace opt { 27 28 class Loop; 29 class ScalarEvolutionAnalysis; 30 class SEConstantNode; 31 class SERecurrentNode; 32 class SEAddNode; 33 class SEMultiplyNode; 34 class SENegative; 35 class SEValueUnknown; 36 class SECantCompute; 37 38 // Abstract class representing a node in the scalar evolution DAG. Each node 39 // contains a vector of pointers to its children and each subclass of SENode 40 // implements GetType and an As method to allow casting. SENodes can be hashed 41 // using the SENodeHash functor. The vector of children is sorted when a node is 42 // added. This is important as it allows the hash of X+Y to be the same as Y+X. 43 class SENode { 44 public: 45 enum SENodeType { 46 Constant, 47 RecurrentAddExpr, 48 Add, 49 Multiply, 50 Negative, 51 ValueUnknown, 52 CanNotCompute 53 }; 54 55 using ChildContainerType = std::vector<SENode*>; 56 57 explicit SENode(ScalarEvolutionAnalysis* parent_analysis) 58 : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {} 59 60 virtual SENodeType GetType() const = 0; 61 62 virtual ~SENode() {} 63 64 virtual inline void AddChild(SENode* child) { 65 // If this is a constant node, assert. 66 if (AsSEConstantNode()) { 67 assert(false && "Trying to add a child node to a constant!"); 68 } 69 70 // Find the first point in the vector where |child| is greater than the node 71 // currently in the vector. 72 auto find_first_less_than = [child](const SENode* node) { 73 return child->unique_id_ <= node->unique_id_; 74 }; 75 76 auto position = std::find_if_not(children_.begin(), children_.end(), 77 find_first_less_than); 78 // Children are sorted so the hashing and equality operator will be the same 79 // for a node with the same children. X+Y should be the same as Y+X. 80 children_.insert(position, child); 81 } 82 83 // Get the type as an std::string. This is used to represent the node in the 84 // dot output and is used to hash the type as well. 85 std::string AsString() const; 86 87 // Dump the SENode and its immediate children, if |recurse| is true then it 88 // will recurse through all children to print the DAG starting from this node 89 // as a root. 90 void DumpDot(std::ostream& out, bool recurse = false) const; 91 92 // Checks if two nodes are the same by hashing them. 93 bool operator==(const SENode& other) const; 94 95 // Checks if two nodes are not the same by comparing the hashes. 96 bool operator!=(const SENode& other) const; 97 98 // Return the child node at |index|. 99 inline SENode* GetChild(size_t index) { return children_[index]; } 100 inline const SENode* GetChild(size_t index) const { return children_[index]; } 101 102 // Iterator to iterate over the child nodes. 103 using iterator = ChildContainerType::iterator; 104 using const_iterator = ChildContainerType::const_iterator; 105 106 // Iterate over immediate child nodes. 107 iterator begin() { return children_.begin(); } 108 iterator end() { return children_.end(); } 109 110 // Constant overloads for iterating over immediate child nodes. 111 const_iterator begin() const { return children_.cbegin(); } 112 const_iterator end() const { return children_.cend(); } 113 const_iterator cbegin() { return children_.cbegin(); } 114 const_iterator cend() { return children_.cend(); } 115 116 // Collect all the recurrent nodes in this SENode 117 std::vector<SERecurrentNode*> CollectRecurrentNodes() { 118 std::vector<SERecurrentNode*> recurrent_nodes{}; 119 120 if (auto recurrent_node = AsSERecurrentNode()) { 121 recurrent_nodes.push_back(recurrent_node); 122 } 123 124 for (auto child : GetChildren()) { 125 auto child_recurrent_nodes = child->CollectRecurrentNodes(); 126 recurrent_nodes.insert(recurrent_nodes.end(), 127 child_recurrent_nodes.begin(), 128 child_recurrent_nodes.end()); 129 } 130 131 return recurrent_nodes; 132 } 133 134 // Collect all the value unknown nodes in this SENode 135 std::vector<SEValueUnknown*> CollectValueUnknownNodes() { 136 std::vector<SEValueUnknown*> value_unknown_nodes{}; 137 138 if (auto value_unknown_node = AsSEValueUnknown()) { 139 value_unknown_nodes.push_back(value_unknown_node); 140 } 141 142 for (auto child : GetChildren()) { 143 auto child_value_unknown_nodes = child->CollectValueUnknownNodes(); 144 value_unknown_nodes.insert(value_unknown_nodes.end(), 145 child_value_unknown_nodes.begin(), 146 child_value_unknown_nodes.end()); 147 } 148 149 return value_unknown_nodes; 150 } 151 152 // Iterator to iterate over the entire DAG. Even though we are using the tree 153 // iterator it should still be safe to iterate over. However, nodes with 154 // multiple parents will be visited multiple times, unlike in a tree. 155 using dag_iterator = TreeDFIterator<SENode>; 156 using const_dag_iterator = TreeDFIterator<const SENode>; 157 158 // Iterate over all child nodes in the graph. 159 dag_iterator graph_begin() { return dag_iterator(this); } 160 dag_iterator graph_end() { return dag_iterator(); } 161 const_dag_iterator graph_begin() const { return graph_cbegin(); } 162 const_dag_iterator graph_end() const { return graph_cend(); } 163 const_dag_iterator graph_cbegin() const { return const_dag_iterator(this); } 164 const_dag_iterator graph_cend() const { return const_dag_iterator(); } 165 166 // Return the vector of immediate children. 167 const ChildContainerType& GetChildren() const { return children_; } 168 ChildContainerType& GetChildren() { return children_; } 169 170 // Return true if this node is a cant compute node. 171 bool IsCantCompute() const { return GetType() == CanNotCompute; } 172 173 // Implements a casting method for each type. 174 #define DeclareCastMethod(target) \ 175 virtual target* As##target() { return nullptr; } \ 176 virtual const target* As##target() const { return nullptr; } 177 DeclareCastMethod(SEConstantNode); 178 DeclareCastMethod(SERecurrentNode); 179 DeclareCastMethod(SEAddNode); 180 DeclareCastMethod(SEMultiplyNode); 181 DeclareCastMethod(SENegative); 182 DeclareCastMethod(SEValueUnknown); 183 DeclareCastMethod(SECantCompute); 184 #undef DeclareCastMethod 185 186 // Get the analysis which has this node in its cache. 187 inline ScalarEvolutionAnalysis* GetParentAnalysis() const { 188 return parent_analysis_; 189 } 190 191 protected: 192 ChildContainerType children_; 193 194 ScalarEvolutionAnalysis* parent_analysis_; 195 196 // The unique id of this node, assigned on creation by incrementing the static 197 // node count. 198 uint32_t unique_id_; 199 200 // The number of nodes created. 201 static uint32_t NumberOfNodes; 202 }; 203 204 // Function object to handle the hashing of SENodes. Hashing algorithm hashes 205 // the type (as a string), the literal value of any constants, and the child 206 // pointers which are assumed to be unique. 207 struct SENodeHash { 208 size_t operator()(const std::unique_ptr<SENode>& node) const; 209 size_t operator()(const SENode* node) const; 210 }; 211 212 // A node representing a constant integer. 213 class SEConstantNode : public SENode { 214 public: 215 SEConstantNode(ScalarEvolutionAnalysis* parent_analysis, int64_t value) 216 : SENode(parent_analysis), literal_value_(value) {} 217 218 SENodeType GetType() const final { return Constant; } 219 220 int64_t FoldToSingleValue() const { return literal_value_; } 221 222 SEConstantNode* AsSEConstantNode() override { return this; } 223 const SEConstantNode* AsSEConstantNode() const override { return this; } 224 225 inline void AddChild(SENode*) final { 226 assert(false && "Attempting to add a child to a constant node!"); 227 } 228 229 protected: 230 int64_t literal_value_; 231 }; 232 233 // A node representing a recurrent expression in the code. A recurrent 234 // expression is an expression whose value can be expressed as a linear 235 // expression of the loop iterations. Such as an induction variable. The actual 236 // value of a recurrent expression is coefficent_ * iteration + offset_, hence 237 // an induction variable i=0, i++ becomes a recurrent expression with an offset 238 // of zero and a coefficient of one. 239 class SERecurrentNode : public SENode { 240 public: 241 SERecurrentNode(ScalarEvolutionAnalysis* parent_analysis, const Loop* loop) 242 : SENode(parent_analysis), loop_(loop) {} 243 244 SENodeType GetType() const final { return RecurrentAddExpr; } 245 246 inline void AddCoefficient(SENode* child) { 247 coefficient_ = child; 248 SENode::AddChild(child); 249 } 250 251 inline void AddOffset(SENode* child) { 252 offset_ = child; 253 SENode::AddChild(child); 254 } 255 256 inline const SENode* GetCoefficient() const { return coefficient_; } 257 inline SENode* GetCoefficient() { return coefficient_; } 258 259 inline const SENode* GetOffset() const { return offset_; } 260 inline SENode* GetOffset() { return offset_; } 261 262 // Return the loop which this recurrent expression is recurring within. 263 const Loop* GetLoop() const { return loop_; } 264 265 SERecurrentNode* AsSERecurrentNode() override { return this; } 266 const SERecurrentNode* AsSERecurrentNode() const override { return this; } 267 268 private: 269 SENode* coefficient_; 270 SENode* offset_; 271 const Loop* loop_; 272 }; 273 274 // A node representing an addition operation between child nodes. 275 class SEAddNode : public SENode { 276 public: 277 explicit SEAddNode(ScalarEvolutionAnalysis* parent_analysis) 278 : SENode(parent_analysis) {} 279 280 SENodeType GetType() const final { return Add; } 281 282 SEAddNode* AsSEAddNode() override { return this; } 283 const SEAddNode* AsSEAddNode() const override { return this; } 284 }; 285 286 // A node representing a multiply operation between child nodes. 287 class SEMultiplyNode : public SENode { 288 public: 289 explicit SEMultiplyNode(ScalarEvolutionAnalysis* parent_analysis) 290 : SENode(parent_analysis) {} 291 292 SENodeType GetType() const final { return Multiply; } 293 294 SEMultiplyNode* AsSEMultiplyNode() override { return this; } 295 const SEMultiplyNode* AsSEMultiplyNode() const override { return this; } 296 }; 297 298 // A node representing a unary negative operation. 299 class SENegative : public SENode { 300 public: 301 explicit SENegative(ScalarEvolutionAnalysis* parent_analysis) 302 : SENode(parent_analysis) {} 303 304 SENodeType GetType() const final { return Negative; } 305 306 SENegative* AsSENegative() override { return this; } 307 const SENegative* AsSENegative() const override { return this; } 308 }; 309 310 // A node representing a value which we do not know the value of, such as a load 311 // instruction. 312 class SEValueUnknown : public SENode { 313 public: 314 // SEValueUnknowns must come from an instruction |unique_id| is the unique id 315 // of that instruction. This is so we cancompare value unknowns and have a 316 // unique value unknown for each instruction. 317 SEValueUnknown(ScalarEvolutionAnalysis* parent_analysis, uint32_t result_id) 318 : SENode(parent_analysis), result_id_(result_id) {} 319 320 SENodeType GetType() const final { return ValueUnknown; } 321 322 SEValueUnknown* AsSEValueUnknown() override { return this; } 323 const SEValueUnknown* AsSEValueUnknown() const override { return this; } 324 325 inline uint32_t ResultId() const { return result_id_; } 326 327 private: 328 uint32_t result_id_; 329 }; 330 331 // A node which we cannot reason about at all. 332 class SECantCompute : public SENode { 333 public: 334 explicit SECantCompute(ScalarEvolutionAnalysis* parent_analysis) 335 : SENode(parent_analysis) {} 336 337 SENodeType GetType() const final { return CanNotCompute; } 338 339 SECantCompute* AsSECantCompute() override { return this; } 340 const SECantCompute* AsSECantCompute() const override { return this; } 341 }; 342 343 } // namespace opt 344 } // namespace spvtools 345 #endif // SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_ 346