1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for IdentityOp.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import gen_array_ops 27 from tensorflow.python.ops import variables 28 from tensorflow.python.platform import test 29 30 31 class IdentityOpTest(test.TestCase): 32 33 def testInt32_6(self): 34 with self.test_session(): 35 value = array_ops.identity([1, 2, 3, 4, 5, 6]).eval() 36 self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value) 37 38 def testInt32_2_3(self): 39 with self.test_session(): 40 inp = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3]) 41 value = array_ops.identity(inp).eval() 42 self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value) 43 44 def testString(self): 45 source = [b"A", b"b", b"C", b"d", b"E", b"f"] 46 with self.test_session(): 47 value = array_ops.identity(source).eval() 48 self.assertAllEqual(source, value) 49 50 def testIdentityShape(self): 51 with self.test_session(): 52 shape = [2, 3] 53 array_2x3 = [[1, 2, 3], [6, 5, 4]] 54 tensor = constant_op.constant(array_2x3) 55 self.assertEquals(shape, tensor.get_shape()) 56 self.assertEquals(shape, array_ops.identity(tensor).get_shape()) 57 self.assertEquals(shape, array_ops.identity(array_2x3).get_shape()) 58 self.assertEquals(shape, 59 array_ops.identity(np.array(array_2x3)).get_shape()) 60 61 def testRefIdentityShape(self): 62 with self.test_session(): 63 shape = [2, 3] 64 tensor = variables.Variable( 65 constant_op.constant( 66 [[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32)) 67 self.assertEquals(shape, tensor.get_shape()) 68 self.assertEquals(shape, gen_array_ops._ref_identity(tensor).get_shape()) 69 70 71 if __name__ == "__main__": 72 test.main() 73