diff --git a/prototype_source/fx_graph_mode_ptq_dynamic.py b/prototype_source/fx_graph_mode_ptq_dynamic.py index eda88ff5c0..98ece5f3d3 100644 --- a/prototype_source/fx_graph_mode_ptq_dynamic.py +++ b/prototype_source/fx_graph_mode_ptq_dynamic.py @@ -239,9 +239,27 @@ def evaluate(model_, data_source): .set_object_type(nn.LSTM, default_dynamic_qconfig) .set_object_type(nn.Linear, default_dynamic_qconfig) ) -# Deepcopying the original model because quantization api changes the model inplace and we want +# Load model to create the original model because quantization api changes the model inplace and we want # to keep the original model for future comparison -model_to_quantize = copy.deepcopy(model) + + +model_to_quantize = LSTMModel( + ntoken = ntokens, + ninp = 512, + nhid = 256, + nlayers = 5, +) + +model_to_quantize.load_state_dict( + torch.load( + model_data_filepath + 'word_language_model_quantize.pth', + map_location=torch.device('cpu') + ) + ) + +model_to_quantize.eval() + + prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs) print("prepared model:", prepared_model) quantized_model = convert_fx(prepared_model) @@ -289,4 +307,4 @@ def time_model_evaluation(model, test_data): # 3. Conclusion # ------------- # This tutorial introduces the api for post training dynamic quantization in FX Graph Mode, -# which dynamically quantizes the same modules as Eager Mode Quantization. \ No newline at end of file +# which dynamically quantizes the same modules as Eager Mode Quantization.