Skip to content

Commit

Permalink
Refactor ARMAConv example (#5271)
Browse files Browse the repository at this point in the history
* move data outside Net class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rfdavid and pre-commit-ci[bot] authored Aug 24, 2022
1 parent 8bcc77c commit b2d7f09
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions examples/arma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,17 @@


class Net(torch.nn.Module):
def __init__(self):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()

self.conv1 = ARMAConv(dataset.num_features, 16, num_stacks=3,
self.conv1 = ARMAConv(in_channels, hidden_channels, num_stacks=3,
num_layers=2, shared_weights=True, dropout=0.25)

self.conv2 = ARMAConv(16, dataset.num_classes, num_stacks=3,
self.conv2 = ARMAConv(hidden_channels, out_channels, num_stacks=3,
num_layers=2, shared_weights=True, dropout=0.25,
act=lambda x: x)

def forward(self):
x, edge_index = data.x, data.edge_index
def forward(self, x, edge_index):
x = F.dropout(x, training=self.training)
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
Expand All @@ -34,20 +33,23 @@ def forward(self):


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
model, data = Net(dataset.num_features, 16,
dataset.num_classes).to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


def train():
model.train()
optimizer.zero_grad()
F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()


def test():
model.eval()
logits, accs = model(), []
logits, accs = model(data.x, data.edge_index), []
for _, mask in data('train_mask', 'val_mask', 'test_mask'):
pred = logits[mask].max(1)[1]
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
Expand Down

0 comments on commit b2d7f09

Please sign in to comment.