-
Notifications
You must be signed in to change notification settings - Fork 11
/
main.py
119 lines (92 loc) · 3.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import rich.logging
import torch
import hydra
import warnings
import logging
from tracklab.utils import monkeypatch_hydra, \
progress # needed to avoid complex hydra stacktraces when errors occur in "instantiate(...)"
from hydra.utils import instantiate
from omegaconf import OmegaConf
from tracklab.datastruct import TrackerState
from tracklab.pipeline import Pipeline
from tracklab.utils import wandb
os.environ["HYDRA_FULL_ERROR"] = "1"
log = logging.getLogger(__name__)
warnings.filterwarnings("ignore")
@hydra.main(version_base=None, config_path="pkg://tracklab.configs", config_name="config")
def main(cfg):
device = init_environment(cfg)
# Instantiate all modules
tracking_dataset = instantiate(cfg.dataset)
evaluator = instantiate(cfg.eval, tracking_dataset=tracking_dataset)
modules = []
if cfg.pipeline is not None:
for name in cfg.pipeline:
module = cfg.modules[name]
inst_module = instantiate(module, device=device, tracking_dataset=tracking_dataset)
modules.append(inst_module)
pipeline = Pipeline(models=modules)
# Train tracking modules
for module in modules:
if module.training_enabled:
module.train()
# Test tracking
if cfg.test_tracking:
log.info(f"Starting tracking operation on {cfg.dataset.eval_set} set.")
# Init tracker state and tracking engine
tracking_set = tracking_dataset.sets[cfg.dataset.eval_set]
tracker_state = TrackerState(tracking_set, pipeline=pipeline, **cfg.state)
tracking_engine = instantiate(
cfg.engine,
modules=pipeline,
tracker_state=tracker_state,
)
# Run tracking and visualization
tracking_engine.track_dataset()
# Evaluation
evaluate(cfg, evaluator, tracker_state)
# Save tracker state
if tracker_state.save_file is not None:
log.info(f"Saved state at : {tracker_state.save_file.resolve()}")
close_enviroment()
return 0
def set_sharing_strategy():
torch.multiprocessing.set_sharing_strategy(
"file_system"
)
def init_environment(cfg):
# For Hydra and Slurm compatibility
progress.use_rich = cfg.use_rich