1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Fake summary writer for unit tests.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.core.framework import summary_pb2 21 from tensorflow.python.summary.writer import writer 22 from tensorflow.python.summary.writer import writer_cache 23 24 25 # TODO(ptucker): Replace with mock framework. 26 class FakeSummaryWriter(object): 27 """Fake summary writer.""" 28 29 _replaced_summary_writer = None 30 31 @classmethod 32 def install(cls): 33 if cls._replaced_summary_writer: 34 raise ValueError('FakeSummaryWriter already installed.') 35 cls._replaced_summary_writer = writer.FileWriter 36 writer.FileWriter = FakeSummaryWriter 37 writer_cache.FileWriter = FakeSummaryWriter 38 39 @classmethod 40 def uninstall(cls): 41 if not cls._replaced_summary_writer: 42 raise ValueError('FakeSummaryWriter not installed.') 43 writer.FileWriter = cls._replaced_summary_writer 44 writer_cache.FileWriter = cls._replaced_summary_writer 45 cls._replaced_summary_writer = None 46 47 def __init__(self, logdir, graph=None): 48 self._logdir = logdir 49 self._graph = graph 50 self._summaries = {} 51 self._added_graphs = [] 52 self._added_meta_graphs = [] 53 self._added_session_logs = [] 54 55 @property 56 def summaries(self): 57 return self._summaries 58 59 def assert_summaries(self, 60 test_case, 61 expected_logdir=None, 62 expected_graph=None, 63 expected_summaries=None, 64 expected_added_graphs=None, 65 expected_added_meta_graphs=None, 66 expected_session_logs=None): 67 """Assert expected items have been added to summary writer.""" 68 if expected_logdir is not None: 69 test_case.assertEqual(expected_logdir, self._logdir) 70 if expected_graph is not None: 71 test_case.assertTrue(expected_graph is self._graph) 72 expected_summaries = expected_summaries or {} 73 for step in expected_summaries: 74 test_case.assertTrue( 75 step in self._summaries, 76 msg='Missing step %s from %s.' % (step, self._summaries.keys())) 77 actual_simple_values = {} 78 for step_summary in self._summaries[step]: 79 for v in step_summary.value: 80 # Ignore global_step/sec since it's written by Supervisor in a 81 # separate thread, so it's non-deterministic how many get written. 82 if 'global_step/sec' != v.tag: 83 actual_simple_values[v.tag] = v.simple_value 84 test_case.assertEqual(expected_summaries[step], actual_simple_values) 85 if expected_added_graphs is not None: 86 test_case.assertEqual(expected_added_graphs, self._added_graphs) 87 if expected_added_meta_graphs is not None: 88 test_case.assertEqual(expected_added_meta_graphs, self._added_meta_graphs) 89 if expected_session_logs is not None: 90 test_case.assertEqual(expected_session_logs, self._added_session_logs) 91 92 def add_summary(self, summ, current_global_step): 93 """Add summary.""" 94 if isinstance(summ, bytes): 95 summary_proto = summary_pb2.Summary() 96 summary_proto.ParseFromString(summ) 97 summ = summary_proto 98 if current_global_step in self._summaries: 99 step_summaries = self._summaries[current_global_step] 100 else: 101 step_summaries = [] 102 self._summaries[current_global_step] = step_summaries 103 step_summaries.append(summ) 104 105 # NOTE: Ignore global_step since its value is non-deterministic. 106 def add_graph(self, graph, global_step=None, graph_def=None): 107 """Add graph.""" 108 if (global_step is not None) and (global_step < 0): 109 raise ValueError('Invalid global_step %s.' % global_step) 110 if graph_def is not None: 111 raise ValueError('Unexpected graph_def %s.' % graph_def) 112 self._added_graphs.append(graph) 113 114 def add_meta_graph(self, meta_graph_def, global_step=None): 115 """Add metagraph.""" 116 if (global_step is not None) and (global_step < 0): 117 raise ValueError('Invalid global_step %s.' % global_step) 118 self._added_meta_graphs.append(meta_graph_def) 119 120 # NOTE: Ignore global_step since its value is non-deterministic. 121 def add_session_log(self, session_log, global_step=None): 122 # pylint: disable=unused-argument 123 self._added_session_logs.append(session_log) 124 125 def flush(self): 126 pass 127 128 def reopen(self): 129 pass 130 131 def close(self): 132 pass 133