From ece912828a87196463f2741ca748818fb8972763 Mon Sep 17 00:00:00 2001 From: james <81617086+je-cook@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:02:55 +0000 Subject: [PATCH] broadcasting fix --- src/pyvmcon/vmcon.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 8c9c5e9..1687a32 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -398,8 +398,12 @@ def _derivative_lagrangian( ) -> np.ndarray: ind_eq = min(lamda_equality.shape[0], result.deq.shape[0]) ind_ieq = min(lamda_inequality.shape[0], result.die.shape[0]) - c_equality_prime = (lamda_equality[:ind_eq] * result.deq[:ind_eq]).sum(axis=0) - c_inequality_prime = (lamda_inequality[:ind_ieq] * result.die[:ind_ieq]).sum(axis=0) + c_equality_prime = (lamda_equality[:ind_eq, None] * result.deq[:ind_eq]).sum( + axis=None if ind_eq == 0 else 0 + ) + c_inequality_prime = (lamda_inequality[:ind_ieq, None] * result.die[:ind_ieq]).sum( + axis=None if ind_ieq == 0 else 0 + ) return result.df - c_equality_prime - c_inequality_prime