Home | History | Annotate | Download | only in summary
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 from __future__ import absolute_import
     16 from __future__ import division
     17 from __future__ import print_function
     18 
     19 import tempfile
     20 
     21 import numpy as np
     22 import six
     23 
     24 from tensorflow.contrib.summary import summary_ops
     25 from tensorflow.contrib.summary import summary_test_util
     26 from tensorflow.core.framework import graph_pb2
     27 from tensorflow.core.framework import node_def_pb2
     28 from tensorflow.core.framework import types_pb2
     29 from tensorflow.python.eager import function
     30 from tensorflow.python.eager import test
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import test_util
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import state_ops
     35 from tensorflow.python.platform import gfile
     36 from tensorflow.python.training import training_util
     37 
     38 get_all = summary_test_util.get_all
     39 get_one = summary_test_util.get_one
     40 
     41 _NUMPY_NUMERIC_TYPES = {
     42     types_pb2.DT_HALF: np.float16,
     43     types_pb2.DT_FLOAT: np.float32,
     44     types_pb2.DT_DOUBLE: np.float64,
     45     types_pb2.DT_INT8: np.int8,
     46     types_pb2.DT_INT16: np.int16,
     47     types_pb2.DT_INT32: np.int32,
     48     types_pb2.DT_INT64: np.int64,
     49     types_pb2.DT_UINT8: np.uint8,
     50     types_pb2.DT_UINT16: np.uint16,
     51     types_pb2.DT_UINT32: np.uint32,
     52     types_pb2.DT_UINT64: np.uint64,
     53     types_pb2.DT_COMPLEX64: np.complex64,
     54     types_pb2.DT_COMPLEX128: np.complex128,
     55     types_pb2.DT_BOOL: np.bool_,
     56 }
     57 
     58 
     59 class TargetTest(test_util.TensorFlowTestCase):
     60 
     61   def testShouldRecordSummary(self):
     62     self.assertFalse(summary_ops.should_record_summaries())
     63     with summary_ops.always_record_summaries():
     64       self.assertTrue(summary_ops.should_record_summaries())
     65 
     66   def testSummaryOps(self):
     67     training_util.get_or_create_global_step()
     68     logdir = tempfile.mkdtemp()
     69     with summary_ops.create_file_writer(
     70         logdir, max_queue=0,
     71         name='t0').as_default(), summary_ops.always_record_summaries():
     72       summary_ops.generic('tensor', 1, '')
     73       summary_ops.scalar('scalar', 2.0)
     74       summary_ops.histogram('histogram', [1.0])
     75       summary_ops.image('image', [[[[1.0]]]])
     76       summary_ops.audio('audio', [[1.0]], 1.0, 1)
     77       # The working condition of the ops is tested in the C++ test so we just
     78       # test here that we're calling them correctly.
     79       self.assertTrue(gfile.Exists(logdir))
     80 
     81   def testDefunSummarys(self):
     82     training_util.get_or_create_global_step()
     83     logdir = tempfile.mkdtemp()
     84     with summary_ops.create_file_writer(
     85         logdir, max_queue=0,
     86         name='t1').as_default(), summary_ops.always_record_summaries():
     87 
     88       @function.defun
     89       def write():
     90         summary_ops.scalar('scalar', 2.0)
     91 
     92       write()
     93       events = summary_test_util.events_from_logdir(logdir)
     94       self.assertEqual(len(events), 2)
     95       self.assertEqual(events[1].summary.value[0].simple_value, 2.0)
     96 
     97   def testSummaryName(self):
     98     training_util.get_or_create_global_step()
     99     logdir = tempfile.mkdtemp()
    100     with summary_ops.create_file_writer(
    101         logdir, max_queue=0,
    102         name='t2').as_default(), summary_ops.always_record_summaries():
    103 
    104       summary_ops.scalar('scalar', 2.0)
    105 
    106       events = summary_test_util.events_from_logdir(logdir)
    107       self.assertEqual(len(events), 2)
    108       self.assertEqual(events[1].summary.value[0].tag, 'scalar')
    109 
    110   def testSummaryGlobalStep(self):
    111     step = training_util.get_or_create_global_step()
    112     logdir = tempfile.mkdtemp()
    113     with summary_ops.create_file_writer(
    114         logdir, max_queue=0,
    115         name='t2').as_default(), summary_ops.always_record_summaries():
    116 
    117       summary_ops.scalar('scalar', 2.0, step=step)
    118 
    119       events = summary_test_util.events_from_logdir(logdir)
    120       self.assertEqual(len(events), 2)
    121       self.assertEqual(events[1].summary.value[0].tag, 'scalar')
    122 
    123   def testMaxQueue(self):
    124     logs = tempfile.mkdtemp()
    125     with summary_ops.create_file_writer(
    126         logs, max_queue=2, flush_millis=999999,
    127         name='lol').as_default(), summary_ops.always_record_summaries():
    128       get_total = lambda: len(summary_test_util.events_from_logdir(logs))
    129       # Note: First tf.Event is always file_version.
    130       self.assertEqual(1, get_total())
    131       summary_ops.scalar('scalar', 2.0, step=1)
    132       self.assertEqual(1, get_total())
    133       summary_ops.scalar('scalar', 2.0, step=2)
    134       self.assertEqual(3, get_total())
    135 
    136   def testFlush(self):
    137     logs = tempfile.mkdtemp()
    138     with summary_ops.create_file_writer(
    139         logs, max_queue=999999, flush_millis=999999,
    140         name='lol').as_default(), summary_ops.always_record_summaries():
    141       get_total = lambda: len(summary_test_util.events_from_logdir(logs))
    142       # Note: First tf.Event is always file_version.
    143       self.assertEqual(1, get_total())
    144       summary_ops.scalar('scalar', 2.0, step=1)
    145       summary_ops.scalar('scalar', 2.0, step=2)
    146       self.assertEqual(1, get_total())
    147       summary_ops.flush()
    148       self.assertEqual(3, get_total())
    149 
    150 
    151 class DbTest(summary_test_util.SummaryDbTest):
    152 
    153   def testIntegerSummaries(self):
    154     step = training_util.create_global_step()
    155     writer = self.create_db_writer()
    156 
    157     def adder(x, y):
    158       state_ops.assign_add(step, 1)
    159       summary_ops.generic('x', x)
    160       summary_ops.generic('y', y)
    161       sum_ = x + y
    162       summary_ops.generic('sum', sum_)
    163       return sum_
    164 
    165     with summary_ops.always_record_summaries():
    166       with writer.as_default():
    167         self.assertEqual(5, adder(int64(2), int64(3)).numpy())
    168 
    169     six.assertCountEqual(
    170         self, [1, 1, 1],
    171         get_all(self.db, 'SELECT step FROM Tensors WHERE dtype IS NOT NULL'))
    172     six.assertCountEqual(self, ['x', 'y', 'sum'],
    173                          get_all(self.db, 'SELECT tag_name FROM Tags'))
    174     x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"')
    175     y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"')
    176     sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"')
    177 
    178     with summary_ops.always_record_summaries():
    179       with writer.as_default():
    180         self.assertEqual(9, adder(int64(4), int64(5)).numpy())
    181 
    182     six.assertCountEqual(
    183         self, [1, 1, 1, 2, 2, 2],
    184         get_all(self.db, 'SELECT step FROM Tensors WHERE dtype IS NOT NULL'))
    185     six.assertCountEqual(self, [x_id, y_id, sum_id],
    186                          get_all(self.db, 'SELECT tag_id FROM Tags'))
    187     self.assertEqual(2, get_tensor(self.db, x_id, 1))
    188     self.assertEqual(3, get_tensor(self.db, y_id, 1))
    189     self.assertEqual(5, get_tensor(self.db, sum_id, 1))
    190     self.assertEqual(4, get_tensor(self.db, x_id, 2))
    191     self.assertEqual(5, get_tensor(self.db, y_id, 2))
    192     self.assertEqual(9, get_tensor(self.db, sum_id, 2))
    193     six.assertCountEqual(
    194         self, ['experiment'],
    195         get_all(self.db, 'SELECT experiment_name FROM Experiments'))
    196     six.assertCountEqual(self, ['run'],
    197                          get_all(self.db, 'SELECT run_name FROM Runs'))
    198     six.assertCountEqual(self, ['user'],
    199                          get_all(self.db, 'SELECT user_name FROM Users'))
    200 
    201   def testBadExperimentName(self):
    202     with self.assertRaises(ValueError):
    203       self.create_db_writer(experiment_name='\0')
    204 
    205   def testBadRunName(self):
    206     with self.assertRaises(ValueError):
    207       self.create_db_writer(run_name='\0')
    208 
    209   def testBadUserName(self):
    210     with self.assertRaises(ValueError):
    211       self.create_db_writer(user_name='-hi')
    212     with self.assertRaises(ValueError):
    213       self.create_db_writer(user_name='hi-')
    214     with self.assertRaises(ValueError):
    215       self.create_db_writer(user_name='@')
    216 
    217   def testGraphSummary(self):
    218     training_util.get_or_create_global_step()
    219     name = 'hi'
    220     graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
    221     with summary_ops.always_record_summaries():
    222       with self.create_db_writer().as_default():
    223         summary_ops.graph(graph)
    224     six.assertCountEqual(self, [name],
    225                          get_all(self.db, 'SELECT node_name FROM Nodes'))
    226 
    227 
    228 def get_tensor(db, tag_id, step):
    229   cursor = db.execute(
    230       'SELECT dtype, shape, data FROM Tensors WHERE series = ? AND step = ?',
    231       (tag_id, step))
    232   dtype, shape, data = cursor.fetchone()
    233   assert dtype in _NUMPY_NUMERIC_TYPES
    234   buf = np.frombuffer(data, dtype=_NUMPY_NUMERIC_TYPES[dtype])
    235   if not shape:
    236     return buf[0]
    237   return buf.reshape([int(i) for i in shape.split(',')])
    238 
    239 
    240 def int64(x):
    241   return array_ops.constant(x, dtypes.int64)
    242 
    243 
    244 if __name__ == '__main__':
    245   test.main()
    246