Home | History | Annotate | Download | only in nets
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contains definitions for the original form of Residual Networks.
     16 
     17 The 'v1' residual networks (ResNets) implemented in this module were proposed
     18 by:
     19 [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
     20     Deep Residual Learning for Image Recognition. arXiv:1512.03385
     21 
     22 Other variants were introduced in:
     23 [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
     24     Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
     25 
     26 The networks defined in this module utilize the bottleneck building block of
     27 [1] with projection shortcuts only for increasing depths. They employ batch
     28 normalization *after* every weight layer. This is the architecture used by
     29 MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and
     30 ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1'
     31 architecture and the alternative 'v2' architecture of [2] which uses batch
     32 normalization *before* every weight layer in the so-called full pre-activation
     33 units.
     34 
     35 Typical use:
     36 
     37    from tensorflow.contrib.slim.python.slim.nets import
     38    resnet_v1
     39 
     40 ResNet-101 for image classification into 1000 classes:
     41 
     42    # inputs has shape [batch, 224, 224, 3]
     43    with slim.arg_scope(resnet_v1.resnet_arg_scope()):
     44       net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False)
     45 
     46 ResNet-101 for semantic segmentation into 21 classes:
     47 
     48    # inputs has shape [batch, 513, 513, 3]
     49    with slim.arg_scope(resnet_v1.resnet_arg_scope()):
     50       net, end_points = resnet_v1.resnet_v1_101(inputs,
     51                                                 21,
     52                                                 is_training=False,
     53                                                 global_pool=False,
     54                                                 output_stride=16)
     55 """
     56 
     57 from __future__ import absolute_import
     58 from __future__ import division
     59 from __future__ import print_function
     60 
     61 from tensorflow.contrib import layers
     62 from tensorflow.contrib.framework.python.ops import add_arg_scope
     63 from tensorflow.contrib.framework.python.ops import arg_scope
     64 from tensorflow.contrib.layers.python.layers import layers as layers_lib
     65 from tensorflow.contrib.layers.python.layers import utils
     66 from tensorflow.contrib.slim.python.slim.nets import resnet_utils
     67 from tensorflow.python.ops import math_ops
     68 from tensorflow.python.ops import nn_ops
     69 from tensorflow.python.ops import variable_scope
     70 
     71 resnet_arg_scope = resnet_utils.resnet_arg_scope
     72 
     73 
     74 @add_arg_scope
     75 def bottleneck(inputs,
     76                depth,
     77                depth_bottleneck,
     78                stride,
     79                rate=1,
     80                outputs_collections=None,
     81                scope=None):
     82   """Bottleneck residual unit variant with BN after convolutions.
     83 
     84   This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
     85   its definition. Note that we use here the bottleneck variant which has an
     86   extra bottleneck layer.
     87 
     88   When putting together two consecutive ResNet blocks that use this unit, one
     89   should use stride = 2 in the last unit of the first block.
     90 
     91   Args:
     92     inputs: A tensor of size [batch, height, width, channels].
     93     depth: The depth of the ResNet unit output.
     94     depth_bottleneck: The depth of the bottleneck layers.
     95     stride: The ResNet unit's stride. Determines the amount of downsampling of
     96       the units output compared to its input.
     97     rate: An integer, rate for atrous convolution.
     98     outputs_collections: Collection to add the ResNet unit output.
     99     scope: Optional variable_scope.
    100 
    101   Returns:
    102     The ResNet unit's output.
    103   """
    104   with variable_scope.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
    105     depth_in = utils.last_dimension(inputs.get_shape(), min_rank=4)
    106     if depth == depth_in:
    107       shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
    108     else:
    109       shortcut = layers.conv2d(
    110           inputs,
    111           depth, [1, 1],
    112           stride=stride,
    113           activation_fn=None,
    114           scope='shortcut')
    115 
    116     residual = layers.conv2d(
    117         inputs, depth_bottleneck, [1, 1], stride=1, scope='conv1')
    118     residual = resnet_utils.conv2d_same(
    119         residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2')
    120     residual = layers.conv2d(
    121         residual, depth, [1, 1], stride=1, activation_fn=None, scope='conv3')
    122 
    123     output = nn_ops.relu(shortcut + residual)
    124 
    125     return utils.collect_named_outputs(outputs_collections, sc.name, output)
    126 
    127 
    128 def resnet_v1(inputs,
    129               blocks,
    130               num_classes=None,
    131               is_training=True,
    132               global_pool=True,
    133               output_stride=None,
    134               include_root_block=True,
    135               reuse=None,
    136               scope=None):
    137   """Generator for v1 ResNet models.
    138 
    139   This function generates a family of ResNet v1 models. See the resnet_v1_*()
    140   methods for specific model instantiations, obtained by selecting different
    141   block instantiations that produce ResNets of various depths.
    142 
    143   Training for image classification on Imagenet is usually done with [224, 224]
    144   inputs, resulting in [7, 7] feature maps at the output of the last ResNet
    145   block for the ResNets defined in [1] that have nominal stride equal to 32.
    146   However, for dense prediction tasks we advise that one uses inputs with
    147   spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
    148   this case the feature maps at the ResNet output will have spatial shape
    149   [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
    150   and corners exactly aligned with the input image corners, which greatly
    151   facilitates alignment of the features to the image. Using as input [225, 225]
    152   images results in [8, 8] feature maps at the output of the last ResNet block.
    153 
    154   For dense prediction tasks, the ResNet needs to run in fully-convolutional
    155   (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
    156   have nominal stride equal to 32 and a good choice in FCN mode is to use
    157   output_stride=16 in order to increase the density of the computed features at
    158   small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
    159 
    160   Args:
    161     inputs: A tensor of size [batch, height_in, width_in, channels].
    162     blocks: A list of length equal to the number of ResNet blocks. Each element
    163       is a resnet_utils.Block object describing the units in the block.
    164     num_classes: Number of predicted classes for classification tasks. If None
    165       we return the features before the logit layer.
    166     is_training: whether batch_norm layers are in training mode.
    167     global_pool: If True, we perform global average pooling before computing the
    168       logits. Set to True for image classification, False for dense prediction.
    169     output_stride: If None, then the output will be computed at the nominal
    170       network stride. If output_stride is not None, it specifies the requested
    171       ratio of input to output spatial resolution.
    172     include_root_block: If True, include the initial convolution followed by
    173       max-pooling, if False excludes it.
    174     reuse: whether or not the network and its variables should be reused. To be
    175       able to reuse 'scope' must be given.
    176     scope: Optional variable_scope.
    177 
    178   Returns:
    179     net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
    180       If global_pool is False, then height_out and width_out are reduced by a
    181       factor of output_stride compared to the respective height_in and width_in,
    182       else both height_out and width_out equal one. If num_classes is None, then
    183       net is the output of the last ResNet block, potentially after global
    184       average pooling. If num_classes is not None, net contains the pre-softmax
    185       activations.
    186     end_points: A dictionary from components of the network to the corresponding
    187       activation.
    188 
    189   Raises:
    190     ValueError: If the target output_stride is not valid.
    191   """
    192   with variable_scope.variable_scope(
    193       scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
    194     end_points_collection = sc.original_name_scope + '_end_points'
    195     with arg_scope(
    196         [layers.conv2d, bottleneck, resnet_utils.stack_blocks_dense],
    197         outputs_collections=end_points_collection):
    198       with arg_scope([layers.batch_norm], is_training=is_training):
    199         net = inputs
    200         if include_root_block:
    201           if output_stride is not None:
    202             if output_stride % 4 != 0:
    203               raise ValueError('The output_stride needs to be a multiple of 4.')
    204             output_stride /= 4
    205           net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
    206           net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope='pool1')
    207         net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
    208         if global_pool:
    209           # Global average pooling.
    210           net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
    211         if num_classes is not None:
    212           net = layers.conv2d(
    213               net,
    214               num_classes, [1, 1],
    215               activation_fn=None,
    216               normalizer_fn=None,
    217               scope='logits')
    218         # Convert end_points_collection into a dictionary of end_points.
    219         end_points = utils.convert_collection_to_dict(end_points_collection)
    220         if num_classes is not None:
    221           end_points['predictions'] = layers_lib.softmax(
    222               net, scope='predictions')
    223         return net, end_points
    224 resnet_v1.default_image_size = 224
    225 
    226 
    227 def resnet_v1_block(scope, base_depth, num_units, stride):
    228   """Helper function for creating a resnet_v1 bottleneck block.
    229 
    230   Args:
    231     scope: The scope of the block.
    232     base_depth: The depth of the bottleneck layer for each unit.
    233     num_units: The number of units in the block.
    234     stride: The stride of the block, implemented as a stride in the last unit.
    235       All other units have stride=1.
    236 
    237   Returns:
    238     A resnet_v1 bottleneck block.
    239   """
    240   return resnet_utils.Block(scope, bottleneck, [{
    241       'depth': base_depth * 4,
    242       'depth_bottleneck': base_depth,
    243       'stride': 1
    244   }] * (num_units - 1) + [{
    245       'depth': base_depth * 4,
    246       'depth_bottleneck': base_depth,
    247       'stride': stride
    248   }])
    249 
    250 
    251 def resnet_v1_50(inputs,
    252                  num_classes=None,
    253                  is_training=True,
    254                  global_pool=True,
    255                  output_stride=None,
    256                  reuse=None,
    257                  scope='resnet_v1_50'):
    258   """ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
    259   blocks = [
    260       resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
    261       resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
    262       resnet_v1_block('block3', base_depth=256, num_units=6, stride=2),
    263       resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
    264   ]
    265   return resnet_v1(
    266       inputs,
    267       blocks,
    268       num_classes,
    269       is_training,
    270       global_pool,
    271       output_stride,
    272       include_root_block=True,
    273       reuse=reuse,
    274       scope=scope)
    275 
    276 
    277 def resnet_v1_101(inputs,
    278                   num_classes=None,
    279                   is_training=True,
    280                   global_pool=True,
    281                   output_stride=None,
    282                   reuse=None,
    283                   scope='resnet_v1_101'):
    284   """ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
    285   blocks = [
    286       resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
    287       resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
    288       resnet_v1_block('block3', base_depth=256, num_units=23, stride=2),
    289       resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
    290   ]
    291   return resnet_v1(
    292       inputs,
    293       blocks,
    294       num_classes,
    295       is_training,
    296       global_pool,
    297       output_stride,
    298       include_root_block=True,
    299       reuse=reuse,
    300       scope=scope)
    301 
    302 
    303 def resnet_v1_152(inputs,
    304                   num_classes=None,
    305                   is_training=True,
    306                   global_pool=True,
    307                   output_stride=None,
    308                   reuse=None,
    309                   scope='resnet_v1_152'):
    310   """ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
    311   blocks = [
    312       resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
    313       resnet_v1_block('block2', base_depth=128, num_units=8, stride=2),
    314       resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
    315       resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
    316   ]
    317   return resnet_v1(
    318       inputs,
    319       blocks,
    320       num_classes,
    321       is_training,
    322       global_pool,
    323       output_stride,
    324       include_root_block=True,
    325       reuse=reuse,
    326       scope=scope)
    327 
    328 
    329 def resnet_v1_200(inputs,
    330                   num_classes=None,
    331                   is_training=True,
    332                   global_pool=True,
    333                   output_stride=None,
    334                   reuse=None,
    335                   scope='resnet_v1_200'):
    336   """ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
    337   blocks = [
    338       resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
    339       resnet_v1_block('block2', base_depth=128, num_units=24, stride=2),
    340       resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
    341       resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
    342   ]
    343   return resnet_v1(
    344       inputs,
    345       blocks,
    346       num_classes,
    347       is_training,
    348       global_pool,
    349       output_stride,
    350       include_root_block=True,
    351       reuse=reuse,
    352       scope=scope)
    353