Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
     18 
     19 #include "absl/types/span.h"
     20 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     24 #include "tensorflow/compiler/xla/shape_util.h"
     25 #include "tensorflow/compiler/xla/statusor.h"
     26 #include "tensorflow/compiler/xla/xla_data.pb.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace xla {
     31 
     32 // HloCostAnalysis traverses an HLO graph and calculates the amount of
     33 // computations required for the graph. Each HLO instruction handler provides
     34 // the computation cost of the instruction, and the values are accumulated
     35 // during the traversal for the entire graph. We treat normal floating point
     36 // operations separately from transcendental operations.
     37 class HloCostAnalysis : public ConstDfsHloVisitor {
     38  public:
     39   // Each HLO is associated to a vector of properties with the indices given
     40   // below. Sub-classes can add further properties.
     41   typedef std::map<string, float> Properties;
     42   static constexpr char kFlopsKey[] = "flops";
     43   static constexpr char kTranscendentalsKey[] = "transcendentals";
     44   static constexpr char kBytesAccessedKey[] = "bytes accessed";
     45   static constexpr char kOptimalSecondsKey[] = "optimal_seconds";
     46 
     47   // shape_size is a function which returns the size in bytes of the top-level
     48   // buffer of a shape.
     49   using ShapeSizeFunction = std::function<int64(const Shape&)>;
     50   explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
     51 
     52   Status HandleElementwiseUnary(const HloInstruction* hlo) override;
     53   Status HandleElementwiseBinary(const HloInstruction* hlo) override;
     54   Status HandleConstant(const HloInstruction* constant) override;
     55   Status HandleIota(const HloInstruction* iota) override;
     56   Status HandleGetTupleElement(
     57       const HloInstruction* get_tuple_element) override;
     58   Status HandleSelect(const HloInstruction* hlo) override;
     59   Status HandleTupleSelect(const HloInstruction* hlo) override;
     60   Status HandleCompare(const HloInstruction* compare) override;
     61   Status HandleClamp(const HloInstruction* clamp) override;
     62   Status HandleReducePrecision(const HloInstruction* hlo) override;
     63   Status HandleConcatenate(const HloInstruction* concatenate) override;
     64   Status HandleSend(const HloInstruction* send) override;
     65   Status HandleSendDone(const HloInstruction* send_done) override;
     66   Status HandleRecv(const HloInstruction* recv) override;
     67   Status HandleRecvDone(const HloInstruction* recv_done) override;
     68   Status HandleConvert(const HloInstruction* convert) override;
     69   Status HandleCopy(const HloInstruction* copy) override;
     70   Status HandleDomain(const HloInstruction* domain) override;
     71   Status HandleDot(const HloInstruction* dot) override;
     72   Status HandleConvolution(const HloInstruction* convolution) override;
     73   Status HandleFft(const HloInstruction* fft) override;
     74   Status HandleTriangularSolve(const HloInstruction* hlo) override;
     75   Status HandleCholesky(const HloInstruction* hlo) override;
     76   Status HandleAllReduce(const HloInstruction* crs) override;
     77   Status HandleAllToAll(const HloInstruction* hlo) override;
     78   Status HandleCollectivePermute(const HloInstruction* hlo) override;
     79   Status HandleReplicaId(const HloInstruction* hlo) override;
     80   Status HandleInfeed(const HloInstruction* infeed) override;
     81   Status HandleOutfeed(const HloInstruction* outfeed) override;
     82   Status HandleRng(const HloInstruction* random) override;
     83   Status HandleReverse(const HloInstruction* reverse) override;
     84   Status HandleSort(const HloInstruction* sort) override;
     85   Status HandleParameter(const HloInstruction* parameter) override;
     86   Status HandleReduce(const HloInstruction* reduce) override;
     87   Status HandleBatchNormTraining(
     88       const HloInstruction* batch_norm_training) override;
     89   Status HandleBatchNormInference(
     90       const HloInstruction* batch_norm_inference) override;
     91   Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override;
     92   Status HandleFusion(const HloInstruction* fusion) override;
     93   Status HandleCall(const HloInstruction* call) override;
     94   Status HandleCustomCall(const HloInstruction* custom_call) override;
     95   Status HandleSlice(const HloInstruction* slice) override;
     96   Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override;
     97   Status HandleDynamicUpdateSlice(
     98       const HloInstruction* dynamic_update_slice) override;
     99   Status HandleTuple(const HloInstruction* tuple) override;
    100   Status HandleMap(const HloInstruction* map) override;
    101   Status HandleReduceWindow(const HloInstruction* reduce_window) override;
    102   Status HandleSelectAndScatter(const HloInstruction* instruction) override;
    103   Status HandleBitcast(const HloInstruction* bitcast) override;
    104   Status HandleBroadcast(const HloInstruction* broadcast) override;
    105   Status HandlePad(const HloInstruction* pad) override;
    106   Status HandleReshape(const HloInstruction* reshape) override;
    107   Status HandleAddDependency(const HloInstruction* add_dependency) override;
    108   Status HandleAfterAll(const HloInstruction* token) override;
    109   Status HandleTranspose(const HloInstruction* transpose) override;
    110   Status HandleWhile(const HloInstruction* xla_while) override;
    111   Status HandleConditional(const HloInstruction* conditional) override;
    112   Status HandleGather(const HloInstruction* gather) override;
    113   Status HandleScatter(const HloInstruction* scatter) override;
    114   Status HandleGetDimensionSize(const HloInstruction* get_size) override;
    115   Status FinishVisit(const HloInstruction* root) override;
    116 
    117   Status Preprocess(const HloInstruction* hlo) override;
    118   Status Postprocess(const HloInstruction* hlo) override;
    119 
    120   // Set the rates used to calculate the time taken by the computation. These
    121   // need to be set before visiting starts.
    122   void set_flops_per_second(float value) {
    123     per_second_rates_[kFlopsKey] = value;
    124   }
    125   void set_transcendentals_per_second(float value) {
    126     per_second_rates_[kTranscendentalsKey] = value;
    127   }
    128   void set_bytes_per_second(float value) {
    129     per_second_rates_[kBytesAccessedKey] = value;
    130   }
    131 
    132   // Returns properties for the computation.
    133   float flop_count() const;
    134   float transcendental_count() const;
    135   float bytes_accessed() const;
    136   float optimal_seconds() const;
    137 
    138   // Returns the respective cost computed for a particular HLO instruction, or 0
    139   // if the HLO was not found to have a cost in the analysis.
    140   int64 flop_count(const HloInstruction& hlo) const;
    141   int64 transcendental_count(const HloInstruction& hlo) const;
    142   int64 bytes_accessed(const HloInstruction& hlo) const;
    143   float optimal_seconds(const HloInstruction& hlo) const;
    144 
    145   const Properties& properties() const { return properties_sum_; }
    146   const float property(const string& key) const {
    147     return GetProperty(key, properties());
    148   }
    149 
    150  protected:
    151   typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties;
    152 
    153   // An FMA counts as two floating point operations in these analyzes.
    154   static constexpr int64 kFmaFlops = 2;
    155 
    156   HloCostAnalysis(const ShapeSizeFunction& shape_size,
    157                   const Properties& per_second_rates);
    158 
    159   // Returns the properties computed from visiting the computation rooted at the
    160   // given hlo.
    161   //
    162   // The difference between ProcessNestedSubcomputation and
    163   // ProcessUnnestedSubcomputation is that we expect to get profile results for
    164   // an unnested subcomputation's individual instructions, while we expect that
    165   // a nested subcomputation is completely subsumed by its parent.
    166   //
    167   // For example, subcomputations inside kFusion and kMap are considered nested,
    168   // while subcomputations inside kWhile and kConditional are considered
    169   // unnested.
    170   //
    171   // Another way of thinking of this is, kFusion is implemented on the GPU
    172   // backend using just one GPU kernel, while kWhile's body is implemented as a
    173   // sequence of kernels, one for each HLO therein.  Backends don't necessarily
    174   // need to follow this same implementation strategy, but we assume they do for
    175   // the purposes of this platform-generic cost analysis.
    176   StatusOr<Properties> ProcessNestedSubcomputation(HloComputation* computation);
    177   StatusOr<Properties> ProcessUnnestedSubcomputation(
    178       HloComputation* computation);
    179 
    180   // Utility function to handle all element-wise operations.
    181   Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
    182 
    183   // Returns the default value if the key is not present in the
    184   // properties. Otherwise, returns the value that the key maps to from the
    185   // properties parameter.
    186   static float GetProperty(const string& key, const Properties& properties,
    187                            float default_value = 0.0f);
    188 
    189   // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key
    190   // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that
    191   // the key maps to in the properties of the given hlo.
    192   static float GetPropertyForHlo(const HloInstruction& hlo, const string& key,
    193                                  const HloToProperties& hlo_to_properties);
    194 
    195   // Decorates shape_size_ by returning 0 immediately if the shape does not have
    196   // a layout.
    197   int64 GetShapeSize(const Shape& shape) const;
    198 
    199   // Function which computes the size of the top-level of a given shape (not
    200   // including nested elements, if any). If null then bytes_accessed methods
    201   // return an error.
    202   const ShapeSizeFunction shape_size_;
    203 
    204   HloToProperties hlo_properties_;
    205 
    206   // If true, the time taken will be computed from the rates for each property
    207   // and the total time will be the maximum time, which is the time of the
    208   // bottleneck.
    209   bool current_should_compute_bottleneck_time_;
    210 
    211   // The properties of the currently visited instruction. A HandleFoo method can
    212   // modify these to change the default values computed in Preprocess.
    213   Properties current_properties_;
    214 
    215   // The sum of the properties of all HLOs in the computation.
    216   Properties properties_sum_;
    217 
    218   // How much of each property can be processed per second. E.g. if the property
    219   // is bytes accessed, this is the number of bytes that can be processed per
    220   // second. Is empty if no rates have been set.
    221   Properties per_second_rates_;
    222 
    223   TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis);
    224 };
    225 
    226 }  // namespace xla
    227 
    228 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
    229