-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
52 lines (35 loc) · 1.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch.nn import nn
from torch.utils.data import DataLoader
from data import dataloader
from config import args
from utils import transform, plot
from models import ResNet, CNN
import train
metadata = args.metadata
path = args.path
data= dataloader.CustomDataSet(metadata, path, transform= transform)
batch_size= args.bs
train_size= args.train_size
train_size= int(train_size*len(data))
val_size = len(data) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(data, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
#Model
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-m" ,"--model_name", help = "This is the name of the model", required = True)
parser.add_argument("-n" ,"--num_epochs", help = "This is the number of epochs", type = int, required = True)
mains_args = vars(parser.parse_args())
num_epochs = mains_args["num_epochs"]
if mains_args["model_name"].lower() == "resnet":
model = ResNet()
elif mains_args["model_name"].lower() == "CNN":
# instantiate CNN model
model = CNN()
criterion= nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), args.lr, args.wd)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_trained, percent, val_loss, val_acc, train_loss, train_acc = train.train(model, criterion, train_loader,val_loader, optimizer, num_epochs, DEVICE)
plot(train_loss, val_loss)