1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/dot_decomposer.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_computation.h" 19 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 20 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 21 #include "tensorflow/compiler/xla/shape_util.h" 22 #include "tensorflow/compiler/xla/status_macros.h" 23 #include "tensorflow/compiler/xla/types.h" 24 #include "tensorflow/core/platform/logging.h" 25 26 namespace xla { 27 28 namespace { 29 30 // TODO(b/69062148) Remove this code when all backends support BatchDot 31 // natively. 32 Status DecomposeBatchDot(HloInstruction* dot) { 33 auto computation = dot->parent(); 34 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); 35 HloInstruction* lhs = dot->mutable_operand(0); 36 HloInstruction* rhs = dot->mutable_operand(1); 37 const Shape& lhs_shape = lhs->shape(); 38 const Shape& rhs_shape = rhs->shape(); 39 const Shape& dot_shape = dot->shape(); 40 41 // ShapeInference should guarantee that lhs/rhs batch dimensions match. 42 CHECK_EQ(dnums.lhs_batch_dimensions_size(), 43 dnums.rhs_batch_dimensions_size()); 44 const int64 num_batch_dims = dnums.lhs_batch_dimensions_size(); 45 // Calculate total batch size (note that ShapeInference requires that 46 // the batch dimensions are most-major). 47 int64 batch_size = 1; 48 for (int i = 0; i < num_batch_dims; ++i) { 49 CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)), 50 rhs_shape.dimensions(dnums.rhs_batch_dimensions(i))); 51 batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)); 52 } 53 54 // Set lhs/rhs_transpose. 55 CHECK_EQ(1, dnums.lhs_contracting_dimensions_size()); 56 const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0); 57 const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0; 58 59 CHECK_EQ(1, dnums.rhs_contracting_dimensions_size()); 60 const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0); 61 const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1; 62 63 // Compute R3 and R3 shapes for lhs. 64 PrimitiveType lhs_type = lhs_shape.element_type(); 65 const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0); 66 const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1); 67 Shape lhs_shape_r3 = 68 ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols}); 69 Shape lhs_slice_shape_r3 = 70 ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols}); 71 Shape lhs_slice_shape_r2 = 72 ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols}); 73 74 // Compute R3 and R3 shapes for rhs. 75 PrimitiveType rhs_type = rhs_shape.element_type(); 76 const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0); 77 const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1); 78 Shape rhs_shape_r3 = 79 ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols}); 80 Shape rhs_slice_shape_r3 = 81 ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols}); 82 Shape rhs_slice_shape_r2 = 83 ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols}); 84 85 // Compute R3 and R3 shapes for dot output. 86 PrimitiveType dot_type = dot_shape.element_type(); 87 const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0); 88 const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1); 89 Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols}); 90 Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols}); 91 Shape concat_shape_r3 = 92 ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols}); 93 94 // Reshape lhs/rhs into R3. 95 auto lhs_r3 = computation->AddInstruction( 96 HloInstruction::CreateReshape(lhs_shape_r3, lhs)); 97 auto rhs_r3 = computation->AddInstruction( 98 HloInstruction::CreateReshape(rhs_shape_r3, rhs)); 99 100 // Loop through batch size, slicing out required lhs/rhs to compute each Dot. 101 std::vector<HloInstruction*> output_slices(batch_size); 102 for (int64 i = 0; i < batch_size; ++i) { 103 // Slice R3 shape from 'lhs' and reshape to R2. 104 auto lhs_slice_r3 = computation->AddInstruction( 105 HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0}, 106 {i + 1, lhs_rows, lhs_cols}, {1, 1, 1})); 107 auto lhs_slice_r2 = computation->AddInstruction( 108 HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3)); 109 110 // Slice R3 shape from 'rhs' and reshape to R2. 111 auto rhs_slice_r3 = computation->AddInstruction( 112 HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0}, 113 {i + 1, rhs_rows, rhs_cols}, {1, 1, 1})); 114 auto rhs_slice_r2 = computation->AddInstruction( 115 HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3)); 116 117 // Transpose lhs/rhs (if needed). 118 if (lhs_transpose) { 119 Shape lhs_slice_shape_r2_transpose = 120 ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows}); 121 lhs_slice_r2 = 122 computation->AddInstruction(HloInstruction::CreateTranspose( 123 lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0})); 124 } 125 if (rhs_transpose) { 126 Shape rhs_slice_shape_r2_transpose = 127 ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows}); 128 rhs_slice_r2 = 129 computation->AddInstruction(HloInstruction::CreateTranspose( 130 rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0})); 131 } 132 133 // Compute Dot of lhs/rhs R2 slices. 134 DotDimensionNumbers dot_dnums; 135 dot_dnums.add_lhs_contracting_dimensions(1); 136 dot_dnums.add_rhs_contracting_dimensions(0); 137 auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( 138 dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); 139 140 // Reshape Dot to R3 so we can concat along batch dimension. 141 auto dot_r3 = computation->AddInstruction( 142 HloInstruction::CreateReshape(dot_shape_r3, dot_r2)); 143 144 output_slices[i] = dot_r3; 145 } 146 147 // Concatenate slices from 'output_slices' along batch dimension. 148 auto concat = computation->AddInstruction( 149 HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0)); 150 // Reshape output 'new_dot' to original dimensions. 151 auto new_dot = computation->AddInstruction( 152 HloInstruction::CreateReshape(dot_shape, concat)); 153 154 // Replace all uses of 'dot' in 'computation' with 'new_dot'. 155 return computation->ReplaceInstruction(dot, new_dot); 156 } 157 158 } // namespace 159 160 StatusOr<bool> DotDecomposer::Run(HloModule* module) { 161 XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); 162 // Gather all batch Dot operations. 163 std::vector<HloInstruction*> batch_dots; 164 for (auto* computation : module->MakeNonfusionComputations()) { 165 for (auto* instruction : computation->instructions()) { 166 if (instruction->opcode() != HloOpcode::kDot) { 167 continue; 168 } 169 const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); 170 if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { 171 batch_dots.push_back(instruction); 172 } 173 } 174 } 175 // Decompose each batch Dot in 'batch_dots'. 176 bool changed = false; 177 for (auto* dot : batch_dots) { 178 TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); 179 changed = true; 180 } 181 XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); 182 return changed; 183 } 184 185 } // namespace xla 186