12  Graph Neural Networks (GNNs)

Graph neural networks (GNN) is a young representative of the deep neural network family but is receiving more and more attention in the last years because of their ability to process non-Euclidean data such as graphs.

Currently there is no R package for GNNs available. However, we can use the ‘reticulate’ package to use the python packages ‘torch’ (python version) and ‘torch_geometric’.

The following example was mostly adapted from the ‘Node Classification with Graph Neural Networks’ example from the torch_geometric documentation.

The dataset is also provided by the ‘torch_geometric’ package and consists of molecules presented as graphs and the task is to predict whether HIV virus replication is inhibited by the molecule or not (classification, binary classification).

library(reticulate)
# Load python packages torch and torch_geometric via the reticulate R package
torch = import("torch") 
torch_geometric = import("torch_geometric")
# helper functions from the torch_geometric modules
GCNConv = torch_geometric$nn$GCNConv
global_mean_pool = torch_geometric$nn$global_mean_pool
# Download the MUTAG TUDataset
dataset = torch_geometric$datasets$TUDataset(root='data/TUDataset', 
                                             name='MUTAG')
dataloader = torch_geometric$loader$DataLoader(dataset, 
                                               batch_size=64L,
                                               shuffle=TRUE)
# Create the model with a python class
# There are two classes in the response variable
GCN = PyClass(
  "GCN", 
   inherit = torch$nn$Module, 
   defs = list(
       `__init__` = function(self, hidden_channels) {
         super()$`__init__`()
         torch$manual_seed(42L)
         self$conv = GCNConv(dataset$num_node_features, hidden_channels)
         self$linear = torch$nn$Linear(hidden_channels, dataset$num_classes)
         NULL
       },
       forward = function(self, x, edge_index, batch) {
         x = self$conv(x, edge_index)
         x = x$relu()
         x = global_mean_pool(x, batch)
         
         x = torch$nn$functional$dropout(x, p = 0.5, training=self$training)
         x = self$linear(x)
         return(x)
       }
   ))

Training loop:

# create model object
model = GCN(hidden_channels = 64L)
# get optimizer and loss function
optimizer = torch$optim$Adamax(model$parameters(), lr = 0.01)
loss_func = torch$nn$CrossEntropyLoss()
# set model into training mode (because of the dropout layer)
model$train()
# train model
for(e in 1:50) {
  iterator = reticulate::as_iterator(dataloader)
  coro::loop(for (b in iterator) { 
     pred = model(b$x, b$edge_index, b$batch)
     loss = loss_func(pred, b$y)
     loss$backward()
     optimizer$step()
     optimizer$zero_grad()
  })
  if(e %% 10 ==0) cat(paste0("Epoch: ",e," Loss: ", round(loss$item()[1], 4), "\n"))
}
## Epoch: 10 Loss: 0.6151
## Epoch: 20 Loss: 0.6163
## Epoch: 30 Loss: 0.5745
## Epoch: 40 Loss: 0.5362
## Epoch: 50 Loss: 0.5829

Make predictions:

preds = list()
test = torch_geometric$loader$DataLoader(dataset, batch_size=64L,shuffle=FALSE)
iterator = reticulate::as_iterator(test)
model$eval()
counter = 1
coro::loop(for (b in iterator) {
  preds[[counter]] = model(b$x, b$edge_index, b$batch)
  counter <<- counter + 1
  })
head(torch$concat(preds)$sigmoid()$data$cpu()$numpy(), n = 3)
##          [,1]      [,2]
## [1,] 0.3076028 0.6427078
## [2,] 0.4121239 0.5515330
## [3,] 0.4119514 0.5516798