forked from Facebear-ljx/PROTO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ensemble.py
37 lines (30 loc) · 1.06 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Type
import flax.linen as nn
import jax
import jax.numpy as jnp
class Ensemble(nn.Module):
net_cls: Type[nn.Module]
num: int = 2
@nn.compact
def __call__(self, *args):
ensemble = nn.vmap(
self.net_cls,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=None,
out_axes=0,
axis_size=self.num,
)
return ensemble()(*args)
def subsample_ensemble(key: jax.random.PRNGKey, params, num_sample: int, num_qs: int):
if num_sample is not None:
all_indx = jnp.arange(0, num_qs)
indx = jax.random.choice(key, a=all_indx, shape=(num_sample,), replace=False)
if "Ensemble_0" in params:
ens_params = jax.tree_util.tree_map(
lambda param: param[indx], params["Ensemble_0"]
)
params = params.copy(add_or_replace={"Ensemble_0": ens_params})
else:
params = jax.tree_util.tree_map(lambda param: param[indx], params)
return params