Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/transpose_folding.h"
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/shape_util.h"
     24 #include "tensorflow/compiler/xla/status_macros.h"
     25 #include "tensorflow/compiler/xla/util.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 
     30 namespace xla {
     31 
     32 namespace {
     33 
     34 TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
     35     const HloInstruction& dot,
     36     const TransposeFolding::TransposableGemmOperandsFn&
     37         transposable_gemm_operands) {
     38   if (HloOpcode::kDot != dot.opcode()) {
     39     return {};
     40   }
     41 
     42   TransposeFolding::OperandIndices operand_set;
     43   for (int64 i = 0; i < dot.operand_count(); ++i) {
     44     auto& operand = *dot.operand(i);
     45     if (operand.IsRank2Transpose()) {
     46       operand_set.push_back(i);
     47     }
     48   }
     49 
     50   return transposable_gemm_operands(dot, operand_set);
     51 }
     52 
     53 TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
     54     const HloInstruction& convolution,
     55     const TransposeFolding::TransposableConvOperandsFn&
     56         transposable_conv_operands) {
     57   if (HloOpcode::kConvolution != convolution.opcode()) {
     58     return {};
     59   }
     60 
     61   TransposeFolding::OperandIndices operand_set;
     62   for (int64 i = 0; i < convolution.operand_count(); ++i) {
     63     auto& operand = *convolution.operand(i);
     64     if (operand.opcode() == HloOpcode::kTranspose) {
     65       operand_set.push_back(i);
     66     }
     67   }
     68 
     69   return transposable_conv_operands(convolution, operand_set);
     70 }
     71 
     72 using InstructionOperandsPair =
     73     std::pair<HloInstruction*, TransposeFolding::OperandIndices>;
     74 
     75 // Folds the operands of `dot` that are foldable transposes. `computation` is
     76 // the parent HLO computation of `dot`.
     77 //
     78 // Returns whether the module is changed.
     79 bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
     80   auto* dot = pair.first;
     81   std::vector<HloInstruction*> instructions_to_fuse(1, dot);
     82   for (const int64 operand_index : pair.second) {
     83     instructions_to_fuse.push_back(dot->mutable_operand(operand_index));
     84   }
     85 
     86   // Early-exit if no operands are foldable.
     87   if (instructions_to_fuse.size() == 1) {
     88     return false;
     89   }
     90 
     91   dot->parent()->CreateFusionInstruction(
     92       instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot);
     93   return true;
     94 }
     95 
     96 // Folds the operands of `convolution` that are foldable transposes.
     97 // `computation` is the parent HLO computation of `convolution`.
     98 //
     99 // Returns whether the module is changed.
    100 bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
    101   auto& convolution = *pair.first;
    102   auto& operand_indices = pair.second;
    103 
    104   if (operand_indices.empty()) {
    105     return false;
    106   }
    107 
    108   const ConvolutionDimensionNumbers& dnums =
    109       convolution.convolution_dimension_numbers();
    110   ConvolutionDimensionNumbers new_dnums = dnums;
    111 
    112   HloInstruction* new_lhs;
    113   const int64 kLhsIdx = 0;
    114   if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) !=
    115       operand_indices.end()) {
    116     HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
    117     const auto& transpose_dimensions = transpose.dimensions();
    118     HloInstruction& transpose_operand = *transpose.mutable_operand(0);
    119 
    120     // Everything remains the same except for the input/output dimension
    121     // numbers. We need to apply the transpose permutation to the original shape
    122     // to figure out what the new logical dimensions are.
    123     new_dnums.set_input_batch_dimension(
    124         transpose_dimensions[dnums.input_batch_dimension()]);
    125     new_dnums.set_input_feature_dimension(
    126         transpose_dimensions[dnums.input_feature_dimension()]);
    127     for (auto& input_spatial_dimension :
    128          *new_dnums.mutable_input_spatial_dimensions()) {
    129       input_spatial_dimension = transpose_dimensions[input_spatial_dimension];
    130     }
    131     new_lhs = &transpose_operand;
    132   } else {
    133     new_lhs = convolution.mutable_operand(kLhsIdx);
    134   }
    135 
    136   HloInstruction* new_rhs;
    137   const int64 kRhsIdx = 1;
    138   if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) !=
    139       operand_indices.end()) {
    140     HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
    141     const auto& transpose_dimensions = transpose.dimensions();
    142     HloInstruction& transpose_operand = *transpose.mutable_operand(0);
    143 
    144     // Everything remains the same except for the kernel dimension numbers. We
    145     // need to apply the transpose permutation to the original shape to figure
    146     // out what the new logical dimensions are.
    147     new_dnums.set_kernel_input_feature_dimension(
    148         transpose_dimensions[dnums.kernel_input_feature_dimension()]);
    149     new_dnums.set_kernel_output_feature_dimension(
    150         transpose_dimensions[dnums.kernel_output_feature_dimension()]);
    151     for (auto& kernel_spatial_dimension :
    152          *new_dnums.mutable_kernel_spatial_dimensions()) {
    153       kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
    154     }
    155     new_rhs = &transpose_operand;
    156   } else {
    157     new_rhs = convolution.mutable_operand(kRhsIdx);
    158   }
    159 
    160   auto new_conv = HloInstruction::CreateConvolve(
    161       convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
    162   TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
    163       &convolution, std::move(new_conv)));
    164 
    165   return true;
    166 }
    167 
    168 }  // namespace
    169 
    170 TransposeFolding::TransposeFolding(
    171     TransposableGemmOperandsFn transposable_gemm_operands,
    172     TransposableConvOperandsFn transposable_conv_operands)
    173     : transposable_gemm_operands_(std::move(transposable_gemm_operands)),
    174       transposable_conv_operands_(std::move(transposable_conv_operands)) {}
    175 
    176 StatusOr<bool> TransposeFolding::Run(HloModule* module) {
    177   // Modifying the graph while traversing is dangerous, so we find all folding
    178   // opportunities before actually folding them.
    179   std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_dots;
    180   std::vector<std::pair<HloInstruction*, OperandIndices>> foldable_convolutions;
    181   auto visit_fn = [this, &foldable_dots,
    182                    &foldable_convolutions](HloInstruction* instruction) {
    183     {
    184       OperandIndices operand_indices =
    185           CanFoldOperandsIntoDot(*instruction, transposable_gemm_operands_);
    186       if (!operand_indices.empty()) {
    187         foldable_dots.emplace_back(instruction, operand_indices);
    188       }
    189     }
    190     {
    191       OperandIndices operand_indices = CanFoldOperandsIntoConvolution(
    192           *instruction, transposable_conv_operands_);
    193       if (!operand_indices.empty()) {
    194         foldable_convolutions.emplace_back(
    195             std::make_pair(instruction, operand_indices));
    196       }
    197     }
    198     return tensorflow::Status::OK();
    199   };
    200 
    201   for (auto* comp : module->MakeNonfusionComputations()) {
    202     TF_RETURN_IF_ERROR(comp->Accept(visit_fn));
    203   }
    204 
    205   bool changed = false;
    206   for (InstructionOperandsPair& pair : foldable_dots) {
    207     changed |= FoldTransposeIntoDot(pair);
    208   }
    209   for (InstructionOperandsPair& pair : foldable_convolutions) {
    210     changed |= FoldTransposeIntoConvolution(pair);
    211   }
    212   return changed;
    213 }
    214 
    215 }  // namespace xla
    216