From bb32b1c9e1af2dd77029287132f5d3768a750397 Mon Sep 17 00:00:00 2001 From: brentyi Date: Fri, 13 Sep 2024 01:02:33 -0700 Subject: [PATCH] Fix (critical!) variable ordering bug --- src/jaxls/_variables.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/jaxls/_variables.py b/src/jaxls/_variables.py index 21b9d34..fb4b823 100644 --- a/src/jaxls/_variables.py +++ b/src/jaxls/_variables.py @@ -47,7 +47,9 @@ def ordered_dict_items[T]( self, var_type_mapping: dict[type[Var[Any]], T], ) -> list[tuple[type[Var[Any]], T]]: - return sorted(var_type_mapping.items(), key=lambda x: x[0]) + return sorted( + var_type_mapping.items(), key=lambda x: self.order_from_type[x[0]] + ) @jdc.pytree_dataclass