-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
executable file
·102 lines (78 loc) · 2.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
'''
SynST
--
Main entry point for training the SynST
'''
from __future__ import print_function
import sys
import threading
from contextlib import ExitStack
# comet_ml now fails to initialize if any torch module
# is loaded before it... so make sure it's loaded first
# pylint:disable=unused-import
import comet_ml
# pylint:enable=unused-import
import torch
from torch.autograd import profiler, set_detect_anomaly
from args import parse_args
from data.utils import get_dataloader
from models.utils import restore
from utils import profile
# import comet_ml in the top of your file
from comet_ml import Experiment
# Add the following code anywhere in your machine learning file
def main(argv=None):
''' Main entry point '''
args = parse_args(argv)
# initialize indices_matq
print(f'Running torch {torch.version.__version__}')
profile_cuda_memory = args.config.cuda.profile_cuda_memory
pin_memory = 'cuda' in args.device.type and not profile_cuda_memory
dataloader = get_dataloader(
args.config.data, args.seed_fn, pin_memory,
args.num_devices, shuffle=args.shuffle
)
print(dataloader.dataset.stats)
model = args.model(args.config.model, dataloader.dataset)
action = args.action(args.action_config, model, dataloader, args.device)
if args.action_type == 'train' and args.action_config.early_stopping:
args.config.data.split = 'valid'
args.config.data.max_examples = 0
action.validation_dataloader = get_dataloader(
args.config.data, args.seed_fn, pin_memory,
args.num_devices, shuffle=args.shuffle
)
if args.config.cuda.profile_cuda_memory:
print('Profiling CUDA memory')
memory_profiler = profile.CUDAMemoryProfiler(
action.modules.values(),
filename=profile_cuda_memory
)
sys.settrace(memory_profiler)
threading.settrace(memory_profiler)
step = 0
epoch = 0
if args.restore:
restore_modules = {
module_name: module
for module_name, module in action.modules.items()
if module_name not in args.reset_parameters
}
epoch, step = restore(
args.restore,
restore_modules,
num_checkpoints=args.average_checkpoints,
map_location=args.device.type,
strict=not args.reset_parameters
)
model.reset_named_parameters(args.reset_parameters)
if 'step' in args.reset_parameters:
step = 0
epoch = 0
args.experiment.set_step(step)
with ExitStack() as stack:
stack.enter_context(profiler.emit_nvtx(args.config.cuda.profile_cuda))
stack.enter_context(set_detect_anomaly(args.detect_anomalies))
action(epoch, args.experiment, args.verbose)
if __name__ == '__main__':
main()