Skip to content

Commit

Permalink
Merge pull request #2 from Chriskuei/hotfix/save_load_trainer
Browse files Browse the repository at this point in the history
Fix save and load trainer error && Support clip grad norm
  • Loading branch information
caiyinqiong authored Jun 18, 2019
2 parents b7ceba8 + 0db1c84 commit 5225711
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 27 deletions.
28 changes: 28 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
*.pyc
*.log
*.swp
*.bak
*.weights
*.trec
*.ranklist
*.DS_Store
.vscode
.coverage
.ipynb_checkpoints/
predict.*
build/
dist/
data/
save/
log/*
.ipynb_checkpoints/
matchzoo/log/*
matchzoo/querydecision/
log/*
.idea/
.pytest_cache/
MatchZoo.egg-info/
notebooks/wikiqa/.ipynb_checkpoints/*
.cache
.tmpdir
htmlcov/
2 changes: 1 addition & 1 deletion matchzoo/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
sort: bool = True,
callbacks: typing.List[Callback] = None
):
""""Init."""
"""Init."""
if stage not in ('train', 'dev', 'test'):
raise ValueError(f"{stage} is not a valid stage type."
f"Must be one of `train`, `dev`, `test`.")
Expand Down
81 changes: 55 additions & 26 deletions matchzoo/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ class Trainer:
:param validate_interval: Int. Interval of validation.
:param scheduler: LR scheduler used to adjust the learning rate
based on the number of epochs.
:param clip_norm: Max norm of the gradients to be clipped.
:param patience: Number fo events to wait if no improvement and
then stop the training.
:param data_parallel: Bool. Whether support data parallel.
:param checkpoint: A checkpoint from which to continue training.
If None, training starts from scratch. Defaults to None.
Should be a file-like object (has to implement read, readline,
tell, and seek), or a string containing a file name.
:param save_path: Path to save trainer.
:param save_dir: Directory to save trainer.
:param verbose: 0, 1, or 2. Verbosity mode. 0 = silent,
1 = verbose, 2 = one log line per epoch.
"""
Expand All @@ -62,10 +63,11 @@ def __init__(
epochs: int = 10,
validate_interval: typing.Optional[int] = None,
scheduler: typing.Any = None,
clip_norm: typing.Union[float, int] = None,
patience: typing.Optional[int] = None,
data_parallel: bool = True,
checkpoint: typing.Union[str, typing.Any] = None,
save_path: typing.Optional[str] = None,
checkpoint: typing.Union[str, Path] = None,
save_dir: typing.Union[str, Path] = None,
verbose: int = 1,
**kwargs
):
Expand All @@ -78,25 +80,18 @@ def __init__(
self._task = self._model.params['task']
self._optimizer = optimizer
self._scheduler = scheduler
self._clip_norm = clip_norm
self._criterions = self._task.losses
self._early_stopping = EarlyStopping(
patience=patience,
key=self._task.metrics[0]
)

self._start_epoch = start_epoch
self._epochs = epochs
self._iteration = 0

self._verbose = verbose

if checkpoint:
self.restore(checkpoint)
# TODO: Change Path
if save_path:
self._save_path = save_path
else:
self._save_path = './save'
self._load_path(save_dir, checkpoint)

def _load_dataloader(
self,
Expand Down Expand Up @@ -148,13 +143,42 @@ def _load_model(
'model should be a `BaseModel` instance.'
f' But got {type(model)}.'
)
if device is None or not isinstance(device, torch.device):
device = torch.device(
'cuda:0' if torch.cuda.is_available() else 'cpu'
)
self.device = device
self._model = model.to(self.device)
if (("cuda" in str(self.device)) and (
if (('cuda' in str(self.device)) and (
torch.cuda.device_count() > 1) and (
data_parallel is True)):
self._model = torch.nn.DataParallel(self._model)

def _load_path(
self,
save_dir: typing.Union[str, Path],
checkpoint: typing.Union[str, Path]
):
"""
Load save_dir and Restore from checkpoint.
:param checkpoint: A checkpoint from which to continue training.
If None, training starts from scratch. Defaults to None.
Should be a file-like object (has to implement read, readline,
tell, and seek), or a string containing a file name.
:param save_dir: Directory to save trainer.
"""
if save_dir:
self._save_dir = Path(save_dir)
else:
save_dir = Path('.').joinpath('save')
if not save_dir.exists():
save_dir.mkdir(parents=True)
self._save_dir = save_dir
# Restore from checkpoint
if checkpoint:
self.restore(checkpoint)

def _backward(self, loss):
"""
Computes the gradient of current `loss` graph leaves.
Expand All @@ -163,6 +187,10 @@ def _backward(self, loss):
"""
self._optimizer.zero_grad()
loss.backward()
if self._clip_norm:
nn.utils.clip_grad_norm_(
self._model.parameters(), self._clip_norm
)
self._optimizer.step()

def _run_scheduler(self):
Expand Down Expand Up @@ -320,19 +348,19 @@ def save(self):
"""
best_so_far = self._early_stopping.best_so_far
save_path = Path(self._save_path).joinpath(
f'model_epoch-{self._epoch}_best-{best_so_far:.4f}.pt'
path = self._save_dir.joinpath(
f'trainer-epoch_{self._epoch}-best_{best_so_far:.4f}.pt'
)
checkpoint = {
state = {
'epoch': self._epoch,
'best_so_far': best_so_far,
'model': self._model.state_dict(),
'optimizer': self._optimizer.state_dict(),
'early_stopping': self._optimizer.state_dict(),
'early_stopping': self._early_stopping.state_dict(),
}
torch.save(checkpoint, save_path)
if self._scheduler:
state['scheduler'] = self._scheduler.state_dict()
torch.save(state, path)

# TODO: model save
def restore(self, checkpoint: typing.Union[str, typing.Any]):
"""
Restore trainer.
Expand All @@ -343,9 +371,10 @@ def restore(self, checkpoint: typing.Union[str, typing.Any]):
tell, and seek), or a string containing a file name.
"""
checkpoint = torch.load(checkpoint)
self._model.load_state_dict(checkpoint['model'])
self._optimizer.load_state_dict(checkpoint['optimizer'])
self._start_epoch = checkpoint['epoch']
self._early_stopping.load_state_dict(
checkpoint['early_stopping'])
state = torch.load(checkpoint)
self._model.load_state_dict(state['model'])
self._optimizer.load_state_dict(state['optimizer'])
self._start_epoch = state['epoch'] + 1
self._early_stopping.load_state_dict(state['early_stopping'])
if self._scheduler:
self._scheduler.load_state_dict(state['scheduler'])

0 comments on commit 5225711

Please sign in to comment.