Home | History | Annotate | Download | only in densenet
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Densely Connected Convolutional Networks.
     16 
     17 Reference [
     18 Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)
     19 
     20 """
     21 
     22 from __future__ import absolute_import
     23 from __future__ import division
     24 from __future__ import print_function
     25 
     26 import tensorflow as tf
     27 l2 = tf.keras.regularizers.l2
     28 
     29 
     30 class ConvBlock(tf.keras.Model):
     31   """Convolutional Block consisting of (batchnorm->relu->conv).
     32 
     33   Arguments:
     34     num_filters: number of filters passed to a convolutional layer.
     35     data_format: "channels_first" or "channels_last"
     36     bottleneck: if True, then a 1x1 Conv is performed followed by 3x3 Conv.
     37     weight_decay: weight decay
     38     dropout_rate: dropout rate.
     39   """
     40 
     41   def __init__(self, num_filters, data_format, bottleneck, weight_decay=1e-4,
     42                dropout_rate=0):
     43     super(ConvBlock, self).__init__()
     44     self.bottleneck = bottleneck
     45 
     46     axis = -1 if data_format == "channels_last" else 1
     47     inter_filter = num_filters * 4
     48     # don't forget to set use_bias=False when using batchnorm
     49     self.conv2 = tf.keras.layers.Conv2D(num_filters,
     50                                         (3, 3),
     51                                         padding="same",
     52                                         use_bias=False,
     53                                         data_format=data_format,
     54                                         kernel_initializer="he_normal",
     55                                         kernel_regularizer=l2(weight_decay))
     56     self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis)
     57     self.dropout = tf.keras.layers.Dropout(dropout_rate)
     58 
     59     if self.bottleneck:
     60       self.conv1 = tf.keras.layers.Conv2D(inter_filter,
     61                                           (1, 1),
     62                                           padding="same",
     63                                           use_bias=False,
     64                                           data_format=data_format,
     65                                           kernel_initializer="he_normal",
     66                                           kernel_regularizer=l2(weight_decay))
     67       self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis)
     68 
     69   def call(self, x, training=True):
     70     output = self.batchnorm1(x, training=training)
     71 
     72     if self.bottleneck:
     73       output = self.conv1(tf.nn.relu(output))
     74       output = self.batchnorm2(output, training=training)
     75 
     76     output = self.conv2(tf.nn.relu(output))
     77     output = self.dropout(output, training=training)
     78 
     79     return output
     80 
     81 
     82 class TransitionBlock(tf.keras.Model):
     83   """Transition Block to reduce the number of features.
     84 
     85   Arguments:
     86     num_filters: number of filters passed to a convolutional layer.
     87     data_format: "channels_first" or "channels_last"
     88     weight_decay: weight decay
     89     dropout_rate: dropout rate.
     90   """
     91 
     92   def __init__(self, num_filters, data_format,
     93                weight_decay=1e-4, dropout_rate=0):
     94     super(TransitionBlock, self).__init__()
     95     axis = -1 if data_format == "channels_last" else 1
     96 
     97     self.batchnorm = tf.keras.layers.BatchNormalization(axis=axis)
     98     self.conv = tf.keras.layers.Conv2D(num_filters,
     99                                        (1, 1),
    100                                        padding="same",
    101                                        use_bias=False,
    102                                        data_format=data_format,
    103                                        kernel_initializer="he_normal",
    104                                        kernel_regularizer=l2(weight_decay))
    105     self.avg_pool = tf.keras.layers.AveragePooling2D(data_format=data_format)
    106 
    107   def call(self, x, training=True):
    108     output = self.batchnorm(x, training=training)
    109     output = self.conv(tf.nn.relu(output))
    110     output = self.avg_pool(output)
    111     return output
    112 
    113 
    114 class DenseBlock(tf.keras.Model):
    115   """Dense Block consisting of ConvBlocks where each block's
    116   output is concatenated with its input.
    117 
    118   Arguments:
    119     num_layers: Number of layers in each block.
    120     growth_rate: number of filters to add per conv block.
    121     data_format: "channels_first" or "channels_last"
    122     bottleneck: boolean, that decides which part of ConvBlock to call.
    123     weight_decay: weight decay
    124     dropout_rate: dropout rate.
    125   """
    126 
    127   def __init__(self, num_layers, growth_rate, data_format, bottleneck,
    128                weight_decay=1e-4, dropout_rate=0):
    129     super(DenseBlock, self).__init__()
    130     self.num_layers = num_layers
    131     self.axis = -1 if data_format == "channels_last" else 1
    132 
    133     self.blocks = []
    134     for _ in range(int(self.num_layers)):
    135       self.blocks.append(ConvBlock(growth_rate,
    136                                    data_format,
    137                                    bottleneck,
    138                                    weight_decay,
    139                                    dropout_rate))
    140 
    141   def call(self, x, training=True):
    142     for i in range(int(self.num_layers)):
    143       output = self.blocks[i](x, training=training)
    144       x = tf.concat([x, output], axis=self.axis)
    145 
    146     return x
    147 
    148 
    149 class DenseNet(tf.keras.Model):
    150   """Creating the Densenet Architecture.
    151 
    152   Arguments:
    153     depth_of_model: number of layers in the model.
    154     growth_rate: number of filters to add per conv block.
    155     num_of_blocks: number of dense blocks.
    156     output_classes: number of output classes.
    157     num_layers_in_each_block: number of layers in each block.
    158                               If -1, then we calculate this by (depth-3)/4.
    159                               If positive integer, then the it is used as the
    160                                 number of layers per block.
    161                               If list or tuple, then this list is used directly.
    162     data_format: "channels_first" or "channels_last"
    163     bottleneck: boolean, to decide which part of conv block to call.
    164     compression: reducing the number of inputs(filters) to the transition block.
    165     weight_decay: weight decay
    166     rate: dropout rate.
    167     pool_initial: If True add a 7x7 conv with stride 2 followed by 3x3 maxpool
    168                   else, do a 3x3 conv with stride 1.
    169     include_top: If true, GlobalAveragePooling Layer and Dense layer are
    170                  included.
    171   """
    172 
    173   def __init__(self, depth_of_model, growth_rate, num_of_blocks,
    174                output_classes, num_layers_in_each_block, data_format,
    175                bottleneck=True, compression=0.5, weight_decay=1e-4,
    176                dropout_rate=0, pool_initial=False, include_top=True):
    177     super(DenseNet, self).__init__()
    178     self.depth_of_model = depth_of_model
    179     self.growth_rate = growth_rate
    180     self.num_of_blocks = num_of_blocks
    181     self.output_classes = output_classes
    182     self.num_layers_in_each_block = num_layers_in_each_block
    183     self.data_format = data_format
    184     self.bottleneck = bottleneck
    185     self.compression = compression
    186     self.weight_decay = weight_decay
    187     self.dropout_rate = dropout_rate
    188     self.pool_initial = pool_initial
    189     self.include_top = include_top
    190 
    191     # deciding on number of layers in each block
    192     if isinstance(self.num_layers_in_each_block, list) or isinstance(
    193         self.num_layers_in_each_block, tuple):
    194       self.num_layers_in_each_block = list(self.num_layers_in_each_block)
    195     else:
    196       if self.num_layers_in_each_block == -1:
    197         if self.num_of_blocks != 3:
    198           raise ValueError(
    199               "Number of blocks must be 3 if num_layers_in_each_block is -1")
    200         if (self.depth_of_model - 4) % 3 == 0:
    201           num_layers = (self.depth_of_model - 4) / 3
    202           if self.bottleneck:
    203             num_layers //= 2
    204           self.num_layers_in_each_block = [num_layers] * self.num_of_blocks
    205         else:
    206           raise ValueError("Depth must be 3N+4 if num_layer_in_each_block=-1")
    207       else:
    208         self.num_layers_in_each_block = [
    209             self.num_layers_in_each_block] * self.num_of_blocks
    210 
    211     axis = -1 if self.data_format == "channels_last" else 1
    212 
    213     # setting the filters and stride of the initial covn layer.
    214     if self.pool_initial:
    215       init_filters = (7, 7)
    216       stride = (2, 2)
    217     else:
    218       init_filters = (3, 3)
    219       stride = (1, 1)
    220 
    221     self.num_filters = 2 * self.growth_rate
    222 
    223     # first conv and pool layer
    224     self.conv1 = tf.keras.layers.Conv2D(self.num_filters,
    225                                         init_filters,
    226                                         strides=stride,
    227                                         padding="same",
    228                                         use_bias=False,
    229                                         data_format=self.data_format,
    230                                         kernel_initializer="he_normal",
    231                                         kernel_regularizer=l2(
    232                                             self.weight_decay))
    233     if self.pool_initial:
    234       self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(3, 3),
    235                                                 strides=(2, 2),
    236                                                 padding="same",
    237                                                 data_format=self.data_format)
    238       self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis)
    239 
    240     self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis)
    241 
    242     # last pooling and fc layer
    243     if self.include_top:
    244       self.last_pool = tf.keras.layers.GlobalAveragePooling2D(
    245           data_format=self.data_format)
    246       self.classifier = tf.keras.layers.Dense(self.output_classes)
    247 
    248     # calculating the number of filters after each block
    249     num_filters_after_each_block = [self.num_filters]
    250     for i in range(1, self.num_of_blocks):
    251       temp_num_filters = num_filters_after_each_block[i-1] + (
    252           self.growth_rate * self.num_layers_in_each_block[i-1])
    253       # using compression to reduce the number of inputs to the
    254       # transition block
    255       temp_num_filters = int(temp_num_filters * compression)
    256       num_filters_after_each_block.append(temp_num_filters)
    257 
    258     # dense block initialization
    259     self.dense_blocks = []
    260     self.transition_blocks = []
    261     for i in range(self.num_of_blocks):
    262       self.dense_blocks.append(DenseBlock(self.num_layers_in_each_block[i],
    263                                           self.growth_rate,
    264                                           self.data_format,
    265                                           self.bottleneck,
    266                                           self.weight_decay,
    267                                           self.dropout_rate))
    268       if i+1 < self.num_of_blocks:
    269         self.transition_blocks.append(
    270             TransitionBlock(num_filters_after_each_block[i+1],
    271                             self.data_format,
    272                             self.weight_decay,
    273                             self.dropout_rate))
    274 
    275   def call(self, x, training=True):
    276     output = self.conv1(x)
    277 
    278     if self.pool_initial:
    279       output = self.batchnorm1(output, training=training)
    280       output = tf.nn.relu(output)
    281       output = self.pool1(output)
    282 
    283     for i in range(self.num_of_blocks - 1):
    284       output = self.dense_blocks[i](output, training=training)
    285       output = self.transition_blocks[i](output, training=training)
    286 
    287     output = self.dense_blocks[
    288         self.num_of_blocks - 1](output, training=training)
    289     output = self.batchnorm2(output, training=training)
    290     output = tf.nn.relu(output)
    291 
    292     if self.include_top:
    293       output = self.last_pool(output)
    294       output = self.classifier(output)
    295 
    296     return output
    297