forked from SymbioticLab/Fluid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtune_syncbohb_dcgan.py
69 lines (52 loc) · 1.76 KB
/
tune_syncbohb_dcgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pathlib import Path
from ray import tune
from ray.tune.suggest.bohb import TuneBOHB
import workloads.common as com
from fluid.syncbohb import SyncBOHB
from fluid.trainer import TorchTrainer
from workloads.common import dcgan as workload
DATA_PATH, RESULTS_PATH = com.detect_paths()
EXP_NAME = com.remove_prefix(Path(__file__).stem, "tune_")
def setup_tune_scheduler(num_worker):
# BOHB uses ConfigSpace for their hyperparameter search space
config_space = workload.create_ch()
experiment_metrics = workload.exp_metric()
bohb_search = TuneBOHB(config_space, **experiment_metrics)
bohb_hyperband = SyncBOHB(
time_attr="training_iteration",
max_t=81,
reduction_factor=3,
**experiment_metrics
)
return dict(
scheduler=bohb_hyperband,
search_alg=bohb_search,
resources_per_trial=com.detect_baseline_resource(),
)
def main():
num_worker, sd = com.init_ray()
workload.init_dcgan()
MyTrainable_asha = TorchTrainer.as_trainable(
data_creator=workload.data_creator,
model_creator=workload.model_creator,
loss_creator=workload.loss_creator,
optimizer_creator=workload.optimizer_creator,
training_operator_cls=workload.GANOperator,
config={
"seed": sd,
**workload.static_config(),
"extra_fluid_trial_resources": {},
},
)
params = {
**com.run_options(__file__),
"stop": workload.create_stopper(),
**setup_tune_scheduler(num_worker),
}
analysis = tune.run(MyTrainable_asha, **params)
dfs = analysis.trial_dataframes
for logdir, df in dfs.items():
ld = Path(logdir)
df.to_csv(ld / "trail_dataframe.csv")
if __name__ == "__main__":
main()