Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3  Licensed under the Apache License, Version 2.0 (the "License");
      4  you may not use this file except in compliance with the License.
      5  You may obtain a copy of the License at
      6 
      7      http://www.apache.org/licenses/LICENSE-2.0
      8 
      9  Unless required by applicable law or agreed to in writing, software
     10  distributed under the License is distributed on an "AS IS" BASIS,
     11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12  See the License for the specific language governing permissions and
     13  limitations under the License.
     14  ==============================================================================*/
     15 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
     16 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
     17 
     18 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
     19 #include "tensorflow/compiler/xla/service/call_graph.h"
     20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     22 #include "tensorflow/compiler/xla/service/hlo_module.h"
     23 #include "tensorflow/compiler/xla/service/hlo_scheduling.h"
     24 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
     25 
     26 namespace xla {
     27 
     28 class HloRematerialization {
     29  public:
     30   using ShapeSizeFunction = std::function<int64(const Shape&)>;
     31 
     32   // Helper struct that communicates the before / after sizes for the
     33   // rematerialization process.
     34   struct RematerializationSizes {
     35     int64 before_bytes;
     36     int64 after_bytes;
     37   };
     38 
     39   // Rematerialize HLO instructions in the given module to reduce peak memory
     40   // use below memory_limit_bytes where memory use is defined as the total size
     41   // of all live HLO instruction values. Parameters and constants are included
     42   // in memory use estimates. Method parameters:
     43   //
     44   //   size_function: Function which returns the size in bytes of the top-level
     45   //     buffer of the given shape.
     46   //
     47   //   memory_limit_bytes: The threshold number of bytes to reduce memory use to
     48   //     via rematerialization.
     49   //
     50   //   hlo_module: HLO module to rematerialize instructions in.
     51   //
     52   //   sequence: Should point to an empty HloModuleSequence. Upon return
     53   //     contains the HLO instruction order which was used for
     54   //     rematerialization. This is the order in which HLO instructions should
     55   //     be emitted to minimize memory use.
     56   //
     57   //   sizes: Optional outparam that indicates the peak memory usage of the HLO
     58   //     module before/after rematerialization.
     59   //
     60   // Returns whether any instructions were rematerialized. If memory use is
     61   // already below the given limit then no instructions are rematerialized and
     62   // false is returned.
     63   //
     64   // CSE will undo the effects of this optimization and should not be run after
     65   // this pass. In general, this pass should be run very late immediately before
     66   // code generation.
     67   static StatusOr<bool> RematerializeAndSchedule(
     68       const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
     69       HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm,
     70       SequentialHloOrdering::HloModuleSequence* sequence,
     71       RematerializationSizes* sizes = nullptr);
     72 
     73  protected:
     74   HloRematerialization(SchedulerAlgorithm scheduler_algorithm,
     75                        const ShapeSizeFunction& size_function)
     76       : scheduler_algorithm_(scheduler_algorithm),
     77         size_function_(size_function) {}
     78   ~HloRematerialization() {}
     79 
     80   // Runs rematerialization on the given module. Returns whether the module was
     81   // changed. memory_limit is the target maximum peak memory usage by the
     82   // module. sequence should be an empty HloModuleSequence. Upon return sequence
     83   // contains the memory-minimizing order in which to emit the HLO instructions.
     84   StatusOr<bool> Run(HloModule* module,
     85                      SequentialHloOrdering::HloModuleSequence* sequence,
     86                      int64 memory_limit, RematerializationSizes* sizes);
     87 
     88   // Rematerializes instructions within the given computation. 'order' is the
     89   // order in which the computation's instructions will be emitted in the
     90   // backend. Rematerialized instructions will be added to the HLO computation
     91   // and inserted into 'order'.
     92   StatusOr<bool> RematerializeComputation(
     93       HloComputation* computation,
     94       SequentialHloOrdering::HloModuleSequence* sequence,
     95       int64 computation_memory_limit);
     96 
     97   // Computes and returns the peak memory used by the given computation. The
     98   // peak memory is the maximum total size of all live HLO instruction values at
     99   // any program point. 'order' is the order in which the HLO instructions will
    100   // be emitted which is used to determine lifespans of HLO values.
    101   StatusOr<int64> ComputePeakMemory(
    102       const HloComputation* computation,
    103       const std::vector<const HloInstruction*>& order) const;
    104 
    105   // Returns the peak memory usage of the called computations for the given
    106   // instruction. Zero is returned if the instruction calls no computations.
    107   StatusOr<int64> CalledComputationsMemoryUsage(
    108       const HloInstruction* instruction) const;
    109 
    110   // Selects an algorithm to use for HLO scheduling.
    111   SchedulerAlgorithm scheduler_algorithm_;
    112 
    113   // Function which computes the size of the top-level buffer of a shape.
    114   const ShapeSizeFunction size_function_;
    115 
    116   // Call graph of the hlo_module.
    117   std::unique_ptr<CallGraph> call_graph_;
    118 
    119   // The peak memory usage of each computation. The map contains only those
    120   // computations called from sequential context
    121   // (CallContext::kSequential). These values are updated as rematerialization
    122   // occurs.
    123   tensorflow::gtl::FlatMap<const HloComputation*, int64>
    124       computation_peak_memory_;
    125 
    126   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
    127 
    128   // Set of computations which have had rematerialization
    129   // applied. Rematerialization is only applied once per computation.
    130   tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_;
    131 
    132   // Count of the total instructions rematerialized.
    133   int64 instructions_rematerialized_ = 0;
    134 
    135   // Count of the net instructions added to the HLO module by
    136   // rematerialization. This can be different than instructions_rematerialized_
    137   // because some rematerializations are effectively moves in the HLO
    138   // schedule. In these cases, the rematerialization instruction replaces all
    139   // uses of the original instruction and the original instruction is
    140   // dead. Hence, no net instructions were added.
    141   int64 net_instructions_added_ = 0;
    142 };
    143 
    144 }  // namespace xla
    145 
    146 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
    147