Home | History | Annotate | Download | only in summary
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/summary/summary_db_writer.h"
     16 
     17 #include "tensorflow/core/summary/schema.h"
     18 #include "tensorflow/core/framework/function.pb.h"
     19 #include "tensorflow/core/framework/graph.pb.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/summary.pb.h"
     22 #include "tensorflow/core/lib/core/status_test_util.h"
     23 #include "tensorflow/core/lib/db/sqlite.h"
     24 #include "tensorflow/core/lib/strings/strcat.h"
     25 #include "tensorflow/core/platform/env.h"
     26 #include "tensorflow/core/platform/test.h"
     27 #include "tensorflow/core/util/event.pb.h"
     28 
     29 namespace tensorflow {
     30 namespace {
     31 
     32 Tensor MakeScalarInt64(int64 x) {
     33   Tensor t(DT_INT64, TensorShape({}));
     34   t.scalar<int64>()() = x;
     35   return t;
     36 }
     37 
     38 class FakeClockEnv : public EnvWrapper {
     39  public:
     40   FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {}
     41   void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; }
     42   uint64 NowMicros() override { return current_millis_ * 1000; }
     43   uint64 NowSeconds() override { return current_millis_ * 1000; }
     44 
     45  private:
     46   uint64 current_millis_;
     47 };
     48 
     49 class SummaryDbWriterTest : public ::testing::Test {
     50  protected:
     51   void SetUp() override {
     52     TF_ASSERT_OK(Sqlite::Open(":memory:", SQLITE_OPEN_READWRITE, &db_));
     53     TF_ASSERT_OK(SetupTensorboardSqliteDb(db_));
     54   }
     55 
     56   void TearDown() override {
     57     if (writer_ != nullptr) {
     58       writer_->Unref();
     59       writer_ = nullptr;
     60     }
     61     db_->Unref();
     62     db_ = nullptr;
     63   }
     64 
     65   int64 QueryInt(const string& sql) {
     66     SqliteStatement stmt = db_->PrepareOrDie(sql);
     67     bool is_done;
     68     Status s = stmt.Step(&is_done);
     69     if (!s.ok() || is_done) {
     70       LOG(ERROR) << s << " due to " << sql;
     71       return -1;
     72     }
     73     return stmt.ColumnInt(0);
     74   }
     75 
     76   double QueryDouble(const string& sql) {
     77     SqliteStatement stmt = db_->PrepareOrDie(sql);
     78     bool is_done;
     79     Status s = stmt.Step(&is_done);
     80     if (!s.ok() || is_done) {
     81       LOG(ERROR) << s << " due to " << sql;
     82       return -1;
     83     }
     84     return stmt.ColumnDouble(0);
     85   }
     86 
     87   string QueryString(const string& sql) {
     88     SqliteStatement stmt = db_->PrepareOrDie(sql);
     89     bool is_done;
     90     Status s = stmt.Step(&is_done);
     91     if (!s.ok() || is_done) {
     92       LOG(ERROR) << s << " due to " << sql;
     93       return "MISSINGNO";
     94     }
     95     return stmt.ColumnString(0);
     96   }
     97 
     98   FakeClockEnv env_;
     99   Sqlite* db_ = nullptr;
    100   SummaryWriterInterface* writer_ = nullptr;
    101 };
    102 
    103 TEST_F(SummaryDbWriterTest, WriteHistogram_VerifyTensorValues) {
    104   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "histtest", "test1", "user1", &env_,
    105                                      &writer_));
    106   int step = 0;
    107   std::unique_ptr<Event> e{new Event};
    108   e->set_step(step);
    109   e->set_wall_time(123);
    110   Summary::Value* s = e->mutable_summary()->add_value();
    111   s->set_tag("normal/myhisto");
    112 
    113   double dummy_value = 10.123;
    114   HistogramProto* proto = s->mutable_histo();
    115   proto->Clear();
    116   proto->set_min(dummy_value);
    117   proto->set_max(dummy_value);
    118   proto->set_num(dummy_value);
    119   proto->set_sum(dummy_value);
    120   proto->set_sum_squares(dummy_value);
    121 
    122   int size = 3;
    123   double bucket_limits[] = {-30.5, -10.5, -5.5};
    124   double bucket[] = {-10, 10, 20};
    125   for (int i = 0; i < size; i++) {
    126     proto->add_bucket_limit(bucket_limits[i]);
    127     proto->add_bucket(bucket[i]);
    128   }
    129   TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
    130   TF_ASSERT_OK(writer_->Flush());
    131   writer_->Unref();
    132   writer_ = nullptr;
    133 
    134   // TODO(nickfelt): implement QueryTensor() to encapsulate this
    135   // Verify the data
    136   string result = QueryString("SELECT data FROM Tensors");
    137   const double* val = reinterpret_cast<const double*>(result.data());
    138   double histarray[] = {std::numeric_limits<double>::min(),
    139                         -30.5,
    140                         -10,
    141                         -30.5,
    142                         -10.5,
    143                         10,
    144                         -10.5,
    145                         -5.5,
    146                         20};
    147   int histarray_size = 9;
    148   for (int i = 0; i < histarray_size; i++) {
    149     EXPECT_EQ(histarray[i], val[i]);
    150   }
    151 }
    152 
    153 TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) {
    154   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
    155                                      &writer_));
    156   TF_ASSERT_OK(writer_->Flush());
    157   writer_->Unref();
    158   writer_ = nullptr;
    159   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Ids"));
    160   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
    161   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
    162   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
    163   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tags"));
    164   EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
    165 }
    166 
    167 TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
    168   SummaryMetadata metadata;
    169   metadata.set_display_name("display_name");
    170   metadata.set_summary_description("description");
    171   metadata.mutable_plugin_data()->set_plugin_name("plugin_name");
    172   metadata.mutable_plugin_data()->set_content("plugin_data");
    173   SummaryMetadata metadata_nope;
    174   metadata_nope.set_display_name("nope");
    175   metadata_nope.set_summary_description("nope");
    176   metadata_nope.mutable_plugin_data()->set_plugin_name("nope");
    177   metadata_nope.mutable_plugin_data()->set_content("nope");
    178   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
    179                                      &writer_));
    180   env_.AdvanceByMillis(23);
    181   TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
    182                                     metadata.SerializeAsString()));
    183   env_.AdvanceByMillis(23);
    184   TF_ASSERT_OK(writer_->WriteTensor(2, MakeScalarInt64(314LL), "taggy",
    185                                     metadata_nope.SerializeAsString()));
    186   TF_ASSERT_OK(writer_->Flush());
    187 
    188   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Users"));
    189   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
    190   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
    191   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
    192   ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
    193 
    194   int64 user_id = QueryInt("SELECT user_id FROM Users");
    195   int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments");
    196   int64 run_id = QueryInt("SELECT run_id FROM Runs");
    197   int64 tag_id = QueryInt("SELECT tag_id FROM Tags");
    198   EXPECT_LT(0LL, user_id);
    199   EXPECT_LT(0LL, experiment_id);
    200   EXPECT_LT(0LL, run_id);
    201   EXPECT_LT(0LL, tag_id);
    202 
    203   EXPECT_EQ("jart", QueryString("SELECT user_name FROM Users"));
    204   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Users"));
    205 
    206   EXPECT_EQ(user_id, QueryInt("SELECT user_id FROM Experiments"));
    207   EXPECT_EQ("mad-science",
    208             QueryString("SELECT experiment_name FROM Experiments"));
    209   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Experiments"));
    210 
    211   EXPECT_EQ(experiment_id, QueryInt("SELECT experiment_id FROM Runs"));
    212   EXPECT_EQ("train", QueryString("SELECT run_name FROM Runs"));
    213   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Runs"));
    214 
    215   EXPECT_EQ(run_id, QueryInt("SELECT run_id FROM Tags"));
    216   EXPECT_EQ("taggy", QueryString("SELECT tag_name FROM Tags"));
    217   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Tags"));
    218 
    219   EXPECT_EQ("display_name", QueryString("SELECT display_name FROM Tags"));
    220   EXPECT_EQ("plugin_name", QueryString("SELECT plugin_name FROM Tags"));
    221   EXPECT_EQ("plugin_data", QueryString("SELECT plugin_data FROM Tags"));
    222   EXPECT_EQ("description", QueryString("SELECT description FROM Descriptions"));
    223 
    224   EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 1"));
    225   EXPECT_EQ(0.023,
    226             QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1"));
    227 
    228   EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 2"));
    229   EXPECT_EQ(0.046,
    230             QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2"));
    231 }
    232 
    233 TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
    234   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
    235   TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", ""));
    236   TF_ASSERT_OK(writer_->Flush());
    237   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
    238   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
    239   ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
    240   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
    241   ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
    242 }
    243 
    244 TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
    245   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
    246   std::unique_ptr<Event> e{new Event};
    247   e->set_step(7);
    248   e->set_wall_time(123.456);
    249   Summary::Value* s = e->mutable_summary()->add_value();
    250   s->set_tag("");
    251   s->set_simple_value(3.14f);
    252   s = e->mutable_summary()->add_value();
    253   s->set_tag("");
    254   s->set_simple_value(1.61f);
    255   TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
    256   TF_ASSERT_OK(writer_->Flush());
    257   ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
    258   ASSERT_EQ(2000LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
    259   int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = ''");
    260   int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = ''");
    261   EXPECT_GT(tag1_id, 0LL);
    262   EXPECT_GT(tag2_id, 0LL);
    263   EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
    264                          "SELECT computed_time FROM Tensors WHERE series = ",
    265                          tag1_id, " AND step = 7")));
    266   EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
    267                          "SELECT computed_time FROM Tensors WHERE series = ",
    268                          tag2_id, " AND step = 7")));
    269 }
    270 
    271 TEST_F(SummaryDbWriterTest, WriteGraph) {
    272   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_));
    273   env_.AdvanceByMillis(23);
    274   GraphDef graph;
    275   graph.mutable_library()->add_gradient()->set_function_name("funk");
    276   NodeDef* node = graph.add_node();
    277   node->set_name("x");
    278   node->set_op("Placeholder");
    279   node = graph.add_node();
    280   node->set_name("y");
    281   node->set_op("Placeholder");
    282   node = graph.add_node();
    283   node->set_name("z");
    284   node->set_op("Love");
    285   node = graph.add_node();
    286   node->set_name("+");
    287   node->set_op("Add");
    288   node->add_input("x");
    289   node->add_input("y");
    290   node->add_input("^z");
    291   node->set_device("tpu/lol");
    292   std::unique_ptr<Event> e{new Event};
    293   graph.SerializeToString(e->mutable_graph_def());
    294   TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
    295   TF_ASSERT_OK(writer_->Flush());
    296   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs"));
    297   ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs"));
    298   ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes"));
    299   ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs"));
    300 
    301   ASSERT_EQ(QueryInt("SELECT run_id FROM Runs"),
    302             QueryInt("SELECT run_id FROM Graphs"));
    303 
    304   int64 graph_id = QueryInt("SELECT graph_id FROM Graphs");
    305   EXPECT_GT(graph_id, 0LL);
    306   EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs"));
    307 
    308   GraphDef graph2;
    309   graph2.ParseFromString(QueryString("SELECT graph_def FROM Graphs"));
    310   EXPECT_EQ(0, graph2.node_size());
    311   EXPECT_EQ("funk", graph2.library().gradient(0).function_name());
    312 
    313   EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0"));
    314   EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1"));
    315   EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2"));
    316   EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3"));
    317 
    318   EXPECT_EQ("Placeholder",
    319             QueryString("SELECT op FROM Nodes WHERE node_id = 0"));
    320   EXPECT_EQ("Placeholder",
    321             QueryString("SELECT op FROM Nodes WHERE node_id = 1"));
    322   EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2"));
    323   EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3"));
    324 
    325   EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0"));
    326   EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1"));
    327   EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2"));
    328   EXPECT_EQ("tpu/lol",
    329             QueryString("SELECT device FROM Nodes WHERE node_id = 3"));
    330 
    331   EXPECT_EQ(graph_id,
    332             QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0"));
    333   EXPECT_EQ(graph_id,
    334             QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1"));
    335   EXPECT_EQ(graph_id,
    336             QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2"));
    337 
    338   EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0"));
    339   EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1"));
    340   EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2"));
    341 
    342   EXPECT_EQ(0LL,
    343             QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0"));
    344   EXPECT_EQ(1LL,
    345             QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1"));
    346   EXPECT_EQ(2LL,
    347             QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2"));
    348 
    349   EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0"));
    350   EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1"));
    351   EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2"));
    352 }
    353 
    354 TEST_F(SummaryDbWriterTest, UsesIdsTable) {
    355   SummaryMetadata metadata;
    356   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
    357                                      &writer_));
    358   env_.AdvanceByMillis(23);
    359   TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
    360                                     metadata.SerializeAsString()));
    361   TF_ASSERT_OK(writer_->Flush());
    362   ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Ids"));
    363   EXPECT_EQ(4LL, QueryInt(strings::StrCat(
    364                      "SELECT COUNT(*) FROM Ids WHERE id IN (",
    365                      QueryInt("SELECT user_id FROM Users"), ", ",
    366                      QueryInt("SELECT experiment_id FROM Experiments"), ", ",
    367                      QueryInt("SELECT run_id FROM Runs"), ", ",
    368                      QueryInt("SELECT tag_id FROM Tags"), ")")));
    369 }
    370 
    371 TEST_F(SummaryDbWriterTest, SetsRunFinishedTime) {
    372   SummaryMetadata metadata;
    373   TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_,
    374                                      &writer_));
    375   env_.AdvanceByMillis(23);
    376   TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
    377                                     metadata.SerializeAsString()));
    378   TF_ASSERT_OK(writer_->Flush());
    379   ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs"));
    380   ASSERT_EQ(0.0, QueryDouble("SELECT finished_time FROM Runs"));
    381   env_.AdvanceByMillis(23);
    382   writer_->Unref();
    383   writer_ = nullptr;
    384   ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs"));
    385   ASSERT_EQ(0.046, QueryDouble("SELECT finished_time FROM Runs"));
    386 }
    387 
    388 }  // namespace
    389 }  // namespace tensorflow
    390