Skip to content

Commit

Permalink
Merge pull request #401 from gyorilab/sympy_odes
Browse files Browse the repository at this point in the history
Improve sympy extraction and ODE simulation
  • Loading branch information
bgyori authored Dec 4, 2024
2 parents aed36f9 + 7461065 commit 9b3ef9f
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 47 deletions.
158 changes: 115 additions & 43 deletions mira/modeling/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ class OdeModel:
"""A class representing an ODE model."""

def __init__(self, model: Model, initialized: bool):
self.model = model
self.y = sympy.MatrixSymbol('y', len(model.variables), 1)
self.vmap = {variable.key: idx for idx, variable
in enumerate(model.variables.values())}
self.vname_map = {idx: variable.concept.name for idx, variable
in enumerate(model.variables.values())}
self.observable_map = {obs_key: idx for idx, obs_key
in enumerate(model.observables)}
real_params = {k: v for k, v in model.parameters.items()
Expand Down Expand Up @@ -100,10 +103,58 @@ def get_interpretable_kinetics(self):
# based on vmap and pmap
subs = {self.y[v]: sympy.Symbol(k) if isinstance(k, str) else k[0] for k, v in self.vmap.items()}
subs.update({self.p[p]: sympy.Symbol(k) for k, p in self.pmap.items()})
return sympy.Matrix([
rhs = sympy.Matrix([
k.subs(subs) for k in self.kinetics
])

lhs = sympy.Matrix([sympy.Derivative(sympy.Symbol(self.model.variables[k].concept.name),
sympy.Symbol('t')) for k, v in self.vmap.items()])

equations = sympy.Matrix([[lhs[i], sympy.Symbol('='), rhs[i]] for i in range(len(lhs))])
return equations

def get_interpretable_observables(self):
subs = {self.y[v]: sympy.Symbol(k) if isinstance(k, str) else k[0] for k, v in self.vmap.items()}
subs.update({self.p[p]: sympy.Symbol(k) for k, p in self.pmap.items()})

lhs = sympy.Matrix([sympy.Symbol(k) for k in self.observable_map.keys()])

obs = sympy.Matrix([
k.subs(subs) for k in self.observables
])

equations = sympy.Matrix([[lhs[i], sympy.Symbol('='), obs[i]] for i in range(len(lhs))])
return equations

def plot_simulation_results(self, times, res):
"""Plot the results of a simulation.
Parameters
----------
times :
A one-dimensional array of time values
res :
A two-dimensional array with the first axis being time
and the second axis being the agents in the ODE model.
"""
import matplotlib.pyplot as plt
num_vars = self.y.shape[0]
num_obs = len(self.observable_map)
num_cols = num_vars + num_obs
num_rows = num_cols // 4 + 1
fig, axes = plt.subplots(num_rows, 4, figsize=(16, 18))
for idx, ax in enumerate(axes.flat):
if idx >= num_cols:
ax.axis('off')
continue
if idx < num_vars:
ax.plot(times, res[:, idx])
ax.set_title(self.vname_map[idx])
else:
ax.plot(times, res[:, idx])
ax.set_title(list(self.observable_map.keys())[idx - num_vars])
plt.show()

def set_parameters(self, params):
"""Set the parameters of the model."""
for p, v in params.items():
Expand All @@ -120,14 +171,69 @@ def rhs(t, y):

return rhs

# TODO is there a way to get the variable names in
# order out of this, e.g., for adding a legend to plots?
def simulate_model(self, times, initials=None,
parameters=None, with_observables=False):
"""Simulate the ODE model given initial conditions, parameters and a
time span.
Parameters
----------
initials :
A one-dimensional array describing the initial values
for the agents in the ODE model
parameters :
A dictionary of keys for parameters to their values
times :
A one-dimensional array of time values, typically from
a linear space like ``numpy.linspace(0, 25, 100)``
with_observables :
A boolean indicating whether to return the observables
as well as the variables.
Returns
-------
A two-dimensional array with the first axis being time
and the second axis being the agents in the ODE model.
"""
rhs = self.get_rhs()

# If parameters and initial values for agents have already been
# initialized before calling the simulate_ode method
if parameters is None and initials is None:
parameters = {}
parameter_name_list = self.pmap.keys()
for parameter, parameter_value in zip(parameter_name_list,
self.parameter_values):
parameters[parameter] = parameter_value

initials = self.variable_values
for index, expression in enumerate(initials):
# Only substitute if this is an expression. Once the model
# has been simulated, this is actually a float.
if isinstance(expression, sympy.Expr):
initials[index] = float(expression.subs(parameters).args[0])

self.set_parameters(parameters)
solver = scipy.integrate.ode(f=rhs)

solver.set_initial_value(initials)
num_vars = self.y.shape[0]
num_obs = len(self.observable_map)
num_cols = num_vars + (num_obs if with_observables else 0)
res = numpy.zeros((len(times), num_cols))
res[0, :num_vars] = initials
for idx, time in enumerate(times[1:]):
res[idx + 1, :num_vars] = solver.integrate(time)

if with_observables:
for tidx, t in enumerate(times):
obs_res = \
self.observables_lmbd(res[tidx, :num_vars][:, None])
for idx, val in enumerate(obs_res):
res[tidx, num_vars + idx] = obs_res[idx]
return res


# Make it such that the method can accept params and initial values from the model itself rather than having them being
# passed to method when called
# ode_model: OdeModel, times, initials=None,
# parameters=None
def simulate_ode_model(ode_model: OdeModel, times, initials=None,
parameters=None, with_observables=False):
"""Simulate an ODE model given initial conditions, parameters and a
Expand All @@ -154,39 +260,5 @@ def simulate_ode_model(ode_model: OdeModel, times, initials=None,
A two-dimensional array with the first axis being time
and the second axis being the agents in the ODE model.
"""

rhs = ode_model.get_rhs()

# If parameters and initial values for agents have already been initialized before calling the simulate_ode method
if parameters is None and initials is None:
parameters = {}
parameter_name_list = ode_model.pmap.keys()
for parameter, parameter_value in zip(parameter_name_list, ode_model.parameter_values):
parameters[parameter] = parameter_value

initials = ode_model.variable_values
for index, expression in enumerate(initials):
# Only substitute if this is an expression. Once the model has been
# simulated, this is actually a float.
if isinstance(expression, sympy.Expr):
initials[index] = float(expression.subs(parameters).args[0])

ode_model.set_parameters(parameters)
solver = scipy.integrate.ode(f=rhs)

solver.set_initial_value(initials)
num_vars = ode_model.y.shape[0]
num_obs = len(ode_model.observable_map)
num_cols = num_vars + (num_obs if with_observables else 0)
res = numpy.zeros((len(times), num_cols))
res[0, :num_vars] = initials
for idx, time in enumerate(times[1:]):
res[idx + 1, :num_vars] = solver.integrate(time)

if with_observables:
for tidx, t in enumerate(times):
obs_res = \
ode_model.observables_lmbd(res[tidx, :num_vars][:, None])
for idx, val in enumerate(obs_res):
res[tidx, num_vars + idx] = obs_res[idx]
return res
return ode_model.simulate_model(times, initials, parameters,
with_observables)
11 changes: 7 additions & 4 deletions mira/sources/sympy_ode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def template_model_from_sympy_odes(odes, concept_data=None, param_data=None):
# Break up the RHS into a sum of terms
terms = rhs.as_ordered_terms()
for term in terms:
neg = is_negative(term)
neg = is_negative(term, time_variable)
term_parameters = term.free_symbols - {time_variable}
parameters |= term_parameters
funcs = term.atoms(Function)
Expand Down Expand Up @@ -123,6 +123,7 @@ def template_model_from_sympy_odes(odes, concept_data=None, param_data=None):
if not effects['produces']:
if len(effects['consumes']) == 1:
cons = effects['consumes'][0]
rate_law = -rate_law
if not controllers:
template = NaturalDegradation(subject=make_concept(cons, concept_data),
rate_law=rate_law)
Expand Down Expand Up @@ -193,13 +194,15 @@ def template_model_from_sympy_odes(odes, concept_data=None, param_data=None):
return tm


def is_negative(term):
def is_negative(term, time):
# Replace any parameters with 0.1, assuming positivity
term = term.subs({s: 0.1 for s in term.free_symbols})
term = term.subs({s: 0.1 for s in term.free_symbols
if s != time})
# Now look at the variables appearing in the term and differentiate
funcs = term.atoms(Function)
for func in funcs:
term = sympy.diff(term, func)
# Replace the function with 1, assuming positivity
term = term.subs(func, 1)
# Whatever is left is the ultimate sign of the term with respect
# to its variables
return term.is_negative
Expand Down

0 comments on commit 9b3ef9f

Please sign in to comment.