From 00bd3e42b93512df67e881eb38236bc820053990 Mon Sep 17 00:00:00 2001 From: Daisuke Tanaka Date: Thu, 16 May 2019 18:59:35 +0900 Subject: [PATCH] fix calling no support snapshot object --- chainerui/extensions/commands_extension.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/chainerui/extensions/commands_extension.py b/chainerui/extensions/commands_extension.py index 0cd5fcd5..4b66cf34 100644 --- a/chainerui/extensions/commands_extension.py +++ b/chainerui/extensions/commands_extension.py @@ -1,18 +1,25 @@ +import os +import shutil + from chainer.serializers import npz from chainer.training import extension -from chainer.training.extensions._snapshot import _snapshot_object from chainer.training import trigger as trigger_module from chainer.training.triggers import IntervalTrigger import six from chainerui.utils.command_item import CommandItem from chainerui.utils.commands_state import CommandsState +from chainerui.utils.tempdir import tempdir def take_snapshot(trainer, body): - filename = 'snapshot_iter_{.updater.iteration}' - savefun = npz.save_npz - _snapshot_object(trainer, trainer, filename.format(trainer), savefun) + filename = 'snapshot_iter_{.updater.iteration}'.format(trainer) + out_path = trainer.out + # same with SimpleWriter, supported from Chainer v6 + with tempdir(prefix='snapshot', dir=out_path) as tempd: + path = os.path.join(tempd, filename) + npz.save_npz(path, trainer) + shutil.move(path, os.path.join(out_path, filename)) def adjust_hyperparams(trainer, body):