Implement ResNet with TensorFlow2
This tutorial shows you how to build ResNet by yourself
Increasing network depth does not work by simply stacking layers together. Deep networks are hard to train because of the notorious “vanishing gradient problem” — as the gradient is back-propagated to earlier layers, repeated multiplication may make the gradient definitively small.
ResNet uses a technic called “Residual” to deal with the “vanishing gradient problem”. When stacking layers, we can use a “shortcut” to link
discontinuous layers. i.e., We can skip some layers, as follows:
Before you read this article, I assume you already know what a convolutional, fully connected network is. In addition, you should be familiar with python and tensorflow2.
Conv2D in Tensorflow
Let’s see how to use Conv2D in Tensorflow Keras.
import tensorflow.keras as keras
from keras import layers
layers.Conv2D(filters, kernel_size, strides, padding)
- filters: Integer, the dimensionality of the output space (aka output channels). e.g., if your input is (h, w, c) and you set filters=64, you will get output (h’, w’, 64). In TensorFlow, we do not need to consider the size of the input channels, however, in Pytorch we need to consider the size of the input channels
- kernel_size: An integer or tuple/list of 2 integers, specifying the height and width of the 2D convolution window. We will only use 1, 3, 7 in the following tutorial.
- strides: An integer or tuple/list of 2 integers, specifying the strides of the convolution along with the height and width. Can be a single integer to specify the same value for all spatial dimensions.
- padding: one of
valid
orsame
(case-insensitive).valid
means no padding.same
results in padding with zeros evenly to the left/right or up/down of the input. Whenpadding="same"
andstrides=1
, the output has the same size as the input.
Calculate the output size
In TensorFlow, we often use channel_last
format. The size of the tensor is (b, h, w, c), where
- b is the batch size
- h is the height (or rows) of the picture (feature map)
- w is the width (or columns) of the picture (feature map)
- c is the channels (or features, depth) of the picture (feature map).
When we use padding = same
, we can just calculate the output height and output width as below:
output = ceil(input / strides)
If we use padding = valid
, we can also calculate the output size as below:
output = ceil((input - filters + 1) /strides)
ResNet18, 34
There are many kinds of ResNet so we see the simplest, ResNet18, firstly. Assume that our input is a 224*224 RGB image, and the output is 1000 classes.
There are 3 main components that make up the network.
- input layer (conv1 + max pooling) (Usually referred to as layer 0)
- ResBlocks (conv2 without max pooing ~ conv5) (Usually referred to as layer1 ~ layer4)
- final layer
STEP0: ResBlocks (layer1~layer4)
The most important component is ResBlocks, Let’s see how to build it!
In the above implementation, there are 3 problems.
- We need to downsample (i.e., zoom out the size of feature map) on
conv3_1
,conv4_1
, andconv5_1
- We can use a variable to control the number of output filters.
- We should apply batch normalization and ReLu function in midst of layers.
STEP1: Input layer (layer0)
Layer0 consists of a 7*7 convolution and a 3*3 max pooling.
self.layer0 = keras.Sequential([
layers.Conv2D(64, 7, 2, padding='same'),
layers.MaxPool2D(pool_size=3, strides=2, padding='same'),
layers.BatchNormalization(),
layers.ReLU()
], name='layer0')
STEP2: Final layer
The final layer consists of a global average pooling (gap) and a fully connected layer (fc). GAP calculates the mean of each feature map (h, w, 1) and then concatenates all the values into a list.
self.gap = layers.GlobalAveragePooling2D()
self.fc = laers.Dense(1000, activation='softmax')
STEP3: Done!
class ResNet18(keras.Model):
def __init__(self, outputs=1000):
super().__init__()
self.layer0 = keras.Sequential([
layers.Conv2D(64, 7, 2, padding='same'),
layers.MaxPool2D(pool_size=3, strides=2, padding='same'),
layers.BatchNormalization(),
layers.ReLU()
], name='layer0')
self.layer1 = keras.Sequential([
ResBlock(64, downsample=False),
ResBlock(64, downsample=False)
], name='layer1')
self.layer2 = keras.Sequential([
ResBlock(128, downsample=True),
ResBlock(128, downsample=False)
], name='layer2')
self.layer3 = keras.Sequential([
ResBlock(256, downsample=True),
ResBlock(256, downsample=False)
], name='layer3')
self.layer4 = keras.Sequential([
ResBlock(512, downsample=True),
ResBlock(512, downsample=False)
], name='layer4')
self.gap = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(outputs, activation='softmax')
def call(self, input):
input = self.layer0(input)
input = self.layer1(input)
input = self.layer2(input)
input = self.layer3(input)
input = self.layer4(input)
input = self.gap(input)
input = self.fc(input)
return input
In order to build ResNet34, we only need to modify the number of ResBlocks in ResNet18.
class ResNet34(keras.Model):
def __init__(self, outputs=1000):
super().__init__()
self.layer0 = keras.Sequential([
layers.Conv2D(64, 7, 2, padding='same'),
layers.MaxPool2D(pool_size=3, strides=2, padding='same'),
layers.BatchNormalization(),
layers.ReLU()
], name='layer0')
self.layer1 = keras.Sequential([
ResBlock(64, downsample=False),
ResBlock(64, downsample=False),
ResBlock(64, downsample=False)
], name='layer1')
self.layer2 = keras.Sequential([
ResBlock(128, downsample=True),
ResBlock(128, downsample=False),
ResBlock(128, downsample=False),
ResBlock(128, downsample=False)
], name='layer2')
self.layer3 = keras.Sequential([
ResBlock(256, downsample=True),
ResBlock(256, downsample=False),
ResBlock(256, downsample=False),
ResBlock(256, downsample=False),
ResBlock(256, downsample=False),
ResBlock(256, downsample=False)
], name='layer3')
self.layer4 = keras.Sequential([
ResBlock(512, downsample=True),
ResBlock(512, downsample=False),
ResBlock(512, downsample=False)
], name='layer4')
self.gap = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(outputs, activation='softmax')
def call(self, input):
input = self.layer0(input)
input = self.layer1(input)
input = self.layer2(input)
input = self.layer3(input)
input = self.layer4(input)
input = self.gap(input)
input = self.fc(input)
return input
ResNet50, 101, 152
STEP0: ResBottleneckBlock
The biggest difference between ResNet34 and ResNet50 is ResBlocks. we need to rewrite the other version and we call the new version “ResBottleneckBlock”.
There are 3 types of ResBottleneckBlock.
We can think of these 3 types as 2 conditions. One is a shortcut that does nothing, and the other is a shortcut that needs to do something. When features do not need downsampling and the input channel is equal to output channel, shortcut does nothing. When features need either downsampling nor changing its channels, shortcut does something.
def __init__(self, filters, downsample):
super().__init__()
self.filters = filters
self.downsample = downsample
self.conv1 = layers.Conv2D(filters, 1, 1, padding='same')
if downsample:
self.conv2 = layers.Conv2D(filters, 3, 2, padding='same')
else:
self.conv2 = layers.Conv2D(filters, 3, 1, padding='same')
self.conv3 = layers.Conv2D(filters*4, 1, 1, padding='same')
It is easy to implement conv1, conv2, and conv3. Let’s see how to implement shortcut!
When we implement shortcut, we have to know how to get the input channels. It is not difficult in fact! We just need to extend the build()
method from layer.Layer
or keras.Model
to get the input_shape
. input_shape is a list with dimensions. As mentioned before, in TensorFlow we usually use the channel_last
format, thus we can get the number of input channels by getting the last value of the input channel.
def build(self, input_shape):
if self.downsample or self.filters * 4 != input_shape[-1]:
self.shortcut = layers.Conv2D(self.filters*4, 1,
2 if self.downsample else 1, padding='same')
else:
self.shortcut = keras.Sequential()
class ResBottleneckBlock(keras.Model):
def __init__(self, filters, downsample):
super().__init__()
self.downsample = downsample
self.filters = filters
self.conv1 = layers.Conv2D(filters, 1, 1)
if downsample:
self.conv2 = layers.Conv2D(filters, 3, 2, padding='same')
else:
self.conv2 = layers.Conv2D(filters, 3, 1, padding='same')
self.conv3 = layers.Conv2D(filters*4, 1, 1)
def build(self, input_shape):
if self.downsample or self.filters * 4 != input_shape[-1]:
self.shortcut = keras.Sequential([
layers.Conv2D(
self.filters*4, 1, 2 if self.downsample else 1, padding='same'),
layers.BatchNormalization()
])
else:
self.shortcut = keras.Sequential()
def call(self, input):
shortcut = self.shortcut(input)
input = self.conv1(input)
input = layers.BatchNormalization()(input)
input = layers.ReLU()(input)
input = self.conv2(input)
input = layers.BatchNormalization()(input)
input = layers.ReLU()(input)
input = self.conv3(input)
input = layers.BatchNormalization()(input)
input = layers.ReLU()(input)
input = input + shortcut
return layers.ReLU()(input)
We can create 3 types of ResBottleneckBlock separately, as shown in Figure 2.
- Left: ResBottleneckBlock(64, downsample=False)
- Middle: ResBottleneckBlock(64, downsample=False)
- Right: ResBottleneckBlock(128, downsample=True)
The difference between the left and the middle is that they have different input channels.
STEP1: Done!
class ResNet(keras.Model):
def __init__(self, repeat, outputs=1000):
super().__init__()
self.layer0 = keras.Sequential([
layers.Conv2D(64, 7, 2, padding='same'),
layers.MaxPool2D(pool_size=3, strides=2, padding='same'),
layers.BatchNormalization(),
layers.ReLU()
], name='layer0')
self.layer1 = keras.Sequential([
ResBottleneckBlock(64, downsample=False) for _ in range(repeat[0])
], name='layer1')
self.layer2 = keras.Sequential([
ResBottleneckBlock(128, downsample=True)
] + [
ResBottleneckBlock(128, downsample=False) for _ in range(1, repeat[1])
], name='layer2')
self.layer3 = keras.Sequential([
ResBottleneckBlock(256, downsample=True)
] + [
ResBottleneckBlock(256, downsample=False) for _ in range(1, repeat[2])
], name='layer3')
self.layer4 = keras.Sequential([
ResBottleneckBlock(512, downsample=True)
] + [
ResBottleneckBlock(512, downsample=False) for _ in range(1, repeat[3])
], name='layer4')
self.gap = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(outputs, activation='softmax')
def call(self, input):
input = self.layer0(input)
input = self.layer1(input)
input = self.layer2(input)
input = self.layer3(input)
input = self.layer4(input)
input = self.gap(input)
input = self.fc(input)
return input
class ResNet50(ResNet):
def __init__(self):
super().__init__([3, 4, 6, 3])
def call(self, input):
return super().call(input)
class ResNet101(ResNet):
def __init__(self):
super().__init__([3, 4, 23, 3])
def call(self, input):
return super().call(input)
class ResNet152(ResNet):
def __init__(self):
super().__init__([3, 8, 36, 3])
def call(self, input):
return super().call(input)
See full code here: https://github.com/ksw2000/ML-Notebook/blob/main/ResNet/ResNet_TensorFlow.ipynb