Skip to content

Commit

Permalink
add compute_likelihood helper
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 23, 2024
1 parent c20232f commit 60b567e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 11 deletions.
4 changes: 2 additions & 2 deletions src/gluonts/torch/model/seg_diff/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def validation_step(self, batch, batch_idx: int): # type: ignore
"""
Execute validation step.
"""
val_loss = self.model.loss(
val_loss = self.model.log_prob(
**select(self.inputs, batch),
future_target=batch["future_target"],
future_observed_values=batch["future_observed_values"],
).mean()

self.log(
"val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True
"val_loss", val_loss, on_epoch=True, on_step=True, prog_bar=True
)
return val_loss

Expand Down
82 changes: 73 additions & 9 deletions src/gluonts/torch/model/seg_diff/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch.distributions import Independent, Normal

from gluonts.core.component import validated
from gluonts.model import Input, InputSpec
Expand Down Expand Up @@ -117,31 +118,34 @@ def __init__(
feat_dim: int,
hidden_dim: int,
time_embed_dim: int = 8,
act_fn_name: str = "gelu",
):
super().__init__()
act_fn = ACT2FN[act_fn_name]

# Time embedding network
self.time_embed = nn.Sequential(
nn.Linear(1, time_embed_dim),
nn.GELU(),
act_fn,
nn.Linear(time_embed_dim, time_embed_dim),
)

# Conditioning network for better feature extraction
self.cond_net = nn.Sequential(
nn.Linear(cond_dim, hidden_dim),
nn.GELU(),
act_fn,
nn.Linear(hidden_dim, hidden_dim),
)

# Main velocity network with skip connections
self.net = nn.Sequential(
nn.Linear(feat_dim + hidden_dim + time_embed_dim, hidden_dim),
nn.GELU(),
act_fn,
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
act_fn,
nn.Dropout(0.1), # Add some regularization
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
act_fn,
nn.Linear(hidden_dim, feat_dim),
)

Expand Down Expand Up @@ -176,11 +180,19 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor):


class Flow(nn.Module):
def __init__(self, cond_dim: int, feat_dim: int, hidden_dim: int):
def __init__(
self,
cond_dim: int,
feat_dim: int,
hidden_dim: int,
act_fn_name: str = "gelu",
):
super().__init__()

# Define MLP for velocity field
self.velocity_model = VelocityModel(cond_dim, feat_dim, hidden_dim)
self.velocity_model = VelocityModel(
cond_dim, feat_dim, hidden_dim, act_fn_name=act_fn_name
)

# Flow matching components
self.prob_path = CondOTProbPath()
Expand Down Expand Up @@ -425,7 +437,7 @@ def __init__(
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
activation="gelu",
layer_norm_eps=layer_norm_eps,
batch_first=True,
norm_first=norm_first,
Expand All @@ -437,7 +449,10 @@ def __init__(
)

self.flow = Flow(
cond_dim=d_model, feat_dim=patch_len, hidden_dim=flow_hidden_dim
cond_dim=d_model,
feat_dim=patch_len,
hidden_dim=flow_hidden_dim,
act_fn_name=activation,
)

def describe_inputs(self, batch_size=1) -> InputSpec:
Expand Down Expand Up @@ -560,6 +575,55 @@ def loss(
x_1=x_1, x_0=x_0, cond=flow_cond[:, :-1, :]
)

def log_prob(
self,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_target: torch.Tensor,
future_observed_values: torch.Tensor,
past_time_feat: Optional[torch.Tensor] = None,
future_time_feat: Optional[torch.Tensor] = None,
) -> torch.Tensor:
device = past_target.device
# gaussian_log_density = MultivariateNormal(
# torch.zeros(self.patch_len, device=device),
# torch.eye(self.patch_len, device=device),
# ).log_prob
gaussian_log_density = Independent(
Normal(
torch.zeros(self.patch_len, device=device),
torch.ones(self.patch_len, device=device),
),
1,
).log_prob

flow_cond, loc, scale = self.params_from_decoder_output(
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat,
future_target=future_target,
future_observed_values=future_observed_values,
)
# Get patches for target
target = self.patch(
(torch.cat((past_target, future_target), dim=1) - loc) / scale
)

# Flow matching loss
x_1 = target[:, 1:, :] # Target patches
cond = flow_cond[:, :-1, :]

_, exact_log_p = self.flow.solver.compute_likelihood(
x_1=x_1.reshape(-1, self.patch_len),
cond=cond.reshape(-1, self.d_model),
method="midpoint",
step_size=0.05,
exact_divergence=True,
log_p0=gaussian_log_density,
)
return -exact_log_p.mean()

def forward(
self,
past_target: torch.Tensor,
Expand Down

0 comments on commit 60b567e

Please sign in to comment.