Skip to content

Commit

Permalink
integrate torchdiffeq with snopt by applying snopt_integration.patch
Browse files Browse the repository at this point in the history
  • Loading branch information
ghliu committed Nov 5, 2021
1 parent 3eb93b3 commit 2a31f24
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
11 changes: 10 additions & 1 deletion torchdiffeq/_impl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .misc import _check_inputs, _flat_to_shape
from .misc import _mixed_norm

from snopt import SNOptAdjointCollector


class OdeintAdjointMethod(torch.autograd.Function):

Expand Down Expand Up @@ -109,6 +111,8 @@ def augmented_dynamics(t, y_aug):
# Solve adjoint ODE #
##################################

snopt_collector = SNOptAdjointCollector(func) if adjoint_options['use_snopt'] else None

if t_requires_grad:
time_vjps = torch.empty(len(t), dtype=t.dtype, device=t.device)
else:
Expand All @@ -123,9 +127,14 @@ def augmented_dynamics(t, y_aug):
time_vjps[i] = dLd_cur_t

# Run the augmented system backwards in time.
samp_t = t[i - 1:i + 1].flip(0)

if snopt_collector:
adjoint_options, samp_t = snopt_collector.check_inputs(adjoint_options, samp_t)

aug_state = odeint(
augmented_dynamics, tuple(aug_state),
t[i - 1:i + 1].flip(0),
samp_t,
rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
)
aug_state = [a[1] for a in aug_state] # extract just the t[i - 1] value
Expand Down
10 changes: 8 additions & 2 deletions torchdiffeq/_impl/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@


class AdaptiveStepsizeODESolver(metaclass=abc.ABCMeta):
def __init__(self, dtype, y0, norm, **unused_kwargs):
def __init__(self, dtype, y0, norm, snopt_collector=None, **unused_kwargs):
_handle_unused_kwargs(self, unused_kwargs)
del unused_kwargs

self.y0 = y0
self.dtype = dtype

self.norm = norm
self.snopt_collector = snopt_collector

def _before_integrate(self, t):
pass
Expand All @@ -28,6 +29,8 @@ def integrate(self, t):
self._before_integrate(t)
for i in range(1, len(t)):
solution[i] = self._advance(t[i])
if self.snopt_collector:
self.snopt_collector.call_invoke(self.func, t[i], solution[i])
return solution


Expand All @@ -48,7 +51,7 @@ def integrate_until_event(self, t0, event_fn):
class FixedGridODESolver(metaclass=abc.ABCMeta):
order: int

def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, **unused_kwargs):
def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, snopt_collector=None, **unused_kwargs):
self.atol = unused_kwargs.pop('atol')
unused_kwargs.pop('rtol', None)
unused_kwargs.pop('norm', None)
Expand All @@ -62,6 +65,7 @@ def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="line
self.step_size = step_size
self.interp = interp
self.perturb = perturb
self.snopt_collector = snopt_collector

if step_size is None:
if grid_constructor is None:
Expand Down Expand Up @@ -113,6 +117,8 @@ def integrate(self, t):
solution[j] = self._cubic_hermite_interp(t0, y0, f0, t1, y1, f1, t[j])
else:
raise ValueError(f"Unknown interpolation method {self.interp}")
if self.snopt_collector:
self.snopt_collector.call_invoke(self.func, t[j], solution[j])
j += 1
y0 = y1

Expand Down

0 comments on commit 2a31f24

Please sign in to comment.