Skip to content
This repository has been archived by the owner on Feb 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #280 from disktnk/fix/snapshot-object
Browse files Browse the repository at this point in the history
Remove _snapshot_object calling
  • Loading branch information
ofk authored May 17, 2019
2 parents 53b5503 + 00bd3e4 commit eb077e0
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions chainerui/extensions/commands_extension.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import os
import shutil

from chainer.serializers import npz
from chainer.training import extension
from chainer.training.extensions._snapshot import _snapshot_object
from chainer.training import trigger as trigger_module
from chainer.training.triggers import IntervalTrigger
import six

from chainerui.utils.command_item import CommandItem
from chainerui.utils.commands_state import CommandsState
from chainerui.utils.tempdir import tempdir


def take_snapshot(trainer, body):
filename = 'snapshot_iter_{.updater.iteration}'
savefun = npz.save_npz
_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
filename = 'snapshot_iter_{.updater.iteration}'.format(trainer)
out_path = trainer.out
# same with SimpleWriter, supported from Chainer v6
with tempdir(prefix='snapshot', dir=out_path) as tempd:
path = os.path.join(tempd, filename)
npz.save_npz(path, trainer)
shutil.move(path, os.path.join(out_path, filename))


def adjust_hyperparams(trainer, body):
Expand Down

0 comments on commit eb077e0

Please sign in to comment.