1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests which set DEBUG_SAVEALL and assert no garbage was created. 16 17 This flag seems to be sticky, so these tests have been isolated for now. 18 """ 19 20 from __future__ import absolute_import 21 from __future__ import division 22 from __future__ import print_function 23 24 from tensorflow.python.eager import context 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import test_util 27 from tensorflow.python.ops import resource_variable_ops 28 from tensorflow.python.ops import tensor_array_ops 29 from tensorflow.python.platform import test 30 31 32 class NoReferenceCycleTests(test_util.TensorFlowTestCase): 33 34 @test_util.assert_no_garbage_created 35 def testEagerResourceVariables(self): 36 with context.eager_mode(): 37 resource_variable_ops.ResourceVariable(1.0, name="a") 38 39 @test_util.assert_no_garbage_created 40 def testTensorArrays(self): 41 with context.eager_mode(): 42 ta = tensor_array_ops.TensorArray( 43 dtype=dtypes.float32, 44 tensor_array_name="foo", 45 size=3, 46 infer_shape=False) 47 48 w0 = ta.write(0, [[4.0, 5.0]]) 49 w1 = w0.write(1, [[1.0]]) 50 w2 = w1.write(2, -3.0) 51 52 r0 = w2.read(0) 53 r1 = w2.read(1) 54 r2 = w2.read(2) 55 56 d0, d1, d2 = self.evaluate([r0, r1, r2]) 57 self.assertAllEqual([[4.0, 5.0]], d0) 58 self.assertAllEqual([[1.0]], d1) 59 self.assertAllEqual(-3.0, d2) 60 61 62 if __name__ == "__main__": 63 test.main() 64