Home | History | Annotate | Download | only in nets
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contains the definition for inception v1 classification network."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib import layers
     22 from tensorflow.contrib.framework.python.ops import arg_scope
     23 from tensorflow.contrib.layers.python.layers import initializers
     24 from tensorflow.contrib.layers.python.layers import layers as layers_lib
     25 from tensorflow.contrib.layers.python.layers import regularizers
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import init_ops
     29 from tensorflow.python.ops import nn_ops
     30 from tensorflow.python.ops import variable_scope
     31 
     32 trunc_normal = lambda stddev: init_ops.truncated_normal_initializer(0.0, stddev)
     33 
     34 
     35 def inception_v1_base(inputs, final_endpoint='Mixed_5c', scope='InceptionV1'):
     36   """Defines the Inception V1 base architecture.
     37 
     38   This architecture is defined in:
     39     Going deeper with convolutions
     40     Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
     41     Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
     42     http://arxiv.org/pdf/1409.4842v1.pdf.
     43 
     44   Args:
     45     inputs: a tensor of size [batch_size, height, width, channels].
     46     final_endpoint: specifies the endpoint to construct the network up to. It
     47       can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
     48       'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
     49       'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
     50       'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
     51     scope: Optional variable_scope.
     52 
     53   Returns:
     54     A dictionary from components of the network to the corresponding activation.
     55 
     56   Raises:
     57     ValueError: if final_endpoint is not set to one of the predefined values.
     58   """
     59   end_points = {}
     60   with variable_scope.variable_scope(scope, 'InceptionV1', [inputs]):
     61     with arg_scope(
     62         [layers.conv2d, layers_lib.fully_connected],
     63         weights_initializer=trunc_normal(0.01)):
     64       with arg_scope(
     65           [layers.conv2d, layers_lib.max_pool2d], stride=1, padding='SAME'):
     66         end_point = 'Conv2d_1a_7x7'
     67         net = layers.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point)
     68         end_points[end_point] = net
     69         if final_endpoint == end_point:
     70           return net, end_points
     71         end_point = 'MaxPool_2a_3x3'
     72         net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope=end_point)
     73         end_points[end_point] = net
     74         if final_endpoint == end_point:
     75           return net, end_points
     76         end_point = 'Conv2d_2b_1x1'
     77         net = layers.conv2d(net, 64, [1, 1], scope=end_point)
     78         end_points[end_point] = net
     79         if final_endpoint == end_point:
     80           return net, end_points
     81         end_point = 'Conv2d_2c_3x3'
     82         net = layers.conv2d(net, 192, [3, 3], scope=end_point)
     83         end_points[end_point] = net
     84         if final_endpoint == end_point:
     85           return net, end_points
     86         end_point = 'MaxPool_3a_3x3'
     87         net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope=end_point)
     88         end_points[end_point] = net
     89         if final_endpoint == end_point:
     90           return net, end_points
     91 
     92         end_point = 'Mixed_3b'
     93         with variable_scope.variable_scope(end_point):
     94           with variable_scope.variable_scope('Branch_0'):
     95             branch_0 = layers.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
     96           with variable_scope.variable_scope('Branch_1'):
     97             branch_1 = layers.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
     98             branch_1 = layers.conv2d(
     99                 branch_1, 128, [3, 3], scope='Conv2d_0b_3x3')
    100           with variable_scope.variable_scope('Branch_2'):
    101             branch_2 = layers.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
    102             branch_2 = layers.conv2d(
    103                 branch_2, 32, [3, 3], scope='Conv2d_0b_3x3')
    104           with variable_scope.variable_scope('Branch_3'):
    105             branch_3 = layers_lib.max_pool2d(
    106                 net, [3, 3], scope='MaxPool_0a_3x3')
    107             branch_3 = layers.conv2d(
    108                 branch_3, 32, [1, 1], scope='Conv2d_0b_1x1')
    109           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    110         end_points[end_point] = net
    111         if final_endpoint == end_point:
    112           return net, end_points
    113 
    114         end_point = 'Mixed_3c'
    115         with variable_scope.variable_scope(end_point):
    116           with variable_scope.variable_scope('Branch_0'):
    117             branch_0 = layers.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
    118           with variable_scope.variable_scope('Branch_1'):
    119             branch_1 = layers.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
    120             branch_1 = layers.conv2d(
    121                 branch_1, 192, [3, 3], scope='Conv2d_0b_3x3')
    122           with variable_scope.variable_scope('Branch_2'):
    123             branch_2 = layers.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
    124             branch_2 = layers.conv2d(
    125                 branch_2, 96, [3, 3], scope='Conv2d_0b_3x3')
    126           with variable_scope.variable_scope('Branch_3'):
    127             branch_3 = layers_lib.max_pool2d(
    128                 net, [3, 3], scope='MaxPool_0a_3x3')
    129             branch_3 = layers.conv2d(
    130                 branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
    131           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    132         end_points[end_point] = net
    133         if final_endpoint == end_point:
    134           return net, end_points
    135 
    136         end_point = 'MaxPool_4a_3x3'
    137         net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope=end_point)
    138         end_points[end_point] = net
    139         if final_endpoint == end_point:
    140           return net, end_points
    141 
    142         end_point = 'Mixed_4b'
    143         with variable_scope.variable_scope(end_point):
    144           with variable_scope.variable_scope('Branch_0'):
    145             branch_0 = layers.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
    146           with variable_scope.variable_scope('Branch_1'):
    147             branch_1 = layers.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
    148             branch_1 = layers.conv2d(
    149                 branch_1, 208, [3, 3], scope='Conv2d_0b_3x3')
    150           with variable_scope.variable_scope('Branch_2'):
    151             branch_2 = layers.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
    152             branch_2 = layers.conv2d(
    153                 branch_2, 48, [3, 3], scope='Conv2d_0b_3x3')
    154           with variable_scope.variable_scope('Branch_3'):
    155             branch_3 = layers_lib.max_pool2d(
    156                 net, [3, 3], scope='MaxPool_0a_3x3')
    157             branch_3 = layers.conv2d(
    158                 branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
    159           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    160         end_points[end_point] = net
    161         if final_endpoint == end_point:
    162           return net, end_points
    163 
    164         end_point = 'Mixed_4c'
    165         with variable_scope.variable_scope(end_point):
    166           with variable_scope.variable_scope('Branch_0'):
    167             branch_0 = layers.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
    168           with variable_scope.variable_scope('Branch_1'):
    169             branch_1 = layers.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
    170             branch_1 = layers.conv2d(
    171                 branch_1, 224, [3, 3], scope='Conv2d_0b_3x3')
    172           with variable_scope.variable_scope('Branch_2'):
    173             branch_2 = layers.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
    174             branch_2 = layers.conv2d(
    175                 branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
    176           with variable_scope.variable_scope('Branch_3'):
    177             branch_3 = layers_lib.max_pool2d(
    178                 net, [3, 3], scope='MaxPool_0a_3x3')
    179             branch_3 = layers.conv2d(
    180                 branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
    181           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    182         end_points[end_point] = net
    183         if final_endpoint == end_point:
    184           return net, end_points
    185 
    186         end_point = 'Mixed_4d'
    187         with variable_scope.variable_scope(end_point):
    188           with variable_scope.variable_scope('Branch_0'):
    189             branch_0 = layers.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
    190           with variable_scope.variable_scope('Branch_1'):
    191             branch_1 = layers.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
    192             branch_1 = layers.conv2d(
    193                 branch_1, 256, [3, 3], scope='Conv2d_0b_3x3')
    194           with variable_scope.variable_scope('Branch_2'):
    195             branch_2 = layers.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
    196             branch_2 = layers.conv2d(
    197                 branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
    198           with variable_scope.variable_scope('Branch_3'):
    199             branch_3 = layers_lib.max_pool2d(
    200                 net, [3, 3], scope='MaxPool_0a_3x3')
    201             branch_3 = layers.conv2d(
    202                 branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
    203           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    204         end_points[end_point] = net
    205         if final_endpoint == end_point:
    206           return net, end_points
    207 
    208         end_point = 'Mixed_4e'
    209         with variable_scope.variable_scope(end_point):
    210           with variable_scope.variable_scope('Branch_0'):
    211             branch_0 = layers.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
    212           with variable_scope.variable_scope('Branch_1'):
    213             branch_1 = layers.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1')
    214             branch_1 = layers.conv2d(
    215                 branch_1, 288, [3, 3], scope='Conv2d_0b_3x3')
    216           with variable_scope.variable_scope('Branch_2'):
    217             branch_2 = layers.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
    218             branch_2 = layers.conv2d(
    219                 branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
    220           with variable_scope.variable_scope('Branch_3'):
    221             branch_3 = layers_lib.max_pool2d(
    222                 net, [3, 3], scope='MaxPool_0a_3x3')
    223             branch_3 = layers.conv2d(
    224                 branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
    225           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    226         end_points[end_point] = net
    227         if final_endpoint == end_point:
    228           return net, end_points
    229 
    230         end_point = 'Mixed_4f'
    231         with variable_scope.variable_scope(end_point):
    232           with variable_scope.variable_scope('Branch_0'):
    233             branch_0 = layers.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
    234           with variable_scope.variable_scope('Branch_1'):
    235             branch_1 = layers.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
    236             branch_1 = layers.conv2d(
    237                 branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
    238           with variable_scope.variable_scope('Branch_2'):
    239             branch_2 = layers.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
    240             branch_2 = layers.conv2d(
    241                 branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
    242           with variable_scope.variable_scope('Branch_3'):
    243             branch_3 = layers_lib.max_pool2d(
    244                 net, [3, 3], scope='MaxPool_0a_3x3')
    245             branch_3 = layers.conv2d(
    246                 branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
    247           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    248         end_points[end_point] = net
    249         if final_endpoint == end_point:
    250           return net, end_points
    251 
    252         end_point = 'MaxPool_5a_2x2'
    253         net = layers_lib.max_pool2d(net, [2, 2], stride=2, scope=end_point)
    254         end_points[end_point] = net
    255         if final_endpoint == end_point:
    256           return net, end_points
    257 
    258         end_point = 'Mixed_5b'
    259         with variable_scope.variable_scope(end_point):
    260           with variable_scope.variable_scope('Branch_0'):
    261             branch_0 = layers.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
    262           with variable_scope.variable_scope('Branch_1'):
    263             branch_1 = layers.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
    264             branch_1 = layers.conv2d(
    265                 branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
    266           with variable_scope.variable_scope('Branch_2'):
    267             branch_2 = layers.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
    268             branch_2 = layers.conv2d(
    269                 branch_2, 128, [3, 3], scope='Conv2d_0a_3x3')
    270           with variable_scope.variable_scope('Branch_3'):
    271             branch_3 = layers_lib.max_pool2d(
    272                 net, [3, 3], scope='MaxPool_0a_3x3')
    273             branch_3 = layers.conv2d(
    274                 branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
    275           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    276         end_points[end_point] = net
    277         if final_endpoint == end_point:
    278           return net, end_points
    279 
    280         end_point = 'Mixed_5c'
    281         with variable_scope.variable_scope(end_point):
    282           with variable_scope.variable_scope('Branch_0'):
    283             branch_0 = layers.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1')
    284           with variable_scope.variable_scope('Branch_1'):
    285             branch_1 = layers.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
    286             branch_1 = layers.conv2d(
    287                 branch_1, 384, [3, 3], scope='Conv2d_0b_3x3')
    288           with variable_scope.variable_scope('Branch_2'):
    289             branch_2 = layers.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1')
    290             branch_2 = layers.conv2d(
    291                 branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
    292           with variable_scope.variable_scope('Branch_3'):
    293             branch_3 = layers_lib.max_pool2d(
    294                 net, [3, 3], scope='MaxPool_0a_3x3')
    295             branch_3 = layers.conv2d(
    296                 branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
    297           net = array_ops.concat([branch_0, branch_1, branch_2, branch_3], 3)
    298         end_points[end_point] = net
    299         if final_endpoint == end_point:
    300           return net, end_points
    301     raise ValueError('Unknown final endpoint %s' % final_endpoint)
    302 
    303 
    304 def inception_v1(inputs,
    305                  num_classes=1000,
    306                  is_training=True,
    307                  dropout_keep_prob=0.8,
    308                  prediction_fn=layers_lib.softmax,
    309                  spatial_squeeze=True,
    310                  reuse=None,
    311                  scope='InceptionV1'):
    312   """Defines the Inception V1 architecture.
    313 
    314   This architecture is defined in:
    315 
    316     Going deeper with convolutions
    317     Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
    318     Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
    319     http://arxiv.org/pdf/1409.4842v1.pdf.
    320 
    321   The default image size used to train this network is 224x224.
    322 
    323   Args:
    324     inputs: a tensor of size [batch_size, height, width, channels].
    325     num_classes: number of predicted classes.
    326     is_training: whether is training or not.
    327     dropout_keep_prob: the percentage of activation values that are retained.
    328     prediction_fn: a function to get predictions out of logits.
    329     spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
    330         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
    331     reuse: whether or not the network and its variables should be reused. To be
    332       able to reuse 'scope' must be given.
    333     scope: Optional variable_scope.
    334 
    335   Returns:
    336     logits: the pre-softmax activations, a tensor of size
    337       [batch_size, num_classes]
    338     end_points: a dictionary from components of the network to the corresponding
    339       activation.
    340   """
    341   # Final pooling and prediction
    342   with variable_scope.variable_scope(
    343       scope, 'InceptionV1', [inputs, num_classes], reuse=reuse) as scope:
    344     with arg_scope(
    345         [layers_lib.batch_norm, layers_lib.dropout], is_training=is_training):
    346       net, end_points = inception_v1_base(inputs, scope=scope)
    347       with variable_scope.variable_scope('Logits'):
    348         net = layers_lib.avg_pool2d(
    349             net, [7, 7], stride=1, scope='MaxPool_0a_7x7')
    350         net = layers_lib.dropout(net, dropout_keep_prob, scope='Dropout_0b')
    351         logits = layers.conv2d(
    352             net,
    353             num_classes, [1, 1],
    354             activation_fn=None,
    355             normalizer_fn=None,
    356             scope='Conv2d_0c_1x1')
    357         if spatial_squeeze:
    358           logits = array_ops.squeeze(logits, [1, 2], name='SpatialSqueeze')
    359 
    360         end_points['Logits'] = logits
    361         end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
    362   return logits, end_points
    363 
    364 
    365 inception_v1.default_image_size = 224
    366 
    367 
    368 def inception_v1_arg_scope(weight_decay=0.00004,
    369                            use_batch_norm=True,
    370                            batch_norm_var_collection='moving_vars'):
    371   """Defines the default InceptionV1 arg scope.
    372 
    373   Note: Althougth the original paper didn't use batch_norm we found it useful.
    374 
    375   Args:
    376     weight_decay: The weight decay to use for regularizing the model.
    377     use_batch_norm: "If `True`, batch_norm is applied after each convolution.
    378     batch_norm_var_collection: The name of the collection for the batch norm
    379       variables.
    380 
    381   Returns:
    382     An `arg_scope` to use for the inception v3 model.
    383   """
    384   batch_norm_params = {
    385       # Decay for the moving averages.
    386       'decay': 0.9997,
    387       # epsilon to prevent 0s in variance.
    388       'epsilon': 0.001,
    389       # collection containing update_ops.
    390       'updates_collections': ops.GraphKeys.UPDATE_OPS,
    391       # collection containing the moving mean and moving variance.
    392       'variables_collections': {
    393           'beta': None,
    394           'gamma': None,
    395           'moving_mean': [batch_norm_var_collection],
    396           'moving_variance': [batch_norm_var_collection],
    397       }
    398   }
    399   if use_batch_norm:
    400     normalizer_fn = layers_lib.batch_norm
    401     normalizer_params = batch_norm_params
    402   else:
    403     normalizer_fn = None
    404     normalizer_params = {}
    405   # Set weight_decay for weights in Conv and FC layers.
    406   with arg_scope(
    407       [layers.conv2d, layers_lib.fully_connected],
    408       weights_regularizer=regularizers.l2_regularizer(weight_decay)):
    409     with arg_scope(
    410         [layers.conv2d],
    411         weights_initializer=initializers.variance_scaling_initializer(),
    412         activation_fn=nn_ops.relu,
    413         normalizer_fn=normalizer_fn,
    414         normalizer_params=normalizer_params) as sc:
    415       return sc
    416