-
Notifications
You must be signed in to change notification settings - Fork 315
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* integrate GRB * integrate GRB * add README_GRB * modify code style to pass flake8 * modify trainer.py * hasattr(graph, 'grb_adj') * add grb pytest * modify grb pytest * modify grb pytest: decrease time * black cogdl * increase test coverage * increase test coverage
- Loading branch information
Showing
76 changed files
with
7,636 additions
and
461 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
## Usage of GRB part | ||
|
||
### 1. Attack | ||
|
||
An example of training Graph Convolutional Network ([GCN](https://arxiv.org/abs/1609.02907)) as surrogate model and another GCN as target model on _grb-cora_ dataset and apply FGSM injection attack on surrogate model and target model. | ||
|
||
#### 1) Load Dataset | ||
|
||
```python | ||
from cogdl.datasets.grb_data import Cora_GRBDataset | ||
dataset = Cora_GRBDataset() | ||
graph = copy.deepcopy(dataset.get(0)) | ||
device = "cuda:0" | ||
graph.to(device) | ||
test_mask = graph.test_mask | ||
``` | ||
|
||
#### 2) Train Surrogate Model | ||
|
||
```python | ||
from cogdl.models.nn import GCN | ||
from cogdl.trainer import Trainer | ||
from cogdl.wrappers import fetch_model_wrapper, fetch_data_wrapper | ||
import torch | ||
model = GCN( | ||
in_feats=graph.num_features, | ||
hidden_size=64, | ||
out_feats=graph.num_classes, | ||
num_layers=2, | ||
dropout=0.5, | ||
activation=None | ||
) | ||
mw_class = fetch_model_wrapper("node_classification_mw") | ||
dw_class = fetch_data_wrapper("node_classification_dw") | ||
optimizer_cfg = dict( | ||
lr=0.01, | ||
weight_decay=0 | ||
) | ||
model_wrapper = mw_class(model, optimizer_cfg) | ||
dataset_wrapper = dw_class(dataset) | ||
trainer = Trainer(epochs=2000, | ||
early_stopping=True, | ||
patience=500, | ||
cpu=device=="cpu", | ||
device_ids=[0]) | ||
trainer.run(model_wrapper, dataset_wrapper) | ||
# load best model | ||
model.load_state_dict(torch.load("./checkpoints/model.pt"), False) | ||
model.to(device) | ||
``` | ||
|
||
#### 3) Train Target Model | ||
|
||
```python | ||
model_target = GCN( | ||
in_feats=graph.num_features, | ||
hidden_size=64, | ||
out_feats=graph.num_classes, | ||
num_layers=3, | ||
dropout=0.5, | ||
activation="relu" | ||
) | ||
mw_class = fetch_model_wrapper("node_classification_mw") | ||
dw_class = fetch_data_wrapper("node_classification_dw") | ||
optimizer_cfg = dict( | ||
lr=0.01, | ||
weight_decay=0 | ||
) | ||
model_wrapper = mw_class(model_target, optimizer_cfg) | ||
dataset_wrapper = dw_class(dataset) | ||
trainer = Trainer(epochs=2000, | ||
early_stopping=True, | ||
patience=500, | ||
cpu=device=="cpu", | ||
device_ids=[0]) | ||
trainer.run(model_wrapper, dataset_wrapper) | ||
# load best model | ||
model_target.load_state_dict(torch.load("./checkpoints/model.pt"), False) | ||
model_target.to(device) | ||
``` | ||
|
||
#### 4) Adversarial attack | ||
|
||
```python | ||
# FGSM attack | ||
from cogdl.attack.injection import FGSM | ||
from cogdl.utils.grb_utils import GCNAdjNorm | ||
attack = FGSM(epsilon=0.01, | ||
n_epoch=1000, | ||
n_inject_max=100, | ||
n_edge_max=200, | ||
feat_lim_min=-1, | ||
feat_lim_max=1, | ||
device=device) | ||
graph_attack = attack.attack(model=model_sur, | ||
graph=graph, | ||
adj_norm_func=GCNAdjNorm) | ||
``` | ||
|
||
#### 5) Evaluate | ||
|
||
```python | ||
from cogdl.utils.grb_utils import evaluate | ||
test_score = evaluate(model, | ||
graph, | ||
mask=test_mask, | ||
device=device) | ||
print("Test score before attack for surrogate model: {:.4f}.".format(test_score)) | ||
test_score = evaluate(model, | ||
graph_attack, | ||
mask=test_mask, | ||
device=device) | ||
print("After attack, test score of surrogate model: {:.4f}".format(test_score)) | ||
test_score = evaluate(model_target, | ||
graph, | ||
mask=test_mask, | ||
device=device) | ||
print("Test score before attack for target model: {:.4f}.".format(test_score)) | ||
test_score = evaluate(model_target, | ||
graph_attack, | ||
mask=test_mask, | ||
device=device) | ||
print("After attack, test score of target model: {:.4f}".format(test_score)) | ||
``` | ||
|
||
|
||
|
||
### 2. Adversarial training | ||
|
||
An example of adversarial training for Graph Convolutional Network ([GCN](https://arxiv.org/abs/1609.02907)). | ||
|
||
```python | ||
device = "cuda:0" | ||
model = GCN( | ||
in_feats=graph.num_features, | ||
hidden_size=64, | ||
out_feats=graph.num_classes, | ||
num_layers=3, | ||
dropout=0.5, | ||
activation=None, | ||
norm="layernorm" | ||
) | ||
from cogdl.attack.injection import FGSM | ||
attack = FGSM(epsilon=0.01, | ||
n_epoch=10, | ||
n_inject_max=10, | ||
n_edge_max=20, | ||
feat_lim_min=-1, | ||
feat_lim_max=1, | ||
device=device, | ||
verbose=False) | ||
mw_class = fetch_model_wrapper("node_classification_mw") | ||
dw_class = fetch_data_wrapper("node_classification_dw") | ||
optimizer_cfg = dict( | ||
lr=0.01, | ||
weight_decay=0 | ||
) | ||
model_wrapper = mw_class(model_target, optimizer_cfg) | ||
dataset_wrapper = dw_class(dataset) | ||
# add argument of attack and attack_mode for adversarial training | ||
trainer = Trainer(epochs=200, | ||
early_stopping=True, | ||
patience=50, | ||
cpu=device=="cpu", | ||
attack=attack, | ||
attack_mode="injection", | ||
device_ids=[0]) | ||
trainer.run(model_wrapper, dataset_wrapper) | ||
model.load_state_dict(torch.load("./checkpoints/model.pt"), False) | ||
model.to(device) | ||
``` | ||
|
||
|
||
|
||
### 3. Defense models | ||
|
||
An example of GATGuard (a defense model). | ||
|
||
```python | ||
# defnese model: GATGuard | ||
from cogdl.models.defense import GATGuard | ||
model = GATGuard(in_feats=graph.num_features, | ||
hidden_size=64, | ||
out_feats=graph.num_classes, | ||
num_layers=3, | ||
activation="relu", | ||
num_heads=4, | ||
drop=True) | ||
print(model) | ||
``` | ||
|
||
|
||
|
||
## Todo | ||
|
||
- [ ] RobustGCN 存在问题 | ||
- [ ] betweenness Flip modification attack 卡住 | ||
- [ ] PRBCD modification attack 不支持CUDA,且在cpu上非常慢,输出inf | ||
- [ ] FGSM injection attack 的 GRB 实现似乎采用了迭代梯度下降 | ||
- [x] SPEIT injection attack 中 inject_mode = "random-iter" 与 "multi-layer" GRB中似乎没有实现 | ||
- [ ] leaderboards |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
"""Attack Module for implementation of graph adversarial attacks""" | ||
from .base import Attack, InjectionAttack, ModificationAttack |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
from abc import ABCMeta, abstractmethod | ||
|
||
|
||
class Attack(metaclass=ABCMeta): | ||
r""" | ||
Description | ||
----------- | ||
Abstract class for graph adversarial attack. | ||
""" | ||
|
||
@abstractmethod | ||
def attack(self, model, adj, features, **kwargs): | ||
r""" | ||
Parameters | ||
---------- | ||
model : torch.nn.module | ||
Model implemented based on ``torch.nn.module``. | ||
adj : scipy.sparse.csr.csr_matrix | ||
Adjacency matrix in form of ``N * N`` sparse matrix. | ||
features : torch.FloatTensor | ||
Features in form of ``N * D`` torch float tensor. | ||
kwargs : | ||
Keyword-only arguments. | ||
""" | ||
|
||
|
||
class ModificationAttack(Attack): | ||
r""" | ||
Description | ||
----------- | ||
Abstract class for graph modification attack. | ||
""" | ||
|
||
@abstractmethod | ||
def attack(self, **kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
kwargs : | ||
Keyword-only arguments. | ||
""" | ||
|
||
@abstractmethod | ||
def modification(self, **kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
kwargs : | ||
Keyword-only arguments. | ||
""" | ||
|
||
|
||
class InjectionAttack(Attack): | ||
r""" | ||
Description | ||
----------- | ||
Abstract class for graph injection attack. | ||
""" | ||
|
||
@abstractmethod | ||
def attack(self, **kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
kwargs : | ||
Keyword-only arguments. | ||
""" | ||
|
||
@abstractmethod | ||
def injection(self, **kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
kwargs : | ||
Keyword-only arguments. | ||
""" | ||
|
||
@abstractmethod | ||
def update_features(self, **kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
kwargs : | ||
Keyword-only arguments. | ||
""" | ||
|
||
|
||
class EarlyStop(object): | ||
r""" | ||
Description | ||
----------- | ||
Strategy to early stop attack process. | ||
""" | ||
|
||
def __init__(self, patience=1000, epsilon=1e-4): | ||
r""" | ||
Parameters | ||
---------- | ||
patience : int, optional | ||
Number of epoch to wait if no further improvement. Default: ``1000``. | ||
epsilon : float, optional | ||
Tolerance range of improvement. Default: ``1e-4``. | ||
""" | ||
self.patience = patience | ||
self.epsilon = epsilon | ||
self.min_score = None | ||
self.stop = False | ||
self.count = 0 | ||
|
||
def __call__(self, score): | ||
r""" | ||
Parameters | ||
---------- | ||
score : float | ||
Value of attack acore. | ||
""" | ||
if self.min_score is None: | ||
self.min_score = score | ||
elif self.min_score - score > 0: | ||
self.count = 0 | ||
self.min_score = score | ||
elif self.min_score - score < self.epsilon: | ||
self.count += 1 | ||
if self.count > self.patience: | ||
self.stop = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Graph injection attacks""" | ||
from .fgsm import FGSM | ||
from .pgd import PGD | ||
from .rand import RAND | ||
from .speit import SPEIT | ||
from .tdgia import TDGIA |
Oops, something went wrong.