Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
     18 
     19 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
     20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     22 #include "tensorflow/compiler/xla/service/hlo_module.h"
     23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     24 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
     25 #include "tensorflow/core/lib/gtl/flatmap.h"
     26 
     27 namespace xla {
     28 
     29 // HLO pass which inserts reduce-precision instructions into the HLO graph, for
     30 // purposes of experimenting with the effects of reduced-precision storage of
     31 // intermediate values.
     32 class ReducePrecisionInsertion : public HloPassInterface {
     33   using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
     34 
     35  public:
     36   // The exponent_bits and mantissa_bits arguments specify the parameters of
     37   // the instructions to insert.  The instructions will be inserted after each
     38   // instruction with an opcode for which the instruction_filter_function
     39   // function returns true and the output type is F32.
     40   explicit ReducePrecisionInsertion(
     41       const int exponent_bits, const int mantissa_bits,
     42       const HloReducePrecisionOptions::Location location,
     43       const InstructionFilterFunction& instruction_filter_function)
     44       : exponent_bits_(exponent_bits),
     45         mantissa_bits_(mantissa_bits),
     46         location_(location),
     47         instruction_filter_function_(instruction_filter_function) {}
     48 
     49   // Version of the constructor that takes an HloReducePrecisionOptions proto
     50   // rather than explicitly-enumerated parameters, for convenience when
     51   // creating passes based on DebugOptions.
     52   explicit ReducePrecisionInsertion(
     53       const HloReducePrecisionOptions& reduce_precision_options)
     54       : exponent_bits_(reduce_precision_options.exponent_bits()),
     55         mantissa_bits_(reduce_precision_options.mantissa_bits()),
     56         location_(reduce_precision_options.location()),
     57         instruction_filter_function_(
     58             make_filter_function(reduce_precision_options)) {}
     59 
     60   ~ReducePrecisionInsertion() override{};
     61 
     62   tensorflow::StringPiece name() const override {
     63     return "reduce-precision-insertion";
     64   }
     65 
     66   // Run the pass on the given module. Returns whether the module was changed
     67   // (reduce-precision instructions were inserted).
     68   StatusOr<bool> Run(HloModule* module) override;
     69 
     70   // Convert between the (inconvenient) xla.proto HloReducePrecisionOptions
     71   // representation and InstructionFilterFunction functions.
     72   static InstructionFilterFunction make_filter_function(
     73       const HloReducePrecisionOptions& reduce_precision_options);
     74   static HloReducePrecisionOptions make_options_proto(
     75       const HloReducePrecisionOptions::Location location,
     76       const int exponent_bits, const int mantissa_bits,
     77       const std::function<bool(HloOpcode)>& opcode_filter_function,
     78       const std::vector<string>& opname_substring_list = {});
     79 
     80   // Enumeration to control which passes should be added.
     81   enum class PassTiming { BEFORE_OPTIMIZATION, AFTER_FUSION };
     82 
     83   // Add ReducePrecisionInsertion passes to an HloPassPipeline based on the list
     84   // of HloReducePrecisionOptions in a DebugOptions proto.  Returns true if any
     85   // passes were added.
     86   static bool AddPasses(HloPassPipeline* pipeline,
     87                         const DebugOptions& debug_options,
     88                         const PassTiming pass_timing);
     89 
     90  private:
     91   // Select the instructions that should have reduce-precision operations
     92   // attached to them.
     93   std::vector<HloInstruction*> instructions_to_modify(
     94       const HloComputation* computation);
     95 
     96   // Insert a reduce-precision operation into the graph on the output of the
     97   // given instruction.
     98   StatusOr<bool> insert_after(HloInstruction* instruction);
     99 
    100   // Insert reduce-precision operations into the graph on the inputs of the
    101   // given instructions.  (For fusion instructions, the operations will be
    102   // inserted inside the fusion computation, on the outputs of the relevant
    103   // input parameters.)
    104   StatusOr<bool> insert_on_inputs(
    105       const std::vector<HloInstruction*>& instructions);
    106 
    107   // Insert reduce-precision operations into the graph on the outputs of the
    108   // given instructions.  (For fusion instructions, the operations will be
    109   // inserted inside the fusion computation as a new root.)
    110   StatusOr<bool> insert_on_outputs(
    111       const std::vector<HloInstruction*>& instructions);
    112 
    113   // Is this shape valid for inserting a reduce-precision operation?
    114   bool is_valid_shape(const Shape& shape) {
    115     // For now, ReducePrecision is only implemented for F32 arrays, so this
    116     // ignores instructions that produce other data.  In particular, this
    117     // currently ignores instructions producing tuples, even if those tuples
    118     // contain F32 arrays inside them.  The assumption is that in most cases
    119     // equivalent behavior can be obtained by adding ReducePrecision
    120     // instructions after the instructions that pull the F32 arrays out of
    121     // the tuples.
    122     //
    123     // TODO(b/64093391): Remove the IsScalar check once this won't cause
    124     // failures on the GPU backend if the ReducePrecision instruction ends up
    125     // inserted between a scalar constant and the init_value argument of a
    126     // Reduce operation.
    127     return shape.element_type() == PrimitiveType::F32 &&
    128            !ShapeUtil::IsScalar(shape);
    129   }
    130 
    131   // Is this instruction one such that following or preceding it with a new
    132   // reduce-precision operation will be redundant?
    133   bool is_redundant(const HloInstruction* instruction) {
    134     return instruction->opcode() == HloOpcode::kReducePrecision &&
    135            instruction->exponent_bits() <= exponent_bits_ &&
    136            instruction->mantissa_bits() <= mantissa_bits_;
    137   }
    138 
    139   // Parameters for the precision reduction to be added.
    140   const int exponent_bits_;
    141   const int mantissa_bits_;
    142 
    143   // Pass "timing" parameter.  This also controls aspects of how the pass
    144   // selects locations to insert instructions.
    145   const HloReducePrecisionOptions::Location location_;
    146 
    147   // User-provided Function to determine whether a given instruction should
    148   // have a reduce-precision instruction inserted in its output stream.
    149   const InstructionFilterFunction instruction_filter_function_;
    150 };
    151 
    152 }  // namespace xla
    153 
    154 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_REDUCE_PRECISION_INSERTION_H_
    155