From 8512383f70cd635518faad4aea34f6aa02e8eaab Mon Sep 17 00:00:00 2001 From: SuperSashka Date: Thu, 25 Jul 2024 11:49:14 +0300 Subject: [PATCH] PSO fix --- tedeous/optimizers/pso.py | 20 +++++++++++++++++--- tedeous/version.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tedeous/optimizers/pso.py b/tedeous/optimizers/pso.py index d2fa656b..5dc97d2e 100644 --- a/tedeous/optimizers/pso.py +++ b/tedeous/optimizers/pso.py @@ -114,8 +114,9 @@ def build_swarm(self): matrix = torch.cat(matrix) variance = torch.FloatTensor(self.pop_size, self.vec_shape).uniform_( -self.variance, self.variance).to(device_type()) - swarm = (matrix + variance).clone().detach().requires_grad_(True) - return swarm + swarm = matrix + variance + swarm[0] = matrix[0] + return swarm.clone().detach().requires_grad_(True) def update_pso_params(self) -> None: """Method for updating pso parameters if c_decrease=True. @@ -176,7 +177,10 @@ def gradient_descent(self) -> torch.Tensor: self.m1 = self.beta1 * self.m1 + (1 - self.beta1) * self.grads_swarm self.m2 = self.beta2 * self.m2 + (1 - self.beta2) * torch.square( self.grads_swarm) - return self.lr * self.m1 / torch.sqrt(self.m2) + self.epsilon + + update = self.lr * self.m1 / (torch.sqrt(torch.abs(self.m2)) + self.epsilon) + + return update def step(self, closure=None) -> torch.Tensor: """ It runs ONE step on the particle swarm optimization. @@ -186,6 +190,16 @@ def step(self, closure=None) -> torch.Tensor: """ self.loss_swarm, self.grads_swarm = closure() + + fix_attempt=0 + + while torch.any(self.loss_swarm!=self.loss_swarm): + self.swarm=self.swarm+0.001*torch.rand(size=self.swarm.shape) + self.loss_swarm, self.grads_swarm = closure() + fix_attempt+=1 + if fix_attempt>5: + break + if self.indicator: self.f_p = copy(self.loss_swarm).detach() self.g_best = self.p[torch.argmin(self.f_p)] diff --git a/tedeous/version.py b/tedeous/version.py index 222c11cf..b703f5c9 100644 --- a/tedeous/version.py +++ b/tedeous/version.py @@ -1 +1 @@ -__version__ = '0.4.0' \ No newline at end of file +__version__ = '0.4.1' \ No newline at end of file