1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/summary/summary_db_writer.h" 16 17 #include <deque> 18 19 #include "tensorflow/core/summary/summary_converter.h" 20 #include "tensorflow/core/framework/graph.pb.h" 21 #include "tensorflow/core/framework/node_def.pb.h" 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/framework/summary.pb.h" 24 #include "tensorflow/core/lib/core/stringpiece.h" 25 #include "tensorflow/core/lib/db/sqlite.h" 26 #include "tensorflow/core/lib/random/random.h" 27 #include "tensorflow/core/util/event.pb.h" 28 29 // TODO(jart): Break this up into multiple files with excellent unit tests. 30 // TODO(jart): Make decision to write in separate op. 31 // TODO(jart): Add really good busy handling. 32 33 // clang-format off 34 #define CALL_SUPPORTED_TYPES(m) \ 35 TF_CALL_string(m) \ 36 TF_CALL_half(m) \ 37 TF_CALL_float(m) \ 38 TF_CALL_double(m) \ 39 TF_CALL_complex64(m) \ 40 TF_CALL_complex128(m) \ 41 TF_CALL_int8(m) \ 42 TF_CALL_int16(m) \ 43 TF_CALL_int32(m) \ 44 TF_CALL_int64(m) \ 45 TF_CALL_uint8(m) \ 46 TF_CALL_uint16(m) \ 47 TF_CALL_uint32(m) \ 48 TF_CALL_uint64(m) 49 // clang-format on 50 51 namespace tensorflow { 52 namespace { 53 54 // https://www.sqlite.org/fileformat.html#record_format 55 const uint64 kIdTiers[] = { 56 0x7fffffULL, // 23-bit (3 bytes on disk) 57 0x7fffffffULL, // 31-bit (4 bytes on disk) 58 0x7fffffffffffULL, // 47-bit (5 bytes on disk) 59 // remaining bits for future use 60 }; 61 const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64); 62 const int kIdCollisionDelayMicros = 10; 63 const int kMaxIdCollisions = 21; // sum(2**i*10s for i in range(21))~=21s 64 const int64 kAbsent = 0LL; 65 66 const char* kScalarPluginName = "scalars"; 67 const char* kImagePluginName = "images"; 68 const char* kAudioPluginName = "audio"; 69 const char* kHistogramPluginName = "histograms"; 70 71 const int64 kReserveMinBytes = 32; 72 const double kReserveMultiplier = 1.5; 73 const int64 kPreallocateRows = 1000; 74 75 // Flush is a misnomer because what we're actually doing is having lots 76 // of commits inside any SqliteTransaction that writes potentially 77 // hundreds of megs but doesn't need the transaction to maintain its 78 // invariants. This ensures the WAL read penalty is small and might 79 // allow writers in other processes a chance to schedule. 80 const uint64 kFlushBytes = 1024 * 1024; 81 82 double DoubleTime(uint64 micros) { 83 // TODO(@jart): Follow precise definitions for time laid out in schema. 84 // TODO(@jart): Use monotonic clock from gRPC codebase. 85 return static_cast<double>(micros) / 1.0e6; 86 } 87 88 string StringifyShape(const TensorShape& shape) { 89 string result; 90 bool first = true; 91 for (const auto& dim : shape) { 92 if (first) { 93 first = false; 94 } else { 95 strings::StrAppend(&result, ","); 96 } 97 strings::StrAppend(&result, dim.size); 98 } 99 return result; 100 } 101 102 Status CheckSupportedType(const Tensor& t) { 103 #define CASE(T) \ 104 case DataTypeToEnum<T>::value: \ 105 break; 106 switch (t.dtype()) { 107 CALL_SUPPORTED_TYPES(CASE) 108 default: 109 return errors::Unimplemented(DataTypeString(t.dtype()), 110 " tensors unsupported on platform"); 111 } 112 return Status::OK(); 113 #undef CASE 114 } 115 116 Tensor AsScalar(const Tensor& t) { 117 Tensor t2{t.dtype(), {}}; 118 #define CASE(T) \ 119 case DataTypeToEnum<T>::value: \ 120 t2.scalar<T>()() = t.flat<T>()(0); \ 121 break; 122 switch (t.dtype()) { 123 CALL_SUPPORTED_TYPES(CASE) 124 default: 125 t2 = {DT_FLOAT, {}}; 126 t2.scalar<float>()() = NAN; 127 break; 128 } 129 return t2; 130 #undef CASE 131 } 132 133 void PatchPluginName(SummaryMetadata* metadata, const char* name) { 134 if (metadata->plugin_data().plugin_name().empty()) { 135 metadata->mutable_plugin_data()->set_plugin_name(name); 136 } 137 } 138 139 Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) { 140 const char* sql = R"sql( 141 INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) 142 )sql"; 143 SqliteStatement insert_desc; 144 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc)); 145 insert_desc.BindInt(1, id); 146 insert_desc.BindText(2, markdown); 147 return insert_desc.StepAndReset(); 148 } 149 150 /// \brief Generates unique IDs randomly in the [1,2**63-1] range. 151 /// 152 /// This class starts off generating IDs in the [1,2**23-1] range, 153 /// because it's human friendly and occupies 4 bytes max on disk with 154 /// SQLite's zigzag varint encoding. Then, each time a collision 155 /// happens, the random space is increased by 8 bits. 156 /// 157 /// This class uses exponential back-off so writes gradually slow down 158 /// as IDs become exhausted but reads are still possible. 159 /// 160 /// This class is thread safe. 161 class IdAllocator { 162 public: 163 IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} { 164 DCHECK(env_ != nullptr); 165 DCHECK(db_ != nullptr); 166 } 167 168 Status CreateNewId(int64* id) LOCKS_EXCLUDED(mu_) { 169 mutex_lock lock(mu_); 170 Status s; 171 SqliteStatement stmt; 172 TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt)); 173 for (int i = 0; i < kMaxIdCollisions; ++i) { 174 int64 tid = MakeRandomId(); 175 stmt.BindInt(1, tid); 176 s = stmt.StepAndReset(); 177 if (s.ok()) { 178 *id = tid; 179 break; 180 } 181 // SQLITE_CONSTRAINT maps to INVALID_ARGUMENT in sqlite.cc 182 if (s.code() != error::INVALID_ARGUMENT) break; 183 if (tier_ < kMaxIdTier) { 184 LOG(INFO) << "IdAllocator collision at tier " << tier_ << " (of " 185 << kMaxIdTier << ") so auto-adjusting to a higher tier"; 186 ++tier_; 187 } else { 188 LOG(WARNING) << "IdAllocator (attempt #" << i << ") " 189 << "resulted in a collision at the highest tier; this " 190 "is problematic if it happens often; you can try " 191 "pruning the Ids table; you can also file a bug " 192 "asking for the ID space to be increased; otherwise " 193 "writes will gradually slow down over time until they " 194 "become impossible"; 195 } 196 env_->SleepForMicroseconds((1 << i) * kIdCollisionDelayMicros); 197 } 198 return s; 199 } 200 201 private: 202 int64 MakeRandomId() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 203 int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]); 204 if (id == kAbsent) ++id; 205 return id; 206 } 207 208 mutex mu_; 209 Env* const env_; 210 Sqlite* const db_; 211 int tier_ GUARDED_BY(mu_) = 0; 212 213 TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator); 214 }; 215 216 class GraphWriter { 217 public: 218 static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids, 219 GraphDef* graph, uint64 now, int64 run_id, int64* graph_id) 220 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { 221 TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id)); 222 GraphWriter saver{db, txn, graph, now, *graph_id}; 223 saver.MapNameToNodeId(); 224 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs"); 225 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes"); 226 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph"); 227 return Status::OK(); 228 } 229 230 private: 231 GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now, 232 int64 graph_id) 233 : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {} 234 235 void MapNameToNodeId() { 236 size_t toto = static_cast<size_t>(graph_->node_size()); 237 name_copies_.reserve(toto); 238 name_to_node_id_.reserve(toto); 239 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { 240 // Copy name into memory region, since we call clear_name() later. 241 // Then wrap in StringPiece so we can compare slices without copy. 242 name_copies_.emplace_back(graph_->node(node_id).name()); 243 name_to_node_id_.emplace(name_copies_.back(), node_id); 244 } 245 } 246 247 Status SaveNodeInputs() { 248 const char* sql = R"sql( 249 INSERT INTO NodeInputs ( 250 graph_id, 251 node_id, 252 idx, 253 input_node_id, 254 input_node_idx, 255 is_control 256 ) VALUES (?, ?, ?, ?, ?, ?) 257 )sql"; 258 SqliteStatement insert; 259 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); 260 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { 261 const NodeDef& node = graph_->node(node_id); 262 for (int idx = 0; idx < node.input_size(); ++idx) { 263 StringPiece name = node.input(idx); 264 int64 input_node_id; 265 int64 input_node_idx = 0; 266 int64 is_control = 0; 267 size_t i = name.rfind(':'); 268 if (i != StringPiece::npos) { 269 if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1), 270 &input_node_idx)) { 271 return errors::DataLoss("Bad NodeDef.input: ", name); 272 } 273 name.remove_suffix(name.size() - i); 274 } 275 if (!name.empty() && name[0] == '^') { 276 name.remove_prefix(1); 277 is_control = 1; 278 } 279 auto e = name_to_node_id_.find(name); 280 if (e == name_to_node_id_.end()) { 281 return errors::DataLoss("Could not find node: ", name); 282 } 283 input_node_id = e->second; 284 insert.BindInt(1, graph_id_); 285 insert.BindInt(2, node_id); 286 insert.BindInt(3, idx); 287 insert.BindInt(4, input_node_id); 288 insert.BindInt(5, input_node_idx); 289 insert.BindInt(6, is_control); 290 unflushed_bytes_ += insert.size(); 291 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(), 292 " -> ", name); 293 TF_RETURN_IF_ERROR(MaybeFlush()); 294 } 295 } 296 return Status::OK(); 297 } 298 299 Status SaveNodes() { 300 const char* sql = R"sql( 301 INSERT INTO Nodes ( 302 graph_id, 303 node_id, 304 node_name, 305 op, 306 device, 307 node_def) 308 VALUES (?, ?, ?, ?, ?, ?) 309 )sql"; 310 SqliteStatement insert; 311 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); 312 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { 313 NodeDef* node = graph_->mutable_node(node_id); 314 insert.BindInt(1, graph_id_); 315 insert.BindInt(2, node_id); 316 insert.BindText(3, node->name()); 317 insert.BindText(4, node->op()); 318 insert.BindText(5, node->device()); 319 node->clear_name(); 320 node->clear_op(); 321 node->clear_device(); 322 node->clear_input(); 323 string node_def; 324 if (node->SerializeToString(&node_def)) { 325 insert.BindBlobUnsafe(6, node_def); 326 } 327 unflushed_bytes_ += insert.size(); 328 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name()); 329 TF_RETURN_IF_ERROR(MaybeFlush()); 330 } 331 return Status::OK(); 332 } 333 334 Status SaveGraph(int64 run_id) { 335 const char* sql = R"sql( 336 INSERT OR REPLACE INTO Graphs ( 337 run_id, 338 graph_id, 339 inserted_time, 340 graph_def 341 ) VALUES (?, ?, ?, ?) 342 )sql"; 343 SqliteStatement insert; 344 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); 345 if (run_id != kAbsent) insert.BindInt(1, run_id); 346 insert.BindInt(2, graph_id_); 347 insert.BindDouble(3, DoubleTime(now_)); 348 graph_->clear_node(); 349 string graph_def; 350 if (graph_->SerializeToString(&graph_def)) { 351 insert.BindBlobUnsafe(4, graph_def); 352 } 353 return insert.StepAndReset(); 354 } 355 356 Status MaybeFlush() { 357 if (unflushed_bytes_ >= kFlushBytes) { 358 TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ", 359 unflushed_bytes_, " bytes"); 360 unflushed_bytes_ = 0; 361 } 362 return Status::OK(); 363 } 364 365 Sqlite* const db_; 366 SqliteTransaction* const txn_; 367 uint64 unflushed_bytes_ = 0; 368 GraphDef* const graph_; 369 const uint64 now_; 370 const int64 graph_id_; 371 std::vector<string> name_copies_; 372 std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_; 373 374 TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter); 375 }; 376 377 /// \brief Run metadata manager. 378 /// 379 /// This class gives us Tag IDs we can pass to SeriesWriter. In order 380 /// to do that, rows are created in the Ids, Tags, Runs, Experiments, 381 /// and Users tables. 382 /// 383 /// This class is thread safe. 384 class RunMetadata { 385 public: 386 RunMetadata(IdAllocator* ids, const string& experiment_name, 387 const string& run_name, const string& user_name) 388 : ids_{ids}, 389 experiment_name_{experiment_name}, 390 run_name_{run_name}, 391 user_name_{user_name} { 392 DCHECK(ids_ != nullptr); 393 } 394 395 const string& experiment_name() { return experiment_name_; } 396 const string& run_name() { return run_name_; } 397 const string& user_name() { return user_name_; } 398 399 int64 run_id() LOCKS_EXCLUDED(mu_) { 400 mutex_lock lock(mu_); 401 return run_id_; 402 } 403 404 Status SetGraph(Sqlite* db, uint64 now, double computed_time, 405 std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db) 406 LOCKS_EXCLUDED(mu_) { 407 int64 run_id; 408 { 409 mutex_lock lock(mu_); 410 TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); 411 run_id = run_id_; 412 } 413 int64 graph_id; 414 SqliteTransaction txn(*db); // only to increase performance 415 TF_RETURN_IF_ERROR( 416 GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id)); 417 return txn.Commit(); 418 } 419 420 Status GetTagId(Sqlite* db, uint64 now, double computed_time, 421 const string& tag_name, int64* tag_id, 422 const SummaryMetadata& metadata) LOCKS_EXCLUDED(mu_) { 423 mutex_lock lock(mu_); 424 TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); 425 auto e = tag_ids_.find(tag_name); 426 if (e != tag_ids_.end()) { 427 *tag_id = e->second; 428 return Status::OK(); 429 } 430 TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id)); 431 tag_ids_[tag_name] = *tag_id; 432 TF_RETURN_IF_ERROR( 433 SetDescription(db, *tag_id, metadata.summary_description())); 434 const char* sql = R"sql( 435 INSERT INTO Tags ( 436 run_id, 437 tag_id, 438 tag_name, 439 inserted_time, 440 display_name, 441 plugin_name, 442 plugin_data 443 ) VALUES ( 444 :run_id, 445 :tag_id, 446 :tag_name, 447 :inserted_time, 448 :display_name, 449 :plugin_name, 450 :plugin_data 451 ) 452 )sql"; 453 SqliteStatement insert; 454 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert)); 455 if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_); 456 insert.BindInt(":tag_id", *tag_id); 457 insert.BindTextUnsafe(":tag_name", tag_name); 458 insert.BindDouble(":inserted_time", DoubleTime(now)); 459 insert.BindTextUnsafe(":display_name", metadata.display_name()); 460 insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name()); 461 insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content()); 462 return insert.StepAndReset(); 463 } 464 465 private: 466 Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 467 if (user_id_ != kAbsent || user_name_.empty()) return Status::OK(); 468 const char* get_sql = R"sql( 469 SELECT user_id FROM Users WHERE user_name = ? 470 )sql"; 471 SqliteStatement get; 472 TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get)); 473 get.BindText(1, user_name_); 474 bool is_done; 475 TF_RETURN_IF_ERROR(get.Step(&is_done)); 476 if (!is_done) { 477 user_id_ = get.ColumnInt(0); 478 return Status::OK(); 479 } 480 TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_)); 481 const char* insert_sql = R"sql( 482 INSERT INTO Users ( 483 user_id, 484 user_name, 485 inserted_time 486 ) VALUES (?, ?, ?) 487 )sql"; 488 SqliteStatement insert; 489 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); 490 insert.BindInt(1, user_id_); 491 insert.BindText(2, user_name_); 492 insert.BindDouble(3, DoubleTime(now)); 493 TF_RETURN_IF_ERROR(insert.StepAndReset()); 494 return Status::OK(); 495 } 496 497 Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time) 498 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 499 if (experiment_name_.empty()) return Status::OK(); 500 if (experiment_id_ == kAbsent) { 501 TF_RETURN_IF_ERROR(InitializeUser(db, now)); 502 const char* get_sql = R"sql( 503 SELECT 504 experiment_id, 505 started_time 506 FROM 507 Experiments 508 WHERE 509 user_id IS ? 510 AND experiment_name = ? 511 )sql"; 512 SqliteStatement get; 513 TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get)); 514 if (user_id_ != kAbsent) get.BindInt(1, user_id_); 515 get.BindText(2, experiment_name_); 516 bool is_done; 517 TF_RETURN_IF_ERROR(get.Step(&is_done)); 518 if (!is_done) { 519 experiment_id_ = get.ColumnInt(0); 520 experiment_started_time_ = get.ColumnInt(1); 521 } else { 522 TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_)); 523 experiment_started_time_ = computed_time; 524 const char* insert_sql = R"sql( 525 INSERT INTO Experiments ( 526 user_id, 527 experiment_id, 528 experiment_name, 529 inserted_time, 530 started_time, 531 is_watching 532 ) VALUES (?, ?, ?, ?, ?, ?) 533 )sql"; 534 SqliteStatement insert; 535 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); 536 if (user_id_ != kAbsent) insert.BindInt(1, user_id_); 537 insert.BindInt(2, experiment_id_); 538 insert.BindText(3, experiment_name_); 539 insert.BindDouble(4, DoubleTime(now)); 540 insert.BindDouble(5, computed_time); 541 insert.BindInt(6, 0); 542 TF_RETURN_IF_ERROR(insert.StepAndReset()); 543 } 544 } 545 if (computed_time < experiment_started_time_) { 546 experiment_started_time_ = computed_time; 547 const char* update_sql = R"sql( 548 UPDATE 549 Experiments 550 SET 551 started_time = ? 552 WHERE 553 experiment_id = ? 554 )sql"; 555 SqliteStatement update; 556 TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update)); 557 update.BindDouble(1, computed_time); 558 update.BindInt(2, experiment_id_); 559 TF_RETURN_IF_ERROR(update.StepAndReset()); 560 } 561 return Status::OK(); 562 } 563 564 Status InitializeRun(Sqlite* db, uint64 now, double computed_time) 565 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 566 if (run_name_.empty()) return Status::OK(); 567 TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time)); 568 if (run_id_ == kAbsent) { 569 TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_)); 570 run_started_time_ = computed_time; 571 const char* insert_sql = R"sql( 572 INSERT OR REPLACE INTO Runs ( 573 experiment_id, 574 run_id, 575 run_name, 576 inserted_time, 577 started_time 578 ) VALUES (?, ?, ?, ?, ?) 579 )sql"; 580 SqliteStatement insert; 581 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); 582 if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_); 583 insert.BindInt(2, run_id_); 584 insert.BindText(3, run_name_); 585 insert.BindDouble(4, DoubleTime(now)); 586 insert.BindDouble(5, computed_time); 587 TF_RETURN_IF_ERROR(insert.StepAndReset()); 588 } 589 if (computed_time < run_started_time_) { 590 run_started_time_ = computed_time; 591 const char* update_sql = R"sql( 592 UPDATE 593 Runs 594 SET 595 started_time = ? 596 WHERE 597 run_id = ? 598 )sql"; 599 SqliteStatement update; 600 TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update)); 601 update.BindDouble(1, computed_time); 602 update.BindInt(2, run_id_); 603 TF_RETURN_IF_ERROR(update.StepAndReset()); 604 } 605 return Status::OK(); 606 } 607 608 mutex mu_; 609 IdAllocator* const ids_; 610 const string experiment_name_; 611 const string run_name_; 612 const string user_name_; 613 int64 experiment_id_ GUARDED_BY(mu_) = kAbsent; 614 int64 run_id_ GUARDED_BY(mu_) = kAbsent; 615 int64 user_id_ GUARDED_BY(mu_) = kAbsent; 616 double experiment_started_time_ GUARDED_BY(mu_) = 0.0; 617 double run_started_time_ GUARDED_BY(mu_) = 0.0; 618 std::unordered_map<string, int64> tag_ids_ GUARDED_BY(mu_); 619 620 TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata); 621 }; 622 623 /// \brief Tensor writer for a single series, e.g. Tag. 624 /// 625 /// This class is thread safe. 626 class SeriesWriter { 627 public: 628 SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} { 629 DCHECK(series_ > 0); 630 } 631 632 Status Append(Sqlite* db, int64 step, uint64 now, double computed_time, 633 const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) 634 LOCKS_EXCLUDED(mu_) { 635 mutex_lock lock(mu_); 636 if (rowids_.empty()) { 637 Status s = Reserve(db, t); 638 if (!s.ok()) { 639 rowids_.clear(); 640 return s; 641 } 642 } 643 int64 rowid = rowids_.front(); 644 Status s = Write(db, rowid, step, computed_time, t); 645 if (s.ok()) { 646 ++count_; 647 } 648 rowids_.pop_front(); 649 return s; 650 } 651 652 Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) 653 LOCKS_EXCLUDED(mu_) { 654 mutex_lock lock(mu_); 655 // Delete unused pre-allocated Tensors. 656 if (!rowids_.empty()) { 657 SqliteTransaction txn(*db); 658 const char* sql = R"sql( 659 DELETE FROM Tensors WHERE rowid = ? 660 )sql"; 661 SqliteStatement deleter; 662 TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter)); 663 for (size_t i = count_; i < rowids_.size(); ++i) { 664 deleter.BindInt(1, rowids_.front()); 665 TF_RETURN_IF_ERROR(deleter.StepAndReset()); 666 rowids_.pop_front(); 667 } 668 TF_RETURN_IF_ERROR(txn.Commit()); 669 rowids_.clear(); 670 } 671 return Status::OK(); 672 } 673 674 private: 675 Status Write(Sqlite* db, int64 rowid, int64 step, double computed_time, 676 const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) { 677 if (t.dtype() == DT_STRING) { 678 if (t.dims() == 0) { 679 return Update(db, step, computed_time, t, t.scalar<string>()(), rowid); 680 } else { 681 SqliteTransaction txn(*db); 682 TF_RETURN_IF_ERROR( 683 Update(db, step, computed_time, t, StringPiece(), rowid)); 684 TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid)); 685 return txn.Commit(); 686 } 687 } else { 688 return Update(db, step, computed_time, t, t.tensor_data(), rowid); 689 } 690 } 691 692 Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t, 693 const StringPiece& data, int64 rowid) { 694 const char* sql = R"sql( 695 UPDATE OR REPLACE 696 Tensors 697 SET 698 step = ?, 699 computed_time = ?, 700 dtype = ?, 701 shape = ?, 702 data = ? 703 WHERE 704 rowid = ? 705 )sql"; 706 SqliteStatement stmt; 707 TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); 708 stmt.BindInt(1, step); 709 stmt.BindDouble(2, computed_time); 710 stmt.BindInt(3, t.dtype()); 711 stmt.BindText(4, StringifyShape(t.shape())); 712 stmt.BindBlobUnsafe(5, data); 713 stmt.BindInt(6, rowid); 714 TF_RETURN_IF_ERROR(stmt.StepAndReset()); 715 return Status::OK(); 716 } 717 718 Status UpdateNdString(Sqlite* db, const Tensor& t, int64 tensor_rowid) 719 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { 720 DCHECK_EQ(t.dtype(), DT_STRING); 721 DCHECK_GT(t.dims(), 0); 722 const char* deleter_sql = R"sql( 723 DELETE FROM TensorStrings WHERE tensor_rowid = ? 724 )sql"; 725 SqliteStatement deleter; 726 TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter)); 727 deleter.BindInt(1, tensor_rowid); 728 TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid); 729 const char* inserter_sql = R"sql( 730 INSERT INTO TensorStrings ( 731 tensor_rowid, 732 idx, 733 data 734 ) VALUES (?, ?, ?) 735 )sql"; 736 SqliteStatement inserter; 737 TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter)); 738 auto flat = t.flat<string>(); 739 for (int64 i = 0; i < flat.size(); ++i) { 740 inserter.BindInt(1, tensor_rowid); 741 inserter.BindInt(2, i); 742 inserter.BindBlobUnsafe(3, flat(i)); 743 TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i); 744 } 745 return Status::OK(); 746 } 747 748 Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) 749 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 750 SqliteTransaction txn(*db); // only for performance 751 unflushed_bytes_ = 0; 752 if (t.dtype() == DT_STRING) { 753 if (t.dims() == 0) { 754 TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<string>()().size())); 755 } else { 756 TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes)); 757 } 758 } else { 759 TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size())); 760 } 761 return txn.Commit(); 762 } 763 764 Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size) 765 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) 766 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 767 int64 space = 768 static_cast<int64>(static_cast<double>(size) * kReserveMultiplier); 769 if (space < kReserveMinBytes) space = kReserveMinBytes; 770 return ReserveTensors(db, txn, space); 771 } 772 773 Status ReserveTensors(Sqlite* db, SqliteTransaction* txn, 774 int64 reserved_bytes) 775 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) 776 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 777 const char* sql = R"sql( 778 INSERT INTO Tensors ( 779 series, 780 data 781 ) VALUES (?, ZEROBLOB(?)) 782 )sql"; 783 SqliteStatement insert; 784 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert)); 785 // TODO(jart): Maybe preallocate index pages by setting step. This 786 // is tricky because UPDATE OR REPLACE can have a side 787 // effect of deleting preallocated rows. 788 for (int64 i = 0; i < kPreallocateRows; ++i) { 789 insert.BindInt(1, series_); 790 insert.BindInt(2, reserved_bytes); 791 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i); 792 rowids_.push_back(db->last_insert_rowid()); 793 unflushed_bytes_ += reserved_bytes; 794 TF_RETURN_IF_ERROR(MaybeFlush(db, txn)); 795 } 796 return Status::OK(); 797 } 798 799 Status MaybeFlush(Sqlite* db, SqliteTransaction* txn) 800 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) 801 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 802 if (unflushed_bytes_ >= kFlushBytes) { 803 TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ", 804 unflushed_bytes_, " bytes"); 805 unflushed_bytes_ = 0; 806 } 807 return Status::OK(); 808 } 809 810 mutex mu_; 811 const int64 series_; 812 RunMetadata* const meta_; 813 uint64 count_ GUARDED_BY(mu_) = 0; 814 std::deque<int64> rowids_ GUARDED_BY(mu_); 815 uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0; 816 817 TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter); 818 }; 819 820 /// \brief Tensor writer for a single Run. 821 /// 822 /// This class farms out tensors to SeriesWriter instances. It also 823 /// keeps track of whether or not someone is watching the TensorBoard 824 /// GUI, so it can avoid writes when possible. 825 /// 826 /// This class is thread safe. 827 class RunWriter { 828 public: 829 explicit RunWriter(RunMetadata* meta) : meta_{meta} {} 830 831 Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now, 832 double computed_time, const Tensor& t) 833 SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { 834 SeriesWriter* writer = GetSeriesWriter(tag_id); 835 return writer->Append(db, step, now, computed_time, t); 836 } 837 838 Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) 839 LOCKS_EXCLUDED(mu_) { 840 mutex_lock lock(mu_); 841 if (series_writers_.empty()) return Status::OK(); 842 for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) { 843 if (!i->second) continue; 844 TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db), 845 "finish tag_id=", i->first); 846 i->second.reset(); 847 } 848 return Status::OK(); 849 } 850 851 private: 852 SeriesWriter* GetSeriesWriter(int64 tag_id) LOCKS_EXCLUDED(mu_) { 853 mutex_lock sl(mu_); 854 auto spot = series_writers_.find(tag_id); 855 if (spot == series_writers_.end()) { 856 SeriesWriter* writer = new SeriesWriter(tag_id, meta_); 857 series_writers_[tag_id].reset(writer); 858 return writer; 859 } else { 860 return spot->second.get(); 861 } 862 } 863 864 mutex mu_; 865 RunMetadata* const meta_; 866 std::unordered_map<int64, std::unique_ptr<SeriesWriter>> series_writers_ 867 GUARDED_BY(mu_); 868 869 TF_DISALLOW_COPY_AND_ASSIGN(RunWriter); 870 }; 871 872 /// \brief SQLite implementation of SummaryWriterInterface. 873 /// 874 /// This class is thread safe. 875 class SummaryDbWriter : public SummaryWriterInterface { 876 public: 877 SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name, 878 const string& run_name, const string& user_name) 879 : SummaryWriterInterface(), 880 env_{env}, 881 db_{db}, 882 ids_{env_, db_}, 883 meta_{&ids_, experiment_name, run_name, user_name}, 884 run_{&meta_} { 885 DCHECK(env_ != nullptr); 886 db_->Ref(); 887 } 888 889 ~SummaryDbWriter() override { 890 core::ScopedUnref unref(db_); 891 Status s = run_.Finish(db_); 892 if (!s.ok()) { 893 // TODO(jart): Retry on transient errors here. 894 LOG(ERROR) << s.ToString(); 895 } 896 int64 run_id = meta_.run_id(); 897 if (run_id == kAbsent) return; 898 const char* sql = R"sql( 899 UPDATE Runs SET finished_time = ? WHERE run_id = ? 900 )sql"; 901 SqliteStatement update; 902 s = db_->Prepare(sql, &update); 903 if (s.ok()) { 904 update.BindDouble(1, DoubleTime(env_->NowMicros())); 905 update.BindInt(2, run_id); 906 s = update.StepAndReset(); 907 } 908 if (!s.ok()) { 909 LOG(ERROR) << "Failed to set Runs[" << run_id 910 << "].finish_time: " << s.ToString(); 911 } 912 } 913 914 Status Flush() override { return Status::OK(); } 915 916 Status WriteTensor(int64 global_step, Tensor t, const string& tag, 917 const string& serialized_metadata) override { 918 TF_RETURN_IF_ERROR(CheckSupportedType(t)); 919 SummaryMetadata metadata; 920 if (!metadata.ParseFromString(serialized_metadata)) { 921 return errors::InvalidArgument("Bad serialized_metadata"); 922 } 923 return Write(global_step, t, tag, metadata); 924 } 925 926 Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { 927 TF_RETURN_IF_ERROR(CheckSupportedType(t)); 928 SummaryMetadata metadata; 929 PatchPluginName(&metadata, kScalarPluginName); 930 return Write(global_step, AsScalar(t), tag, metadata); 931 } 932 933 Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override { 934 uint64 now = env_->NowMicros(); 935 return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g)); 936 } 937 938 Status WriteEvent(std::unique_ptr<Event> e) override { 939 return MigrateEvent(std::move(e)); 940 } 941 942 Status WriteHistogram(int64 global_step, Tensor t, 943 const string& tag) override { 944 uint64 now = env_->NowMicros(); 945 std::unique_ptr<Event> e{new Event}; 946 e->set_step(global_step); 947 e->set_wall_time(DoubleTime(now)); 948 TF_RETURN_IF_ERROR( 949 AddTensorAsHistogramToSummary(t, tag, e->mutable_summary())); 950 return MigrateEvent(std::move(e)); 951 } 952 953 Status WriteImage(int64 global_step, Tensor t, const string& tag, 954 int max_images, Tensor bad_color) override { 955 uint64 now = env_->NowMicros(); 956 std::unique_ptr<Event> e{new Event}; 957 e->set_step(global_step); 958 e->set_wall_time(DoubleTime(now)); 959 TF_RETURN_IF_ERROR(AddTensorAsImageToSummary(t, tag, max_images, bad_color, 960 e->mutable_summary())); 961 return MigrateEvent(std::move(e)); 962 } 963 964 Status WriteAudio(int64 global_step, Tensor t, const string& tag, 965 int max_outputs, float sample_rate) override { 966 uint64 now = env_->NowMicros(); 967 std::unique_ptr<Event> e{new Event}; 968 e->set_step(global_step); 969 e->set_wall_time(DoubleTime(now)); 970 TF_RETURN_IF_ERROR(AddTensorAsAudioToSummary( 971 t, tag, max_outputs, sample_rate, e->mutable_summary())); 972 return MigrateEvent(std::move(e)); 973 } 974 975 string DebugString() const override { return "SummaryDbWriter"; } 976 977 private: 978 Status Write(int64 step, const Tensor& t, const string& tag, 979 const SummaryMetadata& metadata) { 980 uint64 now = env_->NowMicros(); 981 double computed_time = DoubleTime(now); 982 int64 tag_id; 983 TF_RETURN_IF_ERROR( 984 meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata)); 985 TF_RETURN_WITH_CONTEXT_IF_ERROR( 986 run_.Append(db_, tag_id, step, now, computed_time, t), 987 meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(), 988 "/", tag, "@", step); 989 return Status::OK(); 990 } 991 992 Status MigrateEvent(std::unique_ptr<Event> e) { 993 switch (e->what_case()) { 994 case Event::WhatCase::kSummary: { 995 uint64 now = env_->NowMicros(); 996 auto summaries = e->mutable_summary(); 997 for (int i = 0; i < summaries->value_size(); ++i) { 998 Summary::Value* value = summaries->mutable_value(i); 999 TF_RETURN_WITH_CONTEXT_IF_ERROR( 1000 MigrateSummary(e.get(), value, now), meta_.user_name(), "/", 1001 meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(), 1002 "@", e->step()); 1003 } 1004 break; 1005 } 1006 case Event::WhatCase::kGraphDef: 1007 TF_RETURN_WITH_CONTEXT_IF_ERROR( 1008 MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/", 1009 meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@", 1010 e->step()); 1011 break; 1012 default: 1013 // TODO(@jart): Handle other stuff. 1014 break; 1015 } 1016 return Status::OK(); 1017 } 1018 1019 Status MigrateGraph(const Event* e, const string& graph_def) { 1020 uint64 now = env_->NowMicros(); 1021 std::unique_ptr<GraphDef> graph{new GraphDef}; 1022 if (!ParseProtoUnlimited(graph.get(), graph_def)) { 1023 return errors::InvalidArgument("bad proto"); 1024 } 1025 return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph)); 1026 } 1027 1028 Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) { 1029 switch (s->value_case()) { 1030 case Summary::Value::ValueCase::kTensor: 1031 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor"); 1032 break; 1033 case Summary::Value::ValueCase::kSimpleValue: 1034 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar"); 1035 break; 1036 case Summary::Value::ValueCase::kHisto: 1037 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo"); 1038 break; 1039 case Summary::Value::ValueCase::kImage: 1040 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image"); 1041 break; 1042 case Summary::Value::ValueCase::kAudio: 1043 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio"); 1044 break; 1045 default: 1046 break; 1047 } 1048 return Status::OK(); 1049 } 1050 1051 Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) { 1052 Tensor t; 1053 if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto"); 1054 TF_RETURN_IF_ERROR(CheckSupportedType(t)); 1055 int64 tag_id; 1056 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1057 &tag_id, s->metadata())); 1058 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); 1059 } 1060 1061 // TODO(jart): Refactor Summary -> Tensor logic into separate file. 1062 1063 Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) { 1064 // See tensorboard/plugins/scalar/summary.py and data_compat.py 1065 Tensor t{DT_FLOAT, {}}; 1066 t.scalar<float>()() = s->simple_value(); 1067 int64 tag_id; 1068 PatchPluginName(s->mutable_metadata(), kScalarPluginName); 1069 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1070 &tag_id, s->metadata())); 1071 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); 1072 } 1073 1074 Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { 1075 const HistogramProto& histo = s->histo(); 1076 int k = histo.bucket_size(); 1077 if (k != histo.bucket_limit_size()) { 1078 return errors::InvalidArgument("size mismatch"); 1079 } 1080 // See tensorboard/plugins/histogram/summary.py and data_compat.py 1081 Tensor t{DT_DOUBLE, {k, 3}}; 1082 auto data = t.flat<double>(); 1083 for (int i = 0, j = 0; i < k; ++i) { 1084 // TODO(nickfelt): reconcile with TensorBoard's data_compat.py 1085 // From summary.proto 1086 // Parallel arrays encoding the bucket boundaries and the bucket values. 1087 // bucket(i) is the count for the bucket i. The range for 1088 // a bucket is: 1089 // i == 0: -DBL_MAX .. bucket_limit(0) 1090 // i != 0: bucket_limit(i-1) .. bucket_limit(i) 1091 double left_edge = (i == 0) ? std::numeric_limits<double>::min() 1092 : histo.bucket_limit(i - 1); 1093 1094 data(j++) = left_edge; 1095 data(j++) = histo.bucket_limit(i); 1096 data(j++) = histo.bucket(i); 1097 } 1098 int64 tag_id; 1099 PatchPluginName(s->mutable_metadata(), kHistogramPluginName); 1100 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1101 &tag_id, s->metadata())); 1102 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); 1103 } 1104 1105 Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { 1106 // See tensorboard/plugins/image/summary.py and data_compat.py 1107 Tensor t{DT_STRING, {3}}; 1108 auto img = s->mutable_image(); 1109 t.flat<string>()(0) = strings::StrCat(img->width()); 1110 t.flat<string>()(1) = strings::StrCat(img->height()); 1111 t.flat<string>()(2) = std::move(*img->mutable_encoded_image_string()); 1112 int64 tag_id; 1113 PatchPluginName(s->mutable_metadata(), kImagePluginName); 1114 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1115 &tag_id, s->metadata())); 1116 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); 1117 } 1118 1119 Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { 1120 // See tensorboard/plugins/audio/summary.py and data_compat.py 1121 Tensor t{DT_STRING, {1, 2}}; 1122 auto wav = s->mutable_audio(); 1123 t.flat<string>()(0) = std::move(*wav->mutable_encoded_audio_string()); 1124 t.flat<string>()(1) = ""; 1125 int64 tag_id; 1126 PatchPluginName(s->mutable_metadata(), kAudioPluginName); 1127 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1128 &tag_id, s->metadata())); 1129 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); 1130 } 1131 1132 Env* const env_; 1133 Sqlite* const db_; 1134 IdAllocator ids_; 1135 RunMetadata meta_; 1136 RunWriter run_; 1137 }; 1138 1139 } // namespace 1140 1141 Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, 1142 const string& run_name, const string& user_name, 1143 Env* env, SummaryWriterInterface** result) { 1144 *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name); 1145 return Status::OK(); 1146 } 1147 1148 } // namespace tensorflow 1149