1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for slim.nets.overfeat.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.framework.python.ops import variables as variables_lib 22 from tensorflow.contrib.slim.python.slim.nets import overfeat 23 from tensorflow.python.ops import math_ops 24 from tensorflow.python.ops import random_ops 25 from tensorflow.python.ops import variable_scope 26 from tensorflow.python.ops import variables 27 from tensorflow.python.platform import test 28 29 30 class OverFeatTest(test.TestCase): 31 32 def testBuild(self): 33 batch_size = 5 34 height, width = 231, 231 35 num_classes = 1000 36 with self.test_session(): 37 inputs = random_ops.random_uniform((batch_size, height, width, 3)) 38 logits, _ = overfeat.overfeat(inputs, num_classes) 39 self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 40 self.assertListEqual(logits.get_shape().as_list(), 41 [batch_size, num_classes]) 42 43 def testFullyConvolutional(self): 44 batch_size = 1 45 height, width = 281, 281 46 num_classes = 1000 47 with self.test_session(): 48 inputs = random_ops.random_uniform((batch_size, height, width, 3)) 49 logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 50 self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 51 self.assertListEqual(logits.get_shape().as_list(), 52 [batch_size, 2, 2, num_classes]) 53 54 def testEndPoints(self): 55 batch_size = 5 56 height, width = 231, 231 57 num_classes = 1000 58 with self.test_session(): 59 inputs = random_ops.random_uniform((batch_size, height, width, 3)) 60 _, end_points = overfeat.overfeat(inputs, num_classes) 61 expected_names = [ 62 'overfeat/conv1', 'overfeat/pool1', 'overfeat/conv2', 63 'overfeat/pool2', 'overfeat/conv3', 'overfeat/conv4', 64 'overfeat/conv5', 'overfeat/pool5', 'overfeat/fc6', 'overfeat/fc7', 65 'overfeat/fc8' 66 ] 67 self.assertSetEqual(set(end_points.keys()), set(expected_names)) 68 69 def testModelVariables(self): 70 batch_size = 5 71 height, width = 231, 231 72 num_classes = 1000 73 with self.test_session(): 74 inputs = random_ops.random_uniform((batch_size, height, width, 3)) 75 overfeat.overfeat(inputs, num_classes) 76 expected_names = [ 77 'overfeat/conv1/weights', 78 'overfeat/conv1/biases', 79 'overfeat/conv2/weights', 80 'overfeat/conv2/biases', 81 'overfeat/conv3/weights', 82 'overfeat/conv3/biases', 83 'overfeat/conv4/weights', 84 'overfeat/conv4/biases', 85 'overfeat/conv5/weights', 86 'overfeat/conv5/biases', 87 'overfeat/fc6/weights', 88 'overfeat/fc6/biases', 89 'overfeat/fc7/weights', 90 'overfeat/fc7/biases', 91 'overfeat/fc8/weights', 92 'overfeat/fc8/biases', 93 ] 94 model_variables = [v.op.name for v in variables_lib.get_model_variables()] 95 self.assertSetEqual(set(model_variables), set(expected_names)) 96 97 def testEvaluation(self): 98 batch_size = 2 99 height, width = 231, 231 100 num_classes = 1000 101 with self.test_session(): 102 eval_inputs = random_ops.random_uniform((batch_size, height, width, 3)) 103 logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 104 self.assertListEqual(logits.get_shape().as_list(), 105 [batch_size, num_classes]) 106 predictions = math_ops.argmax(logits, 1) 107 self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 108 109 def testTrainEvalWithReuse(self): 110 train_batch_size = 2 111 eval_batch_size = 1 112 train_height, train_width = 231, 231 113 eval_height, eval_width = 281, 281 114 num_classes = 1000 115 with self.test_session(): 116 train_inputs = random_ops.random_uniform( 117 (train_batch_size, train_height, train_width, 3)) 118 logits, _ = overfeat.overfeat(train_inputs) 119 self.assertListEqual(logits.get_shape().as_list(), 120 [train_batch_size, num_classes]) 121 variable_scope.get_variable_scope().reuse_variables() 122 eval_inputs = random_ops.random_uniform( 123 (eval_batch_size, eval_height, eval_width, 3)) 124 logits, _ = overfeat.overfeat( 125 eval_inputs, is_training=False, spatial_squeeze=False) 126 self.assertListEqual(logits.get_shape().as_list(), 127 [eval_batch_size, 2, 2, num_classes]) 128 logits = math_ops.reduce_mean(logits, [1, 2]) 129 predictions = math_ops.argmax(logits, 1) 130 self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 131 132 def testForward(self): 133 batch_size = 1 134 height, width = 231, 231 135 with self.test_session() as sess: 136 inputs = random_ops.random_uniform((batch_size, height, width, 3)) 137 logits, _ = overfeat.overfeat(inputs) 138 sess.run(variables.global_variables_initializer()) 139 output = sess.run(logits) 140 self.assertTrue(output.any()) 141 142 143 if __name__ == '__main__': 144 test.main() 145