Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
     18 
     19 #include <list>
     20 #include <vector>
     21 
     22 #include "tensorflow/compiler/xla/map_util.h"
     23 #include "tensorflow/compiler/xla/types.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/gtl/array_slice.h"
     26 #include "tensorflow/core/lib/gtl/flatmap.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace xla {
     30 
     31 class HloInstruction;
     32 
     33 // A class for representing reachability between HloInstructions.
     34 //
     35 // !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix
     36 // and it is up to the user of the class to set the adjacency matrix such that
     37 // it represents reachability, i.e. such that it is transitive. That the graph
     38 // be transitive is thus not an invariant of this class, but it is required for
     39 // the name of the class and its methods to make sense.
     40 class HloReachabilityMap {
     41  public:
     42   // Sets up a graph with no edges and where the nodes correspond to the given
     43   // instructions.
     44   explicit HloReachabilityMap(const std::list<HloInstruction*>& instructions);
     45 
     46   // Set the reachability set of 'instruction' to the union of the reachability
     47   // sets of 'inputs'. Upon return, IsReachable(x, instruction) where
     48   // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true
     49   // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from
     50   // itself. Returns whether the reachability set of 'instruction' changed.
     51   //
     52   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
     53   // vector in the internal graph of this HloReachabilityMap for the given
     54   // instruction and does not transitively update any other part of the
     55   // adjacency matrix.
     56   bool SetReachabilityToUnion(
     57       tensorflow::gtl::ArraySlice<const HloInstruction*> inputs,
     58       const HloInstruction* instruction);
     59 
     60   // Sets entry so that IsReachable(a, b) will return true
     61   //
     62   // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency
     63   // matrix in the internal graph of this HloReachabilityMap to have an edge
     64   // from a to b and does not transitively update any other part of the
     65   // adjacency matrix.
     66   void SetReachable(const HloInstruction* a, const HloInstruction* b);
     67 
     68   // Returns true if "b" is reachable from "a"
     69   //
     70   // Note that this function only correctly answers queries about reachability
     71   // if the set of edges that have been provided to this class are transitive.
     72   bool IsReachable(const HloInstruction* a, const HloInstruction* b) const;
     73 
     74   // Returns true if "b" is reachable from "a" or "a" is reachable from "b"
     75   //
     76   // Note that this function only correctly answers queries about reachability
     77   // if the set of edges that have been provided to this class are transitive.
     78   bool IsConnected(const HloInstruction* a, const HloInstruction* b) const;
     79 
     80  private:
     81   // A bit-vector implementation specialized for this use case which provides a
     82   // fast bitwise OR operation not available in tensorflow::gtl::BitMap.
     83   class BitVector {
     84    public:
     85     BitVector() = default;
     86     BitVector(size_t size)
     87         : size_(size), vector_((size + kBits - 1) / kBits, 0) {}
     88 
     89     // Return the bit at the given index.
     90     bool Get(size_t index) const {
     91       DCHECK(index >= 0 && index < size_);
     92       return vector_[index / kBits] & (1ull << (index % kBits));
     93     }
     94 
     95     // Set the bit at the given index.
     96     void Set(size_t index) {
     97       DCHECK(index >= 0 && index < size_);
     98       vector_[index / kBits] |= 1ull << (index % kBits);
     99     }
    100 
    101     // Set this bitvector to the Logical OR of this bitvector and 'other'.
    102     void OrWith(const BitVector& other) {
    103       for (size_t i = 0; i < vector_.size(); ++i) {
    104         vector_[i] |= other.vector_[i];
    105       }
    106     }
    107 
    108     // Set the bitvector to all zeros.
    109     void SetToZero() { std::fill(vector_.begin(), vector_.end(), 0); }
    110 
    111     bool operator==(const BitVector& other) const {
    112       return vector_ == other.vector_;
    113     }
    114     bool operator!=(const BitVector& other) const {
    115       return vector_ != other.vector_;
    116     }
    117 
    118    private:
    119     using Word = uint64;
    120     static const size_t kBits = 64;
    121 
    122     // Number of bits in the bitvector.
    123     size_t size_;
    124 
    125     std::vector<Word> vector_;
    126   };
    127 
    128   // Return the bitvector storing the reachability-to of the given instruction.
    129   const BitVector& GetBitVector(const HloInstruction* instruction) const {
    130     return bit_vectors_[GetIndex(instruction)];
    131   }
    132   BitVector& GetBitVector(const HloInstruction* instruction) {
    133     return bit_vectors_[GetIndex(instruction)];
    134   }
    135 
    136   // Return the index of the given instruction. The value is used to index into
    137   // the vector of BitVectors and the BitVectors themselves.
    138   int GetIndex(const HloInstruction* instruction) const {
    139     return FindOrDie(indices_, instruction);
    140   }
    141 
    142   // The number of instructions in the reachability map.
    143   const size_t size_;
    144 
    145   // Dense assignment from HloInstruction* to number. These numbers index
    146   // into the bit_vectors_ vector and into the bits within a BitVector.
    147   tensorflow::gtl::FlatMap<const HloInstruction*, int> indices_;
    148 
    149   // Bitvectors holding the reachability to each instruction. The bit vector for
    150   // instruction X includes ones for each instruction which X is reachable from.
    151   std::vector<BitVector> bit_vectors_;
    152 
    153   // A temporary used by SetReachabilityToUnion to avoid an allocation with each
    154   // call to the method.
    155   BitVector tmp_bit_vector_;
    156 };
    157 
    158 }  // namespace xla
    159 
    160 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REACHABILITY_H_
    161