Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
     17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
     18 
     19 #include <cmath>
     20 #include "tensorflow/core/lib/core/status.h"
     21 #include "tensorflow/core/protobuf/config.pb.h"
     22 
     23 namespace tensorflow {
     24 class GraphDef;
     25 class CostGraphDef;
     26 
     27 namespace grappler {
     28 struct GrapplerItem;
     29 
     30 constexpr int64 kMemoryUnknown = -1ll;
     31 constexpr int64 kZeroMemory = 0ll;
     32 
     33 struct DeviceInfo {
     34   // Billions of operations executed per second.
     35   double gigaops;
     36 
     37   // Bandwidth to main memory in GB per second.
     38   double gb_per_sec;
     39 
     40   // Read bandwidth to intermediate memory in GB per second.
     41   double intermediate_read_gb_per_sec;
     42 
     43   // Read bandwidth to intermediate memory in GB per second.
     44   double intermediate_write_gb_per_sec;
     45 
     46   DeviceInfo()
     47       : gigaops(INFINITY),
     48         gb_per_sec(INFINITY),
     49         intermediate_read_gb_per_sec(INFINITY),
     50         intermediate_write_gb_per_sec(INFINITY) {}
     51 
     52   DeviceInfo(const DeviceInfo& input)
     53       : gigaops(input.gigaops),
     54         gb_per_sec(input.gb_per_sec),
     55         intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec),
     56         intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {}
     57 
     58   DeviceInfo(double gigaops, double gb_per_sec,
     59              double intermediate_read_gb_per_sec = INFINITY,
     60              double intermediate_write_gb_per_sec = INFINITY)
     61       : gigaops(gigaops),
     62         gb_per_sec(gb_per_sec),
     63         intermediate_read_gb_per_sec(intermediate_read_gb_per_sec),
     64         intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {}
     65 };
     66 
     67 // Holds the set of things we might want to estimate or measure in Grappler.
     68 // Always produce execution time. Other fields are optional depending on the
     69 // estimator being used.
     70 struct Costs {
     71   // Returns a Costs structure with default values for all of the fields.
     72   inline Costs();
     73 
     74   // Builds a Costs structure with all zero values, rather than unknowns.
     75   static inline Costs ZeroCosts();
     76 
     77   struct MilliSeconds : std::chrono::milliseconds {
     78     MilliSeconds() : std::chrono::milliseconds(0) {}
     79     MilliSeconds(double d) : std::chrono::milliseconds(static_cast<int64>(d)) {}
     80     MilliSeconds(const std::chrono::milliseconds& d)
     81         : std::chrono::milliseconds(d) {}
     82     MilliSeconds& operator=(const std::chrono::milliseconds& d) {
     83       std::chrono::milliseconds::operator=(d);
     84       return *this;
     85     }
     86   };
     87   struct MicroSeconds : std::chrono::microseconds {
     88     MicroSeconds() : std::chrono::microseconds(0) {}
     89     MicroSeconds(double d) : std::chrono::microseconds(static_cast<int64>(d)) {}
     90     MicroSeconds(const std::chrono::microseconds& d)
     91         : std::chrono::microseconds(d) {}
     92     MicroSeconds& operator=(const std::chrono::microseconds& d) {
     93       std::chrono::microseconds::operator=(d);
     94       return *this;
     95     }
     96     MilliSeconds asMilliSeconds() const {
     97       return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
     98     }
     99   };
    100   struct NanoSeconds : std::chrono::nanoseconds {
    101     NanoSeconds() : std::chrono::nanoseconds(0) {}
    102     NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64>(d)) {}
    103     NanoSeconds(const std::chrono::nanoseconds& d)
    104         : std::chrono::nanoseconds(d) {}
    105     NanoSeconds& operator=(const std::chrono::nanoseconds& d) {
    106       std::chrono::nanoseconds::operator=(d);
    107       return *this;
    108     }
    109     MicroSeconds asMicroSeconds() const {
    110       return std::chrono::duration_cast<std::chrono::microseconds>(*this);
    111     }
    112     MilliSeconds asMilliSeconds() const {
    113       return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
    114     }
    115     static NanoSeconds infinity() {
    116       return NanoSeconds(std::chrono::nanoseconds::max());
    117     }
    118   };
    119 
    120   // We store all our times in nanoseconds. If needs be, we can always switch to
    121   // picoseconds in the future by updating this typedef.
    122   typedef NanoSeconds Duration;
    123 
    124   // Overall cost of running the graph; latency.
    125   Duration execution_time;
    126 
    127   // Computation cost of running the graph.
    128   Duration compute_time;
    129 
    130   // Memory access cost of running the graph.
    131   Duration memory_time;
    132 
    133   // Intermediate memory access cost of running the graph
    134   Duration intermediate_memory_time;
    135   Duration intermediate_memory_read_time;   // Intermediate memory read cost.
    136   Duration intermediate_memory_write_time;  // Intermediate memory write cost.
    137 
    138   // This field can be a very pessimistic estimate of the main memory
    139   // requirements of a graph. For example, it might assume that all activations
    140   // are live for all of a graph's execution.
    141   int64 max_memory;  // Maximum main memory requirement in bytes over all ops.
    142   int64 persistent_memory;
    143   int64 temporary_memory;
    144 
    145   // These fields are used for TPU-related estimations. They are per-op
    146   // maximums, so each op is evaluated independently, but we want the maximum of
    147   // the value over all ops.
    148   int64 max_per_op_buffers;    // Sum of all buffers used by the ops.
    149   int64 max_per_op_streaming;  // Ignore largest input buffer, assuming it
    150                                // streams from main memory.
    151 
    152   // Number of ops included in this Costs in total.
    153   // Default initialized to be one.
    154   int64 num_ops_total = 1;
    155   // If the time estimation is inaccurate.
    156   bool inaccurate = false;
    157   // Number of ops that are estimated with unknown shapes.
    158   int64 num_ops_with_unknown_shapes = 0;
    159   // TODO(pcma): include a counter for total inaccurate ops and counters for
    160   // other reasons causing the inaccuracy
    161 
    162   // Max possible memory usage per device.
    163   std::unordered_map<string, uint64> estimated_max_memory_per_device;
    164 };
    165 
    166 inline std::ostream& operator<<(std::ostream& os, const Costs::MilliSeconds d) {
    167   os << d.count() << "ms";
    168   return os;
    169 }
    170 inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) {
    171   os << d.count() << "us";
    172   return os;
    173 }
    174 inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) {
    175   os << d.count() << "ns";
    176   return os;
    177 }
    178 
    179 Costs::Costs() {
    180   execution_time = Duration::zero();
    181   compute_time = Duration::zero();
    182   memory_time = Duration::zero();
    183   intermediate_memory_time = Duration::zero();
    184   max_memory = kMemoryUnknown;
    185   persistent_memory = kMemoryUnknown;
    186   temporary_memory = kMemoryUnknown;
    187   max_per_op_buffers = kMemoryUnknown;
    188   max_per_op_streaming = kMemoryUnknown;
    189 }
    190 
    191 Costs Costs::ZeroCosts() {
    192   Costs costs;
    193   costs.execution_time = Duration::zero();
    194   costs.compute_time = Duration::zero();
    195   costs.memory_time = Duration::zero();
    196   costs.intermediate_memory_time = Duration::zero();
    197   costs.max_memory = kZeroMemory;
    198   costs.persistent_memory = kZeroMemory;
    199   costs.temporary_memory = kZeroMemory;
    200   costs.max_per_op_buffers = kZeroMemory;
    201   costs.max_per_op_streaming = kZeroMemory;
    202   return costs;
    203 }
    204 
    205 Costs CombineCosts(const Costs& left, const Costs& right);
    206 
    207 // Multiplies Costs by a scalar.
    208 // Equivalent to applying CombineCosts "multiplier" times.
    209 Costs MultiplyCosts(const Costs& costs, int multiplier);
    210 
    211 // Given a GrapperItem and an optimized implementation of the corresponding
    212 // TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
    213 // running the graph.
    214 class CostEstimator {
    215  public:
    216   virtual ~CostEstimator() {}
    217 
    218   // Initializes the estimator for the specified grappler item.
    219   // The estimator shouldn't be used if this function returns any status other
    220   // that OK.
    221   virtual Status Initialize(const GrapplerItem& item) = 0;
    222 
    223   // Predicts the cost of running the given optimized version of the grappler
    224   // item.
    225   // If a RunMetadata is passed, it will be populated with detailed information
    226   // about the cost of running each operation of the optimized graph.
    227   // if a double value is passed, it will be set to a value that reflects the
    228   // overall cost of running the graph (e.g. the latency of the computation).
    229   // Returns a status that indicate is the performance could be estimated or
    230   // not.
    231   virtual Status PredictCosts(const GraphDef& optimized_graph,
    232                               RunMetadata* run_metadata, Costs* cost) const = 0;
    233 };
    234 
    235 }  // end namespace grappler
    236 }  // end namespace tensorflow
    237 
    238 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
    239