1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 from __future__ import absolute_import 16 from __future__ import division 17 from __future__ import print_function 18 19 import tempfile 20 21 import six 22 23 from tensorflow.contrib.summary import summary_ops 24 from tensorflow.contrib.summary import summary_test_util 25 from tensorflow.core.framework import graph_pb2 26 from tensorflow.core.framework import node_def_pb2 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import ops 30 from tensorflow.python.ops import array_ops 31 from tensorflow.python.ops import control_flow_ops 32 from tensorflow.python.ops import state_ops 33 from tensorflow.python.platform import test 34 from tensorflow.python.training import training_util 35 36 get_all = summary_test_util.get_all 37 38 39 class DbTest(summary_test_util.SummaryDbTest): 40 41 def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self): 42 with self.assertRaises(TypeError): 43 summary_ops.graph(ops.Graph()) 44 with self.assertRaises(TypeError): 45 summary_ops.graph('') 46 47 def testGraphSummary(self): 48 training_util.get_or_create_global_step() 49 name = 'hi' 50 graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),)) 51 with self.test_session(): 52 with self.create_db_writer().as_default(): 53 summary_ops.initialize(graph=graph) 54 six.assertCountEqual(self, [name], 55 get_all(self.db, 'SELECT node_name FROM Nodes')) 56 57 def testScalarSummary(self): 58 """Test record_summaries_every_n_global_steps and all_summaries().""" 59 with ops.Graph().as_default(), self.test_session() as sess: 60 global_step = training_util.get_or_create_global_step() 61 global_step.initializer.run() 62 with ops.device('/cpu:0'): 63 step_increment = state_ops.assign_add(global_step, 1) 64 sess.run(step_increment) # Increment global step from 0 to 1 65 66 logdir = tempfile.mkdtemp() 67 with summary_ops.create_file_writer(logdir, max_queue=0, 68 name='t2').as_default(): 69 with summary_ops.record_summaries_every_n_global_steps(2): 70 summary_ops.initialize() 71 summary_op = summary_ops.scalar('my_scalar', 2.0) 72 73 # Neither of these should produce a summary because 74 # global_step is 1 and "1 % 2 != 0" 75 sess.run(summary_ops.all_summary_ops()) 76 sess.run(summary_op) 77 events = summary_test_util.events_from_logdir(logdir) 78 self.assertEqual(len(events), 1) 79 80 # Increment global step from 1 to 2 and check that the summary 81 # is now written 82 sess.run(step_increment) 83 sess.run(summary_ops.all_summary_ops()) 84 events = summary_test_util.events_from_logdir(logdir) 85 self.assertEqual(len(events), 2) 86 self.assertEqual(events[1].summary.value[0].tag, 'my_scalar') 87 88 def testSummaryGraphModeCond(self): 89 with ops.Graph().as_default(), self.test_session(): 90 training_util.get_or_create_global_step() 91 logdir = tempfile.mkdtemp() 92 with summary_ops.create_file_writer( 93 logdir, max_queue=0, 94 name='t2').as_default(), summary_ops.always_record_summaries(): 95 summary_ops.initialize() 96 training_util.get_or_create_global_step().initializer.run() 97 def f(): 98 summary_ops.scalar('scalar', 2.0) 99 return constant_op.constant(True) 100 pred = array_ops.placeholder(dtypes.bool) 101 x = control_flow_ops.cond(pred, f, 102 lambda: constant_op.constant(False)) 103 x.eval(feed_dict={pred: True}) 104 105 events = summary_test_util.events_from_logdir(logdir) 106 self.assertEqual(len(events), 2) 107 self.assertEqual(events[1].summary.value[0].tag, 'cond/scalar') 108 109 def testSummaryGraphModeWhile(self): 110 with ops.Graph().as_default(), self.test_session(): 111 training_util.get_or_create_global_step() 112 logdir = tempfile.mkdtemp() 113 with summary_ops.create_file_writer( 114 logdir, max_queue=0, 115 name='t2').as_default(), summary_ops.always_record_summaries(): 116 summary_ops.initialize() 117 training_util.get_or_create_global_step().initializer.run() 118 def body(unused_pred): 119 summary_ops.scalar('scalar', 2.0) 120 return constant_op.constant(False) 121 def cond(pred): 122 return pred 123 pred = array_ops.placeholder(dtypes.bool) 124 x = control_flow_ops.while_loop(cond, body, [pred]) 125 x.eval(feed_dict={pred: True}) 126 127 events = summary_test_util.events_from_logdir(logdir) 128 self.assertEqual(len(events), 2) 129 self.assertEqual(events[1].summary.value[0].tag, 'while/scalar') 130 131 132 if __name__ == '__main__': 133 test.main() 134