Skip to content

Commit

Permalink
test added
Browse files Browse the repository at this point in the history
  • Loading branch information
JoseLuisC99 committed Jun 10, 2023
1 parent 6668fa3 commit 5b78b34
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 9 deletions.
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,6 @@ experiment with PyTorch.
Graph Convolutional Network
^^^^^^^^^^^^^^^^^^^^^^^^^^^

This example implements the `Semi-Supervised Classification with Graph Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>` __ paper on the `CORA <https://en.wikipedia.org/wiki/MNIST_database>`__ database.
This example implements the `Semi-Supervised Classification with Graph Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`__ paper on the CORA database.

`GO TO EXAMPLE <https://github.com/pytorch/examples/tree/main/graph_conv_network>`__ :opticon:`link-external`
`GO TO EXAMPLE <https://github.com/pytorch/examples/tree/main/gcn>`__ :opticon:`link-external`
26 changes: 22 additions & 4 deletions gcn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,29 @@ def test(model, criterion, input, target, mask):
help='dimension of the hidden representation (default: 16)')
parser.add_argument('--val-every', type=int, default=20,
help='epochs to wait for print training and validation evaluation (default: 20)')
parser.add_argument('--include-bias', type=bool, default=False,
parser.add_argument('--include-bias', action='store_true', default=False,
help='use bias term in convolutions (default: False)')
parser.add_argument('--no-cuda', type=bool, default=False,
help='disable CUDA training (default: False)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
args = parser.parse_args()

device = 'cpu' if args.no_cuda else device
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()

torch.manual_seed(args.seed)

if use_cuda:
device = torch.device('cuda')
elif use_mps:
device = torch.device('mps')
else:
device = torch.device('cpu')
print(f'Using {device} device')

cora_url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz'
Expand All @@ -239,6 +255,8 @@ def test(model, criterion, input, target, mask):

for epoch in range(args.epochs):
train_iter(epoch + 1, gcn, optimizer, criterion, (features, adj_mat), labels, idx_train, idx_val, args.val_every)
if args.dry_run:
break

loss_test, acc_test = test(gcn, criterion, (features, adj_mat), labels, idx_test)
print(f'Test set results: loss {loss_test:.4f} accuracy {acc_test:.4f}')
5 changes: 3 additions & 2 deletions gcn/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy
requests
torch
torchvision
requests
numpy
9 changes: 8 additions & 1 deletion run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ function word_language_model() {
python main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
}

function gcn() {
start
python main.py --epochs 1 --dry-run || error "graph convolutional network failed"
}

function clean() {
cd $BASE_DIR
echo "running clean to remove cruft"
Expand All @@ -192,7 +197,8 @@ function clean() {
super_resolution/model_epoch_1.pth \
time_sequence_prediction/predict*.pdf \
time_sequence_prediction/traindata.pt \
word_language_model/model.pt || error "couldn't clean up some files"
word_language_model/model.pt \
gcn/cora/ || error "couldn't clean up some files"

git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image"
}
Expand All @@ -217,6 +223,7 @@ function run_all() {
vision_transformer
word_language_model
fx
gcn
}

# by default, run all examples
Expand Down

0 comments on commit 5b78b34

Please sign in to comment.