Home | History | Annotate | Download | only in db
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
     16 
     17 #include "tensorflow/contrib/tensorboard/db/schema.h"
     18 #include "tensorflow/core/framework/function.pb.h"
     19 #include "tensorflow/core/framework/graph.pb.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/summary.pb.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/lib/db/sqlite.h"
     24 #include "tensorflow/core/lib/strings/strcat.h"
     25 #include "tensorflow/core/platform/env.h"
     26 #include "tensorflow/core/platform/test.h"
     27 #include "tensorflow/core/util/event.pb.h"
     28 
     29 namespace tensorflow {
     30 namespace {
     31 
     32 Tensor MakeScalarInt64(int64 x) {
     33   Tensor t(DT_INT64, TensorShape({}));
     34   t.scalar<int64>()() = x;
     35   return t;
     36 }
     37 
     38 class FakeClockEnv : public EnvWrapper {
     39  public:
     40   FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {}
     41   void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; }
     42   uint64 NowMicros() override { return current_millis_ * 1000; }
     43   uint64 NowSeconds() override { return current_millis_ * 1000; }
     44 
     45  private:
     46   uint64 current_millis_;
     47 };
     48 
     49 class SummaryDbWriterTest : public ::testing::Test {
     50  protected:
     51   void SetUp() override {
     52     TF_ASSERT_OK(Sqlite::Open(":memory:", SQLITE_OPEN_READWRITE, &db_));
     53     TF_ASSERT_OK(SetupTensorboardSqliteDb(db_));
     54   }
     55 
     56   void TearDown() override {
     57     if (writer_ != nullptr) {
     58       writer_->Unref();
     59       writer_ = nullptr;
     60     }
     61     db_->Unref();
     62     db_ = nullptr;
     63   }
     64 
     65   int64 QueryInt(const string& sql) {
     66     SqliteStatement stmt = db_->PrepareOrDie(sql);
     67     bool is_done;
     68     Status s = stmt.Step(&is_done);
     69     if (!s.ok() || is_done) {
     70       LOG(ERROR) << s << " due to " << sql;
     71       return -1;
     72     }
     73     return stmt.ColumnInt(0);
     74   }
     75 
     76   double QueryDouble(const string& sql) {
     77     SqliteStatement stmt = db_->PrepareOrDie(sql);
     78     bool is_done;
     79     Status s = stmt.Step(&is_done);
     80     if (!s.ok() || is_done) {
     81       LOG(ERROR) << s << " due to " << sql;
     82       return -1;
     83     }
     84     return stmt.ColumnDouble(0);
     85   }
     86 
     87   string QueryString(const string& sql) {
     88     SqliteStatement stmt = db_->PrepareOrDie(sql);
     89     bool is_done;
     90     Status s = stmt.Step(&is_done);
     91     if (!s.ok() || is_done) {
     92       LOG(ERROR) << s << " due to " << sql;
     93       return "MISSINGNO";
     94     }
     95     return stmt.ColumnString(0);
     96   }
     97 
     98   FakeClockEnv env_;
     99   Sqlite* db_ = nullptr;
    100   SummaryWriterInterface* writer_ = nullptr;
    101 };
    102 
    103 TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) {
    104   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
    105                                      &writer_));
    106   TF_ASSERT_OK(writer_->Flush());
    107   writer_->Unref();
    108   writer_ = nullptr;
    109   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Ids"));
    110   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
    111   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
    112   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
    113   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tags"));
    114   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
    115 }
    116 
    117 TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
    118   SummaryMetadata metadata;
    119   metadata.set_display_name("display_name");
    120   metadata.set_summary_description("description");
    121   metadata.mutable_plugin_data()->set_plugin_name("plugin_name");
    122   metadata.mutable_plugin_data()->set_content("plugin_data");
    123   SummaryMetadata metadata_nope;
    124   metadata_nope.set_display_name("nope");
    125   metadata_nope.set_summary_description("nope");
    126   metadata_nope.mutable_plugin_data()->set_plugin_name("nope");
    127   metadata_nope.mutable_plugin_data()->set_content("nope");
    128   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
    129                                      &writer_));
    130   env_.AdvanceByMillis(23);
    131   TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
    132                                     metadata.SerializeAsString()));
    133   env_.AdvanceByMillis(23);
    134   TF_ASSERT_OK(writer_->WriteTensor(2, MakeScalarInt64(314LL), "taggy",
    135                                     metadata_nope.SerializeAsString()));
    136   TF_ASSERT_OK(writer_->Flush());
    137 
    138   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Users"));
    139   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
    140   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
    141   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
    142   ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
    143 
    144   int64 user_id = QueryInt("SELECT user_id FROM Users");
    145   int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments");
    146   int64 run_id = QueryInt("SELECT run_id FROM Runs");
    147   int64 tag_id = QueryInt("SELECT tag_id FROM Tags");
    148   EXPECT_LT(0LL, user_id);
    149   EXPECT_LT(0LL, experiment_id);
    150   EXPECT_LT(0LL, run_id);
    151   EXPECT_LT(0LL, tag_id);
    152 
    153   EXPECT_EQ("jart", QueryString("SELECT user_name FROM Users"));
    154   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Users"));
    155 
    156   EXPECT_EQ(user_id, QueryInt("SELECT user_id FROM Experiments"));
    157   EXPECT_EQ("mad-science",
    158             QueryString("SELECT experiment_name FROM Experiments"));
    159   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Experiments"));
    160 
    161   EXPECT_EQ(experiment_id, QueryInt("SELECT experiment_id FROM Runs"));
    162   EXPECT_EQ("train", QueryString("SELECT run_name FROM Runs"));
    163   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Runs"));
    164 
    165   EXPECT_EQ(run_id, QueryInt("SELECT run_id FROM Tags"));
    166   EXPECT_EQ("taggy", QueryString("SELECT tag_name FROM Tags"));
    167   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Tags"));
    168 
    169   EXPECT_EQ("display_name", QueryString("SELECT display_name FROM Tags"));
    170   EXPECT_EQ("plugin_name", QueryString("SELECT plugin_name FROM Tags"));
    171   EXPECT_EQ("plugin_data", QueryString("SELECT plugin_data FROM Tags"));
    172   EXPECT_EQ("description", QueryString("SELECT description FROM Descriptions"));
    173 
    174   EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 1"));
    175   EXPECT_EQ(0.023,
    176             QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1"));
    177 
    178   EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 2"));
    179   EXPECT_EQ(0.046,
    180             QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2"));
    181 }
    182 
    183 TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
    184   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
    185   TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", ""));
    186   TF_ASSERT_OK(writer_->Flush());
    187   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
    188   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
    189   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
    190   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
    191   ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
    192 }
    193 
    194 TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
    195   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
    196   std::unique_ptr<Event> e{new Event};
    197   e->set_step(7);
    198   e->set_wall_time(123.456);
    199   Summary::Value* s = e->mutable_summary()->add_value();
    200   s->set_tag("");
    201   s->set_simple_value(3.14f);
    202   s = e->mutable_summary()->add_value();
    203   s->set_tag("");
    204   s->set_simple_value(1.61f);
    205   TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
    206   TF_ASSERT_OK(writer_->Flush());
    207   ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
    208   ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
    209   int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = ''");
    210   int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = ''");
    211   EXPECT_GT(tag1_id, 0LL);
    212   EXPECT_GT(tag2_id, 0LL);
    213   EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
    214                          "SELECT computed_time FROM Tensors WHERE series = ",
    215                          tag1_id, " AND step = 7")));
    216   EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
    217                          "SELECT computed_time FROM Tensors WHERE series = ",
    218                          tag2_id, " AND step = 7")));
    219 }
    220 
    221 TEST_F(SummaryDbWriterTest, WriteGraph) {
    222   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_));
    223   env_.AdvanceByMillis(23);
    224   GraphDef graph;
    225   graph.mutable_library()->add_gradient()->set_function_name("funk");
    226   NodeDef* node = graph.add_node();
    227   node->set_name("x");
    228   node->set_op("Placeholder");
    229   node = graph.add_node();
    230   node->set_name("y");
    231   node->set_op("Placeholder");
    232   node = graph.add_node();
    233   node->set_name("z");
    234   node->set_op("Love");
    235   node = graph.add_node();
    236   node->set_name("+");
    237   node->set_op("Add");
    238   node->add_input("x");
    239   node->add_input("y");
    240   node->add_input("^z");
    241   node->set_device("tpu/lol");
    242   std::unique_ptr<Event> e{new Event};
    243   graph.SerializeToString(e->mutable_graph_def());
    244   TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
    245   TF_ASSERT_OK(writer_->Flush());
    246   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
    247   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs"));
    248   ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes"));
    249   ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs"));
    250 
    251   ASSERT_EQ(QueryInt("SELECT run_id FROM Runs"),
    252             QueryInt("SELECT run_id FROM Graphs"));
    253 
    254   int64 graph_id = QueryInt("SELECT graph_id FROM Graphs");
    255   EXPECT_GT(graph_id, 0LL);
    256   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs"));
    257 
    258   GraphDef graph2;
    259   graph2.ParseFromString(QueryString("SELECT graph_def FROM Graphs"));
    260   EXPECT_EQ(0, graph2.node_size());
    261   EXPECT_EQ("funk", graph2.library().gradient(0).function_name());
    262 
    263   EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0"));
    264   EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1"));
    265   EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2"));
    266   EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3"));
    267 
    268   EXPECT_EQ("Placeholder",
    269             QueryString("SELECT op FROM Nodes WHERE node_id = 0"));
    270   EXPECT_EQ("Placeholder",
    271             QueryString("SELECT op FROM Nodes WHERE node_id = 1"));
    272   EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2"));
    273   EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3"));
    274 
    275   EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0"));
    276   EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1"));
    277   EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2"));
    278   EXPECT_EQ("tpu/lol",
    279             QueryString("SELECT device FROM Nodes WHERE node_id = 3"));
    280 
    281   EXPECT_EQ(graph_id,
    282             QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0"));
    283   EXPECT_EQ(graph_id,
    284             QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1"));
    285   EXPECT_EQ(graph_id,
    286             QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2"));
    287 
    288   EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0"));
    289   EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1"));
    290   EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2"));
    291 
    292   EXPECT_EQ(0LL,
    293             QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0"));
    294   EXPECT_EQ(1LL,
    295             QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1"));
    296   EXPECT_EQ(2LL,
    297             QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2"));
    298 
    299   EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0"));
    300   EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1"));
    301   EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2"));
    302 }
    303 
    304 TEST_F(SummaryDbWriterTest, UsesIdsTable) {
    305   SummaryMetadata metadata;
    306   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
    307                                      &writer_));
    308   env_.AdvanceByMillis(23);
    309   TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
    310                                     metadata.SerializeAsString()));
    311   TF_ASSERT_OK(writer_->Flush());
    312   ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Ids"));
    313   EXPECT_EQ(4LL, QueryInt(strings::StrCat(
    314                      "SELECT COUNT(*) FROM Ids WHERE id IN (",
    315                      QueryInt("SELECT user_id FROM Users"), ", ",
    316                      QueryInt("SELECT experiment_id FROM Experiments"), ", ",
    317                      QueryInt("SELECT run_id FROM Runs"), ", ",
    318                      QueryInt("SELECT tag_id FROM Tags"), ")")));
    319 }
    320 
    321 TEST_F(SummaryDbWriterTest, SetsRunFinishedTime) {
    322   SummaryMetadata metadata;
    323   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
    324                                      &writer_));
    325   env_.AdvanceByMillis(23);
    326   TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
    327                                     metadata.SerializeAsString()));
    328   TF_ASSERT_OK(writer_->Flush());
    329   ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs"));
    330   ASSERT_EQ(0.0, QueryDouble("SELECT finished_time FROM Runs"));
    331   env_.AdvanceByMillis(23);
    332   writer_->Unref();
    333   writer_ = nullptr;
    334   ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs"));
    335   ASSERT_EQ(0.046, QueryDouble("SELECT finished_time FROM Runs"));
    336 }
    337 
    338 }  // namespace
    339 }  // namespace tensorflow
    340