Skip to content

Commit

Permalink
[misc] Ensure succeeded variable is properly initialized in matrix-fr…
Browse files Browse the repository at this point in the history
…ee solvers (#8484)

The `succeeded` variable was not properly initialized in the
`MatrixFreeCG` and `MatrixFreeBICGSTAB` functions, leading to potential
issues with the convergence check. By initializing the `succeeded`
variable at the beginning of the `solve` function, we ensure that the
variable is correctly set and returned at the end of the function,
improving the reliability of the solvers.

Issue: #

### Brief Summary

copilot:summary

### Walkthrough

copilot:walkthrough
  • Loading branch information
liblaf committed Jun 23, 2024
1 parent b649d14 commit 55d8e36
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions python/taichi/linalg/matrixfree_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
beta = ti.field(dtype=solver_dtype)
scalar_builder.place(alpha, beta)
scalar_snode_tree = scalar_builder.finalize()
succeeded = True

@ti.kernel
def init():
Expand Down Expand Up @@ -96,6 +95,7 @@ def update_p():
p[I] = r[I] + beta[None] * p[I]

def solve():
succeeded = True
A._matvec(x, Ax)
init()
initial_rTr = reduce(r, r)
Expand Down Expand Up @@ -129,8 +129,9 @@ def solve():
f">>> Conjugate Gradient method failed to converge in {maxiter} iterations: Residual = {sqrt(new_rTr):e}"
)
succeeded = False
return succeeded

solve()
succeeded = solve()
vector_fields_snode_tree.destroy()
scalar_snode_tree.destroy()
return succeeded
Expand Down Expand Up @@ -252,6 +253,7 @@ def update_r():
r[I] = s[I] - omega[None] * t[I]

def solve():
succeeded = True
A._matvec(x, Ax)
init()
initial_rTr = reduce(r, r)
Expand Down Expand Up @@ -296,8 +298,9 @@ def solve():
if not quiet:
print(f">>> BICGSTAB failed to converge in {maxiter} iterations: Residual = {sqrt(rTr):e}")
succeeded = False
return succeeded

solve()
succeeded = solve()
vector_fields_snode_tree.destroy()
scalar_snode_tree.destroy()
return succeeded

0 comments on commit 55d8e36

Please sign in to comment.