library(torch)
= function(m) {
weights_init if(inherits(m, "nn_conv_nd")) {
nn_init_normal_(m$weight$data(), 0.0, 0.02)
} if(inherits(m, "nn_batch_norm_")) {
nn_init_normal_(m$weight$data(), 1.0, 0.02)
nn_init_constant_(m$bias$data(), 0)
}
}
= 128
batch_size = 80L
image_size = 3
nc = 100
nz = 80L
ngf = 80L
ndf = 0.01
lr = "cpu" device
15 Generative Adversarial Networks (GANs)
The idea of a generative adversarial network (GAN) is that two neural networks contest against each other in a “game”. One network is creating data and is trying to “trick” the other network into deciding the generated data is real. The generator (similar to the decoder in autoencoders) creates new images from noise. The discriminator is getting a mix of true (from the data set) and artificially generated images from the generator. Thereby, the loss of the generator rises when fakes are identified as fakes by the discriminator (simple binary cross entropy loss, 0/1…). The loss of the discriminator rises when fakes are identified as real images (class 0) or real images as fakes (class 1), again with binary cross entropy.
Binary cross entropy: Entropy or Shannon entropy (named after Claude Shannon) \(\mathbf{H}\) (uppercase “eta”) in context of information theory is the expected value of information content or the mean/average information content of an “event” compared to all possible outcomes. Encountering an event with low probability holds more information than encountering an event with high probability.
Binary cross entropy is a measure to determine the similarity of two (discrete) probability distributions \(A~(\mathrm{true~distribution}), B~(\mathrm{predicted~distribution})\) according to the inherent information.
It is not (!) symmetric, in general: \(\textbf{H}_{A}(B) \neq \textbf{H}_{B}(A)\). The minimum value depends on the distribution of \(A\) and is the entropy of \(A\): \[\mathrm{min}~\textbf{H}_{A}(B) = \underset{B}{\mathrm{min}}~\textbf{H}_{A}(B) = \textbf{H}_{A}(B = A) = \textbf{H}_{A}(A) = \textbf{H}(A)\]
The setup:
- Outcomes \(y_{i} \in \{0, 1\}\) (labels).
- Predictions \(\hat{y}_{i} \in[0, 1]\) (probabilities).
The binary cross entropy or log loss of a system of outcomes/predictions is then defined as follows: \[ \textbf{H}_{A}(B) = -\frac{1}{N} \sum_{i = 1}^{N} y_{i} \cdot \mathrm{log} \left( p(y_{i}) \right) + (1 -y_{i}) \cdot \mathrm{log} \left( 1-p(y_{i}) \right) =\\ = -\frac{1}{N} \sum_{i = 1}^{N} y_{i} \cdot \mathrm{log} (\hat{y}_{i}) + (1 -y_{i}) \cdot \mathrm{log} \left( 1- \hat{y}_{i} \right) \] High predicted probabilities of having the label for originally labeled data (1) yield a low loss as well as predicting a low probability of having the label for originally unlabeled data (0). Mind the properties of probabilities and the logarithm.
A possible application of generative adversarial networks is to create pictures that look like real photographs e.g. https://thispersondoesnotexist.com/. Visit that site (several times)!. However, the application of generative adversarial networks today is much wider than just the creation of data. For example, generative adversarial networks can also be used to “augment” data, i.e. to create new data and thereby improve the fitted model.
Helper function and hyperparameters
Our Models:
- Generator creates images from noise
- Discriminator classifies images into fake and real
= nn_module(
Generator initialize = function() {
$main = nn_sequential(
selfnn_conv_transpose2d( nz, ngf * 4, kernel_size = 4, 1, 0, bias=FALSE),
nn_batch_norm2d(ngf * 4),
nn_relu(),
nn_conv_transpose2d(ngf * 4, ngf * 2, kernel_size = 5, 2, 1, bias=FALSE),
nn_batch_norm2d(ngf * 2),
nn_relu(),
nn_conv_transpose2d( ngf * 2, ngf , kernel_size = 5, 3, 1, bias=FALSE),
nn_batch_norm2d(ngf),
nn_relu(),
nn_conv_transpose2d( ngf, nc, kernel_size =6, 3, 2, bias=FALSE),
nn_sigmoid()
)
},forward = function(input) self$main(input)
)
=
Flatten nn_module(
forward = function(input) return(input$view(list(input$size(1L), -1)))
)
= nn_module(
Discriminator initialize = function() {
$main = nn_sequential(
selfnn_conv2d(nc, 40, 4, 2, 1, bias=FALSE),
nn_leaky_relu(0.2),
nn_conv2d(40, 80, 4, 2, 1, bias=FALSE),
nn_batch_norm2d(80),
nn_leaky_relu(0.2),
nn_conv2d(80, 80, 4, 2, 1, bias=FALSE),
nn_batch_norm2d(80),
nn_leaky_relu(0.2),
nn_conv2d(80, 80, 4, 2, 1, bias=FALSE),
nn_batch_norm2d(80),
nn_leaky_relu(0.2),
Flatten(),
nn_linear(2000, 1),
nn_sigmoid()
)
},forward = function(input) self$main(input)
)
Prepare data:
= EcoData::dataset_flower()
data = data$train/255
train = data$labels
labels = aperm(train, c(1, 4, 2, 3))
train
= torch::tensor_dataset(torch_tensor(train))
dataset = torch::dataloader(dataset, batch_size = 50L, shuffle = TRUE, pin_memory = TRUE) dataLoader
Create our models and initialize optimizers:
= Discriminator()
disc $apply(weights_init)
disc
= Generator()
gen $apply(weights_init)
gen
$main$to(device = device)
disc$main$to(device = device)
gen= nnf_binary_cross_entropy
loss = torch_randn(64, nz, 1, 1, device=device)
fixed_noise
# Establish convention for real and fake labels during training
= 1.
real_label = 0.
fake_label
# Setup Adam optimizers for both G and D
= optim_adam(gen$parameters, lr=lr)
optimizerD = optim_adam(disc$parameters, lr=lr) optimizerG
for(e in 1:1000) {
= errD_batch = NULL
errG_batch = 1
counter ::loop(for (b in dataLoader) {
coro
$zero_grad()
disc$zero_grad()
gen
= b[[1]]$to(device = device)
real = real$size(1L)
b_size
= torch_full(c(b_size), real_label, dtype=torch_float(), device=device)
label
= disc(real)$view(list(-1))
output = nnf_binary_cross_entropy(output, label, reduction = "sum")
errD_real
= torch_randn(b_size, nz, 1, 1, device=device)
noise = gen(noise)
fake = torch_full(c(b_size), fake_label, dtype=torch_float(), device=device)
label = disc(fake$detach())$view(list(-1))
output = nnf_binary_cross_entropy(output, label,reduction = "sum")
errD_fake
= errD_real + errD_fake
errD $backward()
errD$step()
optimizerD#optimizerD$zero_grad()
$zero_grad()
gen= gen(noise)
fake = disc(fake)$view(list(-1L))
output = nnf_binary_cross_entropy(output, 1.0-label, reduction = "sum")
errG $backward()
errG#optimizerD$step()
$step()
optimizerG
<- errG$item()
errG_batch[counter] <- errD$item()
errD_batch[counter] = counter + 1
counter
})cat("Epoch: ", e, " loss D: ", mean(errD_batch), " loss G: ", mean(errG_batch), "\n")
}= gen(noise)
predictions = as_array(predictions$cpu())
images = aperm(images, c(1, 3, 4, 2))
images
= par()
oldpar par(mfrow = c(4, 5), mar = rep(0, 4), oma = rep(0, 4))
for(i in 1:20) {
%>%
images[i,,,] ::image_to_array() %>%
keras3as.raster() %>%
plot()
}