Home | History | Annotate | Download | only in applications
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 # pylint: disable=invalid-name
     16 # pylint: disable=unused-import
     17 """ResNet50 model for Keras.
     18 
     19 # Reference:
     20 
     21 - [Deep Residual Learning for Image
     22 Recognition](https://arxiv.org/abs/1512.03385)
     23 
     24 Adapted from code contributed by BigMoyan.
     25 """
     26 from __future__ import absolute_import
     27 from __future__ import division
     28 from __future__ import print_function
     29 
     30 import os
     31 
     32 from tensorflow.python.keras._impl.keras import backend as K
     33 from tensorflow.python.keras._impl.keras import layers
     34 from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
     35 from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
     36 from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
     37 from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
     38 from tensorflow.python.keras._impl.keras.layers import Activation
     39 from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
     40 from tensorflow.python.keras._impl.keras.layers import BatchNormalization
     41 from tensorflow.python.keras._impl.keras.layers import Conv2D
     42 from tensorflow.python.keras._impl.keras.layers import Dense
     43 from tensorflow.python.keras._impl.keras.layers import Flatten
     44 from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
     45 from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
     46 from tensorflow.python.keras._impl.keras.layers import Input
     47 from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
     48 from tensorflow.python.keras._impl.keras.models import Model
     49 from tensorflow.python.keras._impl.keras.utils import layer_utils
     50 from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
     51 from tensorflow.python.platform import tf_logging as logging
     52 from tensorflow.python.util.tf_export import tf_export
     53 
     54 
     55 WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'
     56 WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'
     57 
     58 
     59 def identity_block(input_tensor, kernel_size, filters, stage, block):
     60   """The identity block is the block that has no conv layer at shortcut.
     61 
     62   Arguments:
     63       input_tensor: input tensor
     64       kernel_size: default 3, the kernel size of middle conv layer at main path
     65       filters: list of integers, the filters of 3 conv layer at main path
     66       stage: integer, current stage label, used for generating layer names
     67       block: 'a','b'..., current block label, used for generating layer names
     68 
     69   Returns:
     70       Output tensor for the block.
     71   """
     72   filters1, filters2, filters3 = filters
     73   if K.image_data_format() == 'channels_last':
     74     bn_axis = 3
     75   else:
     76     bn_axis = 1
     77   conv_name_base = 'res' + str(stage) + block + '_branch'
     78   bn_name_base = 'bn' + str(stage) + block + '_branch'
     79 
     80   x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor)
     81   x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
     82   x = Activation('relu')(x)
     83 
     84   x = Conv2D(
     85       filters2, kernel_size, padding='same', name=conv_name_base + '2b')(
     86           x)
     87   x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
     88   x = Activation('relu')(x)
     89 
     90   x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
     91   x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
     92 
     93   x = layers.add([x, input_tensor])
     94   x = Activation('relu')(x)
     95   return x
     96 
     97 
     98 def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
     99                                                                           2)):
    100   """A block that has a conv layer at shortcut.
    101 
    102   Arguments:
    103       input_tensor: input tensor
    104       kernel_size: default 3, the kernel size of middle conv layer at main path
    105       filters: list of integers, the filters of 3 conv layer at main path
    106       stage: integer, current stage label, used for generating layer names
    107       block: 'a','b'..., current block label, used for generating layer names
    108       strides: Strides for the first conv layer in the block.
    109 
    110   Returns:
    111       Output tensor for the block.
    112 
    113   Note that from stage 3,
    114   the first conv layer at main path is with strides=(2, 2)
    115   And the shortcut should have strides=(2, 2) as well
    116   """
    117   filters1, filters2, filters3 = filters
    118   if K.image_data_format() == 'channels_last':
    119     bn_axis = 3
    120   else:
    121     bn_axis = 1
    122   conv_name_base = 'res' + str(stage) + block + '_branch'
    123   bn_name_base = 'bn' + str(stage) + block + '_branch'
    124 
    125   x = Conv2D(
    126       filters1, (1, 1), strides=strides, name=conv_name_base + '2a')(
    127           input_tensor)
    128   x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    129   x = Activation('relu')(x)
    130 
    131   x = Conv2D(
    132       filters2, kernel_size, padding='same', name=conv_name_base + '2b')(
    133           x)
    134   x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    135   x = Activation('relu')(x)
    136 
    137   x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
    138   x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
    139 
    140   shortcut = Conv2D(
    141       filters3, (1, 1), strides=strides, name=conv_name_base + '1')(
    142           input_tensor)
    143   shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
    144 
    145   x = layers.add([x, shortcut])
    146   x = Activation('relu')(x)
    147   return x
    148 
    149 
    150 @tf_export('keras.applications.ResNet50',
    151            'keras.applications.resnet50.ResNet50')
    152 def ResNet50(include_top=True,
    153              weights='imagenet',
    154              input_tensor=None,
    155              input_shape=None,
    156              pooling=None,
    157              classes=1000):
    158   """Instantiates the ResNet50 architecture.
    159 
    160   Optionally loads weights pre-trained
    161   on ImageNet. Note that when using TensorFlow,
    162   for best performance you should set
    163   `image_data_format='channels_last'` in your Keras config
    164   at ~/.keras/keras.json.
    165 
    166   The model and the weights are compatible with both
    167   TensorFlow and Theano. The data format
    168   convention used by the model is the one
    169   specified in your Keras config file.
    170 
    171   Arguments:
    172       include_top: whether to include the fully-connected
    173           layer at the top of the network.
    174       weights: one of `None` (random initialization),
    175             'imagenet' (pre-training on ImageNet),
    176             or the path to the weights file to be loaded.
    177       input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
    178           to use as image input for the model.
    179       input_shape: optional shape tuple, only to be specified
    180           if `include_top` is False (otherwise the input shape
    181           has to be `(224, 224, 3)` (with `channels_last` data format)
    182           or `(3, 224, 224)` (with `channels_first` data format).
    183           It should have exactly 3 inputs channels,
    184           and width and height should be no smaller than 197.
    185           E.g. `(200, 200, 3)` would be one valid value.
    186       pooling: Optional pooling mode for feature extraction
    187           when `include_top` is `False`.
    188           - `None` means that the output of the model will be
    189               the 4D tensor output of the
    190               last convolutional layer.
    191           - `avg` means that global average pooling
    192               will be applied to the output of the
    193               last convolutional layer, and thus
    194               the output of the model will be a 2D tensor.
    195           - `max` means that global max pooling will
    196               be applied.
    197       classes: optional number of classes to classify images
    198           into, only to be specified if `include_top` is True, and
    199           if no `weights` argument is specified.
    200 
    201   Returns:
    202       A Keras model instance.
    203 
    204   Raises:
    205       ValueError: in case of invalid argument for `weights`,
    206           or invalid input shape.
    207   """
    208   if not (weights in {'imagenet', None} or os.path.exists(weights)):
    209     raise ValueError('The `weights` argument should be either '
    210                      '`None` (random initialization), `imagenet` '
    211                      '(pre-training on ImageNet), '
    212                      'or the path to the weights file to be loaded.')
    213 
    214   if weights == 'imagenet' and include_top and classes != 1000:
    215     raise ValueError('If using `weights` as imagenet with `include_top`'
    216                      ' as true, `classes` should be 1000')
    217 
    218   # Determine proper input shape
    219   input_shape = _obtain_input_shape(
    220       input_shape,
    221       default_size=224,
    222       min_size=197,
    223       data_format=K.image_data_format(),
    224       require_flatten=include_top,
    225       weights=weights)
    226 
    227   if input_tensor is None:
    228     img_input = Input(shape=input_shape)
    229   else:
    230     if not K.is_keras_tensor(input_tensor):
    231       img_input = Input(tensor=input_tensor, shape=input_shape)
    232     else:
    233       img_input = input_tensor
    234   if K.image_data_format() == 'channels_last':
    235     bn_axis = 3
    236   else:
    237     bn_axis = 1
    238 
    239   x = Conv2D(
    240       64, (7, 7), strides=(2, 2), padding='same', name='conv1')(
    241           img_input)
    242   x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    243   x = Activation('relu')(x)
    244   x = MaxPooling2D((3, 3), strides=(2, 2))(x)
    245 
    246   x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    247   x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    248   x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
    249 
    250   x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    251   x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    252   x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    253   x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
    254 
    255   x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    256   x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    257   x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    258   x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    259   x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    260   x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
    261 
    262   x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    263   x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    264   x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
    265 
    266   x = AveragePooling2D((7, 7), name='avg_pool')(x)
    267 
    268   if include_top:
    269     x = Flatten()(x)
    270     x = Dense(classes, activation='softmax', name='fc1000')(x)
    271   else:
    272     if pooling == 'avg':
    273       x = GlobalAveragePooling2D()(x)
    274     elif pooling == 'max':
    275       x = GlobalMaxPooling2D()(x)
    276 
    277   # Ensure that the model takes into account
    278   # any potential predecessors of `input_tensor`.
    279   if input_tensor is not None:
    280     inputs = get_source_inputs(input_tensor)
    281   else:
    282     inputs = img_input
    283   # Create model.
    284   model = Model(inputs, x, name='resnet50')
    285 
    286   # load weights
    287   if weights == 'imagenet':
    288     if include_top:
    289       weights_path = get_file(
    290           'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
    291           WEIGHTS_PATH,
    292           cache_subdir='models',
    293           md5_hash='a7b3fe01876f51b976af0dea6bc144eb')
    294     else:
    295       weights_path = get_file(
    296           'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
    297           WEIGHTS_PATH_NO_TOP,
    298           cache_subdir='models',
    299           md5_hash='a268eb855778b3df3c7506639542a6af')
    300     model.load_weights(weights_path)
    301   elif weights is not None:
    302     model.load_weights(weights)
    303 
    304   return model
    305