1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 from __future__ import absolute_import 16 from __future__ import division 17 from __future__ import print_function 18 19 import tempfile 20 21 import numpy as np 22 import six 23 24 from tensorflow.contrib.summary import summary_ops 25 from tensorflow.contrib.summary import summary_test_util 26 from tensorflow.core.framework import graph_pb2 27 from tensorflow.core.framework import node_def_pb2 28 from tensorflow.core.framework import types_pb2 29 from tensorflow.python.eager import function 30 from tensorflow.python.eager import test 31 from tensorflow.python.framework import dtypes 32 from tensorflow.python.framework import test_util 33 from tensorflow.python.ops import array_ops 34 from tensorflow.python.ops import state_ops 35 from tensorflow.python.platform import gfile 36 from tensorflow.python.training import training_util 37 38 get_all = summary_test_util.get_all 39 get_one = summary_test_util.get_one 40 41 _NUMPY_NUMERIC_TYPES = { 42 types_pb2.DT_HALF: np.float16, 43 types_pb2.DT_FLOAT: np.float32, 44 types_pb2.DT_DOUBLE: np.float64, 45 types_pb2.DT_INT8: np.int8, 46 types_pb2.DT_INT16: np.int16, 47 types_pb2.DT_INT32: np.int32, 48 types_pb2.DT_INT64: np.int64, 49 types_pb2.DT_UINT8: np.uint8, 50 types_pb2.DT_UINT16: np.uint16, 51 types_pb2.DT_UINT32: np.uint32, 52 types_pb2.DT_UINT64: np.uint64, 53 types_pb2.DT_COMPLEX64: np.complex64, 54 types_pb2.DT_COMPLEX128: np.complex128, 55 types_pb2.DT_BOOL: np.bool_, 56 } 57 58 59 class TargetTest(test_util.TensorFlowTestCase): 60 61 def testShouldRecordSummary(self): 62 self.assertFalse(summary_ops.should_record_summaries()) 63 with summary_ops.always_record_summaries(): 64 self.assertTrue(summary_ops.should_record_summaries()) 65 66 def testSummaryOps(self): 67 training_util.get_or_create_global_step() 68 logdir = tempfile.mkdtemp() 69 with summary_ops.create_file_writer( 70 logdir, max_queue=0, 71 name='t0').as_default(), summary_ops.always_record_summaries(): 72 summary_ops.generic('tensor', 1, '') 73 summary_ops.scalar('scalar', 2.0) 74 summary_ops.histogram('histogram', [1.0]) 75 summary_ops.image('image', [[[[1.0]]]]) 76 summary_ops.audio('audio', [[1.0]], 1.0, 1) 77 # The working condition of the ops is tested in the C++ test so we just 78 # test here that we're calling them correctly. 79 self.assertTrue(gfile.Exists(logdir)) 80 81 def testDefunSummarys(self): 82 training_util.get_or_create_global_step() 83 logdir = tempfile.mkdtemp() 84 with summary_ops.create_file_writer( 85 logdir, max_queue=0, 86 name='t1').as_default(), summary_ops.always_record_summaries(): 87 88 @function.defun 89 def write(): 90 summary_ops.scalar('scalar', 2.0) 91 92 write() 93 events = summary_test_util.events_from_logdir(logdir) 94 self.assertEqual(len(events), 2) 95 self.assertEqual(events[1].summary.value[0].simple_value, 2.0) 96 97 def testSummaryName(self): 98 training_util.get_or_create_global_step() 99 logdir = tempfile.mkdtemp() 100 with summary_ops.create_file_writer( 101 logdir, max_queue=0, 102 name='t2').as_default(), summary_ops.always_record_summaries(): 103 104 summary_ops.scalar('scalar', 2.0) 105 106 events = summary_test_util.events_from_logdir(logdir) 107 self.assertEqual(len(events), 2) 108 self.assertEqual(events[1].summary.value[0].tag, 'scalar') 109 110 def testSummaryGlobalStep(self): 111 step = training_util.get_or_create_global_step() 112 logdir = tempfile.mkdtemp() 113 with summary_ops.create_file_writer( 114 logdir, max_queue=0, 115 name='t2').as_default(), summary_ops.always_record_summaries(): 116 117 summary_ops.scalar('scalar', 2.0, step=step) 118 119 events = summary_test_util.events_from_logdir(logdir) 120 self.assertEqual(len(events), 2) 121 self.assertEqual(events[1].summary.value[0].tag, 'scalar') 122 123 def testMaxQueue(self): 124 logs = tempfile.mkdtemp() 125 with summary_ops.create_file_writer( 126 logs, max_queue=2, flush_millis=999999, 127 name='lol').as_default(), summary_ops.always_record_summaries(): 128 get_total = lambda: len(summary_test_util.events_from_logdir(logs)) 129 # Note: First tf.Event is always file_version. 130 self.assertEqual(1, get_total()) 131 summary_ops.scalar('scalar', 2.0, step=1) 132 self.assertEqual(1, get_total()) 133 summary_ops.scalar('scalar', 2.0, step=2) 134 self.assertEqual(3, get_total()) 135 136 def testFlush(self): 137 logs = tempfile.mkdtemp() 138 with summary_ops.create_file_writer( 139 logs, max_queue=999999, flush_millis=999999, 140 name='lol').as_default(), summary_ops.always_record_summaries(): 141 get_total = lambda: len(summary_test_util.events_from_logdir(logs)) 142 # Note: First tf.Event is always file_version. 143 self.assertEqual(1, get_total()) 144 summary_ops.scalar('scalar', 2.0, step=1) 145 summary_ops.scalar('scalar', 2.0, step=2) 146 self.assertEqual(1, get_total()) 147 summary_ops.flush() 148 self.assertEqual(3, get_total()) 149 150 151 class DbTest(summary_test_util.SummaryDbTest): 152 153 def testIntegerSummaries(self): 154 step = training_util.create_global_step() 155 writer = self.create_db_writer() 156 157 def adder(x, y): 158 state_ops.assign_add(step, 1) 159 summary_ops.generic('x', x) 160 summary_ops.generic('y', y) 161 sum_ = x + y 162 summary_ops.generic('sum', sum_) 163 return sum_ 164 165 with summary_ops.always_record_summaries(): 166 with writer.as_default(): 167 self.assertEqual(5, adder(int64(2), int64(3)).numpy()) 168 169 six.assertCountEqual( 170 self, [1, 1, 1], 171 get_all(self.db, 'SELECT step FROM Tensors WHERE dtype IS NOT NULL')) 172 six.assertCountEqual(self, ['x', 'y', 'sum'], 173 get_all(self.db, 'SELECT tag_name FROM Tags')) 174 x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"') 175 y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"') 176 sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"') 177 178 with summary_ops.always_record_summaries(): 179 with writer.as_default(): 180 self.assertEqual(9, adder(int64(4), int64(5)).numpy()) 181 182 six.assertCountEqual( 183 self, [1, 1, 1, 2, 2, 2], 184 get_all(self.db, 'SELECT step FROM Tensors WHERE dtype IS NOT NULL')) 185 six.assertCountEqual(self, [x_id, y_id, sum_id], 186 get_all(self.db, 'SELECT tag_id FROM Tags')) 187 self.assertEqual(2, get_tensor(self.db, x_id, 1)) 188 self.assertEqual(3, get_tensor(self.db, y_id, 1)) 189 self.assertEqual(5, get_tensor(self.db, sum_id, 1)) 190 self.assertEqual(4, get_tensor(self.db, x_id, 2)) 191 self.assertEqual(5, get_tensor(self.db, y_id, 2)) 192 self.assertEqual(9, get_tensor(self.db, sum_id, 2)) 193 six.assertCountEqual( 194 self, ['experiment'], 195 get_all(self.db, 'SELECT experiment_name FROM Experiments')) 196 six.assertCountEqual(self, ['run'], 197 get_all(self.db, 'SELECT run_name FROM Runs')) 198 six.assertCountEqual(self, ['user'], 199 get_all(self.db, 'SELECT user_name FROM Users')) 200 201 def testBadExperimentName(self): 202 with self.assertRaises(ValueError): 203 self.create_db_writer(experiment_name='\0') 204 205 def testBadRunName(self): 206 with self.assertRaises(ValueError): 207 self.create_db_writer(run_name='\0') 208 209 def testBadUserName(self): 210 with self.assertRaises(ValueError): 211 self.create_db_writer(user_name='-hi') 212 with self.assertRaises(ValueError): 213 self.create_db_writer(user_name='hi-') 214 with self.assertRaises(ValueError): 215 self.create_db_writer(user_name='@') 216 217 def testGraphSummary(self): 218 training_util.get_or_create_global_step() 219 name = 'hi' 220 graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) 221 with summary_ops.always_record_summaries(): 222 with self.create_db_writer().as_default(): 223 summary_ops.graph(graph) 224 six.assertCountEqual(self, [name], 225 get_all(self.db, 'SELECT node_name FROM Nodes')) 226 227 228 def get_tensor(db, tag_id, step): 229 cursor = db.execute( 230 'SELECT dtype, shape, data FROM Tensors WHERE series = ? AND step = ?', 231 (tag_id, step)) 232 dtype, shape, data = cursor.fetchone() 233 assert dtype in _NUMPY_NUMERIC_TYPES 234 buf = np.frombuffer(data, dtype=_NUMPY_NUMERIC_TYPES[dtype]) 235 if not shape: 236 return buf[0] 237 return buf.reshape([int(i) for i in shape.split(',')]) 238 239 240 def int64(x): 241 return array_ops.constant(x, dtypes.int64) 242 243 244 if __name__ == '__main__': 245 test.main() 246