Skip to content

Commit

Permalink
Docs fixes (#21)
Browse files Browse the repository at this point in the history
* Docs fixes

* More fixes

* More fixes

* Rename variable

* Improve equation format

* Fix equations

* Minor edits

* Trivial edit
  • Loading branch information
bwohlberg authored Oct 5, 2021
1 parent bcd4f0f commit b375953
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
42 changes: 21 additions & 21 deletions scico/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ class LinearSubproblemSolver(SubproblemSolver):
\mb{x}^{(k+1)} = \argmin_{\mb{x}} \; \frac{1}{2} \norm{\mb{y} - A x}_W^2 +
\sum_i \frac{\rho_i}{2} \norm{\mb{z}^{(k)}_i - \mb{u}^{(k)}_i - C_i \mb{x}}_2^2 \;,
where :math:`W` is the weighting :class:`.LinearOperator` from the :class:`.WeightedSquaredL2Loss`
instance. This update step reduces to the solution of the linear system
where :math:`W` is the weighting :class:`.LinearOperator` from the
:class:`.WeightedSquaredL2Loss` instance. This update step reduces to the solution
of the linear system
.. math::
\left(A^* W A + \sum_{i=1}^N \rho_i C_i^* C_i \right) \mb{x}^{(k+1)} = \;
Expand Down Expand Up @@ -173,8 +174,8 @@ def internal_init(self, admm):

super().internal_init(admm)

# set lhs_op = \sum_i rho_i * Ci.H @ CircularConvolve
# use reduce as the initialization of this sum is messy otherwise
# Set lhs_op = \sum_i rho_i * Ci.H @ CircularConvolve
# Use reduce as the initialization of this sum is messy otherwise
lhs_op = reduce(
lambda a, b: a + b, [rhoi * Ci.gram_op for rhoi, Ci in zip(admm.rho_list, admm.C_list)]
)
Expand Down Expand Up @@ -295,21 +296,21 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]:


class ADMM:
r"""Basic Alternating Direction Method of Multipliers (ADMM)
algorithm :cite:`boyd-2010-distributed`.
r"""Basic Alternating Direction Method of Multipliers (ADMM) algorithm
:cite:`boyd-2010-distributed`.
|
Solve an optimization problem of the form
.. math::
\argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}),
\argmin_{\mb{x}} \; f(\mb{x}) + \sum_{i=1}^N g_i(C_i \mb{x}) \;,
where :math:`f` is an instance of :class:`.Loss`, the :math:`g_i` are :class:`.Functional`,
and the :math:`C_i` are :class:`.LinearOperator`.
The optimization problem is solved by introducing the splitting :math:`\mb{z}_i = C_i \mb{x}`
and solving
The optimization problem is solved by introducing the splitting :math:`\mb{z}_i =
C_i \mb{x}` and solving
.. math::
\argmin_{\mb{x}, \mb{z}_i} \; f(\mb{x}) + \sum_{i=1}^N g_i(\mb{z}_i) \;
Expand Down Expand Up @@ -423,7 +424,7 @@ def itstat_func(i, obj):
)

else:
# at least one 'g' can't be evaluated, so drop objective from the default itstat
# At least one 'g' can't be evaluated, so drop objective from the default itstat
itstat_dict = {"Iter": "%d", "Primal Rsdl": "%8.3e", "Dual Rsdl": "%8.3e"}

def itstat_func(i, admm):
Expand Down Expand Up @@ -502,7 +503,7 @@ def norm_dual_residual(self) -> float:
r"""Compute the :math:`\ell_2` norm of the dual residual.
.. math::
\left(\sum_{i=1}^N \norm{\mb{z}^{(k)} - \mb{z}^{(k-1)}}_2^2\right)^{1/2}
\left(\sum_{i=1}^N \norm{\mb{z}^{(k)}_i - \mb{z}^{(k-1)}_i}_2^2\right)^{1/2}
Returns:
Current value of dual residual
Expand All @@ -514,12 +515,12 @@ def norm_dual_residual(self) -> float:
return snp.sqrt(out)

def z_init(self, x0: Union[JaxArray, BlockArray]):
r"""Initialize auxiliary variables :math:`\mb{z}`.
r"""Initialize auxiliary variables :math:`\mb{z}_i`.
Initializes to
.. math::
\mb{z}_i = C_i \mb{x}_0
\mb{z}_i = C_i \mb{x}^{(0)}
:code:`z_list` and :code:`z_list_old` are initialized to the same value.
Expand All @@ -531,12 +532,12 @@ def z_init(self, x0: Union[JaxArray, BlockArray]):
return z_list, z_list_old

def u_init(self, x0: Union[JaxArray, BlockArray]):
r"""Initialize scaled Lagrange multipliers :math:`\mb{u}`.
r"""Initialize scaled Lagrange multipliers :math:`\mb{u}_i`.
Initializes to
.. math::
\mb{u}_i = C_i \mb{x}_0
\mb{u}_i = C_i \mb{x}^{(0)}
Args:
Expand All @@ -556,8 +557,8 @@ def x_step(self, x):
return self.subproblem_solver.solve(x)

def z_and_u_step(self, u_list, z_list):
r""" Update the auxiliary variables :math:`\mb{z}` and scaled Lagrange multipliers
:math:`\mb{u}`.
r"""Update the auxiliary variables :math:`\mb{z}_i` and scaled Lagrange multipliers
:math:`\mb{u}_i`.
The auxiliary variables are updated according to
Expand All @@ -576,15 +577,15 @@ def z_and_u_step(self, u_list, z_list):
"""
z_list_old = z_list.copy()

# unpack the arrays that will be changing to prevent side-effects
# Unpack the arrays that will be changing to prevent side-effects
z_list = self.z_list
u_list = self.u_list

for i, (rhoi, fi, Ci, zi, ui) in enumerate(
for i, (rhoi, gi, Ci, zi, ui) in enumerate(
zip(self.rho_list, self.g_list, self.C_list, z_list, u_list)
):
Cix = Ci(self.x)
zi = fi.prox(Cix + ui, 1 / rhoi)
zi = gi.prox(Cix + ui, 1 / rhoi)
ui = ui + Cix - zi
z_list[i] = zi
u_list[i] = ui
Expand All @@ -594,7 +595,6 @@ def step(self):
"""Perform a single ADMM iteration.
Equivalent to calling :meth:`.x_step` followed by :meth:`.z_and_u_step`.
"""
self.x = self.x_step(self.x)
self.u_list, self.z_list, self.z_list_old = self.z_and_u_step(self.u_list, self.z_list)
Expand Down
2 changes: 1 addition & 1 deletion scico/objax.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,5 @@ def __call__(self, x: JaxArray, training: bool) -> JaxArray:
x = ly(x, training)

x = self.post_conv(x)
# residual-like output
# Residual-like output
return base - x
6 changes: 3 additions & 3 deletions scico/pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def update(self, v: Union[JaxArray, BlockArray]) -> float:

if self.xprev is None:
# Solution and gradient of previous iterate are required.
# For first iteration these variables are stored and current estimation is returned.
# For first iteration these variables are stored and current estimate is returned.
self.xprev = v
self.gradprev = self.pgm.f.grad(self.xprev)
L = self.pgm.L
Expand Down Expand Up @@ -197,7 +197,7 @@ def update(self, v: Union[JaxArray, BlockArray]) -> float:
# Store current state and gradient for next update.
self.xprev = v
self.gradprev = gradv
# Store current estimations of Barzilai-Borwein 1 (Lbb1) and Barzilai-Borwein 2 (Lbb2).
# Store current estimates of Barzilai-Borwein 1 (Lbb1) and Barzilai-Borwein 2 (Lbb2).
self.Lbb1prev = Lbb1
self.Lbb2prev = Lbb2

Expand Down Expand Up @@ -333,7 +333,7 @@ class PGM:
The function :math:`f` must be smooth and :math:`g` must have a defined prox.
Uses helper :class:`StepSize` to provide an estimation of the Lipschitz constant :math:`L` of :math:`f`. The step size :math:`\alpha` is the reciprocal of :math:`L`, i.e.: :math:`\alpha = 1 / L`.
Uses helper :class:`StepSize` to provide an estimate of the Lipschitz constant :math:`L` of :math:`f`. The step size :math:`\alpha` is the reciprocal of :math:`L`, i.e.: :math:`\alpha = 1 / L`.
"""

def __init__(
Expand Down

0 comments on commit b375953

Please sign in to comment.