Skip to content

Commit

Permalink
flow_matching lib
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 13, 2024
1 parent b9b5a45 commit 469fde3
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 48 deletions.
3 changes: 3 additions & 0 deletions src/gluonts/torch/model/seg_diff/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
nhead: int = 4,
dim_feedforward: int = 128,
num_feat_dynamic_real: int = 0,
n_steps: int = 10,
dropout: float = 0.1,
activation: str = "relu",
norm_first: bool = False,
Expand All @@ -138,6 +139,7 @@ def __init__(
self.context_length = patch_len * context_length_multiplier
self.context_length_multiplier = context_length_multiplier
self.prediction_length = prediction_length
self.n_steps = n_steps

self.lr = lr
self.weight_decay = weight_decay
Expand Down Expand Up @@ -203,6 +205,7 @@ def create_lightning_module(self) -> pl.LightningModule:
"num_decoder_layers": self.num_decoder_layers,
# "distr_output": self.distr_output,
"scaling": self.scaling,
"n_steps": self.n_steps,
},
)

Expand Down
157 changes: 109 additions & 48 deletions src/gluonts/torch/model/seg_diff/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

from gluonts.core.component import validated
from gluonts.model import Input, InputSpec
from gluonts.torch.distributions import StudentTOutput
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average
from gluonts.torch.model.simple_feedforward import make_linear_layer
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, PolynomialConvexScheduler, VPScheduler, LinearVPScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import ODESolver


class ClassInstantier(OrderedDict):
Expand Down Expand Up @@ -108,51 +110,116 @@ def forward(self, x: torch.Tensor):
return self.layer_norm(out)
return out

class VelocityModel(nn.Module):
def __init__(self, cond_dim: int, out_dim: int, h: int, time_embed_dim: int = 8):
super().__init__()
# Time embedding network
self.time_embed = nn.Sequential(
nn.Linear(1, time_embed_dim),
nn.GELU(),
nn.Linear(time_embed_dim, time_embed_dim),
)

# Conditioning network for better feature extraction
self.cond_net = nn.Sequential(
nn.Linear(cond_dim, h),
nn.GELU(),
nn.Linear(h, h)
)

# Main velocity network with skip connections
self.net = nn.Sequential(
nn.Linear(out_dim + h + time_embed_dim, h),
nn.GELU(),
nn.Linear(h, h),
nn.GELU(),
nn.Dropout(0.1), # Add some regularization
nn.Linear(h, h),
nn.GELU(),
nn.Linear(h, out_dim)
)

# Initialize weights for better gradient flow
self.apply(self._init_weights)

def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
module.bias.data.zero_()

def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor):
# Handle different time tensor shapes
if t.ndim == 0: # scalar time
t = t.view(1)

# Expand t to match batch dimensions of x
while t.ndim < x.ndim:
t = t.unsqueeze(-1)
t = t.expand(*x.shape[:-1], 1)

# Get time embeddings
t_embed = self.time_embed(t)

# Process conditioning
cond_features = self.cond_net(cond)

# Concatenate all inputs and compute velocity
inputs = torch.cat([x, t_embed, cond_features], dim=-1)
return self.net(inputs)

class Flow(nn.Module):
def __init__(self, cond_dim: int, out_dim: int, h: int):
super().__init__()

self.linear = nn.Linear(out_dim + cond_dim + 1, h)
self.act = ACT2FN["gelu"]
self.output_layer = nn.Linear(h, out_dim)

def forward(self, x_t: torch.Tensor, t: torch.Tensor, cond: torch.Tensor):
x = torch.cat((x_t, t, cond), dim=-1)
x = self.linear(x)
x = self.act(x)
x = self.output_layer(x)
return x

# Define MLP for velocity field
self.velocity_model = VelocityModel(cond_dim, out_dim, h)

# Flow matching components
scheduler = CondOTScheduler()
self.prob_path = AffineProbPath(scheduler=scheduler)
self.solver = ODESolver(self.velocity_model)

def compute_loss(self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
"""Compute flow matching loss."""
# Sample time uniformly
t = torch.rand(x_0.shape[0], x_0.shape[1], device=x_0.device)

# Get path sample from probability path with scheduler outputs
path_sample = self.prob_path.sample(
t=t,
x_0=x_0,
x_1=x_1,
)

# Get velocity field prediction
v_t = self.velocity_model(path_sample.x_t, path_sample.t, cond)

# Flow matching loss
return F.mse_loss(v_t, path_sample.dx_t)

@torch.inference_mode()
def step(
self,
x_t: torch.Tensor,
x_t: torch.Tensor,
t_start: float,
t_end: float,
cond: torch.Tensor,
) -> torch.Tensor:
"""Performs one step of the flow matching process.
Args:
x_t: Input tensor to evolve
t_start: Starting time
t_end: Ending time
cond: Conditioning tensor from transformer decoder
"""
# Expand t_start to match batch dimension
t_start = torch.full((x_t.shape[0], 1), t_start, device=x_t.device)
t_mid = t_start + (t_end - t_start) / 2

# First half step
v1 = self(x_t=x_t, t=t_start, cond=cond)
x_mid = x_t + v1 * (t_end - t_start) / 2

# Second half step
v2 = self(x_t=x_mid, t=t_mid, cond=cond)
x_end = x_t + v2 * (t_end - t_start)

return x_end
"""Performs one step of the flow matching process using ODE solver."""
# Create time grid for integration
T = torch.tensor([t_start, t_end], device=x_t.device)

# Solve ODE
sol = self.solver.sample(
x_init=x_t,
time_grid=T,
method='midpoint',
step_size=t_end - t_start,
cond=cond # Pass conditioning as context
)

return sol


class SegDiffModel(nn.Module):
Expand Down Expand Up @@ -188,6 +255,7 @@ def __init__(
dropout_rate: float = 0.1,
num_parallel_samples: int = 100,
flow_hidden_dim: int = 64,
n_steps: int = 10,
) -> None:
super().__init__()

Expand All @@ -198,6 +266,7 @@ def __init__(
self.d_model = d_model
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_parallel_samples = num_parallel_samples
self.n_steps = n_steps

if scaling == "mean":
self.scaler = MeanScaler(keepdim=True)
Expand Down Expand Up @@ -354,16 +423,8 @@ def loss(
# Flow matching loss
x_1 = target[:, 1:, :] # Target patches
x_0 = torch.randn_like(x_1) # Random noise source distribution
# t is a tensor of shape (batch_size, num_patches, 1)
t = torch.rand((x_1.shape[0], x_1.shape[1], 1), device=x_1.device)

x_t = (1 - t) * x_0 + t * x_1
dx_t = x_1 - x_0

# Condition flow on decoder output
flow_out = self.flow(t=t, x_t=x_t, cond=flow_cond[:, :-1, :])

return F.mse_loss(flow_out, dx_t)

return self.flow.compute_loss(x_0=x_0, x_1=x_1, cond=flow_cond[:, :-1, :])

def forward(
self,
Expand Down Expand Up @@ -394,16 +455,16 @@ def forward(
)

# Setup time steps for flow
n_steps = 8
time_steps = torch.linspace(0, 1.0, n_steps + 1, device=x.device)

time_steps = torch.linspace(0, 1.0, self.n_steps + 1, device=x.device)

# Get last decoder output and repeat for parallel samples
last_cond = flow_cond[:, -1, :].repeat_interleave(
num_parallel_samples, dim=0
)

# Evolve the samples through time using the flow
for i in range(n_steps):
for i in range(self.n_steps):
x = self.flow.step(
x_t=x,
t_start=time_steps[i],
Expand Down Expand Up @@ -468,7 +529,7 @@ def forward(
last_cond = flow_cond[:, -1, :]

# Evolve the new samples
for i in range(n_steps):
for i in range(self.n_steps):
x = self.flow.step(
x_t=x,
t_start=time_steps[i],
Expand Down

0 comments on commit 469fde3

Please sign in to comment.