Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BAvSIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for summary ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 import six
     23 
     24 from tensorflow.core.framework import summary_pb2
     25 from tensorflow.python.framework import constant_op
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import tensor_util
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import summary_ops
     30 from tensorflow.python.platform import test
     31 
     32 
     33 class SummaryOpsTest(test.TestCase):
     34 
     35   def _SummarySingleValue(self, s):
     36     summ = summary_pb2.Summary()
     37     summ.ParseFromString(s)
     38     self.assertEqual(len(summ.value), 1)
     39     return summ.value[0]
     40 
     41   def _AssertNumpyEq(self, actual, expected):
     42     self.assertTrue(np.array_equal(actual, expected))
     43 
     44   def testTags(self):
     45     with self.test_session() as sess:
     46       c = constant_op.constant(1)
     47       s1 = summary_ops.tensor_summary("s1", c)
     48       with ops.name_scope("foo"):
     49         s2 = summary_ops.tensor_summary("s2", c)
     50         with ops.name_scope("zod"):
     51           s3 = summary_ops.tensor_summary("s3", c)
     52           s4 = summary_ops.tensor_summary("TensorSummary", c)
     53       summ1, summ2, summ3, summ4 = sess.run([s1, s2, s3, s4])
     54 
     55     v1 = self._SummarySingleValue(summ1)
     56     self.assertEqual(v1.tag, "s1")
     57 
     58     v2 = self._SummarySingleValue(summ2)
     59     self.assertEqual(v2.tag, "foo/s2")
     60 
     61     v3 = self._SummarySingleValue(summ3)
     62     self.assertEqual(v3.tag, "foo/zod/s3")
     63 
     64     v4 = self._SummarySingleValue(summ4)
     65     self.assertEqual(v4.tag, "foo/zod/TensorSummary")
     66 
     67   def testScalarSummary(self):
     68     with self.test_session() as sess:
     69       const = constant_op.constant(10.0)
     70       summ = summary_ops.tensor_summary("foo", const)
     71       result = sess.run(summ)
     72 
     73     value = self._SummarySingleValue(result)
     74     n = tensor_util.MakeNdarray(value.tensor)
     75     self._AssertNumpyEq(n, 10)
     76 
     77   def testStringSummary(self):
     78     s = six.b("foobar")
     79     with self.test_session() as sess:
     80       const = constant_op.constant(s)
     81       summ = summary_ops.tensor_summary("foo", const)
     82       result = sess.run(summ)
     83 
     84     value = self._SummarySingleValue(result)
     85     n = tensor_util.MakeNdarray(value.tensor)
     86     self._AssertNumpyEq(n, s)
     87 
     88   def testManyScalarSummary(self):
     89     with self.test_session() as sess:
     90       const = array_ops.ones([5, 5, 5])
     91       summ = summary_ops.tensor_summary("foo", const)
     92       result = sess.run(summ)
     93     value = self._SummarySingleValue(result)
     94     n = tensor_util.MakeNdarray(value.tensor)
     95     self._AssertNumpyEq(n, np.ones([5, 5, 5]))
     96 
     97   def testManyStringSummary(self):
     98     strings = [[six.b("foo bar"), six.b("baz")], [six.b("zoink"), six.b("zod")]]
     99     with self.test_session() as sess:
    100       const = constant_op.constant(strings)
    101       summ = summary_ops.tensor_summary("foo", const)
    102       result = sess.run(summ)
    103     value = self._SummarySingleValue(result)
    104     n = tensor_util.MakeNdarray(value.tensor)
    105     self._AssertNumpyEq(n, strings)
    106 
    107   def testManyBools(self):
    108     bools = [True, True, True, False, False, False]
    109     with self.test_session() as sess:
    110       const = constant_op.constant(bools)
    111       summ = summary_ops.tensor_summary("foo", const)
    112       result = sess.run(summ)
    113 
    114     value = self._SummarySingleValue(result)
    115     n = tensor_util.MakeNdarray(value.tensor)
    116     self._AssertNumpyEq(n, bools)
    117 
    118   def testSummaryDescriptionAndDisplayName(self):
    119     with self.test_session() as sess:
    120 
    121       def get_description(summary_op):
    122         summ_str = sess.run(summary_op)
    123         summ = summary_pb2.Summary()
    124         summ.ParseFromString(summ_str)
    125         return summ.value[0].metadata
    126 
    127       const = constant_op.constant(1)
    128       # Default case; no description or display name
    129       simple_summary = summary_ops.tensor_summary("simple", const)
    130 
    131       descr = get_description(simple_summary)
    132       self.assertEqual(descr.display_name, "")
    133       self.assertEqual(descr.summary_description, "")
    134 
    135       # Values are provided via function args
    136       with_values = summary_ops.tensor_summary(
    137           "simple",
    138           const,
    139           display_name="my name",
    140           summary_description="my description")
    141 
    142       descr = get_description(with_values)
    143       self.assertEqual(descr.display_name, "my name")
    144       self.assertEqual(descr.summary_description, "my description")
    145 
    146       # Values are provided via the SummaryMetadata arg
    147       metadata = summary_pb2.SummaryMetadata()
    148       metadata.display_name = "my name"
    149       metadata.summary_description = "my description"
    150 
    151       with_metadata = summary_ops.tensor_summary(
    152           "simple", const, summary_metadata=metadata)
    153       descr = get_description(with_metadata)
    154       self.assertEqual(descr.display_name, "my name")
    155       self.assertEqual(descr.summary_description, "my description")
    156 
    157       # If both SummaryMetadata and explicit args are provided, the args win
    158       overwrite = summary_ops.tensor_summary(
    159           "simple",
    160           const,
    161           summary_metadata=metadata,
    162           display_name="overwritten",
    163           summary_description="overwritten")
    164       descr = get_description(overwrite)
    165       self.assertEqual(descr.display_name, "overwritten")
    166       self.assertEqual(descr.summary_description, "overwritten")
    167 
    168 
    169 if __name__ == "__main__":
    170   test.main()
    171