forked from martin-chobanyan/emotion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune_emotion.py
80 lines (67 loc) · 3.44 KB
/
finetune_emotion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file defines a script for a third and final iteration of fine-tuning. This step fine-tunes the model on the
emotions downloaded from Google Images using the download_data.py file. The base model will be the results from the
previous stage of fine-tuning (on FER data). Unlike the previous step, the entire model will not be trainable (only the
last bottleneck and fully connected layers are trainable for the resnet model).
"""
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, RandomAffine, RandomHorizontalFlip, ToTensor
from model import init_grayscale_resnet, train_epoch, val_epoch, checkpoint
# ----------------------------------------------------------------------------------------------------------------------
# Main script
# ----------------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':
NUM_EMOTIONS = 5
P_TRAIN = 0.7
BATCH_SIZE = 360
NUM_WORKERS = 10
LEARNING_RATE = 0.00005
NUM_EPOCHS = 100
CHECKPOINT_RATE = 5 # number of epochs after which to checkpoint the model
IMG_DIR = '/home/mchobanyan/data/emotion/images/gray/'
BASE_MODEL_PATH = '/home/mchobanyan/data/emotion/models/emotion_detect/fer-finetune-all/model_10.pt'
MODEL_DIR = '/home/mchobanyan/data/emotion/models/emotion_detect/gray-base/'
model = init_grayscale_resnet()
conv_out_features = model.fc.in_features
model.fc = nn.Linear(conv_out_features, NUM_EMOTIONS)
model.load_state_dict(torch.load(BASE_MODEL_PATH))
layers = list(model.children())
for i in range(len(layers) - 3):
for param in layers[i].parameters():
param.requires_grad = False
transforms = Compose([RandomHorizontalFlip(),
RandomAffine(degrees=10, translate=(0.25, 0.25), scale=(0.5, 1)),
ToTensor()])
dataset = ImageFolder(IMG_DIR, transform=transforms, loader=Image.open)
labels = dataset.classes
train_size = int(len(dataset) * P_TRAIN)
val_size = len(dataset) - train_size
train_data, val_data = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
device = torch.device('cuda')
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
train_losses = []
val_losses = []
for epoch in range(NUM_EPOCHS):
train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
train_losses.append(train_loss)
message = f'Epoch: {epoch}\tTrainLoss: {train_loss}'
if len(val_data) > 0: # only run validation epoch if the validation dataset is not empty
val_loss, val_acc = val_epoch(model, val_loader, criterion, device)
val_losses.append(val_loss)
message += f'\tValLoss: {val_loss}\tValAcc: {val_acc}'
print(message)
if epoch % CHECKPOINT_RATE == 0:
print('Checkpointing model...')
checkpoint(model, os.path.join(MODEL_DIR, f'model_{epoch}.pt'))