1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for PyFlow list.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.py2tf.utils import tensor_list as tl 22 from tensorflow.python.client.session import Session 23 from tensorflow.python.eager import context 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework.constant_op import constant 26 from tensorflow.python.platform import test 27 28 29 class TensorListTest(test.TestCase): 30 31 def test_list_append_python(self): 32 with context.eager_mode(): 33 a = constant(3.0) 34 l = tl.TensorList(a.shape, a.dtype) 35 l.append(a) 36 self.assertEqual(l.count().numpy(), 1) 37 l.append(a) 38 self.assertEqual(l.count().numpy(), 2) 39 _ = l.pop() 40 self.assertEqual(l.count().numpy(), 1) 41 a2 = l.pop() 42 self.assertEqual(l.count().numpy(), 0) 43 self.assertEqual(a.numpy(), a2.numpy()) 44 45 def test_list_index_python(self): 46 with context.eager_mode(): 47 a = constant(3.0) 48 b = constant(2.0) 49 l = tl.TensorList(a.shape, a.dtype) 50 l.append(a) 51 self.assertEqual(l[0].numpy(), a.numpy()) 52 l[0] = ops.convert_to_tensor(b) 53 self.assertEqual(l[0].numpy(), b.numpy()) 54 55 def test_list_append_tf(self): 56 a = constant(3.0) 57 l = tl.TensorList(a.shape, a.dtype) 58 l.append(a) 59 c1 = l.count() 60 l.append(a) 61 c2 = l.count() 62 _ = l.pop() 63 c3 = l.count() 64 a2 = l.pop() 65 c4 = l.count() 66 with Session() as sess: 67 c1, c2, c3, c4, a, a2 = sess.run([c1, c2, c3, c4, a, a2]) 68 self.assertEqual(c1, 1) 69 self.assertEqual(c2, 2) 70 self.assertEqual(c3, 1) 71 self.assertEqual(c4, 0) 72 self.assertEqual(a, a2) 73 74 def test_list_index_tf(self): 75 a = constant(3.0) 76 b = constant(2.0) 77 l = tl.TensorList(a.shape, a.dtype) 78 l.append(a) 79 l0 = l[0] 80 l[0] = b 81 l1 = l[0] 82 with self.test_session() as sess: 83 l0, l1, a, b = sess.run([l0, l1, a, b]) 84 self.assertEqual(l0, a) 85 self.assertEqual(l1, b) 86 87 88 if __name__ == '__main__': 89 test.main() 90