Skip to content

Commit

Permalink
Complete prune load api (PaddlePaddle#430)
Browse files Browse the repository at this point in the history
* add new argument for easy load prune program

* add argument make load pruned program more easier

* change save function arg

* fix test

* fix format
  • Loading branch information
yukavio authored Aug 31, 2020
1 parent 81db340 commit a412b6f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
19 changes: 10 additions & 9 deletions paddleslim/prune/prune_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

_logger = get_logger(__name__, level=logging.INFO)

_PARAMS_FILE = "__params__"
_SHAPES_FILE = "__shapes__"


Expand All @@ -30,8 +29,8 @@ def save_model(exe, graph, dirname):
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=_PARAMS_FILE)
weights_file = os.path.join(dirname, _PARAMS_FILE)
filename=None)
weights_file = dirname
_logger.info("Save model weights into {}".format(weights_file))
shapes = {}
for var in fluid.io.get_program_persistable_vars(graph.program):
Expand All @@ -57,17 +56,19 @@ def load_model(exe, graph, dirname):
_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))
with open(SHAPES_FILE, "r") as f:
shapes = json.load(f)
for param, shape in shapes.items():
graph.var(param).set_shape(shape)
for param_name, shape in shapes.items():
param = graph.var(param_name)
if param is not None:
param.set_shape(shape)
else:
_logger.info('{} is not loaded'.format(param_name))

_logger.info("Load shapes of weights from {}".format(SHAPES_FILE))

fluid.io.load_persistables(
executor=exe,
dirname=dirname,
main_program=graph.program,
filename=_PARAMS_FILE)
filename=None)
graph.update_groups_of_conv()
graph.infer_shape()
_logger.info("Load weights from {}".format(
os.path.join(dirname, _PARAMS_FILE)))
_logger.info("Load weights from {}".format(dirname))
22 changes: 20 additions & 2 deletions tests/test_pruned_model_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def test_prune(self):
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
feature = fluid.layers.reshape(conv6, [-1, 128, 16])
predict = fluid.layers.fc(input=feature, size=10, act='softmax')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
print(label.shape)
print(predict.shape)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
adam_optimizer = fluid.optimizer.AdamOptimizer(0.01)
adam_optimizer.minimize(avg_cost)

place = fluid.CPUPlace()
exe = fluid.Executor(place)
Expand All @@ -55,9 +64,11 @@ def test_prune(self):
param_shape_backup=None)

x = numpy.random.random(size=(10, 3, 16, 16)).astype('float32')
label = numpy.random.random(size=(10, 1)).astype('int64')
loss_data, = exe.run(train_program,
feed={"image": x},
fetch_list=[conv6.name])
feed={"image": x,
"label": label},
fetch_list=[cost.name])

save_model(exe, main_program, 'model_file')
pruned_program = fluid.Program()
Expand All @@ -72,8 +83,10 @@ def test_prune(self):
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
pruned_test_program = pruned_program.clone(for_test=True)
exe.run(pruned_startup_program)
load_model(exe, pruned_program, 'model_file')
load_model(exe, pruned_test_program, 'model_file')
shapes = {
"conv1_weights": (4, 3, 3, 3),
"conv2_weights": (4, 4, 3, 3),
Expand All @@ -88,6 +101,11 @@ def test_prune(self):
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])
for param in pruned_test_program.global_block().all_parameters():
if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])


if __name__ == '__main__':
Expand Down

0 comments on commit a412b6f

Please sign in to comment.