1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ 17 #define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ 18 19 #include <map> 20 #include <memory> 21 22 #include "tensorflow/core/framework/summary.pb.h" 23 #include "tensorflow/core/lib/core/stringpiece.h" 24 #include "tensorflow/core/lib/monitoring/collected_metrics.h" 25 #include "tensorflow/core/lib/monitoring/metric_def.h" 26 #include "tensorflow/core/platform/env.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/platform/macros.h" 29 #include "tensorflow/core/platform/mutex.h" 30 #include "tensorflow/core/platform/thread_annotations.h" 31 32 namespace tensorflow { 33 namespace monitoring { 34 35 namespace test_util { 36 class CollectionRegistryTestAccess; 37 } // namespace test_util 38 39 namespace internal { 40 class Collector; 41 } // namespace internal 42 43 // Metric implementations would get an instance of this class using the 44 // MetricCollectorGetter in the collection-function lambda, so that their values 45 // can be collected. 46 // 47 // Read the documentation on CollectionRegistry::Register() for more details. 48 // 49 // For example: 50 // auto metric_collector = metric_collector_getter->Get(&metric_def); 51 // metric_collector.CollectValue(some_labels, some_value); 52 // metric_collector.CollectValue(others_labels, other_value); 53 // 54 // This class is NOT thread-safe. 55 template <MetricKind metric_kind, typename Value, int NumLabels> 56 class MetricCollector { 57 public: 58 ~MetricCollector() = default; 59 60 // Collects the value with these labels. 61 void CollectValue(const std::array<string, NumLabels>& labels, 62 const Value& value); 63 64 private: 65 friend class internal::Collector; 66 67 MetricCollector( 68 const MetricDef<metric_kind, Value, NumLabels>* const metric_def, 69 const uint64 registration_time_millis, 70 internal::Collector* const collector, PointSet* const point_set) 71 : metric_def_(metric_def), 72 registration_time_millis_(registration_time_millis), 73 collector_(collector), 74 point_set_(point_set) { 75 point_set_->metric_name = metric_def->name().ToString(); 76 } 77 78 const MetricDef<metric_kind, Value, NumLabels>* const metric_def_; 79 const uint64 registration_time_millis_; 80 internal::Collector* const collector_; 81 PointSet* const point_set_; 82 83 // This is made copyable because we can't hand out references of this class 84 // from MetricCollectorGetter because this class is templatized, and we need 85 // MetricCollectorGetter not to be templatized and hence MetricCollectorGetter 86 // can't own an instance of this class. 87 }; 88 89 // Returns a MetricCollector with the same template parameters as the 90 // metric-definition, so that the values of a metric can be collected. 91 // 92 // The collection-function defined by a metric takes this as a parameter. 93 // 94 // Read the documentation on CollectionRegistry::Register() for more details. 95 class MetricCollectorGetter { 96 public: 97 // Returns the MetricCollector with the same template parameters as the 98 // metric_def. 99 template <MetricKind metric_kind, typename Value, int NumLabels> 100 MetricCollector<metric_kind, Value, NumLabels> Get( 101 const MetricDef<metric_kind, Value, NumLabels>* const metric_def); 102 103 private: 104 friend class internal::Collector; 105 106 MetricCollectorGetter(internal::Collector* const collector, 107 const AbstractMetricDef* const allowed_metric_def, 108 const uint64 registration_time_millis) 109 : collector_(collector), 110 allowed_metric_def_(allowed_metric_def), 111 registration_time_millis_(registration_time_millis) {} 112 113 internal::Collector* const collector_; 114 const AbstractMetricDef* const allowed_metric_def_; 115 const uint64 registration_time_millis_; 116 }; 117 118 // A collection registry for metrics. 119 // 120 // Metrics are registered here so that their state can be collected later and 121 // exported. 122 // 123 // This class is thread-safe. 124 class CollectionRegistry { 125 public: 126 ~CollectionRegistry() = default; 127 128 // Returns the default registry for the process. 129 // 130 // This registry belongs to this library and should never be deleted. 131 static CollectionRegistry* Default(); 132 133 using CollectionFunction = std::function<void(MetricCollectorGetter getter)>; 134 135 // Registers the metric and the collection-function which can be used to 136 // collect its values. Returns a Registration object, which when upon 137 // destruction would cause the metric to be unregistered from this registry. 138 // 139 // IMPORTANT: Delete the handle before the metric-def is deleted. 140 // 141 // Example usage; 142 // CollectionRegistry::Default()->Register( 143 // &metric_def, 144 // [&](MetricCollectorGetter getter) { 145 // auto metric_collector = getter.Get(&metric_def); 146 // for (const auto& cell : cells) { 147 // metric_collector.CollectValue(cell.labels(), cell.value()); 148 // } 149 // }); 150 class RegistrationHandle; 151 std::unique_ptr<RegistrationHandle> Register( 152 const AbstractMetricDef* metric_def, 153 const CollectionFunction& collection_function) 154 LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT; 155 156 // Options for collecting metrics. 157 struct CollectMetricsOptions { 158 CollectMetricsOptions() {} 159 bool collect_metric_descriptors = true; 160 }; 161 // Goes through all the registered metrics, collects their definitions 162 // (optionally) and current values and returns them in a standard format. 163 std::unique_ptr<CollectedMetrics> CollectMetrics( 164 const CollectMetricsOptions& options) const; 165 166 private: 167 friend class test_util::CollectionRegistryTestAccess; 168 friend class internal::Collector; 169 170 CollectionRegistry(Env* env); 171 172 // Unregisters the metric from this registry. This is private because the 173 // public interface provides a Registration handle which automatically calls 174 // this upon destruction. 175 void Unregister(const AbstractMetricDef* metric_def) LOCKS_EXCLUDED(mu_); 176 177 // TF environment, mainly used for timestamping. 178 Env* const env_; 179 180 mutable mutex mu_; 181 182 // Information required for collection. 183 struct CollectionInfo { 184 const AbstractMetricDef* const metric_def; 185 CollectionFunction collection_function; 186 uint64 registration_time_millis; 187 }; 188 std::map<StringPiece, CollectionInfo> registry_ GUARDED_BY(mu_); 189 190 TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry); 191 }; 192 193 //// 194 // Implementation details follow. API readers may skip. 195 //// 196 197 class CollectionRegistry::RegistrationHandle { 198 public: 199 RegistrationHandle(CollectionRegistry* const export_registry, 200 const AbstractMetricDef* const metric_def) 201 : export_registry_(export_registry), metric_def_(metric_def) {} 202 203 ~RegistrationHandle() { export_registry_->Unregister(metric_def_); } 204 205 private: 206 CollectionRegistry* const export_registry_; 207 const AbstractMetricDef* const metric_def_; 208 }; 209 210 namespace internal { 211 212 template <typename Value> 213 void CollectValue(const Value& value, Point* point); 214 215 template <> 216 inline void CollectValue(const int64& value, Point* const point) { 217 point->value_type = ValueType::kInt64; 218 point->int64_value = value; 219 } 220 221 template <> 222 inline void CollectValue(const string& value, Point* const point) { 223 point->value_type = ValueType::kString; 224 point->string_value = value; 225 } 226 227 template <> 228 inline void CollectValue(const bool& value, Point* const point) { 229 point->value_type = ValueType::kBool; 230 point->bool_value = value; 231 } 232 233 template <> 234 inline void CollectValue(const HistogramProto& value, Point* const point) { 235 point->value_type = ValueType::kHistogram; 236 // This is inefficient. If and when we hit snags, we can change the API to do 237 // this more efficiently. 238 point->histogram_value = value; 239 } 240 241 // Used by the CollectionRegistry class to collect all the values of all the 242 // metrics in the registry. This is an implementation detail of the 243 // CollectionRegistry class, please do not depend on this. 244 // 245 // This cannot be a private nested class because we need to forward declare this 246 // so that the MetricCollector and MetricCollectorGetter classes can be friends 247 // with it. 248 // 249 // This class is thread-safe. 250 class Collector { 251 public: 252 Collector(const uint64 collection_time_millis) 253 : collected_metrics_(new CollectedMetrics()), 254 collection_time_millis_(collection_time_millis) {} 255 256 template <MetricKind metric_kind, typename Value, int NumLabels> 257 MetricCollector<metric_kind, Value, NumLabels> GetMetricCollector( 258 const MetricDef<metric_kind, Value, NumLabels>* const metric_def, 259 const uint64 registration_time_millis, 260 internal::Collector* const collector) LOCKS_EXCLUDED(mu_) { 261 auto* const point_set = [&]() { 262 mutex_lock l(mu_); 263 return collected_metrics_->point_set_map 264 .insert(std::make_pair(metric_def->name().ToString(), 265 std::unique_ptr<PointSet>(new PointSet()))) 266 .first->second.get(); 267 }(); 268 return MetricCollector<metric_kind, Value, NumLabels>( 269 metric_def, registration_time_millis, collector, point_set); 270 } 271 272 uint64 collection_time_millis() const { return collection_time_millis_; } 273 274 void CollectMetricDescriptor(const AbstractMetricDef* const metric_def) 275 LOCKS_EXCLUDED(mu_); 276 277 void CollectMetricValues( 278 const CollectionRegistry::CollectionInfo& collection_info); 279 280 std::unique_ptr<CollectedMetrics> ConsumeCollectedMetrics() 281 LOCKS_EXCLUDED(mu_); 282 283 private: 284 mutable mutex mu_; 285 std::unique_ptr<CollectedMetrics> collected_metrics_ GUARDED_BY(mu_); 286 const uint64 collection_time_millis_; 287 288 TF_DISALLOW_COPY_AND_ASSIGN(Collector); 289 }; 290 291 // Write the timestamps for the point based on the MetricKind. 292 // 293 // Gauge metrics will have start and end timestamps set to the collection time. 294 // 295 // Cumulative metrics will have the start timestamp set to the time when the 296 // collection function was registered, while the end timestamp will be set to 297 // the collection time. 298 template <MetricKind kind> 299 void WriteTimestamps(const uint64 registration_time_millis, 300 const uint64 collection_time_millis, Point* const point); 301 302 template <> 303 inline void WriteTimestamps<MetricKind::kGauge>( 304 const uint64 registration_time_millis, const uint64 collection_time_millis, 305 Point* const point) { 306 point->start_timestamp_millis = collection_time_millis; 307 point->end_timestamp_millis = collection_time_millis; 308 } 309 310 template <> 311 inline void WriteTimestamps<MetricKind::kCumulative>( 312 const uint64 registration_time_millis, const uint64 collection_time_millis, 313 Point* const point) { 314 point->start_timestamp_millis = registration_time_millis; 315 // There's a chance that the clock goes backwards on the same machine, so we 316 // protect ourselves against that. 317 point->end_timestamp_millis = 318 registration_time_millis < collection_time_millis 319 ? collection_time_millis 320 : registration_time_millis; 321 } 322 323 } // namespace internal 324 325 template <MetricKind metric_kind, typename Value, int NumLabels> 326 void MetricCollector<metric_kind, Value, NumLabels>::CollectValue( 327 const std::array<string, NumLabels>& labels, const Value& value) { 328 point_set_->points.emplace_back(new Point()); 329 auto* const point = point_set_->points.back().get(); 330 const std::vector<string> label_descriptions = 331 metric_def_->label_descriptions(); 332 point->labels.reserve(NumLabels); 333 for (int i = 0; i < NumLabels; ++i) { 334 point->labels.push_back({}); 335 auto* const label = &point->labels.back(); 336 label->name = label_descriptions[i]; 337 label->value = labels[i]; 338 } 339 internal::CollectValue(value, point); 340 internal::WriteTimestamps<metric_kind>( 341 registration_time_millis_, collector_->collection_time_millis(), point); 342 } 343 344 template <MetricKind metric_kind, typename Value, int NumLabels> 345 MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get( 346 const MetricDef<metric_kind, Value, NumLabels>* const metric_def) { 347 if (allowed_metric_def_ != metric_def) { 348 LOG(FATAL) << "Expected collection for: " << allowed_metric_def_->name() 349 << " but instead got: " << metric_def->name(); 350 } 351 352 return collector_->GetMetricCollector(metric_def, registration_time_millis_, 353 collector_); 354 } 355 356 } // namespace monitoring 357 } // namespace tensorflow 358 359 #endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ 360