Multi-level Disentanglement Graph Neural Network (MD-GNN)
This is a PyTorch implementation of the MD-GNN, and the code includes the following modules:
Datasets (Cora, Citeseer, Pubmed, Synthetic, and ZINC)
Training paradigm for node classification, graph classification, and graph regression tasks
- main() — Train a new model for node classification task on the Cora, Citeseer, and Pubmed datasets
- evaluate() — Test the learned model for node classification task on the Cora, Citeseer, and Pubmed datasets
- main_synthetic() — Train a new model for graph classification task on the Synthetic dataset
- evaluate_synthetic() — Test the learned model for graph classification task on the Synthetic dataset
- main_zinc() — Train a new model for graph regression task on the ZINC datasets
- evaluate_zinc() — Test the learned model for graph regression task on the ZINC datasets
- load_data() — Load data of selected dataset
- MDGNN() — model and loss
- evaluate_att() — Evaluate attribute-level disentanglement with the visualization of relation-related attributes
- evaluate_corr() — Evaluate node-level disentanglement with the correlation analysis of latent features
- evaluate_graph() — Evaluate graph-level disentanglement with the visualization of disentangled relation graphs
Running the code
Install the required dependency packages and unzip files in the data folder.
We use DGL to implement all the GNN models on three citation datasets (Cora, Citeseer, and Pubmed). In order to evaluate the model with different splitting strategy (fewer and harder label rates), you need to replace the following file with the
- To get the results on a specific dataset, run with proper hyperparameters
python train.py --dataset data_name
where the data_name is one of the five datasets (cora, citeseer, pubmed, synthetic, and zinc). The model as well as the training log will be saved to the corresponding dir in ./log for evaluation.
- The evaluation the performance of three-level disentanglement performance, run
MD-GNN is released under the MIT license.