forked from Louisbu/hybrid_ivp
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconstant_current.py
31 lines (24 loc) · 977 Bytes
/
constant_current.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from functools import partial
from typing import Tuple
import jax.numpy as jnp
from jax import jit
from hybrid_routing.vectorfields.base import Vectorfield
class ConstantCurrent(Vectorfield):
"""Constant vector field, implements Vectorfield class.
Vectorfield defined by:
W: (x, y) -> (u, v), u(x, y) = 0.2, v(x, y) = -0.2
with:
du/dx = 0, du/dy = 0
dv/dx = 0, dv/dy = 0
"""
@partial(jit, static_argnums=(0,))
def dv(self, x: jnp.array, y: jnp.array) -> Tuple[jnp.array]:
return (jnp.tile(0.0, x.shape), jnp.tile(0.0, y.shape))
@partial(jit, static_argnums=(0,))
def du(self, x: jnp.array, y: jnp.array) -> Tuple[jnp.array]:
return (jnp.tile(0.0, x.shape), jnp.tile(0.0, y.shape))
@partial(jit, static_argnums=(0,))
def get_current(self, x: jnp.array, y: jnp.array) -> jnp.array:
u = jnp.tile(0.2, x.shape)
v = jnp.tile(-0.2, x.shape)
return jnp.stack([u, v])