Skip to content

Commit

Permalink
Ignored LLPO in the removed nodes; using 0.0 for floating point error
Browse files Browse the repository at this point in the history
  • Loading branch information
bibek committed May 1, 2024
1 parent f3820a7 commit 6b8ff71
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
2 changes: 2 additions & 0 deletions causing/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def recalc_graphs(graphs, model, xdat) -> Iterable[networkx.DiGraph]:
for i, approx_graph in enumerate(graphs):
individual_xdat = xdat[:, i : i + 1]
removed_nodes = set(model.graph.nodes) - set(approx_graph.nodes)
if "LLPO" in removed_nodes:
removed_nodes.remove("LLPO")

# Calc effects on shrunken model
individual_model = model.shrink(removed_nodes)
Expand Down
17 changes: 16 additions & 1 deletion causing/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Iterable, Callable
from functools import cached_property
Expand Down Expand Up @@ -82,8 +83,22 @@ def compute(
eq_inputs[:, fixed_from_ind] = fixed_vals

try:
# print(f"Comuting variable: {self.yvars[i]}")
# yhat[i] = np.array(
# [eq(*eq_in, *parameters.values()) for eq_in in eq_inputs],
# dtype=np.float64,
# )
computed_yvars = []
for eq_in in eq_inputs:
try:
computed_yvars.append(eq(*eq_in, *parameters.values()))
except FloatingPointError:
# Floating Point Error for self.yvars[i]
# Adding 0.0 to overcome this.
computed_yvars.append(0.0)

yhat[i] = np.array(
[eq(*eq_in, *parameters.values()) for eq_in in eq_inputs],
computed_yvars,
dtype=np.float64,
)
except Exception as e:
Expand Down

0 comments on commit 6b8ff71

Please sign in to comment.