From 6b0ecee66c29c38ebabc2d802827ea6b5b2c9e6c Mon Sep 17 00:00:00 2001 From: chicm-ms <38930155+chicm-ms@users.noreply.github.com> Date: Wed, 5 Feb 2020 14:47:09 +0800 Subject: [PATCH] fix compressor ut (#1997) --- src/sdk/pynni/tests/test_compressor.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/sdk/pynni/tests/test_compressor.py b/src/sdk/pynni/tests/test_compressor.py index 778f4341e9..168b949021 100644 --- a/src/sdk/pynni/tests/test_compressor.py +++ b/src/sdk/pynni/tests/test_compressor.py @@ -135,12 +135,12 @@ def test_torch_fpgm_pruner(self): model.conv2.weight.data = torch.tensor(w).float() layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2) - masks = pruner.calc_mask(layer, config_list[0]) + masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0)) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) pruner.update_epoch(1) model.conv2.weight.data = torch.tensor(w).float() - masks = pruner.calc_mask(layer, config_list[1]) + masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0)) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) @tf2 @@ -187,9 +187,9 @@ def test_torch_l1filter_pruner(self): model.conv1.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float() layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1) - mask1 = pruner.calc_mask(layer1, config_list[0]) + mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0)) layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2) - mask2 = pruner.calc_mask(layer2, config_list[1]) + mask2 = pruner.calc_mask(layer2, config_list[1], if_calculated=torch.tensor(0)) assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.])) assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.])) @@ -215,9 +215,9 @@ def test_torch_slim_pruner(self): pruner = torch_compressor.SlimPruner(model, config_list) layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1) - mask1 = pruner.calc_mask(layer1, config_list[0]) + mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0)) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) - mask2 = pruner.calc_mask(layer2, config_list[0]) + mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0)) assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.])) @@ -229,9 +229,9 @@ def test_torch_slim_pruner(self): pruner = torch_compressor.SlimPruner(model, config_list) layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1) - mask1 = pruner.calc_mask(layer1, config_list[0]) + mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0)) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) - mask2 = pruner.calc_mask(layer2, config_list[0]) + mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0)) assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.])) @@ -268,14 +268,14 @@ def test_torch_QAT_quantizer(self): # test ema x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) out = model.relu(x) - assert math.isclose(model.relu.tracked_min_biased, 0, abs_tol=eps) - assert math.isclose(model.relu.tracked_max_biased, 0.002, abs_tol=eps) + assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps) + assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps) quantizer.step() x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) out = model.relu(x) - assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps) - assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps) + assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps) + assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps) if __name__ == '__main__':