library(reticulate)
# Load python packages torch and torch_geometric via the reticulate R package
torch = import("torch")
torch_geometric = import("torch_geometric")
# helper functions from the torch_geometric modules
GCNConv = torch_geometric$nn$GCNConv
global_mean_pool = torch_geometric$nn$global_mean_pool
# Download the MUTAG TUDataset
dataset = torch_geometric$datasets$TUDataset(root='data/TUDataset',
name='MUTAG')
dataloader = torch_geometric$loader$DataLoader(dataset,
batch_size=64L,
shuffle=TRUE)
# Create the model with a python class
# There are two classes in the response variable
GCN = PyClass(
"GCN",
inherit = torch$nn$Module,
defs = list(
`__init__` = function(self, hidden_channels) {
super()$`__init__`()
torch$manual_seed(42L)
self$conv = GCNConv(dataset$num_node_features, hidden_channels)
self$linear = torch$nn$Linear(hidden_channels, dataset$num_classes)
NULL
},
forward = function(self, x, edge_index, batch) {
x = self$conv(x, edge_index)
x = x$relu()
x = global_mean_pool(x, batch)
x = torch$nn$functional$dropout(x, p = 0.5, training=self$training)
x = self$linear(x)
return(x)
}
))13 Graph Neural Networks (GNNs)
Graph neural networks (GNN) is a young representative of the deep neural network family but is receiving more and more attention in the last years because of their ability to process non-Euclidean data such as graphs.
Currently there is no R package for GNNs available. However, we can use the ‘reticulate’ package to use the python packages ‘torch’ (python version) and ‘torch_geometric’.
The following example was mostly adapted from the ‘Node Classification with Graph Neural Networks’ example from the torch_geometric documentation.
The dataset is also provided by the ‘torch_geometric’ package and consists of molecules presented as graphs and the task is to predict whether HIV virus replication is inhibited by the molecule or not (classification, binary classification).
Training loop:
# create model object
model = GCN(hidden_channels = 64L)
# get optimizer and loss function
optimizer = torch$optim$Adamax(model$parameters(), lr = 0.01)
loss_func = torch$nn$CrossEntropyLoss()
# set model into training mode (because of the dropout layer)
model$train()
# train model
for(e in 1:50) {
iterator = reticulate::as_iterator(dataloader)
coro::loop(for (b in iterator) {
pred = model(b$x, b$edge_index, b$batch)
loss = loss_func(pred, b$y)
loss$backward()
optimizer$step()
optimizer$zero_grad()
})
if(e %% 10 ==0) cat(paste0("Epoch: ",e," Loss: ", round(loss$item()[1], 4), "\n"))
}
## Epoch: 10 Loss: 0.6151
## Epoch: 20 Loss: 0.6163
## Epoch: 30 Loss: 0.5745
## Epoch: 40 Loss: 0.5362
## Epoch: 50 Loss: 0.5829Make predictions:
preds = list()
test = torch_geometric$loader$DataLoader(dataset, batch_size=64L,shuffle=FALSE)
iterator = reticulate::as_iterator(test)
model$eval()
counter = 1
coro::loop(for (b in iterator) {
preds[[counter]] = model(b$x, b$edge_index, b$batch)
counter <<- counter + 1
})
head(torch$concat(preds)$sigmoid()$data$cpu()$numpy(), n = 3)
## [,1] [,2]
## [1,] 0.3076028 0.6427078
## [2,] 0.4121239 0.5515330
## [3,] 0.4119514 0.5516798