Home | History | Annotate | Download | only in util
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for tf_should_use."""
     16 
     17 # pylint: disable=unused-import
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import contextlib
     23 import gc
     24 import sys
     25 
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import test_util
     28 from tensorflow.python.platform import test
     29 from tensorflow.python.platform import tf_logging
     30 from tensorflow.python.util import tf_should_use
     31 
     32 
     33 @contextlib.contextmanager
     34 def reroute_error():
     35   """Temporarily reroute errors written to tf_logging.error into `captured`."""
     36   with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error:
     37     with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal:
     38       yield error, fatal
     39 
     40 
     41 class TfShouldUseTest(test.TestCase):
     42 
     43   @test_util.run_deprecated_v1
     44   def testAddShouldUseWarningWhenNotUsed(self):
     45     c = constant_op.constant(0, name='blah0')
     46     def in_this_function():
     47       h = tf_should_use._add_should_use_warning(c)
     48       del h
     49     with reroute_error() as (error, _):
     50       in_this_function()
     51     msg = '\n'.join(error.call_args[0])
     52     self.assertIn('Object was never used', msg)
     53     self.assertIn('blah0:0', msg)
     54     self.assertIn('in_this_function', msg)
     55     self.assertFalse(gc.garbage)
     56 
     57   @test_util.run_deprecated_v1
     58   def testAddShouldUseFatalWhenNotUsed(self):
     59     c = constant_op.constant(0, name='blah0')
     60     def in_this_function():
     61       h = tf_should_use._add_should_use_warning(c, fatal_error=True)
     62       del h
     63     with reroute_error() as (_, fatal):
     64       in_this_function()
     65     msg = '\n'.join(fatal.call_args[0])
     66     self.assertIn('Object was never used', msg)
     67     self.assertIn('blah0:0', msg)
     68     self.assertIn('in_this_function', msg)
     69     self.assertFalse(gc.garbage)
     70 
     71   def _testAddShouldUseWarningWhenUsed(self, fn, name):
     72     c = constant_op.constant(0, name=name)
     73     with reroute_error() as (error, fatal):
     74       h = tf_should_use._add_should_use_warning(c)
     75       fn(h)
     76       del h
     77     error.assert_not_called()
     78     fatal.assert_not_called()
     79 
     80   @test_util.run_deprecated_v1
     81   def testAddShouldUseWarningWhenUsedWithAdd(self):
     82     def add(h):
     83       _ = h + 1
     84     self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
     85     gc.collect()
     86     self.assertFalse(gc.garbage)
     87 
     88   @test_util.run_deprecated_v1
     89   def testAddShouldUseWarningWhenUsedWithGetName(self):
     90     def get_name(h):
     91       _ = h.name
     92     self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
     93     gc.collect()
     94     self.assertFalse(gc.garbage)
     95 
     96   @test_util.run_deprecated_v1
     97   def testShouldUseResult(self):
     98     @tf_should_use.should_use_result
     99     def return_const(value):
    100       return constant_op.constant(value, name='blah2')
    101     with reroute_error() as (error, _):
    102       return_const(0.0)
    103     msg = '\n'.join(error.call_args[0])
    104     self.assertIn('Object was never used', msg)
    105     self.assertIn('blah2:0', msg)
    106     self.assertIn('return_const', msg)
    107     gc.collect()
    108     self.assertFalse(gc.garbage)
    109 
    110   @test_util.run_deprecated_v1
    111   def testShouldUseResultWhenNotReallyUsed(self):
    112     @tf_should_use.should_use_result
    113     def return_const(value):
    114       return constant_op.constant(value, name='blah3')
    115     with reroute_error() as (error, _):
    116       with self.cached_session():
    117         return_const(0.0)
    118         # Creating another op and executing it does not mark the
    119         # unused op as being "used".
    120         v = constant_op.constant(1.0, name='meh')
    121         self.evaluate(v)
    122     msg = '\n'.join(error.call_args[0])
    123     self.assertIn('Object was never used', msg)
    124     self.assertIn('blah3:0', msg)
    125     self.assertIn('return_const', msg)
    126     gc.collect()
    127     self.assertFalse(gc.garbage)
    128 
    129   # Tests that mark_used is available in the API.
    130   def testMarkUsed(self):
    131     @tf_should_use.should_use_result
    132     def return_const(value):
    133       return constant_op.constant(value, name='blah3')
    134 
    135     with self.cached_session():
    136       return_const(0.0).mark_used()
    137 
    138 if __name__ == '__main__':
    139   test.main()
    140