1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for regularizers.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.layers.python.layers import utils 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import tensor_shape 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import variables 30 from tensorflow.python.platform import test 31 32 33 class ConstantValueTest(test.TestCase): 34 35 def test_value(self): 36 for v in [True, False, 1, 0, 1.0]: 37 value = utils.constant_value(v) 38 self.assertEqual(value, v) 39 40 def test_constant(self): 41 for v in [True, False, 1, 0, 1.0]: 42 c = constant_op.constant(v) 43 value = utils.constant_value(c) 44 self.assertEqual(value, v) 45 with self.test_session(): 46 self.assertEqual(c.eval(), v) 47 48 def test_variable(self): 49 for v in [True, False, 1, 0, 1.0]: 50 with ops.Graph().as_default() as g, self.test_session(g) as sess: 51 x = variables.Variable(v) 52 value = utils.constant_value(x) 53 self.assertEqual(value, None) 54 sess.run(variables.global_variables_initializer()) 55 self.assertEqual(x.eval(), v) 56 57 def test_placeholder(self): 58 for v in [True, False, 1, 0, 1.0]: 59 p = array_ops.placeholder(np.dtype(type(v)), []) 60 x = array_ops.identity(p) 61 value = utils.constant_value(p) 62 self.assertEqual(value, None) 63 with self.test_session(): 64 self.assertEqual(x.eval(feed_dict={p: v}), v) 65 66 67 class StaticCondTest(test.TestCase): 68 69 def test_value(self): 70 fn1 = lambda: 'fn1' 71 fn2 = lambda: 'fn2' 72 expected = lambda v: 'fn1' if v else 'fn2' 73 for v in [True, False, 1, 0]: 74 o = utils.static_cond(v, fn1, fn2) 75 self.assertEqual(o, expected(v)) 76 77 def test_constant(self): 78 fn1 = lambda: constant_op.constant('fn1') 79 fn2 = lambda: constant_op.constant('fn2') 80 expected = lambda v: b'fn1' if v else b'fn2' 81 for v in [True, False, 1, 0]: 82 o = utils.static_cond(v, fn1, fn2) 83 with self.test_session(): 84 self.assertEqual(o.eval(), expected(v)) 85 86 def test_variable(self): 87 fn1 = lambda: variables.Variable('fn1') 88 fn2 = lambda: variables.Variable('fn2') 89 expected = lambda v: b'fn1' if v else b'fn2' 90 for v in [True, False, 1, 0]: 91 o = utils.static_cond(v, fn1, fn2) 92 with self.test_session() as sess: 93 sess.run(variables.global_variables_initializer()) 94 self.assertEqual(o.eval(), expected(v)) 95 96 def test_tensors(self): 97 fn1 = lambda: constant_op.constant(0) - constant_op.constant(1) 98 fn2 = lambda: constant_op.constant(0) - constant_op.constant(2) 99 expected = lambda v: -1 if v else -2 100 for v in [True, False, 1, 0]: 101 o = utils.static_cond(v, fn1, fn2) 102 with self.test_session(): 103 self.assertEqual(o.eval(), expected(v)) 104 105 106 class SmartCondStaticTest(test.TestCase): 107 108 def test_value(self): 109 fn1 = lambda: 'fn1' 110 fn2 = lambda: 'fn2' 111 expected = lambda v: 'fn1' if v else 'fn2' 112 for v in [True, False, 1, 0]: 113 o = utils.smart_cond(constant_op.constant(v), fn1, fn2) 114 self.assertEqual(o, expected(v)) 115 116 def test_constant(self): 117 fn1 = lambda: constant_op.constant('fn1') 118 fn2 = lambda: constant_op.constant('fn2') 119 expected = lambda v: b'fn1' if v else b'fn2' 120 for v in [True, False, 1, 0]: 121 o = utils.smart_cond(constant_op.constant(v), fn1, fn2) 122 with self.test_session(): 123 self.assertEqual(o.eval(), expected(v)) 124 125 def test_variable(self): 126 fn1 = lambda: variables.Variable('fn1') 127 fn2 = lambda: variables.Variable('fn2') 128 expected = lambda v: b'fn1' if v else b'fn2' 129 for v in [True, False, 1, 0]: 130 o = utils.smart_cond(constant_op.constant(v), fn1, fn2) 131 with self.test_session() as sess: 132 sess.run(variables.global_variables_initializer()) 133 self.assertEqual(o.eval(), expected(v)) 134 135 def test_tensors(self): 136 fn1 = lambda: constant_op.constant(0) - constant_op.constant(1) 137 fn2 = lambda: constant_op.constant(0) - constant_op.constant(2) 138 expected = lambda v: -1 if v else -2 139 for v in [True, False, 1, 0]: 140 o = utils.smart_cond(constant_op.constant(v), fn1, fn2) 141 with self.test_session(): 142 self.assertEqual(o.eval(), expected(v)) 143 144 145 class SmartCondDynamicTest(test.TestCase): 146 147 def test_value(self): 148 fn1 = lambda: ops.convert_to_tensor('fn1') 149 fn2 = lambda: ops.convert_to_tensor('fn2') 150 expected = lambda v: b'fn1' if v else b'fn2' 151 p = array_ops.placeholder(dtypes.bool, []) 152 for v in [True, False, 1, 0]: 153 o = utils.smart_cond(p, fn1, fn2) 154 with self.test_session(): 155 self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) 156 157 def test_constant(self): 158 fn1 = lambda: constant_op.constant('fn1') 159 fn2 = lambda: constant_op.constant('fn2') 160 expected = lambda v: b'fn1' if v else b'fn2' 161 p = array_ops.placeholder(dtypes.bool, []) 162 for v in [True, False, 1, 0]: 163 o = utils.smart_cond(p, fn1, fn2) 164 with self.test_session(): 165 self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) 166 167 def test_variable(self): 168 fn1 = lambda: variables.Variable('fn1') 169 fn2 = lambda: variables.Variable('fn2') 170 expected = lambda v: b'fn1' if v else b'fn2' 171 p = array_ops.placeholder(dtypes.bool, []) 172 for v in [True, False, 1, 0]: 173 o = utils.smart_cond(p, fn1, fn2) 174 with self.test_session() as sess: 175 sess.run(variables.global_variables_initializer()) 176 self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) 177 178 def test_tensors(self): 179 fn1 = lambda: constant_op.constant(0) - constant_op.constant(1) 180 fn2 = lambda: constant_op.constant(0) - constant_op.constant(2) 181 expected = lambda v: -1 if v else -2 182 p = array_ops.placeholder(dtypes.bool, []) 183 for v in [True, False, 1, 0]: 184 o = utils.smart_cond(p, fn1, fn2) 185 with self.test_session(): 186 self.assertEqual(o.eval(feed_dict={p: v}), expected(v)) 187 188 189 class CollectNamedOutputsTest(test.TestCase): 190 191 def test_collect(self): 192 t1 = constant_op.constant(1.0, name='t1') 193 t2 = constant_op.constant(2.0, name='t2') 194 utils.collect_named_outputs('end_points', 'a1', t1) 195 utils.collect_named_outputs('end_points', 'a2', t2) 196 self.assertEqual(ops.get_collection('end_points'), [t1, t2]) 197 198 def test_aliases(self): 199 t1 = constant_op.constant(1.0, name='t1') 200 t2 = constant_op.constant(2.0, name='t2') 201 utils.collect_named_outputs('end_points', 'a1', t1) 202 utils.collect_named_outputs('end_points', 'a2', t2) 203 self.assertEqual(t1.aliases, ['a1']) 204 self.assertEqual(t2.aliases, ['a2']) 205 206 def test_multiple_aliases(self): 207 t1 = constant_op.constant(1.0, name='t1') 208 t2 = constant_op.constant(2.0, name='t2') 209 utils.collect_named_outputs('end_points', 'a11', t1) 210 utils.collect_named_outputs('end_points', 'a12', t1) 211 utils.collect_named_outputs('end_points', 'a21', t2) 212 utils.collect_named_outputs('end_points', 'a22', t2) 213 self.assertEqual(t1.aliases, ['a11', 'a12']) 214 self.assertEqual(t2.aliases, ['a21', 'a22']) 215 216 def test_gather_aliases(self): 217 t1 = constant_op.constant(1.0, name='t1') 218 t2 = constant_op.constant(2.0, name='t2') 219 t3 = constant_op.constant(2.0, name='t3') 220 utils.collect_named_outputs('end_points', 'a1', t1) 221 utils.collect_named_outputs('end_points', 'a2', t2) 222 ops.add_to_collection('end_points', t3) 223 aliases = utils.gather_tensors_aliases(ops.get_collection('end_points')) 224 self.assertEqual(aliases, ['a1', 'a2', 't3']) 225 226 def test_convert_collection_to_dict(self): 227 t1 = constant_op.constant(1.0, name='t1') 228 t2 = constant_op.constant(2.0, name='t2') 229 utils.collect_named_outputs('end_points', 'a1', t1) 230 utils.collect_named_outputs('end_points', 'a21', t2) 231 utils.collect_named_outputs('end_points', 'a22', t2) 232 end_points = utils.convert_collection_to_dict('end_points') 233 self.assertEqual(end_points['a1'], t1) 234 self.assertEqual(end_points['a21'], t2) 235 self.assertEqual(end_points['a22'], t2) 236 237 def test_convert_collection_to_dict_clear_collection(self): 238 t1 = constant_op.constant(1.0, name='t1') 239 t2 = constant_op.constant(2.0, name='t2') 240 utils.collect_named_outputs('end_points', 'a1', t1) 241 utils.collect_named_outputs('end_points', 'a21', t2) 242 utils.collect_named_outputs('end_points', 'a22', t2) 243 utils.convert_collection_to_dict('end_points', clear_collection=True) 244 self.assertEqual(ops.get_collection('end_points'), []) 245 246 247 class NPositiveIntegersTest(test.TestCase): 248 249 def test_invalid_input(self): 250 with self.assertRaises(ValueError): 251 utils.n_positive_integers('3', [1]) 252 253 with self.assertRaises(ValueError): 254 utils.n_positive_integers(3.3, [1]) 255 256 with self.assertRaises(ValueError): 257 utils.n_positive_integers(-1, [1]) 258 259 with self.assertRaises(ValueError): 260 utils.n_positive_integers(0, [1]) 261 262 with self.assertRaises(ValueError): 263 utils.n_positive_integers(1, [1, 2]) 264 265 with self.assertRaises(ValueError): 266 utils.n_positive_integers(1, [-1]) 267 268 with self.assertRaises(ValueError): 269 utils.n_positive_integers(1, [0]) 270 271 with self.assertRaises(ValueError): 272 utils.n_positive_integers(1, [0]) 273 274 with self.assertRaises(ValueError): 275 utils.n_positive_integers(2, [1]) 276 277 with self.assertRaises(ValueError): 278 utils.n_positive_integers(2, [1, 2, 3]) 279 280 with self.assertRaises(ValueError): 281 utils.n_positive_integers(2, ['hello', 2]) 282 283 with self.assertRaises(ValueError): 284 utils.n_positive_integers(2, tensor_shape.TensorShape([2, 3, 1])) 285 286 with self.assertRaises(ValueError): 287 utils.n_positive_integers(3, tensor_shape.TensorShape([2, None, 1])) 288 289 with self.assertRaises(ValueError): 290 utils.n_positive_integers(3, tensor_shape.TensorShape(None)) 291 292 def test_valid_input(self): 293 self.assertEqual(utils.n_positive_integers(1, 2), (2,)) 294 self.assertEqual(utils.n_positive_integers(2, 2), (2, 2)) 295 self.assertEqual(utils.n_positive_integers(2, (2, 3)), (2, 3)) 296 self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1)) 297 self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1)) 298 self.assertEqual( 299 utils.n_positive_integers(3, tensor_shape.TensorShape([2, 3, 1])), 300 (2, 3, 1)) 301 302 303 if __name__ == '__main__': 304 test.main() 305