Simplify Training for Graph Convolutional Networks
Managing, evaluating, and tuning hyperparameters and complex graph models directly in code can be both time-consuming and overwhelming. This article introduces and implements a unified JSON-based declarative interface that streamlines every stage—building, training, tuning, and testing—of a graph neural network.
Introduction
Graph Convolutional Neural Networks
Overview
Scope
Types
Graph Samplers
Training Configurability
Training environment
Graph Convolutional Model
Graph Node Sampling Method
Hands-on with Python
Environment
Defining hyperparameters
Setting up Training
Evaluating the Model
References
What you will learn: How to define and utilize a JSON-based representation of model and training parameters to design, train, and validate a graph neural network.
✅ The complete article, featuring design principles, detailed implementation, in-depth analysis, and exercises, is available at Plug & Play for Training Graph Convolutional Networks
Introduction
Manually cataloging, updating, and tuning all possible hyper-parameters and model settings directly in code is both error-prone and time-consuming. Adopting a unified, JSON-based declarative format for all configurations can significantly streamline the process and reduce complexity
Data on manifolds can often be represented as a graph, where the manifold's local structure is approximated by connections between nearby points. GNNs and their variants (like Graph Convolutional Networks (GCNs)) extend neural networks to process data on non-Euclidean domains by leveraging the graph structure, which may approximate the underlying manifold. Application Social network analysis, molecular structure prediction, and 3D point cloud data can all be modeled using GNNs.
⚠️ I strongly recommend to review my articles introducing PyTorch Geometric Taming PyTorch Geometric for Graph Neural Networks and Graph loaders Demystifying Graph Sampling & Walk Methods
Graph Convolutional Neural Networks
Graph Neural Networks have been discussed in depth in previous articles in this newsletter [ref 1, 2]
Overview
A Graph Neural Network (GNN) is an optimizable transformation on all attributes of the graph (nodes, edges, global context) that preserves graph symmetries (permutation invariances). GNN takes a graph as input and generate/predict a graph as output.
Data on manifolds can often be represented as a graph, where the manifold's local structure is approximated by connections between nearby points. GNNs and their variants (like Graph Convolutional Networks (GCNs) extend neural networks to process data on non-Euclidean domains by leveraging the graph structure, which may approximate the underlying manifold [ref 3, 4, 5 & 6].
📌 The reminder of this paragraph is a review of topics already discussed in previous articles [ref 1, 2] that can be skipped.
Scope
There are 3 types of tasks to be performed on a GNN:
Types
Application such as social network analysis, molecular structure prediction, and 3D point cloud data can all be modeled using GNNs. Here is an overview of the major types of GNNs:
Graph Samplers
The generation of universal embeddings that apply across different applications remains a significant challenge. PyTorch Geometric simplifies this process by encapsulating these complexities into specialized data loaders, while seamlessly integrating with PyTorch's existing deep learning modules.
PyTorch Geometric supports a large variety of graph data loader, including:
Training Configurability
The goal is to design a modular framework that streamlines the configuration of graph neural networks, sampling strategies, and hyper-parameters—enabling automated training, validation, and tuning pipeline of selected GNN models as illustrated below
It is critical to keep the configuration attributes for hyper-parameters, model definition and sampling method as consistent as possible. To this purpose, we select the ubiquitous JSON notation.
Adopting a declarative format to define models, hyper-parameters, and the training environment offers several advantages:
Below are examples demonstrating how JSON notation can be used to configure the entire training and validation pipeline.
Training environment
The following JSON descriptor defines the parameters required for training, validating and tuning your Graph Neural Model.
{
'dataset_name': 'Sonar',
'learning_rate': 0.0005,
'batch_size': 64,
'loss_function': nn.NLLLoss(weight=class_weights.to('cuda0')),
'momentum': 0.90,
'encoding_len': 8,
'train_eval_ratio': 0.9,
'weight_initialization': 'xavier',
'optim_label': 'adam',
'drop_out': 0.25,
'is_class_imbalance': True,
'class_weights': [0.25, 0.3, 0.2, 0.2, 0.05],
'metrics_list': ['Accuracy', 'Precision', 'Recall', 'F1'],
'plot_parameters': [
{'title': 'Accuracy', 'x_label': 'epoch',
'y_label': 'accuracy'},
{'title': 'Precision', 'x_label': 'epochs',
'y_label': 'precision'},
{'title': 'Recall', 'x_label': 'epochs', 'y_label': 'recall'},
{'title': 'F1', 'x_label': 'epochs', 'y_label': 'F1'},
]
}
📌 We selected the parameters somewhat arbitrarily, so your list may differ slightly
Graph Convolutional Model
Neural blocks serve as the fundamental components of deep neural networks [ref 7]. A Graph Convolutional Neural Network (GCN) consists of a series of graph convolutional blocks followed by a sequence of fully connected blocks, each fully specified with an activation function, layer configuration, and optional components such as batch normalization, pooling, and dropout regularization.
{
'model_id': 'MyModel',
'gconv_blocks': [
{
'block_id': 'conv_block_1',
'conv_layer': GraphConv(in_channels=num_node_features,
out_channels=hidden_channels),
'num_channels': hidden_channels,
'activation': nn.ReLU(),
'batch_norm': None,
'pooling': None,
'dropout': 0.25
},
.....
],
'mlp_blocks': [
{
'block_id': 'mlp_block',
'in_features': hidden_channels,
'out_features': _num_classes,
'activation': nn.LogSoftmax(dim=-1),
'dropout': 0.0
}
]
}
📌 A modular description of the graph convolutional network is described in an early paragraph Graph Neural Network Components
Graph Node Sampling Method
Finally we specify the sampling method as neighbor node sampling [ref 8].
{
'id': 'NeighborLoader',
'num_neighbors': [12, 6, 3],
'batch_size': 64,
'replace': True,
'num_workers': 1
}
PyTorch Geometric library offers a wide arrays of sampling methods as described in the previous paragraph.
Hands-on with Python
Environment
⚠️ Warning: Some sampling methods in PyTorch Geometric rely on additional modules: torch-sparse, torch-scatter, torch-spline-conv, and torch-cluster. These are dependencies of torch-geometric but they may not be compatible with the latest versions of PyTorch, particularly across different operating systems.
For macOS, we recommend the following version setup for best compatibility:
These modules currently support only CPU and CUDA execution — MPS (Metal) is not directly supported.
Defining hyperparameters
We begin by implementing the HyperParams wrapper class, which encapsulates the hyperparameters used for training and tuning the Graph Convolutional Network. The constructor defines commonly used configuration attributes, including the choice of optimizer, optim_label and, if applicable, the class weight distribution, class_weightsto address class imbalance (code snippet 1).
Our plug-and-play approach relies on a detailed configuration for training, implemented through the alternative class method constructor build, which generates a HyperParams instance from a JSON configuration file (code snippet 2).
Setting up Training
The GNNTraining class encapsulates the training and validation process, with its default constructor accepting hyper-parameters and metric-related attributes (Code snippet 3). Notably, the alternative build constructor instantiates the training configuration directly from a JSON-formatted string.
The training and validation process across epochs—controlled by one of the hyper-parameters—is handled by the train method, which accepts the following arguments:
The main loop calls the private method, __train_epoch, which processes batches from the training data loader using the specified hyper-parameters and records the training loss for each epoch.
📌 The train_mask has to be applied to both predicted and label data, data.y, Similarly, the val_mask would have to be used on predicted and label validation data.
The implementation of the validation method for each epoch is omitted here. For more details, please refer to the source available Github.com/patnicolas/dl/training/eval_gconv
Evaluating the Model
The objective is to evaluate a Graph Convolutional Network that can properly classify images from the Flickr Dataset [ref 10]. The Flickr dataset is a graph where nodes represent images and edges signify similarities between them. It includes 89,250images and 899,756 relationships. Node features consist of image descriptions and shared properties.
SetUp
We begin by defining the EvalGConv class to manage model evaluation (Code snippet 6). As expected, its constructor takes two arguments:
The primary method, start_training, sequentially calls the following:
The __get_loaders method leverage the GraphDataLoader class introduced in Demystifying Graph Sampling & Walk Methods: Graph Data Loader
The __get_training_env method defines the remaining three parameters needed to initiate training and set up the environment: the loss function, the number of output classes (encoding_len), and the class weight distribution to address class imbalance.
Our Graph Convolutional Neural Network consists of two graph convolutional blocks followed by a fully connected multilayer perceptron block, as illustrated below.
The __get_eval_model method loads the dataset using the auxiliary class PyGDatasets [ref 11]. Each of the two graph convolutional blocks includes a hidden layer with 384 units, a ReLU activation function, and dropout regularization set to 0.25—without pooling or batch normalization. The fully connected block consists of a standard linear layer followed by a LogSoftmax activation module.
📌 The output layer uses a LogSoftMax activation because the chosen loss function is Negative Log-Likelihood (NLLLoss in PyTorch). In contrast, the standard cross-entropy loss would require a Softmax activation instead.
Results
Finally, training and validation are carried out over 32 epochs using Neighbor Node Sampling, with 12 nodes sampled at the first hop, 6 at the second, and 3 at the third.
sampling_attributes = {
'id': 'NeighborLoader',
'num_neighbors': [12, 6, 3],
'batch_size': 64,
'replace': True,
'num_workers':4
}
Accuracy, Precision, Recall, F1 score, along with training and validation losses, are visualized using the widely adopted Matplotlib library.
✅ The full article Plug & Play for Training Graph Convolutional Networks includes quiz, summary and additional evaluation.
References
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning", Packt Publishing ISBN 978-1-78712-238-3 and Hands-on Geometric Deep Learning Newsletter.