1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/transpose_folding.h" 17 18 #include <vector> 19 20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 21 #include "tensorflow/compiler/xla/service/hlo_computation.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/shape_util.h" 24 #include "tensorflow/compiler/xla/status_macros.h" 25 #include "tensorflow/compiler/xla/util.h" 26 #include "tensorflow/core/lib/core/errors.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/platform/logging.h" 29 30 namespace xla { 31 32 namespace { 33 34 TransposeFolding::OperandIndices CanFoldOperandsIntoDot( 35 const HloInstruction& dot, 36 const TransposeFolding::TransposableGemmOperandsFn& 37 transposable_gemm_operands) { 38 if (HloOpcode::kDot != dot.opcode()) { 39 return {}; 40 } 41 42 TransposeFolding::OperandIndices operand_set; 43 for (int64 i = 0; i < dot.operand_count(); ++i) { 44 auto& operand = *dot.operand(i); 45 if (operand.IsRank2Transpose()) { 46 operand_set.push_back(i); 47 } 48 } 49 50 return transposable_gemm_operands(dot, operand_set); 51 } 52 53 TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( 54 const HloInstruction& convolution, 55 const TransposeFolding::TransposableConvOperandsFn& 56 transposable_conv_operands) { 57 if (HloOpcode::kConvolution != convolution.opcode()) { 58 return {}; 59 } 60 61 TransposeFolding::OperandIndices operand_set; 62 for (int64 i = 0; i < convolution.operand_count(); ++i) { 63 auto& operand = *convolution.operand(i); 64 if (operand.opcode() == HloOpcode::kTranspose) { 65 operand_set.push_back(i); 66 } 67 } 68 69 return transposable_conv_operands(convolution, operand_set); 70 } 71 72 using InstructionOperandsPair = 73 std::pair<HloInstruction*, TransposeFolding::OperandIndices>; 74 75 // Folds the operands of `dot` that are foldable transposes. `computation` is 76 // the parent HLO computation of `dot`. 77 // 78 // Returns whether the module is changed. 79 bool FoldTransposeIntoDot(InstructionOperandsPair pair) { 80 auto* dot = pair.first; 81 std::vector<HloInstruction*> instructions_to_fuse(1, dot); 82 for (const int64 operand_index : pair.second) { 83 instructions_to_fuse.push_back(dot->mutable_operand(operand_index)); 84 } 85 86 // Early-exit if no operands are foldable. 87 if (instructions_to_fuse.size() == 1) { 88 return false; 89 } 90 91 dot->parent()->CreateFusionInstruction( 92 instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); 93 return true; 94 } 95 96 // Folds the operands of `convolution` that are foldable transposes. 97 // `computation` is the parent HLO computation of `convolution`. 98 // 99 // Returns whether the module is changed. 100 bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { 101 auto& convolution = *pair.first; 102 auto& operand_indices = pair.second; 103 104 if (operand_indices.empty()) { 105 return false; 106 } 107 108 const ConvolutionDimensionNumbers& dnums = 109 convolution.convolution_dimension_numbers(); 110 ConvolutionDimensionNumbers new_dnums = dnums; 111 112 HloInstruction* new_lhs; 113 const int64 kLhsIdx = 0; 114 if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != 115 operand_indices.end()) { 116 HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); 117 const auto& transpose_dimensions = transpose.dimensions(); 118 HloInstruction& transpose_operand = *transpose.mutable_operand(0); 119 120 // Everything remains the same except for the input/output dimension 121 // numbers. We need to apply the transpose permutation to the original shape 122 // to figure out what the new logical dimensions are. 123 new_dnums.set_input_batch_dimension( 124 transpose_dimensions[dnums.input_batch_dimension()]); 125 new_dnums.set_input_feature_dimension( 126 transpose_dimensions[dnums.input_feature_dimension()]); 127 for (auto& input_spatial_dimension : 128 *new_dnums.mutable_input_spatial_dimensions()) { 129 input_spatial_dimension = transpose_dimensions[input_spatial_dimension]; 130 } 131 new_lhs = &transpose_operand; 132 } else { 133 new_lhs = convolution.mutable_operand(kLhsIdx); 134 } 135 136 HloInstruction* new_rhs; 137 const int64 kRhsIdx = 1; 138 if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != 139 operand_indices.end()) { 140 HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); 141 const auto& transpose_dimensions = transpose.dimensions(); 142 HloInstruction& transpose_operand = *transpose.mutable_operand(0); 143 144 // Everything remains the same except for the kernel dimension numbers. We 145 // need to apply the transpose permutation to the original shape to figure 146 // out what the new logical dimensions are. 147 new_dnums.set_kernel_input_feature_dimension( 148 transpose_dimensions[dnums.kernel_input_feature_dimension()]); 149 new_dnums.set_kernel_output_feature_dimension( 150 transpose_dimensions[dnums.kernel_output_feature_dimension()]); 151 for (auto& kernel_spatial_dimension : 152 *new_dnums.mutable_kernel_spatial_dimensions()) { 153 kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; 154 } 155 new_rhs = &transpose_operand; 156 } else { 157 new_rhs = convolution.mutable_operand(kRhsIdx); 158 } 159 160 auto new_conv = HloInstruction::CreateConvolve( 161 convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); 162 TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( 163 &convolution, std::move(new_conv))); 164 165 return true; 166 } 167 168 } // namespace 169 170 TransposeFolding::TransposeFolding( 171 TransposableGemmOperandsFn transposable_gemm_operands, 172 TransposableConvOperandsFn transposable_conv_operands) 173 : transposable_gemm_operands_(std::move(transposable_gemm_operands)), 174 transposable_conv_operands_(std::move(transposable_conv_operands)) {} 175 176 StatusOr<bool> TransposeFolding::Run(HloModule* module) { 177 // Modifying the graph while traversing is dangerous, so we find all folding 178 // opportunities before actually folding them. 179 std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots; 180 std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions; 181 auto visit_fn = [this, &foldable_dots, 182 &foldable_convolutions](HloInstruction* instruction) { 183 { 184 OperandIndices operand_indices = 185 CanFoldOperandsIntoDot(*instruction, transposable_gemm_operands_); 186 if (!operand_indices.empty()) { 187 foldable_dots.emplace_back(instruction, operand_indices); 188 } 189 } 190 { 191 OperandIndices operand_indices = CanFoldOperandsIntoConvolution( 192 *instruction, transposable_conv_operands_); 193 if (!operand_indices.empty()) { 194 foldable_convolutions.emplace_back( 195 std::make_pair(instruction, operand_indices)); 196 } 197 } 198 return tensorflow::Status::OK(); 199 }; 200 201 for (auto* comp : module->MakeNonfusionComputations()) { 202 TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); 203 } 204 205 bool changed = false; 206 for (InstructionOperandsPair& pair : foldable_dots) { 207 changed |= FoldTransposeIntoDot(pair); 208 } 209 for (InstructionOperandsPair& pair : foldable_convolutions) { 210 changed |= FoldTransposeIntoConvolution(pair); 211 } 212 return changed; 213 } 214 215 } // namespace xla 216