-
Notifications
You must be signed in to change notification settings - Fork 17
/
jacobi_kan.py
76 lines (67 loc) · 3.16 KB
/
jacobi_kan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
from typing import List
# code modified from https://github.com/SpaceLearner/JacobiKAN
# This is inspired by Kolmogorov-Arnold Networks but using Jacobian polynomials instead of splines coefficients
class JacobiKANLayer(nn.Module):
def __init__(self, input_dim, output_dim, degree, a=1.0, b=1.0):
super(JacobiKANLayer, self).__init__()
self.inputdim = input_dim
self.outdim = output_dim
self.a = a
self.b = b
self.degree = degree
self.jacobi_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
nn.init.normal_(self.jacobi_coeffs, mean=0.0, std=1/(input_dim * (degree + 1)))
def forward(self, x):
x = torch.reshape(x, (-1, self.inputdim)) # shape = (batch_size, inputdim)
# Since Jacobian polynomial is defined in [-1, 1]
# We need to normalize x to [-1, 1] using tanh
x = torch.tanh(x)
# Initialize Jacobian polynomial tensors
jacobi = torch.ones(x.shape[0], self.inputdim, self.degree + 1, device=x.device)
if self.degree > 0: ## degree = 0: jacobi[:, :, 0] = 1 (already initialized) ; degree = 1: jacobi[:, :, 1] = x ; d
jacobi[:, :, 1] = ((self.a-self.b) + (self.a+self.b+2) * x) / 2
for i in range(2, self.degree + 1):
theta_k = (2*i+self.a+self.b)*(2*i+self.a+self.b-1) / (2*i*(i+self.a+self.b))
theta_k1 = (2*i+self.a+self.b-1)*(self.a*self.a-self.b*self.b) / (2*i*(i+self.a+self.b)*(2*i+self.a+self.b-2))
theta_k2 = (i+self.a-1)*(i+self.b-1)*(2*i+self.a+self.b) / (i*(i+self.a+self.b)*(2*i+self.a+self.b-2))
jacobi[:, :, i] = (theta_k * x + theta_k1) * jacobi[:, :, i - 1].clone() - theta_k2 * jacobi[:, :, i - 2].clone() # 2 * x * jacobi[:, :, i - 1].clone() - jacobi[:, :, i - 2].clone()
# Compute the Jacobian interpolation
y = torch.einsum('bid,iod->bo', jacobi, self.jacobi_coeffs) # shape = (batch_size, outdim)
y = y.view(-1, self.outdim)
return y
# To avoid gradient vanishing caused by tanh
class JacobiKANLayerWithNorm(nn.Module):
def __init__(self, input_dim, output_dim, degree, a, b):
super(JacobiKANLayerWithNorm, self).__init__()
self.layer = JacobiKANLayer(input_dim, output_dim, degree, a, b)
self.layer_norm = nn.LayerNorm(output_dim) # To avoid gradient vanishing caused by tanh
def forward(self, x):
x = self.layer(x)
x = self.layer_norm(x)
return x
class Jacobi_KAN(nn.Module):
def __init__(
self,
layers_hidden: List[int],
degree: int = 3,
a=1.0,
b=1.0,
grid_size: int = 8, # placeholder
spline_order=0. # placehold
) -> None:
super().__init__()
self.layers = nn.ModuleList([
JacobiKANLayerWithNorm(
input_dim=in_dim,
output_dim=out_dim,
degree=degree,
a=a,
b=b
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x