Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix infinite variance handling in GaussEst, DiscreteEst, LinEst and ReLUEst #13

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/test_estim/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def discrete_test(zshape=(1000,10), verbose=False, nvals=10,\
raise vp.common.TestException(\
"Initial variance does not match expected value")

# Infinite variance case should match the initial estimate
hat, zhatvar = est.est(r,np.Inf,return_cost=False)
if not np.allclose(hat, zmean) or not np.allclose(zhatvar, zvar):
raise vp.common.TestException("Infinite variance estimate does not match initial estimate")

# Get posterior estimate
zhat, zhatvar, cost = est.est(r,rvar,return_cost=True)

Expand Down
5 changes: 5 additions & 0 deletions test/test_estim/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def gauss_test(zshape=(1000,10), verbose=False, tol=0.1):
raise vp.common.TestException("Initial estimate Gaussian error "+
"does not match predicted value")

# Infinite variance case should match the initial estimate
zhat, zhatvar = est.est(r,np.Inf,return_cost=False)
if not np.allclose(zhat, zmean1) or not np.allclose(zhatvar, zvar1):
raise vp.common.TestException("Infinite variance estimate does not match initial estimate")

# Posterior estimate
zhat, zhatvar, cost = est.est(r,rvar,return_cost=True)
zerr = np.mean(np.abs(z-zhat)**2)
Expand Down
5 changes: 5 additions & 0 deletions test/test_estim/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def lin_test(zshape=(500,10),Ashape=(1000,500),verbose=False,tol=0.1):
raise vp.common.TestException(\
"est_init does not produce the correct shape")

# Infinite variance case should match the initial estimate
zhat_inf, zhatvar_inf = est.est(r,np.Inf,return_cost=False)
if not np.allclose(zhat_inf, zhat) or not np.allclose(zhatvar_inf, zhatvar):
raise vp.common.TestException("Infinite variance estimate does not match initial estimate")

# Posterior estimate
zhat, zhatvar, cost = est.est(r,rvar,return_cost=True)
zerr = np.mean(np.abs(z-zhat)**2)
Expand Down
8 changes: 6 additions & 2 deletions test/test_estim/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import vampyre as vp


def relu_test(zshape=(1000,10),tol=0.15,verbose=False,map_est=False):
def relu_test(zshape=(1000,10),tol=0.15,verbose=False,map_est=False,rvar1=None):
"""
ReLUEstim unit test.

Expand All @@ -33,7 +33,8 @@ def relu_test(zshape=(1000,10),tol=0.15,verbose=False,map_est=False):

# Set random parameters
rvar0 = np.power(10,np.random.uniform(-2,1,ns))
rvar1 = np.power(10,np.random.uniform(-2,1,ns))
if rvar1 is None:
rvar1 = np.power(10,np.random.uniform(-2,1,ns))

# Construct random input
r0 = np.random.normal(0,1,zshape)
Expand Down Expand Up @@ -85,6 +86,9 @@ def test_relu(self):

# MAP test. we use a higher tolerance, since the error is approximate
relu_test(verbose=verbose,map_est=True,tol=1)

# MMSE test with infinite rvar1 variance
relu_test(verbose=verbose,map_est=False,tol=0.15,zshape=(1000,10),rvar1=np.inf*np.ones(10))

if __name__ == '__main__':
unittest.main()
Expand Down
2 changes: 1 addition & 1 deletion vampyre/estim/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def est(self,r,rvar,return_cost=False,ind_out=None,avg_var_cost=True):

# Infinite variance case
if np.any(rvar==np.Inf):
return self.est_init(return_cost, avg_var_cost)
return self.est_init(return_cost=return_cost, ind_out=ind_out, avg_var_cost=avg_var_cost)


# Convert to 1D vectors
Expand Down
2 changes: 1 addition & 1 deletion vampyre/estim/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def est(self,r,rvar,return_cost=False,ind_out=None,avg_var_cost=True):

# Infinite variance case
if np.any(rvar==np.Inf):
return self.est_init(return_cost, avg_var_cost)
return self.est_init(return_cost=return_cost, ind_out=ind_out, avg_var_cost=avg_var_cost)

zhatvar = rvar*self.zvar/(rvar + self.zvar)
gain = self.zvar/(rvar + self.zvar)
Expand Down
4 changes: 4 additions & 0 deletions vampyre/estim/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def est(self,r,rvar,return_cost=False, ind_out=None,\
if not avg_var_cost:
raise ValueError("disabling variance averaging not supported for LinEst")

# Infinite variance case
if np.any(rvar==np.Inf):
return self.est_init(return_cost=return_cost, ind_out=ind_out, avg_var_cost=avg_var_cost)


# Get the diagonal parameters
s, sshape, srep_axes = self.A.get_svd_diag()
Expand Down
10 changes: 5 additions & 5 deletions vampyre/estim/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ def est_mmse(self,r,rvar,return_cost,ind_out):
rvar0 = common.repeat_axes(rvar0,self.shape[0],self.var_axes[0])
rvar1 = common.repeat_axes(rvar1,self.shape[1],self.var_axes[1])

# Compute the MAP estimate
zhat_map, zvar_map = self.est_map(r,rvar,return_cost=False,ind_out=[0,1])
zhat0_map, zhat1_map = zhat_map
zvar0_map, zvar1_map = zvar_map

if np.any(rvar1 == np.Inf):
# Infinite variance case.
zvarp = rvar0
Expand All @@ -204,11 +209,6 @@ def est_mmse(self,r,rvar,return_cost,ind_out):
Amax = 0

else:

# Compute the MAP estimate
zhat_map, zvar_map = self.est_map(r,rvar,return_cost=False,ind_out=[0,1])
zhat0_map, zhat1_map = zhat_map
zvar0_map, zvar1_map = zvar_map

# Compute the conditional Gaussian terms for z > 0 and z < 0
zvarp = rvar0*rvar1/(rvar0+rvar1)
Expand Down