Skip to content

Commit

Permalink
fix acc diff problem caused by pr PaddlePaddle#44116 (PaddlePaddle#44311
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ZhangHandi authored and ceci3 committed Aug 4, 2022
1 parent 36e5e49 commit 3f459e3
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,11 +1440,18 @@ def apply(self, graph):
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue

scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())
try:
graph._find_node_by_name(
graph.all_var_nodes(),
self._scale_name(in_node.name()))
continue
except:
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())

data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type),
Expand Down

0 comments on commit 3f459e3

Please sign in to comment.