1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/summary/summary_db_writer.h" 16 17 #include "tensorflow/core/summary/schema.h" 18 #include "tensorflow/core/framework/function.pb.h" 19 #include "tensorflow/core/framework/graph.pb.h" 20 #include "tensorflow/core/framework/node_def.pb.h" 21 #include "tensorflow/core/framework/summary.pb.h" 22 #include "tensorflow/core/lib/core/status_test_util.h" 23 #include "tensorflow/core/lib/db/sqlite.h" 24 #include "tensorflow/core/lib/strings/strcat.h" 25 #include "tensorflow/core/platform/env.h" 26 #include "tensorflow/core/platform/test.h" 27 #include "tensorflow/core/util/event.pb.h" 28 29 namespace tensorflow { 30 namespace { 31 32 Tensor MakeScalarInt64(int64 x) { 33 Tensor t(DT_INT64, TensorShape({})); 34 t.scalar<int64>()() = x; 35 return t; 36 } 37 38 class FakeClockEnv : public EnvWrapper { 39 public: 40 FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {} 41 void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; } 42 uint64 NowMicros() override { return current_millis_ * 1000; } 43 uint64 NowSeconds() override { return current_millis_ * 1000; } 44 45 private: 46 uint64 current_millis_; 47 }; 48 49 class SummaryDbWriterTest : public ::testing::Test { 50 protected: 51 void SetUp() override { 52 TF_ASSERT_OK(Sqlite::Open(":memory:", SQLITE_OPEN_READWRITE, &db_)); 53 TF_ASSERT_OK(SetupTensorboardSqliteDb(db_)); 54 } 55 56 void TearDown() override { 57 if (writer_ != nullptr) { 58 writer_->Unref(); 59 writer_ = nullptr; 60 } 61 db_->Unref(); 62 db_ = nullptr; 63 } 64 65 int64 QueryInt(const string& sql) { 66 SqliteStatement stmt = db_->PrepareOrDie(sql); 67 bool is_done; 68 Status s = stmt.Step(&is_done); 69 if (!s.ok() || is_done) { 70 LOG(ERROR) << s << " due to " << sql; 71 return -1; 72 } 73 return stmt.ColumnInt(0); 74 } 75 76 double QueryDouble(const string& sql) { 77 SqliteStatement stmt = db_->PrepareOrDie(sql); 78 bool is_done; 79 Status s = stmt.Step(&is_done); 80 if (!s.ok() || is_done) { 81 LOG(ERROR) << s << " due to " << sql; 82 return -1; 83 } 84 return stmt.ColumnDouble(0); 85 } 86 87 string QueryString(const string& sql) { 88 SqliteStatement stmt = db_->PrepareOrDie(sql); 89 bool is_done; 90 Status s = stmt.Step(&is_done); 91 if (!s.ok() || is_done) { 92 LOG(ERROR) << s << " due to " << sql; 93 return "MISSINGNO"; 94 } 95 return stmt.ColumnString(0); 96 } 97 98 FakeClockEnv env_; 99 Sqlite* db_ = nullptr; 100 SummaryWriterInterface* writer_ = nullptr; 101 }; 102 103 TEST_F(SummaryDbWriterTest, WriteHistogram_VerifyTensorValues) { 104 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "histtest", "test1", "user1", &env_, 105 &writer_)); 106 int step = 0; 107 std::unique_ptr<Event> e{new Event}; 108 e->set_step(step); 109 e->set_wall_time(123); 110 Summary::Value* s = e->mutable_summary()->add_value(); 111 s->set_tag("normal/myhisto"); 112 113 double dummy_value = 10.123; 114 HistogramProto* proto = s->mutable_histo(); 115 proto->Clear(); 116 proto->set_min(dummy_value); 117 proto->set_max(dummy_value); 118 proto->set_num(dummy_value); 119 proto->set_sum(dummy_value); 120 proto->set_sum_squares(dummy_value); 121 122 int size = 3; 123 double bucket_limits[] = {-30.5, -10.5, -5.5}; 124 double bucket[] = {-10, 10, 20}; 125 for (int i = 0; i < size; i++) { 126 proto->add_bucket_limit(bucket_limits[i]); 127 proto->add_bucket(bucket[i]); 128 } 129 TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); 130 TF_ASSERT_OK(writer_->Flush()); 131 writer_->Unref(); 132 writer_ = nullptr; 133 134 // TODO(nickfelt): implement QueryTensor() to encapsulate this 135 // Verify the data 136 string result = QueryString("SELECT data FROM Tensors"); 137 const double* val = reinterpret_cast<const double*>(result.data()); 138 double histarray[] = {std::numeric_limits<double>::min(), 139 -30.5, 140 -10, 141 -30.5, 142 -10.5, 143 10, 144 -10.5, 145 -5.5, 146 20}; 147 int histarray_size = 9; 148 for (int i = 0; i < histarray_size; i++) { 149 EXPECT_EQ(histarray[i], val[i]); 150 } 151 } 152 153 TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) { 154 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, 155 &writer_)); 156 TF_ASSERT_OK(writer_->Flush()); 157 writer_->Unref(); 158 writer_ = nullptr; 159 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Ids")); 160 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users")); 161 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); 162 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); 163 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tags")); 164 EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tensors")); 165 } 166 167 TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { 168 SummaryMetadata metadata; 169 metadata.set_display_name("display_name"); 170 metadata.set_summary_description("description"); 171 metadata.mutable_plugin_data()->set_plugin_name("plugin_name"); 172 metadata.mutable_plugin_data()->set_content("plugin_data"); 173 SummaryMetadata metadata_nope; 174 metadata_nope.set_display_name("nope"); 175 metadata_nope.set_summary_description("nope"); 176 metadata_nope.mutable_plugin_data()->set_plugin_name("nope"); 177 metadata_nope.mutable_plugin_data()->set_content("nope"); 178 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, 179 &writer_)); 180 env_.AdvanceByMillis(23); 181 TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", 182 metadata.SerializeAsString())); 183 env_.AdvanceByMillis(23); 184 TF_ASSERT_OK(writer_->WriteTensor(2, MakeScalarInt64(314LL), "taggy", 185 metadata_nope.SerializeAsString())); 186 TF_ASSERT_OK(writer_->Flush()); 187 188 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Users")); 189 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments")); 190 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); 191 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); 192 ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); 193 194 int64 user_id = QueryInt("SELECT user_id FROM Users"); 195 int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments"); 196 int64 run_id = QueryInt("SELECT run_id FROM Runs"); 197 int64 tag_id = QueryInt("SELECT tag_id FROM Tags"); 198 EXPECT_LT(0LL, user_id); 199 EXPECT_LT(0LL, experiment_id); 200 EXPECT_LT(0LL, run_id); 201 EXPECT_LT(0LL, tag_id); 202 203 EXPECT_EQ("jart", QueryString("SELECT user_name FROM Users")); 204 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Users")); 205 206 EXPECT_EQ(user_id, QueryInt("SELECT user_id FROM Experiments")); 207 EXPECT_EQ("mad-science", 208 QueryString("SELECT experiment_name FROM Experiments")); 209 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Experiments")); 210 211 EXPECT_EQ(experiment_id, QueryInt("SELECT experiment_id FROM Runs")); 212 EXPECT_EQ("train", QueryString("SELECT run_name FROM Runs")); 213 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Runs")); 214 215 EXPECT_EQ(run_id, QueryInt("SELECT run_id FROM Tags")); 216 EXPECT_EQ("taggy", QueryString("SELECT tag_name FROM Tags")); 217 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Tags")); 218 219 EXPECT_EQ("display_name", QueryString("SELECT display_name FROM Tags")); 220 EXPECT_EQ("plugin_name", QueryString("SELECT plugin_name FROM Tags")); 221 EXPECT_EQ("plugin_data", QueryString("SELECT plugin_data FROM Tags")); 222 EXPECT_EQ("description", QueryString("SELECT description FROM Descriptions")); 223 224 EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 1")); 225 EXPECT_EQ(0.023, 226 QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1")); 227 228 EXPECT_EQ(tag_id, QueryInt("SELECT series FROM Tensors WHERE step = 2")); 229 EXPECT_EQ(0.046, 230 QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2")); 231 } 232 233 TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { 234 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); 235 TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", "")); 236 TF_ASSERT_OK(writer_->Flush()); 237 ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users")); 238 ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); 239 ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); 240 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); 241 ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); 242 } 243 244 TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { 245 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_)); 246 std::unique_ptr<Event> e{new Event}; 247 e->set_step(7); 248 e->set_wall_time(123.456); 249 Summary::Value* s = e->mutable_summary()->add_value(); 250 s->set_tag(""); 251 s->set_simple_value(3.14f); 252 s = e->mutable_summary()->add_value(); 253 s->set_tag(""); 254 s->set_simple_value(1.61f); 255 TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); 256 TF_ASSERT_OK(writer_->Flush()); 257 ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); 258 ASSERT_EQ(2000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); 259 int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = ''"); 260 int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = ''"); 261 EXPECT_GT(tag1_id, 0LL); 262 EXPECT_GT(tag2_id, 0LL); 263 EXPECT_EQ(123.456, QueryDouble(strings::StrCat( 264 "SELECT computed_time FROM Tensors WHERE series = ", 265 tag1_id, " AND step = 7"))); 266 EXPECT_EQ(123.456, QueryDouble(strings::StrCat( 267 "SELECT computed_time FROM Tensors WHERE series = ", 268 tag2_id, " AND step = 7"))); 269 } 270 271 TEST_F(SummaryDbWriterTest, WriteGraph) { 272 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "R", "", &env_, &writer_)); 273 env_.AdvanceByMillis(23); 274 GraphDef graph; 275 graph.mutable_library()->add_gradient()->set_function_name("funk"); 276 NodeDef* node = graph.add_node(); 277 node->set_name("x"); 278 node->set_op("Placeholder"); 279 node = graph.add_node(); 280 node->set_name("y"); 281 node->set_op("Placeholder"); 282 node = graph.add_node(); 283 node->set_name("z"); 284 node->set_op("Love"); 285 node = graph.add_node(); 286 node->set_name("+"); 287 node->set_op("Add"); 288 node->add_input("x"); 289 node->add_input("y"); 290 node->add_input("^z"); 291 node->set_device("tpu/lol"); 292 std::unique_ptr<Event> e{new Event}; 293 graph.SerializeToString(e->mutable_graph_def()); 294 TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); 295 TF_ASSERT_OK(writer_->Flush()); 296 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); 297 ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Graphs")); 298 ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Nodes")); 299 ASSERT_EQ(3LL, QueryInt("SELECT COUNT(*) FROM NodeInputs")); 300 301 ASSERT_EQ(QueryInt("SELECT run_id FROM Runs"), 302 QueryInt("SELECT run_id FROM Graphs")); 303 304 int64 graph_id = QueryInt("SELECT graph_id FROM Graphs"); 305 EXPECT_GT(graph_id, 0LL); 306 EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Graphs")); 307 308 GraphDef graph2; 309 graph2.ParseFromString(QueryString("SELECT graph_def FROM Graphs")); 310 EXPECT_EQ(0, graph2.node_size()); 311 EXPECT_EQ("funk", graph2.library().gradient(0).function_name()); 312 313 EXPECT_EQ("x", QueryString("SELECT node_name FROM Nodes WHERE node_id = 0")); 314 EXPECT_EQ("y", QueryString("SELECT node_name FROM Nodes WHERE node_id = 1")); 315 EXPECT_EQ("z", QueryString("SELECT node_name FROM Nodes WHERE node_id = 2")); 316 EXPECT_EQ("+", QueryString("SELECT node_name FROM Nodes WHERE node_id = 3")); 317 318 EXPECT_EQ("Placeholder", 319 QueryString("SELECT op FROM Nodes WHERE node_id = 0")); 320 EXPECT_EQ("Placeholder", 321 QueryString("SELECT op FROM Nodes WHERE node_id = 1")); 322 EXPECT_EQ("Love", QueryString("SELECT op FROM Nodes WHERE node_id = 2")); 323 EXPECT_EQ("Add", QueryString("SELECT op FROM Nodes WHERE node_id = 3")); 324 325 EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 0")); 326 EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 1")); 327 EXPECT_EQ("", QueryString("SELECT device FROM Nodes WHERE node_id = 2")); 328 EXPECT_EQ("tpu/lol", 329 QueryString("SELECT device FROM Nodes WHERE node_id = 3")); 330 331 EXPECT_EQ(graph_id, 332 QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 0")); 333 EXPECT_EQ(graph_id, 334 QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 1")); 335 EXPECT_EQ(graph_id, 336 QueryInt("SELECT graph_id FROM NodeInputs WHERE idx = 2")); 337 338 EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 0")); 339 EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 1")); 340 EXPECT_EQ(3LL, QueryInt("SELECT node_id FROM NodeInputs WHERE idx = 2")); 341 342 EXPECT_EQ(0LL, 343 QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 0")); 344 EXPECT_EQ(1LL, 345 QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 1")); 346 EXPECT_EQ(2LL, 347 QueryInt("SELECT input_node_id FROM NodeInputs WHERE idx = 2")); 348 349 EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 0")); 350 EXPECT_EQ(0LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 1")); 351 EXPECT_EQ(1LL, QueryInt("SELECT is_control FROM NodeInputs WHERE idx = 2")); 352 } 353 354 TEST_F(SummaryDbWriterTest, UsesIdsTable) { 355 SummaryMetadata metadata; 356 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, 357 &writer_)); 358 env_.AdvanceByMillis(23); 359 TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", 360 metadata.SerializeAsString())); 361 TF_ASSERT_OK(writer_->Flush()); 362 ASSERT_EQ(4LL, QueryInt("SELECT COUNT(*) FROM Ids")); 363 EXPECT_EQ(4LL, QueryInt(strings::StrCat( 364 "SELECT COUNT(*) FROM Ids WHERE id IN (", 365 QueryInt("SELECT user_id FROM Users"), ", ", 366 QueryInt("SELECT experiment_id FROM Experiments"), ", ", 367 QueryInt("SELECT run_id FROM Runs"), ", ", 368 QueryInt("SELECT tag_id FROM Tags"), ")"))); 369 } 370 371 TEST_F(SummaryDbWriterTest, SetsRunFinishedTime) { 372 SummaryMetadata metadata; 373 TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, 374 &writer_)); 375 env_.AdvanceByMillis(23); 376 TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", 377 metadata.SerializeAsString())); 378 TF_ASSERT_OK(writer_->Flush()); 379 ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs")); 380 ASSERT_EQ(0.0, QueryDouble("SELECT finished_time FROM Runs")); 381 env_.AdvanceByMillis(23); 382 writer_->Unref(); 383 writer_ = nullptr; 384 ASSERT_EQ(0.023, QueryDouble("SELECT started_time FROM Runs")); 385 ASSERT_EQ(0.046, QueryDouble("SELECT finished_time FROM Runs")); 386 } 387 388 } // namespace 389 } // namespace tensorflow 390