Skip to content

Commit

Permalink
fix: remove @equation
Browse files Browse the repository at this point in the history
  • Loading branch information
CallumJHays committed Oct 22, 2021
1 parent 0e6dbe4 commit d140be5
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 82 deletions.
9 changes: 0 additions & 9 deletions examples/braking_deceleration.ipynbx

This file was deleted.

2 changes: 1 addition & 1 deletion examples/cart_spring.ipynb

Large diffs are not rendered by default.

Binary file added examples/imgs/cart_spring.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion mathpad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .units import *
from .solve import solve, Solution
from .equation import equation, Equation
from .equation import Equation
from ._quality_of_life import t, g, pi, frac
from .algebra import subs, simplify, factor, expand
from .display import tabulate
Expand Down
1 change: 0 additions & 1 deletion mathpad/elec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from mathpad import *


@equation
def resistance_resistivity(
*,
R: Q[Impedance], # resistance
Expand Down
34 changes: 4 additions & 30 deletions mathpad/equation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, get_type_hints
from typing import TYPE_CHECKING
import sympy
from sympy.physics.vector.printing import vlatex

Expand All @@ -24,7 +24,9 @@ def __init__(self, lhs: "Q[GPhysicalQuantity]", rhs: "Q[GPhysicalQuantity]"):
# sanity check
assert lhs.units == self.units == rhs.units

self.lhs: AbstractPhysicalQuantity = lhs if lhs_is_pqty else pqty_cls(self.units, rhs)
self.lhs: AbstractPhysicalQuantity = (
lhs if lhs_is_pqty else pqty_cls(self.units, rhs)
)
self.rhs: AbstractPhysicalQuantity = (
rhs if rhs_is_pqty else pqty_cls(self.units, rhs)
)
Expand Down Expand Up @@ -60,31 +62,3 @@ def _repr_latex_(self):
spacer_ltx = "\\hspace{1.25em}"

return f"$\\displaystyle {lhs_ltx} = {rhs_ltx} {spacer_ltx} {units_ltx}$"


def equation(fn):
# TODO: check input types and constraints
def wrap(**kwargs):
return fn(**kwargs)

wrap.__name__ = fn.__name__

try:
type_hints = get_type_hints(fn, include_extras=True)
# TODO: verify this
wrap.__doc__ = f"{fn.__doc__}\n\n" + "\n".join(
f"{argname} [{ann.__metadata__[0]}]: {ann.__metadata__[1]}"
for argname, ann in type_hints.items()
if argname != "return"
)

except TypeError as e:
assert "unexpected keyword argument 'include_extras'" in str(e)
type_hints = get_type_hints(fn)
wrap.__doc__ = f"{fn.__doc__}\n\n" + "\n".join(
f"{argname} [{ann.__args__[0]}]"
for argname, ann in type_hints.items()
if argname != "return"
)

return wrap
19 changes: 7 additions & 12 deletions mathpad/mech.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,34 @@ def euler_lagrange(
return diff(diff(L, ds)) + diff(L, state) == sum_non_conservative_forces


@equation
def impulse_momentum(
m: Q[Mass], # mass
v1: Q[Velocity], # initial velocity
F: Q[Force], # impulse force required
t: Q[Time], # impulse duration in seconds
v2: Q[Velocity], # final velocity
):
) -> Equation:
"The force required in an instant to change an object's velocity"
return m * v1 + integral(F, t) == m * v2


def velocity_acceleration(
a: Q[Acceleration], v: Q[Velocity], t: Q[Time] # symbol for time
):
v: Q[Velocity], t: Q[Time] # symbol for time
) -> Acceleration:
return a == integral(v, t)


@equation
def force_momentum(
F: Q[Force], # resulting force
m: Q[Mass], # mass of object
v: Q[Velocity],
t: Q[Time] = t,
):
return F == diff(m * v, t)
) -> Force:
return diff(m * v, t)


@equation
def angular_momentum(
Ho: Q[AngularMomentum], # Angular momentum around a unit vector
r_p_o: Q[Length], # unit vector of rotation axis (anti-clockwise)
m: Q[Mass], # mass of point object
v: Q[Velocity], # velocity of point object
):
return Ho == r_p_o.cross(m * v)
) -> AngularMomentum:
return r_p_o.cross(m * v)
75 changes: 51 additions & 24 deletions mathpad/simulate_dynamic_system.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Collection, Set, List, Optional, Tuple

import sympy
from sympy.core.function import Function, AppliedUndef
from sympy import Derivative
import sympy
import plotly.graph_objects as go
from sympy.utilities.lambdify import lambdify
from scipy.integrate import RK45
Expand All @@ -23,43 +23,51 @@ def simulate_dynamic_system(
max_step: Optional[float],
substitute: SubstitutionMap = {},
x_axis: AbstractPhysicalQuantity = t,
) -> List[Tuple[float, List[float]]]:
display_equations: bool = True,
display_plots: bool = True,
all_solutions: bool = True,
) -> List[List[Tuple[float, List[float]]]]:
"simulates a differential system specified by dynamics_equations from initial conditions at x_axis=0 (typically t=0) to x_final"

# TODO: support integrals
if max_step is None:
max_step = float("inf")

# pre-substitute and simplify the input equations before further processing
sympy_eqns = [
simplify(subs(eqn, substitute)).as_sympy_eq() for eqn in dynamics_equations
]
problem_eqns = [simplify(subs(eqn, substitute)) for eqn in dynamics_equations]

# collect derivatives and any unspecified unkowns
derivatives: Set[Tuple[Function, float]] = set()

for eqn in sympy_eqns:
for eqn in problem_eqns:
sympy_eqn = eqn.as_sympy_eq()
# TODO: properly check x_axis for derivative collection (usually t)
derivatives.update(
{
(d.args[0], d.args[1][1] if isinstance(d.args[1], sympy.Tuple) else 1)
for d in eqn.atoms(Derivative)
for d in sympy_eqn.atoms(Derivative)
}
)
derivatives.update(
{(f, 0) for f in eqn.atoms(Function) if isinstance(f, AppliedUndef)}
{(f, 0) for f in sympy_eqn.atoms(Function) if isinstance(f, AppliedUndef)}
)

highest_derivatives = {}
lowest_derivatives = {}
for f, lvl in derivatives:

if f in highest_derivatives:
if highest_derivatives[f] < lvl:
lowest_derivatives[f] = highest_derivatives[f]
highest_derivatives[f] = lvl
else:
lowest_derivatives[f] = lvl
else:
highest_derivatives[f] = lvl

if f in lowest_derivatives:
if lowest_derivatives[f] > lvl:
lowest_derivatives[f] = lvl
else:
lowest_derivatives[f] = lvl

solve_for_highest_derivatives = [
fn if lvl == 0 else sympy.diff(fn, (x_axis.val, lvl))
for fn, lvl in highest_derivatives.items()
Expand All @@ -69,10 +77,21 @@ def simulate_dynamic_system(

solve_for = solve_for_highest_derivatives + solve_for_recorded_data

solutions = sympy.solve(sympy_eqns, solve_for_highest_derivatives, dict=True)
if display_equations:
print("Solving Equations:")
for eqn in problem_eqns:
display(eqn)

solutions = sympy.solve(
[eqn.as_sympy_eq() for eqn in problem_eqns],
solve_for_highest_derivatives,
dict=True,
)

assert any(solutions), "sympy solving failed!"

all_data = []

for solution_idx, solution in enumerate(solutions):

# in dict mode, if a solution is equal to the query the solution is not included in the dict
Expand Down Expand Up @@ -156,16 +175,24 @@ def step(x, state):
print(f"integration completed with failed status: {msg}")
break

display(
go.Figure(
[
go.Scatter(
x=[t for t, _ in data],
y=[frame[idx] for _, frame in data],
name=str(sym),
)
for idx, sym in enumerate(record)
],
layout=dict(title=f"Solution #{solution_idx + 1}"),
if display_plots:
display(
go.Figure(
[
go.Scatter(
x=[t for t, _ in data],
y=[frame[idx] for _, frame in data],
name=str(sym),
)
for idx, sym in enumerate(record)
],
layout=dict(title=f"Solution #{solution_idx + 1}"),
)
)
)

all_data.extend(data)

if not all_solutions:
break

return all_data
6 changes: 2 additions & 4 deletions mathpad/trigonometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,20 @@ def tan(x: Q[Angle]):
# TODO: add more trig functions


@equation
def sine_rule(
a: Q[Length], # length of side a
b: Q[Length], # length of side b
alpha: Q[Angle], # internal angle opposite to side a
beta: Q[Angle], # internal angle opposite to side b
):
) -> Equation:
"Relates lengths of two sides of any triangle the internal angle opposite"
return a == b * sin(alpha) / sin(beta)


@equation
def cosine_rule(
a: Q[Length], # length of side a
b: Q[Length], # length of side b
c: Q[Length], # length of side c
C: Q[Angle], # internal angle opposite to side c
): #
) -> Equation:
return c ** 2 == a ** 2 + b ** 2 - 2 * a * b * cos(C)

0 comments on commit d140be5

Please sign in to comment.