Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MultKAN.py #383

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ For pypi users, this is the most recent version 0.2.1.

New functionalities include (documentation later):
* including multiplications in KANs. [Tutorial](https://github.com/KindXiaoming/pykan/blob/master/tutorials/Interp_1_Hello%2C%20MultKAN.ipynb)
* the speed mode. Speed up your KAN using `model = model.speed()` if you never use the symbolic functionalities. [Tutorial](https://github.com/KindXiaoming/pykan/blob/master/tutorials/Example_2_speed_up.ipynb)
* Compiling symbolic formulas into KANs. [Tutorial](https://github.com/KindXiaoming/pykan/blob/master/tutorials/Interp_3_KAN_Compiler.ipynb)
* Feature attribution and pruning inputs. [Tutorial](https://github.com/KindXiaoming/pykan/blob/master/tutorials/Interp_4_feature_attribution.ipynb)
* Using GPUs. It works on my end: please try the [tuotorial](https://github.com/KindXiaoming/pykan/blob/master/tutorials/API_10_device.ipynb) and let me know if it works on your end. Cuda runs 20x faster than CPU.

# Kolmogorov-Arnold Networks (KANs)

Expand Down
57 changes: 40 additions & 17 deletions kan/KANLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def forward(self, x):
y = torch.sum(y, dim=1) # shape (batch, out_dim)
return y, preacts, postacts, postspline

def update_grid_from_samples(self, x):
def update_grid_from_samples(self, x, mode='sample'):
'''
update grid from samples

Expand All @@ -215,21 +215,32 @@ def update_grid_from_samples(self, x):
tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])
tensor([[-3.0002, -1.7882, -0.5763, 0.6357, 1.8476, 3.0002]])
'''

batch = x.shape[0]
#x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
num_interval = self.grid.shape[1] - 1 - 2*self.k
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
margin = 0.01
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive

def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid

grid = get_grid(num_interval)

if mode == 'grid':
sample_grid = get_grid(2*num_interval)
x_pos = sample_grid.permute(1,0)
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)

self.grid.data = extend_grid(grid, k_extend=self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

def initialize_grid_from_parent(self, parent, x):
def initialize_grid_from_parent(self, parent, x, mode='sample'):
'''
update grid from a parent KANLayer & samples

Expand Down Expand Up @@ -257,19 +268,31 @@ def initialize_grid_from_parent(self, parent, x):
tensor([[-1.0000, -0.8000, -0.6000, -0.4000, -0.2000, 0.0000, 0.2000, 0.4000,
0.6000, 0.8000, 1.0000]])
'''

batch = x.shape[0]
# preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
#x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_eval = x
pgrid = parent.grid # (in_dim, G+2*k+1)
pk = parent.k
y_eval = coef2curve(x_eval, pgrid, parent.coef, pk)

h = (pgrid[:,[-pk]] - pgrid[:,[pk]])/self.num
grid = pgrid[:,[pk]] + torch.arange(self.num+1,) * h
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
num_interval = self.grid.shape[1] - 1 - 2*self.k

def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid

grid = get_grid(num_interval)

if mode == 'grid':
sample_grid = get_grid(2*num_interval)
x_pos = sample_grid.permute(1,0)
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)

grid = extend_grid(grid, k_extend=self.k)
self.grid.data = grid
self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

def get_subset(self, in_id, out_id):
'''
Expand Down
10 changes: 7 additions & 3 deletions kan/MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class MLP(nn.Module):

def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
def __init__(self, width, act='identity', save_act=True, seed=0, device='cpu'):
super(MLP, self).__init__()

torch.manual_seed(seed)
Expand All @@ -22,8 +22,12 @@ def __init__(self, width, act='silu', save_act=True, seed=0, device='cpu'):
linears.append(nn.Linear(width[i], width[i+1]))
self.linears = nn.ModuleList(linears)

#if activation == 'silu':
self.act_fun = torch.nn.SiLU()
if act == 'silu':
self.act_fun = torch.nn.SiLU()
elif act == 'relu':
self.act_fun = torch.nn.ReLU()
elif act == 'identity':
self.act_fun = torch.nn.Identity()
self.save_act = save_act
self.acts = None
self.device = device
Expand Down
63 changes: 44 additions & 19 deletions kan/MultKAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def initialize_from_another_model(self, another_model, x):
# spb = spb_parent
preacts = another_model.spline_preacts[l]
postsplines = another_model.spline_postsplines[l]
self.act_fun[l].coef.data = curve2coef(preacts[:,0,:], postsplines.permute(0,2,1), spb.grid, k=spb.k)
#self.act_fun[l].coef.data = curve2coef(preacts[:,0,:], postsplines.permute(0,2,1), spb.grid, k=spb.k)
self.act_fun[l].scale_base.data = another_model.act_fun[l].scale_base.data
self.act_fun[l].scale_sp.data = another_model.act_fun[l].scale_sp.data
self.act_fun[l].mask.data = another_model.act_fun[l].mask.data
Expand All @@ -170,7 +170,7 @@ def initialize_from_another_model(self, another_model, x):
for l in range(self.depth):
self.symbolic_fun[l] = another_model.symbolic_fun[l]

return self.to(device)
return self

def log_history(self, method_name):

Expand Down Expand Up @@ -290,6 +290,11 @@ def loadckpt(path='model'):
model_load.symbolic_fun[l].funs_avoid_singularity[j][i] = SYMBOLIC_LIB[fun_name][3]
return model_load

def copy(self):
path='copy_temp'
self.saveckpt(path)
return KAN.loadckpt(path)

def rewind(self, model_id):

self.round += 1
Expand Down Expand Up @@ -793,14 +798,17 @@ def reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
if reg_metric == 'edge_forward_n':
acts_scale = self.acts_scale_spline

if reg_metric == 'edge_forward_u':
elif reg_metric == 'edge_forward_u':
acts_scale = self.edge_actscale

if reg_metric == 'edge_backward':
elif reg_metric == 'edge_backward':
acts_scale = self.edge_scores

if reg_metric == 'node_backward':
elif reg_metric == 'node_backward':
acts_scale = self.node_attribute_scores

else:
raise Exception(f'reg_metric = {reg_metric} not recognized!')

reg_ = 0.
for i in range(len(acts_scale)):
Expand Down Expand Up @@ -830,12 +838,8 @@ def reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
def get_reg(self, reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff):
return self.reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)

def disable_symbolic_in_fit(self, lamb):

old_save_act = self.save_act
if lamb == 0.:
self.save_act = False

def disable_symbolic_in_fit(self):

# skip symbolic if no symbolic is turned on
depth = len(self.symbolic_fun)
no_symbolic = True
Expand All @@ -847,7 +851,22 @@ def disable_symbolic_in_fit(self, lamb):
if no_symbolic:
self.symbolic_enabled = False

return old_save_act, old_symbolic_enabled
return old_symbolic_enabled

def disable_save_act_in_fit(self, lamb):

old_save_act = self.save_act
if lamb == 0.:
self.save_act = False

return old_save_act

def recover_symbolic_in_fit(self, old_symbolic_enabled):
self.symbolic_enabled = old_symbolic_enabled

def recover_save_act_in_fit(self, old_save_act):
if old_save_act == True:
self.save_act = True

def get_params(self):
return self.parameters()
Expand All @@ -859,7 +878,8 @@ def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_
if lamb > 0. and not self.save_act:
print('setting lamb=0. If you want to set lamb > 0, set self.save_act=True')

old_save_act, old_symbolic_enabled = self.disable_symbolic_in_fit(lamb)
old_save_act = self.disable_save_act_in_fit(lamb)
old_symbolic_enabled = self.disable_symbolic_in_fit()

pbar = tqdm(range(steps), desc='description', ncols=100)

Expand Down Expand Up @@ -915,8 +935,8 @@ def closure():

for _ in pbar:

if _ == steps-1 and old_save_act:
self.save_act = True
if _ == steps-1:
self.recover_save_act_in_fit(old_save_act)

train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
Expand Down Expand Up @@ -977,7 +997,7 @@ def closure():

self.log_history('fit')
# revert back to original state
self.symbolic_enabled = old_symbolic_enabled
self.recover_symbolic_in_fit(old_symbolic_enabled)
return results

def prune_node(self, threshold=1e-2, mode="auto", active_neurons_id=None, log_history=True):
Expand Down Expand Up @@ -1197,12 +1217,13 @@ def score_node2subnode(node_score, width, mult_arity, out_dim):
subnode_score = node_score[:,:width[0]]
if isinstance(mult_arity, int):
#subnode_score[:,width[0]:] = node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[width[0]:].shape[0], mult_arity).reshape(out_dim,-1)
subnode_score = torch.cat([subnode_score, node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[width[0]:].shape[0], mult_arity).reshape(out_dim,-1)], dim=1)
subnode_score = torch.cat([subnode_score, node_score[:,width[0]:][:,:,None].expand(out_dim, node_score[:,width[0]:].shape[1], mult_arity).reshape(out_dim,-1)], dim=1)
else:
acml = width[0]
for i in range(len(mult_arity)):
#subnode_score[:, acml:acml+mult_arity[i]] = node_score[:, width[0]+i]
subnode_score = torch.cat([subnode_score, node_score[:, width[0]+i]].expand(out_dim, mult_arity[i]), dim=1)

subnode_score = torch.cat([subnode_score, node_score[:, width[0]+i].expand(out_dim, mult_arity[i])], dim=1)
acml += mult_arity[i]
return subnode_score

Expand Down Expand Up @@ -1234,7 +1255,7 @@ def score_node2subnode(node_score, width, mult_arity, out_dim):
subnode_score = score_node2subnode(node_score, self.width[l], self.mult_arity, out_dim=out_dim)
else:
mult_arity = self.mult_arity[l]
subnode_score = score_node2subnode(node_score, self.width[l], mult_arity)
subnode_score = score_node2subnode(node_score, self.width[l], mult_arity, out_dim=out_dim)

subnode_scores.append(subnode_score)
# subnode to edge
Expand Down Expand Up @@ -1512,6 +1533,8 @@ def expand_depth(self):
self.node_scale.append(torch.nn.Parameter(torch.ones(dim_out,)).requires_grad_(self.affine_trainable))
self.subnode_bias.append(torch.nn.Parameter(torch.zeros(dim_out,)).requires_grad_(self.affine_trainable))
self.subnode_scale.append(torch.nn.Parameter(torch.ones(dim_out,)).requires_grad_(self.affine_trainable))

self.log_history('expand_depth')

def expand_width(self, layer_id, n_added_nodes, sum_bool=True, mult_arity=2):

Expand Down Expand Up @@ -1647,6 +1670,8 @@ def _expand(layer_id, n_added_nodes, sum_bool=True, mult_arity=2, added_dim='out
self.width[layer_id][1] += n_added_nodes
self.mult_arity[layer_id] += mult_arity

self.log_history('expand_width')

def perturb(self, mag=0.02, mode='all'):
if mode == 'all':
for i in range(self.depth):
Expand Down
9 changes: 7 additions & 2 deletions kan/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def next_nontrivial_operation(expr, scale=1, bias=0):


#def sf2kan(input_variables, expr, grid=3, k=3, noise_scale=0.1, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun=torch.nn.SiLU(), symbolic_enabled=True, affine_trainable=False, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, device='cpu', seed=0):
def sf2kan(input_variables, expr, grid=5, k=3, auto_save=False):
def expr2kan(input_variables, expr, grid=5, k=3, auto_save=True):

class Node:
def __init__(self, expr, mult_bool, depth, scale, bias, parent=None, mult_arity=None):
Expand Down Expand Up @@ -344,7 +344,7 @@ def create_node(expr, parent=None, n_layer=None):
width[0][0] = len(input_variables)

# allow pass in other parameters (probably as a dictionary) in sf2kan, including grid k etc.
model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=auto_save)
model = MultKAN(width=width, mult_arity=mult_arities, grid=grid, k=k, auto_save=False)

# clean the graph
for l in range(model.depth):
Expand Down Expand Up @@ -442,4 +442,9 @@ def create_node(expr, parent=None, n_layer=None):
model.fix_symbolic(kc_depth, kc_i, kc_j, kfun_name, fit_params_bool=False)
model.symbolic_fun[kc_depth].affine.data.reshape(self.width_out[kc_depth+1], self.width_in[kc_depth], 4)[kc_j][kc_i] = torch.tensor(connection.affine)

model.auto_save = auto_save
model.log_history('kanpiler')

return model

kanpiler = expr2kan
Loading