Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Functional test for slot_creator."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.framework import constant_op
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import test_util
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import random_ops
     27 from tensorflow.python.ops import variable_scope
     28 from tensorflow.python.ops import variables
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.training import slot_creator
     31 
     32 
     33 @test_util.with_c_api
     34 class SlotCreatorTest(test.TestCase):
     35 
     36   def testCreateSlotFromVariable(self):
     37     with self.test_session():
     38       v = variables.Variable([1.0, 2.5], name="var")
     39       slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
     40 
     41       variables.global_variables_initializer().run()
     42 
     43       self.assertEqual("var/slot", slot.op.name)
     44       self.assertEqual([2], slot.get_shape().as_list())
     45       self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
     46       self.assertAllEqual([1.0, 2.5], slot.eval())
     47 
     48   def testCreateSlotFromTensor(self):
     49     with self.test_session():
     50       v = constant_op.constant([1.0, 2.5], name="const")
     51       slot = slot_creator.create_slot(v, v * 2, name="slot")
     52 
     53       variables.global_variables_initializer().run()
     54 
     55       self.assertEqual("const/slot", slot.op.name)
     56       self.assertEqual([2], slot.get_shape().as_list())
     57       self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
     58       self.assertAllEqual([2.0, 5.0], slot.eval())
     59 
     60   def testCreateZerosSlotFromVariable(self):
     61     with self.test_session():
     62       v = variables.Variable([1.0, 2.5], name="var")
     63       with ops.control_dependencies(None):
     64         slot = slot_creator.create_zeros_slot(
     65             v, name="slot", dtype=dtypes.float64)
     66 
     67       variables.global_variables_initializer().run()
     68 
     69       self.assertEqual("var/slot", slot.op.name)
     70       self.assertEqual([2], slot.get_shape().as_list())
     71       self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
     72       self.assertAllEqual([0.0, 0.0], slot.eval())
     73 
     74   def testCreateZerosSlotFromDynamicShapedVariable(self):
     75     with self.test_session():
     76       dyn_shape = constant_op.constant([2], dtype=dtypes.int32)
     77       dyn_shape = array_ops.placeholder_with_default(dyn_shape,
     78                                                      shape=[None])
     79       v = variable_scope.get_variable(
     80           "var",
     81           initializer=random_ops.random_uniform(dyn_shape,
     82                                                 dtype=dtypes.float64),
     83           validate_shape=False)
     84       with ops.control_dependencies(None):
     85         slot = slot_creator.create_zeros_slot(
     86             v, name="slot", dtype=dtypes.float64)
     87 
     88       variables.global_variables_initializer().run()
     89 
     90       self.assertEqual("var/slot", slot.op.name)
     91       self.assertEqual([2], array_ops.shape(slot).eval())
     92       self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
     93       self.assertAllEqual([0.0, 0.0], slot.eval())
     94 
     95   def testCreateZerosSlotFromTensor(self):
     96     with self.test_session():
     97       v = constant_op.constant([1.0, 2.5], name="const")
     98       with ops.control_dependencies(None):
     99         slot = slot_creator.create_zeros_slot(v, name="slot")
    100 
    101       variables.global_variables_initializer().run()
    102 
    103       self.assertEqual("const/slot", slot.op.name)
    104       self.assertEqual([2], slot.get_shape().as_list())
    105       self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
    106       self.assertAllEqual([0.0, 0.0], slot.eval())
    107 
    108   def testCreateZerosSlotFromDynamicShapedTensor(self):
    109     with self.test_session():
    110       v = random_ops.random_uniform([2], dtype=dtypes.float64)
    111       v = array_ops.placeholder_with_default(v, shape=[None], name="const")
    112       with ops.control_dependencies(None):
    113         slot = slot_creator.create_zeros_slot(
    114             v, name="slot", dtype=dtypes.float64)
    115 
    116       variables.global_variables_initializer().run()
    117 
    118       self.assertEqual("const/slot", slot.op.name)
    119       self.assertEqual([2], array_ops.shape(slot).eval())
    120       self.assertEqual(dtypes.float64, slot.dtype.base_dtype)
    121       self.assertAllEqual([0.0, 0.0], slot.eval())
    122 
    123   def testCreateSlotFromVariableRespectsScope(self):
    124     # See discussion on #2740.
    125     with self.test_session():
    126       with variable_scope.variable_scope("scope"):
    127         v = variables.Variable([1.0, 2.5], name="var")
    128         slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
    129         self.assertEqual("scope/scope/var/slot", slot.op.name)
    130 
    131 
    132 if __name__ == "__main__":
    133   test.main()
    134