generate diverse outputs from a given source domain image
2 Generators and 2 Discriminators
A special generator consist of
Encoders and Decoders for the image representation (the latent space of images)
MUNIT assume that the image representation (the latent space of images) can be decomposed into:
a content code that is domain-invariant, and
a style code that captures domain-specific properties.
MUNIT also assumes:
images in different domains share a common content space but not the style space
To translate an image to another domain, MUNIT recombine its content code with a random style code sampled from the style space of the target domain. Different style codes lead to different outputs => diverse and multimodal outputs.
In many scenarios, the cross-domain mapping of interest is multimodal. For example, a winter scene could have many possible appearances during summer due to weather, timing, lighting, etc. Unfortunately, existing techniques (e.g. CycleGAN) usually assume a deterministic or unimodal mapping. As a result, they fail to capture the full distribution of possible outputs.
MUNIT can perform Example-guided image translation:
the style of the translation outputs are controlled by a user-provided example image in the target domain
Architecture of MUNIT
Generator of MUNIT
Content encoder
Decompose image into content code
Implementation:
several strided convolutional layers and residual blocks to downsample input (Down-Sampling convolutional Blocks + Residual convolutional Blocks)
All the convolutional layers are followed by Instance Normalization (IN)
# Res blocks self.res_blocks = nn.Sequential( *[ResidualBlock(down_features, kernel_size=3, stride=1, padding=1, norm_type="in") for _ inrange(num_residuals)] )
defforward(self, x): x = self.initial(x) x = self.down_sampling(x) x = self.res_blocks(x) return x
Style encoder
Decompose image into style code
Implementation:
several strided convolutional layers, followed by a global average pooling (GAP) layer and a fully connected (FC) layer.
Instance Normalization (IN) layers are not used in the style encoder, because IN removes the original feature mean and variance that represent important style information.
Reconstructs the input image from its content and style code
Implementation:
processed content code with residual convolutional blocks and then up-sampling through convolutional layers
style code are passed into a multilayer perceptron (MLP) to generate parameters for Adaptive Instance Normalization (AdaIN) in the residual convolutional blocks
defget_num_adain_params(self): """Return the number of AdaIN parameters needed by the model""" num_adain_params = 0 for m in self.modules(): if m.__class__.__name__ == "AdaptiveInstanceNorm2d": num_adain_params += 2 * m.num_features return num_adain_params
defassign_adain_params(self, adain_params): """Assign the adain_params to the AdaIN layers in model""" for m in self.modules(): if m.__class__.__name__ == "AdaptiveInstanceNorm2d": # Extract mean and std predictions mean = adain_params[:, : m.num_features] std = adain_params[:, m.num_features : 2 * m.num_features] # Update bias and weight m.bias = mean.contiguous().view(-1) m.weight = std.contiguous().view(-1) # Move pointer if adain_params.size(1) > 2 * m.num_features: adain_params = adain_params[:, 2 * m.num_features :]
defforward(self, content_code, style_code): # Update AdaIN parameters by MLP prediction based off style code self.assign_adain_params(self.mlp(style_code)) content_code = self.res_blocks(content_code) content_code = self.upsampling(content_code) content_code = self.last(content_code) img = torch.tanh(content_code) return img
defcompute_loss(self, x, gt): """Computes the MSE between model output and scalar gt""" loss = sum([torch.mean((out - gt) ** 2) for out in self.forward(x)]) return loss
defforward(self, x): outputs = [] for m in self.models: outputs.append(m(x)) x = self.downsample(x) return outputs
Loss functions of MUNIT
Bidirectional reconstruction loss of MUNIT
The Bidirectional reconstruction loss encourge the mapping of:
in both image → latent → image and latent → image → latent directions.
Image reconstruction
Given an image sampled from the data distribution, we should be able to reconstruct the image after encoding and decoding.
l1 reconstruction loss is used as it encourages sharp output images.
Where
Ec is the Content Encoder
Es is the Style Encoder
Latent reconstruction
Given a latent code (style and content) sampled from the latent distribution at translation time, we should be able to reconstruct the latent code (style and content) after decoding and encoding.