MUNIT

Paper: Multimodal Unsupervised Image-to-Image Translation (ECCV 2018)

A very good paper to read, I strongly recommend it.

Official Github: https://github.com/NVlabs/imaginaire/tree/master/projects/munit

Key features:

  • generate diverse outputs from a given source domain image
  • 2 Generators and 2 Discriminators
  • A special generator consist of
    • Encoders and Decoders for the image representation (the latent space of images)
  • MUNIT assume that the image representation (the latent space of images) can be decomposed into:
    • a content code that is domain-invariant, and
    • a style code that captures domain-specific properties.
  • MUNIT also assumes:
    • images in different domains share a common content space but not the style space
  • To translate an image to another domain, MUNIT recombine its content code with a random style code sampled from the style space of the target domain. Different style codes lead to different outputs => diverse and multimodal outputs.

In many scenarios, the cross-domain mapping of interest is multimodal. For example, a winter scene could have many possible appearances during summer due to weather, timing, lighting, etc. Unfortunately, existing techniques (e.g. CycleGAN) usually assume a deterministic or unimodal mapping. As a result, they fail to capture the full distribution of possible outputs.

  • MUNIT can perform Example-guided image translation:
    • the style of the translation outputs are controlled by a user-provided example image in the target domain

img

Architecture of MUNIT

Generator of MUNIT

img

Content encoder

  • Decompose image into content code

Implementation:

  • several strided convolutional layers and residual blocks to downsample input (Down-Sampling convolutional Blocks + Residual convolutional Blocks)
  • All the convolutional layers are followed by Instance Normalization (IN)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch.nn as nn
import torch

#=================================================
from norms import AdaptiveInstanceNorm2d, LayerNorm
#=================================================

#=================================================
# Encoder
#=================================================
class ResidualBlock(nn.Module):
def __init__(self, channels, kernel_size=3, stride=1, padding=1, padding_mode="reflect", norm_type:str="in"):
super().__init__()
norm_layer = AdaptiveInstanceNorm2d if norm_type == "adain" else nn.InstanceNorm2d
self.block = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode),
norm_layer(channels),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode),
norm_layer(channels),
)

def forward(self, x):
return x + self.block(x)

class ContentEncoder(nn.Module):
def __init__(self, in_channels=3, num_features=64, num_residuals=3, num_downsample=2):
super().__init__()

# Initial convolution block
self.initial = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=num_features,
kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
nn.InstanceNorm2d(num_features),
nn.ReLU(inplace=True),
)

# Downsampling blocks
down_layers = []
down_features = num_features
for _ in range(num_downsample):
down_layers += [
nn.Conv2d(in_channels=down_features, out_channels=down_features*2,
kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(down_features*2),
nn.ReLU(inplace=True),
]
down_features *= 2
self.down_sampling = nn.Sequential(*down_layers)

# Res blocks
self.res_blocks = nn.Sequential(
*[ResidualBlock(down_features, kernel_size=3, stride=1, padding=1, norm_type="in") for _ in range(num_residuals)]
)

def forward(self, x):
x = self.initial(x)
x = self.down_sampling(x)
x = self.res_blocks(x)
return x

Style encoder

  • Decompose image into style code

Implementation:

  • several strided convolutional layers, followed by a global average pooling (GAP) layer and a fully connected (FC) layer.
  • Instance Normalization (IN) layers are not used in the style encoder, because IN removes the original feature mean and variance that represent important style information.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class StyleEncoder(nn.Module):
def __init__(self, in_channels=3, num_features=64, num_downsample=2, num_style_features=8):
super().__init__()

# Initial convolution block
self.initial = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=num_features,
kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
nn.ReLU(inplace=True),
)

# Downsampling blocks
down_layers = []
down_features = num_features
for _ in range(2):
down_layers += [
nn.Conv2d(in_channels=down_features, out_channels=down_features*2,
kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
]
down_features *= 2
# Downsampling with constant depth
for _ in range(num_downsample - 2):
down_layers += [
nn.Conv2d(in_channels=down_features, out_channels=down_features,
kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
]
self.down_sampling = nn.Sequential(*down_layers)

# Average pool and output layer
self.last_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=1),
nn.Conv2d(in_channels=down_features, out_channels=num_style_features,
kernel_size=1, stride=1, padding=0),
)

def forward(self, x):
x = self.initial(x)
x = self.down_sampling(x)
x = self.last_pool(x)
return x

Therefore the Encoder:

1
2
3
4
5
6
7
8
9
10
11
class Encoder(nn.Module):
def __init__(self, in_channels=3, features=64, num_residuals=3, num_downsample=2, num_style_features=8):
super().__init__()
self.content_encoder = ContentEncoder(in_channels, features, num_residuals, num_downsample)
self.style_encoder = StyleEncoder(in_channels, features, num_downsample, num_style_features)

def forward(self, x):
content_code = self.content_encoder(x)
style_code = self.style_encoder(x)
return content_code, style_code

Decoder

  • Reconstructs the input image from its content and style code

Implementation:

  • processed content code with residual convolutional blocks and then up-sampling through convolutional layers
  • style code are passed into a multilayer perceptron (MLP) to generate parameters for Adaptive Instance Normalization (AdaIN) in the residual convolutional blocks
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#=================================================
# Decoder
#=================================================

# MLP for predict AdaIN parameters
class MLP(nn.Module):
def __init__(self, input_channels, out_channels, features=256, num_blocks=3):
super().__init__()
layers = [nn.Linear(input_channels, features), nn.ReLU(inplace=True)]
for _ in range(num_blocks - 2):
layers += [nn.Linear(features, features), nn.ReLU(inplace=True)]
layers += [nn.Linear(features, out_channels)]
self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x.view(x.size(0), -1))

class Decoder(nn.Module):
def __init__(self, out_channels=3, num_features=64, num_residuals=3, num_upsample=2, num_style_features=8):
super().__init__()
features = num_features*2**num_upsample
# Res blocks
self.res_blocks = nn.Sequential(
*[ResidualBlock(features, kernel_size=3, stride=1, padding=1, norm_type="adain") for _ in range(num_residuals)]
)

# Upsampling blocks
up_layers = []
for _ in range(num_upsample):
up_layers += [
nn.Upsample(scale_factor=2),
nn.Conv2d(features, features // 2, kernel_size=5, stride=1, padding=2, padding_mode='zeros'),
LayerNorm(num_features = features // 2),
nn.ReLU(inplace=True),
]
features = features // 2
self.upsampling = nn.Sequential(*up_layers)

# Output layer
self.last = nn.Sequential(
nn.Conv2d(in_channels=features, out_channels=out_channels,
kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
)

# Initiate mlp (predicts AdaIN parameters)
num_adain_params = self.get_num_adain_params()
self.mlp = MLP(num_style_features, num_adain_params)

def get_num_adain_params(self):
"""Return the number of AdaIN parameters needed by the model"""
num_adain_params = 0
for m in self.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
num_adain_params += 2 * m.num_features
return num_adain_params

def assign_adain_params(self, adain_params):
"""Assign the adain_params to the AdaIN layers in model"""
for m in self.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
# Extract mean and std predictions
mean = adain_params[:, : m.num_features]
std = adain_params[:, m.num_features : 2 * m.num_features]
# Update bias and weight
m.bias = mean.contiguous().view(-1)
m.weight = std.contiguous().view(-1)
# Move pointer
if adain_params.size(1) > 2 * m.num_features:
adain_params = adain_params[:, 2 * m.num_features :]

def forward(self, content_code, style_code):
# Update AdaIN parameters by MLP prediction based off style code
self.assign_adain_params(self.mlp(style_code))
content_code = self.res_blocks(content_code)
content_code = self.upsampling(content_code)
content_code = self.last(content_code)
img = torch.tanh(content_code)
return img

Discriminator of MUNIT

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch.nn as nn
import torch

#https://github.com/eriklindernoren/PyTorch-GAN/blob/36d3c77e5ff20ebe0aeefd322326a134a279b93e/implementations/munit/models.py#L197
class MultiDiscriminator(nn.Module):
def __init__(self, in_channels=3):
super(MultiDiscriminator, self).__init__()

def discriminator_block(in_channels, out_channels, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)]
if normalize:
layers += [nn.InstanceNorm2d(out_channels)]
layers+= [nn.LeakyReLU(0.2, inplace=True)]
return layers

# Extracts three discriminator models
self.models = nn.ModuleList()
for i in range(3):
self.models.add_module(
"disc_%d" % i,
nn.Sequential(
*discriminator_block(in_channels, 64, normalize=False),
*discriminator_block(64, 128, normalize=True),
*discriminator_block(128, 256, normalize=True),
*discriminator_block(256, 512, normalize=True),
nn.Conv2d(512, 1, 3, padding=1)
),
)

self.downsample = nn.AvgPool2d(in_channels, stride=2, padding=[1, 1], count_include_pad=False)

def compute_loss(self, x, gt):
"""Computes the MSE between model output and scalar gt"""
loss = sum([torch.mean((out - gt) ** 2) for out in self.forward(x)])
return loss

def forward(self, x):
outputs = []
for m in self.models:
outputs.append(m(x))
x = self.downsample(x)
return outputs

Loss functions of MUNIT

Bidirectional reconstruction loss of MUNIT

The Bidirectional reconstruction loss encourge the mapping of:

  • in both image → latent → image and latent → image → latent directions.

Image reconstruction

  • Given an image sampled from the data distribution, we should be able to reconstruct the image after encoding and decoding.

Lrecon x1=Ex1p(x1)[G1(E1c(x1),E1s(x1))x11]\mathcal{L}_{\text {recon }}^{x_{1}}=\mathbb{E}_{x_{1} \sim p\left(x_{1}\right)}\left[\left\|G_{1}\left(E_{1}^{c}\left(x_{1}\right), E_{1}^{s}\left(x_{1}\right)\right)-x_{1}\right\|_{1}\right]

Lrecon x2=Ex2p(x2)[G2(E2c(x2),E2s(x2))x21]\mathcal{L}_{\text {recon }}^{x_{2}}=\mathbb{E}_{x_{2} \sim p\left(x_{2}\right)}\left[\left\|G_{2}\left(E_{2}^{c}\left(x_{2}\right), E_{2}^{s}\left(x_{2}\right)\right)-x_{2}\right\|_{1}\right]

l1l1 reconstruction loss is used as it encourages sharp output images.

Where

  • EcE^c is the Content Encoder
  • EsE^s is the Style Encoder

Latent reconstruction

  • Given a latent code (style and content) sampled from the latent distribution at translation time, we should be able to reconstruct the latent code (style and content) after decoding and encoding.

Lrecon c1=Ec1p(c1),s2q(s2)[E2c(G2(c1,s2))c11]Lrecon s2=Ec1p(c1),s2q(s2)[E2s(G2(c1,s2))s21]\begin{aligned} \mathcal{L}_{\text {recon }}^{c_{1}} &=\mathbb{E}_{c_{1} \sim p\left(c_{1}\right), s_{2} \sim q\left(s_{2}\right)}\left[\left\|E_{2}^{c}\left(G_{2}\left(c_{1}, s_{2}\right)\right)-c_{1}\right\|_{1}\right] \\ \mathcal{L}_{\text {recon }}^{s_{2}} &=\mathbb{E}_{c_{1} \sim p\left(c_{1}\right), s_{2} \sim q\left(s_{2}\right)}\left[\left\|E_{2}^{s}\left(G_{2}\left(c_{1}, s_{2}\right)\right)-s_{2}\right\|_{1}\right] \end{aligned}

Lrecon c2=Ec2p(c2),s1q(s1)[E1c(G1(c2,s1))c21]Lrecon s1=Ec2p(c2),s1q(s1)[E1s(G1(c2,s1))s11]\begin{aligned} \mathcal{L}_{\text {recon }}^{c_{2}} &=\mathbb{E}_{c_{2} \sim p\left(c_{2}\right), s_{1} \sim q\left(s_{1}\right)}\left[\left\|E_{1}^{c}\left(G_{1}\left(c_{2}, s_{1}\right)\right)-c_{2}\right\|_{1}\right] \\ \mathcal{L}_{\text {recon }}^{s_{1}} &=\mathbb{E}_{c_{2} \sim p\left(c_{2}\right), s_{1} \sim q\left(s_{1}\right)}\left[\left\|E_{1}^{s}\left(G_{1}\left(c_{2}, s_{1}\right)\right)-s_{1}\right\|_{1}\right] \end{aligned}

l1l1 reconstruction loss is used as it encourages sharp output images.

Where

  • EcE^c is the Content Encoder
  • EsE^s is the Style Encoder
  • G1G_1 and G2G_2 are the 2 generators

Adversarial loss of MUNIT

LGANx1=Ex1p(x1)[logD1(x1)]+Ec2p(c2),s1q(s1)[log(1D1(G1(c2,s1)))]\mathcal{L}_{\mathrm{GAN}}^{x_{1}}=\mathbb{E}_{x_{1} \sim p\left(x_{1}\right)}\left[\log D_{1}\left(x_{1}\right)\right]+\mathbb{E}_{c_{2} \sim p\left(c_{2}\right), s_{1} \sim q\left(s_{1}\right)}\left[\log \left(1-D_{1}\left(G_{1}\left(c_{2}, s_{1}\right)\right)\right)\right]

LGANx2=Ex2p(x2)[logD2(x2)]+Ec1p(c1),s2q(s2)[log(1D2(G2(c1,s2)))]\mathcal{L}_{\mathrm{GAN}}^{x_{2}}=\mathbb{E}_{x_{2} \sim p\left(x_{2}\right)}\left[\log D_{2}\left(x_{2}\right)\right] + \mathbb{E}_{c_{1} \sim p\left(c_{1}\right), s_{2} \sim q\left(s_{2}\right)}\left[\log \left(1-D_{2}\left(G_{2}\left(c_{1}, s_{2}\right)\right)\right)\right]

Where:

  • D1D_1 and D2D_2 are the 2 discriminators, G1G_1 and G2G_2 are the 2 generators
  • The loss is similar to regular GANs.

Total loss function of MUNIT

minE1,E2,G1,G2maxD1,D2L(E1,E2,G1,G2,D1,D2)=\min _{E_{1}, E_{2}, G_{1}, G_{2}} \max _{D_{1}, D_{2}} \mathcal{L}\left(E_{1}, E_{2}, G_{1}, G_{2}, D_{1}, D_{2}\right) =

LGANx1+LGANx2+λx(Lreconx1+Lreconx2)+λc(Lreconc1+Lreconc2)+λs(Lrecons1+Lrecons2)\begin{aligned} \mathcal{L}_{\mathrm{GAN}}^{x_{1}}+\mathcal{L}_{\mathrm{GAN}}^{x_{2}}+ \lambda_{x}\left(\mathcal{L}_{\mathrm{recon}}^{x_{1}}+\mathcal{L}_{\mathrm{recon}}^{x_{2}}\right)+ \lambda_{c}\left(\mathcal{L}_{\mathrm{recon}}^{c_{1}}+\mathcal{L}_{\mathrm{recon}}^{c_{2}}\right)+ \lambda_{s}\left(\mathcal{L}_{\mathrm{recon}}^{s_{1}}+\mathcal{L}_{\mathrm{recon}}^{s_{2}}\right) \end{aligned}

where

  • The λ\lambda are the weighting factors.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
for epoch in range(opt.epoch, opt.n_epochs):
for idx, batch_img in enumerate(dataloader):

# Input domain
X1 = batch_img["A"].to(device)
# Output domain
X2 = batch_img["B"].to(device)

style_1 = torch.randn(X1.size(0), opt.style_dim, 1, 1).to(device)
style_2 = torch.randn(X1.size(0), opt.style_dim, 1, 1).to(device)


# -------------------------------
# Train Encoders and Generators
# -------------------------------

optim_G.zero_grad()

# Get shared latent representation
c_code_1, s_code_1 = Enc1(X1)
c_code_2, s_code_2 = Enc2(X2)

# Reconstruct images
X11 = Dec1(c_code_1, s_code_1)
X22 = Dec2(c_code_2, s_code_2)

# Translate images
X21 = Dec1(c_code_2, style_1)
X12 = Dec2(c_code_1, style_2)

# Cycle translation
c_code_21, s_code_21 = Enc1(X21)
c_code_12, s_code_12 = Enc2(X12)
X121 = Dec1(c_code_12, s_code_1) if lambda_cyc > 0 else 0
X212 = Dec2(c_code_21, s_code_2) if lambda_cyc > 0 else 0

# Losses
loss_GAN_1 = lambda_gan * D1.compute_loss(X21, valid)
loss_GAN_2 = lambda_gan * D2.compute_loss(X12, valid)
loss_ID_1 = lambda_id * criterion_recon(X11, X1)
loss_ID_2 = lambda_id * criterion_recon(X22, X2)
loss_s_1 = lambda_style * criterion_recon(s_code_21, style_1)
loss_s_2 = lambda_style * criterion_recon(s_code_12, style_2)
loss_c_1 = lambda_cont * criterion_recon(c_code_12, c_code_1.detach())
loss_c_2 = lambda_cont * criterion_recon(c_code_21, c_code_2.detach())
loss_cyc_1 = lambda_cyc * criterion_recon(X121, X1) if lambda_cyc > 0 else 0
loss_cyc_2 = lambda_cyc * criterion_recon(X212, X2) if lambda_cyc > 0 else 0

# Total loss
loss_G = (
loss_GAN_1
+ loss_GAN_2
+ loss_ID_1
+ loss_ID_2
+ loss_s_1
+ loss_s_2
+ loss_c_1
+ loss_c_2
+ loss_cyc_1
+ loss_cyc_2
)

loss_G.backward()
optim_G.step()

# -----------------------
# Train Discriminator 1
# -----------------------

optim_D1.zero_grad()

loss_D1 = D1.compute_loss(X1, valid) + D1.compute_loss(X21.detach(), fake)

loss_D1.backward()
optim_D1.step()

# -----------------------
# Train Discriminator 2
# -----------------------

optim_D2.zero_grad()

loss_D2 = D2.compute_loss(X2, valid) + D2.compute_loss(X12.detach(), fake)

loss_D2.backward()
optim_D2.step()