-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
27 lines (20 loc) · 990 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import HfArgumentParser
from utils.configs import ModelArguments, DataArguments, TrainArguments, update_configs
from dataloader.data import get_dataloader
from networks.model import get_model
from approaches.train import Trainer
from approaches.eval import evaluate
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args, data_args, training_args = update_configs(model_args, data_args, training_args)
print(model_args, data_args, training_args)
dataloaders = get_dataloader(data_args, training_args)
model = get_model(model_args, data_args, training_args)
if training_args.eval:
evaluate(model, dataloaders, model_args, data_args, training_args)
else:
trainer = Trainer(model, dataloaders, model_args, data_args, training_args)
trainer.train()
if __name__ == '__main__':
main()