Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Test utilities."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 import glob
     21 import os
     22 import numpy as np
     23 from tensorflow.core.framework import summary_pb2
     24 from tensorflow.python.training import summary_io
     25 
     26 
     27 def assert_summary(expected_tags, expected_simple_values, summary_proto):
     28   """Asserts summary contains the specified tags and values.
     29 
     30   Args:
     31     expected_tags: All tags in summary.
     32     expected_simple_values: Simply values for some tags.
     33     summary_proto: Summary to validate.
     34 
     35   Raises:
     36     ValueError: if expectations are not met.
     37   """
     38   actual_tags = set()
     39   for value in summary_proto.value:
     40     actual_tags.add(value.tag)
     41     if value.tag in expected_simple_values:
     42       expected = expected_simple_values[value.tag]
     43       actual = value.simple_value
     44       np.testing.assert_almost_equal(
     45           actual, expected, decimal=2, err_msg=value.tag)
     46   expected_tags = set(expected_tags)
     47   if expected_tags != actual_tags:
     48     raise ValueError('Expected tags %s, got %s.' % (expected_tags, actual_tags))
     49 
     50 
     51 def to_summary_proto(summary_str):
     52   """Create summary based on latest stats.
     53 
     54   Args:
     55     summary_str: Serialized summary.
     56   Returns:
     57     summary_pb2.Summary.
     58   Raises:
     59     ValueError: if tensor is not a valid summary tensor.
     60   """
     61   summary = summary_pb2.Summary()
     62   summary.ParseFromString(summary_str)
     63   return summary
     64 
     65 
     66 # TODO(ptucker): Move to a non-test package?
     67 def latest_event_file(base_dir):
     68   """Find latest event file in `base_dir`.
     69 
     70   Args:
     71     base_dir: Base directory in which TF event flies are stored.
     72   Returns:
     73     File path, or `None` if none exists.
     74   """
     75   file_paths = glob.glob(os.path.join(base_dir, 'events.*'))
     76   return sorted(file_paths)[-1] if file_paths else None
     77 
     78 
     79 def latest_events(base_dir):
     80   """Parse events from latest event file in base_dir.
     81 
     82   Args:
     83     base_dir: Base directory in which TF event flies are stored.
     84   Returns:
     85     Iterable of event protos.
     86   Raises:
     87     ValueError: if no event files exist under base_dir.
     88   """
     89   file_path = latest_event_file(base_dir)
     90   return summary_io.summary_iterator(file_path) if file_path else []
     91 
     92 
     93 def latest_summaries(base_dir):
     94   """Parse summary events from latest event file in base_dir.
     95 
     96   Args:
     97     base_dir: Base directory in which TF event flies are stored.
     98   Returns:
     99     List of event protos.
    100   Raises:
    101     ValueError: if no event files exist under base_dir.
    102   """
    103   return [e for e in latest_events(base_dir) if e.HasField('summary')]
    104 
    105 
    106 def simple_values_from_events(events, tags):
    107   """Parse summaries from events with simple_value.
    108 
    109   Args:
    110     events: List of tensorflow.Event protos.
    111     tags: List of string event tags corresponding to simple_value summaries.
    112   Returns:
    113     dict of tag:value.
    114   Raises:
    115    ValueError: if a summary with a specified tag does not contain simple_value.
    116   """
    117   step_by_tag = {}
    118   value_by_tag = {}
    119   for e in events:
    120     if e.HasField('summary'):
    121       for v in e.summary.value:
    122         tag = v.tag
    123         if tag in tags:
    124           if not v.HasField('simple_value'):
    125             raise ValueError('Summary for %s is not a simple_value.' % tag)
    126           # The events are mostly sorted in step order, but we explicitly check
    127           # just in case.
    128           if tag not in step_by_tag or e.step > step_by_tag[tag]:
    129             step_by_tag[tag] = e.step
    130             value_by_tag[tag] = v.simple_value
    131   return value_by_tag
    132