Skip to content

Commit

Permalink
add sde transport
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 13, 2024
1 parent a36c562 commit 8d5f4b8
Show file tree
Hide file tree
Showing 6 changed files with 1,066 additions and 35 deletions.
197 changes: 162 additions & 35 deletions src/gluonts/torch/model/seg_diff/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@
from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler
from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average
from gluonts.torch.model.simple_feedforward import make_linear_layer
from flow_matching.path.scheduler import (
CondOTScheduler,
CosineScheduler,
PolynomialConvexScheduler,
VPScheduler,
LinearVPScheduler,
)
from flow_matching.path import AffineProbPath
from flow_matching.solver import ODESolver

# from flow_matching.path.scheduler import (
# CondOTScheduler,
# CosineScheduler,
# PolynomialConvexScheduler,
# VPScheduler,
# LinearVPScheduler,
# )
# from flow_matching.path import AffineProbPath
# from flow_matching.solver import ODESolver
from .transport import Transport, ModelType, PathType, WeightType, Sampler


class ClassInstantier(OrderedDict):
Expand Down Expand Up @@ -176,33 +178,161 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor):
return self.net(inputs)


# class Flow(nn.Module):
# def __init__(self, cond_dim: int, out_dim: int, h: int):
# super().__init__()

# # Define MLP for velocity field
# self.velocity_model = VelocityModel(cond_dim, out_dim, h)

# # Flow matching components
# scheduler = CondOTScheduler()
# self.prob_path = AffineProbPath(scheduler=scheduler)
# self.solver = ODESolver(self.velocity_model)

# def compute_loss(
# self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor
# ) -> torch.Tensor:
# """Compute flow matching loss."""
# # Sample time uniformly
# t = torch.rand(x_0.shape[0], x_0.shape[1], device=x_0.device)

# # Get path sample from probability path with scheduler outputs
# path_sample = self.prob_path.sample(t=t, x_0=x_0, x_1=x_1)

# # Get velocity field prediction
# v_t = self.velocity_model(path_sample.x_t, path_sample.t, cond)

# # Flow matching loss
# return F.mse_loss(v_t, path_sample.dx_t)


class Flow(nn.Module):
def __init__(self, cond_dim: int, out_dim: int, h: int):
super().__init__()

# Define MLP for velocity field
self.velocity_model = VelocityModel(cond_dim, out_dim, h)

# Flow matching components
scheduler = CondOTScheduler()
self.prob_path = AffineProbPath(scheduler=scheduler)
self.solver = ODESolver(self.velocity_model)
# Create transport object with velocity model type and linear path
self.transport = Transport(
model_type=ModelType.VELOCITY,
path_type=PathType.LINEAR,
loss_type=WeightType.NONE,
train_eps=0.0,
sample_eps=0.0,
)

# Create sampler for generating samples
self.sampler = Sampler(self.transport)

def compute_loss(
self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor
) -> torch.Tensor:
"""Compute flow matching loss."""
# Sample time uniformly
t = torch.rand(x_0.shape[0], x_0.shape[1], device=x_0.device)

# Get path sample from probability path with scheduler outputs
path_sample = self.prob_path.sample(t=t, x_0=x_0, x_1=x_1)

# Get velocity field prediction
v_t = self.velocity_model(path_sample.x_t, path_sample.t, cond)

# Flow matching loss
return F.mse_loss(v_t, path_sample.dx_t)
# Use transport training loss
terms = self.transport.training_losses(
model=self.velocity_model,
x1=x_1,
model_kwargs={"cond": cond},
)
return terms["loss"].mean()


# def sample(
# self,
# x_init: torch.Tensor,
# cond: torch.Tensor,
# method: str = "dopri5",
# step_size: float = 0.05,
# return_intermediates: bool = False,
# time_grid: Optional[torch.Tensor] = None,
# ) -> torch.Tensor:
# """
# Generate samples using the ODE solver.

# Args:
# x_init: Initial noise tensor
# cond: Conditioning tensor
# method: ODE solver method ('dopri5', 'euler', 'heun', etc.)
# step_size: Step size for fixed-step solvers
# return_intermediates: Whether to return intermediate states
# time_grid: Optional time points for sampling. If None, uses default grid

# Returns:
# Generated samples
# """
# if time_grid is None:
# time_grid = torch.linspace(0, 1.0, int(1.0/step_size) + 1, device=x_init.device)

# # Get ODE sampler with specified method
# ode_sampler = self.sampler.sample_ode(
# sampling_method=method,
# num_steps=len(time_grid) if method in ['euler', 'heun'] else 50,
# atol=1e-5,
# rtol=1e-5,
# )

# # Sample using the velocity model
# samples = ode_sampler(
# x_init,
# model=self.velocity_model,
# cond=cond,
# )

# if return_intermediates:
# return samples
# return samples[-1] # Return only final state if intermediates not requested

def sample(
self,
x_init: torch.Tensor,
cond: torch.Tensor,
method: str = "Euler",
step_size: float = 0.05,
return_intermediates: bool = False,
time_grid: Optional[torch.Tensor] = None,
diffusion_form: str = "linear",
diffusion_norm: float = 1.0,
) -> torch.Tensor:
"""
Generate samples using the SDE solver.
Args:
x_init: Initial noise tensor
cond: Conditioning tensor
method: SDE solver method ('Euler', 'Heun')
step_size: Step size for fixed-step solvers
return_intermediates: Whether to return intermediate states
time_grid: Optional time points for sampling. If None, uses default grid
diffusion_form: Form of diffusion coefficient ('linear', 'constant', 'SBDM', etc.)
diffusion_norm: Scale of the diffusion coefficient
Returns:
Generated samples
"""
num_steps = int(1.0/step_size) + 1 if time_grid is None else len(time_grid)

# Get SDE sampler with specified method
sde_sampler = self.sampler.sample_sde(
sampling_method=method,
diffusion_form=diffusion_form,
diffusion_norm=diffusion_norm,
last_step="Mean", # Use mean for last step correction
last_step_size=step_size,
num_steps=num_steps,
)

# Sample using the velocity model
samples = sde_sampler(
x_init,
model=lambda x, t, **kwargs: self.velocity_model(x, t, kwargs["cond"]),
cond=cond
)

if return_intermediates:
return samples
return samples[-1] # Return only final state if intermediates not requested


class SegDiffModel(nn.Module):
Expand Down Expand Up @@ -439,23 +569,20 @@ def forward(
device=past_target.device,
)

# Setup time steps for flow

time_grid = torch.linspace(0, 1.0, self.n_steps + 1, device=x.device)

# Get last decoder output and repeat for parallel samples
last_cond = flow_cond[:, -1, :].repeat_interleave(
num_parallel_samples, dim=0
)

# Evolve the samples through time using the flow
x = self.flow.solver.sample(
time_grid=time_grid,
x = self.flow.sample(
x_init=x,
method="midpoint",
cond=last_cond,
method="Heun",
step_size=0.05,
return_intermediates=False,
cond=last_cond,
time_grid=time_grid
)

# Reshape and scale the samples
Expand Down Expand Up @@ -522,13 +649,13 @@ def forward(
# t_end=time_steps[i + 1],
# cond=last_cond,
# )
x = self.flow.solver.sample(
time_grid=time_grid,
x = self.flow.sample(
x_init=x,
method="midpoint",
cond=last_cond,
method="Heun",
step_size=0.05,
return_intermediates=False,
cond=last_cond,
time_grid=time_grid
)

# Scale and store the samples
Expand Down
67 changes: 67 additions & 0 deletions src/gluonts/torch/model/seg_diff/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from .transport import Transport, ModelType, WeightType, PathType, Sampler


def create_transport(
path_type="Linear",
prediction="velocity",
loss_weight=None,
train_eps=None,
sample_eps=None,
):
"""function for creating Transport object
**Note**: model prediction defaults to velocity
Args:
- path_type: type of path to use; default to linear
- learn_score: set model prediction to score
- learn_noise: set model prediction to noise
- velocity_weighted: weight loss by velocity weight
- likelihood_weighted: weight loss by likelihood weight
- train_eps: small epsilon for avoiding instability during training
- sample_eps: small epsilon for avoiding instability during sampling
"""

if prediction == "noise":
model_type = ModelType.NOISE
elif prediction == "score":
model_type = ModelType.SCORE
else:
model_type = ModelType.VELOCITY

if loss_weight == "velocity":
loss_type = WeightType.VELOCITY
elif loss_weight == "likelihood":
loss_type = WeightType.LIKELIHOOD
else:
loss_type = WeightType.NONE

path_choice = {
"Linear": PathType.LINEAR,
"GVP": PathType.GVP,
"VP": PathType.VP,
}

path_type = path_choice[path_type]

if path_type in [PathType.VP]:
train_eps = 1e-5 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
elif (
path_type in [PathType.GVP, PathType.LINEAR]
and model_type != ModelType.VELOCITY
):
train_eps = 1e-3 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
else: # velocity & [GVP, LINEAR] is stable everywhere
train_eps = 0
sample_eps = 0

# create flow state
state = Transport(
model_type=model_type,
path_type=path_type,
loss_type=loss_type,
train_eps=train_eps,
sample_eps=sample_eps,
)

return state
Loading

0 comments on commit 8d5f4b8

Please sign in to comment.