diff --git a/src/efficient_kan/kan.py b/src/efficient_kan/kan.py index 7d36495..5779890 100644 --- a/src/efficient_kan/kan.py +++ b/src/efficient_kan/kan.py @@ -151,14 +151,19 @@ def scaled_spline_weight(self): ) def forward(self, x: torch.Tensor): - assert x.dim() == 2 and x.size(1) == self.in_features + assert x.size(-1) == self.in_features + original_shape = x.shape + x = x.view(-1, self.in_features) base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) - return base_output + spline_output + output = base_output + spline_output + + output = output.view(*original_shape[:-1], self.out_features) + return output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01):