1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" 16 17 #include "tensorflow/contrib/tensorboard/db/schema.h" 18 #include "tensorflow/core/framework/function.pb.h" 19 #include "tensorflow/core/framework/graph.pb.h" 20 #include "tensorflow/core/framework/node_def.pb.h" 21 #include "tensorflow/core/framework/summary.pb.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/lib/db/sqlite.h" 24 #include "tensorflow/core/lib/strings/strcat.h" 25 #include "tensorflow/core/platform/env.h" 26 #include "tensorflow/core/platform/test.h" 27 #include "tensorflow/core/util/event.pb.h" 28 29 namespace tensorflow { 30 namespace { 31 32 Tensor MakeScalarInt64(int64 x) { 33 Tensor t(DT_INT64, TensorShape({})); 34 t.scalar<int64>()() = x; 35 return t; 36 } 37 38 class FakeClockEnv : public EnvWrapper { 39 public: 40 FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {} 41 void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; } 42 uint64 NowMicros() override { return current_millis_ * 1000; } 43 uint64 NowSeconds() override { return current_millis_ * 1000; } 44 45 private: 46 uint64 current_millis_; 47 }; 48 49 class SummaryDbWriterTest : public ::testing::Test { 50 protected: 51 void SetUp() override { 52 TF_ASSERT_OK(Sqlite::Open(":memory:", SQLITE_OPEN_READWRITE, &db_)); 53 TF_ASSERT_OK(SetupTensorboardSqliteDb(db_)); 54 } 55 56 void TearDown() override { 57 if (writer_ != nullptr) { 58 writer_->Unref(); 59 writer_ = nullptr; 60 } 61 db_->Unref(); 62 db_ = nullptr; 63 } 64 65 int64 QueryInt(const string& sql) { 66 SqliteStatement stmt = db_->PrepareOrDie(sql); 67 bool is_done; 68 Status s = stmt.Step(&is_done); 69 if (!s.ok() || is_done) { 70 LOG(ERROR) << s << " due to " << sql; 71 return -1; 72 } 73 return stmt.ColumnInt(0); 74 } 75 76 double QueryDouble(const string& sql) { 77 SqliteStatement stmt = db_->PrepareOrDie(sql); 78 bool is_done; 79 Status s = stmt.Step(&is_done); 80 if (!s.ok() || is_done) { 81 LOG(ERROR) << s << " due to " << sql; 82 return -1; 83 } 84 return stmt.ColumnDouble(0); 85 } 86 87 string QueryString(const string& sql) { 88 SqliteStatement stmt = db_->PrepareOrDie(sql); 89 bool is_done; 90 Status s = stmt.Step(&is_done); 91 if (!s.ok() || is_done) { 92 LOG(ERROR) << s << " due to " << sql; 93 return "MISSINGNO"; 94 } 95 return stmt.ColumnString(0); 96 } 97 98 FakeClockEnv env_; 99 Sqlite* db_ = nullptr; 100 SummaryWriterInterface* writer_ = nullptr; 101 }; 102 103 TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) { 104 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, 105 &writer_)); 106 TF_ASSERT_OK(writer_->Flush()); 107 writer_->Unref(); 108 writer_ = nullptr; 109 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Ids")); 110 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users")); 111 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); 112 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); 113 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tags")); 114 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tensors")); 115 } 116 117 TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { 118 SummaryMetadata metadata; 119 metadata.set_display_name("display_name"); 120 metadata.set_summary_description("description"); 121 metadata.mutable_plugin_data()->set_plugin_name("plugin_name"); 122 metadata.mutable_plugin_data()->set_content("plugin_data"); 123 SummaryMetadata metadata_nope; 124 metadata_nope.set_display_name("nope"); 125 metadata_nope.set_summary_description("nope"); 126 metadata_nope.mutable_plugin_data()->set_plugin_name("nope"); 127 metadata_nope.mutable_plugin_data()->set_content("nope"); 128 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, 129 &writer_)); 130 env_.AdvanceByMillis(23); 131 TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", 132 metadata.SerializeAsString())); 133 env_.AdvanceByMillis(23); 134 TF_ASSERT_OK(writer_->WriteTensor(2, MakeScalarInt64(314LL), "taggy", 135 metadata_nope.SerializeAsString())); 136 TF_ASSERT_OK(writer_->Flush()); 137 138 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Users")); 139 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments")); 140 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); 141 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); 142 ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); 143 144 int64 user_id = QueryInt("SELECT user_id FROM Users"); 145 int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments"); 146 int64 run_id = QueryInt("SELECT run_id FROM Runs"); 147 int64 tag_id = QueryInt("SELECT tag_id FROM Tags"); 148 EXPECT_LT(0LL, user_id); 149 EXPECT_LT(0LL, experiment_id); 150 EXPECT_LT(0LL, run_id); 151 EXPECT_LT(0LL, tag_id); 152 153 EXPECT_EQ("jart", QueryString("SELECT user_name FROM Users")); 154 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Users")); 155 156 EXPECT_EQ(user_id, QueryInt("SELECT user_id FROM Experiments")); 157 EXPECT_EQ("mad-science", 158 QueryString("SELECT experiment_name FROM Experiments")); 159 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Experiments")); 160 161 EXPECT_EQ(experiment_id, QueryInt("SELECT experiment_id FROM Runs")); 162 EXPECT_EQ("train", QueryString("SELECT run_name FROM Runs")); 163 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Runs")); 164 165 EXPECT_EQ(run_id, QueryInt("SELECT run_id FROM Tags")); 166 EXPECT_EQ("taggy", QueryString("SELECT tag_name FROM Tags")); 167 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Tags")); 168 169 EXPECT_EQ("display_name", QueryString("SELECT display_name FROM Tags")); 170 EXPECT_EQ("plugin_name", QueryString("SELECT plugin_name FROM Tags")); 171 EXPECT_EQ("plugin_data", QueryString("SELECT plugin_data FROM Tags")); 172 EXPECT_EQ("description", QueryString("SELECT description FROM Descriptions")); 173 174 EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 1")); 175 EXPECT_EQ(0.023, 176 QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1")); 177 178 EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 2")); 179 EXPECT_EQ(0.046, 180 QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2")); 181 } 182 183 TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { 184 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); 185 TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", "")); 186 TF_ASSERT_OK(writer_->Flush()); 187 ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users")); 188 ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); 189 ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); 190 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); 191 ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); 192 } 193 194 TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { 195 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); 196 std::unique_ptr<Event> e{new Event}; 197 e->set_step(7); 198 e->set_wall_time(123.456); 199 Summary::Value* s = e->mutable_summary()->add_value(); 200 s->set_tag(""); 201 s->set_simple_value(3.14f); 202 s = e->mutable_summary()->add_value(); 203 s->set_tag(""); 204 s->set_simple_value(1.61f); 205 TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); 206 TF_ASSERT_OK(writer_->Flush()); 207 ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); 208 ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); 209 int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = ''"); 210 int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = ''"); 211 EXPECT_GT(tag1_id, 0LL); 212 EXPECT_GT(tag2_id, 0LL); 213 EXPECT_EQ(123.456, QueryDouble(strings::StrCat( 214 "SELECT computed_time FROM Tensors WHERE series = ", 215 tag1_id, " AND step = 7"))); 216 EXPECT_EQ(123.456, QueryDouble(strings::StrCat( 217 "SELECT computed_time FROM Tensors WHERE series = ", 218 tag2_id, " AND step = 7"))); 219 } 220 221 TEST_F(SummaryDbWriterTest, WriteGraph) { 222 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_)); 223 env_.AdvanceByMillis(23); 224 GraphDef graph; 225 graph.mutable_library()->add_gradient()->set_function_name("funk"); 226 NodeDef* node = graph.add_node(); 227 node->set_name("x"); 228 node->set_op("Placeholder"); 229 node = graph.add_node(); 230 node->set_name("y"); 231 node->set_op("Placeholder"); 232 node = graph.add_node(); 233 node->set_name("z"); 234 node->set_op("Love"); 235 node = graph.add_node(); 236 node->set_name("+"); 237 node->set_op("Add"); 238 node->add_input("x"); 239 node->add_input("y"); 240 node->add_input("^z"); 241 node->set_device("tpu/lol"); 242 std::unique_ptr<Event> e{new Event}; 243 graph.SerializeToString(e->mutable_graph_def()); 244 TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); 245 TF_ASSERT_OK(writer_->Flush()); 246 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); 247 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs")); 248 ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes")); 249 ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs")); 250 251 ASSERT_EQ(QueryInt("SELECT run_id FROM Runs"), 252 QueryInt("SELECT run_id FROM Graphs")); 253 254 int64 graph_id = QueryInt("SELECT graph_id FROM Graphs"); 255 EXPECT_GT(graph_id, 0LL); 256 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs")); 257 258 GraphDef graph2; 259 graph2.ParseFromString(QueryString("SELECT graph_def FROM Graphs")); 260 EXPECT_EQ(0, graph2.node_size()); 261 EXPECT_EQ("funk", graph2.library().gradient(0).function_name()); 262 263 EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0")); 264 EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1")); 265 EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2")); 266 EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3")); 267 268 EXPECT_EQ("Placeholder", 269 QueryString("SELECT op FROM Nodes WHERE node_id = 0")); 270 EXPECT_EQ("Placeholder", 271 QueryString("SELECT op FROM Nodes WHERE node_id = 1")); 272 EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2")); 273 EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3")); 274 275 EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0")); 276 EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1")); 277 EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2")); 278 EXPECT_EQ("tpu/lol", 279 QueryString("SELECT device FROM Nodes WHERE node_id = 3")); 280 281 EXPECT_EQ(graph_id, 282 QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0")); 283 EXPECT_EQ(graph_id, 284 QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1")); 285 EXPECT_EQ(graph_id, 286 QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2")); 287 288 EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0")); 289 EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1")); 290 EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2")); 291 292 EXPECT_EQ(0LL, 293 QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0")); 294 EXPECT_EQ(1LL, 295 QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1")); 296 EXPECT_EQ(2LL, 297 QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2")); 298 299 EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0")); 300 EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1")); 301 EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2")); 302 } 303 304 TEST_F(SummaryDbWriterTest, UsesIdsTable) { 305 SummaryMetadata metadata; 306 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, 307 &writer_)); 308 env_.AdvanceByMillis(23); 309 TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", 310 metadata.SerializeAsString())); 311 TF_ASSERT_OK(writer_->Flush()); 312 ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Ids")); 313 EXPECT_EQ(4LL, QueryInt(strings::StrCat( 314 "SELECT COUNT(*) FROM Ids WHERE id IN (", 315 QueryInt("SELECT user_id FROM Users"), ", ", 316 QueryInt("SELECT experiment_id FROM Experiments"), ", ", 317 QueryInt("SELECT run_id FROM Runs"), ", ", 318 QueryInt("SELECT tag_id FROM Tags"), ")"))); 319 } 320 321 TEST_F(SummaryDbWriterTest, SetsRunFinishedTime) { 322 SummaryMetadata metadata; 323 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, 324 &writer_)); 325 env_.AdvanceByMillis(23); 326 TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", 327 metadata.SerializeAsString())); 328 TF_ASSERT_OK(writer_->Flush()); 329 ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs")); 330 ASSERT_EQ(0.0, QueryDouble("SELECT finished_time FROM Runs")); 331 env_.AdvanceByMillis(23); 332 writer_->Unref(); 333 writer_ = nullptr; 334 ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs")); 335 ASSERT_EQ(0.046, QueryDouble("SELECT finished_time FROM Runs")); 336 } 337 338 } // namespace 339 } // namespace tensorflow 340