Home | History | Annotate | Download | only in layers
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for regularizers."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.layers.python.layers import utils
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import variables
     30 from tensorflow.python.platform import test
     31 
     32 
     33 class ConstantValueTest(test.TestCase):
     34 
     35   def test_value(self):
     36     for v in [True, False, 1, 0, 1.0]:
     37       value = utils.constant_value(v)
     38       self.assertEqual(value, v)
     39 
     40   def test_constant(self):
     41     for v in [True, False, 1, 0, 1.0]:
     42       c = constant_op.constant(v)
     43       value = utils.constant_value(c)
     44       self.assertEqual(value, v)
     45       with self.test_session():
     46         self.assertEqual(c.eval(), v)
     47 
     48   def test_variable(self):
     49     for v in [True, False, 1, 0, 1.0]:
     50       with ops.Graph().as_default() as g, self.test_session(g) as sess:
     51         x = variables.Variable(v)
     52         value = utils.constant_value(x)
     53         self.assertEqual(value, None)
     54         sess.run(variables.global_variables_initializer())
     55         self.assertEqual(x.eval(), v)
     56 
     57   def test_placeholder(self):
     58     for v in [True, False, 1, 0, 1.0]:
     59       p = array_ops.placeholder(np.dtype(type(v)), [])
     60       x = array_ops.identity(p)
     61       value = utils.constant_value(p)
     62       self.assertEqual(value, None)
     63       with self.test_session():
     64         self.assertEqual(x.eval(feed_dict={p: v}), v)
     65 
     66 
     67 class StaticCondTest(test.TestCase):
     68 
     69   def test_value(self):
     70     fn1 = lambda: 'fn1'
     71     fn2 = lambda: 'fn2'
     72     expected = lambda v: 'fn1' if v else 'fn2'
     73     for v in [True, False, 1, 0]:
     74       o = utils.static_cond(v, fn1, fn2)
     75       self.assertEqual(o, expected(v))
     76 
     77   def test_constant(self):
     78     fn1 = lambda: constant_op.constant('fn1')
     79     fn2 = lambda: constant_op.constant('fn2')
     80     expected = lambda v: b'fn1' if v else b'fn2'
     81     for v in [True, False, 1, 0]:
     82       o = utils.static_cond(v, fn1, fn2)
     83       with self.test_session():
     84         self.assertEqual(o.eval(), expected(v))
     85 
     86   def test_variable(self):
     87     fn1 = lambda: variables.Variable('fn1')
     88     fn2 = lambda: variables.Variable('fn2')
     89     expected = lambda v: b'fn1' if v else b'fn2'
     90     for v in [True, False, 1, 0]:
     91       o = utils.static_cond(v, fn1, fn2)
     92       with self.test_session() as sess:
     93         sess.run(variables.global_variables_initializer())
     94         self.assertEqual(o.eval(), expected(v))
     95 
     96   def test_tensors(self):
     97     fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
     98     fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
     99     expected = lambda v: -1 if v else -2
    100     for v in [True, False, 1, 0]:
    101       o = utils.static_cond(v, fn1, fn2)
    102       with self.test_session():
    103         self.assertEqual(o.eval(), expected(v))
    104 
    105 
    106 class SmartCondStaticTest(test.TestCase):
    107 
    108   def test_value(self):
    109     fn1 = lambda: 'fn1'
    110     fn2 = lambda: 'fn2'
    111     expected = lambda v: 'fn1' if v else 'fn2'
    112     for v in [True, False, 1, 0]:
    113       o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
    114       self.assertEqual(o, expected(v))
    115 
    116   def test_constant(self):
    117     fn1 = lambda: constant_op.constant('fn1')
    118     fn2 = lambda: constant_op.constant('fn2')
    119     expected = lambda v: b'fn1' if v else b'fn2'
    120     for v in [True, False, 1, 0]:
    121       o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
    122       with self.test_session():
    123         self.assertEqual(o.eval(), expected(v))
    124 
    125   def test_variable(self):
    126     fn1 = lambda: variables.Variable('fn1')
    127     fn2 = lambda: variables.Variable('fn2')
    128     expected = lambda v: b'fn1' if v else b'fn2'
    129     for v in [True, False, 1, 0]:
    130       o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
    131       with self.test_session() as sess:
    132         sess.run(variables.global_variables_initializer())
    133         self.assertEqual(o.eval(), expected(v))
    134 
    135   def test_tensors(self):
    136     fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
    137     fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
    138     expected = lambda v: -1 if v else -2
    139     for v in [True, False, 1, 0]:
    140       o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
    141       with self.test_session():
    142         self.assertEqual(o.eval(), expected(v))
    143 
    144 
    145 class SmartCondDynamicTest(test.TestCase):
    146 
    147   def test_value(self):
    148     fn1 = lambda: ops.convert_to_tensor('fn1')
    149     fn2 = lambda: ops.convert_to_tensor('fn2')
    150     expected = lambda v: b'fn1' if v else b'fn2'
    151     p = array_ops.placeholder(dtypes.bool, [])
    152     for v in [True, False, 1, 0]:
    153       o = utils.smart_cond(p, fn1, fn2)
    154       with self.test_session():
    155         self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
    156 
    157   def test_constant(self):
    158     fn1 = lambda: constant_op.constant('fn1')
    159     fn2 = lambda: constant_op.constant('fn2')
    160     expected = lambda v: b'fn1' if v else b'fn2'
    161     p = array_ops.placeholder(dtypes.bool, [])
    162     for v in [True, False, 1, 0]:
    163       o = utils.smart_cond(p, fn1, fn2)
    164       with self.test_session():
    165         self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
    166 
    167   def test_variable(self):
    168     fn1 = lambda: variables.Variable('fn1')
    169     fn2 = lambda: variables.Variable('fn2')
    170     expected = lambda v: b'fn1' if v else b'fn2'
    171     p = array_ops.placeholder(dtypes.bool, [])
    172     for v in [True, False, 1, 0]:
    173       o = utils.smart_cond(p, fn1, fn2)
    174       with self.test_session() as sess:
    175         sess.run(variables.global_variables_initializer())
    176         self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
    177 
    178   def test_tensors(self):
    179     fn1 = lambda: constant_op.constant(0) - constant_op.constant(1)
    180     fn2 = lambda: constant_op.constant(0) - constant_op.constant(2)
    181     expected = lambda v: -1 if v else -2
    182     p = array_ops.placeholder(dtypes.bool, [])
    183     for v in [True, False, 1, 0]:
    184       o = utils.smart_cond(p, fn1, fn2)
    185       with self.test_session():
    186         self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
    187 
    188 
    189 class CollectNamedOutputsTest(test.TestCase):
    190 
    191   def test_collect(self):
    192     t1 = constant_op.constant(1.0, name='t1')
    193     t2 = constant_op.constant(2.0, name='t2')
    194     utils.collect_named_outputs('end_points', 'a1', t1)
    195     utils.collect_named_outputs('end_points', 'a2', t2)
    196     self.assertEqual(ops.get_collection('end_points'), [t1, t2])
    197 
    198   def test_aliases(self):
    199     t1 = constant_op.constant(1.0, name='t1')
    200     t2 = constant_op.constant(2.0, name='t2')
    201     utils.collect_named_outputs('end_points', 'a1', t1)
    202     utils.collect_named_outputs('end_points', 'a2', t2)
    203     self.assertEqual(t1.aliases, ['a1'])
    204     self.assertEqual(t2.aliases, ['a2'])
    205 
    206   def test_multiple_aliases(self):
    207     t1 = constant_op.constant(1.0, name='t1')
    208     t2 = constant_op.constant(2.0, name='t2')
    209     utils.collect_named_outputs('end_points', 'a11', t1)
    210     utils.collect_named_outputs('end_points', 'a12', t1)
    211     utils.collect_named_outputs('end_points', 'a21', t2)
    212     utils.collect_named_outputs('end_points', 'a22', t2)
    213     self.assertEqual(t1.aliases, ['a11', 'a12'])
    214     self.assertEqual(t2.aliases, ['a21', 'a22'])
    215 
    216   def test_gather_aliases(self):
    217     t1 = constant_op.constant(1.0, name='t1')
    218     t2 = constant_op.constant(2.0, name='t2')
    219     t3 = constant_op.constant(2.0, name='t3')
    220     utils.collect_named_outputs('end_points', 'a1', t1)
    221     utils.collect_named_outputs('end_points', 'a2', t2)
    222     ops.add_to_collection('end_points', t3)
    223     aliases = utils.gather_tensors_aliases(ops.get_collection('end_points'))
    224     self.assertEqual(aliases, ['a1', 'a2', 't3'])
    225 
    226   def test_convert_collection_to_dict(self):
    227     t1 = constant_op.constant(1.0, name='t1')
    228     t2 = constant_op.constant(2.0, name='t2')
    229     utils.collect_named_outputs('end_points', 'a1', t1)
    230     utils.collect_named_outputs('end_points', 'a21', t2)
    231     utils.collect_named_outputs('end_points', 'a22', t2)
    232     end_points = utils.convert_collection_to_dict('end_points')
    233     self.assertEqual(end_points['a1'], t1)
    234     self.assertEqual(end_points['a21'], t2)
    235     self.assertEqual(end_points['a22'], t2)
    236 
    237   def test_convert_collection_to_dict_clear_collection(self):
    238     t1 = constant_op.constant(1.0, name='t1')
    239     t2 = constant_op.constant(2.0, name='t2')
    240     utils.collect_named_outputs('end_points', 'a1', t1)
    241     utils.collect_named_outputs('end_points', 'a21', t2)
    242     utils.collect_named_outputs('end_points', 'a22', t2)
    243     utils.convert_collection_to_dict('end_points', clear_collection=True)
    244     self.assertEqual(ops.get_collection('end_points'), [])
    245 
    246 
    247 class NPositiveIntegersTest(test.TestCase):
    248 
    249   def test_invalid_input(self):
    250     with self.assertRaises(ValueError):
    251       utils.n_positive_integers('3', [1])
    252 
    253     with self.assertRaises(ValueError):
    254       utils.n_positive_integers(3.3, [1])
    255 
    256     with self.assertRaises(ValueError):
    257       utils.n_positive_integers(-1, [1])
    258 
    259     with self.assertRaises(ValueError):
    260       utils.n_positive_integers(0, [1])
    261 
    262     with self.assertRaises(ValueError):
    263       utils.n_positive_integers(1, [1, 2])
    264 
    265     with self.assertRaises(ValueError):
    266       utils.n_positive_integers(1, [-1])
    267 
    268     with self.assertRaises(ValueError):
    269       utils.n_positive_integers(1, [0])
    270 
    271     with self.assertRaises(ValueError):
    272       utils.n_positive_integers(1, [0])
    273 
    274     with self.assertRaises(ValueError):
    275       utils.n_positive_integers(2, [1])
    276 
    277     with self.assertRaises(ValueError):
    278       utils.n_positive_integers(2, [1, 2, 3])
    279 
    280     with self.assertRaises(ValueError):
    281       utils.n_positive_integers(2, ['hello', 2])
    282 
    283     with self.assertRaises(ValueError):
    284       utils.n_positive_integers(2, tensor_shape.TensorShape([2, 3, 1]))
    285 
    286     with self.assertRaises(ValueError):
    287       utils.n_positive_integers(3, tensor_shape.TensorShape([2, None, 1]))
    288 
    289     with self.assertRaises(ValueError):
    290       utils.n_positive_integers(3, tensor_shape.TensorShape(None))
    291 
    292   def test_valid_input(self):
    293     self.assertEqual(utils.n_positive_integers(1, 2), (2,))
    294     self.assertEqual(utils.n_positive_integers(2, 2), (2, 2))
    295     self.assertEqual(utils.n_positive_integers(2, (2, 3)), (2, 3))
    296     self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1))
    297     self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1))
    298     self.assertEqual(
    299         utils.n_positive_integers(3, tensor_shape.TensorShape([2, 3, 1])),
    300         (2, 3, 1))
    301 
    302 
    303 if __name__ == '__main__':
    304   test.main()
    305