Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Fake summary writer for unit tests."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.core.framework import summary_pb2
     21 from tensorflow.python.summary.writer import writer
     22 from tensorflow.python.summary.writer import writer_cache
     23 
     24 
     25 # TODO(ptucker): Replace with mock framework.
     26 class FakeSummaryWriter(object):
     27   """Fake summary writer."""
     28 
     29   _replaced_summary_writer = None
     30 
     31   @classmethod
     32   def install(cls):
     33     if cls._replaced_summary_writer:
     34       raise ValueError('FakeSummaryWriter already installed.')
     35     cls._replaced_summary_writer = writer.FileWriter
     36     writer.FileWriter = FakeSummaryWriter
     37     writer_cache.FileWriter = FakeSummaryWriter
     38 
     39   @classmethod
     40   def uninstall(cls):
     41     if not cls._replaced_summary_writer:
     42       raise ValueError('FakeSummaryWriter not installed.')
     43     writer.FileWriter = cls._replaced_summary_writer
     44     writer_cache.FileWriter = cls._replaced_summary_writer
     45     cls._replaced_summary_writer = None
     46 
     47   def __init__(self, logdir, graph=None):
     48     self._logdir = logdir
     49     self._graph = graph
     50     self._summaries = {}
     51     self._added_graphs = []
     52     self._added_meta_graphs = []
     53     self._added_session_logs = []
     54 
     55   @property
     56   def summaries(self):
     57     return self._summaries
     58 
     59   def assert_summaries(self,
     60                        test_case,
     61                        expected_logdir=None,
     62                        expected_graph=None,
     63                        expected_summaries=None,
     64                        expected_added_graphs=None,
     65                        expected_added_meta_graphs=None,
     66                        expected_session_logs=None):
     67     """Assert expected items have been added to summary writer."""
     68     if expected_logdir is not None:
     69       test_case.assertEqual(expected_logdir, self._logdir)
     70     if expected_graph is not None:
     71       test_case.assertTrue(expected_graph is self._graph)
     72     expected_summaries = expected_summaries or {}
     73     for step in expected_summaries:
     74       test_case.assertTrue(
     75           step in self._summaries,
     76           msg='Missing step %s from %s.' % (step, self._summaries.keys()))
     77       actual_simple_values = {}
     78       for step_summary in self._summaries[step]:
     79         for v in step_summary.value:
     80           # Ignore global_step/sec since it's written by Supervisor in a
     81           # separate thread, so it's non-deterministic how many get written.
     82           if 'global_step/sec' != v.tag:
     83             actual_simple_values[v.tag] = v.simple_value
     84       test_case.assertEqual(expected_summaries[step], actual_simple_values)
     85     if expected_added_graphs is not None:
     86       test_case.assertEqual(expected_added_graphs, self._added_graphs)
     87     if expected_added_meta_graphs is not None:
     88       test_case.assertEqual(expected_added_meta_graphs, self._added_meta_graphs)
     89     if expected_session_logs is not None:
     90       test_case.assertEqual(expected_session_logs, self._added_session_logs)
     91 
     92   def add_summary(self, summ, current_global_step):
     93     """Add summary."""
     94     if isinstance(summ, bytes):
     95       summary_proto = summary_pb2.Summary()
     96       summary_proto.ParseFromString(summ)
     97       summ = summary_proto
     98     if current_global_step in self._summaries:
     99       step_summaries = self._summaries[current_global_step]
    100     else:
    101       step_summaries = []
    102       self._summaries[current_global_step] = step_summaries
    103     step_summaries.append(summ)
    104 
    105   # NOTE: Ignore global_step since its value is non-deterministic.
    106   def add_graph(self, graph, global_step=None, graph_def=None):
    107     """Add graph."""
    108     if (global_step is not None) and (global_step < 0):
    109       raise ValueError('Invalid global_step %s.' % global_step)
    110     if graph_def is not None:
    111       raise ValueError('Unexpected graph_def %s.' % graph_def)
    112     self._added_graphs.append(graph)
    113 
    114   def add_meta_graph(self, meta_graph_def, global_step=None):
    115     """Add metagraph."""
    116     if (global_step is not None) and (global_step < 0):
    117       raise ValueError('Invalid global_step %s.' % global_step)
    118     self._added_meta_graphs.append(meta_graph_def)
    119 
    120   # NOTE: Ignore global_step since its value is non-deterministic.
    121   def add_session_log(self, session_log, global_step=None):
    122     # pylint: disable=unused-argument
    123     self._added_session_logs.append(session_log)
    124 
    125   def flush(self):
    126     pass
    127 
    128   def reopen(self):
    129     pass
    130 
    131   def close(self):
    132     pass
    133