1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """arg_scope tests.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.contrib.framework.python.ops import add_arg_scope 21 from tensorflow.contrib.framework.python.ops import arg_scope 22 from tensorflow.contrib.framework.python.ops import arg_scoped_arguments 23 from tensorflow.python.platform import test 24 25 26 @add_arg_scope 27 def func1(*args, **kwargs): 28 return (args, kwargs) 29 30 31 @add_arg_scope 32 def func2(*args, **kwargs): 33 return (args, kwargs) 34 35 36 @add_arg_scope 37 def func3(args, a=None, b=1, c=2): 38 """Some cool doc string.""" 39 return (args, a, b, c) 40 41 42 def _key_op(op): 43 return getattr(op, '_key_op', str(op)) 44 45 46 class ArgScopeTest(test.TestCase): 47 48 def testEmptyArgScope(self): 49 with self.test_session(): 50 with arg_scope([]) as sc: 51 self.assertEqual(sc, {}) 52 53 def testClearArgScope(self): 54 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 55 key_op = _key_op(func1) 56 func1_scope = {key_op: func1_kwargs.copy()} 57 with self.test_session(): 58 with arg_scope([func1], a=1, b=None, c=[1]) as sc1: 59 self.assertEqual(sc1, func1_scope) 60 with arg_scope({}) as sc2: 61 self.assertEqual(sc2, {}) 62 with arg_scope([]) as current_arg_scope: 63 self.assertEqual(current_arg_scope, func1_scope) 64 65 def testNonDecorated(self): 66 67 def my_func(t, a=None): 68 return (t, a) 69 70 with self.assertRaises(ValueError): 71 with arg_scope([my_func], a=1): 72 pass 73 74 def testUnexpectedArg(self): 75 with self.assertRaises(TypeError): 76 with arg_scope([func3], d=1): 77 func3(1) 78 79 def testCurrentArgScope(self): 80 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 81 key_op = _key_op(func1) 82 current_scope = {key_op: func1_kwargs.copy()} 83 with self.test_session(): 84 with arg_scope([func1], a=1, b=None, c=[1]) as scope: 85 self.assertDictEqual(scope, current_scope) 86 87 def testArgScopedArguments(self): 88 func3_kwargs = ('a', 'b', 'c') 89 self.assertEquals(arg_scoped_arguments(func3), func3_kwargs) 90 91 def testCurrentArgScopeNested(self): 92 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 93 func2_kwargs = {'b': 2, 'd': [2]} 94 key = _key_op 95 current_scope = { 96 key(func1): func1_kwargs.copy(), 97 key(func2): func2_kwargs.copy() 98 } 99 with self.test_session(): 100 with arg_scope([func1], a=1, b=None, c=[1]): 101 with arg_scope([func2], b=2, d=[2]) as scope: 102 self.assertDictEqual(scope, current_scope) 103 104 def testReuseArgScope(self): 105 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 106 key_op = _key_op(func1) 107 current_scope = {key_op: func1_kwargs.copy()} 108 with self.test_session(): 109 with arg_scope([func1], a=1, b=None, c=[1]) as scope1: 110 pass 111 with arg_scope(scope1) as scope: 112 self.assertDictEqual(scope, current_scope) 113 114 def testReuseArgScopeNested(self): 115 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 116 func2_kwargs = {'b': 2, 'd': [2]} 117 key = _key_op 118 current_scope1 = {key(func1): func1_kwargs.copy()} 119 current_scope2 = { 120 key(func1): func1_kwargs.copy(), 121 key(func2): func2_kwargs.copy() 122 } 123 with self.test_session(): 124 with arg_scope([func1], a=1, b=None, c=[1]) as scope1: 125 with arg_scope([func2], b=2, d=[2]) as scope2: 126 pass 127 with arg_scope(scope1): 128 with arg_scope([]) as current_arg_scope: 129 self.assertDictEqual(current_arg_scope, current_scope1) 130 with arg_scope(scope2): 131 with arg_scope([]) as current_arg_scope: 132 self.assertDictEqual(current_arg_scope, current_scope2) 133 134 def testSimpleArgScope(self): 135 func1_args = (0,) 136 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 137 with self.test_session(): 138 with arg_scope([func1], a=1, b=None, c=[1]): 139 args, kwargs = func1(0) 140 self.assertTupleEqual(args, func1_args) 141 self.assertDictEqual(kwargs, func1_kwargs) 142 143 def testSimpleArgScopeWithTuple(self): 144 func1_args = (0,) 145 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 146 with self.test_session(): 147 with arg_scope((func1,), a=1, b=None, c=[1]): 148 args, kwargs = func1(0) 149 self.assertTupleEqual(args, func1_args) 150 self.assertDictEqual(kwargs, func1_kwargs) 151 152 def testOverwriteArgScope(self): 153 func1_args = (0,) 154 func1_kwargs = {'a': 1, 'b': 2, 'c': [1]} 155 with arg_scope([func1], a=1, b=None, c=[1]): 156 args, kwargs = func1(0, b=2) 157 self.assertTupleEqual(args, func1_args) 158 self.assertDictEqual(kwargs, func1_kwargs) 159 160 def testNestedArgScope(self): 161 func1_args = (0,) 162 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 163 with arg_scope([func1], a=1, b=None, c=[1]): 164 args, kwargs = func1(0) 165 self.assertTupleEqual(args, func1_args) 166 self.assertDictEqual(kwargs, func1_kwargs) 167 func1_kwargs['b'] = 2 168 with arg_scope([func1], b=2): 169 args, kwargs = func1(0) 170 self.assertTupleEqual(args, func1_args) 171 self.assertDictEqual(kwargs, func1_kwargs) 172 173 def testSharedArgScope(self): 174 func1_args = (0,) 175 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 176 with arg_scope([func1, func2], a=1, b=None, c=[1]): 177 args, kwargs = func1(0) 178 self.assertTupleEqual(args, func1_args) 179 self.assertDictEqual(kwargs, func1_kwargs) 180 args, kwargs = func2(0) 181 self.assertTupleEqual(args, func1_args) 182 self.assertDictEqual(kwargs, func1_kwargs) 183 184 def testSharedArgScopeTuple(self): 185 func1_args = (0,) 186 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 187 with arg_scope((func1, func2), a=1, b=None, c=[1]): 188 args, kwargs = func1(0) 189 self.assertTupleEqual(args, func1_args) 190 self.assertDictEqual(kwargs, func1_kwargs) 191 args, kwargs = func2(0) 192 self.assertTupleEqual(args, func1_args) 193 self.assertDictEqual(kwargs, func1_kwargs) 194 195 def testPartiallySharedArgScope(self): 196 func1_args = (0,) 197 func1_kwargs = {'a': 1, 'b': None, 'c': [1]} 198 func2_args = (1,) 199 func2_kwargs = {'a': 1, 'b': None, 'd': [2]} 200 with arg_scope([func1, func2], a=1, b=None): 201 with arg_scope([func1], c=[1]): 202 with arg_scope([func2], d=[2]): 203 args, kwargs = func1(0) 204 self.assertTupleEqual(args, func1_args) 205 self.assertDictEqual(kwargs, func1_kwargs) 206 args, kwargs = func2(1) 207 self.assertTupleEqual(args, func2_args) 208 self.assertDictEqual(kwargs, func2_kwargs) 209 210 def testDocString(self): 211 self.assertEqual(func3.__doc__, 'Some cool doc string.') 212 213 214 if __name__ == '__main__': 215 test.main() 216