1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 18 19 #include <list> 20 #include <vector> 21 22 #include "tensorflow/compiler/xla/map_util.h" 23 #include "tensorflow/compiler/xla/types.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/gtl/array_slice.h" 26 #include "tensorflow/core/lib/gtl/flatmap.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace xla { 30 31 class HloInstruction; 32 33 // A class for representing reachability between HloInstructions. 34 // 35 // !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix 36 // and it is up to the user of the class to set the adjacency matrix such that 37 // it represents reachability, i.e. such that it is transitive. That the graph 38 // be transitive is thus not an invariant of this class, but it is required for 39 // the name of the class and its methods to make sense. 40 class HloReachabilityMap { 41 public: 42 // Sets up a graph with no edges and where the nodes correspond to the given 43 // instructions. 44 explicit HloReachabilityMap(const std::list<HloInstruction*>& instructions); 45 46 // Set the reachability set of 'instruction' to the union of the reachability 47 // sets of 'inputs'. Upon return, IsReachable(x, instruction) where 48 // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true 49 // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from 50 // itself. Returns whether the reachability set of 'instruction' changed. 51 // 52 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 53 // vector in the internal graph of this HloReachabilityMap for the given 54 // instruction and does not transitively update any other part of the 55 // adjacency matrix. 56 bool SetReachabilityToUnion( 57 tensorflow::gtl::ArraySlice<const HloInstruction*> inputs, 58 const HloInstruction* instruction); 59 60 // Sets entry so that IsReachable(a, b) will return true 61 // 62 // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency 63 // matrix in the internal graph of this HloReachabilityMap to have an edge 64 // from a to b and does not transitively update any other part of the 65 // adjacency matrix. 66 void SetReachable(const HloInstruction* a, const HloInstruction* b); 67 68 // Returns true if "b" is reachable from "a" 69 // 70 // Note that this function only correctly answers queries about reachability 71 // if the set of edges that have been provided to this class are transitive. 72 bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; 73 74 // Returns true if "b" is reachable from "a" or "a" is reachable from "b" 75 // 76 // Note that this function only correctly answers queries about reachability 77 // if the set of edges that have been provided to this class are transitive. 78 bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; 79 80 private: 81 // A bit-vector implementation specialized for this use case which provides a 82 // fast bitwise OR operation not available in tensorflow::gtl::BitMap. 83 class BitVector { 84 public: 85 BitVector() = default; 86 BitVector(size_t size) 87 : size_(size), vector_((size + kBits - 1) / kBits, 0) {} 88 89 // Return the bit at the given index. 90 bool Get(size_t index) const { 91 DCHECK(index >= 0 && index < size_); 92 return vector_[index / kBits] & (1ull << (index % kBits)); 93 } 94 95 // Set the bit at the given index. 96 void Set(size_t index) { 97 DCHECK(index >= 0 && index < size_); 98 vector_[index / kBits] |= 1ull << (index % kBits); 99 } 100 101 // Set this bitvector to the Logical OR of this bitvector and 'other'. 102 void OrWith(const BitVector& other) { 103 for (size_t i = 0; i < vector_.size(); ++i) { 104 vector_[i] |= other.vector_[i]; 105 } 106 } 107 108 // Set the bitvector to all zeros. 109 void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); } 110 111 bool operator==(const BitVector& other) const { 112 return vector_ == other.vector_; 113 } 114 bool operator!=(const BitVector& other) const { 115 return vector_ != other.vector_; 116 } 117 118 private: 119 using Word = uint64; 120 static const size_t kBits = 64; 121 122 // Number of bits in the bitvector. 123 size_t size_; 124 125 std::vector<Word> vector_; 126 }; 127 128 // Return the bitvector storing the reachability-to of the given instruction. 129 const BitVector& GetBitVector(const HloInstruction* instruction) const { 130 return bit_vectors_[GetIndex(instruction)]; 131 } 132 BitVector& GetBitVector(const HloInstruction* instruction) { 133 return bit_vectors_[GetIndex(instruction)]; 134 } 135 136 // Return the index of the given instruction. The value is used to index into 137 // the vector of BitVectors and the BitVectors themselves. 138 int GetIndex(const HloInstruction* instruction) const { 139 return FindOrDie(indices_, instruction); 140 } 141 142 // The number of instructions in the reachability map. 143 const size_t size_; 144 145 // Dense assignment from HloInstruction* to number. These numbers index 146 // into the bit_vectors_ vector and into the bits within a BitVector. 147 tensorflow::gtl::FlatMap<const HloInstruction*, int> indices_; 148 149 // Bitvectors holding the reachability to each instruction. The bit vector for 150 // instruction X includes ones for each instruction which X is reachable from. 151 std::vector<BitVector> bit_vectors_; 152 153 // A temporary used by SetReachabilityToUnion to avoid an allocation with each 154 // call to the method. 155 BitVector tmp_bit_vector_; 156 }; 157 158 } // namespace xla 159 160 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_ 161