1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for tf_should_use.""" 16 17 # pylint: disable=unused-import 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import contextlib 23 import gc 24 import sys 25 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import test_util 28 from tensorflow.python.platform import test 29 from tensorflow.python.platform import tf_logging 30 from tensorflow.python.util import tf_should_use 31 32 33 @contextlib.contextmanager 34 def reroute_error(): 35 """Temporarily reroute errors written to tf_logging.error into `captured`.""" 36 with test.mock.patch.object(tf_should_use.tf_logging, 'error') as error: 37 with test.mock.patch.object(tf_should_use.tf_logging, 'fatal') as fatal: 38 yield error, fatal 39 40 41 class TfShouldUseTest(test.TestCase): 42 43 @test_util.run_deprecated_v1 44 def testAddShouldUseWarningWhenNotUsed(self): 45 c = constant_op.constant(0, name='blah0') 46 def in_this_function(): 47 h = tf_should_use._add_should_use_warning(c) 48 del h 49 with reroute_error() as (error, _): 50 in_this_function() 51 msg = '\n'.join(error.call_args[0]) 52 self.assertIn('Object was never used', msg) 53 self.assertIn('blah0:0', msg) 54 self.assertIn('in_this_function', msg) 55 self.assertFalse(gc.garbage) 56 57 @test_util.run_deprecated_v1 58 def testAddShouldUseFatalWhenNotUsed(self): 59 c = constant_op.constant(0, name='blah0') 60 def in_this_function(): 61 h = tf_should_use._add_should_use_warning(c, fatal_error=True) 62 del h 63 with reroute_error() as (_, fatal): 64 in_this_function() 65 msg = '\n'.join(fatal.call_args[0]) 66 self.assertIn('Object was never used', msg) 67 self.assertIn('blah0:0', msg) 68 self.assertIn('in_this_function', msg) 69 self.assertFalse(gc.garbage) 70 71 def _testAddShouldUseWarningWhenUsed(self, fn, name): 72 c = constant_op.constant(0, name=name) 73 with reroute_error() as (error, fatal): 74 h = tf_should_use._add_should_use_warning(c) 75 fn(h) 76 del h 77 error.assert_not_called() 78 fatal.assert_not_called() 79 80 @test_util.run_deprecated_v1 81 def testAddShouldUseWarningWhenUsedWithAdd(self): 82 def add(h): 83 _ = h + 1 84 self._testAddShouldUseWarningWhenUsed(add, name='blah_add') 85 gc.collect() 86 self.assertFalse(gc.garbage) 87 88 @test_util.run_deprecated_v1 89 def testAddShouldUseWarningWhenUsedWithGetName(self): 90 def get_name(h): 91 _ = h.name 92 self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name') 93 gc.collect() 94 self.assertFalse(gc.garbage) 95 96 @test_util.run_deprecated_v1 97 def testShouldUseResult(self): 98 @tf_should_use.should_use_result 99 def return_const(value): 100 return constant_op.constant(value, name='blah2') 101 with reroute_error() as (error, _): 102 return_const(0.0) 103 msg = '\n'.join(error.call_args[0]) 104 self.assertIn('Object was never used', msg) 105 self.assertIn('blah2:0', msg) 106 self.assertIn('return_const', msg) 107 gc.collect() 108 self.assertFalse(gc.garbage) 109 110 @test_util.run_deprecated_v1 111 def testShouldUseResultWhenNotReallyUsed(self): 112 @tf_should_use.should_use_result 113 def return_const(value): 114 return constant_op.constant(value, name='blah3') 115 with reroute_error() as (error, _): 116 with self.cached_session(): 117 return_const(0.0) 118 # Creating another op and executing it does not mark the 119 # unused op as being "used". 120 v = constant_op.constant(1.0, name='meh') 121 self.evaluate(v) 122 msg = '\n'.join(error.call_args[0]) 123 self.assertIn('Object was never used', msg) 124 self.assertIn('blah3:0', msg) 125 self.assertIn('return_const', msg) 126 gc.collect() 127 self.assertFalse(gc.garbage) 128 129 # Tests that mark_used is available in the API. 130 def testMarkUsed(self): 131 @tf_should_use.should_use_result 132 def return_const(value): 133 return constant_op.constant(value, name='blah3') 134 135 with self.cached_session(): 136 return_const(0.0).mark_used() 137 138 if __name__ == '__main__': 139 test.main() 140