Home | History | Annotate | Download | only in slim
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.contrib.slim.summaries."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import glob
     22 import os
     23 
     24 
     25 from tensorflow.contrib.slim.python.slim import summaries
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.platform import gfile
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.summary import summary
     31 from tensorflow.python.summary import summary_iterator
     32 
     33 
     34 class SummariesTest(test.TestCase):
     35 
     36   def safe_create(self, output_dir):
     37     if gfile.Exists(output_dir):
     38       gfile.DeleteRecursively(output_dir)
     39     gfile.MakeDirs(output_dir)
     40 
     41   def assert_scalar_summary(self, output_dir, names_to_values):
     42     """Asserts that the given output directory contains written summaries.
     43 
     44     Args:
     45       output_dir: The output directory in which to look for even tfiles.
     46       names_to_values: A dictionary of summary names to values.
     47     """
     48     # The events file may have additional entries, e.g. the event version
     49     # stamp, so have to parse things a bit.
     50     output_filepath = glob.glob(os.path.join(output_dir, '*'))
     51     self.assertEqual(len(output_filepath), 1)
     52 
     53     events = summary_iterator.summary_iterator(output_filepath[0])
     54     summaries_list = [e.summary for e in events if e.summary.value]
     55     values = []
     56     for item in summaries_list:
     57       for value in item.value:
     58         values.append(value)
     59     saved_results = {v.tag: v.simple_value for v in values}
     60     for name in names_to_values:
     61       self.assertAlmostEqual(names_to_values[name], saved_results[name])
     62 
     63   def testScalarSummaryIsPartOfCollectionWithNoPrint(self):
     64     tensor = array_ops.ones([]) * 3
     65     name = 'my_score'
     66     prefix = 'eval'
     67     op = summaries.add_scalar_summary(tensor, name, prefix, print_summary=False)
     68     self.assertTrue(op in ops.get_collection(ops.GraphKeys.SUMMARIES))
     69 
     70   def testScalarSummaryIsPartOfCollectionWithPrint(self):
     71     tensor = array_ops.ones([]) * 3
     72     name = 'my_score'
     73     prefix = 'eval'
     74     op = summaries.add_scalar_summary(tensor, name, prefix, print_summary=True)
     75     self.assertTrue(op in ops.get_collection(ops.GraphKeys.SUMMARIES))
     76 
     77   def verify_scalar_summary_is_written(self, print_summary):
     78     value = 3
     79     tensor = array_ops.ones([]) * value
     80     name = 'my_score'
     81     prefix = 'eval'
     82     summaries.add_scalar_summary(tensor, name, prefix, print_summary)
     83 
     84     output_dir = os.path.join(self.get_temp_dir(),
     85                               'scalar_summary_no_print_test')
     86     self.safe_create(output_dir)
     87 
     88     summary_op = summary.merge_all()
     89 
     90     summary_writer = summary.FileWriter(output_dir)
     91     with self.test_session() as sess:
     92       new_summary = sess.run(summary_op)
     93       summary_writer.add_summary(new_summary, 1)
     94       summary_writer.flush()
     95 
     96     self.assert_scalar_summary(output_dir, {
     97         '%s/%s' % (prefix, name): value
     98     })
     99 
    100   def testScalarSummaryIsWrittenWithNoPrint(self):
    101     self.verify_scalar_summary_is_written(print_summary=False)
    102 
    103   def testScalarSummaryIsWrittenWithPrint(self):
    104     self.verify_scalar_summary_is_written(print_summary=True)
    105 
    106 
    107 if __name__ == '__main__':
    108   test.main()
    109