1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Functional test for slot_creator.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import constant_op 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import ops 24 from tensorflow.python.framework import test_util 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import random_ops 27 from tensorflow.python.ops import variable_scope 28 from tensorflow.python.ops import variables 29 from tensorflow.python.platform import test 30 from tensorflow.python.training import slot_creator 31 32 33 @test_util.with_c_api 34 class SlotCreatorTest(test.TestCase): 35 36 def testCreateSlotFromVariable(self): 37 with self.test_session(): 38 v = variables.Variable([1.0, 2.5], name="var") 39 slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") 40 41 variables.global_variables_initializer().run() 42 43 self.assertEqual("var/slot", slot.op.name) 44 self.assertEqual([2], slot.get_shape().as_list()) 45 self.assertEqual(dtypes.float32, slot.dtype.base_dtype) 46 self.assertAllEqual([1.0, 2.5], slot.eval()) 47 48 def testCreateSlotFromTensor(self): 49 with self.test_session(): 50 v = constant_op.constant([1.0, 2.5], name="const") 51 slot = slot_creator.create_slot(v, v * 2, name="slot") 52 53 variables.global_variables_initializer().run() 54 55 self.assertEqual("const/slot", slot.op.name) 56 self.assertEqual([2], slot.get_shape().as_list()) 57 self.assertEqual(dtypes.float32, slot.dtype.base_dtype) 58 self.assertAllEqual([2.0, 5.0], slot.eval()) 59 60 def testCreateZerosSlotFromVariable(self): 61 with self.test_session(): 62 v = variables.Variable([1.0, 2.5], name="var") 63 with ops.control_dependencies(None): 64 slot = slot_creator.create_zeros_slot( 65 v, name="slot", dtype=dtypes.float64) 66 67 variables.global_variables_initializer().run() 68 69 self.assertEqual("var/slot", slot.op.name) 70 self.assertEqual([2], slot.get_shape().as_list()) 71 self.assertEqual(dtypes.float64, slot.dtype.base_dtype) 72 self.assertAllEqual([0.0, 0.0], slot.eval()) 73 74 def testCreateZerosSlotFromDynamicShapedVariable(self): 75 with self.test_session(): 76 dyn_shape = constant_op.constant([2], dtype=dtypes.int32) 77 dyn_shape = array_ops.placeholder_with_default(dyn_shape, 78 shape=[None]) 79 v = variable_scope.get_variable( 80 "var", 81 initializer=random_ops.random_uniform(dyn_shape, 82 dtype=dtypes.float64), 83 validate_shape=False) 84 with ops.control_dependencies(None): 85 slot = slot_creator.create_zeros_slot( 86 v, name="slot", dtype=dtypes.float64) 87 88 variables.global_variables_initializer().run() 89 90 self.assertEqual("var/slot", slot.op.name) 91 self.assertEqual([2], array_ops.shape(slot).eval()) 92 self.assertEqual(dtypes.float64, slot.dtype.base_dtype) 93 self.assertAllEqual([0.0, 0.0], slot.eval()) 94 95 def testCreateZerosSlotFromTensor(self): 96 with self.test_session(): 97 v = constant_op.constant([1.0, 2.5], name="const") 98 with ops.control_dependencies(None): 99 slot = slot_creator.create_zeros_slot(v, name="slot") 100 101 variables.global_variables_initializer().run() 102 103 self.assertEqual("const/slot", slot.op.name) 104 self.assertEqual([2], slot.get_shape().as_list()) 105 self.assertEqual(dtypes.float32, slot.dtype.base_dtype) 106 self.assertAllEqual([0.0, 0.0], slot.eval()) 107 108 def testCreateZerosSlotFromDynamicShapedTensor(self): 109 with self.test_session(): 110 v = random_ops.random_uniform([2], dtype=dtypes.float64) 111 v = array_ops.placeholder_with_default(v, shape=[None], name="const") 112 with ops.control_dependencies(None): 113 slot = slot_creator.create_zeros_slot( 114 v, name="slot", dtype=dtypes.float64) 115 116 variables.global_variables_initializer().run() 117 118 self.assertEqual("const/slot", slot.op.name) 119 self.assertEqual([2], array_ops.shape(slot).eval()) 120 self.assertEqual(dtypes.float64, slot.dtype.base_dtype) 121 self.assertAllEqual([0.0, 0.0], slot.eval()) 122 123 def testCreateSlotFromVariableRespectsScope(self): 124 # See discussion on #2740. 125 with self.test_session(): 126 with variable_scope.variable_scope("scope"): 127 v = variables.Variable([1.0, 2.5], name="var") 128 slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") 129 self.assertEqual("scope/scope/var/slot", slot.op.name) 130 131 132 if __name__ == "__main__": 133 test.main() 134