Home | History | Annotate | Download | only in db
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
     16 
     17 #include "tensorflow/contrib/tensorboard/db/summary_converter.h"
     18 #include "tensorflow/core/framework/graph.pb.h"
     19 #include "tensorflow/core/framework/node_def.pb.h"
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/framework/summary.pb.h"
     22 #include "tensorflow/core/lib/core/stringpiece.h"
     23 #include "tensorflow/core/lib/db/sqlite.h"
     24 #include "tensorflow/core/lib/random/random.h"
     25 #include "tensorflow/core/util/event.pb.h"
     26 
     27 // TODO(jart): Break this up into multiple files with excellent unit tests.
     28 // TODO(jart): Make decision to write in separate op.
     29 // TODO(jart): Add really good busy handling.
     30 
     31 // clang-format off
     32 #define CALL_SUPPORTED_TYPES(m) \
     33   TF_CALL_string(m)             \
     34   TF_CALL_half(m)               \
     35   TF_CALL_float(m)              \
     36   TF_CALL_double(m)             \
     37   TF_CALL_complex64(m)          \
     38   TF_CALL_complex128(m)         \
     39   TF_CALL_int8(m)               \
     40   TF_CALL_int16(m)              \
     41   TF_CALL_int32(m)              \
     42   TF_CALL_int64(m)              \
     43   TF_CALL_uint8(m)              \
     44   TF_CALL_uint16(m)             \
     45   TF_CALL_uint32(m)             \
     46   TF_CALL_uint64(m)
     47 // clang-format on
     48 
     49 namespace tensorflow {
     50 namespace {
     51 
     52 // https://www.sqlite.org/fileformat.html#record_format
     53 const uint64 kIdTiers[] = {
     54     0x7fffffULL,        // 23-bit (3 bytes on disk)
     55     0x7fffffffULL,      // 31-bit (4 bytes on disk)
     56     0x7fffffffffffULL,  // 47-bit (5 bytes on disk)
     57                         // remaining bits for future use
     58 };
     59 const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64);
     60 const int kIdCollisionDelayMicros = 10;
     61 const int kMaxIdCollisions = 21;  // sum(2**i*10s for i in range(21))~=21s
     62 const int64 kAbsent = 0LL;
     63 
     64 const char* kScalarPluginName = "scalars";
     65 const char* kImagePluginName = "images";
     66 const char* kAudioPluginName = "audio";
     67 const char* kHistogramPluginName = "histograms";
     68 
     69 const int kScalarSlots = 10000;
     70 const int kImageSlots = 10;
     71 const int kAudioSlots = 10;
     72 const int kHistogramSlots = 1;
     73 const int kTensorSlots = 10;
     74 
     75 const int64 kReserveMinBytes = 32;
     76 const double kReserveMultiplier = 1.5;
     77 
     78 // Flush is a misnomer because what we're actually doing is having lots
     79 // of commits inside any SqliteTransaction that writes potentially
     80 // hundreds of megs but doesn't need the transaction to maintain its
     81 // invariants. This ensures the WAL read penalty is small and might
     82 // allow writers in other processes a chance to schedule.
     83 const uint64 kFlushBytes = 1024 * 1024;
     84 
     85 double DoubleTime(uint64 micros) {
     86   // TODO(@jart): Follow precise definitions for time laid out in schema.
     87   // TODO(@jart): Use monotonic clock from gRPC codebase.
     88   return static_cast<double>(micros) / 1.0e6;
     89 }
     90 
     91 string StringifyShape(const TensorShape& shape) {
     92   string result;
     93   bool first = true;
     94   for (const auto& dim : shape) {
     95     if (first) {
     96       first = false;
     97     } else {
     98       strings::StrAppend(&result, ",");
     99     }
    100     strings::StrAppend(&result, dim.size);
    101   }
    102   return result;
    103 }
    104 
    105 Status CheckSupportedType(const Tensor& t) {
    106 #define CASE(T)                  \
    107   case DataTypeToEnum<T>::value: \
    108     break;
    109   switch (t.dtype()) {
    110     CALL_SUPPORTED_TYPES(CASE)
    111     default:
    112       return errors::Unimplemented(DataTypeString(t.dtype()),
    113                                    " tensors unsupported on platform");
    114   }
    115   return Status::OK();
    116 #undef CASE
    117 }
    118 
    119 Tensor AsScalar(const Tensor& t) {
    120   Tensor t2{t.dtype(), {}};
    121 #define CASE(T)                        \
    122   case DataTypeToEnum<T>::value:       \
    123     t2.scalar<T>()() = t.flat<T>()(0); \
    124     break;
    125   switch (t.dtype()) {
    126     CALL_SUPPORTED_TYPES(CASE)
    127     default:
    128       t2 = {DT_FLOAT, {}};
    129       t2.scalar<float>()() = NAN;
    130       break;
    131   }
    132   return t2;
    133 #undef CASE
    134 }
    135 
    136 void PatchPluginName(SummaryMetadata* metadata, const char* name) {
    137   if (metadata->plugin_data().plugin_name().empty()) {
    138     metadata->mutable_plugin_data()->set_plugin_name(name);
    139   }
    140 }
    141 
    142 int GetSlots(const Tensor& t, const SummaryMetadata& metadata) {
    143   if (metadata.plugin_data().plugin_name() == kScalarPluginName) {
    144     return kScalarSlots;
    145   } else if (metadata.plugin_data().plugin_name() == kImagePluginName) {
    146     return kImageSlots;
    147   } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) {
    148     return kAudioSlots;
    149   } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) {
    150     return kHistogramSlots;
    151   } else if (t.dims() == 0 && t.dtype() != DT_STRING) {
    152     return kScalarSlots;
    153   } else {
    154     return kTensorSlots;
    155   }
    156 }
    157 
    158 Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) {
    159   const char* sql = R"sql(
    160     INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?)
    161   )sql";
    162   SqliteStatement insert_desc;
    163   TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc));
    164   insert_desc.BindInt(1, id);
    165   insert_desc.BindText(2, markdown);
    166   return insert_desc.StepAndReset();
    167 }
    168 
    169 /// \brief Generates unique IDs randomly in the [1,2**63-1] range.
    170 ///
    171 /// This class starts off generating IDs in the [1,2**23-1] range,
    172 /// because it's human friendly and occupies 4 bytes max on disk with
    173 /// SQLite's zigzag varint encoding. Then, each time a collision
    174 /// happens, the random space is increased by 8 bits.
    175 ///
    176 /// This class uses exponential back-off so writes gradually slow down
    177 /// as IDs become exhausted but reads are still possible.
    178 ///
    179 /// This class is thread safe.
    180 class IdAllocator {
    181  public:
    182   IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} {
    183     DCHECK(env_ != nullptr);
    184     DCHECK(db_ != nullptr);
    185   }
    186 
    187   Status CreateNewId(int64* id) LOCKS_EXCLUDED(mu_) {
    188     mutex_lock lock(mu_);
    189     Status s;
    190     SqliteStatement stmt;
    191     TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt));
    192     for (int i = 0; i < kMaxIdCollisions; ++i) {
    193       int64 tid = MakeRandomId();
    194       stmt.BindInt(1, tid);
    195       s = stmt.StepAndReset();
    196       if (s.ok()) {
    197         *id = tid;
    198         break;
    199       }
    200       // SQLITE_CONSTRAINT maps to INVALID_ARGUMENT in sqlite.cc
    201       if (s.code() != error::INVALID_ARGUMENT) break;
    202       if (tier_ < kMaxIdTier) {
    203         LOG(INFO) << "IdAllocator collision at tier " << tier_ << " (of "
    204                   << kMaxIdTier << ") so auto-adjusting to a higher tier";
    205         ++tier_;
    206       } else {
    207         LOG(WARNING) << "IdAllocator (attempt #" << i << ") "
    208                      << "resulted in a collision at the highest tier; this "
    209                         "is problematic if it happens often; you can try "
    210                         "pruning the Ids table; you can also file a bug "
    211                         "asking for the ID space to be increased; otherwise "
    212                         "writes will gradually slow down over time until they "
    213                         "become impossible";
    214       }
    215       env_->SleepForMicroseconds((1 << i) * kIdCollisionDelayMicros);
    216     }
    217     return s;
    218   }
    219 
    220  private:
    221   int64 MakeRandomId() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    222     int64 id = static_cast<int64>(random::New64() & kIdTiers[tier_]);
    223     if (id == kAbsent) ++id;
    224     return id;
    225   }
    226 
    227   mutex mu_;
    228   Env* const env_;
    229   Sqlite* const db_;
    230   int tier_ GUARDED_BY(mu_) = 0;
    231 
    232   TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator);
    233 };
    234 
    235 class GraphWriter {
    236  public:
    237   static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids,
    238                      GraphDef* graph, uint64 now, int64 run_id, int64* graph_id)
    239       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
    240     TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id));
    241     GraphWriter saver{db, txn, graph, now, *graph_id};
    242     saver.MapNameToNodeId();
    243     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs");
    244     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes");
    245     TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph");
    246     return Status::OK();
    247   }
    248 
    249  private:
    250   GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now,
    251               int64 graph_id)
    252       : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {}
    253 
    254   void MapNameToNodeId() {
    255     size_t toto = static_cast<size_t>(graph_->node_size());
    256     name_copies_.reserve(toto);
    257     name_to_node_id_.reserve(toto);
    258     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
    259       // Copy name into memory region, since we call clear_name() later.
    260       // Then wrap in StringPiece so we can compare slices without copy.
    261       name_copies_.emplace_back(graph_->node(node_id).name());
    262       name_to_node_id_.emplace(name_copies_.back(), node_id);
    263     }
    264   }
    265 
    266   Status SaveNodeInputs() {
    267     const char* sql = R"sql(
    268       INSERT INTO NodeInputs (
    269         graph_id,
    270         node_id,
    271         idx,
    272         input_node_id,
    273         input_node_idx,
    274         is_control
    275       ) VALUES (?, ?, ?, ?, ?, ?)
    276     )sql";
    277     SqliteStatement insert;
    278     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
    279     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
    280       const NodeDef& node = graph_->node(node_id);
    281       for (int idx = 0; idx < node.input_size(); ++idx) {
    282         StringPiece name = node.input(idx);
    283         int64 input_node_id;
    284         int64 input_node_idx = 0;
    285         int64 is_control = 0;
    286         size_t i = name.rfind(':');
    287         if (i != StringPiece::npos) {
    288           if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1),
    289                                      &input_node_idx)) {
    290             return errors::DataLoss("Bad NodeDef.input: ", name);
    291           }
    292           name.remove_suffix(name.size() - i);
    293         }
    294         if (!name.empty() && name[0] == '^') {
    295           name.remove_prefix(1);
    296           is_control = 1;
    297         }
    298         auto e = name_to_node_id_.find(name);
    299         if (e == name_to_node_id_.end()) {
    300           return errors::DataLoss("Could not find node: ", name);
    301         }
    302         input_node_id = e->second;
    303         insert.BindInt(1, graph_id_);
    304         insert.BindInt(2, node_id);
    305         insert.BindInt(3, idx);
    306         insert.BindInt(4, input_node_id);
    307         insert.BindInt(5, input_node_idx);
    308         insert.BindInt(6, is_control);
    309         unflushed_bytes_ += insert.size();
    310         TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(),
    311                                         " -> ", name);
    312         TF_RETURN_IF_ERROR(MaybeFlush());
    313       }
    314     }
    315     return Status::OK();
    316   }
    317 
    318   Status SaveNodes() {
    319     const char* sql = R"sql(
    320       INSERT INTO Nodes (
    321         graph_id,
    322         node_id,
    323         node_name,
    324         op,
    325         device,
    326         node_def)
    327       VALUES (?, ?, ?, ?, ?, ?)
    328     )sql";
    329     SqliteStatement insert;
    330     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
    331     for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
    332       NodeDef* node = graph_->mutable_node(node_id);
    333       insert.BindInt(1, graph_id_);
    334       insert.BindInt(2, node_id);
    335       insert.BindText(3, node->name());
    336       insert.BindText(4, node->op());
    337       insert.BindText(5, node->device());
    338       node->clear_name();
    339       node->clear_op();
    340       node->clear_device();
    341       node->clear_input();
    342       string node_def;
    343       if (node->SerializeToString(&node_def)) {
    344         insert.BindBlobUnsafe(6, node_def);
    345       }
    346       unflushed_bytes_ += insert.size();
    347       TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name());
    348       TF_RETURN_IF_ERROR(MaybeFlush());
    349     }
    350     return Status::OK();
    351   }
    352 
    353   Status SaveGraph(int64 run_id) {
    354     const char* sql = R"sql(
    355       INSERT OR REPLACE INTO Graphs (
    356         run_id,
    357         graph_id,
    358         inserted_time,
    359         graph_def
    360       ) VALUES (?, ?, ?, ?)
    361     )sql";
    362     SqliteStatement insert;
    363     TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
    364     if (run_id != kAbsent) insert.BindInt(1, run_id);
    365     insert.BindInt(2, graph_id_);
    366     insert.BindDouble(3, DoubleTime(now_));
    367     graph_->clear_node();
    368     string graph_def;
    369     if (graph_->SerializeToString(&graph_def)) {
    370       insert.BindBlobUnsafe(4, graph_def);
    371     }
    372     return insert.StepAndReset();
    373   }
    374 
    375   Status MaybeFlush() {
    376     if (unflushed_bytes_ >= kFlushBytes) {
    377       TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ",
    378                                       unflushed_bytes_, " bytes");
    379       unflushed_bytes_ = 0;
    380     }
    381     return Status::OK();
    382   }
    383 
    384   Sqlite* const db_;
    385   SqliteTransaction* const txn_;
    386   uint64 unflushed_bytes_ = 0;
    387   GraphDef* const graph_;
    388   const uint64 now_;
    389   const int64 graph_id_;
    390   std::vector<string> name_copies_;
    391   std::unordered_map<StringPiece, int64, StringPieceHasher> name_to_node_id_;
    392 
    393   TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter);
    394 };
    395 
    396 /// \brief Run metadata manager.
    397 ///
    398 /// This class gives us Tag IDs we can pass to SeriesWriter. In order
    399 /// to do that, rows are created in the Ids, Tags, Runs, Experiments,
    400 /// and Users tables.
    401 ///
    402 /// This class is thread safe.
    403 class RunMetadata {
    404  public:
    405   RunMetadata(IdAllocator* ids, const string& experiment_name,
    406               const string& run_name, const string& user_name)
    407       : ids_{ids},
    408         experiment_name_{experiment_name},
    409         run_name_{run_name},
    410         user_name_{user_name} {
    411     DCHECK(ids_ != nullptr);
    412   }
    413 
    414   const string& experiment_name() { return experiment_name_; }
    415   const string& run_name() { return run_name_; }
    416   const string& user_name() { return user_name_; }
    417 
    418   int64 run_id() LOCKS_EXCLUDED(mu_) {
    419     mutex_lock lock(mu_);
    420     return run_id_;
    421   }
    422 
    423   Status SetGraph(Sqlite* db, uint64 now, double computed_time,
    424                   std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db)
    425       LOCKS_EXCLUDED(mu_) {
    426     int64 run_id;
    427     {
    428       mutex_lock lock(mu_);
    429       TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
    430       run_id = run_id_;
    431     }
    432     int64 graph_id;
    433     SqliteTransaction txn(*db);  // only to increase performance
    434     TF_RETURN_IF_ERROR(
    435         GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id));
    436     return txn.Commit();
    437   }
    438 
    439   Status GetTagId(Sqlite* db, uint64 now, double computed_time,
    440                   const string& tag_name, int64* tag_id,
    441                   const SummaryMetadata& metadata) LOCKS_EXCLUDED(mu_) {
    442     mutex_lock lock(mu_);
    443     TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
    444     auto e = tag_ids_.find(tag_name);
    445     if (e != tag_ids_.end()) {
    446       *tag_id = e->second;
    447       return Status::OK();
    448     }
    449     TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id));
    450     tag_ids_[tag_name] = *tag_id;
    451     TF_RETURN_IF_ERROR(
    452         SetDescription(db, *tag_id, metadata.summary_description()));
    453     const char* sql = R"sql(
    454       INSERT INTO Tags (
    455         run_id,
    456         tag_id,
    457         tag_name,
    458         inserted_time,
    459         display_name,
    460         plugin_name,
    461         plugin_data
    462       ) VALUES (
    463         :run_id,
    464         :tag_id,
    465         :tag_name,
    466         :inserted_time,
    467         :display_name,
    468         :plugin_name,
    469         :plugin_data
    470       )
    471     )sql";
    472     SqliteStatement insert;
    473     TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
    474     if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_);
    475     insert.BindInt(":tag_id", *tag_id);
    476     insert.BindTextUnsafe(":tag_name", tag_name);
    477     insert.BindDouble(":inserted_time", DoubleTime(now));
    478     insert.BindTextUnsafe(":display_name", metadata.display_name());
    479     insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name());
    480     insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content());
    481     return insert.StepAndReset();
    482   }
    483 
    484   Status GetIsWatching(Sqlite* db, bool* is_watching)
    485       SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
    486     mutex_lock lock(mu_);
    487     if (experiment_id_ == kAbsent) {
    488       *is_watching = true;
    489       return Status::OK();
    490     }
    491     const char* sql = R"sql(
    492       SELECT is_watching FROM Experiments WHERE experiment_id = ?
    493     )sql";
    494     SqliteStatement stmt;
    495     TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
    496     stmt.BindInt(1, experiment_id_);
    497     TF_RETURN_IF_ERROR(stmt.StepOnce());
    498     *is_watching = stmt.ColumnInt(0) != 0;
    499     return Status::OK();
    500   }
    501 
    502  private:
    503   Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    504     if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
    505     const char* get_sql = R"sql(
    506       SELECT user_id FROM Users WHERE user_name = ?
    507     )sql";
    508     SqliteStatement get;
    509     TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
    510     get.BindText(1, user_name_);
    511     bool is_done;
    512     TF_RETURN_IF_ERROR(get.Step(&is_done));
    513     if (!is_done) {
    514       user_id_ = get.ColumnInt(0);
    515       return Status::OK();
    516     }
    517     TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_));
    518     const char* insert_sql = R"sql(
    519       INSERT INTO Users (
    520         user_id,
    521         user_name,
    522         inserted_time
    523       ) VALUES (?, ?, ?)
    524     )sql";
    525     SqliteStatement insert;
    526     TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
    527     insert.BindInt(1, user_id_);
    528     insert.BindText(2, user_name_);
    529     insert.BindDouble(3, DoubleTime(now));
    530     TF_RETURN_IF_ERROR(insert.StepAndReset());
    531     return Status::OK();
    532   }
    533 
    534   Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time)
    535       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    536     if (experiment_name_.empty()) return Status::OK();
    537     if (experiment_id_ == kAbsent) {
    538       TF_RETURN_IF_ERROR(InitializeUser(db, now));
    539       const char* get_sql = R"sql(
    540         SELECT
    541           experiment_id,
    542           started_time
    543         FROM
    544           Experiments
    545         WHERE
    546           user_id IS ?
    547           AND experiment_name = ?
    548       )sql";
    549       SqliteStatement get;
    550       TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
    551       if (user_id_ != kAbsent) get.BindInt(1, user_id_);
    552       get.BindText(2, experiment_name_);
    553       bool is_done;
    554       TF_RETURN_IF_ERROR(get.Step(&is_done));
    555       if (!is_done) {
    556         experiment_id_ = get.ColumnInt(0);
    557         experiment_started_time_ = get.ColumnInt(1);
    558       } else {
    559         TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_));
    560         experiment_started_time_ = computed_time;
    561         const char* insert_sql = R"sql(
    562           INSERT INTO Experiments (
    563             user_id,
    564             experiment_id,
    565             experiment_name,
    566             inserted_time,
    567             started_time,
    568             is_watching
    569           ) VALUES (?, ?, ?, ?, ?, ?)
    570         )sql";
    571         SqliteStatement insert;
    572         TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
    573         if (user_id_ != kAbsent) insert.BindInt(1, user_id_);
    574         insert.BindInt(2, experiment_id_);
    575         insert.BindText(3, experiment_name_);
    576         insert.BindDouble(4, DoubleTime(now));
    577         insert.BindDouble(5, computed_time);
    578         insert.BindInt(6, 0);
    579         TF_RETURN_IF_ERROR(insert.StepAndReset());
    580       }
    581     }
    582     if (computed_time < experiment_started_time_) {
    583       experiment_started_time_ = computed_time;
    584       const char* update_sql = R"sql(
    585         UPDATE
    586           Experiments
    587         SET
    588           started_time = ?
    589         WHERE
    590           experiment_id = ?
    591       )sql";
    592       SqliteStatement update;
    593       TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
    594       update.BindDouble(1, computed_time);
    595       update.BindInt(2, experiment_id_);
    596       TF_RETURN_IF_ERROR(update.StepAndReset());
    597     }
    598     return Status::OK();
    599   }
    600 
    601   Status InitializeRun(Sqlite* db, uint64 now, double computed_time)
    602       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    603     if (run_name_.empty()) return Status::OK();
    604     TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time));
    605     if (run_id_ == kAbsent) {
    606       TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_));
    607       run_started_time_ = computed_time;
    608       const char* insert_sql = R"sql(
    609         INSERT OR REPLACE INTO Runs (
    610           experiment_id,
    611           run_id,
    612           run_name,
    613           inserted_time,
    614           started_time
    615         ) VALUES (?, ?, ?, ?, ?)
    616       )sql";
    617       SqliteStatement insert;
    618       TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
    619       if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_);
    620       insert.BindInt(2, run_id_);
    621       insert.BindText(3, run_name_);
    622       insert.BindDouble(4, DoubleTime(now));
    623       insert.BindDouble(5, computed_time);
    624       TF_RETURN_IF_ERROR(insert.StepAndReset());
    625     }
    626     if (computed_time < run_started_time_) {
    627       run_started_time_ = computed_time;
    628       const char* update_sql = R"sql(
    629         UPDATE
    630           Runs
    631         SET
    632           started_time = ?
    633         WHERE
    634           run_id = ?
    635       )sql";
    636       SqliteStatement update;
    637       TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
    638       update.BindDouble(1, computed_time);
    639       update.BindInt(2, run_id_);
    640       TF_RETURN_IF_ERROR(update.StepAndReset());
    641     }
    642     return Status::OK();
    643   }
    644 
    645   mutex mu_;
    646   IdAllocator* const ids_;
    647   const string experiment_name_;
    648   const string run_name_;
    649   const string user_name_;
    650   int64 experiment_id_ GUARDED_BY(mu_) = kAbsent;
    651   int64 run_id_ GUARDED_BY(mu_) = kAbsent;
    652   int64 user_id_ GUARDED_BY(mu_) = kAbsent;
    653   double experiment_started_time_ GUARDED_BY(mu_) = 0.0;
    654   double run_started_time_ GUARDED_BY(mu_) = 0.0;
    655   std::unordered_map<string, int64> tag_ids_ GUARDED_BY(mu_);
    656 
    657   TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata);
    658 };
    659 
    660 /// \brief Tensor writer for a single series, e.g. Tag.
    661 ///
    662 /// This class can be used to write an infinite stream of Tensors to the
    663 /// database in a fixed block of contiguous disk space. This is
    664 /// accomplished using Algorithm R reservoir sampling.
    665 ///
    666 /// The reservoir consists of a fixed number of rows, which are inserted
    667 /// using ZEROBLOB upon receiving the first sample, which is used to
    668 /// predict how big the other ones are likely to be. This is done
    669 /// transactionally in a way that tries to be mindful of other processes
    670 /// that might be trying to access the same DB.
    671 ///
    672 /// Once the reservoir fills up, rows are replaced at random, and writes
    673 /// gradually become no-ops. This allows long training to go fast
    674 /// without configuration. The exception is when someone is actually
    675 /// looking at TensorBoard. When that happens, the "keep last" behavior
    676 /// is turned on and Append() will always result in a write.
    677 ///
    678 /// If no one is watching training, this class still holds on to the
    679 /// most recent "dangling" Tensor, so if Finish() is called, the most
    680 /// recent training state can be written to disk.
    681 ///
    682 /// The randomly selected sampling points should be consistent across
    683 /// multiple instances.
    684 ///
    685 /// This class is thread safe.
    686 class SeriesWriter {
    687  public:
    688   SeriesWriter(int64 series, int slots, RunMetadata* meta)
    689       : series_{series},
    690         slots_{slots},
    691         meta_{meta},
    692         rng_{std::mt19937_64::default_seed} {
    693     DCHECK(series_ > 0);
    694     DCHECK(slots_ > 0);
    695   }
    696 
    697   Status Append(Sqlite* db, int64 step, uint64 now, double computed_time,
    698                 Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
    699       LOCKS_EXCLUDED(mu_) {
    700     mutex_lock lock(mu_);
    701     if (rowids_.empty()) {
    702       Status s = Reserve(db, t);
    703       if (!s.ok()) {
    704         rowids_.clear();
    705         return s;
    706       }
    707     }
    708     DCHECK(rowids_.size() == slots_);
    709     int64 rowid;
    710     size_t i = count_;
    711     if (i < slots_) {
    712       rowid = last_rowid_ = rowids_[i];
    713     } else {
    714       i = rng_() % (i + 1);
    715       if (i < slots_) {
    716         rowid = last_rowid_ = rowids_[i];
    717       } else {
    718         bool keep_last;
    719         TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last));
    720         if (!keep_last) {
    721           ++count_;
    722           dangling_tensor_.reset(new Tensor(std::move(t)));
    723           dangling_step_ = step;
    724           dangling_computed_time_ = computed_time;
    725           return Status::OK();
    726         }
    727         rowid = last_rowid_;
    728       }
    729     }
    730     Status s = Write(db, rowid, step, computed_time, t);
    731     if (s.ok()) {
    732       ++count_;
    733       dangling_tensor_.reset();
    734     }
    735     return s;
    736   }
    737 
    738   Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
    739       LOCKS_EXCLUDED(mu_) {
    740     mutex_lock lock(mu_);
    741     // Short runs: Delete unused pre-allocated Tensors.
    742     if (count_ < rowids_.size()) {
    743       SqliteTransaction txn(*db);
    744       const char* sql = R"sql(
    745         DELETE FROM Tensors WHERE rowid = ?
    746       )sql";
    747       SqliteStatement deleter;
    748       TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter));
    749       for (size_t i = count_; i < rowids_.size(); ++i) {
    750         deleter.BindInt(1, rowids_[i]);
    751         TF_RETURN_IF_ERROR(deleter.StepAndReset());
    752       }
    753       TF_RETURN_IF_ERROR(txn.Commit());
    754       rowids_.clear();
    755     }
    756     // Long runs: Make last sample be the very most recent one.
    757     if (dangling_tensor_) {
    758       DCHECK(last_rowid_ != kAbsent);
    759       TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_,
    760                                dangling_computed_time_, *dangling_tensor_));
    761       dangling_tensor_.reset();
    762     }
    763     return Status::OK();
    764   }
    765 
    766  private:
    767   Status Write(Sqlite* db, int64 rowid, int64 step, double computed_time,
    768                const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) {
    769     if (t.dtype() == DT_STRING) {
    770       if (t.dims() == 0) {
    771         return Update(db, step, computed_time, t, t.scalar<string>()(), rowid);
    772       } else {
    773         SqliteTransaction txn(*db);
    774         TF_RETURN_IF_ERROR(
    775             Update(db, step, computed_time, t, StringPiece(), rowid));
    776         TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid));
    777         return txn.Commit();
    778       }
    779     } else {
    780       return Update(db, step, computed_time, t, t.tensor_data(), rowid);
    781     }
    782   }
    783 
    784   Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t,
    785                 const StringPiece& data, int64 rowid) {
    786     // TODO(jart): How can we ensure reservoir fills on replace?
    787     const char* sql = R"sql(
    788       UPDATE OR REPLACE
    789         Tensors
    790       SET
    791         step = ?,
    792         computed_time = ?,
    793         dtype = ?,
    794         shape = ?,
    795         data = ?
    796       WHERE
    797         rowid = ?
    798     )sql";
    799     SqliteStatement stmt;
    800     TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
    801     stmt.BindInt(1, step);
    802     stmt.BindDouble(2, computed_time);
    803     stmt.BindInt(3, t.dtype());
    804     stmt.BindText(4, StringifyShape(t.shape()));
    805     stmt.BindBlobUnsafe(5, data);
    806     stmt.BindInt(6, rowid);
    807     TF_RETURN_IF_ERROR(stmt.StepAndReset());
    808     return Status::OK();
    809   }
    810 
    811   Status UpdateNdString(Sqlite* db, const Tensor& t, int64 tensor_rowid)
    812       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
    813     DCHECK_EQ(t.dtype(), DT_STRING);
    814     DCHECK_GT(t.dims(), 0);
    815     const char* deleter_sql = R"sql(
    816       DELETE FROM TensorStrings WHERE tensor_rowid = ?
    817     )sql";
    818     SqliteStatement deleter;
    819     TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter));
    820     deleter.BindInt(1, tensor_rowid);
    821     TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid);
    822     const char* inserter_sql = R"sql(
    823       INSERT INTO TensorStrings (
    824         tensor_rowid,
    825         idx,
    826         data
    827       ) VALUES (?, ?, ?)
    828     )sql";
    829     SqliteStatement inserter;
    830     TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter));
    831     auto flat = t.flat<string>();
    832     for (int64 i = 0; i < flat.size(); ++i) {
    833       inserter.BindInt(1, tensor_rowid);
    834       inserter.BindInt(2, i);
    835       inserter.BindBlobUnsafe(3, flat(i));
    836       TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i);
    837     }
    838     return Status::OK();
    839   }
    840 
    841   Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
    842       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    843     SqliteTransaction txn(*db);  // only for performance
    844     unflushed_bytes_ = 0;
    845     if (t.dtype() == DT_STRING) {
    846       if (t.dims() == 0) {
    847         TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<string>()().size()));
    848       } else {
    849         TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes));
    850       }
    851     } else {
    852       TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size()));
    853     }
    854     return txn.Commit();
    855   }
    856 
    857   Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size)
    858       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
    859           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    860     int64 space =
    861         static_cast<int64>(static_cast<double>(size) * kReserveMultiplier);
    862     if (space < kReserveMinBytes) space = kReserveMinBytes;
    863     return ReserveTensors(db, txn, space);
    864   }
    865 
    866   Status ReserveTensors(Sqlite* db, SqliteTransaction* txn,
    867                         int64 reserved_bytes)
    868       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
    869           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    870     const char* sql = R"sql(
    871       INSERT INTO Tensors (
    872         series,
    873         data
    874       ) VALUES (?, ZEROBLOB(?))
    875     )sql";
    876     SqliteStatement insert;
    877     TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
    878     // TODO(jart): Maybe preallocate index pages by setting step. This
    879     //             is tricky because UPDATE OR REPLACE can have a side
    880     //             effect of deleting preallocated rows.
    881     for (int64 i = 0; i < slots_; ++i) {
    882       insert.BindInt(1, series_);
    883       insert.BindInt(2, reserved_bytes);
    884       TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i);
    885       rowids_.push_back(db->last_insert_rowid());
    886       unflushed_bytes_ += reserved_bytes;
    887       TF_RETURN_IF_ERROR(MaybeFlush(db, txn));
    888     }
    889     return Status::OK();
    890   }
    891 
    892   Status MaybeFlush(Sqlite* db, SqliteTransaction* txn)
    893       SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
    894           EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    895     if (unflushed_bytes_ >= kFlushBytes) {
    896       TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ",
    897                                       unflushed_bytes_, " bytes");
    898       unflushed_bytes_ = 0;
    899     }
    900     return Status::OK();
    901   }
    902 
    903   mutex mu_;
    904   const int64 series_;
    905   const int slots_;
    906   RunMetadata* const meta_;
    907   std::mt19937_64 rng_ GUARDED_BY(mu_);
    908   uint64 count_ GUARDED_BY(mu_) = 0;
    909   int64 last_rowid_ GUARDED_BY(mu_) = kAbsent;
    910   std::vector<int64> rowids_ GUARDED_BY(mu_);
    911   uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0;
    912   std::unique_ptr<Tensor> dangling_tensor_ GUARDED_BY(mu_);
    913   int64 dangling_step_ GUARDED_BY(mu_) = 0;
    914   double dangling_computed_time_ GUARDED_BY(mu_) = 0.0;
    915 
    916   TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter);
    917 };
    918 
    919 /// \brief Tensor writer for a single Run.
    920 ///
    921 /// This class farms out tensors to SeriesWriter instances. It also
    922 /// keeps track of whether or not someone is watching the TensorBoard
    923 /// GUI, so it can avoid writes when possible.
    924 ///
    925 /// This class is thread safe.
    926 class RunWriter {
    927  public:
    928   explicit RunWriter(RunMetadata* meta) : meta_{meta} {}
    929 
    930   Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now,
    931                 double computed_time, Tensor t, int slots)
    932       SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) {
    933     SeriesWriter* writer = GetSeriesWriter(tag_id, slots);
    934     return writer->Append(db, step, now, computed_time, std::move(t));
    935   }
    936 
    937   Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
    938       LOCKS_EXCLUDED(mu_) {
    939     mutex_lock lock(mu_);
    940     if (series_writers_.empty()) return Status::OK();
    941     for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) {
    942       if (!i->second) continue;
    943       TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db),
    944                                       "finish tag_id=", i->first);
    945       i->second.reset();
    946     }
    947     return Status::OK();
    948   }
    949 
    950  private:
    951   SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) {
    952     mutex_lock sl(mu_);
    953     auto spot = series_writers_.find(tag_id);
    954     if (spot == series_writers_.end()) {
    955       SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_);
    956       series_writers_[tag_id].reset(writer);
    957       return writer;
    958     } else {
    959       return spot->second.get();
    960     }
    961   }
    962 
    963   mutex mu_;
    964   RunMetadata* const meta_;
    965   std::unordered_map<int64, std::unique_ptr<SeriesWriter>> series_writers_
    966       GUARDED_BY(mu_);
    967 
    968   TF_DISALLOW_COPY_AND_ASSIGN(RunWriter);
    969 };
    970 
    971 /// \brief SQLite implementation of SummaryWriterInterface.
    972 ///
    973 /// This class is thread safe.
    974 class SummaryDbWriter : public SummaryWriterInterface {
    975  public:
    976   SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name,
    977                   const string& run_name, const string& user_name)
    978       : SummaryWriterInterface(),
    979         env_{env},
    980         db_{db},
    981         ids_{env_, db_},
    982         meta_{&ids_, experiment_name, run_name, user_name},
    983         run_{&meta_} {
    984     DCHECK(env_ != nullptr);
    985     db_->Ref();
    986   }
    987 
    988   ~SummaryDbWriter() override {
    989     core::ScopedUnref unref(db_);
    990     Status s = run_.Finish(db_);
    991     if (!s.ok()) {
    992       // TODO(jart): Retry on transient errors here.
    993       LOG(ERROR) << s.ToString();
    994     }
    995     int64 run_id = meta_.run_id();
    996     if (run_id == kAbsent) return;
    997     const char* sql = R"sql(
    998       UPDATE Runs SET finished_time = ? WHERE run_id = ?
    999     )sql";
   1000     SqliteStatement update;
   1001     s = db_->Prepare(sql, &update);
   1002     if (s.ok()) {
   1003       update.BindDouble(1, DoubleTime(env_->NowMicros()));
   1004       update.BindInt(2, run_id);
   1005       s = update.StepAndReset();
   1006     }
   1007     if (!s.ok()) {
   1008       LOG(ERROR) << "Failed to set Runs[" << run_id
   1009                  << "].finish_time: " << s.ToString();
   1010     }
   1011   }
   1012 
   1013   Status Flush() override { return Status::OK(); }
   1014 
   1015   Status WriteTensor(int64 global_step, Tensor t, const string& tag,
   1016                      const string& serialized_metadata) override {
   1017     TF_RETURN_IF_ERROR(CheckSupportedType(t));
   1018     SummaryMetadata metadata;
   1019     if (!metadata.ParseFromString(serialized_metadata)) {
   1020       return errors::InvalidArgument("Bad serialized_metadata");
   1021     }
   1022     return Write(global_step, t, tag, metadata);
   1023   }
   1024 
   1025   Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
   1026     TF_RETURN_IF_ERROR(CheckSupportedType(t));
   1027     SummaryMetadata metadata;
   1028     PatchPluginName(&metadata, kScalarPluginName);
   1029     return Write(global_step, AsScalar(t), tag, metadata);
   1030   }
   1031 
   1032   Status WriteGraph(int64 global_step, std::unique_ptr<GraphDef> g) override {
   1033     uint64 now = env_->NowMicros();
   1034     return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g));
   1035   }
   1036 
   1037   Status WriteEvent(std::unique_ptr<Event> e) override {
   1038     return MigrateEvent(std::move(e));
   1039   }
   1040 
   1041   Status WriteHistogram(int64 global_step, Tensor t,
   1042                         const string& tag) override {
   1043     uint64 now = env_->NowMicros();
   1044     std::unique_ptr<Event> e{new Event};
   1045     e->set_step(global_step);
   1046     e->set_wall_time(DoubleTime(now));
   1047     TF_RETURN_IF_ERROR(
   1048         AddTensorAsHistogramToSummary(t, tag, e->mutable_summary()));
   1049     return MigrateEvent(std::move(e));
   1050   }
   1051 
   1052   Status WriteImage(int64 global_step, Tensor t, const string& tag,
   1053                     int max_images, Tensor bad_color) override {
   1054     uint64 now = env_->NowMicros();
   1055     std::unique_ptr<Event> e{new Event};
   1056     e->set_step(global_step);
   1057     e->set_wall_time(DoubleTime(now));
   1058     TF_RETURN_IF_ERROR(AddTensorAsImageToSummary(t, tag, max_images, bad_color,
   1059                                                  e->mutable_summary()));
   1060     return MigrateEvent(std::move(e));
   1061   }
   1062 
   1063   Status WriteAudio(int64 global_step, Tensor t, const string& tag,
   1064                     int max_outputs, float sample_rate) override {
   1065     uint64 now = env_->NowMicros();
   1066     std::unique_ptr<Event> e{new Event};
   1067     e->set_step(global_step);
   1068     e->set_wall_time(DoubleTime(now));
   1069     TF_RETURN_IF_ERROR(AddTensorAsAudioToSummary(
   1070         t, tag, max_outputs, sample_rate, e->mutable_summary()));
   1071     return MigrateEvent(std::move(e));
   1072   }
   1073 
   1074   string DebugString() override { return "SummaryDbWriter"; }
   1075 
   1076  private:
   1077   Status Write(int64 step, const Tensor& t, const string& tag,
   1078                const SummaryMetadata& metadata) {
   1079     uint64 now = env_->NowMicros();
   1080     double computed_time = DoubleTime(now);
   1081     int64 tag_id;
   1082     TF_RETURN_IF_ERROR(
   1083         meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata));
   1084     TF_RETURN_WITH_CONTEXT_IF_ERROR(
   1085         run_.Append(db_, tag_id, step, now, computed_time, t,
   1086                     GetSlots(t, metadata)),
   1087         meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(),
   1088         "/", tag, "@", step);
   1089     return Status::OK();
   1090   }
   1091 
   1092   Status MigrateEvent(std::unique_ptr<Event> e) {
   1093     switch (e->what_case()) {
   1094       case Event::WhatCase::kSummary: {
   1095         uint64 now = env_->NowMicros();
   1096         auto summaries = e->mutable_summary();
   1097         for (int i = 0; i < summaries->value_size(); ++i) {
   1098           Summary::Value* value = summaries->mutable_value(i);
   1099           TF_RETURN_WITH_CONTEXT_IF_ERROR(
   1100               MigrateSummary(e.get(), value, now), meta_.user_name(), "/",
   1101               meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(),
   1102               "@", e->step());
   1103         }
   1104         break;
   1105       }
   1106       case Event::WhatCase::kGraphDef:
   1107         TF_RETURN_WITH_CONTEXT_IF_ERROR(
   1108             MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/",
   1109             meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@",
   1110             e->step());
   1111         break;
   1112       default:
   1113         // TODO(@jart): Handle other stuff.
   1114         break;
   1115     }
   1116     return Status::OK();
   1117   }
   1118 
   1119   Status MigrateGraph(const Event* e, const string& graph_def) {
   1120     uint64 now = env_->NowMicros();
   1121     std::unique_ptr<GraphDef> graph{new GraphDef};
   1122     if (!ParseProtoUnlimited(graph.get(), graph_def)) {
   1123       return errors::InvalidArgument("bad proto");
   1124     }
   1125     return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph));
   1126   }
   1127 
   1128   Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) {
   1129     switch (s->value_case()) {
   1130       case Summary::Value::ValueCase::kTensor:
   1131         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor");
   1132         break;
   1133       case Summary::Value::ValueCase::kSimpleValue:
   1134         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar");
   1135         break;
   1136       case Summary::Value::ValueCase::kHisto:
   1137         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo");
   1138         break;
   1139       case Summary::Value::ValueCase::kImage:
   1140         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image");
   1141         break;
   1142       case Summary::Value::ValueCase::kAudio:
   1143         TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio");
   1144         break;
   1145       default:
   1146         break;
   1147     }
   1148     return Status::OK();
   1149   }
   1150 
   1151   Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) {
   1152     Tensor t;
   1153     if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto");
   1154     TF_RETURN_IF_ERROR(CheckSupportedType(t));
   1155     int64 tag_id;
   1156     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
   1157                                       &tag_id, s->metadata()));
   1158     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t,
   1159                        GetSlots(t, s->metadata()));
   1160   }
   1161 
   1162   // TODO(jart): Refactor Summary -> Tensor logic into separate file.
   1163 
   1164   Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) {
   1165     // See tensorboard/plugins/scalar/summary.py and data_compat.py
   1166     Tensor t{DT_FLOAT, {}};
   1167     t.scalar<float>()() = s->simple_value();
   1168     int64 tag_id;
   1169     PatchPluginName(s->mutable_metadata(), kScalarPluginName);
   1170     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
   1171                                       &tag_id, s->metadata()));
   1172     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
   1173                        std::move(t), kScalarSlots);
   1174   }
   1175 
   1176   Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) {
   1177     const HistogramProto& histo = s->histo();
   1178     int k = histo.bucket_size();
   1179     if (k != histo.bucket_limit_size()) {
   1180       return errors::InvalidArgument("size mismatch");
   1181     }
   1182     // See tensorboard/plugins/histogram/summary.py and data_compat.py
   1183     Tensor t{DT_DOUBLE, {k, 3}};
   1184     auto data = t.flat<double>();
   1185     for (int i = 0; i < k; ++i) {
   1186       double left_edge = ((i - 1 >= 0) ? histo.bucket_limit(i - 1)
   1187                                        : std::numeric_limits<double>::min());
   1188       double right_edge = ((i + 1 < k) ? histo.bucket_limit(i + 1)
   1189                                        : std::numeric_limits<double>::max());
   1190       data(i + 0) = left_edge;
   1191       data(i + 1) = right_edge;
   1192       data(i + 2) = histo.bucket(i);
   1193     }
   1194     int64 tag_id;
   1195     PatchPluginName(s->mutable_metadata(), kHistogramPluginName);
   1196     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
   1197                                       &tag_id, s->metadata()));
   1198     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
   1199                        std::move(t), kHistogramSlots);
   1200   }
   1201 
   1202   Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) {
   1203     // See tensorboard/plugins/image/summary.py and data_compat.py
   1204     Tensor t{DT_STRING, {3}};
   1205     auto img = s->mutable_image();
   1206     t.flat<string>()(0) = strings::StrCat(img->width());
   1207     t.flat<string>()(1) = strings::StrCat(img->height());
   1208     t.flat<string>()(2) = std::move(*img->mutable_encoded_image_string());
   1209     int64 tag_id;
   1210     PatchPluginName(s->mutable_metadata(), kImagePluginName);
   1211     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
   1212                                       &tag_id, s->metadata()));
   1213     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
   1214                        std::move(t), kImageSlots);
   1215   }
   1216 
   1217   Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) {
   1218     // See tensorboard/plugins/audio/summary.py and data_compat.py
   1219     Tensor t{DT_STRING, {1, 2}};
   1220     auto wav = s->mutable_audio();
   1221     t.flat<string>()(0) = std::move(*wav->mutable_encoded_audio_string());
   1222     t.flat<string>()(1) = "";
   1223     int64 tag_id;
   1224     PatchPluginName(s->mutable_metadata(), kAudioPluginName);
   1225     TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
   1226                                       &tag_id, s->metadata()));
   1227     return run_.Append(db_, tag_id, e->step(), now, e->wall_time(),
   1228                        std::move(t), kAudioSlots);
   1229   }
   1230 
   1231   Env* const env_;
   1232   Sqlite* const db_;
   1233   IdAllocator ids_;
   1234   RunMetadata meta_;
   1235   RunWriter run_;
   1236 };
   1237 
   1238 }  // namespace
   1239 
   1240 Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name,
   1241                              const string& run_name, const string& user_name,
   1242                              Env* env, SummaryWriterInterface** result) {
   1243   *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name);
   1244   return Status::OK();
   1245 }
   1246 
   1247 }  // namespace tensorflow
   1248