-
Notifications
You must be signed in to change notification settings - Fork 2
/
__init__.py
65 lines (46 loc) · 1.94 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import sys
import math
import operator as O
import itertools as I
import functools as F
# ---------------------------------------------------------------------------- #
# JAX #
# ---------------------------------------------------------------------------- #
import jax
import flax
import optax
import jax.numpy as np
import flax.linen as nn
# ---------------------------------------------------------------------------- #
# TYPE #
# ---------------------------------------------------------------------------- #
from abc import *
from typing import *
from jax import Array
from flax import struct
X = Union[Tuple["X", ...], List["X"], Array]
ϴ = Union[struct.PyTreeNode, "X", None]
Fx = Callable[..., "X"] # real-valued function
Fϴ = Callable[..., "Fx"] # parametrized function
# ---------------------------------------------------------------------------- #
# CONST #
# ---------------------------------------------------------------------------- #
e = np.e
π = np.pi
Δ = F.partial(np.einsum, "...ii -> ...")
# ---------------------------------------------------------------------------- #
# RANDOM #
# ---------------------------------------------------------------------------- #
from jax import random
RNG = Dict[str, random.KeyArray]
class RNGS(RNG):
def __init__(self, prng: random.KeyArray, name: List[str]):
keys = random.split(prng, len(name))
super().__init__(zip(name, keys))
def __next__(self) -> RNG:
self.it = getattr(self, "it", 0) + 1
return self.fold_in(self.it)
def fold_in(self, key: Any) -> RNG:
return { name: random.fold_in(data, hash(key))
for name, data in self.items() }