1 // 2 // Copyright (C) 2015-2016 Google, Inc. 3 // 4 // All rights reserved. 5 // 6 // Redistribution and use in source and binary forms, with or without 7 // modification, are permitted provided that the following conditions 8 // are met: 9 // 10 // Redistributions of source code must retain the above copyright 11 // notice, this list of conditions and the following disclaimer. 12 // 13 // Redistributions in binary form must reproduce the above 14 // copyright notice, this list of conditions and the following 15 // disclaimer in the documentation and/or other materials provided 16 // with the distribution. 17 // 18 // Neither the name of Google Inc. nor the names of its 19 // contributors may be used to endorse or promote products derived 20 // from this software without specific prior written permission. 21 // 22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 33 // POSSIBILITY OF SUCH DAMAGE. 34 35 // 36 // Visit the nodes in the glslang intermediate tree representation to 37 // propagate the 'noContraction' qualifier. 38 // 39 40 #include "propagateNoContraction.h" 41 42 #include <cstdlib> 43 #include <string> 44 #include <tuple> 45 #include <unordered_map> 46 #include <unordered_set> 47 48 #include "localintermediate.h" 49 namespace { 50 51 // Use a string to hold the access chain information, as in most cases the 52 // access chain is short and may contain only one element, which is the symbol 53 // ID. 54 // Example: struct {float a; float b;} s; 55 // Object s.a will be represented with: <symbol ID of s>/0 56 // Object s.b will be represented with: <symbol ID of s>/1 57 // Object s will be represented with: <symbol ID of s> 58 // For members of vector, matrix and arrays, they will be represented with the 59 // same symbol ID of their container symbol objects. This is because their 60 // preciseness is always the same as their container symbol objects. 61 typedef std::string ObjectAccessChain; 62 63 // The delimiter used in the ObjectAccessChain string to separate symbol ID and 64 // different level of struct indices. 65 const char ObjectAccesschainDelimiter = '/'; 66 67 // Mapping from Symbol IDs of symbol nodes, to their defining operation 68 // nodes. 69 typedef std::unordered_multimap<ObjectAccessChain, glslang::TIntermOperator*> NodeMapping; 70 // Mapping from object nodes to their access chain info string. 71 typedef std::unordered_map<glslang::TIntermTyped*, ObjectAccessChain> AccessChainMapping; 72 73 // Set of object IDs. 74 typedef std::unordered_set<ObjectAccessChain> ObjectAccesschainSet; 75 // Set of return branch nodes. 76 typedef std::unordered_set<glslang::TIntermBranch*> ReturnBranchNodeSet; 77 78 // A helper function to tell whether a node is 'noContraction'. Returns true if 79 // the node has 'noContraction' qualifier, otherwise false. 80 bool isPreciseObjectNode(glslang::TIntermTyped* node) 81 { 82 return node->getType().getQualifier().noContraction; 83 } 84 85 // Returns true if the opcode is a dereferencing one. 86 bool isDereferenceOperation(glslang::TOperator op) 87 { 88 switch (op) { 89 case glslang::EOpIndexDirect: 90 case glslang::EOpIndexDirectStruct: 91 case glslang::EOpIndexIndirect: 92 case glslang::EOpVectorSwizzle: 93 case glslang::EOpMatrixSwizzle: 94 return true; 95 default: 96 return false; 97 } 98 } 99 100 // Returns true if the opcode leads to an assignment operation. 101 bool isAssignOperation(glslang::TOperator op) 102 { 103 switch (op) { 104 case glslang::EOpAssign: 105 case glslang::EOpAddAssign: 106 case glslang::EOpSubAssign: 107 case glslang::EOpMulAssign: 108 case glslang::EOpVectorTimesMatrixAssign: 109 case glslang::EOpVectorTimesScalarAssign: 110 case glslang::EOpMatrixTimesScalarAssign: 111 case glslang::EOpMatrixTimesMatrixAssign: 112 case glslang::EOpDivAssign: 113 case glslang::EOpModAssign: 114 case glslang::EOpAndAssign: 115 case glslang::EOpLeftShiftAssign: 116 case glslang::EOpRightShiftAssign: 117 case glslang::EOpInclusiveOrAssign: 118 case glslang::EOpExclusiveOrAssign: 119 120 case glslang::EOpPostIncrement: 121 case glslang::EOpPostDecrement: 122 case glslang::EOpPreIncrement: 123 case glslang::EOpPreDecrement: 124 return true; 125 default: 126 return false; 127 } 128 } 129 130 // A helper function to get the unsigned int from a given constant union node. 131 // Note the node should only hold a uint scalar. 132 unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped* node) 133 { 134 assert(node->getAsConstantUnion() && node->getAsConstantUnion()->isScalar()); 135 unsigned struct_dereference_index = node->getAsConstantUnion()->getConstArray()[0].getUConst(); 136 return struct_dereference_index; 137 } 138 139 // A helper function to generate symbol_label. 140 ObjectAccessChain generateSymbolLabel(glslang::TIntermSymbol* node) 141 { 142 ObjectAccessChain symbol_id = 143 std::to_string(node->getId()) + "(" + node->getName().c_str() + ")"; 144 return symbol_id; 145 } 146 147 // Returns true if the operation is an arithmetic operation and valid for 148 // the 'NoContraction' decoration. 149 bool isArithmeticOperation(glslang::TOperator op) 150 { 151 switch (op) { 152 case glslang::EOpAddAssign: 153 case glslang::EOpSubAssign: 154 case glslang::EOpMulAssign: 155 case glslang::EOpVectorTimesMatrixAssign: 156 case glslang::EOpVectorTimesScalarAssign: 157 case glslang::EOpMatrixTimesScalarAssign: 158 case glslang::EOpMatrixTimesMatrixAssign: 159 case glslang::EOpDivAssign: 160 case glslang::EOpModAssign: 161 162 case glslang::EOpNegative: 163 164 case glslang::EOpAdd: 165 case glslang::EOpSub: 166 case glslang::EOpMul: 167 case glslang::EOpDiv: 168 case glslang::EOpMod: 169 170 case glslang::EOpVectorTimesScalar: 171 case glslang::EOpVectorTimesMatrix: 172 case glslang::EOpMatrixTimesVector: 173 case glslang::EOpMatrixTimesScalar: 174 case glslang::EOpMatrixTimesMatrix: 175 176 case glslang::EOpDot: 177 178 case glslang::EOpPostIncrement: 179 case glslang::EOpPostDecrement: 180 case glslang::EOpPreIncrement: 181 case glslang::EOpPreDecrement: 182 return true; 183 default: 184 return false; 185 } 186 } 187 188 // A helper class to help manage the populating_initial_no_contraction_ flag. 189 template <typename T> class StateSettingGuard { 190 public: 191 StateSettingGuard(T* state_ptr, T new_state_value) 192 : state_ptr_(state_ptr), previous_state_(*state_ptr) 193 { 194 *state_ptr = new_state_value; 195 } 196 StateSettingGuard(T* state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr) {} 197 void setState(T new_state_value) { *state_ptr_ = new_state_value; } 198 ~StateSettingGuard() { *state_ptr_ = previous_state_; } 199 200 private: 201 T* state_ptr_; 202 T previous_state_; 203 }; 204 205 // A helper function to get the front element from a given ObjectAccessChain 206 ObjectAccessChain getFrontElement(const ObjectAccessChain& chain) 207 { 208 size_t pos_delimiter = chain.find(ObjectAccesschainDelimiter); 209 return pos_delimiter == std::string::npos ? chain : chain.substr(0, pos_delimiter); 210 } 211 212 // A helper function to get the access chain starting from the second element. 213 ObjectAccessChain subAccessChainFromSecondElement(const ObjectAccessChain& chain) 214 { 215 size_t pos_delimiter = chain.find(ObjectAccesschainDelimiter); 216 return pos_delimiter == std::string::npos ? "" : chain.substr(pos_delimiter + 1); 217 } 218 219 // A helper function to get the access chain after removing a given prefix. 220 ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain& chain, 221 const ObjectAccessChain& prefix) 222 { 223 size_t pos = chain.find(prefix); 224 if (pos != 0) 225 return chain; 226 return chain.substr(prefix.length() + sizeof(ObjectAccesschainDelimiter)); 227 } 228 229 // 230 // A traverser which traverses the whole AST and populates: 231 // 1) A mapping from symbol nodes' IDs to their defining operation nodes. 232 // 2) A set of access chains of the initial precise object nodes. 233 // 234 class TSymbolDefinitionCollectingTraverser : public glslang::TIntermTraverser { 235 public: 236 TSymbolDefinitionCollectingTraverser(NodeMapping* symbol_definition_mapping, 237 AccessChainMapping* accesschain_mapping, 238 ObjectAccesschainSet* precise_objects, 239 ReturnBranchNodeSet* precise_return_nodes); 240 241 bool visitUnary(glslang::TVisit, glslang::TIntermUnary*) override; 242 bool visitBinary(glslang::TVisit, glslang::TIntermBinary*) override; 243 void visitSymbol(glslang::TIntermSymbol*) override; 244 bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate*) override; 245 bool visitBranch(glslang::TVisit, glslang::TIntermBranch*) override; 246 247 protected: 248 TSymbolDefinitionCollectingTraverser& operator=(const TSymbolDefinitionCollectingTraverser&); 249 250 // The mapping from symbol node IDs to their defining nodes. This should be 251 // populated along traversing the AST. 252 NodeMapping& symbol_definition_mapping_; 253 // The set of symbol node IDs for precise symbol nodes, the ones marked as 254 // 'noContraction'. 255 ObjectAccesschainSet& precise_objects_; 256 // The set of precise return nodes. 257 ReturnBranchNodeSet& precise_return_nodes_; 258 // A temporary cache of the symbol node whose defining node is to be found 259 // currently along traversing the AST. 260 ObjectAccessChain current_object_; 261 // A map from object node to its access chain. This traverser stores 262 // the built access chains into this map for each object node it has 263 // visited. 264 AccessChainMapping& accesschain_mapping_; 265 // The pointer to the Function Definition node, so we can get the 266 // preciseness of the return expression from it when we traverse the 267 // return branch node. 268 glslang::TIntermAggregate* current_function_definition_node_; 269 }; 270 271 TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser( 272 NodeMapping* symbol_definition_mapping, AccessChainMapping* accesschain_mapping, 273 ObjectAccesschainSet* precise_objects, 274 std::unordered_set<glslang::TIntermBranch*>* precise_return_nodes) 275 : TIntermTraverser(true, false, false), symbol_definition_mapping_(*symbol_definition_mapping), 276 precise_objects_(*precise_objects), precise_return_nodes_(*precise_return_nodes), 277 current_object_(), accesschain_mapping_(*accesschain_mapping), 278 current_function_definition_node_(nullptr) {} 279 280 // Visits a symbol node, set the current_object_ to the 281 // current node symbol ID, and record a mapping from this node to the current 282 // current_object_, which is the just obtained symbol 283 // ID. 284 void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol* node) 285 { 286 current_object_ = generateSymbolLabel(node); 287 accesschain_mapping_[node] = current_object_; 288 } 289 290 // Visits an aggregate node, traverses all of its children. 291 bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit, 292 glslang::TIntermAggregate* node) 293 { 294 // This aggregate node might be a function definition node, in which case we need to 295 // cache this node, so we can get the preciseness information of the return value 296 // of this function later. 297 StateSettingGuard<glslang::TIntermAggregate*> current_function_definition_node_setting_guard( 298 ¤t_function_definition_node_); 299 if (node->getOp() == glslang::EOpFunction) { 300 // This is function definition node, we need to cache this node so that we can 301 // get the preciseness of the return value later. 302 current_function_definition_node_setting_guard.setState(node); 303 } 304 // Traverse the items in the sequence. 305 glslang::TIntermSequence& seq = node->getSequence(); 306 for (int i = 0; i < (int)seq.size(); ++i) { 307 current_object_.clear(); 308 seq[i]->traverse(this); 309 } 310 return false; 311 } 312 313 bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit, 314 glslang::TIntermBranch* node) 315 { 316 if (node->getFlowOp() == glslang::EOpReturn && node->getExpression() && 317 current_function_definition_node_ && 318 current_function_definition_node_->getType().getQualifier().noContraction) { 319 // This node is a return node with an expression, and its function has a 320 // precise return value. We need to find the involved objects in its 321 // expression and add them to the set of initial precise objects. 322 precise_return_nodes_.insert(node); 323 node->getExpression()->traverse(this); 324 } 325 return false; 326 } 327 328 // Visits a unary node. This might be an implicit assignment like i++, i--. etc. 329 bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit */, 330 glslang::TIntermUnary* node) 331 { 332 current_object_.clear(); 333 node->getOperand()->traverse(this); 334 if (isAssignOperation(node->getOp())) { 335 // We should always be able to get an access chain of the operand node. 336 assert(!current_object_.empty()); 337 338 // If the operand node object is 'precise', we collect its access chain 339 // for the initial set of 'precise' objects. 340 if (isPreciseObjectNode(node->getOperand())) { 341 // The operand node is an 'precise' object node, add its 342 // access chain to the set of 'precise' objects. This is to collect 343 // the initial set of 'precise' objects. 344 precise_objects_.insert(current_object_); 345 } 346 // Gets the symbol ID from the object's access chain. 347 ObjectAccessChain id_symbol = getFrontElement(current_object_); 348 // Add a mapping from the symbol ID to this assignment operation node. 349 symbol_definition_mapping_.insert(std::make_pair(id_symbol, node)); 350 } 351 // A unary node is not a dereference node, so we clear the access chain which 352 // is under construction. 353 current_object_.clear(); 354 return false; 355 } 356 357 // Visits a binary node and updates the mapping from symbol IDs to the definition 358 // nodes. Also collects the access chains for the initial precise objects. 359 bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit */, 360 glslang::TIntermBinary* node) 361 { 362 // Traverses the left node to build the access chain info for the object. 363 current_object_.clear(); 364 node->getLeft()->traverse(this); 365 366 if (isAssignOperation(node->getOp())) { 367 // We should always be able to get an access chain for the left node. 368 assert(!current_object_.empty()); 369 370 // If the left node object is 'precise', it is an initial precise object 371 // specified in the shader source. Adds it to the initial work list to 372 // process later. 373 if (isPreciseObjectNode(node->getLeft())) { 374 // The left node is an 'precise' object node, add its access chain to 375 // the set of 'precise' objects. This is to collect the initial set 376 // of 'precise' objects. 377 precise_objects_.insert(current_object_); 378 } 379 // Gets the symbol ID from the object access chain, which should be the 380 // first element recorded in the access chain. 381 ObjectAccessChain id_symbol = getFrontElement(current_object_); 382 // Adds a mapping from the symbol ID to this assignment operation node. 383 symbol_definition_mapping_.insert(std::make_pair(id_symbol, node)); 384 385 // Traverses the right node, there may be other 'assignment' 386 // operations in the right. 387 current_object_.clear(); 388 node->getRight()->traverse(this); 389 390 } else if (isDereferenceOperation(node->getOp())) { 391 // The left node (parent node) is a struct type object. We need to 392 // record the access chain information of the current node into its 393 // object id. 394 if (node->getOp() == glslang::EOpIndexDirectStruct) { 395 unsigned struct_dereference_index = getStructIndexFromConstantUnion(node->getRight()); 396 current_object_.push_back(ObjectAccesschainDelimiter); 397 current_object_.append(std::to_string(struct_dereference_index)); 398 } 399 accesschain_mapping_[node] = current_object_; 400 401 // For a dereference node, there is no need to traverse the right child 402 // node as the right node should always be an integer type object. 403 404 } else { 405 // For other binary nodes, still traverse the right node. 406 current_object_.clear(); 407 node->getRight()->traverse(this); 408 } 409 return false; 410 } 411 412 // Traverses the AST and returns a tuple of four members: 413 // 1) a mapping from symbol IDs to the definition nodes (aka. assignment nodes) of these symbols. 414 // 2) a mapping from object nodes in the AST to the access chains of these objects. 415 // 3) a set of access chains of precise objects. 416 // 4) a set of return nodes with precise expressions. 417 std::tuple<NodeMapping, AccessChainMapping, ObjectAccesschainSet, ReturnBranchNodeSet> 418 getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate& intermediate) 419 { 420 auto result_tuple = std::make_tuple(NodeMapping(), AccessChainMapping(), ObjectAccesschainSet(), 421 ReturnBranchNodeSet()); 422 423 TIntermNode* root = intermediate.getTreeRoot(); 424 if (root == 0) 425 return result_tuple; 426 427 NodeMapping& symbol_definition_mapping = std::get<0>(result_tuple); 428 AccessChainMapping& accesschain_mapping = std::get<1>(result_tuple); 429 ObjectAccesschainSet& precise_objects = std::get<2>(result_tuple); 430 ReturnBranchNodeSet& precise_return_nodes = std::get<3>(result_tuple); 431 432 // Traverses the AST and populate the results. 433 TSymbolDefinitionCollectingTraverser collector(&symbol_definition_mapping, &accesschain_mapping, 434 &precise_objects, &precise_return_nodes); 435 root->traverse(&collector); 436 437 return result_tuple; 438 } 439 440 // 441 // A traverser that determine whether the left node (or operand node for unary 442 // node) of an assignment node is 'precise', containing 'precise' or not, 443 // according to the access chain a given precise object which share the same 444 // symbol as the left node. 445 // 446 // Post-orderly traverses the left node subtree of an binary assignment node and: 447 // 448 // 1) Propagates the 'precise' from the left object nodes to this object node. 449 // 450 // 2) Builds object access chain along the traversal, and also compares with 451 // the access chain of the given 'precise' object along with the traversal to 452 // tell if the node to be defined is 'precise' or not. 453 // 454 class TNoContractionAssigneeCheckingTraverser : public glslang::TIntermTraverser { 455 456 enum DecisionStatus { 457 // The object node to be assigned to may contain 'precise' objects and also not 'precise' objects. 458 Mixed = 0, 459 // The object node to be assigned to is either a 'precise' object or a struct objects whose members are all 'precise'. 460 Precise = 1, 461 // The object node to be assigned to is not a 'precise' object. 462 NotPreicse = 2, 463 }; 464 465 public: 466 TNoContractionAssigneeCheckingTraverser(const AccessChainMapping& accesschain_mapping) 467 : TIntermTraverser(true, false, false), accesschain_mapping_(accesschain_mapping), 468 precise_object_(nullptr) {} 469 470 // Checks the preciseness of a given assignment node with a precise object 471 // represented as access chain. The precise object shares the same symbol 472 // with the assignee of the given assignment node. Return a tuple of two: 473 // 474 // 1) The preciseness of the assignee node of this assignment node. True 475 // if the assignee contains 'precise' objects or is 'precise', false if 476 // the assignee is not 'precise' according to the access chain of the given 477 // precise object. 478 // 479 // 2) The incremental access chain from the assignee node to its nested 480 // 'precise' object, according to the access chain of the given precise 481 // object. This incremental access chain can be empty, which means the 482 // assignee is 'precise'. Otherwise it shows the path to the nested 483 // precise object. 484 std::tuple<bool, ObjectAccessChain> 485 getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator* node, 486 const ObjectAccessChain& precise_object) 487 { 488 assert(isAssignOperation(node->getOp())); 489 precise_object_ = &precise_object; 490 ObjectAccessChain assignee_object; 491 if (glslang::TIntermBinary* BN = node->getAsBinaryNode()) { 492 // This is a binary assignment node, we need to check the 493 // preciseness of the left node. 494 assert(accesschain_mapping_.count(BN->getLeft())); 495 // The left node (assignee node) is an object node, traverse the 496 // node to let the 'precise' of nesting objects being transfered to 497 // nested objects. 498 BN->getLeft()->traverse(this); 499 // After traversing the left node, if the left node is 'precise', 500 // we can conclude this assignment should propagate 'precise'. 501 if (isPreciseObjectNode(BN->getLeft())) { 502 return make_tuple(true, ObjectAccessChain()); 503 } 504 // If the preciseness of the left node (assignee node) can not 505 // be determined by now, we need to compare the access chain string 506 // of the assignee object with the given precise object. 507 assignee_object = accesschain_mapping_.at(BN->getLeft()); 508 509 } else if (glslang::TIntermUnary* UN = node->getAsUnaryNode()) { 510 // This is a unary assignment node, we need to check the 511 // preciseness of the operand node. For unary assignment node, the 512 // operand node should always be an object node. 513 assert(accesschain_mapping_.count(UN->getOperand())); 514 // Traverse the operand node to let the 'precise' being propagated 515 // from lower nodes to upper nodes. 516 UN->getOperand()->traverse(this); 517 // After traversing the operand node, if the operand node is 518 // 'precise', this assignment should propagate 'precise'. 519 if (isPreciseObjectNode(UN->getOperand())) { 520 return make_tuple(true, ObjectAccessChain()); 521 } 522 // If the preciseness of the operand node (assignee node) can not 523 // be determined by now, we need to compare the access chain string 524 // of the assignee object with the given precise object. 525 assignee_object = accesschain_mapping_.at(UN->getOperand()); 526 } else { 527 // Not a binary or unary node, should not happen. 528 assert(false); 529 } 530 531 // Compare the access chain string of the assignee node with the given 532 // precise object to determine if this assignment should propagate 533 // 'precise'. 534 if (assignee_object.find(precise_object) == 0) { 535 // The access chain string of the given precise object is a prefix 536 // of assignee's access chain string. The assignee should be 537 // 'precise'. 538 return make_tuple(true, ObjectAccessChain()); 539 } else if (precise_object.find(assignee_object) == 0) { 540 // The assignee's access chain string is a prefix of the given 541 // precise object, the assignee object contains 'precise' object, 542 // and we need to pass the remained access chain to the object nodes 543 // in the right. 544 return make_tuple(true, getSubAccessChainAfterPrefix(precise_object, assignee_object)); 545 } else { 546 // The access chain strings do not match, the assignee object can 547 // not be labeled as 'precise' according to the given precise 548 // object. 549 return make_tuple(false, ObjectAccessChain()); 550 } 551 } 552 553 protected: 554 TNoContractionAssigneeCheckingTraverser& operator=(const TNoContractionAssigneeCheckingTraverser&); 555 556 bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override; 557 void visitSymbol(glslang::TIntermSymbol* node) override; 558 559 // A map from object nodes to their access chain string (used as object ID). 560 const AccessChainMapping& accesschain_mapping_; 561 // A given precise object, represented in it access chain string. This 562 // precise object is used to be compared with the assignee node to tell if 563 // the assignee node is 'precise', contains 'precise' object or not 564 // 'precise'. 565 const ObjectAccessChain* precise_object_; 566 }; 567 568 // Visits a binary node. If the node is an object node, it must be a dereference 569 // node. In such cases, if the left node is 'precise', this node should also be 570 // 'precise'. 571 bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit, 572 glslang::TIntermBinary* node) 573 { 574 // Traverses the left so that we transfer the 'precise' from nesting object 575 // to its nested object. 576 node->getLeft()->traverse(this); 577 // If this binary node is an object node, we should have it in the 578 // accesschain_mapping_. 579 if (accesschain_mapping_.count(node)) { 580 // A binary object node must be a dereference node. 581 assert(isDereferenceOperation(node->getOp())); 582 // If the left node is 'precise', this node should also be precise, 583 // otherwise, compare with the given precise_object_. If the 584 // access chain of this node matches with the given precise_object_, 585 // this node should be marked as 'precise'. 586 if (isPreciseObjectNode(node->getLeft())) { 587 node->getWritableType().getQualifier().noContraction = true; 588 } else if (accesschain_mapping_.at(node) == *precise_object_) { 589 node->getWritableType().getQualifier().noContraction = true; 590 } 591 } 592 return false; 593 } 594 595 // Visits a symbol node, if the symbol node ID (its access chain string) matches 596 // with the given precise object, this node should be 'precise'. 597 void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol* node) 598 { 599 // A symbol node should always be an object node, and should have been added 600 // to the map from object nodes to their access chain strings. 601 assert(accesschain_mapping_.count(node)); 602 if (accesschain_mapping_.at(node) == *precise_object_) { 603 node->getWritableType().getQualifier().noContraction = true; 604 } 605 } 606 607 // 608 // A traverser that only traverses the right side of binary assignment nodes 609 // and the operand node of unary assignment nodes. 610 // 611 // 1) Marks arithmetic operations as 'NoContraction'. 612 // 613 // 2) Find the object which should be marked as 'precise' in the right and 614 // update the 'precise' object work list. 615 // 616 class TNoContractionPropagator : public glslang::TIntermTraverser { 617 public: 618 TNoContractionPropagator(ObjectAccesschainSet* precise_objects, 619 const AccessChainMapping& accesschain_mapping) 620 : TIntermTraverser(true, false, false), 621 precise_objects_(*precise_objects), added_precise_object_ids_(), 622 remained_accesschain_(), accesschain_mapping_(accesschain_mapping) {} 623 624 // Propagates 'precise' in the right nodes of a given assignment node with 625 // access chain record from the assignee node to a 'precise' object it 626 // contains. 627 void 628 propagateNoContractionInOneExpression(glslang::TIntermTyped* defining_node, 629 const ObjectAccessChain& assignee_remained_accesschain) 630 { 631 remained_accesschain_ = assignee_remained_accesschain; 632 if (glslang::TIntermBinary* BN = defining_node->getAsBinaryNode()) { 633 assert(isAssignOperation(BN->getOp())); 634 BN->getRight()->traverse(this); 635 if (isArithmeticOperation(BN->getOp())) { 636 BN->getWritableType().getQualifier().noContraction = true; 637 } 638 } else if (glslang::TIntermUnary* UN = defining_node->getAsUnaryNode()) { 639 assert(isAssignOperation(UN->getOp())); 640 UN->getOperand()->traverse(this); 641 if (isArithmeticOperation(UN->getOp())) { 642 UN->getWritableType().getQualifier().noContraction = true; 643 } 644 } 645 } 646 647 // Propagates 'precise' in a given precise return node. 648 void propagateNoContractionInReturnNode(glslang::TIntermBranch* return_node) 649 { 650 remained_accesschain_ = ""; 651 assert(return_node->getFlowOp() == glslang::EOpReturn && return_node->getExpression()); 652 return_node->getExpression()->traverse(this); 653 } 654 655 protected: 656 TNoContractionPropagator& operator=(const TNoContractionPropagator&); 657 658 // Visits an aggregate node. The node can be a initializer list, in which 659 // case we need to find the 'precise' or 'precise' containing object node 660 // with the access chain record. In other cases, just need to traverse all 661 // the children nodes. 662 bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node) override 663 { 664 if (!remained_accesschain_.empty() && node->getOp() == glslang::EOpConstructStruct) { 665 // This is a struct initializer node, and the remained 666 // access chain is not empty, we need to refer to the 667 // assignee_remained_access_chain_ to find the nested 668 // 'precise' object. And we don't need to visit other nodes in this 669 // aggregate node. 670 671 // Gets the struct dereference index that leads to 'precise' object. 672 ObjectAccessChain precise_accesschain_index_str = 673 getFrontElement(remained_accesschain_); 674 unsigned precise_accesschain_index = (unsigned)strtoul(precise_accesschain_index_str.c_str(), nullptr, 10); 675 // Gets the node pointed by the access chain index extracted before. 676 glslang::TIntermTyped* potential_precise_node = 677 node->getSequence()[precise_accesschain_index]->getAsTyped(); 678 assert(potential_precise_node); 679 // Pop the front access chain index from the path, and visit the nested node. 680 { 681 ObjectAccessChain next_level_accesschain = 682 subAccessChainFromSecondElement(remained_accesschain_); 683 StateSettingGuard<ObjectAccessChain> setup_remained_accesschain_for_next_level( 684 &remained_accesschain_, next_level_accesschain); 685 potential_precise_node->traverse(this); 686 } 687 return false; 688 } 689 return true; 690 } 691 692 // Visits a binary node. A binary node can be an object node, e.g. a dereference node. 693 // As only the top object nodes in the right side of an assignment needs to be visited 694 // and added to 'precise' work list, this traverser won't visit the children nodes of 695 // an object node. If the binary node does not represent an object node, it should 696 // go on to traverse its children nodes and if it is an arithmetic operation node, this 697 // operation should be marked as 'noContraction'. 698 bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override 699 { 700 if (isDereferenceOperation(node->getOp())) { 701 // This binary node is an object node. Need to update the precise 702 // object set with the access chain of this node + remained 703 // access chain . 704 ObjectAccessChain new_precise_accesschain = accesschain_mapping_.at(node); 705 if (remained_accesschain_.empty()) { 706 node->getWritableType().getQualifier().noContraction = true; 707 } else { 708 new_precise_accesschain += ObjectAccesschainDelimiter + remained_accesschain_; 709 } 710 // Cache the access chain as added precise object, so we won't add the 711 // same object to the work list again. 712 if (!added_precise_object_ids_.count(new_precise_accesschain)) { 713 precise_objects_.insert(new_precise_accesschain); 714 added_precise_object_ids_.insert(new_precise_accesschain); 715 } 716 // Only the upper-most object nodes should be visited, so do not 717 // visit children of this object node. 718 return false; 719 } 720 // If this is an arithmetic operation, marks this node as 'noContraction'. 721 if (isArithmeticOperation(node->getOp()) && node->getBasicType() != glslang::EbtInt) { 722 node->getWritableType().getQualifier().noContraction = true; 723 } 724 // As this node is not an object node, need to traverse the children nodes. 725 return true; 726 } 727 728 // Visits a unary node. A unary node can not be an object node. If the operation 729 // is an arithmetic operation, need to mark this node as 'noContraction'. 730 bool visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary* node) override 731 { 732 // If this is an arithmetic operation, marks this with 'noContraction' 733 if (isArithmeticOperation(node->getOp())) { 734 node->getWritableType().getQualifier().noContraction = true; 735 } 736 return true; 737 } 738 739 // Visits a symbol node. A symbol node is always an object node. So we 740 // should always be able to find its in our collected mapping from object 741 // nodes to access chains. As an object node, a symbol node can be either 742 // 'precise' or containing 'precise' objects according to unused 743 // access chain information we have when we visit this node. 744 void visitSymbol(glslang::TIntermSymbol* node) override 745 { 746 // Symbol nodes are object nodes and should always have an 747 // access chain collected before matches with it. 748 assert(accesschain_mapping_.count(node)); 749 ObjectAccessChain new_precise_accesschain = accesschain_mapping_.at(node); 750 // If the unused access chain is empty, this symbol node should be 751 // marked as 'precise'. Otherwise, the unused access chain should be 752 // appended to the symbol ID to build a new access chain which points to 753 // the nested 'precise' object in this symbol object. 754 if (remained_accesschain_.empty()) { 755 node->getWritableType().getQualifier().noContraction = true; 756 } else { 757 new_precise_accesschain += ObjectAccesschainDelimiter + remained_accesschain_; 758 } 759 // Add the new 'precise' access chain to the work list and make sure we 760 // don't visit it again. 761 if (!added_precise_object_ids_.count(new_precise_accesschain)) { 762 precise_objects_.insert(new_precise_accesschain); 763 added_precise_object_ids_.insert(new_precise_accesschain); 764 } 765 } 766 767 // A set of precise objects, represented as access chains. 768 ObjectAccesschainSet& precise_objects_; 769 // Visited symbol nodes, should not revisit these nodes. 770 ObjectAccesschainSet added_precise_object_ids_; 771 // The left node of an assignment operation might be an parent of 'precise' objects. 772 // This means the left node might not be an 'precise' object node, but it may contains 773 // 'precise' qualifier which should be propagated to the corresponding child node in 774 // the right. So we need the path from the left node to its nested 'precise' node to 775 // tell us how to find the corresponding 'precise' node in the right. 776 ObjectAccessChain remained_accesschain_; 777 // A map from node pointers to their access chains. 778 const AccessChainMapping& accesschain_mapping_; 779 }; 780 } 781 782 namespace glslang { 783 784 void PropagateNoContraction(const glslang::TIntermediate& intermediate) 785 { 786 // First, traverses the AST, records symbols with their defining operations 787 // and collects the initial set of precise symbols (symbol nodes that marked 788 // as 'noContraction') and precise return nodes. 789 auto mappings_and_precise_objects = 790 getSymbolToDefinitionMappingAndPreciseSymbolIDs(intermediate); 791 792 // The mapping of symbol node IDs to their defining nodes. This enables us 793 // to get the defining node directly from a given symbol ID without 794 // traversing the tree again. 795 NodeMapping& symbol_definition_mapping = std::get<0>(mappings_and_precise_objects); 796 797 // The mapping of object nodes to their access chains recorded. 798 AccessChainMapping& accesschain_mapping = std::get<1>(mappings_and_precise_objects); 799 800 // The initial set of 'precise' objects which are represented as the 801 // access chain toward them. 802 ObjectAccesschainSet& precise_object_accesschains = std::get<2>(mappings_and_precise_objects); 803 804 // The set of 'precise' return nodes. 805 ReturnBranchNodeSet& precise_return_nodes = std::get<3>(mappings_and_precise_objects); 806 807 // Second, uses the initial set of precise objects as a work list, pops an 808 // access chain, extract the symbol ID from it. Then: 809 // 1) Check the assignee object, see if it is 'precise' object node or 810 // contains 'precise' object. Obtain the incremental access chain from the 811 // assignee node to its nested 'precise' node (if any). 812 // 2) If the assignee object node is 'precise' or it contains 'precise' 813 // objects, traverses the right side of the assignment operation 814 // expression to mark arithmetic operations as 'noContration' and update 815 // 'precise' access chain work list with new found object nodes. 816 // Repeat above steps until the work list is empty. 817 TNoContractionAssigneeCheckingTraverser checker(accesschain_mapping); 818 TNoContractionPropagator propagator(&precise_object_accesschains, accesschain_mapping); 819 820 // We have two initial precise work lists to handle: 821 // 1) precise return nodes 822 // 2) precise object access chains 823 // We should process the precise return nodes first and the involved 824 // objects in the return expression should be added to the precise object 825 // access chain set. 826 while (!precise_return_nodes.empty()) { 827 glslang::TIntermBranch* precise_return_node = *precise_return_nodes.begin(); 828 propagator.propagateNoContractionInReturnNode(precise_return_node); 829 precise_return_nodes.erase(precise_return_node); 830 } 831 832 while (!precise_object_accesschains.empty()) { 833 // Get the access chain of a precise object from the work list. 834 ObjectAccessChain precise_object_accesschain = *precise_object_accesschains.begin(); 835 // Get the symbol id from the access chain. 836 ObjectAccessChain symbol_id = getFrontElement(precise_object_accesschain); 837 // Get all the defining nodes of that symbol ID. 838 std::pair<NodeMapping::iterator, NodeMapping::iterator> range = 839 symbol_definition_mapping.equal_range(symbol_id); 840 // Visits all the assignment nodes of that symbol ID and 841 // 1) Check if the assignee node is 'precise' or contains 'precise' 842 // objects. 843 // 2) Propagate the 'precise' to the top layer object nodes 844 // in the right side of the assignment operation, update the 'precise' 845 // work list with new access chains representing the new 'precise' 846 // objects, and mark arithmetic operations as 'noContraction'. 847 for (NodeMapping::iterator defining_node_iter = range.first; 848 defining_node_iter != range.second; defining_node_iter++) { 849 TIntermOperator* defining_node = defining_node_iter->second; 850 // Check the assignee node. 851 auto checker_result = checker.getPrecisenessAndRemainedAccessChain( 852 defining_node, precise_object_accesschain); 853 bool& contain_precise = std::get<0>(checker_result); 854 ObjectAccessChain& remained_accesschain = std::get<1>(checker_result); 855 // If the assignee node is 'precise' or contains 'precise', propagate the 856 // 'precise' to the right. Otherwise just skip this assignment node. 857 if (contain_precise) { 858 propagator.propagateNoContractionInOneExpression(defining_node, 859 remained_accesschain); 860 } 861 } 862 // Remove the last processed 'precise' object from the work list. 863 precise_object_accesschains.erase(precise_object_accesschain); 864 } 865 } 866 }; 867