-
Notifications
You must be signed in to change notification settings - Fork 0
/
deep_classiflie.py
executable file
·96 lines (82 loc) · 3.74 KB
/
deep_classiflie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
deep_classiflie: Deep Classiflie is a framework for developing ML models that bolster fact-checking efficiency.
The initial alpha release of Deep Classiflie generates/analyzes a model that continuously classifies a single
individual's statements (Donald Trump)<sup id="a1">[1](#f1)</sup> using a single ground truth labeling source
(The Washington Post). See deepclassiflie.org for current predictions and to explore the model and its performance.
@author: Dan Dale, @speediedan
"""
import logging
import os
import sys
from typing import MutableMapping, NoReturn, Optional
import utils.constants as constants
from dataprep.dataprep import DatasetCollection
from utils.core_utils import create_lock_file
from utils.envconfig import EnvConfig
from analysis.inference import Inference
from training.trainer import Trainer
import faulthandler
faulthandler.enable()
logger = logging.getLogger(constants.APP_NAME)
def main() -> Optional[NoReturn]:
config = EnvConfig().config
if config.experiment.dataprep_only:
_ = DatasetCollection(config)
elif config.experiment.predict_only and config.inference.pred_inputs:
Inference(config).init_predict()
elif config.experiment.infsvc.enabled:
init_dc_service(config, 'infsvc')
elif config.experiment.tweetbot.enabled:
init_dc_service(config, 'tweetbot')
elif config.inference.report_mode:
if not config.experiment.db_functionality_enabled:
logger.error(f"{constants.DB_WARNING_START} Model analysis reports {constants.DB_WARNING_END}")
sys.exit(0)
from analysis.model_analysis_rpt import ModelAnalysisRpt
ModelAnalysisRpt(config)
else:
core_flow(config)
def init_dc_service(config:MutableMapping, service_type: str) -> NoReturn:
if service_type == 'infsvc':
svc_name = 'inference service'
from utils.dc_infsvc import DCInfSvc
svc_module = DCInfSvc
else:
svc_name = 'tweetbot'
from utils.dc_tweetbot import DCTweetBot
svc_module = DCTweetBot
lock_file = None
try:
if not config.experiment.db_functionality_enabled:
logger.error(f"{constants.DB_WARNING_START} The {svc_name} {constants.DB_WARNING_END}")
sys.exit(0)
lock_file = create_lock_file()
svc_module(config)
os.remove(lock_file)
except KeyboardInterrupt:
logger.warning('Interrupted bot, removing lock file and exiting...')
os.remove(lock_file)
sys.exit(0)
def core_flow(config: MutableMapping) -> None:
dataset = DatasetCollection(config)
trainer = Trainer(dataset, config)
if config.experiment.inference_ckpt:
# testing mode takes precedence of training if both ckpts specified
logger.info(f'Testing model weights loaded from {config.experiment.inference_ckpt}...')
trainer.init_test(config.experiment.inference_ckpt)
elif config.trainer.restart_training_ckpt:
# restarting training takes precedence over just building custom swa checkpoints
logger.info(f'Restarting model training from {config.trainer.restart_training_ckpt}...')
trainer.train(config.trainer.restart_training_ckpt)
elif config.trainer.build_swa_from_ckpts:
logger.info(f'Building swa checkpoint from specified ckpts: {config.trainer.build_swa_from_ckpts}...')
swa_ckpt = trainer.swa_ckpt_build(mode="custom", ckpt_list=config.trainer.build_swa_from_ckpts)
logger.info(f'Successfully built SWA checkpoint ({swa_ckpt}) from provided list of checkpoints, '
f'proceeding with test')
trainer.init_test(swa_ckpt)
else:
logger.info('Starting model training from scratch...')
trainer.train()
if __name__ == '__main__':
repo_base = None
main()