Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix file like object save load #614

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions python/src/nnabla/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nnabla.logger import logger
import nnabla.utils.nnabla_pb2 as nnabla_pb2
from nnabla.utils.get_file_handle import get_file_handle_load
from nnabla.utils.get_file_handle import get_file_handle_save

# TODO temporary work around to suppress FutureWarning message.
import warnings
Expand Down Expand Up @@ -335,7 +336,7 @@ def set_parameter_from_proto(proto):
var.d = param


def load_parameters(path, proto=None, needs_proto=False, file_like_type=".nntxt"):
def load_parameters(path, proto=None, needs_proto=False, extension=".nntxt"):
"""Load parameters from a file with the specified format.

Args:
Expand All @@ -344,7 +345,7 @@ def load_parameters(path, proto=None, needs_proto=False, file_like_type=".nntxt"
if isinstance(path, str):
_, ext = os.path.splitext(path)
else:
ext = file_like_type
ext = extension

if ext == '.h5':
# TODO temporary work around to suppress FutureWarning message.
Expand Down Expand Up @@ -412,7 +413,7 @@ def _get_keys(name):
return proto


def save_parameters(path, params=None):
def save_parameters(path, params=None, extension=None):
"""Save all parameters into a file with the specified format.

Currently hdf5 and protobuf formats are supported.
Expand All @@ -421,14 +422,17 @@ def save_parameters(path, params=None):
path : path or file object
params (dict, optional): Parameters to be saved. Dictionary is of a parameter name (:obj:`str`) to :obj:`~nnabla.Variable`.
"""
_, ext = os.path.splitext(path)
if isinstance(path, str):
_, ext = os.path.splitext(path)
else:
ext = extension
params = get_parameters(grad_only=False) if params is None else params
if ext == '.h5':
# TODO temporary work around to suppress FutureWarning message.
import warnings
warnings.simplefilter('ignore', category=FutureWarning)
import h5py
with h5py.File(path, 'w') as hd:
with get_file_handle_save(path, ext) as hd:
for i, (k, v) in enumerate(iteritems(params)):
hd[k] = v.d
hd[k].attrs['need_grad'] = v.need_grad
Expand All @@ -443,7 +447,7 @@ def save_parameters(path, params=None):
parameter.data.extend(numpy.array(variable.d).flatten().tolist())
parameter.need_grad = variable.need_grad

with open(path, "wb") as f:
with get_file_handle_save(path, ext) as f:
f.write(proto.SerializeToString())
else:
logger.critical('Only supported hdf5 or protobuf.')
Expand Down
64 changes: 36 additions & 28 deletions python/src/nnabla/utils/get_file_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,51 +19,59 @@

@contextlib.contextmanager
def get_file_handle_load(path, ext):
if isinstance(path, str):
if ext in ['.nntxt', '.prototxt']:
if ext == '.nnp':
need_close = True
f = zipfile.ZipFile(path, 'r')
elif ext == '.h5':
need_close = True
f = h5py.File(path, 'r')
elif ext in ['.nntxt', '.prototxt']:
if hasattr(path, 'read'):
need_close = False
f = path
else:
need_close = True
f = open(path, 'r')
elif ext == '.protobuf':
need_close = True
f = open(path, 'rb')
elif ext == '.nnp':
need_close = True
f = zipfile.ZipFile(path, 'r')
elif ext == '.h5':
need_close = True
f = h5py.File(path, 'r')
else:
raise ValueError("Currently, ext == {} is not support".format(ext))
else:
elif ext == '.protobuf':
if hasattr(path, 'read'):
need_close = False
f = path
else:
need_close = True
f = open(path, 'rb')
else:
raise ValueError("Currently, ext == {} is not support".format(ext))

yield f
if need_close:
f.close()


@contextlib.contextmanager
def get_file_handle_save(path, ext):
if isinstance(path, str):
if ext in ['.nntxt', '.prototxt']:
if ext == '.nnp':
need_close = True
f = zipfile.ZipFile(path, 'w')
elif ext == '.h5':
need_close = True
f = h5py.File(path, 'w')
elif ext in ['.nntxt', '.prototxt']:
if hasattr(path, 'read'):
need_close = False
f = path
else:
need_close = True
f = open(path, 'w')
elif ext == '.protobuf':
elif ext == '.protobuf':
if hasattr(path, 'read'):
need_close = False
f = path
else:
need_close = True
f = open(path, 'wb')
elif ext == '.nnp':
need_close = True
f = zipfile.ZipFile(path, 'w')
elif ext == '.h5':
need_close = True
f = h5py.File(path, 'w')
else:
raise ValueError("Currently, ext == {} is not support".format(ext))
else:
if hasattr(path, 'write'):
need_close = False
f = path
raise ValueError("Currently, ext == {} is not support".format(ext))

yield f
if need_close:
f.close()
55 changes: 24 additions & 31 deletions python/src/nnabla/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import re
import shutil
import tempfile
import zipfile

from nnabla.initializer import (
NormalInitializer, UniformInitializer, ConstantInitializer, RangeInitializer,
Expand All @@ -45,6 +44,7 @@
from nnabla.utils.network import Network
from nnabla.utils.progress import progress
from nnabla.utils.get_file_handle import get_file_handle_load

import nnabla as nn
import nnabla.function as F
import nnabla.solver as S
Expand Down Expand Up @@ -817,13 +817,13 @@ class Executor:
##########################################################################
# API
#
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, file_like_type=".nntxt"):
def load(filenames, prepare_data_iterator=True, batch_size=None, exclude_parameter=False, parameter_only=False, extension=".nntxt"):
'''load
Load network information from files.

Args:
filenames (list): file-like object or List of filenames.
file_like_type: if filenames is file-like object, file_like_type is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
extension: if filenames is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
Returns:
dict: Network information.
'''
Expand All @@ -842,7 +842,7 @@ class Info:
if isinstance(filename, str):
_, ext = os.path.splitext(filename)
else:
ext = file_like_type
ext = extension

# TODO: Here is some known problems.
# - Even when protobuf file includes network structure,
Expand All @@ -862,40 +862,33 @@ class Info:
raise
if len(proto.parameter) > 0:
if not exclude_parameter:
nn.load_parameters(filename)
nn.load_parameters(filename, extension=ext)
elif ext in ['.protobuf', '.h5']:
if not exclude_parameter:
nn.load_parameters(filename)
nn.load_parameters(filename, extension=ext)
else:
logger.info('Skip loading parameter.')

elif ext == '.nnp':
try:
tmpdir = tempfile.mkdtemp()
with get_file_handle_load(filename, ext) as nnp:
for name in nnp.namelist():
_, ext = os.path.splitext(name)
if name == 'nnp_version.txt':
nnp.extract(name, tmpdir)
with open(os.path.join(tmpdir, name), 'rt') as f:
pass # TODO currently do nothing with version.
elif ext in ['.nntxt', '.prototxt']:
nnp.extract(name, tmpdir)
if not parameter_only:
with open(os.path.join(tmpdir, name), 'rt') as f:
text_format.Merge(f.read(), proto)
if len(proto.parameter) > 0:
if not exclude_parameter:
nn.load_parameters(
os.path.join(tmpdir, name))
elif ext in ['.protobuf', '.h5']:
nnp.extract(name, tmpdir)
with get_file_handle_load(filename, ext) as nnp:
for name in nnp.namelist():
_, ext = os.path.splitext(name)
if name == 'nnp_version.txt':
pass # TODO currently do nothing with version.
elif ext in ['.nntxt', '.prototxt']:
if not parameter_only:
with nnp.open(name, 'r') as f:
text_format.Merge(f.read(), proto)
if len(proto.parameter) > 0:
if not exclude_parameter:
nn.load_parameters(os.path.join(tmpdir, name))
else:
logger.info('Skip loading parameter.')
finally:
shutil.rmtree(tmpdir)
with nnp.open(name, 'r') as f:
nn.load_parameters(f, extension=ext)
elif ext in ['.protobuf', '.h5']:
if not exclude_parameter:
with nnp.open(name, 'r') as f:
nn.load_parameters(f, extension=ext)
else:
logger.info('Skip loading parameter.')

default_context = None
if proto.HasField('global_config'):
Expand Down
44 changes: 22 additions & 22 deletions python/src/nnabla/utils/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

from collections import OrderedDict
import google.protobuf.text_format as text_format
import io
import numpy
import os
import re
import shutil
import tempfile
import zipfile

from nnabla import save_parameters
from nnabla.logger import logger
Expand Down Expand Up @@ -517,7 +517,7 @@ def create_proto(contents, include_params=False, variable_batch_size=True):
return proto


def save(filename, contents, include_params=False, variable_batch_size=True, file_like_type=".nntxt"):
def save(filename, contents, include_params=False, variable_batch_size=True, extension=".nnp"):
'''Save network definition, inference/training execution
configurations etc.

Expand All @@ -537,7 +537,7 @@ def save(filename, contents, include_params=False, variable_batch_size=True, fil
as batch size, and left as a placeholder
(more specifically ``-1``). The placeholder dimension will be
filled during/after loading.
file_like_type: if files is file-like object, file_like_type is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".
extension: if files is file-like object, extension is one of ".nntxt", ".prototxt", ".protobuf", ".h5", ".nnp".

Example:
The following example creates a two inputs and two
Expand Down Expand Up @@ -633,7 +633,7 @@ def save(filename, contents, include_params=False, variable_batch_size=True, fil
if isinstance(filename, str):
_, ext = os.path.splitext(filename)
else:
ext = file_like_type
ext = extension
if ext == '.nntxt' or ext == '.prototxt':
logger.info("Saving {} as prototxt".format(filename))
proto = create_proto(contents, include_params, variable_batch_size)
Expand All @@ -646,21 +646,21 @@ def save(filename, contents, include_params=False, variable_batch_size=True, fil
file.write(proto.SerializeToString())
elif ext == '.nnp':
logger.info("Saving {} as nnp".format(filename))
try:
tmpdir = tempfile.mkdtemp()
save('{}/network.nntxt'.format(tmpdir),
contents, include_params=False, variable_batch_size=variable_batch_size)

with open('{}/nnp_version.txt'.format(tmpdir), 'w') as file:
file.write('{}\n'.format(nnp_version()))

save_parameters('{}/parameter.protobuf'.format(tmpdir))

with get_file_handle_save(filename, ext) as nnp:
nnp.write('{}/nnp_version.txt'.format(tmpdir),
'nnp_version.txt')
nnp.write('{}/network.nntxt'.format(tmpdir), 'network.nntxt')
nnp.write('{}/parameter.protobuf'.format(tmpdir),
'parameter.protobuf')
finally:
shutil.rmtree(tmpdir)

nntxt = io.StringIO()
save(nntxt, contents, include_params=False,
variable_batch_size=variable_batch_size, extension='.nntxt')
nntxt.seek(0)

version = io.StringIO()
version.write('{}\n'.format(nnp_version()))
version.seek(0)

param = io.BytesIO()
save_parameters(param, extension='.protobuf')
param.seek(0)

with get_file_handle_save(filename, ext) as nnp:
nnp.writestr('nnp_version.txt', version.read())
nnp.writestr('network.nntxt', nntxt.read())
nnp.writestr('parameter.protobuf', param.read())
9 changes: 4 additions & 5 deletions python/test/utils/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ def test_save_load_with_file_object():
'network': 'net1',
'data': ['x0', 'x1'],
'output': ['y0', 'y1']}]}
import zipfile
with zipfile.ZipFile('tmp.nnp', 'w') as nnp:
nnabla.utils.save.save(nnp, contents, file_like_type='.nnp')
with zipfile.ZipFile('tmp.nnp', 'r') as nnp:
nnabla.utils.load.load(nnp, file_like_type='.nnp')
import io
nnpdata = io.BytesIO()
nnabla.utils.save.save(nnpdata, contents, extension='.nnp')
nnabla.utils.load.load(nnpdata, extension='.nnp')