diff --git a/test/run_tests_model_gen_and_load.py b/test/run_tests_model_gen_and_load.py index 2315d976d71..aff79e99163 100644 --- a/test/run_tests_model_gen_and_load.py +++ b/test/run_tests_model_gen_and_load.py @@ -179,7 +179,15 @@ def load_model( weights_dir = working_dir / "test_weights" weights_dir.mkdir(parents=True, exist_ok=True) weight_file = str(weights_dir / f"weights_{test_id}.json") + + # Load old weights old_weights = json.load(open(weight_file)) + + # Sort the weights inside both dictionaries based on 'index' + new_weights["weights"].sort(key=lambda x: x["index"]) + old_weights["weights"].sort(key=lambda x: x["index"]) + + # Assert if both sorted weights are equal assert new_weights == old_weights vw.finish() except Exception as e: