From 8f060e15c874ea26307e0395d4b795e2565fead0 Mon Sep 17 00:00:00 2001 From: Julian Winkler Date: Sat, 29 Jun 2024 18:00:03 +0200 Subject: [PATCH 1/2] fix infinite variance handling in GaussEst, DiscreteEst and LinEst The infinite variance handling in GaussEst and DiscreteEst was broken, because it passed avg_var_cost into the ind_out parameter of est_init(). Infinite variance handling was also added to LinEst and the unit tests for all three estimators have been extended to test the infinite variance cases. --- test/test_estim/test_discrete.py | 5 +++++ test/test_estim/test_gaussian.py | 5 +++++ test/test_estim/test_linear.py | 5 +++++ vampyre/estim/discrete.py | 2 +- vampyre/estim/gaussian.py | 2 +- vampyre/estim/linear.py | 4 ++++ 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/test/test_estim/test_discrete.py b/test/test_estim/test_discrete.py index 87f3198..c9ee57d 100644 --- a/test/test_estim/test_discrete.py +++ b/test/test_estim/test_discrete.py @@ -79,6 +79,11 @@ def discrete_test(zshape=(1000,10), verbose=False, nvals=10,\ raise vp.common.TestException(\ "Initial variance does not match expected value") + # Infinite variance case should match the initial estimate + hat, zhatvar = est.est(r,np.Inf,return_cost=False) + if not np.allclose(hat, zmean) or not np.allclose(zhatvar, zvar): + raise vp.common.TestException("Infinite variance estimate does not match initial estimate") + # Get posterior estimate zhat, zhatvar, cost = est.est(r,rvar,return_cost=True) diff --git a/test/test_estim/test_gaussian.py b/test/test_estim/test_gaussian.py index c497967..78fedcd 100644 --- a/test/test_estim/test_gaussian.py +++ b/test/test_estim/test_gaussian.py @@ -51,6 +51,11 @@ def gauss_test(zshape=(1000,10), verbose=False, tol=0.1): raise vp.common.TestException("Initial estimate Gaussian error "+ "does not match predicted value") + # Infinite variance case should match the initial estimate + zhat, zhatvar = est.est(r,np.Inf,return_cost=False) + if not np.allclose(zhat, zmean1) or not np.allclose(zhatvar, zvar1): + raise vp.common.TestException("Infinite variance estimate does not match initial estimate") + # Posterior estimate zhat, zhatvar, cost = est.est(r,rvar,return_cost=True) zerr = np.mean(np.abs(z-zhat)**2) diff --git a/test/test_estim/test_linear.py b/test/test_estim/test_linear.py index d33264f..ac92cf0 100644 --- a/test/test_estim/test_linear.py +++ b/test/test_estim/test_linear.py @@ -56,6 +56,11 @@ def lin_test(zshape=(500,10),Ashape=(1000,500),verbose=False,tol=0.1): raise vp.common.TestException(\ "est_init does not produce the correct shape") + # Infinite variance case should match the initial estimate + zhat_inf, zhatvar_inf = est.est(r,np.Inf,return_cost=False) + if not np.allclose(zhat_inf, zhat) or not np.allclose(zhatvar_inf, zhatvar): + raise vp.common.TestException("Infinite variance estimate does not match initial estimate") + # Posterior estimate zhat, zhatvar, cost = est.est(r,rvar,return_cost=True) zerr = np.mean(np.abs(z-zhat)**2) diff --git a/vampyre/estim/discrete.py b/vampyre/estim/discrete.py index 6323ef9..9aa0e26 100644 --- a/vampyre/estim/discrete.py +++ b/vampyre/estim/discrete.py @@ -118,7 +118,7 @@ def est(self,r,rvar,return_cost=False,ind_out=None,avg_var_cost=True): # Infinite variance case if np.any(rvar==np.Inf): - return self.est_init(return_cost, avg_var_cost) + return self.est_init(return_cost=return_cost, ind_out=ind_out, avg_var_cost=avg_var_cost) # Convert to 1D vectors diff --git a/vampyre/estim/gaussian.py b/vampyre/estim/gaussian.py index b545f79..2b22623 100644 --- a/vampyre/estim/gaussian.py +++ b/vampyre/estim/gaussian.py @@ -133,7 +133,7 @@ def est(self,r,rvar,return_cost=False,ind_out=None,avg_var_cost=True): # Infinite variance case if np.any(rvar==np.Inf): - return self.est_init(return_cost, avg_var_cost) + return self.est_init(return_cost=return_cost, ind_out=ind_out, avg_var_cost=avg_var_cost) zhatvar = rvar*self.zvar/(rvar + self.zvar) gain = self.zvar/(rvar + self.zvar) diff --git a/vampyre/estim/linear.py b/vampyre/estim/linear.py index d656599..ab52ff7 100644 --- a/vampyre/estim/linear.py +++ b/vampyre/estim/linear.py @@ -182,6 +182,10 @@ def est(self,r,rvar,return_cost=False, ind_out=None,\ if not avg_var_cost: raise ValueError("disabling variance averaging not supported for LinEst") + # Infinite variance case + if np.any(rvar==np.Inf): + return self.est_init(return_cost=return_cost, ind_out=ind_out, avg_var_cost=avg_var_cost) + # Get the diagonal parameters s, sshape, srep_axes = self.A.get_svd_diag() From 584ba5691547d06e88610b142fd9113fe9f37ac9 Mon Sep 17 00:00:00 2001 From: Julian Winkler Date: Sun, 30 Jun 2024 12:50:01 +0200 Subject: [PATCH 2/2] fix infinite variance handling of ReLUEst The infinite variance code path was broken, because the zhat0_map and related variables did not get set in this code path. The relu unit tests have been extendet to test this code path. --- test/test_estim/test_relu.py | 8 ++++++-- vampyre/estim/relu.py | 10 +++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/test/test_estim/test_relu.py b/test/test_estim/test_relu.py index 35cff50..441a117 100644 --- a/test/test_estim/test_relu.py +++ b/test/test_estim/test_relu.py @@ -13,7 +13,7 @@ import vampyre as vp -def relu_test(zshape=(1000,10),tol=0.15,verbose=False,map_est=False): +def relu_test(zshape=(1000,10),tol=0.15,verbose=False,map_est=False,rvar1=None): """ ReLUEstim unit test. @@ -33,7 +33,8 @@ def relu_test(zshape=(1000,10),tol=0.15,verbose=False,map_est=False): # Set random parameters rvar0 = np.power(10,np.random.uniform(-2,1,ns)) - rvar1 = np.power(10,np.random.uniform(-2,1,ns)) + if rvar1 is None: + rvar1 = np.power(10,np.random.uniform(-2,1,ns)) # Construct random input r0 = np.random.normal(0,1,zshape) @@ -85,6 +86,9 @@ def test_relu(self): # MAP test. we use a higher tolerance, since the error is approximate relu_test(verbose=verbose,map_est=True,tol=1) + + # MMSE test with infinite rvar1 variance + relu_test(verbose=verbose,map_est=False,tol=0.15,zshape=(1000,10),rvar1=np.inf*np.ones(10)) if __name__ == '__main__': unittest.main() diff --git a/vampyre/estim/relu.py b/vampyre/estim/relu.py index d2e14ed..95a9cff 100644 --- a/vampyre/estim/relu.py +++ b/vampyre/estim/relu.py @@ -193,6 +193,11 @@ def est_mmse(self,r,rvar,return_cost,ind_out): rvar0 = common.repeat_axes(rvar0,self.shape[0],self.var_axes[0]) rvar1 = common.repeat_axes(rvar1,self.shape[1],self.var_axes[1]) + # Compute the MAP estimate + zhat_map, zvar_map = self.est_map(r,rvar,return_cost=False,ind_out=[0,1]) + zhat0_map, zhat1_map = zhat_map + zvar0_map, zvar1_map = zvar_map + if np.any(rvar1 == np.Inf): # Infinite variance case. zvarp = rvar0 @@ -204,11 +209,6 @@ def est_mmse(self,r,rvar,return_cost,ind_out): Amax = 0 else: - - # Compute the MAP estimate - zhat_map, zvar_map = self.est_map(r,rvar,return_cost=False,ind_out=[0,1]) - zhat0_map, zhat1_map = zhat_map - zvar0_map, zvar1_map = zvar_map # Compute the conditional Gaussian terms for z > 0 and z < 0 zvarp = rvar0*rvar1/(rvar0+rvar1)