Home | History | Annotate | Download | only in tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 #
     15 # ==============================================================================
     16 """TensorFlow API compatibility tests.
     17 
     18 This test ensures all changes to the public API of TensorFlow are intended.
     19 
     20 If this test fails, it means a change has been made to the public API. Backwards
     21 incompatible changes are not allowed. You can run the test with
     22 "--update_goldens" flag set to "True" to update goldens when making changes to
     23 the public TF python API.
     24 """
     25 
     26 from __future__ import absolute_import
     27 from __future__ import division
     28 from __future__ import print_function
     29 
     30 import argparse
     31 import os
     32 import re
     33 import sys
     34 import unittest
     35 
     36 import tensorflow as tf
     37 
     38 from google.protobuf import text_format
     39 
     40 from tensorflow.python.lib.io import file_io
     41 from tensorflow.python.platform import resource_loader
     42 from tensorflow.python.platform import test
     43 from tensorflow.python.platform import tf_logging as logging
     44 from tensorflow.tools.api.lib import api_objects_pb2
     45 from tensorflow.tools.api.lib import python_object_to_proto_visitor
     46 from tensorflow.tools.common import public_api
     47 from tensorflow.tools.common import traverse
     48 
     49 # FLAGS defined at the bottom:
     50 FLAGS = None
     51 # DEFINE_boolean, update_goldens, default False:
     52 _UPDATE_GOLDENS_HELP = """
     53      Update stored golden files if API is updated. WARNING: All API changes
     54      have to be authorized by TensorFlow leads.
     55 """
     56 
     57 # DEFINE_boolean, verbose_diffs, default False:
     58 _VERBOSE_DIFFS_HELP = """
     59      If set to true, print line by line diffs on all libraries. If set to
     60      false, only print which libraries have differences.
     61 """
     62 
     63 _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden'
     64 _TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
     65 _UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
     66 
     67 
     68 def _KeyToFilePath(key):
     69   """From a given key, construct a filepath."""
     70   def _ReplaceCapsWithDash(matchobj):
     71     match = matchobj.group(0)
     72     return '-%s' % (match.lower())
     73 
     74   case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key)
     75   return os.path.join(_API_GOLDEN_FOLDER, '%s.pbtxt' % case_insensitive_key)
     76 
     77 
     78 def _FileNameToKey(filename):
     79   """From a given filename, construct a key we use for api objects."""
     80   def _ReplaceDashWithCaps(matchobj):
     81     match = matchobj.group(0)
     82     return match[1].upper()
     83 
     84   base_filename = os.path.basename(filename)
     85   base_filename_without_ext = os.path.splitext(base_filename)[0]
     86   api_object_key = re.sub(
     87       '((-[a-z]){1})', _ReplaceDashWithCaps, base_filename_without_ext)
     88   return api_object_key
     89 
     90 
     91 class ApiCompatibilityTest(test.TestCase):
     92 
     93   def __init__(self, *args, **kwargs):
     94     super(ApiCompatibilityTest, self).__init__(*args, **kwargs)
     95 
     96     golden_update_warning_filename = os.path.join(
     97         resource_loader.get_root_dir_with_all_resources(),
     98         _UPDATE_WARNING_FILE)
     99     self._update_golden_warning = file_io.read_file_to_string(
    100         golden_update_warning_filename)
    101 
    102     test_readme_filename = os.path.join(
    103         resource_loader.get_root_dir_with_all_resources(),
    104         _TEST_README_FILE)
    105     self._test_readme_message = file_io.read_file_to_string(
    106         test_readme_filename)
    107 
    108   def _AssertProtoDictEquals(self,
    109                              expected_dict,
    110                              actual_dict,
    111                              verbose=False,
    112                              update_goldens=False):
    113     """Diff given dicts of protobufs and report differences a readable way.
    114 
    115     Args:
    116       expected_dict: a dict of TFAPIObject protos constructed from golden
    117           files.
    118       actual_dict: a ict of TFAPIObject protos constructed by reading from the
    119           TF package linked to the test.
    120       verbose: Whether to log the full diffs, or simply report which files were
    121           different.
    122       update_goldens: Whether to update goldens when there are diffs found.
    123     """
    124     diffs = []
    125     verbose_diffs = []
    126 
    127     expected_keys = set(expected_dict.keys())
    128     actual_keys = set(actual_dict.keys())
    129     only_in_expected = expected_keys - actual_keys
    130     only_in_actual = actual_keys - expected_keys
    131     all_keys = expected_keys | actual_keys
    132 
    133     # This will be populated below.
    134     updated_keys = []
    135 
    136     for key in all_keys:
    137       diff_message = ''
    138       verbose_diff_message = ''
    139       # First check if the key is not found in one or the other.
    140       if key in only_in_expected:
    141         diff_message = 'Object %s expected but not found (removed).' % key
    142         verbose_diff_message = diff_message
    143       elif key in only_in_actual:
    144         diff_message = 'New object %s found (added).' % key
    145         verbose_diff_message = diff_message
    146       else:
    147         # Now we can run an actual proto diff.
    148         try:
    149           self.assertProtoEquals(expected_dict[key], actual_dict[key])
    150         except AssertionError as e:
    151           updated_keys.append(key)
    152           diff_message = 'Change detected in python object: %s.' % key
    153           verbose_diff_message = str(e)
    154 
    155       # All difference cases covered above. If any difference found, add to the
    156       # list.
    157       if diff_message:
    158         diffs.append(diff_message)
    159         verbose_diffs.append(verbose_diff_message)
    160 
    161     # If diffs are found, handle them based on flags.
    162     if diffs:
    163       diff_count = len(diffs)
    164       logging.error(self._test_readme_message)
    165       logging.error('%d differences found between API and golden.', diff_count)
    166       messages = verbose_diffs if verbose else diffs
    167       for i in range(diff_count):
    168         logging.error('Issue %d\t: %s', i + 1, messages[i])
    169 
    170       if update_goldens:
    171         # Write files if requested.
    172         logging.warning(self._update_golden_warning)
    173 
    174         # If the keys are only in expected, some objects are deleted.
    175         # Remove files.
    176         for key in only_in_expected:
    177           filepath = _KeyToFilePath(key)
    178           file_io.delete_file(filepath)
    179 
    180         # If the files are only in actual (current library), these are new
    181         # modules. Write them to files. Also record all updates in files.
    182         for key in only_in_actual | set(updated_keys):
    183           filepath = _KeyToFilePath(key)
    184           file_io.write_string_to_file(
    185               filepath, text_format.MessageToString(actual_dict[key]))
    186       else:
    187         # Fail if we cannot fix the test by updating goldens.
    188         self.fail('%d differences found between API and golden.' % diff_count)
    189 
    190     else:
    191       logging.info('No differences found between API and golden.')
    192 
    193   @unittest.skipUnless(
    194       sys.version_info.major == 2,
    195       'API compabitility test goldens are generated using python2.')
    196   def testAPIBackwardsCompatibility(self):
    197     # Extract all API stuff.
    198     visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
    199 
    200     public_api_visitor = public_api.PublicAPIVisitor(visitor)
    201     public_api_visitor.do_not_descend_map['tf'].append('contrib')
    202     public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    203     traverse.traverse(tf, public_api_visitor)
    204 
    205     proto_dict = visitor.GetProtos()
    206 
    207     # Read all golden files.
    208     expression = os.path.join(
    209         resource_loader.get_root_dir_with_all_resources(),
    210         _KeyToFilePath('*'))
    211     golden_file_list = file_io.get_matching_files(expression)
    212 
    213     def _ReadFileToProto(filename):
    214       """Read a filename, create a protobuf from its contents."""
    215       ret_val = api_objects_pb2.TFAPIObject()
    216       text_format.Merge(file_io.read_file_to_string(filename), ret_val)
    217       return ret_val
    218 
    219     golden_proto_dict = {
    220         _FileNameToKey(filename): _ReadFileToProto(filename)
    221         for filename in golden_file_list
    222     }
    223 
    224     # Diff them. Do not fail if called with update.
    225     # If the test is run to update goldens, only report diffs but do not fail.
    226     self._AssertProtoDictEquals(
    227         golden_proto_dict,
    228         proto_dict,
    229         verbose=FLAGS.verbose_diffs,
    230         update_goldens=FLAGS.update_goldens)
    231 
    232 
    233 if __name__ == '__main__':
    234   parser = argparse.ArgumentParser()
    235   parser.add_argument(
    236       '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP)
    237   parser.add_argument(
    238       '--verbose_diffs', type=bool, default=False, help=_VERBOSE_DIFFS_HELP)
    239   FLAGS, unparsed = parser.parse_known_args()
    240 
    241   # Now update argv, so that unittest library does not get confused.
    242   sys.argv = [sys.argv[0]] + unparsed
    243   test.main()
    244