Skip to content

Commit

Permalink
Rename inference I/O variables for clarity.
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Nov 21, 2024
1 parent b47ee87 commit e1b4d2f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
10 changes: 5 additions & 5 deletions jpc/_core/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
def solve_pc_inference(
params: Tuple[PyTree[Callable], Optional[PyTree[Callable]]],
activities: PyTree[ArrayLike],
y: ArrayLike,
x: Optional[ArrayLike] = None,
output: ArrayLike,
input: Optional[ArrayLike] = None,
loss_id: str = "MSE",
solver: AbstractSolver = Heun(),
max_t1: int = 20,
Expand Down Expand Up @@ -49,8 +49,8 @@ def solve_pc_inference(
- `params`: Tuple with callable model layers and optional skip connections.
- `activities`: List of activities for each layer free to vary.
- `y`: Observation or target of the generative model.
- `x`: Optional prior of the generative model.
- `output`: Observation or target of the generative model.
- `input`: Optional prior of the generative model.
**Other arguments:**
Expand Down Expand Up @@ -87,7 +87,7 @@ def solve_pc_inference(
t1=max_t1,
dt0=dt,
y0=activities,
args=(params, y, x, loss_id, stepsize_controller),
args=(params, output, input, loss_id, stepsize_controller),
stepsize_controller=stepsize_controller,
event=Event(steady_state_event_with_timeout),
saveat=saveat
Expand Down
6 changes: 3 additions & 3 deletions jpc/_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_generative_pc(
input_preds = solve_pc_inference(
params=params,
activities=activities,
y=output,
output=output,
solver=ode_solver,
max_t1=max_t1,
dt=dt,
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_hpc(
hpc_preds = solve_pc_inference(
params=gen_params,
activities=amort_activities,
y=output,
output=output,
solver=ode_solver,
max_t1=max_t1,
dt=dt,
Expand All @@ -223,7 +223,7 @@ def test_hpc(
gen_preds = solve_pc_inference(
params=gen_params,
activities=activities,
y=output,
output=output,
solver=ode_solver,
max_t1=max_t1,
dt=dt,
Expand Down
8 changes: 4 additions & 4 deletions jpc/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def make_pc_step(
equilib_activities = solve_pc_inference(
params=(model, skip_model),
activities=activities,
y=output,
x=input,
output=output,
input=input,
loss_id=loss_id,
solver=ode_solver,
max_t1=max_t1,
Expand Down Expand Up @@ -329,8 +329,8 @@ def make_hpc_step(
equilib_activities = solve_pc_inference(
params=gen_params,
activities=amort_activities[1:] if input is not None else amort_activities,
y=output,
x=input,
output=output,
input=input,
solver=ode_solver,
max_t1=max_t1,
dt=dt,
Expand Down

0 comments on commit e1b4d2f

Please sign in to comment.