From 8cb1a7a24abe8e640fc6730d015361c97051f236 Mon Sep 17 00:00:00 2001 From: Kamil Oster Date: Mon, 15 Jul 2024 10:50:03 +0100 Subject: [PATCH] methods default lr unless specified --- kan/KAN.py | 13 ++++++++++--- kan/MultKAN.py | 13 ++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/kan/KAN.py b/kan/KAN.py index da708b60..72e681bc 100644 --- a/kan/KAN.py +++ b/kan/KAN.py @@ -756,7 +756,7 @@ def score2alpha(score): if title != None: plt.gcf().get_axes()[0].text(0.5, y0 * (len(self.width) - 1) + 0.2, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center') - def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=0., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., stop_grid_update_step=50, batch=-1, + def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=0., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=None, stop_grid_update_step=50, batch=-1, small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video'): ''' training @@ -846,9 +846,16 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor): grid_update_freq = int(stop_grid_update_step / grid_update_num) if opt == "Adam": - optimizer = torch.optim.Adam(self.parameters(), lr=lr) + optimizer = torch.optim.Adam(self.parameters()) + + if lr is not None: + optimizer.param_groups[0]['lr'] = lr + elif opt == "LBFGS": - optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) + optimizer = LBFGS(self.parameters(), history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) + + if lr is not None: + optimizer.param_groups[0]['lr'] = lr results = {} results['train_loss'] = [] diff --git a/kan/MultKAN.py b/kan/MultKAN.py index 97d66ca8..8dc9e74e 100644 --- a/kan/MultKAN.py +++ b/kan/MultKAN.py @@ -713,7 +713,7 @@ def score2alpha(score): if title != None: plt.gcf().get_axes()[0].text(0.5, (y0+z0) * (len(self.width) - 1) + 0.3, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center') - def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, + def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=None,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu', singularity_avoiding=False, y_th=1000., reg_metric='fa'): if lamb > 0. and not self.save_plot_data: @@ -754,12 +754,19 @@ def reg(acts_scale): grid_update_freq = int(stop_grid_update_step / grid_update_num) if opt == "Adam": - optimizer = torch.optim.Adam(self.parameters(), lr=lr) + optimizer = torch.optim.Adam(self.parameters()) + + if lr is not None: + optimizer.param_groups[0]['lr'] = lr + elif opt == "LBFGS": - optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) + optimizer = LBFGS(self.parameters(), history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) #optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) #optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, debug=True) + if lr is not None: + optimizer.param_groups[0]['lr'] = lr + results = {} results['train_loss'] = [] results['test_loss'] = []