Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
     17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
     18 
     19 #include <functional>
     20 #include <map>
     21 #include <string>
     22 
     23 #include "tensorflow/core/grappler/costs/cost_estimator.h"
     24 #include "tensorflow/core/grappler/costs/op_context.h"
     25 #include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
     26 #include "tensorflow/core/util/padding.h"
     27 
     28 namespace tensorflow {
     29 namespace grappler {
     30 
     31 class OpLevelCostEstimator {
     32  public:
     33   OpLevelCostEstimator();
     34   virtual ~OpLevelCostEstimator() {}
     35 
     36   virtual Costs PredictCosts(const OpContext& op_context) const;
     37 
     38   // Basic device performance info, sufficient for roofline estimate.
     39   struct DeviceInfo {
     40     double gigaops;     // Billions of operations executed per second.
     41     double gb_per_sec;  // Bandwidth to main memory in GB per second.
     42   };
     43 
     44   // Returns basic device performance info.
     45   virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const;
     46 
     47  protected:
     48   // Predict cost of an op for which no accurate estimator is defined.
     49   Costs PredictCostOfAnUnknownOp(const OpContext& op_context) const;
     50 
     51   // Naive cost estimate based on operations divided by device ops/sec,
     52   // and input/output tensor sizes.
     53   Costs PredictOpCountBasedCost(double operations,
     54                                 const OpInfo& op_features) const;
     55 
     56   // This family of routines counts the number of operations to perform the
     57   // specified TensorFlow Op.
     58   struct MatMulDimensions {
     59     int m;
     60     int n;
     61     int k;
     62   };
     63   struct ConvolutionDimensions {
     64     int64 batch;      // Batch size.
     65     int64 ix;         // Input size x.
     66     int64 iy;         // Input size y.
     67     int64 iz;         // Input depth.
     68     int64 kx;         // Kernel x.
     69     int64 ky;         // Kernel y.
     70     int64 oz;         // Output depth.
     71     int64 ox;         // Output size x.
     72     int64 oy;         // Output size y.
     73     int64 sx;         // Stride x.
     74     int64 sy;         // Stride y.
     75     Padding padding;  // SAME or VALID.
     76   };
     77   int64 CountConv2DOperations(const OpInfo& op_features,
     78                               bool* found_unknown_shapes) const;
     79   int64 CountConv2DOperations(const OpInfo& op_features,
     80                               ConvolutionDimensions* conv_info,
     81                               bool* found_unknown_shapes) const;
     82   int64 CountMatMulOperations(const OpInfo& op_features,
     83                               bool* found_unknown_shapes) const;
     84   int64 CountMatMulOperations(const OpInfo& op_features,
     85                               MatMulDimensions* mat_mul,
     86                               bool* found_unknown_shapes) const;
     87   int64 CountBatchMatMulOperations(const OpInfo& op_features,
     88                                    bool* found_unknown_shapes) const;
     89   int64 CountConv2DBackpropInputOperations(const OpInfo& op_features,
     90                                            ConvolutionDimensions* conv_info,
     91                                            bool* found_unknown_shapes) const;
     92   int64 CountConv2DBackpropFilterOperations(const OpInfo& op_features,
     93                                             ConvolutionDimensions* conv_info,
     94                                             bool* found_unknown_shapes) const;
     95 
     96   // Calculate the element count of an input/output tensor.
     97   int64 CalculateTensorElementCount(const OpInfo::TensorProperties& tensor,
     98                                     bool* found_unknown_shapes) const;
     99 
    100   // Calculate the total size in bytes of an input/output tensor.
    101   int64 CalculateTensorSize(const OpInfo::TensorProperties& tensor,
    102                             bool* found_unknown_shapes) const;
    103 
    104   // Calculate the element count of the largest
    105   // input of specified TensorFlow op.
    106   int64 CalculateLargestInputCount(const OpInfo& op_features,
    107                                    bool* found_unknown_shapes) const;
    108 
    109   // Calculate the total size in bytes of the all
    110   // the inputs of specified TensorFlow op.
    111   int64 CalculateInputSize(const OpInfo& op_features,
    112                            bool* found_unknown_shapes) const;
    113 
    114   // Calculate the total size in bytes of the all
    115   // the outputs of specified TensorFlow op.
    116   int64 CalculateOutputSize(const OpInfo& op_features,
    117                             bool* found_unknown_shapes) const;
    118 
    119   // This family of routines predicts the costs to
    120   // perform the specified TensorFlow Op on the
    121   // device represented by a subclass. The default
    122   // implementation just divides the operations to
    123   // perform the op (from the "Count" routines,
    124   // above) by the device peak operations per
    125   // second. Override to supply a better estimate.
    126   // Implementation of costs other than
    127   // execution_time is optional, depending on the
    128   // device.
    129   Costs PredictConv2D(const OpContext& op_context) const;
    130   Costs PredictCwiseOp(const OpContext& op_context) const;
    131   Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
    132   Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
    133   Costs PredictMatMul(const OpContext& op_context) const;
    134   Costs PredictNoOp(const OpContext& op_context) const;
    135   Costs PredictIdentity(const OpContext& op_context) const;
    136   Costs PredictVariable(const OpContext& op_context) const;
    137   Costs PredictBatchMatMul(const OpContext& op_context) const;
    138   Costs PredictMetadata(const OpContext& op_context) const;
    139 
    140   // Utility function for safe division. Returns 0
    141   // if rhs is 0 or negative.
    142   static double SafeDiv(const double lhs, const double rhs) {
    143     if (rhs > 0) {
    144       return lhs / rhs;
    145     } else {
    146       return 0.0;
    147     }
    148   }
    149 
    150   static ConvolutionDimensions ConvolutionDimensionsFromInputs(
    151       const TensorShapeProto& original_image_shape,
    152       const TensorShapeProto& original_filter_shape, const OpInfo& op_features,
    153       bool* found_unknown_shapes);
    154 
    155  protected:
    156   std::map<string, int> elementwise_ops_;
    157   typedef std::function<Costs(const OpContext& op_context)> CostImpl;
    158   std::map<string, CostImpl> device_cost_impl_;
    159   // If true, assume compute and memory overlap; hence, the op cost is max of
    160   // compute_time and memory_time, insteaf of sum of those two.
    161   bool compute_memory_overlap_;
    162 
    163  private:
    164   friend class OpLevelCostEstimatorTest;
    165 };
    166 
    167 }  // end namespace grappler
    168 }  // end namespace tensorflow
    169 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
    170