1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BAvSIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for summary ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 import six 23 24 from tensorflow.core.framework import summary_pb2 25 from tensorflow.python.framework import constant_op 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import tensor_util 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import summary_ops 30 from tensorflow.python.platform import test 31 32 33 class SummaryOpsTest(test.TestCase): 34 35 def _SummarySingleValue(self, s): 36 summ = summary_pb2.Summary() 37 summ.ParseFromString(s) 38 self.assertEqual(len(summ.value), 1) 39 return summ.value[0] 40 41 def _AssertNumpyEq(self, actual, expected): 42 self.assertTrue(np.array_equal(actual, expected)) 43 44 def testTags(self): 45 with self.test_session() as sess: 46 c = constant_op.constant(1) 47 s1 = summary_ops.tensor_summary("s1", c) 48 with ops.name_scope("foo"): 49 s2 = summary_ops.tensor_summary("s2", c) 50 with ops.name_scope("zod"): 51 s3 = summary_ops.tensor_summary("s3", c) 52 s4 = summary_ops.tensor_summary("TensorSummary", c) 53 summ1, summ2, summ3, summ4 = sess.run([s1, s2, s3, s4]) 54 55 v1 = self._SummarySingleValue(summ1) 56 self.assertEqual(v1.tag, "s1") 57 58 v2 = self._SummarySingleValue(summ2) 59 self.assertEqual(v2.tag, "foo/s2") 60 61 v3 = self._SummarySingleValue(summ3) 62 self.assertEqual(v3.tag, "foo/zod/s3") 63 64 v4 = self._SummarySingleValue(summ4) 65 self.assertEqual(v4.tag, "foo/zod/TensorSummary") 66 67 def testScalarSummary(self): 68 with self.test_session() as sess: 69 const = constant_op.constant(10.0) 70 summ = summary_ops.tensor_summary("foo", const) 71 result = sess.run(summ) 72 73 value = self._SummarySingleValue(result) 74 n = tensor_util.MakeNdarray(value.tensor) 75 self._AssertNumpyEq(n, 10) 76 77 def testStringSummary(self): 78 s = six.b("foobar") 79 with self.test_session() as sess: 80 const = constant_op.constant(s) 81 summ = summary_ops.tensor_summary("foo", const) 82 result = sess.run(summ) 83 84 value = self._SummarySingleValue(result) 85 n = tensor_util.MakeNdarray(value.tensor) 86 self._AssertNumpyEq(n, s) 87 88 def testManyScalarSummary(self): 89 with self.test_session() as sess: 90 const = array_ops.ones([5, 5, 5]) 91 summ = summary_ops.tensor_summary("foo", const) 92 result = sess.run(summ) 93 value = self._SummarySingleValue(result) 94 n = tensor_util.MakeNdarray(value.tensor) 95 self._AssertNumpyEq(n, np.ones([5, 5, 5])) 96 97 def testManyStringSummary(self): 98 strings = [[six.b("foo bar"), six.b("baz")], [six.b("zoink"), six.b("zod")]] 99 with self.test_session() as sess: 100 const = constant_op.constant(strings) 101 summ = summary_ops.tensor_summary("foo", const) 102 result = sess.run(summ) 103 value = self._SummarySingleValue(result) 104 n = tensor_util.MakeNdarray(value.tensor) 105 self._AssertNumpyEq(n, strings) 106 107 def testManyBools(self): 108 bools = [True, True, True, False, False, False] 109 with self.test_session() as sess: 110 const = constant_op.constant(bools) 111 summ = summary_ops.tensor_summary("foo", const) 112 result = sess.run(summ) 113 114 value = self._SummarySingleValue(result) 115 n = tensor_util.MakeNdarray(value.tensor) 116 self._AssertNumpyEq(n, bools) 117 118 def testSummaryDescriptionAndDisplayName(self): 119 with self.test_session() as sess: 120 121 def get_description(summary_op): 122 summ_str = sess.run(summary_op) 123 summ = summary_pb2.Summary() 124 summ.ParseFromString(summ_str) 125 return summ.value[0].metadata 126 127 const = constant_op.constant(1) 128 # Default case; no description or display name 129 simple_summary = summary_ops.tensor_summary("simple", const) 130 131 descr = get_description(simple_summary) 132 self.assertEqual(descr.display_name, "") 133 self.assertEqual(descr.summary_description, "") 134 135 # Values are provided via function args 136 with_values = summary_ops.tensor_summary( 137 "simple", 138 const, 139 display_name="my name", 140 summary_description="my description") 141 142 descr = get_description(with_values) 143 self.assertEqual(descr.display_name, "my name") 144 self.assertEqual(descr.summary_description, "my description") 145 146 # Values are provided via the SummaryMetadata arg 147 metadata = summary_pb2.SummaryMetadata() 148 metadata.display_name = "my name" 149 metadata.summary_description = "my description" 150 151 with_metadata = summary_ops.tensor_summary( 152 "simple", const, summary_metadata=metadata) 153 descr = get_description(with_metadata) 154 self.assertEqual(descr.display_name, "my name") 155 self.assertEqual(descr.summary_description, "my description") 156 157 # If both SummaryMetadata and explicit args are provided, the args win 158 overwrite = summary_ops.tensor_summary( 159 "simple", 160 const, 161 summary_metadata=metadata, 162 display_name="overwritten", 163 summary_description="overwritten") 164 descr = get_description(overwrite) 165 self.assertEqual(descr.display_name, "overwritten") 166 self.assertEqual(descr.summary_description, "overwritten") 167 168 169 if __name__ == "__main__": 170 test.main() 171