-
Notifications
You must be signed in to change notification settings - Fork 917
/
flatten.py
33 lines (29 loc) · 1.09 KB
/
flatten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
Handy functions for flattening nested containers containing numpy
arrays. The main purpose is to make examples and optimizers simpler.
"""
from autograd import make_vjp
from autograd.builtins import type
import autograd.numpy as np
def flatten(value):
"""Flattens any nesting of tuples, lists, or dicts, with numpy arrays or
scalars inside. Returns 1D numpy array and an unflatten function.
Doesn't preserve mixed numeric types (e.g. floats and ints). Assumes dict
keys are sortable."""
unflatten, flat_value = make_vjp(_flatten)(value)
return flat_value, unflatten
def _flatten(value):
t = type(value)
if t in (list, tuple):
return _concatenate(map(_flatten, value))
elif t is dict:
return _concatenate(_flatten(value[k]) for k in sorted(value))
else:
return np.ravel(value)
def _concatenate(lst):
lst = list(lst)
return np.concatenate(lst) if lst else np.array([])
def flatten_func(func, example):
_ex, unflatten = flatten(example)
_func = lambda _x, *args: flatten(func(unflatten(_x), *args))[0]
return _func, unflatten, _ex