1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" 17 18 #include <memory> 19 #include <string> 20 #include <utility> 21 #include <vector> 22 23 #include "tensorflow/compiler/xla/layout_util.h" 24 #include "tensorflow/compiler/xla/literal_util.h" 25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 26 #include "tensorflow/compiler/xla/service/hlo_computation.h" 27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 30 #include "tensorflow/compiler/xla/service/hlo_query.h" 31 #include "tensorflow/compiler/xla/shape_util.h" 32 #include "tensorflow/compiler/xla/types.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 35 namespace xla { 36 namespace { 37 38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) { 39 if (hlo->shape().element_type() != type) { 40 Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); 41 hlo = hlo->parent()->AddInstruction( 42 HloInstruction::CreateConvert(shape, hlo)); 43 } 44 CHECK_EQ(hlo->shape().element_type(), type); 45 return hlo; 46 } 47 48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { 49 for (HloInstruction* operand : hlo->operands()) { 50 if (operand->shape().element_type() == type) { 51 return true; 52 } 53 } 54 return false; 55 } 56 57 // Finds out the Tuple Shape of the new instruction after converting the element 58 // type of the operands of the original instruction from `from_type` to 59 // `to_type`. 60 // 61 // This routine assumes the resulting `shape` of the original instruction is a 62 // non-nested tuple. This assumption is currently safe as only kTuple, kInfeed, 63 // kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce 64 // results with tuple shapes, and this routine is only called to convert the 65 // result shapes of kBatchNorm* HLO instructions, which are non-nested tuples. 66 Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type, 67 PrimitiveType to_type) { 68 std::vector<Shape> new_tuple_subshapes; 69 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { 70 Shape subshape = ShapeUtil::GetTupleElementShape(shape, i); 71 CHECK(!ShapeUtil::IsTuple(subshape)); 72 if (subshape.element_type() == from_type) { 73 subshape = ShapeUtil::ChangeElementType(subshape, to_type); 74 } 75 new_tuple_subshapes.push_back(subshape); 76 } 77 return ShapeUtil::MakeTupleShape(new_tuple_subshapes); 78 } 79 80 // Converts the elements of the result of `hlo` to produce a new tuple with 81 // shape `to_shape`. 82 // 83 // This routine assumes `hlo` is an instruction that produces a non-nested Tuple 84 // as a result. 85 HloInstruction* ConvertTupleElements(HloInstruction* hlo, 86 const Shape& to_shape) { 87 const Shape& shape = hlo->shape(); 88 HloComputation* computation = hlo->parent(); 89 std::vector<HloInstruction*> tuple_elements; 90 for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { 91 const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i); 92 HloInstruction* element = computation->AddInstruction( 93 HloInstruction::CreateGetTupleElement(ele_shape, hlo, i)); 94 const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i); 95 CHECK(!ShapeUtil::IsTuple(ele_shape)); 96 if (ele_shape.element_type() != to_ele_shape.element_type()) { 97 element = computation->AddInstruction( 98 HloInstruction::CreateConvert(to_ele_shape, element)); 99 } 100 tuple_elements.push_back(element); 101 } 102 return computation->AddInstruction( 103 HloInstruction::CreateTuple(tuple_elements)); 104 } 105 106 } // namespace 107 108 HloElementTypeConverter::HloElementTypeConverter( 109 PrimitiveType eliminate_type, PrimitiveType replace_with_type) 110 : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {} 111 112 // This routine converts the arithmetic operations in the given module that use 113 // eliminate_type_ to operations that use replace_with_type_. 114 StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) { 115 XLA_VLOG_LINES( 116 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString()); 117 118 if (eliminate_type_ == replace_with_type_) { 119 return false; 120 } 121 122 bool changed = false; 123 for (auto* computation : module->computations()) { 124 for (auto* hlo : computation->MakeInstructionPostOrder()) { 125 const auto opcode = hlo->opcode(); 126 // These are ops where it does not make sense to convert them. 127 if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || 128 opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert || 129 opcode == HloOpcode::kGetTupleElement || 130 opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) { 131 continue; 132 } 133 134 // We cannot change a CustomCall since we have no way of adjusting the 135 // called binary to expect the updated type. 136 if (opcode == HloOpcode::kCustomCall) { 137 continue; 138 } 139 140 // These are ops with embedded computations where it suffices to convert 141 // the embedded computations instead of converting the ops themselves. 142 if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || 143 opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || 144 opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || 145 opcode == HloOpcode::kSelectAndScatter || 146 opcode == HloOpcode::kConditional) { 147 continue; 148 } 149 TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); 150 151 if (!HasOperandType(hlo, eliminate_type_)) { 152 // If this CHECK fires, then this was an instruction that does not take 153 // the elimination type as an operand but it does return it. This pass 154 // does not have a feature to change the output type in that case, so 155 // instead of silently failing to eliminate the type, it fails loudly. 156 TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_); 157 continue; 158 } 159 160 // Handle instructions that perform arithmetic operations and contain 161 // operands with eliminate_type_. 162 // 163 // First, convert the operands with eliminate_type_ to operands with 164 // replace_with_type_. 165 std::vector<HloInstruction*> new_operands; 166 for (HloInstruction* operand : hlo->operands()) { 167 if (operand->shape().element_type() == eliminate_type_) { 168 operand = ToElementType(operand, replace_with_type_); 169 } 170 new_operands.push_back(operand); 171 } 172 173 // Then find out the result type of the new instruction with the same 174 // opcode but using the converted operands, create the new instruction, 175 // and convert the result of the new instruction back to match the result 176 // type of the original instruction. 177 HloInstruction* new_hlo; 178 if (hlo->shape().element_type() == eliminate_type_) { 179 Shape shape = 180 ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); 181 new_hlo = computation->AddInstruction( 182 hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); 183 new_hlo = ToElementType(new_hlo, eliminate_type_); 184 } else if (ShapeUtil::IsTuple(hlo->shape())) { 185 Shape old_shape = hlo->shape(); 186 Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, 187 replace_with_type_); 188 new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( 189 new_shape, new_operands, hlo->GetModule())); 190 // Convert the elements of the result of `new_hlo` to produce a new 191 // tuple with shape `old_shape`. 192 new_hlo = ConvertTupleElements(new_hlo, old_shape); 193 } else { 194 new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( 195 hlo->shape(), new_operands, hlo->GetModule())); 196 } 197 198 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); 199 changed = true; 200 } 201 } 202 XLA_VLOG_LINES( 203 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString()); 204 return changed; 205 } 206 207 } // namespace xla 208