diff --git a/src/gluonts/torch/model/seg_diff/estimator.py b/src/gluonts/torch/model/seg_diff/estimator.py index 84fd14fc06..2ac76ab135 100644 --- a/src/gluonts/torch/model/seg_diff/estimator.py +++ b/src/gluonts/torch/model/seg_diff/estimator.py @@ -113,6 +113,7 @@ def __init__( nhead: int = 4, dim_feedforward: int = 128, num_feat_dynamic_real: int = 0, + n_steps: int = 10, dropout: float = 0.1, activation: str = "relu", norm_first: bool = False, @@ -138,6 +139,7 @@ def __init__( self.context_length = patch_len * context_length_multiplier self.context_length_multiplier = context_length_multiplier self.prediction_length = prediction_length + self.n_steps = n_steps self.lr = lr self.weight_decay = weight_decay @@ -203,6 +205,7 @@ def create_lightning_module(self) -> pl.LightningModule: "num_decoder_layers": self.num_decoder_layers, # "distr_output": self.distr_output, "scaling": self.scaling, + "n_steps": self.n_steps, }, ) diff --git a/src/gluonts/torch/model/seg_diff/module.py b/src/gluonts/torch/model/seg_diff/module.py index fa8f126807..43729e3f96 100644 --- a/src/gluonts/torch/model/seg_diff/module.py +++ b/src/gluonts/torch/model/seg_diff/module.py @@ -21,10 +21,12 @@ from gluonts.core.component import validated from gluonts.model import Input, InputSpec -from gluonts.torch.distributions import StudentTOutput from gluonts.torch.scaler import StdScaler, MeanScaler, NOPScaler from gluonts.torch.util import take_last, unsqueeze_expand, weighted_average from gluonts.torch.model.simple_feedforward import make_linear_layer +from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, PolynomialConvexScheduler, VPScheduler, LinearVPScheduler +from flow_matching.path import AffineProbPath +from flow_matching.solver import ODESolver class ClassInstantier(OrderedDict): @@ -108,51 +110,116 @@ def forward(self, x: torch.Tensor): return self.layer_norm(out) return out +class VelocityModel(nn.Module): + def __init__(self, cond_dim: int, out_dim: int, h: int, time_embed_dim: int = 8): + super().__init__() + # Time embedding network + self.time_embed = nn.Sequential( + nn.Linear(1, time_embed_dim), + nn.GELU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + # Conditioning network for better feature extraction + self.cond_net = nn.Sequential( + nn.Linear(cond_dim, h), + nn.GELU(), + nn.Linear(h, h) + ) + + # Main velocity network with skip connections + self.net = nn.Sequential( + nn.Linear(out_dim + h + time_embed_dim, h), + nn.GELU(), + nn.Linear(h, h), + nn.GELU(), + nn.Dropout(0.1), # Add some regularization + nn.Linear(h, h), + nn.GELU(), + nn.Linear(h, out_dim) + ) + + # Initialize weights for better gradient flow + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + module.bias.data.zero_() + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor): + # Handle different time tensor shapes + if t.ndim == 0: # scalar time + t = t.view(1) + + # Expand t to match batch dimensions of x + while t.ndim < x.ndim: + t = t.unsqueeze(-1) + t = t.expand(*x.shape[:-1], 1) + + # Get time embeddings + t_embed = self.time_embed(t) + + # Process conditioning + cond_features = self.cond_net(cond) + + # Concatenate all inputs and compute velocity + inputs = torch.cat([x, t_embed, cond_features], dim=-1) + return self.net(inputs) class Flow(nn.Module): def __init__(self, cond_dim: int, out_dim: int, h: int): super().__init__() - - self.linear = nn.Linear(out_dim + cond_dim + 1, h) - self.act = ACT2FN["gelu"] - self.output_layer = nn.Linear(h, out_dim) - - def forward(self, x_t: torch.Tensor, t: torch.Tensor, cond: torch.Tensor): - x = torch.cat((x_t, t, cond), dim=-1) - x = self.linear(x) - x = self.act(x) - x = self.output_layer(x) - return x + + # Define MLP for velocity field + self.velocity_model = VelocityModel(cond_dim, out_dim, h) + + # Flow matching components + scheduler = CondOTScheduler() + self.prob_path = AffineProbPath(scheduler=scheduler) + self.solver = ODESolver(self.velocity_model) + + def compute_loss(self, x_0: torch.Tensor, x_1: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + """Compute flow matching loss.""" + # Sample time uniformly + t = torch.rand(x_0.shape[0], x_0.shape[1], device=x_0.device) + + # Get path sample from probability path with scheduler outputs + path_sample = self.prob_path.sample( + t=t, + x_0=x_0, + x_1=x_1, + ) + + # Get velocity field prediction + v_t = self.velocity_model(path_sample.x_t, path_sample.t, cond) + + # Flow matching loss + return F.mse_loss(v_t, path_sample.dx_t) @torch.inference_mode() def step( self, - x_t: torch.Tensor, + x_t: torch.Tensor, t_start: float, t_end: float, cond: torch.Tensor, ) -> torch.Tensor: - """Performs one step of the flow matching process. - - Args: - x_t: Input tensor to evolve - t_start: Starting time - t_end: Ending time - cond: Conditioning tensor from transformer decoder - """ - # Expand t_start to match batch dimension - t_start = torch.full((x_t.shape[0], 1), t_start, device=x_t.device) - t_mid = t_start + (t_end - t_start) / 2 - - # First half step - v1 = self(x_t=x_t, t=t_start, cond=cond) - x_mid = x_t + v1 * (t_end - t_start) / 2 - - # Second half step - v2 = self(x_t=x_mid, t=t_mid, cond=cond) - x_end = x_t + v2 * (t_end - t_start) - - return x_end + """Performs one step of the flow matching process using ODE solver.""" + # Create time grid for integration + T = torch.tensor([t_start, t_end], device=x_t.device) + + # Solve ODE + sol = self.solver.sample( + x_init=x_t, + time_grid=T, + method='midpoint', + step_size=t_end - t_start, + cond=cond # Pass conditioning as context + ) + + return sol class SegDiffModel(nn.Module): @@ -188,6 +255,7 @@ def __init__( dropout_rate: float = 0.1, num_parallel_samples: int = 100, flow_hidden_dim: int = 64, + n_steps: int = 10, ) -> None: super().__init__() @@ -198,6 +266,7 @@ def __init__( self.d_model = d_model self.num_feat_dynamic_real = num_feat_dynamic_real self.num_parallel_samples = num_parallel_samples + self.n_steps = n_steps if scaling == "mean": self.scaler = MeanScaler(keepdim=True) @@ -354,16 +423,8 @@ def loss( # Flow matching loss x_1 = target[:, 1:, :] # Target patches x_0 = torch.randn_like(x_1) # Random noise source distribution - # t is a tensor of shape (batch_size, num_patches, 1) - t = torch.rand((x_1.shape[0], x_1.shape[1], 1), device=x_1.device) - - x_t = (1 - t) * x_0 + t * x_1 - dx_t = x_1 - x_0 - - # Condition flow on decoder output - flow_out = self.flow(t=t, x_t=x_t, cond=flow_cond[:, :-1, :]) - - return F.mse_loss(flow_out, dx_t) + + return self.flow.compute_loss(x_0=x_0, x_1=x_1, cond=flow_cond[:, :-1, :]) def forward( self, @@ -394,8 +455,8 @@ def forward( ) # Setup time steps for flow - n_steps = 8 - time_steps = torch.linspace(0, 1.0, n_steps + 1, device=x.device) + + time_steps = torch.linspace(0, 1.0, self.n_steps + 1, device=x.device) # Get last decoder output and repeat for parallel samples last_cond = flow_cond[:, -1, :].repeat_interleave( @@ -403,7 +464,7 @@ def forward( ) # Evolve the samples through time using the flow - for i in range(n_steps): + for i in range(self.n_steps): x = self.flow.step( x_t=x, t_start=time_steps[i], @@ -468,7 +529,7 @@ def forward( last_cond = flow_cond[:, -1, :] # Evolve the new samples - for i in range(n_steps): + for i in range(self.n_steps): x = self.flow.step( x_t=x, t_start=time_steps[i],