1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # ============================================================================== 16 """TensorFlow API compatibility tests. 17 18 This test ensures all changes to the public API of TensorFlow are intended. 19 20 If this test fails, it means a change has been made to the public API. Backwards 21 incompatible changes are not allowed. You can run the test with 22 "--update_goldens" flag set to "True" to update goldens when making changes to 23 the public TF python API. 24 """ 25 26 from __future__ import absolute_import 27 from __future__ import division 28 from __future__ import print_function 29 30 import argparse 31 import os 32 import re 33 import sys 34 import unittest 35 36 import tensorflow as tf 37 38 from google.protobuf import text_format 39 40 from tensorflow.python.lib.io import file_io 41 from tensorflow.python.platform import resource_loader 42 from tensorflow.python.platform import test 43 from tensorflow.python.platform import tf_logging as logging 44 from tensorflow.tools.api.lib import api_objects_pb2 45 from tensorflow.tools.api.lib import python_object_to_proto_visitor 46 from tensorflow.tools.common import public_api 47 from tensorflow.tools.common import traverse 48 49 # FLAGS defined at the bottom: 50 FLAGS = None 51 # DEFINE_boolean, update_goldens, default False: 52 _UPDATE_GOLDENS_HELP = """ 53 Update stored golden files if API is updated. WARNING: All API changes 54 have to be authorized by TensorFlow leads. 55 """ 56 57 # DEFINE_boolean, verbose_diffs, default False: 58 _VERBOSE_DIFFS_HELP = """ 59 If set to true, print line by line diffs on all libraries. If set to 60 false, only print which libraries have differences. 61 """ 62 63 _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden' 64 _TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt' 65 _UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt' 66 67 68 def _KeyToFilePath(key): 69 """From a given key, construct a filepath.""" 70 def _ReplaceCapsWithDash(matchobj): 71 match = matchobj.group(0) 72 return '-%s' % (match.lower()) 73 74 case_insensitive_key = re.sub('([A-Z]{1})', _ReplaceCapsWithDash, key) 75 return os.path.join(_API_GOLDEN_FOLDER, '%s.pbtxt' % case_insensitive_key) 76 77 78 def _FileNameToKey(filename): 79 """From a given filename, construct a key we use for api objects.""" 80 def _ReplaceDashWithCaps(matchobj): 81 match = matchobj.group(0) 82 return match[1].upper() 83 84 base_filename = os.path.basename(filename) 85 base_filename_without_ext = os.path.splitext(base_filename)[0] 86 api_object_key = re.sub( 87 '((-[a-z]){1})', _ReplaceDashWithCaps, base_filename_without_ext) 88 return api_object_key 89 90 91 class ApiCompatibilityTest(test.TestCase): 92 93 def __init__(self, *args, **kwargs): 94 super(ApiCompatibilityTest, self).__init__(*args, **kwargs) 95 96 golden_update_warning_filename = os.path.join( 97 resource_loader.get_root_dir_with_all_resources(), 98 _UPDATE_WARNING_FILE) 99 self._update_golden_warning = file_io.read_file_to_string( 100 golden_update_warning_filename) 101 102 test_readme_filename = os.path.join( 103 resource_loader.get_root_dir_with_all_resources(), 104 _TEST_README_FILE) 105 self._test_readme_message = file_io.read_file_to_string( 106 test_readme_filename) 107 108 def _AssertProtoDictEquals(self, 109 expected_dict, 110 actual_dict, 111 verbose=False, 112 update_goldens=False): 113 """Diff given dicts of protobufs and report differences a readable way. 114 115 Args: 116 expected_dict: a dict of TFAPIObject protos constructed from golden 117 files. 118 actual_dict: a ict of TFAPIObject protos constructed by reading from the 119 TF package linked to the test. 120 verbose: Whether to log the full diffs, or simply report which files were 121 different. 122 update_goldens: Whether to update goldens when there are diffs found. 123 """ 124 diffs = [] 125 verbose_diffs = [] 126 127 expected_keys = set(expected_dict.keys()) 128 actual_keys = set(actual_dict.keys()) 129 only_in_expected = expected_keys - actual_keys 130 only_in_actual = actual_keys - expected_keys 131 all_keys = expected_keys | actual_keys 132 133 # This will be populated below. 134 updated_keys = [] 135 136 for key in all_keys: 137 diff_message = '' 138 verbose_diff_message = '' 139 # First check if the key is not found in one or the other. 140 if key in only_in_expected: 141 diff_message = 'Object %s expected but not found (removed).' % key 142 verbose_diff_message = diff_message 143 elif key in only_in_actual: 144 diff_message = 'New object %s found (added).' % key 145 verbose_diff_message = diff_message 146 else: 147 # Now we can run an actual proto diff. 148 try: 149 self.assertProtoEquals(expected_dict[key], actual_dict[key]) 150 except AssertionError as e: 151 updated_keys.append(key) 152 diff_message = 'Change detected in python object: %s.' % key 153 verbose_diff_message = str(e) 154 155 # All difference cases covered above. If any difference found, add to the 156 # list. 157 if diff_message: 158 diffs.append(diff_message) 159 verbose_diffs.append(verbose_diff_message) 160 161 # If diffs are found, handle them based on flags. 162 if diffs: 163 diff_count = len(diffs) 164 logging.error(self._test_readme_message) 165 logging.error('%d differences found between API and golden.', diff_count) 166 messages = verbose_diffs if verbose else diffs 167 for i in range(diff_count): 168 logging.error('Issue %d\t: %s', i + 1, messages[i]) 169 170 if update_goldens: 171 # Write files if requested. 172 logging.warning(self._update_golden_warning) 173 174 # If the keys are only in expected, some objects are deleted. 175 # Remove files. 176 for key in only_in_expected: 177 filepath = _KeyToFilePath(key) 178 file_io.delete_file(filepath) 179 180 # If the files are only in actual (current library), these are new 181 # modules. Write them to files. Also record all updates in files. 182 for key in only_in_actual | set(updated_keys): 183 filepath = _KeyToFilePath(key) 184 file_io.write_string_to_file( 185 filepath, text_format.MessageToString(actual_dict[key])) 186 else: 187 # Fail if we cannot fix the test by updating goldens. 188 self.fail('%d differences found between API and golden.' % diff_count) 189 190 else: 191 logging.info('No differences found between API and golden.') 192 193 @unittest.skipUnless( 194 sys.version_info.major == 2, 195 'API compabitility test goldens are generated using python2.') 196 def testAPIBackwardsCompatibility(self): 197 # Extract all API stuff. 198 visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() 199 200 public_api_visitor = public_api.PublicAPIVisitor(visitor) 201 public_api_visitor.do_not_descend_map['tf'].append('contrib') 202 public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] 203 traverse.traverse(tf, public_api_visitor) 204 205 proto_dict = visitor.GetProtos() 206 207 # Read all golden files. 208 expression = os.path.join( 209 resource_loader.get_root_dir_with_all_resources(), 210 _KeyToFilePath('*')) 211 golden_file_list = file_io.get_matching_files(expression) 212 213 def _ReadFileToProto(filename): 214 """Read a filename, create a protobuf from its contents.""" 215 ret_val = api_objects_pb2.TFAPIObject() 216 text_format.Merge(file_io.read_file_to_string(filename), ret_val) 217 return ret_val 218 219 golden_proto_dict = { 220 _FileNameToKey(filename): _ReadFileToProto(filename) 221 for filename in golden_file_list 222 } 223 224 # Diff them. Do not fail if called with update. 225 # If the test is run to update goldens, only report diffs but do not fail. 226 self._AssertProtoDictEquals( 227 golden_proto_dict, 228 proto_dict, 229 verbose=FLAGS.verbose_diffs, 230 update_goldens=FLAGS.update_goldens) 231 232 233 if __name__ == '__main__': 234 parser = argparse.ArgumentParser() 235 parser.add_argument( 236 '--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP) 237 parser.add_argument( 238 '--verbose_diffs', type=bool, default=False, help=_VERBOSE_DIFFS_HELP) 239 FLAGS, unparsed = parser.parse_known_args() 240 241 # Now update argv, so that unittest library does not get confused. 242 sys.argv = [sys.argv[0]] + unparsed 243 test.main() 244