Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """arg_scope tests."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.contrib.framework.python.ops import add_arg_scope
     21 from tensorflow.contrib.framework.python.ops import arg_scope
     22 from tensorflow.contrib.framework.python.ops import arg_scoped_arguments
     23 from tensorflow.python.platform import test
     24 
     25 
     26 @add_arg_scope
     27 def func1(*args, **kwargs):
     28   return (args, kwargs)
     29 
     30 
     31 @add_arg_scope
     32 def func2(*args, **kwargs):
     33   return (args, kwargs)
     34 
     35 
     36 @add_arg_scope
     37 def func3(args, a=None, b=1, c=2):
     38   """Some cool doc string."""
     39   return (args, a, b, c)
     40 
     41 
     42 def _key_op(op):
     43   return getattr(op, '_key_op', str(op))
     44 
     45 
     46 class ArgScopeTest(test.TestCase):
     47 
     48   def testEmptyArgScope(self):
     49     with self.test_session():
     50       with arg_scope([]) as sc:
     51         self.assertEqual(sc, {})
     52 
     53   def testClearArgScope(self):
     54     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
     55     key_op = _key_op(func1)
     56     func1_scope = {key_op: func1_kwargs.copy()}
     57     with self.test_session():
     58       with arg_scope([func1], a=1, b=None, c=[1]) as sc1:
     59         self.assertEqual(sc1, func1_scope)
     60         with arg_scope({}) as sc2:
     61           self.assertEqual(sc2, {})
     62         with arg_scope([]) as current_arg_scope:
     63           self.assertEqual(current_arg_scope, func1_scope)
     64 
     65   def testNonDecorated(self):
     66 
     67     def my_func(t, a=None):
     68       return (t, a)
     69 
     70     with self.assertRaises(ValueError):
     71       with arg_scope([my_func], a=1):
     72         pass
     73 
     74   def testUnexpectedArg(self):
     75     with self.assertRaises(TypeError):
     76       with arg_scope([func3], d=1):
     77         func3(1)
     78 
     79   def testCurrentArgScope(self):
     80     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
     81     key_op = _key_op(func1)
     82     current_scope = {key_op: func1_kwargs.copy()}
     83     with self.test_session():
     84       with arg_scope([func1], a=1, b=None, c=[1]) as scope:
     85         self.assertDictEqual(scope, current_scope)
     86 
     87   def testArgScopedArguments(self):
     88     func3_kwargs = ('a', 'b', 'c')
     89     self.assertEquals(arg_scoped_arguments(func3), func3_kwargs)
     90 
     91   def testCurrentArgScopeNested(self):
     92     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
     93     func2_kwargs = {'b': 2, 'd': [2]}
     94     key = _key_op
     95     current_scope = {
     96         key(func1): func1_kwargs.copy(),
     97         key(func2): func2_kwargs.copy()
     98     }
     99     with self.test_session():
    100       with arg_scope([func1], a=1, b=None, c=[1]):
    101         with arg_scope([func2], b=2, d=[2]) as scope:
    102           self.assertDictEqual(scope, current_scope)
    103 
    104   def testReuseArgScope(self):
    105     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
    106     key_op = _key_op(func1)
    107     current_scope = {key_op: func1_kwargs.copy()}
    108     with self.test_session():
    109       with arg_scope([func1], a=1, b=None, c=[1]) as scope1:
    110         pass
    111       with arg_scope(scope1) as scope:
    112         self.assertDictEqual(scope, current_scope)
    113 
    114   def testReuseArgScopeNested(self):
    115     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
    116     func2_kwargs = {'b': 2, 'd': [2]}
    117     key = _key_op
    118     current_scope1 = {key(func1): func1_kwargs.copy()}
    119     current_scope2 = {
    120         key(func1): func1_kwargs.copy(),
    121         key(func2): func2_kwargs.copy()
    122     }
    123     with self.test_session():
    124       with arg_scope([func1], a=1, b=None, c=[1]) as scope1:
    125         with arg_scope([func2], b=2, d=[2]) as scope2:
    126           pass
    127       with arg_scope(scope1):
    128         with arg_scope([]) as current_arg_scope:
    129           self.assertDictEqual(current_arg_scope, current_scope1)
    130       with arg_scope(scope2):
    131         with arg_scope([]) as current_arg_scope:
    132           self.assertDictEqual(current_arg_scope, current_scope2)
    133 
    134   def testSimpleArgScope(self):
    135     func1_args = (0,)
    136     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
    137     with self.test_session():
    138       with arg_scope([func1], a=1, b=None, c=[1]):
    139         args, kwargs = func1(0)
    140         self.assertTupleEqual(args, func1_args)
    141         self.assertDictEqual(kwargs, func1_kwargs)
    142 
    143   def testSimpleArgScopeWithTuple(self):
    144     func1_args = (0,)
    145     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
    146     with self.test_session():
    147       with arg_scope((func1,), a=1, b=None, c=[1]):
    148         args, kwargs = func1(0)
    149         self.assertTupleEqual(args, func1_args)
    150         self.assertDictEqual(kwargs, func1_kwargs)
    151 
    152   def testOverwriteArgScope(self):
    153     func1_args = (0,)
    154     func1_kwargs = {'a': 1, 'b': 2, 'c': [1]}
    155     with arg_scope([func1], a=1, b=None, c=[1]):
    156       args, kwargs = func1(0, b=2)
    157       self.assertTupleEqual(args, func1_args)
    158       self.assertDictEqual(kwargs, func1_kwargs)
    159 
    160   def testNestedArgScope(self):
    161     func1_args = (0,)
    162     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
    163     with arg_scope([func1], a=1, b=None, c=[1]):
    164       args, kwargs = func1(0)
    165       self.assertTupleEqual(args, func1_args)
    166       self.assertDictEqual(kwargs, func1_kwargs)
    167       func1_kwargs['b'] = 2
    168       with arg_scope([func1], b=2):
    169         args, kwargs = func1(0)
    170         self.assertTupleEqual(args, func1_args)
    171         self.assertDictEqual(kwargs, func1_kwargs)
    172 
    173   def testSharedArgScope(self):
    174     func1_args = (0,)
    175     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
    176     with arg_scope([func1, func2], a=1, b=None, c=[1]):
    177       args, kwargs = func1(0)
    178       self.assertTupleEqual(args, func1_args)
    179       self.assertDictEqual(kwargs, func1_kwargs)
    180       args, kwargs = func2(0)
    181       self.assertTupleEqual(args, func1_args)
    182       self.assertDictEqual(kwargs, func1_kwargs)
    183 
    184   def testSharedArgScopeTuple(self):
    185     func1_args = (0,)
    186     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
    187     with arg_scope((func1, func2), a=1, b=None, c=[1]):
    188       args, kwargs = func1(0)
    189       self.assertTupleEqual(args, func1_args)
    190       self.assertDictEqual(kwargs, func1_kwargs)
    191       args, kwargs = func2(0)
    192       self.assertTupleEqual(args, func1_args)
    193       self.assertDictEqual(kwargs, func1_kwargs)
    194 
    195   def testPartiallySharedArgScope(self):
    196     func1_args = (0,)
    197     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
    198     func2_args = (1,)
    199     func2_kwargs = {'a': 1, 'b': None, 'd': [2]}
    200     with arg_scope([func1, func2], a=1, b=None):
    201       with arg_scope([func1], c=[1]):
    202         with arg_scope([func2], d=[2]):
    203           args, kwargs = func1(0)
    204           self.assertTupleEqual(args, func1_args)
    205           self.assertDictEqual(kwargs, func1_kwargs)
    206           args, kwargs = func2(1)
    207           self.assertTupleEqual(args, func2_args)
    208           self.assertDictEqual(kwargs, func2_kwargs)
    209 
    210   def testDocString(self):
    211     self.assertEqual(func3.__doc__, 'Some cool doc string.')
    212 
    213 
    214 if __name__ == '__main__':
    215   test.main()
    216