From 261b59c900c2d5441f637a2d72d8decedb8433db Mon Sep 17 00:00:00 2001 From: Kamil Oster Date: Thu, 11 Jul 2024 11:41:28 +0100 Subject: [PATCH 1/4] Modify KAN to use random seed, instead of 0 --- kan/KAN.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kan/KAN.py b/kan/KAN.py index cdfeba3b..7f65ce91 100644 --- a/kan/KAN.py +++ b/kan/KAN.py @@ -78,7 +78,7 @@ class KAN(nn.Module): ''' def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun=torch.nn.SiLU(), symbolic_enabled=True, bias_trainable=False, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, - device='cpu', seed=0): + device='cpu', seed=None): ''' initalize a KAN model @@ -123,6 +123,9 @@ def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, scale_base_mu=0.0, ''' super(KAN, self).__init__() + if seed is None: + seed = np.random.randint(0, 1e3) + torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) From 9c42ba3f43d92ffc6977fbf0908c7b155dcb6640 Mon Sep 17 00:00:00 2001 From: Kamil Oster Date: Thu, 11 Jul 2024 11:41:57 +0100 Subject: [PATCH 2/4] Stop LBFGS from modifying torch seed to 0 --- kan/LBFGS.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/kan/LBFGS.py b/kan/LBFGS.py index a699b8c2..145bb631 100644 --- a/kan/LBFGS.py +++ b/kan/LBFGS.py @@ -302,8 +302,6 @@ def step(self, closure): and returns the loss. """ - torch.manual_seed(0) - assert len(self.param_groups) == 1 # Make sure the closure is always called with grad enabled From 2c0a6321196ecd186c5d9cdaff53a844bcc711e2 Mon Sep 17 00:00:00 2001 From: Kamil Oster Date: Thu, 11 Jul 2024 11:42:55 +0100 Subject: [PATCH 3/4] Modify MultKAN to use random seed, instead of 0 --- kan/MultKAN.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kan/MultKAN.py b/kan/MultKAN.py index a32517e4..22329d48 100644 --- a/kan/MultKAN.py +++ b/kan/MultKAN.py @@ -23,10 +23,13 @@ class MultKAN(nn.Module): # include mult_ops = [] - def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=1.0, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, device='cpu', seed=0, save_plot_data=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0): + def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=1.0, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, device='cpu', seed=None, save_plot_data=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0): super(MultKAN, self).__init__() + if seed is None: + seed = np.random.randint(0, 1e3) + torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) From f526ed5d27dcf07555e20ff2f8d0d39de6d74804 Mon Sep 17 00:00:00 2001 From: Kamil Oster Date: Thu, 11 Jul 2024 11:44:23 +0100 Subject: [PATCH 4/4] Modify utils to use random seed, instead of 0 --- kan/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kan/utils.py b/kan/utils.py index cb375012..f2cb4ba1 100644 --- a/kan/utils.py +++ b/kan/utils.py @@ -66,7 +66,7 @@ def create_dataset(f, normalize_input=False, normalize_label=False, device='cpu', - seed=0): + seed=None): ''' create dataset @@ -103,6 +103,9 @@ def create_dataset(f, torch.Size([100, 2]) ''' + if seed is None: + seed = np.random.randint(0, 1e3) + np.random.seed(seed) torch.manual_seed(seed)