Skip to content

Commit

Permalink
use the stats for each segment
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jan 20, 2025
1 parent 27071bd commit 001b530
Showing 1 changed file with 32 additions and 28 deletions.
60 changes: 32 additions & 28 deletions src/gluonts/torch/model/seg_diff/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ def compute_loss(
) -> torch.Tensor:
"""Compute flow matching loss."""
batch, seq_len, feat_dim = x_0.shape
# Sample time uniformly
t = torch.rand((batch * seq_len,), device=x_0.device)

# Get path sample from probability path with scheduler outputs
Expand Down Expand Up @@ -538,11 +537,12 @@ def params_from_decoder_output(
(past_observed_values, future_observed_values), dim=1
)

patched_target = self.patch(past_target)

# scale the input
target_scaled, loc, scale = self.scaler(
past_target, past_observed_values
patched_target, ~patched_target.isnan()
)
patched_target = self.patch(target_scaled)

# do patching for time features as well
# if self.num_feat_dynamic_real > 0:
Expand All @@ -553,12 +553,13 @@ def params_from_decoder_output(
log_abs_loc = loc.sign() * loc.abs().log1p()
log_scale = scale.log()

expanded_static_feat = unsqueeze_expand(
torch.cat([log_abs_loc, log_scale], dim=-1),
dim=1,
size=patched_target.shape[1],
)
inputs = torch.cat((patched_target, expanded_static_feat), dim=-1)
# expanded_static_feat = unsqueeze_expand(
# torch.cat([log_abs_loc, log_scale], dim=-1),
# dim=1,
# size=patched_target.shape[1],
# )
static_feat = torch.cat([log_abs_loc, log_scale], dim=-1)
inputs = torch.cat((target_scaled, static_feat), dim=-1)

if future_time_feat is not None:
past_time_feat = torch.cat(
Expand Down Expand Up @@ -592,7 +593,7 @@ def params_from_decoder_output(

# # Project decoder output to condition the flow
# flow_cond = self.proj(dec_out)
return dec_out, loc, scale
return dec_out, target_scaled, loc, scale

def loss(
self,
Expand All @@ -603,7 +604,7 @@ def loss(
past_time_feat: Optional[torch.Tensor] = None,
future_time_feat: Optional[torch.Tensor] = None,
) -> torch.Tensor:
flow_cond, loc, scale = self.params_from_decoder_output(
flow_cond, target_scaled, _, _ = self.params_from_decoder_output(
past_target=past_target,
past_observed_values=past_observed_values,
future_target=future_target,
Expand All @@ -612,13 +613,16 @@ def loss(
future_time_feat=future_time_feat,
)

# Get patches for target
target = self.patch(
(torch.cat((past_target, future_target), dim=1) - loc) / scale
)
# # Get patches for target
# target = self.patch(
# (torch.cat((past_target, future_target), dim=1) - loc) / scale
# )

# Flow matching loss
x_1 = target[:, 1:, :] # Target patches
# Target patches
x_1 = target_scaled[:, 1:, :]

# source distribution
x_0 = torch.randn_like(x_1) # Random noise source distribution
# x_0 = target[:, :-1, :] + torch.randn_like(target[:, :-1, :]) * 0.7

Expand Down Expand Up @@ -648,7 +652,7 @@ def log_prob(
1,
).log_prob

flow_cond, loc, scale = self.params_from_decoder_output(
flow_cond, target_scaled, loc, scale = self.params_from_decoder_output(
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
Expand All @@ -657,12 +661,12 @@ def log_prob(
future_observed_values=future_observed_values,
)
# Get patches for target
target = self.patch(
(torch.cat((past_target, future_target), dim=1) - loc) / scale
)
# target = self.patch(
# (torch.cat((past_target, future_target), dim=1) - loc) / scale
# )

# Flow matching loss
x_1 = target[:, 1:, :] # Target patches
x_1 = target_scaled[:, 1:, :] # Target patches
cond = flow_cond[:, :-1, :]

solver = ODESolver(self.flow.velocity_model)
Expand Down Expand Up @@ -690,14 +694,16 @@ def forward(
num_parallel_samples = self.num_parallel_samples

# Get initial flow conditioning from decoder
flow_cond, loc, scale = self.params_from_decoder_output(
flow_cond, _, past_loc, past_scale = self.params_from_decoder_output(
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=past_time_feat,
future_time_feat=future_time_feat[:, : self.patch_len]
if future_time_feat is not None
else None,
)
loc = past_loc[:, -1, :]
scale = past_scale[:, -1, :]

# Initialize samples for each batch
batch_size = past_target.shape[0]
Expand Down Expand Up @@ -779,7 +785,7 @@ def forward(
else None
)

flow_cond, loc, scale = self.params_from_decoder_output(
flow_cond, _, _, _ = self.params_from_decoder_output(
past_target=repeat_past_target,
past_observed_values=repeat_past_observed_values,
past_time_feat=repeat_past_time_feat
Expand All @@ -790,7 +796,7 @@ def forward(
future_time_feat=current_future_time_feat,
)

# Sample new noise for next patch
# Sample new source sample for next patch
x = torch.randn(
batch_size * num_parallel_samples,
self.patch_len,
Expand Down Expand Up @@ -820,12 +826,10 @@ def forward(
time_grid=T,
)

# Scale and store the sampwg
# Scale and store the samples
next_sample = x.view(
batch_size, num_parallel_samples, self.patch_len
) * scale.view(batch_size, num_parallel_samples, -1) + loc.view(
batch_size, num_parallel_samples, -1
)
) * scale.view(batch_size, 1, -1) + loc.view(batch_size, 1, -1)
future_samples.append(next_sample)
total_samples += self.patch_len

Expand Down

0 comments on commit 001b530

Please sign in to comment.