Skip to content

Commit

Permalink
Fix optional array arguments in class constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanocortinovis committed Nov 4, 2024
1 parent d92f222 commit 5560f00
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
3 changes: 2 additions & 1 deletion gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,8 @@ def __init__(
"""
super().__init__(prior=prior, likelihood=likelihood, jitter=jitter)

latent = latent or jr.normal(key, shape=(self.likelihood.num_datapoints, 1))
if latent is None:
latent = jr.normal(key, shape=(self.likelihood.num_datapoints, 1))

# TODO: static or intermediate?
self.latent = latent if isinstance(latent, Parameter) else Real(latent)
Expand Down
43 changes: 24 additions & 19 deletions gpjax/variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,14 @@ def __init__(
):
super().__init__(posterior, inducing_inputs, jitter)

self.variational_mean = Real(
variational_mean or jnp.zeros((self.num_inducing, 1))
)
self.variational_root_covariance = LowerTriangular(
variational_root_covariance or jnp.eye(self.num_inducing)
)
if variational_mean is None:
variational_mean = jnp.zeros((self.num_inducing, 1))

if variational_root_covariance is None:
variational_root_covariance = jnp.eye(self.num_inducing)

self.variational_mean = Real(variational_mean)
self.variational_root_covariance = LowerTriangular(variational_root_covariance)

def prior_kl(self) -> ScalarFloat:
r"""Compute the prior KL divergence.
Expand Down Expand Up @@ -378,12 +380,14 @@ def __init__(
):
super().__init__(posterior, inducing_inputs, jitter)

self.natural_vector = Static(
natural_vector or jnp.zeros((self.num_inducing, 1))
)
self.natural_matrix = Static(
natural_matrix or -0.5 * jnp.eye(self.num_inducing)
)
if natural_vector is None:
natural_vector = jnp.zeros((self.num_inducing, 1))

if natural_matrix is None:
natural_matrix = -0.5 * jnp.eye(self.num_inducing)

self.natural_vector = Static(natural_vector)
self.natural_matrix = Static(natural_matrix)

def prior_kl(self) -> ScalarFloat:
r"""Compute the KL-divergence between our current variational approximation
Expand Down Expand Up @@ -540,13 +544,14 @@ def __init__(
):
super().__init__(posterior, inducing_inputs, jitter)

# must come after super().__init__
self.expectation_vector = Static(
expectation_vector or jnp.zeros((self.num_inducing, 1))
)
self.expectation_matrix = Static(
expectation_matrix or jnp.eye(self.num_inducing)
)
if expectation_vector is None:
expectation_vector = jnp.zeros((self.num_inducing, 1))

if expectation_matrix is None:
expectation_matrix = jnp.eye(self.num_inducing)

self.expectation_vector = Static(expectation_vector)
self.expectation_matrix = Static(expectation_matrix)

def prior_kl(self) -> ScalarFloat:
r"""Evaluate the prior KL-divergence.
Expand Down

0 comments on commit 5560f00

Please sign in to comment.