1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_cse.h" 17 18 #include <list> 19 #include <map> 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 26 #include "tensorflow/compiler/xla/layout_util.h" 27 #include "tensorflow/compiler/xla/literal_util.h" 28 #include "tensorflow/compiler/xla/service/hlo_computation.h" 29 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 30 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/lib/core/errors.h" 35 #include "tensorflow/core/lib/gtl/inlined_vector.h" 36 37 namespace xla { 38 39 namespace { 40 41 // Find and combine identical constants. Constants are identical if they have 42 // the same type and value. 43 bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { 44 bool changed = false; 45 46 // Map from ShortDebugString of the layoutless shape of the constant to the 47 // set of constant instructions with that shape. Layoutless shape is used to 48 // bin possible common constants together to reduce number of constant 49 // comparisons. If we end up having too many constant comparisons, a more 50 // precise binning might have to be used. 51 std::multimap<string, HloInstruction*> constants; 52 53 auto inst_it = computation->instructions().begin(); 54 while (inst_it != computation->instructions().end()) { 55 HloInstruction* instruction = *inst_it; 56 57 // Advance list iterator before loop body because iterator may be 58 // invalidated due to deletion. 59 ++inst_it; 60 61 if (instruction->opcode() == HloOpcode::kConstant) { 62 Shape shape = instruction->shape(); 63 if (!is_layout_sensitive) { 64 LayoutUtil::ClearLayout(&shape); 65 } 66 string shape_string = shape.ShortDebugString(); 67 68 // Compare against all constants with the same shape 69 auto range = constants.equal_range(shape_string); 70 HloInstruction* match = nullptr; 71 for (auto it = range.first; it != range.second; ++it) { 72 if (instruction->literal() == it->second->literal()) { 73 match = it->second; 74 break; 75 } 76 } 77 if (match == nullptr) { 78 constants.emplace(shape_string, instruction); 79 } else { 80 // Match found, replace this instruction with the one in the multimap. 81 TF_CHECK_OK(instruction->ReplaceAllUsesWith(match)); 82 TF_CHECK_OK(computation->RemoveInstruction(instruction)); 83 changed = true; 84 } 85 } 86 } 87 88 return changed; 89 } 90 91 } // namespace 92 93 StatusOr<bool> HloCSE::Run(HloModule* module) { 94 bool changed = false; 95 const std::function<bool(const HloInstruction*, const HloInstruction*)> 96 eq_instructions = std::equal_to<const HloInstruction*>(); 97 const std::function<bool(const HloComputation*, const HloComputation*)> 98 eq_computations = std::equal_to<const HloComputation*>(); 99 for (auto* computation : module->computations()) { 100 changed |= CombineConstants(computation, is_layout_sensitive_); 101 102 std::list<HloInstruction*> post_order = 103 computation->MakeInstructionPostOrder(); 104 std::set<HloInstruction*> removed_instructions; 105 for (auto instruction : post_order) { 106 // If the instruction has already been removed by CSE skip over it. 107 if (removed_instructions.count(instruction) > 0 || 108 instruction->operand_count() == 0) { 109 continue; 110 } 111 112 // An instruction is considered to be equivalent to another only if they 113 // share the exact same set of operands. So to find equivalent 114 // instructions, we just search among instructions which share operand(0) 115 // of this instruction. 116 const HloInstruction* operand = instruction->operand(0); 117 118 tensorflow::gtl::InlinedVector<HloInstruction*, 8> 119 equivalent_instructions; 120 for (HloInstruction* user : operand->users()) { 121 if (user != instruction && 122 user->Identical(*instruction, eq_instructions, eq_computations, 123 is_layout_sensitive_)) { 124 equivalent_instructions.push_back(user); 125 } 126 } 127 128 // Replace all equivalent instructions with this instruction. 129 for (HloInstruction* equivalent_instruction : equivalent_instructions) { 130 TF_RETURN_IF_ERROR( 131 equivalent_instruction->ReplaceAllUsesWith(instruction)); 132 TF_RETURN_IF_ERROR( 133 computation->RemoveInstruction(equivalent_instruction)); 134 removed_instructions.insert(equivalent_instruction); 135 changed = true; 136 } 137 } 138 } 139 return changed; 140 } 141 142 } // namespace xla 143