15  Generative Adversarial Networks (GANs)

The idea of a generative adversarial network (GAN) is that two neural networks contest against each other in a “game”. One network is creating data and is trying to “trick” the other network into deciding the generated data is real. The generator (similar to the decoder in autoencoders) creates new images from noise. The discriminator is getting a mix of true (from the data set) and artificially generated images from the generator. Thereby, the loss of the generator rises when fakes are identified as fakes by the discriminator (simple binary cross entropy loss, 0/1…). The loss of the discriminator rises when fakes are identified as real images (class 0) or real images as fakes (class 1), again with binary cross entropy.

Binary cross entropy: Entropy or Shannon entropy (named after Claude Shannon) \(\mathbf{H}\) (uppercase “eta”) in context of information theory is the expected value of information content or the mean/average information content of an “event” compared to all possible outcomes. Encountering an event with low probability holds more information than encountering an event with high probability.

Binary cross entropy is a measure to determine the similarity of two (discrete) probability distributions \(A~(\mathrm{true~distribution}), B~(\mathrm{predicted~distribution})\) according to the inherent information.

It is not (!) symmetric, in general: \(\textbf{H}_{A}(B) \neq \textbf{H}_{B}(A)\). The minimum value depends on the distribution of \(A\) and is the entropy of \(A\): \[\mathrm{min}~\textbf{H}_{A}(B) = \underset{B}{\mathrm{min}}~\textbf{H}_{A}(B) = \textbf{H}_{A}(B = A) = \textbf{H}_{A}(A) = \textbf{H}(A)\]

The setup:

The binary cross entropy or log loss of a system of outcomes/predictions is then defined as follows: \[ \textbf{H}_{A}(B) = -\frac{1}{N} \sum_{i = 1}^{N} y_{i} \cdot \mathrm{log} \left( p(y_{i}) \right) + (1 -y_{i}) \cdot \mathrm{log} \left( 1-p(y_{i}) \right) =\\ = -\frac{1}{N} \sum_{i = 1}^{N} y_{i} \cdot \mathrm{log} (\hat{y}_{i}) + (1 -y_{i}) \cdot \mathrm{log} \left( 1- \hat{y}_{i} \right) \] High predicted probabilities of having the label for originally labeled data (1) yield a low loss as well as predicting a low probability of having the label for originally unlabeled data (0). Mind the properties of probabilities and the logarithm.

A possible application of generative adversarial networks is to create pictures that look like real photographs e.g. https://thispersondoesnotexist.com/. Visit that site (several times)!. However, the application of generative adversarial networks today is much wider than just the creation of data. For example, generative adversarial networks can also be used to “augment” data, i.e. to create new data and thereby improve the fitted model.

Helper function and hyperparameters

library(torch)

weights_init = function(m) {
  if(inherits(m, "nn_conv_nd")) {
    nn_init_normal_(m$weight$data(), 0.0, 0.02)
  } 
  if(inherits(m, "nn_batch_norm_")) {
    nn_init_normal_(m$weight$data(), 1.0, 0.02)
    nn_init_constant_(m$bias$data(), 0)
  }
}

batch_size = 128
image_size = 80L
nc = 3
nz = 100
ngf = 80L
ndf = 80L
lr = 0.01
device = "cpu"

Our Models:

Generator = nn_module(
  initialize = function() {
    self$main = nn_sequential(
      nn_conv_transpose2d( nz, ngf * 4, kernel_size = 4, 1, 0, bias=FALSE),
      nn_batch_norm2d(ngf * 4),
      nn_relu(),
      nn_conv_transpose2d(ngf * 4, ngf * 2, kernel_size = 5, 2, 1, bias=FALSE),
      nn_batch_norm2d(ngf * 2),
      nn_relu(),
      nn_conv_transpose2d( ngf * 2, ngf , kernel_size = 5, 3, 1, bias=FALSE),
      nn_batch_norm2d(ngf),
      nn_relu(),
      nn_conv_transpose2d( ngf, nc,  kernel_size =6, 3, 2, bias=FALSE),
      nn_sigmoid()
    )
  },
  forward = function(input) self$main(input)
)

Flatten = 
  nn_module(
    forward = function(input) return(input$view(list(input$size(1L), -1)))
  )

Discriminator = nn_module(
  initialize = function() {
    self$main = nn_sequential(
      nn_conv2d(nc, 40, 4, 2, 1, bias=FALSE),
      nn_leaky_relu(0.2),
      nn_conv2d(40, 80, 4, 2, 1, bias=FALSE),
      nn_batch_norm2d(80),
      nn_leaky_relu(0.2),
      nn_conv2d(80, 80, 4, 2, 1, bias=FALSE),
      nn_batch_norm2d(80),
      nn_leaky_relu(0.2),
      nn_conv2d(80, 80, 4, 2, 1, bias=FALSE),
      nn_batch_norm2d(80),
      nn_leaky_relu(0.2), 
      Flatten(),
      nn_linear(2000, 1),
      nn_sigmoid()
    )
  },
  forward = function(input) self$main(input)
)

Prepare data:

data = EcoData::dataset_flower()
train = data$train/255
labels = data$labels
train = aperm(train, c(1, 4, 2, 3))


dataset = torch::tensor_dataset(torch_tensor(train))
dataLoader = torch::dataloader(dataset, batch_size = 50L, shuffle = TRUE, pin_memory = TRUE)

Create our models and initialize optimizers:

disc = Discriminator()
disc$apply(weights_init)

gen = Generator()
gen$apply(weights_init)

disc$main$to(device = device)
gen$main$to(device = device)
loss = nnf_binary_cross_entropy
fixed_noise = torch_randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim_adam(gen$parameters, lr=lr)
optimizerG = optim_adam(disc$parameters, lr=lr)
for(e in 1:1000) {
  errG_batch = errD_batch = NULL
  counter = 1
  coro::loop(for (b in dataLoader) {
    
    
    disc$zero_grad()
    gen$zero_grad()
    
    real = b[[1]]$to(device = device)
    b_size = real$size(1L)
    
    label = torch_full(c(b_size), real_label, dtype=torch_float(), device=device)
    
    output = disc(real)$view(list(-1))
    errD_real = nnf_binary_cross_entropy(output, label, reduction = "sum")
    
    noise = torch_randn(b_size, nz, 1, 1, device=device)
    fake = gen(noise)
    label = torch_full(c(b_size), fake_label, dtype=torch_float(), device=device)
    output = disc(fake$detach())$view(list(-1))
    errD_fake = nnf_binary_cross_entropy(output, label,reduction = "sum")
    
    errD = errD_real + errD_fake
    errD$backward()
    optimizerD$step()
    #optimizerD$zero_grad()
    
    
    gen$zero_grad()
    fake = gen(noise)
    output = disc(fake)$view(list(-1L))
    errG = nnf_binary_cross_entropy(output, 1.0-label, reduction = "sum")
    errG$backward()
    #optimizerD$step()
    optimizerG$step()
    
    errG_batch[counter] <- errG$item()
    errD_batch[counter] <- errD$item()
    counter = counter + 1
  })
  cat("Epoch: ", e, " loss D: ", mean(errD_batch), " loss G: ", mean(errG_batch), "\n")
}
predictions = gen(noise)
images = as_array(predictions$cpu())
images = aperm(images, c(1, 3, 4, 2))

oldpar = par()
par(mfrow = c(4, 5), mar = rep(0, 4), oma = rep(0, 4))

for(i in 1:20) {
  images[i,,,]  %>%
    keras3::image_to_array() %>% 
    as.raster() %>%
    plot()
}