1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" 17 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 18 19 namespace xla { 20 namespace cpu { 21 22 namespace { 23 24 int64 BytesInDimension(const Shape& shape, int64 dimension) { 25 return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) * 26 shape.dimensions(dimension); 27 } 28 29 bool CanBeLoopFused(const HloInstruction& hlo) { 30 // These are the only ones we fuse since we rely on effective elemental IR 31 // generation. 32 return hlo.IsElementwise() || // 33 hlo.opcode() == HloOpcode::kBitcast || 34 hlo.opcode() == HloOpcode::kBroadcast || 35 hlo.opcode() == HloOpcode::kConcatenate || 36 hlo.opcode() == HloOpcode::kDynamicSlice || 37 hlo.opcode() == HloOpcode::kDynamicUpdateSlice || 38 hlo.opcode() == HloOpcode::kPad || 39 hlo.opcode() == HloOpcode::kReshape || 40 hlo.opcode() == HloOpcode::kReverse || 41 hlo.opcode() == HloOpcode::kSlice || 42 hlo.opcode() == HloOpcode::kTranspose; 43 } 44 45 bool IsMatrixVectorDot(const HloInstruction* hlo) { 46 const Shape& hlo_shape = hlo->shape(); 47 return hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && 48 (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); 49 } 50 51 bool CanBeOutputFused(const HloInstruction* producer, 52 const HloInstruction* consumer) { 53 return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && 54 producer->user_count() == 1; 55 } 56 57 bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { 58 return consumer->opcode() == HloOpcode::kAdd && 59 (CanBeOutputFused(consumer->operand(0), consumer) || 60 CanBeOutputFused(consumer->operand(1), consumer)); 61 } 62 } // namespace 63 64 bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, 65 int64 operand_index) { 66 HloInstruction* producer = consumer->mutable_operand(operand_index); 67 VLOG(2) << "Considering for fusion: operand " << operand_index << " of " 68 << consumer->ToString(); 69 70 constexpr int kFusionThresholdBytes = 16 * 1024; 71 72 if (CanBeOutputFused(producer, consumer)) { 73 return true; 74 } 75 76 if (CanBeOutputFusedIntoSomeOperand(producer)) { 77 return false; 78 } 79 80 if (!CanBeLoopFused(*producer)) { 81 VLOG(2) << "Producer is not fusile."; 82 return false; 83 } 84 85 // Cost condition: not fuse (simple, expensive producers) and (consumers who 86 // reuse operand elements). 87 if (producer->opcode() != HloOpcode::kFusion && 88 consumer->ReusesOperandElements(operand_index) && 89 is_expensive(*producer)) { 90 VLOG(2) << "Fusion is not profitable."; 91 return false; 92 } 93 94 // TODO(b/28644064): see if the "producer->operand_count() == 0" check is 95 // necessary. 96 if (producer->operand_count() == 0 || 97 !InstructionFusion::ShouldFuse(consumer, operand_index)) { 98 VLOG(2) 99 << "Not fusing: producer has no operands, or !ShouldFuse(consumer)."; 100 return false; 101 } 102 103 // Output fusion is not currently supported on CPUs. 104 if (producer->opcode() == HloOpcode::kFusion) { 105 VLOG(2) << "Not fusing: producer is itself a fusion node."; 106 return false; 107 } 108 109 if (consumer->opcode() == HloOpcode::kDot) { 110 // In the general case we call out to optimized "black box" GEMM routines 111 // for Dot, which precludes fusion. However, in very specific cases, we try 112 // to fuse Dot operations by generating an elemental dot implementation. 113 // 114 // We need to be careful and conservative here since any benefit we get from 115 // fusion can easily be overshadowed by the overhead of a naive GEMM 116 // algorithm in the IR. 117 const Shape& output_shape = consumer->shape(); 118 if (output_shape.dimensions_size() == 2) { 119 // We fuse in cases where we have dot([A,B],[B,1]) or dot([1,A],[A,B]) and 120 // fusion can get rid of the larger tensor. We assume that a naive 121 // traversal of a small enough (to fit in L1) column or row tensor is 122 // "good enough" from the perspective of cache management; and calling out 123 // to an optimized GEMM kernel is not a huge win. 124 if (output_shape.dimensions(0) == 1 && operand_index == 1 && 125 BytesInDimension(output_shape, 1) < kFusionThresholdBytes) { 126 VLOG(2) << "Fusing small matrix-vector product."; 127 return true; 128 } else if (output_shape.dimensions(1) == 1 && operand_index == 0 && 129 BytesInDimension(output_shape, 0) < kFusionThresholdBytes) { 130 VLOG(2) << "Fusing small matrix-vector product."; 131 return true; 132 } 133 } 134 } 135 136 if (consumer->opcode() == HloOpcode::kFusion && 137 consumer->fusion_kind() == HloInstruction::FusionKind::kLoop) { 138 VLOG(2) << "Fusing: consumer is a fusion node."; 139 return true; 140 } 141 142 if (CanBeLoopFused(*consumer)) { 143 VLOG(2) << "Fusing: consumer is elementwise or fusile."; 144 return true; 145 } 146 147 VLOG(2) << "Not fusing."; 148 return false; 149 } 150 151 HloInstruction::FusionKind CpuInstructionFusion::ChooseKind( 152 const HloInstruction* producer, const HloInstruction* consumer) { 153 return CanBeOutputFused(producer, consumer) 154 ? HloInstruction::FusionKind::kOutput 155 : HloInstruction::FusionKind::kLoop; 156 } 157 } // namespace cpu 158 } // namespace xla 159