Home | History | Annotate | Download | only in nets
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for slim.nets.overfeat."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.framework.python.ops import variables as variables_lib
     22 from tensorflow.contrib.slim.python.slim.nets import overfeat
     23 from tensorflow.python.ops import math_ops
     24 from tensorflow.python.ops import random_ops
     25 from tensorflow.python.ops import variable_scope
     26 from tensorflow.python.ops import variables
     27 from tensorflow.python.platform import test
     28 
     29 
     30 class OverFeatTest(test.TestCase):
     31 
     32   def testBuild(self):
     33     batch_size = 5
     34     height, width = 231, 231
     35     num_classes = 1000
     36     with self.test_session():
     37       inputs = random_ops.random_uniform((batch_size, height, width, 3))
     38       logits, _ = overfeat.overfeat(inputs, num_classes)
     39       self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed')
     40       self.assertListEqual(logits.get_shape().as_list(),
     41                            [batch_size, num_classes])
     42 
     43   def testFullyConvolutional(self):
     44     batch_size = 1
     45     height, width = 281, 281
     46     num_classes = 1000
     47     with self.test_session():
     48       inputs = random_ops.random_uniform((batch_size, height, width, 3))
     49       logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False)
     50       self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
     51       self.assertListEqual(logits.get_shape().as_list(),
     52                            [batch_size, 2, 2, num_classes])
     53 
     54   def testEndPoints(self):
     55     batch_size = 5
     56     height, width = 231, 231
     57     num_classes = 1000
     58     with self.test_session():
     59       inputs = random_ops.random_uniform((batch_size, height, width, 3))
     60       _, end_points = overfeat.overfeat(inputs, num_classes)
     61       expected_names = [
     62           'overfeat/conv1', 'overfeat/pool1', 'overfeat/conv2',
     63           'overfeat/pool2', 'overfeat/conv3', 'overfeat/conv4',
     64           'overfeat/conv5', 'overfeat/pool5', 'overfeat/fc6', 'overfeat/fc7',
     65           'overfeat/fc8'
     66       ]
     67       self.assertSetEqual(set(end_points.keys()), set(expected_names))
     68 
     69   def testModelVariables(self):
     70     batch_size = 5
     71     height, width = 231, 231
     72     num_classes = 1000
     73     with self.test_session():
     74       inputs = random_ops.random_uniform((batch_size, height, width, 3))
     75       overfeat.overfeat(inputs, num_classes)
     76       expected_names = [
     77           'overfeat/conv1/weights',
     78           'overfeat/conv1/biases',
     79           'overfeat/conv2/weights',
     80           'overfeat/conv2/biases',
     81           'overfeat/conv3/weights',
     82           'overfeat/conv3/biases',
     83           'overfeat/conv4/weights',
     84           'overfeat/conv4/biases',
     85           'overfeat/conv5/weights',
     86           'overfeat/conv5/biases',
     87           'overfeat/fc6/weights',
     88           'overfeat/fc6/biases',
     89           'overfeat/fc7/weights',
     90           'overfeat/fc7/biases',
     91           'overfeat/fc8/weights',
     92           'overfeat/fc8/biases',
     93       ]
     94       model_variables = [v.op.name for v in variables_lib.get_model_variables()]
     95       self.assertSetEqual(set(model_variables), set(expected_names))
     96 
     97   def testEvaluation(self):
     98     batch_size = 2
     99     height, width = 231, 231
    100     num_classes = 1000
    101     with self.test_session():
    102       eval_inputs = random_ops.random_uniform((batch_size, height, width, 3))
    103       logits, _ = overfeat.overfeat(eval_inputs, is_training=False)
    104       self.assertListEqual(logits.get_shape().as_list(),
    105                            [batch_size, num_classes])
    106       predictions = math_ops.argmax(logits, 1)
    107       self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
    108 
    109   def testTrainEvalWithReuse(self):
    110     train_batch_size = 2
    111     eval_batch_size = 1
    112     train_height, train_width = 231, 231
    113     eval_height, eval_width = 281, 281
    114     num_classes = 1000
    115     with self.test_session():
    116       train_inputs = random_ops.random_uniform(
    117           (train_batch_size, train_height, train_width, 3))
    118       logits, _ = overfeat.overfeat(train_inputs)
    119       self.assertListEqual(logits.get_shape().as_list(),
    120                            [train_batch_size, num_classes])
    121       variable_scope.get_variable_scope().reuse_variables()
    122       eval_inputs = random_ops.random_uniform(
    123           (eval_batch_size, eval_height, eval_width, 3))
    124       logits, _ = overfeat.overfeat(
    125           eval_inputs, is_training=False, spatial_squeeze=False)
    126       self.assertListEqual(logits.get_shape().as_list(),
    127                            [eval_batch_size, 2, 2, num_classes])
    128       logits = math_ops.reduce_mean(logits, [1, 2])
    129       predictions = math_ops.argmax(logits, 1)
    130       self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
    131 
    132   def testForward(self):
    133     batch_size = 1
    134     height, width = 231, 231
    135     with self.test_session() as sess:
    136       inputs = random_ops.random_uniform((batch_size, height, width, 3))
    137       logits, _ = overfeat.overfeat(inputs)
    138       sess.run(variables.global_variables_initializer())
    139       output = sess.run(logits)
    140       self.assertTrue(output.any())
    141 
    142 
    143 if __name__ == '__main__':
    144   test.main()
    145