1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Imports unittest as a replacement for testing.pybase.googletest.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import atexit 22 import itertools 23 import os 24 import sys 25 import tempfile 26 27 # go/tf-wildcard-import 28 # pylint: disable=wildcard-import 29 from unittest import * 30 # pylint: enable=wildcard-import 31 32 from tensorflow.python.framework import errors 33 from tensorflow.python.lib.io import file_io 34 from tensorflow.python.platform import app 35 from tensorflow.python.platform import benchmark 36 from tensorflow.python.platform import tf_logging as logging 37 from tensorflow.python.util import tf_decorator 38 from tensorflow.python.util import tf_inspect 39 40 41 Benchmark = benchmark.TensorFlowBenchmark # pylint: disable=invalid-name 42 43 unittest_main = main 44 45 # We keep a global variable in this module to make sure we create the temporary 46 # directory only once per test binary invocation. 47 _googletest_temp_dir = '' 48 49 50 # pylint: disable=invalid-name 51 # pylint: disable=undefined-variable 52 def g_main(argv): 53 """Delegate to unittest.main after redefining testLoader.""" 54 if 'TEST_SHARD_STATUS_FILE' in os.environ: 55 try: 56 f = None 57 try: 58 f = open(os.environ['TEST_SHARD_STATUS_FILE'], 'w') 59 f.write('') 60 except IOError: 61 sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.' 62 % os.environ['TEST_SHARD_STATUS_FILE']) 63 sys.exit(1) 64 finally: 65 if f is not None: f.close() 66 67 if ('TEST_TOTAL_SHARDS' not in os.environ or 68 'TEST_SHARD_INDEX' not in os.environ): 69 return unittest_main(argv=argv) 70 71 total_shards = int(os.environ['TEST_TOTAL_SHARDS']) 72 shard_index = int(os.environ['TEST_SHARD_INDEX']) 73 base_loader = TestLoader() 74 75 delegate_get_names = base_loader.getTestCaseNames 76 bucket_iterator = itertools.cycle(range(total_shards)) 77 78 def getShardedTestCaseNames(testCaseClass): 79 filtered_names = [] 80 for testcase in sorted(delegate_get_names(testCaseClass)): 81 bucket = next(bucket_iterator) 82 if bucket == shard_index: 83 filtered_names.append(testcase) 84 return filtered_names 85 86 # Override getTestCaseNames 87 base_loader.getTestCaseNames = getShardedTestCaseNames 88 89 unittest_main(argv=argv, testLoader=base_loader) 90 91 92 # Redefine main to allow running benchmarks 93 def main(argv=None): # pylint: disable=function-redefined 94 def main_wrapper(): 95 args = argv 96 if args is None: 97 args = sys.argv 98 return app.run(main=g_main, argv=args) 99 benchmark.benchmarks_main(true_main=main_wrapper) 100 101 102 def GetTempDir(): 103 """Return a temporary directory for tests to use.""" 104 global _googletest_temp_dir 105 if not _googletest_temp_dir: 106 first_frame = tf_inspect.stack()[-1][0] 107 temp_dir = os.path.join(tempfile.gettempdir(), 108 os.path.basename(tf_inspect.getfile(first_frame))) 109 temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py')) 110 111 def delete_temp_dir(dirname=temp_dir): 112 try: 113 file_io.delete_recursively(dirname) 114 except errors.OpError as e: 115 logging.error('Error removing %s: %s', dirname, e) 116 117 atexit.register(delete_temp_dir) 118 _googletest_temp_dir = temp_dir 119 120 return _googletest_temp_dir 121 122 123 def test_src_dir_path(relative_path): 124 """Creates an absolute test srcdir path given a relative path. 125 126 Args: 127 relative_path: a path relative to tensorflow root. 128 e.g. "contrib/session_bundle/example". 129 130 Returns: 131 An absolute path to the linked in runfiles. 132 """ 133 return os.path.join(os.environ['TEST_SRCDIR'], 134 'org_tensorflow/tensorflow', relative_path) 135 136 137 def StatefulSessionAvailable(): 138 return False 139 140 141 class StubOutForTesting(object): 142 """Support class for stubbing methods out for unit testing. 143 144 Sample Usage: 145 146 You want os.path.exists() to always return true during testing. 147 148 stubs = StubOutForTesting() 149 stubs.Set(os.path, 'exists', lambda x: 1) 150 ... 151 stubs.CleanUp() 152 153 The above changes os.path.exists into a lambda that returns 1. Once 154 the ... part of the code finishes, the CleanUp() looks up the old 155 value of os.path.exists and restores it. 156 """ 157 158 def __init__(self): 159 self.cache = [] 160 self.stubs = [] 161 162 def __del__(self): 163 """Do not rely on the destructor to undo your stubs. 164 165 You cannot guarantee exactly when the destructor will get called without 166 relying on implementation details of a Python VM that may change. 167 """ 168 self.CleanUp() 169 170 # __enter__ and __exit__ allow use as a context manager. 171 def __enter__(self): 172 return self 173 174 def __exit__(self, unused_exc_type, unused_exc_value, unused_tb): 175 self.CleanUp() 176 177 def CleanUp(self): 178 """Undoes all SmartSet() & Set() calls, restoring original definitions.""" 179 self.SmartUnsetAll() 180 self.UnsetAll() 181 182 def SmartSet(self, obj, attr_name, new_attr): 183 """Replace obj.attr_name with new_attr. 184 185 This method is smart and works at the module, class, and instance level 186 while preserving proper inheritance. It will not stub out C types however 187 unless that has been explicitly allowed by the type. 188 189 This method supports the case where attr_name is a staticmethod or a 190 classmethod of obj. 191 192 Notes: 193 - If obj is an instance, then it is its class that will actually be 194 stubbed. Note that the method Set() does not do that: if obj is 195 an instance, it (and not its class) will be stubbed. 196 - The stubbing is using the builtin getattr and setattr. So, the __get__ 197 and __set__ will be called when stubbing (TODO: A better idea would 198 probably be to manipulate obj.__dict__ instead of getattr() and 199 setattr()). 200 201 Args: 202 obj: The object whose attributes we want to modify. 203 attr_name: The name of the attribute to modify. 204 new_attr: The new value for the attribute. 205 206 Raises: 207 AttributeError: If the attribute cannot be found. 208 """ 209 _, obj = tf_decorator.unwrap(obj) 210 if (tf_inspect.ismodule(obj) or 211 (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)): 212 orig_obj = obj 213 orig_attr = getattr(obj, attr_name) 214 else: 215 if not tf_inspect.isclass(obj): 216 mro = list(tf_inspect.getmro(obj.__class__)) 217 else: 218 mro = list(tf_inspect.getmro(obj)) 219 220 mro.reverse() 221 222 orig_attr = None 223 found_attr = False 224 225 for cls in mro: 226 try: 227 orig_obj = cls 228 orig_attr = getattr(obj, attr_name) 229 found_attr = True 230 except AttributeError: 231 continue 232 233 if not found_attr: 234 raise AttributeError('Attribute not found.') 235 236 # Calling getattr() on a staticmethod transforms it to a 'normal' function. 237 # We need to ensure that we put it back as a staticmethod. 238 old_attribute = obj.__dict__.get(attr_name) 239 if old_attribute is not None and isinstance(old_attribute, staticmethod): 240 orig_attr = staticmethod(orig_attr) 241 242 self.stubs.append((orig_obj, attr_name, orig_attr)) 243 setattr(orig_obj, attr_name, new_attr) 244 245 def SmartUnsetAll(self): 246 """Reverses SmartSet() calls, restoring things to original definitions. 247 248 This method is automatically called when the StubOutForTesting() 249 object is deleted; there is no need to call it explicitly. 250 251 It is okay to call SmartUnsetAll() repeatedly, as later calls have 252 no effect if no SmartSet() calls have been made. 253 """ 254 for args in reversed(self.stubs): 255 setattr(*args) 256 257 self.stubs = [] 258 259 def Set(self, parent, child_name, new_child): 260 """In parent, replace child_name's old definition with new_child. 261 262 The parent could be a module when the child is a function at 263 module scope. Or the parent could be a class when a class' method 264 is being replaced. The named child is set to new_child, while the 265 prior definition is saved away for later, when UnsetAll() is 266 called. 267 268 This method supports the case where child_name is a staticmethod or a 269 classmethod of parent. 270 271 Args: 272 parent: The context in which the attribute child_name is to be changed. 273 child_name: The name of the attribute to change. 274 new_child: The new value of the attribute. 275 """ 276 old_child = getattr(parent, child_name) 277 278 old_attribute = parent.__dict__.get(child_name) 279 if old_attribute is not None and isinstance(old_attribute, staticmethod): 280 old_child = staticmethod(old_child) 281 282 self.cache.append((parent, old_child, child_name)) 283 setattr(parent, child_name, new_child) 284 285 def UnsetAll(self): 286 """Reverses Set() calls, restoring things to their original definitions. 287 288 This method is automatically called when the StubOutForTesting() 289 object is deleted; there is no need to call it explicitly. 290 291 It is okay to call UnsetAll() repeatedly, as later calls have no 292 effect if no Set() calls have been made. 293 """ 294 # Undo calls to Set() in reverse order, in case Set() was called on the 295 # same arguments repeatedly (want the original call to be last one undone) 296 for (parent, old_child, child_name) in reversed(self.cache): 297 setattr(parent, child_name, old_child) 298 self.cache = [] 299