Home | History | Annotate | Download | only in summary
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Utilities to test summaries."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import functools
     23 import os
     24 
     25 import sqlite3
     26 
     27 from tensorflow.contrib.summary import summary_ops
     28 from tensorflow.core.util import event_pb2
     29 from tensorflow.python.framework import test_util
     30 from tensorflow.python.lib.io import tf_record
     31 from tensorflow.python.platform import gfile
     32 
     33 
     34 class SummaryDbTest(test_util.TensorFlowTestCase):
     35   """Helper for summary database testing."""
     36 
     37   def setUp(self):
     38     super(SummaryDbTest, self).setUp()
     39     self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
     40     if os.path.exists(self.db_path):
     41       os.unlink(self.db_path)
     42     self.db = sqlite3.connect(self.db_path)
     43     self.create_db_writer = functools.partial(
     44         summary_ops.create_db_writer,
     45         db_uri=self.db_path,
     46         experiment_name='experiment',
     47         run_name='run',
     48         user_name='user')
     49 
     50   def tearDown(self):
     51     self.db.close()
     52     super(SummaryDbTest, self).tearDown()
     53 
     54 
     55 def events_from_file(filepath):
     56   """Returns all events in a single event file.
     57 
     58   Args:
     59     filepath: Path to the event file.
     60 
     61   Returns:
     62     A list of all tf.Event protos in the event file.
     63   """
     64   records = list(tf_record.tf_record_iterator(filepath))
     65   result = []
     66   for r in records:
     67     event = event_pb2.Event()
     68     event.ParseFromString(r)
     69     result.append(event)
     70   return result
     71 
     72 
     73 def events_from_logdir(logdir):
     74   """Returns all events in the single eventfile in logdir.
     75 
     76   Args:
     77     logdir: The directory in which the single event file is sought.
     78 
     79   Returns:
     80     A list of all tf.Event protos from the single event file.
     81 
     82   Raises:
     83     AssertionError: If logdir does not contain exactly one file.
     84   """
     85   assert gfile.Exists(logdir)
     86   files = gfile.ListDirectory(logdir)
     87   assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
     88   return events_from_file(os.path.join(logdir, files[0]))
     89 
     90 
     91 def get_one(db, q, *p):
     92   return db.execute(q, p).fetchone()[0]
     93 
     94 
     95 def get_all(db, q, *p):
     96   return unroll(db.execute(q, p).fetchall())
     97 
     98 
     99 def unroll(list_of_tuples):
    100   return sum(list_of_tuples, ())
    101