Skip to content

Commit

Permalink
feature: save config file (#26)
Browse files Browse the repository at this point in the history
* feature: save config file

* feature: launch_runner

* fix:

* feature: init logger in launcher
  • Loading branch information
cnstark authored Aug 3, 2021
1 parent 6a1c93a commit a0e0734
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
15 changes: 15 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
except the key is in `TRAINING_INDEPENDENT_KEYS` or `CFG._TRAINING_INDEPENDENT`
"""

import os
import shutil
import types
import copy
import hashlib
Expand Down Expand Up @@ -209,6 +211,19 @@ def save_config(cfg: dict, file_path: str):
f.write(content)


def copy_config_file(cfg_file_path: str, save_dir: str):
"""Copy config file to `save_dir`
Args:
cfg_file_path (str): config file path
save_dir (str): save directory
"""

if os.path.isfile(cfg_file_path) and os.path.isdir(save_dir):
cfg_file_name = os.path.basename(cfg_file_path)
shutil.copyfile(cfg_file_path, os.path.join(save_dir, cfg_file_name))


def import_config(path: str, verbose: bool = True) -> dict:
"""Import config by path
Expand Down
27 changes: 26 additions & 1 deletion core/launcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import random
from typing import Callable

Expand All @@ -6,7 +7,7 @@
from torch.distributed import Backend
from torch import multiprocessing as mp

from ..config import import_config
from ..config import import_config, config_md5, save_config, copy_config_file
from ..utils import set_gpus, set_tf32_mode


Expand All @@ -31,6 +32,9 @@ def train(cfg: dict, use_gpu: bool, tf32_mode: bool):
Runner = cfg['RUNNER']
runner = Runner(cfg, use_gpu)

# init logger (after making ckpt save dir)
runner.init_logger(logger_name='easytorch-training', log_file_name='training_log')

# train
runner.train(cfg)

Expand Down Expand Up @@ -83,7 +87,10 @@ def launch_training(cfg: dict or str, gpus: str, tf32_mode: bool):
"""

if isinstance(cfg, str):
cfg_path = cfg
cfg = import_config(cfg)
else:
cfg_path = None

use_gpu = cfg.get('USE_GPU', True)
gpu_num = cfg.get('GPU_NUM', 0)
Expand All @@ -103,6 +110,17 @@ def launch_training(cfg: dict or str, gpus: str, tf32_mode: bool):
if gpu_num != 0:
raise RuntimeError('Easytorch is running in CPU mode, but cfg.GPU_NUM is not zero')

# convert ckpt save dir
cfg['TRAIN']['CKPT_SAVE_DIR'] = os.path.join(cfg['TRAIN']['CKPT_SAVE_DIR'], config_md5(cfg))

# save config
if not os.path.isdir(cfg['TRAIN']['CKPT_SAVE_DIR']):
os.makedirs(cfg['TRAIN']['CKPT_SAVE_DIR'])
if cfg_path is None:
save_config(cfg, os.path.join(cfg['TRAIN']['CKPT_SAVE_DIR'], 'param.txt'))
else:
copy_config_file(cfg_path, cfg['TRAIN']['CKPT_SAVE_DIR'])

if gpu_num <= 1:
train(cfg, use_gpu, tf32_mode)
else:
Expand Down Expand Up @@ -143,6 +161,13 @@ def launch_runner(cfg: dict or str, fn: Callable, args: tuple = (), gpus: str =
set_gpus(gpus)
set_tf32_mode(tf32_mode)

# convert ckpt save dir
cfg['TRAIN']['CKPT_SAVE_DIR'] = os.path.join(cfg['TRAIN']['CKPT_SAVE_DIR'], config_md5(cfg))

# make ckpt save dir
if not os.path.isdir(cfg['TRAIN']['CKPT_SAVE_DIR']):
os.makedirs(cfg['TRAIN']['CKPT_SAVE_DIR'])

Runner = cfg['RUNNER']
runner = Runner(cfg, use_gpu)

Expand Down
11 changes: 1 addition & 10 deletions core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from .checkpoint import get_ckpt_dict, load_ckpt, save_ckpt, backup_last_ckpt, clear_ckpt
from .data_loader import build_data_loader, build_data_loader_ddp
from .optimizer_builder import build_optim, build_lr_scheduler
from ..config import config_md5, save_config
from ..utils import TimePredictor, get_logger, get_rank, is_master, master_only, setup_random_seed


Expand All @@ -32,7 +31,7 @@ def __init__(self, cfg: dict, use_gpu: bool = True):
# param
self.use_gpu = use_gpu
self.model_name = cfg['MODEL']['NAME']
self.ckpt_save_dir = os.path.join(cfg['TRAIN']['CKPT_SAVE_DIR'], config_md5(cfg))
self.ckpt_save_dir = cfg['TRAIN']['CKPT_SAVE_DIR']
self.logger.info('ckpt save dir: \'{}\''.format(self.ckpt_save_dir))
self.ckpt_save_strategy = None
self.num_epochs = None
Expand Down Expand Up @@ -359,14 +358,6 @@ def init_training(self, cfg: dict):
self.start_epoch = 0
self.ckpt_save_strategy = cfg['TRAIN'].get('CKPT_SAVE_STRATEGY')

# make ckpt_save_dir
if is_master() and not os.path.isdir(self.ckpt_save_dir):
os.makedirs(self.ckpt_save_dir)
save_config(cfg, os.path.join(self.ckpt_save_dir, 'param.txt'))

# init logger (after making ckpt save dir)
self.init_logger(logger_name='easytorch-training', log_file_name='training_log')

# train data loader
self.train_data_loader = self.build_train_data_loader(cfg)
self.register_epoch_meter('train_time', 'train', '{:.2f} (s)', plt=False)
Expand Down

0 comments on commit a0e0734

Please sign in to comment.