Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
     18 
     19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     20 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     21 #include "tensorflow/compiler/xla/service/hlo_module.h"
     22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 
     25 namespace xla {
     26 
     27 // HLO pass which performs instruction fusion. Instructions are fused
     28 // "vertically", meaning producing instructions are fused into their consumers
     29 // with the intent that the loops which compute their values will be fused in
     30 // code generation. Derived classes define ShouldFuse method to select which
     31 // instructions to fuse.
     32 class InstructionFusion : public HloPassInterface {
     33  public:
     34   explicit InstructionFusion(
     35       std::function<bool(const HloInstruction& instruction)> is_expensive,
     36       bool may_duplicate = true)
     37       : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {}
     38   ~InstructionFusion() override = default;
     39   tensorflow::StringPiece name() const override { return "fusion"; }
     40 
     41   // Run instruction fusion on the given computation. Returns whether the
     42   // computation was changed (instructions were fused).
     43   StatusOr<bool> Run(HloModule* module) override;
     44 
     45   // Returns true if the computation of the given instruction is significantly
     46   // more expensive than just writing all the values of the instructions' result
     47   // array. Expensive operations will not be duplicated.
     48   static bool IsExpensive(const HloInstruction& instruction);
     49 
     50  protected:
     51   // Returns whether the given producer instruction should be fused into the
     52   // given consumer instruction. producer is necessarily an operand of consumer.
     53   // Derived classes should define this method to specify which instructions
     54   // should be fused. `operand_index` is which operand of the consumer the
     55   // producer is.
     56   //
     57   // Instructions are traversed in reverse post order (computation root to
     58   // leaves). This method is called for each operand of the instruction (where
     59   // the operand is 'producer' and the instruction is 'consumer')
     60   //
     61   // Subtypes can override this with target-specific heuristics.
     62   virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index);
     63 
     64   // Chooses a fusion kind for `producer` and `consumer`.
     65   // Default method chooses `kLoop`.
     66   virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
     67                                                 const HloInstruction* consumer);
     68 
     69   // Fuses producer into consumer.
     70   virtual HloInstruction* Fuse(HloInstruction* producer,
     71                                HloInstruction* consumer);
     72 
     73   // An "effectively unary" operation is one that has one "large"
     74   // input with the others being negligible in terms of memory usage.
     75   // We use "has a smaller true rank than the output" as a heuristic
     76   // for "negligible" memory usage.
     77   bool EffectivelyUnary(HloInstruction* hlo);
     78 
     79   // Returns true if fusing producer into consumer would cause producer to be
     80   // duplicated. This is the case if producer has uses other than consumer.
     81   bool FusionWouldDuplicate(const HloInstruction& producer,
     82                             const HloInstruction& consumer) {
     83     return !(producer.users().size() == 1 && consumer.IsUserOf(&producer));
     84   }
     85 
     86   bool is_expensive(const HloInstruction& instruction) {
     87     return is_expensive_(instruction);
     88   }
     89 
     90   // Current HloComputation instance the loop fuser is traversing.
     91   HloComputation* computation_;
     92   HloModule* module_;
     93 
     94  private:
     95   // The set of producers whose consumers we cannot fuse into.
     96   using DoNotFuseSet = std::unordered_set<HloInstruction*>;
     97 
     98   // Whether or not we can fuse consumer into original_producer on all paths
     99   // from the producer to the consumer where nodes are HLOs and edges are uses.
    100   bool CanFuseOnAllPaths(const HloReachabilityMap& reachability_map,
    101                          HloInstruction* producer, HloInstruction* consumer,
    102                          DoNotFuseSet* do_not_fuse);
    103 
    104   // Used to determine if an HLO is expensive. Expensive operations will not be
    105   // duplicated.
    106   std::function<bool(const HloInstruction& instruction)> is_expensive_;
    107 
    108   // Returns whether we may duplicate an instruction if we want to fuse it.
    109   bool may_duplicate_;
    110 
    111   TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion);
    112 };
    113 
    114 }  // namespace xla
    115 
    116 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
    117