Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
     17 
     18 #include "tensorflow/compiler/xla/service/hlo_module.h"
     19 #include "tensorflow/compiler/xla/shape_util.h"
     20 #include "tensorflow/core/platform/logging.h"
     21 
     22 namespace xla {
     23 
     24 std::vector<HloInstruction*> ReducePrecisionInsertion::instructions_to_modify(
     25     const HloComputation* computation) {
     26   std::vector<HloInstruction*> instruction_list;
     27 
     28   switch (location_) {
     29     case HloReducePrecisionOptions::OP_INPUTS:
     30     case HloReducePrecisionOptions::OP_OUTPUTS:
     31     case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
     32       for (auto* instruction : computation->instructions()) {
     33         VLOG(4) << "Visited instruction: " << instruction->ToString();
     34         if (instruction_filter_function_(instruction)) {
     35           instruction_list.push_back(instruction);
     36         }
     37       }
     38       break;
     39 
     40     case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
     41     case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
     42       for (auto* instruction : computation->instructions()) {
     43         VLOG(4) << "Visited instruction: " << instruction->ToString();
     44         if (instruction->opcode() != HloOpcode::kFusion) {
     45           continue;
     46         }
     47         for (auto* fused_instruction :
     48              instruction->fused_instructions_computation()->instructions()) {
     49           VLOG(4) << "Checking sub-instruction: "
     50                   << fused_instruction->ToString();
     51           if (instruction_filter_function_(fused_instruction)) {
     52             instruction_list.push_back(instruction);
     53             break;
     54           }
     55         }
     56       }
     57       break;
     58 
     59     default:
     60       break;
     61   }
     62   VLOG(1) << "Found " << instruction_list.size()
     63           << " candidate instruction(s) for reduce-precision insertion";
     64 
     65   return instruction_list;
     66 }
     67 
     68 StatusOr<bool> ReducePrecisionInsertion::insert_after(
     69     HloInstruction* instruction) {
     70   // Check that this isn't already an equivalent operation.
     71   if (is_redundant(instruction)) {
     72     VLOG(2) << "Skipped: instruction is already an equivalent"
     73                " reduce-precision instruction:"
     74             << instruction->ToString();
     75     return false;
     76   }
     77 
     78   // Check that we haven't already inserted an equivalant reduce-precision
     79   // operation after this instruction.  (The zero-user case occurs when this is
     80   // the root instruction.)
     81   if (instruction->user_count() > 0) {
     82     bool redundant_followers = true;
     83     for (HloInstruction* user : instruction->users()) {
     84       if (!is_redundant(user)) {
     85         redundant_followers = false;
     86         break;
     87       }
     88     }
     89     if (redundant_followers) {
     90       VLOG(2) << "Skipped: instruction already followed by equivalent"
     91                  " reduce-precision instructions";
     92       return false;
     93     }
     94   }
     95 
     96   HloInstruction* reduced = instruction->parent()->AddInstruction(
     97       HloInstruction::CreateReducePrecision(instruction->shape(), instruction,
     98                                             exponent_bits_, mantissa_bits_));
     99   TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(reduced));
    100   return true;
    101 }
    102 
    103 StatusOr<bool> ReducePrecisionInsertion::insert_on_inputs(
    104     const std::vector<HloInstruction*>& instructions) {
    105   bool computation_changed = false;
    106   for (auto instruction : instructions) {
    107     VLOG(2) << "Adding reduce-precision operation to inputs of instruction: "
    108             << instruction->ToString();
    109     for (int64 i = 0; i < instruction->operand_count(); i++) {
    110       HloInstruction* operand = instruction->mutable_operand(i);
    111       VLOG(2) << "Adding to operand " << i << ": " << operand;
    112 
    113       if (!is_valid_shape(operand->shape())) {
    114         VLOG(2) << "Skipped: value is not an F32 vector";
    115         continue;
    116       }
    117 
    118       if (is_redundant(operand)) {
    119         VLOG(2) << "Skipped: operand is already an equivalent reduce-precision"
    120                    " instruction";
    121         continue;
    122       }
    123 
    124       if (instruction->opcode() == HloOpcode::kFusion &&
    125           (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
    126            instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) {
    127         // Insert the reduce-precision operation inside the fusion computation,
    128         // after the corresponding parameter instruction.
    129         TF_ASSIGN_OR_RETURN(
    130             bool instruction_changed,
    131             insert_after(instruction->fused_instructions_computation()
    132                              ->parameter_instruction(i)));
    133         computation_changed |= instruction_changed;
    134       } else {
    135         // Look for an existing reduce-precision operation on the operand.  (We
    136         // need to be careful not to create a loop, though!)
    137         HloInstruction* reduced = nullptr;
    138         for (auto& user : operand->users()) {
    139           if (user != instruction &&
    140               user->opcode() == HloOpcode::kReducePrecision &&
    141               user->exponent_bits() == exponent_bits_ &&
    142               user->mantissa_bits() == mantissa_bits_) {
    143             reduced = user;
    144             break;
    145           }
    146         }
    147         // If there wasn't an existing reduce-precision operation, create one.
    148         if (!reduced) {
    149           reduced = instruction->parent()->AddInstruction(
    150               HloInstruction::CreateReducePrecision(
    151                   operand->shape(), operand, exponent_bits_, mantissa_bits_));
    152         }
    153         // Insert the reduce-precision operation before the operand.
    154         TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(i, reduced));
    155         computation_changed = true;
    156       }
    157     }
    158   }
    159 
    160   return computation_changed;
    161 }
    162 
    163 StatusOr<bool> ReducePrecisionInsertion::insert_on_outputs(
    164     const std::vector<HloInstruction*>& instructions) {
    165   bool computation_changed = false;
    166   for (const auto& instruction : instructions) {
    167     VLOG(2) << "Adding reduce-precision operation to output of instruction: "
    168             << instruction->ToString();
    169 
    170     if (!is_valid_shape(instruction->shape())) {
    171       VLOG(2) << "Skipped: value is not an F32 nonscalar array";
    172       continue;
    173     }
    174 
    175     if (instruction->opcode() == HloOpcode::kFusion &&
    176         (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop ||
    177          instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) {
    178       // Insert the reduce-precision operation as the last operation inside
    179       // the fusion computation.
    180       HloInstruction* fusion_root = instruction->fused_expression_root();
    181       VLOG(2) << "Inserting new operation after existing fusion root: "
    182               << fusion_root->ToString();
    183 
    184       TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(fusion_root));
    185       computation_changed |= instruction_changed;
    186     } else {
    187       // Insert the reduce-precision operation after the instruction.
    188       TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(instruction));
    189       computation_changed |= instruction_changed;
    190     }
    191   }
    192 
    193   return computation_changed;
    194 }
    195 
    196 StatusOr<bool> ReducePrecisionInsertion::Run(HloModule* module) {
    197   bool changed = false;
    198   VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name();
    199 
    200   for (auto* computation : module->MakeNonfusionComputations()) {
    201     StatusOr<bool> computation_changed;
    202     switch (location_) {
    203       case HloReducePrecisionOptions::OP_INPUTS:
    204       case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
    205         computation_changed = ReducePrecisionInsertion::insert_on_inputs(
    206             instructions_to_modify(computation));
    207         break;
    208 
    209       case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
    210       case HloReducePrecisionOptions::OP_OUTPUTS:
    211       case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
    212         computation_changed = ReducePrecisionInsertion::insert_on_outputs(
    213             instructions_to_modify(computation));
    214         break;
    215       default:
    216         break;
    217     }
    218     TF_RETURN_IF_ERROR(computation_changed.status());
    219 
    220     if (computation_changed.ValueOrDie()) {
    221       changed = true;
    222       VLOG(3) << "Computation after reduce-precision insertion:";
    223       XLA_VLOG_LINES(3, computation->ToString());
    224     } else {
    225       VLOG(3) << "Computation " << computation->name() << " unchanged";
    226     }
    227   }
    228 
    229   return changed;
    230 }
    231 
    232 ReducePrecisionInsertion::InstructionFilterFunction
    233 ReducePrecisionInsertion::make_filter_function(
    234     const HloReducePrecisionOptions& reduce_precision_options) {
    235   // Implement the filter function with a lookup table.
    236   std::vector<bool> opcode_filter(HloOpcodeCount(), false);
    237   for (const auto& opcode : reduce_precision_options.opcodes_to_suffix()) {
    238     opcode_filter[opcode] = true;
    239   }
    240   if (reduce_precision_options.opname_substrings_to_suffix_size() == 0) {
    241     return [opcode_filter](const HloInstruction* instruction) {
    242       return opcode_filter[static_cast<unsigned int>(instruction->opcode())];
    243     };
    244   } else {
    245     std::vector<string> opname_substrings;
    246     for (const auto& substring :
    247          reduce_precision_options.opname_substrings_to_suffix()) {
    248       opname_substrings.push_back(substring);
    249     }
    250     return [opcode_filter,
    251             opname_substrings](const HloInstruction* instruction) {
    252       if (!opcode_filter[static_cast<unsigned int>(instruction->opcode())]) {
    253         return false;
    254       }
    255       const auto& opname = instruction->metadata().op_name();
    256       for (const auto& substring : opname_substrings) {
    257         if (opname.find(substring) != string::npos) {
    258           return true;
    259         }
    260       }
    261       return false;
    262     };
    263   }
    264 }
    265 
    266 HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto(
    267     const HloReducePrecisionOptions::Location location, const int exponent_bits,
    268     const int mantissa_bits,
    269     const std::function<bool(HloOpcode)>& opcode_filter_function,
    270     const std::vector<string>& opname_substring_list) {
    271   HloReducePrecisionOptions options;
    272   options.set_location(location);
    273   options.set_exponent_bits(exponent_bits);
    274   options.set_mantissa_bits(mantissa_bits);
    275   for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) {
    276     if (opcode_filter_function(static_cast<HloOpcode>(opcode))) {
    277       options.add_opcodes_to_suffix(opcode);
    278     }
    279   }
    280   for (auto& string : opname_substring_list) {
    281     options.add_opname_substrings_to_suffix(string);
    282   }
    283   return options;
    284 }
    285 
    286 bool ReducePrecisionInsertion::AddPasses(HloPassPipeline* pipeline,
    287                                          const DebugOptions& debug_options,
    288                                          const PassTiming pass_timing) {
    289   bool passes_added = false;
    290   for (const auto& pass_options :
    291        debug_options.hlo_reduce_precision_options()) {
    292     bool add_pass;
    293     switch (pass_options.location()) {
    294       case HloReducePrecisionOptions::OP_INPUTS:
    295       case HloReducePrecisionOptions::OP_OUTPUTS:
    296         add_pass = pass_timing == PassTiming::BEFORE_OPTIMIZATION;
    297         break;
    298       case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS:
    299       case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT:
    300       case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT:
    301         add_pass = pass_timing == PassTiming::AFTER_FUSION;
    302         break;
    303       default:
    304         add_pass = false;
    305     }
    306     if (add_pass) {
    307       pipeline->AddPass<ReducePrecisionInsertion>(pass_options);
    308       passes_added = true;
    309     }
    310   }
    311   return passes_added;
    312 }
    313 
    314 }  // namespace xla
    315