-
Notifications
You must be signed in to change notification settings - Fork 1
/
engine.py
74 lines (54 loc) · 2.35 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
os.environ['MKL_SERVICE_FORCE_INTEL'] = "1"
if os.environ.get('DEBUG', False): print('\033[92m' + 'Running code in DEBUG mode' + '\033[0m')
import logging
from models import build_model
from processors import build_processor
from utils import set_seed
from runner.runner import Runner
logger = logging.getLogger(__name__)
def run(args, model, processor, optimizer, scheduler):
set_seed(args)
logger.info("train dataloader generation")
train_examples, train_features, train_dataloader, args.train_invalid_num = processor.generate_dataloader('train')
logger.info("dev dataloader generation")
dev_examples, dev_features, dev_dataloader, args.dev_invalid_num = processor.generate_dataloader('dev')
logger.info("test dataloader generation")
test_examples, test_features, test_dataloader, args.test_invalid_num = processor.generate_dataloader('test')
runner = Runner(
cfg=args,
data_samples=[train_examples, dev_examples, test_examples],
data_features=[train_features, dev_features, test_features],
data_loaders=[train_dataloader, dev_dataloader, test_dataloader],
model=model,
optimizer=optimizer,
scheduler=scheduler,
metric_fn_dict=None,
)
runner.run()
def main():
from config_parser import get_args_parser
args = get_args_parser()
if not args.inference_only:
print(f"Output full path {os.path.join(os.getcwd(), args.output_dir)}")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logging.basicConfig(
filename=os.path.join(args.output_dir, "log.txt"), \
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', \
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO
)
else:
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', \
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO
)
set_seed(args)
model, tokenizer, optimizer, scheduler = build_model(args, args.model_type)
model.to(args.device)
processor = build_processor(args, tokenizer)
logger.info("Training/evaluation parameters %s", args)
# DropoutRate, batch_size , learning_rate , epochs , LRpatience, seed
run(args, model, processor, optimizer, scheduler)
if __name__ == "__main__":
main()