Skip to content

Commit

Permalink
[Feature] Integrate GRB (#347)
Browse files Browse the repository at this point in the history
* integrate GRB

* integrate GRB

* add README_GRB

* modify code style to pass flake8

* modify trainer.py

* hasattr(graph, 'grb_adj')

* add grb pytest

* modify grb pytest

* modify grb pytest: decrease time

* black cogdl

* increase test coverage

* increase test coverage
  • Loading branch information
xll2001 authored Jun 23, 2022
1 parent a5c2877 commit 438f5a9
Show file tree
Hide file tree
Showing 76 changed files with 7,636 additions and 461 deletions.
201 changes: 201 additions & 0 deletions README_GRB.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
## Usage of GRB part

### 1. Attack

An example of training Graph Convolutional Network ([GCN](https://arxiv.org/abs/1609.02907)) as surrogate model and another GCN as target model on _grb-cora_ dataset and apply FGSM injection attack on surrogate model and target model.

#### 1) Load Dataset

```python
from cogdl.datasets.grb_data import Cora_GRBDataset
dataset = Cora_GRBDataset()
graph = copy.deepcopy(dataset.get(0))
device = "cuda:0"
graph.to(device)
test_mask = graph.test_mask
```

#### 2) Train Surrogate Model

```python
from cogdl.models.nn import GCN
from cogdl.trainer import Trainer
from cogdl.wrappers import fetch_model_wrapper, fetch_data_wrapper
import torch
model = GCN(
in_feats=graph.num_features,
hidden_size=64,
out_feats=graph.num_classes,
num_layers=2,
dropout=0.5,
activation=None
)
mw_class = fetch_model_wrapper("node_classification_mw")
dw_class = fetch_data_wrapper("node_classification_dw")
optimizer_cfg = dict(
lr=0.01,
weight_decay=0
)
model_wrapper = mw_class(model, optimizer_cfg)
dataset_wrapper = dw_class(dataset)
trainer = Trainer(epochs=2000,
early_stopping=True,
patience=500,
cpu=device=="cpu",
device_ids=[0])
trainer.run(model_wrapper, dataset_wrapper)
# load best model
model.load_state_dict(torch.load("./checkpoints/model.pt"), False)
model.to(device)
```

#### 3) Train Target Model

```python
model_target = GCN(
in_feats=graph.num_features,
hidden_size=64,
out_feats=graph.num_classes,
num_layers=3,
dropout=0.5,
activation="relu"
)
mw_class = fetch_model_wrapper("node_classification_mw")
dw_class = fetch_data_wrapper("node_classification_dw")
optimizer_cfg = dict(
lr=0.01,
weight_decay=0
)
model_wrapper = mw_class(model_target, optimizer_cfg)
dataset_wrapper = dw_class(dataset)
trainer = Trainer(epochs=2000,
early_stopping=True,
patience=500,
cpu=device=="cpu",
device_ids=[0])
trainer.run(model_wrapper, dataset_wrapper)
# load best model
model_target.load_state_dict(torch.load("./checkpoints/model.pt"), False)
model_target.to(device)
```

#### 4) Adversarial attack

```python
# FGSM attack
from cogdl.attack.injection import FGSM
from cogdl.utils.grb_utils import GCNAdjNorm
attack = FGSM(epsilon=0.01,
n_epoch=1000,
n_inject_max=100,
n_edge_max=200,
feat_lim_min=-1,
feat_lim_max=1,
device=device)
graph_attack = attack.attack(model=model_sur,
graph=graph,
adj_norm_func=GCNAdjNorm)
```

#### 5) Evaluate

```python
from cogdl.utils.grb_utils import evaluate
test_score = evaluate(model,
graph,
mask=test_mask,
device=device)
print("Test score before attack for surrogate model: {:.4f}.".format(test_score))
test_score = evaluate(model,
graph_attack,
mask=test_mask,
device=device)
print("After attack, test score of surrogate model: {:.4f}".format(test_score))
test_score = evaluate(model_target,
graph,
mask=test_mask,
device=device)
print("Test score before attack for target model: {:.4f}.".format(test_score))
test_score = evaluate(model_target,
graph_attack,
mask=test_mask,
device=device)
print("After attack, test score of target model: {:.4f}".format(test_score))
```



### 2. Adversarial training

An example of adversarial training for Graph Convolutional Network ([GCN](https://arxiv.org/abs/1609.02907)).

```python
device = "cuda:0"
model = GCN(
in_feats=graph.num_features,
hidden_size=64,
out_feats=graph.num_classes,
num_layers=3,
dropout=0.5,
activation=None,
norm="layernorm"
)
from cogdl.attack.injection import FGSM
attack = FGSM(epsilon=0.01,
n_epoch=10,
n_inject_max=10,
n_edge_max=20,
feat_lim_min=-1,
feat_lim_max=1,
device=device,
verbose=False)
mw_class = fetch_model_wrapper("node_classification_mw")
dw_class = fetch_data_wrapper("node_classification_dw")
optimizer_cfg = dict(
lr=0.01,
weight_decay=0
)
model_wrapper = mw_class(model_target, optimizer_cfg)
dataset_wrapper = dw_class(dataset)
# add argument of attack and attack_mode for adversarial training
trainer = Trainer(epochs=200,
early_stopping=True,
patience=50,
cpu=device=="cpu",
attack=attack,
attack_mode="injection",
device_ids=[0])
trainer.run(model_wrapper, dataset_wrapper)
model.load_state_dict(torch.load("./checkpoints/model.pt"), False)
model.to(device)
```



### 3. Defense models

An example of GATGuard (a defense model).

```python
# defnese model: GATGuard
from cogdl.models.defense import GATGuard
model = GATGuard(in_feats=graph.num_features,
hidden_size=64,
out_feats=graph.num_classes,
num_layers=3,
activation="relu",
num_heads=4,
drop=True)
print(model)
```



## Todo

- [ ] RobustGCN 存在问题
- [ ] betweenness Flip modification attack 卡住
- [ ] PRBCD modification attack 不支持CUDA,且在cpu上非常慢,输出inf
- [ ] FGSM injection attack 的 GRB 实现似乎采用了迭代梯度下降
- [x] SPEIT injection attack 中 inject_mode = "random-iter" 与 "multi-layer" GRB中似乎没有实现
- [ ] leaderboards
2 changes: 2 additions & 0 deletions cogdl/attack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Attack Module for implementation of graph adversarial attacks"""
from .base import Attack, InjectionAttack, ModificationAttack
149 changes: 149 additions & 0 deletions cogdl/attack/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from abc import ABCMeta, abstractmethod


class Attack(metaclass=ABCMeta):
r"""
Description
-----------
Abstract class for graph adversarial attack.
"""

@abstractmethod
def attack(self, model, adj, features, **kwargs):
r"""
Parameters
----------
model : torch.nn.module
Model implemented based on ``torch.nn.module``.
adj : scipy.sparse.csr.csr_matrix
Adjacency matrix in form of ``N * N`` sparse matrix.
features : torch.FloatTensor
Features in form of ``N * D`` torch float tensor.
kwargs :
Keyword-only arguments.
"""


class ModificationAttack(Attack):
r"""
Description
-----------
Abstract class for graph modification attack.
"""

@abstractmethod
def attack(self, **kwargs):
"""
Parameters
----------
kwargs :
Keyword-only arguments.
"""

@abstractmethod
def modification(self, **kwargs):
"""
Parameters
----------
kwargs :
Keyword-only arguments.
"""


class InjectionAttack(Attack):
r"""
Description
-----------
Abstract class for graph injection attack.
"""

@abstractmethod
def attack(self, **kwargs):
"""
Parameters
----------
kwargs :
Keyword-only arguments.
"""

@abstractmethod
def injection(self, **kwargs):
"""
Parameters
----------
kwargs :
Keyword-only arguments.
"""

@abstractmethod
def update_features(self, **kwargs):
"""
Parameters
----------
kwargs :
Keyword-only arguments.
"""


class EarlyStop(object):
r"""
Description
-----------
Strategy to early stop attack process.
"""

def __init__(self, patience=1000, epsilon=1e-4):
r"""
Parameters
----------
patience : int, optional
Number of epoch to wait if no further improvement. Default: ``1000``.
epsilon : float, optional
Tolerance range of improvement. Default: ``1e-4``.
"""
self.patience = patience
self.epsilon = epsilon
self.min_score = None
self.stop = False
self.count = 0

def __call__(self, score):
r"""
Parameters
----------
score : float
Value of attack acore.
"""
if self.min_score is None:
self.min_score = score
elif self.min_score - score > 0:
self.count = 0
self.min_score = score
elif self.min_score - score < self.epsilon:
self.count += 1
if self.count > self.patience:
self.stop = True
6 changes: 6 additions & 0 deletions cogdl/attack/injection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Graph injection attacks"""
from .fgsm import FGSM
from .pgd import PGD
from .rand import RAND
from .speit import SPEIT
from .tdgia import TDGIA
Loading

0 comments on commit 438f5a9

Please sign in to comment.