-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
96 lines (80 loc) · 2.23 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import os
import pytorch_lightning as pl
import sys
import torch
from adversarial.data import AdversarialLearningDataModule
from adversarial.model import AdversarialTrainingModule
from baseline.data import ArgsBase
from baseline.data import OLIDataModule
from baseline.data import OLIDataset
from baseline.model import ClassificationModule
from plot_utils.tsne import TSNE
from utils import generate_exp_name
def main(args):
# fix random seeds for reproducibility
SEED = 123
pl.seed_everything(SEED)
# generate experiment name
exp_name = generate_exp_name(args)
# Set device
device = torch.device(f'cuda:{args.device}') if torch.cuda.is_available() else torch.device('cpu')
# Load model
if args.load_from is None:
model = ClassificationModule(args)
else:
model = ClassificationModule.load_from_checkpoint(
checkpoint_path=args.load_from,
args=args,
strict=False
)
model.eval()
model.freeze()
# Load dataset
dataset = OLIDataset(
filepath=args.input_file,
enc_model=args.bert
)
tsne_plotter = TSNE(model, dataset, device, exp_name)
tsne_plotter.visualize()
if __name__ == "__main__":
sys.path.append(
os.path.dirname(os.path.abspath(os.path.dirname("__file__")))
)
parser = argparse.ArgumentParser()
parser.add_argument(
'--task',
type=str,
default='plot'
)
parser.add_argument(
'--input_file',
type=str,
nargs='+'
)
parser.add_argument(
'--load_from',
type=str
)
parser.add_argument(
'--bert',
type=str,
default='mbert'
)
parser.add_argument(
'--exp_name',
type=str,
default=''
)
parser.add_argument(
'--device',
type=int,
default=0
)
parser = ArgsBase.add_model_specific_args(parser)
parser = ClassificationModule.add_model_specific_args(parser)
parser = AdversarialTrainingModule.add_model_specific_args(parser)
parser = OLIDataModule.add_model_specific_args(parser)
parser = AdversarialLearningDataModule.add_model_specific_args(parser)
args = parser.parse_args()
main(args)