The Vanishing Gradient and ResNet

Problem of Vanishing Gradient

  • Ideally, the deeper the NN is, the better performance we can obtain.
  • However, in reality, with the network depth increasing, the accuracy gets saturated and then deegrades rapidly.
    • What happened? => Gradient Vanishing / Explosion
      • How? Chain rule multiplication

ResNet (2015)

  • A great break through for using more layers, isolate the residual for improvement with 152 layers, better than human recognition.
  • Residual Learning: Skipping connection or Jump connection y=H(x)=F(x)+xy = H(x) = F(x) + x
    • Avoided vanishing Gradient problem
  • Residual block contains 2 Conv and 2 Activations.
    • The residual output is obtained after the 2 Conv and 1 Activation because we want to make use the activation layer to learn a sparse residual model for better generalization annd also feasibility in real application.
  • ResNet applied Residual Learning into all layers
    • Extremely simple arrangment
    • Not only useful for ResNet, but useful for nearly all deep learning structures
  • If we add all the previous residual output, we have the final residual output as: xL=x1+i=1L1F(xi)x_L=x_1 + \sum^{L-1}_{i=1}F(x_i) which has the additive error. For direct mapping, we will have the multiplicative error.
    • The multiplicative error is the reason of vanishing / exploding gradient problem

Residual Network (ResNet)

Making a network deeper does not necessarily bring better performance because of the vanishing gradient problem.

  • The main idea of ResNet is to use an identity shortcut connection that skips one or more layers.
    • This trick alleviates the gradient vanishing problem, leading to networks with 100+ layers
img
  • If the desired mapping is H(x), it is easier to train a feed-forward network (enclosed by the red-dashed rectangle) to fit a residual mapping F(x)=H(x)xF(x)= H(x) -x

img

It is because we have one extra term to reduce the chance of having small gradient, as the error gradient can be directly passed to lower layers.

PyTorch Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn

# Batchsize = 50, single-channel image of size 23 x 1234
x = torch.zeros([50,1,23,1234])

# First weight layer: Conv1x1+BN+ReLU
h = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=1, stride=1)(x)
h = nn.BatchNorm2d(num_features=32)(h)
h = nn.ReLU()(h)

# Second weight layer: Conv3x3+BN+ReLU
h = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1)(h)
h = nn.BatchNorm2d(num_features=32)(h)
h = nn.ReLU()(h)
print(h.shape)

# Third weight layer: Conv1x1+BN+ReLU
h = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1, stride=1)(h)
h = nn.BatchNorm2d(num_features=1)(h)
h = nn.ReLU()(h)


# Add
out = h + x
out = nn.ReLU()(out)
print(out.shape)

#torch.Size([50,32,23,1234])
#torch.Size([50,1,23,1234])

Densely Connected CNN (DenseNet)

  • DenseNet further exploits the effect of shortcut connections
    • The input of each layer consists of the feature maps of all earlier layers, and its output is passed to each subsequent layer.
  • DenseNet not only alleviates the gradient vanishing problem, but also encourages feature reuse, i.e., the network can perform well with less parameters.
    • Feature reuse is achieved by concatenating feature maps instead of adding, as in ResNet.

img

PyTorch Example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn


# Batchsize = 50, 23 MFCCs per frame, 1234 frames
x0 = torch.zeros([50, 23, 1234])

# First weight layer (BN+ReLU+Conv1D)x2
h = nn.BatchNorm1d(num_features=23)(x0)
h = nn.ReLU()(h)
h = nn.Conv1d(in_channels=23, out_channels=160, kernel_size=1, stride=1)(h)
h = nn.BatchNorm1d(num_features=23)(h)
h = nn.ReLU()(h)
x1 = nn.Conv1d(in_channels=160, out_channels=40, kernel_size=3, stride=1, padding=1)(h)

#Concatenate
c1 = torch.cat([x1, x0], dim=1)
print(c1.shape)

# Second weight layer (BN+ReLU+Conv1D)x2
h = nn.BatchNorm1d(num_features=63)(c1)
h = nn.ReLU()(h)
h = nn.Conv1d(in_channels=63, out_channels=160, kernel_size=1, stride=1)(h)
h = nn.BatchNorm1d(num_features=160)(h)
h = nn.ReLU()(h)
x2 = nn.Conv1d(in_channels=160, out_channels=40, kernel_size=3, stride=1, padding=1)(h)

#Concatenate
c2 = torch.cat([x2, x1, x0], dim=1)
print(c2.shape)


#torch.Size([50, 63, 1234])
#torch.Size([50, 103, 1234])