1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_ 18 19 #include <cmath> 20 #include "tensorflow/core/lib/core/status.h" 21 #include "tensorflow/core/protobuf/config.pb.h" 22 23 namespace tensorflow { 24 class GraphDef; 25 class CostGraphDef; 26 27 namespace grappler { 28 struct GrapplerItem; 29 30 constexpr int64 kMemoryUnknown = -1ll; 31 constexpr int64 kZeroMemory = 0ll; 32 33 struct DeviceInfo { 34 // Billions of operations executed per second. 35 double gigaops; 36 37 // Bandwidth to main memory in GB per second. 38 double gb_per_sec; 39 40 // Read bandwidth to intermediate memory in GB per second. 41 double intermediate_read_gb_per_sec; 42 43 // Read bandwidth to intermediate memory in GB per second. 44 double intermediate_write_gb_per_sec; 45 46 DeviceInfo() 47 : gigaops(INFINITY), 48 gb_per_sec(INFINITY), 49 intermediate_read_gb_per_sec(INFINITY), 50 intermediate_write_gb_per_sec(INFINITY) {} 51 52 DeviceInfo(const DeviceInfo& input) 53 : gigaops(input.gigaops), 54 gb_per_sec(input.gb_per_sec), 55 intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec), 56 intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {} 57 58 DeviceInfo(double gigaops, double gb_per_sec, 59 double intermediate_read_gb_per_sec = INFINITY, 60 double intermediate_write_gb_per_sec = INFINITY) 61 : gigaops(gigaops), 62 gb_per_sec(gb_per_sec), 63 intermediate_read_gb_per_sec(intermediate_read_gb_per_sec), 64 intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {} 65 }; 66 67 // Holds the set of things we might want to estimate or measure in Grappler. 68 // Always produce execution time. Other fields are optional depending on the 69 // estimator being used. 70 struct Costs { 71 // Returns a Costs structure with default values for all of the fields. 72 inline Costs(); 73 74 // Builds a Costs structure with all zero values, rather than unknowns. 75 static inline Costs ZeroCosts(); 76 77 struct MilliSeconds : std::chrono::milliseconds { 78 MilliSeconds() : std::chrono::milliseconds(0) {} 79 MilliSeconds(double d) : std::chrono::milliseconds(static_cast<int64>(d)) {} 80 MilliSeconds(const std::chrono::milliseconds& d) 81 : std::chrono::milliseconds(d) {} 82 MilliSeconds& operator=(const std::chrono::milliseconds& d) { 83 std::chrono::milliseconds::operator=(d); 84 return *this; 85 } 86 }; 87 struct MicroSeconds : std::chrono::microseconds { 88 MicroSeconds() : std::chrono::microseconds(0) {} 89 MicroSeconds(double d) : std::chrono::microseconds(static_cast<int64>(d)) {} 90 MicroSeconds(const std::chrono::microseconds& d) 91 : std::chrono::microseconds(d) {} 92 MicroSeconds& operator=(const std::chrono::microseconds& d) { 93 std::chrono::microseconds::operator=(d); 94 return *this; 95 } 96 MilliSeconds asMilliSeconds() const { 97 return std::chrono::duration_cast<std::chrono::milliseconds>(*this); 98 } 99 }; 100 struct NanoSeconds : std::chrono::nanoseconds { 101 NanoSeconds() : std::chrono::nanoseconds(0) {} 102 NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64>(d)) {} 103 NanoSeconds(const std::chrono::nanoseconds& d) 104 : std::chrono::nanoseconds(d) {} 105 NanoSeconds& operator=(const std::chrono::nanoseconds& d) { 106 std::chrono::nanoseconds::operator=(d); 107 return *this; 108 } 109 MicroSeconds asMicroSeconds() const { 110 return std::chrono::duration_cast<std::chrono::microseconds>(*this); 111 } 112 MilliSeconds asMilliSeconds() const { 113 return std::chrono::duration_cast<std::chrono::milliseconds>(*this); 114 } 115 static NanoSeconds infinity() { 116 return NanoSeconds(std::chrono::nanoseconds::max()); 117 } 118 }; 119 120 // We store all our times in nanoseconds. If needs be, we can always switch to 121 // picoseconds in the future by updating this typedef. 122 typedef NanoSeconds Duration; 123 124 // Overall cost of running the graph; latency. 125 Duration execution_time; 126 127 // Computation cost of running the graph. 128 Duration compute_time; 129 130 // Memory access cost of running the graph. 131 Duration memory_time; 132 133 // Intermediate memory access cost of running the graph 134 Duration intermediate_memory_time; 135 Duration intermediate_memory_read_time; // Intermediate memory read cost. 136 Duration intermediate_memory_write_time; // Intermediate memory write cost. 137 138 // This field can be a very pessimistic estimate of the main memory 139 // requirements of a graph. For example, it might assume that all activations 140 // are live for all of a graph's execution. 141 int64 max_memory; // Maximum main memory requirement in bytes over all ops. 142 int64 persistent_memory; 143 int64 temporary_memory; 144 145 // These fields are used for TPU-related estimations. They are per-op 146 // maximums, so each op is evaluated independently, but we want the maximum of 147 // the value over all ops. 148 int64 max_per_op_buffers; // Sum of all buffers used by the ops. 149 int64 max_per_op_streaming; // Ignore largest input buffer, assuming it 150 // streams from main memory. 151 152 // Number of ops included in this Costs in total. 153 // Default initialized to be one. 154 int64 num_ops_total = 1; 155 // If the time estimation is inaccurate. 156 bool inaccurate = false; 157 // Number of ops that are estimated with unknown shapes. 158 int64 num_ops_with_unknown_shapes = 0; 159 // TODO(pcma): include a counter for total inaccurate ops and counters for 160 // other reasons causing the inaccuracy 161 162 // Max possible memory usage per device. 163 std::unordered_map<string, uint64> estimated_max_memory_per_device; 164 }; 165 166 inline std::ostream& operator<<(std::ostream& os, const Costs::MilliSeconds d) { 167 os << d.count() << "ms"; 168 return os; 169 } 170 inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) { 171 os << d.count() << "us"; 172 return os; 173 } 174 inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) { 175 os << d.count() << "ns"; 176 return os; 177 } 178 179 Costs::Costs() { 180 execution_time = Duration::zero(); 181 compute_time = Duration::zero(); 182 memory_time = Duration::zero(); 183 intermediate_memory_time = Duration::zero(); 184 max_memory = kMemoryUnknown; 185 persistent_memory = kMemoryUnknown; 186 temporary_memory = kMemoryUnknown; 187 max_per_op_buffers = kMemoryUnknown; 188 max_per_op_streaming = kMemoryUnknown; 189 } 190 191 Costs Costs::ZeroCosts() { 192 Costs costs; 193 costs.execution_time = Duration::zero(); 194 costs.compute_time = Duration::zero(); 195 costs.memory_time = Duration::zero(); 196 costs.intermediate_memory_time = Duration::zero(); 197 costs.max_memory = kZeroMemory; 198 costs.persistent_memory = kZeroMemory; 199 costs.temporary_memory = kZeroMemory; 200 costs.max_per_op_buffers = kZeroMemory; 201 costs.max_per_op_streaming = kZeroMemory; 202 return costs; 203 } 204 205 Costs CombineCosts(const Costs& left, const Costs& right); 206 207 // Multiplies Costs by a scalar. 208 // Equivalent to applying CombineCosts "multiplier" times. 209 Costs MultiplyCosts(const Costs& costs, int multiplier); 210 211 // Given a GrapperItem and an optimized implementation of the corresponding 212 // TensorFlow graph, the CostEstimator attempts to predicts the actual cost of 213 // running the graph. 214 class CostEstimator { 215 public: 216 virtual ~CostEstimator() {} 217 218 // Initializes the estimator for the specified grappler item. 219 // The estimator shouldn't be used if this function returns any status other 220 // that OK. 221 virtual Status Initialize(const GrapplerItem& item) = 0; 222 223 // Predicts the cost of running the given optimized version of the grappler 224 // item. 225 // If a RunMetadata is passed, it will be populated with detailed information 226 // about the cost of running each operation of the optimized graph. 227 // if a double value is passed, it will be set to a value that reflects the 228 // overall cost of running the graph (e.g. the latency of the computation). 229 // Returns a status that indicate is the performance could be estimated or 230 // not. 231 virtual Status PredictCosts(const GraphDef& optimized_graph, 232 RunMetadata* run_metadata, Costs* cost) const = 0; 233 }; 234 235 } // end namespace grappler 236 } // end namespace tensorflow 237 238 #endif // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_ 239