Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_value.h"
     17 
     18 #include <algorithm>
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/map_util.h"
     22 #include "tensorflow/compiler/xla/ptr_util.h"
     23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/hlo_module.h"
     26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     27 #include "tensorflow/compiler/xla/shape_util.h"
     28 #include "tensorflow/compiler/xla/status.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/util.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/lib/strings/strcat.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 
     36 namespace xla {
     37 
     38 using ::tensorflow::str_util::Join;
     39 using ::tensorflow::strings::StrAppend;
     40 using ::tensorflow::strings::StrCat;
     41 
     42 const Shape& HloPosition::shape() const {
     43   return ShapeUtil::GetSubshape(instruction->shape(), index);
     44 }
     45 
     46 string HloPosition::ToString() const {
     47   string index_str =
     48       ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : "";
     49   return StrCat(instruction->name(), index_str);
     50 }
     51 
     52 std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
     53   out << position.ToString();
     54   return out;
     55 }
     56 
     57 string HloUse::ToString() const {
     58   string index_str =
     59       ShapeUtil::IsTuple(instruction->operand(operand_number)->shape())
     60           ? (" " + operand_index.ToString())
     61           : "";
     62   return StrCat(instruction->name(), ", operand ", operand_number, index_str);
     63 }
     64 
     65 std::ostream& operator<<(std::ostream& out, const HloUse& use) {
     66   out << use.ToString();
     67   return out;
     68 }
     69 
     70 HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
     71                    const ShapeIndex& index, bool is_phi)
     72     : id_(id), is_phi_(is_phi) {
     73   // The defining position is always the first element in the positions_ vector.
     74   positions_.push_back(HloPosition{instruction, index});
     75 }
     76 
     77 bool HloValue::operator==(const HloValue& other) const {
     78   bool equal = defining_instruction() == other.defining_instruction() &&
     79                defining_index() == other.defining_index();
     80   // If the values are equal they most both be phi (or non phi).
     81   CHECK(!(equal && is_phi() != other.is_phi()));
     82   return equal;
     83 }
     84 
     85 bool HloValue::operator!=(const HloValue& other) const {
     86   return !(*this == other);
     87 }
     88 
     89 string HloValue::ToShortString() const {
     90   string index_str = ShapeUtil::IsTuple(defining_instruction()->shape())
     91                          ? defining_index().ToString()
     92                          : "";
     93   return StrCat(id_, " ", is_phi_ ? "PHI " : "", defining_instruction()->name(),
     94                 index_str);
     95 }
     96 
     97 string HloValue::ToString(int indent) const {
     98   string indentation(indent, ' ');
     99   string out = StrCat(indentation, ToShortString(), ", positions:\n");
    100   for (const HloPosition& position : positions()) {
    101     StrAppend(&out, indentation, "  ", position.ToString(), "\n");
    102   }
    103   StrAppend(&out, indentation, " uses:\n");
    104   for (const HloUse& use : uses()) {
    105     StrAppend(&out, indentation, "  ", use.ToString(), "\n");
    106   }
    107   return out;
    108 }
    109 
    110 namespace {
    111 
    112 // Returns true if the instruction 'user' may use the value at the given
    113 // ShapeIndex in the given operand. Generally, instruction which pass through
    114 // values transparently without reading the value are not considered to use the
    115 // value.
    116 bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
    117                         const HloInstruction* user) {
    118   switch (user->opcode()) {
    119     case HloOpcode::kGetTupleElement:
    120     case HloOpcode::kCopy:
    121       // These instructions only access the top-level values of their
    122       // operand. Non-top-level (nested) values are passed through
    123       // transparently.
    124       CHECK_EQ(operand_number, 0);
    125       return index.empty();
    126     case HloOpcode::kSelect:
    127       // Select does not use any nested elements of its selected-from operands
    128       // (operand 1 and 2)
    129       CHECK_GE(operand_number, 0);
    130       CHECK_LE(operand_number, 2);
    131       return operand_number == 0 || index.empty();
    132 
    133     case HloOpcode::kTuple:
    134       // These instructions always pass through their operands transparently.
    135       return false;
    136 
    137     case HloOpcode::kCall:
    138     case HloOpcode::kWhile:
    139       // Although call and while instructions pass through their operands, they
    140       // are considered uses.
    141       return true;
    142 
    143     default:
    144       return true;
    145   }
    146 }
    147 
    148 }  // namespace
    149 
    150 void HloValue::SetPositionsAndComputeUses(
    151     tensorflow::gtl::ArraySlice<HloPosition> positions) {
    152   CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once.";
    153 
    154   // The positions must be unique and should not contain the defining position
    155   // as this is added at construction time.
    156   for (const HloPosition& position_a : positions) {
    157     DCHECK_NE(position_a, defining_position());
    158     for (const HloPosition& position_b : positions) {
    159       if (&position_a != &position_b) {
    160         DCHECK_NE(position_a, position_b);
    161       }
    162     }
    163   }
    164 
    165   positions_.insert(positions_.end(), positions.begin(), positions.end());
    166 
    167   // Gather the computation roots at which this value appears.
    168   tensorflow::gtl::FlatSet<HloInstruction*> root_positions;
    169   for (const HloPosition& position : positions_) {
    170     if (position.instruction ==
    171         position.instruction->parent()->root_instruction()) {
    172       root_positions.insert(position.instruction);
    173     }
    174   }
    175 
    176   // Build vector of HloUses for the value.
    177   for (const HloPosition& position : positions_) {
    178     for (HloInstruction* user : position.instruction->users()) {
    179       for (int64 operand_number : user->OperandIndices(position.instruction)) {
    180         // Root instructions of computations are considered to be uses whether
    181         // or not the root instruction itself actually uses the value.
    182         if (MayUseOperandValue(operand_number, position.index, user) ||
    183             ContainsKey(root_positions, user)) {
    184           HloUse new_use{user, operand_number, position.index};
    185 
    186           // The new use must not already exist in uses_.
    187           for (const HloUse& use : uses_) {
    188             DCHECK_NE(use, new_use);
    189           }
    190 
    191           uses_.push_back(std::move(new_use));
    192         }
    193       }
    194     }
    195 
    196     // Update liveout status of this HloValue.
    197     const HloModule& module = *position.instruction->parent()->parent();
    198     if (position.instruction ==
    199         module.entry_computation()->root_instruction()) {
    200       live_out_of_module_ = true;
    201     }
    202   }
    203 }
    204 
    205 std::ostream& operator<<(std::ostream& out, const HloValue& value) {
    206   out << value.ToShortString();
    207   return out;
    208 }
    209 
    210 void HloValueSet::SortAndUniquifyValues() {
    211   std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
    212   values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
    213                 values_.end());
    214 }
    215 
    216 string HloValueSet::ToString() const {
    217   return StrCat("HloValueSet: ",
    218                 Join(values_, ", ", [](string* result, const HloValue* value) {
    219                   result->append(value->ToShortString());
    220                 }));
    221 }
    222 
    223 bool HloValueSet::AssignUnionOf(
    224     tensorflow::gtl::ArraySlice<const HloValueSet*> inputs) {
    225   HloValueSet union_set;
    226   for (const HloValueSet* input : inputs) {
    227     for (const HloValue* value : input->values()) {
    228       union_set.values_.push_back(value);
    229     }
    230   }
    231   union_set.SortAndUniquifyValues();
    232   if (*this != union_set) {
    233     *this = union_set;
    234     return true;
    235   }
    236   return false;
    237 }
    238 
    239 bool HloValueSet::AddValue(const HloValue* value) {
    240   auto it = std::lower_bound(values_.begin(), values_.end(), value,
    241                              HloValue::IdLessThan);
    242   if (it == values_.end() || (*it)->id() != value->id()) {
    243     values_.insert(it, value);
    244     return true;
    245   }
    246   return false;  // already exists
    247 }
    248 
    249 std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
    250   out << value_set.ToString();
    251   return out;
    252 }
    253 
    254 bool InstructionValueSet::AssignUnionOf(
    255     tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
    256   CHECK_GT(inputs.size(), 0);
    257   for (int i = 1; i < inputs.size(); ++i) {
    258     DCHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
    259   }
    260   bool changed = false;
    261   for (auto& pair : *this) {
    262     const ShapeIndex& index = pair.first;
    263     HloValueSet& value_set = pair.second;
    264 
    265     std::vector<const HloValueSet*> input_value_sets;
    266     for (const InstructionValueSet* input : inputs) {
    267       input_value_sets.push_back(&input->element(index));
    268     }
    269     changed |= value_set.AssignUnionOf(input_value_sets);
    270   }
    271 
    272   return changed;
    273 }
    274 
    275 std::ostream& operator<<(std::ostream& out,
    276                          const InstructionValueSet& instruction_value_set) {
    277   out << instruction_value_set.ToString();
    278   return out;
    279 }
    280 
    281 string InstructionValueSet::ToString() const {
    282   string out =
    283       StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
    284   ForEachElement([this, &out](const ShapeIndex& index,
    285                               const HloValueSet& value_set) {
    286     StrAppend(&out, "  ", index.ToString(), " : ", value_set.ToString(), "\n");
    287   });
    288   return out;
    289 }
    290 
    291 }  // namespace xla
    292