Skip to content

Commit

Permalink
PSO fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperSashka committed Jul 25, 2024
1 parent 94820b8 commit 8512383
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
20 changes: 17 additions & 3 deletions tedeous/optimizers/pso.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ def build_swarm(self):
matrix = torch.cat(matrix)
variance = torch.FloatTensor(self.pop_size, self.vec_shape).uniform_(
-self.variance, self.variance).to(device_type())
swarm = (matrix + variance).clone().detach().requires_grad_(True)
return swarm
swarm = matrix + variance
swarm[0] = matrix[0]
return swarm.clone().detach().requires_grad_(True)

def update_pso_params(self) -> None:
"""Method for updating pso parameters if c_decrease=True.
Expand Down Expand Up @@ -176,7 +177,10 @@ def gradient_descent(self) -> torch.Tensor:
self.m1 = self.beta1 * self.m1 + (1 - self.beta1) * self.grads_swarm
self.m2 = self.beta2 * self.m2 + (1 - self.beta2) * torch.square(
self.grads_swarm)
return self.lr * self.m1 / torch.sqrt(self.m2) + self.epsilon

update = self.lr * self.m1 / (torch.sqrt(torch.abs(self.m2)) + self.epsilon)

return update

def step(self, closure=None) -> torch.Tensor:
""" It runs ONE step on the particle swarm optimization.
Expand All @@ -186,6 +190,16 @@ def step(self, closure=None) -> torch.Tensor:
"""

self.loss_swarm, self.grads_swarm = closure()

fix_attempt=0

while torch.any(self.loss_swarm!=self.loss_swarm):
self.swarm=self.swarm+0.001*torch.rand(size=self.swarm.shape)
self.loss_swarm, self.grads_swarm = closure()
fix_attempt+=1
if fix_attempt>5:
break

if self.indicator:
self.f_p = copy(self.loss_swarm).detach()
self.g_best = self.p[torch.argmin(self.f_p)]
Expand Down
2 changes: 1 addition & 1 deletion tedeous/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.0'
__version__ = '0.4.1'

0 comments on commit 8512383

Please sign in to comment.