1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Densely Connected Convolutional Networks. 16 17 Reference [ 18 Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993) 19 20 """ 21 22 from __future__ import absolute_import 23 from __future__ import division 24 from __future__ import print_function 25 26 import tensorflow as tf 27 l2 = tf.keras.regularizers.l2 28 29 30 class ConvBlock(tf.keras.Model): 31 """Convolutional Block consisting of (batchnorm->relu->conv). 32 33 Arguments: 34 num_filters: number of filters passed to a convolutional layer. 35 data_format: "channels_first" or "channels_last" 36 bottleneck: if True, then a 1x1 Conv is performed followed by 3x3 Conv. 37 weight_decay: weight decay 38 dropout_rate: dropout rate. 39 """ 40 41 def __init__(self, num_filters, data_format, bottleneck, weight_decay=1e-4, 42 dropout_rate=0): 43 super(ConvBlock, self).__init__() 44 self.bottleneck = bottleneck 45 46 axis = -1 if data_format == "channels_last" else 1 47 inter_filter = num_filters * 4 48 # don't forget to set use_bias=False when using batchnorm 49 self.conv2 = tf.keras.layers.Conv2D(num_filters, 50 (3, 3), 51 padding="same", 52 use_bias=False, 53 data_format=data_format, 54 kernel_initializer="he_normal", 55 kernel_regularizer=l2(weight_decay)) 56 self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis) 57 self.dropout = tf.keras.layers.Dropout(dropout_rate) 58 59 if self.bottleneck: 60 self.conv1 = tf.keras.layers.Conv2D(inter_filter, 61 (1, 1), 62 padding="same", 63 use_bias=False, 64 data_format=data_format, 65 kernel_initializer="he_normal", 66 kernel_regularizer=l2(weight_decay)) 67 self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis) 68 69 def call(self, x, training=True): 70 output = self.batchnorm1(x, training=training) 71 72 if self.bottleneck: 73 output = self.conv1(tf.nn.relu(output)) 74 output = self.batchnorm2(output, training=training) 75 76 output = self.conv2(tf.nn.relu(output)) 77 output = self.dropout(output, training=training) 78 79 return output 80 81 82 class TransitionBlock(tf.keras.Model): 83 """Transition Block to reduce the number of features. 84 85 Arguments: 86 num_filters: number of filters passed to a convolutional layer. 87 data_format: "channels_first" or "channels_last" 88 weight_decay: weight decay 89 dropout_rate: dropout rate. 90 """ 91 92 def __init__(self, num_filters, data_format, 93 weight_decay=1e-4, dropout_rate=0): 94 super(TransitionBlock, self).__init__() 95 axis = -1 if data_format == "channels_last" else 1 96 97 self.batchnorm = tf.keras.layers.BatchNormalization(axis=axis) 98 self.conv = tf.keras.layers.Conv2D(num_filters, 99 (1, 1), 100 padding="same", 101 use_bias=False, 102 data_format=data_format, 103 kernel_initializer="he_normal", 104 kernel_regularizer=l2(weight_decay)) 105 self.avg_pool = tf.keras.layers.AveragePooling2D(data_format=data_format) 106 107 def call(self, x, training=True): 108 output = self.batchnorm(x, training=training) 109 output = self.conv(tf.nn.relu(output)) 110 output = self.avg_pool(output) 111 return output 112 113 114 class DenseBlock(tf.keras.Model): 115 """Dense Block consisting of ConvBlocks where each block's 116 output is concatenated with its input. 117 118 Arguments: 119 num_layers: Number of layers in each block. 120 growth_rate: number of filters to add per conv block. 121 data_format: "channels_first" or "channels_last" 122 bottleneck: boolean, that decides which part of ConvBlock to call. 123 weight_decay: weight decay 124 dropout_rate: dropout rate. 125 """ 126 127 def __init__(self, num_layers, growth_rate, data_format, bottleneck, 128 weight_decay=1e-4, dropout_rate=0): 129 super(DenseBlock, self).__init__() 130 self.num_layers = num_layers 131 self.axis = -1 if data_format == "channels_last" else 1 132 133 self.blocks = [] 134 for _ in range(int(self.num_layers)): 135 self.blocks.append(ConvBlock(growth_rate, 136 data_format, 137 bottleneck, 138 weight_decay, 139 dropout_rate)) 140 141 def call(self, x, training=True): 142 for i in range(int(self.num_layers)): 143 output = self.blocks[i](x, training=training) 144 x = tf.concat([x, output], axis=self.axis) 145 146 return x 147 148 149 class DenseNet(tf.keras.Model): 150 """Creating the Densenet Architecture. 151 152 Arguments: 153 depth_of_model: number of layers in the model. 154 growth_rate: number of filters to add per conv block. 155 num_of_blocks: number of dense blocks. 156 output_classes: number of output classes. 157 num_layers_in_each_block: number of layers in each block. 158 If -1, then we calculate this by (depth-3)/4. 159 If positive integer, then the it is used as the 160 number of layers per block. 161 If list or tuple, then this list is used directly. 162 data_format: "channels_first" or "channels_last" 163 bottleneck: boolean, to decide which part of conv block to call. 164 compression: reducing the number of inputs(filters) to the transition block. 165 weight_decay: weight decay 166 rate: dropout rate. 167 pool_initial: If True add a 7x7 conv with stride 2 followed by 3x3 maxpool 168 else, do a 3x3 conv with stride 1. 169 include_top: If true, GlobalAveragePooling Layer and Dense layer are 170 included. 171 """ 172 173 def __init__(self, depth_of_model, growth_rate, num_of_blocks, 174 output_classes, num_layers_in_each_block, data_format, 175 bottleneck=True, compression=0.5, weight_decay=1e-4, 176 dropout_rate=0, pool_initial=False, include_top=True): 177 super(DenseNet, self).__init__() 178 self.depth_of_model = depth_of_model 179 self.growth_rate = growth_rate 180 self.num_of_blocks = num_of_blocks 181 self.output_classes = output_classes 182 self.num_layers_in_each_block = num_layers_in_each_block 183 self.data_format = data_format 184 self.bottleneck = bottleneck 185 self.compression = compression 186 self.weight_decay = weight_decay 187 self.dropout_rate = dropout_rate 188 self.pool_initial = pool_initial 189 self.include_top = include_top 190 191 # deciding on number of layers in each block 192 if isinstance(self.num_layers_in_each_block, list) or isinstance( 193 self.num_layers_in_each_block, tuple): 194 self.num_layers_in_each_block = list(self.num_layers_in_each_block) 195 else: 196 if self.num_layers_in_each_block == -1: 197 if self.num_of_blocks != 3: 198 raise ValueError( 199 "Number of blocks must be 3 if num_layers_in_each_block is -1") 200 if (self.depth_of_model - 4) % 3 == 0: 201 num_layers = (self.depth_of_model - 4) / 3 202 if self.bottleneck: 203 num_layers //= 2 204 self.num_layers_in_each_block = [num_layers] * self.num_of_blocks 205 else: 206 raise ValueError("Depth must be 3N+4 if num_layer_in_each_block=-1") 207 else: 208 self.num_layers_in_each_block = [ 209 self.num_layers_in_each_block] * self.num_of_blocks 210 211 axis = -1 if self.data_format == "channels_last" else 1 212 213 # setting the filters and stride of the initial covn layer. 214 if self.pool_initial: 215 init_filters = (7, 7) 216 stride = (2, 2) 217 else: 218 init_filters = (3, 3) 219 stride = (1, 1) 220 221 self.num_filters = 2 * self.growth_rate 222 223 # first conv and pool layer 224 self.conv1 = tf.keras.layers.Conv2D(self.num_filters, 225 init_filters, 226 strides=stride, 227 padding="same", 228 use_bias=False, 229 data_format=self.data_format, 230 kernel_initializer="he_normal", 231 kernel_regularizer=l2( 232 self.weight_decay)) 233 if self.pool_initial: 234 self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(3, 3), 235 strides=(2, 2), 236 padding="same", 237 data_format=self.data_format) 238 self.batchnorm1 = tf.keras.layers.BatchNormalization(axis=axis) 239 240 self.batchnorm2 = tf.keras.layers.BatchNormalization(axis=axis) 241 242 # last pooling and fc layer 243 if self.include_top: 244 self.last_pool = tf.keras.layers.GlobalAveragePooling2D( 245 data_format=self.data_format) 246 self.classifier = tf.keras.layers.Dense(self.output_classes) 247 248 # calculating the number of filters after each block 249 num_filters_after_each_block = [self.num_filters] 250 for i in range(1, self.num_of_blocks): 251 temp_num_filters = num_filters_after_each_block[i-1] + ( 252 self.growth_rate * self.num_layers_in_each_block[i-1]) 253 # using compression to reduce the number of inputs to the 254 # transition block 255 temp_num_filters = int(temp_num_filters * compression) 256 num_filters_after_each_block.append(temp_num_filters) 257 258 # dense block initialization 259 self.dense_blocks = [] 260 self.transition_blocks = [] 261 for i in range(self.num_of_blocks): 262 self.dense_blocks.append(DenseBlock(self.num_layers_in_each_block[i], 263 self.growth_rate, 264 self.data_format, 265 self.bottleneck, 266 self.weight_decay, 267 self.dropout_rate)) 268 if i+1 < self.num_of_blocks: 269 self.transition_blocks.append( 270 TransitionBlock(num_filters_after_each_block[i+1], 271 self.data_format, 272 self.weight_decay, 273 self.dropout_rate)) 274 275 def call(self, x, training=True): 276 output = self.conv1(x) 277 278 if self.pool_initial: 279 output = self.batchnorm1(output, training=training) 280 output = tf.nn.relu(output) 281 output = self.pool1(output) 282 283 for i in range(self.num_of_blocks - 1): 284 output = self.dense_blocks[i](output, training=training) 285 output = self.transition_blocks[i](output, training=training) 286 287 output = self.dense_blocks[ 288 self.num_of_blocks - 1](output, training=training) 289 output = self.batchnorm2(output, training=training) 290 output = tf.nn.relu(output) 291 292 if self.include_top: 293 output = self.last_pool(output) 294 output = self.classifier(output) 295 296 return output 297