Home | History | Annotate | Download | only in platform
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Imports unittest as a replacement for testing.pybase.googletest."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import atexit
     22 import itertools
     23 import os
     24 import sys
     25 import tempfile
     26 
     27 # go/tf-wildcard-import
     28 # pylint: disable=wildcard-import
     29 from unittest import *
     30 # pylint: enable=wildcard-import
     31 
     32 from tensorflow.python.framework import errors
     33 from tensorflow.python.lib.io import file_io
     34 from tensorflow.python.platform import app
     35 from tensorflow.python.platform import benchmark
     36 from tensorflow.python.platform import tf_logging as logging
     37 from tensorflow.python.util import tf_decorator
     38 from tensorflow.python.util import tf_inspect
     39 
     40 
     41 Benchmark = benchmark.TensorFlowBenchmark  # pylint: disable=invalid-name
     42 
     43 unittest_main = main
     44 
     45 # We keep a global variable in this module to make sure we create the temporary
     46 # directory only once per test binary invocation.
     47 _googletest_temp_dir = ''
     48 
     49 
     50 # pylint: disable=invalid-name
     51 # pylint: disable=undefined-variable
     52 def g_main(argv):
     53   """Delegate to unittest.main after redefining testLoader."""
     54   if 'TEST_SHARD_STATUS_FILE' in os.environ:
     55     try:
     56       f = None
     57       try:
     58         f = open(os.environ['TEST_SHARD_STATUS_FILE'], 'w')
     59         f.write('')
     60       except IOError:
     61         sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.'
     62                          % os.environ['TEST_SHARD_STATUS_FILE'])
     63         sys.exit(1)
     64     finally:
     65       if f is not None: f.close()
     66 
     67   if ('TEST_TOTAL_SHARDS' not in os.environ or
     68       'TEST_SHARD_INDEX' not in os.environ):
     69     return unittest_main(argv=argv)
     70 
     71   total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
     72   shard_index = int(os.environ['TEST_SHARD_INDEX'])
     73   base_loader = TestLoader()
     74 
     75   delegate_get_names = base_loader.getTestCaseNames
     76   bucket_iterator = itertools.cycle(range(total_shards))
     77 
     78   def getShardedTestCaseNames(testCaseClass):
     79     filtered_names = []
     80     for testcase in sorted(delegate_get_names(testCaseClass)):
     81       bucket = next(bucket_iterator)
     82       if bucket == shard_index:
     83         filtered_names.append(testcase)
     84     return filtered_names
     85 
     86   # Override getTestCaseNames
     87   base_loader.getTestCaseNames = getShardedTestCaseNames
     88 
     89   unittest_main(argv=argv, testLoader=base_loader)
     90 
     91 
     92 # Redefine main to allow running benchmarks
     93 def main(argv=None):  # pylint: disable=function-redefined
     94   def main_wrapper():
     95     args = argv
     96     if args is None:
     97       args = sys.argv
     98     return app.run(main=g_main, argv=args)
     99   benchmark.benchmarks_main(true_main=main_wrapper)
    100 
    101 
    102 def GetTempDir():
    103   """Return a temporary directory for tests to use."""
    104   global _googletest_temp_dir
    105   if not _googletest_temp_dir:
    106     first_frame = tf_inspect.stack()[-1][0]
    107     temp_dir = os.path.join(tempfile.gettempdir(),
    108                             os.path.basename(tf_inspect.getfile(first_frame)))
    109     temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py'))
    110 
    111     def delete_temp_dir(dirname=temp_dir):
    112       try:
    113         file_io.delete_recursively(dirname)
    114       except errors.OpError as e:
    115         logging.error('Error removing %s: %s', dirname, e)
    116 
    117     atexit.register(delete_temp_dir)
    118     _googletest_temp_dir = temp_dir
    119 
    120   return _googletest_temp_dir
    121 
    122 
    123 def test_src_dir_path(relative_path):
    124   """Creates an absolute test srcdir path given a relative path.
    125 
    126   Args:
    127     relative_path: a path relative to tensorflow root.
    128       e.g. "contrib/session_bundle/example".
    129 
    130   Returns:
    131     An absolute path to the linked in runfiles.
    132   """
    133   return os.path.join(os.environ['TEST_SRCDIR'],
    134                       'org_tensorflow/tensorflow', relative_path)
    135 
    136 
    137 def StatefulSessionAvailable():
    138   return False
    139 
    140 
    141 class StubOutForTesting(object):
    142   """Support class for stubbing methods out for unit testing.
    143 
    144   Sample Usage:
    145 
    146   You want os.path.exists() to always return true during testing.
    147 
    148      stubs = StubOutForTesting()
    149      stubs.Set(os.path, 'exists', lambda x: 1)
    150        ...
    151      stubs.CleanUp()
    152 
    153   The above changes os.path.exists into a lambda that returns 1.  Once
    154   the ... part of the code finishes, the CleanUp() looks up the old
    155   value of os.path.exists and restores it.
    156   """
    157 
    158   def __init__(self):
    159     self.cache = []
    160     self.stubs = []
    161 
    162   def __del__(self):
    163     """Do not rely on the destructor to undo your stubs.
    164 
    165     You cannot guarantee exactly when the destructor will get called without
    166     relying on implementation details of a Python VM that may change.
    167     """
    168     self.CleanUp()
    169 
    170   # __enter__ and __exit__ allow use as a context manager.
    171   def __enter__(self):
    172     return self
    173 
    174   def __exit__(self, unused_exc_type, unused_exc_value, unused_tb):
    175     self.CleanUp()
    176 
    177   def CleanUp(self):
    178     """Undoes all SmartSet() & Set() calls, restoring original definitions."""
    179     self.SmartUnsetAll()
    180     self.UnsetAll()
    181 
    182   def SmartSet(self, obj, attr_name, new_attr):
    183     """Replace obj.attr_name with new_attr.
    184 
    185     This method is smart and works at the module, class, and instance level
    186     while preserving proper inheritance. It will not stub out C types however
    187     unless that has been explicitly allowed by the type.
    188 
    189     This method supports the case where attr_name is a staticmethod or a
    190     classmethod of obj.
    191 
    192     Notes:
    193       - If obj is an instance, then it is its class that will actually be
    194         stubbed. Note that the method Set() does not do that: if obj is
    195         an instance, it (and not its class) will be stubbed.
    196       - The stubbing is using the builtin getattr and setattr. So, the __get__
    197         and __set__ will be called when stubbing (TODO: A better idea would
    198         probably be to manipulate obj.__dict__ instead of getattr() and
    199         setattr()).
    200 
    201     Args:
    202       obj: The object whose attributes we want to modify.
    203       attr_name: The name of the attribute to modify.
    204       new_attr: The new value for the attribute.
    205 
    206     Raises:
    207       AttributeError: If the attribute cannot be found.
    208     """
    209     _, obj = tf_decorator.unwrap(obj)
    210     if (tf_inspect.ismodule(obj) or
    211         (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)):
    212       orig_obj = obj
    213       orig_attr = getattr(obj, attr_name)
    214     else:
    215       if not tf_inspect.isclass(obj):
    216         mro = list(tf_inspect.getmro(obj.__class__))
    217       else:
    218         mro = list(tf_inspect.getmro(obj))
    219 
    220       mro.reverse()
    221 
    222       orig_attr = None
    223       found_attr = False
    224 
    225       for cls in mro:
    226         try:
    227           orig_obj = cls
    228           orig_attr = getattr(obj, attr_name)
    229           found_attr = True
    230         except AttributeError:
    231           continue
    232 
    233       if not found_attr:
    234         raise AttributeError('Attribute not found.')
    235 
    236     # Calling getattr() on a staticmethod transforms it to a 'normal' function.
    237     # We need to ensure that we put it back as a staticmethod.
    238     old_attribute = obj.__dict__.get(attr_name)
    239     if old_attribute is not None and isinstance(old_attribute, staticmethod):
    240       orig_attr = staticmethod(orig_attr)
    241 
    242     self.stubs.append((orig_obj, attr_name, orig_attr))
    243     setattr(orig_obj, attr_name, new_attr)
    244 
    245   def SmartUnsetAll(self):
    246     """Reverses SmartSet() calls, restoring things to original definitions.
    247 
    248     This method is automatically called when the StubOutForTesting()
    249     object is deleted; there is no need to call it explicitly.
    250 
    251     It is okay to call SmartUnsetAll() repeatedly, as later calls have
    252     no effect if no SmartSet() calls have been made.
    253     """
    254     for args in reversed(self.stubs):
    255       setattr(*args)
    256 
    257     self.stubs = []
    258 
    259   def Set(self, parent, child_name, new_child):
    260     """In parent, replace child_name's old definition with new_child.
    261 
    262     The parent could be a module when the child is a function at
    263     module scope.  Or the parent could be a class when a class' method
    264     is being replaced.  The named child is set to new_child, while the
    265     prior definition is saved away for later, when UnsetAll() is
    266     called.
    267 
    268     This method supports the case where child_name is a staticmethod or a
    269     classmethod of parent.
    270 
    271     Args:
    272       parent: The context in which the attribute child_name is to be changed.
    273       child_name: The name of the attribute to change.
    274       new_child: The new value of the attribute.
    275     """
    276     old_child = getattr(parent, child_name)
    277 
    278     old_attribute = parent.__dict__.get(child_name)
    279     if old_attribute is not None and isinstance(old_attribute, staticmethod):
    280       old_child = staticmethod(old_child)
    281 
    282     self.cache.append((parent, old_child, child_name))
    283     setattr(parent, child_name, new_child)
    284 
    285   def UnsetAll(self):
    286     """Reverses Set() calls, restoring things to their original definitions.
    287 
    288     This method is automatically called when the StubOutForTesting()
    289     object is deleted; there is no need to call it explicitly.
    290 
    291     It is okay to call UnsetAll() repeatedly, as later calls have no
    292     effect if no Set() calls have been made.
    293     """
    294     # Undo calls to Set() in reverse order, in case Set() was called on the
    295     # same arguments repeatedly (want the original call to be last one undone)
    296     for (parent, old_child, child_name) in reversed(self.cache):
    297       setattr(parent, child_name, old_child)
    298     self.cache = []
    299