Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
     18 
     19 #include <ostream>
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     24 #include "tensorflow/compiler/xla/shape_tree.h"
     25 #include "tensorflow/compiler/xla/types.h"
     26 #include "tensorflow/compiler/xla/xla_data.pb.h"
     27 #include "tensorflow/core/lib/gtl/array_slice.h"
     28 #include "tensorflow/core/platform/macros.h"
     29 
     30 namespace xla {
     31 
     32 // Abstraction which identifies a specific point in the XLA graph. An
     33 // HloPosition specifies a ShapeIndex within the output of a specific
     34 // instruction.
     35 struct HloPosition {
     36   HloInstruction* instruction;
     37   ShapeIndex index;
     38 
     39   // Returns the shape at this position.
     40   const Shape& shape() const;
     41 
     42   string ToString() const;
     43 
     44   bool operator==(const HloPosition& other) const {
     45     return instruction == other.instruction && index == other.index;
     46   }
     47   bool operator!=(const HloPosition& other) const { return !(*this == other); }
     48 
     49   // Stable less-than operator using instruction id and index.
     50   bool operator<(const HloPosition& other) const {
     51     return instruction->unique_id() < other.instruction->unique_id() ||
     52            (instruction->unique_id() == other.instruction->unique_id() &&
     53             index < other.index);
     54   }
     55 };
     56 
     57 std::ostream& operator<<(std::ostream& out, const HloPosition& position);
     58 
     59 // Defines a single use of an HLO value.
     60 struct HloUse {
     61   // Instruction at which the value is used.
     62   HloInstruction* instruction;
     63 
     64   // The operand number in which the value is appears.
     65   int64 operand_number;
     66 
     67   // The shape index within the operand in which the value appears.
     68   ShapeIndex operand_index;
     69 
     70   string ToString() const;
     71 
     72   bool operator==(const HloUse& other) const {
     73     return instruction == other.instruction &&
     74            operand_number == other.operand_number &&
     75            operand_index == other.operand_index;
     76   }
     77 
     78   bool operator!=(const HloUse& other) const { return !(*this == other); }
     79 };
     80 
     81 std::ostream& operator<<(std::ostream& out, const HloUse& use);
     82 
     83 // Class describing a value used by the dataflow analysis. XLA arrays are
     84 // trivially a single HloValue. Tuples are made up of more than one HloValue: an
     85 // HloValue for the pointer vector, and an HloValue for each child element.
     86 //
     87 // Every HloValue is defined by a particular instruction and most instructions
     88 // define only a single HloValue. Instructions which define a single HloValue
     89 // include array-shaped instructions such as Add but also includes Tuple-shaped
     90 // instructions such as Tuple. The Tuple instruction defines a single HloValue
     91 // which is a vector of pointers to the values containing the Tuple
     92 // instruction's operands. Though the result of the Tuple instruction includes
     93 // multiple values only the top-level HloValue (the vector of pointers) is
     94 // defined by the Tuple instruction. The values containing the tuple elements
     95 // are defined by earlier instructions, usually the operands of the Tuple
     96 // instruction.
     97 //
     98 // Instructions which construct both the tuple *and* the tuple elements define
     99 // more than one HloValue. This includes (at least) tuple-shaped Constant,
    100 // Parameter, Infeed and While instructions. These tuple-shaped instructions do
    101 // not assemble a tuple from existing HloValues like the Tuple instruction does,
    102 // but rather define all the HloValues in the tuple.
    103 class HloValue {
    104  public:
    105   using Id = int64;
    106 
    107   // Predicate comparing HloValues by increasing id, useful for std::sort.
    108   static bool IdLessThan(const HloValue* a, const HloValue* b) {
    109     return a->id() < b->id();
    110   }
    111 
    112   // Predicate comparing HloValues by equal id, useful for std::unique.
    113   static bool IdEqual(const HloValue* a, const HloValue* b) {
    114     return a->id() == b->id();
    115   }
    116 
    117   // Construct an HloValue defined by 'instruction' at shape index 'index'. If
    118   // is_phi is true, then this value is a phi value, for example, at the
    119   // parameter of a while body computation. Phi values are only used in the SSA
    120   // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true).
    121   HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index,
    122            bool is_phi = false);
    123 
    124   // Sets the positions in the module at which the HloValue appears. Updates
    125   // uses. Should be called once and only once. The defining position should not
    126   // be included in 'positions' as this is set at construction time.
    127   void SetPositionsAndComputeUses(
    128       tensorflow::gtl::ArraySlice<HloPosition> positions);
    129 
    130   // Return a unique identifier for this HloValue. This value is used for stable
    131   // sorting and iteration
    132   Id id() const { return id_; }
    133 
    134   // Returns whether this value is a phi value.
    135   bool is_phi() const { return is_phi_; }
    136 
    137   // Return the position where this value is defined.
    138   const HloPosition& defining_position() const { return positions_[0]; }
    139 
    140   // Return the instruction which defines this HloValue.
    141   HloInstruction* defining_instruction() const {
    142     return defining_position().instruction;
    143   }
    144 
    145   // Return the shape index at which this HloValue is defined in the output of
    146   // its defining instruction.
    147   const ShapeIndex& defining_index() const { return defining_position().index; }
    148 
    149   // Return the shape of this HloValue.
    150   const Shape& shape() const { return defining_position().shape(); }
    151 
    152   // Return all positions of the HloValue in the module.
    153   const std::vector<HloPosition>& positions() const { return positions_; }
    154 
    155   // Return all uses of the HloValue.
    156   const std::vector<HloUse>& uses() const { return uses_; }
    157 
    158   // Get whether this HloValue is live out of the module.
    159   bool live_out_of_module() const { return live_out_of_module_; }
    160 
    161   bool operator==(const HloValue& other) const;
    162   bool operator!=(const HloValue& other) const;
    163 
    164   // Return a single-line string representation of the value.
    165   string ToShortString() const;
    166 
    167   string ToString(int indent = 0) const;
    168 
    169  private:
    170   // Unique identifier for this HloValue. Used for stable sorting and iteration.
    171   const Id id_;
    172 
    173   // Whether this instruction is a phi value.
    174   const bool is_phi_;
    175 
    176   // The set of positions of this HloValue. The first element is always the
    177   // position of the definition.
    178   std::vector<HloPosition> positions_;
    179 
    180   // The set of uses of this HloValue.
    181   std::vector<HloUse> uses_;
    182 
    183   // Whether this value is live out of the HLO module.
    184   bool live_out_of_module_ = false;
    185 
    186   // Whether this value is live out of its computation.
    187   bool live_out_of_computation_ = false;
    188 };
    189 
    190 std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value);
    191 
    192 // A class representing the possible set of HloValues at a particular point
    193 // (shape index in the output of an instruction) in the XLA graph. This set
    194 // contains the set of reaching HloValue definitions. For a simple array-shaped
    195 // instruction like Add, the HloValueSet of the top-level of the instruction's
    196 // output trivially contains only the HloValue defined by the instruction. For
    197 // instructions which have non-trivial dataflow such as Tuple or Select, the
    198 // HloValueSets of the instruction's output contains one or more HloValues
    199 // defined by the instruction's operands or defined further up in the XLA graph.
    200 class HloValueSet {
    201  public:
    202   HloValueSet() = default;
    203 
    204   explicit HloValueSet(tensorflow::gtl::ArraySlice<const HloValue*> values)
    205       : values_(values.begin(), values.end()) {
    206     SortAndUniquifyValues();
    207   }
    208 
    209   // Sets this value set to the union of the given value sets. Returns whether
    210   // this value set changed.
    211   bool AssignUnionOf(tensorflow::gtl::ArraySlice<const HloValueSet*> inputs);
    212 
    213   // Return the vector of HloValues in the set. Values in the vector are unique
    214   // and stably sorted by value id.
    215   const std::vector<const HloValue*>& values() const { return values_; }
    216 
    217   // Adds the value to the set.  Returns true iff the value was added and didn't
    218   // already exist in the set.
    219   bool AddValue(const HloValue* value);
    220 
    221   // Clear all values from the set.
    222   void Clear() { values_.clear(); }
    223 
    224   // Return the unique HLO value in the set. CHECKs if the set does not contain
    225   // exactly one value.
    226   const HloValue& GetUniqueValue() const {
    227     CHECK_EQ(values_.size(), 1);
    228     return *values_[0];
    229   }
    230 
    231   bool operator==(const HloValueSet& other) const {
    232     if (values_.size() != other.values_.size()) return false;
    233     for (size_t i = 0; i < values_.size(); ++i) {
    234       if (values_[i]->id() != other.values_[i]->id()) {
    235         return false;
    236       }
    237     }
    238     return true;
    239   }
    240   bool operator!=(const HloValueSet& other) const { return !(*this == other); }
    241 
    242   string ToString() const;
    243 
    244  private:
    245   // Sorts value_ and removes duplicates. This should be called after adding any
    246   // elements to values_.
    247   void SortAndUniquifyValues();
    248 
    249   // HloValues sorted by HloValue::Id.
    250   std::vector<const HloValue*> values_;
    251 };
    252 
    253 std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value);
    254 
    255 // A class collecting the HloValues which might be contained in the output of
    256 // an HLO instruction. For array-shaped instructions, an InstructionValueSet
    257 // trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets
    258 // hold multiple HloValueSets.
    259 class InstructionValueSet : public ShapeTree<HloValueSet> {
    260  public:
    261   InstructionValueSet(const Shape& shape) : ShapeTree<HloValueSet>(shape) {}
    262 
    263   // Sets this value set to the union of the given value sets. Returns whether
    264   // this value set changed.
    265   bool AssignUnionOf(
    266       tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
    267 
    268   string ToString() const;
    269 };
    270 
    271 std::ostream& operator<<(std::ostream& out,
    272                          const InstructionValueSet& instruction_value_set);
    273 
    274 }  // namespace xla
    275 
    276 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_
    277