Home | History | Annotate | Download | only in costs
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_
     17 #define TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_
     18 
     19 #include <chrono>
     20 #include <unordered_map>
     21 #include "tensorflow/core/lib/core/status.h"
     22 
     23 namespace tensorflow {
     24 class GraphDef;
     25 class CostGraphDef;
     26 
     27 namespace grappler {
     28 struct GrapplerItem;
     29 
     30 constexpr int64 kMemoryUnknown = -1ll;
     31 constexpr int64 kZeroMemory = 0ll;
     32 
     33 // Holds the set of things we might want to estimate or measure in Grappler.
     34 // Always produce execution time. Other fields are optional depending on the
     35 // estimator being used.
     36 struct Costs {
     37   // Returns a Costs structure with default values for all of the fields.
     38   inline Costs();
     39 
     40   // Builds a Costs structure with all zero values, rather than unknowns.
     41   static inline Costs ZeroCosts();
     42 
     43   struct MilliSeconds : std::chrono::milliseconds {
     44     MilliSeconds() : std::chrono::milliseconds(0) {}
     45     MilliSeconds(double d) : std::chrono::milliseconds(static_cast<int64>(d)) {}
     46     MilliSeconds(const std::chrono::milliseconds& d)
     47         : std::chrono::milliseconds(d) {}
     48     MilliSeconds& operator=(const std::chrono::milliseconds& d) {
     49       std::chrono::milliseconds::operator=(d);
     50       return *this;
     51     }
     52   };
     53   struct MicroSeconds : std::chrono::microseconds {
     54     MicroSeconds() : std::chrono::microseconds(0) {}
     55     MicroSeconds(double d) : std::chrono::microseconds(static_cast<int64>(d)) {}
     56     MicroSeconds(const std::chrono::microseconds& d)
     57         : std::chrono::microseconds(d) {}
     58     MicroSeconds& operator=(const std::chrono::microseconds& d) {
     59       std::chrono::microseconds::operator=(d);
     60       return *this;
     61     }
     62     MilliSeconds asMilliSeconds() const {
     63       return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
     64     }
     65   };
     66   struct NanoSeconds : std::chrono::nanoseconds {
     67     NanoSeconds() : std::chrono::nanoseconds(0) {}
     68     NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64>(d)) {}
     69     NanoSeconds(const std::chrono::nanoseconds& d)
     70         : std::chrono::nanoseconds(d) {}
     71     NanoSeconds& operator=(const std::chrono::nanoseconds& d) {
     72       std::chrono::nanoseconds::operator=(d);
     73       return *this;
     74     }
     75     MicroSeconds asMicroSeconds() const {
     76       return std::chrono::duration_cast<std::chrono::microseconds>(*this);
     77     }
     78     MilliSeconds asMilliSeconds() const {
     79       return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
     80     }
     81     static NanoSeconds infinity() {
     82       return NanoSeconds(std::chrono::nanoseconds::max());
     83     }
     84   };
     85 
     86   // We store all our times in nanoseconds. If needs be, we can always switch to
     87   // picoseconds in the future by updating this typedef.
     88   typedef NanoSeconds Duration;
     89 
     90   // Overall cost of running the graph; latency.
     91   Duration execution_time;
     92 
     93   // Computation cost of running the graph.
     94   Duration compute_time;
     95 
     96   // Memory access cost of running the graph.
     97   Duration memory_time;
     98 
     99   // This field can be a very pessimistic estimate of the main memory
    100   // requirements of a graph. For example, it might assume that all activations
    101   // are live for all of a graph's execution.
    102   int64 max_memory;  // Maximum main memory requirement in bytes over all ops.
    103   int64 persistent_memory;
    104   int64 temporary_memory;
    105 
    106   // These fields are used for TPU-related estimations. They are per-op
    107   // maximums, so each op is evaluated independently, but we want the maximum of
    108   // the value over all ops.
    109   int64 max_per_op_buffers;    // Sum of all buffers used by the ops.
    110   int64 max_per_op_streaming;  // Ignore largest input buffer, assuming it
    111                                // streams from main memory.
    112   // If the time estimation is inaccurate.
    113   bool inaccurate = false;
    114 
    115   // Max possible memory usage per device.
    116   std::unordered_map<string, uint64> estimated_max_memory_per_device;
    117 };
    118 
    119 inline std::ostream& operator<<(std::ostream& os, const Costs::MilliSeconds d) {
    120   os << d.count() << "ms";
    121   return os;
    122 }
    123 inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) {
    124   os << d.count() << "us";
    125   return os;
    126 }
    127 inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) {
    128   os << d.count() << "ns";
    129   return os;
    130 }
    131 
    132 Costs::Costs() {
    133   execution_time = Duration::zero();
    134   compute_time = Duration::zero();
    135   memory_time = Duration::zero();
    136   max_memory = kMemoryUnknown;
    137   persistent_memory = kMemoryUnknown;
    138   temporary_memory = kMemoryUnknown;
    139   max_per_op_buffers = kMemoryUnknown;
    140   max_per_op_streaming = kMemoryUnknown;
    141 }
    142 
    143 Costs Costs::ZeroCosts() {
    144   Costs costs;
    145   costs.execution_time = Duration::zero();
    146   costs.compute_time = Duration::zero();
    147   costs.memory_time = Duration::zero();
    148   costs.max_memory = kZeroMemory;
    149   costs.persistent_memory = kZeroMemory;
    150   costs.temporary_memory = kZeroMemory;
    151   costs.max_per_op_buffers = kZeroMemory;
    152   costs.max_per_op_streaming = kZeroMemory;
    153   return costs;
    154 }
    155 
    156 // Given a GrapperItem and an optimized implementation of the corresponding
    157 // TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
    158 // running the graph.
    159 class CostEstimator {
    160  public:
    161   virtual ~CostEstimator() {}
    162 
    163   // Initializes the estimator for the specified grappler item.
    164   // The estimator shouldn't be used if this function returns any status other
    165   // that OK.
    166   virtual Status Initialize(const GrapplerItem& item) = 0;
    167 
    168   // Predicts the cost of running the given optimized version of the grappler
    169   // item.
    170   // If a CostGraphDef is passed, it will be populated with detailed information
    171   // about the cost of running each operation of the optimized graph.
    172   // if a double value is passed, it will be set to a value that reflects the
    173   // overall cost of running the graph (e.g. the latency of the computation).
    174   // Returns a status that indicate is the performance could be estimated or
    175   // not.
    176   virtual Status PredictCosts(const GraphDef& optimized_graph,
    177                               CostGraphDef* cost_graph, Costs* cost) const = 0;
    178 };
    179 
    180 }  // end namespace grappler
    181 }  // end namespace tensorflow
    182 
    183 #endif  // TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_
    184