Skip to content

Commit

Permalink
Expose activity energy gradient for potential advanced use
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Jul 4, 2024
1 parent 2d5cd42 commit a385868
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/api/Gradients.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Gradients

::: jpc.neg_activity_grad

---

::: jpc.compute_pc_param_grads
1 change: 1 addition & 0 deletions jpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
init_activities_from_gaussian as init_activities_from_gaussian,
init_activities_with_amort as init_activities_with_amort,
pc_energy_fn as pc_energy_fn,
neg_activity_grad as neg_activity_grad,
solve_pc_activities as solve_pc_activities,
compute_pc_param_grads as compute_pc_param_grads,
linear_equilib_energy as linear_equilib_energy
Expand Down
5 changes: 4 additions & 1 deletion jpc/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
init_activities_with_amort as init_activities_with_amort
)
from ._energies import pc_energy_fn as pc_energy_fn
from ._grads import (
neg_activity_grad as neg_activity_grad,
compute_pc_param_grads as compute_pc_param_grads
)
from ._infer import solve_pc_activities as solve_pc_activities
from ._grads import compute_pc_param_grads as compute_pc_param_grads
from ._analytical import linear_equilib_energy_batch as linear_equilib_energy
4 changes: 2 additions & 2 deletions jpc/_core/_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._energies import pc_energy_fn


def _neg_activity_grad(
def neg_activity_grad(
t: float | int,
activities: PyTree[ArrayLike],
args: Tuple[PyTree[Callable], ArrayLike, Optional[ArrayLike]],
Expand Down Expand Up @@ -54,7 +54,7 @@ def compute_pc_param_grads(
y: ArrayLike,
x: Optional[ArrayLike] = None
) -> PyTree[Array]:
"""Computes the gradient of the energy with respect to network parameters $\partial \mathcal{F} / \partial θ$.
"""Computes the gradient of the energy with respect to model parameters $\partial \mathcal{F} / \partial θ$.
**Main arguments:**
Expand Down
4 changes: 2 additions & 2 deletions jpc/_core/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jaxtyping import PyTree, ArrayLike, Array
from typing import Callable, Optional
from ._grads import _neg_activity_grad
from ._grads import neg_activity_grad
from diffrax import (
AbstractSolver,
AbstractStepSizeController,
Expand Down Expand Up @@ -65,7 +65,7 @@ def solve_pc_activities(
"""
sol = diffeqsolve(
terms=ODETerm(_neg_activity_grad),
terms=ODETerm(neg_activity_grad),
solver=solver,
t0=0,
t1=t1,
Expand Down

0 comments on commit a385868

Please sign in to comment.