Home | History | Annotate | Download | only in util
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for tf_should_use."""
     16 
     17 # pylint: disable=unused-import
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import contextlib
     23 import gc
     24 import sys
     25 
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.platform import test
     28 from tensorflow.python.platform import tf_logging
     29 from tensorflow.python.util import tf_should_use
     30 
     31 
     32 @contextlib.contextmanager
     33 def reroute_error(captured):
     34   """Temporarily reroute errors written to tf_logging.error into `captured`."""
     35   del captured[:]
     36   true_logger = tf_logging.error
     37   def capture_errors(*args, **unused_kwargs):
     38     captured.extend(args)
     39   tf_logging.error = capture_errors
     40   try:
     41     yield
     42   finally:
     43     tf_logging.error = true_logger
     44 
     45 
     46 class TfShouldUseTest(test.TestCase):
     47 
     48   def testAddShouldUseWarningWhenNotUsed(self):
     49     self.skipTest('b/65412899')
     50     c = constant_op.constant(0, name='blah0')
     51     captured = []
     52     with reroute_error(captured):
     53       def in_this_function():
     54         h = tf_should_use._add_should_use_warning(c)
     55         del h
     56       in_this_function()
     57     self.assertIn('Object was never used', '\n'.join(captured))
     58     self.assertIn('blah0:0', '\n'.join(captured))
     59     self.assertIn('in_this_function', '\n'.join(captured))
     60     gc.collect()
     61     self.assertFalse(gc.garbage)
     62 
     63   def _testAddShouldUseWarningWhenUsed(self, fn, name):
     64     c = constant_op.constant(0, name=name)
     65     captured = []
     66     with reroute_error(captured):
     67       h = tf_should_use._add_should_use_warning(c)
     68       fn(h)
     69       del h
     70     self.assertNotIn('Object was never used', '\n'.join(captured))
     71     self.assertNotIn('%s:0' % name, '\n'.join(captured))
     72 
     73   def testAddShouldUseWarningWhenUsedWithAdd(self):
     74     self.skipTest('b/65412899')
     75     def add(h):
     76       _ = h + 1
     77     self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
     78     gc.collect()
     79     self.assertFalse(gc.garbage)
     80 
     81   def testAddShouldUseWarningWhenUsedWithGetName(self):
     82     self.skipTest('b/65412899')
     83     def get_name(h):
     84       _ = h.name
     85     self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
     86     gc.collect()
     87     self.assertFalse(gc.garbage)
     88 
     89   def testShouldUseResult(self):
     90     self.skipTest('b/65412899')
     91     @tf_should_use.should_use_result
     92     def return_const(value):
     93       return constant_op.constant(value, name='blah2')
     94     captured = []
     95     with reroute_error(captured):
     96       return_const(0.0)
     97     self.assertIn('Object was never used', '\n'.join(captured))
     98     self.assertIn('blah2:0', '\n'.join(captured))
     99     self.assertIn('return_const', '\n'.join(captured))
    100     gc.collect()
    101     self.assertFalse(gc.garbage)
    102 
    103   def testShouldUseResultWhenNotReallyUsed(self):
    104     self.skipTest('b/65412899')
    105     @tf_should_use.should_use_result
    106     def return_const(value):
    107       return constant_op.constant(value, name='blah3')
    108     captured = []
    109     with reroute_error(captured):
    110       with self.test_session():
    111         return_const(0.0)
    112         # Creating another op and executing it does not mark the
    113         # unused op as being "used".
    114         v = constant_op.constant(1.0, name='meh')
    115         v.eval()
    116     self.assertIn('Object was never used', '\n'.join(captured))
    117     self.assertIn('blah3:0', '\n'.join(captured))
    118     self.assertIn('return_const', '\n'.join(captured))
    119     gc.collect()
    120     self.assertFalse(gc.garbage)
    121 
    122   # Tests that mark_used is available in the API.
    123   def testMarkUsed(self):
    124     @tf_should_use.should_use_result
    125     def return_const(value):
    126       return constant_op.constant(value, name='blah3')
    127     with self.test_session():
    128       return_const(0.0).mark_used()
    129 
    130 if __name__ == '__main__':
    131   test.main()
    132