From 7e04c9e67dca1303c9d1fa912047587efe5cd9c4 Mon Sep 17 00:00:00 2001 From: gileshd Date: Fri, 20 Sep 2024 16:23:45 +0100 Subject: [PATCH] Add further type annotations to arhmm --- dynamax/hidden_markov_model/models/arhmm.py | 28 +++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/dynamax/hidden_markov_model/models/arhmm.py b/dynamax/hidden_markov_model/models/arhmm.py index e6614e10..b7d1fa61 100644 --- a/dynamax/hidden_markov_model/models/arhmm.py +++ b/dynamax/hidden_markov_model/models/arhmm.py @@ -1,8 +1,10 @@ +from typing import NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr from jax import lax from jax.tree_util import tree_map -from jaxtyping import Float, Array +from jaxtyping import Int, Float, Array + from dynamax.hidden_markov_model.models.abstractions import HMM, HMMParameterSet, HMMPropertySet from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions @@ -11,7 +13,6 @@ from dynamax.types import Scalar from dynamax.utils.bijectors import RealToPSDBijector from tensorflow_probability.substrates import jax as tfp -from typing import NamedTuple, Optional, Tuple, Union tfd = tfp.distributions tfb = tfp.bijectors @@ -25,21 +26,22 @@ class ParamsLinearAutoregressiveHMM(NamedTuple): class LinearAutoregressiveHMMEmissions(LinearRegressionHMMEmissions): def __init__(self, - num_states, - emission_dim, - num_lags=1): + num_states: int, + emission_dim: int, + num_lags: int=1): self.num_lags = num_lags self.emission_dim = emission_dim input_dim = num_lags * emission_dim super().__init__(num_states, input_dim, emission_dim) def initialize(self, - key=jr.PRNGKey(0), - method="prior", - emission_weights=None, - emission_biases=None, - emission_covariances=None, - emissions=None): + key: Array=jr.PRNGKey(0), + method: str="prior", + emission_weights: Optional[Float[Array, "num_states emission_dim input_dim"]]=None, + emission_biases: Optional[Float[Array, "num_states emission_dim"]]=None, + emission_covariances: Optional[Float[Array, "num_states emission_dim emission_dim"]]=None, + emissions: Optional[Float[Array, "num_timesteps emission_dim"]]=None + ) -> Tuple[ParamsLinearRegressionHMMEmissions, ParamsLinearRegressionHMMEmissions]: if method.lower() == "kmeans": assert emissions is not None, "Need emissions to initialize the model with K-Means!" from sklearn.cluster import KMeans @@ -166,7 +168,7 @@ def sample(self, key: Array, num_timesteps: int, prev_emissions: Optional[Float[Array, "num_lags emission_dim"]]=None, - ) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: + ) -> Tuple[Int[Array, " num_timesteps"], Float[Array, "num_timesteps emission_dim"]]: r"""Sample states $z_{1:T}$ and emissions $y_{1:T}$ given parameters $\theta$. Args: @@ -211,7 +213,7 @@ def _step(carry, key): def compute_inputs(self, emissions: Float[Array, "num_timesteps emission_dim"], prev_emissions: Optional[Float[Array, "num_lags emission_dim"]]=None - ) -> Float[Array, "num_timesteps emission_dim_times_num_lags"]: + ) -> Float[Array, "num_timesteps {self.num_lags}*{self.emission_dim}"]: r"""Helper function to compute the matrix of lagged emissions. Args: