Skip to content

Commit

Permalink
batch grow 파라미터, epoch 개념을 없애고 학습과정을 단순화함 #30
Browse files Browse the repository at this point in the history
  • Loading branch information
krikit committed Feb 10, 2019
1 parent ad99b4f commit d995b52
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 139 deletions.
219 changes: 107 additions & 112 deletions src/main/python/khaiii/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@
import os
import pathlib
import pprint
import sys
from typing import List, Tuple
from typing import Iterator, List, Tuple

from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from khaiii.train.dataset import PosDataset
from khaiii.train.dataset import PosDataset, PosSentTensor
from khaiii.train.evaluator import Evaluator
from khaiii.train.models import CnnModel
from khaiii.resource.resource import Resource
Expand All @@ -48,18 +47,17 @@ def __init__(self, cfg: Namespace):
"""
self.cfg = cfg
setattr(cfg, 'model_id', self.model_id(cfg))
setattr(cfg, 'out_dir', '{}/{}'.format(cfg.log_dir, cfg.model_id))
setattr(cfg, 'out_dir', '{}/{}'.format(cfg.logdir, cfg.model_id))
setattr(cfg, 'context_len', 2 * cfg.window + 1)
self.rsc = Resource(cfg)
self.model = CnnModel(cfg, self.rsc)
self.optimizer = torch.optim.Adam(self.model.parameters(), cfg.learning_rate)
self.criterion = nn.CrossEntropyLoss()
self.evaler = Evaluator()
self._load_dataset()
if 'epoch' not in cfg.__dict__:
setattr(cfg, 'epoch', 0)
setattr(cfg, 'iteration', 0)
setattr(cfg, 'iter_best', 0)
if 'step' not in cfg.__dict__:
setattr(cfg, 'step', 0)
setattr(cfg, 'best_step', 0)
self.log_file = None # tab separated log file
self.sum_wrt = None # tensorboard summary writer
self.loss_trains = []
Expand All @@ -68,7 +66,6 @@ def __init__(self, cfg: Namespace):
self.acc_words = []
self.f_scores = []
self.learning_rates = []
self.batch_sizes = []

@classmethod
def model_id(cls, cfg: Namespace) -> str:
Expand All @@ -88,8 +85,7 @@ def model_id(cls, cfg: Namespace) -> str:
'lr{}'.format(cfg.learning_rate),
'lrd{}'.format(cfg.lr_decay),
'bs{}'.format(cfg.batch_size),
'ci{}'.format(cfg.check_iter),
'bg{}'.format(cfg.batch_grow),
'cs{}'.format(cfg.check_step),
]
return '.'.join(model_cfgs)

Expand Down Expand Up @@ -133,78 +129,91 @@ def _elapsed(cls, td_obj: timedelta) -> str:
seconds -= minutes * 60
return '{}:{:02d}:{:02d}'.format(hours, minutes, seconds)

def _check_continue(self):
def _restore_prev_train(self):
"""
기존에 학습하다 중지한 경우 그 이후부터 계속해서 학습할 수 있도록 이전 상태를 복원한다.
"""
out_path = pathlib.Path(self.cfg.out_dir)
if not out_path.is_dir():
cfg_path = pathlib.Path('{}/config.json'.format(self.cfg.out_dir))
if not out_path.is_dir() or not cfg_path.is_file():
return
logging.info('==== continue learning: %s ====', self.cfg.model_id)
cfg = json.load(open('{}/config.json'.format(self.cfg.out_dir), 'r', encoding='UTF-8'))
logging.info('==== continue training: %s ====', self.cfg.model_id)
cfg = json.load(open(cfg_path, 'r', encoding='UTF-8'))
for key, val in cfg.items():
setattr(self.cfg, key, val)
self._revert(False)
self._revert_to_best(False)

f_score_best = 0.0
best_idx = -1
for idx, line in enumerate(open('{}/log.tsv'.format(self.cfg.out_dir))):
line = line.rstrip('\r\n')
if not line:
continue
(epoch, iteration, loss_train, loss_dev, acc_char, acc_word, f_score, learning_rate,
batch_size) = line.split('\t')
self.cfg.epoch = int(epoch) + 1
self.cfg.iteration = self.cfg.iter_best = int(iteration) * self.cfg.check_iter
(step, loss_train, loss_dev, acc_char, acc_word, f_score, learning_rate) \
= line.split('\t')
self.cfg.step = self.cfg.best_step = int(step) * self.cfg.check_step
self.loss_trains.append(float(loss_train))
self.loss_devs.append(float(loss_dev))
self.acc_chars.append(float(acc_char))
self.acc_words.append(float(acc_word))
self.f_scores.append(float(f_score))
self.learning_rates.append(float(learning_rate))
self.batch_sizes.append(int(batch_size))
if float(f_score) > f_score_best:
f_score_best = float(f_score)
best_idx = idx
logging.info('---- [%d|%d] loss(train/dev): %f / %f, acc(char/word): %f / %f, ' \
'f-score: %f, lr: %f, bs: %d ----', self.cfg.epoch,
self.cfg.iteration // self.cfg.check_iter, self.loss_trains[best_idx],
self.loss_devs[best_idx], self.acc_chars[best_idx], self.acc_words[best_idx],
self.f_scores[best_idx], self.learning_rates[-1], self.batch_sizes[-1])
logging.info('---- [%d] los(trn/dev): %.4f / %.4f, acc(chr/wrd): %.4f / %.4f, ' \
'f-score: %.4f, lr: %.8f ----', self.cfg.step // self.cfg.check_step,
self.loss_trains[best_idx], self.loss_devs[best_idx], self.acc_chars[best_idx],
self.acc_words[best_idx], self.f_scores[best_idx], self.learning_rates[-1])

@classmethod
def _inf_data_iterator(cls, dataset: PosDataset) -> Iterator[PosSentTensor]:
"""
데이터셋을 무한히 반복하여 문장을 출력하는 제너레이터
Args:
dataset: 데이터셋
Yields:
PosSentTensor 객체
"""
for _ in range(1000000):
for sent in dataset:
yield sent

def train(self):
"""
train model with dataset
"""
self._check_continue()
self._restore_prev_train()
logging.info('config: %s', pprint.pformat(self.cfg.__dict__))

train_begin = datetime.now()
logging.info('{{{{ training begin: %s {{{{', self._dt_str(train_begin))
if torch.cuda.is_available():
self.model.cuda()
pathlib.Path(self.cfg.out_dir).mkdir(parents=True, exist_ok=True)
self.log_file = open('{}/log.tsv'.format(self.cfg.out_dir), 'at')
self.sum_wrt = SummaryWriter(self.cfg.out_dir)
check_start = (1 if self.cfg.step == 0 else (self.cfg.step // self.cfg.check_step + 1))
patience = self.cfg.patience
for _ in range(10000):
has_best = self.train_epoch()
if has_best:
train_iter = self._inf_data_iterator(self.dataset_train)
for check_id in range(check_start, 1000000):
is_best = self._train_and_check(check_id, train_iter)
if is_best:
patience = self.cfg.patience
else:
if patience > 0:
self._revert(True)
patience -= 1
logging.info('==== revert to iter: %d, f-score: %f, patience: %d ====',
self.cfg.iter_best, max(self.f_scores), patience)
else:
break
continue
if patience <= 0:
break
self._revert_to_best(True)
patience -= 1
tqdm.write('==== revert to check: {}, f-score: {:.4f}, patience: {} ===='.format( \
self.cfg.best_step // self.cfg.check_step, max(self.f_scores), patience))

train_end = datetime.now()
train_elapsed = self._elapsed(train_end - train_begin)
logging.info('}}}} training end: %s, elapsed: %s, epoch: %s, iter: %dk }}}}',
self._dt_str(train_end), train_elapsed, self.cfg.epoch,
self.cfg.iteration // 1000)
logging.info('}}}} training end: %s, elapsed: %s, step: %dk }}}}',
self._dt_str(train_end), train_elapsed, self.cfg.step // 1000)

def _revert(self, is_decay_lr: bool):
def _revert_to_best(self, is_decay_lr: bool):
"""
이전 best 모델로 되돌린다.
Args:
Expand All @@ -215,79 +224,67 @@ def _revert(self, is_decay_lr: bool):
self.cfg.learning_rate *= self.cfg.lr_decay
self._load_optim('{}/optim.state'.format(self.cfg.out_dir), self.cfg.learning_rate)

def train_epoch(self) -> bool:
def _train_and_check(self, check_id: int, train_iter: Iterator[PosSentTensor]) -> bool:
"""
한 epoch을 학습한다. 배치 단위는 글자 단위
cfg.check_step 만큼의 step을 수행하고 evaluation을 수행한다.
Args:
check_id: check ID
train_iter: 학습 데이터 iterator
Returns:
best f-score를 기록한 iteration 여부
best f-score를 기록한 step 여부
"""
epoch_begin = datetime.now()
logging.info('{{{{ epoch: %d, begin: %s {{{{', self.cfg.epoch, self._dt_str(epoch_begin))
iter_lap1 = datetime.now()
has_best = False
batches = []
start_step = self.cfg.step
loss_batch = torch.tensor(0.0) # pylint: disable=not-callable
batch_size = 0
loss_trains = []
batch_size = self.cfg.batch_size
if self.cfg.batch_grow > 0:
batch_size = self.cfg.batch_size + self.cfg.iteration // self.cfg.batch_grow
for train_sent in tqdm(self.dataset_train, 'EPOCH [{}]'.format(self.cfg.epoch),
len(self.dataset_train), mininterval=1, ncols=100):
train_sents = tqdm(train_iter, '[{}]'.format(check_id), mininterval=1, ncols=100)
for train_sent in train_sents:
train_labels, train_contexts = train_sent.to_tensor(self.cfg, self.rsc, True)
if torch.cuda.is_available():
train_labels = train_labels.cuda()
train_contexts = train_contexts.cuda()
loss_batch = loss_batch.cuda()

self.optimizer.zero_grad()
self.model.train()
train_outputs = self.model(train_contexts)
batches.append((train_labels, train_outputs))
if sum([batch[0].size(0) for batch in batches]) < batch_size:
train_outputs.requires_grad_()
loss_train = self.criterion(train_outputs, train_labels)
loss_train.backward()
loss_trains.append(loss_train.item())
loss_batch += loss_train
batch_size += len(train_labels)
if batch_size < self.cfg.batch_size:
continue
batch_label = torch.cat([x[0] for x in batches], 0) # pylint: disable=no-member
batch_output = torch.cat([x[1] for x in batches], 0) # pylint: disable=no-member
batches = []

batch_output.requires_grad_()
loss_train = self.criterion(batch_output, batch_label)
loss_trains.append(loss_train.item())
loss_train.backward()
self.optimizer.step()
self.cfg.iteration += 1
self.optimizer.zero_grad()
self.sum_wrt.add_scalar('loss-batch', loss_batch.item(), self.cfg.step)
self.cfg.step += 1
loss_batch = torch.tensor(0.0) # pylint: disable=not-callable
batch_size = 0

if self.cfg.iteration % self.cfg.check_iter == 0:
avg_loss_dev, acc_char, acc_word, f_score = self.evaluate()
iter_lap2 = datetime.now()
iter_elapsed = self._elapsed(iter_lap2 - iter_lap1)
iter_lap1 = iter_lap2
has_best |= self._check_iter(iter_elapsed, loss_trains, avg_loss_dev, acc_char,
acc_word, f_score, batch_size)
if self.cfg.batch_grow > 0 and self.cfg.iteration % self.cfg.batch_grow == 0:
batch_size = self.cfg.batch_size + self.cfg.iteration // self.cfg.batch_grow
if (self.cfg.step - start_step) >= self.cfg.check_step:
train_sents.close()
break

print(file=sys.stderr)
sys.stderr.flush()
epoch_end = datetime.now()
epoch_elapsed = self._elapsed(epoch_end - epoch_begin)
logging.info('}}}} epoch: %d, end: %s, elapsed: %s, %s best }}}}', self.cfg.epoch,
self._dt_str(epoch_begin), epoch_elapsed, 'hit' if has_best else 'did not hit')
self.cfg.epoch += 1
return has_best
avg_loss_dev, acc_char, acc_word, f_score = self.evaluate()
return self._check(check_id, loss_trains, avg_loss_dev, acc_char, acc_word, f_score)

def _check_iter(self, elapsed: str, loss_trains: List[float], avg_loss_dev: float,
acc_char: float, acc_word: float, f_score: float, batch_size: int) -> bool:
def _check(self, check_id: int, loss_trains: List[float], avg_loss_dev: float, acc_char: float,
acc_word: float, f_score: float) -> bool:
"""
cfg.check_iter번의 iteration마다 수행하는 체크
cfg.check_step번의 step마다 수행하는 체크
Args:
elapsed: 경과 시간
loss_trains: train 코퍼스에서 각 배치별 loss 리스트
check_id: check ID
loss_trains: train 코퍼스에서 각 배치별 loss 리스트
avg_loss_dev: dev 코퍼스 문장 별 평균 loss
acc_char: 음절 정확도
acc_word: 어절 정확도
f_score: f-score
batch_size: batch size
Returns:
현재 iteration이 best 성능을 나타냈는 지 여부
현재 step이 best 성능을 나타냈는 지 여부
"""
assert check_id == self.cfg.step // self.cfg.check_step
avg_loss_train = sum(loss_trains) / len(loss_trains)
loss_trains.clear()
self.loss_trains.append(avg_loss_train)
Expand All @@ -296,36 +293,34 @@ def _check_iter(self, elapsed: str, loss_trains: List[float], avg_loss_dev: floa
self.acc_words.append(acc_word)
self.f_scores.append(f_score)
self.learning_rates.append(self.cfg.learning_rate)
self.batch_sizes.append(batch_size)
check_num = self.cfg.iteration // self.cfg.check_iter
logging.info('---- [%d|%d] elapsed: %s, loss(train/dev): %f / %f, acc(char/word): %f / %f' \
', f-score: %f (max: %f), lr: %f, bs: %d ----', self.cfg.epoch, check_num,
elapsed, avg_loss_train, avg_loss_dev, acc_char, acc_word, f_score,
max(self.f_scores), self.cfg.learning_rate, batch_size)
print('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format( \
self.cfg.epoch, check_num, avg_loss_train, avg_loss_dev, acc_char, acc_word, f_score,
self.cfg.learning_rate, batch_size), file=self.log_file)
is_best = self._is_best()
is_best_str = 'BEST' if is_best else '< {:.4f}'.format(max(self.f_scores))
tqdm.write(' [Los trn] [Los dev] [Acc chr] [Acc wrd] [F-score] [LR]')
tqdm.write(' {:9.4f} {:9.4f} {:9.4f} {:9.4f} {:9.4f} {:8} {:.8f}'.format( \
avg_loss_train, avg_loss_dev, acc_char, acc_word, f_score, is_best_str,
self.cfg.learning_rate))
print('{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(check_id, avg_loss_train, avg_loss_dev,
acc_char, acc_word, f_score,
self.cfg.learning_rate), file=self.log_file)
self.log_file.flush()
self.sum_wrt.add_scalar('loss-train', avg_loss_train, check_num)
self.sum_wrt.add_scalar('loss-dev', avg_loss_dev, check_num)
self.sum_wrt.add_scalar('acc-char', acc_char, check_num)
self.sum_wrt.add_scalar('acc-word', acc_word, check_num)
self.sum_wrt.add_scalar('f-score', f_score, check_num)
self.sum_wrt.add_scalar('learning-rate', self.cfg.learning_rate, check_num)
self.sum_wrt.add_scalar('batch-size', batch_size, check_num)
return self._check_best()
self.sum_wrt.add_scalar('loss-train', avg_loss_train, check_id)
self.sum_wrt.add_scalar('loss-dev', avg_loss_dev, check_id)
self.sum_wrt.add_scalar('acc-char', acc_char, check_id)
self.sum_wrt.add_scalar('acc-word', acc_word, check_id)
self.sum_wrt.add_scalar('f-score', f_score, check_id)
self.sum_wrt.add_scalar('learning-rate', self.cfg.learning_rate, check_id)
return is_best

def _check_best(self) -> bool:
def _is_best(self) -> bool:
"""
이번 iteration에 가장 좋은 성능을 냈는 지 확인하고 그럴 경우 현재 상태를 저장한다.
이번 step에 가장 좋은 성능을 냈는 지 확인하고 그럴 경우 현재 상태를 저장한다.
Returns:
best 여부
"""
if len(self.f_scores) > 1 and max(self.f_scores[:-1]) >= self.f_scores[-1]:
return False
# this iteration hits new max value
logging.info('==== best model: %f ====', self.f_scores[-1])
self.cfg.iter_best = self.cfg.iteration
# this step hits new max value
self.cfg.best_step = self.cfg.step
self.model.save('{}/model.state'.format(self.cfg.out_dir))
self._save_optim('{}/optim.state'.format(self.cfg.out_dir))
with open('{}/config.json'.format(self.cfg.out_dir), 'w', encoding='UTF-8') as fout:
Expand Down
Loading

0 comments on commit d995b52

Please sign in to comment.