PyGDA is a Python library for Graph Domain Adaptation built upon PyTorch and PyG to easily train graph domain adaptation models in a sklearn style. PyGDA includes 20+ graph domain adaptation models. See examples with PyGDA below!
Graph Domain Adaptation Using PyGDA with 5 Lines of Code
from pygda.models import A2GNN
# choose a graph domain adaptation model
model = A2GNN(in_dim=num_features, hid_dim=args.nhid, num_classes=num_classes, device=args.device)
# train the model
model.fit(source_data, target_data)
# evaluate the performance
logits, labels = model.predict(target_data)
PyGDA is featured for:
- Consistent APIs and comprehensive documentation.
- Cover 20+ graph domain adaptation models.
- Scalable architecture that efficiently handles large graph datasets through mini-batching and sampling techniques.
- Seamlessly integrated data processing with PyG, ensuring full compatibility with PyG data structures.
[12/2024]. We now support source-free setting of graph domain adaptation.
- 3 recent models including
GTrans
,SOGA
andGraphCTA
are supported.
[08/2024]. We support graph-level domain adaptation task.
- 7 models including
A2GNN
,AdaGCN
,CWGCN
,DANE
,GRADE
,SAGDA
,UDAGCN
are supported. - Various TUDatasets are supported including
FRANKENSTEIN
,Mutagenicity
andPROTEINS
. - To perform a graph-level domain adaptation task, only one parameter is added to the model as follows:
model = A2GNN(in_dim=num_features, hid_dim=args.nhid, num_classes=num_classes, mode='graph', device=args.device)
Note: PyGDA depends on PyTorch, PyG, PyTorch Sparse and Pytorch Scatter. PyGDA does not automatically install these libraries for you. Please install them separately in order to run PyGDA successfully.
Required Dependencies:
- torch>=1.13.1
- torch_geometric>=2.4.0
- torch_sparse>=0.6.15
- torch_scatter>=2.1.0
- python3
- scipy
- sklearn
- numpy
- cvxpy
- tqdm
Installing with pip:
pip install pygda
or
Installation for local development:
git clone https://github.com/pygda-team/pygda
cd pygda
pip install -e .
from pygda.datasets import CitationDataset
source_dataset = CitationDataset(path, args.source)
target_dataset = CitationDataset(path, args.target)
from pygda.models import A2GNN
model = A2GNN(in_dim=num_features, hid_dim=args.nhid, num_classes=num_classes, device=args.device)
model.fit(source_data, target_data)
from pygda.metrics import eval_micro_f1, eval_macro_f1
logits, labels = model.predict(target_data)
preds = logits.argmax(dim=1)
mi_f1 = eval_micro_f1(labels, preds)
ma_f1 = eval_macro_f1(labels, preds)
In addition to the easy application of existing GDA models, PyGDA makes it simple to implement custom models.
- the customed model should inherit
BaseGDA
class. - implement your
fit()
,forward_model()
, andpredict()
functions.
If you compare with, build on, or use aspects of PyGDA, please consider citing "Revisiting, Benchmarking and Understanding Unsupervised Graph Domain Adaptation":
@inproceedings{liu2024revisiting,
title={Revisiting, Benchmarking and Understanding Unsupervised Graph Domain Adaptation},
author={Meihan Liu and Zhen Zhang and Jiachen Tang and Jiajun Bu and Bingsheng He and Sheng Zhou},
booktitle={The Thirty-eight Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2024},
url={https://openreview.net/forum?id=ZsyFwzuDzD}
}