-
Notifications
You must be signed in to change notification settings - Fork 3
/
test.py
78 lines (69 loc) · 2.82 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import wandb
import equinox as eqx
from neptune.pdes import DiffusionSorption1D, NavierStokes2D
from neptune.datasets import DatasetConfig
from neptune.models import dict_to_model_config
from neptune.trainers.optimizers import OptimizerConfig
from neptune.trainers.train_utils import TrainingConfig
from neptune.callbacks import visualization_callback, error_callback
from neptune.trainers.metrics import MetricsLogger
from neptune.trainers import Trainer
from neptune.trainers import restore_trainer_state
def convert_dict_to_tuples(d):
if isinstance(d, dict):
return {
k: convert_dict_to_tuples(v) for k, v in d.items()
}
elif isinstance(d, list) and not isinstance(d, str):
return tuple(convert_dict_to_tuples(v) for v in d)
else:
return d
def main():
parser = argparse.ArgumentParser(description='Resume a training run')
parser.add_argument('run_id', type=str, help='wandb run id')
parser.add_argument('project', type=str, help='wandb project name')
parser.add_argument('--batch_size', type=int, default=-1, help='Manually set batch size')
args = parser.parse_args()
run_id = args.run_id
project = args.project
wandb.init(entity='ml-pde', project=project, id=run_id, resume='must')
cfg = wandb.run.config.as_dict()
pde_name = cfg['pde']
if pde_name == 'Navier-Stokes2D':
pde = NavierStokes2D()
elif pde_name == 'Diffusion-Sorption1D':
pde = DiffusionSorption1D()
else:
raise NotImplementedError(f'PDE {pde_name} not implemented')
if args.batch_size != -1:
cfg['dataset']['batch_size'] = args.batch_size
cfg = convert_dict_to_tuples(cfg)
training_cfg = cfg['training']
training_cfg['optimizer'] = OptimizerConfig(**training_cfg['optimizer'])
training_cfg = TrainingConfig(**training_cfg)
dataset_cfg = DatasetConfig(**cfg['dataset'])
model_cfg = dict_to_model_config(cfg['model'])
cb = [visualization_callback, error_callback]
end_cb = []
logger = MetricsLogger(project=args.project,
pde=pde,
keys=[],
model_args=model_cfg,
train_args=training_cfg,
data_config=dataset_cfg,
wandb_init=False)
trainer = Trainer(model_cfg,
dataset_cfg,
training_cfg,
pde=pde,
logger=logger,
callbacks=cb,
finish_callbacks=end_cb)
ckpt_artifact = wandb.run.use_artifact(
f'{wandb.run.id}_model.ckpt:latest', type='model')
ckpt_file = ckpt_artifact.file()
trainer = restore_trainer_state(trainer, ckpt_file)
trainer._validiation()
if __name__ == '__main__':
main()