White-box Cartoonization

(CVPR2020)

Paper: Learning to Cartoonize Using White-box Cartoon Representations

Supplementary Paper : Learning to Cartoonize Using White-box Cartoon Representations Supplementary materials

Official Github (Tensorflow): https://github.com/SystemErrorWang/White-box-Cartoonization

Github (PyTorch): https://github.com/zhen8838/AnimeStylized

Github (My PyTorch implementation): https://github.com/vinesmsuic/WBCartoonization-PyTorch

This paper, as its name suggested, is to perform Image Cartoonization. It also make use of GANs and performs better than CartoonGAN in my opinion. The paper mentioned the properties of cartoon as:

  • (1) Global structures composed of sparse color blocks
  • (2) Details outlined by sharp and clear edges
  • (3) Flat and smooth surfaces

Although a black-box model can also perform cartoonization, the stylization quality and generality are not optimal and stable. The white-box model divided images into surface representation, structure representation, and textured representation. The equation adjusts and balances the weightings of the loss from the three features to produce different artistic styles of cartoonized image.

Key features of White-box Cartoonization:

  • Requires Unpaired images for training
  • Produce high-quality cartoon stylization (compare to CartoonGAN)
  • Significantly fewer artifacts than CartoonGAN
  • Unlike previous black-box models that guide network training with
    loss terms, this model decompose images into several representations, which enforces network to learn different features with separate objectives, making the learning process controllable and tunable.
  • Identified three white-box representations from cartoon images:
    • surface representation to represent the smooth surface of images
    • structure representation to represent the sparse color-blocks and flatten global content in the celluloid style
    • texture representation to represent high-frequency texture, contours, and details of images
  • Proposed a GAN framework with 1 generator GG and 2 discriminators DsD_s and DtD_t
    • DsD_s aims to distinguish between surface representation extracted from model outputs and cartoons
    • DtD_t aims to distinguish between texture representation extracted from outputs and cartoons
  • Pre-train the generator network with only content loss (Same as CartoonGAN)

The 3 White-box Representations

The representations are extracted through traditional hand-crafted methods (non-network methods).

“The separately extracted cartoon representations enable the cartooniaztion problem to be optimized end-to-end within a Generative Neural Networks (GAN) framework, making it scalable and controllable for practical use cases and easy to meet diversified artistic demands with taskspecific fine-tuning.”

Surface representation

  • Extract a weighted low-frequency component from an image.
    • Preserve color composition and surface texture
    • Ignore edges, textures and details

“This Surface representation design is inspired by the cartoon painting behavior where artists usually draw composition drafts before the details are retouched, and is used to achieve a flexible and learnable feature representation for smoothed surfaces.”

In implementation, it is done by Guided Filtering which uses a differentiable guided filter to extract smooth surface (textures and details removed).

Implementation of differentiable guided filter:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms

def box_filter(x, r):
channel = x.shape[1] # Batch, Channel, H, W
kernel_size = (2*r+1)
weight = 1/(kernel_size**2)
box_kernel = weight*torch.ones((channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device)
output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) #tf.nn.depthwise_conv2d(x, box_kernel, [1, 1, 1, 1], 'SAME')

return output

def guided_filter(x, y, r, eps=1e-2):
# Batch, Channel, H, W
_, _, H, W = x.shape

N = box_filter(torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), r)

mean_x = box_filter(x, r) / N
mean_y = box_filter(y, r) / N
cov_xy = box_filter(x * y, r) / N - mean_x * mean_y
var_x = box_filter(x * x, r) / N - mean_x * mean_x

A = cov_xy / (var_x + eps)
b = mean_y - A * mean_x

mean_A = box_filter(A, r) / N
mean_b = box_filter(b, r) / N

output = mean_A * x + mean_b
return output

img_path = "Vivy.jpg"
image = np.array(Image.open(img_path))
image = transforms.ToTensor()(image).unsqueeze_(0)
result = guided_filter(image, image, r=5) #F_dgf(I, I)
result = torch.cat((image, result), dim=3)
PIL_image = transforms.ToPILImage()(result.squeeze_(0))
PIL_image.save("differentiable_guided_filter.png")

Results:

Structure representation

  • Apply an adaptive coloring algorithm on each segmented regions to generate sparse visual effects.
    • seize the global structural information
    • sparse color blocks in celluloid cartoon style

“This Structure representation design is motivated to emulate the celluloid cartoon style, which is featured by clear boundaries and sparse color blocks.”

In implementation, it is done by Super-pixel segmentation (Felzenszwalb’s Algorithm) and then apply Selective Search to merge segmented regions and extract a sparse segmentation map.

The paper used an adaptive coloring algorithm instead of standard coloring algorithms. They found using standard superpixel algorithms which color each segmented region with an average of the pixel value are not good. They found this lowers global contrast, darkens images, and causes hazing effect on the final results. The adaptive coloring algorithm can be formulated as:

Si,j=(θ1×S+θ2×S~)μ(θ1,θ2)={(0,1)σ(S)<γ1(0.5,0.5)γ1<σ(S)<γ2(1,0)γ2<σ(S)\begin{aligned} \boldsymbol{S}_{i, j} &=\left(\theta_{1} \times \overline{\boldsymbol{S}}+\theta_{2} \times \tilde{\boldsymbol{S}}\right)^{\mu} \\ \left(\theta_{1}, \theta_{2}\right)=&\left\{\begin{array}{ll} (0,1) & \sigma(\boldsymbol{S})<\gamma_{1} \\ (0.5,0.5) & \gamma_{1}<\sigma(\boldsymbol{S})<\gamma_{2} \\ (1,0) & \gamma_{2}<\sigma(\boldsymbol{S}) \end{array}\right. \end{aligned}

The paper found using this γ1=20,γ2=40 and μ=1.2\gamma_{1}=20, \gamma_{2}=40 \text { and } \mu=1.2 setting could effectively enhances the contrast of images and reduces hazing effect on their processed dataset.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from skimage import segmentation
from PIL import Image
import numpy as np

def label2rgb(label_field, image):
out = np.zeros_like(image)
labels = np.unique(label_field)
for label in labels:
mask = (label_field == label).nonzero()
std = np.std(image[mask])
if std < 20:
color = image[mask].mean(axis=0)
elif 20 < std < 40:
mean = image[mask].mean(axis=0)
median = np.median(image[mask], axis=0)
color = 0.5*mean + 0.5*median
elif 40 < std:
color = np.median(image[mask], axis=0)
out[mask] = color
return out


img_path = "Vivy.jpg"
image = np.array(Image.open(img_path))
seg_labels = segmentation.felzenszwalb(image, scale=10, sigma=0.8, min_size=100)
img_cvtcolor = label2rgb(seg_labels, image)
result = np.concatenate((image,img_cvtcolor), axis=1)
print(result.shape)
PIL_image = Image.fromarray(result)
PIL_image.save("felzenszwalb.png")


Texture representation

  • Shift the color of the image to generate random intensity maps with luminance and color information removed
    • Retains high-frequency textures
    • Decreases the influence of color and luminance

"This Texture representation design is motivated by a cartoon painting method where artists firstly draw a line sketch with contours and details, and then apply color on it. It guides the network to learn the high-frequency textural details independently with the color and luminance patterns excluded.

In implementation, it is done by Random Color Shift. The paper proposed an random color shift algorithm Frcs\mathcal{F}_{r c s} to extract single-channel texture representation from color images.

Frcs(Irgb)=(1α)(β1×Ir+β2×Ig+β3×Ib)+α×Y\mathcal{F}_{r c s}\left(\boldsymbol{I}_{r g b}\right)=(1-\alpha)\left(\beta_{1} \times \boldsymbol{I}_{r}+\beta 2 \times \boldsymbol{I}_{g}+\beta_{3} \times \boldsymbol{I}_{b}\right)+\alpha \times \boldsymbol{Y}

Where:

  • II is an image
  • r,g,br,g,b represent the color channel red, green, blue
  • YY is a standard grayscale image converted from RGB color image II.

The paper set α=0.8,β1,β2 and β3U(1,1)\alpha=0.8, \beta_{1}, \beta_{2} \text { and } \beta_{3} \sim U(-1,1).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

class ColorShift(nn.Module):
def __init__(self):
super().__init__()
self.dist = torch.distributions.Uniform(
torch.tensor((0.199, 0.487, 0.014)),
torch.tensor((0.399, 0.687, 0.214)))

def forward(self, img):
weights = self.dist.sample()
img_3d = img.squeeze_(0)
r,g,b = torch.split(img_3d, split_size_or_sections=1, dim=0)
output = (weights[0]*r+weights[1]*g+weights[2]*b)/weights.sum()

return output.unsqueeze_(0)

colorshift = ColorShift()
img_path = "Vivy.jpg"
image = np.array(Image.open(img_path))
image = transforms.ToTensor()(image).unsqueeze_(0)
#print(image.shape)
result = colorshift(image) #F_rcs(I)
#print(image.shape)
result = result.squeeze_(0).repeat(3,1,1)
#print(result.shape)
result = torch.cat((image, result), dim=2)
PIL_image = transforms.ToPILImage()(result.squeeze_(0))
PIL_image.save("color_shift.png")

Results:

Loss functions of White-box Cartoonization

As we mentioned, This paper Proposed a GAN framework with 1 generator GG and 2 discriminators DsD_s and DtD_t where:

  • Generator GG aims to convert an input into cartoon image by
    • learning the information stored in the extracted surface representations
    • learning the clear contours and fine textures stored in the texture representations
  • Discriminator DsD_s aims to distinguish between surface representation extracted from model outputs and cartoons
  • Discriminator DtD_t aims to distinguish between texture representation extracted from outputs and cartoons

The framework also involved a Pre-trained VGG network to extract high-level features and to impose spatial constrain on global contents between extracted structure representations and outputs, and also between input photos and outputs.

Here we denote:

  • DsD_s is the surface discriminator
  • DtD_t is the texture discriminator
  • IcI_c is a cartoon image
  • IpI_p is a photo image
  • G(Ip)G(I_p) is a fake cartoon image generated from photo image

Surface loss of White-box Cartoonization

Surface loss of White-box Cartoonization is used to guide the Generator learning the information stored in the extracted surface representations, with the help of surface discriminator.

Lsurface (G,Ds)=logDs(Fdgf(Ic,Ic))+log(1Ds(Fdgf(G(Ip),G(Ip))))\begin{aligned} \mathcal{L}_{\text {surface }}(G,&\left.D_{s}\right)=\log D_{s}\left(\mathcal{F}_{d g f}\left(\boldsymbol{I}_{c}, \boldsymbol{I}_{c}\right)\right) +\log \left(1-D_{s}\left(\mathcal{F}_{d g f}\left(G\left(\boldsymbol{I}_{p}\right), G\left(\boldsymbol{I}_{p}\right)\right)\right)\right) \end{aligned}

Where:

  • Fdgf(I,I)\mathcal{F}_{d g f}(\boldsymbol{I}, \boldsymbol{I}) is the output of differentiable guided filter mentioned above. The filter take an image as input and take the input itself as guide map to return extracted surface representation (textures and details removed).

Structure loss of White-box Cartoonization

Structure loss of White-box Cartoonization is used to enforce spatial constrain between results and extracted structure representation. This is done by using the high-level features extracted by a pre-trained VGG16 network.

Lstructure =VGGn(G(Ip))VGGn(Fst(G(Ip)))\mathcal{L}_{\text {structure }}=\left\|V G G_{n}\left(G\left(\boldsymbol{I}_{p}\right)\right)-V G G_{n}\left(\mathcal{F}_{s t}\left(G\left(\boldsymbol{I}_{p}\right)\right)\right)\right\|

Where:

  • Fst(I)\mathcal{F}_{s t}(\boldsymbol{I}) is the extracted Structure representation. Output of (Felzenszwalb’s Algorithm + Selective Search).
  • l1l1 sparse regularization is used here

Texture loss of White-box Cartoonization

Texture loss of White-box Cartoonization is used to guide the Generator learning the clear contours and fine textures stored in the texture representations, with the help of texture discriminator.

Ltexture (G,Dt)=logDt(Frcs(Ic))+log(1Dt(Frcs(G(Ip))))\begin{aligned} \mathcal{L}_{\text {texture }}\left(G, D_{t}\right) &=\log D_{t}\left(\mathcal{F}_{r c s}\left(\boldsymbol{I}_{c}\right)\right) + \log \left(1-D_{t}\left(\mathcal{F}_{r c s}\left(G\left(\boldsymbol{I}_{p}\right)\right)\right)\right) \end{aligned}

Where:

  • Frcs(I)\mathcal{F}_{r c s}(\boldsymbol{I}) is the output of random color shift algorithm mentioned above.

Content loss of White-box Cartoonization

“The content loss is used to ensure that the cartoonized results and input photos are semantically invariant, and the sparsity of L1 norm allows for local features to be cartoonized. Similar to the structure loss, it is calculated on pre-trained VGG16 feature space.”

Lcontent =VGGn(G(Ip))VGGn(Ip)\mathcal{L}_{\text {content }}=\left\|V G G_{n}\left(G\left(\boldsymbol{I}_{p}\right)\right)-V G G_{n}\left(\boldsymbol{I}_{p}\right)\right\|

  • l1l1 sparse regularization is used here

Total-variation loss of White-box Cartoonization

Total-variation loss of White-box Cartoonization is used to impose spatial smoothness on generated images. It also reduces high-frequency noises such as salt-and-pepper noise.

Ltv=1H×W×Cx(G(Ip))+y(G(Ip))\mathcal{L}_{t v}=\frac{1}{H \times W \times C}\left\|\nabla_{x}\left(G\left(\boldsymbol{I}_{p}\right)\right)+\nabla_{y}\left(G\left(\boldsymbol{I}_{p}\right)\right)\right\|

Where:

  • H,W,CH, W, C represents the spatial dimensions of images Height, Width, Channel

Full Objective of White-box Cartoonization

Ltotal =λ1×Lsurface +λ2×Ltexture +λ3×Lstructure +λ4×Lcontent +λ5×Ltv\begin{aligned} \mathcal{L}_{\text {total }} &=\lambda_{1} \times \mathcal{L}_{\text {surface }}+\lambda_{2} \times \mathcal{L}_{\text {texture }} +\lambda_{3} \times \mathcal{L}_{\text {structure }}+\lambda_{4} \times \mathcal{L}_{\text {content }}+\lambda_{5} \times \mathcal{L}_{t v} \end{aligned}

Where the λ1,λ2,λ3,λ4,λ5\lambda_{1}, \lambda_{2}, \lambda_{3}, \lambda_{4}, \lambda_{5} are weighting factors. The paper used:

  • λ1=1\lambda_{1} = 1
  • λ2=10\lambda_{2} = 10
  • λ3=2×103\lambda_{3} = 2\times10^3
  • λ4=2×103\lambda_{4} = 2\times10^3
  • λ5=104\lambda_{5} = 10^4

Pretraining of Generator in White-box Cartoonization

Generator is pretrained with only content loss for N=50000N = 50000 iterations. Then the training stage begins.

Postprocessing of White-box Cartoonization

Iinterp =δ×Fdgf (Iin ,G(Iin ))+(1δ)×G(Iin )\boldsymbol{I}_{\text {interp }}=\delta \times \mathcal{F}_{\text {dgf }}\left(\boldsymbol{I}_{\text {in }}, G\left(\boldsymbol{I}_{\text {in }}\right)\right)+(1-\delta) \times G\left(\boldsymbol{I}_{\text {in }}\right)

Architecture of White-box Cartoonization

Architecture of Generator and Discriminator in White-box Cartoonization

Refer to figure 1 of the Supplementary Paper.

  • Discriminators : PatchGAN (similar to the discriminator in CycleGAN, Pix2Pix)
  • Generator : Fully-convolutional U-Net-like network (Similar to the generator in Pix2Pix)

Implementation of Generator

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F

# PyTorch implementation by vinesmsuic
# Referenced from official tensorflow implementation: https://github.com/SystemErrorWang/White-box-Cartoonization/blob/master/train_code/network.py
# slim.convolution2d uses constant padding (zeros).


class ResidualBlock(nn.Module):
def __init__(self, channels, kernel_size, stride, padding):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode="zeros"),
)

def forward(self, x):
#Elementwise Sum (ES)
return x + self.block(x)

class Generator(nn.Module):
def __init__(self, img_channels, num_features=32, num_residuals=4):
super().__init__()
self.initial_down = nn.Sequential(
#k7n32s1
nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)

#Down-convolution
self.down1 = nn.Sequential(
#k3n32s2
nn.Conv2d(num_features, num_features, kernel_size=3, stride=2, padding=1, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),

#k3n64s1
nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)

self.down2 = nn.Sequential(
#k3n64s2
nn.Conv2d(num_features*2, num_features*2, kernel_size=3, stride=2, padding=1, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),

#k3n128s1
nn.Conv2d(num_features*2, num_features*4, kernel_size=3, stride=1, padding=1, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)

#Bottleneck: 4 residual blocks => 4 times [K3n128s1]
self.res_blocks = nn.Sequential(
*[ResidualBlock(num_features*4, kernel_size=3, stride=1, padding=1) for _ in range(num_residuals)]
)

#Up-convolution
self.up1 = nn.Sequential(
#k3n128s1 (should be k3n64s1?)
nn.Conv2d(num_features*4, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)

self.up2 = nn.Sequential(
#k3n64s1
nn.Conv2d(num_features*2, num_features*2, kernel_size=3, stride=1, padding=1, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
#k3n64s1 (should be k3n32s1?)
nn.Conv2d(num_features*2, num_features, kernel_size=3, stride=1, padding=1, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)

self.last = nn.Sequential(
#k3n32s1
nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1, padding_mode="zeros"),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
#k7n3s1
nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="zeros")
)

def forward(self, x):
x1 = self.initial_down(x)
x2 = self.down1(x1)
x = self.down2(x2)
x = self.res_blocks(x)
x = self.up1(x)
#Resize Bilinear
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners = False)
x = self.up2(x + x2)
#Resize Bilinear
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners = False)
x = self.last(x + x1)
#TanH
return torch.tanh(x)

Implementation of Discriminator

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import spectral_norm

# PyTorch implementation by vinesmsuic
# Referenced from official tensorflow implementation: https://github.com/SystemErrorWang/White-box-Cartoonization/blob/master/train_code/network.py
# slim.convolution2d uses constant padding (zeros).
# Paper used spectral_norm

class Block(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super().__init__()
self.sn_conv = spectral_norm(nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
padding_mode="zeros" # Author's code used slim.convolution2d, which is using SAME padding (zero padding in pytorch)
))

self.LReLU = nn.LeakyReLU(negative_slope=0.2, inplace=True)

def forward(self, x):
x = self.sn_conv(x)
x = self.LReLU(x)

return x


class Discriminator(nn.Module):
def __init__(self, in_channels=3, out_channels=1, features=[32, 64, 128]):
super().__init__()
self.model = nn.Sequential(
#k3n32s2
Block(in_channels, features[0], kernel_size=3, stride=2, padding=1),
#k3n32s1
Block(features[0], features[0], kernel_size=3, stride=1, padding=1),

#k3n64s2
Block(features[0], features[1], kernel_size=3, stride=2, padding=1),
#k3n64s1
Block(features[1], features[1], kernel_size=3, stride=1, padding=1),

#k3n128s2
Block(features[1], features[2], kernel_size=3, stride=2, padding=1),
#k3n128s1
Block(features[2], features[2], kernel_size=3, stride=1, padding=1),

#k1n1s1
Block(features[2], out_channels, kernel_size=1, stride=1, padding=0)
)

def forward(self, x):
x = self.model(x)

return x

#No sigmoid for LSGAN adv loss
#return torch.sigmoid(x)

Training Details of White-box Cartoonization

  • Adam Optimzer for both generator and discriminators
  • Learning rate = 2×1042\times10^{-4}
  • Batch size = 1616
  • Generator is pretrained with only content loss for N=50000N = 50000 iterations, and then jointly optimize the GAN based framework. Training
    is stopped after 100000100000 iterations or on convergency.

Dataset

Human face and landscape data are collected for generalization on diverse scenes. For real-world photos, we collect 10000 images from the FFHQ dataset for the human face and 5000 images from the dataset in for landscape. For cartoon images, we collect 10000 images from animations for the human face and 10000 images for landscape. Producers of collected animations include Kyoto animation, P.A.Works, Shinkai Makoto, Hosoda Mamoru, and Miyazaki Hayao. For the validation set, we collect 3011 animation images and 1978 real-world photos. Images shown in the main paper are collected from the DIV2K dataset, and images in user study are collected from the Internet and Microsoft COCO dataset. During training, all images are resized to 256*256 resolution, and face images are feed only once in every five iterations.

  • Training images are resized to 256×256256 \times 256
  • Face images are feed only once in every five iterations