Skip to content

Commit

Permalink
add subset param
Browse files Browse the repository at this point in the history
  • Loading branch information
yoelcortes authored May 24, 2024
1 parent a54dc67 commit 665a04a
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions flexsolve/iterative_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@

@register_jitable(cache=True)
def fixed_point(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True,
checkconvergence=True, convergenceiter=0):
checkconvergence=True, convergenceiter=0, subset=0):
"""Iterative fixed-point solver."""
x0 = x1 = x
errors = np.zeros(convergenceiter)
fixedpoint_converged = utils.fixedpoint_converged
for iter in range(maxiter):
x1 = f(x0, *args)
e = np.abs(x1 - x0)
if fixedpoint_converged(e, xtol): return x1
if fixedpoint_converged(e, xtol, subset): return x1
if convergenceiter:
mean = utils.mean(e)
if iter > convergenceiter and mean > errors.mean():
Expand All @@ -51,7 +51,7 @@ def conditional_fixed_point(f, x):

@register_jitable(cache=True)
def wegstein(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True,
checkconvergence=True, convergenceiter=0):
checkconvergence=True, convergenceiter=0, subset=0):
"""Iterative Wegstein solver."""
errors = np.zeros(convergenceiter)
x0 = x
Expand All @@ -65,7 +65,7 @@ def wegstein(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True,
x1 = g0
g1 = f(x1, *args)
e = np.abs(g1 - x1)
if fixedpoint_converged(e, xtol): return g1
if fixedpoint_converged(e, xtol, subset): return g1
x0 = x1
x1 = wegstein_iter(x1, dx, g1, g0)
g0 = g1
Expand Down Expand Up @@ -100,7 +100,7 @@ def conditional_wegstein(f, x):

@register_jitable(cache=True)
def aitken(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True,
checkconvergence=True, convergenceiter=0):
checkconvergence=True, convergenceiter=0, subset=0):
"""Iterative Aitken solver."""
gg = x
errors = np.zeros(convergenceiter)
Expand All @@ -113,10 +113,10 @@ def aitken(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True,
g = f(x, *args)
dxg = x - g
e = np.abs(dxg)
if fixedpoint_converged(e, xtol): return g
if fixedpoint_converged(e, xtol, subset): return g
gg = f(g, *args)
dgg_g = gg - g
if fixedpoint_converged(np.abs(dgg_g), xtol): return gg
if fixedpoint_converged(np.abs(dgg_g), xtol, subset): return gg
x = aitken_iter(x, gg, dxg, dgg_g)
if convergenceiter:
mean = utils.mean(e)
Expand All @@ -143,4 +143,4 @@ def conditional_aitken(f, x):
if not condition: return g
gg, condition = f(g)
x = aitken_iter(x, gg, x - g, gg - g)
return x
return x

0 comments on commit 665a04a

Please sign in to comment.