-
Notifications
You must be signed in to change notification settings - Fork 55
/
revisiting.py
185 lines (158 loc) · 5.66 KB
/
revisiting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Reported (reproduced) results of FT-Transformer
https://arxiv.org/abs/2106.11959.
adult 85.9 (85.5)
helena 39.1 (39.2)
jannis 73.2 (72.2)
california_housing 0.459 (0.537)
--------
Reported (reproduced) results of ResNet
https://arxiv.org/abs/2106.11959
adult 85.7 (85.4)
helena 39.6 (39.1)
jannis 72.8 (72.5)
california_housing 0.486 (0.523)
"""
import argparse
import os.path as osp
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_frame import stype
from torch_frame.data import DataLoader
from torch_frame.datasets import Yandex
from torch_frame.nn import (
EmbeddingEncoder,
FTTransformer,
LinearBucketEncoder,
LinearEncoder,
LinearPeriodicEncoder,
ResNet,
)
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='adult')
parser.add_argument('--numerical_encoder_type', type=str, default='linear',
choices=['linear', 'linearbucket', 'linearperiodic'])
parser.add_argument('--model_type', type=str, default='fttransformer',
choices=['fttransformer', 'resnet'])
parser.add_argument('--channels', type=int, default=256)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--compile', action='store_true')
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Prepare datasets
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
args.dataset)
dataset = Yandex(root=path, name=args.dataset)
dataset.materialize()
is_classification = dataset.task_type.is_classification
train_dataset, val_dataset, test_dataset = dataset.split()
# Set up data loaders
train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size,
shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size)
test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size)
if args.numerical_encoder_type == 'linear':
numerical_encoder = LinearEncoder()
elif args.numerical_encoder_type == 'linearbucket':
numerical_encoder = LinearBucketEncoder()
elif args.numerical_encoder_type == 'linearperiodic':
numerical_encoder = LinearPeriodicEncoder()
else:
raise ValueError(
f'Unsupported encoder type: {args.numerical_encoder_type}')
stype_encoder_dict = {
stype.categorical: EmbeddingEncoder(),
stype.numerical: numerical_encoder,
}
if is_classification:
output_channels = dataset.num_classes
else:
output_channels = 1
if args.model_type == 'fttransformer':
model = FTTransformer(
channels=args.channels,
out_channels=output_channels,
num_layers=args.num_layers,
col_stats=dataset.col_stats,
col_names_dict=train_tensor_frame.col_names_dict,
stype_encoder_dict=stype_encoder_dict,
).to(device)
elif args.model_type == 'resnet':
model = ResNet(
channels=args.channels,
out_channels=output_channels,
num_layers=args.num_layers,
col_stats=dataset.col_stats,
col_names_dict=train_tensor_frame.col_names_dict,
).to(device)
else:
raise ValueError(f'Unsupported model type: {args.model_type}')
model = torch.compile(model, dynamic=True) if args.compile else model
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
def train(epoch: int) -> float:
model.train()
loss_accum = total_count = 0
for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'):
tf = tf.to(device)
pred = model(tf)
if is_classification:
loss = F.cross_entropy(pred, tf.y)
else:
loss = F.mse_loss(pred.view(-1), tf.y.view(-1))
optimizer.zero_grad()
loss.backward()
loss_accum += float(loss) * len(tf.y)
total_count += len(tf.y)
optimizer.step()
return loss_accum / total_count
@torch.no_grad()
def test(loader: DataLoader) -> float:
model.eval()
accum = total_count = 0
for tf in loader:
tf = tf.to(device)
pred = model(tf)
if is_classification:
pred_class = pred.argmax(dim=-1)
accum += float((tf.y == pred_class).sum())
else:
accum += float(
F.mse_loss(pred.view(-1), tf.y.view(-1), reduction='sum'))
total_count += len(tf.y)
if is_classification:
accuracy = accum / total_count
return accuracy
else:
rmse = (accum / total_count)**0.5
return rmse
if is_classification:
metric = 'Acc'
best_val_metric = 0
best_test_metric = 0
else:
metric = 'RMSE'
best_val_metric = float('inf')
best_test_metric = float('inf')
for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_metric = test(train_loader)
val_metric = test(val_loader)
test_metric = test(test_loader)
if is_classification and val_metric > best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric
elif not is_classification and val_metric < best_val_metric:
best_val_metric = val_metric
best_test_metric = test_metric
print(f'Train Loss: {train_loss:.4f}, Train {metric}: {train_metric:.4f}, '
f'Val {metric}: {val_metric:.4f}, Test {metric}: {test_metric:.4f}')
print(f'Best Val {metric}: {best_val_metric:.4f}, '
f'Best Test {metric}: {best_test_metric:.4f}')