1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" 16 17 #include "tensorflow/contrib/tensorboard/db/summary_converter.h" 18 #include "tensorflow/core/framework/graph.pb.h" 19 #include "tensorflow/core/framework/node_def.pb.h" 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/framework/summary.pb.h" 22 #include "tensorflow/core/lib/core/stringpiece.h" 23 #include "tensorflow/core/lib/db/sqlite.h" 24 #include "tensorflow/core/lib/random/random.h" 25 #include "tensorflow/core/util/event.pb.h" 26 27 // TODO(jart): Break this up into multiple files with excellent unit tests. 28 // TODO(jart): Make decision to write in separate op. 29 // TODO(jart): Add really good busy handling. 30 31 // clang-format off 32 #define CALL_SUPPORTED_TYPES(m) \ 33 TF_CALL_string(m) \ 34 TF_CALL_half(m) \ 35 TF_CALL_float(m) \ 36 TF_CALL_double(m) \ 37 TF_CALL_complex64(m) \ 38 TF_CALL_complex128(m) \ 39 TF_CALL_int8(m) \ 40 TF_CALL_int16(m) \ 41 TF_CALL_int32(m) \ 42 TF_CALL_int64(m) \ 43 TF_CALL_uint8(m) \ 44 TF_CALL_uint16(m) \ 45 TF_CALL_uint32(m) \ 46 TF_CALL_uint64(m) 47 // clang-format on 48 49 namespace tensorflow { 50 namespace { 51 52 // https://www.sqlite.org/fileformat.html#record_format 53 const uint64 kIdTiers[] = { 54 0x7fffffULL, // 23-bit (3 bytes on disk) 55 0x7fffffffULL, // 31-bit (4 bytes on disk) 56 0x7fffffffffffULL, // 47-bit (5 bytes on disk) 57 // remaining bits for future use 58 }; 59 const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64); 60 const int kIdCollisionDelayMicros = 10; 61 const int kMaxIdCollisions = 21; // sum(2**i*10s for i in range(21))~=21s 62 const int64 kAbsent = 0LL; 63 64 const char* kScalarPluginName = "scalars"; 65 const char* kImagePluginName = "images"; 66 const char* kAudioPluginName = "audio"; 67 const char* kHistogramPluginName = "histograms"; 68 69 const int kScalarSlots = 10000; 70 const int kImageSlots = 10; 71 const int kAudioSlots = 10; 72 const int kHistogramSlots = 1; 73 const int kTensorSlots = 10; 74 75 const int64 kReserveMinBytes = 32; 76 const double kReserveMultiplier = 1.5; 77 78 // Flush is a misnomer because what we're actually doing is having lots 79 // of commits inside any SqliteTransaction that writes potentially 80 // hundreds of megs but doesn't need the transaction to maintain its 81 // invariants. This ensures the WAL read penalty is small and might 82 // allow writers in other processes a chance to schedule. 83 const uint64 kFlushBytes = 1024 * 1024; 84 85 double DoubleTime(uint64 micros) { 86 // TODO(@jart): Follow precise definitions for time laid out in schema. 87 // TODO(@jart): Use monotonic clock from gRPC codebase. 88 return static_cast<double>(micros) / 1.0e6; 89 } 90 91 string StringifyShape(const TensorShape& shape) { 92 string result; 93 bool first = true; 94 for (const auto& dim : shape) { 95 if (first) { 96 first = false; 97 } else { 98 strings::StrAppend(&result, ","); 99 } 100 strings::StrAppend(&result, dim.size); 101 } 102 return result; 103 } 104 105 Status CheckSupportedType(const Tensor& t) { 106 #define CASE(T) \ 107 case DataTypeToEnum<T>::value: \ 108 break; 109 switch (t.dtype()) { 110 CALL_SUPPORTED_TYPES(CASE) 111 default: 112 return errors::Unimplemented(DataTypeString(t.dtype()), 113 " tensors unsupported on platform"); 114 } 115 return Status::OK(); 116 #undef CASE 117 } 118 119 Tensor AsScalar(const Tensor& t) { 120 Tensor t2{t.dtype(), {}}; 121 #define CASE(T) \ 122 case DataTypeToEnum<T>::value: \ 123 t2.scalar<T>()() = t.flat<T>()(0); \ 124 break; 125 switch (t.dtype()) { 126 CALL_SUPPORTED_TYPES(CASE) 127 default: 128 t2 = {DT_FLOAT, {}}; 129 t2.scalar<float>()() = NAN; 130 break; 131 } 132 return t2; 133 #undef CASE 134 } 135 136 void PatchPluginName(SummaryMetadata* metadata, const char* name) { 137 if (metadata->plugin_data().plugin_name().empty()) { 138 metadata->mutable_plugin_data()->set_plugin_name(name); 139 } 140 } 141 142 int GetSlots(const Tensor& t, const SummaryMetadata& metadata) { 143 if (metadata.plugin_data().plugin_name() == kScalarPluginName) { 144 return kScalarSlots; 145 } else if (metadata.plugin_data().plugin_name() == kImagePluginName) { 146 return kImageSlots; 147 } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) { 148 return kAudioSlots; 149 } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) { 150 return kHistogramSlots; 151 } else if (t.dims() == 0 && t.dtype() != DT_STRING) { 152 return kScalarSlots; 153 } else { 154 return kTensorSlots; 155 } 156 } 157 158 Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) { 159 const char* sql = R"sql( 160 INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) 161 )sql"; 162 SqliteStatement insert_desc; 163 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc)); 164 insert_desc.BindInt(1, id); 165 insert_desc.BindText(2, markdown); 166 return insert_desc.StepAndReset(); 167 } 168 169 /// \brief Generates unique IDs randomly in the [1,2**63-1] range. 170 /// 171 /// This class starts off generating IDs in the [1,2**23-1] range, 172 /// because it's human friendly and occupies 4 bytes max on disk with 173 /// SQLite's zigzag varint encoding. Then, each time a collision 174 /// happens, the random space is increased by 8 bits. 175 /// 176 /// This class uses exponential back-off so writes gradually slow down 177 /// as IDs become exhausted but reads are still possible. 178 /// 179 /// This class is thread safe. 180 class IdAllocator { 181 public: 182 IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} { 183 DCHECK(env_ != nullptr); 184 DCHECK(db_ != nullptr); 185 } 186 187 Status CreateNewId(int64* id) LOCKS_EXCLUDED(mu_) { 188 mutex_lock lock(mu_); 189 Status s; 190 SqliteStatement stmt; 191 TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt)); 192 for (int i = 0; i < kMaxIdCollisions; ++i) { 193 int64 tid = MakeRandomId(); 194 stmt.BindInt(1, tid); 195 s = stmt.StepAndReset(); 196 if (s.ok()) { 197 *id = tid; 198 break; 199 } 200 // SQLITE_CONSTRAINT maps to INVALID_ARGUMENT in sqlite.cc 201 if (s.code() != error::INVALID_ARGUMENT) break; 202 if (tier_ < kMaxIdTier) { 203 LOG(INFO) << "IdAllocator collision at tier " << tier_ << " (of " 204 << kMaxIdTier << ") so auto-adjusting to a higher tier"; 205 ++tier_; 206 } else { 207 LOG(WARNING) << "IdAllocator (attempt #" << i << ") " 208 << "resulted in a collision at the highest tier; this " 209 "is problematic if it happens often; you can try " 210 "pruning the Ids table; you can also file a bug " 211 "asking for the ID space to be increased; otherwise " 212 "writes will gradually slow down over time until they " 213 "become impossible"; 214 } 215 env_->SleepForMicroseconds((1 << i) * kIdCollisionDelayMicros); 216 } 217 return s; 218 } 219 220 private: 221 int64 MakeRandomId() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 222 int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]); 223 if (id == kAbsent) ++id; 224 return id; 225 } 226 227 mutex mu_; 228 Env* const env_; 229 Sqlite* const db_; 230 int tier_ GUARDED_BY(mu_) = 0; 231 232 TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator); 233 }; 234 235 class GraphWriter { 236 public: 237 static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids, 238 GraphDef* graph, uint64 now, int64 run_id, int64* graph_id) 239 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { 240 TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id)); 241 GraphWriter saver{db, txn, graph, now, *graph_id}; 242 saver.MapNameToNodeId(); 243 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs"); 244 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes"); 245 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph"); 246 return Status::OK(); 247 } 248 249 private: 250 GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now, 251 int64 graph_id) 252 : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {} 253 254 void MapNameToNodeId() { 255 size_t toto = static_cast<size_t>(graph_->node_size()); 256 name_copies_.reserve(toto); 257 name_to_node_id_.reserve(toto); 258 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { 259 // Copy name into memory region, since we call clear_name() later. 260 // Then wrap in StringPiece so we can compare slices without copy. 261 name_copies_.emplace_back(graph_->node(node_id).name()); 262 name_to_node_id_.emplace(name_copies_.back(), node_id); 263 } 264 } 265 266 Status SaveNodeInputs() { 267 const char* sql = R"sql( 268 INSERT INTO NodeInputs ( 269 graph_id, 270 node_id, 271 idx, 272 input_node_id, 273 input_node_idx, 274 is_control 275 ) VALUES (?, ?, ?, ?, ?, ?) 276 )sql"; 277 SqliteStatement insert; 278 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); 279 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { 280 const NodeDef& node = graph_->node(node_id); 281 for (int idx = 0; idx < node.input_size(); ++idx) { 282 StringPiece name = node.input(idx); 283 int64 input_node_id; 284 int64 input_node_idx = 0; 285 int64 is_control = 0; 286 size_t i = name.rfind(':'); 287 if (i != StringPiece::npos) { 288 if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1), 289 &input_node_idx)) { 290 return errors::DataLoss("Bad NodeDef.input: ", name); 291 } 292 name.remove_suffix(name.size() - i); 293 } 294 if (!name.empty() && name[0] == '^') { 295 name.remove_prefix(1); 296 is_control = 1; 297 } 298 auto e = name_to_node_id_.find(name); 299 if (e == name_to_node_id_.end()) { 300 return errors::DataLoss("Could not find node: ", name); 301 } 302 input_node_id = e->second; 303 insert.BindInt(1, graph_id_); 304 insert.BindInt(2, node_id); 305 insert.BindInt(3, idx); 306 insert.BindInt(4, input_node_id); 307 insert.BindInt(5, input_node_idx); 308 insert.BindInt(6, is_control); 309 unflushed_bytes_ += insert.size(); 310 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(), 311 " -> ", name); 312 TF_RETURN_IF_ERROR(MaybeFlush()); 313 } 314 } 315 return Status::OK(); 316 } 317 318 Status SaveNodes() { 319 const char* sql = R"sql( 320 INSERT INTO Nodes ( 321 graph_id, 322 node_id, 323 node_name, 324 op, 325 device, 326 node_def) 327 VALUES (?, ?, ?, ?, ?, ?) 328 )sql"; 329 SqliteStatement insert; 330 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); 331 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { 332 NodeDef* node = graph_->mutable_node(node_id); 333 insert.BindInt(1, graph_id_); 334 insert.BindInt(2, node_id); 335 insert.BindText(3, node->name()); 336 insert.BindText(4, node->op()); 337 insert.BindText(5, node->device()); 338 node->clear_name(); 339 node->clear_op(); 340 node->clear_device(); 341 node->clear_input(); 342 string node_def; 343 if (node->SerializeToString(&node_def)) { 344 insert.BindBlobUnsafe(6, node_def); 345 } 346 unflushed_bytes_ += insert.size(); 347 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name()); 348 TF_RETURN_IF_ERROR(MaybeFlush()); 349 } 350 return Status::OK(); 351 } 352 353 Status SaveGraph(int64 run_id) { 354 const char* sql = R"sql( 355 INSERT OR REPLACE INTO Graphs ( 356 run_id, 357 graph_id, 358 inserted_time, 359 graph_def 360 ) VALUES (?, ?, ?, ?) 361 )sql"; 362 SqliteStatement insert; 363 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); 364 if (run_id != kAbsent) insert.BindInt(1, run_id); 365 insert.BindInt(2, graph_id_); 366 insert.BindDouble(3, DoubleTime(now_)); 367 graph_->clear_node(); 368 string graph_def; 369 if (graph_->SerializeToString(&graph_def)) { 370 insert.BindBlobUnsafe(4, graph_def); 371 } 372 return insert.StepAndReset(); 373 } 374 375 Status MaybeFlush() { 376 if (unflushed_bytes_ >= kFlushBytes) { 377 TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ", 378 unflushed_bytes_, " bytes"); 379 unflushed_bytes_ = 0; 380 } 381 return Status::OK(); 382 } 383 384 Sqlite* const db_; 385 SqliteTransaction* const txn_; 386 uint64 unflushed_bytes_ = 0; 387 GraphDef* const graph_; 388 const uint64 now_; 389 const int64 graph_id_; 390 std::vector<string> name_copies_; 391 std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_; 392 393 TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter); 394 }; 395 396 /// \brief Run metadata manager. 397 /// 398 /// This class gives us Tag IDs we can pass to SeriesWriter. In order 399 /// to do that, rows are created in the Ids, Tags, Runs, Experiments, 400 /// and Users tables. 401 /// 402 /// This class is thread safe. 403 class RunMetadata { 404 public: 405 RunMetadata(IdAllocator* ids, const string& experiment_name, 406 const string& run_name, const string& user_name) 407 : ids_{ids}, 408 experiment_name_{experiment_name}, 409 run_name_{run_name}, 410 user_name_{user_name} { 411 DCHECK(ids_ != nullptr); 412 } 413 414 const string& experiment_name() { return experiment_name_; } 415 const string& run_name() { return run_name_; } 416 const string& user_name() { return user_name_; } 417 418 int64 run_id() LOCKS_EXCLUDED(mu_) { 419 mutex_lock lock(mu_); 420 return run_id_; 421 } 422 423 Status SetGraph(Sqlite* db, uint64 now, double computed_time, 424 std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db) 425 LOCKS_EXCLUDED(mu_) { 426 int64 run_id; 427 { 428 mutex_lock lock(mu_); 429 TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); 430 run_id = run_id_; 431 } 432 int64 graph_id; 433 SqliteTransaction txn(*db); // only to increase performance 434 TF_RETURN_IF_ERROR( 435 GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id)); 436 return txn.Commit(); 437 } 438 439 Status GetTagId(Sqlite* db, uint64 now, double computed_time, 440 const string& tag_name, int64* tag_id, 441 const SummaryMetadata& metadata) LOCKS_EXCLUDED(mu_) { 442 mutex_lock lock(mu_); 443 TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); 444 auto e = tag_ids_.find(tag_name); 445 if (e != tag_ids_.end()) { 446 *tag_id = e->second; 447 return Status::OK(); 448 } 449 TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id)); 450 tag_ids_[tag_name] = *tag_id; 451 TF_RETURN_IF_ERROR( 452 SetDescription(db, *tag_id, metadata.summary_description())); 453 const char* sql = R"sql( 454 INSERT INTO Tags ( 455 run_id, 456 tag_id, 457 tag_name, 458 inserted_time, 459 display_name, 460 plugin_name, 461 plugin_data 462 ) VALUES ( 463 :run_id, 464 :tag_id, 465 :tag_name, 466 :inserted_time, 467 :display_name, 468 :plugin_name, 469 :plugin_data 470 ) 471 )sql"; 472 SqliteStatement insert; 473 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert)); 474 if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_); 475 insert.BindInt(":tag_id", *tag_id); 476 insert.BindTextUnsafe(":tag_name", tag_name); 477 insert.BindDouble(":inserted_time", DoubleTime(now)); 478 insert.BindTextUnsafe(":display_name", metadata.display_name()); 479 insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name()); 480 insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content()); 481 return insert.StepAndReset(); 482 } 483 484 Status GetIsWatching(Sqlite* db, bool* is_watching) 485 SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { 486 mutex_lock lock(mu_); 487 if (experiment_id_ == kAbsent) { 488 *is_watching = true; 489 return Status::OK(); 490 } 491 const char* sql = R"sql( 492 SELECT is_watching FROM Experiments WHERE experiment_id = ? 493 )sql"; 494 SqliteStatement stmt; 495 TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); 496 stmt.BindInt(1, experiment_id_); 497 TF_RETURN_IF_ERROR(stmt.StepOnce()); 498 *is_watching = stmt.ColumnInt(0) != 0; 499 return Status::OK(); 500 } 501 502 private: 503 Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) { 504 if (user_id_ != kAbsent || user_name_.empty()) return Status::OK(); 505 const char* get_sql = R"sql( 506 SELECT user_id FROM Users WHERE user_name = ? 507 )sql"; 508 SqliteStatement get; 509 TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get)); 510 get.BindText(1, user_name_); 511 bool is_done; 512 TF_RETURN_IF_ERROR(get.Step(&is_done)); 513 if (!is_done) { 514 user_id_ = get.ColumnInt(0); 515 return Status::OK(); 516 } 517 TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_)); 518 const char* insert_sql = R"sql( 519 INSERT INTO Users ( 520 user_id, 521 user_name, 522 inserted_time 523 ) VALUES (?, ?, ?) 524 )sql"; 525 SqliteStatement insert; 526 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); 527 insert.BindInt(1, user_id_); 528 insert.BindText(2, user_name_); 529 insert.BindDouble(3, DoubleTime(now)); 530 TF_RETURN_IF_ERROR(insert.StepAndReset()); 531 return Status::OK(); 532 } 533 534 Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time) 535 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 536 if (experiment_name_.empty()) return Status::OK(); 537 if (experiment_id_ == kAbsent) { 538 TF_RETURN_IF_ERROR(InitializeUser(db, now)); 539 const char* get_sql = R"sql( 540 SELECT 541 experiment_id, 542 started_time 543 FROM 544 Experiments 545 WHERE 546 user_id IS ? 547 AND experiment_name = ? 548 )sql"; 549 SqliteStatement get; 550 TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get)); 551 if (user_id_ != kAbsent) get.BindInt(1, user_id_); 552 get.BindText(2, experiment_name_); 553 bool is_done; 554 TF_RETURN_IF_ERROR(get.Step(&is_done)); 555 if (!is_done) { 556 experiment_id_ = get.ColumnInt(0); 557 experiment_started_time_ = get.ColumnInt(1); 558 } else { 559 TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_)); 560 experiment_started_time_ = computed_time; 561 const char* insert_sql = R"sql( 562 INSERT INTO Experiments ( 563 user_id, 564 experiment_id, 565 experiment_name, 566 inserted_time, 567 started_time, 568 is_watching 569 ) VALUES (?, ?, ?, ?, ?, ?) 570 )sql"; 571 SqliteStatement insert; 572 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); 573 if (user_id_ != kAbsent) insert.BindInt(1, user_id_); 574 insert.BindInt(2, experiment_id_); 575 insert.BindText(3, experiment_name_); 576 insert.BindDouble(4, DoubleTime(now)); 577 insert.BindDouble(5, computed_time); 578 insert.BindInt(6, 0); 579 TF_RETURN_IF_ERROR(insert.StepAndReset()); 580 } 581 } 582 if (computed_time < experiment_started_time_) { 583 experiment_started_time_ = computed_time; 584 const char* update_sql = R"sql( 585 UPDATE 586 Experiments 587 SET 588 started_time = ? 589 WHERE 590 experiment_id = ? 591 )sql"; 592 SqliteStatement update; 593 TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update)); 594 update.BindDouble(1, computed_time); 595 update.BindInt(2, experiment_id_); 596 TF_RETURN_IF_ERROR(update.StepAndReset()); 597 } 598 return Status::OK(); 599 } 600 601 Status InitializeRun(Sqlite* db, uint64 now, double computed_time) 602 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 603 if (run_name_.empty()) return Status::OK(); 604 TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time)); 605 if (run_id_ == kAbsent) { 606 TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_)); 607 run_started_time_ = computed_time; 608 const char* insert_sql = R"sql( 609 INSERT OR REPLACE INTO Runs ( 610 experiment_id, 611 run_id, 612 run_name, 613 inserted_time, 614 started_time 615 ) VALUES (?, ?, ?, ?, ?) 616 )sql"; 617 SqliteStatement insert; 618 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); 619 if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_); 620 insert.BindInt(2, run_id_); 621 insert.BindText(3, run_name_); 622 insert.BindDouble(4, DoubleTime(now)); 623 insert.BindDouble(5, computed_time); 624 TF_RETURN_IF_ERROR(insert.StepAndReset()); 625 } 626 if (computed_time < run_started_time_) { 627 run_started_time_ = computed_time; 628 const char* update_sql = R"sql( 629 UPDATE 630 Runs 631 SET 632 started_time = ? 633 WHERE 634 run_id = ? 635 )sql"; 636 SqliteStatement update; 637 TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update)); 638 update.BindDouble(1, computed_time); 639 update.BindInt(2, run_id_); 640 TF_RETURN_IF_ERROR(update.StepAndReset()); 641 } 642 return Status::OK(); 643 } 644 645 mutex mu_; 646 IdAllocator* const ids_; 647 const string experiment_name_; 648 const string run_name_; 649 const string user_name_; 650 int64 experiment_id_ GUARDED_BY(mu_) = kAbsent; 651 int64 run_id_ GUARDED_BY(mu_) = kAbsent; 652 int64 user_id_ GUARDED_BY(mu_) = kAbsent; 653 double experiment_started_time_ GUARDED_BY(mu_) = 0.0; 654 double run_started_time_ GUARDED_BY(mu_) = 0.0; 655 std::unordered_map<string, int64> tag_ids_ GUARDED_BY(mu_); 656 657 TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata); 658 }; 659 660 /// \brief Tensor writer for a single series, e.g. Tag. 661 /// 662 /// This class can be used to write an infinite stream of Tensors to the 663 /// database in a fixed block of contiguous disk space. This is 664 /// accomplished using Algorithm R reservoir sampling. 665 /// 666 /// The reservoir consists of a fixed number of rows, which are inserted 667 /// using ZEROBLOB upon receiving the first sample, which is used to 668 /// predict how big the other ones are likely to be. This is done 669 /// transactionally in a way that tries to be mindful of other processes 670 /// that might be trying to access the same DB. 671 /// 672 /// Once the reservoir fills up, rows are replaced at random, and writes 673 /// gradually become no-ops. This allows long training to go fast 674 /// without configuration. The exception is when someone is actually 675 /// looking at TensorBoard. When that happens, the "keep last" behavior 676 /// is turned on and Append() will always result in a write. 677 /// 678 /// If no one is watching training, this class still holds on to the 679 /// most recent "dangling" Tensor, so if Finish() is called, the most 680 /// recent training state can be written to disk. 681 /// 682 /// The randomly selected sampling points should be consistent across 683 /// multiple instances. 684 /// 685 /// This class is thread safe. 686 class SeriesWriter { 687 public: 688 SeriesWriter(int64 series, int slots, RunMetadata* meta) 689 : series_{series}, 690 slots_{slots}, 691 meta_{meta}, 692 rng_{std::mt19937_64::default_seed} { 693 DCHECK(series_ > 0); 694 DCHECK(slots_ > 0); 695 } 696 697 Status Append(Sqlite* db, int64 step, uint64 now, double computed_time, 698 Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db) 699 LOCKS_EXCLUDED(mu_) { 700 mutex_lock lock(mu_); 701 if (rowids_.empty()) { 702 Status s = Reserve(db, t); 703 if (!s.ok()) { 704 rowids_.clear(); 705 return s; 706 } 707 } 708 DCHECK(rowids_.size() == slots_); 709 int64 rowid; 710 size_t i = count_; 711 if (i < slots_) { 712 rowid = last_rowid_ = rowids_[i]; 713 } else { 714 i = rng_() % (i + 1); 715 if (i < slots_) { 716 rowid = last_rowid_ = rowids_[i]; 717 } else { 718 bool keep_last; 719 TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last)); 720 if (!keep_last) { 721 ++count_; 722 dangling_tensor_.reset(new Tensor(std::move(t))); 723 dangling_step_ = step; 724 dangling_computed_time_ = computed_time; 725 return Status::OK(); 726 } 727 rowid = last_rowid_; 728 } 729 } 730 Status s = Write(db, rowid, step, computed_time, t); 731 if (s.ok()) { 732 ++count_; 733 dangling_tensor_.reset(); 734 } 735 return s; 736 } 737 738 Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) 739 LOCKS_EXCLUDED(mu_) { 740 mutex_lock lock(mu_); 741 // Short runs: Delete unused pre-allocated Tensors. 742 if (count_ < rowids_.size()) { 743 SqliteTransaction txn(*db); 744 const char* sql = R"sql( 745 DELETE FROM Tensors WHERE rowid = ? 746 )sql"; 747 SqliteStatement deleter; 748 TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter)); 749 for (size_t i = count_; i < rowids_.size(); ++i) { 750 deleter.BindInt(1, rowids_[i]); 751 TF_RETURN_IF_ERROR(deleter.StepAndReset()); 752 } 753 TF_RETURN_IF_ERROR(txn.Commit()); 754 rowids_.clear(); 755 } 756 // Long runs: Make last sample be the very most recent one. 757 if (dangling_tensor_) { 758 DCHECK(last_rowid_ != kAbsent); 759 TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_, 760 dangling_computed_time_, *dangling_tensor_)); 761 dangling_tensor_.reset(); 762 } 763 return Status::OK(); 764 } 765 766 private: 767 Status Write(Sqlite* db, int64 rowid, int64 step, double computed_time, 768 const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) { 769 if (t.dtype() == DT_STRING) { 770 if (t.dims() == 0) { 771 return Update(db, step, computed_time, t, t.scalar<string>()(), rowid); 772 } else { 773 SqliteTransaction txn(*db); 774 TF_RETURN_IF_ERROR( 775 Update(db, step, computed_time, t, StringPiece(), rowid)); 776 TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid)); 777 return txn.Commit(); 778 } 779 } else { 780 return Update(db, step, computed_time, t, t.tensor_data(), rowid); 781 } 782 } 783 784 Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t, 785 const StringPiece& data, int64 rowid) { 786 // TODO(jart): How can we ensure reservoir fills on replace? 787 const char* sql = R"sql( 788 UPDATE OR REPLACE 789 Tensors 790 SET 791 step = ?, 792 computed_time = ?, 793 dtype = ?, 794 shape = ?, 795 data = ? 796 WHERE 797 rowid = ? 798 )sql"; 799 SqliteStatement stmt; 800 TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); 801 stmt.BindInt(1, step); 802 stmt.BindDouble(2, computed_time); 803 stmt.BindInt(3, t.dtype()); 804 stmt.BindText(4, StringifyShape(t.shape())); 805 stmt.BindBlobUnsafe(5, data); 806 stmt.BindInt(6, rowid); 807 TF_RETURN_IF_ERROR(stmt.StepAndReset()); 808 return Status::OK(); 809 } 810 811 Status UpdateNdString(Sqlite* db, const Tensor& t, int64 tensor_rowid) 812 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { 813 DCHECK_EQ(t.dtype(), DT_STRING); 814 DCHECK_GT(t.dims(), 0); 815 const char* deleter_sql = R"sql( 816 DELETE FROM TensorStrings WHERE tensor_rowid = ? 817 )sql"; 818 SqliteStatement deleter; 819 TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter)); 820 deleter.BindInt(1, tensor_rowid); 821 TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid); 822 const char* inserter_sql = R"sql( 823 INSERT INTO TensorStrings ( 824 tensor_rowid, 825 idx, 826 data 827 ) VALUES (?, ?, ?) 828 )sql"; 829 SqliteStatement inserter; 830 TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter)); 831 auto flat = t.flat<string>(); 832 for (int64 i = 0; i < flat.size(); ++i) { 833 inserter.BindInt(1, tensor_rowid); 834 inserter.BindInt(2, i); 835 inserter.BindBlobUnsafe(3, flat(i)); 836 TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i); 837 } 838 return Status::OK(); 839 } 840 841 Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) 842 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 843 SqliteTransaction txn(*db); // only for performance 844 unflushed_bytes_ = 0; 845 if (t.dtype() == DT_STRING) { 846 if (t.dims() == 0) { 847 TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<string>()().size())); 848 } else { 849 TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes)); 850 } 851 } else { 852 TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size())); 853 } 854 return txn.Commit(); 855 } 856 857 Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size) 858 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) 859 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 860 int64 space = 861 static_cast<int64>(static_cast<double>(size) * kReserveMultiplier); 862 if (space < kReserveMinBytes) space = kReserveMinBytes; 863 return ReserveTensors(db, txn, space); 864 } 865 866 Status ReserveTensors(Sqlite* db, SqliteTransaction* txn, 867 int64 reserved_bytes) 868 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) 869 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 870 const char* sql = R"sql( 871 INSERT INTO Tensors ( 872 series, 873 data 874 ) VALUES (?, ZEROBLOB(?)) 875 )sql"; 876 SqliteStatement insert; 877 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert)); 878 // TODO(jart): Maybe preallocate index pages by setting step. This 879 // is tricky because UPDATE OR REPLACE can have a side 880 // effect of deleting preallocated rows. 881 for (int64 i = 0; i < slots_; ++i) { 882 insert.BindInt(1, series_); 883 insert.BindInt(2, reserved_bytes); 884 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i); 885 rowids_.push_back(db->last_insert_rowid()); 886 unflushed_bytes_ += reserved_bytes; 887 TF_RETURN_IF_ERROR(MaybeFlush(db, txn)); 888 } 889 return Status::OK(); 890 } 891 892 Status MaybeFlush(Sqlite* db, SqliteTransaction* txn) 893 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) 894 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 895 if (unflushed_bytes_ >= kFlushBytes) { 896 TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ", 897 unflushed_bytes_, " bytes"); 898 unflushed_bytes_ = 0; 899 } 900 return Status::OK(); 901 } 902 903 mutex mu_; 904 const int64 series_; 905 const int slots_; 906 RunMetadata* const meta_; 907 std::mt19937_64 rng_ GUARDED_BY(mu_); 908 uint64 count_ GUARDED_BY(mu_) = 0; 909 int64 last_rowid_ GUARDED_BY(mu_) = kAbsent; 910 std::vector<int64> rowids_ GUARDED_BY(mu_); 911 uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0; 912 std::unique_ptr<Tensor> dangling_tensor_ GUARDED_BY(mu_); 913 int64 dangling_step_ GUARDED_BY(mu_) = 0; 914 double dangling_computed_time_ GUARDED_BY(mu_) = 0.0; 915 916 TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter); 917 }; 918 919 /// \brief Tensor writer for a single Run. 920 /// 921 /// This class farms out tensors to SeriesWriter instances. It also 922 /// keeps track of whether or not someone is watching the TensorBoard 923 /// GUI, so it can avoid writes when possible. 924 /// 925 /// This class is thread safe. 926 class RunWriter { 927 public: 928 explicit RunWriter(RunMetadata* meta) : meta_{meta} {} 929 930 Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now, 931 double computed_time, Tensor t, int slots) 932 SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { 933 SeriesWriter* writer = GetSeriesWriter(tag_id, slots); 934 return writer->Append(db, step, now, computed_time, std::move(t)); 935 } 936 937 Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) 938 LOCKS_EXCLUDED(mu_) { 939 mutex_lock lock(mu_); 940 if (series_writers_.empty()) return Status::OK(); 941 for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) { 942 if (!i->second) continue; 943 TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db), 944 "finish tag_id=", i->first); 945 i->second.reset(); 946 } 947 return Status::OK(); 948 } 949 950 private: 951 SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) { 952 mutex_lock sl(mu_); 953 auto spot = series_writers_.find(tag_id); 954 if (spot == series_writers_.end()) { 955 SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_); 956 series_writers_[tag_id].reset(writer); 957 return writer; 958 } else { 959 return spot->second.get(); 960 } 961 } 962 963 mutex mu_; 964 RunMetadata* const meta_; 965 std::unordered_map<int64, std::unique_ptr<SeriesWriter>> series_writers_ 966 GUARDED_BY(mu_); 967 968 TF_DISALLOW_COPY_AND_ASSIGN(RunWriter); 969 }; 970 971 /// \brief SQLite implementation of SummaryWriterInterface. 972 /// 973 /// This class is thread safe. 974 class SummaryDbWriter : public SummaryWriterInterface { 975 public: 976 SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name, 977 const string& run_name, const string& user_name) 978 : SummaryWriterInterface(), 979 env_{env}, 980 db_{db}, 981 ids_{env_, db_}, 982 meta_{&ids_, experiment_name, run_name, user_name}, 983 run_{&meta_} { 984 DCHECK(env_ != nullptr); 985 db_->Ref(); 986 } 987 988 ~SummaryDbWriter() override { 989 core::ScopedUnref unref(db_); 990 Status s = run_.Finish(db_); 991 if (!s.ok()) { 992 // TODO(jart): Retry on transient errors here. 993 LOG(ERROR) << s.ToString(); 994 } 995 int64 run_id = meta_.run_id(); 996 if (run_id == kAbsent) return; 997 const char* sql = R"sql( 998 UPDATE Runs SET finished_time = ? WHERE run_id = ? 999 )sql"; 1000 SqliteStatement update; 1001 s = db_->Prepare(sql, &update); 1002 if (s.ok()) { 1003 update.BindDouble(1, DoubleTime(env_->NowMicros())); 1004 update.BindInt(2, run_id); 1005 s = update.StepAndReset(); 1006 } 1007 if (!s.ok()) { 1008 LOG(ERROR) << "Failed to set Runs[" << run_id 1009 << "].finish_time: " << s.ToString(); 1010 } 1011 } 1012 1013 Status Flush() override { return Status::OK(); } 1014 1015 Status WriteTensor(int64 global_step, Tensor t, const string& tag, 1016 const string& serialized_metadata) override { 1017 TF_RETURN_IF_ERROR(CheckSupportedType(t)); 1018 SummaryMetadata metadata; 1019 if (!metadata.ParseFromString(serialized_metadata)) { 1020 return errors::InvalidArgument("Bad serialized_metadata"); 1021 } 1022 return Write(global_step, t, tag, metadata); 1023 } 1024 1025 Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { 1026 TF_RETURN_IF_ERROR(CheckSupportedType(t)); 1027 SummaryMetadata metadata; 1028 PatchPluginName(&metadata, kScalarPluginName); 1029 return Write(global_step, AsScalar(t), tag, metadata); 1030 } 1031 1032 Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override { 1033 uint64 now = env_->NowMicros(); 1034 return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g)); 1035 } 1036 1037 Status WriteEvent(std::unique_ptr<Event> e) override { 1038 return MigrateEvent(std::move(e)); 1039 } 1040 1041 Status WriteHistogram(int64 global_step, Tensor t, 1042 const string& tag) override { 1043 uint64 now = env_->NowMicros(); 1044 std::unique_ptr<Event> e{new Event}; 1045 e->set_step(global_step); 1046 e->set_wall_time(DoubleTime(now)); 1047 TF_RETURN_IF_ERROR( 1048 AddTensorAsHistogramToSummary(t, tag, e->mutable_summary())); 1049 return MigrateEvent(std::move(e)); 1050 } 1051 1052 Status WriteImage(int64 global_step, Tensor t, const string& tag, 1053 int max_images, Tensor bad_color) override { 1054 uint64 now = env_->NowMicros(); 1055 std::unique_ptr<Event> e{new Event}; 1056 e->set_step(global_step); 1057 e->set_wall_time(DoubleTime(now)); 1058 TF_RETURN_IF_ERROR(AddTensorAsImageToSummary(t, tag, max_images, bad_color, 1059 e->mutable_summary())); 1060 return MigrateEvent(std::move(e)); 1061 } 1062 1063 Status WriteAudio(int64 global_step, Tensor t, const string& tag, 1064 int max_outputs, float sample_rate) override { 1065 uint64 now = env_->NowMicros(); 1066 std::unique_ptr<Event> e{new Event}; 1067 e->set_step(global_step); 1068 e->set_wall_time(DoubleTime(now)); 1069 TF_RETURN_IF_ERROR(AddTensorAsAudioToSummary( 1070 t, tag, max_outputs, sample_rate, e->mutable_summary())); 1071 return MigrateEvent(std::move(e)); 1072 } 1073 1074 string DebugString() override { return "SummaryDbWriter"; } 1075 1076 private: 1077 Status Write(int64 step, const Tensor& t, const string& tag, 1078 const SummaryMetadata& metadata) { 1079 uint64 now = env_->NowMicros(); 1080 double computed_time = DoubleTime(now); 1081 int64 tag_id; 1082 TF_RETURN_IF_ERROR( 1083 meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata)); 1084 TF_RETURN_WITH_CONTEXT_IF_ERROR( 1085 run_.Append(db_, tag_id, step, now, computed_time, t, 1086 GetSlots(t, metadata)), 1087 meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(), 1088 "/", tag, "@", step); 1089 return Status::OK(); 1090 } 1091 1092 Status MigrateEvent(std::unique_ptr<Event> e) { 1093 switch (e->what_case()) { 1094 case Event::WhatCase::kSummary: { 1095 uint64 now = env_->NowMicros(); 1096 auto summaries = e->mutable_summary(); 1097 for (int i = 0; i < summaries->value_size(); ++i) { 1098 Summary::Value* value = summaries->mutable_value(i); 1099 TF_RETURN_WITH_CONTEXT_IF_ERROR( 1100 MigrateSummary(e.get(), value, now), meta_.user_name(), "/", 1101 meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(), 1102 "@", e->step()); 1103 } 1104 break; 1105 } 1106 case Event::WhatCase::kGraphDef: 1107 TF_RETURN_WITH_CONTEXT_IF_ERROR( 1108 MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/", 1109 meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@", 1110 e->step()); 1111 break; 1112 default: 1113 // TODO(@jart): Handle other stuff. 1114 break; 1115 } 1116 return Status::OK(); 1117 } 1118 1119 Status MigrateGraph(const Event* e, const string& graph_def) { 1120 uint64 now = env_->NowMicros(); 1121 std::unique_ptr<GraphDef> graph{new GraphDef}; 1122 if (!ParseProtoUnlimited(graph.get(), graph_def)) { 1123 return errors::InvalidArgument("bad proto"); 1124 } 1125 return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph)); 1126 } 1127 1128 Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) { 1129 switch (s->value_case()) { 1130 case Summary::Value::ValueCase::kTensor: 1131 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor"); 1132 break; 1133 case Summary::Value::ValueCase::kSimpleValue: 1134 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar"); 1135 break; 1136 case Summary::Value::ValueCase::kHisto: 1137 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo"); 1138 break; 1139 case Summary::Value::ValueCase::kImage: 1140 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image"); 1141 break; 1142 case Summary::Value::ValueCase::kAudio: 1143 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio"); 1144 break; 1145 default: 1146 break; 1147 } 1148 return Status::OK(); 1149 } 1150 1151 Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) { 1152 Tensor t; 1153 if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto"); 1154 TF_RETURN_IF_ERROR(CheckSupportedType(t)); 1155 int64 tag_id; 1156 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1157 &tag_id, s->metadata())); 1158 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t, 1159 GetSlots(t, s->metadata())); 1160 } 1161 1162 // TODO(jart): Refactor Summary -> Tensor logic into separate file. 1163 1164 Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) { 1165 // See tensorboard/plugins/scalar/summary.py and data_compat.py 1166 Tensor t{DT_FLOAT, {}}; 1167 t.scalar<float>()() = s->simple_value(); 1168 int64 tag_id; 1169 PatchPluginName(s->mutable_metadata(), kScalarPluginName); 1170 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1171 &tag_id, s->metadata())); 1172 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), 1173 std::move(t), kScalarSlots); 1174 } 1175 1176 Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { 1177 const HistogramProto& histo = s->histo(); 1178 int k = histo.bucket_size(); 1179 if (k != histo.bucket_limit_size()) { 1180 return errors::InvalidArgument("size mismatch"); 1181 } 1182 // See tensorboard/plugins/histogram/summary.py and data_compat.py 1183 Tensor t{DT_DOUBLE, {k, 3}}; 1184 auto data = t.flat<double>(); 1185 for (int i = 0; i < k; ++i) { 1186 double left_edge = ((i - 1 >= 0) ? histo.bucket_limit(i - 1) 1187 : std::numeric_limits<double>::min()); 1188 double right_edge = ((i + 1 < k) ? histo.bucket_limit(i + 1) 1189 : std::numeric_limits<double>::max()); 1190 data(i + 0) = left_edge; 1191 data(i + 1) = right_edge; 1192 data(i + 2) = histo.bucket(i); 1193 } 1194 int64 tag_id; 1195 PatchPluginName(s->mutable_metadata(), kHistogramPluginName); 1196 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1197 &tag_id, s->metadata())); 1198 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), 1199 std::move(t), kHistogramSlots); 1200 } 1201 1202 Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { 1203 // See tensorboard/plugins/image/summary.py and data_compat.py 1204 Tensor t{DT_STRING, {3}}; 1205 auto img = s->mutable_image(); 1206 t.flat<string>()(0) = strings::StrCat(img->width()); 1207 t.flat<string>()(1) = strings::StrCat(img->height()); 1208 t.flat<string>()(2) = std::move(*img->mutable_encoded_image_string()); 1209 int64 tag_id; 1210 PatchPluginName(s->mutable_metadata(), kImagePluginName); 1211 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1212 &tag_id, s->metadata())); 1213 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), 1214 std::move(t), kImageSlots); 1215 } 1216 1217 Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { 1218 // See tensorboard/plugins/audio/summary.py and data_compat.py 1219 Tensor t{DT_STRING, {1, 2}}; 1220 auto wav = s->mutable_audio(); 1221 t.flat<string>()(0) = std::move(*wav->mutable_encoded_audio_string()); 1222 t.flat<string>()(1) = ""; 1223 int64 tag_id; 1224 PatchPluginName(s->mutable_metadata(), kAudioPluginName); 1225 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), 1226 &tag_id, s->metadata())); 1227 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), 1228 std::move(t), kAudioSlots); 1229 } 1230 1231 Env* const env_; 1232 Sqlite* const db_; 1233 IdAllocator ids_; 1234 RunMetadata meta_; 1235 RunWriter run_; 1236 }; 1237 1238 } // namespace 1239 1240 Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, 1241 const string& run_name, const string& user_name, 1242 Env* env, SummaryWriterInterface** result) { 1243 *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name); 1244 return Status::OK(); 1245 } 1246 1247 } // namespace tensorflow 1248