1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_ 17 #define TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_ 18 19 #include <chrono> 20 #include <unordered_map> 21 #include "tensorflow/core/lib/core/status.h" 22 23 namespace tensorflow { 24 class GraphDef; 25 class CostGraphDef; 26 27 namespace grappler { 28 struct GrapplerItem; 29 30 constexpr int64 kMemoryUnknown = -1ll; 31 constexpr int64 kZeroMemory = 0ll; 32 33 // Holds the set of things we might want to estimate or measure in Grappler. 34 // Always produce execution time. Other fields are optional depending on the 35 // estimator being used. 36 struct Costs { 37 // Returns a Costs structure with default values for all of the fields. 38 inline Costs(); 39 40 // Builds a Costs structure with all zero values, rather than unknowns. 41 static inline Costs ZeroCosts(); 42 43 struct MilliSeconds : std::chrono::milliseconds { 44 MilliSeconds() : std::chrono::milliseconds(0) {} 45 MilliSeconds(double d) : std::chrono::milliseconds(static_cast<int64>(d)) {} 46 MilliSeconds(const std::chrono::milliseconds& d) 47 : std::chrono::milliseconds(d) {} 48 MilliSeconds& operator=(const std::chrono::milliseconds& d) { 49 std::chrono::milliseconds::operator=(d); 50 return *this; 51 } 52 }; 53 struct MicroSeconds : std::chrono::microseconds { 54 MicroSeconds() : std::chrono::microseconds(0) {} 55 MicroSeconds(double d) : std::chrono::microseconds(static_cast<int64>(d)) {} 56 MicroSeconds(const std::chrono::microseconds& d) 57 : std::chrono::microseconds(d) {} 58 MicroSeconds& operator=(const std::chrono::microseconds& d) { 59 std::chrono::microseconds::operator=(d); 60 return *this; 61 } 62 MilliSeconds asMilliSeconds() const { 63 return std::chrono::duration_cast<std::chrono::milliseconds>(*this); 64 } 65 }; 66 struct NanoSeconds : std::chrono::nanoseconds { 67 NanoSeconds() : std::chrono::nanoseconds(0) {} 68 NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64>(d)) {} 69 NanoSeconds(const std::chrono::nanoseconds& d) 70 : std::chrono::nanoseconds(d) {} 71 NanoSeconds& operator=(const std::chrono::nanoseconds& d) { 72 std::chrono::nanoseconds::operator=(d); 73 return *this; 74 } 75 MicroSeconds asMicroSeconds() const { 76 return std::chrono::duration_cast<std::chrono::microseconds>(*this); 77 } 78 MilliSeconds asMilliSeconds() const { 79 return std::chrono::duration_cast<std::chrono::milliseconds>(*this); 80 } 81 static NanoSeconds infinity() { 82 return NanoSeconds(std::chrono::nanoseconds::max()); 83 } 84 }; 85 86 // We store all our times in nanoseconds. If needs be, we can always switch to 87 // picoseconds in the future by updating this typedef. 88 typedef NanoSeconds Duration; 89 90 // Overall cost of running the graph; latency. 91 Duration execution_time; 92 93 // Computation cost of running the graph. 94 Duration compute_time; 95 96 // Memory access cost of running the graph. 97 Duration memory_time; 98 99 // This field can be a very pessimistic estimate of the main memory 100 // requirements of a graph. For example, it might assume that all activations 101 // are live for all of a graph's execution. 102 int64 max_memory; // Maximum main memory requirement in bytes over all ops. 103 int64 persistent_memory; 104 int64 temporary_memory; 105 106 // These fields are used for TPU-related estimations. They are per-op 107 // maximums, so each op is evaluated independently, but we want the maximum of 108 // the value over all ops. 109 int64 max_per_op_buffers; // Sum of all buffers used by the ops. 110 int64 max_per_op_streaming; // Ignore largest input buffer, assuming it 111 // streams from main memory. 112 // If the time estimation is inaccurate. 113 bool inaccurate = false; 114 115 // Max possible memory usage per device. 116 std::unordered_map<string, uint64> estimated_max_memory_per_device; 117 }; 118 119 inline std::ostream& operator<<(std::ostream& os, const Costs::MilliSeconds d) { 120 os << d.count() << "ms"; 121 return os; 122 } 123 inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) { 124 os << d.count() << "us"; 125 return os; 126 } 127 inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) { 128 os << d.count() << "ns"; 129 return os; 130 } 131 132 Costs::Costs() { 133 execution_time = Duration::zero(); 134 compute_time = Duration::zero(); 135 memory_time = Duration::zero(); 136 max_memory = kMemoryUnknown; 137 persistent_memory = kMemoryUnknown; 138 temporary_memory = kMemoryUnknown; 139 max_per_op_buffers = kMemoryUnknown; 140 max_per_op_streaming = kMemoryUnknown; 141 } 142 143 Costs Costs::ZeroCosts() { 144 Costs costs; 145 costs.execution_time = Duration::zero(); 146 costs.compute_time = Duration::zero(); 147 costs.memory_time = Duration::zero(); 148 costs.max_memory = kZeroMemory; 149 costs.persistent_memory = kZeroMemory; 150 costs.temporary_memory = kZeroMemory; 151 costs.max_per_op_buffers = kZeroMemory; 152 costs.max_per_op_streaming = kZeroMemory; 153 return costs; 154 } 155 156 // Given a GrapperItem and an optimized implementation of the corresponding 157 // TensorFlow graph, the CostEstimator attempts to predicts the actual cost of 158 // running the graph. 159 class CostEstimator { 160 public: 161 virtual ~CostEstimator() {} 162 163 // Initializes the estimator for the specified grappler item. 164 // The estimator shouldn't be used if this function returns any status other 165 // that OK. 166 virtual Status Initialize(const GrapplerItem& item) = 0; 167 168 // Predicts the cost of running the given optimized version of the grappler 169 // item. 170 // If a CostGraphDef is passed, it will be populated with detailed information 171 // about the cost of running each operation of the optimized graph. 172 // if a double value is passed, it will be set to a value that reflects the 173 // overall cost of running the graph (e.g. the latency of the computation). 174 // Returns a status that indicate is the performance could be estimated or 175 // not. 176 virtual Status PredictCosts(const GraphDef& optimized_graph, 177 CostGraphDef* cost_graph, Costs* cost) const = 0; 178 }; 179 180 } // end namespace grappler 181 } // end namespace tensorflow 182 183 #endif // TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_ 184