diff --git a/flexsolve/iterative_solvers.py b/flexsolve/iterative_solvers.py index e1ae4fb..8200173 100644 --- a/flexsolve/iterative_solvers.py +++ b/flexsolve/iterative_solvers.py @@ -19,7 +19,7 @@ @register_jitable(cache=True) def fixed_point(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True, - checkconvergence=True, convergenceiter=0): + checkconvergence=True, convergenceiter=0, subset=0): """Iterative fixed-point solver.""" x0 = x1 = x errors = np.zeros(convergenceiter) @@ -27,7 +27,7 @@ def fixed_point(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True, for iter in range(maxiter): x1 = f(x0, *args) e = np.abs(x1 - x0) - if fixedpoint_converged(e, xtol): return x1 + if fixedpoint_converged(e, xtol, subset): return x1 if convergenceiter: mean = utils.mean(e) if iter > convergenceiter and mean > errors.mean(): @@ -51,7 +51,7 @@ def conditional_fixed_point(f, x): @register_jitable(cache=True) def wegstein(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True, - checkconvergence=True, convergenceiter=0): + checkconvergence=True, convergenceiter=0, subset=0): """Iterative Wegstein solver.""" errors = np.zeros(convergenceiter) x0 = x @@ -65,7 +65,7 @@ def wegstein(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True, x1 = g0 g1 = f(x1, *args) e = np.abs(g1 - x1) - if fixedpoint_converged(e, xtol): return g1 + if fixedpoint_converged(e, xtol, subset): return g1 x0 = x1 x1 = wegstein_iter(x1, dx, g1, g0) g0 = g1 @@ -100,7 +100,7 @@ def conditional_wegstein(f, x): @register_jitable(cache=True) def aitken(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True, - checkconvergence=True, convergenceiter=0): + checkconvergence=True, convergenceiter=0, subset=0): """Iterative Aitken solver.""" gg = x errors = np.zeros(convergenceiter) @@ -113,10 +113,10 @@ def aitken(f, x, xtol=5e-8, args=(), maxiter=50, checkiter=True, g = f(x, *args) dxg = x - g e = np.abs(dxg) - if fixedpoint_converged(e, xtol): return g + if fixedpoint_converged(e, xtol, subset): return g gg = f(g, *args) dgg_g = gg - g - if fixedpoint_converged(np.abs(dgg_g), xtol): return gg + if fixedpoint_converged(np.abs(dgg_g), xtol, subset): return gg x = aitken_iter(x, gg, dxg, dgg_g) if convergenceiter: mean = utils.mean(e) @@ -143,4 +143,4 @@ def conditional_aitken(f, x): if not condition: return g gg, condition = f(g) x = aitken_iter(x, gg, x - g, gg - g) - return x \ No newline at end of file + return x