Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_cse.h"
     17 
     18 #include <list>
     19 #include <map>
     20 #include <memory>
     21 #include <set>
     22 #include <string>
     23 #include <utility>
     24 #include <vector>
     25 
     26 #include "tensorflow/compiler/xla/layout_util.h"
     27 #include "tensorflow/compiler/xla/literal_util.h"
     28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     31 #include "tensorflow/compiler/xla/shape_util.h"
     32 #include "tensorflow/compiler/xla/types.h"
     33 #include "tensorflow/compiler/xla/xla_data.pb.h"
     34 #include "tensorflow/core/lib/core/errors.h"
     35 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     36 
     37 namespace xla {
     38 
     39 namespace {
     40 
     41 // Find and combine identical constants. Constants are identical if they have
     42 // the same type and value.
     43 bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
     44   bool changed = false;
     45 
     46   // Map from ShortDebugString of the layoutless shape of the constant to the
     47   // set of constant instructions with that shape. Layoutless shape is used to
     48   // bin possible common constants together to reduce number of constant
     49   // comparisons. If we end up having too many constant comparisons, a more
     50   // precise binning might have to be used.
     51   std::multimap<string, HloInstruction*> constants;
     52 
     53   auto inst_it = computation->instructions().begin();
     54   while (inst_it != computation->instructions().end()) {
     55     HloInstruction* instruction = *inst_it;
     56 
     57     // Advance list iterator before loop body because iterator may be
     58     // invalidated due to deletion.
     59     ++inst_it;
     60 
     61     if (instruction->opcode() == HloOpcode::kConstant) {
     62       Shape shape = instruction->shape();
     63       if (!is_layout_sensitive) {
     64         LayoutUtil::ClearLayout(&shape);
     65       }
     66       string shape_string = shape.ShortDebugString();
     67 
     68       // Compare against all constants with the same shape
     69       auto range = constants.equal_range(shape_string);
     70       HloInstruction* match = nullptr;
     71       for (auto it = range.first; it != range.second; ++it) {
     72         if (instruction->literal() == it->second->literal()) {
     73           match = it->second;
     74           break;
     75         }
     76       }
     77       if (match == nullptr) {
     78         constants.emplace(shape_string, instruction);
     79       } else {
     80         // Match found, replace this instruction with the one in the multimap.
     81         TF_CHECK_OK(instruction->ReplaceAllUsesWith(match));
     82         TF_CHECK_OK(computation->RemoveInstruction(instruction));
     83         changed = true;
     84       }
     85     }
     86   }
     87 
     88   return changed;
     89 }
     90 
     91 }  // namespace
     92 
     93 StatusOr<bool> HloCSE::Run(HloModule* module) {
     94   bool changed = false;
     95   const std::function<bool(const HloInstruction*, const HloInstruction*)>
     96       eq_instructions = std::equal_to<const HloInstruction*>();
     97   const std::function<bool(const HloComputation*, const HloComputation*)>
     98       eq_computations = std::equal_to<const HloComputation*>();
     99   for (auto* computation : module->computations()) {
    100     changed |= CombineConstants(computation, is_layout_sensitive_);
    101 
    102     std::list<HloInstruction*> post_order =
    103         computation->MakeInstructionPostOrder();
    104     std::set<HloInstruction*> removed_instructions;
    105     for (auto instruction : post_order) {
    106       // If the instruction has already been removed by CSE skip over it.
    107       if (removed_instructions.count(instruction) > 0 ||
    108           instruction->operand_count() == 0) {
    109         continue;
    110       }
    111 
    112       // An instruction is considered to be equivalent to another only if they
    113       // share the exact same set of operands. So to find equivalent
    114       // instructions, we just search among instructions which share operand(0)
    115       // of this instruction.
    116       const HloInstruction* operand = instruction->operand(0);
    117 
    118       tensorflow::gtl::InlinedVector<HloInstruction*, 8>
    119           equivalent_instructions;
    120       for (HloInstruction* user : operand->users()) {
    121         if (user != instruction &&
    122             user->Identical(*instruction, eq_instructions, eq_computations,
    123                             is_layout_sensitive_)) {
    124           equivalent_instructions.push_back(user);
    125         }
    126       }
    127 
    128       // Replace all equivalent instructions with this instruction.
    129       for (HloInstruction* equivalent_instruction : equivalent_instructions) {
    130         TF_RETURN_IF_ERROR(
    131             equivalent_instruction->ReplaceAllUsesWith(instruction));
    132         TF_RETURN_IF_ERROR(
    133             computation->RemoveInstruction(equivalent_instruction));
    134         removed_instructions.insert(equivalent_instruction);
    135         changed = true;
    136       }
    137     }
    138   }
    139   return changed;
    140 }
    141 
    142 }  // namespace xla
    143