diff --git a/prototype_source/fx_graph_mode_ptq_dynamic.py b/prototype_source/fx_graph_mode_ptq_dynamic.py index a256c602c7..27a7116d52 100644 --- a/prototype_source/fx_graph_mode_ptq_dynamic.py +++ b/prototype_source/fx_graph_mode_ptq_dynamic.py @@ -59,6 +59,7 @@ from io import open import time import copy +import marshal import torch import torch.nn as nn @@ -238,9 +239,9 @@ def evaluate(model_, data_source): .set_object_type(nn.LSTM, default_dynamic_qconfig) .set_object_type(nn.Linear, default_dynamic_qconfig) ) -# Deepcopying the original model because quantization api changes the model inplace and we want +# Deepcopying the original using native python method marshal model because quantization api changes the model inplace and we want # to keep the original model for future comparison -model_to_quantize = copy.deepcopy(model) +model_to_quantize = marshal.loads(marshal.dumps(model)) prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) print("prepared model:", prepared_model) quantized_model = convert_fx(prepared_model)