1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for tf_should_use.""" 16 17 # pylint: disable=unused-import 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import contextlib 23 import gc 24 import sys 25 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.platform import test 28 from tensorflow.python.platform import tf_logging 29 from tensorflow.python.util import tf_should_use 30 31 32 @contextlib.contextmanager 33 def reroute_error(captured): 34 """Temporarily reroute errors written to tf_logging.error into `captured`.""" 35 del captured[:] 36 true_logger = tf_logging.error 37 def capture_errors(*args, **unused_kwargs): 38 captured.extend(args) 39 tf_logging.error = capture_errors 40 try: 41 yield 42 finally: 43 tf_logging.error = true_logger 44 45 46 class TfShouldUseTest(test.TestCase): 47 48 def testAddShouldUseWarningWhenNotUsed(self): 49 self.skipTest('b/65412899') 50 c = constant_op.constant(0, name='blah0') 51 captured = [] 52 with reroute_error(captured): 53 def in_this_function(): 54 h = tf_should_use._add_should_use_warning(c) 55 del h 56 in_this_function() 57 self.assertIn('Object was never used', '\n'.join(captured)) 58 self.assertIn('blah0:0', '\n'.join(captured)) 59 self.assertIn('in_this_function', '\n'.join(captured)) 60 gc.collect() 61 self.assertFalse(gc.garbage) 62 63 def _testAddShouldUseWarningWhenUsed(self, fn, name): 64 c = constant_op.constant(0, name=name) 65 captured = [] 66 with reroute_error(captured): 67 h = tf_should_use._add_should_use_warning(c) 68 fn(h) 69 del h 70 self.assertNotIn('Object was never used', '\n'.join(captured)) 71 self.assertNotIn('%s:0' % name, '\n'.join(captured)) 72 73 def testAddShouldUseWarningWhenUsedWithAdd(self): 74 self.skipTest('b/65412899') 75 def add(h): 76 _ = h + 1 77 self._testAddShouldUseWarningWhenUsed(add, name='blah_add') 78 gc.collect() 79 self.assertFalse(gc.garbage) 80 81 def testAddShouldUseWarningWhenUsedWithGetName(self): 82 self.skipTest('b/65412899') 83 def get_name(h): 84 _ = h.name 85 self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name') 86 gc.collect() 87 self.assertFalse(gc.garbage) 88 89 def testShouldUseResult(self): 90 self.skipTest('b/65412899') 91 @tf_should_use.should_use_result 92 def return_const(value): 93 return constant_op.constant(value, name='blah2') 94 captured = [] 95 with reroute_error(captured): 96 return_const(0.0) 97 self.assertIn('Object was never used', '\n'.join(captured)) 98 self.assertIn('blah2:0', '\n'.join(captured)) 99 self.assertIn('return_const', '\n'.join(captured)) 100 gc.collect() 101 self.assertFalse(gc.garbage) 102 103 def testShouldUseResultWhenNotReallyUsed(self): 104 self.skipTest('b/65412899') 105 @tf_should_use.should_use_result 106 def return_const(value): 107 return constant_op.constant(value, name='blah3') 108 captured = [] 109 with reroute_error(captured): 110 with self.test_session(): 111 return_const(0.0) 112 # Creating another op and executing it does not mark the 113 # unused op as being "used". 114 v = constant_op.constant(1.0, name='meh') 115 v.eval() 116 self.assertIn('Object was never used', '\n'.join(captured)) 117 self.assertIn('blah3:0', '\n'.join(captured)) 118 self.assertIn('return_const', '\n'.join(captured)) 119 gc.collect() 120 self.assertFalse(gc.garbage) 121 122 # Tests that mark_used is available in the API. 123 def testMarkUsed(self): 124 @tf_should_use.should_use_result 125 def return_const(value): 126 return constant_op.constant(value, name='blah3') 127 with self.test_session(): 128 return_const(0.0).mark_used() 129 130 if __name__ == '__main__': 131 test.main() 132