Home | History | Annotate | Download | only in cpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
     17 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     18 
     19 namespace xla {
     20 namespace cpu {
     21 
     22 namespace {
     23 
     24 int64 BytesInDimension(const Shape& shape, int64 dimension) {
     25   return ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()) *
     26          shape.dimensions(dimension);
     27 }
     28 
     29 bool CanBeLoopFused(const HloInstruction& hlo) {
     30   // These are the only ones we fuse since we rely on effective elemental IR
     31   // generation.
     32   return hlo.IsElementwise() ||  //
     33          hlo.opcode() == HloOpcode::kBitcast ||
     34          hlo.opcode() == HloOpcode::kBroadcast ||
     35          hlo.opcode() == HloOpcode::kConcatenate ||
     36          hlo.opcode() == HloOpcode::kDynamicSlice ||
     37          hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
     38          hlo.opcode() == HloOpcode::kPad ||
     39          hlo.opcode() == HloOpcode::kReshape ||
     40          hlo.opcode() == HloOpcode::kReverse ||
     41          hlo.opcode() == HloOpcode::kSlice ||
     42          hlo.opcode() == HloOpcode::kTranspose;
     43 }
     44 
     45 bool IsMatrixVectorDot(const HloInstruction* hlo) {
     46   const Shape& hlo_shape = hlo->shape();
     47   return hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 &&
     48          (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
     49 }
     50 
     51 bool CanBeOutputFused(const HloInstruction* producer,
     52                       const HloInstruction* consumer) {
     53   return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) &&
     54          producer->user_count() == 1;
     55 }
     56 
     57 bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) {
     58   return consumer->opcode() == HloOpcode::kAdd &&
     59          (CanBeOutputFused(consumer->operand(0), consumer) ||
     60           CanBeOutputFused(consumer->operand(1), consumer));
     61 }
     62 }  // namespace
     63 
     64 bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
     65                                       int64 operand_index) {
     66   HloInstruction* producer = consumer->mutable_operand(operand_index);
     67   VLOG(2) << "Considering for fusion: operand " << operand_index << " of "
     68           << consumer->ToString();
     69 
     70   constexpr int kFusionThresholdBytes = 16 * 1024;
     71 
     72   if (CanBeOutputFused(producer, consumer)) {
     73     return true;
     74   }
     75 
     76   if (CanBeOutputFusedIntoSomeOperand(producer)) {
     77     return false;
     78   }
     79 
     80   if (!CanBeLoopFused(*producer)) {
     81     VLOG(2) << "Producer is not fusile.";
     82     return false;
     83   }
     84 
     85   // Cost condition: not fuse (simple, expensive producers) and (consumers who
     86   // reuse operand elements).
     87   if (producer->opcode() != HloOpcode::kFusion &&
     88       consumer->ReusesOperandElements(operand_index) &&
     89       is_expensive(*producer)) {
     90     VLOG(2) << "Fusion is not profitable.";
     91     return false;
     92   }
     93 
     94   // TODO(b/28644064): see if the "producer->operand_count() == 0" check is
     95   // necessary.
     96   if (producer->operand_count() == 0 ||
     97       !InstructionFusion::ShouldFuse(consumer, operand_index)) {
     98     VLOG(2)
     99         << "Not fusing: producer has no operands, or !ShouldFuse(consumer).";
    100     return false;
    101   }
    102 
    103   // Output fusion is not currently supported on CPUs.
    104   if (producer->opcode() == HloOpcode::kFusion) {
    105     VLOG(2) << "Not fusing: producer is itself a fusion node.";
    106     return false;
    107   }
    108 
    109   if (consumer->opcode() == HloOpcode::kDot) {
    110     // In the general case we call out to optimized "black box" GEMM routines
    111     // for Dot, which precludes fusion.  However, in very specific cases, we try
    112     // to fuse Dot operations by generating an elemental dot implementation.
    113     //
    114     // We need to be careful and conservative here since any benefit we get from
    115     // fusion can easily be overshadowed by the overhead of a naive GEMM
    116     // algorithm in the IR.
    117     const Shape& output_shape = consumer->shape();
    118     if (output_shape.dimensions_size() == 2) {
    119       // We fuse in cases where we have dot([A,B],[B,1]) or dot([1,A],[A,B]) and
    120       // fusion can get rid of the larger tensor.  We assume that a naive
    121       // traversal of a small enough (to fit in L1) column or row tensor is
    122       // "good enough" from the perspective of cache management; and calling out
    123       // to an optimized GEMM kernel is not a huge win.
    124       if (output_shape.dimensions(0) == 1 && operand_index == 1 &&
    125           BytesInDimension(output_shape, 1) < kFusionThresholdBytes) {
    126         VLOG(2) << "Fusing small matrix-vector product.";
    127         return true;
    128       } else if (output_shape.dimensions(1) == 1 && operand_index == 0 &&
    129                  BytesInDimension(output_shape, 0) < kFusionThresholdBytes) {
    130         VLOG(2) << "Fusing small matrix-vector product.";
    131         return true;
    132       }
    133     }
    134   }
    135 
    136   if (consumer->opcode() == HloOpcode::kFusion &&
    137       consumer->fusion_kind() == HloInstruction::FusionKind::kLoop) {
    138     VLOG(2) << "Fusing: consumer is a fusion node.";
    139     return true;
    140   }
    141 
    142   if (CanBeLoopFused(*consumer)) {
    143     VLOG(2) << "Fusing: consumer is elementwise or fusile.";
    144     return true;
    145   }
    146 
    147   VLOG(2) << "Not fusing.";
    148   return false;
    149 }
    150 
    151 HloInstruction::FusionKind CpuInstructionFusion::ChooseKind(
    152     const HloInstruction* producer, const HloInstruction* consumer) {
    153   return CanBeOutputFused(producer, consumer)
    154              ? HloInstruction::FusionKind::kOutput
    155              : HloInstruction::FusionKind::kLoop;
    156 }
    157 }  // namespace cpu
    158 }  // namespace xla
    159