Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
     18 
     19 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
     20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     23 #include "tensorflow/compiler/xla/shape_util.h"
     24 #include "tensorflow/compiler/xla/statusor.h"
     25 #include "tensorflow/compiler/xla/xla_data.pb.h"
     26 #include "tensorflow/core/lib/gtl/array_slice.h"
     27 #include "tensorflow/core/platform/macros.h"
     28 #include "tensorflow/core/platform/types.h"
     29 
     30 namespace xla {
     31 
     32 // HloCostAnalysis traverses an HLO graph and calculates the amount of
     33 // computations required for the graph. Each HLO instruction handler provides
     34 // the computation cost of the instruction, and the values are accumulated
     35 // during the traversal for the entire graph. We treat normal floating point
     36 // operations separately from transcendental operations.
     37 class HloCostAnalysis : public ConstDfsHloVisitor {
     38  public:
     39   // Each HLO is associated to a vector of properties with the indices given
     40   // below. Sub-classes can add further properties.
     41   typedef std::map<string, float> Properties;
     42   static constexpr char kFlopsKey[] = "flops";
     43   static constexpr char kTranscendentalsKey[] = "transcendentals";
     44   static constexpr char kBytesAccessedKey[] = "bytes accessed";
     45   static constexpr char kOptimalSecondsKey[] = "optimal_seconds";
     46 
     47   // shape_size is a function which returns the size in bytes of the top-level
     48   // buffer of a shape.
     49   using ShapeSizeFunction = std::function<int64(const Shape&)>;
     50   explicit HloCostAnalysis(const ShapeSizeFunction& shape_size);
     51 
     52   Status HandleElementwiseUnary(const HloInstruction* hlo) override;
     53   Status HandleElementwiseBinary(const HloInstruction* hlo) override;
     54   Status HandleConstant(const HloInstruction* constant) override;
     55   Status HandleGetTupleElement(
     56       const HloInstruction* get_tuple_element) override;
     57   Status HandleSelect(const HloInstruction* select) override;
     58   Status HandleCompare(const HloInstruction* compare) override;
     59   Status HandleClamp(const HloInstruction* clamp) override;
     60   Status HandleReducePrecision(const HloInstruction* hlo) override;
     61   Status HandleConcatenate(const HloInstruction* concatenate) override;
     62   Status HandleSend(const HloInstruction* send) override;
     63   Status HandleSendDone(const HloInstruction* send_done) override;
     64   Status HandleRecv(const HloInstruction* recv) override;
     65   Status HandleRecvDone(const HloInstruction* recv_done) override;
     66   Status HandleConvert(const HloInstruction* convert) override;
     67   Status HandleCopy(const HloInstruction* copy) override;
     68   Status HandleDot(const HloInstruction* dot) override;
     69   Status HandleConvolution(const HloInstruction* convolution) override;
     70   Status HandleFft(const HloInstruction* fft) override;
     71   Status HandleCrossReplicaSum(const HloInstruction* crs) override;
     72   Status HandleInfeed(const HloInstruction* infeed) override;
     73   Status HandleOutfeed(const HloInstruction* outfeed) override;
     74   Status HandleHostCompute(const HloInstruction* host_compute) override;
     75   Status HandleRng(const HloInstruction* random) override;
     76   Status HandleReverse(const HloInstruction* reverse) override;
     77   Status HandleSort(const HloInstruction* sort) override;
     78   Status HandleParameter(const HloInstruction* parameter) override;
     79   Status HandleReduce(const HloInstruction* reduce) override;
     80   Status HandleBatchNormTraining(
     81       const HloInstruction* batch_norm_training) override;
     82   Status HandleBatchNormInference(
     83       const HloInstruction* batch_norm_inference) override;
     84   Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override;
     85   Status HandleFusion(const HloInstruction* fusion) override;
     86   Status HandleCall(const HloInstruction* call) override;
     87   Status HandleCustomCall(const HloInstruction* custom_call) override;
     88   Status HandleSlice(const HloInstruction* slice) override;
     89   Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override;
     90   Status HandleDynamicUpdateSlice(
     91       const HloInstruction* dynamic_update_slice) override;
     92   Status HandleTuple(const HloInstruction* tuple) override;
     93   Status HandleMap(const HloInstruction* map) override;
     94   Status HandleReduceWindow(const HloInstruction* reduce_window) override;
     95   Status HandleSelectAndScatter(const HloInstruction* instruction) override;
     96   Status HandleBitcast(const HloInstruction* bitcast) override;
     97   Status HandleBroadcast(const HloInstruction* broadcast) override;
     98   Status HandlePad(const HloInstruction* pad) override;
     99   Status HandleReshape(const HloInstruction* reshape) override;
    100   Status HandleTranspose(const HloInstruction* transpose) override;
    101   Status HandleWhile(const HloInstruction* xla_while) override;
    102   Status HandleConditional(const HloInstruction* conditional) override;
    103   Status HandleGather(const HloInstruction* gather) override;
    104   Status FinishVisit(const HloInstruction* root) override;
    105 
    106   Status Preprocess(const HloInstruction* hlo) override;
    107   Status Postprocess(const HloInstruction* hlo) override;
    108 
    109   // Set the rates used to calculate the time taken by the computation. These
    110   // need to be set before visiting starts.
    111   void set_flops_per_second(float value) {
    112     per_second_rates_[kFlopsKey] = value;
    113   }
    114   void set_transcendentals_per_second(float value) {
    115     per_second_rates_[kTranscendentalsKey] = value;
    116   }
    117   void set_bytes_per_second(float value) {
    118     per_second_rates_[kBytesAccessedKey] = value;
    119   }
    120 
    121   // Returns properties for the computation.
    122   float flop_count() const;
    123   float transcendental_count() const;
    124   float bytes_accessed() const;
    125   float optimal_seconds() const;
    126 
    127   // Returns the respective cost computed for a particular HLO instruction, or 0
    128   // if the HLO was not found to have a cost in the analysis.
    129   int64 flop_count(const HloInstruction& hlo) const;
    130   int64 transcendental_count(const HloInstruction& hlo) const;
    131   int64 bytes_accessed(const HloInstruction& hlo) const;
    132   float optimal_seconds(const HloInstruction& hlo) const;
    133 
    134   const Properties& properties() const { return properties_sum_; }
    135   const float property(const string& key) const {
    136     return GetProperty(key, properties());
    137   }
    138 
    139  protected:
    140   typedef std::unordered_map<const HloInstruction*, Properties> HloToProperties;
    141 
    142   // An FMA counts as two floating point operations in these analyzes.
    143   static constexpr int64 kFmaFlops = 2;
    144 
    145   HloCostAnalysis(const ShapeSizeFunction& shape_size,
    146                   const Properties& per_second_rates);
    147 
    148   // Returns the properties computed from visiting the computation rooted at the
    149   // given hlo. Uses shape_size_ to calculate shape sizes if shape_size is null,
    150   // otherwise uses shape_size_.
    151   StatusOr<Properties> ProcessSubcomputation(
    152       HloComputation* computation,
    153       const ShapeSizeFunction* shape_size = nullptr);
    154 
    155   // Utility function to handle all element-wise operations.
    156   Status HandleElementwiseOp(const HloInstruction* hlo_instruction);
    157 
    158   // Returns the default value if the key is not present in the
    159   // properties. Otherwise, returns the value that the key maps to from the
    160   // properties parameter.
    161   static float GetProperty(const string& key, const Properties& properties,
    162                            float default_value = 0.0f);
    163 
    164   // Returns 0.0f if the hlo is not present in hlo_to_properties or if the key
    165   // is not present in hlo_to_properties[hlo]. Otherwise, returns the value that
    166   // the key maps to in the properties of the given hlo.
    167   static float GetPropertyForHlo(const HloInstruction& hlo, const string& key,
    168                                  const HloToProperties& hlo_to_properties);
    169 
    170   // Function which computes the size of the top-level of a given shape (not
    171   // including nested elements, if any). If null then bytes_accessed methods
    172   // return an error.
    173   const ShapeSizeFunction shape_size_;
    174 
    175   HloToProperties hlo_properties_;
    176 
    177   // If true, the time taken will be computed from the rates for each property
    178   // and the total time will be the maximum time, which is the time of the
    179   // bottleneck.
    180   bool current_should_compute_bottleneck_time_;
    181 
    182   // The properties of the currently visited instruction. A HandleFoo method can
    183   // modify these to change the default values computed in Preprocess.
    184   Properties current_properties_;
    185 
    186   // The sum of the properties of all HLOs in the computation.
    187   Properties properties_sum_;
    188 
    189   // How much of each property can be processed per second. E.g. if the property
    190   // is bytes accessed, this is the number of bytes that can be processed per
    191   // second. Is empty if no rates have been set.
    192   Properties per_second_rates_;
    193 
    194   TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis);
    195 };
    196 
    197 }  // namespace xla
    198 
    199 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
    200